Skip to content

C2: spill scalar live ranges before high-pressure vector loops#64

Draft
raneashay wants to merge 1 commit intomicrosoft:mainfrom
raneashay:ashay/reduce-vector-spills
Draft

C2: spill scalar live ranges before high-pressure vector loops#64
raneashay wants to merge 1 commit intomicrosoft:mainfrom
raneashay:ashay/reduce-vector-spills

Conversation

@raneashay
Copy link
Copy Markdown

This patch adds a new pass that spills scalar values residing in vector
registers to memory so that hot loops have complete access to the vector
register file. Specifically, on both X64 and AArch64, scalar
floating-point values often reside in vector registers (e.g. XMM0-XMM15
on AVX2 or D0-D31 on AArch64), and if these scalar live ranges overlap
with vector live ranges inside a hot loop, then the loop will have fewer
vector registers to work with, resulting in frequent spilling and
restoring of the vector registers within the hot loop.

The newly introduced pass in this patch first analyzes loops that are
both high frequency and have a high register pressure, before finding
scalar live ranges that have no definitions inside the loop and which
overlap with the vector register file. Such live ranges are then split
so that the scalar values are spilled at the beginning of the loop and
restored at the end of the loop.

I have validated this using an AVX2 machine using a vectorized 3x4 DGEMM
outer product kernel, which shows vector spills in the inner loop
dropping from 17 to 2. However, it doesn't look there is a good way to
add a JTreg test for this change, since we don't seem to have a reliable
way to identify spill and restore assembly instructions using IRNode.
One alternative is to parse the output of PrintOptoAssembly but that
approach seems very fragile, especially since a source-level loop is
often broken down into multiple pre/main/post loops.

This patch adds a new pass that spills scalar values residing in vector
registers to memory so that hot loops have complete access to the vector
register file.  Specifically, on both X64 and AArch64, scalar
floating-point values often reside in vector registers (e.g. XMM0-XMM15
on AVX2 or D0-D31 on AArch64), and if these scalar live ranges overlap
with vector live ranges inside a hot loop, then the loop will have fewer
vector registers to work with, resulting in frequent spilling and
restoring of the vector registers within the hot loop.

The newly introduced pass in this patch first analyzes loops that are
both high frequency and have a high register pressure, before finding
scalar live ranges that have no definitions inside the loop and which
overlap with the vector register file.  Such live ranges are then split
so that the scalar values are spilled at the beginning of the loop and
restored at the end of the loop.

I have validated this using an AVX2 machine using a vectorized 3x4 DGEMM
outer product kernel, which shows vector spills in the inner loop
dropping from 17 to 2.  However, it doesn't look there is a good way to
add a JTreg test for this change, since we don't seem to have a reliable
way to identify spill and restore assembly instructions using `IRNode`.
One alternative is to parse the output of `PrintOptoAssembly` but that
approach seems very fragile, especially since a source-level loop is
often broken down into multiple pre/main/post loops.
@raneashay
Copy link
Copy Markdown
Author

In lieu of a functional test, here's a sample kernel (derived from Ludovic's BLAS implementation) that demonstrates the improvement (see running times below).

import jdk.incubator.vector.DoubleVector;
import jdk.incubator.vector.FloatVector;
import jdk.incubator.vector.VectorOperators;
import jdk.incubator.vector.VectorSpecies;


public class test {
    private static final VectorSpecies<Double> DMAX = DoubleVector.SPECIES_MAX;

    protected static int loopAlign(int index, int max, int size) {
        return Math.min(loopBound(index + size - 1, size), max);
    }

    protected static int loopBound(int index, int size) {
        return index - (index % size);
    }

