Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions lit_tests/kernel/wave/infer_index_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,112 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import wave_lang.kernel.lang as tkl
import wave_lang.kernel.wave as tkw
from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile

from wave_lang.kernel.wave.templates.gemm import get_gemm_kernel
from wave_lang.kernel.lang.global_symbols import *
from wave_lang.kernel.wave.constraints import MMAType


def _get_matrix_add_kernel():
M = tkl.sym.M
N = tkl.sym.N
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
ADDRESS_SPACE = tkl.sym.GLOBAL_ADDRESS_SPACE
dtype = tkl.f16

constraints = [
tkw.HardwareConstraint(threads_per_wave=64, vector_shapes={M: 4, N: 1}),
tkw.WorkgroupConstraint(M, BLOCK_M, 0),
tkw.WorkgroupConstraint(N, BLOCK_N, 1),
]

@tkw.wave(constraints)
def matrix_add(
a: tkl.Memory[M, N, ADDRESS_SPACE, dtype],
b: tkl.Memory[M, N, ADDRESS_SPACE, dtype],
c: tkl.Memory[M, N, ADDRESS_SPACE, dtype],
):
a_reg = tkw.read(a)
b_reg = tkw.read(b)
c_reg = a_reg + b_reg
tkw.write(c_reg, c)

hyperparams = {
M: 128,
N: 128,
BLOCK_M: 16,
BLOCK_N: 16,
ADDRESS_SPACE: GLOBAL_ADDRESS_SPACE,
}
return matrix_add, hyperparams


def _get_mma_chain_kernel():
# Input sizes
M = tkl.sym.M
N = tkl.sym.N
K = tkl.sym.K
P = tkl.sym.P
# Workgroup tile sizes
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
dtype = tkl.f16

constraints = [
tkw.HardwareConstraint(
threads_per_wave=64,
mma_type=MMAType.F32_16x16x16_F16,
waves_per_block=(1, 2, 2),
),
tkw.WorkgroupConstraint(M, BLOCK_M, 0),
tkw.WorkgroupConstraint(N, BLOCK_N, 1),
]

@tkw.wave(constraints)
def mma_chain(
a: tkl.Memory[M, K, GLOBAL_ADDRESS_SPACE, dtype],
b: tkl.Memory[N, K, GLOBAL_ADDRESS_SPACE, dtype],
c: tkl.Memory[M, P, GLOBAL_ADDRESS_SPACE, tkl.f32],
d: tkl.Memory[P, N, GLOBAL_ADDRESS_SPACE, dtype],
storage: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, dtype],
):
a_read = tkw.read(a)
b_read = tkw.read(b)
c_reg = tkl.Register[M, N, tkl.f32](0.0)
mma1 = tkw.mma(a_read, b_read, c_reg)
mma1_casted = tkw.cast(mma1, tkl.f16)
tkw.write(mma1_casted, storage)
reloaded = tkw.read(storage)
d_read = tkw.read(d)
c_reg2 = tkl.Register[M, P, tkl.f32](0.0)
mma2 = tkw.mma(reloaded, d_read, c_reg2)
tkw.write(mma2, c)

hyperparams = {
M: 128,
N: 128,
K: 128,
P: 128,
BLOCK_M: 16,
BLOCK_N: 16,
}
return mma_chain, hyperparams


def testMatrixAdd():
kernel, params = _get_matrix_add_kernel()
options = WaveCompileOptions(
subs=params,
run_bench=False,
check_water_analysis=True,
)
compiled_kernel = wave_compile(options, kernel)
assert compiled_kernel is not None


