C2: spill scalar live ranges before high-pressure vector loops#64
Draft
raneashay wants to merge 1 commit intomicrosoft:mainfrom
Draft
C2: spill scalar live ranges before high-pressure vector loops#64raneashay wants to merge 1 commit intomicrosoft:mainfrom
raneashay wants to merge 1 commit intomicrosoft:mainfrom
Conversation
This patch adds a new pass that spills scalar values residing in vector registers to memory so that hot loops have complete access to the vector register file. Specifically, on both X64 and AArch64, scalar floating-point values often reside in vector registers (e.g. XMM0-XMM15 on AVX2 or D0-D31 on AArch64), and if these scalar live ranges overlap with vector live ranges inside a hot loop, then the loop will have fewer vector registers to work with, resulting in frequent spilling and restoring of the vector registers within the hot loop. The newly introduced pass in this patch first analyzes loops that are both high frequency and have a high register pressure, before finding scalar live ranges that have no definitions inside the loop and which overlap with the vector register file. Such live ranges are then split so that the scalar values are spilled at the beginning of the loop and restored at the end of the loop. I have validated this using an AVX2 machine using a vectorized 3x4 DGEMM outer product kernel, which shows vector spills in the inner loop dropping from 17 to 2. However, it doesn't look there is a good way to add a JTreg test for this change, since we don't seem to have a reliable way to identify spill and restore assembly instructions using `IRNode`. One alternative is to parse the output of `PrintOptoAssembly` but that approach seems very fragile, especially since a source-level loop is often broken down into multiple pre/main/post loops.
Author
|
In lieu of a functional test, here's a sample kernel (derived from Ludovic's BLAS implementation) that demonstrates the improvement (see running times below). import jdk.incubator.vector.DoubleVector;
import jdk.incubator.vector.FloatVector;
import jdk.incubator.vector.VectorOperators;
import jdk.incubator.vector.VectorSpecies;
public class test {
private static final VectorSpecies<Double> DMAX = DoubleVector.SPECIES_MAX;
protected static int loopAlign(int index, int max, int size) {
return Math.min(loopBound(index + size - 1, size), max);
}
protected static int loopBound(int index, int size) {
return index - (index % size);
}
static void kernel(int m, int rows, int rowe, int n, int cols, int cole, int k, int is, int ie, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
final int Ti = 1;
assert rowe - rows == 3;
assert cole - cols == 4;
int row = rows;
int col = cols;
int i = is;
double sum00 = 0.0;
double sum01 = 0.0;
double sum02 = 0.0;
double sum03 = 0.0;
double sum10 = 0.0;
double sum11 = 0.0;
double sum12 = 0.0;
double sum13 = 0.0;
double sum20 = 0.0;
double sum21 = 0.0;
double sum22 = 0.0;
double sum23 = 0.0;
for (; i < loopAlign(is, ie, Ti * DMAX.length()); i += 1) {
double a00 = a[offseta + (i + 0) + (row + 0) * lda];
double a01 = a[offseta + (i + 0) + (row + 1) * lda];
double a02 = a[offseta + (i + 0) + (row + 2) * lda];
double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
sum00 = Math.fma(a00, b00, sum00);
sum10 = Math.fma(a01, b00, sum10);
sum20 = Math.fma(a02, b00, sum20);
double b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
sum01 = Math.fma(a00, b01, sum01);
sum11 = Math.fma(a01, b01, sum11);
sum21 = Math.fma(a02, b01, sum21);
double b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
sum02 = Math.fma(a00, b02, sum02);
sum12 = Math.fma(a01, b02, sum12);
sum22 = Math.fma(a02, b02, sum22);
double b03 = b[offsetb + (i + 0) + (col + 3) * ldb];
sum03 = Math.fma(a00, b03, sum03);
sum13 = Math.fma(a01, b03, sum13);
sum23 = Math.fma(a02, b03, sum23);
}
DoubleVector vsum00 = DoubleVector.zero(DMAX);
DoubleVector vsum01 = DoubleVector.zero(DMAX);
DoubleVector vsum02 = DoubleVector.zero(DMAX);
DoubleVector vsum03 = DoubleVector.zero(DMAX);
DoubleVector vsum10 = DoubleVector.zero(DMAX);
DoubleVector vsum11 = DoubleVector.zero(DMAX);
DoubleVector vsum12 = DoubleVector.zero(DMAX);
DoubleVector vsum13 = DoubleVector.zero(DMAX);
DoubleVector vsum20 = DoubleVector.zero(DMAX);
DoubleVector vsum21 = DoubleVector.zero(DMAX);
DoubleVector vsum22 = DoubleVector.zero(DMAX);
DoubleVector vsum23 = DoubleVector.zero(DMAX);
for (; i < loopBound(ie, Ti * DMAX.length()); i += Ti * DMAX.length()) {
DoubleVector va00 = DoubleVector.fromArray(DMAX, a, offseta + (i + 0 * DMAX.length()) + (row + 0) * lda);
DoubleVector va01 = DoubleVector.fromArray(DMAX, a, offseta + (i + 0 * DMAX.length()) + (row + 1) * lda);
DoubleVector va02 = DoubleVector.fromArray(DMAX, a, offseta + (i + 0 * DMAX.length()) + (row + 2) * lda);
DoubleVector vb00 = DoubleVector.fromArray(DMAX, b, offsetb + (i + 0 * DMAX.length()) + (col + 0) * ldb);
vsum00 = va00.fma(vb00, vsum00);
vsum10 = va01.fma(vb00, vsum10);
vsum20 = va02.fma(vb00, vsum20);
DoubleVector vb01 = DoubleVector.fromArray(DMAX, b, offsetb + (i + 0 * DMAX.length()) + (col + 1) * ldb);
vsum01 = va00.fma(vb01, vsum01);
vsum11 = va01.fma(vb01, vsum11);
vsum21 = va02.fma(vb01, vsum21);
DoubleVector vb02 = DoubleVector.fromArray(DMAX, b, offsetb + (i + 0 * DMAX.length()) + (col + 2) * ldb);
vsum02 = va00.fma(vb02, vsum02);
vsum12 = va01.fma(vb02, vsum12);
vsum22 = va02.fma(vb02, vsum22);
DoubleVector vb03 = DoubleVector.fromArray(DMAX, b, offsetb + (i + 0 * DMAX.length()) + (col + 3) * ldb);
vsum03 = va00.fma(vb03, vsum03);
vsum13 = va01.fma(vb03, vsum13);
vsum23 = va02.fma(vb03, vsum23);
}
sum00 += vsum00.reduceLanes(VectorOperators.ADD);
sum01 += vsum01.reduceLanes(VectorOperators.ADD);
sum02 += vsum02.reduceLanes(VectorOperators.ADD);
sum03 += vsum03.reduceLanes(VectorOperators.ADD);
sum10 += vsum10.reduceLanes(VectorOperators.ADD);
sum11 += vsum11.reduceLanes(VectorOperators.ADD);
sum12 += vsum12.reduceLanes(VectorOperators.ADD);
sum13 += vsum13.reduceLanes(VectorOperators.ADD);
sum20 += vsum20.reduceLanes(VectorOperators.ADD);
sum21 += vsum21.reduceLanes(VectorOperators.ADD);
sum22 += vsum22.reduceLanes(VectorOperators.ADD);
sum23 += vsum23.reduceLanes(VectorOperators.ADD);
for (; i < ie; i += 1) {
double a00 = a[offseta + (i + 0) + (row + 0) * lda];
double a01 = a[offseta + (i + 0) + (row + 1) * lda];
double a02 = a[offseta + (i + 0) + (row + 2) * lda];
double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
sum00 = Math.fma(a00, b00, sum00);
sum10 = Math.fma(a01, b00, sum10);
sum20 = Math.fma(a02, b00, sum20);
double b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
sum01 = Math.fma(a00, b01, sum01);
sum11 = Math.fma(a01, b01, sum11);
sum21 = Math.fma(a02, b01, sum21);
double b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
sum02 = Math.fma(a00, b02, sum02);
sum12 = Math.fma(a01, b02, sum12);
sum22 = Math.fma(a02, b02, sum22);
double b03 = b[offsetb + (i + 0) + (col + 3) * ldb];
sum03 = Math.fma(a00, b03, sum03);
sum13 = Math.fma(a01, b03, sum13);
sum23 = Math.fma(a02, b03, sum23);
}
c[offsetc + (row + 0) + (col + 0) * ldc] = Math.fma(alpha, sum00, c[offsetc + (row + 0) + (col + 0) * ldc]);
c[offsetc + (row + 0) + (col + 1) * ldc] = Math.fma(alpha, sum01, c[offsetc + (row + 0) + (col + 1) * ldc]);
c[offsetc + (row + 0) + (col + 2) * ldc] = Math.fma(alpha, sum02, c[offsetc + (row + 0) + (col + 2) * ldc]);
c[offsetc + (row + 0) + (col + 3) * ldc] = Math.fma(alpha, sum03, c[offsetc + (row + 0) + (col + 3) * ldc]);
c[offsetc + (row + 1) + (col + 0) * ldc] = Math.fma(alpha, sum10, c[offsetc + (row + 1) + (col + 0) * ldc]);
c[offsetc + (row + 1) + (col + 1) * ldc] = Math.fma(alpha, sum11, c[offsetc + (row + 1) + (col + 1) * ldc]);
c[offsetc + (row + 1) + (col + 2) * ldc] = Math.fma(alpha, sum12, c[offsetc + (row + 1) + (col + 2) * ldc]);
c[offsetc + (row + 1) + (col + 3) * ldc] = Math.fma(alpha, sum13, c[offsetc + (row + 1) + (col + 3) * ldc]);
c[offsetc + (row + 2) + (col + 0) * ldc] = Math.fma(alpha, sum20, c[offsetc + (row + 2) + (col + 0) * ldc]);
c[offsetc + (row + 2) + (col + 1) * ldc] = Math.fma(alpha, sum21, c[offsetc + (row + 2) + (col + 1) * ldc]);
c[offsetc + (row + 2) + (col + 2) * ldc] = Math.fma(alpha, sum22, c[offsetc + (row + 2) + (col + 2) * ldc]);
c[offsetc + (row + 2) + (col + 3) * ldc] = Math.fma(alpha, sum23, c[offsetc + (row + 2) + (col + 3) * ldc]);
}
public static void main(String[] args) {
int m = 16, n = 16, k = 16;
int lda = k, ldb = k, ldc = m;
double[] a = new double[k * m];
double[] b = new double[k * n];
double[] c = new double[m * n];
for (int i = 0; i < k * m; i++) a[i] = (i % 1000) * 0.001;
for (int i = 0; i < k * n; i++) b[i] = (i % 1000) * 0.001;
for (int i = 0; i < m * n; i++) c[i] = (i % 1000) * 0.001;
for (int iter = 0; iter < 100_000_000; iter++) {
kernel(m, 0, 3, n, 0, 4, k, 0, k, 1.0, a, 0, lda, b, 0, ldb, 1.0, c, 0, ldc);
}
System.out.println("c[0] = " + c[0] + ", c[1] = " + c[1] + ", c[2] = " + c[2]);
}
}And here are the running time measurements: |
Author
|
Also of relevance, I copied the above standalone code into a JMH benchmark (see below), but I didn't see any improvement: package org.jmh.suite;
import jdk.incubator.vector.DoubleVector;
import jdk.incubator.vector.VectorOperators;
import jdk.incubator.vector.VectorSpecies;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import java.util.concurrent.TimeUnit;
/**
* Run with:
* java --add-modules jdk.incubator.vector -jar target/benchmarks.jar
* DgepdotTN3x4Benchmark
*/
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Thread)
@Warmup(iterations = 5, time = 1)
@Measurement(iterations = 10, time = 1)
@Fork(value = 5, jvmArgs = {"--add-modules", "jdk.incubator.vector",
"-XX:-TieredCompilation",
"-XX:UseAVX=2"})
public class DgepdotTN3x4Benchmark {
static final VectorSpecies<Double> DMAX = DoubleVector.SPECIES_MAX;
static final int REPS = 20_000;
int m, n, k, lda, ldb, ldc;
double[] a, b, c;
@Setup
public void setup() {
m = 16; n = 16; k = 32;
lda = k; ldb = k; ldc = m;
a = new double[k * m];
b = new double[k * n];
c = new double[m * n];
for (int i = 0; i < k * m; i++) a[i] = (i % 1000) * 0.001;
for (int i = 0; i < k * n; i++) b[i] = (i % 1000) * 0.001;
for (int i = 0; i < m * n; i++) c[i] = 0.0;
}
@Benchmark
public void dgepdotTN3x4() {
for (int r = 0; r < REPS; r++) {
kernel(m, 0, 3, n, 0, 4, k, 0, k, 1.0,
a, 0, lda, b, 0, ldb, 1.0, c, 0, ldc);
}
}
static void kernel(int m, int rows, int rowe, int n, int cols, int cole, int k, int is, int ie, double alpha, double[] a, int offseta, int lda, double[] b, int offsetb, int ldb, double beta, double[] c, int offsetc, int ldc) {
final int Ti = 1;
int row = rows;
int col = cols;
int i = is;
double sum00 = 0.0;
double sum01 = 0.0;
double sum02 = 0.0;
double sum03 = 0.0;
double sum10 = 0.0;
double sum11 = 0.0;
double sum12 = 0.0;
double sum13 = 0.0;
double sum20 = 0.0;
double sum21 = 0.0;
double sum22 = 0.0;
double sum23 = 0.0;
int alignBound = is + DMAX.length() - 1 - ((is + DMAX.length() - 1) % DMAX.length());
int alignLimit = Math.min(alignBound, ie);
for (; i < alignLimit; i += 1) {
double a00 = a[offseta + (i + 0) + (row + 0) * lda];
double a01 = a[offseta + (i + 0) + (row + 1) * lda];
double a02 = a[offseta + (i + 0) + (row + 2) * lda];
double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
sum00 = Math.fma(a00, b00, sum00);
sum10 = Math.fma(a01, b00, sum10);
sum20 = Math.fma(a02, b00, sum20);
double b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
sum01 = Math.fma(a00, b01, sum01);
sum11 = Math.fma(a01, b01, sum11);
sum21 = Math.fma(a02, b01, sum21);
double b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
sum02 = Math.fma(a00, b02, sum02);
sum12 = Math.fma(a01, b02, sum12);
sum22 = Math.fma(a02, b02, sum22);
double b03 = b[offsetb + (i + 0) + (col + 3) * ldb];
sum03 = Math.fma(a00, b03, sum03);
sum13 = Math.fma(a01, b03, sum13);
sum23 = Math.fma(a02, b03, sum23);
}
DoubleVector vsum00 = DoubleVector.zero(DMAX);
DoubleVector vsum01 = DoubleVector.zero(DMAX);
DoubleVector vsum02 = DoubleVector.zero(DMAX);
DoubleVector vsum03 = DoubleVector.zero(DMAX);
DoubleVector vsum10 = DoubleVector.zero(DMAX);
DoubleVector vsum11 = DoubleVector.zero(DMAX);
DoubleVector vsum12 = DoubleVector.zero(DMAX);
DoubleVector vsum13 = DoubleVector.zero(DMAX);
DoubleVector vsum20 = DoubleVector.zero(DMAX);
DoubleVector vsum21 = DoubleVector.zero(DMAX);
DoubleVector vsum22 = DoubleVector.zero(DMAX);
DoubleVector vsum23 = DoubleVector.zero(DMAX);
for (; i < ie - (ie % Ti * DMAX.length()); i += Ti * DMAX.length()) {
DoubleVector va00 = DoubleVector.fromArray(DMAX, a, offseta + (i + 0 * DMAX.length()) + (row + 0) * lda);
DoubleVector va01 = DoubleVector.fromArray(DMAX, a, offseta + (i + 0 * DMAX.length()) + (row + 1) * lda);
DoubleVector va02 = DoubleVector.fromArray(DMAX, a, offseta + (i + 0 * DMAX.length()) + (row + 2) * lda);
DoubleVector vb00 = DoubleVector.fromArray(DMAX, b, offsetb + (i + 0 * DMAX.length()) + (col + 0) * ldb);
vsum00 = va00.fma(vb00, vsum00);
vsum10 = va01.fma(vb00, vsum10);
vsum20 = va02.fma(vb00, vsum20);
DoubleVector vb01 = DoubleVector.fromArray(DMAX, b, offsetb + (i + 0 * DMAX.length()) + (col + 1) * ldb);
vsum01 = va00.fma(vb01, vsum01);
vsum11 = va01.fma(vb01, vsum11);
vsum21 = va02.fma(vb01, vsum21);
DoubleVector vb02 = DoubleVector.fromArray(DMAX, b, offsetb + (i + 0 * DMAX.length()) + (col + 2) * ldb);
vsum02 = va00.fma(vb02, vsum02);
vsum12 = va01.fma(vb02, vsum12);
vsum22 = va02.fma(vb02, vsum22);
DoubleVector vb03 = DoubleVector.fromArray(DMAX, b, offsetb + (i + 0 * DMAX.length()) + (col + 3) * ldb);
vsum03 = va00.fma(vb03, vsum03);
vsum13 = va01.fma(vb03, vsum13);
vsum23 = va02.fma(vb03, vsum23);
}
sum00 += vsum00.reduceLanes(VectorOperators.ADD);
sum01 += vsum01.reduceLanes(VectorOperators.ADD);
sum02 += vsum02.reduceLanes(VectorOperators.ADD);
sum03 += vsum03.reduceLanes(VectorOperators.ADD);
sum10 += vsum10.reduceLanes(VectorOperators.ADD);
sum11 += vsum11.reduceLanes(VectorOperators.ADD);
sum12 += vsum12.reduceLanes(VectorOperators.ADD);
sum13 += vsum13.reduceLanes(VectorOperators.ADD);
sum20 += vsum20.reduceLanes(VectorOperators.ADD);
sum21 += vsum21.reduceLanes(VectorOperators.ADD);
sum22 += vsum22.reduceLanes(VectorOperators.ADD);
sum23 += vsum23.reduceLanes(VectorOperators.ADD);
for (; i < ie; i += 1) {
double a00 = a[offseta + (i + 0) + (row + 0) * lda];
double a01 = a[offseta + (i + 0) + (row + 1) * lda];
double a02 = a[offseta + (i + 0) + (row + 2) * lda];
double b00 = b[offsetb + (i + 0) + (col + 0) * ldb];
sum00 = Math.fma(a00, b00, sum00);
sum10 = Math.fma(a01, b00, sum10);
sum20 = Math.fma(a02, b00, sum20);
double b01 = b[offsetb + (i + 0) + (col + 1) * ldb];
sum01 = Math.fma(a00, b01, sum01);
sum11 = Math.fma(a01, b01, sum11);
sum21 = Math.fma(a02, b01, sum21);
double b02 = b[offsetb + (i + 0) + (col + 2) * ldb];
sum02 = Math.fma(a00, b02, sum02);
sum12 = Math.fma(a01, b02, sum12);
sum22 = Math.fma(a02, b02, sum22);
double b03 = b[offsetb + (i + 0) + (col + 3) * ldb];
sum03 = Math.fma(a00, b03, sum03);
sum13 = Math.fma(a01, b03, sum13);
sum23 = Math.fma(a02, b03, sum23);
}
c[offsetc + (row + 0) + (col + 0) * ldc] = Math.fma(alpha, sum00, c[offsetc + (row + 0) + (col + 0) * ldc]);
c[offsetc + (row + 0) + (col + 1) * ldc] = Math.fma(alpha, sum01, c[offsetc + (row + 0) + (col + 1) * ldc]);
c[offsetc + (row + 0) + (col + 2) * ldc] = Math.fma(alpha, sum02, c[offsetc + (row + 0) + (col + 2) * ldc]);
c[offsetc + (row + 0) + (col + 3) * ldc] = Math.fma(alpha, sum03, c[offsetc + (row + 0) + (col + 3) * ldc]);
c[offsetc + (row + 1) + (col + 0) * ldc] = Math.fma(alpha, sum10, c[offsetc + (row + 1) + (col + 0) * ldc]);
c[offsetc + (row + 1) + (col + 1) * ldc] = Math.fma(alpha, sum11, c[offsetc + (row + 1) + (col + 1) * ldc]);
c[offsetc + (row + 1) + (col + 2) * ldc] = Math.fma(alpha, sum12, c[offsetc + (row + 1) + (col + 2) * ldc]);
c[offsetc + (row + 1) + (col + 3) * ldc] = Math.fma(alpha, sum13, c[offsetc + (row + 1) + (col + 3) * ldc]);
c[offsetc + (row + 2) + (col + 0) * ldc] = Math.fma(alpha, sum20, c[offsetc + (row + 2) + (col + 0) * ldc]);
c[offsetc + (row + 2) + (col + 1) * ldc] = Math.fma(alpha, sum21, c[offsetc + (row + 2) + (col + 1) * ldc]);
c[offsetc + (row + 2) + (col + 2) * ldc] = Math.fma(alpha, sum22, c[offsetc + (row + 2) + (col + 2) * ldc]);
c[offsetc + (row + 2) + (col + 3) * ldc] = Math.fma(alpha, sum23, c[offsetc + (row + 2) + (col + 3) * ldc]);
}
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This patch adds a new pass that spills scalar values residing in vector
registers to memory so that hot loops have complete access to the vector
register file. Specifically, on both X64 and AArch64, scalar
floating-point values often reside in vector registers (e.g. XMM0-XMM15
on AVX2 or D0-D31 on AArch64), and if these scalar live ranges overlap
with vector live ranges inside a hot loop, then the loop will have fewer
vector registers to work with, resulting in frequent spilling and
restoring of the vector registers within the hot loop.
The newly introduced pass in this patch first analyzes loops that are
both high frequency and have a high register pressure, before finding
scalar live ranges that have no definitions inside the loop and which
overlap with the vector register file. Such live ranges are then split
so that the scalar values are spilled at the beginning of the loop and
restored at the end of the loop.
I have validated this using an AVX2 machine using a vectorized 3x4 DGEMM
outer product kernel, which shows vector spills in the inner loop
dropping from 17 to 2. However, it doesn't look there is a good way to
add a JTreg test for this change, since we don't seem to have a reliable
way to identify spill and restore assembly instructions using
IRNode.One alternative is to parse the output of
PrintOptoAssemblybut thatapproach seems very fragile, especially since a source-level loop is
often broken down into multiple pre/main/post loops.