    static void kernel(int m, int rows, int rowe, int n, int cols, int cole, int k, int is, int ie, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
        final int Ti = 1;

        assert rowe - rows == 3;
        assert cole - cols == 4;

        int row = rows;
        int col = cols;
        int i = is;
        double sum00 = 0.0;
        double sum01 = 0.0;
        double sum02 = 0.0;
        double sum03 = 0.0;
        double sum10 = 0.0;
        double sum11 = 0.0;
        double sum12 = 0.0;
        double sum13 = 0.0;
        double sum20 = 0.0;
        double sum21 = 0.0;
        double sum22 = 0.0;
        double sum23 = 0.0;
        for (; i < loopAlign(is, ie, Ti * DMAX.length()); i += 1) {
            double a00 = a[offseta + (i + 0) + (row + 0) * lda];
            double a01 = a[offseta + (i + 0) + (row + 1) * lda];
            double a02 = a[offseta + (i + 0) + (row + 2) * lda];
            double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
            sum00 = Math.fma(a00, b00, sum00);
            sum10 = Math.fma(a01, b00, sum10);
            sum20 = Math.fma(a02, b00, sum20);
            double b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
            sum01 = Math.fma(a00, b01, sum01);
            sum11 = Math.fma(a01, b01, sum11);
            sum21 = Math.fma(a02, b01, sum21);
            double b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
            sum02 = Math.fma(a00, b02, sum02);
            sum12 = Math.fma(a01, b02, sum12);
            sum22 = Math.fma(a02, b02, sum22);
            double b03 = b[offsetb + (i + 0) + (col + 3) * ldb];
            sum03 = Math.fma(a00, b03, sum03);
            sum13 = Math.fma(a01, b03, sum13);
            sum23 = Math.fma(a02, b03, sum23);
        }
        DoubleVector vsum00 = DoubleVector.zero(DMAX);
        DoubleVector vsum01 = DoubleVector.zero(DMAX);
        DoubleVector vsum02 = DoubleVector.zero(DMAX);
        DoubleVector vsum03 = DoubleVector.zero(DMAX);
        DoubleVector vsum10 = DoubleVector.zero(DMAX);
        DoubleVector vsum11 = DoubleVector.zero(DMAX);
        DoubleVector vsum12 = DoubleVector.zero(DMAX);
        DoubleVector vsum13 = DoubleVector.zero(DMAX);
        DoubleVector vsum20 = DoubleVector.zero(DMAX);
        DoubleVector vsum21 = DoubleVector.zero(DMAX);
        DoubleVector vsum22 = DoubleVector.zero(DMAX);
        DoubleVector vsum23 = DoubleVector.zero(DMAX);
        for (; i < loopBound(ie, Ti * DMAX.length()); i += Ti * DMAX.length()) {
            DoubleVector va00 = DoubleVector.fromArray(DMAX, a, offseta + (i + 0 * DMAX.length()) + (row + 0) * lda);
            DoubleVector va01 = DoubleVector.fromArray(DMAX, a, offseta + (i + 0 * DMAX.length()) + (row + 1) * lda);
            DoubleVector va02 = DoubleVector.fromArray(DMAX, a, offseta + (i + 0 * DMAX.length()) + (row + 2) * lda);
            DoubleVector vb00 = DoubleVector.fromArray(DMAX, b, offsetb + (i + 0 * DMAX.length()) + (col + 0) * ldb);
            vsum00 = va00.fma(vb00, vsum00);
            vsum10 = va01.fma(vb00, vsum10);
            vsum20 = va02.fma(vb00, vsum20);
            DoubleVector vb01 = DoubleVector.fromArray(DMAX, b, offsetb + (i + 0 * DMAX.length()) + (col + 1) * ldb);
            vsum01 = va00.fma(vb01, vsum01);
            vsum11 = va01.fma(vb01, vsum11);
            vsum21 = va02.fma(vb01, vsum21);
            DoubleVector vb02 = DoubleVector.fromArray(DMAX, b, offsetb + (i + 0 * DMAX.length()) + (col + 2) * ldb);
            vsum02 = va00.fma(vb02, vsum02);
            vsum12 = va01.fma(vb02, vsum12);
            vsum22 = va02.fma(vb02, vsum22);
            DoubleVector vb03 = DoubleVector.fromArray(DMAX, b, offsetb + (i + 0 * DMAX.length()) + (col + 3) * ldb);
            vsum03 = va00.fma(vb03, vsum03);
            vsum13 = va01.fma(vb03, vsum13);
            vsum23 = va02.fma(vb03, vsum23);
        }
        sum00 += vsum00.reduceLanes(VectorOperators.ADD);
        sum01 += vsum01.reduceLanes(VectorOperators.ADD);
        sum02 += vsum02.reduceLanes(VectorOperators.ADD);
        sum03 += vsum03.reduceLanes(VectorOperators.ADD);
        sum10 += vsum10.reduceLanes(VectorOperators.ADD);
        sum11 += vsum11.reduceLanes(VectorOperators.ADD);
        sum12 += vsum12.reduceLanes(VectorOperators.ADD);
        sum13 += vsum13.reduceLanes(VectorOperators.ADD);
        sum20 += vsum20.reduceLanes(VectorOperators.ADD);
        sum21 += vsum21.reduceLanes(VectorOperators.ADD);
        sum22 += vsum22.reduceLanes(VectorOperators.ADD);
        sum23 += vsum23.reduceLanes(VectorOperators.ADD);
        for (; i < ie; i += 1) {
            double a00 = a[offseta + (i + 0) + (row + 0) * lda];
            double a01 = a[offseta + (i + 0) + (row + 1) * lda];
            double a02 = a[offseta + (i + 0) + (row + 2) * lda];
            double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
            sum00 = Math.fma(a00, b00, sum00);
            sum10 = Math.fma(a01, b00, sum10);
            sum20 = Math.fma(a02, b00, sum20);
            double b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
            sum01 = Math.fma(a00, b01, sum01);
            sum11 = Math.fma(a01, b01, sum11);
            sum21 = Math.fma(a02, b01, sum21);
            double b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
            sum02 = Math.fma(a00, b02, sum02);
            sum12 = Math.fma(a01, b02, sum12);
            sum22 = Math.fma(a02, b02, sum22);
            double b03 = b[offsetb + (i + 0) + (col + 3) * ldb];
            sum03 = Math.fma(a00, b03, sum03);
            sum13 = Math.fma(a01, b03, sum13);
            sum23 = Math.fma(a02, b03, sum23);
        }
        c[offsetc + (row + 0) + (col + 0) * ldc] = Math.fma(alpha, sum00, c[offsetc + (row + 0) + (col + 0) * ldc]);
        c[offsetc + (row + 0) + (col + 1) * ldc] = Math.fma(alpha, sum01, c[offsetc + (row + 0) + (col + 1) * ldc]);
        c[offsetc + (row + 0) + (col + 2) * ldc] = Math.fma(alpha, sum02, c[offsetc + (row + 0) + (col + 2) * ldc]);
        c[offsetc + (row + 0) + (col + 3) * ldc] = Math.fma(alpha, sum03, c[offsetc + (row + 0) + (col + 3) * ldc]);
        c[offsetc + (row + 1) + (col + 0) * ldc] = Math.fma(alpha, sum10, c[offsetc + (row + 1) + (col + 0) * ldc]);
        c[offsetc + (row + 1) + (col + 1) * ldc] = Math.fma(alpha, sum11, c[offsetc + (row + 1) + (col + 1) * ldc]);
        c[offsetc + (row + 1) + (col + 2) * ldc] = Math.fma(alpha, sum12, c[offsetc + (row + 1) + (col + 2) * ldc]);
        c[offsetc + (row + 1) + (col + 3) * ldc] = Math.fma(alpha, sum13, c[offsetc + (row + 1) + (col + 3) * ldc]);
        c[offsetc + (row + 2) + (col + 0) * ldc] = Math.fma(alpha, sum20, c[offsetc + (row + 2) + (col + 0) * ldc]);
        c[offsetc + (row + 2) + (col + 1) * ldc] = Math.fma(alpha, sum21, c[offsetc + (row + 2) + (col + 1) * ldc]);
        c[offsetc + (row + 2) + (col + 2) * ldc] = Math.fma(alpha, sum22, c[offsetc + (row + 2) + (col + 2) * ldc]);
        c[offsetc + (row + 2) + (col + 3) * ldc] = Math.fma(alpha, sum23, c[offsetc + (row + 2) + (col + 3) * ldc]);
    }
    public static void main(String[] args) {
        int m = 16, n = 16, k = 16;
        int lda = k, ldb = k, ldc = m;
        double[] a = new double[k * m];
        double[] b = new double[k * n];
        double[] c = new double[m * n];
        for (int i = 0; i < k * m; i++) a[i] = (i % 1000) * 0.001;
        for (int i = 0; i < k * n; i++) b[i] = (i % 1000) * 0.001;
        for (int i = 0; i < m * n; i++) c[i] = (i % 1000) * 0.001;
        for (int iter = 0; iter < 100_000_000; iter++) {
            kernel(m, 0, 3, n, 0, 4, k, 0, k, 1.0, a, 0, lda, b, 0, ldb, 1.0, c, 0, ldc);
        }
        System.out.println("c[0] = " + c[0] + ", c[1] = " + c[1] + ", c[2] = " + c[2]);
    }
}