def testGemm():
relevant_hyperparams = [
tkl.sym.M,
Expand Down Expand Up @@ -54,5 +153,18 @@ def testGemm():
assert compiled_gemm is not None


def testMmaChain():
kernel, params = _get_mma_chain_kernel()
options = WaveCompileOptions(
subs=params,
run_bench=False,
check_water_analysis=True,
)
compiled_kernel = wave_compile(options, kernel)
assert compiled_kernel is not None


if __name__ == "__main__":
testMatrixAdd()
testMmaChain()
testGemm()
19 changes: 18 additions & 1 deletion water/include/water/Dialect/Wave/IR/WaveInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "llvm/Support/raw_ostream.h"

#include "water/Dialect/Wave/IR/WaveAttrs.h"
#include "water/Dialect/Wave/Transforms/Utils.h"

namespace wave {

Expand Down Expand Up @@ -563,9 +564,17 @@ class IndexExprsAnalysisInit {
// expressions.
class IndexExprsLatticeStorage {
public:
// Priorities for specific operations that may be used.
static constexpr int32_t kHighestPriority =
std::numeric_limits<int32_t>::max();
static constexpr int32_t kMmaPriority = 3;
static constexpr int32_t kWritePriority = 1;
static constexpr int32_t kLowestPriority = 0;

IndexExprsLatticeStorage();
IndexExprsLatticeStorage(const IndexExprsLatticeStorage &value) = default;
IndexExprsLatticeStorage(mlir::DictionaryAttr concreteValue);
IndexExprsLatticeStorage(mlir::DictionaryAttr concreteValue,
int32_t priority);

IndexExprsLatticeStorage &
operator=(const IndexExprsLatticeStorage &other) = default;
Expand All @@ -583,6 +592,10 @@ class IndexExprsLatticeStorage {
// specified or not, or null if the lattice instance is a top or a bottom.
mlir::DictionaryAttr getConcreteValue() const;

// Return the priority of this lattice instance or -1 if it is not a concrete
// value.
int32_t getPriority() const { return getConcreteValue() ? priority : -1; }

// Return the top lattice instance.
static IndexExprsLatticeStorage top();

Expand Down Expand Up @@ -628,6 +641,10 @@ class IndexExprsLatticeStorage {
// symbol indexing the value or one of the top/bottom flags.
llvm::PointerIntPair<mlir::Attribute, 2> value;

// Priority of this value. Specific values with higher priority override
// values with lower priority in joins.
int32_t priority;

// State flags.
constexpr static unsigned kUninitializedState = 0;
constexpr static unsigned kSpecificTypeState = 1;
Expand Down
3 changes: 2 additions & 1 deletion water/include/water/Dialect/Wave/IR/WaveInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,8 @@ def WaveInferIndexExprsOpInterface : OpInterface<"WaveInferIndexExprsOpInterface
"initializeIndexExprsBackward",
(ins "::llvm::MutableArrayRef<::wave::IndexExprsLatticeStorage>":$operandExprs,
"const ::wave::IndexExprsAnalysisInit &":$initObject,
"::wave::EmitErrorFn":$emitError),
"::wave::EmitErrorFn":$emitError,
"::wave::EmitDelayedErrorFn &":$delayedErrorEmitter),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return ::llvm::success();
Expand Down
2 changes: 1 addition & 1 deletion water/include/water/Dialect/Wave/IR/WaveOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ def WriteOp : WaveOp<"write", [
DeclareOpInterfaceMethods<WaveInferTypeOpInterface>,
DeclareOpInterfaceMethods<WaveElementsPerThreadOpInterface>,
DeclareOpInterfaceMethods<WaveInferIndexExprsOpInterface,
["getIndexExprValuesAndDescriptions"]>,
["getIndexExprValuesAndDescriptions", "initializeIndexExprsBackward"]>,
RequiresSidewaysBackwardPropagationOpTrait]> {
let summary = "Writes into memory";
let description = [{
Expand Down
22 changes: 14 additions & 8 deletions water/include/water/Dialect/Wave/Transforms/DataFlowAnalyses.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#ifndef WATER_DIALECT_WAVE_TRANSFORMS_DATAFLOWANALYSES_H
#define WATER_DIALECT_WAVE_TRANSFORMS_DATAFLOWANALYSES_H

#include "water/Dialect/Wave/Transforms/Utils.h"
#include "llvm/ADT/FunctionExtras.h"

namespace llvm {
Expand All @@ -23,7 +24,7 @@ class DictionaryAttr;

namespace wave {
using SetIndexLatticeFn =
llvm::function_ref<void(mlir::Value, mlir::DictionaryAttr)>;
llvm::function_ref<void(mlir::Value, mlir::DictionaryAttr, int32_t)>;
using OverrideInitializationFn = llvm::function_ref<llvm::LogicalResult(
mlir::Operation *, SetIndexLatticeFn)>;

Expand All @@ -36,13 +37,18 @@ struct WaveIndexExprsAnalysisOptions {
};

// Add analyses for index expression propagation to the solver.
void addWaveIndexExprsAnalyses(mlir::DataFlowSolver &solver,
mlir::SymbolTableCollection &symbolTable,
WaveIndexExprsAnalysisOptions options = {});

llvm::LogicalResult
setWaveIndexExprAnalysisResults(mlir::Operation *top,
const mlir::DataFlowSolver &solver);
wave::DelayedErrorEmitterInfo
addWaveIndexExprsAnalyses(mlir::DataFlowSolver &solver,
mlir::SymbolTableCollection &symbolTable,
WaveIndexExprsAnalysisOptions options = {});

// Set the index attribute attributes on operations nested under `top` using the
// lattices computed by the dataflow analyses in the given solver. Emit delayed
// errors if they are related to operations for which we failed to infer index
// expressions.
llvm::LogicalResult setWaveIndexExprAnalysisResults(
mlir::Operation *top, const mlir::DataFlowSolver &solver,
const wave::DelayedErrorEmitterInfo &delayedErrorInfo);

// Run the dataflow analyses and capture whether some diagnostics were emitted.
// Only emit a generic diagnostic if no more specific diagnostic was emitted.
Expand Down
13 changes: 13 additions & 0 deletions water/include/water/Dialect/Wave/Transforms/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,19 @@

namespace wave {

// Callback to generate a delayed diagnostic. The diagnostic message should be
// attached to the argument. It may or may not be emitted.
using EmitDelayedErrorFn = std::function<void(mlir::InFlightDiagnostic &)>;

// Information to emit delayed errors.
struct DelayedErrorEmitterInfo {
// Returns the delayed error for the given operation.
std::function<EmitDelayedErrorFn(mlir::Operation *)> getDelayedError;

// Returns true if there are any delayed errors.
std::function<bool()> hasDelayedErrors;
};

/// Get the hyperparameters from an ancestor operation.
/// Returns nullptr if no hyperparameters are found.
WaveHyperparameterAttr getHyperparameters(mlir::Operation *op);
Expand Down
45 changes: 29 additions & 16 deletions water/lib/Dialect/Wave/IR/WaveInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -729,11 +729,11 @@ llvm::LogicalResult wave::detail::verifyCompatibleOperandsAndResultsOpTrait(
//-----------------------------------------------------------------------------

wave::IndexExprsLatticeStorage::IndexExprsLatticeStorage()
: value(nullptr, kUninitializedState) {}
: value(nullptr, kUninitializedState), priority(kLowestPriority) {}

wave::IndexExprsLatticeStorage::IndexExprsLatticeStorage(
DictionaryAttr concreteValue)
: value(concreteValue, kSpecificTypeState) {}
DictionaryAttr concreteValue, int32_t priority)
: value(concreteValue, kSpecificTypeState), priority(priority) {}

bool wave::IndexExprsLatticeStorage::operator==(
const IndexExprsLatticeStorage &other) const {
Expand Down Expand Up @@ -1097,12 +1097,18 @@ static FailureOr<AffineMap> getIndexExprStepStrideJoinedMap(
return failure();
}

// Join two concrete index expressions mappings by joining their
// start/step/stride maps independently. See getIndexExprStartJoinedMap and
// getIndexExprStepStrideJoinedMap for more details.
// Join two concrete index expressions mappings either by picking the
// higher-priority one or by joining their start/step/stride maps independently.
// See getIndexExprStartJoinedMap and getIndexExprStepStrideJoinedMap for more
// details on independent joining.
static wave::WaveIndexMappingAttr
getIndexExprsJoinMappings(wave::WaveIndexMappingAttr lhs,
wave::WaveIndexMappingAttr rhs) {
wave::WaveIndexMappingAttr rhs, int32_t lhsPriority,
int32_t rhsPriority) {
if (lhsPriority > rhsPriority)
return lhs;
if (rhsPriority > lhsPriority)
return rhs;

// Collect all unique symbol names from both index mappings in order.
llvm::SmallVector<Attribute> allSymbols;
Expand Down Expand Up @@ -1162,7 +1168,8 @@ wave::IndexExprsLatticeStorage wave::IndexExprsLatticeStorage::join(
namedAttr.getName().getValue());
});
return IndexExprsLatticeStorage(
DictionaryAttr::get(rhs.getConcreteValue().getContext(), filtered));
DictionaryAttr::get(rhs.getConcreteValue().getContext(), filtered),
rhs.getPriority());
}

if (rhs.isBottom())
Expand Down Expand Up @@ -1195,17 +1202,21 @@ wave::IndexExprsLatticeStorage wave::IndexExprsLatticeStorage::join(
if (lhsValue == rhsValue)
continue;

wave::WaveIndexMappingAttr joinedMapping =
getIndexExprsJoinMappings(lhsValue, rhsValue);
wave::WaveIndexMappingAttr joinedMapping = getIndexExprsJoinMappings(
lhsValue, rhsValue, lhs.getPriority(), rhs.getPriority());
if (!joinedMapping)
return IndexExprsLatticeStorage::top();

result[namedAttr.getName()] = joinedMapping;
}
return IndexExprsLatticeStorage(
DictionaryAttr::get(ctx, llvm::map_to_vector(result, [](auto &&pair) {
return NamedAttribute(pair.first, pair.second);
})));
DictionaryAttr::get(ctx, llvm::map_to_vector(result,
[](auto &&pair) {
return NamedAttribute(
pair.first,
pair.second);
})),
std::max(lhs.getPriority(), rhs.getPriority()));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need priority per-symbol?..

}

wave::IndexExprsLatticeStorage
Expand Down Expand Up @@ -1237,7 +1248,8 @@ wave::IndexExprsLatticeStorage wave::IndexExprsLatticeStorage::keepOnlySymbols(
return bottom();

return IndexExprsLatticeStorage(
DictionaryAttr::get(getConcreteValue().getContext(), filtered));
DictionaryAttr::get(getConcreteValue().getContext(), filtered),
getPriority());
}

wave::IndexExprsLatticeStorage
Expand All @@ -1257,7 +1269,8 @@ wave::IndexExprsLatticeStorage::withoutIterSymbols(
}
return NamedAttribute(attr.getName(), value);
});
return IndexExprsLatticeStorage(DictionaryAttr::get(ctx, updated));
return IndexExprsLatticeStorage(DictionaryAttr::get(ctx, updated),
getPriority());
}

void wave::IndexExprsLatticeStorage::print(llvm::raw_ostream &os) const {
Expand All @@ -1266,7 +1279,7 @@ void wave::IndexExprsLatticeStorage::print(llvm::raw_ostream &os) const {
} else if (isTop()) {
os << "<top>";
} else {
os << getConcreteValue();
os << "[pri: " << getPriority() << "] " << getConcreteValue();
}
}

Expand Down
Loading
Loading