diff --git a/Makefile b/Makefile index f62c5266..26bffeda 100644 --- a/Makefile +++ b/Makefile @@ -19,8 +19,12 @@ else CXX = g++ endif -COMMON_FLAGS = -Wall -std=c++20 -# -static-libstdc++ -static-libgcc +ifneq ($(findstring MINGW, $(HOST_OS)), MINGW) +COMMON_FLAGS = -Wall -std=c++20 -static-libstdc++ -static-libgcc +else +# For mingw-64 use this: +COMMON_FLAGS = -Wall -std=c++20 -static-libstdc++ -static-libgcc -static +endif # -fext-numeric-literals ifeq ($(HOST_OS), Darwin) diff --git a/src/Args.cpp b/src/Args.cpp index 17aa6738..041202ad 100644 --- a/src/Args.cpp +++ b/src/Args.cpp @@ -156,6 +156,7 @@ named "config.txt" in the prpll run directory. -prp : run a single PRP test and exit, ignoring worktodo.txt -ll : run a single LL test and exit, ignoring worktodo.txt -verify : verify PRP-proof contained in +-smallest : work on smallest exponent in worktodo.txt rather than the first exponent in worktodo.txt -proof : generate proof of power (default: optimal depending on exponent). A lower power reduces disk space requirements but increases the verification cost. A higher power increases disk usage a lot. @@ -185,6 +186,8 @@ named "config.txt" in the prpll run directory. 2 = calculate from scratch, no memory read 1 = calculate using one complex multiply from cached memory and uncached memory 0 = read trig values from memory + -use INPLACE=n : Perform tranforms in-place. Great if the reduced memory usage fits in the GPU's L2 cache. + 0 = not in-place, 1 = nVidia friendly access pattern, 2 = AMD friendly access pattern. -use PAD= : insert pad bytes to possibly improve memory access patterns. Val is number bytes to pad. -use MIDDLE_IN_LDS_TRANSPOSE=0|1 : Transpose values in local memory before writing to global memory -use MIDDLE_OUT_LDS_TRANSPOSE=0|1 : Transpose values in local memory before writing to global memory @@ -194,23 +197,14 @@ named "config.txt" in the prpll run directory. -use DEBUG : enable asserts in OpenCL kernels (slow, developers) --tune : measures the speed of the FFTs specified in -fft to find the best FFT for each exponent. - --ctune : finds the best configuration for each FFT specified in -fft . - Prints the results in a form that can be incorporated in config.txt - -fft 6.5M -ctune "OUT_SIZEX=32,8;OUT_WG=64,128,256" - - It is possible to specify -ctune multiple times on the same command in order to define multiple - sets of parameters to be combined, e.g.: - -ctune "IN_WG=256,128,64" -ctune "OUT_WG=256,64;OUT_SIZEX=32,16,8" - which would try only 8 combinations among those two sets. - - The tunable parameters (with the default value emphasized) are: - IN_WG, OUT_WG: 64, 128, *256* - IN_SIZEX, OUT_SIZEX: 4, 8, 16, *32* - UNROLL_W: *0*, 1 - UNROLL_H: 0, 1 - +-tune : Looks for best settings to include in config.txt. Times many FFTs to find fastest one to test exponents -- written to tune.txt. + An -fft can be given on the command line to limit which FFTs are timed. + Options are not required. If present, the options are a comma separated list from below. + noconfig - Skip timings to find best config.txt settings + fp64 - Tune for settings that affect FP64 FFTs. Time FP64 FFTs for tune.txt. + ntt - Tune for settings that affect integer NTTs. Time integer NTTs for tune.txt. + minexp= - Time FFTs to find the best one for exponents greater than . + maxexp= - Time FFTs to find the best one for exponents less than . -device : select the GPU at position N in the list of devices -uid : select the GPU with the given UID (on ROCm/AMDGPU, Linux) -pci : select the GPU with the given PCI BDF, e.g. "0c:00.0" @@ -236,31 +230,34 @@ Device selection : use one of -uid , -pci , -device , see the list ); } - printf("\nFFT Configurations (specify with -fft :: from the set below):\n" + printf("\nFFT Configurations (specify with -fft ::: from the set below):\n" " Size MaxExp BPW FFT\n"); - + vector configs = FFTShape::allShapes(); configs.push_back(configs.front()); // dummy guard for the loop below. - string variants; u32 activeSize = 0; - double maxBpw = 0; - for (auto c : configs) { - if (c.size() != activeSize) { - if (!variants.empty()) { - printf("%5s %7.2fM %.2f %s\n", - numberK(activeSize).c_str(), - // activeSize * FFTShape::MIN_BPW / 1'000'000, - activeSize * maxBpw / 1'000'000.0, - maxBpw, - variants.c_str()); - variants.clear(); + float maxBpw = 0; + string variants; + for (enum FFT_TYPES type : {FFT64, FFT3161, FFT3261, FFT61}) { + for (auto c : configs) { + if (c.fft_type != type) continue; + if (c.size() != activeSize) { + if (!variants.empty()) { + printf("%5s %7.2fM %.2f %s\n", + numberK(activeSize).c_str(), + // activeSize * FFTShape::MIN_BPW / 1'000'000, + activeSize * maxBpw / 1'000'000.0, + maxBpw, + variants.c_str()); + variants.clear(); + } + activeSize = c.size(); + maxBpw = 0; } - activeSize = c.size(); - maxBpw = 0; + maxBpw = max(maxBpw, c.maxBpw()); + if (!variants.empty()) { variants.push_back(','); } + variants += c.spec(); } - maxBpw = max(maxBpw, c.maxBpw()); - if (!variants.empty()) { variants.push_back(','); } - variants += c.spec(); } } @@ -295,9 +292,10 @@ void Args::parse(const string& line) { log("-info expects an FFT spec, e.g. -info 1K:13:256\n"); throw "-info "; } - log(" FFT | BPW | Max exp (M)\n"); + log(" FFT | BPW | Max exp (M)\n"); for (const FFTShape& shape : FFTShape::multiSpec(s)) { for (u32 variant = 0; variant <= LAST_VARIANT; variant = next_variant (variant)) { + if (variant != LAST_VARIANT && shape.fft_type != FFT64) continue; FFTConfig fft{shape, variant, CARRY_AUTO}; log("%12s | %.2f | %5.1f\n", fft.spec().c_str(), fft.maxBpw(), fft.maxExp() / 1'000'000.0); } @@ -310,8 +308,8 @@ void Args::parse(const string& line) { assert(s.empty()); logROE = true; } else if (key == "-tune") { - assert(s.empty()); doTune = true; + if (!s.empty()) { tune = s; } } else if (key == "-ctune") { doCtune = true; if (!s.empty()) { ctune.push_back(s); } @@ -372,6 +370,7 @@ void Args::parse(const string& line) { else if (key == "-iters") { iters = stoi(s); assert(iters && (iters % 10000 == 0)); } else if (key == "-prp" || key == "-PRP") { prpExp = stoll(s); } else if (key == "-ll" || key == "-LL") { llExp = stoll(s); } + else if (key == "-smallest") { smallest = true; } else if (key == "-fft") { fftSpec = s; } else if (key == "-dump") { dump = s; } else if (key == "-user") { user = s; } diff --git a/src/Args.h b/src/Args.h index c7e92259..795cd99c 100644 --- a/src/Args.h +++ b/src/Args.h @@ -43,6 +43,7 @@ class Args { string uid; string verifyPath; + string tune; vector ctune; bool doCtune{}; @@ -53,7 +54,7 @@ class Args { std::map flags; std::map> perFftConfig; - + int device = 0; bool safeMath = true; @@ -61,6 +62,7 @@ class Args { bool verbose = false; bool useCache = false; bool profile = false; + bool smallest = false; fs::path masterDir; fs::path proofResultDir = "proof"; diff --git a/src/Background.h b/src/Background.h index 08f4d079..96cbe804 100644 --- a/src/Background.h +++ b/src/Background.h @@ -16,10 +16,10 @@ class Background { unsigned maxSize; std::deque > tasks; - std::jthread thread; std::mutex mut; std::condition_variable cond; - bool stopRequested{}; + bool stopRequested; + std::jthread thread; void run() { std::function task; @@ -59,6 +59,7 @@ class Background { public: Background(unsigned size = 2) : maxSize{size}, + stopRequested(false), thread{&Background::run, this} { } diff --git a/src/Buffer.h b/src/Buffer.h index c172c883..7866975d 100644 --- a/src/Buffer.h +++ b/src/Buffer.h @@ -26,7 +26,7 @@ class Buffer { Queue* queue; TimeInfo *tInfo; - + Buffer(cl_context context, TimeInfo *tInfo, Queue* queue, size_t size, unsigned flags, const T* ptr = nullptr) : ptr{size == 0 ? NULL : makeBuf_(context, flags, size * sizeof(T), ptr)} , size{size} diff --git a/src/FFTConfig.cpp b/src/FFTConfig.cpp index 669f7eea..2308a037 100644 --- a/src/FFTConfig.cpp +++ b/src/FFTConfig.cpp @@ -14,15 +14,16 @@ #include #include #include +#include using namespace std; struct FftBpw { string fft; - array bpw; + array bpw; }; -map> BPW { +map> BPW { #include "fftbpw.h" }; @@ -39,6 +40,7 @@ u32 parseInt(const string& s) { } // namespace // Accepts: +// - a prefix indicating FFT_type (if not specified, default is FP64) // - a single config: 1K:13:256 // - a size: 6.5M // - a range of sizes: 6.5M-7M @@ -49,13 +51,18 @@ vector FFTShape::multiSpec(const string& iniSpec) { vector ret; for (const string &spec : split(iniSpec, ',')) { + enum FFT_TYPES fft_type = FFT64; auto parts = split(spec, ':'); + if (parseInt(parts[0]) < 60) { // Look for a prefix specifying the FFT type + fft_type = (enum FFT_TYPES) parseInt(parts[0]); + parts = vector(next(parts.begin()), parts.end()); + } assert(parts.size() <= 3); if (parts.size() == 3) { u32 width = parseInt(parts[0]); u32 middle = parseInt(parts[1]); u32 height = parseInt(parts[2]); - ret.push_back({width, middle, height}); + ret.push_back({fft_type, width, middle, height}); continue; } assert(parts.size() == 1); @@ -76,13 +83,16 @@ vector FFTShape::multiSpec(const string& iniSpec) { vector FFTShape::allShapes(u32 sizeFrom, u32 sizeTo) { vector configs; - for (u32 width : {256, 512, 1024, 4096}) { - for (u32 height : {256, 512, 1024}) { - if (width == 256 && height == 1024) { continue; } // Skip because we prefer width >= height - for (u32 middle : {2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}) { - u32 sz = width * height * middle * 2; - if (sizeFrom <= sz && sz <= sizeTo) { - configs.push_back({width, middle, height}); + for (enum FFT_TYPES type : {FFT64, FFT3161, FFT3261, FFT61, FFT323161}) { + for (u32 width : {256, 512, 1024, 4096}) { + for (u32 height : {256, 512, 1024}) { + if (width == 256 && height == 1024) { continue; } // Skip because we prefer width >= height + for (u32 middle : {2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}) { + if (type != FFT64 && (middle & (middle - 1))) continue; // Reject non-power-of-two NTTs + u32 sz = width * height * middle * 2; + if (sizeFrom <= sz && sz <= sizeTo) { + configs.push_back({type, width, middle, height}); + } } } } @@ -101,36 +111,27 @@ vector FFTShape::allShapes(u32 sizeFrom, u32 sizeTo) { FFTShape::FFTShape(const string& spec) { assert(!spec.empty()); + enum FFT_TYPES fft_type = FFT64; vector v = split(spec, ':'); + if (parseInt(v[0]) < 60) { // Look for a prefix specifying the FFT type + fft_type = (enum FFT_TYPES) parseInt(v[0]); + for (u32 i = 1; i < v.size(); ++i) v[i-1] = v[i]; + v.resize(v.size() - 1); + } assert(v.size() == 3); - *this = FFTShape{v.at(0), v.at(1), v.at(2)}; + *this = FFTShape{fft_type, v.at(0), v.at(1), v.at(2)}; } -FFTShape::FFTShape(const string& w, const string& m, const string& h) : - FFTShape{parseInt(w), parseInt(m), parseInt(h)} +FFTShape::FFTShape(enum FFT_TYPES t, const string& w, const string& m, const string& h) : + FFTShape{t, parseInt(w), parseInt(m), parseInt(h)} {} -double FFTShape::carry32BPW() const { - // The formula below was validated empirically with -carryTune - - // We observe that FFT 6.5M (1024:13:256) has safe carry32 up to 18.35 BPW - // while the 0.5*log2() models the impact of FFT size changes. - // We model carry with a Gumbel distrib similar to the one used for ROE, and measure carry with - // -use STATS=1. See -carryTune - -//GW: I have no idea why this is needed. Without it, -tune fails on FFT sizes from 256K to 1M -// Perhaps it has something to do with RNDVALdoubleToLong in carryutil -if (18.35 + 0.5 * (log2(13 * 1024 * 512) - log2(size())) > 19.0) return 19.0; - - return 18.35 + 0.5 * (log2(13 * 1024 * 512) - log2(size())); -} - -bool FFTShape::needsLargeCarry(u32 E) const { - return E / double(size()) > carry32BPW(); +FFTShape::FFTShape(u32 w, u32 m, u32 h) : + FFTShape(FFT64, w, m, h) { } -FFTShape::FFTShape(u32 w, u32 m, u32 h) : - width{w}, middle{m}, height{h} { +FFTShape::FFTShape(enum FFT_TYPES t, u32 w, u32 m, u32 h) : + fft_type{t}, width{w}, middle{m}, height{h} { assert(w && m && h); // Un-initialized shape, don't set BPW @@ -154,17 +155,37 @@ FFTShape::FFTShape(u32 w, u32 m, u32 h) : while (w >= 4*h) { w /= 2; h *= 2; } while (w < h || w < 256 || w == 2048) { w *= 2; h /= 2; } while (h < 256) { h *= 2; m /= 2; } + if (m == 1) m = 2; bpw = FFTShape{w, m, h}.bpw; for (u32 j = 0; j < NUM_BPW_ENTRIES; ++j) bpw[j] -= 0.05; // Assume this fft spec is worse than measured fft specs if (this->isFavoredShape()) { // Don't output this warning message for non-favored shapes (we expect the BPW info to be missing) printf("BPW info for %s not found, defaults={", s.c_str()); - for (u32 j = 0; j < NUM_BPW_ENTRIES; ++j) printf("%s%.2f", j ? ", " : "", bpw[j]); - printf("}\n"); + for (u32 j = 0; j < NUM_BPW_ENTRIES; ++j) printf("%s%.2f", j ? ", " : "", (double) bpw[j]); + printf("}\n"); } } } } +float FFTShape::carry32BPW() const { + // The formula below was validated empirically with -carryTune + + // We observe that FFT 6.5M (1024:13:256) has safe carry32 up to 18.35 BPW + // while the 0.5*log2() models the impact of FFT size changes. + // We model carry with a Gumbel distrib similar to the one used for ROE, and measure carry with + // -use STATS=1. See -carryTune + +//GW: I have no idea why this is needed. Without it, -tune fails on FFT sizes from 256K to 1M +// Perhaps it has something to do with RNDVALdoubleToLong in carryutil +if (18.35 + 0.5 * (log2(13 * 1024 * 512) - log2(size())) > 19.0) return 19.0; + + return 18.35 + 0.5 * (log2(13 * 1024 * 512) - log2(size())); +} + +bool FFTShape::needsLargeCarry(u32 E) const { + return E / double(size()) > carry32BPW(); +} + // Return TRUE for "favored" shapes. That is, those that are most likely to be useful. To save time in generating bpw data, only these favored // shapes have their bpw data pre-computed. Bpw for non-favored shapes is guessed from the bpw data we do have. Also. -tune will normally only // time favored shapes. These are the rules for deciding favored shapes: @@ -183,18 +204,24 @@ bool FFTShape::isFavoredShape() const { FFTConfig::FFTConfig(const string& spec) { auto v = split(spec, ':'); - // assert(v.size() == 1 || v.size() == 3 || v.size() == 4 || v.size() == 5); + + enum FFT_TYPES fft_type = FFT64; + if (parseInt(v[0]) < 60) { // Look for a prefix specifying the FFT type + fft_type = (enum FFT_TYPES) parseInt(v[0]); + for (u32 i = 1; i < v.size(); ++i) v[i-1] = v[i]; + v.resize(v.size() - 1); + } if (v.size() == 1) { *this = {FFTShape::multiSpec(spec).front(), LAST_VARIANT, CARRY_AUTO}; } if (v.size() == 3) { - *this = {FFTShape{v[0], v[1], v[2]}, LAST_VARIANT, CARRY_AUTO}; + *this = {FFTShape{fft_type, v[0], v[1], v[2]}, LAST_VARIANT, CARRY_AUTO}; } else if (v.size() == 4) { - *this = {FFTShape{v[0], v[1], v[2]}, parseInt(v[3]), CARRY_AUTO}; + *this = {FFTShape{fft_type, v[0], v[1], v[2]}, parseInt(v[3]), CARRY_AUTO}; } else if (v.size() == 5) { int c = parseInt(v[4]); assert(c == 0 || c == 1); - *this = {FFTShape{v[0], v[1], v[2]}, parseInt(v[3]), c == 0 ? CARRY_32 : CARRY_64}; + *this = {FFTShape{fft_type, v[0], v[1], v[2]}, parseInt(v[3]), c == 0 ? CARRY_32 : CARRY_64}; } else { throw "FFT spec"; } @@ -208,6 +235,17 @@ FFTConfig::FFTConfig(FFTShape shape, u32 variant, u32 carry) : assert(variant_W(variant) < N_VARIANT_W); assert(variant_M(variant) < N_VARIANT_M); assert(variant_H(variant) < N_VARIANT_H); + + if (shape.fft_type == FFT64) FFT_FP64 = 1, FFT_FP32 = 0, NTT_GF31 = 0, NTT_GF61 = 0, WordSize = 4; + else if (shape.fft_type == FFT3161) FFT_FP64 = 0, FFT_FP32 = 0, NTT_GF31 = 1, NTT_GF61 = 1, WordSize = 8; + else if (shape.fft_type == FFT3261) FFT_FP64 = 0, FFT_FP32 = 1, NTT_GF31 = 0, NTT_GF61 = 1, WordSize = 8; + else if (shape.fft_type == FFT61) FFT_FP64 = 0, FFT_FP32 = 0, NTT_GF31 = 0, NTT_GF61 = 1, WordSize = 4; + else if (shape.fft_type == FFT323161) FFT_FP64 = 0, FFT_FP32 = 1, NTT_GF31 = 1, NTT_GF61 = 1, WordSize = 8; + else if (shape.fft_type == FFT3231) FFT_FP64 = 0, FFT_FP32 = 1, NTT_GF31 = 1, NTT_GF61 = 0, WordSize = 4; + else if (shape.fft_type == FFT6431) FFT_FP64 = 1, FFT_FP32 = 0, NTT_GF31 = 1, NTT_GF61 = 0, WordSize = 8; + else if (shape.fft_type == FFT31) FFT_FP64 = 0, FFT_FP32 = 0, NTT_GF31 = 1, NTT_GF61 = 0, WordSize = 4; + else if (shape.fft_type == FFT32) FFT_FP64 = 0, FFT_FP32 = 1, NTT_GF31 = 0, NTT_GF61 = 0, WordSize = 4; + else throw "FFT type"; } string FFTConfig::spec() const { @@ -215,8 +253,8 @@ string FFTConfig::spec() const { return carry == CARRY_AUTO ? s : (s + (carry == CARRY_32 ? ":0" : ":1")); } -double FFTConfig::maxBpw() const { - double b; +float FFTConfig::maxBpw() const { + float b; // Look up the pre-computed maximum bpw. The lookup table contains data for variants 000, 101, 202, 010, 111, 212. // For 4K width, the lookup table contains data for variants 100, 101, 202, 110, 111, 212 since BCAST only works for width <= 1024. if (variant_W(variant) == variant_H(variant) || @@ -225,11 +263,12 @@ double FFTConfig::maxBpw() const { } // Interpolate for the maximum bpw. This might could be improved upon. However, I doubt people will use these variants often. else { - double b1 = shape.bpw[variant_M(variant) * 3 + variant_W(variant)]; - double b2 = shape.bpw[variant_M(variant) * 3 + variant_H(variant)]; + float b1 = shape.bpw[variant_M(variant) * 3 + variant_W(variant)]; + float b2 = shape.bpw[variant_M(variant) * 3 + variant_H(variant)]; b = (b1 + b2) / 2.0; } - return carry == CARRY_32 ? std::min(shape.carry32BPW(), b) : b; + // Only some FFTs support both 32 and 64 bit carries. + return (carry == CARRY_32 && (shape.fft_type == FFT64 || shape.fft_type == FFT3231)) ? std::min(shape.carry32BPW(), b) : b; } FFTConfig FFTConfig::bestFit(const Args& args, u32 E, const string& spec) { @@ -237,7 +276,7 @@ FFTConfig FFTConfig::bestFit(const Args& args, u32 E, const string& spec) { if (!spec.empty()) { FFTConfig fft{spec}; if (fft.maxExp() * args.fftOverdrive < E) { - log("Warning: %s (max %u) may be too small for %u\n", fft.spec().c_str(), fft.maxExp(), E); + log("Warning: %s (max %" PRIu64 ") may be too small for %u\n", fft.spec().c_str(), fft.maxExp(), E); } return fft; } @@ -263,7 +302,7 @@ FFTConfig FFTConfig::bestFit(const Args& args, u32 E, const string& spec) { } -string numberK(u32 n) { +string numberK(u64 n) { u32 K = 1024; u32 M = K * K; diff --git a/src/FFTConfig.h b/src/FFTConfig.h index dc65cad0..c873eb8c 100644 --- a/src/FFTConfig.h +++ b/src/FFTConfig.h @@ -17,37 +17,40 @@ class Args; // Format 'n' with a K or M suffix if multiple of 1024 or 1024*1024 -string numberK(u32 n); +string numberK(u64 n); using KeyVal = std::pair; +enum FFT_TYPES {FFT64=0, FFT3161=1, FFT3261=2, FFT61=3, FFT323161=4, FFT3231=50, FFT6431=51, FFT31=52, FFT32=53}; + class FFTShape { public: - static constexpr const float MIN_BPW = 3; - static std::vector allShapes(u32 from=0, u32 to = -1); static tuple getChainLengths(u32 fftSize, u32 exponent, u32 middle); static vector multiSpec(const string& spec); + enum FFT_TYPES fft_type; u32 width = 0; u32 middle = 0; u32 height = 0; - array bpw; + array bpw; FFTShape(u32 w = 1, u32 m = 1, u32 h = 1); - FFTShape(const string& w, const string& m, const string& h); + FFTShape(enum FFT_TYPES t, u32 w, u32 m, u32 h); + FFTShape(enum FFT_TYPES t, const string& w, const string& m, const string& h); explicit FFTShape(const string& spec); u32 size() const { return width * height * middle * 2; } u32 nW() const { return (width == 1024 || width == 256 /*|| width == 4096*/) ? 4 : 8; } u32 nH() const { return (height == 1024 || height == 256 /*|| height == 4096*/) ? 4 : 8; } - double maxBpw() const { return *max_element(bpw.begin(), bpw.end()); } - std::string spec() const { return numberK(width) + ':' + numberK(middle) + ':' + numberK(height); } + float minBpw() const { return fft_type != FFT32 ? 3.0f : 1.0f; } + float maxBpw() const { return *max_element(bpw.begin(), bpw.end()); } + std::string spec() const { return (fft_type ? to_string(fft_type) + ':' : "") + numberK(width) + ':' + numberK(middle) + ':' + numberK(height); } - double carry32BPW() const; + float carry32BPW() const; bool needsLargeCarry(u32 E) const; bool isFavoredShape() const; }; @@ -66,12 +69,22 @@ inline u32 next_variant(u32 v) { new_v = (v / 100 + 1) * 100; return (new_v); } -enum CARRY_KIND { CARRY_32=0, CARRY_64=1, CARRY_AUTO=2}; +enum CARRY_KIND {CARRY_32=0, CARRY_64=1, CARRY_AUTO=2}; struct FFTConfig { public: static FFTConfig bestFit(const Args& args, u32 E, const std::string& spec); + // Which FP and NTT primes are involved in the FFT + bool FFT_FP64; + bool FFT_FP32; + bool NTT_GF31; + bool NTT_GF61; + // bool NTT_NCW; // Nick Craig-Wood prime not supported (yet?) + + // Size (in bytes) of integer data passed to FFTs/NTTs on the GPU + u32 WordSize; + FFTShape shape{}; u32 variant; u32 carry; @@ -80,8 +93,9 @@ struct FFTConfig { FFTConfig(FFTShape shape, u32 variant, u32 carry); std::string spec() const; - u32 size() const { return shape.size(); } - u32 maxExp() const { return maxBpw() * shape.size(); } + u64 size() const { return shape.size(); } + u64 maxExp() const { return maxBpw() * shape.size(); } - double maxBpw() const; + float minBpw() const { return shape.minBpw(); } + float maxBpw() const; }; diff --git a/src/File.h b/src/File.h index 6049246d..1e1766e4 100644 --- a/src/File.h +++ b/src/File.h @@ -60,11 +60,17 @@ class File { _commit(fileno(f)); #elif defined(__APPLE__) fcntl(fileno(f), F_FULLFSYNC, 0); +#elif defined(__MINGW32__) || defined(__MINGW64__) + fdatasync(fileno(f)); +#elif defined(__MSYS__) // MSYS2 using CLANG64 compiler +#define fileno(__F) ((__F)->_file) + fsync(fileno(f)); // This doesn't work +#undef fileno #else fdatasync(fileno(f)); #endif } - + public: const std::string name; diff --git a/src/Gpu.cpp b/src/Gpu.cpp index 3b0f980b..44656fa1 100644 --- a/src/Gpu.cpp +++ b/src/Gpu.cpp @@ -23,6 +23,8 @@ #include #include #include +#include +#include #define _USE_MATH_DEFINES #include @@ -35,6 +37,8 @@ #define M_PI 3.141592653589793238462643383279502884 #endif +#define CARRY_LEN 8 + namespace { u32 kAt(u32 H, u32 line, u32 col) { return (line + col * H) * 2; } @@ -57,83 +61,118 @@ double invWeightM1(u32 N, u32 E, u32 H, u32 line, u32 col, u32 rep) { double boundUnderOne(double x) { return std::min(x, nexttoward(1, 0)); } -#define CARRY_LEN 8 +float weight32(u32 N, u32 E, u32 H, u32 line, u32 col, u32 rep) { + return exp2((double)(extra(N, E, kAt(H, line, col) + rep)) / N); +} + +float invWeight32(u32 N, u32 E, u32 H, u32 line, u32 col, u32 rep) { + return exp2(-(double)(extra(N, E, kAt(H, line, col) + rep)) / N); +} -Weights genWeights(u32 E, u32 W, u32 H, u32 nW, bool AmdGpu) { +float weightM132(u32 N, u32 E, u32 H, u32 line, u32 col, u32 rep) { + return exp2((double)(extra(N, E, kAt(H, line, col) + rep)) / N) - 1; +} + +float invWeightM132(u32 N, u32 E, u32 H, u32 line, u32 col, u32 rep) { + return exp2(- (double)(extra(N, E, kAt(H, line, col) + rep)) / N) - 1; +} + +float boundUnderOne(float x) { return std::min(x, nexttowardf(1, 0)); } + +Weights genWeights(FFTConfig fft, u32 E, u32 W, u32 H, u32 nW, bool AmdGpu) { u32 N = 2u * W * H; - u32 groupWidth = W / nW; - // Inverse + Forward vector weightsConstIF; vector weightsIF; - for (u32 thread = 0; thread < groupWidth; ++thread) { - auto iw = invWeight(N, E, H, 0, thread, 0); - auto w = weight(N, E, H, 0, thread, 0); - // nVidia GPUs have a constant cache that only works on buffer sizes less than 64KB. Create a smaller buffer - // that is a copy of the first part of weightsIF. There are several kernels that need the combined weightsIF - // buffer, so there is an unfortunate duplication of these weights. - if (!AmdGpu) { - weightsConstIF.push_back(2 * boundUnderOne(iw)); - weightsConstIF.push_back(2 * w); + vector bits; + + if (fft.FFT_FP64) { + // Inverse + Forward + for (u32 thread = 0; thread < groupWidth; ++thread) { + auto iw = invWeight(N, E, H, 0, thread, 0); + auto w = weight(N, E, H, 0, thread, 0); + // nVidia GPUs have a constant cache that only works on buffer sizes less than 64KB. Create a smaller buffer + // that is a copy of the first part of weightsIF. There are several kernels that need the combined weightsIF + // buffer, so there is an unfortunate duplication of these weights. + if (!AmdGpu) { + weightsConstIF.push_back(2 * boundUnderOne(iw)); + weightsConstIF.push_back(2 * w); + } + weightsIF.push_back(2 * boundUnderOne(iw)); + weightsIF.push_back(2 * w); } - weightsIF.push_back(2 * boundUnderOne(iw)); - weightsIF.push_back(2 * w); - } - // the group order matches CarryA/M (not fftP/CarryFused). - for (u32 gy = 0; gy < H; ++gy) { - weightsIF.push_back(invWeightM1(N, E, H, gy, 0, 0)); - weightsIF.push_back(weightM1(N, E, H, gy, 0, 0)); + // the group order matches CarryA/M (not fftP/CarryFused). + for (u32 gy = 0; gy < H; ++gy) { + weightsIF.push_back(invWeightM1(N, E, H, gy, 0, 0)); + weightsIF.push_back(weightM1(N, E, H, gy, 0, 0)); + } } - - vector bits; - - for (u32 line = 0; line < H; ++line) { - for (u32 thread = 0; thread < groupWidth; ) { - std::bitset<32> b; - for (u32 bitoffset = 0; bitoffset < 32; bitoffset += nW*2, ++thread) { - for (u32 block = 0; block < nW; ++block) { - for (u32 rep = 0; rep < 2; ++rep) { - if (isBigWord(N, E, kAt(H, line, block * groupWidth + thread) + rep)) { b.set(bitoffset + block * 2 + rep); } - } - } + + else if (fft.FFT_FP32) { + vector weightsConstIF32; + vector weightsIF32; + // Inverse + Forward + for (u32 thread = 0; thread < groupWidth; ++thread) { + auto iw = invWeight32(N, E, H, 0, thread, 0); + auto w = weight32(N, E, H, 0, thread, 0); + // nVidia GPUs have a constant cache that only works on buffer sizes less than 64KB. Create a smaller buffer + // that is a copy of the first part of weightsIF. There are several kernels that need the combined weightsIF + // buffer, so there is an unfortunate duplication of these weights. + if (!AmdGpu) { + weightsConstIF32.push_back(2 * boundUnderOne(iw)); + weightsConstIF32.push_back(2 * w); } - bits.push_back(b.to_ulong()); + weightsIF32.push_back(2 * boundUnderOne(iw)); + weightsIF32.push_back(2 * w); + } + + // the group order matches CarryA/M (not fftP/CarryFused). + for (u32 gy = 0; gy < H; ++gy) { + weightsIF32.push_back(invWeightM132(N, E, H, gy, 0, 0)); + weightsIF32.push_back(weightM132(N, E, H, gy, 0, 0)); } + + // Copy the float vectors to the double vectors + weightsConstIF.resize(weightsConstIF32.size() / 2); + memcpy((double *) weightsConstIF.data(), weightsConstIF32.data(), weightsConstIF32.size() * sizeof(float)); + weightsIF.resize(weightsIF32.size() / 2); + memcpy((double *) weightsIF.data(), weightsIF32.data(), weightsIF32.size() * sizeof(float)); } - assert(bits.size() == N / 32); - vector bitsC; - - for (u32 gy = 0; gy < H / CARRY_LEN; ++gy) { - for (u32 gx = 0; gx < nW; ++gx) { + if (fft.FFT_FP64 || fft.FFT_FP32) { + for (u32 line = 0; line < H; ++line) { for (u32 thread = 0; thread < groupWidth; ) { std::bitset<32> b; - for (u32 bitoffset = 0; bitoffset < 32; bitoffset += CARRY_LEN * 2, ++thread) { - for (u32 block = 0; block < CARRY_LEN; ++block) { + for (u32 bitoffset = 0; bitoffset < 32; bitoffset += nW*2, ++thread) { + for (u32 block = 0; block < nW; ++block) { for (u32 rep = 0; rep < 2; ++rep) { - if (isBigWord(N, E, kAt(H, gy * CARRY_LEN + block, gx * groupWidth + thread) + rep)) { b.set(bitoffset + block * 2 + rep); } - } + if (isBigWord(N, E, kAt(H, line, block * groupWidth + thread) + rep)) { b.set(bitoffset + block * 2 + rep); } + } } } - bitsC.push_back(b.to_ulong()); + bits.push_back(b.to_ulong()); } } + assert(bits.size() == N / 32); } - assert(bitsC.size() == N / 32); - return Weights{weightsConstIF, weightsIF, bits, bitsC}; + return Weights{weightsConstIF, weightsIF, bits}; } -string toLiteral(u32 value) { return to_string(value) + 'u'; } string toLiteral(i32 value) { return to_string(value); } -[[maybe_unused]] string toLiteral(u64 value) { return to_string(value) + "ul"; } +string toLiteral(u32 value) { return to_string(value) + 'u'; } +[[maybe_unused]] string toLiteral(long value) { return to_string(value) + "l"; } +[[maybe_unused]] string toLiteral(unsigned long value) { return to_string(value) + "ul"; } +[[maybe_unused]] string toLiteral(long long value) { return to_string(value) + "l"; } // Yes, this looks wrong. The Mingw64 C compiler uses +[[maybe_unused]] string toLiteral(unsigned long long value) { return to_string(value) + "ul"; } // long long for 64-bits, while openCL uses long for 64 bits. template string toLiteral(F value) { std::ostringstream ss; ss << std::setprecision(numeric_limits::max_digits10) << value; + if (sizeof(F) == 4) ss << "f"; string s = std::move(ss).str(); // verify exact roundtrip @@ -167,7 +206,11 @@ string toLiteral(const std::array& v) { string toLiteral(const string& s) { return s; } -string toLiteral(double2 cs) { return "U2("s + toLiteral(cs.first) + ',' + toLiteral(cs.second) + ')'; } +[[maybe_unused]] string toLiteral(float2 cs) { return "U2("s + toLiteral(cs.first) + ',' + toLiteral(cs.second) + ')'; } +[[maybe_unused]] string toLiteral(double2 cs) { return "U2("s + toLiteral(cs.first) + ',' + toLiteral(cs.second) + ')'; } +[[maybe_unused]] string toLiteral(int2 cs) { return "U2("s + toLiteral(cs.first) + ',' + toLiteral(cs.second) + ')'; } +[[maybe_unused]] string toLiteral(uint2 cs) { return "U2("s + toLiteral(cs.first) + ',' + toLiteral(cs.second) + ')'; } +[[maybe_unused]] string toLiteral(ulong2 cs) { return "U2("s + toLiteral(cs.first) + ',' + toLiteral(cs.second) + ')'; } template string toDefine(const string& k, T v) { return " -D"s + k + '=' + toLiteral(v); } @@ -185,7 +228,7 @@ constexpr bool isInList(const string& s, initializer_list list) { } string clDefines(const Args& args, cl_device_id id, FFTConfig fft, const vector& extraConf, u32 E, bool doLog, - bool &tail_single_wide, bool &tail_single_kernel, u32 &tail_trigs, u32 &pad_size) { + bool &tail_single_wide, bool &tail_single_kernel, u32 &in_place, u32 &pad_size) { map config; // Highest priority is the requested "extra" conf @@ -202,7 +245,7 @@ string clDefines(const Args& args, cl_device_id id, FFTConfig fft, const vector< // Default value for -use options that must also be parsed in C++ code tail_single_wide = 0, tail_single_kernel = 1; // Default tailSquare is double-wide in one kernel - tail_trigs = 2; // Default is calculating from scratch, no memory accesses + in_place = 0; // Default is not in-place pad_size = isAmdGpu(id) ? 256 : 0; // Default is 256 bytes for AMD, 0 for others // Validate -use options @@ -223,12 +266,20 @@ string clDefines(const Args& args, cl_device_id id, FFTConfig fft, const vector< "CARRY64", "BIGLIT", "NONTEMPORAL", + "INPLACE", "PAD", "MIDDLE_IN_LDS_TRANSPOSE", "MIDDLE_OUT_LDS_TRANSPOSE", "TAIL_KERNELS", "TAIL_TRIGS", - "TABMUL_CHAIN" + "TAIL_TRIGS31", + "TAIL_TRIGS32", + "TAIL_TRIGS61", + "TABMUL_CHAIN", + "TABMUL_CHAIN31", + "TABMUL_CHAIN32", + "TABMUL_CHAIN61", + "MODM31" }); if (!isValid) { log("Warning: unrecognized -use key '%s'\n", k.c_str()); @@ -241,7 +292,7 @@ string clDefines(const Args& args, cl_device_id id, FFTConfig fft, const vector< if (atoi(v.c_str()) == 2) tail_single_wide = 0, tail_single_kernel = 1; if (atoi(v.c_str()) == 3) tail_single_wide = 0, tail_single_kernel = 0; } - if (k == "TAIL_TRIGS") tail_trigs = atoi(v.c_str()); + if (k == "INPLACE") in_place = atoi(v.c_str()); if (k == "PAD") pad_size = atoi(v.c_str()); } @@ -259,6 +310,7 @@ string clDefines(const Args& args, cl_device_id id, FFTConfig fft, const vector< }); if (isAmdGpu(id)) { defines += toDefine("AMDGPU", 1); } + if (isNvidiaGpu(id)) { defines += toDefine("NVIDIAGPU", 1); } if ((fft.carry == CARRY_AUTO && fft.shape.needsLargeCarry(E)) || (fft.carry == CARRY_64)) { if (doLog) { log("Using CARRY64\n"); } @@ -266,16 +318,93 @@ string clDefines(const Args& args, cl_device_id id, FFTConfig fft, const vector< } u32 N = fft.shape.size(); - - defines += toDefine("WEIGHT_STEP", weightM1(N, E, fft.shape.height * fft.shape.middle, 0, 0, 1)); - defines += toDefine("IWEIGHT_STEP", invWeightM1(N, E, fft.shape.height * fft.shape.middle, 0, 0, 1)); defines += toDefine("FFT_VARIANT", fft.variant); - defines += toDefine("TAILT", root1Fancy(fft.shape.height * 2, 1)); + defines += toDefine("MAXBPW", (u32)(fft.maxBpw() * 100.0f)); + + if (fft.FFT_FP64 | fft.FFT_FP32) { + defines += toDefine("WEIGHT_STEP", weightM1(N, E, fft.shape.height * fft.shape.middle, 0, 0, 1)); + defines += toDefine("IWEIGHT_STEP", invWeightM1(N, E, fft.shape.height * fft.shape.middle, 0, 0, 1)); + if (fft.FFT_FP64) defines += toDefine("TAILT", root1Fancy(fft.shape.height * 2, 1)); + else defines += toDefine("TAILT", root1FancyFP32(fft.shape.height * 2, 1)); + + TrigCoefs coefs = trigCoefs(fft.shape.size() / 4); + defines += toDefine("TRIG_SCALE", int(coefs.scale)); + defines += toDefine("TRIG_SIN", coefs.sinCoefs); + defines += toDefine("TRIG_COS", coefs.cosCoefs); + } + if (fft.NTT_GF31) { + defines += toDefine("TAILTGF31", root1GF31(fft.shape.height * 2, 1)); + } + if (fft.NTT_GF61) { + defines += toDefine("TAILTGF61", root1GF61(fft.shape.height * 2, 1)); + } - TrigCoefs coefs = trigCoefs(fft.shape.size() / 4); - defines += toDefine("TRIG_SCALE", int(coefs.scale)); - defines += toDefine("TRIG_SIN", coefs.sinCoefs); - defines += toDefine("TRIG_COS", coefs.cosCoefs); + // Send the FFT/NTT type and booleans that enable/disable code for each possible FP and NTT + defines += toDefine("FFT_TYPE", (int) fft.shape.fft_type); + defines += toDefine("FFT_FP64", (int) fft.FFT_FP64); + defines += toDefine("FFT_FP32", (int) fft.FFT_FP32); + defines += toDefine("NTT_GF31", (int) fft.NTT_GF31); + defines += toDefine("NTT_GF61", (int) fft.NTT_GF61); + defines += toDefine("WordSize", fft.WordSize); + + // When using multiple NTT primes or hybrid FFT/NTT, each FFT/NTT prime's data buffer and trig values are combined into one buffer. + // The openCL code needs to know the offset to the data and trig values. Distances are in "number of double2 values". + if (fft.FFT_FP64 && fft.NTT_GF31) { + // GF31 data is located after the FP64 data. Compute size of the FP64 data and trigs. + defines += toDefine("DISTGF31", FP64_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, in_place, pad_size) / 2); + defines += toDefine("DISTWTRIGGF31", SMALLTRIG_FP64_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); + defines += toDefine("DISTMTRIGGF31", MIDDLETRIG_FP64_DIST(fft.shape.width, fft.shape.middle, fft.shape.height)); + defines += toDefine("DISTHTRIGGF31", SMALLTRIGCOMBO_FP64_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); + } + else if (fft.FFT_FP32 && fft.NTT_GF31 && fft.NTT_GF61) { + // GF31 and GF61 data is located after the FP32 data. Compute size of the FP32 data and trigs. + u32 sz1, sz2, sz3, sz4; + defines += toDefine("DISTGF31", sz1 = FP32_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, in_place, pad_size) / 2); + defines += toDefine("DISTWTRIGGF31", sz2 = SMALLTRIG_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); + defines += toDefine("DISTMTRIGGF31", sz3 = MIDDLETRIG_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height)); + defines += toDefine("DISTHTRIGGF31", sz4 = SMALLTRIGCOMBO_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); + defines += toDefine("DISTGF61", sz1 + GF31_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, in_place, pad_size) / 2); + defines += toDefine("DISTWTRIGGF61", sz2 + SMALLTRIG_GF31_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); + defines += toDefine("DISTMTRIGGF61", sz3 + MIDDLETRIG_GF31_DIST(fft.shape.width, fft.shape.middle, fft.shape.height)); + defines += toDefine("DISTHTRIGGF61", sz4 + SMALLTRIGCOMBO_GF31_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); + } + else if (fft.FFT_FP32 && fft.NTT_GF31) { + // GF31 data is located after the FP32 data. Compute size of the FP32 data and trigs. + defines += toDefine("DISTGF31", FP32_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, in_place, pad_size) / 2); + defines += toDefine("DISTWTRIGGF31", SMALLTRIG_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); + defines += toDefine("DISTMTRIGGF31", MIDDLETRIG_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height)); + defines += toDefine("DISTHTRIGGF31", SMALLTRIGCOMBO_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); + } + else if (fft.FFT_FP32 && fft.NTT_GF61) { + // GF61 data is located after the FP32 data. Compute size of the FP32 data and trigs. + defines += toDefine("DISTGF61", FP32_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, in_place, pad_size) / 2); + defines += toDefine("DISTWTRIGGF61", SMALLTRIG_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); + defines += toDefine("DISTMTRIGGF61", MIDDLETRIG_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height)); + defines += toDefine("DISTHTRIGGF61", SMALLTRIGCOMBO_FP32_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); + } + else if (fft.NTT_GF31 && fft.NTT_GF61) { + defines += toDefine("DISTGF31", 0); + defines += toDefine("DISTWTRIGGF31", 0); + defines += toDefine("DISTMTRIGGF31", 0); + defines += toDefine("DISTHTRIGGF31", 0); + // GF61 data is located after the GF31 data. Compute size of the GF31 data and trigs. + defines += toDefine("DISTGF61", GF31_DATA_SIZE(fft.shape.width, fft.shape.middle, fft.shape.height, in_place, pad_size) / 2); + defines += toDefine("DISTWTRIGGF61", SMALLTRIG_GF31_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); + defines += toDefine("DISTMTRIGGF61", MIDDLETRIG_GF31_DIST(fft.shape.width, fft.shape.middle, fft.shape.height)); + defines += toDefine("DISTHTRIGGF61", SMALLTRIGCOMBO_GF31_DIST(fft.shape.width, fft.shape.middle, fft.shape.height, fft.shape.nH())); + } + else if (fft.NTT_GF31) { + defines += toDefine("DISTGF31", 0); + defines += toDefine("DISTWTRIGGF31", 0); + defines += toDefine("DISTMTRIGGF31", 0); + defines += toDefine("DISTHTRIGGF31", 0); + } + else if (fft.NTT_GF61) { + defines += toDefine("DISTGF61", 0); + defines += toDefine("DISTWTRIGGF61", 0); + defines += toDefine("DISTMTRIGGF61", 0); + defines += toDefine("DISTHTRIGGF61", 0); + } // Calculate fractional bits-per-word = (E % N) / N * 2^64 u32 bpw_hi = (u64(E % N) << 32) / N; @@ -284,8 +413,6 @@ string clDefines(const Args& args, cl_device_id id, FFTConfig fft, const vector< bpw--; // bpw must not be an exact value -- it must be less than exact value to get last biglit value right defines += toDefine("FRAC_BPW_HI", (u32) (bpw >> 32)); defines += toDefine("FRAC_BPW_LO", (u32) bpw); - u32 bigstep = (bpw * (N / fft.shape.nW())) >> 32; - defines += toDefine("FRAC_BITS_BIGSTEP", bigstep); return defines; } @@ -397,92 +524,107 @@ Gpu::Gpu(Queue* q, GpuCommon shared, FFTConfig fft, u32 E, const vector& args{*shared.args}, E(E), N(fft.shape.size()), + fft(fft), WIDTH(fft.shape.width), SMALL_H(fft.shape.height), BIG_H(SMALL_H * fft.shape.middle), hN(N / 2), nW(fft.shape.nW()), nH(fft.shape.nH()), - bufSize(N * sizeof(double)), useLongCarry{args.carry == Args::CARRY_LONG}, - compiler{args, queue->context, clDefines(args, queue->context->deviceId(), fft, extraConf, E, logFftSize, - tail_single_wide, tail_single_kernel, tail_trigs, pad_size)}, + compiler{args, queue->context, clDefines(args, queue->context->deviceId(), fft, extraConf, E, logFftSize, tail_single_wide, tail_single_kernel, in_place, pad_size)}, #define K(name, ...) name(#name, &compiler, profile.make(#name), queue, __VA_ARGS__) - // W / nW - K(kCarryFused, "carryfused.cl", "carryFused", WIDTH * (BIG_H + 1) / nW), - K(kCarryFusedROE, "carryfused.cl", "carryFused", WIDTH * (BIG_H + 1) / nW, "-DROE=1"), - - K(kCarryFusedMul, "carryfused.cl", "carryFused", WIDTH * (BIG_H + 1) / nW, "-DMUL3=1"), - K(kCarryFusedMulROE, "carryfused.cl", "carryFused", WIDTH * (BIG_H + 1) / nW, "-DMUL3=1 -DROE=1"), - - K(kCarryFusedLL, "carryfused.cl", "carryFused", WIDTH * (BIG_H + 1) / nW, "-DLL=1"), - - K(kCarryA, "carry.cl", "carry", hN / CARRY_LEN), - K(kCarryAROE, "carry.cl", "carry", hN / CARRY_LEN, "-DROE=1"), - - K(kCarryM, "carry.cl", "carry", hN / CARRY_LEN, "-DMUL3=1"), - K(kCarryMROE, "carry.cl", "carry", hN / CARRY_LEN, "-DMUL3=1 -DROE=1"), - - K(kCarryLL, "carry.cl", "carry", hN / CARRY_LEN, "-DLL=1"), - K(carryB, "carryb.cl", "carryB", hN / CARRY_LEN), - - K(fftP, "fftp.cl", "fftP", hN / nW), - K(fftW, "fftw.cl", "fftW", hN / nW), - - // SMALL_H / nH - K(fftHin, "ffthin.cl", "fftHin", hN / nH), - K(tailSquareZero, "tailsquare.cl", "tailSquareZero", SMALL_H / nH * 2), - K(tailSquare, "tailsquare.cl", "tailSquare", !tail_single_wide && !tail_single_kernel ? hN / nH - SMALL_H / nH * 2 : // Double-wide tailSquare with two kernels + K(kfftMidIn, "fftmiddlein.cl", "fftMiddleIn", hN / (BIG_H / SMALL_H)), + K(kfftHin, "ffthin.cl", "fftHin", hN / nH), + K(ktailSquareZero, "tailsquare.cl", "tailSquareZero", SMALL_H / nH * 2), + K(ktailSquare, "tailsquare.cl", "tailSquare", + !tail_single_wide && !tail_single_kernel ? hN / nH - SMALL_H / nH * 2 : // Double-wide tailSquare with two kernels !tail_single_wide ? hN / nH : // Double-wide tailSquare with one kernel !tail_single_kernel ? hN / nH / 2 - SMALL_H / nH : // Single-wide tailSquare with two kernels hN / nH / 2), // Single-wide tailSquare with one kernel + K(ktailMul, "tailmul.cl", "tailMul", hN / nH / 2), + K(ktailMulLow, "tailmul.cl", "tailMul", hN / nH / 2, "-DMUL_LOW=1"), + K(kfftMidOut, "fftmiddleout.cl", "fftMiddleOut", hN / (BIG_H / SMALL_H)), + K(kfftW, "fftw.cl", "fftW", hN / nW), + + K(kfftMidInGF31, "fftmiddlein.cl", "fftMiddleInGF31", hN / (BIG_H / SMALL_H)), + K(kfftHinGF31, "ffthin.cl", "fftHinGF31", hN / nH), + K(ktailSquareZeroGF31, "tailsquare.cl", "tailSquareZeroGF31", SMALL_H / nH * 2), + K(ktailSquareGF31, "tailsquare.cl", "tailSquareGF31", + !tail_single_wide && !tail_single_kernel ? hN / nH - SMALL_H / nH * 2 : // Double-wide tailSquare with two kernels + !tail_single_wide ? hN / nH : // Double-wide tailSquare with one kernel + !tail_single_kernel ? hN / nH / 2 - SMALL_H / nH : // Single-wide tailSquare with two kernels + hN / nH / 2), // Single-wide tailSquare with one kernel + K(ktailMulGF31, "tailmul.cl", "tailMulGF31", hN / nH / 2), + K(ktailMulLowGF31, "tailmul.cl", "tailMulGF31", hN / nH / 2, "-DMUL_LOW=1"), + K(kfftMidOutGF31, "fftmiddleout.cl", "fftMiddleOutGF31", hN / (BIG_H / SMALL_H)), + K(kfftWGF31, "fftw.cl", "fftWGF31", hN / nW), + + K(kfftMidInGF61, "fftmiddlein.cl", "fftMiddleInGF61", hN / (BIG_H / SMALL_H)), + K(kfftHinGF61, "ffthin.cl", "fftHinGF61", hN / nH), + K(ktailSquareZeroGF61, "tailsquare.cl", "tailSquareZeroGF61", SMALL_H / nH * 2), + K(ktailSquareGF61, "tailsquare.cl", "tailSquareGF61", + !tail_single_wide && !tail_single_kernel ? hN / nH - SMALL_H / nH * 2 : // Double-wide tailSquare with two kernels + !tail_single_wide ? hN / nH : // Double-wide tailSquare with one kernel + !tail_single_kernel ? hN / nH / 2 - SMALL_H / nH : // Single-wide tailSquare with two kernels + hN / nH / 2), // Single-wide tailSquare with one kernel + K(ktailMulGF61, "tailmul.cl", "tailMulGF61", hN / nH / 2), + K(ktailMulLowGF61, "tailmul.cl", "tailMulGF61", hN / nH / 2, "-DMUL_LOW=1"), + K(kfftMidOutGF61, "fftmiddleout.cl", "fftMiddleOutGF61", hN / (BIG_H / SMALL_H)), + K(kfftWGF61, "fftw.cl", "fftWGF61", hN / nW), + + K(kfftP, "fftp.cl", "fftP", hN / nW), + K(kCarryA, "carry.cl", "carry", hN / CARRY_LEN), + K(kCarryAROE, "carry.cl", "carry", hN / CARRY_LEN, "-DROE=1"), + K(kCarryM, "carry.cl", "carry", hN / CARRY_LEN, "-DMUL3=1"), + K(kCarryMROE, "carry.cl", "carry", hN / CARRY_LEN, "-DMUL3=1 -DROE=1"), + K(kCarryLL, "carry.cl", "carry", hN / CARRY_LEN, "-DLL=1"), + K(kCarryFused, "carryfused.cl", "carryFused", WIDTH * (BIG_H + 1) / nW), + K(kCarryFusedROE, "carryfused.cl", "carryFused", WIDTH * (BIG_H + 1) / nW, "-DROE=1"), + K(kCarryFusedMul, "carryfused.cl", "carryFused", WIDTH * (BIG_H + 1) / nW, "-DMUL3=1"), + K(kCarryFusedMulROE, "carryfused.cl", "carryFused", WIDTH * (BIG_H + 1) / nW, "-DMUL3=1 -DROE=1"), + K(kCarryFusedLL, "carryfused.cl", "carryFused", WIDTH * (BIG_H + 1) / nW, "-DLL=1"), + + K(carryB, "carryb.cl", "carryB", hN / CARRY_LEN), - K(tailMul, "tailmul.cl", "tailMul", hN / nH / 2), - K(tailMulLow, "tailmul.cl", "tailMul", hN / nH / 2, "-DMUL_LOW=1"), - - // 256 - K(fftMidIn, "fftmiddlein.cl", "fftMiddleIn", hN / (BIG_H / SMALL_H)), - K(fftMidOut, "fftmiddleout.cl", "fftMiddleOut", hN / (BIG_H / SMALL_H)), - // 64 K(transpIn, "transpose.cl", "transposeIn", hN / 64), K(transpOut, "transpose.cl", "transposeOut", hN / 64), - + K(readResidue, "etc.cl", "readResidue", 32, "-DREADRESIDUE=1"), // 256 K(kernIsEqual, "etc.cl", "isEqual", 256 * 256, "-DISEQUAL=1"), K(sum64, "etc.cl", "sum64", 256 * 256, "-DSUM64=1"), + K(testTrig, "selftest.cl", "testTrig", 256 * 256), K(testFFT4, "selftest.cl", "testFFT4", 256), - K(testFFT, "selftest.cl", "testFFT", 256), - K(testFFT15, "selftest.cl", "testFFT15", 256), K(testFFT14, "selftest.cl", "testFFT14", 256), + K(testFFT15, "selftest.cl", "testFFT15", 256), + K(testFFT, "selftest.cl", "testFFT", 256), K(testTime, "selftest.cl", "testTime", 4096 * 64), -#undef K - bufTrigW{shared.bufCache->smallTrig(WIDTH, nW)}, - bufTrigH{shared.bufCache->smallTrigCombo(WIDTH, fft.shape.middle, SMALL_H, nH, fft.variant, tail_single_wide, tail_trigs)}, - bufTrigM{shared.bufCache->middleTrig(SMALL_H, BIG_H / SMALL_H, WIDTH)}, +#undef K - weights{genWeights(E, WIDTH, BIG_H, nW, isAmdGpu(q->context->deviceId()))}, + bufTrigH{shared.bufCache->smallTrigCombo(shared.args, fft, WIDTH, fft.shape.middle, SMALL_H, nH, tail_single_wide)}, + bufTrigM{shared.bufCache->middleTrig(shared.args, fft, SMALL_H, BIG_H / SMALL_H, WIDTH)}, + bufTrigW{shared.bufCache->smallTrig(shared.args, fft, WIDTH, nW, fft.shape.middle, SMALL_H, nH, tail_single_wide)}, + weights{genWeights(fft, E, WIDTH, BIG_H, nW, isAmdGpu(q->context->deviceId()))}, bufConstWeights{q->context, std::move(weights.weightsConstIF)}, bufWeights{q->context, std::move(weights.weightsIF)}, bufBits{q->context, std::move(weights.bitsCF)}, - bufBitsC{q->context, std::move(weights.bitsC)}, #define BUF(name, ...) name{profile.make(#name), queue, __VA_ARGS__} - BUF(bufData, N), - BUF(bufAux, N), - - BUF(bufCheck, N), - BUF(bufBase, N), + // GPU Buffers containing integer data. Since this buffer is type i64, if fft.WordSize < 8 then we need less memory allocated. + BUF(bufData, N * fft.WordSize / sizeof(Word)), + BUF(bufAux, N * fft.WordSize / sizeof(Word)), + BUF(bufCheck, N * fft.WordSize / sizeof(Word)), // Every double-word (i.e. N/2) produces one carry. In addition we may have one extra group thus WIDTH more carries. - BUF(bufCarry, N / 2 + WIDTH), + BUF(bufCarry, N / 2 + WIDTH), BUF(bufReady, (N / 2 + WIDTH) / 32), // Every wavefront (32 or 64 lanes) needs to signal "carry is ready" BUF(bufSmallOut, 256), @@ -491,12 +633,9 @@ Gpu::Gpu(Queue* q, GpuCommon shared, FFTConfig fft, u32 E, const vector& BUF(bufROE, ROE_SIZE), BUF(bufStatsCarry, CARRY_SIZE), - // Allocate extra for padding. We can probably tighten up the amount of extra memory allocated. - // The worst case seems to be MIDDLE=4, PAD_SIZE=512 - #define total_padding (((pad_size == 0 ? 0 : (pad_size <= 128 ? N/8 : (pad_size <= 256 ? N/4 : N/2)))) * (fft.shape.middle == 4 ? 5 : 4) / 4) - BUF(buf1, N + total_padding), - BUF(buf2, N + total_padding), - BUF(buf3, N + total_padding), + BUF(buf1, TOTAL_DATA_SIZE(fft, WIDTH, fft.shape.middle, SMALL_H, in_place, pad_size)), + BUF(buf2, TOTAL_DATA_SIZE(fft, WIDTH, fft.shape.middle, SMALL_H, in_place, pad_size)), + BUF(buf3, TOTAL_DATA_SIZE(fft, WIDTH, fft.shape.middle, SMALL_H, in_place, pad_size)), #undef BUF statsBits{u32(args.value("STATS", 0))}, @@ -510,45 +649,76 @@ Gpu::Gpu(Queue* q, GpuCommon shared, FFTConfig fft, u32 E, const vector& // Sometimes we do want to run a FFT beyond a reasonable BPW (e.g. during -ztune), and these situations // coincide with logFftSize == false if (fft.maxExp() < E) { - log("Warning: %s (max %u) may be too small for %u\n", fft.spec().c_str(), fft.maxExp(), E); + log("Warning: %s (max %" PRIu64 ") may be too small for %u\n", fft.spec().c_str(), fft.maxExp(), E); } } - if (bitsPerWord < FFTShape::MIN_BPW) { - log("FFT size too large for exponent (%.2f bits/word < %.2f bits/word).\n", bitsPerWord, FFTShape::MIN_BPW); + if (bitsPerWord < fft.minBpw()) { + log("FFT size too large for exponent (%.2f bits/word < %.2f bits/word).\n", bitsPerWord, fft.minBpw()); throw "FFT size too large"; } - useLongCarry = useLongCarry || (bitsPerWord < 12.0); + useLongCarry = useLongCarry || (bitsPerWord < 10.0); if (useLongCarry) { log("Using long carry!\n"); } - - for (Kernel* k : {&kCarryFused, &kCarryFusedROE, &kCarryFusedMul, &kCarryFusedMulROE, &kCarryFusedLL}) { - k->setFixedArgs(3, bufCarry, bufReady, bufTrigW, bufBits, bufConstWeights, bufWeights); + + if (fft.FFT_FP64 || fft.FFT_FP32) { + kfftMidIn.setFixedArgs(2, bufTrigM); + kfftHin.setFixedArgs(2, bufTrigH); + ktailSquareZero.setFixedArgs(2, bufTrigH); + ktailSquare.setFixedArgs(2, bufTrigH); + ktailMulLow.setFixedArgs(3, bufTrigH); + ktailMul.setFixedArgs(3, bufTrigH); + kfftMidOut.setFixedArgs(2, bufTrigM); + kfftW.setFixedArgs(2, bufTrigW); } - for (Kernel* k : {&kCarryFusedROE, &kCarryFusedMulROE}) { k->setFixedArgs(9, bufROE); } - for (Kernel* k : {&kCarryFused, &kCarryFusedMul, &kCarryFusedLL}) { k->setFixedArgs(9, bufStatsCarry); } + if (fft.NTT_GF31) { + kfftMidInGF31.setFixedArgs(2, bufTrigM); + kfftHinGF31.setFixedArgs(2, bufTrigH); + ktailSquareZeroGF31.setFixedArgs(2, bufTrigH); + ktailSquareGF31.setFixedArgs(2, bufTrigH); + ktailMulLowGF31.setFixedArgs(3, bufTrigH); + ktailMulGF31.setFixedArgs(3, bufTrigH); + kfftMidOutGF31.setFixedArgs(2, bufTrigM); + kfftWGF31.setFixedArgs(2, bufTrigW); + } - for (Kernel* k : {&kCarryA, &kCarryAROE, &kCarryM, &kCarryMROE, &kCarryLL}) { - k->setFixedArgs(3, bufCarry, bufBitsC, bufWeights); + if (fft.NTT_GF61) { + kfftMidInGF61.setFixedArgs(2, bufTrigM); + kfftHinGF61.setFixedArgs(2, bufTrigH); + ktailSquareZeroGF61.setFixedArgs(2, bufTrigH); + ktailSquareGF61.setFixedArgs(2, bufTrigH); + ktailMulLowGF61.setFixedArgs(3, bufTrigH); + ktailMulGF61.setFixedArgs(3, bufTrigH); + kfftMidOutGF61.setFixedArgs(2, bufTrigM); + kfftWGF61.setFixedArgs(2, bufTrigW); } - for (Kernel* k : {&kCarryAROE, &kCarryMROE}) { k->setFixedArgs(6, bufROE); } - for (Kernel* k : {&kCarryA, &kCarryM, &kCarryLL}) { k->setFixedArgs(6, bufStatsCarry); } + if (fft.FFT_FP64 || fft.FFT_FP32) { // The FP versions take bufWeight arguments (and bufBits which may be deleted) + kfftP.setFixedArgs(2, bufTrigW, bufWeights); + for (Kernel* k : {&kCarryA, &kCarryAROE, &kCarryM, &kCarryMROE, &kCarryLL}) { k->setFixedArgs(3, bufCarry, bufWeights); } + for (Kernel* k : {&kCarryA, &kCarryM, &kCarryLL}) { k->setFixedArgs(5, bufStatsCarry); } + for (Kernel* k : {&kCarryAROE, &kCarryMROE}) { k->setFixedArgs(5, bufROE); } + for (Kernel* k : {&kCarryFused, &kCarryFusedROE, &kCarryFusedMul, &kCarryFusedMulROE, &kCarryFusedLL}) { + k->setFixedArgs(3, bufCarry, bufReady, bufTrigW, bufBits, bufConstWeights, bufWeights); + } + for (Kernel* k : {&kCarryFusedROE, &kCarryFusedMulROE}) { k->setFixedArgs(9, bufROE); } + for (Kernel* k : {&kCarryFused, &kCarryFusedMul, &kCarryFusedLL}) { k->setFixedArgs(9, bufStatsCarry); } + } else { + kfftP.setFixedArgs(2, bufTrigW); + for (Kernel* k : {&kCarryA, &kCarryAROE, &kCarryM, &kCarryMROE, &kCarryLL}) { k->setFixedArgs(3, bufCarry); } + for (Kernel* k : {&kCarryA, &kCarryM, &kCarryLL}) { k->setFixedArgs(4, bufStatsCarry); } + for (Kernel* k : {&kCarryAROE, &kCarryMROE}) { k->setFixedArgs(4, bufROE); } + for (Kernel* k : {&kCarryFused, &kCarryFusedROE, &kCarryFusedMul, &kCarryFusedMulROE, &kCarryFusedLL}) { + k->setFixedArgs(3, bufCarry, bufReady, bufTrigW); + } + for (Kernel* k : {&kCarryFusedROE, &kCarryFusedMulROE}) { k->setFixedArgs(6, bufROE); } + for (Kernel* k : {&kCarryFused, &kCarryFusedMul, &kCarryFusedLL}) { k->setFixedArgs(6, bufStatsCarry); } + } - fftP.setFixedArgs(2, bufTrigW, bufWeights); - fftW.setFixedArgs(2, bufTrigW); - fftHin.setFixedArgs(2, bufTrigH); + carryB.setFixedArgs(1, bufCarry); - fftMidIn.setFixedArgs( 2, bufTrigM); - fftMidOut.setFixedArgs(2, bufTrigM); - - carryB.setFixedArgs(1, bufCarry, bufBitsC); - tailMulLow.setFixedArgs(3, bufTrigH); - tailMul.setFixedArgs(3, bufTrigH); - tailSquareZero.setFixedArgs(2, bufTrigH); - tailSquare.setFixedArgs(2, bufTrigH); kernIsEqual.setFixedArgs(2, bufTrue); bufReady.zero(); @@ -560,9 +730,97 @@ Gpu::Gpu(Queue* q, GpuCommon shared, FFTConfig fft, u32 E, const vector& selftestTrig(); } + queue->setSquareKernels(1 + 3 * (fft.FFT_FP64 + fft.FFT_FP32 + fft.NTT_GF31 + fft.NTT_GF61)); queue->finish(); } + +// Call the appropriate kernels to support hybrid FFTs and NTTs + +void Gpu::fftP(Buffer& out, Buffer& in) { + kfftP(out, in); +} + +void Gpu::fftW(Buffer& out, Buffer& in, int cache_group) { + if ((cache_group == 0 || cache_group == 1) && (fft.FFT_FP64 || fft.FFT_FP32)) kfftW(out, in); + if ((cache_group == 0 || cache_group == 2) && fft.NTT_GF31) kfftWGF31(out, in); + if ((cache_group == 0 || cache_group == 3) && fft.NTT_GF61) kfftWGF61(out, in); +} + +void Gpu::fftMidIn(Buffer& out, Buffer& in, int cache_group) { + if ((cache_group == 0 || cache_group == 1) && (fft.FFT_FP64 || fft.FFT_FP32)) kfftMidIn(out, in); + if ((cache_group == 0 || cache_group == 2) && fft.NTT_GF31) kfftMidInGF31(out, in); + if ((cache_group == 0 || cache_group == 3) && fft.NTT_GF61) kfftMidInGF61(out, in); +} + +void Gpu::fftMidOut(Buffer& out, Buffer& in, int cache_group) { + if ((cache_group == 0 || cache_group == 1) && (fft.FFT_FP64 || fft.FFT_FP32)) kfftMidOut(out, in); + if ((cache_group == 0 || cache_group == 2) && fft.NTT_GF31) kfftMidOutGF31(out, in); + if ((cache_group == 0 || cache_group == 3) && fft.NTT_GF61) kfftMidOutGF61(out, in); +} + +void Gpu::fftHin(Buffer& out, Buffer& in) { + if (fft.FFT_FP64 || fft.FFT_FP32) kfftHin(out, in); + if (fft.NTT_GF31) kfftHinGF31(out, in); + if (fft.NTT_GF61) kfftHinGF61(out, in); +} + +void Gpu::tailSquare(Buffer& out, Buffer& in, int cache_group) { + if (!tail_single_kernel) { + if ((cache_group == 0 || cache_group == 1) && (fft.FFT_FP64 || fft.FFT_FP32)) ktailSquareZero(out, in); + if ((cache_group == 0 || cache_group == 2) && fft.NTT_GF31) ktailSquareZeroGF31(out, in); + if ((cache_group == 0 || cache_group == 3) && fft.NTT_GF61) ktailSquareZeroGF61(out, in); + } + if ((cache_group == 0 || cache_group == 1) && (fft.FFT_FP64 || fft.FFT_FP32)) ktailSquare(out, in); + if ((cache_group == 0 || cache_group == 2) && fft.NTT_GF31) ktailSquareGF31(out, in); + if ((cache_group == 0 || cache_group == 3) && fft.NTT_GF61) ktailSquareGF61(out, in); +} + +void Gpu::tailMul(Buffer& out, Buffer& in1, Buffer& in2, int cache_group) { + if ((cache_group == 0 || cache_group == 1) && (fft.FFT_FP64 || fft.FFT_FP32)) ktailMul(out, in1, in2); + if ((cache_group == 0 || cache_group == 2) && fft.NTT_GF31) ktailMulGF31(out, in1, in2); + if ((cache_group == 0 || cache_group == 3) && fft.NTT_GF61) ktailMulGF61(out, in1, in2); +} + +void Gpu::tailMulLow(Buffer& out, Buffer& in1, Buffer& in2, int cache_group) { + if ((cache_group == 0 || cache_group == 1) && (fft.FFT_FP64 || fft.FFT_FP32)) ktailMulLow(out, in1, in2); + if ((cache_group == 0 || cache_group == 2) && fft.NTT_GF31) ktailMulLowGF31(out, in1, in2); + if ((cache_group == 0 || cache_group == 3) && fft.NTT_GF61) ktailMulLowGF61(out, in1, in2); +} + +void Gpu::carryA(Buffer& out, Buffer& in) { + assert(roePos <= ROE_SIZE); + roePos < wantROE ? kCarryAROE(out, in, roePos++) + : kCarryA(out, in, updateCarryPos(1 << 2)); +} + +void Gpu::carryM(Buffer& out, Buffer& in) { + assert(roePos <= ROE_SIZE); + roePos < wantROE ? kCarryMROE(out, in, roePos++) + : kCarryM(out, in, updateCarryPos(1 << 3)); +} + +void Gpu::carryLL(Buffer& out, Buffer& in) { + kCarryLL(out, in, updateCarryPos(1 << 2)); +} + +void Gpu::carryFused(Buffer& out, Buffer& in) { + assert(roePos <= ROE_SIZE); + roePos < wantROE ? kCarryFusedROE(out, in, roePos++) + : kCarryFused(out, in, updateCarryPos(1 << 0)); +} + +void Gpu::carryFusedMul(Buffer& out, Buffer& in) { + assert(roePos <= ROE_SIZE); + roePos < wantROE ? kCarryFusedMulROE(out, in, roePos++) + : kCarryFusedMul(out, in, updateCarryPos(1 << 1)); +} + +void Gpu::carryFusedLL(Buffer& out, Buffer& in) { + kCarryFusedLL(out, in, updateCarryPos(1 << 0)); +} + + #if 0 void Gpu::measureTransferSpeed() { u32 SIZE_MB = 16; @@ -589,34 +847,8 @@ u32 Gpu::updateCarryPos(u32 bit) { return (statsBits & bit) && (carryPos < CARRY_SIZE) ? carryPos++ : carryPos; } -void Gpu::carryFused(Buffer& a, Buffer& b) { - assert(roePos <= ROE_SIZE); - roePos < wantROE ? kCarryFusedROE(a, b, roePos++) - : kCarryFused(a, b, updateCarryPos(1 << 0)); -} - -void Gpu::carryFusedMul(Buffer& a, Buffer& b) { - assert(roePos <= ROE_SIZE); - roePos < wantROE ? kCarryFusedMulROE(a, b, roePos++) - : kCarryFusedMul(a, b, updateCarryPos(1 << 1)); -} - -void Gpu::carryA(Buffer& a, Buffer& b) { - assert(roePos <= ROE_SIZE); - roePos < wantROE ? kCarryAROE(a, b, roePos++) - : kCarryA(a, b, updateCarryPos(1 << 2)); -} - -void Gpu::carryLL(Buffer& a, Buffer& b) { kCarryLL(a, b, updateCarryPos(1 << 2)); } - -void Gpu::carryM(Buffer& a, Buffer& b) { - assert(roePos <= ROE_SIZE); - roePos < wantROE ? kCarryMROE(a, b, roePos++) - : kCarryM(a, b, updateCarryPos(1 << 3)); -} - -vector> Gpu::makeBufVector(u32 size) { - vector> r; +vector> Gpu::makeBufVector(u32 size) { + vector> r; for (u32 i = 0; i < size; ++i) { r.emplace_back(timeBufVect, queue, N); } return r; } @@ -674,20 +906,22 @@ template static bool isAllZero(vector v) { return std::all_of(v.begin(), v.end(), [](T x) { return x == 0;}); } // Read from GPU, verifying the transfer with a sum, and retry on failure. -vector Gpu::readChecked(Buffer& buf) { +vector Gpu::readChecked(Buffer& buf) { for (int nRetry = 0; nRetry < 3; ++nRetry) { - sum64(bufSumOut, u32(buf.size * sizeof(int)), buf); + sum64(bufSumOut, u32(buf.size * sizeof(Word)), buf); vector expectedVect(1); bufSumOut.readAsync(expectedVect); - vector data = readOut(buf); + vector data = readOut(buf); u64 gpuSum = expectedVect[0]; - u64 hostSum = 0; - for (auto it = data.begin(), end = data.end(); it < end; it += 2) { - hostSum += u32(*it) | (u64(*(it + 1)) << 32); + + int even = 1; + for (auto it = data.begin(), end = data.end(); it < end; ++it, even = !even) { + if (fft.WordSize == 4) hostSum += even ? u64(u32(*it)) : (u64(*it) << 32); + if (fft.WordSize == 8) hostSum += u64(*it); } if (hostSum == gpuSum) { @@ -704,63 +938,88 @@ vector Gpu::readChecked(Buffer& buf) { throw "GPU persistent read errors"; } -Words Gpu::readAndCompress(Buffer& buf) { return compactBits(readChecked(buf), E); } +Words Gpu::readAndCompress(Buffer& buf) { return compactBits(readChecked(buf), E); } vector Gpu::readCheck() { return readAndCompress(bufCheck); } vector Gpu::readData() { return readAndCompress(bufData); } -// out := inA * inB; -void Gpu::mul(Buffer& ioA, Buffer& inB, Buffer& tmp1, Buffer& tmp2, bool mul3) { +// out := inA * inB; inB is preserved +void Gpu::mul(Buffer& ioA, Buffer& inB, Buffer& tmp1, Buffer& tmp2, bool mul3) { + if (!in_place) { + fftP(tmp2, ioA); + for (int cache_group = 1; cache_group <= NUM_CACHE_GROUPS; ++cache_group) { + fftMidIn(tmp1, tmp2, cache_group); + tailMul(tmp2, inB, tmp1, cache_group); + fftMidOut(tmp1, tmp2, cache_group); + fftW(tmp2, tmp1, cache_group); + } + } + else { fftP(tmp1, ioA); - fftMidIn(tmp2, tmp1); - tailMul(tmp1, inB, tmp2); + for (int cache_group = 1; cache_group <= NUM_CACHE_GROUPS; ++cache_group) { + fftMidIn(tmp1, tmp1, cache_group); + tailMul(tmp1, inB, tmp1, cache_group); + fftMidOut(tmp1, tmp1, cache_group); + fftW(tmp2, tmp1, cache_group); + } + } - // Register the current ROE pos as multiplication (vs. a squaring) - if (mulRoePos.empty() || mulRoePos.back() < roePos) { mulRoePos.push_back(roePos); } + // Register the current ROE pos as multiplication (vs. a squaring) + if (mulRoePos.empty() || mulRoePos.back() < roePos) { mulRoePos.push_back(roePos); } - fftMidOut(tmp2, tmp1); - fftW(tmp1, tmp2); - if (mul3) { carryM(ioA, tmp1); } else { carryA(ioA, tmp1); } - carryB(ioA); + if (mul3) { carryM(ioA, tmp2); } else { carryA(ioA, tmp2); } + carryB(ioA); } -void Gpu::mul(Buffer& io, Buffer& buf1) { +void Gpu::mul(Buffer& io, Buffer& buf1) { // We know that mul() stores double output in buf1; so we're going to use buf2 & buf3 for temps. mul(io, buf1, buf2, buf3, false); } // out := inA * inB; -void Gpu::modMul(Buffer& ioA, Buffer& inB, bool mul3) { - fftP(buf2, inB); - fftMidIn(buf1, buf2); +void Gpu::modMul(Buffer& ioA, Buffer& inB, bool mul3) { + modMul(ioA, inB, LEAD_NONE, mul3); +}; + +// out := inA * inB; inB will end up in buf1 in the LEAD_MIDDLE state +void Gpu::modMul(Buffer& ioA, Buffer& inB, enum LEAD_TYPE leadInB, bool mul3) { + if (!in_place) { + if (leadInB == LEAD_NONE) fftP(buf2, inB); + if (leadInB != LEAD_MIDDLE) fftMidIn(buf1, buf2); + } else { + if (leadInB == LEAD_NONE) fftP(buf1, inB); + if (leadInB != LEAD_MIDDLE) fftMidIn(buf1, buf1); + } mul(ioA, buf1, buf2, buf3, mul3); }; -void Gpu::writeState(const vector& check, u32 blockSize) { +void Gpu::writeState(u32 k, const vector& check, u32 blockSize) { assert(blockSize > 0); writeIn(bufCheck, check); bufData << bufCheck; bufAux << bufCheck; - - u32 n = 0; - for (n = 1; blockSize % (2 * n) == 0; n *= 2) { - squareLoop(bufData, 0, n); - modMul(bufData, bufAux); - bufAux << bufData; - } - - assert((n & (n - 1)) == 0); - assert(blockSize % n == 0); - - blockSize /= n; - assert(blockSize >= 2); - - for (u32 i = 0; i < blockSize - 2; ++i) { + + if (k) { // Only verify bufData that was read in from a save file + u32 n; + for (n = 1; blockSize % (2 * n) == 0; n *= 2) { + squareLoop(bufData, 0, n); + modMul(bufData, bufAux); + bufAux << bufData; + } + + assert((n & (n - 1)) == 0); + assert(blockSize % n == 0); + + blockSize /= n; + assert(blockSize >= 2); + + for (u32 i = 0; i < blockSize - 2; ++i) { + squareLoop(bufData, 0, n); + modMul(bufData, bufAux); + } + squareLoop(bufData, 0, n); - modMul(bufData, bufAux); } - - squareLoop(bufData, 0, n); modMul(bufData, bufAux, true); } @@ -798,15 +1057,44 @@ void Gpu::logTimeKernels() { profile.reset(); } -vector Gpu::readOut(Buffer &buf) { +vector Gpu::readWords(Buffer &buf) { + // GPU is returning either 4-byte or 8-byte integers. C++ code is expecting 8-byte integers. Handle the "no conversion" case. + if (fft.WordSize == 8) return buf.read(); + // Convert 32-bit GPU Words into 64-bit C++ Words + vector GPUdata = buf.read(); + vector CPUdata; + CPUdata.resize(GPUdata.size() * 2); + for (u32 i = 0; i < GPUdata.size(); ++i) { + CPUdata[2*i] = (i32) GPUdata[i]; + CPUdata[2*i+1] = (GPUdata[i] >> 32); + } + return CPUdata; +} + +void Gpu::writeWords(Buffer& buf, vector &words) { + // GPU is expecting either 4-byte or 8-byte integers. C++ code is using 8-byte integers. Handle the "no conversion" case. + if (fft.WordSize == 8) buf.write(std::move(words)); + // Convert 64-bit C++ Words into 32-bit GPU Words + else { + vector GPUdata; + GPUdata.resize(words.size() / 2); + assert((words.size() & 1) == 0); + for (u32 i = 0; i < words.size(); i += 2) { + GPUdata[i/2] = ((i64) words[i+1] << 32) | (u32) words[i]; + } + buf.write(std::move(GPUdata)); + } +} + +vector Gpu::readOut(Buffer &buf) { transpOut(bufAux, buf); - return bufAux.read(); + return readWords(bufAux); } -void Gpu::writeIn(Buffer& buf, const vector& words) { writeIn(buf, expandBits(words, N, E)); } +void Gpu::writeIn(Buffer& buf, const vector& words) { writeIn(buf, expandBits(words, N, E)); } -void Gpu::writeIn(Buffer& buf, vector&& words) { - bufAux.write(std::move(words)); +void Gpu::writeIn(Buffer& buf, vector&& words) { + writeWords(bufAux, words); transpIn(buf, bufAux); } @@ -831,7 +1119,7 @@ Words Gpu::expExp2(const Words& A, u32 n) { } // A:= A^h * B -void Gpu::expMul(Buffer& A, u64 h, Buffer& B) { +void Gpu::expMul(Buffer& A, u64 h, Buffer& B) { exponentiate(A, h, buf1, buf2, buf3); modMul(A, B); } @@ -848,38 +1136,56 @@ Words Gpu::expMul(const Words& A, u64 h, const Words& B, bool doSquareB) { static bool testBit(u64 x, int bit) { return x & (u64(1) << bit); } -void Gpu::bottomHalf(Buffer& out, Buffer& inTmp) { - fftMidIn(out, inTmp); - if (!tail_single_kernel) tailSquareZero(inTmp, out); - tailSquare(inTmp, out); - fftMidOut(out, inTmp); -} - // See "left-to-right binary exponentiation" on wikipedia -void Gpu::exponentiate(Buffer& bufInOut, u64 exp, Buffer& buf1, Buffer& buf2, Buffer& buf3) { +void Gpu::exponentiate(Buffer& bufInOut, u64 exp, Buffer& buf1, Buffer& buf2, Buffer& buf3) { if (exp == 0) { bufInOut.set(1); } else if (exp > 1) { - fftP(buf3, bufInOut); - fftMidIn(buf2, buf3); + if (!in_place) { + fftP(buf3, bufInOut); + fftMidIn(buf2, buf3); + } else { + fftP(buf2, bufInOut); + fftMidIn(buf2, buf2); + } fftHin(buf1, buf2); // save "base" to buf1 + bool midInAlreadyDone = 1; int p = 63; while (!testBit(exp, p)) { --p; } for (--p; ; --p) { - bottomHalf(buf2, buf3); + for (int cache_group = 1; cache_group <= NUM_CACHE_GROUPS; ++cache_group) { + if (!in_place) { + if (!midInAlreadyDone) fftMidIn(buf2, buf3, cache_group); + tailSquare(buf3, buf2, cache_group); + fftMidOut(buf2, buf3, cache_group); + } else { + if (!midInAlreadyDone) fftMidIn(buf2, buf2, cache_group); + tailSquare(buf2, buf2, cache_group); + fftMidOut(buf2, buf2, cache_group); + } + } + midInAlreadyDone = 0; if (testBit(exp, p)) { - doCarry(buf3, buf2); - fftMidIn(buf2, buf3); - tailMulLow(buf3, buf2, buf1); - fftMidOut(buf2, buf3); + doCarry(buf3, buf2, bufInOut); + for (int cache_group = 1; cache_group <= NUM_CACHE_GROUPS; ++cache_group) { + if (!in_place) { + fftMidIn(buf2, buf3, cache_group); + tailMulLow(buf3, buf2, buf1, cache_group); + fftMidOut(buf2, buf3, cache_group); + } else { + fftMidIn(buf2, buf2, cache_group); + tailMulLow(buf2, buf2, buf1, cache_group); + fftMidOut(buf2, buf2, cache_group); + } + } } if (!p) { break; } - doCarry(buf3, buf2); + doCarry(buf3, buf2, bufInOut); } fftW(buf3, buf2); @@ -888,28 +1194,68 @@ void Gpu::exponentiate(Buffer& bufInOut, u64 exp, Buffer& buf1, Buf } } -// does either carrryFused() or the expanded version depending on useLongCarry -void Gpu::doCarry(Buffer& out, Buffer& in) { - if (useLongCarry) { - fftW(out, in); - carryA(in, out); - carryB(in); - fftP(out, in); +// does either carryFused() or the expanded version depending on useLongCarry +void Gpu::doCarry(Buffer& out, Buffer& in, Buffer& tmp) { + if (!in_place) { + if (useLongCarry) { + fftW(out, in); + carryA(tmp, out); + carryB(tmp); + fftP(out, tmp); + } else { + carryFused(out, in); + } } else { - carryFused(out, in); + if (useLongCarry) { + fftW(out, in); + carryA(tmp, out); + carryB(tmp); + fftP(in, tmp); + } else { + carryFused(in, in); + } } } -void Gpu::square(Buffer& out, Buffer& in, bool leadIn, bool leadOut, bool doMul3, bool doLL) { +// Use buf1 and buf2 to do a single squaring. +void Gpu::square(Buffer& out, Buffer& in, enum LEAD_TYPE leadIn, enum LEAD_TYPE leadOut, bool doMul3, bool doLL) { + // leadOut = LEAD_MIDDLE is not supported (slower than LEAD_WIDTH) + assert(leadOut != LEAD_MIDDLE); // LL does not do Mul3 assert(!(doMul3 && doLL)); - if (leadIn) { fftP(buf2, in); } - - bottomHalf(buf1, buf2); + // Not in place FFTs use buf1 and buf2 in a "ping pong" fashion. + // If leadIn is LEAD_NONE, in contains the input data, squaring starts at fftP + // If leadIn is LEAD_WIDTH, buf2 contains the input data, squaring starts at fftMidIn + // If leadIn is LEAD_MIDDLE, buf1 contains the input data, squaring starts at tailSquare + // If leadOut is LEAD_WIDTH, then will buf2 contain the output of carryFused -- to be used as input to the next squaring. + if (!in_place) { + if (leadIn == LEAD_NONE) fftP(buf2, in); + for (int cache_group = 1; cache_group <= NUM_CACHE_GROUPS; ++cache_group) { + if (leadIn != LEAD_MIDDLE) fftMidIn(buf1, buf2, cache_group); + tailSquare(buf2, buf1, cache_group); + fftMidOut(buf1, buf2, cache_group); + if (leadOut == LEAD_NONE) fftW(buf2, buf1, cache_group); + } + } + + // In place FFTs use buf1. + // If leadIn is LEAD_NONE, in contains the input data, squaring starts at fftP + // If leadIn is LEAD_WIDTH, buf1 contains the input data, squaring starts at fftMidIn + // If leadIn is LEAD_MIDDLE, buf1 contains the input data, squaring starts at tailSquare + // If leadOut is LEAD_WIDTH, then buf1 will contain the output of carryFused -- to be used as input to the next squaring. + else { + if (leadIn == LEAD_NONE) fftP(buf1, in); + for (int cache_group = 1; cache_group <= NUM_CACHE_GROUPS; ++cache_group) { + if (leadIn != LEAD_MIDDLE) fftMidIn(buf1, buf1, cache_group); + tailSquare(buf1, buf1, cache_group); + fftMidOut(buf1, buf1, cache_group); + if (leadOut == LEAD_NONE) fftW(buf2, buf1, cache_group); + } + } - if (leadOut) { - fftW(buf2, buf1); + // If leadOut is not allowed then we cannot use the faster carryFused kernel + if (leadOut == LEAD_NONE) { if (!doLL && !doMul3) { carryA(out, buf2); } else if (doLL) { @@ -918,57 +1264,59 @@ void Gpu::square(Buffer& out, Buffer& in, bool leadIn, bool leadOut, b carryM(out, buf2); } carryB(out); - } else { + } + + // Use CarryFused + else { assert(!useLongCarry); assert(!doMul3); - if (doLL) { - carryFusedLL(buf2, buf1); + carryFusedLL(in_place ? buf1 : buf2, buf1); } else { - carryFused(buf2, buf1); + carryFused(in_place ? buf1 : buf2, buf1); } - // Unused: carryFusedMul(buf2, buf1); } } -void Gpu::square(Buffer& io) { square(io, io, true, true, false, false); } - -u32 Gpu::squareLoop(Buffer& out, Buffer& in, u32 from, u32 to, bool doTailMul3) { +u32 Gpu::squareLoop(Buffer& out, Buffer& in, u32 from, u32 to, bool doTailMul3) { assert(from < to); - bool leadIn = true; + enum LEAD_TYPE leadIn = LEAD_NONE; for (u32 k = from; k < to; ++k) { - bool leadOut = useLongCarry || (k == to - 1); + enum LEAD_TYPE leadOut = useLongCarry || (k == to - 1) ? LEAD_NONE : LEAD_WIDTH; square(out, (k==from) ? in : out, leadIn, leadOut, doTailMul3 && (k == to - 1)); leadIn = leadOut; } return to; } -bool Gpu::isEqual(Buffer& in1, Buffer& in2) { +bool Gpu::isEqual(Buffer& in1, Buffer& in2) { kernIsEqual(in1, in2); int isEq = 0; bufTrue.read(&isEq, 1); if (!isEq) { bufTrue.write({1}); } return isEq; } - -u64 Gpu::bufResidue(Buffer &buf) { + +u64 Gpu::bufResidue(Buffer &buf) { readResidue(bufSmallOut, buf); - int words[64]; - bufSmallOut.read(words, 64); + vector words = readWords(bufSmallOut); int carry = 0; - for (int i = 0; i < 32; ++i) { carry = (words[i] + carry < 0) ? -1 : 0; } + for (int i = 0; i < 32; ++i) { + u32 len = bitlen(N, E, N - 32 + i); + i64 w = (i64) words[i] + carry; + carry = (int) (w >> len); + } u64 res = 0; int hasBits = 0; for (int k = 0; k < 32 && hasBits < 64; ++k) { u32 len = bitlen(N, E, k); - int w = words[32 + k] + carry; - carry = (w < 0) ? -1 : 0; - if (w < 0) { w += (1 << len); } - assert(w >= 0 && w < (1 << len)); - res |= u64(w) << hasBits; + i64 tmp = (i64) words[32 + k] + carry; + carry = (int) (tmp >> len); + u64 w = tmp - ((i64) carry << len); + assert(w < (1ULL << len)); + res += w << hasBits; hasBits += len; } return res; @@ -1015,10 +1363,14 @@ void Gpu::doBigLog(u32 k, u64 res, bool checkOK, float secsPerIt, u32 nIters, u3 auto [roeSq, roeMul] = readROE(); double z = roeSq.z(); zAvg.update(z, roeSq.N); - log("%sZ=%.0f (avg %.1f)%s\n", makeLogStr(checkOK ? "OK" : "EE", k, res, secsPerIt, nIters).c_str(), - z, zAvg.avg(), (nErrors ? " "s + to_string(nErrors) + " errors"s : ""s).c_str()); - - if (roeSq.N > 2 && z < 20) { + if (roeSq.max > 0.005) + log("%sZ=%.0f (avg %.1f), ROEmax=%.3f, ROEavg=%.3f. %s\n", makeLogStr(checkOK ? "OK" : "EE", k, res, secsPerIt, nIters).c_str(), + z, zAvg.avg(), roeSq.max, roeSq.mean, (nErrors ? " "s + to_string(nErrors) + " errors"s : ""s).c_str()); + else + log("%sZ=%.0f (avg %.1f) %s\n", makeLogStr(checkOK ? "OK" : "EE", k, res, secsPerIt, nIters).c_str(), + z, zAvg.avg(), (nErrors ? " "s + to_string(nErrors) + " errors"s : ""s).c_str()); + + if (roeSq.N > 2 && (z < 6 || (fft.shape.fft_type == FFT64 && z < 20))) { log("Danger ROE! Z=%.1f is too small, increase precision or FFT size!\n", z); } @@ -1057,6 +1409,8 @@ int ulps(double a, double b) { } void Gpu::selftestTrig() { + +#if FFT_FP64 const u32 n = hN / 8; testTrig(buf1); vector trig = buf1.read(n * 2); @@ -1081,7 +1435,7 @@ void Gpu::selftestTrig() { if (c < refCos) { ++cdown; } double norm = trigNorm(c, s); - + if (norm < 1.0) { ++oneDown; } if (norm > 1.0) { ++oneUp; } } @@ -1091,7 +1445,8 @@ void Gpu::selftestTrig() { log("TRIG cos(): imperfect %d / %d (%.2f%%), balance %d\n", cup + cdown, n, (cup + cdown) * 100.0 / n, cup - cdown); log("TRIG norm: up %d, down %d\n", oneUp, oneDown); - +#endif + if (isAmdGpu(queue->context->deviceId())) { vector WHATS {"V_NOP", "V_ADD_I32", "V_FMA_F32", "V_ADD_F64", "V_FMA_F64", "V_MUL_F64", "V_MAD_U64_U32"}; for (int w = 0; w < int(WHATS.size()); ++w) { @@ -1173,7 +1528,7 @@ PRPState Gpu::loadPRP(Saver& saver) { } PRPState state = saver.load(); - writeState(state.check, state.blockSize); + writeState(state.k, state.check, state.blockSize); u64 res = dataResidue(); if (res == state.res64) { @@ -1216,7 +1571,7 @@ tuple Gpu::measureCarry() { u32 k = 0; PRPState state{E, 0, blockSize, 3, makeWords(E, 1), 0}; - writeState(state.check, state.blockSize); + writeState(state.k, state.check, state.blockSize); { u64 res = dataResidue(); if (res != state.res64) { @@ -1225,12 +1580,18 @@ tuple Gpu::measureCarry() { assert(res == state.res64); } - modMul(bufCheck, bufData); - square(bufData, bufData, true, useLongCarry); + enum LEAD_TYPE leadIn = LEAD_NONE; + modMul(bufCheck, bufData, leadIn); + leadIn = LEAD_MIDDLE; + + enum LEAD_TYPE leadOut = useLongCarry ? LEAD_NONE : LEAD_WIDTH; + square(bufData, bufData, leadIn, leadOut); + leadIn = leadOut; ++k; while (k < warmup) { - square(bufData, bufData, useLongCarry, useLongCarry); + square(bufData, bufData, leadIn, leadOut); + leadIn = leadOut; ++k; } @@ -1238,20 +1599,20 @@ tuple Gpu::measureCarry() { if (Signal::stopRequested()) { throw "stop requested"; } - bool leadIn = useLongCarry; while (true) { while (k % blockSize < blockSize-1) { - square(bufData, bufData, leadIn, useLongCarry); + square(bufData, bufData, leadIn, leadOut); + leadIn = leadOut; ++k; - leadIn = useLongCarry; } - square(bufData, bufData, useLongCarry, true); - leadIn = true; + square(bufData, bufData, leadIn, LEAD_NONE); + leadIn = LEAD_NONE; ++k; if (k >= iters) { break; } - modMul(bufCheck, bufData); + modMul(bufCheck, bufData, leadIn); + leadIn = LEAD_MIDDLE; if (Signal::stopRequested()) { throw "stop requested"; } } @@ -1284,7 +1645,7 @@ tuple Gpu::measureROE(bool quick) { u32 k = 0; PRPState state{E, 0, blockSize, 3, makeWords(E, 1), 0}; - writeState(state.check, state.blockSize); + writeState(state.k, state.check, state.blockSize); { u64 res = dataResidue(); if (res != state.res64) { @@ -1293,12 +1654,18 @@ tuple Gpu::measureROE(bool quick) { assert(res == state.res64); } - modMul(bufCheck, bufData); - square(bufData, bufData, true, useLongCarry); + enum LEAD_TYPE leadIn = LEAD_NONE; + modMul(bufCheck, bufData, leadIn); + leadIn = LEAD_MIDDLE; + + enum LEAD_TYPE leadOut = useLongCarry ? LEAD_NONE : LEAD_WIDTH; + square(bufData, bufData, leadIn, leadOut); + leadIn = leadOut; ++k; while (k < warmup) { - square(bufData, bufData, useLongCarry, useLongCarry); + square(bufData, bufData, leadIn, leadOut); + leadIn = leadOut; ++k; } @@ -1306,20 +1673,20 @@ tuple Gpu::measureROE(bool quick) { if (Signal::stopRequested()) { throw "stop requested"; } - bool leadIn = useLongCarry; while (true) { while (k % blockSize < blockSize-1) { - square(bufData, bufData, leadIn, useLongCarry); + square(bufData, bufData, leadIn, leadOut); + leadIn = leadOut; ++k; - leadIn = useLongCarry; } - square(bufData, bufData, useLongCarry, true); - leadIn = true; + square(bufData, bufData, leadIn, LEAD_NONE); + leadIn = LEAD_NONE; ++k; if (k >= iters) { break; } - modMul(bufCheck, bufData); + modMul(bufCheck, bufData, leadIn); + leadIn = LEAD_MIDDLE; if (Signal::stopRequested()) { throw "stop requested"; } } @@ -1334,32 +1701,40 @@ tuple Gpu::measureROE(bool quick) { return {ok, res, roes.first, roes.second}; } -double Gpu::timePRP() { +double Gpu::timePRP(int quick) { // Quick varies from 1 (slowest, longest) to 10 (quickest, shortest) u32 blockSize{}, iters{}, warmup{}; - if (true) { - blockSize = 200; - iters = 1000; - warmup = 30; - } else { - blockSize = 1000; - iters = 10'000; - warmup = 100; - } + if (quick == 10) iters = 400, blockSize = 200; + else if (quick == 9) iters = 600, blockSize = 300; + else if (quick == 8) iters = 900, blockSize = 300; + else if (quick == 7) iters = 1200, blockSize = 400; + else if (quick == 6) iters = 1800, blockSize = 600; + else if (quick == 5) iters = 3000, blockSize = 1000; + else if (quick == 4) iters = 5000, blockSize = 1000; + else if (quick == 3) iters = 8000, blockSize = 1000; + else if (quick == 2) iters = 12000, blockSize = 1000; + else if (quick == 1) iters = 20000, blockSize = 1000; + warmup = 20; assert(iters % blockSize == 0); u32 k = 0; PRPState state{E, 0, blockSize, 3, makeWords(E, 1), 0}; - writeState(state.check, state.blockSize); + writeState(state.k, state.check, state.blockSize); assert(dataResidue() == state.res64); - modMul(bufCheck, bufData); - square(bufData, bufData, true, useLongCarry); + enum LEAD_TYPE leadIn = LEAD_NONE; + modMul(bufCheck, bufData, leadIn); + leadIn = LEAD_MIDDLE; + + enum LEAD_TYPE leadOut = useLongCarry ? LEAD_NONE : LEAD_WIDTH; + square(bufData, bufData, leadIn, leadOut); + leadIn = leadOut; ++k; while (k < warmup) { - square(bufData, bufData, useLongCarry, useLongCarry); + square(bufData, bufData, leadIn, leadOut); + leadIn = leadOut; ++k; } queue->finish(); @@ -1367,20 +1742,20 @@ double Gpu::timePRP() { Timer t; queue->setSquareTime(0); // Busy wait on nVidia to get the most accurate timings while tuning - bool leadIn = useLongCarry; while (true) { while (k % blockSize < blockSize-1) { - square(bufData, bufData, leadIn, useLongCarry); + square(bufData, bufData, leadIn, leadOut); + leadIn = leadOut; ++k; - leadIn = useLongCarry; } - square(bufData, bufData, useLongCarry, true); - leadIn = true; + square(bufData, bufData, leadIn, LEAD_NONE); + leadIn = LEAD_NONE; ++k; if (k >= iters) { break; } - modMul(bufCheck, bufData); + modMul(bufCheck, bufData, leadIn); + leadIn = LEAD_MIDDLE; if (Signal::stopRequested()) { throw "stop requested"; } } queue->finish(); @@ -1448,7 +1823,7 @@ PRPResult Gpu::isPrimePRP(const Task& task) { bool skipNextCheckUpdate = false; u32 persistK = proofSet.next(k); - bool leadIn = true; + enum LEAD_TYPE leadIn = LEAD_NONE; assert(k % blockSize == 0); assert(checkStep % blockSize == 0); @@ -1466,23 +1841,24 @@ PRPResult Gpu::isPrimePRP(const Task& task) { if (skipNextCheckUpdate) { skipNextCheckUpdate = false; } else if (k % blockSize == 0) { - assert(leadIn); - modMul(bufCheck, bufData); + modMul(bufCheck, bufData, leadIn); + leadIn = LEAD_MIDDLE; } ++k; // !! early inc bool doStop = (k % blockSize == 0) && (Signal::stopRequested() || (args.iters && k - startK >= args.iters)); - bool leadOut = (k % blockSize == 0) || k == persistK || k == kEnd || useLongCarry; + bool doCheck = doStop || (k % checkStep == 0) || (k >= kEndEnd) || (k - startK == 2 * blockSize); + bool doLog = k % logStep == 0; + enum LEAD_TYPE leadOut = doCheck || doLog || k == persistK || k == kEnd || useLongCarry ? LEAD_NONE : LEAD_WIDTH; - assert(!doStop || leadOut); if (doStop) { log("Stopping, please wait..\n"); } square(bufData, bufData, leadIn, leadOut, false); leadIn = leadOut; - + if (k == persistK) { - vector rawData = readChecked(bufData); + vector rawData = readChecked(bufData); if (rawData.empty()) { log("Data error ZERO\n"); ++nErrors; @@ -1503,18 +1879,13 @@ PRPResult Gpu::isPrimePRP(const Task& task) { log("%s %8d / %d, %s\n", isPrime ? "PP" : "CC", kEnd, E, hex(finalRes64).c_str()); } - bool doCheck = doStop || (k % checkStep == 0) || (k >= kEndEnd) || (k - startK == 2 * blockSize); - bool doLog = k % logStep == 0; - - if (!leadOut || (!doCheck && !doLog)) continue; - - assert(doCheck || doLog); + if (!doCheck && !doLog) continue; u64 res = dataResidue(); float secsPerIt = iterationTimer.reset(k); queue->setSquareTime((int) (secsPerIt * 1'000'000)); - vector rawCheck = readChecked(bufCheck); + vector rawCheck = readChecked(bufCheck); if (rawCheck.empty()) { ++nErrors; log("%9u %016" PRIx64 " read NULL check\n", k, res); @@ -1577,7 +1948,7 @@ PRPResult Gpu::isPrimePRP(const Task& task) { queue->finish(); throw "stop requested"; } - + iterationTimer.reset(k); } } @@ -1613,7 +1984,7 @@ LLResult Gpu::isPrimeLL(const Task& task) { u32 k = startK; u32 kEnd = E - 2; - bool leadIn = true; + enum LEAD_TYPE leadIn = LEAD_NONE; while (true) { ++k; @@ -1624,8 +1995,8 @@ LLResult Gpu::isPrimeLL(const Task& task) { log("Stopping, please wait..\n"); } - bool doLog = (k % 10'000 == 0) || doStop; - bool leadOut = doLog || useLongCarry; + bool doLog = (k % args.logStep == 0) || doStop; + enum LEAD_TYPE leadOut = doLog || useLongCarry ? LEAD_NONE : LEAD_WIDTH; squareLL(bufData, leadIn, leadOut); leadIn = leadOut; @@ -1669,14 +2040,15 @@ array Gpu::isCERT(const Task& task) { // Get CERT start value char fname[32]; sprintf(fname, "M%u.cert", E); - File fi = File::openReadThrow(fname); - -//We need to gracefully handle the CERT file missing. There is a window in primenet.py between worktodo.txt entry and starting value download. - u32 nBytes = (E - 1) / 8 + 1; - Words B = fi.readBytesLE(nBytes); +// Autoprimenet.py does not add the cert entry to worktodo.txt until it has successfully downloaded the .cert file. - writeIn(bufData, std::move(B)); + { // Enclosing this code in braces ensures the file will be closed by the File destructor. The later file deletion requires the file be closed in Windows. + File fi = File::openReadThrow(fname); + u32 nBytes = (E - 1) / 8 + 1; + Words B = fi.readBytesLE(nBytes); + writeIn(bufData, std::move(B)); + } Timer elapsedTimer; @@ -1688,7 +2060,7 @@ array Gpu::isCERT(const Task& task) { u32 k = 0; u32 kEnd = task.squarings; - bool leadIn = true; + enum LEAD_TYPE leadIn = LEAD_NONE; while (true) { ++k; @@ -1700,7 +2072,7 @@ array Gpu::isCERT(const Task& task) { } bool doLog = (k % 100'000 == 0) || doStop; - bool leadOut = doLog || useLongCarry; + enum LEAD_TYPE leadOut = doLog || useLongCarry ? LEAD_NONE : LEAD_WIDTH; squareCERT(bufData, leadIn, leadOut); leadIn = leadOut; diff --git a/src/Gpu.h b/src/Gpu.h index d232157b..fc5166f3 100644 --- a/src/Gpu.h +++ b/src/Gpu.h @@ -26,12 +26,9 @@ struct Task; class Signal; class ProofSet; -using double2 = pair; using TrigBuf = Buffer; using TrigPtr = shared_ptr; -namespace fs = std::filesystem; - inline u64 residue(const Words& words) { return (u64(words[1]) << 32) | words[0]; } struct PRPResult { @@ -84,7 +81,6 @@ struct Weights { vector weightsConstIF; vector weightsIF; vector bitsCF; - vector bitsC; }; class Gpu { @@ -100,54 +96,74 @@ class Gpu { u32 E; u32 N; + FFTConfig fft; u32 WIDTH; u32 SMALL_H; u32 BIG_H; - u32 hN, nW, nH, bufSize; + u32 hN, nW, nH; bool useLongCarry; u32 wantROE{}; Profile profile{}; KernelCompiler compiler; - - Kernel kCarryFused; - Kernel kCarryFusedROE; - Kernel kCarryFusedMul; - Kernel kCarryFusedMulROE; - Kernel kCarryFusedLL; + /* Kernels for FFT_FP64 or FFT_FP32 */ + Kernel kfftMidIn; + Kernel kfftHin; + Kernel ktailSquareZero; + Kernel ktailSquare; + Kernel ktailMul; + Kernel ktailMulLow; + Kernel kfftMidOut; + Kernel kfftW; + + /* Kernels for NTT_GF31 */ + Kernel kfftMidInGF31; + Kernel kfftHinGF31; + Kernel ktailSquareZeroGF31; + Kernel ktailSquareGF31; + Kernel ktailMulGF31; + Kernel ktailMulLowGF31; + Kernel kfftMidOutGF31; + Kernel kfftWGF31; + + /* Kernels for NTT_GF61 */ + Kernel kfftMidInGF61; + Kernel kfftHinGF61; + Kernel ktailSquareZeroGF61; + Kernel ktailSquareGF61; + Kernel ktailMulGF61; + Kernel ktailMulLowGF61; + Kernel kfftMidOutGF61; + Kernel kfftWGF61; + + /* Kernels dealing with the FP data and product of NTT primes */ + Kernel kfftP; Kernel kCarryA; Kernel kCarryAROE; Kernel kCarryM; Kernel kCarryMROE; Kernel kCarryLL; - Kernel carryB; - - Kernel fftP; - Kernel fftW; - - Kernel fftHin; - - Kernel tailSquareZero; - Kernel tailSquare; - Kernel tailMul; - Kernel tailMulLow; - - Kernel fftMidIn; - Kernel fftMidOut; + Kernel kCarryFused; + Kernel kCarryFusedROE; + Kernel kCarryFusedMul; + Kernel kCarryFusedMulROE; + Kernel kCarryFusedLL; + Kernel carryB; Kernel transpIn, transpOut; - Kernel readResidue; Kernel kernIsEqual; Kernel sum64; + + /* Weird test kernels */ Kernel testTrig; Kernel testFFT4; - Kernel testFFT; - Kernel testFFT15; Kernel testFFT14; + Kernel testFFT15; + Kernel testFFT; Kernel testTime; // Kernel testKernel; @@ -155,40 +171,35 @@ class Gpu { // Copy of some -use options needed for Kernel, Trig, and Weights initialization bool tail_single_wide; // TailSquare processes one line at a time bool tail_single_kernel; // TailSquare does not use a separate kernel for line zero - u32 tail_trigs; // 0,1,2. Increasing values use more DP and less memory accesses - u32 pad_size; // Pad size in bytes + u32 in_place; // Should GPU perform transform in-place. 1 = nVidia friendly memory layout, 2 = AMD friendly. + u32 pad_size; // Pad size in bytes as specified on the command line or config.txt. Maximum value is 512. // Twiddles: trigonometry constant buffers, used in FFTs. // The twiddles depend only on FFT config and do not depend on the exponent. - TrigPtr bufTrigW; + // It is important to generate the height trigs before the width trigs because width trigs can be a subset of the height trigs TrigPtr bufTrigH; TrigPtr bufTrigM; + TrigPtr bufTrigW; + // Weights and the "bigWord bits" are only needed for FP64 and FP32 FFTs Weights weights; - - // The weights and the "bigWord bits" depend on the exponent. Buffer bufConstWeights; Buffer bufWeights; - Buffer bufBits; // bigWord bits aligned for CarryFused/fftP - Buffer bufBitsC; // bigWord bits aligned for CarryA/M // "integer word" buffers. These are "small buffers": N x int. - Buffer bufData; // Main int buffer with the words. - Buffer bufAux; // Auxiliary int buffer, used in transposing data in/out and in check. - Buffer bufCheck; // Buffers used with the error check. - Buffer bufBase; // used in P-1 error check. + Buffer bufData; // Main int buffer with the words. + Buffer bufAux; // Auxiliary int buffer, used in transposing data in/out and in check. + Buffer bufCheck; // Buffers used with the error check. // Carry buffers, used in carry and fusedCarry. Buffer bufCarry; // Carry shuttle. - Buffer bufReady; // Per-group ready flag for stairway carry propagation. // Small aux buffers. - Buffer bufSmallOut; + Buffer bufSmallOut; Buffer bufSumOut; Buffer bufTrue; - Buffer bufROE; // The round-off error ("ROE"), one float element per iteration. Buffer bufStatsCarry; @@ -207,46 +218,68 @@ class Gpu { TimeInfo* timeBufVect; ZAvg zAvg; - vector readOut(Buffer &buf); - void writeIn(Buffer& buf, vector&& words); - - void square(Buffer& out, Buffer& in, bool leadIn, bool leadOut, bool doMul3 = false, bool doLL = false); - void squareCERT(Buffer& io, bool leadIn, bool leadOut) { square(io, io, leadIn, leadOut, false, false); } - void squareLL(Buffer& io, bool leadIn, bool leadOut) { square(io, io, leadIn, leadOut, false, true); } - - void square(Buffer& io); - - u32 squareLoop(Buffer& out, Buffer& in, u32 from, u32 to, bool doTailMul3); - u32 squareLoop(Buffer& io, u32 from, u32 to) { return squareLoop(io, io, from, to, false); } - - bool isEqual(Buffer& bufCheck, Buffer& bufAux); - u64 bufResidue(Buffer& buf); + int NUM_CACHE_GROUPS = 3; + + void fftP(Buffer& out, Buffer& in) { fftP(out, reinterpret_cast&>(in)); } + void fftP(Buffer& out, Buffer& in); + void fftMidIn(Buffer& out, Buffer& in, int cache_group = 0); + void fftMidOut(Buffer& out, Buffer& in, int cache_group = 0); + void fftHin(Buffer& out, Buffer& in); + void tailSquare(Buffer& out, Buffer& in, int cache_group = 0); + void tailMul(Buffer& out, Buffer& in1, Buffer& in2, int cache_group = 0); + void tailMulLow(Buffer& out, Buffer& in1, Buffer& in2, int cache_group = 0); + void fftW(Buffer& out, Buffer& in, int cache_group = 0); + void carryA(Buffer& out, Buffer& in) { carryA(reinterpret_cast&>(out), in); } + void carryA(Buffer& out, Buffer& in); + void carryM(Buffer& out, Buffer& in); + void carryLL(Buffer& out, Buffer& in); + void carryFused(Buffer& out, Buffer& in); + void carryFusedMul(Buffer& out, Buffer& in); + void carryFusedLL(Buffer& out, Buffer& in); + + vector readWords(Buffer &buf); + void writeWords(Buffer& buf, vector &words); + + vector readOut(Buffer &buf); + void writeIn(Buffer& buf, vector&& words); + + enum LEAD_TYPE {LEAD_NONE = 0, LEAD_WIDTH = 1, LEAD_MIDDLE = 2}; + + void square(Buffer& out, Buffer& in, enum LEAD_TYPE leadIn, enum LEAD_TYPE leadOut, bool doMul3 = false, bool doLL = false); + void square(Buffer& io) { square(io, io, LEAD_NONE, LEAD_NONE, false, false); } + void squareCERT(Buffer& io, enum LEAD_TYPE leadIn, enum LEAD_TYPE leadOut) { square(io, io, leadIn, leadOut, false, false); } + void squareLL(Buffer& io, enum LEAD_TYPE leadIn, enum LEAD_TYPE leadOut) { square(io, io, leadIn, leadOut, false, true); } + + u32 squareLoop(Buffer& out, Buffer& in, u32 from, u32 to, bool doTailMul3); + u32 squareLoop(Buffer& io, u32 from, u32 to) { return squareLoop(io, io, from, to, false); } + + bool isEqual(Buffer& bufCheck, Buffer& bufAux); + u64 bufResidue(Buffer& buf); vector writeBase(const vector &v); - void exponentiate(Buffer& bufInOut, u64 exp, Buffer& buf1, Buffer& buf2, Buffer& buf3); + void exponentiate(Buffer& bufInOut, u64 exp, Buffer& buf1, Buffer& buf2, Buffer& buf3); - void bottomHalf(Buffer& out, Buffer& inTmp); + void writeState(u32 k, const vector& check, u32 blockSize); - void writeState(const vector& check, u32 blockSize); - // does either carrryFused() or the expanded version depending on useLongCarry - void doCarry(Buffer& out, Buffer& in); + void doCarry(Buffer& out, Buffer& in, Buffer& tmp); - void mul(Buffer& ioA, Buffer& inB, Buffer& tmp1, Buffer& tmp2, bool mul3 = false); - void mul(Buffer& io, Buffer& inB); + void mul(Buffer& ioA, Buffer& inB, Buffer& tmp1, Buffer& tmp2, bool mul3 = false); + void mul(Buffer& io, Buffer& inB); + + void modMul(Buffer& ioA, Buffer& inB, bool mul3 = false); + void modMul(Buffer& ioA, Buffer& inB, enum LEAD_TYPE leadInB, bool mul3 = false); - void modMul(Buffer& ioA, Buffer& inB, bool mul3 = false); - fs::path saveProof(const Args& args, const ProofSet& proofSet); std::pair readROE(); RoeInfo readCarryStats(); - + u32 updateCarryPos(u32 bit); PRPState loadPRP(Saver& saver); - vector readChecked(Buffer& buf); + vector readChecked(Buffer& buf); // void measureTransferSpeed(); @@ -265,37 +298,23 @@ class Gpu { LLResult isPrimeLL(const Task& task); array isCERT(const Task& task); - double timePRP(); + double timePRP(int quick = 7); tuple measureROE(bool quick); tuple measureCarry(); Saver *getSaver(); - void carryA(Buffer& a, Buffer& b) { carryA(reinterpret_cast&>(a), b); } - - void carryA(Buffer& a, Buffer& b); - - void carryM(Buffer& a, Buffer& b); - - void carryLL(Buffer& a, Buffer& b); - - void carryFused(Buffer& a, Buffer& b); - - void carryFusedMul(Buffer& a, Buffer& b); - - void carryFusedLL(Buffer& a, Buffer& b) { kCarryFusedLL(a, b, updateCarryPos(1<<0));} - - void writeIn(Buffer& buf, const vector &words); + void writeIn(Buffer& buf, const vector &words); u64 dataResidue() { return bufResidue(bufData); } u64 checkResidue() { return bufResidue(bufCheck); } - + bool doCheck(u32 blockSize); void logTimeKernels(); - Words readAndCompress(Buffer& buf); + Words readAndCompress(Buffer& buf); vector readCheck(); vector readData(); @@ -309,11 +328,11 @@ class Gpu { Words expMul2(const Words& A, u64 h, const Words& B); // A:= A^h * B - void expMul(Buffer& A, u64 h, Buffer& B); - + void expMul(Buffer& A, u64 h, Buffer& B); + // return A^(2^n) Words expExp2(const Words& A, u32 n); - vector> makeBufVector(u32 size); + vector> makeBufVector(u32 size); void clear(bool isPRP); @@ -321,3 +340,16 @@ class Gpu { u32 getProofPower(u32 k); void doBigLog(u32 k, u64 res, bool checkOK, float secsPerIt, u32 nIters, u32 nErrors); }; + +// Compute the size of an FFT/NTT data buffer depending on the FFT/NTT float/prime. Size is returned in units of sizeof(double). +// Data buffers require extra space for padding. We can probably tighten up the amount of extra memory allocated. +// The worst case seems to be !INPLACE, MIDDLE=4, PAD_SIZE=512. + +#define MID_ADJUST(size,M,pad) ((pad == 0 || M != 4) ? (size) : (size) * 5/4) +#define PAD_ADJUST(N,M,inplace,pad) (inplace ? 3*N/2 : MID_ADJUST(pad == 0 ? N : pad <= 128 ? 9*N/8 : pad <= 256 ? 5*N/4 : 3*N/2, M, pad)) +#define FP64_DATA_SIZE(W,M,H,inplace,pad) PAD_ADJUST(W*M*H*2, M, inplace, pad) +#define FP32_DATA_SIZE(W,M,H,inplace,pad) PAD_ADJUST(W*M*H*2, M, inplace, pad) * sizeof(float) / sizeof(double) +#define GF31_DATA_SIZE(W,M,H,inplace,pad) PAD_ADJUST(W*M*H*2, M, inplace, pad) * sizeof(uint) / sizeof(double) +#define GF61_DATA_SIZE(W,M,H,inplace,pad) PAD_ADJUST(W*M*H*2, M, inplace, pad) * sizeof(ulong) / sizeof(double) +#define TOTAL_DATA_SIZE(fft,W,M,H,inplace,pad) (int)fft.FFT_FP64 * FP64_DATA_SIZE(W,M,H,inplace,pad) + (int)fft.FFT_FP32 * FP32_DATA_SIZE(W,M,H,inplace,pad) + \ + (int)fft.NTT_GF31 * GF31_DATA_SIZE(W,M,H,inplace,pad) + (int)fft.NTT_GF61 * GF61_DATA_SIZE(W,M,H,inplace,pad) diff --git a/src/Proof.cpp b/src/Proof.cpp index 4f1d01e1..31c488f8 100644 --- a/src/Proof.cpp +++ b/src/Proof.cpp @@ -268,7 +268,7 @@ std::pair> ProofSet::computeProof(Gpu *gpu) const { auto hash = proof::hashWords(E, B); - vector> bufVect = gpu->makeBufVector(power); + vector> bufVect = gpu->makeBufVector(power); for (u32 p = 0; p < power; ++p) { auto bufIt = bufVect.begin(); diff --git a/src/Queue.cpp b/src/Queue.cpp index d2c1ceac..27f461ad 100644 --- a/src/Queue.cpp +++ b/src/Queue.cpp @@ -24,11 +24,13 @@ Queue::Queue(const Context& context, bool profile) : markerEvent{}, markerQueued(false), queueCount(0), - squareTime(50) + squareTime(50), + squareKernels(4), + firstSetTime(true) { // Formerly a constant (thus the CAPS). nVidia is 3% CPU load at 400 or 500, and 35% load at 800 on my Linux machine. // AMD is just over 2% load at 1600 and 3200 on the same Linux machine. Marginally better timings(?) at 3200. - MAX_QUEUE_COUNT = isAmdGpu(context.deviceId()) ? 3200 : 500; // Queue size for 800 or 125 squarings + MAX_QUEUE_COUNT = isAmdGpu(context.deviceId()) ? 3200 : 500; // Queue size for 800 or 125 squarings (if squareKernels = 4) } void Queue::writeTE(cl_mem buf, u64 size, const void* data, TimeInfo* tInfo) { @@ -58,7 +60,6 @@ void Queue::add(EventHolder&& e, TimeInfo* ti) { } void Queue::readSync(cl_mem buf, u32 size, void* out, TimeInfo* tInfo) { - queueMarkerEvent(); add(read(get(), {}, true, buf, size, out, hasEvents), tInfo); events.synced(); } @@ -91,7 +92,7 @@ void Queue::queueMarkerEvent() { } // Enqueue a marker for nVidia GPUs else { - clEnqueueMarkerWithWaitList(get(), 0, NULL, &markerEvent); + markerEvent = enqueueMarker(get()); markerQueued = true; queueCount = 0; } @@ -102,14 +103,18 @@ void Queue::waitForMarkerEvent() { if (!markerQueued) return; // By default, nVidia finish causes a CPU busy wait. Instead, sleep for a while. Since we know how many items are enqueued after the marker we can make an // educated guess of how long to sleep to keep CPU overhead low. - while (getEventInfo(markerEvent) != CL_COMPLETE) { - // There are 4 kernels per squaring. Don't overestimate sleep time. Divide by 10 instead of 4. - std::this_thread::sleep_for(std::chrono::microseconds(1 + queueCount * squareTime / 10)); + while (getEventInfo(markerEvent.get()) != CL_COMPLETE) { + // There are 4, 7, or 10 kernels per squaring. Don't overestimate sleep time. Divide by much more than the number of kernels. + std::this_thread::sleep_for(std::chrono::microseconds(1 + queueCount * squareTime / squareKernels / 2)); } markerQueued = false; } void Queue::setSquareTime(int time) { + if (firstSetTime) { // Ignore first setSquareTime call. First measured times are wrong because of startup costs + firstSetTime = false; + return; + } if (time < 30) time = 30; // Assume a minimum square time of 30us if (time > 3000) time = 3000; // Assume a maximum square time of 3000us squareTime = time; diff --git a/src/Queue.h b/src/Queue.h index 06efa159..be8341d6 100644 --- a/src/Queue.h +++ b/src/Queue.h @@ -50,14 +50,17 @@ class Queue : public QueueHolder { void copyBuf(cl_mem src, cl_mem dst, u32 size, TimeInfo* tInfo); void finish(); - void setSquareTime(int); // Set the time to do one squaring (in microseconds) + void setSquareTime(int); // Update the time to do one squaring (in microseconds) + void setSquareKernels(int n) { squareKernels = n; firstSetTime = true; } private: // This replaces the "call queue->finish every 400 squarings" code in Gpu.cpp. Solves the busy wait on nVidia GPUs. int MAX_QUEUE_COUNT; // Queue size before a marker will be enqueued. Typically, 100 to 1000 squarings. - cl_event markerEvent; // Event associated with an enqueued marker placed in the queue every MAX_QUEUE_COUNT entries and before r/w operations. + EventHolder markerEvent; // Event associated with an enqueued marker placed in the queue every MAX_QUEUE_COUNT entries and before r/w operations. bool markerQueued; // TRUE if a marker and event have been queued int queueCount; // Count of items added to the queue since last marker int squareTime; // Time to do one squaring (in microseconds) + int squareKernels; // Number of kernels in one squaring + bool firstSetTime; // Flag so we can ignore first setSquareTime call (which is inaccurate because of all the initial openCL compiles) void queueMarkerEvent(); // Queue the marker event void waitForMarkerEvent(); // Wait for marker event to complete }; diff --git a/src/Task.cpp b/src/Task.cpp index db36e782..d72f4db2 100644 --- a/src/Task.cpp +++ b/src/Task.cpp @@ -8,6 +8,7 @@ #include "Worktodo.h" #include "Saver.h" #include "version.h" +#include "Primes.h" #include "Proof.h" #include "log.h" #include "timeutil.h" @@ -40,7 +41,7 @@ constexpr int platform() { const constexpr bool IS_32BIT = (sizeof(void*) == 4); -#if defined(WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(__NT__) +#if defined(WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(__NT__) || defined(__MINGW32__) || defined(__MINGW64__) || defined(__MSYS__) return IS_32BIT ? WIN_32 : WIN_64; #elif __APPLE__ @@ -83,6 +84,13 @@ OsInfo getOsInfo() { return getOsInfoMinimum(); } #endif +string ffttype(FFTConfig fft) { + return fft.shape.fft_type == FFT64 ? "FP64" : fft.shape.fft_type == FFT3161 ? "M31+M61" : + fft.shape.fft_type == FFT61 ? "M61" : fft.shape.fft_type == FFT3261 ? "FP32+M61" : + fft.shape.fft_type == FFT31 ? "M31" : fft.shape.fft_type == FFT3231 ? "FP32+M31" : + fft.shape.fft_type == FFT32 ? "FP32" : fft.shape.fft_type == FFT6431 ? "FP64+M31" : fft.shape.fft_type == FFT323161 ? "FP32+M31+M61" : "unknown"; +} + string json(const vector& v) { bool isFirst = true; string s = "{"; @@ -145,12 +153,13 @@ void writeResult(u32 instance, u32 E, const char *workType, const string &status } -void Task::writeResultPRP(const Args &args, u32 instance, bool isPrime, u64 res64, const string& res2048, u32 fftSize, u32 nErrors, const fs::path& proofPath) const { +void Task::writeResultPRP(FFTConfig fft, const Args &args, u32 instance, bool isPrime, u64 res64, const string& res2048, u32 nErrors, const fs::path& proofPath) const { vector fields{json("res64", hex(res64)), json("res2048", res2048), json("residue-type", 1), json("errors", vector{json("gerbicz", nErrors)}), - json("fft-length", fftSize) + json("fft-type", ffttype(fft)), + json("fft-length", fft.size()) }; // "proof":{"version":1, "power":6, "hashsize":64, "md5":"0123456789ABCDEF"}, @@ -169,9 +178,10 @@ void Task::writeResultPRP(const Args &args, u32 instance, bool isPrime, u64 res6 writeResult(instance, exponent, "PRP-3", isPrime ? "P" : "C", AID, args, fields); } -void Task::writeResultLL(const Args &args, u32 instance, bool isPrime, u64 res64, u32 fftSize) const { +void Task::writeResultLL(FFTConfig fft, const Args &args, u32 instance, bool isPrime, u64 res64) const { vector fields{json("res64", hex(res64)), - json("fft-length", fftSize), + json("fft-type", ffttype(fft)), + json("fft-length", fft.size()), json("shift-count", 0), json("error-code", "00000000"), // I don't know the meaning of this }; @@ -179,13 +189,14 @@ void Task::writeResultLL(const Args &args, u32 instance, bool isPrime, u64 res64 writeResult(instance, exponent, "LL", isPrime ? "P" : "C", AID, args, fields); } -void Task::writeResultCERT(const Args &args, u32 instance, array hash, u32 squarings, u32 fftSize) const { +void Task::writeResultCERT(FFTConfig fft, const Args &args, u32 instance, array hash, u32 squarings) const { string hexhash = hex(hash[3]) + hex(hash[2]) + hex(hash[1]) + hex(hash[0]); vector fields{json("worktype", "Cert"), - json("exponent", exponent), - json("sha3-hash", hexhash.c_str()), - json("squarings", squarings), - json("fft-length", fftSize), + json("exponent", exponent), + json("sha3-hash", hexhash.c_str()), + json("squarings", squarings), + json("fft-type", ffttype(fft)), + json("fft-length", fft.size()), json("shift-count", 0), json("error-code", "00000000"), // I don't know the meaning of this }; @@ -201,6 +212,20 @@ void Task::execute(GpuCommon shared, Queue *q, u32 instance) { assert(exponent); + // Testing exponent 140000001 using FFT 512:15:512 fails with severe round off errors. + // I'm guessing this is because bot the exponent and FFT size are divisible by 3. + // Here we make sure the exponent is prime. If not we do not raise an error because it + // is very common to use command line argument "-prp some-random-exponent" to get a quick + // timing. Instead, we output a warning and test a smaller prime exponent. + { + Primes primes; + if (!primes.isPrime(exponent)) { + u32 new_exponent = primes.prevPrime(exponent); + log("Warning: Exponent %u is not prime. Using exponent %u instead.\n", exponent, new_exponent); + exponent = new_exponent; + } + } + LogContext pushContext(std::to_string(exponent)); FFTConfig fft = FFTConfig::bestFit(*shared.args, exponent, shared.args->fftSpec); @@ -218,11 +243,11 @@ void Task::execute(GpuCommon shared, Queue *q, u32 instance) { if (kind == PRP) { auto [tmpIsPrime, res64, nErrors, proofPath, res2048] = gpu->isPrimePRP(*this); isPrime = tmpIsPrime; - writeResultPRP(*shared.args, instance, isPrime, res64, res2048, fft.size(), nErrors, proofPath); + writeResultPRP(fft, *shared.args, instance, isPrime, res64, res2048, nErrors, proofPath); } else { // LL auto [tmpIsPrime, res64] = gpu->isPrimeLL(*this); isPrime = tmpIsPrime; - writeResultLL(*shared.args, instance, isPrime, res64, fft.size()); + writeResultLL(fft, *shared.args, instance, isPrime, res64); } Worktodo::deleteTask(*this, instance); @@ -234,7 +259,7 @@ void Task::execute(GpuCommon shared, Queue *q, u32 instance) { } } else if (kind == CERT) { auto sha256 = gpu->isCERT(*this); - writeResultCERT(*shared.args, instance, sha256, squarings, fft.size()); + writeResultCERT(fft, *shared.args, instance, sha256, squarings); Worktodo::deleteTask(*this, instance); } else { throw "Unexpected task kind " + to_string(kind); diff --git a/src/Task.h b/src/Task.h index 6d870263..95f08024 100644 --- a/src/Task.h +++ b/src/Task.h @@ -5,6 +5,7 @@ #include "Args.h" #include "common.h" #include "GpuCommon.h" +#include "FFTConfig.h" #include @@ -27,7 +28,7 @@ class Task { string verifyPath; // For Verify void execute(GpuCommon shared, Queue* q, u32 instance); - void writeResultPRP(const Args&, u32 instance, bool isPrime, u64 res64, const std::string& res2048, u32 fftSize, u32 nErrors, const fs::path& proofPath) const; - void writeResultLL(const Args&, u32 instance, bool isPrime, u64 res64, u32 fftSize) const; - void writeResultCERT(const Args&, u32 instance, array hash, u32 squarings, u32 fftSize) const; + void writeResultPRP(FFTConfig fft, const Args&, u32 instance, bool isPrime, u64 res64, const std::string& res2048, u32 nErrors, const fs::path& proofPath) const; + void writeResultLL(FFTConfig fft, const Args&, u32 instance, bool isPrime, u64 res64) const; + void writeResultCERT(FFTConfig fft, const Args&, u32 instance, array hash, u32 squarings) const; }; diff --git a/src/TrigBufCache.cpp b/src/TrigBufCache.cpp index 7646a724..03c8e1e6 100644 --- a/src/TrigBufCache.cpp +++ b/src/TrigBufCache.cpp @@ -1,9 +1,10 @@ - // Copyright Mihai Preda +// Copyright Mihai Preda +#include #include "TrigBufCache.h" #define SAVE_ONE_MORE_WIDTH_MUL 0 // I want to make saving the only option -- but rocm optimizer is inexplicably making it slower in carryfused -#define SAVE_ONE_MORE_HEIGHT_MUL 1 // In tailSquar this is the fastest option +#define SAVE_ONE_MORE_HEIGHT_MUL 1 // In tailSquare this is the fastest option #define _USE_MATH_DEFINES #include @@ -119,7 +120,6 @@ double root1cosover(u32 N, u32 k, double over) { return double(c / over); } -namespace { static const constexpr bool LOG_TRIG_ALLOC = false; // Interleave two lines of trig values so that AMD GPUs can use global_load_dwordx4 instructions @@ -136,8 +136,8 @@ void T2shuffle(u32 size, u32 radix, u32 line, vector &tab) { } } -vector genSmallTrig(u32 size, u32 radix) { - if (LOG_TRIG_ALLOC) { log("genSmallTrig(%u, %u)\n", size, radix); } +vector genSmallTrigFP64(u32 size, u32 radix) { + if (LOG_TRIG_ALLOC) { log("genSmallTrigFP64(%u, %u)\n", size, radix); } u32 WG = size / radix; vector tab; @@ -221,10 +221,12 @@ vector genSmallTrig(u32 size, u32 radix) { } // Generate the small trig values for fft_HEIGHT plus optionally trig values used in pairSq. -vector genSmallTrigCombo(u32 width, u32 middle, u32 size, u32 radix, bool tail_single_wide, u32 tail_trigs) { - if (LOG_TRIG_ALLOC) { log("genSmallTrigCombo(%u, %u)\n", size, radix); } +vector genSmallTrigComboFP64(Args *args, u32 width, u32 middle, u32 size, u32 radix, bool tail_single_wide) { + if (LOG_TRIG_ALLOC) { log("genSmallTrigComboFP64(%u, %u)\n", size, radix); } - vector tab = genSmallTrig(size, radix); + vector tab = genSmallTrigFP64(size, radix); + + u32 tail_trigs = args->value("TAIL_TRIGS", 2); // Default is calculating from scratch, no memory accesses // From tailSquare pre-calculate some or all of these: T2 trig = slowTrig_N(line + H * lowMe, ND / NH * 2); if (tail_trigs == 1) { // Some trig values in memory, some are computed with a complex multiply. Best option on a Radeon VII. @@ -234,14 +236,14 @@ vector genSmallTrigCombo(u32 width, u32 middle, u32 size, u32 radix, bo tab.push_back(root1(width * middle * height, width * middle * me)); } // Output the one or two T2 multipliers to be read by one u,v pair of lines - for (u32 line = 0; line < width * middle / 2; ++line) { + for (u32 line = 0; line <= width * middle / 2; ++line) { tab.push_back(root1Fancy(width * middle * height, line)); if (!tail_single_wide) tab.push_back(root1Fancy(width * middle * height, line ? width * middle - line : width * middle / 2)); } } if (tail_trigs == 0) { // All trig values read from memory. Best option for GPUs with lousy DP performance. u32 height = size; - for (u32 u = 0; u < width * middle / 2; ++u) { + for (u32 u = 0; u <= width * middle / 2; ++u) { for (u32 v = 0; v < (tail_single_wide ? 1 : 2); ++v) { u32 line = (v == 0) ? u : (u ? width * middle - u : width * middle / 2); for (u32 me = 0; me < height / radix; ++me) { @@ -258,8 +260,8 @@ vector genSmallTrigCombo(u32 width, u32 middle, u32 size, u32 radix, bo // cos-1 "fancy" trick. #define SHARP_MIDDLE 5 -vector genMiddleTrig(u32 smallH, u32 middle, u32 width) { - if (LOG_TRIG_ALLOC) { log("genMiddleTrig(%u, %u, %u)\n", smallH, middle, width); } +vector genMiddleTrigFP64(u32 smallH, u32 middle, u32 width) { + if (LOG_TRIG_ALLOC) { log("genMiddleTrigFP64(%u, %u, %u)\n", smallH, middle, width); } vector tab; if (middle == 1) { tab.resize(1); @@ -267,63 +269,662 @@ vector genMiddleTrig(u32 smallH, u32 middle, u32 width) { if (middle < SHARP_MIDDLE) { for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1(smallH * middle, k)); } for (u32 k = 0; k < width; ++k) { tab.push_back(root1(middle * width, k)); } + for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1(width * middle * smallH, k)); } } else { for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1Fancy(smallH * middle, k)); } for (u32 k = 0; k < width; ++k) { tab.push_back(root1Fancy(middle * width, k)); } + for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1(width * middle * smallH, k)); } + } + } + return tab; +} + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on floats */ +/**************************************************************************/ + +// For small angles, return "fancy" cos - 1 for increased precision +float2 root1FancyFP32(u32 N, u32 k) { + assert(!(N&7)); + assert(k < N); + assert(k < N/4); + + double angle = M_PI * k / (N / 2); + return {float(cos(angle) - 1), float(sin(angle))}; +} + +static float trigNorm(float c, float s) { return c * c + s * s; } +static float trigError(float c, float s) { return abs(trigNorm(c, s) - 1.0f); } + +// Round trig double to float as to satisfy c^2 + s^2 == 1 as best as possible +static float2 roundTrig(double lc, double ls) { + float c1 = lc; + float c2 = nexttoward(c1, lc); + float s1 = ls; + float s2 = nexttoward(s1, ls); + + float c = c1; + float s = s1; + for (float tryC : {c1, c2}) { + for (float tryS : {s1, s2}) { + if (trigError(tryC, tryS) < trigError(c, s)) { + c = tryC; + s = tryS; + } + } + } + return {c, s}; +} + +// Returns the primitive root of unity of order N, to the power k. +float2 root1FP32(u32 N, u32 k) { + assert(k < N); + if (k >= N/2) { + auto [c, s] = root1FP32(N, k - N/2); + return {-c, -s}; + } else if (k > N/4) { + auto [c, s] = root1FP32(N, N/2 - k); + return {-c, s}; + } else if (k > N/8) { + auto [c, s] = root1FP32(N, N/4 - k); + return {s, c}; + } else { + assert(k <= N/8); + + double angle = M_PI * k / (N / 2); + return roundTrig(cos(angle), sin(angle)); + } +} + +vector genSmallTrigFP32(u32 size, u32 radix) { + u32 WG = size / radix; + vector tab; + +// old fft_WIDTH and fft_HEIGHT + for (u32 line = 1; line < radix; ++line) { + for (u32 col = 0; col < WG; ++col) { + tab.push_back(radix / line >= 8 ? root1FancyFP32(size, col * line) : root1FP32(size, col * line)); + } + } + tab.resize(size); + return tab; +} + +// Generate the small trig values for fft_HEIGHT plus optionally trig values used in pairSq. +vector genSmallTrigComboFP32(Args *args, u32 width, u32 middle, u32 size, u32 radix, bool tail_single_wide) { + vector tab = genSmallTrigFP32(size, radix); + + u32 tail_trigs = args->value("TAIL_TRIGS32", 2); // Default is calculating from scratch, no memory accesses + + // From tailSquare pre-calculate some or all of these: F2 trig = slowTrig_N(line + H * lowMe, ND / NH * 2); + if (tail_trigs == 1) { // Some trig values in memory, some are computed with a complex multiply. + u32 height = size; + // Output line 0 trig values to be read by every u,v pair of lines + for (u32 me = 0; me < height / radix; ++me) { + tab.push_back(root1FP32(width * middle * height, width * middle * me)); + } + // Output the one or two F2 multipliers to be read by one u,v pair of lines + for (u32 line = 0; line <= width * middle / 2; ++line) { + tab.push_back(root1FancyFP32(width * middle * height, line)); + if (!tail_single_wide) tab.push_back(root1FancyFP32(width * middle * height, line ? width * middle - line : width * middle / 2)); + } + } + if (tail_trigs == 0) { // All trig values read from memory. Best option for GPUs with lousy FP performance? + u32 height = size; + for (u32 u = 0; u <= width * middle / 2; ++u) { + for (u32 v = 0; v < (tail_single_wide ? 1 : 2); ++v) { + u32 line = (v == 0) ? u : (u ? width * middle - u : width * middle / 2); + for (u32 me = 0; me < height / radix; ++me) { + tab.push_back(root1FP32(width * middle * height, line + width * middle * me)); + } + } + } + } + + return tab; +} + +vector genMiddleTrigFP32(u32 smallH, u32 middle, u32 width) { + vector tab; + if (middle == 1) { + tab.resize(1); + } else { + if (middle < SHARP_MIDDLE) { + for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1FP32(smallH * middle, k)); } + for (u32 k = 0; k < width; ++k) { tab.push_back(root1FP32(middle * width, k)); } + for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1FP32(width * middle * smallH, k)); } + } else { + for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1FancyFP32(smallH * middle, k)); } + for (u32 k = 0; k < width; ++k) { tab.push_back(root1FancyFP32(middle * width, k)); } + for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1FP32(width * middle * smallH, k)); } + } + } + return tab; +} + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +// Z31 and GF31 code copied from Yves Gallot's mersenne2 program + +// Z/{2^31 - 1}Z: the prime field of order p = 2^31 - 1 +class Z31 +{ +private: + static const uint32_t _p = (uint32_t(1) << 31) - 1; + uint32_t _n; // 0 <= n < p + + static uint32_t _add(const uint32_t a, const uint32_t b) + { + const uint32_t t = a + b; + return t - ((t >= _p) ? _p : 0); + } + + static uint32_t _sub(const uint32_t a, const uint32_t b) + { + const uint32_t t = a - b; + return t + ((a < b) ? _p : 0); + } + + static uint32_t _mul(const uint32_t a, const uint32_t b) + { + const uint64_t t = a * uint64_t(b); + return _add(uint32_t(t) & _p, uint32_t(t >> 31)); + } + +public: + Z31() {} + explicit Z31(const uint32_t n) : _n(n) {} + + uint32_t get() const { return _n; } + + bool operator!=(const Z31 & rhs) const { return (_n != rhs._n); } + + // Z31 neg() const { return Z31((_n == 0) ? 0 : _p - _n); } + // Z31 half() const { return Z31(((_n % 2 == 0) ? _n : (_n + _p)) / 2); } + + Z31 operator+(const Z31 & rhs) const { return Z31(_add(_n, rhs._n)); } + Z31 operator-(const Z31 & rhs) const { return Z31(_sub(_n, rhs._n)); } + Z31 operator*(const Z31 & rhs) const { return Z31(_mul(_n, rhs._n)); } + + Z31 sqr() const { return Z31(_mul(_n, _n)); } +}; + + +// GF((2^31 - 1)^2): the prime field of order p^2, p = 2^31 - 1 +class GF31 +{ +private: + Z31 _s0, _s1; + // a primitive root of order 2^32 which is a root of (0, 1). + static const uint64_t _h_order = uint64_t(1) << 32; + static const uint32_t _h_0 = 7735u, _h_1 = 748621u; + +public: + GF31() {} + explicit GF31(const Z31 & s0, const Z31 & s1) : _s0(s0), _s1(s1) {} + explicit GF31(const uint32_t n0, const uint32_t n1) : _s0(n0), _s1(n1) {} + + const Z31 & s0() const { return _s0; } + const Z31 & s1() const { return _s1; } + + GF31 operator+(const GF31 & rhs) const { return GF31(_s0 + rhs._s0, _s1 + rhs._s1); } + GF31 operator-(const GF31 & rhs) const { return GF31(_s0 - rhs._s0, _s1 - rhs._s1); } + + GF31 sqr() const { const Z31 t = _s0 * _s1; return GF31(_s0.sqr() - _s1.sqr(), t + t); } + GF31 mul(const GF31 & rhs) const { return GF31(_s0 * rhs._s0 - _s1 * rhs._s1, _s1 * rhs._s0 + _s0 * rhs._s1); } + + GF31 pow(const uint64_t e) const + { + if (e == 0) return GF31(1u, 0u); + GF31 r = GF31(1u, 0u), y = *this; + for (uint64_t i = e; i != 1; i /= 2) { if (i % 2 != 0) r = r.mul(y); y = y.sqr(); } + return r.mul(y); + } + + static const GF31 root_one(const size_t n) { return GF31(Z31(_h_0), Z31(_h_1)).pow(_h_order / n); } + static uint8_t log2_root_two(const size_t n) { return uint8_t(((uint64_t(1) << 30) / n) % 31); } +}; + +// Returns the primitive root of unity of order N, to the power k. +uint2 root1GF31(GF31 root1N, u32 k) { + GF31 x = root1N.pow(k); + return { x.s0().get(), x.s1().get() }; +} +uint2 root1GF31(u32 N, u32 k) { + assert(k < N); + GF31 root1N = GF31::root_one(N); + return root1GF31(root1N, k); +} + +vector genSmallTrigGF31(u32 size, u32 radix) { + u32 WG = size / radix; + vector tab; + + GF31 root1size = GF31::root_one(size); + for (u32 line = 1; line < radix; ++line) { + for (u32 col = 0; col < WG; ++col) { + tab.push_back(root1GF31(root1size, col * line)); + } + } + tab.resize(size); + return tab; +} + +// Generate the small trig values for fft_HEIGHT plus optionally trig values used in pairSq. +vector genSmallTrigComboGF31(Args *args, u32 width, u32 middle, u32 size, u32 radix, bool tail_single_wide) { + vector tab = genSmallTrigGF31(size, radix); + + u32 tail_trigs = args->value("TAIL_TRIGS31", 0); // Default is reading all trigs from memory + + // From tailSquareGF31 pre-calculate some or all of these: GF31 trig = slowTrigGF31(line + H * lowMe, ND / NH * 2); + u32 height = size; + GF31 root1wmh = GF31::root_one(width * middle * height); + if (tail_trigs >= 1) { // Some trig values in memory, some are computed with a complex multiply. Best option on a Radeon VII. + // Output line 0 trig values to be read by every u,v pair of lines + for (u32 me = 0; me < height / radix; ++me) { + tab.push_back(root1GF31(root1wmh, width * middle * me)); + } + // Output the one or two GF31 multipliers to be read by one u,v pair of lines + for (u32 line = 0; line <= width * middle / 2; ++line) { + tab.push_back(root1GF31(root1wmh, line)); + if (!tail_single_wide) tab.push_back(root1GF31(root1wmh, line ? width * middle - line : width * middle / 2)); + } + } + if (tail_trigs == 0) { // All trig values read from memory. Best option for GPUs with great memory performance. + for (u32 u = 0; u <= width * middle / 2; ++u) { + for (u32 v = 0; v < (tail_single_wide ? 1 : 2); ++v) { + u32 line = (v == 0) ? u : (u ? width * middle - u : width * middle / 2); + for (u32 me = 0; me < height / radix; ++me) { + tab.push_back(root1GF31(root1wmh, line + width * middle * me)); + } + } + } + } + + return tab; +} + +vector genMiddleTrigGF31(u32 smallH, u32 middle, u32 width) { + vector tab; + if (middle == 1) { + tab.resize(1); + } else { + GF31 root1hm = GF31::root_one(smallH * middle); + for (u32 m = 1; m < middle; ++m) { + for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1GF31(root1hm, k * m)); } + } + GF31 root1mw = GF31::root_one(middle * width); + for (u32 k = 0; k < width; ++k) { tab.push_back(root1GF31(root1mw, k)); } + GF31 root1wmh = GF31::root_one(width * middle * smallH); + for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1GF31(root1wmh, k)); } + } + return tab; +} + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +// Z61 and GF61 code copied from Yves Gallot's mersenne2 program + +// Z/{2^61 - 1}Z: the prime field of order p = 2^61 - 1 +class Z61 +{ +private: + static const uint64_t _p = (uint64_t(1) << 61) - 1; + uint64_t _n; // 0 <= n < p + + static uint64_t _add(const uint64_t a, const uint64_t b) + { + const uint64_t t = a + b; + return t - ((t >= _p) ? _p : 0); + } + + static uint64_t _sub(const uint64_t a, const uint64_t b) + { + const uint64_t t = a - b; + return t + ((a < b) ? _p : 0); + } + + static uint64_t _mul(const uint64_t a, const uint64_t b) + { + const __uint128_t t = a * __uint128_t(b); + const uint64_t lo = uint64_t(t), hi = uint64_t(t >> 64); + const uint64_t lo61 = lo & _p, hi61 = (lo >> 61) | (hi << 3); + return _add(lo61, hi61); + } + +public: + Z61() {} + explicit Z61(const uint64_t n) : _n(n) {} + + uint64_t get() const { return _n; } + + bool operator!=(const Z61 & rhs) const { return (_n != rhs._n); } + + Z61 operator+(const Z61 & rhs) const { return Z61(_add(_n, rhs._n)); } + Z61 operator-(const Z61 & rhs) const { return Z61(_sub(_n, rhs._n)); } + Z61 operator*(const Z61 & rhs) const { return Z61(_mul(_n, rhs._n)); } + + Z61 sqr() const { return Z61(_mul(_n, _n)); } +}; + +// GF((2^61 - 1)^2): the prime field of order p^2, p = 2^61 - 1 +class GF61 +{ +private: + Z61 _s0, _s1; + // Primitive root of order 2^62 which is a root of (0, 1). This root corresponds to 2*pi*i*j/N in FFTs. PRPLL FFTs use this root. Thanks, Yves! + static const uint64_t _h_0 = 264036120304204ull, _h_1 = 4677669021635377ull; + // Primitive root of order 2^62 which is a root of (0, -1). This root corresponds to -2*pi*i*j/N in FFTs. + //static const uint64_t _h_0 = 481139922016222ull, _h_1 = 814659809902011ull; + static const uint64_t _h_order = uint64_t(1) << 62; + +public: + GF61() {} + explicit GF61(const Z61 & s0, const Z61 & s1) : _s0(s0), _s1(s1) {} + explicit GF61(const uint64_t n0, const uint64_t n1) : _s0(n0), _s1(n1) {} + + const Z61 & s0() const { return _s0; } + const Z61 & s1() const { return _s1; } + + GF61 operator+(const GF61 & rhs) const { return GF61(_s0 + rhs._s0, _s1 + rhs._s1); } + GF61 operator-(const GF61 & rhs) const { return GF61(_s0 - rhs._s0, _s1 - rhs._s1); } + + GF61 sqr() const { const Z61 t = _s0 * _s1; return GF61(_s0.sqr() - _s1.sqr(), t + t); } + GF61 mul(const GF61 & rhs) const { return GF61(_s0 * rhs._s0 - _s1 * rhs._s1, _s1 * rhs._s0 + _s0 * rhs._s1); } + + GF61 pow(const uint64_t e) const + { + if (e == 0) return GF61(1u, 0u); + GF61 r = GF61(1u, 0u), y = *this; + for (uint64_t i = e; i != 1; i /= 2) { if (i % 2 != 0) r = r.mul(y); y = y.sqr(); } + return r.mul(y); + } + + static const GF61 root_one(const size_t n) { return GF61(Z61(_h_0), Z61(_h_1)).pow(_h_order / n); } + static uint8_t log2_root_two(const size_t n) { return uint8_t(((uint64_t(1) << 60) / n) % 61); } +}; + +// Returns the primitive root of unity of order N, to the power k. +ulong2 root1GF61(GF61 root1N, u32 k) { + GF61 x = root1N.pow(k); + return { x.s0().get(), x.s1().get() }; +} +ulong2 root1GF61(u32 N, u32 k) { + assert(k < N); + GF61 root1N = GF61::root_one(N); + return root1GF61(root1N, k); +} + +vector genSmallTrigGF61(u32 size, u32 radix) { + u32 WG = size / radix; + vector tab; + + GF61 root1size = GF61::root_one(size); + for (u32 line = 1; line < radix; ++line) { + for (u32 col = 0; col < WG; ++col) { + tab.push_back(root1GF61(root1size, col * line)); + } + } + tab.resize(size); + return tab; +} + +// Generate the small trig values for fft_HEIGHT plus optionally trig values used in pairSq. +vector genSmallTrigComboGF61(Args *args, u32 width, u32 middle, u32 size, u32 radix, bool tail_single_wide) { + vector tab = genSmallTrigGF61(size, radix); + + u32 tail_trigs = args->value("TAIL_TRIGS61", 0); // Default is reading all trigs from memory + + // From tailSquareGF61 pre-calculate some or all of these: GF61 trig = slowTrigGF61(line + H * lowMe, ND / NH * 2); + u32 height = size; + GF61 root1wmh = GF61::root_one(width * middle * height); + if (tail_trigs >= 1) { // Some trig values in memory, some are computed with a complex multiply. Best option on a Radeon VII. + // Output line 0 trig values to be read by every u,v pair of lines + for (u32 me = 0; me < height / radix; ++me) { + tab.push_back(root1GF61(root1wmh, width * middle * me)); + } + // Output the one or two GF61 multipliers to be read by one u,v pair of lines + for (u32 line = 0; line <= width * middle / 2; ++line) { + tab.push_back(root1GF61(root1wmh, line)); + if (!tail_single_wide) tab.push_back(root1GF61(root1wmh, line ? width * middle - line : width * middle / 2)); + } + } + if (tail_trigs == 0) { // All trig values read from memory. Best option for GPUs with great memory performance. + for (u32 u = 0; u <= width * middle / 2; ++u) { + for (u32 v = 0; v < (tail_single_wide ? 1 : 2); ++v) { + u32 line = (v == 0) ? u : (u ? width * middle - u : width * middle / 2); + for (u32 me = 0; me < height / radix; ++me) { + tab.push_back(root1GF61(root1wmh, line + width * middle * me)); + } + } } } + + return tab; +} + +vector genMiddleTrigGF61(u32 smallH, u32 middle, u32 width) { + vector tab; + if (middle == 1) { + tab.resize(1); + } else { + GF61 root1hm = GF61::root_one(smallH * middle); + for (u32 m = 1; m < middle; ++m) { + for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1GF61(root1hm, k * m)); } + } + GF61 root1mw = GF61::root_one(middle * width); + for (u32 k = 0; k < width; ++k) { tab.push_back(root1GF61(root1mw, k)); } + GF61 root1wmh = GF61::root_one(width * middle * smallH); + for (u32 k = 0; k < smallH; ++k) { tab.push_back(root1GF61(root1wmh, k)); } + } + return tab; +} + + +/**********************************************************/ +/* Build all the needed trig values into one big buffer */ +/**********************************************************/ + +vector genSmallTrig(FFTConfig fft, u32 size, u32 radix) { + vector tab; + u32 tabsize; + + if (fft.FFT_FP64) { + tab = genSmallTrigFP64(size, radix); + tab.resize(SMALLTRIG_FP64_SIZE(size, 0, 0, 0)); + } + + if (fft.FFT_FP32) { + vector tab1 = genSmallTrigFP32(size, radix); + tab1.resize(SMALLTRIG_FP32_SIZE(size, 0, 0, 0)); + // Append tab1 to tab + tabsize = tab.size(); + tab.resize(tabsize + tab1.size() / 2); + memcpy((double *) tab.data() + tabsize * 2, tab1.data(), tab1.size() * 2 * sizeof(float)); + } + + if (fft.NTT_GF31) { + vector tab2 = genSmallTrigGF31(size, radix); + tab2.resize(SMALLTRIG_GF31_SIZE(size, 0, 0, 0)); + // Append tab2 to tab + tabsize = tab.size(); + tab.resize(tabsize + tab2.size() / 2); + memcpy((double *) tab.data() + tabsize * 2, tab2.data(), tab2.size() * 2 * sizeof(uint)); + } + + if (fft.NTT_GF61) { + vector tab3 = genSmallTrigGF61(size, radix); + tab3.resize(SMALLTRIG_GF61_SIZE(size, 0, 0, 0)); + // Append tab3 to tab + tabsize = tab.size(); + tab.resize(tabsize + tab3.size()); + memcpy((double *) tab.data() + tabsize * 2, tab3.data(), tab3.size() * 2 * sizeof(ulong)); + } + + return tab; +} + +vector genSmallTrigCombo(Args *args, FFTConfig fft, u32 width, u32 middle, u32 size, u32 radix, bool tail_single_wide) { + vector tab; + u32 tabsize; + + if (fft.FFT_FP64) { + tab = genSmallTrigComboFP64(args, width, middle, size, radix, tail_single_wide); + tab.resize(SMALLTRIGCOMBO_FP64_SIZE(width, middle, size, radix)); + } + + if (fft.FFT_FP32) { + vector tab1 = genSmallTrigComboFP32(args, width, middle, size, radix, tail_single_wide); + tab1.resize(SMALLTRIGCOMBO_FP32_SIZE(width, middle, size, radix)); + // Append tab1 to tab + tabsize = tab.size(); + tab.resize(tabsize + tab1.size() / 2); + memcpy((double *) tab.data() + tabsize * 2, tab1.data(), tab1.size() * 2 * sizeof(float)); + } + + if (fft.NTT_GF31) { + vector tab2 = genSmallTrigComboGF31(args, width, middle, size, radix, tail_single_wide); + tab2.resize(SMALLTRIGCOMBO_GF31_SIZE(width, middle, size, radix)); + // Append tab2 to tab + tabsize = tab.size(); + tab.resize(tabsize + tab2.size() / 2); + memcpy((double *) tab.data() + tabsize * 2, tab2.data(), tab2.size() * 2 * sizeof(uint)); + } + + if (fft.NTT_GF61) { + vector tab3 = genSmallTrigComboGF61(args, width, middle, size, radix, tail_single_wide); + tab3.resize(SMALLTRIGCOMBO_GF61_SIZE(width, middle, size, radix)); + // Append tab3 to tab + tabsize = tab.size(); + tab.resize(tabsize + tab3.size()); + memcpy((double *) tab.data() + tabsize * 2, tab3.data(), tab3.size() * 2 * sizeof(ulong)); + } + return tab; } -} // namespace +vector genMiddleTrig(FFTConfig fft, u32 smallH, u32 middle, u32 width) { + vector tab; + u32 tabsize; + + if (fft.FFT_FP64) { + tab = genMiddleTrigFP64(smallH, middle, width); + tab.resize(MIDDLETRIG_FP64_SIZE(width, middle, smallH)); + } + + if (fft.FFT_FP32) { + vector tab1 = genMiddleTrigFP32(smallH, middle, width); + tab1.resize(MIDDLETRIG_FP32_SIZE(width, middle, smallH)); + // Append tab1 to tab + tabsize = tab.size(); + tab.resize(tabsize + tab1.size() / 2); + memcpy((double *) tab.data() + tabsize * 2, tab1.data(), tab1.size() * 2 * sizeof(float)); + } + + if (fft.NTT_GF31) { + vector tab2 = genMiddleTrigGF31(smallH, middle, width); + tab2.resize(MIDDLETRIG_GF31_SIZE(width, middle, smallH)); + // Append tab2 to tab + tabsize = tab.size(); + tab.resize(tabsize + tab2.size() / 2); + memcpy((double *) tab.data() + tabsize * 2, tab2.data(), tab2.size() * 2 * sizeof(uint)); + } + + if (fft.NTT_GF61) { + vector tab3 = genMiddleTrigGF61(smallH, middle, width); + tab3.resize(MIDDLETRIG_GF61_SIZE(width, middle, smallH)); + // Append tab3 to tab + tabsize = tab.size(); + tab.resize(tabsize + tab3.size()); + memcpy((double *) tab.data() + tabsize * 2, tab3.data(), tab3.size() * 2 * sizeof(ulong)); + } + + return tab; +} + + +/********************************************************/ +/* Code to manage a cache of trigBuffers */ +/********************************************************/ + +#define make_key_part(b,tt,b31,tt31,b32,tt32,b61,tt61,tk) ((((((((b+tt) << 2) + b31+tt31) << 2) + b32+tt32) << 2) + b61+tt61) << 2) + tk TrigBufCache::~TrigBufCache() = default; -TrigPtr TrigBufCache::smallTrig(u32 W, u32 nW) { +TrigPtr TrigBufCache::smallTrig(Args *args, FFTConfig fft, u32 width, u32 nW, u32 middle, u32 height, u32 nH, bool tail_single_wide) { lock_guard lock{mut}; auto& m = small; - decay_t::key_type key{W, nW, 0, 0, 0, 0}; - TrigPtr p{}; - auto it = m.find(key); - if (it == m.end() || !(p = it->second.lock())) { - p = make_shared(context, genSmallTrig(W, nW)); - m[key] = p; - smallCache.add(p); + + u32 tail_trigs = args->value("TAIL_TRIGS", 2); // Default is calculating FP64 trigs from scratch, no memory accesses + u32 tail_trigs31 = args->value("TAIL_TRIGS31", 2); // Default is reading GF31 trigs from memory + u32 tail_trigs32 = args->value("TAIL_TRIGS32", 2); // Default is calculating FP32 trigs from scratch, no memory accesses + u32 tail_trigs61 = args->value("TAIL_TRIGS61", 2); // Default is reading GF61 trigs from memory + u32 key_part = make_key_part(fft.FFT_FP64, tail_trigs, fft.NTT_GF31, tail_trigs31, fft.FFT_FP32, tail_trigs32, fft.NTT_GF61, tail_trigs61, tail_single_wide); + + // See if there is an existing smallTrigCombo that we can return (using only a subset of the data) + // In theory, we could match any smallTrigCombo where width matches. However, SMALLTRIG_GF31_SIZE wouldn't be able to figure out the size. + // In practice, those cases will likely never arise. + if (width == height && nW == nH) { + decay_t::key_type key{height, nH, width, middle, key_part}; + auto it = m.find(key); + if (it != m.end() && (p = it->second.lock())) return p; } + + // See if there is an existing non-combo smallTrig that we can return + decay_t::key_type key{width, nW, 0, 0, key_part}; + auto it = m.find(key); + if (it != m.end() && (p = it->second.lock())) return p; + + // Create a new non-combo + p = make_shared(context, genSmallTrig(fft, width, nW)); + m[key] = p; + smallCache.add(p); return p; } -TrigPtr TrigBufCache::smallTrigCombo(u32 width, u32 middle, u32 W, u32 nW, u32 variant, bool tail_single_wide, u32 tail_trigs) { - if (tail_trigs == 2) // No pre-computed trig values. We might be able to share this trig table with fft_WIDTH - return smallTrig(W, nW); +TrigPtr TrigBufCache::smallTrigCombo(Args *args, FFTConfig fft, u32 width, u32 middle, u32 height, u32 nH, bool tail_single_wide) { + u32 tail_trigs = args->value("TAIL_TRIGS", 2); // Default is calculating FP64 trigs from scratch, no memory accesses + u32 tail_trigs31 = args->value("TAIL_TRIGS31", 2); // Default is reading GF31 trigs from memory + u32 tail_trigs32 = args->value("TAIL_TRIGS32", 2); // Default is calculating FP32 trigs from scratch, no memory accesses + u32 tail_trigs61 = args->value("TAIL_TRIGS61", 2); // Default is reading GF61 trigs from memory + u32 key_part = make_key_part(fft.FFT_FP64, tail_trigs, fft.NTT_GF31, tail_trigs31, fft.FFT_FP32, tail_trigs32, fft.NTT_GF61, tail_trigs61, tail_single_wide); + + // If there are no pre-computed trig values we might be able to share this trig table with fft_WIDTH + if (((tail_trigs == 2 && fft.FFT_FP64) || (tail_trigs32 == 2 && fft.FFT_FP32)) && !fft.NTT_GF31 && !fft.NTT_GF61) + return smallTrig(args, fft, height, nH, middle, height, nH, tail_single_wide); lock_guard lock{mut}; auto& m = small; - decay_t::key_type key1{W, nW, width, middle, tail_single_wide, tail_trigs}; - // We write the "combo" under two keys, so it can also be retrieved as non-combo by smallTrig() - decay_t::key_type key2{W, nW, 0, 0, 0, 0}; + decay_t::key_type key{height, nH, width, middle, key_part}; TrigPtr p{}; - auto it = m.find(key1); + auto it = m.find(key); if (it == m.end() || !(p = it->second.lock())) { - p = make_shared(context, genSmallTrigCombo(width, middle, W, nW, tail_single_wide, tail_trigs)); - m[key1] = p; - m[key2] = p; + p = make_shared(context, genSmallTrigCombo(args, fft, width, middle, height, nH, tail_single_wide)); + m[key] = p; smallCache.add(p); } return p; } -TrigPtr TrigBufCache::middleTrig(u32 SMALL_H, u32 MIDDLE, u32 width) { +TrigPtr TrigBufCache::middleTrig(Args *args, FFTConfig fft, u32 SMALL_H, u32 MIDDLE, u32 width) { lock_guard lock{mut}; auto& m = middle; - decay_t::key_type key{SMALL_H, MIDDLE, width}; + u32 key_part = make_key_part(fft.FFT_FP64, 0, fft.NTT_GF31, 0, fft.FFT_FP32, 0, fft.NTT_GF61, 0, 0); + decay_t::key_type key{SMALL_H, MIDDLE, width, key_part}; TrigPtr p{}; auto it = m.find(key); if (it == m.end() || !(p = it->second.lock())) { - p = make_shared(context, genMiddleTrig(SMALL_H, MIDDLE, width)); + p = make_shared(context, genMiddleTrig(fft, SMALL_H, MIDDLE, width)); m[key] = p; middleCache.add(p); } diff --git a/src/TrigBufCache.h b/src/TrigBufCache.h index 04ace546..d5e5317a 100644 --- a/src/TrigBufCache.h +++ b/src/TrigBufCache.h @@ -3,10 +3,10 @@ #pragma once #include "Buffer.h" +#include "FFTConfig.h" #include -using double2 = pair; using TrigBuf = Buffer; using TrigPtr = shared_ptr; @@ -27,8 +27,8 @@ class TrigBufCache { const Context* context; std::mutex mut; - std::map, TrigPtr::weak_type> small; - std::map, TrigPtr::weak_type> middle; + std::map, TrigPtr::weak_type> small; + std::map, TrigPtr::weak_type> middle; // The shared-pointers below keep the most recent set of buffers alive even without any Gpu instance // referencing them. This allows a single worker to delete & re-create the Gpu instance and still reuse the buffers. @@ -42,12 +42,54 @@ class TrigBufCache { ~TrigBufCache(); - TrigPtr smallTrig(u32 W, u32 nW); - TrigPtr smallTrigCombo(u32 width, u32 middle, u32 W, u32 nW, u32 variant, bool tail_single_wide, u32 tail_trigs); - TrigPtr middleTrig(u32 SMALL_H, u32 MIDDLE, u32 W); + TrigPtr smallTrigCombo(Args *args, FFTConfig fft, u32 width, u32 middle, u32 height, u32 nH, bool tail_single_wide); + TrigPtr middleTrig(Args *args, FFTConfig fft, u32 SMALL_H, u32 MIDDLE, u32 W); + TrigPtr smallTrig(Args *args, FFTConfig fft, u32 width, u32 nW, u32 middle, u32 height, u32 nH, bool tail_single_wide); }; -// For small angles, return "fancy" cos - 1 for increased precision -double2 root1Fancy(u32 N, u32 k); +double2 root1Fancy(u32 N, u32 k); // For small angles, return "fancy" cos - 1 for increased precision double2 root1(u32 N, u32 k); + +float2 root1FancyFP32(u32 N, u32 k); // For small angles, return "fancy" cos - 1 for increased precision +float2 root1FP32(u32 N, u32 k); + +uint2 root1GF31(u32 N, u32 k); +ulong2 root1GF61(u32 N, u32 k); + +// Compute the size of the largest possible trig buffer given width, middle, height (in number of float2 values) +#define SMALLTRIG_FP64_SIZE(W,M,H,nH) (W != H || H == 0 ? W * 5 : SMALLTRIGCOMBO_FP64_SIZE(W,M,H,nH)) // See genSmallTrigFP64 +#define SMALLTRIGCOMBO_FP64_SIZE(W,M,H,nH) (H * 5 + (W * M / 2 + 1) * 2 * H / nH) // See genSmallTrigComboFP64 +#define MIDDLETRIG_FP64_SIZE(W,M,H) (H + W + H) // See genMiddleTrigFP64 + +// Compute the size of the largest possible trig buffer given width, middle, height (in number of float2 values) +#define SMALLTRIG_FP32_SIZE(W,M,H,nH) (W != H || H == 0 ? W : SMALLTRIGCOMBO_FP32_SIZE(W,M,H,nH)) // See genSmallTrigFP32 +#define SMALLTRIGCOMBO_FP32_SIZE(W,M,H,nH) (H + (W * M / 2 + 1) * 2 * H / nH) // See genSmallTrigComboFP32 +#define MIDDLETRIG_FP32_SIZE(W,M,H) (H + W + H) // See genMiddleTrigFP32 + +// Compute the size of the largest possible trig buffer given width, middle, height (in number of uint2 values) +#define SMALLTRIG_GF31_SIZE(W,M,H,nH) (W != H || H == 0 ? W : SMALLTRIGCOMBO_GF31_SIZE(W,M,H,nH)) // See genSmallTrigGF31 +#define SMALLTRIGCOMBO_GF31_SIZE(W,M,H,nH) (H + (W * M / 2 + 1) * 2 * H / nH) // See genSmallTrigComboGF31 +#define MIDDLETRIG_GF31_SIZE(W,M,H) (H * (M - 1) + W + H) // See genMiddleTrigGF31 + +// Compute the size of the largest possible trig buffer given width, middle, height (in number of ulong2 values) +#define SMALLTRIG_GF61_SIZE(W,M,H,nH) (W != H || H == 0 ? W : SMALLTRIGCOMBO_GF61_SIZE(W,M,H,nH)) // See genSmallTrigGF61 +#define SMALLTRIGCOMBO_GF61_SIZE(W,M,H,nH) (H + (W * M / 2 + 1) * 2 * H / nH) // See genSmallTrigComboGF61 +#define MIDDLETRIG_GF61_SIZE(W,M,H) (H * (M - 1) + W + H) // See genMiddleTrigGF61 + +// Convert above sizes to distances (in units of double2) +#define SMALLTRIG_FP64_DIST(W,M,H,nH) SMALLTRIG_FP64_SIZE(W,M,H,nH) +#define SMALLTRIGCOMBO_FP64_DIST(W,M,H,nH) SMALLTRIGCOMBO_FP64_SIZE(W,M,H,nH) +#define MIDDLETRIG_FP64_DIST(W,M,H) MIDDLETRIG_FP64_SIZE(W,M,H) + +#define SMALLTRIG_FP32_DIST(W,M,H,nH) SMALLTRIG_FP32_SIZE(W,M,H,nH) * sizeof(float) / sizeof(double) +#define SMALLTRIGCOMBO_FP32_DIST(W,M,H,nH) SMALLTRIGCOMBO_FP32_SIZE(W,M,H,nH) * sizeof(float) / sizeof(double) +#define MIDDLETRIG_FP32_DIST(W,M,H) MIDDLETRIG_FP32_SIZE(W,M,H) * sizeof(float) / sizeof(double) + +#define SMALLTRIG_GF31_DIST(W,M,H,nH) SMALLTRIG_GF31_SIZE(W,M,H,nH) * sizeof(uint) / sizeof(double) +#define SMALLTRIGCOMBO_GF31_DIST(W,M,H,nH) SMALLTRIGCOMBO_GF31_SIZE(W,M,H,nH) * sizeof(uint) / sizeof(double) +#define MIDDLETRIG_GF31_DIST(W,M,H) MIDDLETRIG_GF31_SIZE(W,M,H) * sizeof(uint) / sizeof(double) + +#define SMALLTRIG_GF61_DIST(W,M,H,nH) SMALLTRIG_GF61_SIZE(W,M,H,nH) * sizeof(ulong) / sizeof(double) +#define SMALLTRIGCOMBO_GF61_DIST(W,M,H,nH) SMALLTRIGCOMBO_GF61_SIZE(W,M,H,nH) * sizeof(ulong) / sizeof(double) +#define MIDDLETRIG_GF61_DIST(W,M,H) MIDDLETRIG_GF61_SIZE(W,M,H) * sizeof(ulong) / sizeof(double) diff --git a/src/Worktodo.cpp b/src/Worktodo.cpp index bc3e4d82..0a981a39 100644 --- a/src/Worktodo.cpp +++ b/src/Worktodo.cpp @@ -93,13 +93,13 @@ std::optional parse(const std::string& line) { } // Among the valid tasks from fileName, return the "best" which means the smallest CERT, or otherwise the exponent PRP/LL -static std::optional bestTask(const fs::path& fileName) { +static std::optional bestTask(const fs::path& fileName, bool smallest) { optional best; for (const string& line : File::openRead(fileName)) { optional task = parse(line); if (task && (!best || (best->kind != Task::CERT && task->kind == Task::CERT) - || ((best->kind != Task::CERT || task->kind == Task::CERT) && task->exponent < best->exponent))) { + || ((best->kind != Task::CERT || task->kind == Task::CERT) && smallest && task->exponent < best->exponent))) { best = task; } } @@ -112,7 +112,7 @@ optional getWork(Args& args, i32 instance) { fs::path localWork = workName(instance); // Try to get a task from the local worktodo- file. - if (optional task = bestTask(localWork)) { return task; } + if (optional task = bestTask(localWork, args.smallest)) { return task; } if (args.masterDir.empty()) { return {}; } @@ -140,7 +140,7 @@ optional getWork(Args& args, i32 instance) { u64 initialSize = fileSize(worktodo); if (!initialSize) { return {}; } - optional task = bestTask(worktodo); + optional task = bestTask(worktodo, args.smallest); if (!task) { return {}; } string workLine = task->line; diff --git a/src/cl/base.cl b/src/cl/base.cl index 8190406c..df1ef02b 100644 --- a/src/cl/base.cl +++ b/src/cl/base.cl @@ -19,7 +19,9 @@ CARRY_LEN NW NH AMDGPU : if this is an AMD GPU -HAS_ASM : set if we believe __asm() can be used +NVIDIAGPU : if this is an nVidia GPU +HAS_ASM : set if we believe __asm() can be used for AMD GCN +HAS_PTX : set if we believe __asm() can be used for nVidia PTX -- Derived from above: BIG_HEIGHT == SMALL_HEIGHT * MIDDLE @@ -56,12 +58,24 @@ G_H "group height" == SMALL_HEIGHT / NH //__builtin_assume(condition) #endif // DEBUG -#if AMDGPU -// On AMDGPU the default is HAS_ASM -#if !NO_ASM +#if NO_ASM +#define HAS_ASM 0 +#define HAS_PTX 0 +#elif AMDGPU #define HAS_ASM 1 +#define HAS_PTX 0 +#elif NVIDIAGPU +#define HAS_ASM 0 +#define HAS_PTX 1200 // Assume CUDA 12.00 support until we can figure out how to automatically determine this at runtime +#else +#define HAS_ASM 0 +#define HAS_PTX 0 +#endif + +// Default is not adding -2 to results for LL +#if !defined(LL) +#define LL 0 #endif -#endif // AMDGPU // On Nvidia we need the old sync between groups in carryFused #if !defined(OLD_FENCE) && !AMDGPU @@ -97,10 +111,46 @@ G_H "group height" == SMALL_HEIGHT / NH #if FFT_VARIANT_H > 2 #error FFT_VARIANT_H must be between 0 and 2 #endif +// C code ensures that only AMD GPUs use FFT_VARIANT_W=0 and FFT_VARIANT_H=0. However, this does not guarantee that the OpenCL compiler supports +// the necessary amdgcn builtins. If those builtins are not present convert to variant one. +#if AMDGPU +#if !defined(__has_builtin) || !__has_builtin(__builtin_amdgcn_mov_dpp) || !__has_builtin(__builtin_amdgcn_ds_swizzle) || !__has_builtin(__builtin_amdgcn_readfirstlane) +#if FFT_VARIANT_W == 0 +#warning Missing builtins for FFT_VARIANT_W=0, switching to FFT_VARIANT_W=1 +#undef FFT_VARIANT_W +#define FFT_VARIANT_W 1 +#endif +#if FFT_VARIANT_H == 0 +#warning Missing builtins for FFT_VARIANT_H=0, switching to FFT_VARIANT_H=1 +#undef FFT_VARIANT_H +#define FFT_VARIANT_H 1 +#endif +#endif +#endif + +#if !defined(BIGLIT) +#define BIGLIT 1 +#endif #if !defined(TABMUL_CHAIN) #define TABMUL_CHAIN 0 #endif +#if !defined(TABMUL_CHAIN31) +#define TABMUL_CHAIN31 0 +#endif +#if !defined(TABMUL_CHAIN32) +#define TABMUL_CHAIN32 0 +#endif +#if !defined(TABMUL_CHAIN61) +#define TABMUL_CHAIN61 0 +#endif +#if !defined(MODM31) +#define MODM31 0 +#endif + +#if !defined(MIDDLE_CHAIN) +#define MIDDLE_CHAIN 0 +#endif #if !defined(UNROLL_W) #if AMDGPU @@ -145,59 +195,159 @@ typedef uint u32; typedef long i64; typedef ulong u64; +// Data types for data stored in FFTs and NTTs during the transform +typedef double T; // For historical reasons, classic FFTs using doubles call their data T and T2. +typedef double2 T2; // A complex value using doubles in a classic FFT. +typedef float F; // A classic FFT using floats. Use typedefs F and F2. +typedef float2 F2; +typedef uint Z31; // A value calculated mod M31. For a GF(M31^2) NTT. +typedef uint2 GF31; // A complex value using two Z31s. For a GF(M31^2) NTT. +typedef ulong Z61; // A value calculated mod M61. For a GF(M61^2) NTT. +typedef ulong2 GF61; // A complex value using two Z61s. For a GF(M61^2) NTT. +//typedef ulong NCW; // A value calculated mod 2^64 - 2^32 + 1. +//typedef ulong2 NCW2; // A complex value using NCWs. For a Nick Craig-Wood's insipred NTT using prime 2^64 - 2^32 + 1. + +// Defines for the various supported FFTs/NTTs. These match the enumeration in FFTConfig.h. Sanity check for supported FFT/NTT. +#define FFT64 0 +#define FFT3161 1 +#define FFT3261 2 +#define FFT61 3 +#define FFT323161 4 +#define FFT3231 50 +#define FFT6431 51 +#define FFT31 52 +#define FFT32 53 +#if FFT_TYPE < 0 || (FFT_TYPE > 4 && FFT_TYPE < 50) || FFT_TYPE > 53 +#error - unsupported FFT/NTT +#endif +// Word and Word2 define the data type for FFT integers passed between the CPU and GPU. +#if WordSize == 8 +typedef i64 Word; +typedef long2 Word2; +#elif WordSize == 4 typedef i32 Word; typedef int2 Word2; +#else +error - unsupported integer WordSize +#endif -typedef double T; -typedef double2 T2; +// Routine to create a pair +double2 OVERLOAD U2(double a, double b) { return (double2) (a, b); } +float2 OVERLOAD U2(float a, float b) { return (float2) (a, b); } +int2 OVERLOAD U2(int a, int b) { return (int2) (a, b); } +long2 OVERLOAD U2(long a, long b) { return (long2) (a, b); } +uint2 OVERLOAD U2(uint a, uint b) { return (uint2) (a, b); } +ulong2 OVERLOAD U2(ulong a, ulong b) { return (ulong2) (a, b); } +// Other handy macros #define RE(a) (a.x) #define IM(a) (a.y) #define P(x) global x * restrict #define CP(x) const P(x) -// Macros for non-temporal load and store (in case we later want to provide a -use option to turn this off) -#if NONTEMPORAL && defined(__has_builtin) && __has_builtin(__builtin_nontemporal_load) && __has_builtin(__builtin_nontemporal_store) +// Macros for non-temporal load and store. The theory behind only non-temporal reads (option 2) is that with alternating buffers, +// read buffers will not be needed for quite a while, but write buffers will be needed soon. +#if NONTEMPORAL == 1 && defined(__has_builtin) && __has_builtin(__builtin_nontemporal_load) && __has_builtin(__builtin_nontemporal_store) #define NTLOAD(mem) __builtin_nontemporal_load(&(mem)) #define NTSTORE(mem,val) __builtin_nontemporal_store(val, &(mem)) +#elif NONTEMPORAL == 2 && defined(__has_builtin) && __has_builtin(__builtin_nontemporal_load) +#define NTLOAD(mem) __builtin_nontemporal_load(&(mem)) +#define NTSTORE(mem,val) (mem) = val #else #define NTLOAD(mem) (mem) #define NTSTORE(mem,val) (mem) = val #endif +// Prefetch macros. Unused at present, I tried using them in fftMiddleInGF61 on a 5080 with no benefit. +void PREFETCHL1(const __global void *addr) { +#if HAS_PTX >= 200 // Prefetch instruction requires sm_20 support or higher + __asm("prefetch.global.L1 [%0];" : : "l"(addr)); +#endif +} +void PREFETCHL2(const __global void *addr) { +#if HAS_PTX >= 200 // Prefetch instruction requires sm_20 support or higher + __asm("prefetch.global.L2 [%0];" : : "l"(addr)); +#endif +} + // For reasons unknown, loading trig values into nVidia's constant cache has terrible performance #if AMDGPU typedef constant const T2* Trig; typedef constant const T* TrigSingle; +typedef constant const F2* TrigFP32; +typedef constant const GF31* TrigGF31; +typedef constant const GF61* TrigGF61; #else typedef global const T2* Trig; typedef global const T* TrigSingle; +typedef global const F2* TrigFP32; +typedef global const GF31* TrigGF31; +typedef global const GF61* TrigGF61; #endif // However, caching weights in nVidia's constant cache improves performance. // Even better is to not pollute the constant cache with weights that are used only once. // This requires two typedefs depending on how we want to use the BigTab pointer. // For AMD we can declare BigTab as constant or global - it doesn't really matter. typedef constant const double2* ConstBigTab; +typedef constant const float2* ConstBigTabFP32; #if AMDGPU typedef constant const double2* BigTab; +typedef constant const float2* BigTabFP32; #else typedef global const double2* BigTab; +typedef global const float2* BigTabFP32; #endif #define KERNEL(x) kernel __attribute__((reqd_work_group_size(x, 1, 1))) void - -void read(u32 WG, u32 N, T2 *u, const global T2 *in, u32 base) { + +#if FFT_FP64 +void OVERLOAD read(u32 WG, u32 N, T2 *u, const global T2 *in, u32 base) { in += base + (u32) get_local_id(0); for (u32 i = 0; i < N; ++i) { u[i] = in[i * WG]; } } -void write(u32 WG, u32 N, T2 *u, global T2 *out, u32 base) { +void OVERLOAD write(u32 WG, u32 N, T2 *u, global T2 *out, u32 base) { out += base + (u32) get_local_id(0); for (u32 i = 0; i < N; ++i) { out[i * WG] = u[i]; } } +#endif + +#if FFT_FP32 +void OVERLOAD read(u32 WG, u32 N, F2 *u, const global F2 *in, u32 base) { + in += base + (u32) get_local_id(0); + for (u32 i = 0; i < N; ++i) { u[i] = in[i * WG]; } +} + +void OVERLOAD write(u32 WG, u32 N, F2 *u, global F2 *out, u32 base) { + out += base + (u32) get_local_id(0); + for (u32 i = 0; i < N; ++i) { out[i * WG] = u[i]; } +} +#endif + +#if NTT_GF31 +void OVERLOAD read(u32 WG, u32 N, GF31 *u, const global GF31 *in, u32 base) { + in += base + (u32) get_local_id(0); + for (u32 i = 0; i < N; ++i) { u[i] = in[i * WG]; } +} -T2 U2(T a, T b) { return (T2) (a, b); } +void OVERLOAD write(u32 WG, u32 N, GF31 *u, global GF31 *out, u32 base) { + out += base + (u32) get_local_id(0); + for (u32 i = 0; i < N; ++i) { out[i * WG] = u[i]; } +} +#endif + +#if NTT_GF61 +void OVERLOAD read(u32 WG, u32 N, GF61 *u, const global GF61 *in, u32 base) { + in += base + (u32) get_local_id(0); + for (u32 i = 0; i < N; ++i) { u[i] = in[i * WG]; } +} + +void OVERLOAD write(u32 WG, u32 N, GF61 *u, global GF61 *out, u32 base) { + out += base + (u32) get_local_id(0); + for (u32 i = 0; i < N; ++i) { out[i * WG] = u[i]; } +} +#endif // On "classic" AMD GCN GPUs such as Radeon VII, the wavefront size was always 64. On RDNA GPUs the wavefront can // be configured to be either 64 or 32. We use the FAST_BARRIER define as an indicator for GCN GPUs. diff --git a/src/cl/carry.cl b/src/cl/carry.cl index 0a613c73..28863dd8 100644 --- a/src/cl/carry.cl +++ b/src/cl/carry.cl @@ -3,33 +3,82 @@ #include "carryutil.cl" #include "weight.cl" +#if FFT_TYPE == FFT64 + // Carry propagation with optional MUL-3, over CARRY_LEN words. -// Input arrives conjugated and inverse-weighted. +// Input arrives with real and imaginary values swapped and weighted. -KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, CP(u32) bits, - BigTab THREAD_WEIGHTS, P(uint) bufROE) { +KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, BigTab THREAD_WEIGHTS, P(uint) bufROE) { u32 g = get_group_id(0); u32 me = get_local_id(0); u32 gx = g % NW; u32 gy = g / NW; + u32 H = BIG_HEIGHT; // & vs. && to workaround spurious warning CarryABM carry = (LL & (me == 0) & (g == 0)) ? -2 : 0; float roundMax = 0; float carryMax = 0; - // Split 32 bits into CARRY_LEN groups of 2 bits. -#define GPW (16 / CARRY_LEN) - u32 b = bits[(G_W * g + me) / GPW] >> (me % GPW * (2 * CARRY_LEN)); -#undef GPW + // Calculate the most significant 32-bits of FRAC_BPW * the index of the FFT word. Also add FRAC_BPW_HI to test first biglit flag. + u32 line = gy * CARRY_LEN; + u32 fft_word_index = (gx * G_W * H + me * H + line) * 2; + u32 frac_bits = fft_word_index * FRAC_BPW_HI + mad_hi (fft_word_index, FRAC_BPW_LO, FRAC_BPW_HI); T base = optionalDouble(fancyMul(THREAD_WEIGHTS[me].x, iweightStep(gx))); for (i32 i = 0; i < CARRY_LEN; ++i) { u32 p = G_W * gx + WIDTH * (CARRY_LEN * gy + i) + me; - double w1 = optionalDouble(fancyMul(base, THREAD_WEIGHTS[G_W + gy * CARRY_LEN + i].x)); - double w2 = optionalDouble(fancyMul(w1, IWEIGHT_STEP)); - out[p] = weightAndCarryPair(conjugate(in[p]), U2(w1, w2), carry, &roundMax, &carry, test(b, 2 * i), test(b, 2 * i + 1), &carryMax); + T w1 = optionalDouble(fancyMul(base, THREAD_WEIGHTS[G_W + gy * CARRY_LEN + i].x)); + T w2 = optionalDouble(fancyMul(w1, IWEIGHT_STEP)); + bool biglit0 = frac_bits + (2*i) * FRAC_BPW_HI <= FRAC_BPW_HI; + bool biglit1 = frac_bits + (2*i) * FRAC_BPW_HI >= -FRAC_BPW_HI; // Same as frac_bits + (2*i) * FRAC_BPW_HI + FRAC_BPW_HI <= FRAC_BPW_HI; + out[p] = weightAndCarryPair(SWAP_XY(in[p]), U2(w1, w2), carry, biglit0, biglit1, &carry, &roundMax, &carryMax); + } + carryOut[G_W * g + me] = carry; + +#if ROE + updateStats(bufROE, posROE, roundMax); +#elif (STATS & (1 << (2 + MUL3))) + updateStats(bufROE, posROE, carryMax); +#endif +} + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT32 + +// Carry propagation with optional MUL-3, over CARRY_LEN words. +// Input arrives with real and imaginary values swapped and weighted. + +KERNEL(G_W) carry(P(Word2) out, CP(F2) in, u32 posROE, P(CarryABM) carryOut, BigTabFP32 THREAD_WEIGHTS, P(uint) bufROE) { + u32 g = get_group_id(0); + u32 me = get_local_id(0); + u32 gx = g % NW; + u32 gy = g / NW; + u32 H = BIG_HEIGHT; + + CarryABM carry = (LL & (me == 0) & (g == 0)) ? -2 : 0; + float roundMax = 0; + float carryMax = 0; + + // Calculate the most significant 32-bits of FRAC_BPW * the index of the FFT word. Also add FRAC_BPW_HI to test first biglit flag. + u32 line = gy * CARRY_LEN; + u32 fft_word_index = (gx * G_W * H + me * H + line) * 2; + u32 frac_bits = fft_word_index * FRAC_BPW_HI + mad_hi (fft_word_index, FRAC_BPW_LO, FRAC_BPW_HI); + + F base = optionalDouble(fancyMul(THREAD_WEIGHTS[me].x, iweightStep(gx))); + + for (i32 i = 0; i < CARRY_LEN; ++i) { + u32 p = G_W * gx + WIDTH * (CARRY_LEN * gy + i) + me; + F w1 = optionalDouble(fancyMul(base, THREAD_WEIGHTS[G_W + gy * CARRY_LEN + i].x)); + F w2 = optionalDouble(fancyMul(w1, IWEIGHT_STEP)); + bool biglit0 = frac_bits + (2*i) * FRAC_BPW_HI <= FRAC_BPW_HI; + bool biglit1 = frac_bits + (2*i) * FRAC_BPW_HI >= -FRAC_BPW_HI; // Same as frac_bits + (2*i) * FRAC_BPW_HI + FRAC_BPW_HI <= FRAC_BPW_HI; + out[p] = weightAndCarryPair(SWAP_XY(in[p]), U2(w1, w2), carry, biglit0, biglit1, &carry, &roundMax, &carryMax); } carryOut[G_W * g + me] = carry; @@ -39,3 +88,578 @@ KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, CP( updateStats(bufROE, posROE, carryMax); #endif } + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT31 + +KERNEL(G_W) carry(P(Word2) out, CP(GF31) in, u32 posROE, P(CarryABM) carryOut, P(uint) bufROE) { + u32 g = get_group_id(0); + u32 me = get_local_id(0); + u32 gx = g % NW; + u32 gy = g / NW; + u32 H = BIG_HEIGHT; + u32 line = gy * CARRY_LEN; + + // & vs. && to workaround spurious warning + CarryABM carry = (LL & (me == 0) & (g == 0)) ? -2 : 0; + u32 roundMax = 0; + float carryMax = 0; + + u32 word_index = (gx * G_W * H + me * H + line) * 2; + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + // Weights can be applied with shifts because 2 is the 30th root GF31. + // Let s be the shift amount for word 1. The shift amount for word x is ceil(x * (s - 1) + num_big_words_less_than_x) % 31. + const u32 log2_root_two = (u32) (((1ULL << 30) / NWORDS) % 31); + const u32 bigword_weight_shift = (NWORDS - EXP % NWORDS) * log2_root_two % 31; + const u32 bigword_weight_shift_minus1 = (bigword_weight_shift + 30) % 31; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } combo; +#define frac_bits combo.a[0] +#define weight_shift combo.a[1] +#define combo_counter combo.b + + const u64 combo_step = ((u64) bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + combo_counter = word_index * combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + + // We also adjust shift amount for the fact that NTT returns results multiplied by 2*NWORDS. + const u32 log2_NWORDS = (WIDTH == 256 ? 8 : WIDTH == 512 ? 9 : WIDTH == 1024 ? 10 : 12) + + (MIDDLE == 1 ? 0 : MIDDLE == 2 ? 1 : MIDDLE == 4 ? 2 : MIDDLE == 8 ? 3 : 4) + + (SMALL_HEIGHT == 256 ? 8 : SMALL_HEIGHT == 512 ? 9 : SMALL_HEIGHT == 1024 ? 10 : 12) + 1; + weight_shift = (weight_shift + log2_NWORDS + 1) % 31; + + for (i32 i = 0; i < CARRY_LEN; ++i) { + u32 p = G_W * gx + WIDTH * (CARRY_LEN * gy + i) + me; + + // Generate the second weight shift + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 31) weight_shift -= 31; + u32 weight_shift1 = weight_shift; + + // Generate big-word/little-word flags + bool biglit0 = frac_bits <= FRAC_BPW_HI; + bool biglit1 = frac_bits >= -FRAC_BPW_HI; // Same as frac_bits + FRAC_BPW_HI <= FRAC_BPW_HI; + + // Compute result + out[p] = weightAndCarryPair(SWAP_XY(in[p]), weight_shift0, weight_shift1, carry, biglit0, biglit1, &carry, &roundMax, &carryMax); + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_step; + if (weight_shift > 31) weight_shift -= 31; + } + carryOut[G_W * g + me] = carry; + +#if ROE + float fltRoundMax = (float) roundMax / (float) M31; // For speed, roundoff was computed as 32-bit integer. Convert to float. + updateStats(bufROE, posROE, fltRoundMax); +#elif (STATS & (1 << (2 + MUL3))) + updateStats(bufROE, posROE, carryMax); +#endif +} + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT61 + +KERNEL(G_W) carry(P(Word2) out, CP(GF61) in, u32 posROE, P(CarryABM) carryOut, P(uint) bufROE) { + u32 g = get_group_id(0); + u32 me = get_local_id(0); + u32 gx = g % NW; + u32 gy = g / NW; + u32 H = BIG_HEIGHT; + u32 line = gy * CARRY_LEN; + + // & vs. && to workaround spurious warning + CarryABM carry = (LL & (me == 0) & (g == 0)) ? -2 : 0; + u32 roundMax = 0; + float carryMax = 0; + + u32 word_index = (gx * G_W * H + me * H + line) * 2; + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + // Weights can be applied with shifts because 2 is the 60th root GF61. + // Let s be the shift amount for word 1. The shift amount for word x is ceil(x * (s - 1) + num_big_words_less_than_x) % 61. + const u32 log2_root_two = (u32) (((1ULL << 60) / NWORDS) % 61); + const u32 bigword_weight_shift = (NWORDS - EXP % NWORDS) * log2_root_two % 61; + const u32 bigword_weight_shift_minus1 = (bigword_weight_shift + 60) % 61; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } combo; +#define frac_bits combo.a[0] +#define weight_shift combo.a[1] +#define combo_counter combo.b + + const u64 combo_step = ((u64) bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + combo_counter = word_index * combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + + // We also adjust shift amount for the fact that NTT returns results multiplied by 2*NWORDS. + const u32 log2_NWORDS = (WIDTH == 256 ? 8 : WIDTH == 512 ? 9 : WIDTH == 1024 ? 10 : 12) + + (MIDDLE == 1 ? 0 : MIDDLE == 2 ? 1 : MIDDLE == 4 ? 2 : MIDDLE == 8 ? 3 : 4) + + (SMALL_HEIGHT == 256 ? 8 : SMALL_HEIGHT == 512 ? 9 : SMALL_HEIGHT == 1024 ? 10 : 12) + 1; + weight_shift = (weight_shift + log2_NWORDS + 1) % 61; + + for (i32 i = 0; i < CARRY_LEN; ++i) { + u32 p = G_W * gx + WIDTH * (CARRY_LEN * gy + i) + me; + + // Generate the second weight shift + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 61) weight_shift -= 61; + u32 weight_shift1 = weight_shift; + + // Generate big-word/little-word flags + bool biglit0 = frac_bits <= FRAC_BPW_HI; + bool biglit1 = frac_bits >= -FRAC_BPW_HI; // Same as frac_bits + FRAC_BPW_HI <= FRAC_BPW_HI; + + // Compute result + out[p] = weightAndCarryPair(SWAP_XY(in[p]), weight_shift0, weight_shift1, carry, biglit0, biglit1, &carry, &roundMax, &carryMax); + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_step; + if (weight_shift > 61) weight_shift -= 61; + } + carryOut[G_W * g + me] = carry; + +#if ROE + float fltRoundMax = (float) roundMax / (float) (M61 >> 32); // For speed, roundoff was computed as 32-bit integer. Convert to float. + updateStats(bufROE, posROE, fltRoundMax); +#elif (STATS & (1 << (2 + MUL3))) + updateStats(bufROE, posROE, carryMax); +#endif +} + + +/**************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP64 & GF(M31^2) */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT6431 + +KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, BigTab THREAD_WEIGHTS, P(uint) bufROE) { + u32 g = get_group_id(0); + u32 me = get_local_id(0); + u32 gx = g % NW; + u32 gy = g / NW; + u32 H = BIG_HEIGHT; + u32 line = gy * CARRY_LEN; + + CP(GF31) in31 = (CP(GF31)) (in + DISTGF31); + + CarryABM carry = (LL & (me == 0) & (g == 0)) ? -2 : 0; + float roundMax = 0; + float carryMax = 0; + + u32 word_index = (gx * G_W * H + me * H + line) * 2; + + T base = optionalDouble(fancyMul(THREAD_WEIGHTS[me].x, iweightStep(gx))); + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + const u32 log2_root_two = (u32) (((1ULL << 30) / NWORDS) % 31); + const u32 bigword_weight_shift = (NWORDS - EXP % NWORDS) * log2_root_two % 31; + const u32 bigword_weight_shift_minus1 = (bigword_weight_shift + 30) % 31; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } combo; +#define frac_bits combo.a[0] +#define weight_shift combo.a[1] +#define combo_counter combo.b + + const u64 combo_step = ((u64) bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + combo_counter = word_index * combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + + // We also adjust shift amount for the fact that NTT returns results multiplied by 2*NWORDS. + const u32 log2_NWORDS = (WIDTH == 256 ? 8 : WIDTH == 512 ? 9 : WIDTH == 1024 ? 10 : 12) + + (MIDDLE == 1 ? 0 : MIDDLE == 2 ? 1 : MIDDLE == 4 ? 2 : MIDDLE == 8 ? 3 : 4) + + (SMALL_HEIGHT == 256 ? 8 : SMALL_HEIGHT == 512 ? 9 : SMALL_HEIGHT == 1024 ? 10 : 12) + 1; + weight_shift = (weight_shift + log2_NWORDS + 1) % 31; + + for (i32 i = 0; i < CARRY_LEN; ++i) { + u32 p = G_W * gx + WIDTH * (CARRY_LEN * gy + i) + me; + + // Generate the FP64 and second GF31 weight shift + T w1 = optionalDouble(fancyMul(base, THREAD_WEIGHTS[G_W + gy * CARRY_LEN + i].x)); + T w2 = optionalDouble(fancyMul(w1, IWEIGHT_STEP)); + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 31) weight_shift -= 31; + u32 weight_shift1 = weight_shift; + + // Generate big-word/little-word flags + bool biglit0 = frac_bits <= FRAC_BPW_HI; + bool biglit1 = frac_bits >= -FRAC_BPW_HI; // Same as frac_bits + FRAC_BPW_HI <= FRAC_BPW_HI; + + // Compute result + out[p] = weightAndCarryPair(SWAP_XY(in[p]), SWAP_XY(in31[p]), w1, w2, weight_shift0, weight_shift1, + LL != 0 || i != 0, carry, biglit0, biglit1, &carry, &roundMax, &carryMax); + + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_step; + if (weight_shift > 31) weight_shift -= 31; + } + carryOut[G_W * g + me] = carry; + +#if ROE + updateStats(bufROE, posROE, roundMax); +#elif (STATS & (1 << (2 + MUL3))) + updateStats(bufROE, posROE, carryMax); +#endif +} + + +/**************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP32 & GF(M31^2) */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT3231 + +KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, BigTabFP32 THREAD_WEIGHTS, P(uint) bufROE) { + u32 g = get_group_id(0); + u32 me = get_local_id(0); + u32 gx = g % NW; + u32 gy = g / NW; + u32 H = BIG_HEIGHT; + u32 line = gy * CARRY_LEN; + + CP(F2) inF2 = (CP(F2)) in; + CP(GF31) in31 = (CP(GF31)) (in + DISTGF31); + + CarryABM carry = (LL & (me == 0) & (g == 0)) ? -2 : 0; + float roundMax = 0; + float carryMax = 0; + + u32 word_index = (gx * G_W * H + me * H + line) * 2; + + F base = optionalDouble(fancyMul(THREAD_WEIGHTS[me].x, iweightStep(gx))); + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + const u32 log2_root_two = (u32) (((1ULL << 30) / NWORDS) % 31); + const u32 bigword_weight_shift = (NWORDS - EXP % NWORDS) * log2_root_two % 31; + const u32 bigword_weight_shift_minus1 = (bigword_weight_shift + 30) % 31; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } combo; +#define frac_bits combo.a[0] +#define weight_shift combo.a[1] +#define combo_counter combo.b + + const u64 combo_step = ((u64) bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + combo_counter = word_index * combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + + // We also adjust shift amount for the fact that NTT returns results multiplied by 2*NWORDS. + const u32 log2_NWORDS = (WIDTH == 256 ? 8 : WIDTH == 512 ? 9 : WIDTH == 1024 ? 10 : 12) + + (MIDDLE == 1 ? 0 : MIDDLE == 2 ? 1 : MIDDLE == 4 ? 2 : MIDDLE == 8 ? 3 : 4) + + (SMALL_HEIGHT == 256 ? 8 : SMALL_HEIGHT == 512 ? 9 : SMALL_HEIGHT == 1024 ? 10 : 12) + 1; + weight_shift = (weight_shift + log2_NWORDS + 1) % 31; + + for (i32 i = 0; i < CARRY_LEN; ++i) { + u32 p = G_W * gx + WIDTH * (CARRY_LEN * gy + i) + me; + + // Generate the FP32 and second GF31 weight shift + F w1 = optionalDouble(fancyMul(base, THREAD_WEIGHTS[G_W + gy * CARRY_LEN + i].x)); + F w2 = optionalDouble(fancyMul(w1, IWEIGHT_STEP)); + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 31) weight_shift -= 31; + u32 weight_shift1 = weight_shift; + + // Generate big-word/little-word flags + bool biglit0 = frac_bits <= FRAC_BPW_HI; + bool biglit1 = frac_bits >= -FRAC_BPW_HI; // Same as frac_bits + FRAC_BPW_HI <= FRAC_BPW_HI; + + // Compute result + out[p] = weightAndCarryPair(SWAP_XY(inF2[p]), SWAP_XY(in31[p]), w1, w2, weight_shift0, weight_shift1, + carry, biglit0, biglit1, &carry, &roundMax, &carryMax); + + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_step; + if (weight_shift > 31) weight_shift -= 31; + } + carryOut[G_W * g + me] = carry; + +#if ROE + updateStats(bufROE, posROE, roundMax); +#elif (STATS & (1 << (2 + MUL3))) + updateStats(bufROE, posROE, carryMax); +#endif +} + + +/**************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP32 & GF(M61^2) */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT3261 + +KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, BigTabFP32 THREAD_WEIGHTS, P(uint) bufROE) { + u32 g = get_group_id(0); + u32 me = get_local_id(0); + u32 gx = g % NW; + u32 gy = g / NW; + u32 H = BIG_HEIGHT; + u32 line = gy * CARRY_LEN; + + CP(F2) inF2 = (CP(F2)) in; + CP(GF61) in61 = (CP(GF61)) (in + DISTGF61); + + CarryABM carry = (LL & (me == 0) & (g == 0)) ? -2 : 0; + float roundMax = 0; + float carryMax = 0; + + u32 word_index = (gx * G_W * H + me * H + line) * 2; + + F base = optionalDouble(fancyMul(THREAD_WEIGHTS[me].x, iweightStep(gx))); + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + const u32 log2_root_two = (u32) (((1ULL << 60) / NWORDS) % 61); + const u32 bigword_weight_shift = (NWORDS - EXP % NWORDS) * log2_root_two % 61; + const u32 bigword_weight_shift_minus1 = (bigword_weight_shift + 60) % 61; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } combo; +#define frac_bits combo.a[0] +#define weight_shift combo.a[1] +#define combo_counter combo.b + + const u64 combo_step = ((u64) bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + combo_counter = word_index * combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + + // We also adjust shift amount for the fact that NTT returns results multiplied by 2*NWORDS. + const u32 log2_NWORDS = (WIDTH == 256 ? 8 : WIDTH == 512 ? 9 : WIDTH == 1024 ? 10 : 12) + + (MIDDLE == 1 ? 0 : MIDDLE == 2 ? 1 : MIDDLE == 4 ? 2 : MIDDLE == 8 ? 3 : 4) + + (SMALL_HEIGHT == 256 ? 8 : SMALL_HEIGHT == 512 ? 9 : SMALL_HEIGHT == 1024 ? 10 : 12) + 1; + weight_shift = (weight_shift + log2_NWORDS + 1) % 61; + + for (i32 i = 0; i < CARRY_LEN; ++i) { + u32 p = G_W * gx + WIDTH * (CARRY_LEN * gy + i) + me; + + // Generate the FP32 and second GF61 weight shift + F w1 = optionalDouble(fancyMul(base, THREAD_WEIGHTS[G_W + gy * CARRY_LEN + i].x)); + F w2 = optionalDouble(fancyMul(w1, IWEIGHT_STEP)); + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 61) weight_shift -= 61; + u32 weight_shift1 = weight_shift; + + // Generate big-word/little-word flags + bool biglit0 = frac_bits <= FRAC_BPW_HI; + bool biglit1 = frac_bits >= -FRAC_BPW_HI; // Same as frac_bits + FRAC_BPW_HI <= FRAC_BPW_HI; + + // Compute result + out[p] = weightAndCarryPair(SWAP_XY(inF2[p]), SWAP_XY(in61[p]), w1, w2, weight_shift0, weight_shift1, + LL != 0 || i != 0, carry, biglit0, biglit1, &carry, &roundMax, &carryMax); + + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_step; + if (weight_shift > 61) weight_shift -= 61; + } + carryOut[G_W * g + me] = carry; + +#if ROE + updateStats(bufROE, posROE, roundMax); +#elif (STATS & (1 << (2 + MUL3))) + updateStats(bufROE, posROE, carryMax); +#endif +} + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2)*GF(M61^2) */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT3161 + +KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, P(uint) bufROE) { + u32 g = get_group_id(0); + u32 me = get_local_id(0); + u32 gx = g % NW; + u32 gy = g / NW; + u32 H = BIG_HEIGHT; + u32 line = gy * CARRY_LEN; + + CP(GF31) in31 = (CP(GF31)) (in + DISTGF31); + CP(GF61) in61 = (CP(GF61)) (in + DISTGF61); + + // & vs. && to workaround spurious warning + CarryABM carry = (LL & (me == 0) & (g == 0)) ? -2 : 0; + u32 roundMax = 0; + float carryMax = 0; + + u32 word_index = (gx * G_W * H + me * H + line) * 2; + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + const u32 m31_log2_root_two = (u32) (((1ULL << 30) / NWORDS) % 31); + const u32 m31_bigword_weight_shift = (NWORDS - EXP % NWORDS) * m31_log2_root_two % 31; + const u32 m31_bigword_weight_shift_minus1 = (m31_bigword_weight_shift + 30) % 31; + const u32 m61_log2_root_two = (u32) (((1ULL << 60) / NWORDS) % 61); + const u32 m61_bigword_weight_shift = (NWORDS - EXP % NWORDS) * m61_log2_root_two % 61; + const u32 m61_bigword_weight_shift_minus1 = (m61_bigword_weight_shift + 60) % 61; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } m31_combo, m61_combo; +#define frac_bits m31_combo.a[0] +#define m31_weight_shift m31_combo.a[1] +#define m31_combo_counter m31_combo.b +#define m61_weight_shift m61_combo.a[1] +#define m61_combo_counter m61_combo.b + + const u64 m31_combo_step = ((u64) m31_bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + m31_combo_counter = word_index * m31_combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + const u64 m61_combo_step = ((u64) m61_bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + m61_combo_counter = word_index * m61_combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + + // We also adjust shift amount for the fact that NTT returns results multiplied by 2*NWORDS. + const u32 log2_NWORDS = (WIDTH == 256 ? 8 : WIDTH == 512 ? 9 : WIDTH == 1024 ? 10 : 12) + + (MIDDLE == 1 ? 0 : MIDDLE == 2 ? 1 : MIDDLE == 4 ? 2 : MIDDLE == 8 ? 3 : 4) + + (SMALL_HEIGHT == 256 ? 8 : SMALL_HEIGHT == 512 ? 9 : SMALL_HEIGHT == 1024 ? 10 : 12) + 1; + m31_weight_shift = (m31_weight_shift + log2_NWORDS + 1) % 31; + m61_weight_shift = (m61_weight_shift + log2_NWORDS + 1) % 61; + + for (i32 i = 0; i < CARRY_LEN; ++i) { + u32 p = G_W * gx + WIDTH * (CARRY_LEN * gy + i) + me; + + // Generate the second weight shifts + u32 m31_weight_shift0 = m31_weight_shift; + m31_combo_counter += m31_combo_step; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift); + u32 m31_weight_shift1 = m31_weight_shift; + u32 m61_weight_shift0 = m61_weight_shift; + m61_combo_counter += m61_combo_step; + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift); + u32 m61_weight_shift1 = m61_weight_shift; + + // Generate big-word/little-word flags + bool biglit0 = frac_bits <= FRAC_BPW_HI; + bool biglit1 = frac_bits >= -FRAC_BPW_HI; // Same as frac_bits + FRAC_BPW_HI <= FRAC_BPW_HI; + + // Compute result + out[p] = weightAndCarryPair(SWAP_XY(in31[p]), SWAP_XY(in61[p]), m31_weight_shift0, m31_weight_shift1, m61_weight_shift0, m61_weight_shift1, + LL != 0 || i != 0, carry, biglit0, biglit1, &carry, &roundMax, &carryMax); + + // Generate weight shifts and frac_bits for next pair + m31_combo_counter += m31_combo_step; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift); + m61_combo_counter += m61_combo_step; + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift); + } + carryOut[G_W * g + me] = carry; + +#if ROE + float fltRoundMax = (float) roundMax / (float) 0x1FFFFFFF; // For speed, roundoff was computed as 32-bit integer. Convert to float. + updateStats(bufROE, posROE, fltRoundMax); +#elif (STATS & (1 << (2 + MUL3))) + updateStats(bufROE, posROE, carryMax); +#endif +} + + +/******************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP32*GF(M31^2)*GF(M61^2) */ +/******************************************************************************/ + +#elif FFT_TYPE == FFT323161 + +KERNEL(G_W) carry(P(Word2) out, CP(T2) in, u32 posROE, P(CarryABM) carryOut, BigTabFP32 THREAD_WEIGHTS, P(uint) bufROE) { + u32 g = get_group_id(0); + u32 me = get_local_id(0); + u32 gx = g % NW; + u32 gy = g / NW; + u32 H = BIG_HEIGHT; + u32 line = gy * CARRY_LEN; + + CP(F2) inF2 = (CP(F2)) in; + CP(GF31) in31 = (CP(GF31)) (in + DISTGF31); + CP(GF61) in61 = (CP(GF61)) (in + DISTGF61); + + // & vs. && to workaround spurious warning + CarryABM carry = (LL & (me == 0) & (g == 0)) ? -2 : 0; + float roundMax = 0; + float carryMax = 0; + + u32 word_index = (gx * G_W * H + me * H + line) * 2; + + F base = optionalDouble(fancyMul(THREAD_WEIGHTS[me].x, iweightStep(gx))); + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + const u32 m31_log2_root_two = (u32) (((1ULL << 30) / NWORDS) % 31); + const u32 m31_bigword_weight_shift = (NWORDS - EXP % NWORDS) * m31_log2_root_two % 31; + const u32 m31_bigword_weight_shift_minus1 = (m31_bigword_weight_shift + 30) % 31; + const u32 m61_log2_root_two = (u32) (((1ULL << 60) / NWORDS) % 61); + const u32 m61_bigword_weight_shift = (NWORDS - EXP % NWORDS) * m61_log2_root_two % 61; + const u32 m61_bigword_weight_shift_minus1 = (m61_bigword_weight_shift + 60) % 61; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } m31_combo, m61_combo; +#define frac_bits m31_combo.a[0] +#define m31_weight_shift m31_combo.a[1] +#define m31_combo_counter m31_combo.b +#define m61_weight_shift m61_combo.a[1] +#define m61_combo_counter m61_combo.b + + const u64 m31_combo_step = ((u64) m31_bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + m31_combo_counter = word_index * m31_combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + const u64 m61_combo_step = ((u64) m61_bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + m61_combo_counter = word_index * m61_combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + + // We also adjust shift amount for the fact that NTT returns results multiplied by 2*NWORDS. + const u32 log2_NWORDS = (WIDTH == 256 ? 8 : WIDTH == 512 ? 9 : WIDTH == 1024 ? 10 : 12) + + (MIDDLE == 1 ? 0 : MIDDLE == 2 ? 1 : MIDDLE == 4 ? 2 : MIDDLE == 8 ? 3 : 4) + + (SMALL_HEIGHT == 256 ? 8 : SMALL_HEIGHT == 512 ? 9 : SMALL_HEIGHT == 1024 ? 10 : 12) + 1; + m31_weight_shift = (m31_weight_shift + log2_NWORDS + 1) % 31; + m61_weight_shift = (m61_weight_shift + log2_NWORDS + 1) % 61; + + for (i32 i = 0; i < CARRY_LEN; ++i) { + u32 p = G_W * gx + WIDTH * (CARRY_LEN * gy + i) + me; + + // Generate the FP32 and second GF31 and GF61 weight shift + F w1 = optionalDouble(fancyMul(base, THREAD_WEIGHTS[G_W + gy * CARRY_LEN + i].x)); + F w2 = optionalDouble(fancyMul(w1, IWEIGHT_STEP)); + u32 m31_weight_shift0 = m31_weight_shift; + m31_combo_counter += m31_combo_step; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift); + u32 m31_weight_shift1 = m31_weight_shift; + u32 m61_weight_shift0 = m61_weight_shift; + m61_combo_counter += m61_combo_step; + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift); + u32 m61_weight_shift1 = m61_weight_shift; + + // Generate big-word/little-word flags + bool biglit0 = frac_bits <= FRAC_BPW_HI; + bool biglit1 = frac_bits >= -FRAC_BPW_HI; // Same as frac_bits + FRAC_BPW_HI <= FRAC_BPW_HI; + + // Compute result + out[p] = weightAndCarryPair(SWAP_XY(inF2[p]), SWAP_XY(in31[p]), SWAP_XY(in61[p]), w1, w2, m31_weight_shift0, m31_weight_shift1, m61_weight_shift0, m61_weight_shift1, + LL != 0 || i != 0, carry, biglit0, biglit1, &carry, &roundMax, &carryMax); + + // Generate weight shifts and frac_bits for next pair + m31_combo_counter += m31_combo_step; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift); + m61_combo_counter += m61_combo_step; + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift); + } + carryOut[G_W * g + me] = carry; + +#if ROE + updateStats(bufROE, posROE, roundMax); +#elif (STATS & (1 << (2 + MUL3))) + updateStats(bufROE, posROE, carryMax); +#endif +} + + +#else +error - missing Carry kernel implementation +#endif diff --git a/src/cl/carryb.cl b/src/cl/carryb.cl index 1f424c9f..d6033e19 100644 --- a/src/cl/carryb.cl +++ b/src/cl/carryb.cl @@ -2,19 +2,20 @@ #include "carryutil.cl" -KERNEL(G_W) carryB(P(Word2) io, CP(CarryABM) carryIn, CP(u32) bits) { +KERNEL(G_W) carryB(P(Word2) io, CP(CarryABM) carryIn) { u32 g = get_group_id(0); - u32 me = get_local_id(0); + u32 me = get_local_id(0); u32 gx = g % NW; u32 gy = g / NW; + u32 H = BIG_HEIGHT; - // Split 32 bits into CARRY_LEN groups of 2 bits. -#define GPW (16 / CARRY_LEN) - u32 b = bits[(G_W * g + me) / GPW] >> (me % GPW * (2 * CARRY_LEN)); -#undef GPW + // Derive the big vs. little flags from the fractional number of bits in each FFT word rather read the flags from memory. + // Calculate the most significant 32-bits of FRAC_BPW * the index of the FFT word. Also add FRAC_BPW_HI to test first biglit flag. + u32 line = gy * CARRY_LEN; + u32 fft_word_index = (gx * G_W * H + me * H + line) * 2; + u32 frac_bits = fft_word_index * FRAC_BPW_HI + mad_hi (fft_word_index, FRAC_BPW_LO, FRAC_BPW_HI); - u32 step = G_W * gx + WIDTH * CARRY_LEN * gy; - io += step; + io += G_W * gx + WIDTH * CARRY_LEN * gy; u32 HB = BIG_HEIGHT / CARRY_LEN; @@ -26,7 +27,9 @@ KERNEL(G_W) carryB(P(Word2) io, CP(CarryABM) carryIn, CP(u32) bits) { for (i32 i = 0; i < CARRY_LEN; ++i) { u32 p = i * WIDTH + me; - io[p] = carryWord(io[p], &carry, test(b, 2 * i), test(b, 2 * i + 1)); + bool biglit0 = frac_bits + (2*i) * FRAC_BPW_HI <= FRAC_BPW_HI; + bool biglit1 = frac_bits + (2*i) * FRAC_BPW_HI >= -FRAC_BPW_HI; // Same as frac_bits + (2*i) * FRAC_BPW_HI + FRAC_BPW_HI <= FRAC_BPW_HI; + io[p] = carryWord(io[p], &carry, biglit0, biglit1); if (!carry) { return; } } } diff --git a/src/cl/carryfused.cl b/src/cl/carryfused.cl index bf905597..05e4ca4c 100644 --- a/src/cl/carryfused.cl +++ b/src/cl/carryfused.cl @@ -16,24 +16,27 @@ void spin() { #endif } +#if FFT_TYPE == FFT64 + // The "carryFused" is equivalent to the sequence: fftW, carryA, carryB, fftPremul. // It uses "stairway forwarding" (forwarding carry data from one workgroup to the next) KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P(u32) ready, Trig smallTrig, - CP(u32) bits, ConstBigTab CONST_THREAD_WEIGHTS, BigTab THREAD_WEIGHTS, P(uint) bufROE) { + CP(u32) bits, ConstBigTab CONST_THREAD_WEIGHTS, BigTab THREAD_WEIGHTS, P(uint) bufROE) { #if 0 // fft_WIDTH uses shufl_int instead of shufl local T2 lds[WIDTH / 4]; #else local T2 lds[WIDTH / 2]; #endif + + T2 u[NW]; + u32 gr = get_group_id(0); u32 me = get_local_id(0); u32 H = BIG_HEIGHT; u32 line = gr % H; - T2 u[NW]; - #if HAS_ASM __asm("s_setprio 3"); #endif @@ -50,14 +53,9 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( // Try this weird FFT_width call that adds a "hidden zero" when unrolling. This prevents the compiler from finding // common sub-expressions to re-use in the second fft_WIDTH call. Re-using this data requires dozens of VGPRs // which causes a terrible reduction in occupancy. -// fft_WIDTH(lds + (get_group_id(0) / 131072), u, smallTrig + (get_group_id(0) / 131072)); - -// A temporary hack until we figure out which combinations we want to finally offer: -// UNROLL_W=0: old fft_WIDTH, no loop unrolling -// UNROLL_W=1: old fft_WIDTH, loop unrolling -// UNROLL_W=3: new fft_WIDTH if applicable. Slightly better on Radeon VII -- more study needed as to why results weren't better. #if ZEROHACK_W - new_fft_WIDTH1(lds + (get_group_id(0) / 131072), u, smallTrig + (get_group_id(0) / 131072)); + u32 zerohack = get_group_id(0) / 131072; + new_fft_WIDTH1(lds + zerohack, u, smallTrig + zerohack); #else new_fft_WIDTH1(lds, u, smallTrig); #endif @@ -87,28 +85,27 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( float roundMax = 0; float carryMax = 0; - // On Titan V it is faster to derive the big vs. little flags from the fractional number of bits in each FFT word rather read the flags from memory. + // On Titan V it is faster to derive the big vs. little flags from the fractional number of bits in each FFT word rather than read the flags from memory. // On Radeon VII this code is about the same speed. Not sure which is better on other GPUs. #if BIGLIT - // Calculate the most significant 32-bits of FRAC_BPW * the index of the FFT word. Also add FRAC_BPW_HI to test first biglit flag. - u32 fft_word_index = (me * H + line) * 2; - u32 frac_bits = fft_word_index * FRAC_BPW_HI + mad_hi (fft_word_index, FRAC_BPW_LO, FRAC_BPW_HI); + // Calculate the most significant 32-bits of FRAC_BPW * the word index. Also add FRAC_BPW_HI to test first biglit flag. + u32 word_index = (me * H + line) * 2; + u32 frac_bits = word_index * FRAC_BPW_HI + mad_hi (word_index, FRAC_BPW_LO, FRAC_BPW_HI); + const u32 frac_bits_bigstep = ((G_W * H * 2) * FRAC_BPW_HI + (u32)(((u64)(G_W * H * 2) * FRAC_BPW_LO) >> 32)); #endif // Apply the inverse weights and carry propagate pairs to generate the output carries T invBase = optionalDouble(weights.x); - + for (u32 i = 0; i < NW; ++i) { T invWeight1 = i == 0 ? invBase : optionalDouble(fancyMul(invBase, iweightStep(i))); T invWeight2 = optionalDouble(fancyMul(invWeight1, IWEIGHT_STEP)); // Generate big-word/little-word flags #if BIGLIT -// bool biglit0 = frac_bits + i * FRAC_BITS_BIGSTEP <= FRAC_BPW_HI; -// bool biglit1 = frac_bits + i * FRAC_BITS_BIGSTEP + FRAC_BPW_HI <= FRAC_BPW_HI; - bool biglit0 = frac_bits + i * FRAC_BITS_BIGSTEP <= FRAC_BPW_HI; - bool biglit1 = frac_bits + i * FRAC_BITS_BIGSTEP >= -FRAC_BPW_HI; + bool biglit0 = frac_bits + i * frac_bits_bigstep <= FRAC_BPW_HI; + bool biglit1 = frac_bits + i * frac_bits_bigstep >= -FRAC_BPW_HI; // Same as frac_bits + i * frac_bits_bigstep + FRAC_BPW_HI <= FRAC_BPW_HI; #else bool biglit0 = test(b, 2 * i); bool biglit1 = test(b, 2 * i + 1); @@ -117,10 +114,10 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( // Apply the inverse weights, optionally compute roundoff error, and convert to integer. Also apply MUL3 here. // Then propagate carries through two words (the first carry does not have to be accurately calculated because it will // be accurately calculated by carryFinal later on). The second carry must be accurate for output to the carry shuttle. - wu[i] = weightAndCarryPairSloppy(conjugate(u[i]), U2(invWeight1, invWeight2), + wu[i] = weightAndCarryPairSloppy(SWAP_XY(u[i]), U2(invWeight1, invWeight2), // For an LL test, add -2 as the very initial "carry in" // We'd normally use logical &&, but the compiler whines with warning and bitwise fixes it - (LL & (i == 0) & (line==0) & (me == 0)) ? -2 : 0, &roundMax, &carry[i], biglit0, biglit1, &carryMax); + (LL & (i == 0) & (line==0) & (me == 0)) ? -2 : 0, biglit0, biglit1, &carry[i], &roundMax, &carryMax); } #if ROE @@ -214,12 +211,203 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( // Apply each 32 or 64 bit carry to the 2 words for (i32 i = 0; i < NW; ++i) { #if BIGLIT - bool biglit0 = frac_bits + i * FRAC_BITS_BIGSTEP <= FRAC_BPW_HI; + bool biglit0 = frac_bits + i * frac_bits_bigstep <= FRAC_BPW_HI; #else bool biglit0 = test(b, 2 * i); #endif wu[i] = carryFinal(wu[i], carry[i], biglit0); - u[i] *= U2(wu[i].x, wu[i].y); + u[i] = U2(u[i].x * wu[i].x, u[i].y * wu[i].y); + } + + bar(); + +// fft_WIDTH(lds, u, smallTrig); + new_fft_WIDTH2(lds, u, smallTrig); + + writeCarryFusedLine(u, out, line); +} + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT32 + +// The "carryFused" is equivalent to the sequence: fftW, carryA, carryB, fftPremul. +// It uses "stairway forwarding" (forwarding carry data from one workgroup to the next) +KERNEL(G_W) carryFused(P(F2) out, CP(F2) in, u32 posROE, P(i64) carryShuttle, P(u32) ready, TrigFP32 smallTrig, + CP(u32) bits, ConstBigTabFP32 CONST_THREAD_WEIGHTS, BigTabFP32 THREAD_WEIGHTS, P(uint) bufROE) { + +#if 0 // fft_WIDTH uses shufl_int instead of shufl + local F2 lds[WIDTH / 4]; +#else + local F2 lds[WIDTH / 2]; +#endif + + F2 u[NW]; + + u32 gr = get_group_id(0); + u32 me = get_local_id(0); + + u32 H = BIG_HEIGHT; + u32 line = gr % H; + +#if HAS_ASM + __asm("s_setprio 3"); +#endif + + readCarryFusedLine(in, u, line); + +// Try this weird FFT_width call that adds a "hidden zero" when unrolling. This prevents the compiler from finding +// common sub-expressions to re-use in the second fft_WIDTH call. Re-using this data requires dozens of VGPRs +// which causes a terrible reduction in occupancy. +#if ZEROHACK_W + u32 zerohack = get_group_id(0) / 131072; + new_fft_WIDTH1(lds + zerohack, u, smallTrig + zerohack); +#else + new_fft_WIDTH1(lds, u, smallTrig); +#endif + + Word2 wu[NW]; +#if AMDGPU + F2 weights = fancyMul(THREAD_WEIGHTS[me], THREAD_WEIGHTS[G_W + line]); +#else + F2 weights = fancyMul(CONST_THREAD_WEIGHTS[me], THREAD_WEIGHTS[G_W + line]); // On nVidia, don't pollute the constant cache with line weights +#endif + + P(CFcarry) carryShuttlePtr = (P(CFcarry)) carryShuttle; + CFcarry carry[NW+1]; + +#if AMDGPU +#define CarryShuttleAccess(me,i) ((me) * NW + (i)) // Generates denser global_load_dwordx4 instructions +//#define CarryShuttleAccess(me,i) ((me) * 4 + (i)%4 + (i)/4 * 4*G_W) // Also generates global_load_dwordx4 instructions and unit stride when NW=8 +#else +#define CarryShuttleAccess(me,i) ((me) + (i) * G_W) // nVidia likes this unit stride better +#endif + + float roundMax = 0; + float carryMax = 0; + + // Calculate the most significant 32-bits of FRAC_BPW * the word index. Also add FRAC_BPW_HI to test first biglit flag. + u32 word_index = (me * H + line) * 2; + u32 frac_bits = word_index * FRAC_BPW_HI + mad_hi (word_index, FRAC_BPW_LO, FRAC_BPW_HI); + const u32 frac_bits_bigstep = ((G_W * H * 2) * FRAC_BPW_HI + (u32)(((u64)(G_W * H * 2) * FRAC_BPW_LO) >> 32)); + + // Apply the inverse weights and carry propagate pairs to generate the output carries + + F invBase = optionalDouble(weights.x); + + for (u32 i = 0; i < NW; ++i) { + F invWeight1 = i == 0 ? invBase : optionalDouble(fancyMul(invBase, iweightStep(i))); + F invWeight2 = optionalDouble(fancyMul(invWeight1, IWEIGHT_STEP)); + + // Generate big-word/little-word flags + bool biglit0 = frac_bits + i * frac_bits_bigstep <= FRAC_BPW_HI; + bool biglit1 = frac_bits + i * frac_bits_bigstep >= -FRAC_BPW_HI; // Same as frac_bits + i * frac_bits_bigstep + FRAC_BPW_HI <= FRAC_BPW_HI; + + // Apply the inverse weights, optionally compute roundoff error, and convert to integer. Also apply MUL3 here. + // Then propagate carries through two words (the first carry does not have to be accurately calculated because it will + // be accurately calculated by carryFinal later on). The second carry must be accurate for output to the carry shuttle. + wu[i] = weightAndCarryPairSloppy(SWAP_XY(u[i]), U2(invWeight1, invWeight2), + // For an LL test, add -2 as the very initial "carry in" + // We'd normally use logical &&, but the compiler whines with warning and bitwise fixes it + (LL & (i == 0) & (line==0) & (me == 0)) ? -2 : 0, biglit0, biglit1, &carry[i], &roundMax, &carryMax); + } + +#if ROE + updateStats(bufROE, posROE, roundMax); +#elif STATS & (1 << MUL3) + updateStats(bufROE, posROE, carryMax); +#endif + + // Write out our carries. Only groups 0 to H-1 need to write carries out. + // Group H is a duplicate of group 0 (producing the same results) so we don't care about group H writing out, + // but it's fine either way. + if (gr < H) { for (i32 i = 0; i < NW; ++i) { carryShuttlePtr[gr * WIDTH + CarryShuttleAccess(me, i)] = carry[i]; } } + + // Tell next line that its carries are ready + if (gr < H) { +#if OLD_FENCE + // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); + write_mem_fence(CLK_GLOBAL_MEM_FENCE); + bar(); + if (me == 0) { atomic_store((atomic_uint *) &ready[gr], 1); } +#else + write_mem_fence(CLK_GLOBAL_MEM_FENCE); + if (me % WAVEFRONT == 0) { + u32 pos = gr * (G_W / WAVEFRONT) + me / WAVEFRONT; + atomic_store((atomic_uint *) &ready[pos], 1); + } +#endif + } + + // Line zero will be redone when gr == H + if (gr == 0) { return; } + + // Do some work while our carries may not be ready +#if HAS_ASM + __asm("s_setprio 0"); +#endif + + // Calculate inverse weights + F base = optionalHalve(weights.y); + for (u32 i = 0; i < NW; ++i) { + F weight1 = i == 0 ? base : optionalHalve(fancyMul(base, fweightStep(i))); + F weight2 = optionalHalve(fancyMul(weight1, WEIGHT_STEP)); + u[i] = U2(weight1, weight2); + } + + // Wait until our carries are ready +#if OLD_FENCE + if (me == 0) { do { spin(); } while(!atomic_load_explicit((atomic_uint *) &ready[gr - 1], memory_order_relaxed, memory_scope_device)); } + // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); + bar(); + read_mem_fence(CLK_GLOBAL_MEM_FENCE); + // Clear carry ready flag for next iteration + if (me == 0) ready[gr - 1] = 0; +#else + u32 pos = (gr - 1) * (G_W / WAVEFRONT) + me / WAVEFRONT; + if (me % WAVEFRONT == 0) { + do { spin(); } while(atomic_load_explicit((atomic_uint *) &ready[pos], memory_order_relaxed, memory_scope_device) == 0); + } + mem_fence(CLK_GLOBAL_MEM_FENCE); + // Clear carry ready flag for next iteration + if (me % WAVEFRONT == 0) ready[(gr - 1) * (G_W / WAVEFRONT) + me / WAVEFRONT] = 0; +#endif +#if HAS_ASM + __asm("s_setprio 1"); +#endif + + // Read from the carryShuttle carries produced by the previous WIDTH row. Rotate carries from the last WIDTH row. + // The new carry layout lets the compiler generate global_load_dwordx4 instructions. + if (gr < H) { + for (i32 i = 0; i < NW; ++i) { + carry[i] = carryShuttlePtr[(gr - 1) * WIDTH + CarryShuttleAccess(me, i)]; + } + } else { + +#if !OLD_FENCE + // For gr==H we need the barrier since the carry reading is shifted, thus the per-wavefront trick does not apply. + bar(); +#endif + + for (i32 i = 0; i < NW; ++i) { + carry[i] = carryShuttlePtr[(gr - 1) * WIDTH + CarryShuttleAccess((me + G_W - 1) % G_W, i) /* ((me!=0) + NW - 1 + i) % NW*/]; + } + + if (me == 0) { + carry[NW] = carry[NW-1]; + for (i32 i = NW-1; i; --i) { carry[i] = carry[i-1]; } + carry[0] = carry[NW]; + } + } + + // Apply each 32 or 64 bit carry to the 2 words + for (i32 i = 0; i < NW; ++i) { + bool biglit0 = frac_bits + i * frac_bits_bigstep <= FRAC_BPW_HI; + wu[i] = carryFinal(wu[i], carry[i], biglit0); + u[i] = U2(u[i].x * wu[i].x, u[i].y * wu[i].y); } bar(); @@ -229,3 +417,1726 @@ KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P( writeCarryFusedLine(u, out, line); } + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT31 + +// The "carryFused" is equivalent to the sequence: fftW, carryA, carryB, fftPremul. +// It uses "stairway forwarding" (forwarding carry data from one workgroup to the next) +KERNEL(G_W) carryFused(P(GF31) out, CP(GF31) in, u32 posROE, P(i64) carryShuttle, P(u32) ready, TrigGF31 smallTrig, P(uint) bufROE) { + +#if 0 // fft_WIDTH uses shufl_int instead of shufl + local GF31 lds[WIDTH / 4]; +#else + local GF31 lds[WIDTH / 2]; +#endif + + GF31 u[NW]; + + u32 gr = get_group_id(0); + u32 me = get_local_id(0); + + u32 H = BIG_HEIGHT; + u32 line = gr % H; + +#if HAS_ASM + __asm("s_setprio 3"); +#endif + + readCarryFusedLine(in, u, line); + +// Try this weird FFT_width call that adds a "hidden zero" when unrolling. This prevents the compiler from finding +// common sub-expressions to re-use in the second fft_WIDTH call. Re-using this data requires dozens of VGPRs +// which causes a terrible reduction in occupancy. +#if ZEROHACK_W + u32 zerohack = get_group_id(0) / 131072; + new_fft_WIDTH1(lds + zerohack, u, smallTrig + zerohack); +#else + new_fft_WIDTH1(lds, u, smallTrig); +#endif + + Word2 wu[NW]; + + P(CFcarry) carryShuttlePtr = (P(CFcarry)) carryShuttle; + CFcarry carry[NW+1]; + +#if AMDGPU +#define CarryShuttleAccess(me,i) ((me) * NW + (i)) // Generates denser global_load_dwordx4 instructions +//#define CarryShuttleAccess(me,i) ((me) * 4 + (i)%4 + (i)/4 * 4*G_W) // Also generates global_load_dwordx4 instructions and unit stride when NW=8 +#else +#define CarryShuttleAccess(me,i) ((me) + (i) * G_W) // nVidia likes this unit stride better +#endif + + u32 roundMax = 0; + float carryMax = 0; + + u32 word_index = (me * H + line) * 2; + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + // Weights can be applied with shifts because 2 is the 60th root GF31. + // Let s be the shift amount for word 1. The shift amount for word x is ceil(x * (s - 1) + num_big_words_less_than_x) % 31. + const u32 log2_root_two = (u32) (((1ULL << 30) / NWORDS) % 31); + const u32 bigword_weight_shift = (NWORDS - EXP % NWORDS) * log2_root_two % 31; + const u32 bigword_weight_shift_minus1 = (bigword_weight_shift + 30) % 31; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } combo; +#define frac_bits combo.a[0] +#define weight_shift combo.a[1] +#define combo_counter combo.b + + const u64 combo_step = ((u64) bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 combo_bigstep = ((G_W * H * 2 - 1) * combo_step + (((u64) (G_W * H * 2 - 1) * FRAC_BPW_LO) >> 32)) % (31ULL << 32); + combo_counter = word_index * combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + weight_shift = weight_shift % 31; + u64 starting_combo_counter = combo_counter; // Save starting counter before adding log2_NWORDS+1 for applying weights after carry propagation + + // We also adjust shift amount for the fact that NTT returns results multiplied by 2*NWORDS. + const u32 log2_NWORDS = (WIDTH == 256 ? 8 : WIDTH == 512 ? 9 : WIDTH == 1024 ? 10 : 12) + + (MIDDLE == 1 ? 0 : MIDDLE == 2 ? 1 : MIDDLE == 4 ? 2 : MIDDLE == 8 ? 3 : 4) + + (SMALL_HEIGHT == 256 ? 8 : SMALL_HEIGHT == 512 ? 9 : SMALL_HEIGHT == 1024 ? 10 : 12) + 1; + weight_shift = weight_shift + log2_NWORDS + 1; + if (weight_shift > 31) weight_shift -= 31; + + // Apply the inverse weights and carry propagate pairs to generate the output carries + + for (u32 i = 0; i < NW; ++i) { + // Generate the second weight shift + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 31) weight_shift -= 31; + u32 weight_shift1 = weight_shift; + + // Generate big-word/little-word flags + bool biglit0 = frac_bits <= FRAC_BPW_HI; + bool biglit1 = frac_bits >= -FRAC_BPW_HI; // Same as frac_bits + FRAC_BPW_HI <= FRAC_BPW_HI; + + // Apply the inverse weights, optionally compute roundoff error, and convert to integer. Also apply MUL3 here. + // Then propagate carries through two words (the first carry does not have to be accurately calculated because it will + // be accurately calculated by carryFinal later on). The second carry must be accurate for output to the carry shuttle. + wu[i] = weightAndCarryPairSloppy(SWAP_XY(u[i]), weight_shift0, weight_shift1, + // For an LL test, add -2 as the very initial "carry in" + // We'd normally use logical &&, but the compiler whines with warning and bitwise fixes it + (LL & (i == 0) & (line==0) & (me == 0)) ? -2 : 0, biglit0, biglit1, &carry[i], &roundMax, &carryMax); + + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_bigstep; + if (weight_shift > 31) weight_shift -= 31; + } + combo_counter = starting_combo_counter; // Restore starting counter for applying weights after carry propagation + +#if ROE + float fltRoundMax = (float) roundMax / (float) M31; // For speed, roundoff was computed as 32-bit integer. Convert to float. + updateStats(bufROE, posROE, fltRoundMax); +#elif STATS & (1 << MUL3) + updateStats(bufROE, posROE, carryMax); +#endif + + // Write out our carries. Only groups 0 to H-1 need to write carries out. + // Group H is a duplicate of group 0 (producing the same results) so we don't care about group H writing out, + // but it's fine either way. + if (gr < H) { for (i32 i = 0; i < NW; ++i) { carryShuttlePtr[gr * WIDTH + CarryShuttleAccess(me, i)] = carry[i]; } } + + // Tell next line that its carries are ready + if (gr < H) { +#if OLD_FENCE + // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); + write_mem_fence(CLK_GLOBAL_MEM_FENCE); + bar(); + if (me == 0) { atomic_store((atomic_uint *) &ready[gr], 1); } +#else + write_mem_fence(CLK_GLOBAL_MEM_FENCE); + if (me % WAVEFRONT == 0) { + u32 pos = gr * (G_W / WAVEFRONT) + me / WAVEFRONT; + atomic_store((atomic_uint *) &ready[pos], 1); + } +#endif + } + + // Line zero will be redone when gr == H + if (gr == 0) { return; } + + // Do some work while our carries may not be ready +#if HAS_ASM + __asm("s_setprio 0"); +#endif + + // Wait until our carries are ready +#if OLD_FENCE + if (me == 0) { do { spin(); } while(!atomic_load_explicit((atomic_uint *) &ready[gr - 1], memory_order_relaxed, memory_scope_device)); } + // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); + bar(); + read_mem_fence(CLK_GLOBAL_MEM_FENCE); + // Clear carry ready flag for next iteration + if (me == 0) ready[gr - 1] = 0; +#else + u32 pos = (gr - 1) * (G_W / WAVEFRONT) + me / WAVEFRONT; + if (me % WAVEFRONT == 0) { + do { spin(); } while(atomic_load_explicit((atomic_uint *) &ready[pos], memory_order_relaxed, memory_scope_device) == 0); + } + mem_fence(CLK_GLOBAL_MEM_FENCE); + // Clear carry ready flag for next iteration + if (me % WAVEFRONT == 0) ready[(gr - 1) * (G_W / WAVEFRONT) + me / WAVEFRONT] = 0; +#endif +#if HAS_ASM + __asm("s_setprio 1"); +#endif + + // Read from the carryShuttle carries produced by the previous WIDTH row. Rotate carries from the last WIDTH row. + // The new carry layout lets the compiler generate global_load_dwordx4 instructions. + if (gr < H) { + for (i32 i = 0; i < NW; ++i) { + carry[i] = carryShuttlePtr[(gr - 1) * WIDTH + CarryShuttleAccess(me, i)]; + } + } else { + +#if !OLD_FENCE + // For gr==H we need the barrier since the carry reading is shifted, thus the per-wavefront trick does not apply. + bar(); +#endif + + for (i32 i = 0; i < NW; ++i) { + carry[i] = carryShuttlePtr[(gr - 1) * WIDTH + CarryShuttleAccess((me + G_W - 1) % G_W, i) /* ((me!=0) + NW - 1 + i) % NW*/]; + } + + if (me == 0) { + carry[NW] = carry[NW-1]; + for (i32 i = NW-1; i; --i) { carry[i] = carry[i-1]; } + carry[0] = carry[NW]; + } + } + + // Apply each 32 or 64 bit carry to the 2 words. Apply weights. + for (i32 i = 0; i < NW; ++i) { + // Generate the second weight shift + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 31) weight_shift -= 31; + u32 weight_shift1 = weight_shift; + // Generate big-word/little-word flag, propagate final carry + bool biglit0 = frac_bits <= FRAC_BPW_HI; + wu[i] = carryFinal(wu[i], carry[i], biglit0); + u[i] = U2(shl(make_Z31(wu[i].x), weight_shift0), shl(make_Z31(wu[i].y), weight_shift1)); + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_bigstep; + if (weight_shift > 31) weight_shift -= 31; + } + + bar(); + + new_fft_WIDTH2(lds, u, smallTrig); + + writeCarryFusedLine(u, out, line); +} + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT61 + +// The "carryFused" is equivalent to the sequence: fftW, carryA, carryB, fftPremul. +// It uses "stairway forwarding" (forwarding carry data from one workgroup to the next) +KERNEL(G_W) carryFused(P(GF61) out, CP(GF61) in, u32 posROE, P(i64) carryShuttle, P(u32) ready, TrigGF61 smallTrig, P(uint) bufROE) { + +#if 0 // fft_WIDTH uses shufl_int instead of shufl + local GF61 lds[WIDTH / 4]; +#else + local GF61 lds[WIDTH / 2]; +#endif + + GF61 u[NW]; + + u32 gr = get_group_id(0); + u32 me = get_local_id(0); + + u32 H = BIG_HEIGHT; + u32 line = gr % H; + +#if HAS_ASM + __asm("s_setprio 3"); +#endif + + readCarryFusedLine(in, u, line); + +// Try this weird FFT_width call that adds a "hidden zero" when unrolling. This prevents the compiler from finding +// common sub-expressions to re-use in the second fft_WIDTH call. Re-using this data requires dozens of VGPRs +// which causes a terrible reduction in occupancy. +#if ZEROHACK_W + u32 zerohack = get_group_id(0) / 131072; + new_fft_WIDTH1(lds + zerohack, u, smallTrig + zerohack); +#else + new_fft_WIDTH1(lds, u, smallTrig); +#endif + + Word2 wu[NW]; + +#if MUL3 + P(i64) carryShuttlePtr = (P(i64)) carryShuttle; + i64 carry[NW+1]; +#else + P(CFcarry) carryShuttlePtr = (P(CFcarry)) carryShuttle; + CFcarry carry[NW+1]; +#endif + +#if AMDGPU +#define CarryShuttleAccess(me,i) ((me) * NW + (i)) // Generates denser global_load_dwordx4 instructions +//#define CarryShuttleAccess(me,i) ((me) * 4 + (i)%4 + (i)/4 * 4*G_W) // Also generates global_load_dwordx4 instructions and unit stride when NW=8 +#else +#define CarryShuttleAccess(me,i) ((me) + (i) * G_W) // nVidia likes this unit stride better +#endif + + u32 roundMax = 0; + float carryMax = 0; + + u32 word_index = (me * H + line) * 2; + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + // Weights can be applied with shifts because 2 is the 60th root GF61. + // Let s be the shift amount for word 1. The shift amount for word x is ceil(x * (s - 1) + num_big_words_less_than_x) % 61. + const u32 log2_root_two = (u32) (((1ULL << 60) / NWORDS) % 61); + const u32 bigword_weight_shift = (NWORDS - EXP % NWORDS) * log2_root_two % 61; + const u32 bigword_weight_shift_minus1 = (bigword_weight_shift + 60) % 61; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } combo; +#define frac_bits combo.a[0] +#define weight_shift combo.a[1] +#define combo_counter combo.b + + const u64 combo_step = ((u64) bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 combo_bigstep = ((G_W * H * 2 - 1) * combo_step + (((u64) (G_W * H * 2 - 1) * FRAC_BPW_LO) >> 32)) % (61ULL << 32); + combo_counter = word_index * combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + weight_shift = weight_shift % 61; + u64 starting_combo_counter = combo_counter; // Save starting counter before adding log2_NWORDS+1 for applying weights after carry propagation + + // We also adjust shift amount for the fact that NTT returns results multiplied by 2*NWORDS. + const u32 log2_NWORDS = (WIDTH == 256 ? 8 : WIDTH == 512 ? 9 : WIDTH == 1024 ? 10 : 12) + + (MIDDLE == 1 ? 0 : MIDDLE == 2 ? 1 : MIDDLE == 4 ? 2 : MIDDLE == 8 ? 3 : 4) + + (SMALL_HEIGHT == 256 ? 8 : SMALL_HEIGHT == 512 ? 9 : SMALL_HEIGHT == 1024 ? 10 : 12) + 1; + weight_shift = weight_shift + log2_NWORDS + 1; + if (weight_shift > 61) weight_shift -= 61; + + // Apply the inverse weights and carry propagate pairs to generate the output carries + + for (u32 i = 0; i < NW; ++i) { + // Generate the second weight shift + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 61) weight_shift -= 61; + u32 weight_shift1 = weight_shift; + + // Generate big-word/little-word flags + bool biglit0 = frac_bits <= FRAC_BPW_HI; + bool biglit1 = frac_bits >= -FRAC_BPW_HI; // Same as frac_bits + FRAC_BPW_HI <= FRAC_BPW_HI; + + // Apply the inverse weights, optionally compute roundoff error, and convert to integer. Also apply MUL3 here. + // Then propagate carries through two words (the first carry does not have to be accurately calculated because it will + // be accurately calculated by carryFinal later on). The second carry must be accurate for output to the carry shuttle. + wu[i] = weightAndCarryPairSloppy(SWAP_XY(u[i]), weight_shift0, weight_shift1, + // For an LL test, add -2 as the very initial "carry in" + // We'd normally use logical &&, but the compiler whines with warning and bitwise fixes it + (LL & (i == 0) & (line==0) & (me == 0)) ? -2 : 0, biglit0, biglit1, &carry[i], &roundMax, &carryMax); + + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_bigstep; + if (weight_shift > 61) weight_shift -= 61; + } + combo_counter = starting_combo_counter; // Restore starting counter for applying weights after carry propagation + +#if ROE + float fltRoundMax = (float) roundMax / (float) (M61 >> 32); // For speed, roundoff was computed as 32-bit integer. Convert to float. + updateStats(bufROE, posROE, fltRoundMax); +#elif STATS & (1 << MUL3) + updateStats(bufROE, posROE, carryMax); +#endif + + // Write out our carries. Only groups 0 to H-1 need to write carries out. + // Group H is a duplicate of group 0 (producing the same results) so we don't care about group H writing out, + // but it's fine either way. + if (gr < H) { for (i32 i = 0; i < NW; ++i) { carryShuttlePtr[gr * WIDTH + CarryShuttleAccess(me, i)] = carry[i]; } } + + // Tell next line that its carries are ready + if (gr < H) { +#if OLD_FENCE + // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); + write_mem_fence(CLK_GLOBAL_MEM_FENCE); + bar(); + if (me == 0) { atomic_store((atomic_uint *) &ready[gr], 1); } +#else + write_mem_fence(CLK_GLOBAL_MEM_FENCE); + if (me % WAVEFRONT == 0) { + u32 pos = gr * (G_W / WAVEFRONT) + me / WAVEFRONT; + atomic_store((atomic_uint *) &ready[pos], 1); + } +#endif + } + + // Line zero will be redone when gr == H + if (gr == 0) { return; } + + // Do some work while our carries may not be ready +#if HAS_ASM + __asm("s_setprio 0"); +#endif + + // Wait until our carries are ready +#if OLD_FENCE + if (me == 0) { do { spin(); } while(!atomic_load_explicit((atomic_uint *) &ready[gr - 1], memory_order_relaxed, memory_scope_device)); } + // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); + bar(); + read_mem_fence(CLK_GLOBAL_MEM_FENCE); + // Clear carry ready flag for next iteration + if (me == 0) ready[gr - 1] = 0; +#else + u32 pos = (gr - 1) * (G_W / WAVEFRONT) + me / WAVEFRONT; + if (me % WAVEFRONT == 0) { + do { spin(); } while(atomic_load_explicit((atomic_uint *) &ready[pos], memory_order_relaxed, memory_scope_device) == 0); + } + mem_fence(CLK_GLOBAL_MEM_FENCE); + // Clear carry ready flag for next iteration + if (me % WAVEFRONT == 0) ready[(gr - 1) * (G_W / WAVEFRONT) + me / WAVEFRONT] = 0; +#endif +#if HAS_ASM + __asm("s_setprio 1"); +#endif + + // Read from the carryShuttle carries produced by the previous WIDTH row. Rotate carries from the last WIDTH row. + // The new carry layout lets the compiler generate global_load_dwordx4 instructions. + if (gr < H) { + for (i32 i = 0; i < NW; ++i) { + carry[i] = carryShuttlePtr[(gr - 1) * WIDTH + CarryShuttleAccess(me, i)]; + } + } else { + +#if !OLD_FENCE + // For gr==H we need the barrier since the carry reading is shifted, thus the per-wavefront trick does not apply. + bar(); +#endif + + for (i32 i = 0; i < NW; ++i) { + carry[i] = carryShuttlePtr[(gr - 1) * WIDTH + CarryShuttleAccess((me + G_W - 1) % G_W, i) /* ((me!=0) + NW - 1 + i) % NW*/]; + } + + if (me == 0) { + carry[NW] = carry[NW-1]; + for (i32 i = NW-1; i; --i) { carry[i] = carry[i-1]; } + carry[0] = carry[NW]; + } + } + + // Apply each 32 or 64 bit carry to the 2 words. Apply weights. + for (i32 i = 0; i < NW; ++i) { + // Generate the second weight shift + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 61) weight_shift -= 61; + u32 weight_shift1 = weight_shift; + // Generate big-word/little-word flag, propagate final carry + bool biglit0 = frac_bits <= FRAC_BPW_HI; + wu[i] = carryFinal(wu[i], carry[i], biglit0); + u[i] = U2(shl(make_Z61(wu[i].x), weight_shift0), shl(make_Z61(wu[i].y), weight_shift1)); + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_bigstep; + if (weight_shift > 61) weight_shift -= 61; + } + + bar(); + + new_fft_WIDTH2(lds, u, smallTrig); + + writeCarryFusedLine(u, out, line); +} + + +/**************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP64 & GF(M31^2) */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT6431 + +// The "carryFused" is equivalent to the sequence: fftW, carryA, carryB, fftPremul. +// It uses "stairway forwarding" (forwarding carry data from one workgroup to the next) +KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P(u32) ready, Trig smallTrig, + CP(u32) bits, ConstBigTab CONST_THREAD_WEIGHTS, BigTab THREAD_WEIGHTS, P(uint) bufROE) { + + local T2 lds[WIDTH / 2]; + local GF31 *lds31 = (local GF31 *) lds; + + T2 u[NW]; + GF31 u31[NW]; + + u32 gr = get_group_id(0); + u32 me = get_local_id(0); + + u32 H = BIG_HEIGHT; + u32 line = gr % H; + + CP(GF31) in31 = (CP(GF31)) (in + DISTGF31); + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 smallTrig31 = (TrigGF31) (smallTrig + DISTWTRIGGF31); + +#if HAS_ASM + __asm("s_setprio 3"); +#endif + + readCarryFusedLine(in, u, line); + readCarryFusedLine(in31, u31, line); + +// Try this weird FFT_width call that adds a "hidden zero" when unrolling. This prevents the compiler from finding +// common sub-expressions to re-use in the second fft_WIDTH call. Re-using this data requires dozens of VGPRs +// which causes a terrible reduction in occupancy. +#if ZEROHACK_W + u32 zerohack = get_group_id(0) / 131072; + new_fft_WIDTH1(lds + zerohack, u, smallTrig + zerohack); + bar(); + new_fft_WIDTH1(lds31 + zerohack, u31, smallTrig31 + zerohack); +#else + new_fft_WIDTH1(lds, u, smallTrig); + bar(); + new_fft_WIDTH1(lds31, u31, smallTrig31); +#endif + + Word2 wu[NW]; +#if AMDGPU + T2 weights = fancyMul(THREAD_WEIGHTS[me], THREAD_WEIGHTS[G_W + line]); +#else + T2 weights = fancyMul(CONST_THREAD_WEIGHTS[me], THREAD_WEIGHTS[G_W + line]); // On nVidia, don't pollute the constant cache with line weights +#endif + P(i64) carryShuttlePtr = (P(i64)) carryShuttle; + i64 carry[NW+1]; + +#if AMDGPU +#define CarryShuttleAccess(me,i) ((me) * NW + (i)) // Generates denser global_load_dwordx4 instructions +//#define CarryShuttleAccess(me,i) ((me) * 4 + (i)%4 + (i)/4 * 4*G_W) // Also generates global_load_dwordx4 instructions and unit stride when NW=8 +#else +#define CarryShuttleAccess(me,i) ((me) + (i) * G_W) // nVidia likes this unit stride better +#endif + + float roundMax = 0; + float carryMax = 0; + + u32 word_index = (me * H + line) * 2; + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + // Let s be the shift amount for word 1. The shift amount for word x is ceil(x * (s - 1) + num_big_words_less_than_x) % 31. + const u32 log2_root_two = (u32) (((1ULL << 30) / NWORDS) % 31); + const u32 bigword_weight_shift = (NWORDS - EXP % NWORDS) * log2_root_two % 31; + const u32 bigword_weight_shift_minus1 = (bigword_weight_shift + 30) % 31; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } combo; +#define frac_bits combo.a[0] +#define weight_shift combo.a[1] +#define combo_counter combo.b + + const u64 combo_step = ((u64) bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 combo_bigstep = ((G_W * H * 2 - 1) * combo_step + (((u64) (G_W * H * 2 - 1) * FRAC_BPW_LO) >> 32)) % (31ULL << 32); + combo_counter = word_index * combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + weight_shift = weight_shift % 31; + u64 starting_combo_counter = combo_counter; // Save starting counter before adding log2_NWORDS+1 for applying weights after carry propagation + + // We also adjust shift amount for the fact that NTT returns results multiplied by 2*NWORDS. + const u32 log2_NWORDS = (WIDTH == 256 ? 8 : WIDTH == 512 ? 9 : WIDTH == 1024 ? 10 : 12) + + (MIDDLE == 1 ? 0 : MIDDLE == 2 ? 1 : MIDDLE == 4 ? 2 : MIDDLE == 8 ? 3 : 4) + + (SMALL_HEIGHT == 256 ? 8 : SMALL_HEIGHT == 512 ? 9 : SMALL_HEIGHT == 1024 ? 10 : 12) + 1; + weight_shift = weight_shift + log2_NWORDS + 1; + if (weight_shift > 31) weight_shift -= 31; + + // Apply the inverse weights and carry propagate pairs to generate the output carries + + T invBase = optionalDouble(weights.x); + for (u32 i = 0; i < NW; ++i) { + // Generate the FP64 weights and second GF31 weight shift + T invWeight1 = i == 0 ? invBase : optionalDouble(fancyMul(invBase, iweightStep(i))); + T invWeight2 = optionalDouble(fancyMul(invWeight1, IWEIGHT_STEP)); + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 31) weight_shift -= 31; + u32 weight_shift1 = weight_shift; + + // Generate big-word/little-word flags + bool biglit0 = frac_bits <= FRAC_BPW_HI; + bool biglit1 = frac_bits >= -FRAC_BPW_HI; // Same as frac_bits + FRAC_BPW_HI <= FRAC_BPW_HI; + + // Apply the inverse weights, optionally compute roundoff error, and convert to integer. Also apply MUL3 here. + // Then propagate carries through two words (the first carry does not have to be accurately calculated because it will + // be accurately calculated by carryFinal later on). The second carry must be accurate for output to the carry shuttle. + wu[i] = weightAndCarryPairSloppy(SWAP_XY(u[i]), SWAP_XY(u31[i]), invWeight1, invWeight2, weight_shift0, weight_shift1, + // For an LL test, add -2 as the very initial "carry in" + // We'd normally use logical &&, but the compiler whines with warning and bitwise fixes it + LL != 0, (LL & (i == 0) & (line==0) & (me == 0)) ? -2 : 0, biglit0, biglit1, &carry[i], &roundMax, &carryMax); + + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_bigstep; + if (weight_shift > 31) weight_shift -= 31; + } + combo_counter = starting_combo_counter; // Restore starting counter for applying weights after carry propagation + +#if ROE + updateStats(bufROE, posROE, roundMax); +#elif STATS & (1 << MUL3) + updateStats(bufROE, posROE, carryMax); +#endif + + // Write out our carries. Only groups 0 to H-1 need to write carries out. + // Group H is a duplicate of group 0 (producing the same results) so we don't care about group H writing out, + // but it's fine either way. + if (gr < H) { for (i32 i = 0; i < NW; ++i) { carryShuttlePtr[gr * WIDTH + CarryShuttleAccess(me, i)] = carry[i]; } } + + // Tell next line that its carries are ready + if (gr < H) { +#if OLD_FENCE + // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); + write_mem_fence(CLK_GLOBAL_MEM_FENCE); + bar(); + if (me == 0) { atomic_store((atomic_uint *) &ready[gr], 1); } +#else + write_mem_fence(CLK_GLOBAL_MEM_FENCE); + if (me % WAVEFRONT == 0) { + u32 pos = gr * (G_W / WAVEFRONT) + me / WAVEFRONT; + atomic_store((atomic_uint *) &ready[pos], 1); + } +#endif + } + + // Line zero will be redone when gr == H + if (gr == 0) { return; } + + // Do some work while our carries may not be ready +#if HAS_ASM + __asm("s_setprio 0"); +#endif + + // Calculate inverse weights + T base = optionalHalve(weights.y); + for (u32 i = 0; i < NW; ++i) { + T weight1 = i == 0 ? base : optionalHalve(fancyMul(base, fweightStep(i))); + T weight2 = optionalHalve(fancyMul(weight1, WEIGHT_STEP)); + u[i] = U2(weight1, weight2); + } + + // Wait until our carries are ready +#if OLD_FENCE + if (me == 0) { do { spin(); } while(!atomic_load_explicit((atomic_uint *) &ready[gr - 1], memory_order_relaxed, memory_scope_device)); } + // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); + bar(); + read_mem_fence(CLK_GLOBAL_MEM_FENCE); + // Clear carry ready flag for next iteration + if (me == 0) ready[gr - 1] = 0; +#else + u32 pos = (gr - 1) * (G_W / WAVEFRONT) + me / WAVEFRONT; + if (me % WAVEFRONT == 0) { + do { spin(); } while(atomic_load_explicit((atomic_uint *) &ready[pos], memory_order_relaxed, memory_scope_device) == 0); + } + mem_fence(CLK_GLOBAL_MEM_FENCE); + // Clear carry ready flag for next iteration + if (me % WAVEFRONT == 0) ready[(gr - 1) * (G_W / WAVEFRONT) + me / WAVEFRONT] = 0; +#endif +#if HAS_ASM + __asm("s_setprio 1"); +#endif + + // Read from the carryShuttle carries produced by the previous WIDTH row. Rotate carries from the last WIDTH row. + // The new carry layout lets the compiler generate global_load_dwordx4 instructions. + if (gr < H) { + for (i32 i = 0; i < NW; ++i) { + carry[i] = carryShuttlePtr[(gr - 1) * WIDTH + CarryShuttleAccess(me, i)]; + } + } else { + +#if !OLD_FENCE + // For gr==H we need the barrier since the carry reading is shifted, thus the per-wavefront trick does not apply. + bar(); +#endif + + for (i32 i = 0; i < NW; ++i) { + carry[i] = carryShuttlePtr[(gr - 1) * WIDTH + CarryShuttleAccess((me + G_W - 1) % G_W, i) /* ((me!=0) + NW - 1 + i) % NW*/]; + } + + if (me == 0) { + carry[NW] = carry[NW-1]; + for (i32 i = NW-1; i; --i) { carry[i] = carry[i-1]; } + carry[0] = carry[NW]; + } + } + + // Apply each 32 or 64 bit carry to the 2 words. Apply weights. + for (i32 i = 0; i < NW; ++i) { + // Generate the second weight shift + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 31) weight_shift -= 31; + u32 weight_shift1 = weight_shift; + // Generate big-word/little-word flag, propagate final carry + bool biglit0 = frac_bits <= FRAC_BPW_HI; + wu[i] = carryFinal(wu[i], carry[i], biglit0); + u[i] = U2(u[i].x * wu[i].x, u[i].y * wu[i].y); + u31[i] = U2(shl(make_Z31(wu[i].x), weight_shift0), shl(make_Z31(wu[i].y), weight_shift1)); + + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_bigstep; + if (weight_shift > 31) weight_shift -= 31; + } + + bar(); + + new_fft_WIDTH2(lds, u, smallTrig); + writeCarryFusedLine(u, out, line); + + bar(); + + new_fft_WIDTH2(lds31, u31, smallTrig31); + writeCarryFusedLine(u31, out31, line); +} + + +/**************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP32 & GF(M31^2) */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT3231 + +// The "carryFused" is equivalent to the sequence: fftW, carryA, carryB, fftPremul. +// It uses "stairway forwarding" (forwarding carry data from one workgroup to the next) +KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P(u32) ready, Trig smallTrig, + CP(u32) bits, ConstBigTabFP32 CONST_THREAD_WEIGHTS, BigTabFP32 THREAD_WEIGHTS, P(uint) bufROE) { + + local F2 ldsF2[WIDTH / 2]; + local GF31 *lds31 = (local GF31 *) ldsF2; + + F2 uF2[NW]; + GF31 u31[NW]; + + u32 gr = get_group_id(0); + u32 me = get_local_id(0); + + u32 H = BIG_HEIGHT; + u32 line = gr % H; + + CP(F2) inF2 = (CP(F2)) in; + P(F2) outF2 = (P(F2)) out; + TrigFP32 smallTrigF2 = (TrigFP32) smallTrig; + CP(GF31) in31 = (CP(GF31)) (in + DISTGF31); + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 smallTrig31 = (TrigGF31) (smallTrig + DISTWTRIGGF31); + +#if HAS_ASM + __asm("s_setprio 3"); +#endif + + readCarryFusedLine(inF2, uF2, line); + readCarryFusedLine(in31, u31, line); + +// Try this weird FFT_width call that adds a "hidden zero" when unrolling. This prevents the compiler from finding +// common sub-expressions to re-use in the second fft_WIDTH call. Re-using this data requires dozens of VGPRs +// which causes a terrible reduction in occupancy. +#if ZEROHACK_W + u32 zerohack = get_group_id(0) / 131072; + new_fft_WIDTH1(ldsF2 + zerohack, uF2, smallTrigF2 + zerohack); + bar(); + new_fft_WIDTH1(lds31 + zerohack, u31, smallTrig31 + zerohack); +#else + new_fft_WIDTH1(ldsF2, uF2, smallTrigF2); + bar(); + new_fft_WIDTH1(lds31, u31, smallTrig31); +#endif + + Word2 wu[NW]; +#if AMDGPU + F2 weights = fancyMul(THREAD_WEIGHTS[me], THREAD_WEIGHTS[G_W + line]); +#else + F2 weights = fancyMul(CONST_THREAD_WEIGHTS[me], THREAD_WEIGHTS[G_W + line]); // On nVidia, don't pollute the constant cache with line weights +#endif + P(i32) carryShuttlePtr = (P(i32)) carryShuttle; + i32 carry[NW+1]; + +#if AMDGPU +#define CarryShuttleAccess(me,i) ((me) * NW + (i)) // Generates denser global_load_dwordx4 instructions +//#define CarryShuttleAccess(me,i) ((me) * 4 + (i)%4 + (i)/4 * 4*G_W) // Also generates global_load_dwordx4 instructions and unit stride when NW=8 +#else +#define CarryShuttleAccess(me,i) ((me) + (i) * G_W) // nVidia likes this unit stride better +#endif + + float roundMax = 0; + float carryMax = 0; + + u32 word_index = (me * H + line) * 2; + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + // Let s be the shift amount for word 1. The shift amount for word x is ceil(x * (s - 1) + num_big_words_less_than_x) % 31. + const u32 log2_root_two = (u32) (((1ULL << 30) / NWORDS) % 31); + const u32 bigword_weight_shift = (NWORDS - EXP % NWORDS) * log2_root_two % 31; + const u32 bigword_weight_shift_minus1 = (bigword_weight_shift + 30) % 31; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } combo; +#define frac_bits combo.a[0] +#define weight_shift combo.a[1] +#define combo_counter combo.b + + const u64 combo_step = ((u64) bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 combo_bigstep = ((G_W * H * 2 - 1) * combo_step + (((u64) (G_W * H * 2 - 1) * FRAC_BPW_LO) >> 32)) % (31ULL << 32); + combo_counter = word_index * combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + weight_shift = weight_shift % 31; + u64 starting_combo_counter = combo_counter; // Save starting counter before adding log2_NWORDS+1 for applying weights after carry propagation + + // We also adjust shift amount for the fact that NTT returns results multiplied by 2*NWORDS. + const u32 log2_NWORDS = (WIDTH == 256 ? 8 : WIDTH == 512 ? 9 : WIDTH == 1024 ? 10 : 12) + + (MIDDLE == 1 ? 0 : MIDDLE == 2 ? 1 : MIDDLE == 4 ? 2 : MIDDLE == 8 ? 3 : 4) + + (SMALL_HEIGHT == 256 ? 8 : SMALL_HEIGHT == 512 ? 9 : SMALL_HEIGHT == 1024 ? 10 : 12) + 1; + weight_shift = weight_shift + log2_NWORDS + 1; + if (weight_shift > 31) weight_shift -= 31; + + // Apply the inverse weights and carry propagate pairs to generate the output carries + + F invBase = optionalDouble(weights.x); + for (u32 i = 0; i < NW; ++i) { + // Generate the FP32 weights and second GF31 weight shift + F invWeight1 = i == 0 ? invBase : optionalDouble(fancyMul(invBase, iweightStep(i))); + F invWeight2 = optionalDouble(fancyMul(invWeight1, IWEIGHT_STEP)); + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 31) weight_shift -= 31; + u32 weight_shift1 = weight_shift; + + // Generate big-word/little-word flags + bool biglit0 = frac_bits <= FRAC_BPW_HI; + bool biglit1 = frac_bits >= -FRAC_BPW_HI; // Same as frac_bits + FRAC_BPW_HI <= FRAC_BPW_HI; + + // Apply the inverse weights, optionally compute roundoff error, and convert to integer. Also apply MUL3 here. + // Then propagate carries through two words (the first carry does not have to be accurately calculated because it will + // be accurately calculated by carryFinal later on). The second carry must be accurate for output to the carry shuttle. + wu[i] = weightAndCarryPairSloppy(SWAP_XY(uF2[i]), SWAP_XY(u31[i]), invWeight1, invWeight2, weight_shift0, weight_shift1, + // For an LL test, add -2 as the very initial "carry in" + // We'd normally use logical &&, but the compiler whines with warning and bitwise fixes it + (LL & (i == 0) & (line==0) & (me == 0)) ? -2 : 0, biglit0, biglit1, &carry[i], &roundMax, &carryMax); + + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_bigstep; + if (weight_shift > 31) weight_shift -= 31; + } + combo_counter = starting_combo_counter; // Restore starting counter for applying weights after carry propagation + +#if ROE + updateStats(bufROE, posROE, roundMax); +#elif STATS & (1 << MUL3) + updateStats(bufROE, posROE, carryMax); +#endif + + // Write out our carries. Only groups 0 to H-1 need to write carries out. + // Group H is a duplicate of group 0 (producing the same results) so we don't care about group H writing out, + // but it's fine either way. + if (gr < H) { for (i32 i = 0; i < NW; ++i) { carryShuttlePtr[gr * WIDTH + CarryShuttleAccess(me, i)] = carry[i]; } } + + // Tell next line that its carries are ready + if (gr < H) { +#if OLD_FENCE + // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); + write_mem_fence(CLK_GLOBAL_MEM_FENCE); + bar(); + if (me == 0) { atomic_store((atomic_uint *) &ready[gr], 1); } +#else + write_mem_fence(CLK_GLOBAL_MEM_FENCE); + if (me % WAVEFRONT == 0) { + u32 pos = gr * (G_W / WAVEFRONT) + me / WAVEFRONT; + atomic_store((atomic_uint *) &ready[pos], 1); + } +#endif + } + + // Line zero will be redone when gr == H + if (gr == 0) { return; } + + // Do some work while our carries may not be ready +#if HAS_ASM + __asm("s_setprio 0"); +#endif + + // Calculate inverse weights + F base = optionalHalve(weights.y); + for (u32 i = 0; i < NW; ++i) { + F weight1 = i == 0 ? base : optionalHalve(fancyMul(base, fweightStep(i))); + F weight2 = optionalHalve(fancyMul(weight1, WEIGHT_STEP)); + uF2[i] = U2(weight1, weight2); + } + + // Wait until our carries are ready +#if OLD_FENCE + if (me == 0) { do { spin(); } while(!atomic_load_explicit((atomic_uint *) &ready[gr - 1], memory_order_relaxed, memory_scope_device)); } + // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); + bar(); + read_mem_fence(CLK_GLOBAL_MEM_FENCE); + // Clear carry ready flag for next iteration + if (me == 0) ready[gr - 1] = 0; +#else + u32 pos = (gr - 1) * (G_W / WAVEFRONT) + me / WAVEFRONT; + if (me % WAVEFRONT == 0) { + do { spin(); } while(atomic_load_explicit((atomic_uint *) &ready[pos], memory_order_relaxed, memory_scope_device) == 0); + } + mem_fence(CLK_GLOBAL_MEM_FENCE); + // Clear carry ready flag for next iteration + if (me % WAVEFRONT == 0) ready[(gr - 1) * (G_W / WAVEFRONT) + me / WAVEFRONT] = 0; +#endif +#if HAS_ASM + __asm("s_setprio 1"); +#endif + + // Read from the carryShuttle carries produced by the previous WIDTH row. Rotate carries from the last WIDTH row. + // The new carry layout lets the compiler generate global_load_dwordx4 instructions. + if (gr < H) { + for (i32 i = 0; i < NW; ++i) { + carry[i] = carryShuttlePtr[(gr - 1) * WIDTH + CarryShuttleAccess(me, i)]; + } + } else { + +#if !OLD_FENCE + // For gr==H we need the barrier since the carry reading is shifted, thus the per-wavefront trick does not apply. + bar(); +#endif + + for (i32 i = 0; i < NW; ++i) { + carry[i] = carryShuttlePtr[(gr - 1) * WIDTH + CarryShuttleAccess((me + G_W - 1) % G_W, i) /* ((me!=0) + NW - 1 + i) % NW*/]; + } + + if (me == 0) { + carry[NW] = carry[NW-1]; + for (i32 i = NW-1; i; --i) { carry[i] = carry[i-1]; } + carry[0] = carry[NW]; + } + } + + // Apply each 32 or 64 bit carry to the 2 words. Apply weights. + for (i32 i = 0; i < NW; ++i) { + // Generate the second weight shift + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 31) weight_shift -= 31; + u32 weight_shift1 = weight_shift; + // Generate big-word/little-word flag, propagate final carry + bool biglit0 = frac_bits <= FRAC_BPW_HI; + wu[i] = carryFinal(wu[i], carry[i], biglit0); + uF2[i] = U2(uF2[i].x * wu[i].x, uF2[i].y * wu[i].y); + u31[i] = U2(shl(make_Z31(wu[i].x), weight_shift0), shl(make_Z31(wu[i].y), weight_shift1)); + + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_bigstep; + if (weight_shift > 31) weight_shift -= 31; + } + + bar(); + + new_fft_WIDTH2(ldsF2, uF2, smallTrigF2); + writeCarryFusedLine(uF2, outF2, line); + + bar(); + + new_fft_WIDTH2(lds31, u31, smallTrig31); + writeCarryFusedLine(u31, out31, line); +} + + +/**************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP32 & GF(M61^2) */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT3261 + +// The "carryFused" is equivalent to the sequence: fftW, carryA, carryB, fftPremul. +// It uses "stairway forwarding" (forwarding carry data from one workgroup to the next) +KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P(u32) ready, Trig smallTrig, + CP(u32) bits, ConstBigTabFP32 CONST_THREAD_WEIGHTS, BigTabFP32 THREAD_WEIGHTS, P(uint) bufROE) { + + local GF61 lds61[WIDTH / 2]; + local F2 *ldsF2 = (local F2 *) lds61; + + F2 uF2[NW]; + GF61 u61[NW]; + + u32 gr = get_group_id(0); + u32 me = get_local_id(0); + + u32 H = BIG_HEIGHT; + u32 line = gr % H; + + CP(F2) inF2 = (CP(F2)) in; + P(F2) outF2 = (P(F2)) out; + TrigFP32 smallTrigF2 = (TrigFP32) smallTrig; + CP(GF61) in61 = (CP(GF61)) (in + DISTGF61); + P(GF61) out61 = (P(GF61)) (out + DISTGF61); + TrigGF61 smallTrig61 = (TrigGF61) (smallTrig + DISTWTRIGGF61); + +#if HAS_ASM + __asm("s_setprio 3"); +#endif + + readCarryFusedLine(inF2, uF2, line); + readCarryFusedLine(in61, u61, line); + +// Try this weird FFT_width call that adds a "hidden zero" when unrolling. This prevents the compiler from finding +// common sub-expressions to re-use in the second fft_WIDTH call. Re-using this data requires dozens of VGPRs +// which causes a terrible reduction in occupancy. +#if ZEROHACK_W + u32 zerohack = get_group_id(0) / 131072; + new_fft_WIDTH1(ldsF2 + zerohack, uF2, smallTrigF2 + zerohack); + bar(); + new_fft_WIDTH1(lds61 + zerohack, u61, smallTrig61 + zerohack); +#else + new_fft_WIDTH1(ldsF2, uF2, smallTrigF2); + bar(); + new_fft_WIDTH1(lds61, u61, smallTrig61); +#endif + + Word2 wu[NW]; +#if AMDGPU + F2 weights = fancyMul(THREAD_WEIGHTS[me], THREAD_WEIGHTS[G_W + line]); +#else + F2 weights = fancyMul(CONST_THREAD_WEIGHTS[me], THREAD_WEIGHTS[G_W + line]); // On nVidia, don't pollute the constant cache with line weights +#endif + P(i64) carryShuttlePtr = (P(i64)) carryShuttle; + i64 carry[NW+1]; + +#if AMDGPU +#define CarryShuttleAccess(me,i) ((me) * NW + (i)) // Generates denser global_load_dwordx4 instructions +//#define CarryShuttleAccess(me,i) ((me) * 4 + (i)%4 + (i)/4 * 4*G_W) // Also generates global_load_dwordx4 instructions and unit stride when NW=8 +#else +#define CarryShuttleAccess(me,i) ((me) + (i) * G_W) // nVidia likes this unit stride better +#endif + + float roundMax = 0; + float carryMax = 0; + + u32 word_index = (me * H + line) * 2; + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + // Let s be the shift amount for word 1. The shift amount for word x is ceil(x * (s - 1) + num_big_words_less_than_x) % 61. + const u32 log2_root_two = (u32) (((1ULL << 60) / NWORDS) % 61); + const u32 bigword_weight_shift = (NWORDS - EXP % NWORDS) * log2_root_two % 61; + const u32 bigword_weight_shift_minus1 = (bigword_weight_shift + 60) % 61; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } combo; +#define frac_bits combo.a[0] +#define weight_shift combo.a[1] +#define combo_counter combo.b + + const u64 combo_step = ((u64) bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 combo_bigstep = ((G_W * H * 2 - 1) * combo_step + (((u64) (G_W * H * 2 - 1) * FRAC_BPW_LO) >> 32)) % (61ULL << 32); + combo_counter = word_index * combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + weight_shift = weight_shift % 61; + u64 starting_combo_counter = combo_counter; // Save starting counter before adding log2_NWORDS+1 for applying weights after carry propagation + + // We also adjust shift amount for the fact that NTT returns results multiplied by 2*NWORDS. + const u32 log2_NWORDS = (WIDTH == 256 ? 8 : WIDTH == 512 ? 9 : WIDTH == 1024 ? 10 : 12) + + (MIDDLE == 1 ? 0 : MIDDLE == 2 ? 1 : MIDDLE == 4 ? 2 : MIDDLE == 8 ? 3 : 4) + + (SMALL_HEIGHT == 256 ? 8 : SMALL_HEIGHT == 512 ? 9 : SMALL_HEIGHT == 1024 ? 10 : 12) + 1; + weight_shift = weight_shift + log2_NWORDS + 1; + if (weight_shift > 61) weight_shift -= 61; + + // Apply the inverse weights and carry propagate pairs to generate the output carries + + F invBase = optionalDouble(weights.x); + for (u32 i = 0; i < NW; ++i) { + // Generate the FP32 weights and second GF61 weight shift + F invWeight1 = i == 0 ? invBase : optionalDouble(fancyMul(invBase, iweightStep(i))); + F invWeight2 = optionalDouble(fancyMul(invWeight1, IWEIGHT_STEP)); + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 61) weight_shift -= 61; + u32 weight_shift1 = weight_shift; + + // Generate big-word/little-word flags + bool biglit0 = frac_bits <= FRAC_BPW_HI; + bool biglit1 = frac_bits >= -FRAC_BPW_HI; // Same as frac_bits + FRAC_BPW_HI <= FRAC_BPW_HI; + + // Apply the inverse weights, optionally compute roundoff error, and convert to integer. Also apply MUL3 here. + // Then propagate carries through two words (the first carry does not have to be accurately calculated because it will + // be accurately calculated by carryFinal later on). The second carry must be accurate for output to the carry shuttle. + wu[i] = weightAndCarryPairSloppy(SWAP_XY(uF2[i]), SWAP_XY(u61[i]), invWeight1, invWeight2, weight_shift0, weight_shift1, + // For an LL test, add -2 as the very initial "carry in" + // We'd normally use logical &&, but the compiler whines with warning and bitwise fixes it + LL != 0, (LL & (i == 0) & (line==0) & (me == 0)) ? -2 : 0, biglit0, biglit1, &carry[i], &roundMax, &carryMax); + + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_bigstep; + if (weight_shift > 61) weight_shift -= 61; + } + combo_counter = starting_combo_counter; // Restore starting counter for applying weights after carry propagation + +#if ROE + updateStats(bufROE, posROE, roundMax); +#elif STATS & (1 << MUL3) + updateStats(bufROE, posROE, carryMax); +#endif + + // Write out our carries. Only groups 0 to H-1 need to write carries out. + // Group H is a duplicate of group 0 (producing the same results) so we don't care about group H writing out, + // but it's fine either way. + if (gr < H) { for (i32 i = 0; i < NW; ++i) { carryShuttlePtr[gr * WIDTH + CarryShuttleAccess(me, i)] = carry[i]; } } + + // Tell next line that its carries are ready + if (gr < H) { +#if OLD_FENCE + // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); + write_mem_fence(CLK_GLOBAL_MEM_FENCE); + bar(); + if (me == 0) { atomic_store((atomic_uint *) &ready[gr], 1); } +#else + write_mem_fence(CLK_GLOBAL_MEM_FENCE); + if (me % WAVEFRONT == 0) { + u32 pos = gr * (G_W / WAVEFRONT) + me / WAVEFRONT; + atomic_store((atomic_uint *) &ready[pos], 1); + } +#endif + } + + // Line zero will be redone when gr == H + if (gr == 0) { return; } + + // Do some work while our carries may not be ready +#if HAS_ASM + __asm("s_setprio 0"); +#endif + + // Calculate inverse weights + F base = optionalHalve(weights.y); + for (u32 i = 0; i < NW; ++i) { + F weight1 = i == 0 ? base : optionalHalve(fancyMul(base, fweightStep(i))); + F weight2 = optionalHalve(fancyMul(weight1, WEIGHT_STEP)); + uF2[i] = U2(weight1, weight2); + } + + // Wait until our carries are ready +#if OLD_FENCE + if (me == 0) { do { spin(); } while(!atomic_load_explicit((atomic_uint *) &ready[gr - 1], memory_order_relaxed, memory_scope_device)); } + // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); + bar(); + read_mem_fence(CLK_GLOBAL_MEM_FENCE); + // Clear carry ready flag for next iteration + if (me == 0) ready[gr - 1] = 0; +#else + u32 pos = (gr - 1) * (G_W / WAVEFRONT) + me / WAVEFRONT; + if (me % WAVEFRONT == 0) { + do { spin(); } while(atomic_load_explicit((atomic_uint *) &ready[pos], memory_order_relaxed, memory_scope_device) == 0); + } + mem_fence(CLK_GLOBAL_MEM_FENCE); + // Clear carry ready flag for next iteration + if (me % WAVEFRONT == 0) ready[(gr - 1) * (G_W / WAVEFRONT) + me / WAVEFRONT] = 0; +#endif +#if HAS_ASM + __asm("s_setprio 1"); +#endif + + // Read from the carryShuttle carries produced by the previous WIDTH row. Rotate carries from the last WIDTH row. + // The new carry layout lets the compiler generate global_load_dwordx4 instructions. + if (gr < H) { + for (i32 i = 0; i < NW; ++i) { + carry[i] = carryShuttlePtr[(gr - 1) * WIDTH + CarryShuttleAccess(me, i)]; + } + } else { + +#if !OLD_FENCE + // For gr==H we need the barrier since the carry reading is shifted, thus the per-wavefront trick does not apply. + bar(); +#endif + + for (i32 i = 0; i < NW; ++i) { + carry[i] = carryShuttlePtr[(gr - 1) * WIDTH + CarryShuttleAccess((me + G_W - 1) % G_W, i) /* ((me!=0) + NW - 1 + i) % NW*/]; + } + + if (me == 0) { + carry[NW] = carry[NW-1]; + for (i32 i = NW-1; i; --i) { carry[i] = carry[i-1]; } + carry[0] = carry[NW]; + } + } + + // Apply each 32 or 64 bit carry to the 2 words. Apply weights. + for (i32 i = 0; i < NW; ++i) { + // Generate the second weight shift + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 61) weight_shift -= 61; + u32 weight_shift1 = weight_shift; + // Generate big-word/little-word flag, propagate final carry + bool biglit0 = frac_bits <= FRAC_BPW_HI; + wu[i] = carryFinal(wu[i], carry[i], biglit0); + uF2[i] = U2(uF2[i].x * wu[i].x, uF2[i].y * wu[i].y); + u61[i] = U2(shl(make_Z61(wu[i].x), weight_shift0), shl(make_Z61(wu[i].y), weight_shift1)); + + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_bigstep; + if (weight_shift > 61) weight_shift -= 61; + } + + bar(); + + new_fft_WIDTH2(ldsF2, uF2, smallTrigF2); + writeCarryFusedLine(uF2, outF2, line); + + bar(); + + new_fft_WIDTH2(lds61, u61, smallTrig61); + writeCarryFusedLine(u61, out61, line); +} + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2)*GF(M61^2) */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT3161 + +// The "carryFused" is equivalent to the sequence: fftW, carryA, carryB, fftPremul. +// It uses "stairway forwarding" (forwarding carry data from one workgroup to the next) +KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P(u32) ready, Trig smallTrig, P(uint) bufROE) { + +#if 0 // fft_WIDTH uses shufl_int instead of shufl + local GF61 lds61[WIDTH / 4]; +#else + local GF61 lds61[WIDTH / 2]; +#endif + local GF31 *lds31 = (local GF31 *) lds61; + + GF31 u31[NW]; + GF61 u61[NW]; + + u32 gr = get_group_id(0); + u32 me = get_local_id(0); + + u32 H = BIG_HEIGHT; + u32 line = gr % H; + + CP(GF31) in31 = (CP(GF31)) (in + DISTGF31); + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 smallTrig31 = (TrigGF31) (smallTrig + DISTWTRIGGF31); + CP(GF61) in61 = (CP(GF61)) (in + DISTGF61); + P(GF61) out61 = (P(GF61)) (out + DISTGF61); + TrigGF61 smallTrig61 = (TrigGF61) (smallTrig + DISTWTRIGGF61); + +#if HAS_ASM + __asm("s_setprio 3"); +#endif + + readCarryFusedLine(in31, u31, line); + readCarryFusedLine(in61, u61, line); + +// Try this weird FFT_width call that adds a "hidden zero" when unrolling. This prevents the compiler from finding +// common sub-expressions to re-use in the second fft_WIDTH call. Re-using this data requires dozens of VGPRs +// which causes a terrible reduction in occupancy. +#if ZEROHACK_W + u32 zerohack = get_group_id(0) / 131072; + new_fft_WIDTH1(lds31 + zerohack, u31, smallTrig31 + zerohack); + bar(); + new_fft_WIDTH1(lds61 + zerohack, u61, smallTrig61 + zerohack); +#else + new_fft_WIDTH1(lds31, u31, smallTrig31); + bar(); + new_fft_WIDTH1(lds61, u61, smallTrig61); +#endif + + Word2 wu[NW]; + P(i64) carryShuttlePtr = (P(i64)) carryShuttle; + i64 carry[NW+1]; + +#if AMDGPU +#define CarryShuttleAccess(me,i) ((me) * NW + (i)) // Generates denser global_load_dwordx4 instructions +//#define CarryShuttleAccess(me,i) ((me) * 4 + (i)%4 + (i)/4 * 4*G_W) // Also generates global_load_dwordx4 instructions and unit stride when NW=8 +#else +#define CarryShuttleAccess(me,i) ((me) + (i) * G_W) // nVidia likes this unit stride better +#endif + + u32 roundMax = 0; + float carryMax = 0; + + u32 word_index = (me * H + line) * 2; + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + // Let s be the shift amount for word 1. The shift amount for word x is ceil(x * (s - 1) + num_big_words_less_than_x) % 31. + const u32 m31_log2_root_two = (u32) (((1ULL << 30) / NWORDS) % 31); + const u32 m31_bigword_weight_shift = (NWORDS - EXP % NWORDS) * m31_log2_root_two % 31; + const u32 m31_bigword_weight_shift_minus1 = (m31_bigword_weight_shift + 30) % 31; + const u32 m61_log2_root_two = (u32) (((1ULL << 60) / NWORDS) % 61); + const u32 m61_bigword_weight_shift = (NWORDS - EXP % NWORDS) * m61_log2_root_two % 61; + const u32 m61_bigword_weight_shift_minus1 = (m61_bigword_weight_shift + 60) % 61; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } m31_combo, m61_combo; +#define frac_bits m31_combo.a[0] +#define m31_weight_shift m31_combo.a[1] +#define m31_combo_counter m31_combo.b +#define m61_weight_shift m61_combo.a[1] +#define m61_combo_counter m61_combo.b + + const u64 m31_combo_step = ((u64) m31_bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 m31_combo_bigstep = ((G_W * H * 2 - 1) * m31_combo_step + (((u64) (G_W * H * 2 - 1) * FRAC_BPW_LO) >> 32)) % (31ULL << 32); + m31_combo_counter = word_index * m31_combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + m31_weight_shift = m31_weight_shift % 31; + u64 m31_starting_combo_counter = m31_combo_counter; // Save starting counter before adding log2_NWORDS+1 for applying weights after carry propagation + const u64 m61_combo_step = ((u64) m61_bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 m61_combo_bigstep = ((G_W * H * 2 - 1) * m61_combo_step + (((u64) (G_W * H * 2 - 1) * FRAC_BPW_LO) >> 32)) % (61ULL << 32); + m61_combo_counter = word_index * m61_combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + m61_weight_shift = m61_weight_shift % 61; + u64 m61_starting_combo_counter = m61_combo_counter; // Save starting counter before adding log2_NWORDS+1 for applying weights after carry propagation + + // We also adjust shift amount for the fact that NTT returns results multiplied by 2*NWORDS. + const u32 log2_NWORDS = (WIDTH == 256 ? 8 : WIDTH == 512 ? 9 : WIDTH == 1024 ? 10 : 12) + + (MIDDLE == 1 ? 0 : MIDDLE == 2 ? 1 : MIDDLE == 4 ? 2 : MIDDLE == 8 ? 3 : 4) + + (SMALL_HEIGHT == 256 ? 8 : SMALL_HEIGHT == 512 ? 9 : SMALL_HEIGHT == 1024 ? 10 : 12) + 1; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift + log2_NWORDS + 1); + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift + log2_NWORDS + 1); + + // Apply the inverse weights and carry propagate pairs to generate the output carries + + for (u32 i = 0; i < NW; ++i) { + // Generate the second weight shifts + u32 m31_weight_shift0 = m31_weight_shift; + m31_combo_counter += m31_combo_step; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift); + u32 m31_weight_shift1 = m31_weight_shift; + u32 m61_weight_shift0 = m61_weight_shift; + m61_combo_counter += m61_combo_step; + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift); + u32 m61_weight_shift1 = m61_weight_shift; + + // Generate big-word/little-word flags + bool biglit0 = frac_bits <= FRAC_BPW_HI; + bool biglit1 = frac_bits >= -FRAC_BPW_HI; // Same as frac_bits + FRAC_BPW_HI <= FRAC_BPW_HI; + + // Apply the inverse weights, optionally compute roundoff error, and convert to integer. Also apply MUL3 here. + // Then propagate carries through two words (the first carry does not have to be accurately calculated because it will + // be accurately calculated by carryFinal later on). The second carry must be accurate for output to the carry shuttle. + wu[i] = weightAndCarryPairSloppy(SWAP_XY(u31[i]), SWAP_XY(u61[i]), m31_weight_shift0, m31_weight_shift1, m61_weight_shift0, m61_weight_shift1, + // For an LL test, add -2 as the very initial "carry in" + // We'd normally use logical &&, but the compiler whines with warning and bitwise fixes it + LL != 0, (LL & (i == 0) & (line==0) & (me == 0)) ? -2 : 0, biglit0, biglit1, &carry[i], &roundMax, &carryMax); + + // Generate weight shifts and frac_bits for next pair + m31_combo_counter += m31_combo_bigstep; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift); + m61_combo_counter += m61_combo_bigstep; + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift); + } + m31_combo_counter = m31_starting_combo_counter; // Restore starting counter for applying weights after carry propagation + m61_combo_counter = m61_starting_combo_counter; + +#if ROE + float fltRoundMax = (float) roundMax / (float) 0x1FFFFFFF; // For speed, roundoff was computed as 32-bit integer. Convert to float - divide by M61. + updateStats(bufROE, posROE, fltRoundMax); +#elif STATS & (1 << MUL3) + updateStats(bufROE, posROE, carryMax); +#endif + + // Write out our carries. Only groups 0 to H-1 need to write carries out. + // Group H is a duplicate of group 0 (producing the same results) so we don't care about group H writing out, + // but it's fine either way. + if (gr < H) { for (i32 i = 0; i < NW; ++i) { carryShuttlePtr[gr * WIDTH + CarryShuttleAccess(me, i)] = carry[i]; } } + + // Tell next line that its carries are ready + if (gr < H) { +#if OLD_FENCE + // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); + write_mem_fence(CLK_GLOBAL_MEM_FENCE); + bar(); + if (me == 0) { atomic_store((atomic_uint *) &ready[gr], 1); } +#else + write_mem_fence(CLK_GLOBAL_MEM_FENCE); + if (me % WAVEFRONT == 0) { + u32 pos = gr * (G_W / WAVEFRONT) + me / WAVEFRONT; + atomic_store((atomic_uint *) &ready[pos], 1); + } +#endif + } + + // Line zero will be redone when gr == H + if (gr == 0) { return; } + + // Do some work while our carries may not be ready +#if HAS_ASM + __asm("s_setprio 0"); +#endif + + // Wait until our carries are ready +#if OLD_FENCE + if (me == 0) { do { spin(); } while(!atomic_load_explicit((atomic_uint *) &ready[gr - 1], memory_order_relaxed, memory_scope_device)); } + // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); + bar(); + read_mem_fence(CLK_GLOBAL_MEM_FENCE); + // Clear carry ready flag for next iteration + if (me == 0) ready[gr - 1] = 0; +#else + u32 pos = (gr - 1) * (G_W / WAVEFRONT) + me / WAVEFRONT; + if (me % WAVEFRONT == 0) { + do { spin(); } while(atomic_load_explicit((atomic_uint *) &ready[pos], memory_order_relaxed, memory_scope_device) == 0); + } + mem_fence(CLK_GLOBAL_MEM_FENCE); + // Clear carry ready flag for next iteration + if (me % WAVEFRONT == 0) ready[(gr - 1) * (G_W / WAVEFRONT) + me / WAVEFRONT] = 0; +#endif +#if HAS_ASM + __asm("s_setprio 1"); +#endif + + // Read from the carryShuttle carries produced by the previous WIDTH row. Rotate carries from the last WIDTH row. + // The new carry layout lets the compiler generate global_load_dwordx4 instructions. + if (gr < H) { + for (i32 i = 0; i < NW; ++i) { + carry[i] = carryShuttlePtr[(gr - 1) * WIDTH + CarryShuttleAccess(me, i)]; + } + } else { + +#if !OLD_FENCE + // For gr==H we need the barrier since the carry reading is shifted, thus the per-wavefront trick does not apply. + bar(); +#endif + + for (i32 i = 0; i < NW; ++i) { + carry[i] = carryShuttlePtr[(gr - 1) * WIDTH + CarryShuttleAccess((me + G_W - 1) % G_W, i) /* ((me!=0) + NW - 1 + i) % NW*/]; + } + + if (me == 0) { + carry[NW] = carry[NW-1]; + for (i32 i = NW-1; i; --i) { carry[i] = carry[i-1]; } + carry[0] = carry[NW]; + } + } + + // Apply each 32 or 64 bit carry to the 2 words. Apply weights. + for (i32 i = 0; i < NW; ++i) { + // Generate the second weight shifts + u32 m31_weight_shift0 = m31_weight_shift; + m31_combo_counter += m31_combo_step; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift); + u32 m31_weight_shift1 = m31_weight_shift; + u32 m61_weight_shift0 = m61_weight_shift; + m61_combo_counter += m61_combo_step; + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift); + u32 m61_weight_shift1 = m61_weight_shift; + // Generate big-word/little-word flag, propagate final carry + bool biglit0 = frac_bits <= FRAC_BPW_HI; + wu[i] = carryFinal(wu[i], carry[i], biglit0); + u31[i] = U2(shl(make_Z31(wu[i].x), m31_weight_shift0), shl(make_Z31(wu[i].y), m31_weight_shift1)); + u61[i] = U2(shl(make_Z61(wu[i].x), m61_weight_shift0), shl(make_Z61(wu[i].y), m61_weight_shift1)); + + // Generate weight shifts and frac_bits for next pair + m31_combo_counter += m31_combo_bigstep; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift); + m61_combo_counter += m61_combo_bigstep; + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift); + } + + bar(); + + new_fft_WIDTH2(lds31, u31, smallTrig31); + writeCarryFusedLine(u31, out31, line); + + bar(); + + new_fft_WIDTH2(lds61, u61, smallTrig61); + writeCarryFusedLine(u61, out61, line); +} + + +/******************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP32*GF(M31^2)*GF(M61^2) */ +/******************************************************************************/ + +#elif FFT_TYPE == FFT323161 + +// The "carryFused" is equivalent to the sequence: fftW, carryA, carryB, fftPremul. +// It uses "stairway forwarding" (forwarding carry data from one workgroup to the next) +KERNEL(G_W) carryFused(P(T2) out, CP(T2) in, u32 posROE, P(i64) carryShuttle, P(u32) ready, Trig smallTrig, + CP(u32) bits, ConstBigTabFP32 CONST_THREAD_WEIGHTS, BigTabFP32 THREAD_WEIGHTS, P(uint) bufROE) { + +#if 0 // fft_WIDTH uses shufl_int instead of shufl + local GF61 lds61[WIDTH / 4]; +#else + local GF61 lds61[WIDTH / 2]; +#endif + local F2 *ldsF2 = (local F2 *) lds61; + local GF31 *lds31 = (local GF31 *) lds61; + + F2 uF2[NW]; + GF31 u31[NW]; + GF61 u61[NW]; + + u32 gr = get_group_id(0); + u32 me = get_local_id(0); + + u32 H = BIG_HEIGHT; + u32 line = gr % H; + + CP(F2) inF2 = (CP(F2)) in; + P(F2) outF2 = (P(F2)) out; + TrigFP32 smallTrigF2 = (TrigFP32) smallTrig; + CP(GF31) in31 = (CP(GF31)) (in + DISTGF31); + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 smallTrig31 = (TrigGF31) (smallTrig + DISTWTRIGGF31); + CP(GF61) in61 = (CP(GF61)) (in + DISTGF61); + P(GF61) out61 = (P(GF61)) (out + DISTGF61); + TrigGF61 smallTrig61 = (TrigGF61) (smallTrig + DISTWTRIGGF61); + +#if HAS_ASM + __asm("s_setprio 3"); +#endif + + readCarryFusedLine(inF2, uF2, line); + readCarryFusedLine(in31, u31, line); + readCarryFusedLine(in61, u61, line); + +// Try this weird FFT_width call that adds a "hidden zero" when unrolling. This prevents the compiler from finding +// common sub-expressions to re-use in the second fft_WIDTH call. Re-using this data requires dozens of VGPRs +// which causes a terrible reduction in occupancy. +#if ZEROHACK_W + u32 zerohack = get_group_id(0) / 131072; + new_fft_WIDTH1(ldsF2 + zerohack, uF2, smallTrigF2 + zerohack); + bar(); + new_fft_WIDTH1(lds31 + zerohack, u31, smallTrig31 + zerohack); + bar(); + new_fft_WIDTH1(lds61 + zerohack, u61, smallTrig61 + zerohack); +#else + new_fft_WIDTH1(ldsF2, uF2, smallTrigF2); + bar(); + new_fft_WIDTH1(lds31, u31, smallTrig31); + bar(); + new_fft_WIDTH1(lds61, u61, smallTrig61); +#endif + + Word2 wu[NW]; +#if AMDGPU + F2 weights = fancyMul(THREAD_WEIGHTS[me], THREAD_WEIGHTS[G_W + line]); +#else + F2 weights = fancyMul(CONST_THREAD_WEIGHTS[me], THREAD_WEIGHTS[G_W + line]); // On nVidia, don't pollute the constant cache with line weights +#endif + P(i64) carryShuttlePtr = (P(i64)) carryShuttle; + i64 carry[NW+1]; + +#if AMDGPU +#define CarryShuttleAccess(me,i) ((me) * NW + (i)) // Generates denser global_load_dwordx4 instructions +//#define CarryShuttleAccess(me,i) ((me) * 4 + (i)%4 + (i)/4 * 4*G_W) // Also generates global_load_dwordx4 instructions and unit stride when NW=8 +#else +#define CarryShuttleAccess(me,i) ((me) + (i) * G_W) // nVidia likes this unit stride better +#endif + + float roundMax = 0; + float carryMax = 0; + + u32 word_index = (me * H + line) * 2; + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + // Let s be the shift amount for word 1. The shift amount for word x is ceil(x * (s - 1) + num_big_words_less_than_x) % 31. + const u32 m31_log2_root_two = (u32) (((1ULL << 30) / NWORDS) % 31); + const u32 m31_bigword_weight_shift = (NWORDS - EXP % NWORDS) * m31_log2_root_two % 31; + const u32 m31_bigword_weight_shift_minus1 = (m31_bigword_weight_shift + 30) % 31; + const u32 m61_log2_root_two = (u32) (((1ULL << 60) / NWORDS) % 61); + const u32 m61_bigword_weight_shift = (NWORDS - EXP % NWORDS) * m61_log2_root_two % 61; + const u32 m61_bigword_weight_shift_minus1 = (m61_bigword_weight_shift + 60) % 61; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } m31_combo, m61_combo; +#define frac_bits m31_combo.a[0] +#define m31_weight_shift m31_combo.a[1] +#define m31_combo_counter m31_combo.b +#define m61_weight_shift m61_combo.a[1] +#define m61_combo_counter m61_combo.b + + const u64 m31_combo_step = ((u64) m31_bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 m31_combo_bigstep = ((G_W * H * 2 - 1) * m31_combo_step + (((u64) (G_W * H * 2 - 1) * FRAC_BPW_LO) >> 32)) % (31ULL << 32); + m31_combo_counter = word_index * m31_combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + m31_weight_shift = m31_weight_shift % 31; + u64 m31_starting_combo_counter = m31_combo_counter; // Save starting counter before adding log2_NWORDS+1 for applying weights after carry propagation + const u64 m61_combo_step = ((u64) m61_bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 m61_combo_bigstep = ((G_W * H * 2 - 1) * m61_combo_step + (((u64) (G_W * H * 2 - 1) * FRAC_BPW_LO) >> 32)) % (61ULL << 32); + m61_combo_counter = word_index * m61_combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + m61_weight_shift = m61_weight_shift % 61; + u64 m61_starting_combo_counter = m61_combo_counter; // Save starting counter before adding log2_NWORDS+1 for applying weights after carry propagation + + // We also adjust shift amount for the fact that NTT returns results multiplied by 2*NWORDS. + const u32 log2_NWORDS = (WIDTH == 256 ? 8 : WIDTH == 512 ? 9 : WIDTH == 1024 ? 10 : 12) + + (MIDDLE == 1 ? 0 : MIDDLE == 2 ? 1 : MIDDLE == 4 ? 2 : MIDDLE == 8 ? 3 : 4) + + (SMALL_HEIGHT == 256 ? 8 : SMALL_HEIGHT == 512 ? 9 : SMALL_HEIGHT == 1024 ? 10 : 12) + 1; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift + log2_NWORDS + 1); + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift + log2_NWORDS + 1); + + // Apply the inverse weights and carry propagate pairs to generate the output carries + + F invBase = optionalDouble(weights.x); + for (u32 i = 0; i < NW; ++i) { + // Generate the FP32 weights and second GF31 and GF61 weight shift + F invWeight1 = i == 0 ? invBase : optionalDouble(fancyMul(invBase, iweightStep(i))); + F invWeight2 = optionalDouble(fancyMul(invWeight1, IWEIGHT_STEP)); + u32 m31_weight_shift0 = m31_weight_shift; + m31_combo_counter += m31_combo_step; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift); + u32 m31_weight_shift1 = m31_weight_shift; + u32 m61_weight_shift0 = m61_weight_shift; + m61_combo_counter += m61_combo_step; + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift); + u32 m61_weight_shift1 = m61_weight_shift; + + // Generate big-word/little-word flags + bool biglit0 = frac_bits <= FRAC_BPW_HI; + bool biglit1 = frac_bits >= -FRAC_BPW_HI; // Same as frac_bits + FRAC_BPW_HI <= FRAC_BPW_HI; + + // Apply the inverse weights, optionally compute roundoff error, and convert to integer. Also apply MUL3 here. + // Then propagate carries through two words (the first carry does not have to be accurately calculated because it will + // be accurately calculated by carryFinal later on). The second carry must be accurate for output to the carry shuttle. + wu[i] = weightAndCarryPairSloppy(SWAP_XY(uF2[i]), SWAP_XY(u31[i]), SWAP_XY(u61[i]), invWeight1, invWeight2, m31_weight_shift0, m31_weight_shift1, m61_weight_shift0, m61_weight_shift1, + // For an LL test, add -2 as the very initial "carry in" + // We'd normally use logical &&, but the compiler whines with warning and bitwise fixes it + LL != 0, (LL & (i == 0) & (line==0) & (me == 0)) ? -2 : 0, biglit0, biglit1, &carry[i], &roundMax, &carryMax); + + // Generate weight shifts and frac_bits for next pair + m31_combo_counter += m31_combo_bigstep; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift); + m61_combo_counter += m61_combo_bigstep; + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift); + } + m31_combo_counter = m31_starting_combo_counter; // Restore starting counter for applying weights after carry propagation + m61_combo_counter = m61_starting_combo_counter; + +#if ROE + updateStats(bufROE, posROE, roundMax); +#elif STATS & (1 << MUL3) + updateStats(bufROE, posROE, carryMax); +#endif + + // Write out our carries. Only groups 0 to H-1 need to write carries out. + // Group H is a duplicate of group 0 (producing the same results) so we don't care about group H writing out, + // but it's fine either way. + if (gr < H) { for (i32 i = 0; i < NW; ++i) { carryShuttlePtr[gr * WIDTH + CarryShuttleAccess(me, i)] = carry[i]; } } + + // Tell next line that its carries are ready + if (gr < H) { +#if OLD_FENCE + // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); + write_mem_fence(CLK_GLOBAL_MEM_FENCE); + bar(); + if (me == 0) { atomic_store((atomic_uint *) &ready[gr], 1); } +#else + write_mem_fence(CLK_GLOBAL_MEM_FENCE); + if (me % WAVEFRONT == 0) { + u32 pos = gr * (G_W / WAVEFRONT) + me / WAVEFRONT; + atomic_store((atomic_uint *) &ready[pos], 1); + } +#endif + } + + // Line zero will be redone when gr == H + if (gr == 0) { return; } + + // Do some work while our carries may not be ready +#if HAS_ASM + __asm("s_setprio 0"); +#endif + + // Calculate inverse weights + F base = optionalHalve(weights.y); + for (u32 i = 0; i < NW; ++i) { + F weight1 = i == 0 ? base : optionalHalve(fancyMul(base, fweightStep(i))); + F weight2 = optionalHalve(fancyMul(weight1, WEIGHT_STEP)); + uF2[i] = U2(weight1, weight2); + } + + // Wait until our carries are ready +#if OLD_FENCE + if (me == 0) { do { spin(); } while(!atomic_load_explicit((atomic_uint *) &ready[gr - 1], memory_order_relaxed, memory_scope_device)); } + // work_group_barrier(CLK_GLOBAL_MEM_FENCE, memory_scope_device); + bar(); + read_mem_fence(CLK_GLOBAL_MEM_FENCE); + // Clear carry ready flag for next iteration + if (me == 0) ready[gr - 1] = 0; +#else + u32 pos = (gr - 1) * (G_W / WAVEFRONT) + me / WAVEFRONT; + if (me % WAVEFRONT == 0) { + do { spin(); } while(atomic_load_explicit((atomic_uint *) &ready[pos], memory_order_relaxed, memory_scope_device) == 0); + } + mem_fence(CLK_GLOBAL_MEM_FENCE); + // Clear carry ready flag for next iteration + if (me % WAVEFRONT == 0) ready[(gr - 1) * (G_W / WAVEFRONT) + me / WAVEFRONT] = 0; +#endif +#if HAS_ASM + __asm("s_setprio 1"); +#endif + + // Read from the carryShuttle carries produced by the previous WIDTH row. Rotate carries from the last WIDTH row. + // The new carry layout lets the compiler generate global_load_dwordx4 instructions. + if (gr < H) { + for (i32 i = 0; i < NW; ++i) { + carry[i] = carryShuttlePtr[(gr - 1) * WIDTH + CarryShuttleAccess(me, i)]; + } + } else { + +#if !OLD_FENCE + // For gr==H we need the barrier since the carry reading is shifted, thus the per-wavefront trick does not apply. + bar(); +#endif + + for (i32 i = 0; i < NW; ++i) { + carry[i] = carryShuttlePtr[(gr - 1) * WIDTH + CarryShuttleAccess((me + G_W - 1) % G_W, i) /* ((me!=0) + NW - 1 + i) % NW*/]; + } + + if (me == 0) { + carry[NW] = carry[NW-1]; + for (i32 i = NW-1; i; --i) { carry[i] = carry[i-1]; } + carry[0] = carry[NW]; + } + } + + // Apply each 32 or 64 bit carry to the 2 words. Apply weights. + for (i32 i = 0; i < NW; ++i) { + // Generate the second weight shifts + u32 m31_weight_shift0 = m31_weight_shift; + m31_combo_counter += m31_combo_step; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift); + u32 m31_weight_shift1 = m31_weight_shift; + u32 m61_weight_shift0 = m61_weight_shift; + m61_combo_counter += m61_combo_step; + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift); + u32 m61_weight_shift1 = m61_weight_shift; + // Generate big-word/little-word flag, propagate final carry + bool biglit0 = frac_bits <= FRAC_BPW_HI; + wu[i] = carryFinal(wu[i], carry[i], biglit0); + uF2[i] = U2(uF2[i].x * wu[i].x, uF2[i].y * wu[i].y); + u31[i] = U2(shl(make_Z31(wu[i].x), m31_weight_shift0), shl(make_Z31(wu[i].y), m31_weight_shift1)); + u61[i] = U2(shl(make_Z61(wu[i].x), m61_weight_shift0), shl(make_Z61(wu[i].y), m61_weight_shift1)); + + // Generate weight shifts and frac_bits for next pair + m31_combo_counter += m31_combo_bigstep; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift); + m61_combo_counter += m61_combo_bigstep; + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift); + } + + bar(); + + new_fft_WIDTH2(ldsF2, uF2, smallTrigF2); + writeCarryFusedLine(uF2, outF2, line); + + bar(); + + new_fft_WIDTH2(lds31, u31, smallTrig31); + writeCarryFusedLine(u31, out31, line); + + bar(); + + new_fft_WIDTH2(lds61, u61, smallTrig61); + writeCarryFusedLine(u61, out61, line); +} + + +#else +error - missing CarryFused kernel implementation +#endif diff --git a/src/cl/carryinc.cl b/src/cl/carryinc.cl index 1ee8e440..459476da 100644 --- a/src/cl/carryinc.cl +++ b/src/cl/carryinc.cl @@ -2,27 +2,22 @@ // This file is included with different definitions for iCARRY -Word2 OVERLOAD carryPair(long2 u, iCARRY *outCarry, bool b1, bool b2, float* carryMax) { - iCARRY midCarry; - Word a = carryStep(u.x, &midCarry, b1); - Word b = carryStep(u.y + midCarry, outCarry, b2); -// #if STATS & 0x5 - *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); -// #endif - return (Word2) (a, b); -} - Word2 OVERLOAD carryFinal(Word2 u, iCARRY inCarry, bool b1) { i32 tmpCarry; - u.x = carryStep(u.x + inCarry, &tmpCarry, b1); + u.x = carryStepSignedSloppy(u.x + inCarry, &tmpCarry, b1); u.y += tmpCarry; return u; } +/*******************************************************************************************/ +/* Original FP64 version to start the carry propagation process for a pair of FFT values */ +/*******************************************************************************************/ + +#if FFT_TYPE == FFT64 // Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. // Then propagate carries through two words. Generate the output carry. -Word2 OVERLOAD weightAndCarryPair(T2 u, T2 invWeight, i64 inCarry, float* maxROE, iCARRY *outCarry, bool b1, bool b2, float* carryMax) { +Word2 OVERLOAD weightAndCarryPair(T2 u, T2 invWeight, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { iCARRY midCarry; i64 tmp1 = weightAndCarryOne(u.x, invWeight.x, inCarry, maxROE, sizeof(midCarry) == 4); Word a = carryStep(tmp1, &midCarry, b1); @@ -32,13 +27,266 @@ Word2 OVERLOAD weightAndCarryPair(T2 u, T2 invWeight, i64 inCarry, float* maxROE return (Word2) (a, b); } -// Like weightAndCarryPair except that a strictly accuracy calculation of the first carry is not required. -Word2 OVERLOAD weightAndCarryPairSloppy(T2 u, T2 invWeight, i64 inCarry, float* maxROE, iCARRY *outCarry, bool b1, bool b2, float* carryMax) { +// Like weightAndCarryPair except that a strictly accurate calculation of the first Word and carry is not required. Second word may also be sloppy. +Word2 OVERLOAD weightAndCarryPairSloppy(T2 u, T2 invWeight, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { iCARRY midCarry; i64 tmp1 = weightAndCarryOne(u.x, invWeight.x, inCarry, maxROE, sizeof(midCarry) == 4); - Word a = carryStepSloppy(tmp1, &midCarry, b1); + Word a = carryStepUnsignedSloppy(tmp1, &midCarry, b1); i64 tmp2 = weightAndCarryOne(u.y, invWeight.y, midCarry, maxROE, sizeof(midCarry) == 4); + Word b = carryStepSignedSloppy(tmp2, outCarry, b2); + *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); + return (Word2) (a, b); +} + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT32 + +// Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. +// Then propagate carries through two words. Generate the output carry. +Word2 OVERLOAD weightAndCarryPair(F2 u, F2 invWeight, iCARRY inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { + i32 midCarry; + i32 tmp1 = weightAndCarryOne(u.x, invWeight.x, inCarry, maxROE, sizeof(midCarry) == 4); + Word a = carryStep(tmp1, &midCarry, b1); + i32 tmp2 = weightAndCarryOne(u.y, invWeight.y, midCarry, maxROE, sizeof(midCarry) == 4); + Word b = carryStep(tmp2, outCarry, b2); + *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); + return (Word2) (a, b); +} + +// Like weightAndCarryPair except that a strictly accurate calculation of the first Word and carry is not required. Second word may also be sloppy. +Word2 OVERLOAD weightAndCarryPairSloppy(F2 u, F2 invWeight, iCARRY inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { + i32 midCarry; + i32 tmp1 = weightAndCarryOne(u.x, invWeight.x, inCarry, maxROE, sizeof(midCarry) == 4); + Word a = carryStepUnsignedSloppy(tmp1, &midCarry, b1); + i32 tmp2 = weightAndCarryOne(u.y, invWeight.y, midCarry, maxROE, sizeof(midCarry) == 4); + Word b = carryStepSignedSloppy(tmp2, outCarry, b2); + *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); + return (Word2) (a, b); +} + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT31 + +// Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. +// Then propagate carries through two words. Generate the output carry. +Word2 OVERLOAD weightAndCarryPair(GF31 u, u32 invWeight1, u32 invWeight2, i32 inCarry, bool b1, bool b2, iCARRY *outCarry, u32* maxROE, float* carryMax) { + iCARRY midCarry; + i64 tmp1 = weightAndCarryOne(u.x, invWeight1, inCarry, maxROE); + Word a = carryStep(tmp1, &midCarry, b1); + i64 tmp2 = weightAndCarryOne(u.y, invWeight2, midCarry, maxROE); Word b = carryStep(tmp2, outCarry, b2); *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); return (Word2) (a, b); } + +// Like weightAndCarryPair except that a strictly accurate calculation of the first Word and carry is not required. Second word may also be sloppy. +Word2 OVERLOAD weightAndCarryPairSloppy(GF31 u, u32 invWeight1, u32 invWeight2, i32 inCarry, bool b1, bool b2, iCARRY *outCarry, u32* maxROE, float* carryMax) { + iCARRY midCarry; + i64 tmp1 = weightAndCarryOne(u.x, invWeight1, inCarry, maxROE); + Word a = carryStepUnsignedSloppy(tmp1, &midCarry, b1); + i64 tmp2 = weightAndCarryOne(u.y, invWeight2, midCarry, maxROE); + Word b = carryStepSignedSloppy(tmp2, outCarry, b2); + *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); + return (Word2) (a, b); +} + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT61 + +// Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. +// Then propagate carries through two words. Generate the output carry. +Word2 OVERLOAD weightAndCarryPair(GF61 u, u32 invWeight1, u32 invWeight2, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, u32* maxROE, float* carryMax) { + iCARRY midCarry; + i64 tmp1 = weightAndCarryOne(u.x, invWeight1, inCarry, maxROE); + Word a = carryStep(tmp1, &midCarry, b1); + i64 tmp2 = weightAndCarryOne(u.y, invWeight2, midCarry, maxROE); + Word b = carryStep(tmp2, outCarry, b2); + *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); + return (Word2) (a, b); +} + +// Like weightAndCarryPair except that a strictly accurate calculation of the first Word and carry is not required. Second word may also be sloppy. +Word2 OVERLOAD weightAndCarryPairSloppy(GF61 u, u32 invWeight1, u32 invWeight2, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, u32* maxROE, float* carryMax) { + iCARRY midCarry; + i64 tmp1 = weightAndCarryOne(u.x, invWeight1, inCarry, maxROE); + Word a = carryStepUnsignedSloppy(tmp1, &midCarry, b1); + i64 tmp2 = weightAndCarryOne(u.y, invWeight2, midCarry, maxROE); + Word b = carryStepSignedSloppy(tmp2, outCarry, b2); + *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); + return (Word2) (a, b); +} + + +/**************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP64 & GF(M31^2) */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT6431 + +// Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. +// Then propagate carries through two words. Generate the output carry. +Word2 OVERLOAD weightAndCarryPair(T2 u, GF31 u31, T invWeight1, T invWeight2, u32 m31_invWeight1, u32 m31_invWeight2, + bool hasInCarry, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { + i64 midCarry; + i96 tmp1 = weightAndCarryOne(u.x, u31.x, invWeight1, m31_invWeight1, hasInCarry, inCarry, maxROE); + Word a = carryStep(tmp1, &midCarry, b1); + i96 tmp2 = weightAndCarryOne(u.y, u31.y, invWeight2, m31_invWeight2, true, midCarry, maxROE); + Word b = carryStep(tmp2, outCarry, b2); + *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); + return (Word2) (a, b); +} + +// Like weightAndCarryPair except that a strictly accurate calculation of the first Word and carry is not required. Second word may also be sloppy. +Word2 OVERLOAD weightAndCarryPairSloppy(T2 u, GF31 u31, T invWeight1, T invWeight2, u32 m31_invWeight1, u32 m31_invWeight2, + bool hasInCarry, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { + i64 midCarry; + i96 tmp1 = weightAndCarryOne(u.x, u31.x, invWeight1, m31_invWeight1, hasInCarry, inCarry, maxROE); + Word a = carryStepUnsignedSloppy(tmp1, &midCarry, b1); + i96 tmp2 = weightAndCarryOne(u.y, u31.y, invWeight2, m31_invWeight2, true, midCarry, maxROE); + Word b = carryStepSignedSloppy(tmp2, outCarry, b2); + *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); + return (Word2) (a, b); +} + + +/**************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP32 & GF(M31^2) */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT3231 + +// Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. +// Then propagate carries through two words. Generate the output carry. +Word2 OVERLOAD weightAndCarryPair(F2 uF2, GF31 u31, F invWeight1, F invWeight2, u32 m31_invWeight1, u32 m31_invWeight2, + i32 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { + i32 midCarry; + i64 tmp1 = weightAndCarryOne(uF2.x, u31.x, invWeight1, m31_invWeight1, inCarry, maxROE); + Word a = carryStep(tmp1, &midCarry, b1); + i64 tmp2 = weightAndCarryOne(uF2.y, u31.y, invWeight2, m31_invWeight2, midCarry, maxROE); + Word b = carryStep(tmp2, outCarry, b2); + *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); + return (Word2) (a, b); +} + +// Like weightAndCarryPair except that a strictly accurate calculation of the first Word and carry is not required. Second word may also be sloppy. +Word2 OVERLOAD weightAndCarryPairSloppy(F2 uF2, GF31 u31, F invWeight1, F invWeight2, u32 m31_invWeight1, u32 m31_invWeight2, + i32 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { + i32 midCarry; + i64 tmp1 = weightAndCarryOne(uF2.x, u31.x, invWeight1, m31_invWeight1, inCarry, maxROE); + Word a = carryStepUnsignedSloppy(tmp1, &midCarry, b1); + i64 tmp2 = weightAndCarryOne(uF2.y, u31.y, invWeight2, m31_invWeight2, midCarry, maxROE); + Word b = carryStepSignedSloppy(tmp2, outCarry, b2); + *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); + return (Word2) (a, b); +} + + +/**************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP32 & GF(M61^2) */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT3261 + +// Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. +// Then propagate carries through two words. Generate the output carry. +Word2 OVERLOAD weightAndCarryPair(F2 uF2, GF61 u61, F invWeight1, F invWeight2, u32 m61_invWeight1, u32 m61_invWeight2, + bool hasInCarry, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { + i64 midCarry; + i96 tmp1 = weightAndCarryOne(uF2.x, u61.x, invWeight1, m61_invWeight1, hasInCarry, inCarry, maxROE); + Word a = carryStep(tmp1, &midCarry, b1); + i96 tmp2 = weightAndCarryOne(uF2.y, u61.y, invWeight2, m61_invWeight2, true, midCarry, maxROE); + Word b = carryStep(tmp2, outCarry, b2); + *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); + return (Word2) (a, b); +} + +// Like weightAndCarryPair except that a strictly accurate calculation of the first Word and carry is not required. Second word may also be sloppy. +Word2 OVERLOAD weightAndCarryPairSloppy(F2 uF2, GF61 u61, F invWeight1, F invWeight2, u32 m61_invWeight1, u32 m61_invWeight2, + bool hasInCarry, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { + i64 midCarry; + i96 tmp1 = weightAndCarryOne(uF2.x, u61.x, invWeight1, m61_invWeight1, hasInCarry, inCarry, maxROE); + Word a = carryStepUnsignedSloppy(tmp1, &midCarry, b1); + i96 tmp2 = weightAndCarryOne(uF2.y, u61.y, invWeight2, m61_invWeight2, true, midCarry, maxROE); + Word b = carryStepSignedSloppy(tmp2, outCarry, b2); + *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); + return (Word2) (a, b); +} + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2)*GF(M61^2) */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT3161 + +// Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. +// Then propagate carries through two words. Generate the output carry. +Word2 OVERLOAD weightAndCarryPair(GF31 u31, GF61 u61, u32 m31_invWeight1, u32 m31_invWeight2, u32 m61_invWeight1, u32 m61_invWeight2, + bool hasInCarry, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, u32* maxROE, float* carryMax) { + iCARRY midCarry; + i96 tmp1 = weightAndCarryOne(u31.x, u61.x, m31_invWeight1, m61_invWeight1, hasInCarry, inCarry, maxROE); + Word a = carryStep(tmp1, &midCarry, b1); + i96 tmp2 = weightAndCarryOne(u31.y, u61.y, m31_invWeight2, m61_invWeight2, true, midCarry, maxROE); + Word b = carryStep(tmp2, outCarry, b2); + *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); + return (Word2) (a, b); +} + +// Like weightAndCarryPair except that a strictly accurate calculation of the first Word and carry is not required. Second word may also be sloppy. +Word2 OVERLOAD weightAndCarryPairSloppy(GF31 u31, GF61 u61, u32 m31_invWeight1, u32 m31_invWeight2, u32 m61_invWeight1, u32 m61_invWeight2, + bool hasInCarry, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, u32* maxROE, float* carryMax) { + iCARRY midCarry; + i96 tmp1 = weightAndCarryOne(u31.x, u61.x, m31_invWeight1, m61_invWeight1, hasInCarry, inCarry, maxROE); + Word a = carryStepUnsignedSloppy(tmp1, &midCarry, b1); + i96 tmp2 = weightAndCarryOne(u31.y, u61.y, m31_invWeight2, m61_invWeight2, true, midCarry, maxROE); + Word b = carryStepSignedSloppy(tmp2, outCarry, b2); + *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); + return (Word2) (a, b); +} + +/******************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP32*GF(M31^2)*GF(M61^2) */ +/******************************************************************************/ + +#elif FFT_TYPE == FFT323161 + +// Apply inverse weights, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. +// Then propagate carries through two words. Generate the output carry. +Word2 OVERLOAD weightAndCarryPair(F2 uF2, GF31 u31, GF61 u61, F invWeight1, F invWeight2, u32 m31_invWeight1, u32 m31_invWeight2, + u32 m61_invWeight1, u32 m61_invWeight2, bool hasInCarry, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { + iCARRY midCarry; + i128 tmp1 = weightAndCarryOne(uF2.x, u31.x, u61.x, invWeight1, m31_invWeight1, m61_invWeight1, hasInCarry, inCarry, maxROE); + Word a = carryStep(tmp1, &midCarry, b1); + i128 tmp2 = weightAndCarryOne(uF2.y, u31.y, u61.y, invWeight2, m31_invWeight2, m61_invWeight2, true, midCarry, maxROE); + Word b = carryStep(tmp2, outCarry, b2); + *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); + return (Word2) (a, b); +} + +// Like weightAndCarryPair except that a strictly accurate calculation of the first Word and carry is not required. Second word may also be sloppy. +Word2 OVERLOAD weightAndCarryPairSloppy(F2 uF2, GF31 u31, GF61 u61, F invWeight1, F invWeight2, u32 m31_invWeight1, u32 m31_invWeight2, + u32 m61_invWeight1, u32 m61_invWeight2, bool hasInCarry, i64 inCarry, bool b1, bool b2, iCARRY *outCarry, float* maxROE, float* carryMax) { + iCARRY midCarry; + i128 tmp1 = weightAndCarryOne(uF2.x, u31.x, u61.x, invWeight1, m31_invWeight1, m61_invWeight1, hasInCarry, inCarry, maxROE); + Word a = carryStepUnsignedSloppy(tmp1, &midCarry, b1); + i128 tmp2 = weightAndCarryOne(uF2.y, u31.y, u61.y, invWeight2, m31_invWeight2, m61_invWeight2, true, midCarry, maxROE); + Word b = carryStepSignedSloppy(tmp2, outCarry, b2); + *carryMax = max(*carryMax, max(boundCarry(midCarry), boundCarry(*outCarry))); + return (Word2) (a, b); +} + +#else +error - missing weightAndCarryPair implementation +#endif diff --git a/src/cl/carryutil.cl b/src/cl/carryutil.cl index ed2545f7..6cdff372 100644 --- a/src/cl/carryutil.cl +++ b/src/cl/carryutil.cl @@ -3,56 +3,99 @@ #include "base.cl" #include "math.cl" -#if STATS || ROE -void updateStats(global uint *bufROE, u32 posROE, float roundMax) { - assert(roundMax >= 0); - // work_group_reduce_max() allocates an additional 256Bytes LDS for a 64lane workgroup, so avoid it. - // u32 groupRound = work_group_reduce_max(as_uint(roundMax)); - // if (get_local_id(0) == 0) { atomic_max(bufROE + posROE, groupRound); } +#if CARRY64 +typedef i64 CFcarry; +#else +typedef i32 CFcarry; +#endif - // Do the reduction directly over global mem. - atomic_max(bufROE + posROE, as_uint(roundMax)); -} +// The carry for the non-fused CarryA, CarryB, CarryM kernels. +// Simply use largest possible carry always as the split kernels are slow anyway (and seldomly used normally). +#if FFT_TYPE != FFT32 && FFT_TYPE != FFT31 +typedef i64 CarryABM; +#else +typedef i32 CarryABM; #endif +/********************************/ +/* Helper routines */ +/********************************/ + +// Return unsigned low bits (number of bits must be between 1 and 31) +#if defined(__has_builtin) && __has_builtin(__builtin_amdgcn_ubfe) +u32 OVERLOAD ulowBits(i32 u, u32 bits) { return __builtin_amdgcn_ubfe(u, 0, bits); } +#elif HAS_PTX >= 700 // szext instruction requires sm_70 support or higher +u32 OVERLOAD ulowBits(i32 u, u32 bits) { u32 res; __asm("szext.clamp.u32 %0, %1, %2;" : "=r"(res) : "r"(u), "r"(bits)); return res; } +#else +u32 OVERLOAD ulowBits(i32 u, u32 bits) { return (((u32) u << (32 - bits)) >> (32 - bits)); } +#endif +u32 OVERLOAD ulowBits(u32 u, u32 bits) { return ulowBits((i32) u, bits); } +// Return unsigned low bits (number of bits must be between 1 and 63) +u64 OVERLOAD ulowBits(i64 u, u32 bits) { return (((u64) u << (64 - bits)) >> (64 - bits)); } +u64 OVERLOAD ulowBits(u64 u, u32 bits) { return ulowBits((i64) u, bits); } + +// Return unsigned low bits where number of bits is known at compile time (number of bits can be 0 to 32) +u32 OVERLOAD ulowFixedBits(i32 u, const u32 bits) { if (bits == 32) return u; return u & ((1 << bits) - 1); } +u32 OVERLOAD ulowFixedBits(u32 u, const u32 bits) { return ulowFixedBits((i32) u, bits); } +// Return unsigned low bits where number of bits is known at compile time (number of bits can be 0 to 64) +u64 OVERLOAD ulowFixedBits(i64 u, const u32 bits) { return u & ((1LL << bits) - 1); } +u64 OVERLOAD ulowFixedBits(u64 u, const u32 bits) { return ulowFixedBits((i64) u, bits); } + +// Return signed low bits (number of bits must be between 1 and 31) #if defined(__has_builtin) && __has_builtin(__builtin_amdgcn_sbfe) -i32 lowBits(i32 u, u32 bits) { return __builtin_amdgcn_sbfe(u, 0, bits); } +i32 OVERLOAD lowBits(i32 u, u32 bits) { return __builtin_amdgcn_sbfe(u, 0, bits); } +#elif HAS_PTX >= 700 // szext instruction requires sm_70 support or higher +i32 OVERLOAD lowBits(i32 u, u32 bits) { i32 res; __asm("szext.clamp.s32 %0, %1, %2;" : "=r"(res) : "r"(u), "r"(bits)); return res; } +#else +i32 OVERLOAD lowBits(i32 u, u32 bits) { return ((u << (32 - bits)) >> (32 - bits)); } +#endif +i32 OVERLOAD lowBits(u32 u, u32 bits) { return lowBits((i32)u, bits); } +// Return signed low bits (number of bits must be between 1 and 63) +i64 OVERLOAD lowBits(i64 u, u32 bits) { return ((u << (64 - bits)) >> (64 - bits)); } +i64 OVERLOAD lowBits(u64 u, u32 bits) { return lowBits((i64)u, bits); } + +// Return signed low bits (number of bits must be between 1 and 32) +#if HAS_PTX // szext does not return result we are looking for if bits = 32 +i32 OVERLOAD lowBitsSafe32(i32 u, u32 bits) { return lowBits(u, bits); } #else -i32 lowBits(i32 u, u32 bits) { return ((u << (32 - bits)) >> (32 - bits)); } +i32 OVERLOAD lowBitsSafe32(i32 u, u32 bits) { return lowBits((u64)u, bits); } #endif +i32 OVERLOAD lowBitsSafe32(u32 u, u32 bits) { return lowBitsSafe32((i32)u, bits); } -#if defined(__has_builtin) && __has_builtin(__builtin_amdgcn_ubfe) -i32 ulowBits(i32 u, u32 bits) { return __builtin_amdgcn_ubfe(u, 0, bits); } +// Return signed low bits where number of bits is known at compile time (number of bits can be 0 to 32) +#if defined(__has_builtin) && __has_builtin(__builtin_amdgcn_sbfe) +i32 OVERLOAD lowFixedBits(i32 u, const u32 bits) { if (bits == 32) return u; return __builtin_amdgcn_sbfe(u, 0, bits); } +#elif HAS_PTX >= 700 // szext instruction requires sm_70 support or higher +i32 OVERLOAD lowFixedBits(i32 u, const u32 bits) { if (bits == 32) return u; i32 res; __asm("szext.clamp.s32 %0, %1, %2;" : "=r"(res) : "r"(u), "r"(bits)); return res; } #else -i32 ulowBits(i32 u, u32 bits) { u32 uu = (u32) u; return ((uu << (32 - bits)) >> (32 - bits)); } +i32 OVERLOAD lowFixedBits(i32 u, const u32 bits) { if (bits == 32) return u; return (u << (32 - bits)) >> (32 - bits); } #endif +i32 OVERLOAD lowFixedBits(u32 u, const u32 bits) { return lowFixedBits((i32)u, bits); } +// Return signed low bits where number of bits is known at compile time (number of bits can be 1 to 63). The two versions are the same speed on TitanV. +i64 OVERLOAD lowFixedBits(i64 u, const u32 bits) { if (bits <= 32) return lowFixedBits((i32) u, bits); return ((u << (64 - bits)) >> (64 - bits)); } +//i64 OVERLOAD lowFixedBits(i64 u, const u32 bits) { if (bits <= 32) return lowFixedBits((i32) u, bits); return (i64) ulowFixedBits(u, bits - 1) - (u & (1LL << (bits - 1))); } +i64 OVERLOAD lowFixedBits(u64 u, const u32 bits) { return lowFixedBits((i64)u, bits); } +// Extract 32 bits from a 64-bit value (starting bit offset can be 0 to 31) #if defined(__has_builtin) && __has_builtin(__builtin_amdgcn_alignbit) i32 xtract32(i64 x, u32 bits) { return __builtin_amdgcn_alignbit(as_int2(x).y, as_int2(x).x, bits); } +#elif HAS_PTX >= 320 // shf instruction requires sm_32 support or higher +i32 xtract32(i64 x, u32 bits) { i32 res; __asm("shf.r.clamp.b32 %0, %1, %2, %3;" : "=r"(res) : "r"(as_uint2(x).x), "r"(as_uint2(x).y), "r"(bits)); return res; } #else i32 xtract32(i64 x, u32 bits) { return x >> bits; } #endif -#if !defined(LL) -#define LL 0 +// Extract 32 bits from a 64-bit value (starting bit offset can be 0 to 32) +#if HAS_PTX >= 320 // shf instruction requires sm_32 support or higher +i32 xtractSafe32(i64 x, u32 bits) { i32 res; __asm("shf.r.clamp.b32 %0, %1, %2, %3;" : "=r"(res) : "r"(as_uint2(x).x), "r"(as_uint2(x).y), "r"(bits)); return res; } +#else +i32 xtractSafe32(i64 x, u32 bits) { return x >> bits; } #endif u32 bitlen(bool b) { return EXP / NWORDS + b; } bool test(u32 bits, u32 pos) { return (bits >> pos) & 1; } -#if 0 -// Check for round off errors above a threshold (default is 0.43) -void ROUNDOFF_CHECK(double x) { -#if DEBUG -#ifndef ROUNDOFF_LIMIT -#define ROUNDOFF_LIMIT 0.43 -#endif - float error = fabs(x - rint(x)); - if (error > ROUNDOFF_LIMIT) printf("Roundoff: %g %30.2f\n", error, x); -#endif -} -#endif - +#if FFT_FP64 // Rounding constant: 3 * 2^51, See https://stackoverflow.com/questions/17035464 #define RNDVAL (3.0 * (1l << 51)) @@ -70,6 +113,62 @@ i64 RNDVALdoubleToLong(double d) { return as_long(words); } +#elif FFT_FP32 +// Rounding constant: 3 * 2^22 +#define RNDVAL (3.0f * (1 << 22)) + +// Convert a float to int efficiently. Float must be in RNDVAL+integer format. +i32 RNDVALfloatToInt(float d) { + int w = as_int(d); +//#if 0 +// We extend the range to 23 bits instead of 22 by taking the sign from the negation of bit 22 +// w ^= 0x00800000u; +// w = lowBits(words.y, 23); +//#else +// // Take the sign from bit 21 (i.e. use lower 22 bits). + w = lowBits(w, 22); +//#endif + return w; +} +#endif + +// map abs(carry) to floats, with 2^32 corresponding to 1.0 +// So that the maximum CARRY32 abs(carry), 2^31, is mapped to 0.5 (the same as the maximum ROE) +float OVERLOAD boundCarry(i32 c) { return ldexp(fabs((float) c), -32); } +float OVERLOAD boundCarry(i64 c) { return ldexp(fabs((float) (i32) (c >> 8)), -24); } + +#if STATS || ROE +void updateStats(global uint *bufROE, u32 posROE, float roundMax) { + assert(roundMax >= 0); + // work_group_reduce_max() allocates an additional 256Bytes LDS for a 64lane workgroup, so avoid it. + // u32 groupRound = work_group_reduce_max(as_uint(roundMax)); + // if (get_local_id(0) == 0) { atomic_max(bufROE + posROE, groupRound); } + + // Do the reduction directly over global mem. + atomic_max(bufROE + posROE, as_uint(roundMax)); +} +#endif + +#if 0 +// Check for round off errors above a threshold (default is 0.43) +void ROUNDOFF_CHECK(double x) { +#if DEBUG +#ifndef ROUNDOFF_LIMIT +#define ROUNDOFF_LIMIT 0.43 +#endif + float error = fabs(x - rint(x)); + if (error > ROUNDOFF_LIMIT) printf("Roundoff: %g %30.2f\n", error, x); +#endif +} +#endif + + +/***************************************************************************/ +/* From the FFT data, construct a value to normalize and carry propagate */ +/***************************************************************************/ + +#if FFT_TYPE == FFT64 + // Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. i64 weightAndCarryOne(T u, T invWeight, i64 inCarry, float* maxROE, int sloppy_result_is_acceptable) { @@ -83,7 +182,7 @@ i64 weightAndCarryOne(T u, T invWeight, i64 inCarry, float* maxROE, int sloppy_r double d = fma(u, invWeight, RNDVALCarry); // Optionally calculate roundoff error - float roundoff = fabs((float) fma(u, -invWeight, d - RNDVALCarry)); + float roundoff = fabs((float) fma(u, invWeight, RNDVALCarry - d)); *maxROE = max(*maxROE, roundoff); // Convert to long (for CARRY32 case we don't need to strip off the RNDVAL bits) @@ -105,19 +204,375 @@ i64 weightAndCarryOne(T u, T invWeight, i64 inCarry, float* maxROE, int sloppy_r #endif } +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT32 + +// Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. +i32 weightAndCarryOne(F u, F invWeight, i32 inCarry, float* maxROE, int sloppy_result_is_acceptable) { + +#if !MUL3 + + // Convert carry into RNDVAL + carry. + float RNDVALCarry = as_float(as_int(RNDVAL) + inCarry); // GWBUG - just the float arithmetic? s.b. fast + + // Apply inverse weight and RNDVAL+carry + float d = fma(u, invWeight, RNDVALCarry); + + // Optionally calculate roundoff error + float roundoff = fabs(fma(u, invWeight, RNDVALCarry - d)); + *maxROE = max(*maxROE, roundoff); + + // Convert to int + return RNDVALfloatToInt(d); + +#else // We cannot add in the carry until after the mul by 3 + + // Apply inverse weight and RNDVAL + float d = fma(u, invWeight, RNDVAL); + + // Optionally calculate roundoff error + float roundoff = fabs(fma(u, -invWeight, d - RNDVAL)); + *maxROE = max(*maxROE, roundoff); + + // Convert to int, mul by 3, and add carry + return RNDVALfloatToInt(d) * 3 + inCarry; + +#endif +} + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT31 + +// Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. +i64 weightAndCarryOne(Z61 u, u32 invWeight, i32 inCarry, u32* maxROE) { + + // Apply inverse weight + u = shr(u, invWeight); + + // Convert input to balanced representation + i32 value = get_balanced_Z31(u); + + // Optionally calculate roundoff error as proximity to M31/2. + u32 roundoff = (u32) abs(value); + *maxROE = max(*maxROE, roundoff); + + // Mul by 3 and add carry +#if MUL3 + return (i64)value * 3 + inCarry; +#endif + return value + inCarry; +} + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT61 + +// Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. +i64 weightAndCarryOne(Z61 u, u32 invWeight, i64 inCarry, u32* maxROE) { + + // Apply inverse weight + u = shr(u, invWeight); + + // Convert input to balanced representation + i64 value = get_balanced_Z61(u); + + // Optionally calculate roundoff error as proximity to M61/2. 28 bits of accuracy should be sufficient. + u32 roundoff = (u32) abs((i32) (value >> 32)); + *maxROE = max(*maxROE, roundoff); + + // Mul by 3 and add carry +#if MUL3 + value *= 3; +#endif + return value + inCarry; +} + +/**************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP64 & GF(M31^2) */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT6431 + +// Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. +i96 weightAndCarryOne(T u, Z31 u31, T invWeight, u32 m31_invWeight, bool hasInCarry, i64 inCarry, float* maxROE) { + + // Apply inverse weight and get the Z31 data + u31 = shr(u31, m31_invWeight); + u32 n31 = get_Z31(u31); + + // The final result must be n31 mod M31. Use FP64 data to calculate this value. + u = fma(u, invWeight, - (double) n31); // This should be close to a multiple of M31 + double uInt = fma(u, 4.656612875245796924105750827168e-10, RNDVAL); // Divide by M31 and round to int + i64 n64 = RNDVALdoubleToLong(uInt); + + // Optionally calculate roundoff error + float roundoff = (float) fabs(fma(u, 4.656612875245796924105750827168e-10, RNDVAL - uInt)); + *maxROE = max(*maxROE, roundoff); + + // Compute the value using i96 math + i64 vhi = n64 >> 33; + u64 vlo = ((u64)n64 << 31) | n31; + i96 value = make_i96(vhi, vlo); // (n64 << 31) + n31 + value = sub(value, n64); // n64 * M31 + n31 + + // Mul by 3 and add carry +#if MUL3 + value = add(value, add(value, value)); +#endif + if (hasInCarry) value = add(value, inCarry); + return value; +} + +/**************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP32 & GF(M31^2) */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT3231 + +// Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. +i64 weightAndCarryOne(float uF2, Z31 u31, float F2_invWeight, u32 m31_invWeight, i32 inCarry, float* maxROE) { + + // Apply inverse weight and get the Z31 data + u31 = shr(u31, m31_invWeight); + u32 n31 = get_Z31(u31); + + // The final result must be n31 mod M31. Use FP32 data to calculate this value. + uF2 = fma(uF2, F2_invWeight, - (float) n31); // This should be close to a multiple of M31 + float uF2int = fma(uF2, 4.656612875245796924105750827168e-10f, RNDVAL); // Divide by M31 and round to int + i32 nF2 = RNDVALfloatToInt(uF2int); + + i64 v = (((i64) nF2 << 31) | n31) - nF2; // nF2 * M31 + n31 + + // Optionally calculate roundoff error + float roundoff = fabs(fma(uF2, 4.656612875245796924105750827168e-10f, RNDVAL - uF2int)); + *maxROE = max(*maxROE, roundoff); + + // Mul by 3 and add carry +#if MUL3 + v = v * 3; +#endif + return v + inCarry; +} + +/**************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP32 & GF(M61^2) */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT3261 + +// Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. +i96 weightAndCarryOne(float uF2, Z61 u61, float F2_invWeight, u32 m61_invWeight, bool hasInCarry, i64 inCarry, float* maxROE) { + + // Apply inverse weight and get the Z61 data + u61 = shr(u61, m61_invWeight); + u64 n61 = get_Z61(u61); + + // The final result must be n61 mod M61. Use FP32 data to calculate this value. + float n61f = (float)((u32)(n61 >> 32)) * -4294967296.0f; // Conversion from u64 to float might be slow, this might be faster + uF2 = fma(uF2, F2_invWeight, n61f); // This should be close to a multiple of M61 + float uF2int = fma(uF2, 4.3368086899420177360298112034798e-19f, RNDVAL); // Divide by M61 and round to int + i32 nF2 = RNDVALfloatToInt(uF2int); + + // Optionally calculate roundoff error + float roundoff = fabs(fma(uF2, 4.3368086899420177360298112034798e-19f, RNDVAL - uF2int)); + *maxROE = max(*maxROE, roundoff); + + // Compute the value using i96 math + i32 vhi = nF2 >> 3; + u64 vlo = ((u64)nF2 << 61) | n61; + i96 value = make_i96(vhi, vlo); // (nF2 << 61) + n61 + value = sub(value, nF2); // nF2 * M61 + n61 + + // Mul by 3 and add carry +#if MUL3 + value = add(value, add(value, value)); +#endif + if (hasInCarry) value = add(value, inCarry); + return value; +} + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2)*GF(M61^2) */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT3161 + +// Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. +i96 weightAndCarryOne(Z31 u31, Z61 u61, u32 m31_invWeight, u32 m61_invWeight, bool hasInCarry, i64 inCarry, u32* maxROE) { + + // Apply inverse weights + u31 = shr(u31, m31_invWeight); + u61 = shr(u61, m61_invWeight); + + // Use chinese remainder theorem to create a 92-bit result. Loosely copied from Yves Gallot's mersenne2 program. + u32 n31 = get_Z31(u31); + u61 = subq(u61, make_Z61(n31), 2); // u61 - u31 + u61 = add(u61, shl(u61, 31)); // u61 + (u61 << 31) + + // The resulting value will be get_Z61(u61) * M31 + n31 and if larger than ~M31*M61/2 return a negative value by subtracting M31 * M61. + // We can save a little work by determining if the result will be large using just u61 and returning (get_Z61(u61) - M61) * M31 + n31. + // This simplifies to get_balanced_Z61(u61) * M31 + n31. + i64 n61 = get_balanced_Z61(u61); + + // Optionally calculate roundoff error as proximity to M61/2. 28 bits of accuracy should be sufficient. + u32 roundoff = (u32) abs((i32)(n61 >> 32)); + *maxROE = max(*maxROE, roundoff); + + // Compute the value using i96 math + i64 vhi = n61 >> 33; + u64 vlo = ((u64)n61 << 31) | n31; + i96 value = make_i96(vhi, vlo); // (n61 << 31) + n31 + value = sub(value, n61); // n61 * M31 + n31 + + // Mul by 3 and add carry +#if MUL3 + value = add(value, add(value, value)); +#endif + if (hasInCarry) value = add(value, inCarry); + return value; +} + +/******************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP32*GF(M31^2)*GF(M61^2) */ +/******************************************************************************/ + +#elif FFT_TYPE == FFT323161 + +// Apply inverse weight, add in optional carry, calculate roundoff error, convert to integer. Handle MUL3. +i128 weightAndCarryOne(float uF2, Z31 u31, Z61 u61, float F2_invWeight, u32 m31_invWeight, u32 m61_invWeight, bool hasInCarry, i64 inCarry, float* maxROE) { + + // Apply inverse weights + u31 = shr(u31, m31_invWeight); + u61 = shr(u61, m61_invWeight); + + // Use chinese remainder theorem to create a 92-bit result. Loosely copied from Yves Gallot's mersenne2 program. + u32 n31 = get_Z31(u31); + u61 = subq(u61, make_Z61(n31), 2); // u61 - u31 + u61 = add(u61, shl(u61, 31)); // u61 + (u61 << 31) + u64 n61 = get_Z61(u61); + + i128 n3161 = make_i128(n61 >> 33, (n61 << 31) | n31); // n61 << 31 + n31 + n3161 = sub(n3161, n61); // n61 * M31 + n31 + + // The final result must be n3161 mod M31*M61. Use FP32 data to calculate this value. + float n3161f = (float)((u32)(n61 >> 32)) * -9223372036854775808.0f; // Converting n3161 from i128 to float might be slow, this might be faster + uF2 = fma(uF2, F2_invWeight, n3161f); // This should be close to a multiple of M31*M61 + float uF2int = fma(uF2, 2.0194839183061857038255724444152e-28f, RNDVAL); // Divide by M31*M61 and round to int + i32 nF2 = RNDVALfloatToInt(uF2int); + + i64 nF2m31 = ((i64)nF2 << 31) - nF2; // nF2 * M31 + i128 v = make_i128(nF2m31 >> 3, (u64)nF2m31 << 61); // nF2m31 << 61 + v = sub(v, nF2m31); // nF2m31 * M61 + v = add(v, n3161); // nF2m31 * M61 + n3161 + + // Optionally calculate roundoff error + float roundoff = fabs(fma(uF2, 2.0194839183061857038255724444152e-28f, RNDVAL - uF2int)); + *maxROE = max(*maxROE, roundoff); + + // Mul by 3 and add carry +#if MUL3 + v = add(v, add(v, v)); +#endif + if (hasInCarry) v = add(v, inCarry); + return v; +} + +#else +error - missing weightAndCarryOne implementation +#endif + + +/************************************************************************/ +/* Split a value + carryIn into a big-or-little word and a carryOut */ +/************************************************************************/ + +Word OVERLOAD carryStep(i128 x, i64 *outCarry, bool isBigWord) { + u32 nBits = bitlen(isBigWord); + i64 w = lowBits(i128_lo64(x), nBits); + *outCarry = i128_shrlo64(x, nBits) + (w < 0); + return w; +} + +Word OVERLOAD carryStep(i96 x, i64 *outCarry, bool isBigWord) { + u32 nBits = bitlen(isBigWord); + u32 nBitsLess32 = bitlen(isBigWord) - 32; + +// This code can be tricky because we must not shift i32 or u32 variables by 32. +#if EXP / NWORDS >= 33 + i32 whi = lowBits(i96_mid32(x), nBitsLess32); + *outCarry = ((i64)i96_hi64(x) - (i64)whi) >> nBitsLess32; + return as_ulong((uint2)(i96_lo32(x), (u32)whi)); +#elif EXP / NWORDS == 32 + i32 whi = xtract32(i96_lo64(x), nBitsLess32) >> 31; + *outCarry = ((i64)i96_hi64(x) - (i64)whi) >> nBitsLess32; + return as_ulong((uint2)(i96_lo32(x), (u32)whi)); +#elif EXP / NWORDS == 31 + i32 w = lowBitsSafe32(i96_lo32(x), nBits); + *outCarry = as_long((int2)(xtractSafe32(i96_lo64(x), nBits), xtractSafe32(i96_hi64(x), nBits))) + (w < 0); + return w; +// i64 w = lowBits(i96_lo64(x), nBits); +// *outCarry = ((i96_hi64(x) << (32 - nBits)) | ((i96_lo32(x) >> 16) >> (nBits - 16))) + (w < 0); +// return w; +#else + i32 w = lowBits(i96_lo32(x), nBits); + *outCarry = as_long((int2)(xtract32(i96_lo64(x), nBits), xtract32(i96_hi64(x), nBits))) + (w < 0); + return w; +#endif +} + Word OVERLOAD carryStep(i64 x, i64 *outCarry, bool isBigWord) { u32 nBits = bitlen(isBigWord); - Word w = lowBits(x, nBits); - x -= w; - *outCarry = x >> nBits; +#if EXP / NWORDS >= 33 + i32 xhi = hi32(x); + i32 whi = lowBits(xhi, nBits - 32); + *outCarry = (xhi - whi) >> (nBits - 32); + return (Word) as_long((int2)(lo32(x), whi)); +#elif EXP / NWORDS == 32 + i32 xhi = hi32(x); + i64 w = lowBits(x, nBits); + xhi -= (i32)(w >> 32); + *outCarry = xhi >> (nBits - 32); + return w; +#elif EXP / NWORDS == 31 + i32 w = lowBitsSafe32(lo32(x), nBits); + *outCarry = (x - w) >> nBits; return w; +#else + Word w = lowBits(lo32(x), nBits); + *outCarry = (x - w) >> nBits; + return w; +#endif } Word OVERLOAD carryStep(i64 x, i32 *outCarry, bool isBigWord) { u32 nBits = bitlen(isBigWord); - Word w = lowBits(x, nBits); +#if EXP / NWORDS >= 33 + i32 xhi = hi32(x); + i32 w = lowBits(xhi, nBits - 32); + *outCarry = (xhi >> (nBits - 32)) + (w < 0); + return as_long((int2)(lo32(x), w)); +#elif EXP / NWORDS == 32 + i32 xhi = hi32(x); + i64 w = lowBits(x, nBits); + *outCarry = (xhi >> (nBits - 32)) + (w < 0); + return w; +#elif EXP / NWORDS == 31 + i32 w = lowBitsSafe32(lo32(x), nBits); + *outCarry = xtractSafe32(x, nBits) + (w < 0); + return w; +#else + i32 w = lowBits(x, nBits); *outCarry = xtract32(x, nBits) + (w < 0); return w; +#endif } Word OVERLOAD carryStep(i32 x, i32 *outCarry, bool isBigWord) { @@ -127,54 +582,188 @@ Word OVERLOAD carryStep(i32 x, i32 *outCarry, bool isBigWord) { return w; } -Word OVERLOAD carryStepSloppy(i64 x, i64 *outCarry, bool isBigWord) { +/*****************************************************************/ +/* Same as CarryStep but returns a faster unsigned result. */ +/* Used on first word of pair in carryFused. */ +/* CarryFinal will later turn this into a balanced signed value. */ +/*****************************************************************/ + +Word OVERLOAD carryStepUnsignedSloppy(i128 x, i64 *outCarry, bool isBigWord) { + const u32 bigwordBits = EXP / NWORDS + 1; u32 nBits = bitlen(isBigWord); - Word w = ulowBits(x, nBits); - *outCarry = x >> nBits; + +// Return a Word using the big word size. Big word size is a constant which allows for more optimization. + u64 w = ulowFixedBits(i128_lo64(x), bigwordBits); + x = i128_masklo64(x, ~((u64)1 << (bigwordBits - 1))); + *outCarry = i128_shrlo64(x, nBits); return w; } -Word OVERLOAD carryStepSloppy(i64 x, i32 *outCarry, bool isBigWord) { +Word OVERLOAD carryStepUnsignedSloppy(i96 x, i64 *outCarry, bool isBigWord) { + const u32 bigwordBits = EXP / NWORDS + 1; u32 nBits = bitlen(isBigWord); - Word w = ulowBits(x, nBits); - *outCarry = xtract32(x, nBits); + +// Return a Word using the big word size. Big word size is a constant which allows for more optimization. +#if EXP / NWORDS >= 32 // nBits is 32 or more + i64 xhi = as_ulong((uint2)(i96_mid32(x) & ~((1 << (bigwordBits - 32)) - 1), i96_hi32(x))); + *outCarry = xhi >> (nBits - 32); + return as_ulong((uint2)(i96_lo32(x), ulowFixedBits(i96_mid32(x), bigwordBits - 32))); +#elif EXP / NWORDS == 31 || EXP / NWORDS >= 22 // nBits = 31 or 32, fastest version. Should also work on smaller nBits. + *outCarry = i96_hi64(x) << (32 - nBits); + return i96_lo32(x); // ulowBits(x, bigwordBits = 32); +#else // nBits less than 32 + u32 w = ulowFixedBits(i96_lo32(x), bigwordBits); + *outCarry = as_long((int2)(xtract32(as_long((int2)(i96_lo32(x) - w, i96_mid32(x))), nBits), xtract32(i96_hi64(x), nBits))); return w; +#endif +} + +Word OVERLOAD carryStepUnsignedSloppy(i64 x, i64 *outCarry, bool isBigWord) { + u32 nBits = bitlen(isBigWord); + *outCarry = x >> nBits; + return ulowBits(x, nBits); +} + +Word OVERLOAD carryStepUnsignedSloppy(i64 x, i32 *outCarry, bool isBigWord) { + u32 nBits = bitlen(isBigWord); + *outCarry = xtract32(x, nBits); + return ulowBits(x, nBits); } -Word OVERLOAD carryStepSloppy(i32 x, i32 *outCarry, bool isBigWord) { +Word OVERLOAD carryStepUnsignedSloppy(i32 x, i32 *outCarry, bool isBigWord) { u32 nBits = bitlen(isBigWord); - Word w = ulowBits(x, nBits); *outCarry = x >> nBits; + return ulowBits(x, nBits); +} + +/**********************************************************************/ +/* Same as CarryStep but may return a faster big word signed result. */ +/* Used on second word of pair in carryFused when not near max BPW. */ +/* Also used on first word in carryFinal when not near max BPW. */ +/**********************************************************************/ + +// We only allow sloppy results when not near the maximum bits-per-word. For now, this is defined as 1.1 bits below maxbpw. +// No studies have been done on reducing this 1,1 value since this is a rather minor optimization. Since the preprocessor can't +// handle floats, the MAXBPW value passed in is 100 * maxbpw. +#define SLOPPY_MAXBPW (MAXBPW - 110) +#define ACTUAL_BPW (EXP / (NWORDS / 100)) + +Word OVERLOAD carryStepSignedSloppy(i128 x, i64 *outCarry, bool isBigWord) { +#if ACTUAL_BPW > SLOPPY_MAXBPW + return carryStep(x, outCarry, isBigWord); +#else + +//GW: Need to compare to simple carryStep + +// Return a Word using the big word size. Big word size is a constant which allows for more optimization. + const u32 bigwordBits = EXP / NWORDS + 1; + u32 nBits = bitlen(isBigWord); + u64 xlo = i128_lo64(x); + u64 xlo_topbit = xlo & ((u64)1 << (bigwordBits - 1)); + i64 w = ulowFixedBits(xlo, bigwordBits - 1) - xlo_topbit; + *outCarry = i128_shrlo64(add(x, xlo_topbit), nBits); return w; +#endif } -// map abs(carry) to floats, with 2^32 corresponding to 1.0 -// So that the maximum CARRY32 abs(carry), 2^31, is mapped to 0.5 (the same as the maximum ROE) -float OVERLOAD boundCarry(i32 c) { return ldexp(fabs((float) c), -32); } -float OVERLOAD boundCarry(i64 c) { return ldexp(fabs((float) (i32) (c >> 8)), -24); } +Word OVERLOAD carryStepSignedSloppy(i96 x, i64 *outCarry, bool isBigWord) { +#if ACTUAL_BPW > SLOPPY_MAXBPW + return carryStep(x, outCarry, isBigWord); +#else -#define iCARRY i32 -#include "carryinc.cl" -#undef iCARRY +// Return a Word using the big word size. Big word size is a constant which allows for more optimization. + const u32 bigwordBits = EXP / NWORDS + 1; + u32 nBits = bitlen(isBigWord); +#if EXP / NWORDS >= 32 // nBits is 32 or more + return carryStep(x, outCarry, isBigWord); // Should be just as fast as code below +// u32 xmid_topbit = i96_mid32(x) & (1 << (bigwordBits - 32 - 1)); +// i32 whi = ulowFixedBits(i96_mid32(x), bigwordBits - 32 - 1) - xmid_topbit; +// i64 xhi = i96_hi64(x) + xmid_topbit; +// *outCarry = xhi >> (nBits - 32); +// return as_long((int2)(i96_lo32(x), whi)); +#elif EXP / NWORDS == 31 || SLOPPY_MAXBPW >= 3200 // nBits = 31 or 32, bigwordBits = 32 (or allowed to create 32-bit word for better performance) + i32 w = i96_lo32(x); // lowBits(x, bigwordBits = 32); + *outCarry = (i96_hi64(x) + (w < 0)) << (32 - nBits); + return w; +#else // nBits less than 32 + return carryStep(x, outCarry, isBigWord); // Should be faster than code below +// i32 w = lowFixedBits(i96_lo32(x), bigwordBits); +// *outCarry = (as_long((int2)(xtract32(i96_lo64(x), bigwordBits), xtract32(i96_hi64(x), bigwordBits))) + (w < 0)) << (bigwordBits - nBits); +// return w; +#endif +#endif +} -#define iCARRY i64 -#include "carryinc.cl" -#undef iCARRY +Word OVERLOAD carryStepSignedSloppy(i64 x, i64 *outCarry, bool isBigWord) { +#if ACTUAL_BPW > SLOPPY_MAXBPW + return carryStep(x, outCarry, isBigWord); +#else -#if CARRY64 -typedef i64 CFcarry; + // We're unlikely to find code that is better than carryStep + return carryStep(x, outCarry, isBigWord); +#endif +} + +Word OVERLOAD carryStepSignedSloppy(i64 x, i32 *outCarry, bool isBigWord) { +#if ACTUAL_BPW > SLOPPY_MAXBPW + return carryStep(x, outCarry, isBigWord); #else -typedef i32 CFcarry; + +//GW: I need to look at PTX code generated by the code below vs. carryStep + +// Return a Word using the big word size. Big word size is a constant which allows for more optimization. + const u32 bigwordBits = EXP / NWORDS + 1; + u32 nBits = bitlen(isBigWord); +#if EXP / NWORDS >= 32 // nBits is 32 or more + u64 x_topbit = x & ((u64)1 << (bigwordBits - 1)); + i64 w = ulowFixedBits(x, bigwordBits - 1) - x_topbit; + i32 xhi = (i32)(x >> 32) + (i32)(x_topbit >> 32); + *outCarry = xhi >> (nBits - 32); + return w; +// nBits = 31 or 32, bigwordBits = 32 (or allowed to create 32-bit word for better performance). For reasons I don't fully understand the sloppy +// case fails if BPW is too low. Probably something to do with a small BPW with sloppy 32-bit values would require CARRY_LONG to work properly. +// Not a major concern as end users should avoid small BPW as there is probably a more efficient NTT that could be used. +#elif EXP / NWORDS == 31 || (EXP / NWORDS >= 23 && SLOPPY_MAXBPW >= 3200) + i32 w = x; // lowBits(x, bigwordBits = 32); + *outCarry = ((i32)(x >> 32) + (w < 0)) << (32 - nBits); + return w; +#else // nBits less than 32 //GWBUG - is there a faster version? Is this faster than plain old carryStep? No +// u32 x_topbit = (u32) x & (1 << (bigwordBits - 1)); +// i32 w = ulowFixedBits((u32) x, bigwordBits - 1) - x_topbit; +// *outCarry = (i64)(x + x_topbit) >> nBits; +// return w; + return carryStep(x, outCarry, isBigWord); +#endif #endif +} + +Word OVERLOAD carryStepSignedSloppy(i32 x, i32 *outCarry, bool isBigWord) { + return carryStep(x, outCarry, isBigWord); +} -// The carry for the non-fused CarryA, CarryB, CarryM kernels. -// Simply use large carry always as the split kernels are slow anyway (and seldomly used normally). -typedef i64 CarryABM; -// Carry propagation from word and carry. + +// Carry propagation from word and carry. Used by carryB.cl. Word2 carryWord(Word2 a, CarryABM* carry, bool b1, bool b2) { a.x = carryStep(a.x + *carry, carry, b1); a.y = carryStep(a.y + *carry, carry, b2); return a; } +/**************************************************************************/ +/* Do this last, it depends on weightAndCarryOne defined above */ +/**************************************************************************/ + +/* Support both 32-bit and 64-bit carries */ + +#if WordSize <= 4 +#define iCARRY i32 +#include "carryinc.cl" +#undef iCARRY +#endif + +#if FFT_TYPE != FFT32 && FFT_TYPE != FFT31 +#define iCARRY i64 +#include "carryinc.cl" +#undef iCARRY +#endif diff --git a/src/cl/etc.cl b/src/cl/etc.cl index 89fcaf7b..ae4fd857 100644 --- a/src/cl/etc.cl +++ b/src/cl/etc.cl @@ -18,7 +18,7 @@ KERNEL(32) readResidue(P(Word2) out, CP(Word2) in) { #if SUM64 KERNEL(64) sum64(global ulong* out, u32 sizeBytes, global ulong* in) { if (get_global_id(0) == 0) { out[0] = 0; } - + ulong sum = 0; for (i32 p = get_global_id(0); p < sizeBytes / sizeof(u64); p += get_global_size(0)) { sum += in[p]; @@ -32,7 +32,7 @@ KERNEL(64) sum64(global ulong* out, u32 sizeBytes, global ulong* in) { #if ISEQUAL // outEqual must be "true" on entry. KERNEL(256) isEqual(global i64 *in1, global i64 *in2, P(int) outEqual) { - for (i32 p = get_global_id(0); p < ND; p += get_global_size(0)) { + for (i32 p = get_global_id(0); p < NWORDS * sizeof(Word) / sizeof(i64); p += get_global_size(0)) { if (in1[p] != in2[p]) { *outEqual = 0; return; @@ -50,5 +50,4 @@ kernel void testKernel(global int* in, global double* out) { int p = me * in[me] % 8; // % 15; out[me] = TAB[p]; } - #endif diff --git a/src/cl/fft-middle.cl b/src/cl/fft-middle.cl index bf61d555..31db8bfc 100644 --- a/src/cl/fft-middle.cl +++ b/src/cl/fft-middle.cl @@ -2,8 +2,6 @@ #include "trig.cl" -void fft2(T2* u) { X2(u[0], u[1]); } - #if MIDDLE == 3 #include "fft3.cl" #elif MIDDLE == 4 @@ -34,7 +32,29 @@ void fft2(T2* u) { X2(u[0], u[1]); } #include "fft16.cl" #endif -void fft_MIDDLE(T2 *u) { +#if !defined(MM_CHAIN) && !defined(MM2_CHAIN) && FFT_VARIANT_M == 0 +#define MM_CHAIN 0 +#define MM2_CHAIN 0 +#endif + +#if !defined(MM_CHAIN) && !defined(MM2_CHAIN) && FFT_VARIANT_M == 1 +#define MM_CHAIN 1 +#define MM2_CHAIN 2 +#endif + +// Apply the twiddles needed after fft_MIDDLE and before fft_HEIGHT in forward FFT. +// Also used after fft_HEIGHT and before fft_MIDDLE in inverse FFT. + +#define WADD(i, w) u[i] = cmul(u[i], w) +#define WSUB(i, w) u[i] = cmul_by_conjugate(u[i], w) +#define WADDF(i, w) u[i] = cmulFancy(u[i], w) +#define WSUBF(i, w) u[i] = cmulFancy(u[i], conjugate(w)) + +#if FFT_FP64 + +void OVERLOAD fft2(T2* u) { X2(u[0], u[1]); } + +void OVERLOAD fft_MIDDLE(T2 *u) { #if MIDDLE == 1 // Do nothing #elif MIDDLE == 2 @@ -72,28 +92,15 @@ void fft_MIDDLE(T2 *u) { #endif } -// Apply the twiddles needed after fft_MIDDLE and before fft_HEIGHT in forward FFT. -// Also used after fft_HEIGHT and before fft_MIDDLE in inverse FFT. - -#define WADD(i, w) u[i] = cmul(u[i], w) -#define WSUB(i, w) u[i] = cmul_by_conjugate(u[i], w) - -#define WADDF(i, w) u[i] = cmulFancy(u[i], w) -#define WSUBF(i, w) u[i] = cmulFancy(u[i], conjugate(w)) - // Keep in sync with TrigBufCache.cpp, see comment there. #define SHARP_MIDDLE 5 -#if !defined(MM_CHAIN) && !defined(MM2_CHAIN) && FFT_VARIANT_M == 1 -#define MM_CHAIN 1 -#define MM2_CHAIN 2 -#endif - -void middleMul(T2 *u, u32 s, Trig trig) { +void OVERLOAD middleMul(T2 *u, u32 s, Trig trig) { assert(s < SMALL_HEIGHT); - if (MIDDLE == 1) { return; } + if (MIDDLE == 1) return; - T2 w = trig[s]; // s / BIG_HEIGHT + if (WIDTH == SMALL_HEIGHT) trig += SMALL_HEIGHT; // In this case we can share the MiddleMul2 trig table. Skip over the MiddleMul trig table. + T2 w = trig[s]; // s / BIG_HEIGHT if (MIDDLE < SHARP_MIDDLE) { WADD(1, w); @@ -175,7 +182,7 @@ void middleMul(T2 *u, u32 s, Trig trig) { } } -void middleMul2(T2 *u, u32 x, u32 y, double factor, Trig trig) { +void OVERLOAD middleMul2(T2 *u, u32 x, u32 y, double factor, Trig trig) { assert(x < WIDTH); assert(y < SMALL_HEIGHT); @@ -184,7 +191,8 @@ void middleMul2(T2 *u, u32 x, u32 y, double factor, Trig trig) { return; } - T2 w = trig[SMALL_HEIGHT + x]; // x / (MIDDLE * WIDTH) + trig += SMALL_HEIGHT; // Skip over the MiddleMul trig table + T2 w = trig[x]; // x / (MIDDLE * WIDTH) if (MIDDLE < SHARP_MIDDLE) { T2 base = slowTrig_N(x * y + x * SMALL_HEIGHT, ND / MIDDLE * 2) * factor; @@ -196,7 +204,19 @@ void middleMul2(T2 *u, u32 x, u32 y, double factor, Trig trig) { } else { // MIDDLE >= 5 // T2 w = slowTrig_N(x * SMALL_HEIGHT, ND / MIDDLE); -#if AMDGPU && MM2_CHAIN == 0 // Oddly, Radeon 7 is faster with this version that uses more F64 ops +#if 0 // Slower on Radeon 7, but proves the concept for use in GF61. Might be worthwhile on poor FP64 GPUs + + Trig trig2 = trig + WIDTH; // Skip over the fist MiddleMul2 trig table + u32 desired_root = x * y; + T2 base = cmulFancy(trig2[desired_root % SMALL_HEIGHT], trig[desired_root / SMALL_HEIGHT]) * factor; //Optimization to do: put multiply by factor in trig2 table + + WADD(0, base); + for (u32 k = 1; k < MIDDLE; ++k) { + base = cmulFancy(base, w); + WADD(k, base); + } + +#elif AMDGPU && MM2_CHAIN == 0 // Oddly, Radeon 7 is faster with this version that uses more F64 ops T2 base = slowTrig_N(x * y + x * SMALL_HEIGHT, ND / MIDDLE * 2) * factor; WADD(0, base); @@ -281,13 +301,8 @@ void middleMul2(T2 *u, u32 x, u32 y, double factor, Trig trig) { } } -#undef WADD -#undef WADDF -#undef WSUB -#undef WSUBF - // Do a partial transpose during fftMiddleIn/Out -void middleShuffle(local T *lds, T2 *u, u32 workgroupSize, u32 blockSize) { +void OVERLOAD middleShuffle(local T *lds, T2 *u, u32 workgroupSize, u32 blockSize) { u32 me = get_local_id(0); if (MIDDLE <= 8) { local T *p1 = lds + (me % blockSize) * (workgroupSize / blockSize) + me / blockSize; @@ -323,10 +338,567 @@ void middleShuffle(local T *lds, T2 *u, u32 workgroupSize, u32 blockSize) { } } +// Do a partial transpose during fftMiddleIn/Out and write the results to global memory +void OVERLOAD middleShuffleWrite(global T2 *out, T2 *u, u32 workgroupSize, u32 blockSize) { + u32 me = get_local_id(0); + out += (me % blockSize) * (workgroupSize / blockSize) + me / blockSize; + for (int i = 0; i < MIDDLE; ++i) { out[i * workgroupSize] = u[i]; } +} + +// Do an in-place 16x16 transpose during fftMiddleIn/Out +void OVERLOAD middleShuffle(local T2 *lds, T2 *u) { + u32 me = get_local_id(0); + u32 y = me / 16; + u32 x = me % 16; + + for (int i = 0; i < MIDDLE; ++i) { +// lds[x * 16 + y] = u[i]; + lds[x * 16 + y ^ x] = u[i]; // Swizzling with XOR should reduce LDS bank conflicts + bar(); +// u[i] = lds[me]; + u[i] = lds[y * 16 + x ^ y]; + bar(); + } +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 + +void OVERLOAD fft2(F2* u) { X2(u[0], u[1]); } + +void OVERLOAD fft_MIDDLE(F2 *u) { +#if MIDDLE == 1 + // Do nothing +#elif MIDDLE == 2 + fft2(u); +#elif MIDDLE == 4 + fft4(u); +#elif MIDDLE == 8 + fft8(u); +#elif MIDDLE == 16 + fft16(u); +#else +#error UNRECOGNIZED MIDDLE +#endif +} + +// Keep in sync with TrigBufCache.cpp, see comment there. +#define SHARP_MIDDLE 5 + +void OVERLOAD middleMul(F2 *u, u32 s, TrigFP32 trig) { + assert(s < SMALL_HEIGHT); + if (MIDDLE == 1) return; + + if (WIDTH == SMALL_HEIGHT) trig += SMALL_HEIGHT; // In this case we can share the MiddleMul2 trig table. Skip over the MiddleMul trig table. + F2 w = trig[s]; // s / BIG_HEIGHT + + if (MIDDLE < SHARP_MIDDLE) { + WADD(1, w); +#if MM_CHAIN == 0 + F2 base = csqTrig(w); + for (u32 k = 2; k < MIDDLE; ++k) { + WADD(k, base); + base = cmul(base, w); + } +#elif MM_CHAIN == 1 + for (u32 k = 2; k < MIDDLE; ++k) { WADD(k, slowTrig_N(WIDTH * k * s, WIDTH * k * SMALL_HEIGHT)); } +#else +#error MM_CHAIN must be 0 or 1 +#endif + + } else { // MIDDLE >= 5 + +#if MM_CHAIN == 0 + WADDF(1, w); + F2 base; + base = csqTrigFancy(w); + WADDF(2, base); + base = ccubeTrigFancy(base, w); + WADDF(3, base); + base.x += 1; + + for (u32 k = 4; k < MIDDLE; ++k) { + base = cmulFancy(base, w); + WADD(k, base); + } + +#elif 0 && MM_CHAIN == 1 // This is fewer F64 ops, but may be slower on Radeon 7 -- probably the optimizer being weird. It also has somewhat worse Z. + for (u32 k = 3 + (MIDDLE - 2) % 3; k < MIDDLE; k += 3) { + F2 base, base_minus1, base_plus1; + base = slowTrig_N(WIDTH * k * s, WIDTH * SMALL_HEIGHT * k); + cmul_a_by_fancyb_and_conjfancyb(&base_plus1, &base_minus1, base, w); + WADD(k-1, base_minus1); + WADD(k, base); + WADD(k+1, base_plus1); + } + + WADDF(1, w); + + F2 w2; + if ((MIDDLE - 2) % 3 > 0) { + w2 = csqTrigFancy(w); + WADDF(2, w2); + } + + if ((MIDDLE - 2) % 3 == 2) { + F2 w3 = ccubeTrigFancy(w2, w); + WADDF(3, w3); + } + +#elif MM_CHAIN == 1 + for (u32 k = 3 + (MIDDLE - 2) % 3; k < MIDDLE; k += 3) { + F2 base, base_minus1, base_plus1; + base = slowTrig_N(WIDTH * k * s, WIDTH * SMALL_HEIGHT * k); + cmul_a_by_fancyb_and_conjfancyb(&base_plus1, &base_minus1, base, w); + WADD(k-1, base_minus1); + WADD(k, base); + WADD(k+1, base_plus1); + } + + WADDF(1, w); + + if ((MIDDLE - 2) % 3 > 0) { + WADDF(2, w); + WADDF(2, w); + } + + if ((MIDDLE - 2) % 3 == 2) { + WADDF(3, w); + WADDF(3, csqTrigFancy(w)); + } +#else +#error MM_CHAIN must be 0 or 1. +#endif + } +} + +void OVERLOAD middleMul2(F2 *u, u32 x, u32 y, float factor, TrigFP32 trig) { + assert(x < WIDTH); + assert(y < SMALL_HEIGHT); + + if (MIDDLE == 1) { + WADD(0, slowTrig_N(x * y, ND) * factor); + return; + } + + trig += SMALL_HEIGHT; // Skip over the MiddleMul trig table + F2 w = trig[x]; // x / (MIDDLE * WIDTH) + + if (MIDDLE < SHARP_MIDDLE) { + F2 base = slowTrig_N(x * y + x * SMALL_HEIGHT, ND / MIDDLE * 2) * factor; + for (u32 k = 0; k < MIDDLE; ++k) { WADD(k, base); } + WSUB(0, w); + if (MIDDLE > 2) { WADD(2, w); } + if (MIDDLE > 3) { WADD(3, w); WADD(3, w); } + + } else { // MIDDLE >= 5 + // F2 w = slowTrig_N(x * SMALL_HEIGHT, ND / MIDDLE); + +#if 0 // Slower on Radeon 7, but proves the concept for use in GF61. Might be worthwhile on poor FP64 GPUs + + TrigFP32 trig2 = trig + WIDTH; // Skip over the fist MiddleMul2 trig table + u32 desired_root = x * y; + F2 base = cmulFancy(trig2[desired_root % SMALL_HEIGHT], trig[desired_root / SMALL_HEIGHT]) * factor; //Optimization to do: put multiply by factor in trig2 table + + WADD(0, base); + for (u32 k = 1; k < MIDDLE; ++k) { + base = cmulFancy(base, w); + WADD(k, base); + } + +#elif AMDGPU && MM2_CHAIN == 0 // Oddly, Radeon 7 is faster with this version that uses more F64 ops + + F2 base = slowTrig_N(x * y + x * SMALL_HEIGHT, ND / MIDDLE * 2) * factor; + WADD(0, base); + WADD(1, base); + + for (u32 k = 2; k < MIDDLE; ++k) { + base = cmulFancy(base, w); + WADD(k, base); + } + WSUBF(0, w); + +#elif MM2_CHAIN == 0 + + u32 mid = MIDDLE / 2; + F2 base = slowTrig_N(x * y + x * SMALL_HEIGHT * mid, ND / MIDDLE * (mid + 1)) * factor; + WADD(mid, base); + + F2 basehi, baselo; + cmul_a_by_fancyb_and_conjfancyb(&basehi, &baselo, base, w); + WADD(mid-1, baselo); + WADD(mid+1, basehi); + + for (int i = mid-2; i >= 0; --i) { + baselo = cmulFancy(baselo, conjugate(w)); + WADD(i, baselo); + } + + for (int i = mid+2; i < MIDDLE; ++i) { + basehi = cmulFancy(basehi, w); + WADD(i, basehi); + } + +#elif MM2_CHAIN == 1 + u32 cnt = 1; + for (u32 start = 0, sz = (MIDDLE - start + cnt - 1) / cnt; cnt > 0; --cnt, start += sz) { + if (start + sz > MIDDLE) { --sz; } + u32 n = (sz - 1) / 2; + u32 mid = start + n; + + F2 base1 = slowTrig_N(x * y + x * SMALL_HEIGHT * mid, ND / MIDDLE * (mid + 1)) * factor; + WADD(mid, base1); + + F2 base2 = base1; + for (u32 i = 1; i <= n; ++i) { + base1 = cmulFancy(base1, conjugate(w)); + WADD(mid - i, base1); + + base2 = cmulFancy(base2, w); + WADD(mid + i, base2); + } + if (!(sz & 1)) { + base2 = cmulFancy(base2, w); + WADD(mid + n + 1, base2); + } + } + +#elif MM2_CHAIN == 2 + F2 base, base_minus1, base_plus1; + for (u32 i = 1; ; i += 3) { + if (i-1 == MIDDLE-1) { + base_minus1 = slowTrig_N(x * y + x * SMALL_HEIGHT * (i - 1), ND / MIDDLE * i) * factor; + WADD(i-1, base_minus1); + break; + } else if (i == MIDDLE-1) { + base_minus1 = slowTrig_N(x * y + x * SMALL_HEIGHT * (i - 1), ND / MIDDLE * i) * factor; + base = cmulFancy(base_minus1, w); + WADD(i-1, base_minus1); + WADD(i, base); + break; + } else { + base = slowTrig_N(x * y + x * SMALL_HEIGHT * i, ND / MIDDLE * (i + 1)) * factor; + cmul_a_by_fancyb_and_conjfancyb(&base_plus1, &base_minus1, base, w); + WADD(i-1, base_minus1); + WADD(i, base); + WADD(i+1, base_plus1); + if (i+1 == MIDDLE-1) break; + } + } +#else +#error MM2_CHAIN must be 0, 1 or 2. +#endif + } +} + +// Do a partial transpose during fftMiddleIn/Out +void OVERLOAD middleShuffle(local F *lds, F2 *u, u32 workgroupSize, u32 blockSize) { + u32 me = get_local_id(0); + if (MIDDLE <= 16) { + local F *p1 = lds + (me % blockSize) * (workgroupSize / blockSize) + me / blockSize; + local F *p2 = lds + me; + for (int i = 0; i < MIDDLE; ++i) { p1[i * workgroupSize] = u[i].x; } + bar(); + for (int i = 0; i < MIDDLE; ++i) { u[i].x = p2[workgroupSize * i]; } + bar(); + for (int i = 0; i < MIDDLE; ++i) { p1[i * workgroupSize] = u[i].y; } + bar(); + for (int i = 0; i < MIDDLE; ++i) { u[i].y = p2[workgroupSize * i]; } + } +} // Do a partial transpose during fftMiddleIn/Out and write the results to global memory -void middleShuffleWrite(global T2 *out, T2 *u, u32 workgroupSize, u32 blockSize) { +void OVERLOAD middleShuffleWrite(global F2 *out, F2 *u, u32 workgroupSize, u32 blockSize) { u32 me = get_local_id(0); out += (me % blockSize) * (workgroupSize / blockSize) + me / blockSize; for (int i = 0; i < MIDDLE; ++i) { out[i * workgroupSize] = u[i]; } } + +// Do an in-place 16x16 transpose during fftMiddleIn/Out +void OVERLOAD middleShuffle(local F2 *lds, F2 *u) { + u32 me = get_local_id(0); + u32 y = me / 16; + u32 x = me % 16; + for (int i = 0; i < MIDDLE; ++i) { +// lds[x * 16 + y] = u[i]; + lds[x * 16 + y ^ x] = u[i]; // Swizzling with XOR should reduce LDS bank conflicts + bar(); +// u[i] = lds[me]; + u[i] = lds[y * 16 + x ^ y]; + bar(); + } +} +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +void OVERLOAD fft2(GF31* u) { X2(u[0], u[1]); } + +void OVERLOAD fft_MIDDLE(GF31 *u) { +#if MIDDLE == 1 + // Do nothing +#elif MIDDLE == 2 + fft2(u); +#elif MIDDLE == 4 + fft4(u); +#elif MIDDLE == 8 + fft8(u); +#elif MIDDLE == 16 + fft16(u); +#else +#error UNRECOGNIZED MIDDLE +#endif +} + +void OVERLOAD middleMul(GF31 *u, u32 s, TrigGF31 trig) { + assert(s < SMALL_HEIGHT); + if (MIDDLE == 1) return; + +#if !MIDDLE_CHAIN // Read all trig values from memory + + for (u32 k = 1; k < MIDDLE; ++k) { + WADD(k, trig[s]); + s += SMALL_HEIGHT; + } + +#else + + GF31 w = trig[s]; // s / BIG_HEIGHT + WADD(1, w); + if (MIDDLE == 2) return; + +#if SHOULD_BE_FASTER + GF31 sq = csqTrig(w); + WADD(2, sq); + GF31 base = ccubeTrig(sq, w); // GWBUG: compute w^4 as csqTriq(sq), w^6 as ccubeTrig(w2, w4), and w^5 and w^7 as cmul_a_by_b_and_conjb + for (u32 k = 3; k < MIDDLE; ++k) { +#else + GF31 base = csq(w); + for (u32 k = 2; k < MIDDLE; ++k) { +#endif + WADD(k, base); + base = cmul(base, w); + } + +#endif + +} + +void OVERLOAD middleMul2(GF31 *u, u32 x, u32 y, TrigGF31 trig) { + assert(x < WIDTH); + assert(y < SMALL_HEIGHT); + + // First trig table comes after the MiddleMul trig table. Second trig table comes after the first MiddleMul2 trig table. + TrigGF31 trig1 = trig + SMALL_HEIGHT * (MIDDLE - 1); + TrigGF31 trig2 = trig1 + WIDTH; + // The first trig table can be shared with MiddleMul trig table if WIDTH = HEIGHT. + if (WIDTH == SMALL_HEIGHT) trig1 = trig; + + GF31 w = trig1[x]; // x / (MIDDLE * WIDTH) + u32 desired_root = x * y; + GF31 base = cmul(trig2[desired_root % SMALL_HEIGHT], trig1[desired_root / SMALL_HEIGHT]); + + WADD(0, base); + for (u32 k = 1; k < MIDDLE; ++k) { + base = cmul(base, w); + WADD(k, base); + } +} + +// Do a partial transpose during fftMiddleIn/Out +void OVERLOAD middleShuffle(local Z31 *lds, GF31 *u, u32 workgroupSize, u32 blockSize) { + u32 me = get_local_id(0); + if (MIDDLE <= 16) { + local Z31 *p1 = lds + (me % blockSize) * (workgroupSize / blockSize) + me / blockSize; + local Z31 *p2 = lds + me; + for (int i = 0; i < MIDDLE; ++i) { p1[i * workgroupSize] = u[i].x; } + bar(); + for (int i = 0; i < MIDDLE; ++i) { u[i].x = p2[workgroupSize * i]; } + bar(); + for (int i = 0; i < MIDDLE; ++i) { p1[i * workgroupSize] = u[i].y; } + bar(); + for (int i = 0; i < MIDDLE; ++i) { u[i].y = p2[workgroupSize * i]; } + } +} + +// Do a partial transpose during fftMiddleIn/Out and write the results to global memory +void OVERLOAD middleShuffleWrite(global GF31 *out, GF31 *u, u32 workgroupSize, u32 blockSize) { + u32 me = get_local_id(0); + out += (me % blockSize) * (workgroupSize / blockSize) + me / blockSize; + for (int i = 0; i < MIDDLE; ++i) { out[i * workgroupSize] = u[i]; } +} + +// Do an in-place 16x16 transpose during fftMiddleIn/Out +void OVERLOAD middleShuffle(local GF31 *lds, GF31 *u) { + u32 me = get_local_id(0); + u32 y = me / 16; + u32 x = me % 16; + for (int i = 0; i < MIDDLE; ++i) { +// lds[x * 16 + y] = u[i]; + lds[x * 16 + y ^ x] = u[i]; // Swizzling with XOR should reduce LDS bank conflicts + bar(); +// u[i] = lds[me]; + u[i] = lds[y * 16 + x ^ y]; + bar(); + } +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +void OVERLOAD fft2(GF61* u) { X2(u[0], u[1]); } + +void OVERLOAD fft_MIDDLE(GF61 *u) { +#if MIDDLE == 1 + // Do nothing +#elif MIDDLE == 2 + fft2(u); +#elif MIDDLE == 4 + fft4(u); +#elif MIDDLE == 8 + fft8(u); +#elif MIDDLE == 16 + fft16(u); +#else +#error UNRECOGNIZED MIDDLE +#endif +} + +void OVERLOAD middleMul(GF61 *u, u32 s, TrigGF61 trig) { + assert(s < SMALL_HEIGHT); + if (MIDDLE == 1) return; + +#if !MIDDLE_CHAIN // Read all trig values from memory + + for (u32 k = 1; k < MIDDLE; ++k) { + WADD(k, trig[s]); + s += SMALL_HEIGHT; + } + +#else + + GF61 w = trig[s]; // s / BIG_HEIGHT + WADD(1, w); + if (MIDDLE == 2) return; + +#if SHOULD_BE_FASTER + GF61 sq = csqTrig(w); + WADD(2, sq); + GF61 base = ccubeTrig(sq, w); // GWBUG: compute w^4 as csqTriq(sq), w^6 as ccubeTrig(w2, w4), and w^5 and w^7 as cmul_a_by_b_and_conjb + for (u32 k = 3; k < MIDDLE; ++k) { +#else + GF61 base = csq(w); + for (u32 k = 2; k < MIDDLE; ++k) { +#endif + WADD(k, base); + base = cmul(base, w); + } + +#endif + +} + +void OVERLOAD middleMul2(GF61 *u, u32 x, u32 y, TrigGF61 trig) { + assert(x < WIDTH); + assert(y < SMALL_HEIGHT); + + // First trig table comes after the MiddleMul trig table. Second trig table comes after the first MiddleMul2 trig table. + TrigGF61 trig1 = trig + SMALL_HEIGHT * (MIDDLE - 1); + TrigGF61 trig2 = trig1 + WIDTH; + // The first trig table can be shared with MiddleMul trig table if WIDTH = HEIGHT. + if (WIDTH == SMALL_HEIGHT) trig1 = trig; + + GF61 w = trig1[x]; // x / (MIDDLE * WIDTH) + u32 desired_root = x * y; + GF61 base = cmul(trig2[desired_root % SMALL_HEIGHT], trig1[desired_root / SMALL_HEIGHT]); + + WADD(0, base); + for (u32 k = 1; k < MIDDLE; ++k) { + base = cmul(base, w); + WADD(k, base); + } +} + +// Do a partial transpose during fftMiddleIn/Out +void OVERLOAD middleShuffle(local Z61 *lds, GF61 *u, u32 workgroupSize, u32 blockSize) { + u32 me = get_local_id(0); + if (MIDDLE <= 8) { + local Z61 *p1 = lds + (me % blockSize) * (workgroupSize / blockSize) + me / blockSize; + local Z61 *p2 = lds + me; + for (int i = 0; i < MIDDLE; ++i) { p1[i * workgroupSize] = u[i].x; } + bar(); + for (int i = 0; i < MIDDLE; ++i) { u[i].x = p2[workgroupSize * i]; } + bar(); + for (int i = 0; i < MIDDLE; ++i) { p1[i * workgroupSize] = u[i].y; } + bar(); + for (int i = 0; i < MIDDLE; ++i) { u[i].y = p2[workgroupSize * i]; } + } else { + local int *p1 = ((local int*) lds) + (me % blockSize) * (workgroupSize / blockSize) + me / blockSize; + local int *p2 = (local int*) lds + me; + int4 *pu = (int4 *)u; + + for (int i = 0; i < MIDDLE; ++i) { p1[i * workgroupSize] = pu[i].x; } + bar(); + for (int i = 0; i < MIDDLE; ++i) { pu[i].x = p2[workgroupSize * i]; } + bar(); + for (int i = 0; i < MIDDLE; ++i) { p1[i * workgroupSize] = pu[i].y; } + bar(); + for (int i = 0; i < MIDDLE; ++i) { pu[i].y = p2[workgroupSize * i]; } + bar(); + + for (int i = 0; i < MIDDLE; ++i) { p1[i * workgroupSize] = pu[i].z; } + bar(); + for (int i = 0; i < MIDDLE; ++i) { pu[i].z = p2[workgroupSize * i]; } + bar(); + for (int i = 0; i < MIDDLE; ++i) { p1[i * workgroupSize] = pu[i].w; } + bar(); + for (int i = 0; i < MIDDLE; ++i) { pu[i].w = p2[workgroupSize * i]; } + } +} + +// Do a partial transpose during fftMiddleIn/Out and write the results to global memory +void OVERLOAD middleShuffleWrite(global GF61 *out, GF61 *u, u32 workgroupSize, u32 blockSize) { + u32 me = get_local_id(0); + out += (me % blockSize) * (workgroupSize / blockSize) + me / blockSize; + for (int i = 0; i < MIDDLE; ++i) { out[i * workgroupSize] = u[i]; } +} + +// Do an in-place 16x16 transpose during fftMiddleIn/Out +void OVERLOAD middleShuffle(local GF61 *lds, GF61 *u) { + u32 me = get_local_id(0); + u32 y = me / 16; + u32 x = me % 16; + for (int i = 0; i < MIDDLE; ++i) { +// lds[x * 16 + y] = u[i]; + lds[x * 16 + y ^ x] = u[i]; // Swizzling with XOR should reduce LDS bank conflicts + bar(); +// u[i] = lds[me]; + u[i] = lds[y * 16 + x ^ y]; + bar(); + } +} + +#endif + + +#undef WADD +#undef WADDF +#undef WSUB +#undef WSUBF diff --git a/src/cl/fft10.cl b/src/cl/fft10.cl index 0ede9c75..fe23a0b5 100644 --- a/src/cl/fft10.cl +++ b/src/cl/fft10.cl @@ -1,5 +1,7 @@ // Copyright (C) Mihai Preda +#if FFT_FP64 + #include "fft5.cl" // PFA(5*2): 24 FMA + 68 ADD @@ -14,3 +16,5 @@ void fft10(T2 *u) { CYCLE(2, 4, 8, 6); #undef CYCLE } + +#endif diff --git a/src/cl/fft11.cl b/src/cl/fft11.cl index 5ac4f982..9bc12b13 100644 --- a/src/cl/fft11.cl +++ b/src/cl/fft11.cl @@ -1,5 +1,7 @@ // Copyright (C) Mihai Preda & George Woltman +#if FFT_FP64 + #if 0 // Adapted from https://web.archive.org/web/20101126231320/http://cnx.org/content/col10569/1.7/pdf // 40 FMA + 150 ADD @@ -178,3 +180,5 @@ void fft11(T2 *u) { } #endif + +#endif diff --git a/src/cl/fft12.cl b/src/cl/fft12.cl index 5f74cf8e..6f400edc 100644 --- a/src/cl/fft12.cl +++ b/src/cl/fft12.cl @@ -1,5 +1,7 @@ // Copyright (C) Mihai Preda and George Woltman +#if FFT_FP64 + #if 1 #include "fft3.cl" #include "fft4.cl" @@ -75,3 +77,5 @@ void fft12(T2 *u) { } #endif + +#endif diff --git a/src/cl/fft13.cl b/src/cl/fft13.cl index 6cc27d9c..528fdc04 100644 --- a/src/cl/fft13.cl +++ b/src/cl/fft13.cl @@ -1,5 +1,7 @@ // Copyright (C) Mihai Preda and George Woltman +#if FFT_FP64 + // To calculate a 13-complex FFT in a brute force way (using a shorthand notation): // The sin/cos values (w = 13th root of unity) are: // w^1 = .885 - .465i @@ -155,3 +157,5 @@ void fft13(T2 *u) { fma_addsub(u[5], u[8], SIN1, tmp69a, tmp69b); fma_addsub(u[6], u[7], SIN1, tmp78a, tmp78b); } + +#endif diff --git a/src/cl/fft14.cl b/src/cl/fft14.cl index f786820c..8690cc74 100644 --- a/src/cl/fft14.cl +++ b/src/cl/fft14.cl @@ -1,5 +1,7 @@ // Copyright (C) Mihai Preda and George Woltman +#if FFT_FP64 + #if 1 #include "fft7.cl" @@ -96,3 +98,5 @@ void fft14(T2 *u) { fma_addsub(u[6], u[8], SIN1, tmp79a, tmp79b); } #endif + +#endif diff --git a/src/cl/fft15.cl b/src/cl/fft15.cl index 3052f06c..beb8c895 100644 --- a/src/cl/fft15.cl +++ b/src/cl/fft15.cl @@ -1,5 +1,7 @@ // Copyright (C) Mihai Preda and George Woltman +#if FFT_FP64 + // The fft3by() and fft5by() below use a different "output map" relative to fft3.cl and fft5.cl // This way fft15() does not need a "fix order" step at the end. // See "An In-Place, In-Order Prime Factor Algorithm" by Burrus & Eschenbacher (1981) @@ -73,3 +75,5 @@ void fft15(T2 *u) { fft5_15(u, 5); fft5_15(u, 10); } + +#endif diff --git a/src/cl/fft16.cl b/src/cl/fft16.cl index a07599ba..7cbbb24b 100644 --- a/src/cl/fft16.cl +++ b/src/cl/fft16.cl @@ -1,5 +1,6 @@ // Copyright (C) Mihai Preda +#if FFT_FP64 #if 0 @@ -8,7 +9,7 @@ #include "fft4.cl" // 24 FMA (of which 16 MUL) + 136 ADD -void fft16(T2 *u) { +void OVERLOAD fft16(T2 *u) { double C1 = 0.92387953251128674, // cos(tau/16) S1 = 0.38268343236508978; // sin(tau/16) @@ -43,7 +44,7 @@ void fft16(T2 *u) { #include "fft8.cl" -void fft16(T2 *u) { +void OVERLOAD fft16(T2 *u) { double C1 = 0.92387953251128674, // cos(tau/16) S1 = 0.38268343236508978; // sin(tau/16) @@ -77,7 +78,7 @@ void fft16(T2 *u) { // FFT-16 Adapted from Nussbaumer, "Fast Fourier Transform and Convolution Algorithms" // 28 FMA + 124 ADD -void fft16(T2 *u) { +void OVERLOAD fft16(T2 *u) { double C1 = 0.70710678118654757, // cos(2t/16) C2 = 0.38268343236508978, // cos(3t/16) @@ -152,3 +153,128 @@ void fft16(T2 *u) { } #endif + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 + +#include "fft8.cl" + +void OVERLOAD fft16(F2 *u) { + float + C1 = 0.92387953251128674, // cos(tau/16) + S1 = 0.38268343236508978; // sin(tau/16) + + for (int i = 0; i < 8; ++i) { X2(u[i], u[i + 8]); } + u[ 9] = cmul(u[ 9], U2( C1, S1)); // 1t16 + u[11] = cmul(u[11], U2( S1, C1)); // 3t16 + u[13] = cmul(u[13], U2(-S1, C1)); // 5t16 + u[15] = cmul(u[15], U2(-C1, S1)); // 7t16 + + u[10] = mul_t8(u[10]); + u[12] = mul_t4(u[12]); + u[14] = mul_3t8(u[14]); + + fft8Core(u); + fft8Core(u + 8); + + // revbin fix order + // 0 8 4 12 2 10 6 14 1 9 5 13 3 11 7 15 + SWAP(u[1], u[8]); + SWAP(u[2], u[4]); + SWAP(u[3], u[12]); + SWAP(u[5], u[10]); + SWAP(u[7], u[14]); + SWAP(u[11], u[13]); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +#include "fft8.cl" + +void OVERLOAD fft16(GF31 *u) { + const Z31 C1 = 1556715293; + const Z31 S1 = 978592373; + + X2(u[0], u[8]); + X2(u[1], u[9]); + X2_mul_t8(u[2], u[10]); + X2(u[3], u[11]); + X2_mul_t4(u[4], u[12]); + X2(u[5], u[13]); + X2_mul_3t8(u[6], u[14]); + X2(u[7], u[15]); + + u[ 9] = cmul(u[ 9], U2( C1, S1)); // 1t16 + u[11] = cmul(u[11], U2( S1, C1)); // 3t16 + u[13] = cmul(u[13], U2(neg(S1), C1)); // 5t16 //GWBUG - check if optimizer is eliminating the neg (or better yet perhaps tweak follow up code to expect a negative) + u[15] = cmul(u[15], U2(neg(C1), S1)); // 7t16 + + fft8Core(u); + fft8Core(u + 8); + + // revbin fix order + // 0 8 4 12 2 10 6 14 1 9 5 13 3 11 7 15 + SWAP(u[1], u[8]); + SWAP(u[2], u[4]); + SWAP(u[3], u[12]); + SWAP(u[5], u[10]); + SWAP(u[7], u[14]); + SWAP(u[11], u[13]); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +#include "fft8.cl" + +void OVERLOAD fft16(GF61 *u) { + const Z61 C1 = 22027337052962166ULL; + const Z61 S1 = 1693317751237720973ULL; + + X2(u[0], u[8]); + X2(u[1], u[9]); + X2_mul_t8(u[2], u[10]); + X2(u[3], u[11]); + X2_mul_t4(u[4], u[12]); + X2(u[5], u[13]); + X2_mul_3t8(u[6], u[14]); + X2(u[7], u[15]); + + u[ 9] = cmul(u[ 9], U2( C1, S1)); // 1t16 + u[11] = cmul(u[11], U2( S1, C1)); // 3t16 + u[13] = cmul(u[13], U2(neg(S1), C1)); // 5t16 //GWBUG - check if optimizer is eliminating the neg (or better yet perhaps tweak follow up code to expect a negative) + u[15] = cmul(u[15], U2(neg(C1), S1)); // 7t16 + + fft8Core(u); + fft8Core(u + 8); + + // revbin fix order + // 0 8 4 12 2 10 6 14 1 9 5 13 3 11 7 15 + SWAP(u[1], u[8]); + SWAP(u[2], u[4]); + SWAP(u[3], u[12]); + SWAP(u[5], u[10]); + SWAP(u[7], u[14]); + SWAP(u[11], u[13]); +} + +#endif diff --git a/src/cl/fft3.cl b/src/cl/fft3.cl index cc501705..ee5b7af3 100644 --- a/src/cl/fft3.cl +++ b/src/cl/fft3.cl @@ -2,6 +2,8 @@ #pragma once +#if FFT_FP64 + // 6 FMA + 6 ADD void fft3by(T2 *u, u32 base, u32 step, u32 M) { #define A(k) u[(base + k * step) % M] @@ -28,3 +30,5 @@ void fft3by(T2 *u, u32 base, u32 step, u32 M) { } void fft3(T2 *u) { fft3by(u, 0, 1, 3); } + +#endif diff --git a/src/cl/fft4.cl b/src/cl/fft4.cl index 02d740f0..8b422ca9 100644 --- a/src/cl/fft4.cl +++ b/src/cl/fft4.cl @@ -2,7 +2,9 @@ #pragma once -void fft4Core(T2 *u) { +#if FFT_FP64 + +void OVERLOAD fft4Core(T2 *u) { X2(u[0], u[2]); X2(u[1], u[3]); u[3] = mul_t4(u[3]); @@ -11,7 +13,7 @@ void fft4Core(T2 *u) { } // 16 ADD -void fft4by(T2 *u, u32 base, u32 step, u32 M) { +void OVERLOAD fft4by(T2 *u, u32 base, u32 step, u32 M) { #define A(k) u[(base + step * k) % M] @@ -59,4 +61,182 @@ void fft4by(T2 *u, u32 base, u32 step, u32 M) { } -void fft4(T2 *u) { fft4by(u, 0, 1, 4); } +void OVERLOAD fft4(T2 *u) { fft4by(u, 0, 1, 4); } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 + +void OVERLOAD fft4Core(F2 *u) { + X2(u[0], u[2]); + X2(u[1], u[3]); u[3] = mul_t4(u[3]); + + X2(u[0], u[1]); + X2(u[2], u[3]); +} + +// 16 ADD +void OVERLOAD fft4by(F2 *u, u32 base, u32 step, u32 M) { + +#define A(k) u[(base + step * k) % M] + +#if 1 + float x0 = A(0).x + A(2).x; + float x2 = A(0).x - A(2).x; + float y0 = A(0).y + A(2).y; + float y2 = A(0).y - A(2).y; + + float x1 = A(1).x + A(3).x; + float y3 = A(1).x - A(3).x; + float y1 = A(1).y + A(3).y; + float x3 = -(A(1).y - A(3).y); + + float a0 = x0 + x1; + float a1 = x0 - x1; + + float b0 = y0 + y1; + float b1 = y0 - y1; + + float a2 = x2 + x3; + float a3 = x2 - x3; + + float b2 = y2 + y3; + float b3 = y2 - y3; + + A(0) = U2(a0, b0); + A(1) = U2(a2, b2); + A(2) = U2(a1, b1); + A(3) = U2(a3, b3); + +#else + + X2(A(0), A(2)); + X2(A(1), A(3)); + X2(A(0), A(1)); + + A(3) = mul_t4(A(3)); + X2(A(2), A(3)); + SWAP(A(1), A(2)); + +#endif + +#undef A + +} + +void OVERLOAD fft4(F2 *u) { fft4by(u, 0, 1, 4); } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +void OVERLOAD fft4Core(GF31 *u) { + X2(u[0], u[2]); + X2_mul_t4(u[1], u[3]); + X2(u[0], u[1]); + X2(u[2], u[3]); +} + +// 16 ADD +void OVERLOAD fft4by(GF31 *u, u32 base, u32 step, u32 M) { + +#define A(k) u[(base + step * k) % M] + + Z31 x0 = add(A(0).x, A(2).x); //GWBUG: Delay some of the mods using 64 bit temps? + Z31 x2 = sub(A(0).x, A(2).x); + Z31 y0 = add(A(0).y, A(2).y); + Z31 y2 = sub(A(0).y, A(2).y); + + Z31 x1 = add(A(1).x, A(3).x); + Z31 y3 = sub(A(1).x, A(3).x); + Z31 y1 = add(A(1).y, A(3).y); + Z31 x3 = sub(A(3).y, A(1).y); + + Z31 a0 = add(x0, x1); + Z31 a1 = sub(x0, x1); + + Z31 b0 = add(y0, y1); + Z31 b1 = sub(y0, y1); + + Z31 a2 = add(x2, x3); + Z31 a3 = sub(x2, x3); + + Z31 b2 = add(y2, y3); + Z31 b3 = sub(y2, y3); + + A(0) = U2(a0, b0); + A(1) = U2(a2, b2); + A(2) = U2(a1, b1); + A(3) = U2(a3, b3); + +#undef A + +} + +void OVERLOAD fft4(GF31 *u) { fft4by(u, 0, 1, 4); } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +void OVERLOAD fft4Core(GF61 *u) { // Starts with all u[i] having maximum values of M61+epsilon. + X2q(&u[0], &u[2], 2); // X2(u[0], u[2]); No reductions mod M61. Will require 3 M61s additions to make positives. + X2q_mul_t4(&u[1], &u[3], 2); // X2(u[1], u[3]); u[3] = mul_t4(u[3]); + X2s(&u[0], &u[1], 3); + X2s(&u[2], &u[3], 3); +} + +// 16 ADD +void OVERLOAD fft4by(GF61 *u, u32 base, u32 step, u32 M) { + +#define A(k) u[(base + step * k) % M] + + Z61 x0 = addq(A(0).x, A(2).x); // Max value is 2*M61+epsilon + Z61 x2 = subq(A(0).x, A(2).x, 2); // Max value is 3*M61+epsilon + Z61 y0 = addq(A(0).y, A(2).y); + Z61 y2 = subq(A(0).y, A(2).y, 2); + + Z61 x1 = addq(A(1).x, A(3).x); + Z61 y3 = subq(A(1).x, A(3).x, 2); + Z61 y1 = addq(A(1).y, A(3).y); + Z61 x3 = subq(A(3).y, A(1).y, 2); + + Z61 a0 = add(x0, x1); + Z61 a1 = subs(x0, x1, 3); + + Z61 b0 = add(y0, y1); + Z61 b1 = subs(y0, y1, 3); + + Z61 a2 = add(x2, x3); + Z61 a3 = subs(x2, x3, 4); + + Z61 b2 = add(y2, y3); + Z61 b3 = subs(y2, y3, 4); + + A(0) = U2(a0, b0); + A(1) = U2(a2, b2); + A(2) = U2(a1, b1); + A(3) = U2(a3, b3); + +#undef A + +} + +void OVERLOAD fft4(GF61 *u) { fft4by(u, 0, 1, 4); } + +#endif diff --git a/src/cl/fft5.cl b/src/cl/fft5.cl index 2cf446b2..31e2ba9d 100644 --- a/src/cl/fft5.cl +++ b/src/cl/fft5.cl @@ -2,6 +2,8 @@ #pragma once +#if FFT_FP64 + // Adapted from: Nussbaumer, "Fast Fourier Transform and Convolution Algorithms", 5.5.4 "5-Point DFT". // 12 FMA + 24 ADD (or 10 FMA + 28 ADD) void fft5by(T2 *u, u32 base, u32 step, u32 m) { @@ -45,3 +47,5 @@ void fft5by(T2 *u, u32 base, u32 step, u32 m) { } void fft5(T2 *u) { return fft5by(u, 0, 1, 5); } + +#endif diff --git a/src/cl/fft6.cl b/src/cl/fft6.cl index d25466fe..54299286 100644 --- a/src/cl/fft6.cl +++ b/src/cl/fft6.cl @@ -2,6 +2,8 @@ #include "fft3.cl" +#if FFT_FP64 + // 12 FMA + 24 ADD void fft6(T2 *u) { #if 1 @@ -36,3 +38,5 @@ void fft6(T2 *u) { fma_addsub(u[2], u[4], -SIN1, tmp35a, u[2]); #endif } + +#endif diff --git a/src/cl/fft7.cl b/src/cl/fft7.cl index 1424148e..450cfa7d 100644 --- a/src/cl/fft7.cl +++ b/src/cl/fft7.cl @@ -4,6 +4,8 @@ #include "base.cl" +#if FFT_FP64 + #define A(i) u[(base + i * step) % M] #if 1 @@ -108,3 +110,5 @@ void fft7by(T2 *u, u32 base, u32 step, u32 M) { #undef A void fft7(T2 *u) { return fft7by(u, 0, 1, 7); } + +#endif diff --git a/src/cl/fft8.cl b/src/cl/fft8.cl index 9b113795..56e2fc94 100644 --- a/src/cl/fft8.cl +++ b/src/cl/fft8.cl @@ -4,31 +4,162 @@ #include "fft4.cl" +#if FFT_FP64 + T2 mul_t8_delayed(T2 a) { return U2(a.x - a.y, a.x + a.y); } T2 mul_3t8_delayed(T2 a) { return U2(-(a.x + a.y), a.x - a.y); } //#define X2_apply_delay(a, b) { T2 t = a; a = t + M_SQRT1_2 * b; b = t - M_SQRT1_2 * b; } #define X2_apply_delay(a, b) { T2 t = a; a.x = fma(b.x, M_SQRT1_2, a.x); a.y = fma(b.y, M_SQRT1_2, a.y); b.x = fma(-M_SQRT1_2, b.x, t.x); b.y = fma(-M_SQRT1_2, b.y, t.y); } -void fft4CoreSpecial(T2 *u) { +void OVERLOAD fft4CoreSpecial(T2 *u) { X2(u[0], u[2]); - X2(u[1], u[3]); u[3] = mul_t4(u[3]); + X2_mul_t4(u[1], u[3]); // X2(u[1], u[3]); u[3] = mul_t4(u[3]); X2_apply_delay(u[0], u[1]); X2_apply_delay(u[2], u[3]); } -void fft8Core(T2 *u) { +void OVERLOAD fft8Core(T2 *u) { X2(u[0], u[4]); X2(u[1], u[5]); u[5] = mul_t8_delayed(u[5]); - X2(u[2], u[6]); u[6] = mul_t4(u[6]); + X2_mul_t4(u[2], u[6]); // X2(u[2], u[6]); u[6] = mul_t4(u[6]); X2(u[3], u[7]); u[7] = mul_3t8_delayed(u[7]); fft4Core(u); fft4CoreSpecial(u + 4); } // 4 MUL + 52 ADD -void fft8(T2 *u) { +void OVERLOAD fft8(T2 *u) { fft8Core(u); // revbin [0, 4, 2, 6, 1, 5, 3, 7] undo SWAP(u[1], u[4]); SWAP(u[3], u[6]); } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 + +F2 mul_t8_delayed(F2 a) { return U2(a.x - a.y, a.x + a.y); } +F2 mul_3t8_delayed(F2 a) { return U2(-(a.x + a.y), a.x - a.y); } +//#define X2_apply_delay(a, b) { F2 t = a; a = t + M_SQRT1_2 * b; b = t - M_SQRT1_2 * b; } +#define X2_apply_delay(a, b) { F2 t = a; a.x = fma(b.x, (float) M_SQRT1_2, a.x); a.y = fma(b.y, (float) M_SQRT1_2, a.y); b.x = fma((float) -M_SQRT1_2, b.x, t.x); b.y = fma((float) -M_SQRT1_2, b.y, t.y); } + +void OVERLOAD fft4CoreSpecial(F2 *u) { + X2(u[0], u[2]); + X2_mul_t4(u[1], u[3]); // X2(u[1], u[3]); u[3] = mul_t4(u[3]); + X2_apply_delay(u[0], u[1]); + X2_apply_delay(u[2], u[3]); +} + +void OVERLOAD fft8Core(F2 *u) { + X2(u[0], u[4]); + X2(u[1], u[5]); u[5] = mul_t8_delayed(u[5]); + X2_mul_t4(u[2], u[6]); // X2(u[2], u[6]); u[6] = mul_t4(u[6]); + X2(u[3], u[7]); u[7] = mul_3t8_delayed(u[7]); + fft4Core(u); + fft4CoreSpecial(u + 4); +} + +// 4 MUL + 52 ADD +void OVERLOAD fft8(F2 *u) { + fft8Core(u); + // revbin [0, 4, 2, 6, 1, 5, 3, 7] undo + SWAP(u[1], u[4]); + SWAP(u[3], u[6]); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +void OVERLOAD fft8Core(GF31 *u) { + X2(u[0], u[4]); + X2_mul_t8(u[1], u[5]); + X2_mul_t4(u[2], u[6]); + X2_mul_3t8(u[3], u[7]); + fft4Core(u); + fft4Core(u + 4); +} + +// 4 MUL + 52 ADD +void OVERLOAD fft8(GF31 *u) { + fft8Core(u); + // revbin [0, 4, 2, 6, 1, 5, 3, 7] undo + SWAP(u[1], u[4]); + SWAP(u[3], u[6]); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +#if 0 // Working code. + +void OVERLOAD fft8Core(GF61 *u) { + X2(u[0], u[4]); //GWBUG: Delay some mods using extra 3 bits of Z61 + X2_mul_t8(u[1], u[5]); // X2(u[1], u[5]); u[5] = mul_t8(u[5]); + X2_mul_t4(u[2], u[6]); // X2(u[2], u[6]); u[6] = mul_t4(u[6]); + X2_mul_3t8(u[3], u[7]); // X2(u[3], u[7]); u[7] = mul_3t8(u[7]); + fft4Core(u); + fft4Core(u + 4); +} + +void OVERLOAD fft8(GF61 *u) { + fft8Core(u); + // revbin [0, 4, 2, 6, 1, 5, 3, 7] undo + SWAP(u[1], u[4]); + SWAP(u[3], u[6]); +} + +#else // Carefully track the size of numbers to reduce the number of mod M61 reductions + +void OVERLOAD fft4CoreSpecial1(GF61 *u) { // Starts with u[0,1,2,3] having maximum values of (2,2,3,2)*M61+epsilon. + X2q(&u[0], &u[2], 4); // X2(u[0], u[2]); No reductions mod M61. u[0,2] max value is 5,6*M61+epsilon. + X2q_mul_t4(&u[1], &u[3], 3); // X2(u[1], u[3]); u[3] = mul_t4(u[3]); u[1,3] max value is 5,4*M61+epsilon. + u[1] = modM61(u[1]); u[2] = modM61(u[2]); // Reduce the worst offenders. u[0,1,2,3] have maximum values of (5,1,1,4)*M61+epsilon. + X2s(&u[0], &u[1], 2); // u[0,1] max value before reduction is 6,7*M61+epsilon + X2s(&u[2], &u[3], 5); // u[2,3] max value before reduction is 5,6*M61+epsilon +} + +void OVERLOAD fft4CoreSpecial2(GF61 *u) { // Similar to above. Starts with u[0,1,2,3] having maximum values of (3,1,2,1)*M61+epsilon. + X2q(&u[0], &u[2], 3); // u[0,2] max value is 5,6*M61+epsilon. + X2q_mul_t4(&u[1], &u[3], 2); // X2(u[1], u[3]); u[3] = mul_t4(u[3]); u[1,3] max value is 3,2*M61+epsilon. + u[0] = modM61(u[0]); u[2] = modM61(u[2]); // Reduce the worst offenders u[0,1,2,3] have maximum values of (1,3,1,2)*M61+epsilon. + X2s(&u[0], &u[1], 4); // u[0,1] max value before reduction is 4,5*M61+epsilon + X2s(&u[2], &u[3], 3); // u[2,3] max value before reduction is 3,4*M61+epsilon +} + +void OVERLOAD fft8Core(GF61 *u) { // Starts with all u[i] having maximum values of M61+epsilon. + X2q(&u[0], &u[4], 2); // X2(u[0], u[4]); No reductions mod M61. u[0,4] max value is 2,3*M61+epsilon. + X2q_mul_t8(&u[1], &u[5], 2); // X2(u[1], u[5]); u[5] = mul_t8(u[5]); u[1,5] max value is 2,1*M61+epsilon. + X2q_mul_t4(&u[2], &u[6], 2); // X2(u[2], u[6]); u[6] = mul_t4(u[6]); u[2,6] max value is 3,2*M61+epsilon. + X2q_mul_3t8(&u[3], &u[7], 2); // X2(u[3], u[7]); u[7] = mul_3t8(u[7]); u[3,7] max value is 2,1*M61+epsilon. + fft4CoreSpecial1(u); + fft4CoreSpecial2(u + 4); +} + +void OVERLOAD fft8(GF61 *u) { + fft8Core(u); + // revbin [0, 4, 2, 6, 1, 5, 3, 7] undo + SWAP(u[1], u[4]); + SWAP(u[3], u[6]); +} + +#endif + +#endif diff --git a/src/cl/fft9.cl b/src/cl/fft9.cl index a077d551..c980ec11 100644 --- a/src/cl/fft9.cl +++ b/src/cl/fft9.cl @@ -1,5 +1,7 @@ // Copyright (C) Mihai Preda and George Woltman +#if FFT_FP64 + // Adapted from: Nussbaumer, "Fast Fourier Transform and Convolution Algorithms", 5.5.7 "9-Point DFT". // 12 FMA + 8 MUL, 72 ADD void fft9(T2 *u) { @@ -61,3 +63,5 @@ void fft9(T2 *u) { X2(u[2], u[7]); X2(u[1], u[8]); } + +#endif diff --git a/src/cl/fftbase.cl b/src/cl/fftbase.cl index 4e5f9584..c955e803 100644 --- a/src/cl/fftbase.cl +++ b/src/cl/fftbase.cl @@ -5,7 +5,9 @@ #include "trig.cl" // #include "math.cl" -void chainMul4(T2 *u, T2 w) { +#if FFT_FP64 + +void OVERLOAD chainMul4(T2 *u, T2 w) { u[1] = cmul(u[1], w); T2 base = csqTrig(w); @@ -19,7 +21,7 @@ void chainMul4(T2 *u, T2 w) { #if 1 // This version of chainMul8 tries to minimize roundoff error even if more F64 ops are used. // Trial and error looking at Z values on a WIDTH=512 FFT was used to determine when to switch from fancy to non-fancy powers of w. -void chainMul8(T2 *u, T2 w, u32 tailSquareBcast) { +void OVERLOAD chainMul8(T2 *u, T2 w, u32 tailSquareBcast) { u[1] = cmulFancy(u[1], w); T2 w2; @@ -57,7 +59,7 @@ void chainMul8(T2 *u, T2 w, u32 tailSquareBcast) { // This version is slower on R7Pro due to a rocm optimizer issue in double-wide single-kernel tailSquare using BCAST. I could not find a work-around. // Other GPUs??? This version might be useful. If we decide to make this available, it will need a new width and height fft spec number. // Consequently, an increase in the BPW table and increase work for -ztune and -tune. -void chainMul8(T2 *u, T2 w, u32 tailSquareBcast) { +void OVERLOAD chainMul8(T2 *u, T2 w, u32 tailSquareBcast) { u[1] = cmulFancy(u[1], w); T2 w2 = csqTrigFancy(w); @@ -77,7 +79,7 @@ void chainMul8(T2 *u, T2 w, u32 tailSquareBcast) { } #endif -void chainMul(u32 len, T2 *u, T2 w, u32 tailSquareBcast) { +void OVERLOAD chainMul(u32 len, T2 *u, T2 w, u32 tailSquareBcast) { // Do a length 4 chain mul, w must not be in Fancy format if (len == 4) chainMul4(u, w); // Do a length 8 chain mul, w must be in Fancy format @@ -85,7 +87,7 @@ void chainMul(u32 len, T2 *u, T2 w, u32 tailSquareBcast) { } -#if AMDGPU +#if AMDGPU && (FFT_VARIANT_W == 0 || FFT_VARIANT_H == 0) int bcast4(int x) { return __builtin_amdgcn_mov_dpp(x, 0, 0xf, 0xf, false); } int bcast8(int x) { return __builtin_amdgcn_ds_swizzle(x, 0x0018); } @@ -104,7 +106,7 @@ T2 bcast(T2 src, u32 span) { #endif -void shuflBigLDS(u32 WG, local T2 *lds, T2 *u, u32 n, u32 f) { +void OVERLOAD shuflBigLDS(u32 WG, local T2 *lds, T2 *u, u32 n, u32 f) { u32 me = get_local_id(0); u32 mask = f - 1; assert((mask & (mask + 1)) == 0); @@ -114,7 +116,7 @@ void shuflBigLDS(u32 WG, local T2 *lds, T2 *u, u32 n, u32 f) { for (u32 i = 0; i < n; ++i) { u[i] = lds[i * WG + me]; } } -void shufl(u32 WG, local T2 *lds2, T2 *u, u32 n, u32 f) { +void OVERLOAD shufl(u32 WG, local T2 *lds2, T2 *u, u32 n, u32 f) { u32 me = get_local_id(0); local T* lds = (local T*) lds2; @@ -133,7 +135,7 @@ void shufl(u32 WG, local T2 *lds2, T2 *u, u32 n, u32 f) { // Lower LDS requirements should let the optimizer use fewer VGPRs and increase occupancy for WIDTHs >= 1024. // Alas, the increased occupancy does not offset extra code needed for shufl_int (the assembly // code generated is not pretty). This might not be true for nVidia or future ROCm optimizers. -void shufl_int(u32 WG, local T2 *lds2, T2 *u, u32 n, u32 f) { +void OVERLOAD shufl_int(u32 WG, local T2 *lds2, T2 *u, u32 n, u32 f) { u32 me = get_local_id(0); local int* lds = (local int*) lds2; @@ -161,7 +163,7 @@ void shufl_int(u32 WG, local T2 *lds2, T2 *u, u32 n, u32 f) { // NOTE: It is very important for this routine to use lds memory in coordination with reverseLine2 and unreverseLine2. // Failure to do so would result in the need for more bar() calls. Specifically, the u values are stored in the upper half // of lds memory (first SMALL_HEIGHT T2 values). The v values are stored in the lower half of lds memory (next SMALL_HEIGHT T2 values). -void shufl2(u32 WG, local T2 *lds2, T2 *u, u32 n, u32 f) { +void OVERLOAD shufl2(u32 WG, local T2 *lds2, T2 *u, u32 n, u32 f) { u32 me = get_local_id(0); // Partition lds memory into upper and lower halves @@ -173,7 +175,7 @@ void shufl2(u32 WG, local T2 *lds2, T2 *u, u32 n, u32 f) { me = me % WG; u32 mask = f - 1; assert((mask & (mask + 1)) == 0); - + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i].x; } bar(WG); for (u32 i = 0; i < n; ++i) { u[i].x = lds[i * WG + me]; } @@ -183,7 +185,7 @@ void shufl2(u32 WG, local T2 *lds2, T2 *u, u32 n, u32 f) { for (u32 i = 0; i < n; ++i) { u[i].y = lds[i * WG + me]; } } -void tabMul(u32 WG, Trig trig, T2 *u, u32 n, u32 f, u32 me) { +void OVERLOAD tabMul(u32 WG, Trig trig, T2 *u, u32 n, u32 f, u32 me) { #if 0 u32 p = me / f * f; #else @@ -470,3 +472,373 @@ void finish_tabMul8_fft8(u32 WG, local T2 *lds, Trig trig, T *preloads, T2 *u, u SWAP(u[1], u[4]); SWAP(u[3], u[6]); } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 + +void OVERLOAD chainMul4(F2 *u, F2 w) { + u[1] = cmul(u[1], w); + + F2 base = csqTrig(w); + u[2] = cmul(u[2], base); + + F a = mul2(base.y); + base = U2(fma(a, -w.y, w.x), fma(a, w.x, -w.y)); + u[3] = cmul(u[3], base); +} + +void OVERLOAD chainMul8(F2 *u, F2 w, u32 tailSquareBcast) { + u[1] = cmulFancy(u[1], w); + //GWBUG - see FP64 version for many possible optimizations + F2 w2 = csqTrigFancy(w); + u[2] = cmulFancy(u[2], w2); + + F2 w3 = ccubeTrigFancy(w2, w); + u[3] = cmulFancy(u[3], w3); + + w3.x += 1; + F2 base = cmulFancy (w3, w); + for (int i = 4; i < 8; ++i) { + u[i] = cmul(u[i], base); + base = cmulFancy(base, w); + } +} + +void OVERLOAD chainMul(u32 len, F2 *u, F2 w, u32 tailSquareBcast) { + // Do a length 4 chain mul + if (len == 4) chainMul4(u, w); + // Do a length 8 chain mul + if (len == 8) chainMul8(u, w, tailSquareBcast); +} + +void OVERLOAD shuflBigLDS(u32 WG, local F2 *lds, F2 *u, u32 n, u32 f) { + u32 me = get_local_id(0); + u32 mask = f - 1; + assert((mask & (mask + 1)) == 0); + + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i]; } + bar(); + for (u32 i = 0; i < n; ++i) { u[i] = lds[i * WG + me]; } +} + +void OVERLOAD shufl(u32 WG, local F2 *lds2, F2 *u, u32 n, u32 f) { //GWBUG - is shufl of int2 faster (BigLDS)? + u32 me = get_local_id(0); + local F* lds = (local F*) lds2; + + u32 mask = f - 1; + assert((mask & (mask + 1)) == 0); + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i].x; } + bar(); + for (u32 i = 0; i < n; ++i) { u[i].x = lds[i * WG + me]; } + bar(); + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i].y; } + bar(); + for (u32 i = 0; i < n; ++i) { u[i].y = lds[i * WG + me]; } +} + +// Shufl two simultaneous FFT_HEIGHTs. Needed for tailSquared where u and v are computed simultaneously in different threads. +// NOTE: It is very important for this routine to use lds memory in coordination with reverseLine2 and unreverseLine2. +// Failure to do so would result in the need for more bar() calls. Specifically, the u values are stored in the upper half +// of lds memory (first SMALL_HEIGHT GF31 values). The v values are stored in the lower half of lds memory (next SMALL_HEIGHT GF31 values). +void OVERLOAD shufl2(u32 WG, local F2 *lds2, F2 *u, u32 n, u32 f) { + u32 me = get_local_id(0); + + // Partition lds memory into upper and lower halves + assert(WG == G_H); + + // Accessing lds memory as F is faster than F2 accesses //GWBUG??? + local F* lds = ((local F*) lds2) + (me / WG) * SMALL_HEIGHT; + + me = me % WG; + u32 mask = f - 1; + assert((mask & (mask + 1)) == 0); + + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i].x; } + bar(WG); + for (u32 i = 0; i < n; ++i) { u[i].x = lds[i * WG + me]; } + bar(WG); + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i].y; } + bar(WG); + for (u32 i = 0; i < n; ++i) { u[i].y = lds[i * WG + me]; } +} + +void OVERLOAD tabMul(u32 WG, TrigFP32 trig, F2 *u, u32 n, u32 f, u32 me) { + u32 p = me & ~(f - 1); + +// This code uses chained complex multiplies which could be faster on GPUs with great mul throughput or poor memory bandwidth or caching. + + if (TABMUL_CHAIN32) { + chainMul (n, u, trig[p], 0); + return; + } + +// Use memory accesses (probably cached) to reduce complex muls. Beneficial when memory bandwidth is not the bottleneck. + + if (!TABMUL_CHAIN32) { + if (n >= 8) { + u[1] = cmulFancy(u[1], trig[p]); + } else { + u[1] = cmul(u[1], trig[p]); + } + for (u32 i = 2; i < n; ++i) { + u[i] = cmul(u[i], trig[(i-1)*WG + p]); + } + return; + } +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +void OVERLOAD chainMul4(GF31 *u, GF31 w) { + u[1] = cmul(u[1], w); + + GF31 base = csqTrig(w); + u[2] = cmul(u[2], base); + + base = ccubeTrig(base, w); + u[3] = cmul(u[3], base); +} + +void OVERLOAD chainMul8(GF31 *u, GF31 w) { + u[1] = cmul(u[1], w); + + GF31 base = csqTrig(w); + u[2] = cmul(u[2], base); + + base = ccubeTrig(base, w); + for (int i = 3; i < 8; ++i) { + u[i] = cmul(u[i], base); + base = cmul(base, w); + } +} + +void OVERLOAD chainMul(u32 len, GF31 *u, GF31 w) { + // Do a length 4 chain mul + if (len == 4) chainMul4(u, w); + // Do a length 8 chain mul + if (len == 8) chainMul8(u, w); +} + +void OVERLOAD shuflBigLDS(u32 WG, local GF31 *lds, GF31 *u, u32 n, u32 f) { + u32 me = get_local_id(0); + u32 mask = f - 1; + assert((mask & (mask + 1)) == 0); + + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i]; } + bar(); + for (u32 i = 0; i < n; ++i) { u[i] = lds[i * WG + me]; } +} + +void OVERLOAD shufl(u32 WG, local GF31 *lds2, GF31 *u, u32 n, u32 f) { //GWBUG - is shufl of int2 faster (BigLDS)? + u32 me = get_local_id(0); + local Z31* lds = (local Z31*) lds2; + + u32 mask = f - 1; + assert((mask & (mask + 1)) == 0); + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i].x; } + bar(); + for (u32 i = 0; i < n; ++i) { u[i].x = lds[i * WG + me]; } + bar(); + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i].y; } + bar(); + for (u32 i = 0; i < n; ++i) { u[i].y = lds[i * WG + me]; } +} + +// Shufl two simultaneous FFT_HEIGHTs. Needed for tailSquared where u and v are computed simultaneously in different threads. +// NOTE: It is very important for this routine to use lds memory in coordination with reverseLine2 and unreverseLine2. +// Failure to do so would result in the need for more bar() calls. Specifically, the u values are stored in the upper half +// of lds memory (first SMALL_HEIGHT GF31 values). The v values are stored in the lower half of lds memory (next SMALL_HEIGHT GF31 values). +void OVERLOAD shufl2(u32 WG, local GF31 *lds2, GF31 *u, u32 n, u32 f) { + u32 me = get_local_id(0); + + // Partition lds memory into upper and lower halves + assert(WG == G_H); + + // Accessing lds memory as Z31s is faster than GF31 accesses //GWBUG??? + local Z31* lds = ((local Z31*) lds2) + (me / WG) * SMALL_HEIGHT; + + me = me % WG; + u32 mask = f - 1; + assert((mask & (mask + 1)) == 0); + + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i].x; } + bar(WG); + for (u32 i = 0; i < n; ++i) { u[i].x = lds[i * WG + me]; } + bar(WG); + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i].y; } + bar(WG); + for (u32 i = 0; i < n; ++i) { u[i].y = lds[i * WG + me]; } +} + +void OVERLOAD tabMul(u32 WG, TrigGF31 trig, GF31 *u, u32 n, u32 f, u32 me) { + u32 p = me & ~(f - 1); + +// This code uses chained complex multiplies which could be faster on GPUs with great mul throughput or poor memory bandwidth or caching. + + if (TABMUL_CHAIN31) { + chainMul (n, u, trig[p]); + return; + } + +// Use memory accesses (probably cached) to reduce complex muls. Beneficial when memory bandwidth is not the bottleneck. + + if (!TABMUL_CHAIN31) { + for (u32 i = 1; i < n; ++i) { + u[i] = cmul(u[i], trig[(i-1)*WG + p]); + } + return; + } +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +void OVERLOAD chainMul4(GF61 *u, GF61 w) { + u[1] = cmul(u[1], w); + + GF61 base = csq(w); + u[2] = cmul(u[2], base); + + base = cmul(base, w); //GWBUG - see FP64 version for possible optimization + u[3] = cmul(u[3], base); +} + +void OVERLOAD chainMul8(GF61 *u, GF61 w, u32 tailSquareBcast) { + u[1] = cmul(u[1], w); + + GF61 w2 = csq(w); + u[2] = cmul(u[2], w2); + + GF61 base = cmul (w2, w); //GWBUG - see FP64 version for many possible optimizations + for (int i = 3; i < 8; ++i) { + u[i] = cmul(u[i], base); + base = cmul(base, w); + } +} + +void OVERLOAD chainMul(u32 len, GF61 *u, GF61 w, u32 tailSquareBcast) { + // Do a length 4 chain mul + if (len == 4) chainMul4(u, w); + // Do a length 8 chain mul + if (len == 8) chainMul8(u, w, tailSquareBcast); +} + +void OVERLOAD shuflBigLDS(u32 WG, local GF61 *lds, GF61 *u, u32 n, u32 f) { + u32 me = get_local_id(0); + u32 mask = f - 1; + assert((mask & (mask + 1)) == 0); + + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i]; } + bar(); + for (u32 i = 0; i < n; ++i) { u[i] = lds[i * WG + me]; } +} + +void OVERLOAD shufl(u32 WG, local GF61 *lds2, GF61 *u, u32 n, u32 f) { + u32 me = get_local_id(0); + local Z61* lds = (local Z61*) lds2; + + u32 mask = f - 1; + assert((mask & (mask + 1)) == 0); + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i].x; } + bar(); + for (u32 i = 0; i < n; ++i) { u[i].x = lds[i * WG + me]; } + bar(); + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i].y; } + bar(); + for (u32 i = 0; i < n; ++i) { u[i].y = lds[i * WG + me]; } +} + +// Same as shufl but use ints instead of Z61s to reduce LDS memory requirements. +// Lower LDS requirements should let the optimizer use fewer VGPRs and increase occupancy for WIDTHs >= 1024. +// Alas, the increased occupancy does not offset extra code needed for shufl_int (the assembly +// code generated is not pretty). This might not be true for nVidia or future ROCm optimizers. +void OVERLOAD shufl_int(u32 WG, local GF61 *lds2, GF61 *u, u32 n, u32 f) { + u32 me = get_local_id(0); + local int* lds = (local int*) lds2; + + u32 mask = f - 1; + assert((mask & (mask + 1)) == 0); + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = as_int4(u[i]).x; } + bar(); + for (u32 i = 0; i < n; ++i) { int4 tmp = as_int4(u[i]); tmp.x = lds[i * WG + me]; u[i] = as_ulong2(tmp); } + bar(); + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = as_int4(u[i]).y; } + bar(); + for (u32 i = 0; i < n; ++i) { int4 tmp = as_int4(u[i]); tmp.y = lds[i * WG + me]; u[i] = as_ulong2(tmp); } + bar(); + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = as_int4(u[i]).z; } + bar(); + for (u32 i = 0; i < n; ++i) { int4 tmp = as_int4(u[i]); tmp.z = lds[i * WG + me]; u[i] = as_ulong2(tmp); } + bar(); + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = as_int4(u[i]).w; } + bar(); + for (u32 i = 0; i < n; ++i) { int4 tmp = as_int4(u[i]); tmp.w = lds[i * WG + me]; u[i] = as_ulong2(tmp); } + bar(); // I'm not sure why this barrier call is needed +} + +// Shufl two simultaneous FFT_HEIGHTs. Needed for tailSquared where u and v are computed simultaneously in different threads. +// NOTE: It is very important for this routine to use lds memory in coordination with reverseLine2 and unreverseLine2. +// Failure to do so would result in the need for more bar() calls. Specifically, the u values are stored in the upper half +// of lds memory (first SMALL_HEIGHT GF61 values). The v values are stored in the lower half of lds memory (next SMALL_HEIGHT GF61 values). +void OVERLOAD shufl2(u32 WG, local GF61 *lds2, GF61 *u, u32 n, u32 f) { + u32 me = get_local_id(0); + + // Partition lds memory into upper and lower halves + assert(WG == G_H); + + // Accessing lds memory as Z61s is faster than GF61 accesses + local Z61* lds = ((local Z61*) lds2) + (me / WG) * SMALL_HEIGHT; + + me = me % WG; + u32 mask = f - 1; + assert((mask & (mask + 1)) == 0); + + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i].x; } + bar(WG); + for (u32 i = 0; i < n; ++i) { u[i].x = lds[i * WG + me]; } + bar(WG); + for (u32 i = 0; i < n; ++i) { lds[i * f + (me & ~mask) * n + (me & mask)] = u[i].y; } + bar(WG); + for (u32 i = 0; i < n; ++i) { u[i].y = lds[i * WG + me]; } +} + +void OVERLOAD tabMul(u32 WG, TrigGF61 trig, GF61 *u, u32 n, u32 f, u32 me) { + u32 p = me & ~(f - 1); + +// This code uses chained complex multiplies which could be faster on GPUs with great mul throughput or poor memory bandwidth or caching. + + if (TABMUL_CHAIN61) { + chainMul (n, u, trig[p], 0); + return; + } + +// Use memory accesses (probably cached) to reduce complex muls. Beneficial when memory bandwidth is not the bottleneck. + + if (!TABMUL_CHAIN61) { + for (u32 i = 1; i < n; ++i) { + u[i] = cmul(u[i], trig[(i-1)*WG + p]); + } + return; + } +} + +#endif diff --git a/src/cl/fftheight.cl b/src/cl/fftheight.cl index 2ceb9326..69fa75b8 100644 --- a/src/cl/fftheight.cl +++ b/src/cl/fftheight.cl @@ -8,9 +8,15 @@ #error SMALL_HEIGHT must be one of: 256, 512, 1024, 4096 #endif +#if !INPLACE u32 transPos(u32 k, u32 middle, u32 width) { return k / width + k % width * middle; } +#else +u32 transPos(u32 k, u32 middle, u32 width) { return k; } +#endif + +#if FFT_FP64 -void fft_NH(T2 *u) { +void OVERLOAD fft_NH(T2 *u) { #if NH == 4 fft4(u); #elif NH == 8 @@ -29,7 +35,7 @@ void fft_NH(T2 *u) { #error FFT_VARIANT_H == 0 only supported by AMD GPUs #endif -void fft_HEIGHT(local T2 *lds, T2 *u, Trig trig, T2 w) { +void OVERLOAD fft_HEIGHT(local T2 *lds, T2 *u, Trig trig, T2 w) { for (u32 s = 1; s < SMALL_HEIGHT / NH; s *= NH) { if (s > 1) { bar(); } fft_NH(u); @@ -42,7 +48,7 @@ void fft_HEIGHT(local T2 *lds, T2 *u, Trig trig, T2 w) { fft_NH(u); } -void fft_HEIGHT2(local T2 *lds, T2 *u, Trig trig, T2 w) { +void OVERLOAD fft_HEIGHT2(local T2 *lds, T2 *u, Trig trig, T2 w) { u32 WG = SMALL_HEIGHT / NH; for (u32 s = 1; s < SMALL_HEIGHT / NH; s *= NH) { if (s > 1) { bar(WG); } @@ -58,7 +64,7 @@ void fft_HEIGHT2(local T2 *lds, T2 *u, Trig trig, T2 w) { #else -void fft_HEIGHT(local T2 *lds, T2 *u, Trig trig, T2 w) { +void OVERLOAD fft_HEIGHT(local T2 *lds, T2 *u, Trig trig, T2 w) { u32 me = get_local_id(0); #if !UNROLL_H @@ -74,7 +80,7 @@ void fft_HEIGHT(local T2 *lds, T2 *u, Trig trig, T2 w) { fft_NH(u); } -void fft_HEIGHT2(local T2 *lds, T2 *u, Trig trig, T2 w) { +void OVERLOAD fft_HEIGHT2(local T2 *lds, T2 *u, Trig trig, T2 w) { u32 me = get_local_id(0); u32 WG = SMALL_HEIGHT / NH; @@ -93,9 +99,7 @@ void fft_HEIGHT2(local T2 *lds, T2 *u, Trig trig, T2 w) { #endif - - -void new_fft_HEIGHT2(local T2 *lds, T2 *u, Trig trig, T2 w, int callnum) { +void OVERLOAD new_fft_HEIGHT2(local T2 *lds, T2 *u, Trig trig, T2 w, int callnum) { u32 WG = SMALL_HEIGHT / NH; u32 me = get_local_id(0); // This line mimics shufl2 -- partition lds into halves @@ -205,3 +209,169 @@ void new_fft_HEIGHT2(local T2 *lds, T2 *u, Trig trig, T2 w, int callnum) { void new_fft_HEIGHT2_1(local T2 *lds, T2 *u, Trig trig, T2 w) { new_fft_HEIGHT2(lds, u, trig, w, 1); } void new_fft_HEIGHT2_2(local T2 *lds, T2 *u, Trig trig, T2 w) { new_fft_HEIGHT2(lds, u, trig, w, 2); } +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 + +void OVERLOAD fft_NH(F2 *u) { +#if NH == 4 + fft4(u); +#elif NH == 8 + fft8(u); +#else +#error NH +#endif +} + +void OVERLOAD fft_HEIGHT(local F2 *lds, F2 *u, TrigFP32 trig) { + u32 me = get_local_id(0); + +#if !UNROLL_H + __attribute__((opencl_unroll_hint(1))) +#endif + + for (u32 s = 1; s < SMALL_HEIGHT / NH; s *= NH) { + if (s > 1) { bar(); } + fft_NH(u); + tabMul(SMALL_HEIGHT / NH, trig, u, NH, s, me); + shufl(SMALL_HEIGHT / NH, lds, u, NH, s); + } + fft_NH(u); +} + +void OVERLOAD fft_HEIGHT2(local F2 *lds, F2 *u, TrigFP32 trig) { + u32 me = get_local_id(0); + u32 WG = SMALL_HEIGHT / NH; + +#if !UNROLL_H + __attribute__((opencl_unroll_hint(1))) +#endif + + for (u32 s = 1; s < WG; s *= NH) { + if (s > 1) { bar(WG); } + fft_NH(u); + tabMul(WG, trig, u, NH, s, me % WG); + shufl2(WG, lds, u, NH, s); + } + fft_NH(u); +} + +void new_fft_HEIGHT2_1(local F2 *lds, F2 *u, TrigFP32 trig) { fft_HEIGHT2(lds, u, trig); } +void new_fft_HEIGHT2_2(local F2 *lds, F2 *u, TrigFP32 trig) { fft_HEIGHT2(lds, u, trig); } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +void OVERLOAD fft_NH(GF31 *u) { +#if NH == 4 + fft4(u); +#elif NH == 8 + fft8(u); +#else +#error NH +#endif +} + +void OVERLOAD fft_HEIGHT(local GF31 *lds, GF31 *u, TrigGF31 trig) { + u32 me = get_local_id(0); + +#if !UNROLL_H + __attribute__((opencl_unroll_hint(1))) +#endif + + for (u32 s = 1; s < SMALL_HEIGHT / NH; s *= NH) { + if (s > 1) { bar(); } + fft_NH(u); + tabMul(SMALL_HEIGHT / NH, trig, u, NH, s, me); + shufl(SMALL_HEIGHT / NH, lds, u, NH, s); + } + fft_NH(u); +} + +void OVERLOAD fft_HEIGHT2(local GF31 *lds, GF31 *u, TrigGF31 trig) { + u32 me = get_local_id(0); + u32 WG = SMALL_HEIGHT / NH; + +#if !UNROLL_H + __attribute__((opencl_unroll_hint(1))) +#endif + + for (u32 s = 1; s < WG; s *= NH) { + if (s > 1) { bar(WG); } + fft_NH(u); + tabMul(WG, trig, u, NH, s, me % WG); + shufl2(WG, lds, u, NH, s); + } + fft_NH(u); +} + +void OVERLOAD new_fft_HEIGHT2_1(local GF31 *lds, GF31 *u, TrigGF31 trig) { fft_HEIGHT2(lds, u, trig); } +void OVERLOAD new_fft_HEIGHT2_2(local GF31 *lds, GF31 *u, TrigGF31 trig) { fft_HEIGHT2(lds, u, trig); } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +void OVERLOAD fft_NH(GF61 *u) { +#if NH == 4 + fft4(u); +#elif NH == 8 + fft8(u); +#else +#error NH +#endif +} + +void OVERLOAD fft_HEIGHT(local GF61 *lds, GF61 *u, TrigGF61 trig) { + u32 me = get_local_id(0); + +#if !UNROLL_H + __attribute__((opencl_unroll_hint(1))) +#endif + + for (u32 s = 1; s < SMALL_HEIGHT / NH; s *= NH) { + if (s > 1) { bar(); } + fft_NH(u); + tabMul(SMALL_HEIGHT / NH, trig, u, NH, s, me); + shufl(SMALL_HEIGHT / NH, lds, u, NH, s); + } + fft_NH(u); +} + +void OVERLOAD fft_HEIGHT2(local GF61 *lds, GF61 *u, TrigGF61 trig) { + u32 me = get_local_id(0); + u32 WG = SMALL_HEIGHT / NH; + +#if !UNROLL_H + __attribute__((opencl_unroll_hint(1))) +#endif + + for (u32 s = 1; s < WG; s *= NH) { + if (s > 1) { bar(WG); } + fft_NH(u); + tabMul(WG, trig, u, NH, s, me % WG); + shufl2(WG, lds, u, NH, s); + } + fft_NH(u); +} + +void OVERLOAD new_fft_HEIGHT2_1(local GF61 *lds, GF61 *u, TrigGF61 trig) { fft_HEIGHT2(lds, u, trig); } +void OVERLOAD new_fft_HEIGHT2_2(local GF61 *lds, GF61 *u, TrigGF61 trig) { fft_HEIGHT2(lds, u, trig); } + +#endif diff --git a/src/cl/ffthin.cl b/src/cl/ffthin.cl index a869a2a0..ae14ec53 100644 --- a/src/cl/ffthin.cl +++ b/src/cl/ffthin.cl @@ -4,7 +4,9 @@ #include "math.cl" #include "fftheight.cl" -// Do an FFT Height after a transposeW (which may not have fully transposed data, leading to non-sequential input) +#if FFT_FP64 + +// Do an FFT Height after an fftMiddleIn (which may not have fully transposed data, leading to non-sequential input) KERNEL(G_H) fftHin(P(T2) out, CP(T2) in, Trig smallTrig) { local T2 lds[SMALL_HEIGHT / 2]; @@ -25,3 +27,97 @@ KERNEL(G_H) fftHin(P(T2) out, CP(T2) in, Trig smallTrig) { write(G_H, NH, u, out, SMALL_HEIGHT * transPos(g, MIDDLE, WIDTH)); } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 + +// Do an FFT Height after an fftMiddleIn (which may not have fully transposed data, leading to non-sequential input) +KERNEL(G_H) fftHin(P(T2) out, CP(T2) in, Trig smallTrig) { + local F2 lds[SMALL_HEIGHT / 2]; + + CP(F2) inF2 = (CP(F2)) in; + P(F2) outF2 = (P(F2)) out; + TrigFP32 smallTrigF2 = (TrigFP32) smallTrig; + + F2 u[NH]; + u32 g = get_group_id(0); + u32 me = get_local_id(0); + + readTailFusedLine(inF2, u, g, me); + +#if NH == 8 + F2 w = fancyTrig_N(ND / SMALL_HEIGHT * me); +#else + F2 w = slowTrig_N(ND / SMALL_HEIGHT * me, ND / NH); +#endif + + fft_HEIGHT(lds, u, smallTrigF2); + + write(G_H, NH, u, outF2, SMALL_HEIGHT * transPos(g, MIDDLE, WIDTH)); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +// Do an FFT Height after an fftMiddleIn (which may not have fully transposed data, leading to non-sequential input) +KERNEL(G_H) fftHinGF31(P(T2) out, CP(T2) in, Trig smallTrig) { + local GF31 lds[SMALL_HEIGHT / 2]; + + CP(GF31) in31 = (CP(GF31)) (in + DISTGF31); + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 smallTrig31 = (TrigGF31) (smallTrig + DISTHTRIGGF31); + + GF31 u[NH]; + u32 g = get_group_id(0); + + u32 me = get_local_id(0); + + readTailFusedLine(in31, u, g, me); + + fft_HEIGHT(lds, u, smallTrig31); + + write(G_H, NH, u, out31, SMALL_HEIGHT * transPos(g, MIDDLE, WIDTH)); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +// Do an FFT Height after an fftMiddleIn (which may not have fully transposed data, leading to non-sequential input) +KERNEL(G_H) fftHinGF61(P(T2) out, CP(T2) in, Trig smallTrig) { + local GF61 lds[SMALL_HEIGHT / 2]; + + CP(GF61) in61 = (CP(GF61)) (in + DISTGF61); + P(GF61) out61 = (P(GF61)) (out + DISTGF61); + TrigGF61 smallTrig61 = (TrigGF61) (smallTrig + DISTHTRIGGF61); + + GF61 u[NH]; + u32 g = get_group_id(0); + + u32 me = get_local_id(0); + + readTailFusedLine(in61, u, g, me); + + fft_HEIGHT(lds, u, smallTrig61); + + write(G_H, NH, u, out61, SMALL_HEIGHT * transPos(g, MIDDLE, WIDTH)); +} + +#endif diff --git a/src/cl/fftmiddlein.cl b/src/cl/fftmiddlein.cl index 830fa2b3..d1788799 100644 --- a/src/cl/fftmiddlein.cl +++ b/src/cl/fftmiddlein.cl @@ -5,13 +5,17 @@ #include "fft-middle.cl" #include "middle.cl" +#if !INPLACE // Original implementation (not in place) + +#if FFT_FP64 + KERNEL(IN_WG) fftMiddleIn(P(T2) out, CP(T2) in, Trig trig) { T2 u[MIDDLE]; - + u32 SIZEY = IN_WG / IN_SIZEX; u32 N = WIDTH / IN_SIZEX; - + u32 g = get_group_id(0); u32 gx = g % N; u32 gy = g / N; @@ -46,3 +50,366 @@ KERNEL(IN_WG) fftMiddleIn(P(T2) out, CP(T2) in, Trig trig) { writeMiddleInLine(out, u, gy, gx); } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 + +KERNEL(IN_WG) fftMiddleIn(P(T2) out, CP(T2) in, Trig trig) { + F2 u[MIDDLE]; + + CP(F2) inF2 = (CP(F2)) in; + P(F2) outF2 = (P(F2)) out; + TrigFP32 trigF2 = (TrigFP32) trig; + + u32 SIZEY = IN_WG / IN_SIZEX; + + u32 N = WIDTH / IN_SIZEX; + + u32 g = get_group_id(0); + u32 gx = g % N; + u32 gy = g / N; + + u32 me = get_local_id(0); + u32 mx = me % IN_SIZEX; + u32 my = me / IN_SIZEX; + + u32 startx = gx * IN_SIZEX; + u32 starty = gy * SIZEY; + + u32 x = startx + mx; + u32 y = starty + my; + + readMiddleInLine(u, inF2, y, x); + + middleMul2(u, x, y, 1, trigF2); + + fft_MIDDLE(u); + + middleMul(u, y, trigF2); + +#if MIDDLE_IN_LDS_TRANSPOSE + // Transpose the x and y values + local F lds[IN_WG / 2 * (MIDDLE <= 16 ? 2 * MIDDLE : MIDDLE)]; + middleShuffle(lds, u, IN_WG, IN_SIZEX); + outF2 += me; // Threads write sequentially to memory since x and y values are already transposed +#else + // Adjust out pointer to effect a transpose of x and y values + outF2 += mx * SIZEY + my; +#endif + + writeMiddleInLine(outF2, u, gy, gx); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +KERNEL(IN_WG) fftMiddleInGF31(P(T2) out, CP(T2) in, Trig trig) { + GF31 u[MIDDLE]; + + CP(GF31) in31 = (CP(GF31)) (in + DISTGF31); + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 trig31 = (TrigGF31) (trig + DISTMTRIGGF31); + + u32 SIZEY = IN_WG / IN_SIZEX; + + u32 N = WIDTH / IN_SIZEX; + + u32 g = get_group_id(0); + u32 gx = g % N; + u32 gy = g / N; + + u32 me = get_local_id(0); + u32 mx = me % IN_SIZEX; + u32 my = me / IN_SIZEX; + + u32 startx = gx * IN_SIZEX; + u32 starty = gy * SIZEY; + + u32 x = startx + mx; + u32 y = starty + my; + + readMiddleInLine(u, in31, y, x); + + middleMul2(u, x, y, trig31); + + fft_MIDDLE(u); + + middleMul(u, y, trig31); + +#if MIDDLE_IN_LDS_TRANSPOSE + // Transpose the x and y values + local Z31 lds[IN_WG / 2 * (MIDDLE <= 16 ? 2 * MIDDLE : MIDDLE)]; + middleShuffle(lds, u, IN_WG, IN_SIZEX); + out31 += me; // Threads write sequentially to memory since x and y values are already transposed +#else + // Adjust out pointer to effect a transpose of x and y values + out31 += mx * SIZEY + my; +#endif + + writeMiddleInLine(out31, u, gy, gx); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +KERNEL(IN_WG) fftMiddleInGF61(P(T2) out, CP(T2) in, Trig trig) { + GF61 u[MIDDLE]; + + CP(GF61) in61 = (CP(GF61)) (in + DISTGF61); + P(GF61) out61 = (P(GF61)) (out + DISTGF61); + TrigGF61 trig61 = (TrigGF61) (trig + DISTMTRIGGF61); + + u32 SIZEY = IN_WG / IN_SIZEX; + + u32 N = WIDTH / IN_SIZEX; + + u32 g = get_group_id(0); + u32 gx = g % N; + u32 gy = g / N; + + u32 me = get_local_id(0); + u32 mx = me % IN_SIZEX; + u32 my = me / IN_SIZEX; + + u32 startx = gx * IN_SIZEX; + u32 starty = gy * SIZEY; + + u32 x = startx + mx; + u32 y = starty + my; + + readMiddleInLine(u, in61, y, x); + + middleMul2(u, x, y, trig61); + + fft_MIDDLE(u); + + middleMul(u, y, trig61); + +#if MIDDLE_IN_LDS_TRANSPOSE + // Transpose the x and y values + local Z61 lds[IN_WG / 2 * (MIDDLE <= 8 ? 2 * MIDDLE : MIDDLE)]; + middleShuffle(lds, u, IN_WG, IN_SIZEX); + out61 += me; // Threads write sequentially to memory since x and y values are already transposed +#else + // Adjust out pointer to effect a transpose of x and y values + out61 += mx * SIZEY + my; +#endif + + writeMiddleInLine(out61, u, gy, gx); +} + +#endif + + + + + + +#else // in place transpose + +#if FFT_FP64 + +KERNEL(256) fftMiddleIn(P(T2) out, P(T2) in, Trig trig) { + assert(out == in); + T2 u[MIDDLE]; + + u32 g = get_group_id(0); +#if INPLACE == 1 // nVidia friendly padding + u32 N = SMALL_HEIGHT / 16; + u32 starty = g % N * 16; + u32 startx = g / N * 16; + u32 zerohack = g / 131072; // A super tiny benefit (much smaller than margin of error) on TitanV +#else // AMD friendly padding, vary x fast? I've no explanation for why that would be better + u32 N = WIDTH / 16; + u32 startx = g % N * 16; + u32 starty = g / N * 16; + u32 zerohack = (MIDDLE >= 16) ? 0 : g / 131072; // Rocm optimizer goes bonkers if zerohack used when MIDDLE=16 +#endif + + u32 me = get_local_id(0); + u32 x = startx + me % 16; + u32 y = starty + me / 16; + + readMiddleInLine(u, in, y, x); + + middleMul2(u, x, y, 1, trig); + + fft_MIDDLE(u); + + middleMul(u, y, trig); + + // Transpose the x and y values + local T2 lds[256]; + middleShuffle(lds, u); + + writeMiddleInLine(in + zerohack, u, y, x); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 + +KERNEL(256) fftMiddleIn(P(T2) out, P(T2) in, Trig trig) { + assert(out == in); + F2 u[MIDDLE]; + + P(F2) inF2 = (P(F2)) in; + P(F2) outF2 = (P(F2)) out; + TrigFP32 trigF2 = (TrigFP32) trig; + + u32 g = get_group_id(0); +#if INPLACE == 1 // nVidia friendly padding + u32 N = SMALL_HEIGHT / 16; + u32 starty = g % N * 16; + u32 startx = g / N * 16; + u32 zerohack = 0; // Need to test if g / 131072 is of any benefit +#else // AMD friendly padding, vary x fast? I've no explanation for why that would be better + u32 N = WIDTH / 16; + u32 startx = g % N * 16; + u32 starty = g / N * 16; + u32 zerohack = (MIDDLE >= 16) ? 0 : g / 131072; // Rocm optimizer goes bonkers if zerohack used when MIDDLE=16 (for FP64, FP32 untimed) +#endif + + u32 me = get_local_id(0); + u32 x = startx + me % 16; + u32 y = starty + me / 16; + + readMiddleInLine(u, inF2, y, x); + + middleMul2(u, x, y, 1, trigF2); + + fft_MIDDLE(u); + + middleMul(u, y, trigF2); + + // Transpose the x and y values + local F2 lds[256]; + middleShuffle(lds, u); + + writeMiddleInLine(inF2 + zerohack, u, y, x); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +KERNEL(256) fftMiddleInGF31(P(T2) out, P(T2) in, Trig trig) { + assert(out == in); + GF31 u[MIDDLE]; + + P(GF31) in31 = (P(GF31)) (in + DISTGF31); + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 trig31 = (TrigGF31) (trig + DISTMTRIGGF31); + + u32 g = get_group_id(0); +#if INPLACE == 1 // nVidia friendly padding + u32 N = SMALL_HEIGHT / 16; + u32 starty = g % N * 16; + u32 startx = g / N * 16; + u32 zerohack = 0; // Need to test if g / 131072 is of any benefit +#else // AMD friendly padding, vary x fast? I've no explanation for why that would be better + u32 N = WIDTH / 16; + u32 startx = g % N * 16; + u32 starty = g / N * 16; + u32 zerohack = (MIDDLE >= 16) ? 0 : g / 131072; // Rocm optimizer goes bonkers if zerohack used when MIDDLE=16 (for FP64, GF31 untimed) +#endif + + u32 me = get_local_id(0); + u32 x = startx + me % 16; + u32 y = starty + me / 16; + + readMiddleInLine(u, in31, y, x); + + middleMul2(u, x, y, trig31); + + fft_MIDDLE(u); + + middleMul(u, y, trig31); + + // Transpose the x and y values + local GF31 lds[256]; + middleShuffle(lds, u); + + writeMiddleInLine(in31 + zerohack, u, y, x); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +KERNEL(256) fftMiddleInGF61(P(T2) out, P(T2) in, Trig trig) { + assert(out == in); + GF61 u[MIDDLE]; + + P(GF61) in61 = (P(GF61)) (in + DISTGF61); + P(GF61) out61 = (P(GF61)) (out + DISTGF61); + TrigGF61 trig61 = (TrigGF61) (trig + DISTMTRIGGF61); + + u32 g = get_group_id(0); +#if INPLACE == 1 // nVidia friendly padding + u32 N = SMALL_HEIGHT / 16; + u32 starty = g % N * 16; + u32 startx = g / N * 16; + u32 zerohack = 0; // Need to test if g / 131072 is of any benefit +#else // AMD friendly padding, vary x fast? I've no explanation for why that would be better + u32 N = WIDTH / 16; + u32 startx = g % N * 16; + u32 starty = g / N * 16; + u32 zerohack = (MIDDLE >= 16) ? 0 : g / 131072; // Rocm optimizer goes bonkers if zerohack used when MIDDLE=16 (for FP64, GF61 untimed) +#endif + + u32 me = get_local_id(0); + u32 x = startx + me % 16; + u32 y = starty + me / 16; + + readMiddleInLine(u, in61, y, x); + + middleMul2(u, x, y, trig61); + + fft_MIDDLE(u); + + middleMul(u, y, trig61); + + // Transpose the x and y values + local GF61 lds[256]; + middleShuffle(lds, u); + + writeMiddleInLine(in61 + zerohack, u, y, x); +} + +#endif + +#endif diff --git a/src/cl/fftmiddleout.cl b/src/cl/fftmiddleout.cl index 6404c7ae..a7263dcd 100644 --- a/src/cl/fftmiddleout.cl +++ b/src/cl/fftmiddleout.cl @@ -5,7 +5,11 @@ #include "fft-middle.cl" #include "middle.cl" -KERNEL(OUT_WG) fftMiddleOut(P(T2) out, P(T2) in, Trig trig) { +#if !INPLACE // Original implementation (not in place) + +#if FFT_FP64 + +KERNEL(OUT_WG) fftMiddleOut(P(T2) out, CP(T2) in, Trig trig) { T2 u[MIDDLE]; u32 SIZEY = OUT_WG / OUT_SIZEX; @@ -56,3 +60,385 @@ KERNEL(OUT_WG) fftMiddleOut(P(T2) out, P(T2) in, Trig trig) { writeMiddleOutLine(out, u, gy, gx); } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if FFT_FP32 + +KERNEL(OUT_WG) fftMiddleOut(P(T2) out, CP(T2) in, Trig trig) { + F2 u[MIDDLE]; + + CP(F2) inF2 = (CP(F2)) in; + P(F2) outF2 = (P(F2)) out; + TrigFP32 trigF2 = (TrigFP32) trig; + + u32 SIZEY = OUT_WG / OUT_SIZEX; + + u32 N = SMALL_HEIGHT / OUT_SIZEX; + + u32 g = get_group_id(0); + u32 gx = g % N; + u32 gy = g / N; + + u32 me = get_local_id(0); + u32 mx = me % OUT_SIZEX; + u32 my = me / OUT_SIZEX; + + // Kernels read OUT_SIZEX consecutive T2. + // Each WG-thread kernel processes OUT_SIZEX columns from a needed SMALL_HEIGHT columns + // Each WG-thread kernel processes SIZEY rows out of a needed WIDTH rows + + u32 startx = gx * OUT_SIZEX; // Each input column increases FFT element by one + u32 starty = gy * SIZEY; // Each input row increases FFT element by BIG_HEIGHT + + u32 x = startx + mx; + u32 y = starty + my; + + readMiddleOutLine(u, inF2, y, x); + + middleMul(u, x, trigF2); + + fft_MIDDLE(u); + + // FFT results come out multiplied by the FFT length (NWORDS). Also, for performance reasons + // weights and invweights are doubled meaning we need to divide by another 2^2 and 2^2. + // Finally, roundoff errors are sometimes improved if we use the next lower double precision + // number. This may be due to roundoff errors introduced by applying inexact TWO_TO_N_8TH weights. + double factor = 1.0 / (4 * 4 * NWORDS); + + middleMul2(u, y, x, factor, trigF2); + +#if MIDDLE_OUT_LDS_TRANSPOSE + // Transpose the x and y values + local F lds[OUT_WG / 2 * (MIDDLE <= 16 ? 2 * MIDDLE : MIDDLE)]; + middleShuffle(lds, u, OUT_WG, OUT_SIZEX); + outF2 += me; // Threads write sequentially to memory since x and y values are already transposed +#else + // Adjust out pointer to effect a transpose of x and y values + outF2 += mx * SIZEY + my; +#endif + + writeMiddleOutLine(outF2, u, gy, gx); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +KERNEL(OUT_WG) fftMiddleOutGF31(P(T2) out, CP(T2) in, Trig trig) { + GF31 u[MIDDLE]; + + CP(GF31) in31 = (CP(GF31)) (in + DISTGF31); + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 trig31 = (TrigGF31) (trig + DISTMTRIGGF31); + + u32 SIZEY = OUT_WG / OUT_SIZEX; + + u32 N = SMALL_HEIGHT / OUT_SIZEX; + + u32 g = get_group_id(0); + u32 gx = g % N; + u32 gy = g / N; + + u32 me = get_local_id(0); + u32 mx = me % OUT_SIZEX; + u32 my = me / OUT_SIZEX; + + // Kernels read OUT_SIZEX consecutive T2. + // Each WG-thread kernel processes OUT_SIZEX columns from a needed SMALL_HEIGHT columns + // Each WG-thread kernel processes SIZEY rows out of a needed WIDTH rows + + u32 startx = gx * OUT_SIZEX; // Each input column increases FFT element by one + u32 starty = gy * SIZEY; // Each input row increases FFT element by BIG_HEIGHT + + u32 x = startx + mx; + u32 y = starty + my; + + readMiddleOutLine(u, in31, y, x); + + middleMul(u, x, trig31); + + fft_MIDDLE(u); + + middleMul2(u, y, x, trig31); + +#if MIDDLE_OUT_LDS_TRANSPOSE + // Transpose the x and y values + local Z31 lds[OUT_WG / 2 * (MIDDLE <= 16 ? 2 * MIDDLE : MIDDLE)]; + middleShuffle(lds, u, OUT_WG, OUT_SIZEX); + out31 += me; // Threads write sequentially to memory since x and y values are already transposed +#else + // Adjust out pointer to effect a transpose of x and y values + out31 += mx * SIZEY + my; +#endif + + writeMiddleOutLine(out31, u, gy, gx); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +KERNEL(OUT_WG) fftMiddleOutGF61(P(T2) out, CP(T2) in, Trig trig) { + GF61 u[MIDDLE]; + + CP(GF61) in61 = (CP(GF61)) (in + DISTGF61); + P(GF61) out61 = (P(GF61)) (out + DISTGF61); + TrigGF61 trig61 = (TrigGF61) (trig + DISTMTRIGGF61); + + u32 SIZEY = OUT_WG / OUT_SIZEX; + + u32 N = SMALL_HEIGHT / OUT_SIZEX; + + u32 g = get_group_id(0); + u32 gx = g % N; + u32 gy = g / N; + + u32 me = get_local_id(0); + u32 mx = me % OUT_SIZEX; + u32 my = me / OUT_SIZEX; + + // Kernels read OUT_SIZEX consecutive T2. + // Each WG-thread kernel processes OUT_SIZEX columns from a needed SMALL_HEIGHT columns + // Each WG-thread kernel processes SIZEY rows out of a needed WIDTH rows + + u32 startx = gx * OUT_SIZEX; // Each input column increases FFT element by one + u32 starty = gy * SIZEY; // Each input row increases FFT element by BIG_HEIGHT + + u32 x = startx + mx; + u32 y = starty + my; + + readMiddleOutLine(u, in61, y, x); + + middleMul(u, x, trig61); + + fft_MIDDLE(u); + + middleMul2(u, y, x, trig61); + +#if MIDDLE_OUT_LDS_TRANSPOSE + // Transpose the x and y values + local Z61 lds[OUT_WG / 2 * (MIDDLE <= 8 ? 2 * MIDDLE : MIDDLE)]; + middleShuffle(lds, u, OUT_WG, OUT_SIZEX); + out61 += me; // Threads write sequentially to memory since x and y values are already transposed +#else + // Adjust out pointer to effect a transpose of x and y values + out61 += mx * SIZEY + my; +#endif + + writeMiddleOutLine(out61, u, gy, gx); +} + +#endif + + + +#else // in place transpose + +#if FFT_FP64 + +KERNEL(256) fftMiddleOut(P(T2) out, P(T2) in, Trig trig) { + assert(out == in); + T2 u[MIDDLE]; + + u32 g = get_group_id(0); +#if INPLACE == 1 // nVidia friendly padding + u32 N = SMALL_HEIGHT / 16; + u32 startx = g % N * 16; + u32 starty = g / N * 16; +#else // AMD friendly padding, vary x fast? I've no explanation for why that would be better + u32 N = WIDTH / 16; + u32 starty = g % N * 16; + u32 startx = g / N * 16; +#endif + + u32 me = get_local_id(0); + u32 x = startx + me % 16; + u32 y = starty + me / 16; + + readMiddleOutLine(u, in, y, x); + + middleMul(u, x, trig); + + fft_MIDDLE(u); + + // FFT results come out multiplied by the FFT length (NWORDS). Also, for performance reasons + // weights and invweights are doubled meaning we need to divide by another 2^2 and 2^2. + // Finally, roundoff errors are sometimes improved if we use the next lower double precision + // number. This may be due to roundoff errors introduced by applying inexact TWO_TO_N_8TH weights. + double factor = 1.0 / (4 * 4 * NWORDS); + + middleMul2(u, y, x, factor, trig); + + // Transpose the x and y values + local T2 lds[256]; + middleShuffle(lds, u); + + writeMiddleOutLine(out, u, y, x); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if FFT_FP32 + +KERNEL(256) fftMiddleOut(P(T2) out, P(T2) in, Trig trig) { + assert(out == in); + F2 u[MIDDLE]; + + P(F2) inF2 = (P(F2)) in; + P(F2) outF2 = (P(F2)) out; + TrigFP32 trigF2 = (TrigFP32) trig; + + u32 g = get_group_id(0); +#if INPLACE == 1 // nVidia friendly padding + u32 N = SMALL_HEIGHT / 16; + u32 startx = g % N * 16; + u32 starty = g / N * 16; +#else // AMD friendly padding, vary x fast? I've no explanation for why that would be better + u32 N = WIDTH / 16; + u32 starty = g % N * 16; + u32 startx = g / N * 16; +#endif + + u32 me = get_local_id(0); + u32 x = startx + me % 16; + u32 y = starty + me / 16; + + readMiddleOutLine(u, inF2, y, x); + + middleMul(u, x, trigF2); + + fft_MIDDLE(u); + + // FFT results come out multiplied by the FFT length (NWORDS). Also, for performance reasons + // weights and invweights are doubled meaning we need to divide by another 2^2 and 2^2. + // Finally, roundoff errors are sometimes improved if we use the next lower double precision + // number. This may be due to roundoff errors introduced by applying inexact TWO_TO_N_8TH weights. + double factor = 1.0 / (4 * 4 * NWORDS); + + middleMul2(u, y, x, factor, trigF2); + + // Transpose the x and y values + local F2 lds[256]; + middleShuffle(lds, u); + + writeMiddleOutLine(outF2, u, y, x); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +KERNEL(256) fftMiddleOutGF31(P(T2) out, P(T2) in, Trig trig) { + assert(out == in); + GF31 u[MIDDLE]; + + P(GF31) in31 = (P(GF31)) (in + DISTGF31); + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 trig31 = (TrigGF31) (trig + DISTMTRIGGF31); + + u32 g = get_group_id(0); +#if INPLACE == 1 // nVidia friendly padding + u32 N = SMALL_HEIGHT / 16; + u32 startx = g % N * 16; + u32 starty = g / N * 16; +#else // AMD friendly padding, vary x fast? I've no explanation for why that would be better + u32 N = WIDTH / 16; + u32 starty = g % N * 16; + u32 startx = g / N * 16; +#endif + + u32 me = get_local_id(0); + u32 x = startx + me % 16; + u32 y = starty + me / 16; + + readMiddleOutLine(u, in31, y, x); + + middleMul(u, x, trig31); + + fft_MIDDLE(u); + + middleMul2(u, y, x, trig31); + + // Transpose the x and y values + local GF31 lds[256]; + middleShuffle(lds, u); + + writeMiddleOutLine(out31, u, y, x); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +KERNEL(256) fftMiddleOutGF61(P(T2) out, P(T2) in, Trig trig) { + assert(out == in); + GF61 u[MIDDLE]; + + P(GF61) in61 = (P(GF61)) (in + DISTGF61); + P(GF61) out61 = (P(GF61)) (out + DISTGF61); + TrigGF61 trig61 = (TrigGF61) (trig + DISTMTRIGGF61); + + u32 g = get_group_id(0); +#if INPLACE == 1 // nVidia friendly padding + u32 N = SMALL_HEIGHT / 16; + u32 startx = g % N * 16; + u32 starty = g / N * 16; +#else // AMD friendly padding, vary x fast? I've no explanation for why that would be better + u32 N = WIDTH / 16; + u32 starty = g % N * 16; + u32 startx = g / N * 16; +#endif + + u32 me = get_local_id(0); + u32 x = startx + me % 16; + u32 y = starty + me / 16; + + readMiddleOutLine(u, in61, y, x); + + middleMul(u, x, trig61); + + fft_MIDDLE(u); + + middleMul2(u, y, x, trig61); + + // Transpose the x and y values + local GF61 lds[256]; + middleShuffle(lds, u); + + writeMiddleOutLine(out61, u, y, x); +} + +#endif + +#endif diff --git a/src/cl/fftp.cl b/src/cl/fftp.cl index 70a707d8..2fe0d7d4 100644 --- a/src/cl/fftp.cl +++ b/src/cl/fftp.cl @@ -6,28 +6,562 @@ #include "fftwidth.cl" #include "middle.cl" +#if FFT_TYPE == FFT64 + // fftPremul: weight words with IBDWT weights followed by FFT-width. KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig, BigTab THREAD_WEIGHTS) { local T2 lds[WIDTH / 2]; - T2 u[NW]; + + u32 g = get_group_id(0); + u32 me = get_local_id(0); + + in += g * WIDTH; + + T base = optionalHalve(fancyMul(THREAD_WEIGHTS[me].y, THREAD_WEIGHTS[G_W + g].y)); + + for (u32 i = 0; i < NW; ++i) { + T w1 = i == 0 ? base : optionalHalve(fancyMul(base, fweightStep(i))); + T w2 = optionalHalve(fancyMul(w1, WEIGHT_STEP)); + u32 p = G_W * i + me; + u[i] = U2(in[p].x * w1, in[p].y * w2); + } + + fft_WIDTH(lds, u, smallTrig); + + writeCarryFusedLine(u, out, g); +} + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT32 + +// fftPremul: weight words with IBDWT weights followed by FFT-width. +KERNEL(G_W) fftP(P(F2) out, CP(Word2) in, TrigFP32 smallTrig, BigTabFP32 THREAD_WEIGHTS) { + local F2 lds[WIDTH / 2]; + F2 u[NW]; + + u32 g = get_group_id(0); + u32 me = get_local_id(0); + + in += g * WIDTH; + + F base = optionalHalve(fancyMul(THREAD_WEIGHTS[me].y, THREAD_WEIGHTS[G_W + g].y)); + + for (u32 i = 0; i < NW; ++i) { + F w1 = i == 0 ? base : optionalHalve(fancyMul(base, fweightStep(i))); + F w2 = optionalHalve(fancyMul(w1, WEIGHT_STEP)); + u32 p = G_W * i + me; + u[i] = U2(in[p].x * w1, in[p].y * w2); + } + + fft_WIDTH(lds, u, smallTrig); + + writeCarryFusedLine(u, out, g); +} + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT31 + +// fftPremul: weight words with IBDWT weights followed by FFT-width. +KERNEL(G_W) fftP(P(GF31) out, CP(Word2) in, TrigGF31 smallTrig) { + local GF31 lds[WIDTH / 2]; + GF31 u[NW]; + + u32 g = get_group_id(0); + u32 me = get_local_id(0); + + in += g * WIDTH; + + u32 word_index = (me * BIG_HEIGHT + g) * 2; + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + const u32 log2_root_two = (u32) (((1ULL << 30) / NWORDS) % 31); + const u32 bigword_weight_shift = (NWORDS - EXP % NWORDS) * log2_root_two % 31; + const u32 bigword_weight_shift_minus1 = (bigword_weight_shift + 30) % 31; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } combo; +#define frac_bits combo.a[0] +#define weight_shift combo.a[1] +#define combo_counter combo.b + + const u64 combo_step = ((u64) bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 combo_bigstep = ((G_W * BIG_HEIGHT * 2 - 1) * combo_step + (((u64) (G_W * BIG_HEIGHT * 2 - 1) * FRAC_BPW_LO) >> 32)) % (31ULL << 32); + combo_counter = word_index * combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + weight_shift = weight_shift % 31; + + for (u32 i = 0; i < NW; ++i) { + u32 p = G_W * i + me; + // Generate the second weight shift + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 31) weight_shift -= 31; + u32 weight_shift1 = weight_shift; + // Convert and weight inputs + u[i] = U2(shl(make_Z31(in[p].x), weight_shift0), shl(make_Z31(in[p].y), weight_shift1)); // Form a GF31 from each pair of input words + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_bigstep; + if (weight_shift > 31) weight_shift -= 31; + } + + fft_WIDTH(lds, u, smallTrig); + + writeCarryFusedLine(u, out, g); +} + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT61 + +// fftPremul: weight words with IBDWT weights followed by FFT-width. +KERNEL(G_W) fftP(P(GF61) out, CP(Word2) in, TrigGF61 smallTrig) { + local GF61 lds[WIDTH / 2]; + GF61 u[NW]; + u32 g = get_group_id(0); + u32 me = get_local_id(0); + + in += g * WIDTH; + + u32 word_index = (me * BIG_HEIGHT + g) * 2; - u32 step = WIDTH * g; - in += step; + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + // Weights can be applied with shifts because 2 is the 60th root GF61. + // Let s be the shift amount for word 1. The shift amount for word x is ceil(x * (s - 1) + num_big_words_less_than_x) % 61. + const u32 log2_root_two = (u32) (((1ULL << 60) / NWORDS) % 61); + const u32 bigword_weight_shift = (NWORDS - EXP % NWORDS) * log2_root_two % 61; + const u32 bigword_weight_shift_minus1 = (bigword_weight_shift + 60) % 61; + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } combo; +#define frac_bits combo.a[0] +#define weight_shift combo.a[1] +#define combo_counter combo.b + + const u64 combo_step = ((u64) bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 combo_bigstep = ((G_W * BIG_HEIGHT * 2 - 1) * combo_step + (((u64) (G_W * BIG_HEIGHT * 2 - 1) * FRAC_BPW_LO) >> 32)) % (61ULL << 32); + combo_counter = word_index * combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + weight_shift = weight_shift % 61; + + for (u32 i = 0; i < NW; ++i) { + u32 p = G_W * i + me; + // Generate the second weight shift + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 61) weight_shift -= 61; + u32 weight_shift1 = weight_shift; + // Convert and weight input + u[i] = U2(shl(make_Z61(in[p].x), weight_shift0), shl(make_Z61(in[p].y), weight_shift1)); // Form a GF61 from each pair of input words + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_bigstep; + if (weight_shift > 61) weight_shift -= 61; + } + + fft_WIDTH(lds, u, smallTrig); + + writeCarryFusedLine(u, out, g); +} + + +/**************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP64 & GF(M31^2) */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT6431 + +// fftPremul: weight words with IBDWT weights followed by FFT-width. +KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig, BigTab THREAD_WEIGHTS) { + local T2 lds[WIDTH / 2]; + local GF31 *lds31 = (local GF31 *) lds; + T2 u[NW]; + GF31 u31[NW]; + + u32 g = get_group_id(0); u32 me = get_local_id(0); + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 smallTrig31 = (TrigGF31) (smallTrig + DISTWTRIGGF31); + + in += g * WIDTH; + T base = optionalHalve(fancyMul(THREAD_WEIGHTS[me].y, THREAD_WEIGHTS[G_W + g].y)); + u32 word_index = (me * BIG_HEIGHT + g) * 2; + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + // Let s be the shift amount for word 1. The shift amount for word x is ceil(x * (s - 1) + num_big_words_less_than_x) % 31. + const u32 log2_root_two = (u32) (((1ULL << 30) / NWORDS) % 31); + const u32 bigword_weight_shift = (NWORDS - EXP % NWORDS) * log2_root_two % 31; + const u32 bigword_weight_shift_minus1 = (bigword_weight_shift + 30) % 31; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } combo; +#define frac_bits combo.a[0] +#define weight_shift combo.a[1] +#define combo_counter combo.b + + const u64 combo_step = ((u64) bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 combo_bigstep = ((G_W * BIG_HEIGHT * 2 - 1) * combo_step + (((u64) (G_W * BIG_HEIGHT * 2 - 1) * FRAC_BPW_LO) >> 32)) % (31ULL << 32); + combo_counter = word_index * combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + weight_shift = weight_shift % 31; + for (u32 i = 0; i < NW; ++i) { + u32 p = G_W * i + me; + // Generate the FP64 weights and the second GF31 weight shift T w1 = i == 0 ? base : optionalHalve(fancyMul(base, fweightStep(i))); T w2 = optionalHalve(fancyMul(w1, WEIGHT_STEP)); - u32 p = G_W * i + me; - u[i] = U2(in[p].x, in[p].y) * U2(w1, w2); + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 31) weight_shift -= 31; + u32 weight_shift1 = weight_shift; + // Convert and weight input + u[i] = U2(in[p].x * w1, in[p].y * w2); + u31[i] = U2(shl(make_Z31(in[p].x), weight_shift0), shl(make_Z31(in[p].y), weight_shift1)); // Form a GF31 from each pair of input words + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_bigstep; + if (weight_shift > 31) weight_shift -= 31; } fft_WIDTH(lds, u, smallTrig); - writeCarryFusedLine(u, out, g); + bar(); + fft_WIDTH(lds31, u31, smallTrig31); + writeCarryFusedLine(u31, out31, g); +} + + +/**************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP32 & GF(M31^2) */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT3231 + +// fftPremul: weight words with IBDWT weights followed by FFT-width. +KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig, BigTabFP32 THREAD_WEIGHTS) { + local F2 ldsF2[WIDTH / 2]; + local GF31 *lds31 = (local GF31 *) ldsF2; + F2 uF2[NW]; + GF31 u31[NW]; + + u32 g = get_group_id(0); + u32 me = get_local_id(0); + + P(F2) outF2 = (P(F2)) out; + TrigFP32 smallTrigF2 = (TrigFP32) smallTrig; + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 smallTrig31 = (TrigGF31) (smallTrig + DISTWTRIGGF31); + + in += g * WIDTH; + + F base = optionalHalve(fancyMul(THREAD_WEIGHTS[me].y, THREAD_WEIGHTS[G_W + g].y)); + + u32 word_index = (me * BIG_HEIGHT + g) * 2; + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + // Let s be the shift amount for word 1. The shift amount for word x is ceil(x * (s - 1) + num_big_words_less_than_x) % 31. + const u32 log2_root_two = (u32) (((1ULL << 30) / NWORDS) % 31); + const u32 bigword_weight_shift = (NWORDS - EXP % NWORDS) * log2_root_two % 31; + const u32 bigword_weight_shift_minus1 = (bigword_weight_shift + 30) % 31; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } combo; +#define frac_bits combo.a[0] +#define weight_shift combo.a[1] +#define combo_counter combo.b + + const u64 combo_step = ((u64) bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 combo_bigstep = ((G_W * BIG_HEIGHT * 2 - 1) * combo_step + (((u64) (G_W * BIG_HEIGHT * 2 - 1) * FRAC_BPW_LO) >> 32)) % (31ULL << 32); + combo_counter = word_index * combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + weight_shift = weight_shift % 31; + + for (u32 i = 0; i < NW; ++i) { + u32 p = G_W * i + me; + // Generate the FP32 weights and the second GF31 weight shift + F w1 = i == 0 ? base : optionalHalve(fancyMul(base, fweightStep(i))); + F w2 = optionalHalve(fancyMul(w1, WEIGHT_STEP)); + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 31) weight_shift -= 31; + u32 weight_shift1 = weight_shift; + // Convert and weight input + uF2[i] = U2(in[p].x * w1, in[p].y * w2); + u31[i] = U2(shl(make_Z31(in[p].x), weight_shift0), shl(make_Z31(in[p].y), weight_shift1)); // Form a GF31 from each pair of input words + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_bigstep; + if (weight_shift > 31) weight_shift -= 31; + } + + fft_WIDTH(ldsF2, uF2, smallTrigF2); + writeCarryFusedLine(uF2, outF2, g); + bar(); + fft_WIDTH(lds31, u31, smallTrig31); + writeCarryFusedLine(u31, out31, g); +} + + +/**************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP32 & GF(M61^2) */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT3261 + +// fftPremul: weight words with IBDWT weights followed by FFT-width. +KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig, BigTabFP32 THREAD_WEIGHTS) { + local GF61 lds61[WIDTH / 2]; + local F2 *ldsF2 = (local F2 *) lds61; + F2 uF2[NW]; + GF61 u61[NW]; + + u32 g = get_group_id(0); + u32 me = get_local_id(0); + + P(F2) outF2 = (P(F2)) out; + TrigFP32 smallTrigF2 = (TrigFP32) smallTrig; + P(GF61) out61 = (P(GF61)) (out + DISTGF61); + TrigGF61 smallTrig61 = (TrigGF61) (smallTrig + DISTWTRIGGF61); + + in += g * WIDTH; + + F base = optionalHalve(fancyMul(THREAD_WEIGHTS[me].y, THREAD_WEIGHTS[G_W + g].y)); + + u32 word_index = (me * BIG_HEIGHT + g) * 2; + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + // Let s be the shift amount for word 1. The shift amount for word x is ceil(x * (s - 1) + num_big_words_less_than_x) % 61. + const u32 log2_root_two = (u32) (((1ULL << 60) / NWORDS) % 61); + const u32 bigword_weight_shift = (NWORDS - EXP % NWORDS) * log2_root_two % 61; + const u32 bigword_weight_shift_minus1 = (bigword_weight_shift + 60) % 61; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } combo; +#define frac_bits combo.a[0] +#define weight_shift combo.a[1] +#define combo_counter combo.b + + const u64 combo_step = ((u64) bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 combo_bigstep = ((G_W * BIG_HEIGHT * 2 - 1) * combo_step + (((u64) (G_W * BIG_HEIGHT * 2 - 1) * FRAC_BPW_LO) >> 32)) % (61ULL << 32); + combo_counter = word_index * combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + weight_shift = weight_shift % 61; + + for (u32 i = 0; i < NW; ++i) { + u32 p = G_W * i + me; + // Generate the FP32 weights and the second GF61 weight shift + F w1 = i == 0 ? base : optionalHalve(fancyMul(base, fweightStep(i))); + F w2 = optionalHalve(fancyMul(w1, WEIGHT_STEP)); + u32 weight_shift0 = weight_shift; + combo_counter += combo_step; + if (weight_shift > 61) weight_shift -= 61; + u32 weight_shift1 = weight_shift; + // Convert and weight input + uF2[i] = U2(in[p].x * w1, in[p].y * w2); + u61[i] = U2(shl(make_Z61(in[p].x), weight_shift0), shl(make_Z61(in[p].y), weight_shift1)); // Form a GF61 from each pair of input words + // Generate weight shifts and frac_bits for next pair + combo_counter += combo_bigstep; + if (weight_shift > 61) weight_shift -= 61; + } + + fft_WIDTH(ldsF2, uF2, smallTrigF2); + writeCarryFusedLine(uF2, outF2, g); + bar(); + fft_WIDTH(lds61, u61, smallTrig61); + writeCarryFusedLine(u61, out61, g); +} + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2)*GF(M61^2) */ +/**************************************************************************/ + +#elif FFT_TYPE == FFT3161 + +// fftPremul: weight words with IBDWT weights followed by FFT-width. +KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig) { + local GF61 lds61[WIDTH / 2]; + local GF31 *lds31 = (local GF31 *) lds61; + GF31 u31[NW]; + GF61 u61[NW]; + + u32 g = get_group_id(0); + u32 me = get_local_id(0); + + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 smallTrig31 = (TrigGF31) (smallTrig + DISTWTRIGGF31); + P(GF61) out61 = (P(GF61)) (out + DISTGF61); + TrigGF61 smallTrig61 = (TrigGF61) (smallTrig + DISTWTRIGGF61); + + in += g * WIDTH; + + u32 word_index = (me * BIG_HEIGHT + g) * 2; + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + // Weights can be applied with shifts because 2 is the 60th root GF61. + // Let s be the shift amount for word 1. The shift amount for word x is ceil(x * (s - 1) + num_big_words_less_than_x) % 61. + const u32 m31_log2_root_two = (u32) (((1ULL << 30) / NWORDS) % 31); + const u32 m31_bigword_weight_shift = (NWORDS - EXP % NWORDS) * m31_log2_root_two % 31; + const u32 m31_bigword_weight_shift_minus1 = (m31_bigword_weight_shift + 30) % 31; + const u32 m61_log2_root_two = (u32) (((1ULL << 60) / NWORDS) % 61); + const u32 m61_bigword_weight_shift = (NWORDS - EXP % NWORDS) * m61_log2_root_two % 61; + const u32 m61_bigword_weight_shift_minus1 = (m61_bigword_weight_shift + 60) % 61; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } m31_combo, m61_combo; +#define frac_bits m31_combo.a[0] +#define m31_weight_shift m31_combo.a[1] +#define m31_combo_counter m31_combo.b +#define m61_weight_shift m61_combo.a[1] +#define m61_combo_counter m61_combo.b + + const u64 m31_combo_step = ((u64) m31_bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 m31_combo_bigstep = ((G_W * BIG_HEIGHT * 2 - 1) * m31_combo_step + (((u64) (G_W * BIG_HEIGHT * 2 - 1) * FRAC_BPW_LO) >> 32)) % (31ULL << 32); + m31_combo_counter = word_index * m31_combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + m31_weight_shift = m31_weight_shift % 31; + const u64 m61_combo_step = ((u64) m61_bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 m61_combo_bigstep = ((G_W * BIG_HEIGHT * 2 - 1) * m61_combo_step + (((u64) (G_W * BIG_HEIGHT * 2 - 1) * FRAC_BPW_LO) >> 32)) % (61ULL << 32); + m61_combo_counter = word_index * m61_combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + m61_weight_shift = m61_weight_shift % 61; + + for (u32 i = 0; i < NW; ++i) { + u32 p = G_W * i + me; + // Generate the second weight shifts + u32 m31_weight_shift0 = m31_weight_shift; + m31_combo_counter += m31_combo_step; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift); + u32 m31_weight_shift1 = m31_weight_shift; + u32 m61_weight_shift0 = m61_weight_shift; + m61_combo_counter += m61_combo_step; + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift); + u32 m61_weight_shift1 = m61_weight_shift; + // Convert and weight input + u31[i] = U2(shl(make_Z31(in[p].x), m31_weight_shift0), shl(make_Z31(in[p].y), m31_weight_shift1)); // Form a GF31 from each pair of input words + u61[i] = U2(shl(make_Z61(in[p].x), m61_weight_shift0), shl(make_Z61(in[p].y), m61_weight_shift1)); // Form a GF61 from each pair of input words + +// Generate weight shifts and frac_bits for next pair + m31_combo_counter += m31_combo_bigstep; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift); + m61_combo_counter += m61_combo_bigstep; + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift); + } + + fft_WIDTH(lds31, u31, smallTrig31); + writeCarryFusedLine(u31, out31, g); + bar(); + fft_WIDTH(lds61, u61, smallTrig61); + writeCarryFusedLine(u61, out61, g); +} + + +/******************************************************************************/ +/* Similar to above, but for a hybrid FFT based on FP32*GF(M31^2)*GF(M61^2) */ +/******************************************************************************/ + +#elif FFT_TYPE == FFT323161 + +// fftPremul: weight words with IBDWT weights followed by FFT-width. +KERNEL(G_W) fftP(P(T2) out, CP(Word2) in, Trig smallTrig, BigTabFP32 THREAD_WEIGHTS) { + local GF61 lds61[WIDTH / 2]; + local F2 *ldsF2 = (local F2 *) lds61; + local GF31 *lds31 = (local GF31 *) lds61; + F2 uF2[NW]; + GF31 u31[NW]; + GF61 u61[NW]; + + u32 g = get_group_id(0); + u32 me = get_local_id(0); + + P(F2) outF2 = (P(F2)) out; + TrigFP32 smallTrigF2 = (TrigFP32) smallTrig; + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 smallTrig31 = (TrigGF31) (smallTrig + DISTWTRIGGF31); + P(GF61) out61 = (P(GF61)) (out + DISTGF61); + TrigGF61 smallTrig61 = (TrigGF61) (smallTrig + DISTWTRIGGF61); + + in += g * WIDTH; + + F base = optionalHalve(fancyMul(THREAD_WEIGHTS[me].y, THREAD_WEIGHTS[G_W + g].y)); + + u32 word_index = (me * BIG_HEIGHT + g) * 2; + + // Weight is 2^[ceil(qj / n) - qj/n] where j is the word index, q is the Mersenne exponent, and n is the number of words. + // Weights can be applied with shifts because 2 is the 60th root GF61. + // Let s be the shift amount for word 1. The shift amount for word x is ceil(x * (s - 1) + num_big_words_less_than_x) % 61. + const u32 m31_log2_root_two = (u32) (((1ULL << 30) / NWORDS) % 31); + const u32 m31_bigword_weight_shift = (NWORDS - EXP % NWORDS) * m31_log2_root_two % 31; + const u32 m31_bigword_weight_shift_minus1 = (m31_bigword_weight_shift + 30) % 31; + const u32 m61_log2_root_two = (u32) (((1ULL << 60) / NWORDS) % 61); + const u32 m61_bigword_weight_shift = (NWORDS - EXP % NWORDS) * m61_log2_root_two % 61; + const u32 m61_bigword_weight_shift_minus1 = (m61_bigword_weight_shift + 60) % 61; + + // Derive the big vs. little flags from the fractional number of bits in each word. + // Create a 64-bit counter that tracks both weight shifts and frac_bits (adding 0xFFFFFFFF to effect the ceil operation required for weight shift). + union { uint2 a; u64 b; } m31_combo, m61_combo; +#define frac_bits m31_combo.a[0] +#define m31_weight_shift m31_combo.a[1] +#define m31_combo_counter m31_combo.b +#define m61_weight_shift m61_combo.a[1] +#define m61_combo_counter m61_combo.b + + const u64 m31_combo_step = ((u64) m31_bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 m31_combo_bigstep = ((G_W * BIG_HEIGHT * 2 - 1) * m31_combo_step + (((u64) (G_W * BIG_HEIGHT * 2 - 1) * FRAC_BPW_LO) >> 32)) % (31ULL << 32); + m31_combo_counter = word_index * m31_combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + m31_weight_shift = m31_weight_shift % 31; + const u64 m61_combo_step = ((u64) m61_bigword_weight_shift_minus1 << 32) + FRAC_BPW_HI; + const u64 m61_combo_bigstep = ((G_W * BIG_HEIGHT * 2 - 1) * m61_combo_step + (((u64) (G_W * BIG_HEIGHT * 2 - 1) * FRAC_BPW_LO) >> 32)) % (61ULL << 32); + m61_combo_counter = word_index * m61_combo_step + mul_hi(word_index, FRAC_BPW_LO) + 0xFFFFFFFFULL; + m61_weight_shift = m61_weight_shift % 61; + + for (u32 i = 0; i < NW; ++i) { + u32 p = G_W * i + me; + // Generate the FP32 weights and the second GF31 and GF61 weight shift + F w1 = i == 0 ? base : optionalHalve(fancyMul(base, fweightStep(i))); + F w2 = optionalHalve(fancyMul(w1, WEIGHT_STEP)); + u32 m31_weight_shift0 = m31_weight_shift; + m31_combo_counter += m31_combo_step; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift); + u32 m31_weight_shift1 = m31_weight_shift; + u32 m61_weight_shift0 = m61_weight_shift; + m61_combo_counter += m61_combo_step; + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift); + u32 m61_weight_shift1 = m61_weight_shift; + // Convert and weight input + uF2[i] = U2(in[p].x * w1, in[p].y * w2); + u31[i] = U2(shl(make_Z31(in[p].x), m31_weight_shift0), shl(make_Z31(in[p].y), m31_weight_shift1)); // Form a GF31 from each pair of input words + u61[i] = U2(shl(make_Z61(in[p].x), m61_weight_shift0), shl(make_Z61(in[p].y), m61_weight_shift1)); // Form a GF61 from each pair of input words + +// Generate weight shifts and frac_bits for next pair + m31_combo_counter += m31_combo_bigstep; + m31_weight_shift = adjust_m31_weight_shift(m31_weight_shift); + m61_combo_counter += m61_combo_bigstep; + m61_weight_shift = adjust_m61_weight_shift(m61_weight_shift); + } + + fft_WIDTH(ldsF2, uF2, smallTrigF2); + writeCarryFusedLine(uF2, outF2, g); + bar(); + fft_WIDTH(lds31, u31, smallTrig31); + writeCarryFusedLine(u31, out31, g); + bar(); + fft_WIDTH(lds61, u61, smallTrig61); + writeCarryFusedLine(u61, out61, g); } + + +#else +error - missing FFTp kernel implementation +#endif diff --git a/src/cl/fftw.cl b/src/cl/fftw.cl index afefc547..a19b26d0 100644 --- a/src/cl/fftw.cl +++ b/src/cl/fftw.cl @@ -5,10 +5,12 @@ #include "fftwidth.cl" #include "middle.cl" -// Do an fft_WIDTH after a transposeH (which may not have fully transposed data, leading to non-sequential input) +#if FFT_FP64 + +// Do the ending fft_WIDTH after an fftMiddleOut. This is the same as the first half of carryFused. KERNEL(G_W) fftW(P(T2) out, CP(T2) in, Trig smallTrig) { local T2 lds[WIDTH / 2]; - + T2 u[NW]; u32 g = get_group_id(0); @@ -17,3 +19,81 @@ KERNEL(G_W) fftW(P(T2) out, CP(T2) in, Trig smallTrig) { out += WIDTH * g; write(G_W, NW, u, out, 0); } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 + +// Do the ending fft_WIDTH after an fftMiddleOut. This is the same as the first half of carryFused. +KERNEL(G_W) fftW(P(T2) out, CP(T2) in, Trig smallTrig) { + local F2 lds[WIDTH / 2]; + + CP(F2) inF2 = (CP(F2)) in; + P(F2) outF2 = (P(F2)) out; + TrigFP32 smallTrigF2 = (TrigFP32) smallTrig; + + F2 u[NW]; + u32 g = get_group_id(0); + + readCarryFusedLine(inF2, u, g); + fft_WIDTH(lds, u, smallTrigF2); + outF2 += WIDTH * g; + write(G_W, NW, u, outF2, 0); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +KERNEL(G_W) fftWGF31(P(T2) out, CP(T2) in, Trig smallTrig) { + local GF31 lds[WIDTH / 2]; + + CP(GF31) in31 = (CP(GF31)) (in + DISTGF31); + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 smallTrig31 = (TrigGF31) (smallTrig + DISTWTRIGGF31); + + GF31 u[NW]; + u32 g = get_group_id(0); + + readCarryFusedLine(in31, u, g); + fft_WIDTH(lds, u, smallTrig31); + out31 += WIDTH * g; + write(G_W, NW, u, out31, 0); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +KERNEL(G_W) fftWGF61(P(T2) out, CP(T2) in, Trig smallTrig) { + local GF61 lds[WIDTH / 2]; + + CP(GF61) in61 = (CP(GF61)) (in + DISTGF61); + P(GF61) out61 = (P(GF61)) (out + DISTGF61); + TrigGF61 smallTrig61 = (TrigGF61) (smallTrig + DISTWTRIGGF61); + + GF61 u[NW]; + u32 g = get_group_id(0); + + readCarryFusedLine(in61, u, g); + fft_WIDTH(lds, u, smallTrig61); + out61 += WIDTH * g; + write(G_W, NW, u, out61, 0); +} + +#endif diff --git a/src/cl/fftwidth.cl b/src/cl/fftwidth.cl index 92cfc47d..d0589ab0 100644 --- a/src/cl/fftwidth.cl +++ b/src/cl/fftwidth.cl @@ -6,7 +6,9 @@ #error WIDTH must be one of: 256, 512, 1024, 4096, 625 #endif -void fft_NW(T2 *u) { +#if FFT_FP64 + +void OVERLOAD fft_NW(T2 *u) { #if NW == 4 fft4(u); #elif NW == 5 @@ -27,7 +29,7 @@ void fft_NW(T2 *u) { #error FFT_VARIANT_W == 0 only supported by AMD GPUs #endif -void fft_WIDTH(local T2 *lds, T2 *u, Trig trig) { +void OVERLOAD fft_WIDTH(local T2 *lds, T2 *u, Trig trig) { u32 me = get_local_id(0); #if NW == 8 T2 w = fancyTrig_N(ND / WIDTH * me); @@ -49,7 +51,7 @@ void fft_WIDTH(local T2 *lds, T2 *u, Trig trig) { #else -void fft_WIDTH(local T2 *lds, T2 *u, Trig trig) { +void OVERLOAD fft_WIDTH(local T2 *lds, T2 *u, Trig trig) { u32 me = get_local_id(0); #if !UNROLL_W @@ -59,7 +61,7 @@ void fft_WIDTH(local T2 *lds, T2 *u, Trig trig) { if (s > 1) { bar(); } fft_NW(u); tabMul(WIDTH / NW, trig, u, NW, s, me); - shufl( WIDTH / NW, lds, u, NW, s); + shufl(WIDTH / NW, lds, u, NW, s); } fft_NW(u); } @@ -67,15 +69,12 @@ void fft_WIDTH(local T2 *lds, T2 *u, Trig trig) { #endif - - - // New fft_WIDTH that uses more FMA instructions than the old fft_WIDTH. // The tabMul after fft8 only does a partial complex multiply, saving a mul-by-cosine for the next fft8 using FMA instructions. -// To maximize FMA opportunities we precompute tig values as cosine and sine/cosine rather than cosine and sine. +// To maximize FMA opportunities we precompute trig values as cosine and sine/cosine rather than cosine and sine. // The downside is sine/cosine cannot be computed with chained multiplies. -void new_fft_WIDTH(local T2 *lds, T2 *u, Trig trig, int callnum) { +void OVERLOAD new_fft_WIDTH(local T2 *lds, T2 *u, Trig trig, int callnum) { u32 WG = WIDTH / NW; u32 me = get_local_id(0); @@ -209,6 +208,119 @@ void new_fft_WIDTH(local T2 *lds, T2 *u, Trig trig, int callnum) { #endif } -void new_fft_WIDTH1(local T2 *lds, T2 *u, Trig trig) { new_fft_WIDTH(lds, u, trig, 1); } -void new_fft_WIDTH2(local T2 *lds, T2 *u, Trig trig) { new_fft_WIDTH(lds, u, trig, 2); } +// There are two version of new_fft_WIDTH in case we want to try saving some trig values from new_fft_WIDTH1 in LDS memory for later use in new_fft_WIDTH2. +void OVERLOAD new_fft_WIDTH1(local T2 *lds, T2 *u, Trig trig) { new_fft_WIDTH(lds, u, trig, 1); } +void OVERLOAD new_fft_WIDTH2(local T2 *lds, T2 *u, Trig trig) { new_fft_WIDTH(lds, u, trig, 2); } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 + +void OVERLOAD fft_NW(F2 *u) { +#if NW == 4 + fft4(u); +#elif NW == 8 + fft8(u); +#else +#error NW +#endif +} + +void OVERLOAD fft_WIDTH(local F2 *lds, F2 *u, TrigFP32 trig) { + u32 me = get_local_id(0); + +#if !UNROLL_W + __attribute__((opencl_unroll_hint(1))) +#endif + for (u32 s = 1; s < WIDTH / NW; s *= NW) { + if (s > 1) { bar(); } + fft_NW(u); + tabMul(WIDTH / NW, trig, u, NW, s, me); + shufl(WIDTH / NW, lds, u, NW, s); + } + fft_NW(u); +} + +void OVERLOAD new_fft_WIDTH1(local F2 *lds, F2 *u, TrigFP32 trig) { fft_WIDTH(lds, u, trig); } +void OVERLOAD new_fft_WIDTH2(local F2 *lds, F2 *u, TrigFP32 trig) { fft_WIDTH(lds, u, trig); } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +void OVERLOAD fft_NW(GF31 *u) { +#if NW == 4 + fft4(u); +#elif NW == 8 + fft8(u); +#else +#error NW +#endif +} + +void OVERLOAD fft_WIDTH(local GF31 *lds, GF31 *u, TrigGF31 trig) { + u32 me = get_local_id(0); + +#if !UNROLL_W + __attribute__((opencl_unroll_hint(1))) +#endif + for (u32 s = 1; s < WIDTH / NW; s *= NW) { + if (s > 1) { bar(); } + fft_NW(u); + tabMul(WIDTH / NW, trig, u, NW, s, me); + shufl(WIDTH / NW, lds, u, NW, s); + } + fft_NW(u); +} + +void OVERLOAD new_fft_WIDTH1(local GF31 *lds, GF31 *u, TrigGF31 trig) { fft_WIDTH(lds, u, trig); } +void OVERLOAD new_fft_WIDTH2(local GF31 *lds, GF31 *u, TrigGF31 trig) { fft_WIDTH(lds, u, trig); } + +#endif + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +void OVERLOAD fft_NW(GF61 *u) { +#if NW == 4 + fft4(u); +#elif NW == 8 + fft8(u); +#else +#error NW +#endif +} + +void OVERLOAD fft_WIDTH(local GF61 *lds, GF61 *u, TrigGF61 trig) { + u32 me = get_local_id(0); + +#if !UNROLL_W + __attribute__((opencl_unroll_hint(1))) +#endif + for (u32 s = 1; s < WIDTH / NW; s *= NW) { + if (s > 1) { bar(); } + fft_NW(u); + tabMul(WIDTH / NW, trig, u, NW, s, me); + shufl(WIDTH / NW, lds, u, NW, s); + } + fft_NW(u); +} + +void OVERLOAD new_fft_WIDTH1(local GF61 *lds, GF61 *u, TrigGF61 trig) { fft_WIDTH(lds, u, trig); } +void OVERLOAD new_fft_WIDTH2(local GF61 *lds, GF61 *u, TrigGF61 trig) { fft_WIDTH(lds, u, trig); } + +#endif diff --git a/src/cl/math.cl b/src/cl/math.cl index a97ea506..479238c6 100644 --- a/src/cl/math.cl +++ b/src/cl/math.cl @@ -4,6 +4,283 @@ #include "base.cl" +// Access parts of a 64-bit value + +u32 OVERLOAD lo32(u64 x) { uint2 x2 = as_uint2(x); return (u32)x2.x; } +u32 OVERLOAD hi32(u64 x) { uint2 x2 = as_uint2(x); return (u32)x2.y; } +u32 OVERLOAD lo32(i64 x) { uint2 x2 = as_uint2(x); return (u32)x2.x; } +i32 OVERLOAD hi32(i64 x) { uint2 x2 = as_uint2(x); return (i32)x2.y; } + +// A primitive partial implementation of an i96 integer type +#if 1 +// An all 32-bit implementation. The add and subtract routines desperately need to use ASM with add.cc and sub.cc PTX instructions. +// This version might be best on AMD and Intel if we can generate add-with-carry instructions. +typedef struct { i32 hi32; u32 mid32; u32 lo32; } i96; +i96 OVERLOAD make_i96(i64 v) { i96 val; val.lo32 = lo32(v); val.mid32 = hi32(v); val.hi32 = (i32)val.mid32 >> 31; return val; } +i96 OVERLOAD make_i96(i32 v) { i96 val; val.lo32 = v; val.mid32 = val.hi32 = (i32)val.lo32 >> 31; return val; } +i96 OVERLOAD make_i96(i64 hi, u64 lo) { i96 val; val.hi32 = hi; val.mid32 = hi32(lo); val.lo32 = lo32(lo); return val; } +i96 OVERLOAD make_i96(i32 hi, u64 lo) { i96 val; val.hi32 = hi; val.mid32 = hi32(lo); val.lo32 = lo32(lo); return val; } +u32 i96_hi32(i96 val) { return val.hi32; } +u32 i96_mid32(i96 val) { return val.mid32; } +u32 i96_lo32(i96 val) { return val.lo32; } +u64 i96_lo64(i96 val) { return as_ulong((uint2)(val.lo32, val.mid32)); } +u64 i96_hi64(i96 val) { return as_ulong((uint2)(val.mid32, val.hi32)); } +i96 OVERLOAD add(i96 a, i96 b) { + i96 val; +#if HAS_PTX >= 200 // Carry instructions requires sm_20 support or higher + __asm("add.cc.u32 %0, %3, %6;\n\t" + "addc.cc.u32 %1, %4, %7;\n\t" + "addc.u32 %2, %5, %8;" + : "=r"(val.lo32), "=r"(val.mid32), "=r"(val.hi32) + : "r"(a.lo32), "r"(a.mid32), "r"(a.hi32), "r"(b.lo32), "r"(b.mid32), "r"(b.hi32)); +#else + u64 alo64 = as_ulong((uint2)(a.lo32, a.mid32)); u64 blo64 = as_ulong((uint2)(b.lo32, b.mid32)); u64 lo64 = alo64 + blo64; + val.lo32 = lo32(lo64); val.mid32 = hi32(lo64); val.hi32 = a.hi32 + b.hi32 + (lo64 < alo64); +#endif + return val; +} +i96 OVERLOAD add(i96 a, i64 b) { return add(a, make_i96(b)); } +i96 OVERLOAD sub(i96 a, i96 b) { + i96 val; +#if HAS_PTX >= 200 // Carry instructions requires sm_20 support or higher + __asm("sub.cc.u32 %0, %3, %6;\n\t" + "subc.cc.u32 %1, %4, %7;\n\t" + "subc.u32 %2, %5, %8;" + : "=r"(val.lo32), "=r"(val.mid32), "=r"(val.hi32) + : "r"(a.lo32), "r"(a.mid32), "r"(a.hi32), "r"(b.lo32), "r"(b.mid32), "r"(b.hi32)); +#else + u64 alo64 = as_ulong((uint2)(a.lo32, a.mid32)); u64 blo64 = as_ulong((uint2)(b.lo32, b.mid32)); u64 lo64 = alo64 - blo64; + val.lo32 = lo32(lo64); val.mid32 = hi32(lo64); val.hi32 = a.hi32 - b.hi32 - (lo64 > alo64); +#endif + return val; +} +i96 OVERLOAD sub(i96 a, i64 b) { return sub(a, make_i96(b)); } +i96 OVERLOAD sub(i96 a, i32 b) { return sub(a, make_i96(b)); } +#elif defined(__SIZEOF_INT128__) +// An i128 implementation. This might use more GPU registers. nVidia likes this version. +typedef struct { __int128 x; } i96; +i96 OVERLOAD make_i96(i64 v) { i96 val; val.x = v; return val; } +i96 OVERLOAD make_i96(i32 v) { i96 val; val.x = v; return val; } +i96 OVERLOAD make_i96(i64 hi, u64 lo) { i96 val; val.x = ((unsigned __int128)hi << 64) + lo; return val; } +i96 OVERLOAD make_i96(i32 hi, u64 lo) { return make_i96((i64)hi, lo); } +u32 i96_hi32(i96 val) { return (unsigned __int128)val.x >> 64; } +u32 i96_mid32(i96 val) { return (u64)val.x >> 32; } +u32 i96_lo32(i96 val) { return val.x; } +u64 i96_lo64(i96 val) { return val.x; } +u64 i96_hi64(i96 val) { return (unsigned __int128)val.x >> 32; } +i96 OVERLOAD add(i96 a, i96 b) { i96 val; val.x = a.x + b.x; return val; } +i96 OVERLOAD add(i96 a, i64 b) { return add(a, make_i96(b)); } +i96 OVERLOAD sub(i96 a, i96 b) { i96 val; val.x = a.x - b.x; return val; } +i96 OVERLOAD sub(i96 a, i64 b) { return sub(a, make_i96(b)); } +i96 OVERLOAD sub(i96 a, i32 b) { return sub(a, make_i96(b)); } +#elif 1 +// A u64 lo32, u32 hi32 implementation. This too would benefit from add with carry instructions. +// On nVidia, the clang optimizer kept the hi32 value as 64-bits! +typedef struct { u64 lo64; u32 hi32; } i96; +i96 OVERLOAD make_i96(i64 v) { i96 val; val.hi32 = v >> 63, val.lo64 = v; return val; } +i96 OVERLOAD make_i96(i32 v) { return make_i96((i64)v); } +i96 OVERLOAD make_i96(i64 hi, u64 lo) { i96 val; val.hi32 = hi, val.lo64 = lo; return val; } +i96 OVERLOAD make_i96(i32 hi, u64 lo) { i96 val; val.hi32 = hi, val.lo64 = lo; return val; } +u32 i96_hi32(i96 val) { return val.hi32; } +u32 i96_mid32(i96 val) { return hi32(val.lo64); } +u32 i96_lo32(i96 val) { return val.lo64; } +u64 i96_lo64(i96 val) { return val.lo64; } +u64 i96_hi64(i96 val) { return ((u64) val.hi32 << 32) | i96_mid32(val); } +i96 OVERLOAD add(i96 a, i96 b) { i96 val; val.lo64 = a.lo64 + b.lo64; val.hi32 = a.hi32 + b.hi32 + (val.lo64 < a.lo64); return val; } +i96 OVERLOAD add(i96 a, i64 b) { return add(a, make_i96(b)); } +i96 OVERLOAD sub(i96 a, i96 b) { i96 val; val.lo64 = a.lo64 - b.lo64; val.hi32 = a.hi32 - b.hi32 - (val.lo64 > a.lo64); return val; } +i96 OVERLOAD sub(i96 a, i64 b) { return sub(a, make_i96(b)); } +i96 OVERLOAD sub(i96 a, i32 b) { return sub(a, make_i96(b)); } +#endif + +// A primitive partial implementation of an i128 and u128 integer type +#if defined(__SIZEOF_INT128__) +typedef struct { __int128 x; } i128; +typedef struct { unsigned __int128 x; } u128; +i128 OVERLOAD make_i128(i64 hi, u64 lo) { i128 val; val.x = ((__int128)hi << 64) | lo; return val; } +i128 OVERLOAD make_i128(u64 hi, u64 lo) { i128 val; val.x = ((__int128)hi << 64) | lo; return val; } +u64 i128_lo64(i128 val) { return val.x; } +u64 i128_shrlo64(i128 val, u32 bits) { return val.x >> bits; } +i128 OVERLOAD i128_masklo64(i128 a, u64 m) { i128 val; val.x = a.x & (((__int128)0xFFFFFFFFFFFFFFFFULL << 64) | m); return val; } +i128 OVERLOAD add(i128 a, i128 b) { i128 val; val.x = a.x + b.x; return val; } +i128 OVERLOAD add(i128 a, u64 b) { i128 val; val.x = a.x + (__int128)b; return val; } +i128 OVERLOAD add(i128 a, i64 b) { i128 val; val.x = a.x + (__int128)b; return val; } +i128 OVERLOAD sub(i128 a, u64 b) { i128 val; val.x = a.x - (__int128)b; return val; } +i128 OVERLOAD sub(i128 a, i64 b) { i128 val; val.x = a.x - (__int128)b; return val; } +u128 OVERLOAD make_u128(u64 hi, u64 lo) { u128 val; val.x = ((unsigned __int128)hi << 64) | lo; return val; } +u64 u128_lo64(u128 val) { return val.x; } +u64 u128_hi64(u128 val) { return val.x >> 64; } +u128 mul64(u64 a, u64 b) { u128 val; val.x = (unsigned __int128)a * (unsigned __int128)b; return val; } +u128 OVERLOAD add(u128 a, u128 b) { u128 val; val.x = a.x + b.x; return val; } +#else // UNTESTED! The mul64 macro causes clang to hang! +typedef struct { i64 hi64; u64 lo64; } i128; +typedef struct { u64 hi64; u64 lo64; } u128; +i128 OVERLOAD make_i128(i64 hi, u64 lo) { i128 val; val.hi64 = hi; val.lo64 = lo; return val; } +i128 OVERLOAD make_i128(u64 hi, u64 lo) { i128 val; val.hi64 = hi; val.lo64 = lo; return val; } +u64 i128_lo64(i128 val) { return val.lo64; } +u64 i128_shrlo64(i128 val, u32 bits) { return (val.hi64 << (64 - bits)) | (val.lo64 >> bits); } +i128 OVERLOAD i128_masklo64(i128 a, u64 m) { i128 val; val.lo64 = a.lo64 & m; val.hi64 = a.hi64; return val; } +i128 OVERLOAD add(i128 a, i128 b) { i128 val; val.lo64 = a.lo64 + b.lo64; val.hi64 = a.hi64 + b.hi64 + (val.lo64 < a.lo64); return val; } +i128 OVERLOAD add(i128 a, u64 b) { i128 val; val.lo64 = a.lo64 + b; val.hi64 = a.hi64 + (val.lo64 < a.lo64); return val; } +i128 OVERLOAD add(i128 a, i64 b) { i128 val; val.lo64 = a.lo64 + b; val.hi64 = a.hi64 + (b >> 63) + (val.lo64 < a.lo64); return val; } +i128 OVERLOAD sub(i128 a, u64 b) { i128 val; val.lo64 = a.lo64 - b; val.hi64 = a.hi64 - (val.lo64 > a.lo64); return val; } +i128 OVERLOAD sub(i128 a, i64 b) { i128 val; val.lo64 = a.lo64 - (u64)b; val.hi64 = a.hi64 - (b >> 63) - (val.lo64 > a.lo64); return val; } +u128 OVERLOAD make_u128(u64 hi, u64 lo) { u128 val; val.hi64 = hi; val.lo64 = lo; return val; } +u64 u128_lo64(u128 val) { return val.lo64; } +u64 u128_hi64(u128 val) { return val.hi64; } +u128 mul64(u64 a, u64 b) { u128 val; val.lo64 = a * b; val.hi64 = mul_hi(a, b); return val; } +u128 OVERLOAD add(u128 a, u128 b) { u128 val; val.lo64 = a.lo64 + b.lo64; val.hi64 = a.hi64 + b.hi64 + (val.lo64 < a.lo64); return val; } +#endif + +// Select based on sign of first argument. This generates less PTX code, but is no faster on 5xxx GPUs +i32 select32(i32 a, i32 b, i32 c) { +#if HAS_PTX >= 100 // slct instruction requires sm_10 support or higher + i32 res; + __asm("slct.s32.s32 %0, %2, %3, %1;" : "=r"(res) : "r"(a), "r"(b), "r"(c)); + return res; +#else + return a >= 0 ? b : c; +#endif +} + +// Optionally add a value if first arg is negative. +i32 optional_add(i32 a, const i32 b) { +#if HAS_PTX >= 100 // setp/add instruction requires sm_10 support or higher + __asm("{.reg .pred %%p;\n\t" + " setp.lt.s32 %%p, %0, 0;\n\t" // a < 0 + " @%%p add.s32 %0, %0, %1;}" // if (a < 0) a = a + b + : "+r"(a) : "n"(b)); +#else + if (a < 0) a = a + b; +#endif + return a; +} + +// Optionally subtract a value if first arg is negative. +i32 optional_sub(i32 a, const i32 b) { +#if HAS_PTX >= 100 // setp/sub instruction requires sm_10 support or higher + __asm("{.reg .pred %%p;\n\t" + " setp.lt.s32 %%p, %0, 0;\n\t" // a < 0 + " @%%p sub.s32 %0, %0, %1;}" // if (a < 0) a = a - b + : "+r"(a) : "n"(b)); +#else + if (a < 0) a = a - b; +#endif + return a; +} + +// Optionally subtract a value if first arg is greater than value. +i32 optional_mod(i32 a, const i32 b) { +#if 0 //HAS_PTX >= 100 // setp/sub instruction requires sm_10 support or higher // Not faster on 5xxx GPUs (not sure why) + __asm("{.reg .pred %%p;\n\t" + " setp.ge.s32 %%p, %0, %1;\n\t" // a > b + " @%%p sub.s32 %0, %0, %1;}" // if (a > b) a = a - b + : "+r"(a) : "n"(b)); +#else + if (a >= b) a = a - b; +#endif + return a; +} + +// Multiply and add primitives + +u64 OVERLOAD mad32(u32 a, u32 b, u32 c) { +#if HAS_PTX >= 200 // mad instruction requires sm_20 support or higher // Same speed on TitanV, any gain may be too small to measure + u32 reslo, reshi; + __asm("mad.lo.cc.u32 %0, %2, %3, %4;\n\t" + "madc.hi.u32 %1, %2, %3, 0;" : "=r"(reslo), "=r"(reshi) : "r"(a), "r"(b), "r"(c)); + return as_ulong((uint2)(reslo, reshi)); +#else + return (u64)a * (u64)b + c; +#endif +} + +u64 OVERLOAD mad32(u32 a, u32 b, u64 c) { +#if HAS_PTX >= 200 // mad instruction requires sm_20 support or higher // Same speed on TitanV, any gain may be too small to measure + u32 reslo, reshi; + __asm("mad.lo.cc.u32 %0, %2, %3, %4;\n\t" + "madc.hi.u32 %1, %2, %3, %5;" : "=r"(reslo), "=r"(reshi) : "r"(a), "r"(b), "r"(lo32(c)), "r"(hi32(c))); + return as_ulong((uint2)(reslo, reshi)); +#else + return (u64)a * (u64)b + c; +#endif +} + +u128 OVERLOAD mad64(u64 a, u64 b, u64 c) { +#if 0 && HAS_PTX >= 200 // mad instruction requires sm_20 support or higher // Slower on TitanV and mobile 4070, don't understand why + u64 reslo, reshi; + __asm("mad.lo.cc.u64 %0, %2, %3, %4;\n\t" + "madc.hi.u64 %1, %2, %3, 0;" : "=l"(reslo), "=l"(reshi) : "l"(a), "l"(b), "l"(u128_lo64(c))); + return make_u128(reshi, reslo); +#elif HAS_PTX >= 200 // mad instruction requires sm_20 support or higher // Faster on TitanV. No difference on mobile 4070. Much cleaner PTX code generated. + uint2 a2 = as_uint2(a); + uint2 b2 = as_uint2(b); + uint2 c2 = as_uint2(c); + uint2 rlo2, rhi2; + __asm("mad.lo.cc.u32 %0, %4, %6, %8;\n\t" + "madc.hi.cc.u32 %1, %4, %6, %9;\n\t" + "madc.lo.cc.u32 %2, %5, %7, 0;\n\t" + "madc.hi.u32 %3, %5, %7, 0;\n\t" + "mad.lo.cc.u32 %1, %5, %6, %1;\n\t" + "madc.hi.cc.u32 %2, %5, %6, %2;\n\t" + "addc.u32 %3, %3, 0;\n\t" + "mad.lo.cc.u32 %1, %4, %7, %1;\n\t" + "madc.hi.cc.u32 %2, %4, %7, %2;\n\t" + "addc.u32 %3, %3, 0;" + : "=r"(rlo2.x), "=r"(rlo2.y), "=r"(rhi2.x), "=r"(rhi2.y) + : "r"(a2.x), "r"(a2.y), "r"(b2.x), "r"(b2.y), "r"(c2.x), "r"(c2.y)); + return make_u128((u64)as_ulong(rhi2), (u64)as_ulong(rlo2)); +#else + return add(mul64(a, b), make_u128(0, c)); +#endif +} + +u128 OVERLOAD mad64(u64 a, u64 b, u128 c) { +#if 0 && HAS_PTX >= 200 // mad instruction requires sm_20 support or higher // Slower on TitanV and mobile 4070, don't understand why + u64 reslo, reshi; + __asm("mad.lo.cc.u64 %0, %2, %3, %4;\n\t" + "madc.hi.u64 %1, %2, %3, %5;" : "=l"(reslo), "=l"(reshi) : "l"(a), "l"(b), "l"(u128_lo64(c)), "l"(u128_hi64(c))); + return make_u128(reshi, reslo); +#elif HAS_PTX >= 200 // mad instruction requires sm_20 support or higher // Faster on TitanV. No difference on mobile 4070. Much cleaner PTX code generated. + uint2 a2 = as_uint2(a); + uint2 b2 = as_uint2(b); + uint2 clo2 = as_uint2(u128_lo64(c)); + uint2 chi2 = as_uint2(u128_hi64(c)); + uint2 rlo2, rhi2; + __asm("mad.lo.cc.u32 %0, %4, %6, %8;\n\t" + "madc.hi.cc.u32 %1, %4, %6, %9;\n\t" + "madc.lo.cc.u32 %2, %5, %7, %10;\n\t" + "madc.hi.u32 %3, %5, %7, %11;\n\t" + "mad.lo.cc.u32 %1, %5, %6, %1;\n\t" + "madc.hi.cc.u32 %2, %5, %6, %2;\n\t" + "addc.u32 %3, %3, 0;\n\t" + "mad.lo.cc.u32 %1, %4, %7, %1;\n\t" + "madc.hi.cc.u32 %2, %4, %7, %2;\n\t" + "addc.u32 %3, %3, 0;" + : "=r"(rlo2.x), "=r"(rlo2.y), "=r"(rhi2.x), "=r"(rhi2.y) + : "r"(a2.x), "r"(a2.y), "r"(b2.x), "r"(b2.y), "r"(clo2.x), "r"(clo2.y), "r"(chi2.x), "r"(chi2.y)); + return make_u128((u64)as_ulong(rhi2), (u64)as_ulong(rlo2)); +#else + return add(mul64(a, b), c); +#endif +} + + +// The X2 family of macros and SWAP are #defines because OpenCL does not allow pass by reference. +// With NTT support added, we need to turn these macros into overloaded routines. +#define X2(a, b) X2_internal(&(a), &(b)) // a = a + b, b = a - b +#define X2conjb(a, b) X2conjb_internal(&(a), &(b)) // X2(a, conjugate(b)) +#define X2_mul_t4(a, b) X2_mul_t4_internal(&(a), &(b)) // X2(a, b), b = mul_t4(b) +#define X2_mul_t8(a, b) X2_mul_t8_internal(&(a), &(b)) // X2(a, b), b = mul_t8(b) +#define X2_mul_3t8(a, b) X2_mul_3t8_internal(&(a), &(b)) // X2(a, b), b = mul_3t8(b) +#define X2_conjb(a, b) X2_conjb_internal(&(a), &(b)) // X2(a, b), b = conjugate(b) +#define SWAP(a, b) SWAP_internal(&(a), &(b)) // a = b, b = a +#define SWAP_XY(a) U2((a).y, (a).x) // Swap real and imaginary components of a + +#if FFT_FP64 + +T2 OVERLOAD conjugate(T2 a) { return U2(a.x, -a.y); } + // Multiply by 2 without using floating point instructions. This is a little sloppy as an input of zero returns 2^-1022. T OVERLOAD mul2(T a) { int2 tmp = as_int2(a); tmp.y += 0x00100000; /* Bump exponent by 1 */ return (as_double(tmp)); } T2 OVERLOAD mul2(T2 a) { return U2(mul2(a.x), mul2(a.y)); } @@ -16,51 +293,39 @@ T2 OVERLOAD mulminus2(T2 a) { return U2(mulminus2(a.x), mulminus2(a.y)); } T OVERLOAD fancyMul(T a, T b) { return fma(a, b, a); } T2 OVERLOAD fancyMul(T2 a, T2 b) { return U2(fancyMul(a.x, b.x), fancyMul(a.y, b.y)); } -T2 cmul(T2 a, T2 b) { -#if 1 - return U2(fma(a.x, b.x, -a.y * b.y), fma(a.x, b.y, a.y * b.x)); -#else - return U2(fma(a.y, -b.y, a.x * b.x), fma(a.x, b.y, a.y * b.x)); -#endif -} +// Square a complex number +T2 OVERLOAD csq(T2 a) { return U2(fma(a.x, a.x, - a.y * a.y), mul2(a.x) * a.y); } +// a^2 + c +T2 OVERLOAD csqa(T2 a, T2 c) { return U2(fma(a.x, a.x, fma(a.y, -a.y, c.x)), fma(mul2(a.x), a.y, c.y)); } +// Same as csq(a), -a +T2 OVERLOAD csq_neg(T2 a) { return U2(fma(-a.x, a.x, a.y * a.y), mulminus2(a.x) * a.y); } // NOT USED + +// Complex multiply +T2 OVERLOAD cmul(T2 a, T2 b) { return U2(fma(a.x, b.x, -a.y * b.y), fma(a.x, b.y, a.y * b.x)); } -T2 conjugate(T2 a) { return U2(a.x, -a.y); } +T2 OVERLOAD cfma(T2 a, T2 b, T2 c) { return U2(fma(a.x, b.x, fma(a.y, -b.y, c.x)), fma(a.y, b.x, fma(a.x, b.y, c.y))); } -T2 cmul_by_conjugate(T2 a, T2 b) { return cmul(a, conjugate(b)); } +T2 OVERLOAD cmul_by_conjugate(T2 a, T2 b) { return cmul(a, conjugate(b)); } // Multiply a by b and conjugate(b). This saves 2 multiplies. -void cmul_a_by_b_and_conjb(T2 *res1, T2 *res2, T2 a, T2 b) { +void OVERLOAD cmul_a_by_b_and_conjb(T2 *res1, T2 *res2, T2 a, T2 b) { T axbx = a.x * b.x; T aybx = a.y * b.x; res1->x = fma(a.y, -b.y, axbx), res1->y = fma(a.x, b.y, aybx); res2->x = fma(a.y, b.y, axbx), res2->y = fma(a.x, -b.y, aybx); } -T2 cfma(T2 a, T2 b, T2 c) { -#if 1 - return U2(fma(a.x, b.x, fma(a.y, -b.y, c.x)), fma(a.y, b.x, fma(a.x, b.y, c.y))); -#else - return U2(fma(a.y, -b.y, fma(a.x, b.x, c.x)), fma(a.x, b.y, fma(a.y, b.x, c.y))); -#endif -} - -// Square a complex number -T2 csq(T2 a) { return U2(fma(a.x, a.x, - a.y * a.y), mul2(a.x) * a.y); } - // Square a (cos,sin) complex number. Fancy squaring returns a fancy value. Defancy squares a fancy number returning a non-fancy number. -T2 csqTrig(T2 a) { T two_ay = mul2(a.y); return U2(fma(-two_ay, a.y, 1), a.x * two_ay); } +T2 OVERLOAD csqTrig(T2 a) { T two_ay = mul2(a.y); return U2(fma(-two_ay, a.y, 1), a.x * two_ay); } T2 csqTrigFancy(T2 a) { T two_ay = mul2(a.y); return U2(-two_ay * a.y, fma(a.x, two_ay, two_ay)); } T2 csqTrigDefancy(T2 a) { T two_ay = mul2(a.y); return U2(fma (-two_ay, a.y, 1), fma(a.x, two_ay, two_ay)); } // Cube a complex number w (cos,sin) given w^2 and w. The squared input can be either fancy or not fancy. // Fancy cCube takes a fancy w argument and returns a fancy value. Defancy takes a fancy w argument and returns a non-fancy value. -T2 ccubeTrig(T2 sq, T2 w) { T tmp = mul2(sq.y); return U2(fma(tmp, -w.y, w.x), fma(tmp, w.x, -w.y)); } +T2 OVERLOAD ccubeTrig(T2 sq, T2 w) { T tmp = mul2(sq.y); return U2(fma(tmp, -w.y, w.x), fma(tmp, w.x, -w.y)); } T2 ccubeTrigFancy(T2 sq, T2 w) { T tmp = mul2(sq.y); return U2(fma(tmp, -w.y, w.x), fma(tmp, w.x, tmp - w.y)); } T2 ccubeTrigDefancy(T2 sq, T2 w) { T tmp = mul2(sq.y); T wx = w.x + 1; return U2(fma(tmp, -w.y, wx), fma(tmp, wx, -w.y)); } -// a^2 + c -T2 csqa(T2 a, T2 c) { return U2(fma(a.x, a.x, fma(a.y, -a.y, c.x)), fma(mul2(a.x), a.y, c.y)); } - // Complex a * (b + 1) // Useful for mul with twiddles of small angles, where the real part is stored with the -1 trick for increased precision T2 cmulFancy(T2 a, T2 b) { return cfma(a, b, a); } @@ -73,10 +338,9 @@ void cmul_a_by_fancyb_and_conjfancyb(T2 *res1, T2 *res2, T2 a, T2 b) { res2->x = fma(a.y, b.y, axbx), res2->y = fma(a.x, -b.y, aybx); } +T2 OVERLOAD mul_t4(T2 a) { return U2(-a.y, a.x); } // i.e. a * i -T2 mul_t4(T2 a) { return U2(-a.y, a.x); } // i.e. a * i - -T2 mul_t8(T2 a) { // mul(a, U2( 1, 1)) * (T)(M_SQRT1_2); } +T2 OVERLOAD mul_t8(T2 a) { // mul(a, U2(1, 1)) * (T)(M_SQRT1_2); } // One mul, two FMAs T ay = a.y * M_SQRT1_2; return U2(fma(a.x, M_SQRT1_2, -ay), fma(a.x, M_SQRT1_2, ay)); @@ -84,7 +348,7 @@ T2 mul_t8(T2 a) { // mul(a, U2( 1, 1)) * (T)(M_SQRT1_2); } // return U2(a.x - a.y, a.x + a.y) * M_SQRT1_2; } -T2 mul_3t8(T2 a) { // mul(a, U2(-1, 1)) * (T)(M_SQRT1_2); } +T2 OVERLOAD mul_3t8(T2 a) { // mul(a, U2(-1, 1)) * (T)(M_SQRT1_2); } // One mul, two FMAs T ay = a.y * M_SQRT1_2; return U2(fma(-a.x, M_SQRT1_2, -ay), fma(a.x, M_SQRT1_2, -ay)); @@ -92,24 +356,733 @@ T2 mul_3t8(T2 a) { // mul(a, U2(-1, 1)) * (T)(M_SQRT1_2); } // return U2(-(a.x + a.y), a.x - a.y) * M_SQRT1_2; } -T2 swap(T2 a) { return U2(a.y, a.x); } -T2 addsub(T2 a) { return U2(a.x + a.y, a.x - a.y); } - -#define X2(a, b) { T2 t = a; a = t + b; b = t - b; } +// Return a+b and a-b +void OVERLOAD X2_internal(T2 *a, T2 *b) { T2 t = *a; *a = t + *b; *b = t - *b; } // Same as X2(a, b), b = mul_t4(b) -#define X2_mul_t4(a, b) { X2(a, b); b = mul_t4(b); } -// { T2 t = a; a = a + b; t.x = t.x - b.x; b.x = b.y - t.y; b.y = t.x; } +void OVERLOAD X2_mul_t4_internal(T2 *a, T2 *b) { T2 t = *a; *a = *a + *b; t.x = t.x - b->x; b->x = b->y - t.y; b->y = t.x; } // Same as X2(a, conjugate(b)) -#define X2conjb(a, b) { T2 t = a; a.x = a.x + b.x; a.y = a.y - b.y; b.x = t.x - b.x; b.y = t.y + b.y; } +void OVERLOAD X2conjb_internal(T2 *a, T2 *b) { T2 t = *a; a->x = a->x + b->x; a->y = a->y - b->y; b->x = t.x - b->x; b->y = t.y + b->y; } -// Same as X2(a, b), a = conjugate(a) -#define X2conja(a, b) { T2 t = a; (a).x = (a).x + (b).x; (a).y = -(a).y - (b).y; b = t - b; } +// Same as X2(a, b), b = conjugate(b) +void OVERLOAD X2_conjb_internal(T2 *a, T2 *b) { T2 t = *a; *a = t + *b; b->x = t.x - b->x; b->y = b->y - t.y; } -#define SWAP(a, b) { T2 t = a; a = b; b = t; } +void OVERLOAD SWAP_internal(T2 *a, T2 *b) { T2 t = *a; *a = *b; *b = t; } T2 fmaT2(T a, T2 b, T2 c) { return fma(U2(a, a), b, c); } // a = c + sin * d; b = c - sin * d; #define fma_addsub(a, b, sin, c, d) { T2 t = c + sin * d; b = c - sin * d; a = t; } + +T2 OVERLOAD addsub(T2 a) { return U2(a.x + a.y, a.x - a.y); } + +// computes 2*(a.x*b.x+a.y*b.y) + i*2*(a.x*b.y+a.y*b.x) +// which happens to be the cyclical convolution (a.x, a.y)x(b.x, b.y) * 2 +T2 OVERLOAD foo2(T2 a, T2 b) { a = addsub(a); b = addsub(b); return addsub(U2(RE(a) * RE(b), IM(a) * IM(b))); } + +// computes 2*[x^2+y^2 + i*(2*x*y)]. i.e. 2 * cyclical autoconvolution of (x, y) +T2 OVERLOAD foo(T2 a) { return foo2(a, a); } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 + +F2 OVERLOAD conjugate(F2 a) { return U2(a.x, -a.y); } + +// Multiply by 2 without using floating point instructions. This is a little sloppy as an input of zero returns 2^-126. +F OVERLOAD mul2(F a) { return a + a; } //{ int tmp = as_int(a); tmp += 0x00800000; /* Bump exponent by 1 */ return (as_float(tmp)); } +F2 OVERLOAD mul2(F2 a) { return U2(mul2(a.x), mul2(a.y)); } + +// Multiply by -2 without using floating point instructions. This is a little sloppy as an input of zero returns -2^-126. +F OVERLOAD mulminus2(F a) { return -2.0f * a; } //{ int tmp = as_int(a); tmp += 0x80800000; /* Bump exponent by 1, flip sign bit */ return (as_float(tmp)); } +F2 OVERLOAD mulminus2(F2 a) { return U2(mulminus2(a.x), mulminus2(a.y)); } + +// a * (b + 1) == a * b + a +F OVERLOAD fancyMul(F a, F b) { return fma(a, b, a); } +F2 OVERLOAD fancyMul(F2 a, F2 b) { return U2(fancyMul(a.x, b.x), fancyMul(a.y, b.y)); } + +// Square a complex number +F2 OVERLOAD csq(F2 a) { return U2(fma(a.x, a.x, - a.y * a.y), mul2(a.x) * a.y); } +// a^2 + c +F2 OVERLOAD csqa(F2 a, F2 c) { return U2(fma(a.x, a.x, fma(a.y, -a.y, c.x)), fma(mul2(a.x), a.y, c.y)); } +// Same as csq(a), -a +F2 OVERLOAD csq_neg(F2 a) { return U2(fma(-a.x, a.x, a.y * a.y), mulminus2(a.x) * a.y); } // NOT USED + +// Complex multiply +F2 OVERLOAD cmul(F2 a, F2 b) { return U2(fma(a.x, b.x, -a.y * b.y), fma(a.x, b.y, a.y * b.x)); } + +F2 OVERLOAD cfma(F2 a, F2 b, F2 c) { return U2(fma(a.x, b.x, fma(a.y, -b.y, c.x)), fma(a.y, b.x, fma(a.x, b.y, c.y))); } + +F2 OVERLOAD cmul_by_conjugate(F2 a, F2 b) { return cmul(a, conjugate(b)); } + +// Multiply a by b and conjugate(b). This saves 2 multiplies. +void OVERLOAD cmul_a_by_b_and_conjb(F2 *res1, F2 *res2, F2 a, F2 b) { + F axbx = a.x * b.x; + F aybx = a.y * b.x; + res1->x = fma(a.y, -b.y, axbx), res1->y = fma(a.x, b.y, aybx); + res2->x = fma(a.y, b.y, axbx), res2->y = fma(a.x, -b.y, aybx); +} + +// Square a (cos,sin) complex number. Fancy squaring returns a fancy value. Defancy squares a fancy number returning a non-fancy number. +F2 OVERLOAD csqTrig(F2 a) { F two_ay = mul2(a.y); return U2(fma(-two_ay, a.y, 1), a.x * two_ay); } +F2 csqTrigFancy(F2 a) { F two_ay = mul2(a.y); return U2(-two_ay * a.y, fma(a.x, two_ay, two_ay)); } +F2 csqTrigDefancy(F2 a) { F two_ay = mul2(a.y); return U2(fma (-two_ay, a.y, 1), fma(a.x, two_ay, two_ay)); } + +// Cube a complex number w (cos,sin) given w^2 and w. The squared input can be either fancy or not fancy. +// Fancy cCube takes a fancy w argument and returns a fancy value. Defancy takes a fancy w argument and returns a non-fancy value. +F2 OVERLOAD ccubeTrig(F2 sq, F2 w) { F tmp = mul2(sq.y); return U2(fma(tmp, -w.y, w.x), fma(tmp, w.x, -w.y)); } +F2 ccubeTrigFancy(F2 sq, F2 w) { F tmp = mul2(sq.y); return U2(fma(tmp, -w.y, w.x), fma(tmp, w.x, tmp - w.y)); } +F2 ccubeTrigDefancy(F2 sq, F2 w) { F tmp = mul2(sq.y); F wx = w.x + 1; return U2(fma(tmp, -w.y, wx), fma(tmp, wx, -w.y)); } + +// Complex a * (b + 1) +// Useful for mul with twiddles of small angles, where the real part is stored with the -1 trick for increased precision +F2 cmulFancy(F2 a, F2 b) { return cfma(a, b, a); } + +// Multiply a by fancy b and conjugate(fancy b). This saves 2 FMAs. +void cmul_a_by_fancyb_and_conjfancyb(F2 *res1, F2 *res2, F2 a, F2 b) { + F axbx = fma(a.x, b.x, a.x); + F aybx = fma(a.y, b.x, a.y); + res1->x = fma(a.y, -b.y, axbx), res1->y = fma(a.x, b.y, aybx); + res2->x = fma(a.y, b.y, axbx), res2->y = fma(a.x, -b.y, aybx); +} + +F2 OVERLOAD mul_t4(F2 a) { return U2(-a.y, a.x); } // i.e. a * i + +F2 OVERLOAD mul_t8(F2 a) { // mul(a, U2(1, 1)) * (T)(M_SQRT1_2); } + // One mul, two FMAs + F ay = a.y * (float) M_SQRT1_2; + return U2(fma(a.x, (float) M_SQRT1_2, -ay), fma(a.x, (float) M_SQRT1_2, ay)); +// Two adds, two muls +// return U2(a.x - a.y, a.x + a.y) * M_SQRT1_2; +} + +F2 OVERLOAD mul_3t8(F2 a) { // mul(a, U2(-1, 1)) * (T)(M_SQRT1_2); } + // One mul, two FMAs + F ay = a.y * (float) M_SQRT1_2; + return U2(fma(-a.x, (float) M_SQRT1_2, -ay), fma(a.x, (float) M_SQRT1_2, -ay)); +// Two adds, two muls +// return U2(-(a.x + a.y), a.x - a.y) * M_SQRT1_2; +} + +// Return a+b and a-b +void OVERLOAD X2_internal(F2 *a, F2 *b) { F2 t = *a; *a = t + *b; *b = t - *b; } + +// Same as X2(a, b), b = mul_t4(b) +void OVERLOAD X2_mul_t4_internal(F2 *a, F2 *b) { F2 t = *a; *a = *a + *b; t.x = t.x - b->x; b->x = b->y - t.y; b->y = t.x; } + +// Same as X2(a, conjugate(b)) +void OVERLOAD X2conjb_internal(F2 *a, F2 *b) { F2 t = *a; a->x = a->x + b->x; a->y = a->y - b->y; b->x = t.x - b->x; b->y = t.y + b->y; } + +// Same as X2(a, b), b = conjugate(b) +void OVERLOAD X2_conjb_internal(F2 *a, F2 *b) { F2 t = *a; *a = t + *b; b->x = t.x - b->x; b->y = b->y - t.y; } + +void OVERLOAD SWAP_internal(F2 *a, F2 *b) { F2 t = *a; *a = *b; *b = t; } + +F2 fmaT2(F a, F2 b, F2 c) { return fma(U2(a, a), b, c); } + +// a = c + sin * d; b = c - sin * d; +#define fma_addsub(a, b, sin, c, d) { F2 t = c + sin * d; b = c - sin * d; a = t; } + +F2 OVERLOAD addsub(F2 a) { return U2(a.x + a.y, a.x - a.y); } + +// computes 2*(a.x*b.x+a.y*b.y) + i*2*(a.x*b.y+a.y*b.x) +// which happens to be the cyclical convolution (a.x, a.y)x(b.x, b.y) * 2 +F2 OVERLOAD foo2(F2 a, F2 b) { a = addsub(a); b = addsub(b); return addsub(U2(RE(a) * RE(b), IM(a) * IM(b))); } + +// computes 2*[x^2+y^2 + i*(2*x*y)]. i.e. 2 * cyclical autoconvolution of (x, y) +F2 OVERLOAD foo(F2 a) { return foo2(a, a); } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +// bits in reduced mod M. +#define M31 ((((Z31) 1) << 31) - 1) + +#if 0 // Version that keeps results strictly in the 0..M31-1 range + +Z31 OVERLOAD mod(Z31 a) { return (a & M31) + (a >> 31); } // GWBUG: This could be larger than M31 (unless a is result of an add), need a wesk and strong mod + +Z31 OVERLOAD add(Z31 a, Z31 b) { Z31 t = a + b; return t - (t >= M31 ? M31 : 0); } //GWBUG - an if stmt may be faster +GF31 OVERLOAD add(GF31 a, GF31 b) { return U2(add(a.x, b.x), add(a.y, b.y)); } + +Z31 OVERLOAD sub(Z31 a, Z31 b) { Z31 t = a - b; return t + (t >= M31 ? M31 : 0); } //GWBUG - an if stmt may be faster. So might "(i64) t < 0". +GF31 OVERLOAD sub(GF31 a, GF31 b) { return U2(sub(a.x, b.x), sub(a.y, b.y)); } + +Z31 OVERLOAD neg(Z31 a) { return a == 0 ? 0 : M31 - a; } // GWBUG: Examine all callers to see if neg call can be avoided +GF31 OVERLOAD neg(GF31 a) { return U2(neg(a.x), neg(a.y)); } + +Z31 OVERLOAD make_Z31(i32 a) { return (Z31) (a < 0 ? a + M31 : a); } // Handles signed values of a +Z31 OVERLOAD make_Z31(u32 a) { return (Z31) (a); } // a must be in range of 0 .. M31-1 +Z31 OVERLOAD make_Z31(i64 a) { if (a < 0) a += (((i64) M31 << 31) + M31); return add((Z31) (a & M31), (Z31) (a >> 31)); } // Handles 62-bit a values + +u32 get_Z31(Z31 a) { return a; } // Get balanced value in range 0 to M31-1 +i32 get_balanced_Z31(Z31 a) { return (a & 0xC0000000) ? (i32) a - M31 : (i32) a; } // Get balanced value in range -M31/2 to M31/2 + +// Assumes k reduced mod 31. +Z31 OVERLOAD shl(Z31 a, u32 k) { return ((a << k) + (a >> (31 - k))) & M31; } +GF31 OVERLOAD shl(GF31 a, u32 k) { return U2(shl(a.x, k), shl(a.y, k)); } +Z31 OVERLOAD shr(Z31 a, u32 k) { return ((a >> k) + (a << (31 - k))) & M31; } +GF31 OVERLOAD shr(GF31 a, u32 k) { return U2(shr(a.x, k), shr(a.y, k)); } + +Z31 OVERLOAD mul(Z31 a, Z31 b) { u64 t = a * (u64) b; return add((Z31) (t & M31), (Z31) (t >> 31)); } + +// Multiply by 2 +Z31 OVERLOAD mul2(Z31 a) { return add(a, a); } +GF31 OVERLOAD mul2(GF31 a) { return U2(mul2(a.x), mul2(a.y)); } + +// Return conjugate of a +GF31 OVERLOAD conjugate(GF31 a) { return U2(a.x, neg(a.y)); } + +// Complex square. input, output 31 bits. Uses (a + i*b)^2 == ((a+b)*(a-b) + i*2*a*b). +GF31 OVERLOAD csq(GF31 a) { return U2(mul(add(a.x, a.y), sub(a.x, a.y)), mul2(mul(a.x, a.y))); } //GWBUG: Probably faster to double a.y and have a mul that takes non-normalized inputs + +// a^2 + c +GF31 OVERLOAD csqa(GF31 a, GF31 c) { return add(csq(a), c); } // GWBUG: inline csq so we only "mod" after adding c?? Find a way to use fma instructions + +// Complex mul +//GF31 OVERLOAD cmul(GF31 a, GF31 b) { return U2(sub(mul(a.x, b.x), mul(a.y, b.y)), add(mul(a.x, b.y), mul(a.y, b.x)));} // GWBUG: Is a 3 multiply complex mul faster? See above +GF31 OVERLOAD cmul(GF31 a, GF31 b) { + Z31 k1 = mul(b.x, add(a.x, a.y)); + Z31 k2 = mul(a.x, sub(b.y, b.x)); + Z31 k3 = mul(a.y, add(b.y, b.x)); + return U2(sub(k1, k3), add(k1, k2)); +} + +// mul with (0, 1). (twiddle of tau/4, sqrt(-1) aka "i"). +GF31 OVERLOAD mul_t4(GF31 a) { return U2(neg(a.y), a.x); } // GWBUG: Can caller use a version that does not negate real? + +// mul with (2^15, 2^15). (twiddle of tau/8 aka sqrt(i)). Note: 2 * (+/-2^15)^2 == 1 (mod M31). +GF31 OVERLOAD mul_t8(GF31 a) { return U2(shl(sub(a.x, a.y), 15), shl(add(a.x, a.y), 15)); } // GWBUG: Can caller use a version that does not negate real? is shl(neg) same as shr??? + +// mul with (-2^15, 2^15). (twiddle of 3*tau/8). +GF31 OVERLOAD mul_3t8(GF31 a) { return U2(shl(neg(add(a.x, a.y)), 15), shl(sub(a.x, a.y), 15)); } + +// Return a+b and a-b +void OVERLOAD X2_internal(GF31 *a, GF31 *b) { GF31 t = *a; *a = add(t, *b); *b = sub(t, *b); } + +// Same as X2(a, conjugate(b)) +void OVERLOAD X2conjb_internal(GF31 *a, GF31 *b) { GF31 t = *a; a->x = add(a->x, b->x); a->y = sub(a->y, b->y); b->x = sub(t.x, b->x); b->y = add(t.y, b->y); } + +// Same as X2(a, b), b = mul_t4(b) +void OVERLOAD X2_mul_t4_internal(GF31 *a, GF31 *b) { GF31 t = *a; *a = add(*a, *b); t.x = sub(t.x, b->x); b->x = sub(b->y, t.y); b->y = t.x; } + +// Same as X2(a, b), b = mul_t8(b) +void OVERLOAD X2_mul_t8_internal(GF31 *a, GF31 *b) { X2(*a, *b); *b = mul_t8(*b); } + +// Same as X2(a, b), b = mul_3t8(b) +void OVERLOAD X2_mul_3t8_internal(GF31 *a, GF31 *b) { X2(*a, *b); *b = mul_3t8(*b); } //GWBUG: can we do better (elim a negate)? + +// Same as X2(a, b), b = conjugate(b) +void OVERLOAD X2_conjb_internal(GF31 *a, GF31 *b) { GF31 t = *a; *a = add(t, *b); b->x = sub(t.x, b->x); b->y = sub(b->y, t.y); } + +void OVERLOAD SWAP_internal(GF31 *a, GF31 *b) { GF31 t = *a; *a = *b; *b = t; } + +GF31 OVERLOAD addsub(GF31 a) { return U2(add(a.x, a.y), sub(a.x, a.y)); } +GF31 OVERLOAD foo2(GF31 a, GF31 b) { a = addsub(a); b = addsub(b); return addsub(U2(mul(RE(a), RE(b)), mul(IM(a), IM(b)))); } +GF31 OVERLOAD foo(GF31 a) { return foo2(a, a); } + + + + + +#elif 1 // This version is a little sloppy. Returns values in 0..M31 range. + +// Internal routines to return value in 0..M31 range +#if MODM31 == 0 +Z31 OVERLOAD modM31(Z31 a) { return (a & M31) + (a >> 31); } // Assumes a is not 0xFFFFFFFF (which would return 0x80000000) +Z31 OVERLOAD modM31(i32 a) { return (a & M31) + (a >> 31); } // Assumes a is not 0x80000000 (which would return 0xFFFFFFFF) +#elif MODM31 == 1 +Z31 OVERLOAD modM31(Z31 a) { i32 alt = a + 0x80000001; return select32(a, a, alt); } // Assumes a is not 0xFFFFFFFF (which would return 0x80000000) +Z31 OVERLOAD modM31(i32 a) { i32 alt = a - 0x80000001; return select32(a, a, alt); } // Assumes a is not 0x80000000 (which would return 0xFFFFFFFF) +#else +Z31 OVERLOAD modM31(Z31 a) { return optional_add(a, 0x80000001); } // Assumes a is not 0xFFFFFFFF (which would return 0x80000000) +Z31 OVERLOAD modM31(i32 a) { return optional_sub(a, 0x80000001); } // Assumes a is not 0x80000000 (which would return 0xFFFFFFFF) +#endif + +Z31 OVERLOAD modM31(u64 a) { // a must be less than 0xFFFFFFFF7FFFFFFF + u32 alo = a & M31; + u32 amid = (a >> 31) & M31; + u32 ahi = a >> 62; + return modM31(ahi + amid + alo); // 32-bit overflow does not occur due to restrictions on input +} +Z31 OVERLOAD modM31(i64 a) { // abs(a) must be less than 0x7FFFFFFF80000000 + u32 alo = a & M31; + u32 amid = ((u64) a >> 31) & M31; // Unsigned shift might be faster than signed shift + u32 ahi = a >> 62; // Sign extend the top bits + return modM31(ahi + amid + alo); // This is where caller must assure a 32-bit overflow does not occur +} +Z31 OVERLOAD modM31q(u64 a) { // Quick version, a < 2^62 + u32 alo = a & M31; + u32 ahi = a >> 31; + return modM31(ahi + alo); +} +#if 0 // GWBUG - which is faster? +Z31 OVERLOAD modM31q(i64 a) { // Quick version, abs(a) must be 61 bits + u32 alo = a & M31; + i32 ahi = a >> 31; // Sign extend the top bits + if (ahi < 0) ahi = ahi + M31; + return modM31((u32) ahi + alo); +} +#else +Z31 OVERLOAD modM31q(i64 a) { return modM31(a); } // Quick version, abs(a) must be 61 bits +#endif + +Z31 OVERLOAD neg(Z31 a) { return M31 - a; } // GWBUG: Examine all callers to see if neg call can be avoided +GF31 OVERLOAD neg(GF31 a) { return U2(neg(a.x), neg(a.y)); } + +Z31 OVERLOAD add(Z31 a, Z31 b) { return modM31(a + b); } +GF31 OVERLOAD add(GF31 a, GF31 b) { return U2(add(a.x, b.x), add(a.y, b.y)); } + +Z31 OVERLOAD sub(Z31 a, Z31 b) { i32 t = a - b; return (t & M31) + (t >> 31); } +GF31 OVERLOAD sub(GF31 a, GF31 b) { return U2(sub(a.x, b.x), sub(a.y, b.y)); } + +Z31 OVERLOAD make_Z31(i32 a) { return (Z31) (a < 0 ? a + M31 : a); } // Handles signed values of a +Z31 OVERLOAD make_Z31(u32 a) { return (Z31) (a); } // a must be in range of 0 .. M31-1 +Z31 OVERLOAD make_Z31(i64 a) { return modM31q(a); } // Handles range -2^61..2^61 + +u32 get_Z31(Z31 a) { return a == M31 ? 0 : a; } // Get value in range 0 to M31-1 +i32 get_balanced_Z31(Z31 a) { return (a & 0xC0000000) ? (i32) a - M31 : (i32) a; } // Get balanced value in range -M31/2 to M31/2 + +// Assumes k reduced mod 31. +Z31 OVERLOAD shr(Z31 a, u32 k) { return (a >> k) + ((a << (31 - k)) & M31); } +GF31 OVERLOAD shr(GF31 a, u32 k) { return U2(shr(a.x, k), shr(a.y, k)); } +Z31 OVERLOAD shl(Z31 a, u32 k) { return shr(a, 31 - k); } +GF31 OVERLOAD shl(GF31 a, u32 k) { return U2(shl(a.x, k), shl(a.y, k)); } + +Z31 OVERLOAD mul(Z31 a, Z31 b) { u64 t = a * (u64) b; return modM31(add((Z31)(t & M31), (Z31)(t >> 31))); } + +// Multiply by 2 +Z31 OVERLOAD mul2(Z31 a) { return add(a, a); } +GF31 OVERLOAD mul2(GF31 a) { return U2(mul2(a.x), mul2(a.y)); } + +// Return conjugate of a +GF31 OVERLOAD conjugate(GF31 a) { return U2(a.x, neg(a.y)); } + +// Complex square. input, output 31 bits. Uses (a + i*b)^2 == ((a+b)*(a-b) + i*2*a*b). +GF31 OVERLOAD csq(GF31 a) { + u64 r = (a.x + a.y) * (u64) (a.x + neg(a.y)); // 64-bit value, max = FFFF FFFE 0000 0004 (actually cannot exceed 9000 0000 0000 0000) + u64 i = (a.x + a.x) * (u64) a.y; // 63-bit value, max = 7FFF FFFE 0000 0002 + return U2(modM31(r), modM31(i)); +} + +// a^2 + c +GF31 OVERLOAD csq_add(GF31 a, GF31 c) { + u64 r = mad32(a.x + a.y, a.x + neg(a.y), c.x); // 64-bit value, mul max = FFFF FFFE 0000 0004 (actually cannot exceed 9000 0000 0000 0000) + u64 i = mad32(a.x + a.x, a.y, c.y); // 63-bit value, mul max = 7FFF FFFE 0000 0002 + return U2(modM31(r), modM31(i)); +} + +// a^2 - c +GF31 OVERLOAD csq_sub(GF31 a, GF31 c) { + u64 r = mad32(a.x + a.y, a.x + neg(a.y), neg(c.x)); // 64-bit value, mul max = FFFF FFFE 0000 0004 (actually cannot exceed 9000 0000 0000 0000) + u64 i = mad32(a.x + a.x, a.y, neg(c.y)); // 63-bit value, mul max = 7FFF FFFE 0000 0002 + return U2(modM31(r), modM31(i)); +} + +// a^2 + i*c +GF31 OVERLOAD csq_addi(GF31 a, GF31 c) { + u64 r = mad32(a.x + a.y, a.x + neg(a.y), neg(c.y)); // 64-bit value, mul max = FFFF FFFE 0000 0004 (actually cannot exceed 9000 0000 0000 0000) + u64 i = mad32(a.x + a.x, a.y, c.x); // 63-bit value, mul max = 7FFF FFFE 0000 0002 + return U2(modM31(r), modM31(i)); +} + +// a^2 - i*c +GF31 OVERLOAD csq_subi(GF31 a, GF31 c) { + u64 r = mad32(a.x + a.y, a.x + neg(a.y), c.y); // 64-bit value, mul max = FFFF FFFE 0000 0004 (actually cannot exceed 9000 0000 0000 0000) + u64 i = mad32(a.x + a.x, a.y, neg(c.x)); // 63-bit value, max = 7FFF FFFE 0000 0002 + return U2(modM31(r), modM31(i)); +} + +// Complex mul +#if 0 // One less negation, requires signed shifts. Seems microscopically faster on TitanV. +GF31 OVERLOAD cmul(GF31 a, GF31 b) { + u64 k1 = b.x * (u64) (a.x + a.y); // 63-bit value, max = 7FFF FFFE 0000 0002 + u64 k2 = a.x * (u64) (b.y + neg(b.x)); + u64 k3 = a.y * (u64) (b.y + b.x); + i64 k1k3 = k1 - k3; // signed 63-bit value, absolute value <= 7FFF FFFE 0000 0002 + u64 k1k2 = k1 + k2; // unsigned 64-bit value, max = FFFF FFFC 0000 0004 + return U2(modM31(k1k3), modM31(k1k2)); +} +#else +GF31 OVERLOAD cmul(GF31 a, GF31 b) { + u32 negbx = neg(b.x); // Negate and add b values as much as possible in case b is used several times (as in a chainmul) + u64 k1 = b.x * (u64)(a.x + a.y); // 63-bit value, max = 7FFF FFFE 0000 0002 + u64 k1k2 = mad32(a.x, b.y + negbx, k1); // unsigned 64-bit value, max = FFFF FFFC 0000 0004 + u64 k1k3 = mad32(a.y, neg(b.y) + negbx, k1); // unsigned 64-bit value, max = FFFF FFFC 0000 0004 + return U2(modM31(k1k3), modM31(k1k2)); +} +#endif + +// Square a root of unity complex number +GF31 OVERLOAD csqTrig(GF31 a) { u32 two_ay = a.y + a.y; return U2(modM31(mad32(two_ay, neg(a.y), (u32)1)), modM31(a.x * (u64)two_ay)); } + +// Cube w, a root of unity complex number, given w^2 and w +GF31 OVERLOAD ccubeTrig(GF31 sq, GF31 w) { u32 tmp = sq.y + sq.y; return U2(modM31(mad32(tmp, neg(w.y), w.x)), modM31(mad32(tmp, w.x, neg(w.y)))); } + +// mul with (0, 1). (twiddle of tau/4, sqrt(-1) aka "i"). +GF31 OVERLOAD mul_t4(GF31 a) { return U2(neg(a.y), a.x); } // GWBUG: Can caller use a version that does not negate real? + +// mul with (2^15, 2^15). (twiddle of tau/8 aka sqrt(i)). Note: 2 * (+/-2^15)^2 == 1 (mod M31). +GF31 OVERLOAD mul_t8(GF31 a) { return U2(shl(sub(a.x, a.y), 15), shl(add(a.x, a.y), 15)); } + +// mul with (-2^15, 2^15). (twiddle of 3*tau/8). +GF31 OVERLOAD mul_3t8(GF31 a) { return U2(shl(neg(add(a.x, a.y)), 15), shl(sub(a.x, a.y), 15)); } + +// Return a+b and a-b +void OVERLOAD X2_internal(GF31 *a, GF31 *b) { GF31 t = *a; *a = add(t, *b); *b = sub(t, *b); } + +// Same as X2(a, conjugate(b)) +void OVERLOAD X2conjb_internal(GF31 *a, GF31 *b) { GF31 t = *a; a->x = add(a->x, b->x); a->y = sub(a->y, b->y); b->x = sub(t.x, b->x); b->y = add(t.y, b->y); } + +// Same as X2(a, b), b = mul_t4(b) +void OVERLOAD X2_mul_t4_internal(GF31 *a, GF31 *b) { GF31 t = *a; *a = add(*a, *b); t.x = sub(t.x, b->x); b->x = sub(b->y, t.y); b->y = t.x; } + +// Same as X2(a, b), b = mul_t8(b) +void OVERLOAD X2_mul_t8_internal(GF31 *a, GF31 *b) { X2(*a, *b); *b = mul_t8(*b); } + +// Same as X2(a, b), b = mul_3t8(b) +void OVERLOAD X2_mul_3t8_internal(GF31 *a, GF31 *b) { GF31 t = *a; *a = add(t, *b); t = sub(*b, t); *b = shl(U2(add(t.x, t.y), sub(t.y, t.x)), 15); } + +// Same as X2(a, b), b = conjugate(b) +void OVERLOAD X2_conjb_internal(GF31 *a, GF31 *b) { GF31 t = *a; *a = add(t, *b); b->x = sub(t.x, b->x); b->y = sub(b->y, t.y); } + +void OVERLOAD SWAP_internal(GF31 *a, GF31 *b) { GF31 t = *a; *a = *b; *b = t; } + +GF31 OVERLOAD addsub(GF31 a) { return U2(add(a.x, a.y), sub(a.x, a.y)); } +GF31 OVERLOAD foo2(GF31 a, GF31 b) { a = addsub(a); b = addsub(b); return addsub(U2(mul(RE(a), RE(b)), mul(IM(a), IM(b)))); } +GF31 OVERLOAD foo(GF31 a) { return foo2(a, a); } + +#endif + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +// bits in reduced mod M. +#define M61 ((((Z61) 1) << 61) - 1) + +Z61 OVERLOAD make_Z61(i32 a) { return (Z61) (a < 0 ? (i64) a + M61 : (i64) a); } // Handles all values of a +Z61 OVERLOAD make_Z61(i64 a) { return (Z61) (a < 0 ? a + M61 : a); } // a must be in range of -M61 .. M61-1 +Z61 OVERLOAD make_Z61(u32 a) { return (Z61) (a); } // Handles all values of a +Z61 OVERLOAD make_Z61(u64 a) { return (Z61) (a); } // a must be in range of 0 .. M61-1 + +#if 0 // Slower version that keeps results strictly in the range 0 .. M61-1 + +u64 OVERLOAD get_Z61(Z61 a) { return a; } // Get value in range 0 to M61-1 +i64 OVERLOAD get_balanced_Z61(Z61 a) { return (hi32(a) & 0xF0000000) ? (i64) a - (i64) M61 : (i64) a; } // Get balanced value in range -M61/2 to M61/2 + +Z61 OVERLOAD neg(Z61 a) { return a == 0 ? 0 : M61 - a; } // GWBUG: Examine all callers to see if neg call can be avoided +GF61 OVERLOAD neg(GF61 a) { return U2(neg(a.x), neg(a.y)); } + +Z61 OVERLOAD add(Z61 a, Z61 b) { Z61 t = a + b; Z61 m = t - M61; return (m & 0x8000000000000000ULL) ? t : m; } +//Z61 OVERLOAD add(Z61 a, Z61 b) { Z61 t = a + b; Z61 m = t - M61; return t < m ? t : m; } // Slower on TitanV +//Z61 OVERLOAD add(Z61 a, Z61 b) { Z61 t = a + b; return t - (t >= M61 ? M61 : 0); } // Slower on TitanV +GF61 OVERLOAD add(GF61 a, GF61 b) { return U2(add(a.x, b.x), add(a.y, b.y)); } + +Z61 OVERLOAD sub(Z61 a, Z61 b) { Z61 t = a - b; return t + (((i64) t >> 63) & 0x1FFFFFFFFFFFFFFFULL); } +//Z61 OVERLOAD sub(Z61 a, Z61 b) { Z61 t = a - b; Z61 p = t + M61; return (t & 0x8000000000000000ULL) ? p : t; } // Better??? +//Z61 OVERLOAD sub(Z61 a, Z61 b) { Z61 t = a - b; return t + (t >= M61 ? M61 : 0); } // Slower on TitanV +// BETTER???: t = a - b; carry_mask = sbb x, x; (generates 32 bits of 0 or 1; return t + make_carry_mask_64_bits +GF61 OVERLOAD sub(GF61 a, GF61 b) { return U2(sub(a.x, b.x), sub(a.y, b.y)); } + +// Assumes k reduced mod 61. +Z61 OVERLOAD shl(Z61 a, u32 k) { return ((a << k) + (a >> (61 - k))) & M61; } //GWBUG: Make sure & M61 operates on just one u32 +GF61 OVERLOAD shl(GF61 a, u32 k) { return U2(shl(a.x, k), shl(a.y, k)); } +Z61 OVERLOAD shr(Z61 a, u32 k) { return ((a >> k) + (a << (61 - k))) & M61; } //GWBUG: Make sure & M61 operates on just one u32. & M61 not needed? +GF61 OVERLOAD shr(GF61 a, u32 k) { return U2(shr(a.x, k), shr(a.y, k)); } + +ulong2 wideMul(u64 ab, u64 cd) { + u128 r = mul64(ab, cd); + return U2(u128_lo64(r), u128_hi64(r)); +} + +Z61 OVERLOAD mul(Z61 a, Z61 b) { + ulong2 ab = wideMul(a, b); + u64 lo = ab.x, hi = ab.y; + u64 lo61 = lo & M61, hi61 = (hi << 3) + (lo >> 61); + return add(lo61, hi61); +} + +Z61 OVERLOAD fma(Z61 a, Z61 b, Z61 c) { return add(mul(a, b), c); } // GWBUG: Can we do better? + +// Multiply by 2 +Z61 OVERLOAD mul2(Z61 a) { return ((a + a) + (a >> 60)) & M61; } // GWBUG: Make sure "+ a>>60" does an add to lower u32 without a followup adc. +GF61 OVERLOAD mul2(GF61 a) { return U2(mul2(a.x), mul2(a.y)); } + +// Return conjugate of a +GF61 OVERLOAD conjugate(GF61 a) { return U2(a.x, neg(a.y)); } + +// Complex square. input, output 61 bits. Uses (a + i*b)^2 == ((a+b)*(a-b) + i*2*a*b). +GF61 OVERLOAD csq(GF61 a) { return U2(mul(add(a.x, a.y), sub(a.x, a.y)), mul2(mul(a.x, a.y))); } //GWBUG: Probably faster to double a.y and have a mul that takes non-normalized inputs + +// a^2 + c +GF61 OVERLOAD csqa(GF61 a, GF61 c) { return add(csq(a), c); } // GWBUG: inline csq so we only "mod" after adding c?? Find a way to use fma instructions + +// Complex mul +//GF61 OVERLOAD cmul(GF61 a, GF61 b) { return U2(sub(mul(a.x, b.x), mul(a.y, b.y)), add(mul(a.x, b.y), mul(a.y, b.x)));} // GWBUG: Is a 3 multiply complex mul faster? See above +GF61 OVERLOAD cmul(GF61 a, GF61 b) { + Z61 k1 = mul(b.x, add(a.x, a.y)); + Z61 k2 = mul(a.x, sub(b.y, b.x)); + Z61 k3 = mul(a.y, add(b.y, b.x)); + return U2(sub(k1, k3), add(k1, k2)); +} + +// mul with (0, 1). (twiddle of tau/4, sqrt(-1) aka "i"). +GF61 OVERLOAD mul_t4(GF61 a) { return U2(neg(a.y), a.x); } // GWBUG: Can caller use a version that does not negate real? + +// mul with (-2^30, -2^30). (twiddle of tau/8 aka sqrt(i)). Note: 2 * (+/-2^30)^2 == 1 (mod M61). +GF61 OVERLOAD mul_t8(GF61 a) { return shl(U2(sub(a.y, a.x), neg(add(a.x, a.y))), 30); } // GWBUG: Can caller use a version that does not negate real? + +// mul with (2^30, -2^30). (twiddle of 3*tau/8). +GF61 OVERLOAD mul_3t8(GF61 a) { return shl(U2(add(a.x, a.y), sub(a.y, a.x)), 30); } + +// Return a+b and a-b +void OVERLOAD X2_internal(GF61 *a, GF61 *b) { GF61 t = *a; *a = add(t, *b); *b = sub(t, *b); } + +// Same as X2(a, conjugate(b)) +void OVERLOAD X2conjb_internal(GF61 *a, GF61 *b) { GF61 t = *a; a->x = add(a->x, b->x); a->y = sub(a->y, b->y); b->x = sub(t.x, b->x); b->y = add(t.y, b->y); } + +// Same as X2(a, b), b = mul_t4(b) +void OVERLOAD X2_mul_t4_internal(GF61 *a, GF61 *b) { GF61 t = *a; *a = add(*a, *b); t.x = sub(t.x, b->x); b->x = sub(b->y, t.y); b->y = t.x; } + +// Same as X2(a, b), b = mul_t8(b) +void OVERLOAD X2_mul_t8_internal(GF61 *a, GF61 *b) { X2(*a, *b); *b = mul_t8(*b); } + +// Same as X2(a, b), b = mul_3t8(b) +void OVERLOAD X2_mul_3t8_internal(GF61 *a, GF61 *b) { X2(*a, *b); *b = mul_3t8(*b); } + +// Same as X2(a, b), b = conjugate(b) +void OVERLOAD X2_conjb_internal(GF61 *a, GF61 *b) { GF61 t = *a; *a = add(t, *b); b->x = sub(t.x, b->x); b->y = sub(b->y, t.y); } + +void OVERLOAD SWAP_internal(GF61 *a, GF61 *b) { GF61 t = *a; *a = *b; *b = t; } + +GF61 OVERLOAD addsub(GF61 a) { return U2(add(a.x, a.y), sub(a.x, a.y)); } +GF61 OVERLOAD foo2(GF61 a, GF61 b) { a = addsub(a); b = addsub(b); return addsub(U2(mul(RE(a), RE(b)), mul(IM(a), IM(b)))); } +GF61 OVERLOAD foo(GF61 a) { return foo2(a, a); } + +// The following routines can be used to reduce mod M61 operations (in the other Z61 implementations). +// Caller must track how many M61s need to be added to make positive values for subtractions. +// In function names, "q" stands for quick, "s" stands for slow (i.e. does mod). +// These functions are untested with this strict Z61 implementation. Callers need to eliminate all uses of + or - operators. + +Z61 OVERLOAD modM61(Z61 a) { return a; } +GF61 OVERLOAD modM61(GF61 a) { return a; } +Z61 OVERLOAD neg(Z61 a, u32 m61_count) { return neg(a); } +GF61 OVERLOAD neg(GF61 a, u32 m61_count) { return neg(a); } +Z61 OVERLOAD addq(Z61 a, Z61 b) { return add(a, b); } +GF61 OVERLOAD addq(GF61 a, GF61 b) { return add(a, b); } +Z61 OVERLOAD subq(Z61 a, Z61 b, u32 m61_count) { return sub(a, b); } +GF61 OVERLOAD subq(GF61 a, GF61 b, u32 m61_count) { return sub(a, b); } +Z61 OVERLOAD subs(Z61 a, Z61 b, u32 m61_count) { return sub(a, b); } +GF61 OVERLOAD subs(GF61 a, GF61 b, u32 m61_count) { return sub(a, b); } +void OVERLOAD X2q(GF61 *a, GF61 *b, u32 m61_count) { X2_internal(a, b); } +void OVERLOAD X2q_mul_t4(GF61 *a, GF61 *b, u32 m61_count) { X2_mul_t4_internal(a, b); } +void OVERLOAD X2s(GF61 *a, GF61 *b, u32 m61_count) { X2_internal(a, b); } +void OVERLOAD X2s_conjb(GF61 *a, GF61 *b, u32 m61_count) { X2_conjb_internal(a, b); } + + + + +// Philosophy: This Z61/GF61 implementation uses faster, sloppier mod M61 reduction where the end result is in the range 0..M61+epsilon. +// This implementation also handles subtractions by adding enough M61s to make a value positive. This allows us to always deal with positive +// intermediate results. The downside is that a caller using the sloppy/quick routines must keep track of how large unreduced values can get. +// An alternative implementation is to have Z61 be an i64 (costs us a precious bit of precision) but is surprisingly slower (at least on TitanV) because +// mod(a - b), where the mod routinue uses a signed right shift is slower than +// mod(a + (M61*2 - b)) where the mod routine uses an unsigned shift right. +// However, a long string of subtracts (example, fft8 does 3 subtracts before mod M61 might be better off using negative intermediate results. +// The mul routine (and obviously csq and cmul) must use only positive values as __int128 multiply is very slow. + +#elif 1 // Faster version that keeps results in the range 0 .. M61+epsilon + +u64 OVERLOAD get_Z61(Z61 a) { Z61 m = a - M61; return (m & 0x8000000000000000ULL) ? a : m; } // Get value in range 0 to M61-1 +i64 OVERLOAD get_balanced_Z61(Z61 a) { return (a >= 0x1000000000000000ULL) ? (i64) a - (i64) M61 : (i64) a; } // Get balanced value in range -M61/2 to M61/2 + +// Internal routine to bring Z61 value into the range 0..M61+epsilon +Z61 OVERLOAD modM61(Z61 a) { return (a & M61) + (a >> 61); } +GF61 OVERLOAD modM61(GF61 a) { return U2(modM61(a.x), modM61(a.y)); } +// Internal routine to negate a value by adding the specified number of M61s -- no mod M61 reduction +Z61 OVERLOAD neg(Z61 a, const u32 m61_count) { return m61_count * M61 - a; } +GF61 OVERLOAD neg(GF61 a, const u32 m61_count) { return U2(neg(a.x, m61_count), neg(a.y, m61_count)); } + +Z61 OVERLOAD add(Z61 a, Z61 b) { return modM61(a + b); } +GF61 OVERLOAD add(GF61 a, GF61 b) { return U2(add(a.x, b.x), add(a.y, b.y)); } + +Z61 OVERLOAD sub(Z61 a, Z61 b) { return modM61(a + neg(b, 2)); } +GF61 OVERLOAD sub(GF61 a, GF61 b) { return U2(sub(a.x, b.x), sub(a.y, b.y)); } + +Z61 OVERLOAD neg(Z61 a) { return modM61(neg(a, 2)); } // GWBUG: Examine all callers to see if neg call can be avoided +GF61 OVERLOAD neg(GF61 a) { return U2(neg(a.x), neg(a.y)); } + +// Assumes k reduced mod 61. +Z61 OVERLOAD shr(Z61 a, u32 k) { return (a >> k) + ((a << (61 - k)) & M61); } // Return range 0..M61+2^(61-k), can handle 64-bit inputs but small k is big epsilon +GF61 OVERLOAD shr(GF61 a, u32 k) { return U2(shr(a.x, k), shr(a.y, k)); } +Z61 OVERLOAD shl(Z61 a, u32 k) { return shr(a, 61 - k); } // Return range 0..M61+2^k, can handle 64-bit inputs but large k yields big epsilon +//Z61 OVERLOAD shl(Z61 a, u32 k) { return modM61(a << k) + ((a >> (64 - k)) << 3); } // Return range 0..M61+2^k, can handle 64-bit inputs but large k is big epsilon +//Z61 OVERLOAD shl(Z61 a, u32 k) { return modM61((a << k) + ((a >> (64 - k)) << 3)); } // Return range 0..M61+epsilon, input must be M61+epsilon a full 62-bit value can overflow +GF61 OVERLOAD shl(GF61 a, u32 k) { return U2(shl(a.x, k), shl(a.y, k)); } + +ulong2 wideMul(u64 ab, u64 cd) { + u128 r = mul64(ab, cd); + return U2(u128_lo64(r), u128_hi64(r)); +} + +// Returns a * b not modded by M61. Max value of result depends on the m61_counts of the inputs. +// Let n = (a_m61_count - 1) * (b_m61_count - 1). This is the maximum value in the highest 6 bits of a * b. +// If n <= 4 result will be at most (n+1)*M61+epsilon. +// If n > 4 result will be at most 2*M61+epsilon. +Z61 OVERLOAD weakMul(Z61 a, Z61 b, const u32 a_m61_count, const u32 b_m61_count) { + ulong2 ab = wideMul(a, b); + u64 lo = ab.x, hi = ab.y; + u64 lo61 = lo & M61; // Max value is M61 + if ((a_m61_count - 1) * (b_m61_count - 1) <= 4) { + hi = (hi << 3) + (lo >> 61); // Max value is (a_m61_count - 1) * (b_m61_count - 1) * M61 + epsilon + return lo61 + hi; // Max value is ((a_m61_count - 1) * (b_m61_count - 1) + 1) * M61 + epsilon + } else { + u64 hi61 = ((hi << 3) + (lo >> 61)) & M61; // Max value is M61 + return lo61 + hi61 + (hi >> 58); // Max value is 2*M61 + epsilon + } +} + +Z61 OVERLOAD mul(Z61 a, Z61 b) { return modM61(weakMul(a, b, 2, 2)); } + +Z61 OVERLOAD fma(Z61 a, Z61 b, Z61 c) { return modM61(weakMul(a, b, 2, 2) + c); } // GWBUG: Can we do better? + +// Multiply by 2 +Z61 OVERLOAD mul2(Z61 a) { return add(a, a); } +GF61 OVERLOAD mul2(GF61 a) { return U2(mul2(a.x), mul2(a.y)); } + +// Return conjugate of a +GF61 OVERLOAD conjugate(GF61 a) { return U2(a.x, neg(a.y)); } + +// Complex square. Uses (a + i*b)^2 == ((a+b)*(a-b) + i*2*a*b). +GF61 OVERLOAD csqq(GF61 a, const u32 m61_count) { + if (m61_count > 4) return csqq(modM61(a), 2); + Z61 re = weakMul(a.x + a.y, a.x + neg(a.y, m61_count), 2 * m61_count - 1, 2 * m61_count); + Z61 im = weakMul(a.x + a.x, a.y, 2 * m61_count - 1, m61_count); + return U2(re, im); +} +GF61 OVERLOAD csqs(GF61 a, const u32 m61_count) { return modM61(csqq(a, m61_count)); } +GF61 OVERLOAD csq(GF61 a) { return csqs(a, 2); } + +// a^2 + c +GF61 OVERLOAD csqa(GF61 a, GF61 c) { return U2(modM61(weakMul(a.x + a.y, a.x + neg(a.y, 2), 3, 4) + c.x), modM61(weakMul(a.x + a.x, a.y, 3, 2) + c.y)); } + +// Complex mul +#if 0 +GF61 OVERLOAD cmul(GF61 a, GF61 b) { // Use 3-epsilon extra bits in u64 + Z61 k1 = weakMul(b.x, a.x + a.y, 2, 3); // max value is 3*M61+epsilon + Z61 k2 = weakMul(a.x, b.y + neg(b.x, 2), 2, 3); // max value is 3*M61+epsilon + Z61 k3 = weakMul(a.y, b.y + b.x, 2, 3); // max value is 3*M61+epsilon + return U2(modM61(k1 + neg(k3, 4)), modM61(k1 + k2)); +} +#else +Z61 OVERLOAD weakMulAdd(Z61 a, Z61 b, u128 c, const u32 a_m61_count, const u32 b_m61_count) { // Max c value assumed to be 2*M61^2+epsilon + u128 ab = mad64(a, b, c); // Max value is ((a_m61_count - 1) * (b_m61_count - 1) + 2) * M61^2 + epsilon + u64 lo = u128_lo64(ab), hi = u128_hi64(ab); + u64 lo61 = lo & M61; // Max value is M61 + if ((a_m61_count - 1) * (b_m61_count - 1) + 2 <= 6) { + hi = (hi << 3) + (lo >> 61); // Max value is ((a_m61_count - 1) * (b_m61_count - 1) + 2) * M61 + epsilon + return lo61 + hi; // Max value is ((a_m61_count - 1) * (b_m61_count - 1) + 3) * M61 + epsilon + } else { + u64 hi61 = ((hi << 3) + (lo >> 61)) & M61; // Max value is M61 + return lo61 + hi61 + (hi >> 58); // Max value is 2*M61 + epsilon + } +} +GF61 OVERLOAD cmul(GF61 a, GF61 b) { + u128 k1 = mul64(b.x, a.x + a.y); // max value is 2*M61^2+epsilon + Z61 k1k2 = weakMulAdd(a.x, b.y + neg(b.x, 2), k1, 2, 4); // max value is 6*M61+epsilon + Z61 k1k3 = weakMulAdd(a.y, neg(b.y + b.x, 3), k1, 2, 4); // max value is 6*M61+epsilon + return U2(modM61(k1k3), modM61(k1k2)); +} +#endif + +// Square a root of unity complex number (the second version may be faster if the compiler optimizes the u128 squaring). +//GF61 OVERLOAD csqTrig(GF61 a) { Z61 two_ay = a.y + a.y; return U2(modM61(1 + weakMul(two_ay, neg(a.y, 2))), mul(a.x, two_ay)); } +GF61 OVERLOAD csqTrig(GF61 a) { Z61 ay_sq = weakMul(a.y, a.y, 2, 2); return U2(modM61(1 + neg(ay_sq + ay_sq, 4)), mul2(weakMul(a.x, a.y, 2, 2))); } + +// Cube w, a root of unity complex number, given w^2 and w +GF61 OVERLOAD ccubeTrig(GF61 sq, GF61 w) { Z61 tmp = sq.y + sq.y; return U2(modM61(weakMul(tmp, neg(w.y, 2), 3, 3) + w.x), modM61(weakMul(tmp, w.x, 3, 2) + neg(w.y, 2))); } + +// a + i*b +GF61 OVERLOAD addi(GF61 a, GF61 b) { return U2(sub(a.x, b.y), add(a.y, b.x)); } +// a - i*b +GF61 OVERLOAD subi(GF61 a, GF61 b) { return U2(add(a.x, b.y), sub(a.y, b.x)); } + +// mul with (0, 1). (twiddle of tau/4, sqrt(-1) aka "i"). +GF61 OVERLOAD mul_t4(GF61 a) { return U2(neg(a.y), a.x); } // GWBUG: Can caller use a version that does not negate real? + +// mul with (-2^30, -2^30). (twiddle of tau/8 aka sqrt(i)). Note: 2 * (+/-2^30)^2 == 1 (mod M61). +GF61 OVERLOAD mul_t8(GF61 a, const u32 m61_count) { return shl(U2(a.y + neg(a.x, m61_count), neg(a.x + a.y, 2 * m61_count - 1)), 30); } +GF61 OVERLOAD mul_t8(GF61 a) { return mul_t8(a, 2); } + +// mul with (2^30, -2^30). (twiddle of 3*tau/8). +GF61 OVERLOAD mul_3t8(GF61 a, const u32 m61_count) { return shl(U2(a.x + a.y, a.y + neg(a.x, m61_count)), 30); } +GF61 OVERLOAD mul_3t8(GF61 a) { return mul_3t8(a, 2); } + +// Return a+b and a-b +void OVERLOAD X2_internal(GF61 *a, GF61 *b) { GF61 t = *a; *a = add(t, *b); *b = sub(t, *b); } + +// Same as X2(a, conjugate(b)) +void OVERLOAD X2conjb_internal(GF61 *a, GF61 *b) { GF61 t = *a; a->x = add(a->x, b->x); a->y = sub(a->y, b->y); b->x = sub(t.x, b->x); b->y = add(t.y, b->y); } + +// Same as X2(a, b), b = mul_t4(b) +void OVERLOAD X2_mul_t4_internal(GF61 *a, GF61 *b) { GF61 t = *a; *a = add(*a, *b); t.x = sub(t.x, b->x); b->x = sub(b->y, t.y); b->y = t.x; } + +// Same as X2(a, b), b = mul_t8(b) +void OVERLOAD X2_mul_t8_internal(GF61 *a, GF61 *b) { GF61 t = *a; *a = add(t, *b); t = *b + neg(t, 2); *b = shl(U2(t.x + neg(t.y, 4), t.x + t.y), 30); } + +// Same as X2(a, b), b = mul_3t8(b) +void OVERLOAD X2_mul_3t8_internal(GF61 *a, GF61 *b) { GF61 t = *a; *a = add(t, *b); *b = t + neg(*b, 2); *b = mul_3t8(*b, 4); } + +// Same as X2(a, b), b = conjugate(b) +void OVERLOAD X2_conjb_internal(GF61 *a, GF61 *b) { GF61 t = *a; *a = add(t, *b); b->x = sub(t.x, b->x); b->y = sub(b->y, t.y); } + +void OVERLOAD SWAP_internal(GF61 *a, GF61 *b) { GF61 t = *a; *a = *b; *b = t; } + +GF61 OVERLOAD addsub(GF61 a) { return U2(add(a.x, a.y), sub(a.x, a.y)); } +GF61 OVERLOAD foo2(GF61 a, GF61 b) { a = addsub(a); b = addsub(b); return addsub(U2(mul(RE(a), RE(b)), mul(IM(a), IM(b)))); } +GF61 OVERLOAD foo(GF61 a) { return foo2(a, a); } + +// The following routines can be used to reduce mod M61 operations. Caller must track how many M61s need to be added to make positive +// values for subtractions. In function names, "q" stands for quick (no modM61), "s" stands for slow (i.e. does modM61). + +Z61 OVERLOAD addq(Z61 a, Z61 b) { return a + b; } +GF61 OVERLOAD addq(GF61 a, GF61 b) { return U2(addq(a.x, b.x), addq(a.y, b.y)); } + +Z61 OVERLOAD subq(Z61 a, Z61 b, const u32 m61_count) { return a + neg(b, m61_count); } +GF61 OVERLOAD subq(GF61 a, GF61 b, const u32 m61_count) { return U2(subq(a.x, b.x, m61_count), subq(a.y, b.y, m61_count)); } + +Z61 OVERLOAD subs(Z61 a, Z61 b, const u32 m61_count) { return modM61(a + neg(b, m61_count)); } +GF61 OVERLOAD subs(GF61 a, GF61 b, const u32 m61_count) { return U2(subs(a.x, b.x, m61_count), subs(a.y, b.y, m61_count)); } + +GF61 OVERLOAD addiq(GF61 a, GF61 b, const u32 m61_count) { return U2(subq(a.x, b.y, m61_count), addq(a.y, b.x)); } +GF61 OVERLOAD subiq(GF61 a, GF61 b, const u32 m61_count) { return U2(addq(a.x, b.y), subq(a.y, b.x, m61_count)); } + +void OVERLOAD X2q(GF61 *a, GF61 *b, const u32 m61_count) { GF61 t = *a; *a = t + *b; *b = t + neg(*b, m61_count); } +void OVERLOAD X2q_mul_t4(GF61 *a, GF61 *b, const u32 m61_count) { GF61 t = *a; *a = t + *b; t.x = t.x + neg(b->x, m61_count); b->x = b->y + neg(t.y, m61_count); b->y = t.x; } +void OVERLOAD X2q_mul_t8(GF61 *a, GF61 *b, const u32 m61_count) { GF61 t = *a; *a = t + *b; t = *b + neg(t, m61_count); *b = shl(U2(t.x + neg(t.y, m61_count * 2), t.x + t.y), 30); } +void OVERLOAD X2q_mul_3t8(GF61 *a, GF61 *b, const u32 m61_count) { GF61 t = *a; *a = t + *b; t = t + neg(*b, m61_count); *b = mul_3t8(t, m61_count * 2); } + +void OVERLOAD X2s(GF61 *a, GF61 *b, const u32 m61_count) { GF61 t = *a; *a = add(t, *b); *b = subs(t, *b, m61_count); } +void OVERLOAD X2s_conjb(GF61 *a, GF61 *b, const u32 a_m61_count, const u32 b_m61_count) { GF61 t = *a; *a = add(t, *b); b->x = subs(t.x, b->x, b_m61_count); b->y = subs(b->y, t.y, a_m61_count); } + +#endif + +#endif diff --git a/src/cl/middle.cl b/src/cl/middle.cl index f1a40b75..263937f3 100644 --- a/src/cl/middle.cl +++ b/src/cl/middle.cl @@ -34,6 +34,10 @@ #define MIDDLE_OUT_LDS_TRANSPOSE 1 #endif +#if !INPLACE // Original implementation (not in place) + +#if FFT_FP64 || NTT_GF61 + //**************************************************************************************** // Pair of routines to write data from carryFused and read data into fftMiddleIn //**************************************************************************************** @@ -50,7 +54,7 @@ // u[i] i ranges 0...MIDDLE-1 (multiples of SMALL_HEIGHT) // y ranges 0...SMALL_HEIGHT-1 (multiples of one) -void writeCarryFusedLine(T2 *u, P(T2) out, u32 line) { +void OVERLOAD writeCarryFusedLine(T2 *u, P(T2) out, u32 line) { #if PAD_SIZE > 0 u32 BIG_PAD_SIZE = (PAD_SIZE/2+1)*PAD_SIZE; out += line * WIDTH + line * PAD_SIZE + line / SMALL_HEIGHT * BIG_PAD_SIZE + (u32) get_local_id(0); // One pad every line + a big pad every SMALL_HEIGHT lines @@ -61,7 +65,7 @@ void writeCarryFusedLine(T2 *u, P(T2) out, u32 line) { #endif } -void readMiddleInLine(T2 *u, CP(T2) in, u32 y, u32 x) { +void OVERLOAD readMiddleInLine(T2 *u, CP(T2) in, u32 y, u32 x) { #if PAD_SIZE > 0 // Each work group reads successive y's which increments by one pad size. // Rather than having u[i] also increment by one, we choose a larger pad increment @@ -86,7 +90,7 @@ void readMiddleInLine(T2 *u, CP(T2) in, u32 y, u32 x) { // x ranges 0...SMALL_HEIGHT-1 (multiples of one) (also known as 0...G_H-1 and 0...NH-1) // y ranges 0...MIDDLE*WIDTH-1 (multiples of SMALL_HEIGHT) -void writeMiddleInLine (P(T2) out, T2 *u, u32 chunk_y, u32 chunk_x) +void OVERLOAD writeMiddleInLine (P(T2) out, T2 *u, u32 chunk_y, u32 chunk_x) { //u32 SIZEY = IN_WG / IN_SIZEX; //u32 num_x_chunks = WIDTH / IN_SIZEX; // Number of x chunks @@ -121,7 +125,7 @@ void writeMiddleInLine (P(T2) out, T2 *u, u32 chunk_y, u32 chunk_x) // Read a line for tailFused or fftHin // This reads partially transposed data as written by fftMiddleIn -void readTailFusedLine(CP(T2) in, T2 *u, u32 line, u32 me) { +void OVERLOAD readTailFusedLine(CP(T2) in, T2 *u, u32 line, u32 me) { u32 SIZEY = IN_WG / IN_SIZEX; #if PAD_SIZE > 0 @@ -145,8 +149,8 @@ void readTailFusedLine(CP(T2) in, T2 *u, u32 line, u32 me) { u32 fftMiddleIn_y_incr = G_H; // The increment to next fftMiddleIn y value u32 chunk_y_incr = fftMiddleIn_y_incr / SIZEY; // The increment to next fftMiddleIn chunk_y value for (i32 i = 0; i < NH; ++i) { -// u32 fftMiddleIn_y = i * G_H + me; // The fftMiddleIn y value -// u32 chunk_y = fftMiddleIn_y / SIZEY; // The fftMiddleIn chunk_y value +// u32 fftMiddleIn_y = i * G_H + me; // The fftMiddleIn y value +// u32 chunk_y = fftMiddleIn_y / SIZEY; // The fftMiddleIn chunk_y value u[i] = NTLOAD(in[chunk_y * (MIDDLE * IN_WG + PAD_SIZE)]); // Adjust in pointer the same way writeMiddleInLine did chunk_y += chunk_y_incr; } @@ -196,7 +200,7 @@ void readTailFusedLine(CP(T2) in, T2 *u, u32 line, u32 me) { // i in u[i] ranges 0...MIDDLE-1 (multiples of SMALL_HEIGHT) // y ranges 0...WIDTH-1 (multiples of BIG_HEIGHT) (processed in batches of OUT_WG/OUT_SIZEX) -void writeTailFusedLine(T2 *u, P(T2) out, u32 line, u32 me) { +void OVERLOAD writeTailFusedLine(T2 *u, P(T2) out, u32 line, u32 me) { #if PAD_SIZE > 0 #if MIDDLE == 4 || MIDDLE == 8 || MIDDLE == 16 u32 BIG_PAD_SIZE = (PAD_SIZE/2+1)*PAD_SIZE; @@ -211,7 +215,7 @@ void writeTailFusedLine(T2 *u, P(T2) out, u32 line, u32 me) { #endif } -void readMiddleOutLine(T2 *u, CP(T2) in, u32 y, u32 x) { +void OVERLOAD readMiddleOutLine(T2 *u, CP(T2) in, u32 y, u32 x) { #if PAD_SIZE > 0 #if MIDDLE == 4 || MIDDLE == 8 || MIDDLE == 16 // Each u[i] increments by one pad size. @@ -274,7 +278,7 @@ void readMiddleOutLine(T2 *u, CP(T2) in, u32 y, u32 x) { // adjusted to effect a transpose. Or caller must transpose the x and y values and send us an out pointer with thread_id added in. // In other words, caller is responsible for deciding the best way to transpose x and y values. -void writeMiddleOutLine (P(T2) out, T2 *u, u32 chunk_y, u32 chunk_x) +void OVERLOAD writeMiddleOutLine (P(T2) out, T2 *u, u32 chunk_y, u32 chunk_x) { //u32 SIZEY = OUT_WG / OUT_SIZEX; //u32 num_x_chunks = SMALL_HEIGHT / OUT_SIZEX; // Number of x chunks @@ -307,7 +311,7 @@ void writeMiddleOutLine (P(T2) out, T2 *u, u32 chunk_y, u32 chunk_x) } // Read a line for carryFused or FFTW. This line was written by writeMiddleOutLine above. -void readCarryFusedLine(CP(T2) in, T2 *u, u32 line) { +void OVERLOAD readCarryFusedLine(CP(T2) in, T2 *u, u32 line) { u32 me = get_local_id(0); u32 SIZEY = OUT_WG / OUT_SIZEX; @@ -332,8 +336,8 @@ void readCarryFusedLine(CP(T2) in, T2 *u, u32 line) { u32 fftMiddleOut_y_incr = G_W; // The increment to next fftMiddleOut y value u32 chunk_y_incr = fftMiddleOut_y_incr / SIZEY; // The increment to next fftMiddleOut chunk_y value for (i32 i = 0; i < NW; ++i) { -// u32 fftMiddleOut_y = i * G_W + me; // The fftMiddleOut y value -// u32 chunk_y = fftMiddleOut_y / SIZEY; // The fftMiddleOut chunk_y value +// u32 fftMiddleOut_y = i * G_W + me; // The fftMiddleOut y value +// u32 chunk_y = fftMiddleOut_y / SIZEY; // The fftMiddleOut chunk_y value u[i] = NTLOAD(in[chunk_y * (MIDDLE * OUT_WG + PAD_SIZE)]); // Adjust in pointer the same way writeMiddleOutLine did chunk_y += chunk_y_incr; } @@ -367,3 +371,700 @@ void readCarryFusedLine(CP(T2) in, T2 *u, u32 line) { #endif } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 || NTT_GF31 + +void OVERLOAD writeCarryFusedLine(F2 *u, P(F2) out, u32 line) { +#if PAD_SIZE > 0 + u32 BIG_PAD_SIZE = (PAD_SIZE/2+1)*PAD_SIZE; + out += line * WIDTH + line * PAD_SIZE + line / SMALL_HEIGHT * BIG_PAD_SIZE + (u32) get_local_id(0); // One pad every line + a big pad every SMALL_HEIGHT lines + for (u32 i = 0; i < NW; ++i) { NTSTORE(out[i * G_W], u[i]); } +#else + out += line * WIDTH + (u32) get_local_id(0); + for (u32 i = 0; i < NW; ++i) { NTSTORE(out[i * G_W], u[i]); } +#endif +} + +void OVERLOAD readMiddleInLine(F2 *u, CP(F2) in, u32 y, u32 x) { +#if PAD_SIZE > 0 + // Each work group reads successive y's which increments by one pad size. + // Rather than having u[i] also increment by one, we choose a larger pad increment + u32 BIG_PAD_SIZE = (PAD_SIZE/2+1)*PAD_SIZE; + in += y * WIDTH + y * PAD_SIZE + (y / SMALL_HEIGHT) * BIG_PAD_SIZE + x; + for (i32 i = 0; i < MIDDLE; ++i) { u[i] = NTLOAD(in[i * (SMALL_HEIGHT * (WIDTH + PAD_SIZE) + BIG_PAD_SIZE)]); } +#else + in += y * WIDTH + x; + for (i32 i = 0; i < MIDDLE; ++i) { u[i] = NTLOAD(in[i * SMALL_HEIGHT * WIDTH]); } +#endif +} + +void OVERLOAD writeMiddleInLine (P(F2) out, F2 *u, u32 chunk_y, u32 chunk_x) +{ +#if PAD_SIZE > 0 + u32 SIZEY = IN_WG / IN_SIZEX; + u32 BIG_PAD_SIZE = (PAD_SIZE/2+1)*PAD_SIZE; + + out += chunk_y * (MIDDLE * IN_WG + PAD_SIZE) + // Write y chunks after middle chunks and a pad + chunk_x * (SMALL_HEIGHT * MIDDLE * IN_SIZEX + // num_y_chunks * (MIDDLE * IN_WG + PAD_SIZE) + SMALL_HEIGHT / SIZEY * PAD_SIZE + BIG_PAD_SIZE); + // = SMALL_HEIGHT / SIZEY * (MIDDLE * IN_WG + PAD_SIZE) + // = SMALL_HEIGHT / (IN_WG / IN_SIZEX) * (MIDDLE * IN_WG + PAD_SIZE) + // = SMALL_HEIGHT * MIDDLE * IN_SIZEX + SMALL_HEIGHT / SIZEY * PAD_SIZE + // Write each u[i] sequentially + for (int i = 0; i < MIDDLE; ++i) { NTSTORE(out[i * IN_WG], u[i]); } +#else + // Output data such that readCarryFused lines are packed tightly together. No padding. + out += chunk_y * MIDDLE * IN_WG + // Write y chunks after middles + chunk_x * MIDDLE * SMALL_HEIGHT * IN_SIZEX; // num_y_chunks * IN_WG = SMALL_HEIGHT / SIZEY * MIDDLE * IN_WG + // = MIDDLE * SMALL_HEIGHT / (IN_WG / IN_SIZEX) * IN_WG + // = MIDDLE * SMALL_HEIGHT * IN_SIZEX + // Write each u[i] sequentially + for (int i = 0; i < MIDDLE; ++i) { NTSTORE(out[i * IN_WG], u[i]); } +#endif +} + +// Read a line for tailFused or fftHin +// This reads partially transposed data as written by fftMiddleIn +void OVERLOAD readTailFusedLine(CP(F2) in, F2 *u, u32 line, u32 me) { + u32 SIZEY = IN_WG / IN_SIZEX; +#if PAD_SIZE > 0 + // Adjust in pointer based on the x value used in writeMiddleInLine + u32 BIG_PAD_SIZE = (PAD_SIZE/2+1)*PAD_SIZE; + u32 fftMiddleIn_x = line % WIDTH; // The fftMiddleIn x value + u32 chunk_x = fftMiddleIn_x / IN_SIZEX; // The fftMiddleIn chunk_x value + in += chunk_x * (SMALL_HEIGHT * MIDDLE * IN_SIZEX + SMALL_HEIGHT / SIZEY * PAD_SIZE + BIG_PAD_SIZE); // Adjust in pointer the same way writeMiddleInLine did + u32 x_within_in_wg = fftMiddleIn_x % IN_SIZEX; // There were IN_SIZEX x values within IN_WG + in += x_within_in_wg * SIZEY; // Adjust in pointer the same way writeMiddleInLine wrote x values within IN_WG + + // Adjust in pointer based on the i value used in writeMiddleInLine + u32 fftMiddleIn_i = line / WIDTH; // The i in fftMiddleIn's u[i] + in += fftMiddleIn_i * IN_WG; // Adjust in pointer the same way writeMiddleInLine did + // Adjust in pointer based on the y value used in writeMiddleInLine. This code is a little obscure as rocm compiler has trouble optimizing commented out code. + in += me % SIZEY; // Adjust in pointer to read SIZEY consecutive values + u32 fftMiddleIn_y = me; // The i=0 fftMiddleIn y value + u32 chunk_y = fftMiddleIn_y / SIZEY; // The i=0 fftMiddleIn chunk_y value + u32 fftMiddleIn_y_incr = G_H; // The increment to next fftMiddleIn y value + u32 chunk_y_incr = fftMiddleIn_y_incr / SIZEY; // The increment to next fftMiddleIn chunk_y value + for (i32 i = 0; i < NH; ++i) { + u[i] = NTLOAD(in[chunk_y * (MIDDLE * IN_WG + PAD_SIZE)]); // Adjust in pointer the same way writeMiddleInLine did + chunk_y += chunk_y_incr; + } +#else // Read data that was not rotated or padded + // Adjust in pointer based on the x value used in writeMiddleInLine + u32 fftMiddleIn_x = line % WIDTH; // The fftMiddleIn x value + u32 chunk_x = fftMiddleIn_x / IN_SIZEX; // The fftMiddleIn chunk_x value + in += chunk_x * (SMALL_HEIGHT * MIDDLE * IN_SIZEX); // Adjust in pointer the same way writeMiddleInLine did + u32 x_within_in_wg = fftMiddleIn_x % IN_SIZEX; // There were IN_SIZEX x values within IN_WG + in += x_within_in_wg * SIZEY; // Adjust in pointer the same way writeMiddleInLine wrote x values within IN_WG + // Adjust in pointer based on the i value used in writeMiddleInLine + u32 fftMiddleIn_i = line / WIDTH; // The i in fftMiddleIn's u[i] + in += fftMiddleIn_i * IN_WG; // Adjust in pointer the same way writeMiddleInLine did + // Adjust in pointer based on the y value used in writeMiddleInLine. This code is a little obscure as rocm compiler has trouble optimizing commented out code. + in += me % SIZEY; // Adjust in pointer to read SIZEY consecutive values + u32 fftMiddleIn_y = me; // The i=0 fftMiddleIn y value + u32 chunk_y = fftMiddleIn_y / SIZEY; // The i=0 fftMiddleIn chunk_y value + u32 fftMiddleIn_y_incr = G_H; // The increment to next fftMiddleIn y value + u32 chunk_y_incr = fftMiddleIn_y_incr / SIZEY; // The increment to next fftMiddleIn chunk_y value + for (i32 i = 0; i < NH; ++i) { + u32 fftMiddleIn_y = i * G_H + me; // The fftMiddleIn y value + u32 chunk_y = fftMiddleIn_y / SIZEY; // The fftMiddleIn chunk_y value + u[i] = NTLOAD(in[chunk_y * (MIDDLE * IN_WG)]); // Adjust in pointer the same way writeMiddleInLine did + chunk_y += chunk_y_incr; + } +#endif +} + +void OVERLOAD writeTailFusedLine(F2 *u, P(F2) out, u32 line, u32 me) { +#if PAD_SIZE > 0 +#if MIDDLE == 4 || MIDDLE == 8 || MIDDLE == 16 + u32 BIG_PAD_SIZE = (PAD_SIZE/2+1)*PAD_SIZE; + out += line * (SMALL_HEIGHT + PAD_SIZE) + line / MIDDLE * BIG_PAD_SIZE + me; // Pad every output line plus every MIDDLE +#else + out += line * (SMALL_HEIGHT + PAD_SIZE) + me; // Pad every output line +#endif + for (u32 i = 0; i < NH; ++i) { NTSTORE(out[i * G_H], u[i]); } +#else // No padding + out += line * SMALL_HEIGHT + me; + for (u32 i = 0; i < NH; ++i) { NTSTORE(out[i * G_H], u[i]); } +#endif +} + +void OVERLOAD readMiddleOutLine(F2 *u, CP(F2) in, u32 y, u32 x) { +#if PAD_SIZE > 0 +#if MIDDLE == 4 || MIDDLE == 8 || MIDDLE == 16 + // Each u[i] increments by one pad size. + // Rather than each work group reading successive y's also increment by one, we choose a larger pad increment. + u32 BIG_PAD_SIZE = (PAD_SIZE/2+1)*PAD_SIZE; + in += y * MIDDLE * (SMALL_HEIGHT + PAD_SIZE) + y * BIG_PAD_SIZE + x; +#else + in += y * MIDDLE * (SMALL_HEIGHT + PAD_SIZE) + x; +#endif + for (i32 i = 0; i < MIDDLE; ++i) { u[i] = NTLOAD(in[i * (SMALL_HEIGHT + PAD_SIZE)]); } +#else // No rotation, might be better on nVidia cards + in += y * MIDDLE * SMALL_HEIGHT + x; + for (i32 i = 0; i < MIDDLE; ++i) { u[i] = NTLOAD(in[i * SMALL_HEIGHT]); } +#endif +} + +void OVERLOAD writeMiddleOutLine (P(F2) out, F2 *u, u32 chunk_y, u32 chunk_x) +{ +#if PAD_SIZE > 0 + u32 SIZEY = OUT_WG / OUT_SIZEX; + u32 BIG_PAD_SIZE = (PAD_SIZE/2+1)*PAD_SIZE; + + out += chunk_y * (MIDDLE * OUT_WG + PAD_SIZE) + // Write y chunks after middle chunks and a pad + chunk_x * (WIDTH * MIDDLE * OUT_SIZEX + // num_y_chunks * (MIDDLE * OUT_WG + PAD_SIZE) + WIDTH / SIZEY * PAD_SIZE + BIG_PAD_SIZE);// = WIDTH / SIZEY * (MIDDLE * OUT_WG + PAD_SIZE) + // = WIDTH / (OUT_WG / OUT_SIZEX) * (MIDDLE * OUT_WG + PAD_SIZE) + // = WIDTH * MIDDLE * OUT_SIZEX + WIDTH / SIZEY * PAD_SIZE + // Write each u[i] sequentially + for (int i = 0; i < MIDDLE; ++i) { NTSTORE(out[i * OUT_WG], u[i]); } +#else + // Output data such that readCarryFused lines are packed tightly together. No padding. + out += chunk_y * MIDDLE * OUT_WG + // Write y chunks after middles + chunk_x * MIDDLE * WIDTH * OUT_SIZEX; // num_y_chunks * OUT_WG = WIDTH / SIZEY * MIDDLE * OUT_WG + // = MIDDLE * WIDTH / (OUT_WG / OUT_SIZEX) * OUT_WG + // = MIDDLE * WIDTH * OUT_SIZEX + // Write each u[i] sequentially + for (int i = 0; i < MIDDLE; ++i) { NTSTORE(out[i * OUT_WG], u[i]); } +#endif +} + +void OVERLOAD readCarryFusedLine(CP(F2) in, F2 *u, u32 line) { + u32 me = get_local_id(0); + u32 SIZEY = OUT_WG / OUT_SIZEX; +#if PAD_SIZE > 0 + // Adjust in pointer based on the x value used in writeMiddleOutLine + u32 BIG_PAD_SIZE = (PAD_SIZE/2+1)*PAD_SIZE; + u32 fftMiddleOut_x = line % SMALL_HEIGHT; // The fftMiddleOut x value + u32 chunk_x = fftMiddleOut_x / OUT_SIZEX; // The fftMiddleOut chunk_x value + in += chunk_x * (WIDTH * MIDDLE * OUT_SIZEX + WIDTH / SIZEY * PAD_SIZE + BIG_PAD_SIZE); // Adjust in pointer the same way writeMiddleOutLine did + u32 x_within_out_wg = fftMiddleOut_x % OUT_SIZEX; // There were OUT_SIZEX x values within OUT_WG + in += x_within_out_wg * SIZEY; // Adjust in pointer the same way writeMiddleOutLine wrote x values within OUT_WG + // Adjust in pointer based on the i value used in writeMiddleOutLine + u32 fftMiddleOut_i = line / SMALL_HEIGHT; // The i in fftMiddleOut's u[i] + in += fftMiddleOut_i * OUT_WG; // Adjust in pointer the same way writeMiddleOutLine did + // Adjust in pointer based on the y value used in writeMiddleOutLine. This code is a little obscure as rocm compiler has trouble optimizing commented out code. + in += me % SIZEY; // Adjust in pointer to read SIZEY consecutive values + u32 fftMiddleOut_y = me; // The i=0 fftMiddleOut y value + u32 chunk_y = fftMiddleOut_y / SIZEY; // The i=0 fftMiddleOut chunk_y value + u32 fftMiddleOut_y_incr = G_W; // The increment to next fftMiddleOut y value + u32 chunk_y_incr = fftMiddleOut_y_incr / SIZEY; // The increment to next fftMiddleOut chunk_y value + for (i32 i = 0; i < NW; ++i) { + u[i] = NTLOAD(in[chunk_y * (MIDDLE * OUT_WG + PAD_SIZE)]); // Adjust in pointer the same way writeMiddleOutLine did + chunk_y += chunk_y_incr; + } +#else // Read data that was not rotated or padded + // Adjust in pointer based on the x value used in writeMiddleOutLine + u32 fftMiddleOut_x = line % SMALL_HEIGHT; // The fftMiddleOut x value + u32 chunk_x = fftMiddleOut_x / OUT_SIZEX; // The fftMiddleOut chunk_x value + in += chunk_x * MIDDLE * WIDTH * OUT_SIZEX; // Adjust in pointer the same way writeMiddleOutLine did + u32 x_within_out_wg = fftMiddleOut_x % OUT_SIZEX; // There were OUT_SIZEX x values within OUT_WG + in += x_within_out_wg * SIZEY; // Adjust in pointer the same way writeMiddleOutLine wrote x values with OUT_WG + // Adjust in pointer based on the i value used in writeMiddleOutLine + u32 fftMiddleOut_i = line / SMALL_HEIGHT; // The i in fftMiddleOut's u[i] + in += fftMiddleOut_i * OUT_WG; // Adjust in pointer the same way writeMiddleOutLine did + // Adjust in pointer based on the y value used in writeMiddleOutLine. This code is a little obscure as rocm compiler has trouble optimizing commented out code. + in += me % SIZEY; // Adjust in pointer to read SIZEY consecutive values + u32 fftMiddleOut_y = me; // The i=0 fftMiddleOut y value + u32 chunk_y = fftMiddleOut_y / SIZEY; // The i=0 fftMiddleOut chunk_y value + u32 fftMiddleOut_y_incr = G_W; // The increment to next fftMiddleOut y value + u32 chunk_y_incr = fftMiddleOut_y_incr / SIZEY; // The increment to next fftMiddleOut chunk_y value + for (i32 i = 0; i < NW; ++i) { + u[i] = NTLOAD(in[chunk_y * MIDDLE * OUT_WG]); // Adjust in pointer the same way writeMiddleOutLine did + chunk_y += chunk_y_incr; + } +#endif +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +// Since F2 and GF31 are the same size we can simply call the floats based code + +void OVERLOAD writeCarryFusedLine(GF31 *u, P(GF31) out, u32 line) { + writeCarryFusedLine((F2 *) u, (P(F2)) out, line); +} + +void OVERLOAD readMiddleInLine(GF31 *u, CP(GF31) in, u32 y, u32 x) { + readMiddleInLine((F2 *) u, (CP(F2)) in, y, x); +} + +void OVERLOAD writeMiddleInLine (P(GF31) out, GF31 *u, u32 chunk_y, u32 chunk_x) { + writeMiddleInLine ((P(F2)) out, (F2 *) u, chunk_y, chunk_x); +} + +void OVERLOAD readTailFusedLine(CP(GF31) in, GF31 *u, u32 line, u32 me) { + readTailFusedLine((CP(F2)) in, (F2 *) u, line, me); +} + +void OVERLOAD writeTailFusedLine(GF31 *u, P(GF31) out, u32 line, u32 me) { + writeTailFusedLine((F2 *) u, (P(F2)) out, line, me); +} + +void OVERLOAD readMiddleOutLine(GF31 *u, CP(GF31) in, u32 y, u32 x) { + readMiddleOutLine((F2 *) u, (CP(F2)) in, y, x); +} + +void OVERLOAD writeMiddleOutLine (P(GF31) out, GF31 *u, u32 chunk_y, u32 chunk_x) { + writeMiddleOutLine ((P(F2)) out, (F2 *) u, chunk_y, chunk_x); +} + +void OVERLOAD readCarryFusedLine(CP(GF31) in, GF31 *u, u32 line) { + readCarryFusedLine((CP(F2)) in, (F2 *) u, line); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +// Since T2 and GF61 are the same size we can simply call the doubles based code + +void OVERLOAD writeCarryFusedLine(GF61 *u, P(GF61) out, u32 line) { + writeCarryFusedLine((T2 *) u, (P(T2)) out, line); +} + +void OVERLOAD readMiddleInLine(GF61 *u, CP(GF61) in, u32 y, u32 x) { + readMiddleInLine((T2 *) u, (CP(T2)) in, y, x); +} + +void OVERLOAD writeMiddleInLine (P(GF61) out, GF61 *u, u32 chunk_y, u32 chunk_x) { + writeMiddleInLine ((P(T2)) out, (T2 *) u, chunk_y, chunk_x); +} + +void OVERLOAD readTailFusedLine(CP(GF61) in, GF61 *u, u32 line, u32 me) { + readTailFusedLine((CP(T2)) in, (T2 *) u, line, me); +} + +void OVERLOAD writeTailFusedLine(GF61 *u, P(GF61) out, u32 line, u32 me) { + writeTailFusedLine((T2 *) u, (P(T2)) out, line, me); +} + +void OVERLOAD readMiddleOutLine(GF61 *u, CP(GF61) in, u32 y, u32 x) { + readMiddleOutLine((T2 *) u, (CP(T2)) in, y, x); +} + +void OVERLOAD writeMiddleOutLine (P(GF61) out, GF61 *u, u32 chunk_y, u32 chunk_x) { + writeMiddleOutLine ((P(T2)) out, (T2 *) u, chunk_y, chunk_x); +} + +void OVERLOAD readCarryFusedLine(CP(GF61) in, GF61 *u, u32 line) { + readCarryFusedLine((CP(T2)) in, (T2 *) u, line); +} + +#endif + + + + + + +#else // New implementation (in-place) + +// Goals: +// 1) In-place transpose. Rather than "ping-pong"ing buffers, an in-place transpose uses half as much memory. This may allow +// the entire FFT/NTT data set to reside in the L2 cache on upper end consumer GPUs (circa 2025) which can have 64MB or larger L2 caches. +// 2) We want to have distribute the carryFused and/or tailSquare memory in the L2 cache with minimal cache line collisions. The hope is to (one day) do +// fftMiddleOut/carryFused/fftMiddleIn or fftMiddleIn/tailSquare/fftMiddleOut in L2 cache-sized chunks to minimize the slowest memory accesses. +// The cost of extra kernel launches may negate any L2 cache benefits. +// 3) We use swizzling and/or modest padding to reduce carryFused L2 cache line collisions. Several different memory layouts and padding were tried +// on nVidia Titan V and AMD Radeon VII to find the fastest in-place layout and padding scheme. Hopefully, these will schemes will work well +// on later generation GPUs with different L2 cache dimensions (size and "number-of-ways"). +// 4) Apparently cache line collisions in the L1 cache also adversely affect timings. The L1 cache may have a different cache line size and number-of-ways +// which makes padding tuning a bit difficult. This is especially true on AMD which has a very strange channel & banks partitioning of memory accesses. +// 5) Manufacturers are not very good about documenting L1 & L2 cache configurations. nVidia GPUs seem to have a L2 cache line size of 128 bytes. +// AMD documentation seems to indicate 64 or 128 byte cache lines. However, other documentation indicates 256 byte reads are to be preferred. +// Clinfo on an rx9070 says the L2 cache line size is 256 bytes. Thus, our goal is to target cache line size of 256 bytes. +// +// Here is the proposed memory layout for a 512 x 8 x 512 = 2M complex FFT using two doubles (16 bytes). PRPLL calls this a 4M FFT for the end user. +// The in-place transpose done by fftMiddleIn and fftMiddleOut works on a grid of 16 values from carryFused (multiples of 4K) and 16 values +// from tailSquare (multiples of 1). Sixteen 16-byte values = one 256 byte cache line or two 128 byte cache lines. +// +// tailSquare memory layout (also fftMiddleIn output layout and fftMiddleOut input layout). +// A HEIGHT=512 tail line is 512*16 bytes = 8KB. There are 2M/512=4K lines. Lines k and N-k are processed together. +// A 2MB L2 cache can fit 256 tail lines. So the first group of tail lines output by fftMiddleIn can be (mostly) paired with the last group of +// tail lines output by fftMiddleIn for tailSquare to process. 128 tail lines = 64K FFT values = 1MB. +// Here is the memory layout of FFT data. +// 0..511 The first tailSquare line (tail line 0). SMALL_HEIGHT values * 16 bytes (8KB). +// 4K The tailSquare line starting with FFT data element 4K (tail line 1). +// .. +// 60K These 16 lines will form 16x16 "transpose blocks". +// 64K Follow the same pattern starting at FFT data element 64K (tail line 16). Note: We have tried placing the "middle" lines next. +// ... +// 2M-4K The last "width" line (tail line MIDDLE*(WIDTH-1)). +// 512.. The tailSquare line starting with the first "middle" value (tail line WIDTH, i.e. 512), followed again by WIDTH lines 4K apart. +// ... +// 3.5K The tailSquare line containing the last middle value that fftMiddleOut will need (tail line (MIDDLE-1)*WIDTH) +// ... +// +// fftMiddleOut uses 16 sets of 16 threads to transpose 16x16 blocks of 16-byte values. Those 256 threads also process the MIDDLE blocks. For example, +// the 256 threads read 0,1..15, 4K+0,4K+1..4K+15, ..., 60K+0..60K+15 to form one 16x16 block. For MIDDLE processing, those 256 threads also read +// the seven 16x16 blocks beginning at 512, 1024, ... 3584. +// +// After transposing, the memory layout of fftMiddleOut output (also carryFused input & output and fftMiddleIn layout) is: +// 0,4K..60K 16..16+60K, 32..32+60K, ..., 496..496+60K (SMALL_HEIGHT/16=32) groups of 16 values * 16 bytes (8KB) +// ... +// 15.. +// 64K After the above 16 lines (128KB), follow the same pattern starting at FFT data element 64K +// ... +// 2.5M-64K +// 512.. After WIDTH lines above, follow the same pattern starting at the first "middle" value +// ... +// 3.5K The last "middle" value +// ... +// +// carryFused reads FFT data values that are 4K apart. The first 16 are in a single 256 byte cache line. The next 16 (starting at 64K) occurs 128KB later. +// This large power of two stride could cause cache collisions depending on the cache layout. If so, padding or swizzling MAY be very useful. +// carryFused does not reuse any read in data, so in theory its no big deal if a cache line is evicted before writing the result back out. +// I'm not sure if cache hardware cares about having multiple reads to the same cache line in-flight. NOTE: One minor advantage to eliminating cache +// conflicts is that carryFused likely processes lines in order or close to it. If we next process MiddleIn/tailSquare/MiddleOut in a 2MB cache as +// described above, it will appreciate the last 128 lines being in the L2 cache from carryfused. +// +// Back to carryfused. What might be an optimal padding scheme? Say that 64 carryfuseds are active simultaneously (probably more). +// The +1s are +8KB +// The +16s are +256B +// The +64Ks are at +128KB +// This leaves the "columns" starting at +1KB unused - suggesting a pad of +1KB before the 64Ks would yield a better distribution in the L2 cache. +// However, if we have say an 8-way 16MB L2 cache then each way contains 2MB. If so, we'd want to pad 1KB before the 16th 64K FFT data value. + +#if FFT_FP64 || NTT_GF61 + +//**************************************************************************************** +// Pair of routines to read/write data to/from carryFused +//**************************************************************************************** + +#if INPLACE == 1 // nVidia friendly padding +// Place middle rows after first 16 rows +//#define SIZEBLK (SMALL_HEIGHT + 16) // Pad 256 bytes +//#define SIZEW (16 * SIZEA + 16) // Pad 256 bytes +//#define SIZEM (MIDDLE * SIZEB + (1 - (MIDDLE & 1)) * 16) // Pad 256 bytes if MIDDLE is even +// Place middle rows after all width rows +#define SIZEBLK (SMALL_HEIGHT + 0) // No pad needed when swizzling +#define SIZEW (16 * SIZEBLK + 16) // Pad 256 bytes +#define SIZEM (WIDTH / 16 * SIZEW + 16) // Pad 256 bytes +#define SWIZ(a,m) ((m) ^ (a)) // Swizzle 16 rows (remove "^ (a)" to turn swizzling off) +#else // AMD friendly padding +// Place middle rows after first 16 rows +//#define SIZEBLK (SMALL_HEIGHT + 16) // Pad 256 bytes +//#define SIZEM (16 * SIZEBLK + 16) // Pad 256 bytes +//#define SIZEW (MIDDLE * SIZEM + (1 - (MIDDLE & 1)) * 16) // Pad 256 bytes if MIDDLE is even +// Place middle rows after all width rows +#define SIZEBLK (SMALL_HEIGHT + 0) // No pad needed when swizzling +#define SIZEW (16 * SIZEBLK + 16) // Pad 256 bytes +#define SIZEM (WIDTH / 16 * SIZEW + 0) // Pad 0 bytes +#define SWIZ(a,m) ((m) ^ (a)) // Swizzle 16 rows (remove "^ (a)" to turn swizzling off) +#endif + +// me ranges 0...WIDTH/NW-1 (multiples of BIG_HEIGHT) +// u[i] ranges 0..NW-1 (big multiples of BIG_HEIGHT) +// line ranges 0...BIG_HEIGHT-1 (multiples of one) + +// Read a line for carryFused or FFTW. This line was written by writeMiddleOutLine above. +void OVERLOAD readCarryFusedLine(CP(T2) in, T2 *u, u32 line) { + u32 me = get_local_id(0); // Multiples of BIG_HEIGHT + u32 middle = line / SMALL_HEIGHT; // Multiples of SMALL_HEIGHT + line = line % SMALL_HEIGHT; // Multiples of one + in += (me / 16 * SIZEW) + (middle * SIZEM) + (line % 16 * SIZEBLK) + SWIZ(line % 16, line / 16) * 16 + (me % 16); + for (u32 i = 0; i < NW; ++i) { u[i] = NTLOAD(in[i * G_W / 16 * SIZEW]); } +} + +// Write a line from carryFused. This data will be read by fftMiddleIn. +void OVERLOAD writeCarryFusedLine(T2 *u, P(T2) out, u32 line) { + u32 me = get_local_id(0); // Multiples of BIG_HEIGHT + u32 middle = line / SMALL_HEIGHT; // Multiples of SMALL_HEIGHT + line = line % SMALL_HEIGHT; // Multiples of one + out += (me / 16 * SIZEW) + (middle * SIZEM) + (line % 16 * SIZEBLK) + SWIZ(line % 16, line / 16) * 16 + (me % 16); + for (i32 i = 0; i < NW; ++i) { NTSTORE(out[i * G_W / 16 * SIZEW], u[i]); } +} + +//**************************************************************************************** +// Pair of routines to read/write data to/from fftMiddleIn +//**************************************************************************************** + +// x ranges 0...WIDTH-1 (multiples of BIG_HEIGHT) +// u[i] ranges 0...MIDDLE-1 (multiples of SMALL_HEIGHT) +// y ranges 0...SMALL_HEIGHT-1 (multiples of one) + +void OVERLOAD readMiddleInLine(T2 *u, CP(T2) in, u32 y, u32 x) { + in += (x / 16 * SIZEW) + (y % 16 * SIZEBLK) + (SWIZ(y % 16, y / 16) * 16) + (x % 16); + for (i32 i = 0; i < MIDDLE; ++i) { u[i] = NTLOAD(in[i * SIZEM]); } +} + +// NOTE: writeMiddleInLine uses the same definition of x,y as readMiddleInLine. Caller transposes 16x16 blocks of FFT data before calling writeMiddleInLine. +void OVERLOAD writeMiddleInLine (P(T2) out, T2 *u, u32 y, u32 x) +{ + out += (x / 16 * SIZEW) + (y % 16 * SIZEBLK) + (SWIZ(y % 16, y / 16) * 16) + (x % 16); + for (i32 i = 0; i < MIDDLE; ++i) { NTSTORE(out[i * SIZEM], u[i]); } +} + +//**************************************************************************************** +// Pair of routines to read/write data to/from tailSquare/Mul +//**************************************************************************************** + +// me ranges 0...SMALL_HEIGHT/NH-1 (multiples of one) +// u[i] ranges 0...NH-1 (big multiples of one) +// line ranges 0...MIDDLE*WIDTH-1 (multiples of SMALL_HEIGHT) + +// Read a line for tailSquare/Mul or fftHin +void OVERLOAD readTailFusedLine(CP(T2) in, T2 *u, u32 line, u32 me) { + u32 width = line % WIDTH; // Multiples of BIG_HEIGHT + u32 middle = line / WIDTH; // Multiples of SMALL_HEIGHT + in += (width / 16 * SIZEW) + (middle * SIZEM) + (width % 16 * SIZEBLK) + (me % 16); + for (i32 i = 0; i < NH; ++i) { u[i] = NTLOAD(in[SWIZ(width % 16, (i * SMALL_HEIGHT / NH + me) / 16) * 16]); } +} + +void OVERLOAD writeTailFusedLine(T2 *u, P(T2) out, u32 line, u32 me) { + u32 width = line % WIDTH; // Multiples of BIG_HEIGHT + u32 middle = line / WIDTH; // Multiples of SMALL_HEIGHT + out += (width / 16 * SIZEW) + (middle * SIZEM) + (width % 16 * SIZEBLK) + (me % 16); + for (i32 i = 0; i < NH; ++i) { NTSTORE(out[SWIZ(width % 16, (i * SMALL_HEIGHT / NH + me) / 16) * 16], u[i]); } +} + +//**************************************************************************************** +// Pair of routines to read/write data to/from fftMiddleOut +//**************************************************************************************** + +// x ranges 0...SMALL_HEIGHT-1 (multiples of one) +// u[i] ranges 0...MIDDLE-1 (multiples of SMALL_HEIGHT) +// y ranges 0...WIDTH-1 (multiples of BIG_HEIGHT) + +void OVERLOAD readMiddleOutLine(T2 *u, CP(T2) in, u32 y, u32 x) { + in += (y / 16 * SIZEW) + (y % 16 * SIZEBLK) + (SWIZ(y % 16, x / 16) * 16) + (x % 16); + for (i32 i = 0; i < MIDDLE; ++i) { u[i] = NTLOAD(in[i * SIZEM]); } +} + +// NOTE: writeMiddleOutLine uses the same definition of x,y as readMiddleOutLine. Caller transposes 16x16 blocks of FFT data before calling writeMiddleOutLine. +void OVERLOAD writeMiddleOutLine (P(T2) out, T2 *u, u32 y, u32 x) +{ + out += (y / 16 * SIZEW) + (y % 16 * SIZEBLK) + (SWIZ(y % 16, x / 16) * 16) + (x % 16); + for (i32 i = 0; i < MIDDLE; ++i) { NTSTORE(out[i * SIZEM], u[i]); } +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 || NTT_GF31 + +//**************************************************************************************** +// Pair of routines to read/write data to/from carryFused +//**************************************************************************************** + +// NOTE: I hven't studied best padding/swizzling/memory layout for 32-bit values. I'm assuming the 64-bit values scheme will be pretty good. +#if INPLACE == 1 // nVidia friendly padding +// Place middle rows after first 16 rows +//#define SIZEBLK32 (SMALL_HEIGHT + 16) // Pad 128 bytes +//#define SIZEW32 (16 * SIZEA + 16) // Pad 128 bytes +//#define SIZEM32 (MIDDLE * SIZEB + (1 - (MIDDLE & 1)) * 16) // Pad 128 bytes if MIDDLE is even +// Place middle rows after all width rows +#define SIZEBLK32 (SMALL_HEIGHT + 0) // No pad needed when swizzling +#define SIZEW32 (16 * SIZEBLK + 16) // Pad 128 bytes +#define SIZEM32 (WIDTH / 16 * SIZEW + 16) // Pad 128 bytes +#define SWIZ32(a,m) ((m) ^ (a)) // Swizzle 16 rows (remove "^ (a)" to turn swizzling off) +#else // AMD friendly padding +// Place middle rows after first 16 rows +//#define SIZEBLK32 (SMALL_HEIGHT + 16) // Pad 128 bytes +//#define SIZEM32 (16 * SIZEBLK + 16) // Pad 128 bytes +//#define SIZEW32 (MIDDLE * SIZEM + (1 - (MIDDLE & 1)) * 16) // Pad 128 bytes if MIDDLE is even +// Place middle rows after all width rows +#define SIZEBLK32 (SMALL_HEIGHT + 0) // No pad needed when swizzling +#define SIZEW32 (16 * SIZEBLK + 16) // Pad 128 bytes +#define SIZEM32 (WIDTH / 16 * SIZEW + 0) // Pad 0 bytes +#define SWIZ32(a,m) ((m) ^ (a)) // Swizzle 16 rows (remove "^ (a)" to turn swizzling off) +#endif + +// me ranges 0...WIDTH/NW-1 (multiples of BIG_HEIGHT) +// u[i] ranges 0..NW-1 (big multiples of BIG_HEIGHT) +// line ranges 0...BIG_HEIGHT-1 (multiples of one) + +// Read a line for carryFused or FFTW. This line was written by writeMiddleOutLine above. +void OVERLOAD readCarryFusedLine(CP(F2) in, F2 *u, u32 line) { + u32 me = get_local_id(0); // Multiples of BIG_HEIGHT + u32 middle = line / SMALL_HEIGHT; // Multiples of SMALL_HEIGHT + line = line % SMALL_HEIGHT; // Multiples of one + in += (me / 16 * SIZEW32) + (middle * SIZEM32) + (line % 16 * SIZEBLK32) + SWIZ32(line % 16, line / 16) * 16 + (me % 16); + for (u32 i = 0; i < NW; ++i) { u[i] = NTLOAD(in[i * G_W / 16 * SIZEW32]); } +} + +// Write a line from carryFused. This data will be read by fftMiddleIn. +void OVERLOAD writeCarryFusedLine(F2 *u, P(F2) out, u32 line) { + u32 me = get_local_id(0); // Multiples of BIG_HEIGHT + u32 middle = line / SMALL_HEIGHT; // Multiples of SMALL_HEIGHT + line = line % SMALL_HEIGHT; // Multiples of one + out += (me / 16 * SIZEW32) + (middle * SIZEM32) + (line % 16 * SIZEBLK32) + SWIZ32(line % 16, line / 16) * 16 + (me % 16); + for (i32 i = 0; i < NW; ++i) { NTSTORE(out[i * G_W / 16 * SIZEW32], u[i]); } +} + +//**************************************************************************************** +// Pair of routines to read/write data to/from fftMiddleIn +//**************************************************************************************** + +// x ranges 0...WIDTH-1 (multiples of BIG_HEIGHT) +// u[i] ranges 0...MIDDLE-1 (multiples of SMALL_HEIGHT) +// y ranges 0...SMALL_HEIGHT-1 (multiples of one) + +void OVERLOAD readMiddleInLine(F2 *u, CP(F2) in, u32 y, u32 x) { + in += (x / 16 * SIZEW32) + (y % 16 * SIZEBLK32) + (SWIZ32(y % 16, y / 16) * 16) + (x % 16); + for (i32 i = 0; i < MIDDLE; ++i) { u[i] = NTLOAD(in[i * SIZEM32]); } +} + +// NOTE: writeMiddleInLine uses the same definition of x,y as readMiddleInLine. Caller transposes 16x16 blocks of FFT data before calling writeMiddleInLine. +void OVERLOAD writeMiddleInLine (P(F2) out, F2 *u, u32 y, u32 x) +{ + out += (x / 16 * SIZEW32) + (y % 16 * SIZEBLK32) + (SWIZ32(y % 16, y / 16) * 16) + (x % 16); + for (i32 i = 0; i < MIDDLE; ++i) { NTSTORE(out[i * SIZEM32], u[i]); } +} + +//**************************************************************************************** +// Pair of routines to read/write data to/from tailSquare/Mul +//**************************************************************************************** + +// me ranges 0...SMALL_HEIGHT/NH-1 (multiples of one) +// u[i] ranges 0...NH-1 (big multiples of one) +// line ranges 0...MIDDLE*WIDTH-1 (multiples of SMALL_HEIGHT) + +// Read a line for tailSquare/Mul or fftHin +void OVERLOAD readTailFusedLine(CP(F2) in, F2 *u, u32 line, u32 me) { + u32 width = line % WIDTH; // Multiples of BIG_HEIGHT + u32 middle = line / WIDTH; // Multiples of SMALL_HEIGHT + in += (width / 16 * SIZEW32) + (middle * SIZEM32) + (width % 16 * SIZEBLK32) + (me % 16); + for (i32 i = 0; i < NH; ++i) { u[i] = NTLOAD(in[SWIZ32(width % 16, (i * SMALL_HEIGHT / NH + me) / 16) * 16]); } +} + +void OVERLOAD writeTailFusedLine(F2 *u, P(F2) out, u32 line, u32 me) { + u32 width = line % WIDTH; // Multiples of BIG_HEIGHT + u32 middle = line / WIDTH; // Multiples of SMALL_HEIGHT + out += (width / 16 * SIZEW32) + (middle * SIZEM32) + (width % 16 * SIZEBLK32) + (me % 16); + for (i32 i = 0; i < NH; ++i) { NTSTORE(out[SWIZ32(width % 16, (i * SMALL_HEIGHT / NH + me) / 16) * 16], u[i]); } +} + +//**************************************************************************************** +// Pair of routines to read/write data to/from fftMiddleOut +//**************************************************************************************** + +// x ranges 0...SMALL_HEIGHT-1 (multiples of one) +// u[i] ranges 0...MIDDLE-1 (multiples of SMALL_HEIGHT) +// y ranges 0...WIDTH-1 (multiples of BIG_HEIGHT) + +void OVERLOAD readMiddleOutLine(F2 *u, CP(F2) in, u32 y, u32 x) { + in += (y / 16 * SIZEW32) + (y % 16 * SIZEBLK32) + (SWIZ32(y % 16, x / 16) * 16) + (x % 16); + for (i32 i = 0; i < MIDDLE; ++i) { u[i] = NTLOAD(in[i * SIZEM32]); } +} + +// NOTE: writeMiddleOutLine uses the same definition of x,y as readMiddleOutLine. Caller transposes 16x16 blocks of FFT data before calling writeMiddleOutLine. +void OVERLOAD writeMiddleOutLine (P(F2) out, F2 *u, u32 y, u32 x) +{ + out += (y / 16 * SIZEW32) + (y % 16 * SIZEBLK32) + (SWIZ32(y % 16, x / 16) * 16) + (x % 16); + for (i32 i = 0; i < MIDDLE; ++i) { NTSTORE(out[i * SIZEM32], u[i]); } +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +// Since F2 and GF31 are the same size we can simply call the floats based code + +void OVERLOAD readCarryFusedLine(CP(GF31) in, GF31 *u, u32 line) { + readCarryFusedLine((CP(F2)) in, (F2 *) u, line); +} + +void OVERLOAD writeCarryFusedLine(GF31 *u, P(GF31) out, u32 line) { + writeCarryFusedLine((F2 *) u, (P(F2)) out, line); +} + +void OVERLOAD readMiddleInLine(GF31 *u, CP(GF31) in, u32 y, u32 x) { + readMiddleInLine((F2 *) u, (CP(F2)) in, y, x); +} + +void OVERLOAD writeMiddleInLine (P(GF31) out, GF31 *u, u32 y, u32 x) { + writeMiddleInLine ((P(F2)) out, (F2 *) u, y, x); +} + +void OVERLOAD readTailFusedLine(CP(GF31) in, GF31 *u, u32 line, u32 me) { + readTailFusedLine((CP(F2)) in, (F2 *) u, line, me); +} + +void OVERLOAD writeTailFusedLine(GF31 *u, P(GF31) out, u32 line, u32 me) { + writeTailFusedLine((F2 *) u, (P(F2)) out, line, me); +} + +void OVERLOAD readMiddleOutLine(GF31 *u, CP(GF31) in, u32 y, u32 x) { + readMiddleOutLine((F2 *) u, (CP(F2)) in, y, x); +} + +void OVERLOAD writeMiddleOutLine (P(GF31) out, GF31 *u, u32 y, u32 x) { + writeMiddleOutLine ((P(F2)) out, (F2 *) u, y, x); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +// Since T2 and GF61 are the same size we can simply call the doubles based code + +void OVERLOAD readCarryFusedLine(CP(GF61) in, GF61 *u, u32 line) { + readCarryFusedLine((CP(T2)) in, (T2 *) u, line); +} + +void OVERLOAD writeCarryFusedLine(GF61 *u, P(GF61) out, u32 line) { + writeCarryFusedLine((T2 *) u, (P(T2)) out, line); +} + +void OVERLOAD readMiddleInLine(GF61 *u, CP(GF61) in, u32 y, u32 x) { + readMiddleInLine((T2 *) u, (CP(T2)) in, y, x); +} + +void OVERLOAD writeMiddleInLine (P(GF61) out, GF61 *u, u32 y, u32 x) { + writeMiddleInLine ((P(T2)) out, (T2 *) u, y, x); +} + +void OVERLOAD readTailFusedLine(CP(GF61) in, GF61 *u, u32 line, u32 me) { + readTailFusedLine((CP(T2)) in, (T2 *) u, line, me); +} + +void OVERLOAD writeTailFusedLine(GF61 *u, P(GF61) out, u32 line, u32 me) { + writeTailFusedLine((T2 *) u, (P(T2)) out, line, me); +} + +void OVERLOAD readMiddleOutLine(GF61 *u, CP(GF61) in, u32 y, u32 x) { + readMiddleOutLine((T2 *) u, (CP(T2)) in, y, x); +} + +void OVERLOAD writeMiddleOutLine (P(GF61) out, GF61 *u, u32 y, u32 x) { + writeMiddleOutLine ((P(T2)) out, (T2 *) u, y, x); +} + +#endif + +#endif diff --git a/src/cl/selftest.cl b/src/cl/selftest.cl index 7e42b217..3564f131 100644 --- a/src/cl/selftest.cl +++ b/src/cl/selftest.cl @@ -152,6 +152,7 @@ KERNEL(32) testTime(int what, global i64* io) { #endif } +#if FFT_FP64 KERNEL(256) testFFT3(global double2* io) { T2 u[4]; @@ -243,24 +244,12 @@ KERNEL(256) testFFT13(global double2* io) { } } -KERNEL(256) testTrig(global double2* out) { - for (i32 k = get_global_id(0); k < ND / 8; k += get_global_size(0)) { -#if 0 - double angle = M_PI / (ND / 2) * k; - out[k] = U2(cos(angle), -sin(angle)); -#else - out[k] = slowTrig_N(k, ND/8); -#endif - } -} - -KERNEL(256) testFFT(global double2* io) { -#define SIZE 16 - double2 u[SIZE]; +KERNEL(256) testFFT14(global double2* io) { + double2 u[14]; if (get_global_id(0) == 0) { - for (int i = 0; i < SIZE; ++i) { u[i] = io[i]; } - fft16(u); - for (int i = 0; i < SIZE; ++i) { io[i] = u[i]; } + for (int i = 0; i < 14; ++i) { u[i] = io[i]; } + fft14(u); + for (int i = 0; i < 14; ++i) { io[i] = u[i]; } } } @@ -273,11 +262,25 @@ KERNEL(256) testFFT15(global double2* io) { } } -KERNEL(256) testFFT14(global double2* io) { - double2 u[14]; +KERNEL(256) testFFT(global double2* io) { +#define SIZE 16 + double2 u[SIZE]; if (get_global_id(0) == 0) { - for (int i = 0; i < 14; ++i) { u[i] = io[i]; } - fft14(u); - for (int i = 0; i < 14; ++i) { io[i] = u[i]; } + for (int i = 0; i < SIZE; ++i) { u[i] = io[i]; } + fft16(u); + for (int i = 0; i < SIZE; ++i) { io[i] = u[i]; } } } + +KERNEL(256) testTrig(global double2* out) { + for (i32 k = get_global_id(0); k < ND / 8; k += get_global_size(0)) { +#if 0 + double angle = M_PI / (ND / 2) * k; + out[k] = U2(cos(angle), -sin(angle)); +#else + out[k] = slowTrig_N(k, ND/8); +#endif + } +} + +#endif diff --git a/src/cl/tailmul.cl b/src/cl/tailmul.cl index 8266f411..1cdd5db0 100644 --- a/src/cl/tailmul.cl +++ b/src/cl/tailmul.cl @@ -5,54 +5,45 @@ #include "trig.cl" #include "fftheight.cl" -// Why does this alternate implementation work? Let t' be the conjugate of t and note that t*t' = 1. -// Now consider these lines from the original implementation (comments appear alongside): -// b = mul_by_conjugate(b, t); -// X2(a, b); a + bt', a - bt' -// d = mul_by_conjugate(d, t); -// X2(c, d); c + dt', c - dt' -// a = mul(a, c); (a+bt')(c+dt') = ac + bct' + adt' + bdt'^2 -// b = mul(b, d); (a-bt')(c-dt') = ac - bct' - adt' + bdt'^2 -// X2(a, b); 2ac + 2bdt'^2, 2bct' + 2adt' -// b = mul(b, t); 2bc + 2ad - -void onePairMul(T2* pa, T2* pb, T2* pc, T2* pd, T2 conjugate_t_squared) { +#if FFT_FP64 + +// Handle the final multiplication step on a pair of complex numbers. Swap real and imaginary results for the inverse FFT. +// We used to conjugate the results, but swapping real and imaginary can save some negations in carry propagation. + +void OVERLOAD onePairMul(T2* pa, T2* pb, T2* pc, T2* pd, T2 t_squared) { T2 a = *pa, b = *pb, c = *pc, d = *pd; X2conjb(a, b); X2conjb(c, d); - T2 tmp = a; + *pa = cfma(a, c, cmul(cmul(b, d), -t_squared)); + *pb = cfma(b, c, cmul(a, d)); - a = cfma(a, c, cmul(cmul(b, d), conjugate_t_squared)); - b = cfma(b, c, cmul(tmp, d)); + X2_conjb(*pa, *pb); - X2conja(a, b); - - *pa = a; - *pb = b; + *pa = SWAP_XY(*pa), *pb = SWAP_XY(*pb); } -void pairMul(u32 N, T2 *u, T2 *v, T2 *p, T2 *q, T2 base_squared, bool special) { +void OVERLOAD pairMul(u32 N, T2 *u, T2 *v, T2 *p, T2 *q, T2 base_squared, bool special) { u32 me = get_local_id(0); for (i32 i = 0; i < NH / 4; ++i, base_squared = mul_t8(base_squared)) { if (special && i == 0 && me == 0) { - u[i] = conjugate(2 * foo2(u[i], p[i])); - v[i] = 4 * cmul(conjugate(v[i]), conjugate(q[i])); + u[i] = SWAP_XY(2 * foo2(u[i], p[i])); + v[i] = SWAP_XY(4 * cmul(v[i], q[i])); } else { - onePairMul(&u[i], &v[i], &p[i], &q[i], -base_squared); + onePairMul(&u[i], &v[i], &p[i], &q[i], base_squared); } if (N == NH) { - onePairMul(&u[i+NH/2], &v[i+NH/2], &p[i+NH/2], &q[i+NH/2], base_squared); + onePairMul(&u[i+NH/2], &v[i+NH/2], &p[i+NH/2], &q[i+NH/2], -base_squared); } T2 new_base_squared = mul_t4(base_squared); - onePairMul(&u[i+NH/4], &v[i+NH/4], &p[i+NH/4], &q[i+NH/4], -new_base_squared); + onePairMul(&u[i+NH/4], &v[i+NH/4], &p[i+NH/4], &q[i+NH/4], new_base_squared); if (N == NH) { - onePairMul(&u[i+3*NH/4], &v[i+3*NH/4], &p[i+3*NH/4], &q[i+3*NH/4], new_base_squared); + onePairMul(&u[i+3*NH/4], &v[i+3*NH/4], &p[i+3*NH/4], &q[i+3*NH/4], -new_base_squared); } } } @@ -100,12 +91,7 @@ KERNEL(G_H) tailMul(P(T2) out, CP(T2) in, CP(T2) a, Trig smallTrig) { T2 trig = slowTrig_N(line1 + me * H, ND / NH); - if (line1) { - reverseLine(G_H, lds, v); - reverseLine(G_H, lds, q); - pairMul(NH, u, v, p, q, trig, false); - reverseLine(G_H, lds, v); - } else { + if (line1 == 0) { reverse(G_H, lds, u + NH/2, true); reverse(G_H, lds, p + NH/2, true); pairMul(NH/2, u, u + NH/2, p, p + NH/2, trig, true); @@ -116,6 +102,11 @@ KERNEL(G_H) tailMul(P(T2) out, CP(T2) in, CP(T2) a, Trig smallTrig) { reverse(G_H, lds, q + NH/2, false); pairMul(NH/2, v, v + NH/2, q, q + NH/2, trig2, false); reverse(G_H, lds, v + NH/2, false); + } else { + reverseLine(G_H, lds, v); + reverseLine(G_H, lds, q); + pairMul(NH, u, v, p, q, trig, false); + reverseLine(G_H, lds, v); } bar(); @@ -125,3 +116,379 @@ KERNEL(G_H) tailMul(P(T2) out, CP(T2) in, CP(T2) a, Trig smallTrig) { writeTailFusedLine(v, out, memline2, me); writeTailFusedLine(u, out, memline1, me); } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 + +// Handle the final multiplication step on a pair of complex numbers. Swap real and imaginary results for the inverse FFT. +// We used to conjugate the results, but swapping real and imaginary can save some negations in carry propagation. + +void OVERLOAD onePairMul(F2* pa, F2* pb, F2* pc, F2* pd, F2 t_squared) { + F2 a = *pa, b = *pb, c = *pc, d = *pd; + X2conjb(a, b); + X2conjb(c, d); + *pa = cfma(a, c, cmul(cmul(b, d), -t_squared)); + *pb = cfma(b, c, cmul(a, d)); + X2_conjb(*pa, *pb); + *pa = SWAP_XY(*pa), *pb = SWAP_XY(*pb); +} + +void OVERLOAD pairMul(u32 N, F2 *u, F2 *v, F2 *p, F2 *q, F2 base_squared, bool special) { + u32 me = get_local_id(0); + + for (i32 i = 0; i < NH / 4; ++i, base_squared = mul_t8(base_squared)) { + if (special && i == 0 && me == 0) { + u[i] = SWAP_XY(2 * foo2(u[i], p[i])); + v[i] = SWAP_XY(4 * cmul(v[i], q[i])); + } else { + onePairMul(&u[i], &v[i], &p[i], &q[i], base_squared); + } + + if (N == NH) { + onePairMul(&u[i+NH/2], &v[i+NH/2], &p[i+NH/2], &q[i+NH/2], -base_squared); + } + + F2 new_base_squared = mul_t4(base_squared); + onePairMul(&u[i+NH/4], &v[i+NH/4], &p[i+NH/4], &q[i+NH/4], new_base_squared); + + if (N == NH) { + onePairMul(&u[i+3*NH/4], &v[i+3*NH/4], &p[i+3*NH/4], &q[i+3*NH/4], -new_base_squared); + } + } +} + +KERNEL(G_H) tailMul(P(T2) out, CP(T2) in, CP(T2) a, Trig smallTrig) { + local F2 lds[SMALL_HEIGHT]; + + CP(F2) inF2 = (CP(F2)) in; + CP(F2) aF2 = (CP(F2)) a; + P(F2) outF2 = (P(F2)) out; + TrigFP32 smallTrigF2 = (TrigFP32) smallTrig; + + F2 u[NH], v[NH]; + F2 p[NH], q[NH]; + + u32 H = ND / SMALL_HEIGHT; + + u32 line1 = get_group_id(0); + u32 line2 = line1 ? H - line1 : (H / 2); + u32 memline1 = transPos(line1, MIDDLE, WIDTH); + u32 memline2 = transPos(line2, MIDDLE, WIDTH); + + u32 me = get_local_id(0); + readTailFusedLine(inF2, u, line1, me); + readTailFusedLine(inF2, v, line2, me); + +#if MUL_LOW + read(G_H, NH, p, aF2, memline1 * SMALL_HEIGHT); + read(G_H, NH, q, aF2, memline2 * SMALL_HEIGHT); + fft_HEIGHT(lds, u, smallTrigF2); + bar(); + fft_HEIGHT(lds, v, smallTrigF2); +#else + readTailFusedLine(aF2, p, line1, me); + readTailFusedLine(aF2, q, line2, me); + fft_HEIGHT(lds, u, smallTrigF2); + bar(); + fft_HEIGHT(lds, v, smallTrigF2); + bar(); + fft_HEIGHT(lds, p, smallTrigF2); + bar(); + fft_HEIGHT(lds, q, smallTrigF2); +#endif + + F2 trig = slowTrig_N(line1 + me * H, ND / NH); + + if (line1 == 0) { + reverse(G_H, lds, u + NH/2, true); + reverse(G_H, lds, p + NH/2, true); + pairMul(NH/2, u, u + NH/2, p, p + NH/2, trig, true); + reverse(G_H, lds, u + NH/2, true); + + F2 trig2 = cmulFancy(trig, TAILT); + reverse(G_H, lds, v + NH/2, false); + reverse(G_H, lds, q + NH/2, false); + pairMul(NH/2, v, v + NH/2, q, q + NH/2, trig2, false); + reverse(G_H, lds, v + NH/2, false); + } else { + reverseLine(G_H, lds, v); + reverseLine(G_H, lds, q); + pairMul(NH, u, v, p, q, trig, false); + reverseLine(G_H, lds, v); + } + + bar(); + fft_HEIGHT(lds, v, smallTrigF2); + bar(); + fft_HEIGHT(lds, u, smallTrigF2); + writeTailFusedLine(v, outF2, memline2, me); + writeTailFusedLine(u, outF2, memline1, me); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +void OVERLOAD onePairMul(GF31* pa, GF31* pb, GF31* pc, GF31* pd, GF31 t_squared) { + GF31 a = *pa, b = *pb, c = *pc, d = *pd; + + X2conjb(a, b); + X2conjb(c, d); + + *pa = sub(cmul(a, c), cmul(cmul(b, d), t_squared)); + *pb = add(cmul(b, c), cmul(a, d)); + + X2_conjb(*pa, *pb); + *pa = SWAP_XY(*pa), *pb = SWAP_XY(*pb); +} + +void OVERLOAD pairMul(u32 N, GF31 *u, GF31 *v, GF31 *p, GF31 *q, GF31 base_squared, bool special) { + u32 me = get_local_id(0); + + for (i32 i = 0; i < NH / 4; ++i, base_squared = mul_t8(base_squared)) { + if (special && i == 0 && me == 0) { + u[i] = SWAP_XY(mul2(foo2(u[i], p[i]))); + v[i] = SWAP_XY(shl(cmul(v[i], q[i]), 2)); + } else { + onePairMul(&u[i], &v[i], &p[i], &q[i], base_squared); + } + + if (N == NH) { + onePairMul(&u[i+NH/2], &v[i+NH/2], &p[i+NH/2], &q[i+NH/2], neg(base_squared)); + } + + GF31 new_base_squared = mul_t4(base_squared); + onePairMul(&u[i+NH/4], &v[i+NH/4], &p[i+NH/4], &q[i+NH/4], new_base_squared); + + if (N == NH) { + onePairMul(&u[i+3*NH/4], &v[i+3*NH/4], &p[i+3*NH/4], &q[i+3*NH/4], neg(new_base_squared)); + } + } +} + +KERNEL(G_H) tailMulGF31(P(T2) out, CP(T2) in, CP(T2) a, Trig smallTrig) { + local GF31 lds[SMALL_HEIGHT]; + + CP(GF31) in31 = (CP(GF31)) (in + DISTGF31); + CP(GF31) a31 = (CP(GF31)) (a + DISTGF31); + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 smallTrig31 = (TrigGF31) (smallTrig + DISTHTRIGGF31); + + GF31 u[NH], v[NH]; + GF31 p[NH], q[NH]; + + u32 H = ND / SMALL_HEIGHT; + + u32 line1 = get_group_id(0); + u32 line2 = line1 ? H - line1 : (H / 2); + u32 memline1 = transPos(line1, MIDDLE, WIDTH); + u32 memline2 = transPos(line2, MIDDLE, WIDTH); + + u32 me = get_local_id(0); + readTailFusedLine(in31, u, line1, me); + readTailFusedLine(in31, v, line2, me); + +#if MUL_LOW + read(G_H, NH, p, a31, memline1 * SMALL_HEIGHT); + read(G_H, NH, q, a31, memline2 * SMALL_HEIGHT); + fft_HEIGHT(lds, u, smallTrig31); + bar(); + fft_HEIGHT(lds, v, smallTrig31); +#else + readTailFusedLine(a31, p, line1, me); + readTailFusedLine(a31, q, line2, me); + fft_HEIGHT(lds, u, smallTrig31); + bar(); + fft_HEIGHT(lds, v, smallTrig31); + bar(); + fft_HEIGHT(lds, p, smallTrig31); + bar(); + fft_HEIGHT(lds, q, smallTrig31); +#endif + + // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) + // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. + u32 height_trigs = SMALL_HEIGHT*1; +#if TAIL_TRIGS31 >= 1 + GF31 trig = smallTrig31[height_trigs + me]; // Trig values for line zero, should be cached +#if SINGLE_WIDE + GF31 mult = smallTrig31[height_trigs + G_H + line1]; +#else + GF31 mult = smallTrig31[height_trigs + G_H + line1 * 2]; +#endif + trig = cmul(trig, mult); +#else +#if SINGLE_WIDE + GF31 trig = NTLOAD(smallTrig31[height_trigs + line1*G_H + me]); +#else + GF31 trig = NTLOAD(smallTrig31[height_trigs + line1*2*G_H + me]); +#endif +#endif + + if (line1 == 0) { + reverse(G_H, lds, u + NH/2, true); + reverse(G_H, lds, p + NH/2, true); + pairMul(NH/2, u, u + NH/2, p, p + NH/2, trig, true); + reverse(G_H, lds, u + NH/2, true); + + GF31 trig2 = cmul(trig, TAILTGF31); + reverse(G_H, lds, v + NH/2, false); + reverse(G_H, lds, q + NH/2, false); + pairMul(NH/2, v, v + NH/2, q, q + NH/2, trig2, false); + reverse(G_H, lds, v + NH/2, false); + } else { + reverseLine(G_H, lds, v); + reverseLine(G_H, lds, q); + pairMul(NH, u, v, p, q, trig, false); + reverseLine(G_H, lds, v); + } + + bar(); + fft_HEIGHT(lds, v, smallTrig31); + bar(); + fft_HEIGHT(lds, u, smallTrig31); + writeTailFusedLine(v, out31, memline2, me); + writeTailFusedLine(u, out31, memline1, me); +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +void OVERLOAD onePairMul(GF61* pa, GF61* pb, GF61* pc, GF61* pd, GF61 t_squared) { + GF61 a = *pa, b = *pb, c = *pc, d = *pd; + + X2conjb(a, b); + X2conjb(c, d); + GF61 e = subq(cmul(a, c), cmul(cmul(b, d), t_squared), 2); // Max value is 3*M61+epsilon + GF61 f = addq(cmul(b, c), cmul(a, d)); // Max value is 2*M61+epsilon + X2s_conjb(&e, &f, 4, 3); + *pa = SWAP_XY(e), *pb = SWAP_XY(f); +} + +void OVERLOAD pairMul(u32 N, GF61 *u, GF61 *v, GF61 *p, GF61 *q, GF61 base_squared, bool special) { + u32 me = get_local_id(0); + + for (i32 i = 0; i < NH / 4; ++i, base_squared = mul_t8(base_squared)) { + if (special && i == 0 && me == 0) { + u[i] = SWAP_XY(mul2(foo2(u[i], p[i]))); + v[i] = SWAP_XY(shl(cmul(v[i], q[i]), 2)); + } else { + onePairMul(&u[i], &v[i], &p[i], &q[i], base_squared); + } + + if (N == NH) { + onePairMul(&u[i+NH/2], &v[i+NH/2], &p[i+NH/2], &q[i+NH/2], neg(base_squared)); + } + + GF61 new_base_squared = mul_t4(base_squared); + onePairMul(&u[i+NH/4], &v[i+NH/4], &p[i+NH/4], &q[i+NH/4], new_base_squared); + + if (N == NH) { + onePairMul(&u[i+3*NH/4], &v[i+3*NH/4], &p[i+3*NH/4], &q[i+3*NH/4], neg(new_base_squared)); + } + } +} + +KERNEL(G_H) tailMulGF61(P(T2) out, CP(T2) in, CP(T2) a, Trig smallTrig) { + local GF61 lds[SMALL_HEIGHT]; + + CP(GF61) in61 = (CP(GF61)) (in + DISTGF61); + CP(GF61) a61 = (CP(GF61)) (a + DISTGF61); + P(GF61) out61 = (P(GF61)) (out + DISTGF61); + TrigGF61 smallTrig61 = (TrigGF61) (smallTrig + DISTHTRIGGF61); + + GF61 u[NH], v[NH]; + GF61 p[NH], q[NH]; + + u32 H = ND / SMALL_HEIGHT; + + u32 line1 = get_group_id(0); + u32 line2 = line1 ? H - line1 : (H / 2); + u32 memline1 = transPos(line1, MIDDLE, WIDTH); + u32 memline2 = transPos(line2, MIDDLE, WIDTH); + + u32 me = get_local_id(0); + readTailFusedLine(in61, u, line1, me); + readTailFusedLine(in61, v, line2, me); + +#if MUL_LOW + read(G_H, NH, p, a61, memline1 * SMALL_HEIGHT); + read(G_H, NH, q, a61, memline2 * SMALL_HEIGHT); + fft_HEIGHT(lds, u, smallTrig61); + bar(); + fft_HEIGHT(lds, v, smallTrig61); +#else + readTailFusedLine(a61, p, line1, me); + readTailFusedLine(a61, q, line2, me); + fft_HEIGHT(lds, u, smallTrig61); + bar(); + fft_HEIGHT(lds, v, smallTrig61); + bar(); + fft_HEIGHT(lds, p, smallTrig61); + bar(); + fft_HEIGHT(lds, q, smallTrig61); +#endif + + // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) + // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. + u32 height_trigs = SMALL_HEIGHT*1; +#if TAIL_TRIGS61 >= 1 + GF61 trig = smallTrig61[height_trigs + me]; // Trig values for line zero, should be cached +#if SINGLE_WIDE + GF61 mult = smallTrig61[height_trigs + G_H + line1]; +#else + GF61 mult = smallTrig61[height_trigs + G_H + line1 * 2]; +#endif + trig = cmul(trig, mult); +#else +#if SINGLE_WIDE + GF61 trig = NTLOAD(smallTrig61[height_trigs + line1*G_H + me]); +#else + GF61 trig = NTLOAD(smallTrig61[height_trigs + line1*2*G_H + me]); +#endif +#endif + + if (line1 == 0) { + reverse(G_H, lds, u + NH/2, true); + reverse(G_H, lds, p + NH/2, true); + pairMul(NH/2, u, u + NH/2, p, p + NH/2, trig, true); + reverse(G_H, lds, u + NH/2, true); + + GF61 trig2 = cmul(trig, TAILTGF61); + reverse(G_H, lds, v + NH/2, false); + reverse(G_H, lds, q + NH/2, false); + pairMul(NH/2, v, v + NH/2, q, q + NH/2, trig2, false); + reverse(G_H, lds, v + NH/2, false); + } else { + reverseLine(G_H, lds, v); + reverseLine(G_H, lds, q); + pairMul(NH, u, v, p, q, trig, false); + reverseLine(G_H, lds, v); + } + + bar(); + fft_HEIGHT(lds, v, smallTrig61); + bar(); + fft_HEIGHT(lds, u, smallTrig61); + writeTailFusedLine(v, out61, memline2, me); + writeTailFusedLine(u, out61, memline1, me); +} + +#endif diff --git a/src/cl/tailsquare.cl b/src/cl/tailsquare.cl index 30105619..bf960f1c 100644 --- a/src/cl/tailsquare.cl +++ b/src/cl/tailsquare.cl @@ -4,71 +4,49 @@ #include "trig.cl" #include "fftheight.cl" -// TAIL_TRIGS setting: -// 2 = No memory accesses, trig values computed from scratch. Good for excellent DP GPUs such as Titan V or Radeon VII Pro. -// 1 = Limited memory accesses and some DP computation. Tuned for Radeon VII a GPU with good DP performance. -// 0 = No DP computation. Trig vaules read from memory. Good for GPUs with poor DP performance (a typical consumer grade GPU). -#if !defined(TAIL_TRIGS) -#define TAIL_TRIGS 2 // Default is compute trig values from scratch -#endif - -// TAIL_KERNELS setting: -// 0 = single wide, single kernel -// 1 = single wide, two kernels -// 2 = double wide, single kernel -// 3 = double wide, two kernels -#if !defined(TAIL_KERNELS) -#define TAIL_KERNELS 2 // Default is double-wide tailSquare with two kernels -#endif -#define SINGLE_WIDE TAIL_KERNELS < 2 // Old single-wide tailSquare vs. new double-wide tailSquare -#define SINGLE_KERNEL (TAIL_KERNELS & 1) == 0 // TailSquare uses a single kernel vs. two kernels - -// Why does this alternate implementation work? Let t' be the conjugate of t and note that t*t' = 1. -// Now consider these lines from the original implementation (comments appear alongside): -// b = mul_by_conjugate(b, t); bt' -// X2(a, b); a + bt', a - bt' -// a = sq(a); a^2 + 2abt' + (bt')^2 -// b = sq(b); a^2 - 2abt' + (bt')^2 -// X2(a, b); 2a^2 + 2(bt')^2, 4abt' -// b = mul(b, t); 4ab - -void onePairSq(T2* pa, T2* pb, T2 conjugate_t_squared) { +#if FFT_FP64 + +// Handle the final squaring step on a pair of complex numbers. Swap real and imaginary results for the inverse FFT. +// We used to conjugate the results, but swapping real and imaginary can save some negations in carry propagation. +void OVERLOAD onePairSq(T2* pa, T2* pb, T2 t_squared) { T2 a = *pa; T2 b = *pb; // X2conjb(a, b); // *pb = mul2(cmul(a, b)); -// *pa = csqa(a, cmul(csq(b), conjugate_t_squared)); -// X2conja(*pa, *pb); +// *pa = csqa(a, cmul(csq(b), -t_squared)); +// X2_conjb(*pa, *pb); +// *pa = SWAP_XY(*pa), *pb = SWAP_XY(*pb) // Less readable version of the above that saves one complex add by using FMA instructions X2conjb(a, b); - T2 minusnewb = mulminus2(cmul(a, b)); // -newb = -2ab - *pb = csqa(a, cfma(csq(b), conjugate_t_squared, minusnewb)); // final b = newa - newb = a^2 + (bt')^2 - newb - (*pa).x = fma(-2.0, minusnewb.x, (*pb).x); // final a = newa + newb = finalb + 2 * newb - (*pa).y = fma(2.0, minusnewb.y, -(*pb).y); // conjugate(final a) + T2 twoab = mul2(cmul(a, b)); // 2ab + *pa = csqa(a, cfma(csq(b), -t_squared, twoab)); // final a = a^2 + 2ab - (bt)^2 + (*pb).x = fma(-2.0, twoab.x, (*pa).x); // final b = a^2 - 2ab - (bt)^2 + (*pb).y = fma(2.0, twoab.y, -(*pa).y); // conjugate(final b) + *pa = SWAP_XY(*pa), *pb = SWAP_XY(*pb); } -void pairSq(u32 N, T2 *u, T2 *v, T2 base_squared, bool special) { +void OVERLOAD pairSq(u32 N, T2 *u, T2 *v, T2 base_squared, bool special) { u32 me = get_local_id(0); for (i32 i = 0; i < NH / 4; ++i, base_squared = mul_t8(base_squared)) { if (special && i == 0 && me == 0) { - u[i] = 2 * foo(conjugate(u[i])); - v[i] = 4 * csq(conjugate(v[i])); + u[i] = SWAP_XY(2 * foo(u[i])); + v[i] = SWAP_XY(4 * csq(v[i])); } else { - onePairSq(&u[i], &v[i], -base_squared); + onePairSq(&u[i], &v[i], base_squared); } if (N == NH) { - onePairSq(&u[i+NH/2], &v[i+NH/2], base_squared); + onePairSq(&u[i+NH/2], &v[i+NH/2], -base_squared); } T2 new_base_squared = mul_t4(base_squared); - onePairSq(&u[i+NH/4], &v[i+NH/4], -new_base_squared); + onePairSq(&u[i+NH/4], &v[i+NH/4], new_base_squared); if (N == NH) { - onePairSq(&u[i+3*NH/4], &v[i+3*NH/4], new_base_squared); + onePairSq(&u[i+3*NH/4], &v[i+3*NH/4], -new_base_squared); } } } @@ -136,9 +114,10 @@ KERNEL(G_H) tailSquare(P(T2) out, CP(T2) in, Trig smallTrig) { #endif #if ZEROHACK_H - fft_HEIGHT(lds + (get_group_id(0) / 131072), u, smallTrig + (get_group_id(0) / 131072), w); + u32 zerohack = (u32) get_group_id(0) / 131072; + fft_HEIGHT(lds + zerohack, u, smallTrig + zerohack, w); bar(); - fft_HEIGHT(lds + (get_group_id(0) / 131072), v, smallTrig + (get_group_id(0) / 131072), w); + fft_HEIGHT(lds + zerohack, v, smallTrig + zerohack, w); #else fft_HEIGHT(lds, u, smallTrig, w); bar(); @@ -208,17 +187,17 @@ KERNEL(G_H) tailSquare(P(T2) out, CP(T2) in, Trig smallTrig) { #else // Special pairSq for double-wide line 0 -void pairSq2_special(T2 *u, T2 base_squared) { +void OVERLOAD pairSq2_special(T2 *u, T2 base_squared) { u32 me = get_local_id(0); for (i32 i = 0; i < NH / 4; ++i, base_squared = mul_t8(base_squared)) { if (i == 0 && me == 0) { - u[0] = 2 * foo(conjugate(u[0])); - u[NH/2] = 4 * csq(conjugate(u[NH/2])); + u[0] = SWAP_XY(2 * foo(u[0])); + u[NH/2] = SWAP_XY(4 * csq(u[NH/2])); } else { - onePairSq(&u[i], &u[NH/2+i], -base_squared); + onePairSq(&u[i], &u[NH/2+i], base_squared); } T2 new_base_squared = mul_t4(base_squared); - onePairSq(&u[i+NH/4], &u[NH/2+i+NH/4], -new_base_squared); + onePairSq(&u[i+NH/4], &u[NH/2+i+NH/4], new_base_squared); } } @@ -239,10 +218,10 @@ KERNEL(G_H * 2) tailSquare(P(T2) out, CP(T2) in, Trig smallTrig) { u32 me = get_local_id(0); u32 lowMe = me % G_H; // lane-id in one of the two halves (half-workgroups). - + // We're going to call the halves "first-half" and "second-half". bool isSecondHalf = me >= G_H; - + u32 line = !isSecondHalf ? line_u : line_v; // Read lines u and v @@ -255,7 +234,8 @@ KERNEL(G_H * 2) tailSquare(P(T2) out, CP(T2) in, Trig smallTrig) { #endif #if ZEROHACK_H - new_fft_HEIGHT2_1(lds + (get_group_id(0) / 131072), u, smallTrig + (get_group_id(0) / 131072), w); + u32 zerohack = (u32) get_group_id(0) / 131072; + new_fft_HEIGHT2_1(lds + zerohack, u, smallTrig + zerohack, w); #else new_fft_HEIGHT2_1(lds, u, smallTrig, w); #endif @@ -298,9 +278,7 @@ KERNEL(G_H * 2) tailSquare(P(T2) out, CP(T2) in, Trig smallTrig) { #endif revCrossLine(G_H, lds, u + NH/2, NH/2, isSecondHalf); pairSq(NH/2, u, u + NH/2, trig, false); - bar(G_H); - // We change the LDS halves we're using in order to enable half-bars revCrossLine(G_H, lds, u + NH/2, NH/2, !isSecondHalf); } @@ -313,3 +291,873 @@ KERNEL(G_H * 2) tailSquare(P(T2) out, CP(T2) in, Trig smallTrig) { } #endif + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 + +// Handle the final squaring step on a pair of complex numbers. Swap real and imaginary results for the inverse FFT. +// We used to conjugate the results, but swapping real and imaginary can save some negations in carry propagation. +void OVERLOAD onePairSq(F2* pa, F2* pb, F2 t_squared) { + F2 a = *pa; + F2 b = *pb; + + X2conjb(a, b); + F2 twoab = mul2(cmul(a, b)); // 2ab + *pa = csqa(a, cfma(csq(b), -t_squared, twoab)); // final a = a^2 + 2ab - (bt)^2 + (*pb).x = fma(-2.0f, twoab.x, (*pa).x); // final b = a^2 - 2ab - (bt)^2 + (*pb).y = fma(2.0f, twoab.y, -(*pa).y); // conjugate(final b) + *pa = SWAP_XY(*pa), *pb = SWAP_XY(*pb); +} + +void OVERLOAD pairSq(u32 N, F2 *u, F2 *v, F2 base_squared, bool special) { + u32 me = get_local_id(0); + + for (i32 i = 0; i < NH / 4; ++i, base_squared = mul_t8(base_squared)) { + if (special && i == 0 && me == 0) { + u[i] = SWAP_XY(2 * foo(u[i])); + v[i] = SWAP_XY(4 * csq(v[i])); + } else { + onePairSq(&u[i], &v[i], base_squared); + } + + if (N == NH) { + onePairSq(&u[i+NH/2], &v[i+NH/2], -base_squared); + } + + F2 new_base_squared = mul_t4(base_squared); + onePairSq(&u[i+NH/4], &v[i+NH/4], new_base_squared); + + if (N == NH) { + onePairSq(&u[i+3*NH/4], &v[i+3*NH/4], -new_base_squared); + } + } +} + +// The kernel tailSquareZero handles the special cases in tailSquare, i.e. the lines 0 and H/2 +// This kernel is launched with 2 workgroups (handling line 0, resp. H/2) +KERNEL(G_H) tailSquareZero(P(T2) out, CP(T2) in, Trig smallTrig) { + local F2 lds[SMALL_HEIGHT / 2]; + F2 u[NH]; + u32 H = ND / SMALL_HEIGHT; + + CP(F2) inF2 = (CP(F2)) in; + P(F2) outF2 = (P(F2)) out; + TrigFP32 smallTrigF2 = (TrigFP32) smallTrig; + + // This kernel in executed in two workgroups. + u32 which = get_group_id(0); + assert(which < 2); + + u32 line = which ? (H/2) : 0; + u32 me = get_local_id(0); + readTailFusedLine(inF2, u, line, me); + + F2 trig = slowTrig_N(line + me * H, ND / NH); + + fft_HEIGHT(lds, u, smallTrigF2); + reverse(G_H, lds, u + NH/2, !which); + pairSq(NH/2, u, u + NH/2, trig, !which); + reverse(G_H, lds, u + NH/2, !which); + + bar(); + fft_HEIGHT(lds, u, smallTrigF2); + writeTailFusedLine(u, outF2, transPos(line, MIDDLE, WIDTH), me); +} + +#if SINGLE_WIDE + +KERNEL(G_H) tailSquare(P(T2) out, CP(T2) in, Trig smallTrig) { + local F2 lds[SMALL_HEIGHT]; + + CP(F2) inF2 = (CP(F2)) in; + P(F2) outF2 = (P(F2)) out; + TrigFP32 smallTrigF2 = (TrigFP32) smallTrig; + + F2 u[NH], v[NH]; + + u32 H = ND / SMALL_HEIGHT; + +#if SINGLE_KERNEL + u32 line1 = get_group_id(0); + u32 line2 = line1 ? H - line1 : (H / 2); +#else + u32 line1 = get_group_id(0) + 1; + u32 line2 = H - line1; +#endif + u32 memline1 = transPos(line1, MIDDLE, WIDTH); + u32 memline2 = transPos(line2, MIDDLE, WIDTH); + + u32 me = get_local_id(0); + readTailFusedLine(inF2, u, line1, me); + readTailFusedLine(inF2, v, line2, me); + +#if ZEROHACK_H + u32 zerohack = get_group_id(0) / 131072; + fft_HEIGHT(lds + zerohack, u, smallTrigF2 + zerohack); + bar(); + fft_HEIGHT(lds + zerohack, v, smallTrigF2 + zerohack); +#else + fft_HEIGHT(lds, u, smallTrigF2); + bar(); + fft_HEIGHT(lds, v, smallTrigF2); +#endif + + // Compute trig values from scratch. Good on GPUs with high DP throughput. +#if TAIL_TRIGS32 == 2 + F2 trig = slowTrig_N(line1 + me * H, ND / NH); + + // Do a little bit of memory access and a little bit of DP math. Good on a Radeon VII. +#elif TAIL_TRIGS32 == 1 + // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) + // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. + u32 height_trigs = SMALL_HEIGHT*1; + // Read a hopefully cached line of data and one non-cached F2 per line + F2 trig = smallTrigF2[height_trigs + me]; // Trig values for line zero, should be cached + F2 mult = smallTrigF2[height_trigs + G_H + line1]; // Line multiplier + trig = cmulFancy(trig, mult); + + // On consumer-grade GPUs, it is likely beneficial to read all trig values. +#else + // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) + // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. + u32 height_trigs = SMALL_HEIGHT*1; + // Read pre-computed trig values + F2 trig = NTLOAD(smallTrigF2[height_trigs + line1*G_H + me]); +#endif + +#if SINGLE_KERNEL + if (line1 == 0) { + // Line 0 is special: it pairs with itself, offseted by 1. + reverse(G_H, lds, u + NH/2, true); + pairSq(NH/2, u, u + NH/2, trig, true); + reverse(G_H, lds, u + NH/2, true); + + // Line H/2 also pairs with itself (but without offset). + F2 trig2 = cmulFancy(trig, TAILT); + reverse(G_H, lds, v + NH/2, false); + pairSq(NH/2, v, v + NH/2, trig2, false); + reverse(G_H, lds, v + NH/2, false); + } + else { +#else + if (1) { +#endif + reverseLine(G_H, lds, v); + pairSq(NH, u, v, trig, false); + reverseLine(G_H, lds, v); + } + + bar(); + fft_HEIGHT(lds, v, smallTrigF2); + bar(); + fft_HEIGHT(lds, u, smallTrigF2); + + writeTailFusedLine(v, outF2, memline2, me); + writeTailFusedLine(u, outF2, memline1, me); +} + + +// +// Create a kernel that uses a double-wide workgroup (u in half the workgroup, v in the other half) +// We hope to get better occupancy with the reduced register usage +// + +#else + +// Special pairSq for double-wide line 0 +void OVERLOAD pairSq2_special(F2 *u, F2 base_squared) { + u32 me = get_local_id(0); + for (i32 i = 0; i < NH / 4; ++i, base_squared = mul_t8(base_squared)) { + if (i == 0 && me == 0) { + u[0] = SWAP_XY(2 * foo(u[0])); + u[NH/2] = SWAP_XY(4 * csq(u[NH/2])); + } else { + onePairSq(&u[i], &u[NH/2+i], base_squared); + } + F2 new_base_squared = mul_t4(base_squared); + onePairSq(&u[i+NH/4], &u[NH/2+i+NH/4], new_base_squared); + } +} + +KERNEL(G_H * 2) tailSquare(P(T2) out, CP(T2) in, Trig smallTrig) { + local F2 lds[SMALL_HEIGHT]; + + CP(F2) inF2 = (CP(F2)) in; + P(F2) outF2 = (P(F2)) out; + TrigFP32 smallTrigF2 = (TrigFP32) smallTrig; + + F2 u[NH]; + + u32 H = ND / SMALL_HEIGHT; + +#if SINGLE_KERNEL + u32 line_u = get_group_id(0); + u32 line_v = line_u ? H - line_u : (H / 2); +#else + u32 line_u = get_group_id(0) + 1; + u32 line_v = H - line_u; +#endif + + u32 me = get_local_id(0); + u32 lowMe = me % G_H; // lane-id in one of the two halves (half-workgroups). + + // We're going to call the halves "first-half" and "second-half". + bool isSecondHalf = me >= G_H; + + u32 line = !isSecondHalf ? line_u : line_v; + + // Read lines u and v + readTailFusedLine(inF2, u, line, lowMe); + +#if ZEROHACK_H + u32 zerohack = (u32) get_group_id(0) / 131072; + new_fft_HEIGHT2_1(lds + zerohack, u, smallTrigF2 + zerohack); +#else + new_fft_HEIGHT2_1(lds, u, smallTrigF2); +#endif + + // Compute trig values from scratch. Good on GPUs with high DP throughput. +#if TAIL_TRIGS32 == 2 + F2 trig = slowTrig_N(line + H * lowMe, ND / NH * 2); + + // Do a little bit of memory access and a little bit of DP math. Good on a Radeon VII. +#elif TAIL_TRIGS32 == 1 + // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) + // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. + u32 height_trigs = SMALL_HEIGHT*1; + // Read a hopefully cached line of data and one non-cached F2 per line + F2 trig = smallTrigF2[height_trigs + lowMe]; // Trig values for line zero, should be cached + F2 mult = smallTrigF2[height_trigs + G_H + line_u*2 + isSecondHalf]; // Two multipliers. One for line u, one for line v. + trig = cmulFancy(trig, mult); + + // On consumer-grade GPUs, it is likely beneficial to read all trig values. +#else + // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) + // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. + u32 height_trigs = SMALL_HEIGHT*1; + // Read pre-computed trig values + F2 trig = NTLOAD(smallTrigF2[height_trigs + line_u*G_H*2 + me]); +#endif + + bar(G_H); + +#if SINGLE_KERNEL + // Line 0 and H/2 are special: they pair with themselves, line 0 is offseted by 1. + if (line_u == 0) { + reverse2(lds, u); + pairSq2_special(u, trig); + reverse2(lds, u); + } + else { +#else + if (1) { +#endif + revCrossLine(G_H, lds, u + NH/2, NH/2, isSecondHalf); + pairSq(NH/2, u, u + NH/2, trig, false); + bar(G_H); + revCrossLine(G_H, lds, u + NH/2, NH/2, !isSecondHalf); + } + + bar(G_H); + + new_fft_HEIGHT2_2(lds, u, smallTrigF2); + + // Write lines u and v + writeTailFusedLine(u, outF2, transPos(line, MIDDLE, WIDTH), lowMe); +} + +#endif + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +void OVERLOAD onePairSq(GF31* pa, GF31* pb, GF31 t_squared, const u32 t_squared_type) { + GF31 a = *pa, b = *pb; + GF31 c, d; + + X2conjb(a, b); + if (t_squared_type == 0) // mul t_squared by 1 + c = csq_sub(a, cmul(csq(b), t_squared)); // a^2 - (b^2 * t_squared) + if (t_squared_type == 1) // mul t_squared by i + c = csq_subi(a, cmul(csq(b), t_squared)); // a^2 - i*(b^2 * t_squared) + if (t_squared_type == 2) // mul t_squared by -1 + c = csq_add(a, cmul(csq(b), t_squared)); // a^2 - -1*(b^2 * t_squared) + if (t_squared_type == 3) // mul t_squared by -i + c = csq_addi(a, cmul(csq(b), t_squared)); // a^2 - -i*(b^2 * t_squared) + d = mul2(cmul(a, b)); + X2_conjb(c, d); + *pa = SWAP_XY(c), *pb = SWAP_XY(d); +} + +void OVERLOAD pairSq(u32 N, GF31 *u, GF31 *v, GF31 base_squared, bool special) { + u32 me = get_local_id(0); + + for (i32 i = 0; i < NH / 4; ++i, base_squared = mul_t8(base_squared)) { + if (special && i == 0 && me == 0) { + u[i] = SWAP_XY(mul2(foo(u[i]))); + v[i] = SWAP_XY(shl(csq(v[i]), 2)); + } else { + onePairSq(&u[i], &v[i], base_squared, 0); + } + + if (N == NH) { + onePairSq(&u[i+NH/2], &v[i+NH/2], base_squared, 2); + } + + onePairSq(&u[i+NH/4], &v[i+NH/4], base_squared, 1); + + if (N == NH) { + onePairSq(&u[i+3*NH/4], &v[i+3*NH/4], base_squared, 3); + } + } +} + +// The kernel tailSquareZero handles the special cases in tailSquare, i.e. the lines 0 and H/2 +// This kernel is launched with 2 workgroups (handling line 0, resp. H/2) +KERNEL(G_H) tailSquareZeroGF31(P(T2) out, CP(T2) in, Trig smallTrig) { + local GF31 lds[SMALL_HEIGHT / 2]; + + CP(GF31) in31 = (CP(GF31)) (in + DISTGF31); + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 smallTrig31 = (TrigGF31) (smallTrig + DISTHTRIGGF31); + + GF31 u[NH]; + u32 H = ND / SMALL_HEIGHT; + + // This kernel in executed in two workgroups. + u32 which = get_group_id(0); + assert(which < 2); + + u32 line = which ? (H/2) : 0; + u32 me = get_local_id(0); + readTailFusedLine(in31, u, line, me); + + // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) + // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. + u32 height_trigs = SMALL_HEIGHT*1; +#if TAIL_TRIGS31 >= 1 + GF31 trig = smallTrig31[height_trigs + me]; +#if SINGLE_WIDE + GF31 mult = smallTrig31[height_trigs + G_H + line]; +#else + GF31 mult = smallTrig31[height_trigs + G_H + which]; +#endif + trig = cmul(trig, mult); +#else +#if SINGLE_WIDE + GF31 trig = NTLOAD(smallTrig31[height_trigs + line*G_H + me]); +#else + GF31 trig = NTLOAD(smallTrig31[height_trigs + which*G_H + me]); +#endif +#endif + + fft_HEIGHT(lds, u, smallTrig31); + reverse(G_H, lds, u + NH/2, !which); + pairSq(NH/2, u, u + NH/2, trig, !which); + reverse(G_H, lds, u + NH/2, !which); + bar(); + fft_HEIGHT(lds, u, smallTrig31); + writeTailFusedLine(u, out31, transPos(line, MIDDLE, WIDTH), me); +} + +#if SINGLE_WIDE + +KERNEL(G_H) tailSquareGF31(P(T2) out, CP(T2) in, Trig smallTrig) { + local GF31 lds[SMALL_HEIGHT]; + + CP(GF31) in31 = (CP(GF31)) (in + DISTGF31); + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 smallTrig31 = (TrigGF31) (smallTrig + DISTHTRIGGF31); + + GF31 u[NH], v[NH]; + + u32 H = ND / SMALL_HEIGHT; + +#if SINGLE_KERNEL + u32 line1 = get_group_id(0); + u32 line2 = line1 ? H - line1 : (H / 2); +#else + u32 line1 = get_group_id(0) + 1; + u32 line2 = H - line1; +#endif + u32 memline1 = transPos(line1, MIDDLE, WIDTH); + u32 memline2 = transPos(line2, MIDDLE, WIDTH); + + u32 me = get_local_id(0); + readTailFusedLine(in31, u, line1, me); + readTailFusedLine(in31, v, line2, me); + +#if ZEROHACK_H + u32 zerohack = (u32) get_group_id(0) / 131072; + fft_HEIGHT(lds + zerohack, u, smallTrig31 + zerohack); + bar(); + fft_HEIGHT(lds + zerohack, v, smallTrig31 + zerohack); +#else + fft_HEIGHT(lds, u, smallTrig31); + bar(); + fft_HEIGHT(lds, v, smallTrig31); +#endif + + // Do a little bit of memory access and a little bit of math. +#if TAIL_TRIGS31 >= 1 + // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) + // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. + u32 height_trigs = SMALL_HEIGHT*1; + // Read a hopefully cached line of data and one non-cached GF31 per line + GF31 trig = smallTrig31[height_trigs + me]; // Trig values for line zero, should be cached + GF31 mult = smallTrig31[height_trigs + G_H + line1]; // Line multiplier + trig = cmul(trig, mult); + + // On consumer-grade GPUs, it is likely beneficial to read all trig values. +#else + // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) + // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. + u32 height_trigs = SMALL_HEIGHT*1; + // Read pre-computed trig values + GF31 trig = NTLOAD(smallTrig31[height_trigs + line1*G_H + me]); +#endif + +#if SINGLE_KERNEL + if (line1 == 0) { + // Line 0 is special: it pairs with itself, offseted by 1. + reverse(G_H, lds, u + NH/2, true); + pairSq(NH/2, u, u + NH/2, trig, true); + reverse(G_H, lds, u + NH/2, true); + + // Line H/2 also pairs with itself (but without offset). + GF31 trig2 = cmul(trig, TAILTGF31); + reverse(G_H, lds, v + NH/2, false); + pairSq(NH/2, v, v + NH/2, trig2, false); + reverse(G_H, lds, v + NH/2, false); + } + else { +#else + if (1) { +#endif + reverseLine(G_H, lds, v); + pairSq(NH, u, v, trig, false); + reverseLine(G_H, lds, v); + } + + bar(); + fft_HEIGHT(lds, v, smallTrig31); + bar(); + fft_HEIGHT(lds, u, smallTrig31); + + writeTailFusedLine(v, out31, memline2, me); + writeTailFusedLine(u, out31, memline1, me); +} + + +// +// Create a kernel that uses a double-wide workgroup (u in half the workgroup, v in the other half) +// We hope to get better occupancy with the reduced register usage +// + +#else + +// Special pairSq for double-wide line 0 +void OVERLOAD pairSq2_special(GF31 *u, GF31 base_squared) { + u32 me = get_local_id(0); + for (i32 i = 0; i < NH / 4; ++i, base_squared = mul_t8(base_squared)) { + if (i == 0 && me == 0) { + u[0] = SWAP_XY(mul2(foo(u[0]))); + u[NH/2] = SWAP_XY(shl(csq(u[NH/2]), 2)); + } else { + onePairSq(&u[i], &u[NH/2+i], base_squared, 0); + } + onePairSq(&u[i+NH/4], &u[NH/2+i+NH/4], base_squared, 1); + } +} + +KERNEL(G_H * 2) tailSquareGF31(P(T2) out, CP(T2) in, Trig smallTrig) { + local GF31 lds[SMALL_HEIGHT]; + + CP(GF31) in31 = (CP(GF31)) (in + DISTGF31); + P(GF31) out31 = (P(GF31)) (out + DISTGF31); + TrigGF31 smallTrig31 = (TrigGF31) (smallTrig + DISTHTRIGGF31); + + GF31 u[NH]; + + u32 H = ND / SMALL_HEIGHT; + +#if SINGLE_KERNEL + u32 line_u = get_group_id(0); + u32 line_v = line_u ? H - line_u : (H / 2); +#else + u32 line_u = get_group_id(0) + 1; + u32 line_v = H - line_u; +#endif + + u32 me = get_local_id(0); + u32 lowMe = me % G_H; // lane-id in one of the two halves (half-workgroups). + + // We're going to call the halves "first-half" and "second-half". + bool isSecondHalf = me >= G_H; + + u32 line = !isSecondHalf ? line_u : line_v; + + // Read lines u and v + readTailFusedLine(in31, u, line, lowMe); + +#if ZEROHACK_H + u32 zerohack = (u32) get_group_id(0) / 131072; + new_fft_HEIGHT2_1(lds + zerohack, u, smallTrig31 + zerohack); +#else + new_fft_HEIGHT2_1(lds, u, smallTrig31); +#endif + + // Do a little bit of memory access and a little bit of math. Good on a Radeon VII. +#if TAIL_TRIGS31 >= 1 + // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) + // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. + u32 height_trigs = SMALL_HEIGHT*1; + // Read a hopefully cached line of data and one non-cached GF31 per line + GF31 trig = smallTrig31[height_trigs + lowMe]; // Trig values for line zero, should be cached + GF31 mult = smallTrig31[height_trigs + G_H + line_u*2 + isSecondHalf]; // Two multipliers. One for line u, one for line v. + trig = cmul(trig, mult); + + // On consumer-grade GPUs, it is likely beneficial to read all trig values. +#else + // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) + // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. + u32 height_trigs = SMALL_HEIGHT*1; + // Read pre-computed trig values + GF31 trig = NTLOAD(smallTrig31[height_trigs + line_u*G_H*2 + me]); +#endif + + bar(G_H); + +#if SINGLE_KERNEL + // Line 0 and H/2 are special: they pair with themselves, line 0 is offseted by 1. + if (line_u == 0) { + reverse2(lds, u); + pairSq2_special(u, trig); + reverse2(lds, u); + } + else { +#else + if (1) { +#endif + revCrossLine(G_H, lds, u + NH/2, NH/2, isSecondHalf); + pairSq(NH/2, u, u + NH/2, trig, false); + bar(G_H); + revCrossLine(G_H, lds, u + NH/2, NH/2, !isSecondHalf); + } + + bar(G_H); + + new_fft_HEIGHT2_2(lds, u, smallTrig31); + + // Write lines u and v + writeTailFusedLine(u, out31, transPos(line, MIDDLE, WIDTH), lowMe); +} + +#endif + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +void OVERLOAD onePairSq(GF61* pa, GF61* pb, GF61 t_squared, const u32 t_squared_type) { + GF61 a = *pa, b = *pb; + GF61 c, d; + + X2conjb(a, b); + if (t_squared_type == 0) // mul t_squared by 1 + c = subq(csqq(a, 2), cmul(csq(b), t_squared), 2); // max c value is 4*M61+epsilon + if (t_squared_type == 1) // mul t_squared by i + c = subiq(csqq(a, 2), cmul(csq(b), t_squared), 2); // max c value is 4*M61+epsilon + if (t_squared_type == 2) // mul t_squared by -1 + c = addq(csqq(a, 2), cmul(csq(b), t_squared)); // max c value is 3*M61+epsilon + if (t_squared_type == 3) // mul t_squared by -i + c = addiq(csqq(a, 2), cmul(csq(b), t_squared), 2); // max c value is 4*M61+epsilon + d = 2 * cmul(a, b); // max d value is 2*M61+epsilon + X2s_conjb(&c, &d, 5, 3); + *pa = SWAP_XY(c), *pb = SWAP_XY(d); +} + +void OVERLOAD pairSq(u32 N, GF61 *u, GF61 *v, GF61 base_squared, bool special) { + u32 me = get_local_id(0); + + for (i32 i = 0; i < NH / 4; ++i, base_squared = mul_t8(base_squared)) { + if (special && i == 0 && me == 0) { + u[i] = SWAP_XY(mul2(foo(u[i]))); + v[i] = SWAP_XY(shl(csq(v[i]), 2)); + } else { + onePairSq(&u[i], &v[i], base_squared, 0); + } + + if (N == NH) { + onePairSq(&u[i+NH/2], &v[i+NH/2], base_squared, 2); + } + + onePairSq(&u[i+NH/4], &v[i+NH/4], base_squared, 1); + + if (N == NH) { + onePairSq(&u[i+3*NH/4], &v[i+3*NH/4], base_squared, 3); + } + } +} + +// The kernel tailSquareZero handles the special cases in tailSquare, i.e. the lines 0 and H/2 +// This kernel is launched with 2 workgroups (handling line 0, resp. H/2) +KERNEL(G_H) tailSquareZeroGF61(P(T2) out, CP(T2) in, Trig smallTrig) { + local GF61 lds[SMALL_HEIGHT / 2]; + + CP(GF61) in61 = (CP(GF61)) (in + DISTGF61); + P(GF61) out61 = (P(GF61)) (out + DISTGF61); + TrigGF61 smallTrig61 = (TrigGF61) (smallTrig + DISTHTRIGGF61); + + GF61 u[NH]; + u32 H = ND / SMALL_HEIGHT; + + // This kernel in executed in two workgroups. + u32 which = get_group_id(0); + assert(which < 2); + + u32 line = which ? (H/2) : 0; + u32 me = get_local_id(0); + readTailFusedLine(in61, u, line, me); + + // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) + // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. + u32 height_trigs = SMALL_HEIGHT*1; +#if TAIL_TRIGS61 >= 1 + GF61 trig = smallTrig61[height_trigs + me]; +#if SINGLE_WIDE + GF61 mult = smallTrig61[height_trigs + G_H + line]; +#else + GF61 mult = smallTrig61[height_trigs + G_H + which]; +#endif + trig = cmul(trig, mult); +#else +#if SINGLE_WIDE + GF61 trig = NTLOAD(smallTrig61[height_trigs + line*G_H + me]); +#else + GF61 trig = NTLOAD(smallTrig61[height_trigs + which*G_H + me]); +#endif +#endif + + fft_HEIGHT(lds, u, smallTrig61); + reverse(G_H, lds, u + NH/2, !which); + pairSq(NH/2, u, u + NH/2, trig, !which); + reverse(G_H, lds, u + NH/2, !which); + bar(); + fft_HEIGHT(lds, u, smallTrig61); + writeTailFusedLine(u, out61, transPos(line, MIDDLE, WIDTH), me); +} + +#if SINGLE_WIDE + +KERNEL(G_H) tailSquareGF61(P(T2) out, CP(T2) in, Trig smallTrig) { + local GF61 lds[SMALL_HEIGHT]; + + CP(GF61) in61 = (CP(GF61)) (in + DISTGF61); + P(GF61) out61 = (P(GF61)) (out + DISTGF61); + TrigGF61 smallTrig61 = (TrigGF61) (smallTrig + DISTHTRIGGF61); + + GF61 u[NH], v[NH]; + + u32 H = ND / SMALL_HEIGHT; + +#if SINGLE_KERNEL + u32 line1 = get_group_id(0); + u32 line2 = line1 ? H - line1 : (H / 2); +#else + u32 line1 = get_group_id(0) + 1; + u32 line2 = H - line1; +#endif + u32 memline1 = transPos(line1, MIDDLE, WIDTH); + u32 memline2 = transPos(line2, MIDDLE, WIDTH); + + u32 me = get_local_id(0); + readTailFusedLine(in61, u, line1, me); + readTailFusedLine(in61, v, line2, me); + +#if ZEROHACK_H + u32 zerohack = (u32) get_group_id(0) / 131072; + fft_HEIGHT(lds + zerohack, u, smallTrig61 + zerohack); + bar(); + fft_HEIGHT(lds + zerohack, v, smallTrig61 + zerohack); +#else + fft_HEIGHT(lds, u, smallTrig61); + bar(); + fft_HEIGHT(lds, v, smallTrig61); +#endif + + // Do a little bit of memory access and a little bit of math. +#if TAIL_TRIGS61 >= 1 + // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) + // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. + u32 height_trigs = SMALL_HEIGHT*1; + // Read a hopefully cached line of data and one non-cached GF61 per line + GF61 trig = smallTrig61[height_trigs + me]; // Trig values for line zero, should be cached + GF61 mult = smallTrig61[height_trigs + G_H + line1]; // Line multiplier + trig = cmul(trig, mult); + + // On consumer-grade GPUs, it is likely beneficial to read all trig values. +#else + // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) + // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. + u32 height_trigs = SMALL_HEIGHT*1; + // Read pre-computed trig values + GF61 trig = NTLOAD(smallTrig61[height_trigs + line1*G_H + me]); +#endif + +#if SINGLE_KERNEL + if (line1 == 0) { + // Line 0 is special: it pairs with itself, offseted by 1. + reverse(G_H, lds, u + NH/2, true); + pairSq(NH/2, u, u + NH/2, trig, true); + reverse(G_H, lds, u + NH/2, true); + + // Line H/2 also pairs with itself (but without offset). + GF61 trig2 = cmul(trig, TAILTGF61); + reverse(G_H, lds, v + NH/2, false); + pairSq(NH/2, v, v + NH/2, trig2, false); + reverse(G_H, lds, v + NH/2, false); + } + else { +#else + if (1) { +#endif + reverseLine(G_H, lds, v); + pairSq(NH, u, v, trig, false); + reverseLine(G_H, lds, v); + } + + bar(); + fft_HEIGHT(lds, v, smallTrig61); + bar(); + fft_HEIGHT(lds, u, smallTrig61); + + writeTailFusedLine(v, out61, memline2, me); + writeTailFusedLine(u, out61, memline1, me); +} + + +// +// Create a kernel that uses a double-wide workgroup (u in half the workgroup, v in the other half) +// We hope to get better occupancy with the reduced register usage +// + +#else + +// Special pairSq for double-wide line 0 +void OVERLOAD pairSq2_special(GF61 *u, GF61 base_squared) { + u32 me = get_local_id(0); + for (i32 i = 0; i < NH / 4; ++i, base_squared = mul_t8(base_squared)) { + if (i == 0 && me == 0) { + u[0] = SWAP_XY(mul2(foo(u[0]))); + u[NH/2] = SWAP_XY(shl(csq(u[NH/2]), 2)); + } else { + onePairSq(&u[i], &u[NH/2+i], base_squared, 0); + } + onePairSq(&u[i+NH/4], &u[NH/2+i+NH/4], base_squared, 1); + } +} + +KERNEL(G_H * 2) tailSquareGF61(P(T2) out, CP(T2) in, Trig smallTrig) { + local GF61 lds[SMALL_HEIGHT]; + + CP(GF61) in61 = (CP(GF61)) (in + DISTGF61); + P(GF61) out61 = (P(GF61)) (out + DISTGF61); + TrigGF61 smallTrig61 = (TrigGF61) (smallTrig + DISTHTRIGGF61); + + GF61 u[NH]; + + u32 H = ND / SMALL_HEIGHT; + +#if SINGLE_KERNEL + u32 line_u = get_group_id(0); + u32 line_v = line_u ? H - line_u : (H / 2); +#else + u32 line_u = get_group_id(0) + 1; + u32 line_v = H - line_u; +#endif + + u32 me = get_local_id(0); + u32 lowMe = me % G_H; // lane-id in one of the two halves (half-workgroups). + + // We're going to call the halves "first-half" and "second-half". + bool isSecondHalf = me >= G_H; + + u32 line = !isSecondHalf ? line_u : line_v; + + // Read lines u and v + readTailFusedLine(in61, u, line, lowMe); + +#if ZEROHACK_H + u32 zerohack = (u32) get_group_id(0) / 131072; + new_fft_HEIGHT2_1(lds + zerohack, u, smallTrig61 + zerohack); +#else + new_fft_HEIGHT2_1(lds, u, smallTrig61); +#endif + + // Do a little bit of memory access and a little bit of math. Good on a Radeon VII. +#if TAIL_TRIGS61 >= 1 + // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) + // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. + u32 height_trigs = SMALL_HEIGHT*1; + // Read a hopefully cached line of data and one non-cached GF61 per line + GF61 trig = smallTrig61[height_trigs + lowMe]; // Trig values for line zero, should be cached + GF61 mult = smallTrig61[height_trigs + G_H + line_u*2 + isSecondHalf]; // Two multipliers. One for line u, one for line v. + trig = cmul(trig, mult); + + // On consumer-grade GPUs, it is likely beneficial to read all trig values. +#else + // Calculate number of trig values used by fft_HEIGHT (see genSmallTrigCombo in trigBufCache.cpp) + // The trig values used here are pre-computed and stored after the fft_HEIGHT trig values. + u32 height_trigs = SMALL_HEIGHT*1; + // Read pre-computed trig values + GF61 trig = NTLOAD(smallTrig61[height_trigs + line_u*G_H*2 + me]); +#endif + + bar(G_H); + +#if SINGLE_KERNEL + // Line 0 and H/2 are special: they pair with themselves, line 0 is offseted by 1. + if (line_u == 0) { + reverse2(lds, u); + pairSq2_special(u, trig); + reverse2(lds, u); + } + else { +#else + if (1) { +#endif + revCrossLine(G_H, lds, u + NH/2, NH/2, isSecondHalf); + pairSq(NH/2, u, u + NH/2, trig, false); + bar(G_H); + revCrossLine(G_H, lds, u + NH/2, NH/2, !isSecondHalf); + } + + bar(G_H); + + new_fft_HEIGHT2_2(lds, u, smallTrig61); + + // Write lines u and v + writeTailFusedLine(u, out61, transPos(line, MIDDLE, WIDTH), lowMe); +} + +#endif + +#endif diff --git a/src/cl/tailutil.cl b/src/cl/tailutil.cl index a33f280a..01710cf1 100644 --- a/src/cl/tailutil.cl +++ b/src/cl/tailutil.cl @@ -2,7 +2,37 @@ #include "math.cl" -void reverse(u32 WG, local T2 *lds, T2 *u, bool bump) { +// TAIL_TRIGS setting: +// 2 = No memory accesses, trig values computed from scratch. Good for excellent DP GPUs such as Titan V or Radeon VII Pro. +// 1 = Limited memory accesses and some DP computation. Tuned for Radeon VII a GPU with good DP performance. +// 0 = No DP computation. Trig vaules read from memory. Good for GPUs with poor DP performance (a typical consumer grade GPU). +#if !defined(TAIL_TRIGS) +#define TAIL_TRIGS 2 // Default is compute trig values from scratch for FP64 +#endif +#if !defined(TAIL_TRIGS31) +#define TAIL_TRIGS31 0 // Default is read all trig values from memory for GF31 +#endif +#if !defined(TAIL_TRIGS32) +#define TAIL_TRIGS32 2 // Default is compute trig values from scratch for FP32 +#endif +#if !defined(TAIL_TRIGS61) +#define TAIL_TRIGS61 0 // Default is read all trig values from memory for GF61 +#endif + +// TAIL_KERNELS setting: +// 0 = single wide, single kernel +// 1 = single wide, two kernels +// 2 = double wide, single kernel +// 3 = double wide, two kernels +#if !defined(TAIL_KERNELS) +#define TAIL_KERNELS 2 // Default is double-wide tailSquare with two kernels +#endif +#define SINGLE_WIDE TAIL_KERNELS < 2 // Old single-wide tailSquare vs. new double-wide tailSquare +#define SINGLE_KERNEL (TAIL_KERNELS & 1) == 0 // TailSquare uses a single kernel vs. two kernels + +#if FFT_FP64 + +void OVERLOAD reverse(u32 WG, local T2 *lds, T2 *u, bool bump) { u32 me = get_local_id(0); u32 revMe = WG - 1 - me + bump; @@ -24,7 +54,7 @@ void reverse(u32 WG, local T2 *lds, T2 *u, bool bump) { for (i32 i = 0; i < NH/2; ++i) { u[i] = lds[i * WG + me]; } } -void reverseLine(u32 WG, local T2 *lds2, T2 *u) { +void OVERLOAD reverseLine(u32 WG, local T2 *lds2, T2 *u) { u32 me = get_local_id(0); u32 revMe = WG - 1 - me; @@ -38,7 +68,7 @@ void reverseLine(u32 WG, local T2 *lds2, T2 *u) { } // This is used to reverse the second part of a line, and cross the reversed parts between the halves. -void revCrossLine(u32 WG, local T2* lds2, T2 *u, u32 n, bool writeSecondHalf) { +void OVERLOAD revCrossLine(u32 WG, local T2* lds2, T2 *u, u32 n, bool writeSecondHalf) { u32 me = get_local_id(0); u32 lowMe = me % WG; @@ -51,23 +81,11 @@ void revCrossLine(u32 WG, local T2* lds2, T2 *u, u32 n, bool writeSecondHalf) { for (u32 i = 0; i < n; ++i) { u[i] = lds2[WG * n * !writeSecondHalf + WG * i + lowMe]; } } -// computes 2*(a.x*b.x+a.y*b.y) + i*2*(a.x*b.y+a.y*b.x) -// which happens to be the cyclical convolution (a.x, a.y)x(b.x, b.y) * 2 -T2 foo2(T2 a, T2 b) { - a = addsub(a); - b = addsub(b); - return addsub(U2(RE(a) * RE(b), IM(a) * IM(b))); -} - -// computes 2*[x^2+y^2 + i*(2*x*y)]. i.e. 2 * cyclical autoconvolution of (x, y) -T2 foo(T2 a) { return foo2(a, a); } - - // // These versions are for the kernel(s) that uses a double-wide workgroup (u in half the workgroup, v in the other half) // -void reverse2(local T2 *lds, T2 *u) { +void OVERLOAD reverse2(local T2 *lds, T2 *u) { u32 me = get_local_id(0); // For NH=8, u[0] to u[3] are left unchanged. Write to lds: @@ -98,7 +116,7 @@ void reverse2(local T2 *lds, T2 *u) { // u[2] u[3] // Returned in u[1] // v[3]rev v[2]rev // Returned in u[2] // v[1]rev v[0]rev // Returned in u[3] -void reverseLine2(local T2 *lds, T2 *u) { +void OVERLOAD reverseLine2(local T2 *lds, T2 *u) { u32 me = get_local_id(0); // NOTE: It is important that this routine use lds memory in coordination with shufl2. Failure to do so would require an @@ -134,7 +152,7 @@ void reverseLine2(local T2 *lds, T2 *u) { } // Undo a reverseLine2 -void unreverseLine2(local T2 *lds, T2 *u) { +void OVERLOAD unreverseLine2(local T2 *lds, T2 *u) { u32 me = get_local_id(0); // NOTE: It is important that this routine use lds memory in coordination with reverseLine2 and shufl2. By initially @@ -170,3 +188,503 @@ void unreverseLine2(local T2 *lds, T2 *u) { for (u32 i = 0; i < NH; ++i, ldsIn += ldsInc) { u[i].x = ldsIn[0]; u[i].y = ldsIn[NH*2*G_H]; } #endif } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 + +void OVERLOAD reverse(u32 WG, local F2 *lds, F2 *u, bool bump) { + u32 me = get_local_id(0); + u32 revMe = WG - 1 - me + bump; + + bar(); + +#if NH == 8 + lds[revMe + 0 * WG] = u[3]; + lds[revMe + 1 * WG] = u[2]; + lds[revMe + 2 * WG] = u[1]; + lds[bump ? ((revMe + 3 * WG) % (4 * WG)) : (revMe + 3 * WG)] = u[0]; +#elif NH == 4 + lds[revMe + 0 * WG] = u[1]; + lds[bump ? ((revMe + WG) % (2 * WG)) : (revMe + WG)] = u[0]; +#else +#error +#endif + + bar(); + for (i32 i = 0; i < NH/2; ++i) { u[i] = lds[i * WG + me]; } +} + +void OVERLOAD reverseLine(u32 WG, local F2 *lds2, F2 *u) { + u32 me = get_local_id(0); + u32 revMe = WG - 1 - me; + + local F2 *lds = lds2 + revMe; + bar(); + for (u32 i = 0; i < NH; ++i) { lds[WG * (NH - 1 - i)] = u[i]; } + + lds = lds2 + me; + bar(); + for (u32 i = 0; i < NH; ++i) { u[i] = lds[WG * i]; } +} + +// This is used to reverse the second part of a line, and cross the reversed parts between the halves. +void OVERLOAD revCrossLine(u32 WG, local F2* lds2, F2 *u, u32 n, bool writeSecondHalf) { + u32 me = get_local_id(0); + u32 lowMe = me % WG; + + u32 revLowMe = WG - 1 - lowMe; + + for (u32 i = 0; i < n; ++i) { lds2[WG * n * writeSecondHalf + WG * (n - 1 - i) + revLowMe] = u[i]; } + + bar(); // we need a full bar because we're crossing halves + + for (u32 i = 0; i < n; ++i) { u[i] = lds2[WG * n * !writeSecondHalf + WG * i + lowMe]; } +} + +// +// These versions are for the kernel(s) that uses a double-wide workgroup (u in half the workgroup, v in the other half) +// + +void OVERLOAD reverse2(local F2 *lds, F2 *u) { + u32 me = get_local_id(0); + + // For NH=8, u[0] to u[3] are left unchanged. Write to lds: + // u[7]rev u[6]rev + // u[5]rev u[4]rev + // v[7]rev v[6]rev + // v[5]rev v[4]rev + bar(); + for (u32 i = 0; i < NH / 2; ++i) { + u32 j = (i * G_H + me % G_H); + lds[me < G_H ? ((NH/2)*G_H - j) % ((NH/2)*G_H) : NH*G_H-1 - j] = u[NH/2 + i]; + } + // For NH=8, read from lds into u[i]: + // u[4] = u[7]rev v[7]rev + // u[5] = u[6]rev v[6]rev + // u[6] = u[5]rev v[5]rev + // u[7] = u[4]rev v[4]rev + bar(); + lds += me % G_H + (me / G_H) * NH/2 * G_H; + for (u32 i = 0; i < NH / 2; ++i) { u[NH/2 + i] = lds[i * G_H]; } +} + +// Somewhat similar to reverseLine. +// The u values are in threads < G_H, the v values to reverse in threads >= G_H. +// Whereas reverseLine leaves u values alone. This reverseLine moves u values around +// so that pairSq2 can easily operate on pairs. This means for NH = 4, web output: +// u[0] u[1] // Returned in u[0] +// u[2] u[3] // Returned in u[1] +// v[3]rev v[2]rev // Returned in u[2] +// v[1]rev v[0]rev // Returned in u[3] +void OVERLOAD reverseLine2(local F2 *lds, F2 *u) { + u32 me = get_local_id(0); + +// NOTE: It is important that this routine use lds memory in coordination with shufl2. Failure to do so would require an +// unqualified bar() call here. Specifically, the u values are stored in the upper half of lds memory (SMALL_HEIGHT F2 values). +// The v values are stored in the lower half of lds memory (the next SMALL_HEIGHT F2 values). + + if (G_H > WAVEFRONT) bar(); + +// For NH=4, the lds indices (where to write each incoming u[i] which has v[i] in the upper threads) looks like this: +// 0..GH-1 +0*G_H GH-1..0 +7*G_H +// 0..GH-1 +1*G_H GH-1..0 +6*G_H +// 0..GH-1 +2*G_H GH-1..0 +5*G_H +// 0..GH-1 +3*G_H GH-1..0 +4*G_H +// That means saving to lds using index: me < G_H ? me % G_H + i * G_H : 8*G_H-1 - me % G_H - i * G_H + +#if 1 + local F2 *ldsOut = lds + (me < G_H ? me % G_H : (NH*2)*G_H-1 - me % G_H); + i32 ldsOutInc = (me < G_H) ? G_H : -G_H; + for (u32 i = 0; i < NH; ++i, ldsOut += ldsOutInc) { *ldsOut = u[i]; } + + lds += me; + bar(); + for (u32 i = 0; i < NH; ++i) { u[i] = lds[i * 2*G_H]; } +#else + local F *ldsOut = (local F *) lds + (me < G_H ? me % G_H : (NH*2)*G_H-1 - me % G_H); + i32 ldsOutInc = (me < G_H) ? G_H : -G_H; + for (u32 i = 0; i < NH; ++i, ldsOut += ldsOutInc) { ldsOut[0] = u[i].x; ldsOut[NH*2*G_H] = u[i].y; } + + local F *ldsIn = (local F *) lds + me; + bar(); + for (u32 i = 0; i < NH; ++i) { u[i].x = ldsIn[i * 2*G_H]; u[i].y = ldsIn[NH*2*G_H + i * 2*G_H]; } +#endif +} + +// Undo a reverseLine2 +void OVERLOAD unreverseLine2(local F2 *lds, F2 *u) { + u32 me = get_local_id(0); + +// NOTE: It is important that this routine use lds memory in coordination with reverseLine2 and shufl2. By initially +// writing to the lds locations that reverseLine2 read from we do not need an initial bar() call here. Also, by reading +// from the lds locations that shufl2 will use (u values in the upper half of lds memory, v values in the lower half of +// lds memory) we can issue a qualified bar() call before calling FFT_HEIGHT2. + +#if 1 + local F2 *ldsOut = lds + me; + for (u32 i = 0; i < NH; ++i) { ldsOut[i * 2*G_H] = u[i]; } + +// For NH=4, the lds indices (where to read each outgoing u[i] which has v[i] in the upper threads) looks like this: +// 0..GH-1 +0*G_H GH-1..0 +7*G_H +// 0..GH-1 +1*G_H GH-1..0 +6*G_H +// 0..GH-1 +2*G_H GH-1..0 +5*G_H +// 0..GH-1 +3*G_H GH-1..0 +4*G_H + lds += (me < G_H) ? me % G_H : (NH*2)*G_H-1 - me % G_H; + i32 ldsInc = (me < G_H) ? G_H : -G_H; + bar(); + for (u32 i = 0; i < NH; ++i, lds += ldsInc) { u[i] = *lds; } +#else + local F *ldsOut = (local F *) lds + me; + for (u32 i = 0; i < NH; ++i) { ldsOut[i * 2*G_H] = u[i].x; ldsOut[NH*2*G_H + i * 2*G_H] = u[i].y; } + +// For NH=4, the lds indices (where to read each outgoing u[i] which has v[i] in the upper threads) looks like this: +// 0..GH-1 +0*G_H GH-1..0 +7*G_H +// 0..GH-1 +1*G_H GH-1..0 +6*G_H +// 0..GH-1 +2*G_H GH-1..0 +5*G_H +// 0..GH-1 +3*G_H GH-1..0 +4*G_H + local F *ldsIn = (local F *) lds + ((me < G_H) ? me % G_H : (NH*2)*G_H-1 - me % G_H); + i32 ldsInc = (me < G_H) ? G_H : -G_H; + bar(); + for (u32 i = 0; i < NH; ++i, ldsIn += ldsInc) { u[i].x = ldsIn[0]; u[i].y = ldsIn[NH*2*G_H]; } +#endif +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M31^2) */ +/**************************************************************************/ + +#if NTT_GF31 + +void OVERLOAD reverse(u32 WG, local GF31 *lds, GF31 *u, bool bump) { + u32 me = get_local_id(0); + u32 revMe = WG - 1 - me + bump; + + bar(); + +#if NH == 8 + lds[revMe + 0 * WG] = u[3]; + lds[revMe + 1 * WG] = u[2]; + lds[revMe + 2 * WG] = u[1]; + lds[bump ? ((revMe + 3 * WG) % (4 * WG)) : (revMe + 3 * WG)] = u[0]; +#elif NH == 4 + lds[revMe + 0 * WG] = u[1]; + lds[bump ? ((revMe + WG) % (2 * WG)) : (revMe + WG)] = u[0]; +#else +#error +#endif + + bar(); + for (i32 i = 0; i < NH/2; ++i) { u[i] = lds[i * WG + me]; } +} + +void OVERLOAD reverseLine(u32 WG, local GF31 *lds2, GF31 *u) { + u32 me = get_local_id(0); + u32 revMe = WG - 1 - me; + + local GF31 *lds = lds2 + revMe; + bar(); + for (u32 i = 0; i < NH; ++i) { lds[WG * (NH - 1 - i)] = u[i]; } + + lds = lds2 + me; + bar(); + for (u32 i = 0; i < NH; ++i) { u[i] = lds[WG * i]; } +} + +// This is used to reverse the second part of a line, and cross the reversed parts between the halves. +void OVERLOAD revCrossLine(u32 WG, local GF31* lds2, GF31 *u, u32 n, bool writeSecondHalf) { + u32 me = get_local_id(0); + u32 lowMe = me % WG; + + u32 revLowMe = WG - 1 - lowMe; + + for (u32 i = 0; i < n; ++i) { lds2[WG * n * writeSecondHalf + WG * (n - 1 - i) + revLowMe] = u[i]; } + + bar(); // we need a full bar because we're crossing halves + + for (u32 i = 0; i < n; ++i) { u[i] = lds2[WG * n * !writeSecondHalf + WG * i + lowMe]; } +} + +// +// These versions are for the kernel(s) that uses a double-wide workgroup (u in half the workgroup, v in the other half) +// + +void OVERLOAD reverse2(local GF31 *lds, GF31 *u) { + u32 me = get_local_id(0); + + // For NH=8, u[0] to u[3] are left unchanged. Write to lds: + // u[7]rev u[6]rev + // u[5]rev u[4]rev + // v[7]rev v[6]rev + // v[5]rev v[4]rev + bar(); + for (u32 i = 0; i < NH / 2; ++i) { + u32 j = (i * G_H + me % G_H); + lds[me < G_H ? ((NH/2)*G_H - j) % ((NH/2)*G_H) : NH*G_H-1 - j] = u[NH/2 + i]; + } + // For NH=8, read from lds into u[i]: + // u[4] = u[7]rev v[7]rev + // u[5] = u[6]rev v[6]rev + // u[6] = u[5]rev v[5]rev + // u[7] = u[4]rev v[4]rev + bar(); + lds += me % G_H + (me / G_H) * NH/2 * G_H; + for (u32 i = 0; i < NH / 2; ++i) { u[NH/2 + i] = lds[i * G_H]; } +} + +// Somewhat similar to reverseLine. +// The u values are in threads < G_H, the v values to reverse in threads >= G_H. +// Whereas reverseLine leaves u values alone. This reverseLine moves u values around +// so that pairSq2 can easily operate on pairs. This means for NH = 4, web output: +// u[0] u[1] // Returned in u[0] +// u[2] u[3] // Returned in u[1] +// v[3]rev v[2]rev // Returned in u[2] +// v[1]rev v[0]rev // Returned in u[3] +void OVERLOAD reverseLine2(local GF31 *lds, GF31 *u) { + u32 me = get_local_id(0); + +// NOTE: It is important that this routine use lds memory in coordination with shufl2. Failure to do so would require an +// unqualified bar() call here. Specifically, the u values are stored in the upper half of lds memory (SMALL_HEIGHT GF31 values). +// The v values are stored in the lower half of lds memory (the next SMALL_HEIGHT GF31 values). + + if (G_H > WAVEFRONT) bar(); + +// For NH=4, the lds indices (where to write each incoming u[i] which has v[i] in the upper threads) looks like this: +// 0..GH-1 +0*G_H GH-1..0 +7*G_H +// 0..GH-1 +1*G_H GH-1..0 +6*G_H +// 0..GH-1 +2*G_H GH-1..0 +5*G_H +// 0..GH-1 +3*G_H GH-1..0 +4*G_H +// That means saving to lds using index: me < G_H ? me % G_H + i * G_H : 8*G_H-1 - me % G_H - i * G_H + +#if 1 + local GF31 *ldsOut = lds + (me < G_H ? me % G_H : (NH*2)*G_H-1 - me % G_H); + i32 ldsOutInc = (me < G_H) ? G_H : -G_H; + for (u32 i = 0; i < NH; ++i, ldsOut += ldsOutInc) { *ldsOut = u[i]; } + + lds += me; + bar(); + for (u32 i = 0; i < NH; ++i) { u[i] = lds[i * 2*G_H]; } +#else + local Z61 *ldsOut = (local Z61 *) lds + (me < G_H ? me % G_H : (NH*2)*G_H-1 - me % G_H); + i32 ldsOutInc = (me < G_H) ? G_H : -G_H; + for (u32 i = 0; i < NH; ++i, ldsOut += ldsOutInc) { ldsOut[0] = u[i].x; ldsOut[NH*2*G_H] = u[i].y; } + + local ZF61 *ldsIn = (local T *) lds + me; + bar(); + for (u32 i = 0; i < NH; ++i) { u[i].x = ldsIn[i * 2*G_H]; u[i].y = ldsIn[NH*2*G_H + i * 2*G_H]; } +#endif +} + +// Undo a reverseLine2 +void OVERLOAD unreverseLine2(local GF31 *lds, GF31 *u) { + u32 me = get_local_id(0); + +// NOTE: It is important that this routine use lds memory in coordination with reverseLine2 and shufl2. By initially +// writing to the lds locations that reverseLine2 read from we do not need an initial bar() call here. Also, by reading +// from the lds locations that shufl2 will use (u values in the upper half of lds memory, v values in the lower half of +// lds memory) we can issue a qualified bar() call before calling FFT_HEIGHT2. + +#if 1 + local GF31 *ldsOut = lds + me; + for (u32 i = 0; i < NH; ++i) { ldsOut[i * 2*G_H] = u[i]; } + +// For NH=4, the lds indices (where to read each outgoing u[i] which has v[i] in the upper threads) looks like this: +// 0..GH-1 +0*G_H GH-1..0 +7*G_H +// 0..GH-1 +1*G_H GH-1..0 +6*G_H +// 0..GH-1 +2*G_H GH-1..0 +5*G_H +// 0..GH-1 +3*G_H GH-1..0 +4*G_H + lds += (me < G_H) ? me % G_H : (NH*2)*G_H-1 - me % G_H; + i32 ldsInc = (me < G_H) ? G_H : -G_H; + bar(); + for (u32 i = 0; i < NH; ++i, lds += ldsInc) { u[i] = *lds; } +#else + local Z61 *ldsOut = (local T *) lds + me; + for (u32 i = 0; i < NH; ++i) { ldsOut[i * 2*G_H] = u[i].x; ldsOut[NH*2*G_H + i * 2*G_H] = u[i].y; } + +// For NH=4, the lds indices (where to read each outgoing u[i] which has v[i] in the upper threads) looks like this: +// 0..GH-1 +0*G_H GH-1..0 +7*G_H +// 0..GH-1 +1*G_H GH-1..0 +6*G_H +// 0..GH-1 +2*G_H GH-1..0 +5*G_H +// 0..GH-1 +3*G_H GH-1..0 +4*G_H + local Z61 *ldsIn = (local T *) lds + ((me < G_H) ? me % G_H : (NH*2)*G_H-1 - me % G_H); + i32 ldsInc = (me < G_H) ? G_H : -G_H; + bar(); + for (u32 i = 0; i < NH; ++i, ldsIn += ldsInc) { u[i].x = ldsIn[0]; u[i].y = ldsIn[NH*2*G_H]; } +#endif +} + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an NTT based on GF(M61^2) */ +/**************************************************************************/ + +#if NTT_GF61 + +void OVERLOAD reverse(u32 WG, local GF61 *lds, GF61 *u, bool bump) { + u32 me = get_local_id(0); + u32 revMe = WG - 1 - me + bump; + + bar(); + +#if NH == 8 + lds[revMe + 0 * WG] = u[3]; + lds[revMe + 1 * WG] = u[2]; + lds[revMe + 2 * WG] = u[1]; + lds[bump ? ((revMe + 3 * WG) % (4 * WG)) : (revMe + 3 * WG)] = u[0]; +#elif NH == 4 + lds[revMe + 0 * WG] = u[1]; + lds[bump ? ((revMe + WG) % (2 * WG)) : (revMe + WG)] = u[0]; +#else +#error +#endif + + bar(); + for (i32 i = 0; i < NH/2; ++i) { u[i] = lds[i * WG + me]; } +} + +void OVERLOAD reverseLine(u32 WG, local GF61 *lds2, GF61 *u) { + u32 me = get_local_id(0); + u32 revMe = WG - 1 - me; + + local GF61 *lds = lds2 + revMe; + bar(); + for (u32 i = 0; i < NH; ++i) { lds[WG * (NH - 1 - i)] = u[i]; } + + lds = lds2 + me; + bar(); + for (u32 i = 0; i < NH; ++i) { u[i] = lds[WG * i]; } +} + +// This is used to reverse the second part of a line, and cross the reversed parts between the halves. +void OVERLOAD revCrossLine(u32 WG, local GF61* lds2, GF61 *u, u32 n, bool writeSecondHalf) { + u32 me = get_local_id(0); + u32 lowMe = me % WG; + + u32 revLowMe = WG - 1 - lowMe; + + for (u32 i = 0; i < n; ++i) { lds2[WG * n * writeSecondHalf + WG * (n - 1 - i) + revLowMe] = u[i]; } + + bar(); // we need a full bar because we're crossing halves + + for (u32 i = 0; i < n; ++i) { u[i] = lds2[WG * n * !writeSecondHalf + WG * i + lowMe]; } +} + +// +// These versions are for the kernel(s) that uses a double-wide workgroup (u in half the workgroup, v in the other half) +// + +void OVERLOAD reverse2(local GF61 *lds, GF61 *u) { + u32 me = get_local_id(0); + + // For NH=8, u[0] to u[3] are left unchanged. Write to lds: + // u[7]rev u[6]rev + // u[5]rev u[4]rev + // v[7]rev v[6]rev + // v[5]rev v[4]rev + bar(); + for (u32 i = 0; i < NH / 2; ++i) { + u32 j = (i * G_H + me % G_H); + lds[me < G_H ? ((NH/2)*G_H - j) % ((NH/2)*G_H) : NH*G_H-1 - j] = u[NH/2 + i]; + } + // For NH=8, read from lds into u[i]: + // u[4] = u[7]rev v[7]rev + // u[5] = u[6]rev v[6]rev + // u[6] = u[5]rev v[5]rev + // u[7] = u[4]rev v[4]rev + bar(); + lds += me % G_H + (me / G_H) * NH/2 * G_H; + for (u32 i = 0; i < NH / 2; ++i) { u[NH/2 + i] = lds[i * G_H]; } +} + +// Somewhat similar to reverseLine. +// The u values are in threads < G_H, the v values to reverse in threads >= G_H. +// Whereas reverseLine leaves u values alone. This reverseLine moves u values around +// so that pairSq2 can easily operate on pairs. This means for NH = 4, web output: +// u[0] u[1] // Returned in u[0] +// u[2] u[3] // Returned in u[1] +// v[3]rev v[2]rev // Returned in u[2] +// v[1]rev v[0]rev // Returned in u[3] +void OVERLOAD reverseLine2(local GF61 *lds, GF61 *u) { + u32 me = get_local_id(0); + +// NOTE: It is important that this routine use lds memory in coordination with shufl2. Failure to do so would require an +// unqualified bar() call here. Specifically, the u values are stored in the upper half of lds memory (SMALL_HEIGHT GF61 values). +// The v values are stored in the lower half of lds memory (the next SMALL_HEIGHT GF61 values). + + if (G_H > WAVEFRONT) bar(); + +// For NH=4, the lds indices (where to write each incoming u[i] which has v[i] in the upper threads) looks like this: +// 0..GH-1 +0*G_H GH-1..0 +7*G_H +// 0..GH-1 +1*G_H GH-1..0 +6*G_H +// 0..GH-1 +2*G_H GH-1..0 +5*G_H +// 0..GH-1 +3*G_H GH-1..0 +4*G_H +// That means saving to lds using index: me < G_H ? me % G_H + i * G_H : 8*G_H-1 - me % G_H - i * G_H + +#if 1 + local GF61 *ldsOut = lds + (me < G_H ? me % G_H : (NH*2)*G_H-1 - me % G_H); + i32 ldsOutInc = (me < G_H) ? G_H : -G_H; + for (u32 i = 0; i < NH; ++i, ldsOut += ldsOutInc) { *ldsOut = u[i]; } + + lds += me; + bar(); + for (u32 i = 0; i < NH; ++i) { u[i] = lds[i * 2*G_H]; } +#else + local Z61 *ldsOut = (local Z61 *) lds + (me < G_H ? me % G_H : (NH*2)*G_H-1 - me % G_H); + i32 ldsOutInc = (me < G_H) ? G_H : -G_H; + for (u32 i = 0; i < NH; ++i, ldsOut += ldsOutInc) { ldsOut[0] = u[i].x; ldsOut[NH*2*G_H] = u[i].y; } + + local ZF61 *ldsIn = (local T *) lds + me; + bar(); + for (u32 i = 0; i < NH; ++i) { u[i].x = ldsIn[i * 2*G_H]; u[i].y = ldsIn[NH*2*G_H + i * 2*G_H]; } +#endif +} + +// Undo a reverseLine2 +void OVERLOAD unreverseLine2(local GF61 *lds, GF61 *u) { + u32 me = get_local_id(0); + +// NOTE: It is important that this routine use lds memory in coordination with reverseLine2 and shufl2. By initially +// writing to the lds locations that reverseLine2 read from we do not need an initial bar() call here. Also, by reading +// from the lds locations that shufl2 will use (u values in the upper half of lds memory, v values in the lower half of +// lds memory) we can issue a qualified bar() call before calling FFT_HEIGHT2. + +#if 1 + local GF61 *ldsOut = lds + me; + for (u32 i = 0; i < NH; ++i) { ldsOut[i * 2*G_H] = u[i]; } + +// For NH=4, the lds indices (where to read each outgoing u[i] which has v[i] in the upper threads) looks like this: +// 0..GH-1 +0*G_H GH-1..0 +7*G_H +// 0..GH-1 +1*G_H GH-1..0 +6*G_H +// 0..GH-1 +2*G_H GH-1..0 +5*G_H +// 0..GH-1 +3*G_H GH-1..0 +4*G_H + lds += (me < G_H) ? me % G_H : (NH*2)*G_H-1 - me % G_H; + i32 ldsInc = (me < G_H) ? G_H : -G_H; + bar(); + for (u32 i = 0; i < NH; ++i, lds += ldsInc) { u[i] = *lds; } +#else + local Z61 *ldsOut = (local T *) lds + me; + for (u32 i = 0; i < NH; ++i) { ldsOut[i * 2*G_H] = u[i].x; ldsOut[NH*2*G_H + i * 2*G_H] = u[i].y; } + +// For NH=4, the lds indices (where to read each outgoing u[i] which has v[i] in the upper threads) looks like this: +// 0..GH-1 +0*G_H GH-1..0 +7*G_H +// 0..GH-1 +1*G_H GH-1..0 +6*G_H +// 0..GH-1 +2*G_H GH-1..0 +5*G_H +// 0..GH-1 +3*G_H GH-1..0 +4*G_H + local Z61 *ldsIn = (local T *) lds + ((me < G_H) ? me % G_H : (NH*2)*G_H-1 - me % G_H); + i32 ldsInc = (me < G_H) ? G_H : -G_H; + bar(); + for (u32 i = 0; i < NH; ++i, ldsIn += ldsInc) { u[i].x = ldsIn[0]; u[i].y = ldsIn[NH*2*G_H]; } +#endif +} + +#endif diff --git a/src/cl/transpose.cl b/src/cl/transpose.cl index be999393..8a1c9dae 100644 --- a/src/cl/transpose.cl +++ b/src/cl/transpose.cl @@ -2,8 +2,7 @@ #include "base.cl" -// Prototypes -void transposeWords(u32 W, u32 H, local Word2 *lds, global const Word2 *restrict in, global Word2 *restrict out); +#if WordSize <= 4 void transposeWords(u32 W, u32 H, local Word2 *lds, global const Word2 *restrict in, global Word2 *restrict out) { u32 GPW = W / 64, GPH = H / 64; @@ -38,3 +37,51 @@ KERNEL(64) transposeIn(P(Word2) out, CP(Word2) in) { local Word2 lds[4096]; transposeWords(BIG_HEIGHT, WIDTH, lds, in, out); } + +#else + +void transposeWords(u32 W, u32 H, local Word *lds, global const Word2 *restrict in, global Word2 *restrict out) { + u32 GPW = W / 64, GPH = H / 64; + + u32 g = get_group_id(0); + u32 gy = g % GPH; + u32 gx = g / GPH; + gx = (gy + gx) % GPW; + + in += 64 * W * gy + 64 * gx; + out += 64 * gy + 64 * H * gx; + u32 me = get_local_id(0); + #pragma unroll 1 + for (i32 i = 0; i < 64; ++i) { + lds[i * 64 + me] = in[i * W + me].x; + } + bar(); + #pragma unroll 1 + for (i32 i = 0; i < 64; ++i) { + out[i * H + me].x = lds[me * 64 + i]; + } + bar(); + #pragma unroll 1 + for (i32 i = 0; i < 64; ++i) { + lds[i * 64 + me] = in[i * W + me].y; + } + bar(); + #pragma unroll 1 + for (i32 i = 0; i < 64; ++i) { + out[i * H + me].y = lds[me * 64 + i]; + } +} + +// from transposed to sequential. +KERNEL(64) transposeOut(P(Word2) out, CP(Word2) in) { + local Word lds[4096]; + transposeWords(WIDTH, BIG_HEIGHT, lds, in, out); +} + +// from sequential to transposed. +KERNEL(64) transposeIn(P(Word2) out, CP(Word2) in) { + local Word lds[4096]; + transposeWords(BIG_HEIGHT, WIDTH, lds, in, out); +} + +#endif diff --git a/src/cl/trig.cl b/src/cl/trig.cl index 69d32a2c..ebdd2af9 100644 --- a/src/cl/trig.cl +++ b/src/cl/trig.cl @@ -4,7 +4,9 @@ #include "math.cl" -double2 reducedCosSin(int k, double cosBase) { +#if FFT_FP64 + +T2 reducedCosSin(int k, double cosBase) { const double S[] = TRIG_SIN; const double C[] = TRIG_COS; @@ -36,12 +38,12 @@ double2 reducedCosSin(int k, double cosBase) { return U2(c, s); } -double2 fancyTrig_N(u32 k) { +T2 fancyTrig_N(u32 k) { return reducedCosSin(k, 0); } // Returns e^(i * tau * k / n), (tau == 2*pi represents a full circle). So k/n is the ratio of a full circle. -double2 slowTrig_N(u32 k, u32 kBound) { +T2 OVERLOAD slowTrig_N(u32 k, u32 kBound) { u32 n = ND; assert(n % 8 == 0); assert(k < kBound); // kBound actually bounds k @@ -61,11 +63,88 @@ double2 slowTrig_N(u32 k, u32 kBound) { assert(k <= n / 8); - double2 r = reducedCosSin(k, 1); + T2 r = reducedCosSin(k, 1); - if (flip) { r = swap(r); } + if (flip) { r = SWAP_XY(r); } if (negateCos) { r.x = -r.x; } if (negate) { r = -r; } return r; } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 + +F2 reducedCosSin(int k, double cosBase) { + const float S[] = TRIG_SIN; + const float C[] = TRIG_COS; + + float x = k * TRIG_SCALE; + float z = x * x; + + float r1 = fma(S[7], z, S[6]); + float r2 = fma(C[7], z, C[6]); + + r1 = fma(r1, z, S[5]); + r2 = fma(r2, z, C[5]); + + r1 = fma(r1, z, S[4]); + r2 = fma(r2, z, C[4]); + + r1 = fma(r1, z, S[3]); + r2 = fma(r2, z, C[3]); + + r1 = fma(r1, z, S[2]); + r2 = fma(r2, z, C[2]); + + r1 = fma(r1, z, S[1]); + r2 = fma(r2, z, C[1]); + + r1 = r1 * x; + float c = fma(r2, z, (float) cosBase); + float s = fma(x, S[0], r1); + + return U2(c, s); +} + +F2 fancyTrig_N(u32 k) { + return reducedCosSin(k, 0); +} + +// Returns e^(i * tau * k / n), (tau == 2*pi represents a full circle). So k/n is the ratio of a full circle. +F2 OVERLOAD slowTrig_N(u32 k, u32 kBound) { + u32 n = ND; + assert(n % 8 == 0); + assert(k < kBound); // kBound actually bounds k + assert(kBound <= 2 * n); // angle <= 2 tau + + if (kBound > n && k >= n) { k -= n; } + assert(k < n); + + bool negate = kBound > n/2 && k >= n/2; + if (negate) { k -= n/2; } + + bool negateCos = kBound > n / 4 && k >= n / 4; + if (negateCos) { k = n/2 - k; } + + bool flip = kBound > n / 8 + 1 && k > n / 8; + if (flip) { k = n / 4 - k; } + + assert(k <= n / 8); + + F2 r = reducedCosSin(k, 1); + + if (flip) { r = SWAP_XY(r); } + if (negateCos) { r.x = -r.x; } + if (negate) { r = -r; } + + return r; +} + +#endif diff --git a/src/cl/weight.cl b/src/cl/weight.cl index 9967c936..66075bb7 100644 --- a/src/cl/weight.cl +++ b/src/cl/weight.cl @@ -3,6 +3,8 @@ #define STEP (NWORDS - (EXP % NWORDS)) // bool isBigWord(u32 extra) { return extra < NWORDS - STEP; } +#if FFT_FP64 + T fweightStep(u32 i) { const T TWO_TO_NTH[8] = { // 2^(k/8) -1 for k in [0..8) @@ -69,3 +71,92 @@ T optionalHalve(T w) { // return w >= 4 ? w / 2 : w; //u.y = bfi(u.y, 0xffefffff, 0); return as_double(u); } + +#endif + + +/**************************************************************************/ +/* Similar to above, but for an FFT based on FP32 */ +/**************************************************************************/ + +#if FFT_FP32 + +F fweightStep(u32 i) { + const F TWO_TO_NTH[8] = { + // 2^(k/8) -1 for k in [0..8) + 0, + 0.090507732665257662, + 0.18920711500272105, + 0.29683955465100964, + 0.41421356237309503, + 0.54221082540794086, + 0.68179283050742912, + 0.83400808640934243, + }; + return TWO_TO_NTH[i * STEP % NW * (8 / NW)]; +} + +F iweightStep(u32 i) { + const F TWO_TO_MINUS_NTH[8] = { + // 2^-(k/8) - 1 for k in [0..8) + 0, + -0.082995956795328771, + -0.15910358474628547, + -0.2288945872960296, + -0.29289321881345248, + -0.35158022267449518, + -0.40539644249863949, + -0.45474613366737116, + }; + return TWO_TO_MINUS_NTH[i * STEP % NW * (8 / NW)]; +} + +F optionalDouble(F iw) { + // In a straightforward implementation, inverse weights are between 0.5 and 1.0. We use inverse weights between 1.0 and 2.0 + // because it allows us to implement this routine with a single OR instruction on the exponent. The original implementation + // where this routine took as input values from 0.25 to 1.0 required both an AND and an OR instruction on the exponent. + // return iw <= 1.0 ? iw * 2 : iw; + assert(iw > 0.5 && iw < 2); + uint u = as_uint(iw); + u |= 0x00800000; + return as_float(u); +} + +F optionalHalve(F w) { // return w >= 4 ? w / 2 : w; + // In a straightforward implementation, weights are between 1.0 and 2.0. We use weights between 2.0 and 4.0 because + // it allows us to implement this routine with a single AND instruction on the exponent. The original implementation + // where this routine took as input values from 1.0 to 4.0 required both an AND and an OR instruction on the exponent. + assert(w >= 2 && w < 8); + uint u = as_uint(w); + u &= 0xFF7FFFFF; + return as_float(u); +} + +#endif + + +/**************************************************************************/ +/* Helper routines for NTT weight calculations */ +/**************************************************************************/ + +#if NTT_GF31 + +// if (weight_shift > 31) weight_shift -= 31; +// This version uses PTX instructions which may be faster on nVidia GPUs +u32 adjust_m31_weight_shift (u32 weight_shift) { + return optional_mod(weight_shift, 31); +} + +#endif + + +#if NTT_GF61 + +// if (weight_shift > 61) weight_shift -= 61; +// This version uses PTX instructions which may be faster on nVidia GPUs +u32 adjust_m61_weight_shift (u32 weight_shift) { + return optional_mod(weight_shift, 61); +} + +#endif + diff --git a/src/clwrap.cpp b/src/clwrap.cpp index 29cdab5d..93385e64 100644 --- a/src/clwrap.cpp +++ b/src/clwrap.cpp @@ -172,6 +172,12 @@ bool isAmdGpu(cl_device_id id) { return pcieId == 0x1002; } +bool isNvidiaGpu(cl_device_id id) { + u32 pcieId = 0; + GET_INFO(id, CL_DEVICE_VENDOR_ID, pcieId); + return pcieId == 0x10DE; +} + /* static string getFreq(cl_device_id device) { unsigned computeUnits, frequency; @@ -357,6 +363,12 @@ EventHolder fillBuf(cl_queue q, vector&& waits, return genEvent ? EventHolder{event} : EventHolder{}; } +EventHolder enqueueMarker(cl_queue q) { + cl_event event{}; + CHECK1(clEnqueueMarkerWithWaitList(q, 0, 0, &event)); + return EventHolder{event}; +} + void waitForEvents(vector&& waits) { if (!waits.empty()) { CHECK1(clWaitForEvents(waits.size(), waits.data())); diff --git a/src/clwrap.h b/src/clwrap.h index 919f3a4c..8b9bf9e4 100644 --- a/src/clwrap.h +++ b/src/clwrap.h @@ -62,6 +62,7 @@ float getGpuRamGB(cl_device_id id); u64 getFreeMem(cl_device_id id); bool hasFreeMemInfo(cl_device_id id); bool isAmdGpu(cl_device_id id); +bool isNvidiaGpu(cl_device_id id); string getDriverVersion(cl_device_id id); string getDriverVersionByPos(int pos); @@ -105,6 +106,8 @@ EventHolder copyBuf(cl_queue queue, vector&& waits, const cl_mem src, EventHolder fillBuf(cl_queue q, vector&& waits, cl_mem buf, const void *pat, size_t patSize, size_t size, bool genEvent); +EventHolder enqueueMarker(cl_queue q); + void waitForEvents(vector&& waits); diff --git a/src/common.h b/src/common.h index fbf5deda..516b099d 100644 --- a/src/common.h +++ b/src/common.h @@ -11,6 +11,8 @@ using i32 = int32_t; using u32 = uint32_t; using i64 = int64_t; using u64 = uint64_t; +using i128 = __int128; +using u128 = unsigned __int128; using f128 = __float128; static_assert(sizeof(u8) == 1, "size u8"); @@ -21,6 +23,19 @@ using namespace std; namespace std::filesystem{}; namespace fs = std::filesystem; +// When using multiple primes in an NTT the size of an integer FFT "word" can be 64 bits. Original FP64 FFT needs only 32 bits. +// C code will use i64 integer data. The code that reads and writes GPU buffers will downsize the integers to 32 bits when required. +typedef i64 Word; + +// Create datatype names that mimic the ones used in OpenCL code +using double2 = pair; +using float2 = pair; +using int2 = pair; +using uint = u32; +using uint2 = pair; +using ulong = u64; +using ulong2 = pair; + std::vector split(const string& s, char delim); string hex(u64 x); diff --git a/src/fftbpw.h b/src/fftbpw.h index 1e8712d9..c62db530 100644 --- a/src/fftbpw.h +++ b/src/fftbpw.h @@ -1,3 +1,4 @@ +// FFT64 - Computed by targeting Z=28 { "256:2:256", {19.204, 19.547, 19.636, 19.204, 19.547, 19.636}}, { "256:3:256", {19.106, 19.386, 19.369, 19.106, 19.399, 19.361}}, { "256:4:256", {18.928, 19.236, 19.322, 18.954, 19.272, 19.367}}, @@ -91,3 +92,117 @@ { "4K:14:1K", {16.942, 17.055, 17.160, 17.027, 17.127, 17.306}}, { "4K:15:1K", {17.021, 17.007, 17.137, 17.104, 17.087, 17.282}}, { "4K:16:1K", {16.744, 16.887, 16.966, 16.921, 17.048, 17.208}}, +// FFT3161 - Computed by targeting maxROE of ~0.35 over 1000 iterations, probably could go higher +{ "1:256:2:256", {40.54, 40.54, 40.54, 40.54, 40.54, 40.54}}, +{ "1:256:4:256", {40.19, 40.19, 40.19, 40.19, 40.19, 40.19}}, +{ "1:256:8:256", {39.98, 39.98, 39.98, 39.98, 39.98, 39.98}}, +{ "1:512:4:256", {39.98, 39.98, 39.98, 39.98, 39.98, 39.98}}, +{"1:256:16:256", {39.67, 39.67, 39.67, 39.67, 39.67, 39.67}}, +{ "1:512:8:256", {39.67, 39.67, 39.67, 39.67, 39.67, 39.67}}, +{ "1:512:4:512", {39.67, 39.67, 39.67, 39.67, 39.67, 39.67}}, +{ "1:1K:8:256", {39.46, 39.46, 39.46, 39.46, 39.46, 39.46}}, +{"1:512:16:256", {39.46, 39.46, 39.46, 39.46, 39.46, 39.46}}, +{ "1:512:8:512", {39.46, 39.46, 39.46, 39.46, 39.46, 39.46}}, +{ "1:1K:16:256", {39.15, 39.15, 39.15, 39.15, 39.15, 39.15}}, +{ "1:1K:8:512", {39.15, 39.15, 39.15, 39.15, 39.15, 39.15}}, +{"1:512:16:512", {39.15, 39.15, 39.15, 39.15, 39.15, 39.15}}, +{ "1:1K:16:512", {38.97, 38.97, 38.97, 38.97, 38.97, 38.97}}, +{ "1:1K:8:1K", {38.97, 38.97, 38.97, 38.97, 38.97, 38.97}}, +{ "1:1K:16:1K", {38.62, 38.62, 38.62, 38.62, 38.62, 38.62}}, +{ "1:4K:16:512", {38.37, 38.37, 38.37, 38.37, 38.37, 38.37}}, +{ "1:4K:16:1K", {38.12, 38.12, 38.12, 38.12, 38.12, 38.12}}, // Estimated +// FFT3261 - Computed with -use TABMUL_CHAIN32=0,TAIL_TRIGS32=0 and targeting maxROE of ~0.35 over 1000 iterations, probably could go higher +{ "2:256:2:256", {34.57, 34.57, 34.57, 34.57, 34.57, 34.57}}, +{ "2:256:4:256", {34.24, 34.24, 34.24, 34.24, 34.24, 34.24}}, +{ "2:256:8:256", {34.06, 34.06, 34.06, 34.06, 34.06, 34.06}}, +{ "2:512:4:256", {34.06, 34.06, 34.06, 34.06, 34.06, 34.06}}, +{"2:256:16:256", {32.07, 32.07, 32.07, 32.07, 32.07, 32.07}}, +{ "2:512:8:256", {32.07, 32.07, 32.07, 32.07, 32.07, 32.07}}, +{ "2:512:4:512", {32.07, 32.07, 32.07, 32.07, 32.07, 32.07}}, +{ "2:1K:8:256", {31.81, 31.81, 31.81, 31.81, 31.81, 31.81}}, +{"2:512:16:256", {31.81, 31.81, 31.81, 31.81, 31.81, 31.81}}, +{ "2:512:8:512", {31.81, 31.81, 31.81, 31.81, 31.81, 31.81}}, +{ "2:1K:16:256", {31.50, 31.50, 31.50, 31.50, 31.50, 31.50}}, +{ "2:1K:8:512", {31.50, 31.50, 31.50, 31.50, 31.50, 31.50}}, +{"2:512:16:512", {31.50, 31.50, 31.50, 31.50, 31.50, 31.50}}, +{ "2:1K:16:512", {28.68, 28.68, 28.68, 28.68, 28.68, 28.68}}, // Very strange. 481421001 has a maxROE of 0.180, 481422001 has a maxROE of 0.5 +{ "2:1K:8:1K", {28.68, 28.68, 28.68, 28.68, 28.68, 28.68}}, +{ "2:1K:16:1K", {25.37, 25.37, 25.37, 25.37, 25.37, 25.37}}, // Also strange. 851422001 ROEmax=0.273, ROEavg=0.003 +{ "2:4K:16:512", {23.27, 23.27, 23.27, 23.27, 23.27, 23.27}}, +{ "2:4K:16:1K", {21.15, 21.15, 21.15, 21.15, 21.15, 21.15}}, // Estimated +// FFT61 - Computed by targeting maxROE of ~0.35 over 1000 iterations, probably could go higher +{ "3:256:2:256", {25.02, 25.02, 25.02, 25.02, 25.02, 25.02}}, +{ "3:256:4:256", {24.72, 24.10, 24.10, 24.10, 24.10, 24.10}}, +{ "3:256:8:256", {24.46, 24.46, 24.46, 24.46, 24.46, 24.46}}, +{ "3:512:4:256", {24.46, 24.46, 24.46, 24.46, 24.46, 24.46}}, +{"3:256:16:256", {24.15, 24.15, 24.15, 24.15, 24.15, 24.15}}, +{ "3:512:8:256", {24.15, 24.15, 24.15, 24.15, 24.15, 24.15}}, +{ "3:512:4:512", {24.15, 24.15, 24.15, 24.15, 24.15, 24.15}}, +{ "3:1K:8:256", {23.94, 23.94, 23.94, 23.94, 23.94, 23.94}}, +{"3:512:16:256", {23.94, 23.94, 23.94, 23.94, 23.94, 23.94}}, +{ "3:512:8:512", {23.94, 23.94, 23.94, 23.94, 23.94, 23.94}}, +{ "3:1K:16:256", {23.65, 23.65, 23.65, 23.65, 23.65, 23.65}}, +{ "3:1K:8:512", {23.65, 23.65, 23.65, 23.65, 23.65, 23.65}}, +{"3:512:16:512", {23.65, 23.65, 23.65, 23.65, 23.65, 23.65}}, +{ "3:1K:16:512", {23.42, 23.42, 23.42, 23.42, 23.42, 23.42}}, +{ "3:1K:8:1K", {23.42, 23.42, 23.42, 23.42, 23.42, 23.42}}, +{ "3:1K:16:1K", {23.13, 23.13, 23.13, 23.13, 23.13, 23.13}}, +{ "3:4K:16:512", {22.92, 22.92, 22.92, 22.92, 22.92, 22.92}}, +{ "3:4K:16:1K", {22.72, 22.72, 22.72, 22.72, 22.72, 22.72}}, // Estimated +// FFT323161 - Computed with -use TABMUL_CHAIN32=0,TAIL_TRIGS32=0 and targeting maxROE of ~0.35 over 1000 iterations, probably could go higher +{ "4:256:2:256", {50.05, 50.05, 50.05, 50.05, 50.05, 50.05}}, +{ "4:256:4:256", {49.76, 49.76, 49.76, 49.76, 49.76, 49.76}}, +{ "4:256:8:256", {49.59, 49.59, 49.59, 49.59, 49.59, 49.59}}, +{ "4:512:4:256", {49.59, 49.59, 49.59, 49.59, 49.59, 49.59}}, +{"4:256:16:256", {47.59, 47.59, 47.59, 47.59, 47.59, 47.59}}, +{ "4:512:8:256", {47.59, 47.59, 47.59, 47.59, 47.59, 47.59}}, +{ "4:512:4:512", {47.59, 47.59, 47.59, 47.59, 47.59, 47.59}}, +{ "4:1K:8:256", {47.33, 47.33, 47.33, 47.33, 47.33, 47.33}}, +{"4:512:16:256", {47.33, 47.33, 47.33, 47.33, 47.33, 47.33}}, +{ "4:512:8:512", {47.33, 47.33, 47.33, 47.33, 47.33, 47.33}}, +{ "4:1K:16:256", {47.00, 47.00, 47.00, 47.00, 47.00, 47.00}}, +{ "4:1K:8:512", {47.00, 47.00, 47.00, 47.00, 47.00, 47.00}}, +{"4:512:16:512", {47.00, 47.00, 47.00, 47.00, 47.00, 47.00}}, +{ "4:1K:16:512", {44.52, 44.52, 44.52, 44.52, 44.52, 44.52}}, +{ "4:1K:8:1K", {44.52, 44.52, 44.52, 44.52, 44.52, 44.52}}, +{ "4:1K:16:1K", {41.72, 41.72, 41.72, 41.72, 41.72, 41.72}}, // Strange 41.72 has tiny error, 41.75 is 0.5 +{ "4:4K:16:512", {39.50, 39.50, 39.50, 39.50, 39.50, 39.50}}, // Estimated +{ "4:4K:16:1K", {37.50, 37.50, 37.50, 37.50, 37.50, 37.50}}, // Estimated +// FFT3231 - Computed with -use TABMUL_CHAIN32=0,TAIL_TRIGS32=0 and targeting maxROE of ~0.35 over 1000 iterations, probably could go higher +{ "50:256:2:256", {19.57, 19.57, 19.57, 19.57, 19.57, 19.57}}, +{ "50:256:4:256", {19.23, 19.23, 19.23, 19.23, 19.23, 19.23}}, +{ "50:256:8:256", {19.07, 19.07, 19.07, 19.07, 19.07, 19.07}}, +{ "50:512:4:256", {19.07, 19.07, 19.07, 19.07, 19.07, 19.07}}, +{"50:256:16:256", {17.07, 17.07, 17.07, 17.07, 17.07, 17.07}}, +{ "50:512:8:256", {17.07, 17.07, 17.07, 17.07, 17.07, 17.07}}, +{ "50:512:4:512", {17.07, 17.07, 17.07, 17.07, 17.07, 17.07}}, +{ "50:1K:8:256", {16.78, 16.78, 16.78, 16.78, 16.78, 16.78}}, +{"50:512:16:256", {16.78, 16.78, 16.78, 16.78, 16.78, 16.78}}, +{ "50:512:8:512", {16.78, 16.78, 16.78, 16.78, 16.78, 16.78}}, +{ "50:1K:16:256", {16.52, 16.52, 16.52, 16.52, 16.52, 16.52}}, +{ "50:1K:8:512", {16.52, 16.52, 16.52, 16.52, 16.52, 16.52}}, +{"50:512:16:512", {16.52, 16.52, 16.52, 16.52, 16.52, 16.52}}, +{ "50:1K:16:512", {14.01, 14.01, 14.01, 14.01, 14.01, 14.01}}, +{ "50:1K:8:1K", {14.01, 14.01, 14.01, 14.01, 14.01, 14.01}}, +{ "50:1K:16:1K", {11.21, 11.21, 11.21, 11.21, 11.21, 11.21}}, // Estimated +{ "50:4K:16:512", {9.15, 9.15, 9.15, 9.15, 9.15, 9.15}}, // Estimated +{ "50:4K:16:1K", {7.05, 7.05, 7.05, 7.05, 7.05, 7.05}}, // Estimated +// FFT6431 - Computed with variant 202 and targeting maxROE of ~0.35 over 1000 iterations, probably could go higher +{ "51:256:2:256", {35.27, 35.27, 35.27, 35.27, 35.27, 35.27}}, +{ "51:256:4:256", {34.98, 34.98, 34.98, 34.98, 34.98, 34.98}}, +{ "51:256:8:256", {34.61, 34.61, 34.61, 34.61, 34.61, 34.61}}, +{ "51:512:4:256", {34.61, 34.61, 34.61, 34.61, 34.61, 34.61}}, +{"51:256:16:256", {34.33, 34.33, 34.33, 34.33, 34.33, 34.33}}, +{ "51:512:8:256", {34.33, 34.33, 34.33, 34.33, 34.33, 34.33}}, +{ "51:512:4:512", {34.33, 34.33, 34.33, 34.33, 34.33, 34.33}}, +{ "51:1K:8:256", {33.97, 33.97, 33.97, 33.97, 33.97, 33.97}}, +{"51:512:16:256", {33.97, 33.97, 33.97, 33.97, 33.97, 33.97}}, +{ "51:512:8:512", {33.97, 33.97, 33.97, 33.97, 33.97, 33.97}}, +{ "51:1K:16:256", {33.64, 33.64, 33.64, 33.64, 33.64, 33.64}}, +{ "51:1K:8:512", {33.64, 33.64, 33.64, 33.64, 33.64, 33.64}}, +{"51:512:16:512", {33.64, 33.64, 33.64, 33.64, 33.64, 33.64}}, +{ "51:1K:16:512", {33.45, 33.45, 33.45, 33.45, 33.45, 33.45}}, +{ "51:1K:8:1K", {33.45, 33.45, 33.45, 33.45, 33.45, 33.45}}, +{ "51:1K:16:1K", {33.23, 33.23, 33.23, 33.23, 33.23, 33.23}}, +{ "51:4K:16:512", {32.75, 32.75, 32.75, 32.75, 32.75, 32.75}}, +{ "51:4K:16:1K", {32.25, 32.25, 32.25, 32.25, 32.25, 32.25}}, // Estimated diff --git a/src/main.cpp b/src/main.cpp index e1b00588..b62f1b32 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -20,8 +20,6 @@ #include // #include from GCC-13 onwards -namespace fs = std::filesystem; - void gpuWorker(GpuCommon shared, Queue *q, i32 instance) { // LogContext context{(instance ? shared.args->tailDir() : ""s) + to_string(instance) + ' '}; // log("Starting worker %d\n", instance); @@ -42,13 +40,15 @@ void gpuWorker(GpuCommon shared, Queue *q, i32 instance) { } -#ifdef __MINGW32__ // for Windows -extern int putenv(const char *); +#if defined(__MINGW32__) || defined(__MINGW64__) || defined(__MSYS__) // for Windows +extern int putenv(char *); #endif int main(int argc, char **argv) { -#ifdef __MINGW32__ +#if defined(__MSYS__) + // I was unable to get putenv to link in MSYS2 +#elif defined(__MINGW32__) || defined(__MINGW64__) putenv("ROC_SIGNAL_POOL_SIZE=32"); #else // Required to work around a ROCm bug when using multiple queues @@ -66,7 +66,7 @@ int main(int argc, char **argv) { fs::current_path(args.dir); } } - + fs::path poolDir; { Args args{true}; @@ -74,24 +74,27 @@ int main(int argc, char **argv) { args.parse(mainLine); poolDir = args.masterDir; } - + initLog("gpuowl-0.log"); log("PRPLL %s starting\n", VERSION); - + Args args; if (!poolDir.empty()) { args.readConfig(poolDir / "config.txt"); } args.readConfig("config.txt"); args.parse(mainLine); args.setDefaults(); - + if (args.maxAlloc) { AllocTrac::setMaxAlloc(args.maxAlloc); } Context context(getDevice(args.device)); - TrigBufCache bufCache{&context}; Signal signal; Background background; - GpuCommon shared{&args, &bufCache, &background}; + GpuCommon shared; + shared.args = &args; + TrigBufCache bufCache{&context}; + shared.bufCache = &bufCache; + shared.background = &background; if (args.doCtune || args.doTune || args.doZtune || args.carryTune) { Queue q(context, args.profile); diff --git a/src/state.cpp b/src/state.cpp index 617a2e20..4ac159d8 100644 --- a/src/state.cpp +++ b/src/state.cpp @@ -8,50 +8,47 @@ #include #include -static int lowBits(int u, int bits) { return (u << (32 - bits)) >> (32 - bits); } - -static u32 unbalance(int w, int nBits, int *carry) { - assert(*carry == 0 || *carry == -1); - w += *carry; - *carry = 0; - if (w < 0) { - w += (1 << nBits); - *carry = -1; - } - if (!(0 <= w && w < (1 << nBits))) { log("w=%d, nBits=%d\n", w, nBits); } - assert(0 <= w && w < (1 << nBits)); - return w; -} +static i64 lowBits(i64 u, int bits) { return (u << (64 - bits)) >> (64 - bits); } -std::vector compactBits(const vector &dataVect, u32 E) { +std::vector compactBits(const vector &dataVect, u32 E) { if (dataVect.empty()) { return {}; } // Indicating all zero + u32 N = dataVect.size(); + const Word *data = dataVect.data(); + std::vector out; out.reserve((E - 1) / 32 + 1); - u32 N = dataVect.size(); - const int *data = dataVect.data(); - int carry = 0; u32 outWord = 0; int haveBits = 0; + // Convert to compact form for (u32 p = 0; p < N; ++p) { int nBits = bitlen(N, E, p); - u32 w = unbalance(data[p], nBits, &carry); - assert(nBits > 0); - assert(w < (1u << nBits)); + + // Be careful adding in the carry -- it could overflow a 32-bit word. Convert value into desired unsigned range. + i64 tmp = (i64) data[p] + carry; + carry = (int) (tmp >> nBits); + u64 w = (u64) (tmp - ((i64) carry << nBits)); + assert(w < (1ULL << nBits)); assert(haveBits < 32); - int topBits = 32 - haveBits; - outWord |= w << haveBits; - if (nBits >= topBits) { - out.push_back(outWord); - outWord = w >> topBits; - haveBits = nBits - topBits; - } else { - haveBits += nBits; + while (nBits) { + int needBits = 32 - haveBits; + outWord |= w << haveBits; + if (nBits >= needBits) { + w >>= needBits; + nBits -= needBits; + out.push_back(outWord); + outWord = 0; + haveBits = 0; + } else { + haveBits += nBits; + w >>= nBits; + break; + } } } @@ -69,20 +66,20 @@ std::vector compactBits(const vector &dataVect, u32 E) { } struct BitBucket { - u64 bits; + u128 bits; u32 size; BitBucket() : bits(0), size(0) {} void put32(u32 b) { - assert(size <= 32); - bits += (u64(b) << size); + assert(size <= 96); + bits += (u128(b) << size); size += 32; } - int popSigned(u32 n) { + i64 popSigned(u32 n) { assert(size >= n); - int b = lowBits(bits, n); + i64 b = lowBits((i64) bits, n); size -= n; bits >>= n; bits += (b < 0); // carry fixup. @@ -90,22 +87,22 @@ struct BitBucket { } }; -vector expandBits(const vector &compactBits, u32 N, u32 E) { +vector expandBits(const vector &compactBits, u32 N, u32 E) { assert(E % 32 != 0); - std::vector out(N); - int *data = out.data(); + std::vector out(N); + Word *data = out.data(); BitBucket bucket; auto it = compactBits.cbegin(); [[maybe_unused]] auto itEnd = compactBits.cend(); for (u32 p = 0; p < N; ++p) { - u32 len = bitlen(N, E, p); - if (bucket.size < len) { + u32 len = bitlen(N, E, p); + while (bucket.size < len) { assert(it != itEnd); bucket.put32(*it++); } - data[p] = bucket.popSigned(len); + data[p] = (Word) bucket.popSigned(len); } assert(it == itEnd); assert(bucket.size == 32 - E % 32); diff --git a/src/state.h b/src/state.h index 4b32be04..9b37c0fd 100644 --- a/src/state.h +++ b/src/state.h @@ -8,8 +8,8 @@ #include #include -vector compactBits(const vector &dataVect, u32 E); -vector expandBits(const vector &compactBits, u32 N, u32 E); +vector compactBits(const vector &dataVect, u32 E); +vector expandBits(const vector &compactBits, u32 N, u32 E); constexpr u32 step(u32 N, u32 E) { return N - (E % N); } constexpr u32 extra(u32 N, u32 E, u32 k) { return u64(step(N, E)) * k % N; } diff --git a/src/tune.cpp b/src/tune.cpp index ff1b76d6..3b773437 100644 --- a/src/tune.cpp +++ b/src/tune.cpp @@ -14,6 +14,7 @@ #include #include #include +#include using std::accumulate; @@ -103,26 +104,26 @@ string formatConfigResults(const vector& results) { } // namespace -double Tune::maxBpw(FFTConfig fft) { +float Tune::maxBpw(FFTConfig fft) { -// double bpw = oldBpw; +// float bpw = oldBpw; - const double TARGET = 28; + const float TARGET = 28; const u32 sample_size = 5; // Estimate how much bpw needs to change to increase/decrease Z by 1. // This doesn't need to be a very accurate estimate. // This estimate comes from analyzing a 4M FFT and a 7.5M FFT. // The 4M FFT needed a .015 step, the 7.5M FFT needed a .012 step. - double bpw_step = .015 + (log2(fft.size()) - log2(4.0*1024*1024)) / (log2(7.5*1024*1024) - log2(4.0*1024*1024)) * (.012 - .015); + float bpw_step = .015 + (log2(fft.size()) - log2(4.0*1024*1024)) / (log2(7.5*1024*1024) - log2(4.0*1024*1024)) * (.012 - .015); // Pick a bpw that might be close to Z=34, it is best to err on the high side of Z=34 - double bpw1 = fft.maxBpw() - 9 * bpw_step; // Old bpw gave Z=28, we want Z=34 (or more) + float bpw1 = fft.maxBpw() - 9 * bpw_step; // Old bpw gave Z=28, we want Z=34 (or more) // The code below was used when building the maxBpw table from scratch // u32 non_best_width = N_VARIANT_W - 1 - variant_W(fft.variant); // Number of notches below best-Z width variant // u32 non_best_middle = N_VARIANT_M - 1 - variant_M(fft.variant); // Number of notches below best-Z middle variant -// double bpw1 = 18.3 - 0.275 * (log2(fft.size()) - log2(256 * 13 * 1024 * 2)) - // Default max bpw from an old gpuowl version +// float bpw1 = 18.3 - 0.275 * (log2(fft.size()) - log2(256 * 13 * 1024 * 2)) - // Default max bpw from an old gpuowl version // 9 * bpw_step - // Default above should give Z=28, we want Z=34 (or more) // (.08/.012 * bpw_step) * non_best_width - // 7.5M FFT has ~.08 bpw difference for each width variant below best variant // (.06 + .04 * (fft.shape.middle - 4) / 11) * non_best_middle; // Assume .1 bpw difference MIDDLE=15 and .06 for MIDDLE=4 @@ -130,11 +131,11 @@ double Tune::maxBpw(FFTConfig fft) { //if (fft.size() < 512000) bpw1 = 19, bpw_step = .02; // Fine tune our estimate for Z=34 - double z1 = zForBpw(bpw1, fft, 1); + float z1 = zForBpw(bpw1, fft, 1); printf ("Guess bpw for %s is %.2f first Z34 is %.2f\n", fft.spec().c_str(), bpw1, z1); while (z1 < 31.0 || z1 > 37.0) { - double prev_bpw1 = bpw1; - double prev_z1 = z1; + float prev_bpw1 = bpw1; + float prev_z1 = z1; bpw1 = bpw1 + (z1 - 34) * bpw_step; z1 = zForBpw(bpw1, fft, 1); printf ("Reguess bpw for %s is %.2f first Z34 is %.2f\n", fft.spec().c_str(), bpw1, z1); @@ -147,12 +148,12 @@ printf ("Reguess bpw for %s is %.2f first Z34 is %.2f\n", fft.spec().c_str(), bp z1 = (z1 + (sample_size - 1) * zForBpw(bpw1, fft, sample_size - 1)) / sample_size; // Pick a bpw somewhere near Z=22 then fine tune the guess - double bpw2 = bpw1 + (z1 - 22) * bpw_step; - double z2 = zForBpw(bpw2, fft, 1); + float bpw2 = bpw1 + (z1 - 22) * bpw_step; + float z2 = zForBpw(bpw2, fft, 1); printf ("Guess bpw for %s is %.2f first Z22 is %.2f\n", fft.spec().c_str(), bpw2, z2); while (z2 < 20.0 || z2 > 25.0) { - double prev_bpw2 = bpw2; - double prev_z2 = z2; + float prev_bpw2 = bpw2; + float prev_z2 = z2; // bool error_recovery = (z2 <= 0.0); // if (error_recovery) bpw2 -= bpw_step; else bpw2 = bpw2 + (z2 - 21) * bpw_step; @@ -171,12 +172,12 @@ printf ("Reguess bpw for %s is %.2f first Z22 is %.2f\n", fft.spec().c_str(), bp return bpw2 + (bpw1 - bpw2) * (TARGET - z2) / (z1 - z2); } -double Tune::zForBpw(double bpw, FFTConfig fft, u32 count) { +float Tune::zForBpw(float bpw, FFTConfig fft, u32 count) { u32 exponent = (count == 1) ? primes.prevPrime(fft.size() * bpw) : primes.nextPrime(fft.size() * bpw); - double total_z = 0.0; + float total_z = 0.0; for (u32 i = 0; i < count; i++, exponent = primes.nextPrime (exponent + 1)) { auto [ok, res, roeSq, roeMul] = Gpu::make(q, exponent, shared, fft, {}, false)->measureROE(true); - double z = roeSq.z(); + float z = roeSq.z(); total_z += z; log("Zforbpw %.2f (z %.2f) : %s\n", bpw, z, fft.spec().c_str()); if (!ok) { log("Error at bpw %.2f (z %.2f) : %s\n", bpw, z, fft.spec().c_str()); continue; } @@ -195,8 +196,8 @@ void Tune::ztune() { u32 variant = 202; u32 sample_size = 5; FFTConfig fft{shape, variant, CARRY_AUTO}; - for (double bpw = 18.18; bpw < 18.305; bpw += 0.02) { - double z = zForBpw(bpw, fft, sample_size); + for (float bpw = 18.18; bpw < 18.305; bpw += 0.02) { + float z = zForBpw(bpw, fft, sample_size); log ("Avg zForBpw %s %.2f %.2f\n", fft.spec().c_str(), bpw, z); } } @@ -213,7 +214,7 @@ void Tune::ztune() { if (shape.width > 1024) bpw_variants[0] = 100, bpw_variants[3] = 110; // Copy the existing bpw array (in case we're replacing only some of the entries) - array bpw; + array bpw; bpw = shape.bpw; // Not all shapes have their maximum bpw per-computed. But one can work on a non-favored shape by specifying it on the command line. @@ -244,10 +245,10 @@ void Tune::carryTune() { if (prevSize == fft.size()) { continue; } prevSize = fft.size(); - vector zv; + vector zv; double m = 0; - const double mid = fft.shape.carry32BPW(); - for (double bpw : {mid - 0.05, mid + 0.05}) { + const float mid = fft.shape.carry32BPW(); + for (float bpw : {mid - 0.05, mid + 0.05}) { u32 exponent = primes.nearestPrime(fft.size() * bpw); auto [ok, carry] = Gpu::make(q, exponent, shared, fft, {}, false)->measureCarry(); m = carry.max; @@ -255,7 +256,7 @@ void Tune::carryTune() { zv.push_back(carry.z()); } - double avg = (zv[0] + zv[1]) / 2; + float avg = (zv[0] + zv[1]) / 2; u32 exponent = fft.shape.carry32BPW() * fft.size(); double pErr100 = -expm1(-exp(-avg) * exponent * 100); log("%14s %.3f : %.3f (%.3f %.3f) %f %.0f%%\n", fft.spec().c_str(), mid, avg, zv[0], zv[1], m, pErr100 * 100); @@ -290,7 +291,7 @@ void Tune::ctune() { } for (FFTShape shape : shapes) { - FFTConfig fft{shape, 0, CARRY_32}; + FFTConfig fft{shape, 101, CARRY_32}; u32 exponent = primes.prevPrime(fft.maxExp()); // log("tuning %10s with exponent %u\n", fft.shape.spec().c_str(), exponent); @@ -325,288 +326,632 @@ void Tune::ctune() { log("\nBest configs (lines can be copied to config.txt):\n%s", formatConfigResults(results).c_str()); } +// Add better -use settings to list of changes to be made to config.txt +void configsUpdate(double current_cost, double best_cost, double threshold, const char *key, u32 value, vector> &newConfigKeyVals, vector> &suggestedConfigKeyVals) { + if (best_cost == current_cost) return; + // If best cost is better than current cost by a substantial margin (the threshold) then add the key value pair to suggestedConfigKeyVals + if (best_cost < (1.0 - threshold) * current_cost) + newConfigKeyVals.push_back({key, value}); + // Otherwise, add the key value pair to newConfigKeyVals + else + suggestedConfigKeyVals.push_back({key, value}); +} + void Tune::tune() { Args *args = shared.args; vector shapes = FFTShape::multiSpec(args->fftSpec); - + // There are some options and variants that are different based on GPU manufacturer bool AMDGPU = isAmdGpu(q->context->deviceId()); - // Look for best settings of various options + bool tune_config = 1; + bool time_FFTs = 0; + bool time_NTTs = 0; + bool time_FP32 = 1; + int quick = 7; // Run config from slowest (quick=1) to fastest (quick=10) + u64 min_exponent = 75000000; + u64 max_exponent = 350000000; + if (!args->fftSpec.empty()) { min_exponent = 0; max_exponent = 1000000000000ull; } + + // Parse input args + for (const string& s : split(args->tune, ',')) { + if (s.empty()) continue; + if (s == "noconfig") tune_config = 0; + if (s == "fp64") time_FFTs = 1; + if (s == "ntt") time_NTTs = 1; + if (s == "nofp32") time_FP32 = 0; + auto keyVal = split(s, '='); + if (keyVal.size() == 2) { + if (keyVal.front() == "quick") quick = stod(keyVal.back()); + if (keyVal.front() == "minexp") min_exponent = stoull(keyVal.back()); + if (keyVal.front() == "maxexp") max_exponent = stoull(keyVal.back()); + } + } + if (quick < 1) quick = 1; + if (quick > 10) quick = 10; + + // Look for best settings of various options. Append best settings to config.txt. + if (tune_config) { + vector> newConfigKeyVals; + vector> suggestedConfigKeyVals; + + // Select/init the default FFTshape(s) and FFTConfig(s) for optimal -use options testing + FFTShape defaultFFTShape, defaultNTTShape, *defaultShape; + + // If user gave us an fft-spec, use that to time options + if (!args->fftSpec.empty()) { + defaultShape = &shapes[0]; + if (shapes[0].fft_type == FFT64) { + defaultFFTShape = shapes[0]; + time_FFTs = 1; + } else { + defaultNTTShape = shapes[0]; + time_NTTs = 1; + } + } + // If user specified FP64-timings, time a wavefront exponent using an 7.5M FFT + // If user specified NTT-timings, time a wavefront exponent using an 4M M31+M61 NTT + else if (time_FFTs || time_NTTs) { + if (time_FFTs) { + defaultFFTShape = FFTShape(FFT64, 512, 15, 512); + defaultShape = &defaultFFTShape; + } + if (time_NTTs) { + defaultNTTShape = FFTShape(FFT3161, 512, 8, 512); + defaultShape = &defaultNTTShape; + } + } + // No user specifications. Time an FP64 FFT and a GF31*GF61 NTT to see if the GPU is more suited for FP64 work or NTT work. + else { + log("Checking whether this GPU is better suited for double-precision FFTs or integer NTTs.\n"); + defaultFFTShape = FFTShape(FFT64, 512, 16, 512); + FFTConfig fft{defaultFFTShape, 101, CARRY_32}; + double fp64_time = Gpu::make(q, 141000001, shared, fft, {}, false)->timePRP(quick); + log("Time for FP64 FFT %12s is %6.1f\n", fft.spec().c_str(), fp64_time); + defaultNTTShape = FFTShape(FFT3161, 512, 8, 512); + FFTConfig ntt{defaultNTTShape, 202, CARRY_AUTO}; + double ntt_time = Gpu::make(q, 141000001, shared, ntt, {}, false)->timePRP(quick); + log("Time for M31*M61 NTT %12s is %6.1f\n", ntt.spec().c_str(), ntt_time); + if (fp64_time < ntt_time) { + defaultShape = &defaultFFTShape; + time_FFTs = 1; + if (fp64_time < 0.80 * ntt_time) { + log("FP64 FFTs are significantly faster than integer NTTs. No NTT tuning will be performed.\n"); + } else { + log("FP64 FFTs are not significantly faster than integer NTTs. NTT tuning will be performed.\n"); + time_NTTs = 1; + } + } else { + defaultShape = &defaultNTTShape; + time_NTTs = 1; + if (fp64_time > 1.20 * ntt_time) { + log("FP64 FFTs are significantly slower than integer NTTs. No FP64 tuning will be performed.\n"); + } else { + log("FP64 FFTs are not significantly slower than integer NTTs. FP64 tuning will be performed.\n"); + time_FFTs = 1; + } + } + } + + log("\n"); + log("Beginning timing of various options. These settings will be appended to config.txt.\n"); + log("Please read config.txt after -tune completes.\n"); + log("\n"); - if (1) { - u32 variant = 101; + u32 variant = (defaultShape == &defaultFFTShape) ? 101 : 202; //GW: if fft spec on the command line specifies a variant then we should use that variant (I get some interesting results with 000 vs 101 vs 201 vs 202 likely due to rocm optimizer) - // Find best IN_WG,IN_SIZEX setting - if (1) { - const FFTShape& shape = shapes[0]; - FFTConfig fft{shape, variant, CARRY_32}; + // IN_WG/SIZEX, OUT_WG/SIZEX, PAD, MIDDLE_IN/OUT_LDS_TRANSPOSE apply only if INPLACE=0 + u32 current_inplace = args->value("INPLACE", 0); + args->flags["INPLACE"] = to_string(0); + + // Find best IN_WG,IN_SIZEX,OUT_WG,OUT_SIZEX settings + if (1/*option to time IN/OUT settings*/) { + FFTConfig fft{*defaultShape, variant, CARRY_AUTO}; u32 exponent = primes.prevPrime(fft.maxExp()); u32 best_in_wg = 0; u32 best_in_sizex = 0; + u32 current_in_wg = args->value("IN_WG", 128); + u32 current_in_sizex = args->value("IN_SIZEX", 16); double best_cost = -1.0; + double current_cost = -1.0; for (u32 in_wg : {64, 128, 256}) { for (u32 in_sizex : {8, 16, 32}) { - shared.args->flags["IN_WG"] = to_string(in_wg); - shared.args->flags["IN_SIZEX"] = to_string(in_sizex); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + args->flags["IN_WG"] = to_string(in_wg); + args->flags["IN_SIZEX"] = to_string(in_sizex); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); log("Time for %12s using IN_WG=%u, IN_SIZEX=%u is %6.1f\n", fft.spec().c_str(), in_wg, in_sizex, cost); - if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_in_wg = in_wg; best_in_sizex = in_sizex; } - } + if (in_wg == current_in_wg && in_sizex == current_in_sizex) current_cost = cost; + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_in_wg = in_wg; best_in_sizex = in_sizex; } + } } log("Best IN_WG, IN_SIZEX is %u, %u. Default is 128, 16.\n", best_in_wg, best_in_sizex); - shared.args->flags["IN_WG"] = to_string(best_in_wg); - shared.args->flags["IN_SIZEX"] = to_string(best_in_sizex); - } + configsUpdate(current_cost, best_cost, 0.003, "IN_WG", best_in_wg, newConfigKeyVals, suggestedConfigKeyVals); + configsUpdate(current_cost, best_cost, 0.003, "IN_SIZEX", best_in_sizex, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["IN_WG"] = to_string(best_in_wg); + args->flags["IN_SIZEX"] = to_string(best_in_sizex); - // Find best OUT_WG,OUT_SIZEX setting - if (1) { - const FFTShape& shape = shapes[0]; - FFTConfig fft{shape, variant, CARRY_32}; - u32 exponent = primes.prevPrime(fft.maxExp()); u32 best_out_wg = 0; u32 best_out_sizex = 0; - double best_cost = -1.0; + u32 current_out_wg = args->value("OUT_WG", 128); + u32 current_out_sizex = args->value("OUT_SIZEX", 16); + best_cost = -1.0; + current_cost = -1.0; for (u32 out_wg : {64, 128, 256}) { for (u32 out_sizex : {8, 16, 32}) { - shared.args->flags["OUT_WG"] = to_string(out_wg); - shared.args->flags["OUT_SIZEX"] = to_string(out_sizex); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + args->flags["OUT_WG"] = to_string(out_wg); + args->flags["OUT_SIZEX"] = to_string(out_sizex); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); log("Time for %12s using OUT_WG=%u, OUT_SIZEX=%u is %6.1f\n", fft.spec().c_str(), out_wg, out_sizex, cost); - if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_out_wg = out_wg; best_out_sizex = out_sizex; } - } + if (out_wg == current_out_wg && out_sizex == current_out_sizex) current_cost = cost; + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_out_wg = out_wg; best_out_sizex = out_sizex; } + } } log("Best OUT_WG, OUT_SIZEX is %u, %u. Default is 128, 16.\n", best_out_wg, best_out_sizex); - shared.args->flags["OUT_WG"] = to_string(best_out_wg); - shared.args->flags["OUT_SIZEX"] = to_string(best_out_sizex); + configsUpdate(current_cost, best_cost, 0.003, "OUT_WG", best_out_wg, newConfigKeyVals, suggestedConfigKeyVals); + configsUpdate(current_cost, best_cost, 0.003, "OUT_SIZEX", best_out_sizex, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["OUT_WG"] = to_string(best_out_wg); + args->flags["OUT_SIZEX"] = to_string(best_out_sizex); + } + + // Find best PAD setting. Default is 256 bytes for AMD, 0 for all others. + if (1) { + FFTConfig fft{*defaultShape, variant, CARRY_AUTO}; + u32 exponent = primes.prevPrime(fft.maxExp()); + u32 best_pad = 0; + u32 current_pad = args->value("PAD", AMDGPU ? 256 : 0); + double best_cost = -1.0; + double current_cost = -1.0; + for (u32 pad : {0, 64, 128, 256, 512}) { + args->flags["PAD"] = to_string(pad); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); + log("Time for %12s using PAD=%u is %6.1f\n", fft.spec().c_str(), pad, cost); + if (pad == current_pad) current_cost = cost; + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_pad = pad; } + } + log("Best PAD is %u bytes. Default PAD is %u bytes.\n", best_pad, AMDGPU ? 256 : 0); + configsUpdate(current_cost, best_cost, 0.000, "PAD", best_pad, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["PAD"] = to_string(best_pad); + } + + // Find best MIDDLE_IN_LDS_TRANSPOSE setting + if (1) { + FFTConfig fft{*defaultShape, variant, CARRY_AUTO}; + u32 exponent = primes.prevPrime(fft.maxExp()); + u32 best_middle_in_lds_transpose = 0; + u32 current_middle_in_lds_transpose = args->value("MIDDLE_IN_LDS_TRANSPOSE", 1); + double best_cost = -1.0; + double current_cost = -1.0; + for (u32 middle_in_lds_transpose : {0, 1}) { + args->flags["MIDDLE_IN_LDS_TRANSPOSE"] = to_string(middle_in_lds_transpose); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); + log("Time for %12s using MIDDLE_IN_LDS_TRANSPOSE=%u is %6.1f\n", fft.spec().c_str(), middle_in_lds_transpose, cost); + if (middle_in_lds_transpose == current_middle_in_lds_transpose) current_cost = cost; + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_middle_in_lds_transpose = middle_in_lds_transpose; } + } + log("Best MIDDLE_IN_LDS_TRANSPOSE is %u. Default MIDDLE_IN_LDS_TRANSPOSE is 1.\n", best_middle_in_lds_transpose); + configsUpdate(current_cost, best_cost, 0.000, "MIDDLE_IN_LDS_TRANSPOSE", best_middle_in_lds_transpose, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["MIDDLE_IN_LDS_TRANSPOSE"] = to_string(best_middle_in_lds_transpose); + } + + // Find best MIDDLE_OUT_LDS_TRANSPOSE setting + if (1) { + FFTConfig fft{*defaultShape, variant, CARRY_AUTO}; + u32 exponent = primes.prevPrime(fft.maxExp()); + u32 best_middle_out_lds_transpose = 0; + u32 current_middle_out_lds_transpose = args->value("MIDDLE_OUT_LDS_TRANSPOSE", 1); + double best_cost = -1.0; + double current_cost = -1.0; + for (u32 middle_out_lds_transpose : {0, 1}) { + args->flags["MIDDLE_OUT_LDS_TRANSPOSE"] = to_string(middle_out_lds_transpose); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); + log("Time for %12s using MIDDLE_OUT_LDS_TRANSPOSE=%u is %6.1f\n", fft.spec().c_str(), middle_out_lds_transpose, cost); + if (middle_out_lds_transpose == current_middle_out_lds_transpose) current_cost = cost; + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_middle_out_lds_transpose = middle_out_lds_transpose; } + } + log("Best MIDDLE_OUT_LDS_TRANSPOSE is %u. Default MIDDLE_OUT_LDS_TRANSPOSE is 1.\n", best_middle_out_lds_transpose); + configsUpdate(current_cost, best_cost, 0.000, "MIDDLE_OUT_LDS_TRANSPOSE", best_middle_out_lds_transpose, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["MIDDLE_OUT_LDS_TRANSPOSE"] = to_string(best_middle_out_lds_transpose); + } + + // Find best INPLACE setting + if (1) { + FFTConfig fft{*defaultShape, variant, CARRY_AUTO}; + u32 exponent = primes.prevPrime(fft.maxExp()); + u32 best_inplace = 0; + double best_cost = -1.0; + double current_cost = -1.0; + for (u32 inplace : {0, 1}) { + args->flags["INPLACE"] = to_string(inplace); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); + log("Time for %12s using INPLACE=%u is %6.1f\n", fft.spec().c_str(), inplace, cost); + if (inplace == current_inplace) current_cost = cost; + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_inplace = inplace; } + } + log("Best INPLACE is %u. Default INPLACE is 0. Best INPLACE setting may be different for other FFT lengths.\n", best_inplace); + configsUpdate(current_cost, best_cost, 0.002, "INPLACE", best_inplace, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["INPLACE"] = to_string(best_inplace); + } + + // Find best NONTEMPORAL setting + if (1) { + FFTConfig fft{*defaultShape, variant, CARRY_AUTO}; + u32 exponent = primes.prevPrime(fft.maxExp()); + u32 best_nontemporal = 0; + u32 current_nontemporal = args->value("NONTEMPORAL", 0); + double best_cost = -1.0; + double current_cost = -1.0; + for (u32 nontemporal : {0, 1, 2}) { + args->flags["NONTEMPORAL"] = to_string(nontemporal); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); + log("Time for %12s using NONTEMPORAL=%u is %6.1f\n", fft.spec().c_str(), nontemporal, cost); + if (nontemporal == current_nontemporal) current_cost = cost; + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_nontemporal = nontemporal; } + } + log("Best NONTEMPORAL is %u. Default NONTEMPORAL is 0.\n", best_nontemporal); + configsUpdate(current_cost, best_cost, 0.000, "NONTEMPORAL", best_nontemporal, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["NONTEMPORAL"] = to_string(best_nontemporal); } // Find best FAST_BARRIER setting - if (1 && AMDGPU) { - const FFTShape& shape = shapes[0]; - FFTConfig fft{shape, variant, CARRY_32}; + if (AMDGPU) { + FFTConfig fft{*defaultShape, variant, CARRY_AUTO}; u32 exponent = primes.prevPrime(fft.maxExp()); u32 best_fast_barrier = 0; + u32 current_fast_barrier = args->value("FAST_BARRIER", 0); double best_cost = -1.0; + double current_cost = -1.0; for (u32 fast_barrier : {0, 1}) { - shared.args->flags["FAST_BARRIER"] = to_string(fast_barrier); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + args->flags["FAST_BARRIER"] = to_string(fast_barrier); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); log("Time for %12s using FAST_BARRIER=%u is %6.1f\n", fft.spec().c_str(), fast_barrier, cost); + if (fast_barrier == current_fast_barrier) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_fast_barrier = fast_barrier; } } log("Best FAST_BARRIER is %u. Default FAST_BARRIER is 0.\n", best_fast_barrier); - shared.args->flags["FAST_BARRIER"] = to_string(best_fast_barrier); + configsUpdate(current_cost, best_cost, 0.000, "FAST_BARRIER", best_fast_barrier, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["FAST_BARRIER"] = to_string(best_fast_barrier); } - // Find best TAIL_TRIGS setting + // Find best TAIL_KERNELS setting if (1) { - const FFTShape& shape = shapes[0]; - FFTConfig fft{shape, variant, CARRY_32}; + FFTConfig fft{*defaultShape, variant, CARRY_AUTO}; + u32 exponent = primes.prevPrime(fft.maxExp()); + u32 best_tail_kernels = 0; + u32 current_tail_kernels = args->value("TAIL_KERNELS", 2); + double best_cost = -1.0; + double current_cost = -1.0; + for (u32 tail_kernels : {0, 1, 2, 3}) { + args->flags["TAIL_KERNELS"] = to_string(tail_kernels); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); + log("Time for %12s using TAIL_KERNELS=%u is %6.1f\n", fft.spec().c_str(), tail_kernels, cost); + if (tail_kernels == current_tail_kernels) current_cost = cost; + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_tail_kernels = tail_kernels; } + } + if (best_tail_kernels & 1) + log("Best TAIL_KERNELS is %u. Default TAIL_KERNELS is 2.\n", best_tail_kernels); + else + log("Best TAIL_KERNELS is %u (but best may be %u when running two workers on one GPU). Default TAIL_KERNELS is 2.\n", best_tail_kernels, best_tail_kernels | 1); + configsUpdate(current_cost, best_cost, 0.000, "TAIL_KERNELS", best_tail_kernels, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["TAIL_KERNELS"] = to_string(best_tail_kernels); + } + + // Find best TAIL_TRIGS setting + if (time_FFTs) { + FFTConfig fft{defaultFFTShape, 101, CARRY_AUTO}; u32 exponent = primes.prevPrime(fft.maxExp()); u32 best_tail_trigs = 0; + u32 current_tail_trigs = args->value("TAIL_TRIGS", 2); double best_cost = -1.0; + double current_cost = -1.0; for (u32 tail_trigs : {0, 1, 2}) { - shared.args->flags["TAIL_TRIGS"] = to_string(tail_trigs); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + args->flags["TAIL_TRIGS"] = to_string(tail_trigs); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); log("Time for %12s using TAIL_TRIGS=%u is %6.1f\n", fft.spec().c_str(), tail_trigs, cost); + if (tail_trigs == current_tail_trigs) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_tail_trigs = tail_trigs; } } log("Best TAIL_TRIGS is %u. Default TAIL_TRIGS is 2.\n", best_tail_trigs); - shared.args->flags["TAIL_TRIGS"] = to_string(best_tail_trigs); + configsUpdate(current_cost, best_cost, 0.003, "TAIL_TRIGS", best_tail_trigs, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["TAIL_TRIGS"] = to_string(best_tail_trigs); } - // Find best TAIL_KERNELS setting - if (1) { - const FFTShape& shape = shapes[0]; - FFTConfig fft{shape, variant, CARRY_32}; + // Find best TAIL_TRIGS31 setting + if (time_NTTs) { + FFTConfig fft{defaultNTTShape, 202, CARRY_AUTO}; + if (!fft.NTT_GF31) fft = FFTConfig(FFTShape(FFT3161, 512, 8, 512), 202, CARRY_AUTO); u32 exponent = primes.prevPrime(fft.maxExp()); - u32 best_tail_kernels = 0; + u32 best_tail_trigs = 0; + u32 current_tail_trigs = args->value("TAIL_TRIGS31", 0); double best_cost = -1.0; - for (u32 tail_kernels : {0, 1, 2, 3}) { - shared.args->flags["TAIL_KERNELS"] = to_string(tail_kernels); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); - log("Time for %12s using TAIL_KERNELS=%u is %6.1f\n", fft.spec().c_str(), tail_kernels, cost); - if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_tail_kernels = tail_kernels; } + double current_cost = -1.0; + for (u32 tail_trigs : {0, 1}) { + args->flags["TAIL_TRIGS31"] = to_string(tail_trigs); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); + log("Time for %12s using TAIL_TRIGS31=%u is %6.1f\n", fft.spec().c_str(), tail_trigs, cost); + if (tail_trigs == current_tail_trigs) current_cost = cost; + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_tail_trigs = tail_trigs; } } - if (best_tail_kernels & 1) - log("Best TAIL_KERNELS is %u. Default TAIL_KERNELS is 2.\n", best_tail_kernels); - else - log("Best TAIL_KERNELS is %u (but best may be %u when running two workers on one GPU). Default TAIL_KERNELS is 2.\n", best_tail_kernels, best_tail_kernels | 1); - shared.args->flags["TAIL_KERNELS"] = to_string(best_tail_kernels); + log("Best TAIL_TRIGS31 is %u. Default TAIL_TRIGS31 is 0.\n", best_tail_trigs); + configsUpdate(current_cost, best_cost, 0.003, "TAIL_TRIGS31", best_tail_trigs, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["TAIL_TRIGS31"] = to_string(best_tail_trigs); + } + + // Find best TAIL_TRIGS32 setting + if (time_NTTs && time_FP32) { + FFTConfig fft{defaultNTTShape, 202, CARRY_AUTO}; + if (!fft.FFT_FP32) fft = FFTConfig(FFTShape(FFT3261, 512, 8, 512), 202, CARRY_AUTO); + u32 exponent = primes.prevPrime(fft.maxBpw() * 0.95 * fft.shape.size()); // Back off the maxExp as different settings will have different maxBpw + u32 best_tail_trigs = 0; + u32 current_tail_trigs = args->value("TAIL_TRIGS32", 2); + double best_cost = -1.0; + double current_cost = -1.0; + for (u32 tail_trigs : {0, 1, 2}) { + args->flags["TAIL_TRIGS32"] = to_string(tail_trigs); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); + log("Time for %12s using TAIL_TRIGS32=%u is %6.1f\n", fft.spec().c_str(), tail_trigs, cost); + if (tail_trigs == current_tail_trigs) current_cost = cost; + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_tail_trigs = tail_trigs; } + } + log("Best TAIL_TRIGS32 is %u. Default TAIL_TRIGS32 is 2.\n", best_tail_trigs); + configsUpdate(current_cost, best_cost, 0.003, "TAIL_TRIGS32", best_tail_trigs, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["TAIL_TRIGS32"] = to_string(best_tail_trigs); + } + + // Find best TAIL_TRIGS61 setting + if (time_NTTs) { + FFTConfig fft{defaultNTTShape, 202, CARRY_AUTO}; + if (!fft.NTT_GF61) fft = FFTConfig(FFTShape(FFT3161, 512, 8, 512), 202, CARRY_AUTO); + u32 exponent = primes.prevPrime(fft.maxExp()); + u32 best_tail_trigs = 0; + u32 current_tail_trigs = args->value("TAIL_TRIGS61", 0); + double best_cost = -1.0; + double current_cost = -1.0; + for (u32 tail_trigs : {0, 1}) { + args->flags["TAIL_TRIGS61"] = to_string(tail_trigs); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); + log("Time for %12s using TAIL_TRIGS61=%u is %6.1f\n", fft.spec().c_str(), tail_trigs, cost); + if (tail_trigs == current_tail_trigs) current_cost = cost; + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_tail_trigs = tail_trigs; } + } + log("Best TAIL_TRIGS61 is %u. Default TAIL_TRIGS61 is 0.\n", best_tail_trigs); + configsUpdate(current_cost, best_cost, 0.003, "TAIL_TRIGS61", best_tail_trigs, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["TAIL_TRIGS61"] = to_string(best_tail_trigs); } // Find best TABMUL_CHAIN setting - if (1) { - const FFTShape& shape = shapes[0]; - FFTConfig fft{shape, 101, CARRY_32}; + if (time_FFTs) { + FFTConfig fft{defaultFFTShape, 101, CARRY_AUTO}; u32 exponent = primes.prevPrime(fft.maxExp()); u32 best_tabmul_chain = 0; + u32 current_tabmul_chain = args->value("TABMUL_CHAIN", 0); double best_cost = -1.0; + double current_cost = -1.0; for (u32 tabmul_chain : {0, 1}) { - shared.args->flags["TABMUL_CHAIN"] = to_string(tabmul_chain); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + args->flags["TABMUL_CHAIN"] = to_string(tabmul_chain); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); log("Time for %12s using TABMUL_CHAIN=%u is %6.1f\n", fft.spec().c_str(), tabmul_chain, cost); + if (tabmul_chain == current_tabmul_chain) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_tabmul_chain = tabmul_chain; } } log("Best TABMUL_CHAIN is %u. Default TABMUL_CHAIN is 0.\n", best_tabmul_chain); - shared.args->flags["TABMUL_CHAIN"] = to_string(best_tabmul_chain); + configsUpdate(current_cost, best_cost, 0.003, "TABMUL_CHAIN", best_tabmul_chain, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["TABMUL_CHAIN"] = to_string(best_tabmul_chain); } - // Find best PAD setting. Default is 256 bytes for AMD, 0 for all others. - if (1) { - const FFTShape& shape = shapes[0]; - FFTConfig fft{shape, variant, CARRY_32}; + // Find best TABMUL_CHAIN31 setting + if (time_NTTs) { + FFTConfig fft{defaultNTTShape, 202, CARRY_AUTO}; + if (!fft.NTT_GF31) fft = FFTConfig(FFTShape(FFT3161, 512, 8, 512), 202, CARRY_AUTO); u32 exponent = primes.prevPrime(fft.maxExp()); - u32 best_pad = 0; + u32 best_tabmul_chain = 0; + u32 current_tabmul_chain = args->value("TABMUL_CHAIN31", 0); double best_cost = -1.0; - for (u32 pad : {0, 64, 128, 256, 512}) { - shared.args->flags["PAD"] = to_string(pad); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); - log("Time for %12s using PAD=%u is %6.1f\n", fft.spec().c_str(), pad, cost); - if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_pad = pad; } + double current_cost = -1.0; + for (u32 tabmul_chain : {0, 1}) { + args->flags["TABMUL_CHAIN31"] = to_string(tabmul_chain); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); + log("Time for %12s using TABMUL_CHAIN31=%u is %6.1f\n", fft.spec().c_str(), tabmul_chain, cost); + if (tabmul_chain == current_tabmul_chain) current_cost = cost; + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_tabmul_chain = tabmul_chain; } } - log("Best PAD is %u bytes. Default PAD is %u bytes.\n", best_pad, AMDGPU ? 256 : 0); - shared.args->flags["PAD"] = to_string(best_pad); + log("Best TABMUL_CHAIN31 is %u. Default TABMUL_CHAIN31 is 0.\n", best_tabmul_chain); + configsUpdate(current_cost, best_cost, 0.003, "TABMUL_CHAIN31", best_tabmul_chain, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["TABMUL_CHAIN31"] = to_string(best_tabmul_chain); } - // Find best NONTEMPORAL setting - if (1) { - const FFTShape& shape = shapes[0]; - FFTConfig fft{shape, variant, CARRY_32}; + // Find best TABMUL_CHAIN32 setting + if (time_NTTs && time_FP32) { + FFTConfig fft{defaultNTTShape, 202, CARRY_AUTO}; + if (!fft.FFT_FP32) fft = FFTConfig(FFTShape(FFT3261, 512, 8, 512), 202, CARRY_AUTO); + u32 exponent = primes.prevPrime(fft.maxBpw() * 0.95 * fft.shape.size()); // Back off the maxExp as different settings will have different maxBpw + u32 best_tabmul_chain = 0; + u32 current_tabmul_chain = args->value("TABMUL_CHAIN32", 0); + double best_cost = -1.0; + double current_cost = -1.0; + for (u32 tabmul_chain : {0, 1}) { + args->flags["TABMUL_CHAIN32"] = to_string(tabmul_chain); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); + log("Time for %12s using TABMUL_CHAIN32=%u is %6.1f\n", fft.spec().c_str(), tabmul_chain, cost); + if (tabmul_chain == current_tabmul_chain) current_cost = cost; + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_tabmul_chain = tabmul_chain; } + } + log("Best TABMUL_CHAIN32 is %u. Default TABMUL_CHAIN32 is 0.\n", best_tabmul_chain); + configsUpdate(current_cost, best_cost, 0.003, "TABMUL_CHAIN32", best_tabmul_chain, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["TABMUL_CHAIN32"] = to_string(best_tabmul_chain); + } + + // Find best TABMUL_CHAIN61 setting + if (time_NTTs) { + FFTConfig fft{defaultNTTShape, 202, CARRY_AUTO}; + if (!fft.NTT_GF61) fft = FFTConfig(FFTShape(FFT3161, 512, 8, 512), 202, CARRY_AUTO); u32 exponent = primes.prevPrime(fft.maxExp()); - u32 best_nontemporal = 0; + u32 best_tabmul_chain = 0; + u32 current_tabmul_chain = args->value("TABMUL_CHAIN61", 0); double best_cost = -1.0; - for (u32 nontemporal : {0, 1}) { - shared.args->flags["NONTEMPORAL"] = to_string(nontemporal); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); - log("Time for %12s using NONTEMPORAL=%u is %6.1f\n", fft.spec().c_str(), nontemporal, cost); - if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_nontemporal = nontemporal; } + double current_cost = -1.0; + for (u32 tabmul_chain : {0, 1}) { + args->flags["TABMUL_CHAIN61"] = to_string(tabmul_chain); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); + log("Time for %12s using TABMUL_CHAIN61=%u is %6.1f\n", fft.spec().c_str(), tabmul_chain, cost); + if (tabmul_chain == current_tabmul_chain) current_cost = cost; + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_tabmul_chain = tabmul_chain; } } - log("Best NONTEMPORAL is %u. Default NONTEMPORAL is 0.\n", best_nontemporal); - shared.args->flags["NONTEMPORAL"] = to_string(best_nontemporal); + log("Best TABMUL_CHAIN61 is %u. Default TABMUL_CHAIN61 is 0.\n", best_tabmul_chain); + configsUpdate(current_cost, best_cost, 0.003, "TABMUL_CHAIN61", best_tabmul_chain, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["TABMUL_CHAIN61"] = to_string(best_tabmul_chain); + } + + // Find best MODM31 setting + if (time_NTTs) { + FFTConfig fft{defaultNTTShape, 202, CARRY_AUTO}; + if (!fft.NTT_GF31) fft = FFTConfig(FFTShape(FFT3161, 512, 8, 512), 202, CARRY_AUTO); + u32 exponent = primes.prevPrime(fft.maxExp()); + u32 best_modm31 = 0; + u32 current_modm31 = args->value("MODM31", 0); + double best_cost = -1.0; + double current_cost = -1.0; + for (u32 modm31 : {0, 1, 2}) { + args->flags["MODM31"] = to_string(modm31); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); + log("Time for %12s using MODM31=%u is %6.1f\n", fft.spec().c_str(), modm31, cost); + if (modm31 == current_modm31) current_cost = cost; + if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_modm31 = modm31; } + } + log("Best MODM31 is %u. Default MODM31 is 0.\n", best_modm31); + configsUpdate(current_cost, best_cost, 0.003, "MODM31", best_modm31, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["MODM31"] = to_string(best_modm31); } // Find best UNROLL_W setting if (1) { - const FFTShape& shape = shapes[0]; - FFTConfig fft{shape, variant, CARRY_32}; + FFTConfig fft{*defaultShape, variant, CARRY_AUTO}; u32 exponent = primes.prevPrime(fft.maxExp()); u32 best_unroll_w = 0; + u32 current_unroll_w = args->value("UNROLL_W", AMDGPU ? 0 : 1); double best_cost = -1.0; + double current_cost = -1.0; for (u32 unroll_w : {0, 1}) { - shared.args->flags["UNROLL_W"] = to_string(unroll_w); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + args->flags["UNROLL_W"] = to_string(unroll_w); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); log("Time for %12s using UNROLL_W=%u is %6.1f\n", fft.spec().c_str(), unroll_w, cost); + if (unroll_w == current_unroll_w) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_unroll_w = unroll_w; } } log("Best UNROLL_W is %u. Default UNROLL_W is %u.\n", best_unroll_w, AMDGPU ? 0 : 1); - shared.args->flags["UNROLL_W"] = to_string(best_unroll_w); + configsUpdate(current_cost, best_cost, 0.003, "UNROLL_W", best_unroll_w, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["UNROLL_W"] = to_string(best_unroll_w); } // Find best UNROLL_H setting if (1) { - const FFTShape& shape = shapes[0]; - FFTConfig fft{shape, variant, CARRY_32}; + FFTConfig fft{*defaultShape, variant, CARRY_AUTO}; u32 exponent = primes.prevPrime(fft.maxExp()); u32 best_unroll_h = 0; + u32 current_unroll_h = args->value("UNROLL_H", AMDGPU && defaultShape->height >= 1024 ? 0 : 1); double best_cost = -1.0; + double current_cost = -1.0; for (u32 unroll_h : {0, 1}) { - shared.args->flags["UNROLL_H"] = to_string(unroll_h); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + args->flags["UNROLL_H"] = to_string(unroll_h); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); log("Time for %12s using UNROLL_H=%u is %6.1f\n", fft.spec().c_str(), unroll_h, cost); + if (unroll_h == current_unroll_h) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_unroll_h = unroll_h; } } - log("Best UNROLL_H is %u. Default UNROLL_H is %u.\n", best_unroll_h, AMDGPU && shape.height >= 1024 ? 0 : 1); - shared.args->flags["UNROLL_H"] = to_string(best_unroll_h); + log("Best UNROLL_H is %u. Default UNROLL_H is %u.\n", best_unroll_h, AMDGPU && defaultShape->height >= 1024 ? 0 : 1); + configsUpdate(current_cost, best_cost, 0.003, "UNROLL_H", best_unroll_h, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["UNROLL_H"] = to_string(best_unroll_h); } // Find best ZEROHACK_W setting if (1) { - const FFTShape& shape = shapes[0]; - FFTConfig fft{shape, variant, CARRY_32}; + FFTConfig fft{*defaultShape, variant, CARRY_AUTO}; u32 exponent = primes.prevPrime(fft.maxExp()); u32 best_zerohack_w = 0; + u32 current_zerohack_w = args->value("ZEROHACK_W", 1); double best_cost = -1.0; + double current_cost = -1.0; for (u32 zerohack_w : {0, 1}) { - shared.args->flags["ZEROHACK_W"] = to_string(zerohack_w); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + args->flags["ZEROHACK_W"] = to_string(zerohack_w); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); log("Time for %12s using ZEROHACK_W=%u is %6.1f\n", fft.spec().c_str(), zerohack_w, cost); + if (zerohack_w == current_zerohack_w) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_zerohack_w = zerohack_w; } } log("Best ZEROHACK_W is %u. Default ZEROHACK_W is 1.\n", best_zerohack_w); - shared.args->flags["ZEROHACK_W"] = to_string(best_zerohack_w); + configsUpdate(current_cost, best_cost, 0.003, "ZEROHACK_W", best_zerohack_w, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["ZEROHACK_W"] = to_string(best_zerohack_w); } // Find best ZEROHACK_H setting if (1) { - const FFTShape& shape = shapes[0]; - FFTConfig fft{shape, variant, CARRY_32}; + FFTConfig fft{*defaultShape, variant, CARRY_AUTO}; u32 exponent = primes.prevPrime(fft.maxExp()); u32 best_zerohack_h = 0; + u32 current_zerohack_h = args->value("ZEROHACK_H", 1); double best_cost = -1.0; + double current_cost = -1.0; for (u32 zerohack_h : {0, 1}) { - shared.args->flags["ZEROHACK_H"] = to_string(zerohack_h); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + args->flags["ZEROHACK_H"] = to_string(zerohack_h); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); log("Time for %12s using ZEROHACK_H=%u is %6.1f\n", fft.spec().c_str(), zerohack_h, cost); + if (zerohack_h == current_zerohack_h) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_zerohack_h = zerohack_h; } } log("Best ZEROHACK_H is %u. Default ZEROHACK_H is 1.\n", best_zerohack_h); - shared.args->flags["ZEROHACK_H"] = to_string(best_zerohack_h); - } - - // Find best MIDDLE_IN_LDS_TRANSPOSE setting - if (1) { - const FFTShape& shape = shapes[0]; - FFTConfig fft{shape, variant, CARRY_32}; - u32 exponent = primes.prevPrime(fft.maxExp()); - u32 best_middle_in_lds_transpose = 0; - double best_cost = -1.0; - for (u32 middle_in_lds_transpose : {0, 1}) { - shared.args->flags["MIDDLE_IN_LDS_TRANSPOSE"] = to_string(middle_in_lds_transpose); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); - log("Time for %12s using MIDDLE_IN_LDS_TRANSPOSE=%u is %6.1f\n", fft.spec().c_str(), middle_in_lds_transpose, cost); - if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_middle_in_lds_transpose = middle_in_lds_transpose; } - } - log("Best MIDDLE_IN_LDS_TRANSPOSE is %u. Default MIDDLE_IN_LDS_TRANSPOSE is 1.\n", best_middle_in_lds_transpose); - shared.args->flags["MIDDLE_IN_LDS_TRANSPOSE"] = to_string(best_middle_in_lds_transpose); - } - - // Find best MIDDLE_OUT_LDS_TRANSPOSE setting - if (1) { - const FFTShape& shape = shapes[0]; - FFTConfig fft{shape, variant, CARRY_32}; - u32 exponent = primes.prevPrime(fft.maxExp()); - u32 best_middle_out_lds_transpose = 0; - double best_cost = -1.0; - for (u32 middle_out_lds_transpose : {0, 1}) { - shared.args->flags["MIDDLE_OUT_LDS_TRANSPOSE"] = to_string(middle_out_lds_transpose); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); - log("Time for %12s using MIDDLE_OUT_LDS_TRANSPOSE=%u is %6.1f\n", fft.spec().c_str(), middle_out_lds_transpose, cost); - if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_middle_out_lds_transpose = middle_out_lds_transpose; } - } - log("Best MIDDLE_OUT_LDS_TRANSPOSE is %u. Default MIDDLE_OUT_LDS_TRANSPOSE is 1.\n", best_middle_out_lds_transpose); - shared.args->flags["MIDDLE_OUT_LDS_TRANSPOSE"] = to_string(best_middle_out_lds_transpose); + configsUpdate(current_cost, best_cost, 0.003, "ZEROHACK_H", best_zerohack_h, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["ZEROHACK_H"] = to_string(best_zerohack_h); } // Find best BIGLIT setting - if (1) { - const FFTShape& shape = shapes[0]; - FFTConfig fft{shape, variant, CARRY_32}; + if (time_FFTs) { + FFTConfig fft{*defaultShape, 101, CARRY_AUTO}; u32 exponent = primes.prevPrime(fft.maxExp()); u32 best_biglit = 0; + u32 current_biglit = args->value("BIGLIT", 1); double best_cost = -1.0; + double current_cost = -1.0; for (u32 biglit : {0, 1}) { - shared.args->flags["BIGLIT"] = to_string(biglit); - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + args->flags["BIGLIT"] = to_string(biglit); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); log("Time for %12s using BIGLIT=%u is %6.1f\n", fft.spec().c_str(), biglit, cost); + if (biglit == current_biglit) current_cost = cost; if (best_cost < 0.0 || cost < best_cost) { best_cost = cost; best_biglit = biglit; } } log("Best BIGLIT is %u. Default BIGLIT is 1. The BIGLIT=0 option will probably be deprecated.\n", best_biglit); - shared.args->flags["BIGLIT"] = to_string(best_biglit); + configsUpdate(current_cost, best_cost, 0.003, "BIGLIT", best_biglit, newConfigKeyVals, suggestedConfigKeyVals); + args->flags["BIGLIT"] = to_string(best_biglit); } - //GW: Time some IN/OUT_WG/SIZEX combos? + // Output new settings to config.txt + File config = File::openAppend("config.txt"); + if (newConfigKeyVals.size()) { + config.write("\n# New settings based on a -tune run."); + for (u32 i = 0; i < newConfigKeyVals.size(); ++i) { + config.write(i == 0 ? "\n -use " : ","); + config.printf("%s=%u", newConfigKeyVals[i].first.c_str(), newConfigKeyVals[i].second); + } + config.write("\n"); + } + if (suggestedConfigKeyVals.size()) { + config.write("\n# These settings were slightly faster in a -tune run."); + config.write("\n# It is suggested that each setting be timed over a longer duration to see if the setting really is faster."); + for (u32 i = 0; i < suggestedConfigKeyVals.size(); ++i) { + config.write(i == 0 ? "\n# -use " : ","); + config.printf("%s=%u", suggestedConfigKeyVals[i].first.c_str(), suggestedConfigKeyVals[i].second); + } + config.write("\n"); + } + if (args->logStep < 100000) { + config.write("\n# Less frequent save file creation improves throughput."); + config.write("\n -log 1000000\n"); + } + if (args->workers < 2) { + config.write("\n# Running two workers sometimes gives better throughput. Autoprimenet will need to create up a second worktodo file."); + config.write("\n# -workers 2\n"); + config.write("\n# Changing TAIL_KERNELS to 3 when running two workers may be better."); + config.write("\n# -use TAIL_KERNELS=3\n"); + } } // Flags that prune the amount of shapes and variants to time. @@ -623,13 +968,11 @@ void Tune::tune() { skip_some_WH_variants = 2; // should default be 1?? skip_1K_256 = 0; -//GW: Suggest tuning with TAIL_KERNELS=2 even if production runs use TAIL_KERNELS=3 - - // For each width, time the 001, 101, and 201 variants to find the fastest width variant. + // For each width, time the 001, 101, and 201 FP64 variants to find the fastest width variant. // In an ideal world we'd use the -time feature and look at the kCarryFused timing. Then we'd save this info in config.txt or tune.txt. map fastest_width_variants; - // For each height, time the 100, 101, and 102 variants to find the fastest height variant. + // For each height, time the 100, 101, and 102 FP64 variants to find the fastest height variant. // In an ideal world we'd use the -time feature and look at the tailSquare timing. Then we'd save this info in config.txt or tune.txt. map fastest_height_variants; @@ -638,22 +981,43 @@ skip_1K_256 = 0; // Loop through all possible FFT shapes for (const FFTShape& shape : shapes) { + // Skip some FFTs and NTTs + if (shape.fft_type == FFT64 && !time_FFTs) continue; + if (shape.fft_type != FFT64 && !time_NTTs) continue; + if ((shape.fft_type == FFT3261 || shape.fft_type == FFT323161 || shape.fft_type == FFT3231 || shape.fft_type == FFT32) && !time_FP32) continue; + // Time an exponent that's good for all variants and carry-config. u32 exponent = primes.prevPrime(FFTConfig{shape, shape.width <= 1024 ? 0u : 100u, CARRY_32}.maxExp()); + u32 adjusted_quick = (exponent < 50000000) ? quick - 1 : (exponent < 170000000) ? quick : (exponent < 350000000) ? quick + 1 : quick + 2; + if (adjusted_quick < 1) adjusted_quick = 1; + if (adjusted_quick > 10) adjusted_quick = 10; - // Loop through all possible variants + // Loop through all possible variants for (u32 variant = 0; variant <= LAST_VARIANT; variant = next_variant (variant)) { - // Only AMD GPUs support variant zero (BCAST) and only if width <= 1024. + // Only FP64 code supports variants + if (variant != 202 && !FFTConfig{shape, variant, CARRY_AUTO}.FFT_FP64) continue; + + // Only AMD GPUs support variant zero (BCAST) and only if width <= 1024. CLANG doesn't support builtins. Let NO_ASM bypass variant zero. if (variant_W(variant) == 0) { if (!AMDGPU) continue; if (shape.width > 1024) continue; + if (args->value("NO_ASM", 0)) continue; } // Only AMD GPUs support variant zero (BCAST) and only if height <= 1024. if (variant_H(variant) == 0) { if (!AMDGPU) continue; if (shape.height > 1024) continue; + if (args->value("NO_ASM", 0)) continue; + } + + // Reject shapes that won't be used to test exponents in the user's desired range + { + FFTConfig fft{shape, variant, CARRY_AUTO}; + if (fft.maxExp() < min_exponent) continue; + if (fft.maxExp() > 2*max_exponent) continue; + if (shape.fft_type == FFT64 && fft.maxExp() > 1.2*max_exponent) continue; } // If only one shape was specified on the command line, time it. This lets the user time any shape, including non-favored ones. @@ -667,7 +1031,7 @@ skip_1K_256 = 0; // Skip variants where width or height are not using the fastest variant. // NOTE: We ought to offer a tune=option where we also test more accurate variants to extend the FFT's max exponent. - if (skip_some_WH_variants) { + if (skip_some_WH_variants && FFTConfig{shape, variant, CARRY_AUTO}.FFT_FP64) { u32 fastest_width = 1; if (auto it = fastest_width_variants.find(shape.width); it != fastest_width_variants.end()) { fastest_width = it->second; @@ -678,7 +1042,7 @@ skip_1K_256 = 0; if (w == 0 && !AMDGPU) continue; if (w == 0 && test.width > 1024) continue; FFTConfig fft{test, variant_WMH (w, 0, 1), CARRY_32}; - cost = Gpu::make(q, primes.prevPrime(fft.maxExp()), shared, fft, {}, false)->timePRP(); + cost = Gpu::make(q, primes.prevPrime(fft.maxExp()), shared, fft, {}, false)->timePRP(adjusted_quick); log("Fast width search %6.1f %12s\n", cost, fft.spec().c_str()); if (min_cost < 0.0 || cost < min_cost) { min_cost = cost; fastest_width = w; } } @@ -689,7 +1053,7 @@ skip_1K_256 = 0; FFTConfig{shape, variant, CARRY_32}.maxBpw() < FFTConfig{shape, variant_WMH (fastest_width, variant_M(variant), variant_H(variant)), CARRY_32}.maxBpw()) continue; } - if (skip_some_WH_variants) { + if (skip_some_WH_variants && FFTConfig{shape, variant, CARRY_AUTO}.FFT_FP64) { u32 fastest_height = 1; if (auto it = fastest_height_variants.find(shape.height); it != fastest_height_variants.end()) { fastest_height = it->second; @@ -700,7 +1064,7 @@ skip_1K_256 = 0; if (h == 0 && !AMDGPU) continue; if (h == 0 && test.height > 1024) continue; FFTConfig fft{test, variant_WMH (1, 0, h), CARRY_32}; - cost = Gpu::make(q, primes.prevPrime(fft.maxExp()), shared, fft, {}, false)->timePRP(); + cost = Gpu::make(q, primes.prevPrime(fft.maxExp()), shared, fft, {}, false)->timePRP(quick); log("Fast height search %6.1f %12s\n", cost, fft.spec().c_str()); if (min_cost < 0.0 || cost < min_cost) { min_cost = cost; fastest_height = h; } } @@ -715,10 +1079,13 @@ skip_1K_256 = 0; //GW: If variant is specified on command line, time it (and only it)?? Or an option to only time one variant number?? - vector carryToTest{CARRY_32}; - // We need to test both carry-32 and carry-64 only when the carry transition is within the BPW range. - if (FFTConfig{shape, variant, CARRY_64}.maxBpw() > FFTConfig{shape, variant, CARRY_32}.maxBpw()) { - carryToTest.push_back(CARRY_64); + vector carryToTest{CARRY_AUTO}; + if (shape.fft_type == FFT64) { + carryToTest[0] = CARRY_32; + // We need to test both carry-32 and carry-64 only when the carry transition is within the BPW range. + if (FFTConfig{shape, variant, CARRY_64}.maxBpw() > FFTConfig{shape, variant, CARRY_32}.maxBpw()) { + carryToTest.push_back(CARRY_64); + } } for (auto carry : carryToTest) { @@ -727,12 +1094,11 @@ skip_1K_256 = 0; // Skip middle = 1, CARRY_32 if maximum exponent would be the same as middle = 0, CARRY_32 if (variant_M(variant) > 0 && carry == CARRY_32 && fft.maxExp() <= FFTConfig{shape, variant - 10, CARRY_32}.maxExp()) continue; - double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(); + double cost = Gpu::make(q, exponent, shared, fft, {}, false)->timePRP(quick); bool isUseful = TuneEntry{cost, fft}.update(results); - log("%c %6.1f %12s %9u\n", isUseful ? '*' : ' ', cost, fft.spec().c_str(), fft.maxExp()); + log("%c %6.1f %12s %9" PRIu64 "\n", isUseful ? '*' : ' ', cost, fft.spec().c_str(), fft.maxExp()); + if (isUseful) TuneEntry::writeTuneFile(results); } } } - - TuneEntry::writeTuneFile(results); } diff --git a/src/tune.h b/src/tune.h index de50bf52..a64f2100 100644 --- a/src/tune.h +++ b/src/tune.h @@ -22,8 +22,8 @@ class Tune { GpuCommon shared; Primes primes; - double maxBpw(FFTConfig fft); - double zForBpw(double bpw, FFTConfig fft, u32); + float maxBpw(FFTConfig fft); + float zForBpw(float bpw, FFTConfig fft, u32); public: Tune(Queue *q, GpuCommon shared) : q{q}, shared{shared} {}