And here are the running time measurements:

$ FLAGS='-Xbatch -XX:-TieredCompilation -XX:CompileThreshold=1 --add-modules jdk.incubator.vector -XX:CompileCommand=compileonly,test::kernel -XX:UseAVX=2' hyperfine --warmup=3 'old/jdk/bin/java $FLAGS test.java > /dev/null' 'new/jdk/bin/java $FLAGS test.java > /dev/null'
Benchmark 1: old/jdk/bin/java $FLAGS test.java > /dev/null
  Time (mean ± σ):     11.126 s ±  0.194 s    [User: 10.837 s, System: 0.399 s]
  Range (min … max):   10.960 s … 11.623 s    10 runs

Benchmark 2: new/jdk/bin/java $FLAGS test.java > /dev/null
  Time (mean ± σ):     10.037 s ±  0.143 s    [User: 9.743 s, System: 0.399 s]
  Range (min … max):    9.759 s … 10.227 s    10 runs

Summary
  new/jdk/bin/java $FLAGS test.java > /dev/null ran
    1.11 ± 0.02 times faster than old/jdk/bin/java $FLAGS test.java > /dev/null

@raneashay
Copy link
Copy Markdown
Author

Also of relevance, I copied the above standalone code into a JMH benchmark (see below), but I didn't see any improvement:

