From be097bf495218bd1ff5c9b34030e70e421b62ab8 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Fri, 6 Mar 2026 15:56:10 +0100 Subject: [PATCH] [water] priority-based index expression propagation This mimics the index exrepssion propagation behavior pywave has with its multiple passes of propagation based on different "source" ops in a single dataflow process, with convergence guarantees. Earlier "source" operations assign higher-priority index expressions to values that completely override lower-priority expressions on join. Only equal-priority expressions perform actual per-dimension join. Several design decisions are replicated for consistency purposes, but should be revised in the future to provide a unified approach to inference without surprises such as the "mode switch" that results in completely different expressions when an mma is removed. This is particularly true for the case of missing `vector_shape` part of the hardware constraint that is only necessary when a "write" operation isn't writing a value transitively obtained from an MMA. Supporting this requires deferring diagnostic emission to account for the case where the process would converge thanks to a higher-priority expression coming from elsewhere. Signed-off-by: Alex Zinenko --- lit_tests/kernel/wave/infer_index_exprs.py | 112 +++++++ .../water/Dialect/Wave/IR/WaveInterfaces.h | 19 +- .../water/Dialect/Wave/IR/WaveInterfaces.td | 3 +- .../include/water/Dialect/Wave/IR/WaveOps.td | 2 +- .../Wave/Transforms/DataFlowAnalyses.h | 22 +- .../water/Dialect/Wave/Transforms/Utils.h | 13 + water/lib/Dialect/Wave/IR/WaveInterfaces.cpp | 45 ++- water/lib/Dialect/Wave/IR/WaveOps.cpp | 188 ++++++++++- .../Dialect/Wave/Transforms/InferTypes.cpp | 115 +++++-- .../Wave/infer-index-exprs-lattice.mlir | 316 ++++++++++++------ .../test/Dialect/Wave/infer-index-exprs.mlir | 246 ++++++++++++++ .../TestWaveDialectInferIndexExprs.cpp | 30 +- .../wave/analysis/index_sequence_analysis.py | 62 +++- 13 files changed, 996 insertions(+), 177 deletions(-) diff --git a/lit_tests/kernel/wave/infer_index_exprs.py b/lit_tests/kernel/wave/infer_index_exprs.py index 7ec72c0555..7f125bbf3d 100644 --- a/lit_tests/kernel/wave/infer_index_exprs.py +++ b/lit_tests/kernel/wave/infer_index_exprs.py @@ -9,6 +9,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import wave_lang.kernel.lang as tkl +import wave_lang.kernel.wave as tkw from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile from wave_lang.kernel.wave.templates.gemm import get_gemm_kernel @@ -16,6 +17,104 @@ from wave_lang.kernel.wave.constraints import MMAType +def _get_matrix_add_kernel(): + M = tkl.sym.M + N = tkl.sym.N + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + ADDRESS_SPACE = tkl.sym.GLOBAL_ADDRESS_SPACE + dtype = tkl.f16 + + constraints = [ + tkw.HardwareConstraint(threads_per_wave=64, vector_shapes={M: 4, N: 1}), + tkw.WorkgroupConstraint(M, BLOCK_M, 0), + tkw.WorkgroupConstraint(N, BLOCK_N, 1), + ] + + @tkw.wave(constraints) + def matrix_add( + a: tkl.Memory[M, N, ADDRESS_SPACE, dtype], + b: tkl.Memory[M, N, ADDRESS_SPACE, dtype], + c: tkl.Memory[M, N, ADDRESS_SPACE, dtype], + ): + a_reg = tkw.read(a) + b_reg = tkw.read(b) + c_reg = a_reg + b_reg + tkw.write(c_reg, c) + + hyperparams = { + M: 128, + N: 128, + BLOCK_M: 16, + BLOCK_N: 16, + ADDRESS_SPACE: GLOBAL_ADDRESS_SPACE, + } + return matrix_add, hyperparams + + +def _get_mma_chain_kernel(): + # Input sizes + M = tkl.sym.M + N = tkl.sym.N + K = tkl.sym.K + P = tkl.sym.P + # Workgroup tile sizes + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + dtype = tkl.f16 + + constraints = [ + tkw.HardwareConstraint( + threads_per_wave=64, + mma_type=MMAType.F32_16x16x16_F16, + waves_per_block=(1, 2, 2), + ), + tkw.WorkgroupConstraint(M, BLOCK_M, 0), + tkw.WorkgroupConstraint(N, BLOCK_N, 1), + ] + + @tkw.wave(constraints) + def mma_chain( + a: tkl.Memory[M, K, GLOBAL_ADDRESS_SPACE, dtype], + b: tkl.Memory[N, K, GLOBAL_ADDRESS_SPACE, dtype], + c: tkl.Memory[M, P, GLOBAL_ADDRESS_SPACE, tkl.f32], + d: tkl.Memory[P, N, GLOBAL_ADDRESS_SPACE, dtype], + storage: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, dtype], + ): + a_read = tkw.read(a) + b_read = tkw.read(b) + c_reg = tkl.Register[M, N, tkl.f32](0.0) + mma1 = tkw.mma(a_read, b_read, c_reg) + mma1_casted = tkw.cast(mma1, tkl.f16) + tkw.write(mma1_casted, storage) + reloaded = tkw.read(storage) + d_read = tkw.read(d) + c_reg2 = tkl.Register[M, P, tkl.f32](0.0) + mma2 = tkw.mma(reloaded, d_read, c_reg2) + tkw.write(mma2, c) + + hyperparams = { + M: 128, + N: 128, + K: 128, + P: 128, + BLOCK_M: 16, + BLOCK_N: 16, + } + return mma_chain, hyperparams + + +def testMatrixAdd(): + kernel, params = _get_matrix_add_kernel() + options = WaveCompileOptions( + subs=params, + run_bench=False, + check_water_analysis=True, + ) + compiled_kernel = wave_compile(options, kernel) + assert compiled_kernel is not None + + def testGemm(): relevant_hyperparams = [ tkl.sym.M, @@ -54,5 +153,18 @@ def testGemm(): assert compiled_gemm is not None +def testMmaChain(): + kernel, params = _get_mma_chain_kernel() + options = WaveCompileOptions( + subs=params, + run_bench=False, + check_water_analysis=True, + ) + compiled_kernel = wave_compile(options, kernel) + assert compiled_kernel is not None + + if __name__ == "__main__": + testMatrixAdd() + testMmaChain() testGemm() diff --git a/water/include/water/Dialect/Wave/IR/WaveInterfaces.h b/water/include/water/Dialect/Wave/IR/WaveInterfaces.h index 2d2373d12f..c7bf7092b5 100644 --- a/water/include/water/Dialect/Wave/IR/WaveInterfaces.h +++ b/water/include/water/Dialect/Wave/IR/WaveInterfaces.h @@ -18,6 +18,7 @@ #include "llvm/Support/raw_ostream.h" #include "water/Dialect/Wave/IR/WaveAttrs.h" +#include "water/Dialect/Wave/Transforms/Utils.h" namespace wave { @@ -563,9 +564,17 @@ class IndexExprsAnalysisInit { // expressions. class IndexExprsLatticeStorage { public: + // Priorities for specific operations that may be used. + static constexpr int32_t kHighestPriority = + std::numeric_limits::max(); + static constexpr int32_t kMmaPriority = 3; + static constexpr int32_t kWritePriority = 1; + static constexpr int32_t kLowestPriority = 0; + IndexExprsLatticeStorage(); IndexExprsLatticeStorage(const IndexExprsLatticeStorage &value) = default; - IndexExprsLatticeStorage(mlir::DictionaryAttr concreteValue); + IndexExprsLatticeStorage(mlir::DictionaryAttr concreteValue, + int32_t priority); IndexExprsLatticeStorage & operator=(const IndexExprsLatticeStorage &other) = default; @@ -583,6 +592,10 @@ class IndexExprsLatticeStorage { // specified or not, or null if the lattice instance is a top or a bottom. mlir::DictionaryAttr getConcreteValue() const; + // Return the priority of this lattice instance or -1 if it is not a concrete + // value. + int32_t getPriority() const { return getConcreteValue() ? priority : -1; } + // Return the top lattice instance. static IndexExprsLatticeStorage top(); @@ -628,6 +641,10 @@ class IndexExprsLatticeStorage { // symbol indexing the value or one of the top/bottom flags. llvm::PointerIntPair value; + // Priority of this value. Specific values with higher priority override + // values with lower priority in joins. + int32_t priority; + // State flags. constexpr static unsigned kUninitializedState = 0; constexpr static unsigned kSpecificTypeState = 1; diff --git a/water/include/water/Dialect/Wave/IR/WaveInterfaces.td b/water/include/water/Dialect/Wave/IR/WaveInterfaces.td index 1c7aaaff73..2f4c409008 100644 --- a/water/include/water/Dialect/Wave/IR/WaveInterfaces.td +++ b/water/include/water/Dialect/Wave/IR/WaveInterfaces.td @@ -264,7 +264,8 @@ def WaveInferIndexExprsOpInterface : OpInterface<"WaveInferIndexExprsOpInterface "initializeIndexExprsBackward", (ins "::llvm::MutableArrayRef<::wave::IndexExprsLatticeStorage>":$operandExprs, "const ::wave::IndexExprsAnalysisInit &":$initObject, - "::wave::EmitErrorFn":$emitError), + "::wave::EmitErrorFn":$emitError, + "::wave::EmitDelayedErrorFn &":$delayedErrorEmitter), /*methodBody=*/"", /*defaultImplementation=*/[{ return ::llvm::success(); diff --git a/water/include/water/Dialect/Wave/IR/WaveOps.td b/water/include/water/Dialect/Wave/IR/WaveOps.td index 75ca7002cc..3186fe05e6 100644 --- a/water/include/water/Dialect/Wave/IR/WaveOps.td +++ b/water/include/water/Dialect/Wave/IR/WaveOps.td @@ -574,7 +574,7 @@ def WriteOp : WaveOp<"write", [ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, + ["getIndexExprValuesAndDescriptions", "initializeIndexExprsBackward"]>, RequiresSidewaysBackwardPropagationOpTrait]> { let summary = "Writes into memory"; let description = [{ diff --git a/water/include/water/Dialect/Wave/Transforms/DataFlowAnalyses.h b/water/include/water/Dialect/Wave/Transforms/DataFlowAnalyses.h index aefb0a3a38..a73f0d89bf 100644 --- a/water/include/water/Dialect/Wave/Transforms/DataFlowAnalyses.h +++ b/water/include/water/Dialect/Wave/Transforms/DataFlowAnalyses.h @@ -7,6 +7,7 @@ #ifndef WATER_DIALECT_WAVE_TRANSFORMS_DATAFLOWANALYSES_H #define WATER_DIALECT_WAVE_TRANSFORMS_DATAFLOWANALYSES_H +#include "water/Dialect/Wave/Transforms/Utils.h" #include "llvm/ADT/FunctionExtras.h" namespace llvm { @@ -23,7 +24,7 @@ class DictionaryAttr; namespace wave { using SetIndexLatticeFn = - llvm::function_ref; + llvm::function_ref; using OverrideInitializationFn = llvm::function_ref; @@ -36,13 +37,18 @@ struct WaveIndexExprsAnalysisOptions { }; // Add analyses for index expression propagation to the solver. -void addWaveIndexExprsAnalyses(mlir::DataFlowSolver &solver, - mlir::SymbolTableCollection &symbolTable, - WaveIndexExprsAnalysisOptions options = {}); - -llvm::LogicalResult -setWaveIndexExprAnalysisResults(mlir::Operation *top, - const mlir::DataFlowSolver &solver); +wave::DelayedErrorEmitterInfo +addWaveIndexExprsAnalyses(mlir::DataFlowSolver &solver, + mlir::SymbolTableCollection &symbolTable, + WaveIndexExprsAnalysisOptions options = {}); + +// Set the index attribute attributes on operations nested under `top` using the +// lattices computed by the dataflow analyses in the given solver. Emit delayed +// errors if they are related to operations for which we failed to infer index +// expressions. +llvm::LogicalResult setWaveIndexExprAnalysisResults( + mlir::Operation *top, const mlir::DataFlowSolver &solver, + const wave::DelayedErrorEmitterInfo &delayedErrorInfo); // Run the dataflow analyses and capture whether some diagnostics were emitted. // Only emit a generic diagnostic if no more specific diagnostic was emitted. diff --git a/water/include/water/Dialect/Wave/Transforms/Utils.h b/water/include/water/Dialect/Wave/Transforms/Utils.h index 4d3cc3692f..022762278d 100644 --- a/water/include/water/Dialect/Wave/Transforms/Utils.h +++ b/water/include/water/Dialect/Wave/Transforms/Utils.h @@ -11,6 +11,19 @@ namespace wave { +// Callback to generate a delayed diagnostic. The diagnostic message should be +// attached to the argument. It may or may not be emitted. +using EmitDelayedErrorFn = std::function; + +// Information to emit delayed errors. +struct DelayedErrorEmitterInfo { + // Returns the delayed error for the given operation. + std::function getDelayedError; + + // Returns true if there are any delayed errors. + std::function hasDelayedErrors; +}; + /// Get the hyperparameters from an ancestor operation. /// Returns nullptr if no hyperparameters are found. WaveHyperparameterAttr getHyperparameters(mlir::Operation *op); diff --git a/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp b/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp index f96aadeb6f..0dcce4ce1b 100644 --- a/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp +++ b/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp @@ -729,11 +729,11 @@ llvm::LogicalResult wave::detail::verifyCompatibleOperandsAndResultsOpTrait( //----------------------------------------------------------------------------- wave::IndexExprsLatticeStorage::IndexExprsLatticeStorage() - : value(nullptr, kUninitializedState) {} + : value(nullptr, kUninitializedState), priority(kLowestPriority) {} wave::IndexExprsLatticeStorage::IndexExprsLatticeStorage( - DictionaryAttr concreteValue) - : value(concreteValue, kSpecificTypeState) {} + DictionaryAttr concreteValue, int32_t priority) + : value(concreteValue, kSpecificTypeState), priority(priority) {} bool wave::IndexExprsLatticeStorage::operator==( const IndexExprsLatticeStorage &other) const { @@ -1097,12 +1097,18 @@ static FailureOr getIndexExprStepStrideJoinedMap( return failure(); } -// Join two concrete index expressions mappings by joining their -// start/step/stride maps independently. See getIndexExprStartJoinedMap and -// getIndexExprStepStrideJoinedMap for more details. +// Join two concrete index expressions mappings either by picking the +// higher-priority one or by joining their start/step/stride maps independently. +// See getIndexExprStartJoinedMap and getIndexExprStepStrideJoinedMap for more +// details on independent joining. static wave::WaveIndexMappingAttr getIndexExprsJoinMappings(wave::WaveIndexMappingAttr lhs, - wave::WaveIndexMappingAttr rhs) { + wave::WaveIndexMappingAttr rhs, int32_t lhsPriority, + int32_t rhsPriority) { + if (lhsPriority > rhsPriority) + return lhs; + if (rhsPriority > lhsPriority) + return rhs; // Collect all unique symbol names from both index mappings in order. llvm::SmallVector allSymbols; @@ -1162,7 +1168,8 @@ wave::IndexExprsLatticeStorage wave::IndexExprsLatticeStorage::join( namedAttr.getName().getValue()); }); return IndexExprsLatticeStorage( - DictionaryAttr::get(rhs.getConcreteValue().getContext(), filtered)); + DictionaryAttr::get(rhs.getConcreteValue().getContext(), filtered), + rhs.getPriority()); } if (rhs.isBottom()) @@ -1195,17 +1202,21 @@ wave::IndexExprsLatticeStorage wave::IndexExprsLatticeStorage::join( if (lhsValue == rhsValue) continue; - wave::WaveIndexMappingAttr joinedMapping = - getIndexExprsJoinMappings(lhsValue, rhsValue); + wave::WaveIndexMappingAttr joinedMapping = getIndexExprsJoinMappings( + lhsValue, rhsValue, lhs.getPriority(), rhs.getPriority()); if (!joinedMapping) return IndexExprsLatticeStorage::top(); result[namedAttr.getName()] = joinedMapping; } return IndexExprsLatticeStorage( - DictionaryAttr::get(ctx, llvm::map_to_vector(result, [](auto &&pair) { - return NamedAttribute(pair.first, pair.second); - }))); + DictionaryAttr::get(ctx, llvm::map_to_vector(result, + [](auto &&pair) { + return NamedAttribute( + pair.first, + pair.second); + })), + std::max(lhs.getPriority(), rhs.getPriority())); } wave::IndexExprsLatticeStorage @@ -1237,7 +1248,8 @@ wave::IndexExprsLatticeStorage wave::IndexExprsLatticeStorage::keepOnlySymbols( return bottom(); return IndexExprsLatticeStorage( - DictionaryAttr::get(getConcreteValue().getContext(), filtered)); + DictionaryAttr::get(getConcreteValue().getContext(), filtered), + getPriority()); } wave::IndexExprsLatticeStorage @@ -1257,7 +1269,8 @@ wave::IndexExprsLatticeStorage::withoutIterSymbols( } return NamedAttribute(attr.getName(), value); }); - return IndexExprsLatticeStorage(DictionaryAttr::get(ctx, updated)); + return IndexExprsLatticeStorage(DictionaryAttr::get(ctx, updated), + getPriority()); } void wave::IndexExprsLatticeStorage::print(llvm::raw_ostream &os) const { @@ -1266,7 +1279,7 @@ void wave::IndexExprsLatticeStorage::print(llvm::raw_ostream &os) const { } else if (isTop()) { os << ""; } else { - os << getConcreteValue(); + os << "[pri: " << getPriority() << "] " << getConcreteValue(); } } diff --git a/water/lib/Dialect/Wave/IR/WaveOps.cpp b/water/lib/Dialect/Wave/IR/WaveOps.cpp index bce01c65a8..a8009899f0 100644 --- a/water/lib/Dialect/Wave/IR/WaveOps.cpp +++ b/water/lib/Dialect/Wave/IR/WaveOps.cpp @@ -1164,7 +1164,9 @@ LogicalResult MmaOp::initializeIndexExprsForward( mixInThreadIndependentConstraints( *this, initObject.hardwareConstraint.getThreadsPerWave(), indexingSymbols, initObject.symbolConstraints, symbolMappings); - resultExprs[0].unsafeSet(DictionaryAttr::get(getContext(), symbolMappings)); + resultExprs[0].unsafeSet(wave::IndexExprsLatticeStorage( + DictionaryAttr::get(getContext(), symbolMappings), + wave::IndexExprsLatticeStorage::kMmaPriority)); return llvm::success(); } @@ -1175,8 +1177,8 @@ LogicalResult MmaOp::initializeIndexExprsForward( // well as workgroup constraints (thread-independent). LogicalResult MmaOp::initializeIndexExprsBackward( llvm::MutableArrayRef operandExprs, - const wave::IndexExprsAnalysisInit &initObject, - wave::EmitErrorFn emitError) { + const wave::IndexExprsAnalysisInit &initObject, wave::EmitErrorFn emitError, + wave::EmitDelayedErrorFn &delayedErrorEmitter) { auto resultType = llvm::cast(getResult().getType()); auto lhsType = llvm::cast(getLhs().getType()); assert(resultType.getRank() == lhsType.getRank() && lhsType.getRank() >= 2 && @@ -1242,13 +1244,16 @@ LogicalResult MmaOp::initializeIndexExprsBackward( operandExprs[getLhsMutable().getOperandNumber()] = wave::IndexExprsLatticeStorage( - DictionaryAttr::get(getContext(), lhsSymbolMappings)); + DictionaryAttr::get(getContext(), lhsSymbolMappings), + wave::IndexExprsLatticeStorage::kMmaPriority); operandExprs[getRhsMutable().getOperandNumber()] = wave::IndexExprsLatticeStorage( - DictionaryAttr::get(getContext(), rhsSymbolMappings)); + DictionaryAttr::get(getContext(), rhsSymbolMappings), + wave::IndexExprsLatticeStorage::kMmaPriority); operandExprs[getAccumulatorMutable().getOperandNumber()] = wave::IndexExprsLatticeStorage( - DictionaryAttr::get(getContext(), accumulatorSymbolMappings)); + DictionaryAttr::get(getContext(), accumulatorSymbolMappings), + wave::IndexExprsLatticeStorage::kMmaPriority); return llvm::success(); } @@ -2233,6 +2238,174 @@ llvm::FailureOr wave::WriteOp::propagateIndexExprsForward( return ChangeResult::NoChange; } +/// Computes the vector stride for each dimension: stride[i] is the product of +/// vector shapes for dimensions i+1 .. rank-1 (so the last dimension has +/// stride 1). A dimension is contiguous iff its stride is 1. +static FailureOr> +getVectorStrides(wave::WaveTensorType tensorType, + HardwareConstraintAttr hardwareConstraint) { + assert(tensorType.getFullySpecified() && + "expected fully-specified tensor type"); + DictionaryAttr vectorShapes = hardwareConstraint.getVectorShapes(); + if (!vectorShapes) + return failure(); + int64_t rank = tensorType.getRank(); + SmallVector strides(rank); + strides[rank - 1] = 1; + for (int64_t i = rank - 2; i >= 0; --i) { + Attribute vectorShape = + vectorShapes.get(tensorType.getShape()[i + 1].getName()); + if (!vectorShape) + return failure(); + int64_t shape = cast(vectorShape).getValue().getSExtValue(); + strides[i] = shape * strides[i + 1]; + } + return strides; +} + +LogicalResult WriteOp::initializeIndexExprsBackward( + llvm::MutableArrayRef operandExprs, + const wave::IndexExprsAnalysisInit &initObject, wave::EmitErrorFn emitError, + wave::EmitDelayedErrorFn &delayedErrorEmitter) { + + // TODO: figure out how to propagate elements per threads from constraints to + // operations while avoiding the clash with index sequences. When propagating, + // we don't have sequences yet. + WaveTensorType tensorType = cast(getValueToStore().getType()); + HardwareConstraintAttr hardwareConstraint = initObject.hardwareConstraint; + + assert(tensorType.getFullySpecified()); + FailureOr> stridesOr = + getVectorStrides(tensorType, hardwareConstraint); + // XXX: don't report this error immediately since we may be able to proceed + // without it, e.g., when index expressions may be propagated from + // operations with higher priority operations to this one. This is a + // questionable design choice carried over from the initial Python + // prototype, but is needed for initial consistency. Consider revising. + if (failed(stridesOr)) { + delayedErrorEmitter = [](InFlightDiagnostic &diag) { + diag << "couldn't find vector shapes in the contiguity check"; + }; + return success(); + } + llvm::ArrayRef strides = *stridesOr; + // XXX: pywave confusingly calls this "contiguous" but it is actually the + // dimension along which SIMD vectorization is applied, i.e., the deepest + // dimension for which the per-thread vector shape is not 1 or, alternatively, + // the product of vector shapes for trailing dimensions remains 1. + int64_t vectorizedDimPos = -1; + for (int64_t i = 0, e = tensorType.getRank(); i < e; ++i) + if (strides[i] == 1) { + vectorizedDimPos = i; + break; + } + + SmallVector indexMappings; + for (int64_t i = 0, e = tensorType.getRank(); i < e; ++i) { + AffineExpr elementsPerThread = nullptr; + bool isVectorized = (i == vectorizedDimPos); + + // The absence of constraints for a dimension means it is not mapped to + // workgroups/wave/items, so there is nothing to do here. + // We expect it to be handled by thread-independent constraints setting + // the default (0, 1, 1) index expression or following the tiling + // constraint. + auto it = initObject.symbolConstraints.find(tensorType.getShape()[i]); + if (it == initObject.symbolConstraints.end()) + continue; + auto wgConstraintIt = + llvm::find_if(it->second, llvm::IsaPred); + if (wgConstraintIt == it->second.end()) + continue; + WorkgroupConstraintAttr wgConstraint = + cast(*wgConstraintIt); + + // The innermost dimension with vectorized size other than 1 is the one we + // want to vectorize along. + std::optional opElementsPerThread = getElementsPerThread(); + SmallVector symbols; + if (isVectorized) { + if (opElementsPerThread) { + elementsPerThread = + getAffineConstantExpr(*opElementsPerThread, getContext()); + } else { + AffineMap tileSizeMap = wgConstraint.getTileSize().getMap(); + assert(tileSizeMap.getNumResults() == 1 && + "expected a single expression in tile size affine map"); + unsigned numThreads = [&]() { + switch (wgConstraint.getWorkgroupDim().getValue()) { + case WaveWorkgroupDim::X: + return initObject.wavesPerBlock[0] * + initObject.hardwareConstraint.getThreadsPerWave(); + case WaveWorkgroupDim::Y: + return initObject.wavesPerBlock[1]; + case WaveWorkgroupDim::Z: + return initObject.wavesPerBlock[2]; + } + }(); + + elementsPerThread = tileSizeMap.getResult(0).ceilDiv(numThreads); + llvm::append_range(symbols, wgConstraint.getTileSize().getSymbols()); + } + } else { + elementsPerThread = getAffineConstantExpr(1, getContext()); + } + + int64_t stride = strides[i]; + assert(stride > 0 && "stride should be positive"); + + WaveIndexSymbol threadSymbol = [&]() { + switch (wgConstraint.getWorkgroupDim().getValue()) { + case WaveWorkgroupDim::X: + return WaveIndexSymbol::THREAD_0; + case WaveWorkgroupDim::Y: + return WaveIndexSymbol::THREAD_1; + case WaveWorkgroupDim::Z: + return WaveIndexSymbol::THREAD_2; + } + }(); + WaveIndexSymbolAttr threadSymbolAttr = + WaveIndexSymbolAttr::get(getContext(), threadSymbol); + symbols.push_back(threadSymbolAttr); + + AffineExpr startExpr = + getAffineSymbolExpr(symbols.size() - 1, getContext()); + if (wgConstraint.getWorkgroupDim().getValue() == WaveWorkgroupDim::X) { + startExpr = startExpr % hardwareConstraint.getThreadsPerWave(); + } else { + // TODO: in pywave, we always do `startExpr % threadsPerWave` where + // threadsPerWave == 1 for workgroup dims other than X, making it + // always zero. It mentions an assumption about the (64, 1, 1) thread + // shape, but it is unclear whether that assumption always holds. + // It looks like the intention for this was to express lane ID rather + // than thread ID, but it is unclear how it accounts for multiple + // wavefronts running in parallel. + startExpr = getAffineConstantExpr(0, getContext()); + } + + auto indexMapping = WaveIndexMappingAttr::get( + getContext(), symbols, + AffineMap::get(/*dimCount=*/0, symbols.size(), + startExpr * elementsPerThread), + AffineMap::get(/*dimCount=*/0, symbols.size(), elementsPerThread), + AffineMap::get(/*dimCount=*/0, symbols.size(), + getAffineConstantExpr(stride, getContext()))); + indexMappings.emplace_back(tensorType.getShape()[i].getName(), + indexMapping); + } + mixInThreadIndependentConstraints( + *this, initObject.hardwareConstraint.getThreadsPerWave(), + tensorType.getShape(), initObject.symbolConstraints, indexMappings); + operandExprs[getValueToStoreMutable().getOperandNumber()] = + IndexExprsLatticeStorage(DictionaryAttr::get(getContext(), indexMappings), + IndexExprsLatticeStorage::kWritePriority); + operandExprs[getMemoryMutable().getOperandNumber()] = + IndexExprsLatticeStorage(DictionaryAttr::get(getContext(), indexMappings), + IndexExprsLatticeStorage::kWritePriority); + + return success(); +} + // Propagating "sideways" between operands, but only if this would not result // in conflicts. llvm::FailureOr wave::WriteOp::propagateIndexExprsBackward( @@ -2761,7 +2934,8 @@ permuteIndexExprsStrides(const IndexExprsLatticeStorage &inputLattice, NamedAttribute(StringAttr::get(ctx, srcSymbol.getName()), newMapping)); } - return IndexExprsLatticeStorage(DictionaryAttr::get(ctx, permutedMappings)); + return IndexExprsLatticeStorage(DictionaryAttr::get(ctx, permutedMappings), + inputLattice.getPriority()); } llvm::FailureOr wave::PermuteOp::propagateIndexExprsForward( diff --git a/water/lib/Dialect/Wave/Transforms/InferTypes.cpp b/water/lib/Dialect/Wave/Transforms/InferTypes.cpp index cc8126c53d..2af50ef48b 100644 --- a/water/lib/Dialect/Wave/Transforms/InferTypes.cpp +++ b/water/lib/Dialect/Wave/Transforms/InferTypes.cpp @@ -33,6 +33,7 @@ using namespace mlir; #define DEBUG_TYPE "wave-infer-types" +using namespace mlir; using wave::ElementsPerThreadLatticeValue; using wave::IndexExprsLatticeStorage; @@ -1396,7 +1397,10 @@ class IndexExprsForwardAnalysis wave::applyConstraint(tilingConstraint)}}); LDBG() << "setting iterate block argument lattice " << capture << " from " << PrintNoRegions(iterateOp) << " to " << dict; - unsafeSet(getLatticeElement(capture), dict); + unsafeSet( + getLatticeElement(capture), + wave::IndexExprsLatticeStorage( + dict, wave::IndexExprsLatticeStorage::kLowestPriority)); } } } @@ -1408,11 +1412,12 @@ class IndexExprsForwardAnalysis if (overrideInitialization) { if (llvm::failed(overrideInitialization( - top, [&](Value value, DictionaryAttr dict) { + top, [&](Value value, DictionaryAttr dict, int32_t priority) { if (!dict) return unsafeSet(getLatticeElement(value), IndexExprsLatticeStorage::top()); - unsafeSet(getLatticeElement(value), dict); + unsafeSet(getLatticeElement(value), + wave::IndexExprsLatticeStorage(dict, priority)); }))) return llvm::failure(); } @@ -1608,9 +1613,18 @@ class IndexExprsForwardAnalysis } } + // Return true if there are pending error reports. + bool hasDelayedErrors() const { return !delayedErrors.empty(); } + + // Return the emitter of a pending error report for the given operation. + wave::EmitDelayedErrorFn getDelayedError(Operation *op) const { + return delayedErrors.lookup_or(op, wave::EmitDelayedErrorFn()); + } + private: bool initialized = false; wave::OverrideInitializationFn overrideInitialization; + llvm::SmallDenseMap delayedErrors; }; class IndexExprsBackwardAnalysis @@ -1666,9 +1680,14 @@ class IndexExprsBackwardAnalysis LDBG() << "initializing index expressions backward for " << PrintNoRegions(op); + wave::EmitDelayedErrorFn delayedErrorEmitter = nullptr; if (llvm::failed(iface.initializeIndexExprsBackward( - operandExprs, *initObject, emitError))) + operandExprs, *initObject, emitError, delayedErrorEmitter))) return WalkResult::interrupt(); + if (delayedErrorEmitter) { + LDBG() << "delayed error recorded\n"; + delayedErrors[op] = delayedErrorEmitter; + } for (auto &&[i, operand, lattice] : llvm::enumerate(op->getOperands(), operandExprs)) { IndexExprsLattice *latticeObject = getLatticeElement(operand); @@ -1694,11 +1713,12 @@ class IndexExprsBackwardAnalysis if (overrideInitialization) { if (llvm::failed(overrideInitialization( - top, [&](Value value, DictionaryAttr dict) { + top, [&](Value value, DictionaryAttr dict, int32_t priority) { if (!dict) return unsafeSet(getLatticeElement(value), IndexExprsLatticeStorage::top()); - unsafeSet(getLatticeElement(value), dict); + unsafeSet(getLatticeElement(value), + wave::IndexExprsLatticeStorage(dict, priority)); }))) return llvm::failure(); } @@ -1873,9 +1893,18 @@ class IndexExprsBackwardAnalysis // by the forward analysis. } + // Returns true if there are any delayed errors. + bool hasDelayedErrors() const { return !delayedErrors.empty(); } + + // Returns the delayed error emitter for the given operation. + wave::EmitDelayedErrorFn getDelayedError(Operation *op) const { + return delayedErrors.lookup_or(op, wave::EmitDelayedErrorFn()); + } + private: bool initialized = false; wave::OverrideInitializationFn overrideInitialization; + llvm::SmallDenseMap delayedErrors; }; namespace { @@ -1899,16 +1928,17 @@ class InferIndexExprsPass config.setInterprocedural(false); DataFlowSolver solver(config); - solver.load(); - solver.load(); - wave::addWaveIndexExprsAnalyses(solver, symbolTable); + solver.load(); + solver.load(); + wave::DelayedErrorEmitterInfo delayedErrorInfo = + wave::addWaveIndexExprsAnalyses(solver, symbolTable); if (llvm::failed( wave::runSolverAndCaptureErrors(solver, getOperation(), false))) return signalPassFailure(); - if (llvm::failed( - wave::setWaveIndexExprAnalysisResults(getOperation(), solver))) + if (llvm::failed(wave::setWaveIndexExprAnalysisResults( + getOperation(), solver, delayedErrorInfo))) return signalPassFailure(); getOperation()->walk([&](wave::IterateOp iterateOp) { @@ -1922,21 +1952,46 @@ class InferIndexExprsPass }; } // namespace -void wave::addWaveIndexExprsAnalyses( - DataFlowSolver &solver, SymbolTableCollection &symbolTable, - wave::WaveIndexExprsAnalysisOptions options) { +wave::DelayedErrorEmitterInfo +wave::addWaveIndexExprsAnalyses(DataFlowSolver &solver, + SymbolTableCollection &symbolTable, + wave::WaveIndexExprsAnalysisOptions options) { + IndexExprsForwardAnalysis *forward = nullptr; if (!options.disableForward) { - solver.load(options.overrideInitialization); + forward = + solver.load(options.overrideInitialization); } + IndexExprsBackwardAnalysis *backward = nullptr; if (!options.disableBackward) { - solver.load(symbolTable, - options.overrideInitialization); + backward = solver.load( + symbolTable, options.overrideInitialization); } + + // Note that these lambdas are stored and used later so they must not capture + // anything that has a function-level lifetime. + wave::DelayedErrorEmitterInfo delayedErrorEmitterInfo; + delayedErrorEmitterInfo.getDelayedError = + [forward, backward](Operation *op) -> wave::EmitDelayedErrorFn { + if (forward) { + if (wave::EmitDelayedErrorFn delayedError = forward->getDelayedError(op)) + return delayedError; + } + if (backward) { + return backward->getDelayedError(op); + } + return nullptr; + }; + delayedErrorEmitterInfo.hasDelayedErrors = [forward, backward]() { + return (forward && forward->hasDelayedErrors()) || + (backward && backward->hasDelayedErrors()); + }; + return delayedErrorEmitterInfo; } -LogicalResult -wave::setWaveIndexExprAnalysisResults(Operation *top, - const DataFlowSolver &solver) { +LogicalResult wave::setWaveIndexExprAnalysisResults( + Operation *top, const DataFlowSolver &solver, + const DelayedErrorEmitterInfo &delayedErrorInfo) { + bool hadFailures = false; WalkResult walkResult = top->walk([&](wave::WaveInferIndexExprsOpInterface iface) { auto getLatticeValue = [&](Value value) { @@ -1956,13 +2011,27 @@ wave::setWaveIndexExprAnalysisResults(Operation *top, descriptionGenerator(os, i); if (failed(detail::checkAndAppendIndexExpr(iface->getLoc(), getLatticeValue(value), - os.str(), indexExprs))) - return WalkResult::interrupt(); + os.str(), indexExprs))) { + // Don't stop on the first reported error if there are some delayed + // errors that would be useful to report here. We need to wait and + // see whether the operation they are attached to actually has had + // inference issues as some errors may be corrected. + if (!delayedErrorInfo.hasDelayedErrors()) + return WalkResult::interrupt(); + + hadFailures = true; + if (auto delayedError = delayedErrorInfo.getDelayedError(iface)) { + InFlightDiagnostic diag = + iface->emitError() + << "the error above may be caused by the following: "; + delayedError(diag); + } + } } iface->setAttr(wave::WaveDialect::kIndexWaveExprListAttrName, ArrayAttr::get(iface->getContext(), indexExprs)); return WalkResult::advance(); }); - return llvm::failure(walkResult.wasInterrupted()); + return llvm::failure(hadFailures || walkResult.wasInterrupted()); } diff --git a/water/test/Dialect/Wave/infer-index-exprs-lattice.mlir b/water/test/Dialect/Wave/infer-index-exprs-lattice.mlir index e8256c754f..926038a51b 100644 --- a/water/test/Dialect/Wave/infer-index-exprs-lattice.mlir +++ b/water/test/Dialect/Wave/infer-index-exprs-lattice.mlir @@ -18,9 +18,9 @@ normalform.module [#wave.normal_form] { #wave.hardware_constraint ] } { - %lhs_override = wave.read %lhs { wave_test.override_result_index = [{ + %lhs_override = wave.read %lhs { wave_test.override_result_index = [[3, { N = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 32, 1, 1)> - }]} : (!wave.tensor<[@M, @K] of f16>) -> !wave.tensor<[@M, @K] of f16> + }]]} : (!wave.tensor<[@M, @K] of f16>) -> !wave.tensor<[@M, @K] of f16> // expected-error @below {{conflict when propagating forward to the result lattice in MmaOp}} // expected-note @below {{Result lattice}} // expected-note @below {{LHS lattice}} @@ -49,9 +49,9 @@ normalform.module [#wave.normal_form] { // expected-note @below {{LHS lattice}} // expected-note @below {{result lattice}} %r = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind, - wave_test.override_result_index = [ + wave_test.override_result_index = [[3, {K = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 32, 1, 1)>} - ] + ]] } : (!wave.tensor<[@M, @K] of f16>, !wave.tensor<[@N, @K] of f16>, !wave.tensor<[@M, @N] of f32>) -> !wave.tensor<[@M, @N] of f32> @@ -77,7 +77,7 @@ normalform.module [#wave.normal_form] { %r = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind, wave_test.override_operand_index = [ unit, - {N = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 32, 1, 1)>} + [3, {N = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 32, 1, 1)>}] ] } : (!wave.tensor<[@M, @K] of f16>, !wave.tensor<[@N, @K] of f16>, !wave.tensor<[@M, @N] of f32>) @@ -105,7 +105,7 @@ normalform.module [#wave.normal_form] { wave_test.override_operand_index = [ unit, unit, - {M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 32, 1, 1)>} + [3, {M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 32, 1, 1)>}] ] } : (!wave.tensor<[@M, @K] of f16>, !wave.tensor<[@N, @K] of f16>, !wave.tensor<[@M, @N] of f32>) @@ -126,18 +126,18 @@ normalform.module [#wave.normal_form] { #wave.hardware_constraint ] } { - %add = wave.add %a, %b {wave_test.override_result_index = [{ + %add = wave.add %a, %b {wave_test.override_result_index = [[1,{ M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 32, 1, 1)>, K = #wave.index_mapping<[#wave.index_symbol] -> (T1 * 16, 1, 1)> - }]}: (!wave.tensor<[@M, @K] of f16>, !wave.tensor<[@M, @K] of f16>) -> !wave.tensor<[@M, @K] of f16> + }]]}: (!wave.tensor<[@M, @K] of f16>, !wave.tensor<[@M, @K] of f16>) -> !wave.tensor<[@M, @K] of f16> // expected-error @below {{conflict when propagating index expressions from result to operand #0}} // expected-note @below {{original operand lattice}} // expected-note @below {{result #0 lattice}} - %mul = wave.mul %add, %c {wave_test.override_result_index = [{ + %mul = wave.mul %add, %c {wave_test.override_result_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40, 1, 1)>, K = #wave.index_mapping<[#wave.index_symbol] -> (T1 * 16, 1, 1)> - }]}: (!wave.tensor<[@M, @K] of f16>, !wave.tensor<[@M, @K] of f16>) -> !wave.tensor<[@M, @K] of f16> + }]]}: (!wave.tensor<[@M, @K] of f16>, !wave.tensor<[@M, @K] of f16>) -> !wave.tensor<[@M, @K] of f16> return %mul : !wave.tensor<[@M, @K] of f16> } } @@ -154,19 +154,19 @@ normalform.module [#wave.normal_form] attributes { wave_test.disable #wave.hardware_constraint ] } { - %add = wave.add %a, %b {wave_test.override_result_index = [{ + %add = wave.add %a, %b {wave_test.override_result_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 32, 1, 1)>, K = #wave.index_mapping<[#wave.index_symbol] -> (T1 * 16, 1, 1)> - }]}: (!wave.tensor<[@M, @K] of f16>, !wave.tensor<[@M, @K] of f16>) -> !wave.tensor<[@M, @K] of f16> + }]]}: (!wave.tensor<[@M, @K] of f16>, !wave.tensor<[@M, @K] of f16>) -> !wave.tensor<[@M, @K] of f16> // expected-error @below {{conflict when propagating index expressions from operand to result #0}} // expected-note @below {{original result lattice}} // expected-note @below {{operand #0 lattice}} // expected-note @below {{operand #1 lattice}} - %mul = wave.mul %add, %c {wave_test.override_result_index = [{ + %mul = wave.mul %add, %c {wave_test.override_result_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40, 1, 1)>, K = #wave.index_mapping<[#wave.index_symbol] -> (T1 * 16, 1, 1)> - }]}: (!wave.tensor<[@M, @K] of f16>, !wave.tensor<[@M, @K] of f16>) -> !wave.tensor<[@M, @K] of f16> + }]]}: (!wave.tensor<[@M, @K] of f16>, !wave.tensor<[@M, @K] of f16>) -> !wave.tensor<[@M, @K] of f16> return %mul : !wave.tensor<[@M, @K] of f16> } } @@ -186,13 +186,13 @@ normalform.module [#wave.normal_form] attributes { wave_test.disable // expected-error @below {{incompatible operand lattices when propagating from those to result}} // expected-note @below {{operand #0 lattice}} // expected-note @below {{operand #1 lattice}} - %add = wave.add %a, %b {wave_test.override_operand_index = [{ + %add = wave.add %a, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 32, 1, 1)>, K = #wave.index_mapping<[#wave.index_symbol] -> (T1 * 16, 1, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 44, 1, 1)>, K = #wave.index_mapping<[#wave.index_symbol] -> (T1 * 16, 1, 1)> - }]}: (!wave.tensor<[@M, @K] of f16>, !wave.tensor<[@M, @K] of f16>) -> !wave.tensor<[@M, @K] of f16> + }]]}: (!wave.tensor<[@M, @K] of f16>, !wave.tensor<[@M, @K] of f16>) -> !wave.tensor<[@M, @K] of f16> return %add : !wave.tensor<[@M, @K] of f16> } @@ -235,12 +235,12 @@ normalform.module [#wave.normal_form] { } { // CHECK: wave.add // CHECK-SAME: index - // CHECK-SAME: M : <[#wave.index_symbol] -> (T0 * 40, 1, 1)> - %result = wave.add %a, %b {wave_test.override_operand_index = [{ + // CHECK-SAME: M : [#wave.index_symbol] -> (T0 * 40, 1, 1) + %result = wave.add %a, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40, 1, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40, 1, 1)> - }]} + }]]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> return %result : !wave.tensor<[@M] of f32> } @@ -262,12 +262,12 @@ normalform.module [#wave.normal_form] { } { // CHECK: wave.add // CHECK-SAME: index - // CHECK-SAME: M : <[#wave.index_symbol] -> (T0 * 40, 1, 1)> - %result = wave.add %a, %b {wave_test.override_operand_index = [{ + // CHECK-SAME: M : [#wave.index_symbol] -> (T0 * 40, 1, 1) + %result = wave.add %a, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40, 1, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.index_symbol] -> (, 1, 1)> - }]} + }]]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> return %result : !wave.tensor<[@M] of f32> } @@ -289,10 +289,10 @@ normalform.module [#wave.normal_form] { } { // CHECK: wave.add // CHECK-SAME: index - // CHECK-SAME: M : <[#wave.index_symbol] -> (T0 * 40, 1, 1)> - %result = wave.add %a, %b {wave_test.override_operand_index = [{ + // CHECK-SAME: M : [#wave.index_symbol] -> (T0 * 40, 1, 1) + %result = wave.add %a, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40, 1, 1)> - }, + }], unit // will default-initialize to bottom. ]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> @@ -316,12 +316,12 @@ normalform.module [#wave.normal_form] { } { // CHECK: wave.add // CHECK-SAME: index - // CHECK-SAME: M : <[#wave.index_symbol] -> (T0 * 40, 1, 1)> - %result = wave.add %a, %b {wave_test.override_operand_index = [{ + // CHECK-SAME: M : [#wave.index_symbol] -> (T0 * 40, 1, 1) + %result = wave.add %a, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40, 1, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.index_symbol] -> (0, 1, 1)> - }]} + }]]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> return %result : !wave.tensor<[@M] of f32> } @@ -343,11 +343,11 @@ normalform.module [#wave.normal_form] { // expected-error @below {{incompatible operand lattices when propagating from those to result}} // expected-note @below {{operand #0 lattice}} // expected-note @below {{operand #1 lattice}} - %result = wave.add %a, %b {wave_test.override_operand_index = [{ + %result = wave.add %a, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40, 1, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40 + 1, 1, 1)> - }]} + }]]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> return %result : !wave.tensor<[@M] of f32> } @@ -369,11 +369,11 @@ normalform.module [#wave.normal_form] { // expected-error @below {{incompatible operand lattices when propagating from those to result}} // expected-note @below {{operand #0 lattice}} // expected-note @below {{operand #1 lattice}} - %result = wave.add %a, %b {wave_test.override_operand_index = [{ + %result = wave.add %a, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40 + 2, 1, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40 + 1, 1, 1)> - }]} + }]]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> return %result : !wave.tensor<[@M] of f32> } @@ -396,11 +396,11 @@ normalform.module [#wave.normal_form] { // expected-error @below {{incompatible operand lattices when propagating from those to result}} // expected-note @below {{operand #0 lattice}} // expected-note @below {{operand #1 lattice}} - %result = wave.add %a, %b {wave_test.override_operand_index = [{ + %result = wave.add %a, %b {wave_test.override_operand_index = [[1,{ M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40, 3, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40, 2, 1)> - }]} + }]]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> return %result : !wave.tensor<[@M] of f32> } @@ -422,11 +422,11 @@ normalform.module [#wave.normal_form] { // expected-error @below {{incompatible operand lattices when propagating from those to result}} // expected-note @below {{operand #0 lattice}} // expected-note @below {{operand #1 lattice}} - %result = wave.add %a, %b {wave_test.override_operand_index = [{ + %result = wave.add %a, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40, 1, 2)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40, 1, 3)> - }]} + }]]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> return %result : !wave.tensor<[@M] of f32> } @@ -447,12 +447,12 @@ normalform.module [#wave.normal_form] { } { // CHECK: wave.add // CHECK-SAME: index - // CHECK-SAME: M : <[#wave.index_symbol] -> (T0 * 40, 1, 2)> - %result = wave.add %a, %b {wave_test.override_operand_index = [{ + // CHECK-SAME: M : [#wave.index_symbol] -> (T0 * 40, 1, 2) + %result = wave.add %a, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40, 1, 2)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40, 1, 1)> - }]} + }]]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> return %result : !wave.tensor<[@M] of f32> } @@ -473,12 +473,12 @@ normalform.module [#wave.normal_form] { } { // CHECK: wave.add // CHECK-SAME: index - // CHECK-SAME: M : <[#wave.index_symbol] -> (T0 * 40, T0, 1)> - %result = wave.add %a, %b {wave_test.override_operand_index = [{ + // CHECK-SAME: M : [#wave.index_symbol] -> (T0 * 40, T0, 1) + %result = wave.add %a, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40, 1, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40, T0, 1)> - }]} + }]]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> return %result : !wave.tensor<[@M] of f32> } @@ -500,11 +500,11 @@ normalform.module [#wave.normal_form] { // expected-error @below {{incompatible operand lattices when propagating from those to result}} // expected-note @below {{operand #0 lattice}} // expected-note @below {{operand #1 lattice}} - %result = wave.add %a, %b {wave_test.override_operand_index = [{ + %result = wave.add %a, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0, T0, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.index_symbol, #wave.index_symbol] -> (T0, WG0, 1)> - }]} + }]]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> return %result : !wave.tensor<[@M] of f32> } @@ -528,11 +528,11 @@ normalform.module [#wave.normal_form] { // expected-error @below {{incompatible operand lattices when propagating from those to result}} // expected-note @below {{operand #0 lattice}} // expected-note @below {{operand #1 lattice}} - %result = wave.add %a, %b {wave_test.override_operand_index = [{ + %result = wave.add %a, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40, 1, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T1 * 40, 1, 1)> - }]} + }]]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> return %result : !wave.tensor<[@M] of f32> } @@ -555,11 +555,11 @@ normalform.module [#wave.normal_form] { // expected-error @below {{incompatible operand lattices when propagating from those to result}} // expected-note @below {{operand #0 lattice}} // expected-note @below {{operand #1 lattice}} - %result = wave.add %a, %b {wave_test.override_operand_index = [{ + %result = wave.add %a, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol, #wave.index_symbol] -> (WG0, 1, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.index_symbol, #wave.index_symbol] -> (WG1, 1, 1)> - }]} + }]]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> return %result : !wave.tensor<[@M] of f32> } @@ -582,12 +582,12 @@ normalform.module [#wave.normal_form] { } { // CHECK: wave.add // CHECK-SAME: index - // CHECK-SAME: {M : <[#wave.index_symbol, #wave.index_symbol] -> (WG0 + T0, 1, 1)> - %result = wave.add %a, %b {wave_test.override_operand_index = [{ + // CHECK-SAME: {M : [#wave.index_symbol, #wave.index_symbol] -> (WG0 + T0, 1, 1) + %result = wave.add %a, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol, #wave.index_symbol] -> (WG0, 1, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.index_symbol, #wave.index_symbol] -> (T0, 1, 1)> - }]} + }]]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> return %result : !wave.tensor<[@M] of f32> } @@ -609,12 +609,12 @@ normalform.module [#wave.normal_form] { } { // CHECK: wave.add // CHECK-SAME: index - // CHECK-SAME: {M : <[#wave.index_symbol, #wave.index_symbol] -> (WG0 + T0 + 2, 1, 1)> - %result = wave.add %a, %b {wave_test.override_operand_index = [{ + // CHECK-SAME: {M : [#wave.index_symbol, #wave.index_symbol] -> (WG0 + T0 + 2, 1, 1) + %result = wave.add %a, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (WG0 + 2, 1, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 + 2 , 1, 1)> - }]} + }]]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> return %result : !wave.tensor<[@M] of f32> } @@ -636,12 +636,12 @@ normalform.module [#wave.normal_form] { } { // CHECK: wave.add // CHECK-SAME: index - // CHECK-SAME: {M : <[#wave.index_symbol, #wave.index_symbol] -> (WG0 + T0, 1, 1)> - %result = wave.add %a, %b {wave_test.override_operand_index = [{ + // CHECK-SAME: {M : [#wave.index_symbol, #wave.index_symbol] -> (WG0 + T0, 1, 1) + %result = wave.add %a, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (WG0, 1, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0, 1, 1)> - }]} + }]]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> return %result : !wave.tensor<[@M] of f32> } @@ -666,12 +666,12 @@ normalform.module [#wave.normal_form] { ^bb0(%a_arg: !wave.tensor<[@M, @K] of f32>): // CHECK: wave.add // CHECK-SAME: index - // CHECK-SAME: M : <[#wave.index_symbol, #wave.iter<"K">] -> (WG0 + _Iter_K, 1, 1)> - %partial_result = wave.add %a_arg, %b {wave_test.override_operand_index = [{ + // CHECK-SAME: M : [#wave.index_symbol, #wave.iter<"K">] -> (WG0 + _Iter_K, 1, 1) + %partial_result = wave.add %a_arg, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (WG0, 1, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.iter<"K">] -> (_Iter_K, 1, 1)> - }]} + }]]} : (!wave.tensor<[@M, @K] of f32>, !wave.tensor<[@M, @K] of f32>) -> !wave.tensor<[@M, @K] of f32> wave.yield %partial_result : !wave.tensor<[@M, @K] of f32> @@ -698,12 +698,12 @@ normalform.module [#wave.normal_form] { ^bb0(%a_arg: !wave.tensor<[@M, @K] of f32>): // CHECK: wave.add // CHECK-SAME: index - // CHECK-SAME: M : <[#wave.iter<"K">] -> (_Iter_K + 42, 1, 1)> - %partial_result = wave.add %a_arg, %b {wave_test.override_operand_index = [{ + // CHECK-SAME: M : [#wave.iter<"K">] -> (_Iter_K + 42, 1, 1) + %partial_result = wave.add %a_arg, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.iter<"K">] -> (_Iter_K + 42, 1, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.iter<"K">] -> (_Iter_K + 42, 1, 1)> - }]} + }]]} : (!wave.tensor<[@M, @K] of f32>, !wave.tensor<[@M, @K] of f32>) -> !wave.tensor<[@M, @K] of f32> wave.yield %partial_result : !wave.tensor<[@M, @K] of f32> @@ -740,12 +740,12 @@ normalform.module [#wave.normal_form] { ^bb1(%a_arg: !wave.tensor<[@M, @K] of f32>): // CHECK: wave.add // CHECK-SAME: index - // CHECK-SAME: M : <[#wave.iter<"K">, #wave.iter<"M">] -> (_Iter_K + _Iter_M, 1, 1)> - %partial_result = wave.add %a_arg, %b_arg {wave_test.override_operand_index = [{ + // CHECK-SAME: M : [#wave.iter<"K">, #wave.iter<"M">] -> (_Iter_K + _Iter_M, 1, 1) + %partial_result = wave.add %a_arg, %b_arg {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.iter<"M">] -> (_Iter_M, 1, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.iter<"K">] -> (_Iter_K, 1, 1)> - }]} + }]]} : (!wave.tensor<[@M, @K] of f32>, !wave.tensor<[@M, @K] of f32>) -> !wave.tensor<[@M, @K] of f32> wave.yield %partial_result : !wave.tensor<[@M, @K] of f32> @@ -776,11 +776,11 @@ normalform.module [#wave.normal_form] { // expected-error @below {{incompatible operand lattices when propagating from those to result}} // expected-note @below {{operand #0 lattice}} // expected-note @below {{operand #1 lattice}} - %partial_result = wave.add %a_arg, %b {wave_test.override_operand_index = [{ + %partial_result = wave.add %a_arg, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.iter<"K">] -> (_Iter_K + 42, 1, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.iter<"K">] -> (_Iter_K * 2, 1, 1)> - }]} + }]]} : (!wave.tensor<[@M, @K] of f32>, !wave.tensor<[@M, @K] of f32>) -> !wave.tensor<[@M, @K] of f32> wave.yield %partial_result : !wave.tensor<[@M, @K] of f32> @@ -808,11 +808,11 @@ normalform.module [#wave.normal_form] { %b_reg = wave.read %b : (!wave.tensor<[@M, @K] of f32>) -> !wave.tensor<[@M, @K] of f32> %result = wave.iterate @K iter_args(%a) { ^bb0(%a_arg: !wave.tensor<[@M, @K] of f32>): - %partial_result = wave.add %a_arg, %b_reg {wave_test.override_operand_index = [{ + %partial_result = wave.add %a_arg, %b_reg {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.iter<"K">] -> (_Iter_K + 42, 1, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.iter<"K">] -> (_Iter_K + 42, 1, 1)> - }]} + }]]} : (!wave.tensor<[@M, @K] of f32>, !wave.tensor<[@M, @K] of f32>) -> !wave.tensor<[@M, @K] of f32> wave.yield %partial_result : !wave.tensor<[@M, @K] of f32> @@ -833,19 +833,21 @@ normalform.module [#wave.normal_form] { %c: !wave.tensor<[@M] of f32> ) attributes { wave.constraints = [ - #wave.hardware_constraint - ] + #wave.hardware_constraint, + #wave.workgroup_constraint, tile_size = <[] -> (42)>, workgroup_dim = > + ], + wave.hyperparameters = #wave.hyperparameters<{M = 1024}> } { // CHECK: index // CHECK: M : <[] -> (42, 1, )> wave.write %a, %b {wave_test.override_operand_index = [ - unit, { + unit, [1, { M = #wave.index_mapping<[#wave.iter<"K">] -> (, 1, )> - }] + }]] } : !wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32> - %c_reg = wave.read %c {wave_test.override_result_index = [{ + %c_reg = wave.read %c {wave_test.override_result_index = [[1, { M = #wave.index_mapping<[#wave.iter<"K">] -> (42, , )> - }]} : (!wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> + }]]} : (!wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> wave.write %c_reg, %b : !wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32> return } @@ -853,6 +855,109 @@ normalform.module [#wave.normal_form] { // ----- +normalform.module [#wave.normal_form] { + // CHECK-LABEL: @priority_join + func.func @priority_join( + %a: !wave.tensor<[@M] of f32>, + %b: !wave.tensor<[@M] of f32> + ) attributes { + wave.constraints = [ + #wave.hardware_constraint + ] + } { + // Low priority lattice. + %a_reg = wave.read %a {wave_test.override_result_index = [[0, { + M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 10, 1, 1)> + }]]} : (!wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> + + // Higher priority lattice on the other operand. Selected without causing a conflict. + // CHECK: wave.add + // CHECK-SAME: M : [#wave.index_symbol] -> (T0 * 20, 1, 1) + %sum = wave.add %a_reg, %b {wave_test.override_operand_index = [ + unit, + [3, { + M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 20, 1, 1)> + }] + ]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> + + return + } +} + +// ----- + +normalform.module [#wave.normal_form] attributes { wave_test.disable_backward } { + func.func @same_priority_conflict( + %a: !wave.tensor<[@M] of f32>, + %b: !wave.tensor<[@M] of f32> + ) attributes { + wave.constraints = [ + #wave.hardware_constraint + ] + } { + // Both operands have same priority but different expressions - should conflict + // expected-error @below {{incompatible operand lattices when propagating from those to result}} + // expected-note @below {{operand #0 lattice}} + // expected-note @below {{operand #1 lattice}} + %sum = wave.add %a, %b {wave_test.override_operand_index = [[1, { + M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 10, 1, 1)> + }], [1, { + M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 20, 1, 1)> + }]]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> + + return + } +} + +// ----- + +// Test that higher priority from write propagates backward through multiple operations. + +normalform.module [#wave.normal_form] { + // CHECK-LABEL: @priority_backward_through_chain + func.func @priority_backward_through_chain( + %a: !wave.tensor<[@M] of f32>, + %b: !wave.tensor<[@M] of f32>, + %output: !wave.tensor<[@M] of f32> + ) attributes { + wave.constraints = [ + #wave.hardware_constraint + ] + } { + // CHECK: wave.read + // CHECK-SAME: M : [#wave.index_symbol] -> (T0 * 30, 1, 1) + %a_reg = wave.read %a : (!wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> + + // The override has low priority, it only matters at initialization and is + // then itself overridden by the higher-priority lattice in backpropagation. + // CHECK: wave.add + // CHECK-SAME: M : [#wave.index_symbol] -> (T0 * 30, 1, 1) + %sum1 = wave.add %a_reg, %b {wave_test.override_operand_index = [ + unit, + [0, { + M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 5, 1, 1)> + }] + ]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> + + // CHECK: wave.add + // CHECK-SAME: M : [#wave.index_symbol] -> (T0 * 30, 1, 1) + %sum2 = wave.add %sum1, %a_reg : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> + + // CHECK: wave.write + // CHECK-SAME: M : [#wave.index_symbol] -> (T0 * 30, 1, 1) + wave.write %sum2, %output {wave_test.override_operand_index = [ + unit, + [3, { + M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 30, 1, 1)> + }] + ]} : !wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32> + + return + } +} + +// ----- + // Check that sideways propagation between operands of a write that would // lead to a conflict is not happening. @@ -864,23 +969,26 @@ normalform.module [#wave.normal_form] { %c: !wave.tensor<[@M] of f32> ) attributes { wave.constraints = [ - #wave.hardware_constraint - ] + #wave.hardware_constraint, + #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, workgroup_dim = >, + #wave.wave_constraint, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M floordiv 4)>> + ], + wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 64 : i64, M = 128}> } { // CHECK: wave.write // CHECK: index // CHECK: M : <[] -> (1, , )> wave.write %a, %b {wave_test.override_operand_index = [ - unit, { + unit, [1, { M = #wave.index_mapping<[#wave.iter<"K">] -> (1, , )> - }] + }]] } : !wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32> // CHECK: wave.read // CHECK: index - // CHECK: M : <[] -> (42, , )> - %c_reg = wave.read %c {wave_test.override_result_index = [{ + // CHECK: M : [] -> (42, , ) + %c_reg = wave.read %c {wave_test.override_result_index = [[1, { M = #wave.index_mapping<[#wave.iter<"K">] -> (42, , )> - }]} : (!wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> + }]]} : (!wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> wave.write %c_reg, %b : !wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32> return } diff --git a/water/test/Dialect/Wave/infer-index-exprs.mlir b/water/test/Dialect/Wave/infer-index-exprs.mlir index 6e1af3d240..eec60b399b 100644 --- a/water/test/Dialect/Wave/infer-index-exprs.mlir +++ b/water/test/Dialect/Wave/infer-index-exprs.mlir @@ -633,6 +633,7 @@ normalform.module [#wave.normal_form] { ]} { // expected-error @below {{failed to infer index expressions for value to store}} + // expected-error @below {{the error above may be caused by the following: couldn't find vector shapes in the contiguity check}} wave.write %src, %dst : !wave.tensor<[@M, @N] of f32>, !wave.tensor<[@M, @N] of f32, > return @@ -963,3 +964,248 @@ normalform.module [#wave.normal_form] { return } } + +// ----- + +normalform.module [#wave.normal_form] { + // CHECK-LABEL: @propagate_from_write + func.func @propagate_from_write( + %a: !wave.tensor<[@M, @N] of f32>, + %b: !wave.tensor<[@M, @N] of f32>, + %output: !wave.tensor<[@M, @N] of f32> + ) attributes { + wave.constraints = [ + #wave.hardware_constraint, vector_shapes = {M = 4 : i64, N = 1 : i64}>, + #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, workgroup_dim = >, + #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_N">] -> (BLOCK_N)>, workgroup_dim = > + ], + wave.hyperparameters = #wave.hyperparameters<{M = 128, N = 128, BLOCK_M = 64 : i64, BLOCK_N = 64 : i64}> + } { + // CHECK: wave.read + // CHECK-DAG: M : <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> ((T0 mod 64) * (BLOCK_M ceildiv 64) + WG0 * BLOCK_M, BLOCK_M ceildiv 64, 1)> + // CHECK-DAG: N : <[#wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N, 1, 1)> + %a_reg = wave.read %a : (!wave.tensor<[@M, @N] of f32>) -> !wave.tensor<[@M, @N] of f32, > + // CHECK: wave.read + // CHECK-DAG: M : <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> ((T0 mod 64) * (BLOCK_M ceildiv 64) + WG0 * BLOCK_M, BLOCK_M ceildiv 64, 1)> + // CHECK-DAG: N : <[#wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N, 1, 1)> + %b_reg = wave.read %b : (!wave.tensor<[@M, @N] of f32>) -> !wave.tensor<[@M, @N] of f32, > + + // CHECK: wave.add + // CHECK-DAG: M : <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> ((T0 mod 64) * (BLOCK_M ceildiv 64) + WG0 * BLOCK_M, BLOCK_M ceildiv 64, 1)> + // CHECK-DAG: N : <[#wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N, 1, 1)> + %sum = wave.add %a_reg, %b_reg : (!wave.tensor<[@M, @N] of f32, >, !wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + + // CHECK: wave.write + // CHECK-DAG: M : <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> ((T0 mod 64) * (BLOCK_M ceildiv 64) + WG0 * BLOCK_M, BLOCK_M ceildiv 64, 1)> + // CHECK-DAG: N : <[#wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N, 1, 1)> + wave.write %sum, %output : !wave.tensor<[@M, @N] of f32, >, !wave.tensor<[@M, @N] of f32> + + return + } +} + +// ----- + +// Elements per thread provided on the op used instead of the value inferred from workgroup constraints. +normalform.module [#wave.normal_form] { + // CHECK-LABEL: @propagate_from_write_explicit_ept + func.func @propagate_from_write_explicit_ept( + %output: !wave.tensor<[@M, @N] of f32> + ) attributes { + wave.constraints = [ + #wave.hardware_constraint, vector_shapes = {M = 4 : i64, N = 1 : i64}>, + #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, workgroup_dim = >, + #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_N">] -> (BLOCK_N)>, workgroup_dim = > + ], + wave.hyperparameters = #wave.hyperparameters<{M = 128, N = 128, BLOCK_M = 64 : i64, BLOCK_N = 64 : i64}> + } { + %cst = arith.constant 0.0 : f32 + %reg = wave.register %cst : !wave.tensor<[@M, @N] of f32, > + // CHECK: wave.write + // CHECK-DAG: M : <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> ((T0 mod 64) * 8 + WG0 * BLOCK_M, 8, 1)> + // CHECK-DAG: N : <[#wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N, 1, 1)> + wave.write %reg, %output {elements_per_thread = 8} : !wave.tensor<[@M, @N] of f32, >, !wave.tensor<[@M, @N] of f32> + + return + } +} + +// ----- + +// Elements per thread is used for the trailing dimension because its vector shape is no longer 1. +normalform.module [#wave.normal_form] { + // CHECK-LABEL: @propagate_from_write_vector_shape + func.func @propagate_from_write_vector_shape( + %output: !wave.tensor<[@M, @N] of f32> + ) attributes { + wave.constraints = [ + #wave.hardware_constraint, vector_shapes = {M = 4 : i64, N = 16 : i64}>, + #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, workgroup_dim = >, + #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_N">] -> (BLOCK_N)>, workgroup_dim = > + ], + wave.hyperparameters = #wave.hyperparameters<{M = 128, N = 128, BLOCK_M = 64 : i64, BLOCK_N = 64 : i64}> + } { + %cst = arith.constant 0.0 : f32 + %reg = wave.register %cst : !wave.tensor<[@M, @N] of f32, > + // CHECK: wave.write + // CHECK-DAG: M : <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (T0 mod 64 + WG0 * BLOCK_M, 1, 16)> + // CHECK-DAG: N : <[#wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N, 8, 1)> + wave.write %reg, %output {elements_per_thread = 8} : !wave.tensor<[@M, @N] of f32, >, !wave.tensor<[@M, @N] of f32> + + return + } +} + +// ----- + +// Test that unmapped dimensions get default (0, 1, 1) index expressions +// when there are no workgroup/wave/tiling constraints for them. + +normalform.module [#wave.normal_form] { + // CHECK-LABEL: @unmapped_dimension_default + func.func @unmapped_dimension_default( + %a: !wave.tensor<[@B, @M, @N] of f32>, + %output: !wave.tensor<[@B, @M, @N] of f32> + ) attributes { + wave.constraints = [ + #wave.hardware_constraint, vector_shapes = {M = 4 : i64, N = 1 : i64}>, + #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, workgroup_dim = >, + #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_N">] -> (BLOCK_N)>, workgroup_dim = > + ], + wave.hyperparameters = #wave.hyperparameters<{B = 8, M = 128, N = 128, BLOCK_M = 64 : i64, BLOCK_N = 64 : i64}> + } { + // Read should get default index expression for B dimension (no constraints) + // and computed expressions for M and N dimensions + // CHECK: wave.read + // CHECK-DAG: B : <[] -> (0, 1, 1)> + // CHECK-DAG: M : <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> ((T0 mod 64) * (BLOCK_M ceildiv 64) + WG0 * BLOCK_M, BLOCK_M ceildiv 64, 1)> + // CHECK-DAG: N : <[#wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N, 1, 1)> + %a_reg = wave.read %a : (!wave.tensor<[@B, @M, @N] of f32>) -> !wave.tensor<[@B, @M, @N] of f32, > + + // Write should preserve the same index expressions + // CHECK: wave.write + // CHECK-DAG: B : <[] -> (0, 1, 1)> + // CHECK-DAG: M : <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> ((T0 mod 64) * (BLOCK_M ceildiv 64) + WG0 * BLOCK_M, BLOCK_M ceildiv 64, 1)> + // CHECK-DAG: N : <[#wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N, 1, 1)> + wave.write %a_reg, %output : !wave.tensor<[@B, @M, @N] of f32, >, !wave.tensor<[@B, @M, @N] of f32> + + return + } +} + +// ----- + +// Test priority-based propagation with multiple write operations. +// All writes should establish index expressions with the same priority, +// and the join should succeed since they agree. + +normalform.module [#wave.normal_form] { + // CHECK-LABEL: @multiple_writes_consistent + func.func @multiple_writes_consistent( + %a: !wave.tensor<[@M, @N] of f32>, + %b: !wave.tensor<[@M, @N] of f32>, + %out1: !wave.tensor<[@M, @N] of f32>, + %out2: !wave.tensor<[@M, @N] of f32> + ) attributes { + wave.constraints = [ + #wave.hardware_constraint, vector_shapes = {M = 4 : i64, N = 1 : i64}>, + #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, workgroup_dim = >, + #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_N">] -> (BLOCK_N)>, workgroup_dim = > + ], + wave.hyperparameters = #wave.hyperparameters<{M = 128, N = 128, BLOCK_M = 64 : i64, BLOCK_N = 64 : i64}> + } { + // CHECK: wave.read + // CHECK-DAG: M : <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> ((T0 mod 64) * (BLOCK_M ceildiv 64) + WG0 * BLOCK_M, BLOCK_M ceildiv 64, 1)> + // CHECK-DAG: N : <[#wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N, 1, 1)> + %a_reg = wave.read %a : (!wave.tensor<[@M, @N] of f32>) -> !wave.tensor<[@M, @N] of f32, > + + // CHECK: wave.read + // CHECK-DAG: M : <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> ((T0 mod 64) * (BLOCK_M ceildiv 64) + WG0 * BLOCK_M, BLOCK_M ceildiv 64, 1)> + // CHECK-DAG: N : <[#wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N, 1, 1)> + %b_reg = wave.read %b : (!wave.tensor<[@M, @N] of f32>) -> !wave.tensor<[@M, @N] of f32, > + + // CHECK: wave.add + // CHECK-DAG: M : <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> ((T0 mod 64) * (BLOCK_M ceildiv 64) + WG0 * BLOCK_M, BLOCK_M ceildiv 64, 1)> + // CHECK-DAG: N : <[#wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N, 1, 1)> + %sum = wave.add %a_reg, %b_reg : (!wave.tensor<[@M, @N] of f32, >, !wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + + // Both writes establish the same index expressions + // CHECK: wave.write + // CHECK-DAG: M : <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> ((T0 mod 64) * (BLOCK_M ceildiv 64) + WG0 * BLOCK_M, BLOCK_M ceildiv 64, 1)> + // CHECK-DAG: N : <[#wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N, 1, 1)> + wave.write %sum, %out1 : !wave.tensor<[@M, @N] of f32, >, !wave.tensor<[@M, @N] of f32> + + // CHECK: wave.write + // CHECK-DAG: M : <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> ((T0 mod 64) * (BLOCK_M ceildiv 64) + WG0 * BLOCK_M, BLOCK_M ceildiv 64, 1)> + // CHECK-DAG: N : <[#wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N, 1, 1)> + wave.write %a_reg, %out2 : !wave.tensor<[@M, @N] of f32, >, !wave.tensor<[@M, @N] of f32> + + return + } +} + +// ----- + +// Test write when all dimension symbols are absent from constraints. +normalform.module [#wave.normal_form] { + // CHECK-LABEL: @write_all_dimensions_unmapped + func.func @write_all_dimensions_unmapped( + %a: !wave.tensor<[@P, @Q] of f32>, + %output: !wave.tensor<[@P, @Q] of f32> + ) attributes { + wave.constraints = [ + #wave.hardware_constraint, vector_shapes = {P = 1 : i64, Q = 1 : i64}> + ], + wave.hyperparameters = #wave.hyperparameters<{P = 8 : i64, Q = 16 : i64}> + } { + // CHECK: wave.read + // CHECK-DAG: P : <[] -> (0, 1, 1)> + // CHECK-DAG: Q : <[] -> (0, 1, 1)> + %a_reg = wave.read %a : (!wave.tensor<[@P, @Q] of f32>) -> !wave.tensor<[@P, @Q] of f32, > + + // CHECK: wave.write + // CHECK-DAG: P : <[] -> (0, 1, 1)> + // CHECK-DAG: Q : <[] -> (0, 1, 1)> + wave.write %a_reg, %output : !wave.tensor<[@P, @Q] of f32, >, !wave.tensor<[@P, @Q] of f32> + + return + } +} + + +// ----- + +// MMa index expression has higher priority than write. + +normalform.module [#wave.normal_form] { + // CHECK-LABEL: @write_after_mma_priority + func.func @write_after_mma_priority( + %a: !wave.tensor<[@M, @K] of f16>, + %b: !wave.tensor<[@N, @K] of f16>, + %c: !wave.tensor<[@M, @N] of f32>, + %output: !wave.tensor<[@M, @N] of f32> + ) attributes { + wave.constraints = [ + #wave.hardware_constraint, vector_shapes = {M = 4 : i64, N = 1 : i64}>, + #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, workgroup_dim = >, + #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_N">] -> (BLOCK_N)>, workgroup_dim = >, + #wave.wave_constraint, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>>, + #wave.wave_constraint, tile_size = <[#wave.symbol<"BLOCK_N">] -> (BLOCK_N)>> + ], + wave.hyperparameters = #wave.hyperparameters<{M = 128, N = 128, K = 64, BLOCK_M = 64 : i64, BLOCK_N = 64 : i64}> + } { + %a_reg = wave.read %a : (!wave.tensor<[@M, @K] of f16>) -> !wave.tensor<[@M, @K] of f16, > + %b_reg = wave.read %b : (!wave.tensor<[@N, @K] of f16>) -> !wave.tensor<[@N, @K] of f16, > + %mma = wave.mma %a_reg, %b_reg, %c {kind = #wave.mma_kind} + : (!wave.tensor<[@M, @K] of f16, >, !wave.tensor<[@N, @K] of f16, >, !wave.tensor<[@M, @N] of f32>) -> !wave.tensor<[@M, @N] of f32, > + // Write in isolation would have inferred step=floordiv(BLOCK_M, 64) since M is mapped + // to workgroup X with 64 threads in it, but we obtain step=4 propagate from the mma + // above, because that has higher priority. + // CHECK: wave.write + // CHECK-DAG: M : <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (((T0 mod 64) floordiv 16) * 4 + WG0 * BLOCK_M + // CHECK-DAG: N : <[#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (T0 mod 16 + WG1 * BLOCK_N + wave.write %mma, %output : !wave.tensor<[@M, @N] of f32, >, !wave.tensor<[@M, @N] of f32> + + return + } +} diff --git a/water/test/lib/Transforms/TestWaveDialectInferIndexExprs.cpp b/water/test/lib/Transforms/TestWaveDialectInferIndexExprs.cpp index e46a4a0195..ff1131f718 100644 --- a/water/test/lib/Transforms/TestWaveDialectInferIndexExprs.cpp +++ b/water/test/lib/Transforms/TestWaveDialectInferIndexExprs.cpp @@ -38,21 +38,25 @@ overrideInitialization(Operation *top, continue; if (auto strAttr = llvm::dyn_cast(attr); strAttr && strAttr.getValue() == "") { - setIndexForValue(value, nullptr); + setIndexForValue(value, nullptr, + wave::IndexExprsLatticeStorage::kHighestPriority); continue; } - auto dict = llvm::dyn_cast(attr); - if (!dict || llvm::any_of(dict.getValue(), [](NamedAttribute attr) { - return !llvm::isa(attr.getValue()); - })) { - return op->emitError() - << "expected " << attributeName - << " to be an array of " - "dictionaries with WaveIndexMappingAttr or UnitAttr values"; + auto array = llvm::dyn_cast(attr); + auto priority = (array && !array.empty()) + ? llvm::dyn_cast(array.getValue()[0]) + : IntegerAttr(); + auto dict = (array && array.size() > 1) + ? llvm::dyn_cast(array.getValue()[1]) + : DictionaryAttr(); + if (!priority || !dict) { + return op->emitError() << "expected " << attributeName + << " to be an array containing an integer " + "priority and a dictionary mapping"; } - setIndexForValue(value, dict); + setIndexForValue(value, dict, priority.getInt()); } return success(); }; @@ -100,7 +104,8 @@ class TestWaveDialectInferIndexExprsPass options.disableForward = getOperation()->getAttrOfType( "wave_test.disable_forward") != nullptr; options.overrideInitialization = overrideInitialization; - addWaveIndexExprsAnalyses(solver, symbolTable, options); + wave::DelayedErrorEmitterInfo delayedErrorInfo = + wave::addWaveIndexExprsAnalyses(solver, symbolTable, options); IRRewriter rewriter(&getContext()); getOperation()->walk( @@ -109,7 +114,8 @@ class TestWaveDialectInferIndexExprsPass if (failed(wave::runSolverAndCaptureErrors(solver, getOperation(), false))) return signalPassFailure(); - if (failed(setWaveIndexExprAnalysisResults(getOperation(), solver))) + if (failed(setWaveIndexExprAnalysisResults(getOperation(), solver, + delayedErrorInfo))) return signalPassFailure(); getOperation()->walk([&](wave::IterateOp iterateOp) { diff --git a/wave_lang/kernel/wave/analysis/index_sequence_analysis.py b/wave_lang/kernel/wave/analysis/index_sequence_analysis.py index 4c72b29663..26a0961fd2 100644 --- a/wave_lang/kernel/wave/analysis/index_sequence_analysis.py +++ b/wave_lang/kernel/wave/analysis/index_sequence_analysis.py @@ -21,7 +21,7 @@ from wave_lang.kernel.wave.compile_options import WaveCompileOptions from wave_lang.support.logging import get_logger -from ..._support.indexing import IndexSequence, IndexSymbol +from ..._support.indexing import IndexSequence, IndexSymbol, IndexExpr from ..._support.tracing import CapturedTrace from ...lang.global_symbols import * from ...ops.wave_ops import ( @@ -328,18 +328,72 @@ def _check_water_indices(trace: CapturedTrace, inferred: dict[str, IndexSequence if isinstance(custom, GetResult): continue + # Assumptions on symbols may be insufficiently tight, in particular for + # non-counting symbols we can assume strict positivity, which allows us + # to get rid of Max(1, ...) expressions that otherwise appear on the + # python side. + def ensure_symbols_positive( + seqs: dict[IndexSymbol, IndexSequence], + ) -> dict[IndexSymbol, IndexSequence]: + all_symbols = set() + for _, seq in seqs.items(): + if isinstance(seq.start, sympy.Expr): + all_symbols.update(seq.start.free_symbols) + if isinstance(seq.size, sympy.Expr): + all_symbols.update(seq.size.free_symbols) + if isinstance(seq.stride, sympy.Expr): + all_symbols.update(seq.stride.free_symbols) + + symbol_remapping = { + symbol: ( + sympy.Symbol(symbol.name, nonnegative=True, integer=True) + if symbol.name.startswith("$") + else sympy.Symbol(symbol.name, positive=True, integer=True) + ) + for symbol in all_symbols + } + return { + dim: IndexSequence( + start=( + sympy.simplify( + seq.start.subs(symbol_remapping, simultaneous=True) + ) + if isinstance(seq.start, IndexExpr) + else seq.start + ), + size=( + sympy.simplify( + seq.size.subs(symbol_remapping, simultaneous=True) + ) + if isinstance(seq.size, IndexExpr) + else seq.size + ), + stride=( + sympy.simplify( + seq.stride.subs(symbol_remapping, simultaneous=True) + ) + if isinstance(seq.stride, IndexExpr) + else seq.stride + ), + ) + for dim, seq in node.index.items() + } + + node_index = ensure_symbols_positive(node.index) + inferred_index = ensure_symbols_positive(inferred_index) + # Check that that indices match, raise an error if they don't. Start by # a trivial direct comparison, fall back to computing and simplifying # the difference. The latter can raise with additional information, # which this wants to preserve. try: - if node.index != inferred_index and not _check_index_difference_is_zero( - node.index, inferred_index + if node_index != inferred_index and not _check_index_difference_is_zero( + node_index, inferred_index ): raise ValueError("mismatching indices") except ValueError as e: raise RuntimeError( - f"Index for node {get_custom(node)}, {get_custom(node).index} does not match inferred index {inferred_index}." + f"Index for node {get_custom(node)}, {node_index} does not match inferred index {inferred_index}." ) from e