package org.jmh.suite;

import jdk.incubator.vector.DoubleVector;
import jdk.incubator.vector.VectorOperators;
import jdk.incubator.vector.VectorSpecies;

import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;

import java.util.concurrent.TimeUnit;

/**
 * Run with:
 *   java --add-modules jdk.incubator.vector -jar target/benchmarks.jar
 *        DgepdotTN3x4Benchmark
 */

@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Thread)
@Warmup(iterations = 5, time = 1)
@Measurement(iterations = 10, time = 1)
@Fork(value = 5, jvmArgs = {"--add-modules", "jdk.incubator.vector",
                             "-XX:-TieredCompilation",
                             "-XX:UseAVX=2"})
public class DgepdotTN3x4Benchmark {

    static final VectorSpecies<Double> DMAX = DoubleVector.SPECIES_MAX;

    static final int REPS = 20_000;
    int m, n, k, lda, ldb, ldc;
    double[] a, b, c;

    @Setup
    public void setup() {
        m = 16; n = 16; k = 32;
        lda = k; ldb = k; ldc = m;
        a = new double[k * m];
        b = new double[k * n];
        c = new double[m * n];
        for (int i = 0; i < k * m; i++) a[i] = (i % 1000) * 0.001;
        for (int i = 0; i < k * n; i++) b[i] = (i % 1000) * 0.001;
        for (int i = 0; i < m * n; i++) c[i] = 0.0;
    }

    @Benchmark
    public void dgepdotTN3x4() {
        for (int r = 0; r < REPS; r++) {
            kernel(m, 0, 3, n, 0, 4, k, 0, k, 1.0,
                   a, 0, lda, b, 0, ldb, 1.0, c, 0, ldc);
        }
    }

    static void kernel(int m, int rows, int rowe, int n, int cols, int cole, int k, int is, int ie, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
        final int Ti = 1;

        int row = rows;
        int col = cols;
        int i = is;
        double sum00 = 0.0;
        double sum01 = 0.0;
        double sum02 = 0.0;
        double sum03 = 0.0;
        double sum10 = 0.0;
        double sum11 = 0.0;
        double sum12 = 0.0;
        double sum13 = 0.0;
        double sum20 = 0.0;
        double sum21 = 0.0;
        double sum22 = 0.0;
        double sum23 = 0.0;

        int alignBound = is + DMAX.length() - 1 - ((is + DMAX.length() - 1) % DMAX.length());
        int alignLimit = Math.min(alignBound, ie);

        for (; i < alignLimit; i += 1) {
            double a00 = a[offseta + (i + 0) + (row + 0) * lda];
            double a01 = a[offseta + (i + 0) + (row + 1) * lda];
            double a02 = a[offseta + (i + 0) + (row + 2) * lda];
            double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
            sum00 = Math.fma(a00, b00, sum00);
            sum10 = Math.fma(a01, b00, sum10);
            sum20 = Math.fma(a02, b00, sum20);
            double b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
            sum01 = Math.fma(a00, b01, sum01);
            sum11 = Math.fma(a01, b01, sum11);
            sum21 = Math.fma(a02, b01, sum21);
            double b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
            sum02 = Math.fma(a00, b02, sum02);
            sum12 = Math.fma(a01, b02, sum12);
            sum22 = Math.fma(a02, b02, sum22);
            double b03 = b[offsetb + (i + 0) + (col + 3) * ldb];
            sum03 = Math.fma(a00, b03, sum03);
            sum13 = Math.fma(a01, b03, sum13);
            sum23 = Math.fma(a02, b03, sum23);
        }
        DoubleVector vsum00 = DoubleVector.zero(DMAX);
        DoubleVector vsum01 = DoubleVector.zero(DMAX);
        DoubleVector vsum02 = DoubleVector.zero(DMAX);
        DoubleVector vsum03 = DoubleVector.zero(DMAX);
        DoubleVector vsum10 = DoubleVector.zero(DMAX);
        DoubleVector vsum11 = DoubleVector.zero(DMAX);
        DoubleVector vsum12 = DoubleVector.zero(DMAX);
        DoubleVector vsum13 = DoubleVector.zero(DMAX);
        DoubleVector vsum20 = DoubleVector.zero(DMAX);
        DoubleVector vsum21 = DoubleVector.zero(DMAX);
        DoubleVector vsum22 = DoubleVector.zero(DMAX);
        DoubleVector vsum23 = DoubleVector.zero(DMAX);

        for (; i < ie - (ie % Ti * DMAX.length()); i += Ti * DMAX.length()) {
            DoubleVector va00 = DoubleVector.fromArray(DMAX, a, offseta + (i + 0 * DMAX.length()) + (row + 0) * lda);
            DoubleVector va01 = DoubleVector.fromArray(DMAX, a, offseta + (i + 0 * DMAX.length()) + (row + 1) * lda);
            DoubleVector va02 = DoubleVector.fromArray(DMAX, a, offseta + (i + 0 * DMAX.length()) + (row + 2) * lda);
            DoubleVector vb00 = DoubleVector.fromArray(DMAX, b, offsetb + (i + 0 * DMAX.length()) + (col + 0) * ldb);
            vsum00 = va00.fma(vb00, vsum00);
            vsum10 = va01.fma(vb00, vsum10);
            vsum20 = va02.fma(vb00, vsum20);
            DoubleVector vb01 = DoubleVector.fromArray(DMAX, b, offsetb + (i + 0 * DMAX.length()) + (col + 1) * ldb);
            vsum01 = va00.fma(vb01, vsum01);
            vsum11 = va01.fma(vb01, vsum11);
            vsum21 = va02.fma(vb01, vsum21);
            DoubleVector vb02 = DoubleVector.fromArray(DMAX, b, offsetb + (i + 0 * DMAX.length()) + (col + 2) * ldb);
            vsum02 = va00.fma(vb02, vsum02);
            vsum12 = va01.fma(vb02, vsum12);
            vsum22 = va02.fma(vb02, vsum22);
            DoubleVector vb03 = DoubleVector.fromArray(DMAX, b, offsetb + (i + 0 * DMAX.length()) + (col + 3) * ldb);
            vsum03 = va00.fma(vb03, vsum03);
            vsum13 = va01.fma(vb03, vsum13);
            vsum23 = va02.fma(vb03, vsum23);
        }
        sum00 += vsum00.reduceLanes(VectorOperators.ADD);
        sum01 += vsum01.reduceLanes(VectorOperators.ADD);
        sum02 += vsum02.reduceLanes(VectorOperators.ADD);
        sum03 += vsum03.reduceLanes(VectorOperators.ADD);
        sum10 += vsum10.reduceLanes(VectorOperators.ADD);
        sum11 += vsum11.reduceLanes(VectorOperators.ADD);
        sum12 += vsum12.reduceLanes(VectorOperators.ADD);
        sum13 += vsum13.reduceLanes(VectorOperators.ADD);
        sum20 += vsum20.reduceLanes(VectorOperators.ADD);
        sum21 += vsum21.reduceLanes(VectorOperators.ADD);
        sum22 += vsum22.reduceLanes(VectorOperators.ADD);
        sum23 += vsum23.reduceLanes(VectorOperators.ADD);
        for (; i < ie; i += 1) {
            double a00 = a[offseta + (i + 0) + (row + 0) * lda];
            double a01 = a[offseta + (i + 0) + (row + 1) * lda];
            double a02 = a[offseta + (i + 0) + (row + 2) * lda];
            double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
            sum00 = Math.fma(a00, b00, sum00);
            sum10 = Math.fma(a01, b00, sum10);
            sum20 = Math.fma(a02, b00, sum20);
            double b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
            sum01 = Math.fma(a00, b01, sum01);
            sum11 = Math.fma(a01, b01, sum11);
            sum21 = Math.fma(a02, b01, sum21);
            double b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
            sum02 = Math.fma(a00, b02, sum02);
            sum12 = Math.fma(a01, b02, sum12);
            sum22 = Math.fma(a02, b02, sum22);
            double b03 = b[offsetb + (i + 0) + (col + 3) * ldb];
            sum03 = Math.fma(a00, b03, sum03);
            sum13 = Math.fma(a01, b03, sum13);
            sum23 = Math.fma(a02, b03, sum23);
        }
        c[offsetc + (row + 0) + (col + 0) * ldc] = Math.fma(alpha, sum00, c[offsetc + (row + 0) + (col + 0) * ldc]);
        c[offsetc + (row + 0) + (col + 1) * ldc] = Math.fma(alpha, sum01, c[offsetc + (row + 0) + (col + 1) * ldc]);
        c[offsetc + (row + 0) + (col + 2) * ldc] = Math.fma(alpha, sum02, c[offsetc + (row + 0) + (col + 2) * ldc]);
        c[offsetc + (row + 0) + (col + 3) * ldc] = Math.fma(alpha, sum03, c[offsetc + (row + 0) + (col + 3) * ldc]);
        c[offsetc + (row + 1) + (col + 0) * ldc] = Math.fma(alpha, sum10, c[offsetc + (row + 1) + (col + 0) * ldc]);
        c[offsetc + (row + 1) + (col + 1) * ldc] = Math.fma(alpha, sum11, c[offsetc + (row + 1) + (col + 1) * ldc]);
        c[offsetc + (row + 1) + (col + 2) * ldc] = Math.fma(alpha, sum12, c[offsetc + (row + 1) + (col + 2) * ldc]);
        c[offsetc + (row + 1) + (col + 3) * ldc] = Math.fma(alpha, sum13, c[offsetc + (row + 1) + (col + 3) * ldc]);
        c[offsetc + (row + 2) + (col + 0) * ldc] = Math.fma(alpha, sum20, c[offsetc + (row + 2) + (col + 0) * ldc]);
        c[offsetc + (row + 2) + (col + 1) * ldc] = Math.fma(alpha, sum21, c[offsetc + (row + 2) + (col + 1) * ldc]);
        c[offsetc + (row + 2) + (col + 2) * ldc] = Math.fma(alpha, sum22, c[offsetc + (row + 2) + (col + 2) * ldc]);
        c[offsetc + (row + 2) + (col + 3) * ldc] = Math.fma(alpha, sum23, c[offsetc + (row + 2) + (col + 3) * ldc]);
    }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant