diff --git a/lit_tests/kernel/wave/infer_index_exprs.py b/lit_tests/kernel/wave/infer_index_exprs.py index 7ec72c0555..7f125bbf3d 100644 --- a/lit_tests/kernel/wave/infer_index_exprs.py +++ b/lit_tests/kernel/wave/infer_index_exprs.py @@ -9,6 +9,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import wave_lang.kernel.lang as tkl +import wave_lang.kernel.wave as tkw from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile from wave_lang.kernel.wave.templates.gemm import get_gemm_kernel @@ -16,6 +17,104 @@ from wave_lang.kernel.wave.constraints import MMAType +def _get_matrix_add_kernel(): + M = tkl.sym.M + N = tkl.sym.N + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + ADDRESS_SPACE = tkl.sym.GLOBAL_ADDRESS_SPACE + dtype = tkl.f16 + + constraints = [ + tkw.HardwareConstraint(threads_per_wave=64, vector_shapes={M: 4, N: 1}), + tkw.WorkgroupConstraint(M, BLOCK_M, 0), + tkw.WorkgroupConstraint(N, BLOCK_N, 1), + ] + + @tkw.wave(constraints) + def matrix_add( + a: tkl.Memory[M, N, ADDRESS_SPACE, dtype], + b: tkl.Memory[M, N, ADDRESS_SPACE, dtype], + c: tkl.Memory[M, N, ADDRESS_SPACE, dtype], + ): + a_reg = tkw.read(a) + b_reg = tkw.read(b) + c_reg = a_reg + b_reg + tkw.write(c_reg, c) + + hyperparams = { + M: 128, + N: 128, + BLOCK_M: 16, + BLOCK_N: 16, + ADDRESS_SPACE: GLOBAL_ADDRESS_SPACE, + } + return matrix_add, hyperparams + + +def _get_mma_chain_kernel(): + # Input sizes + M = tkl.sym.M + N = tkl.sym.N + K = tkl.sym.K + P = tkl.sym.P + # Workgroup tile sizes + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + dtype = tkl.f16 + + constraints = [ + tkw.HardwareConstraint( + threads_per_wave=64, + mma_type=MMAType.F32_16x16x16_F16, + waves_per_block=(1, 2, 2), + ), + tkw.WorkgroupConstraint(M, BLOCK_M, 0), + tkw.WorkgroupConstraint(N, BLOCK_N, 1), + ] + + @tkw.wave(constraints) + def mma_chain( + a: tkl.Memory[M, K, GLOBAL_ADDRESS_SPACE, dtype], + b: tkl.Memory[N, K, GLOBAL_ADDRESS_SPACE, dtype], + c: tkl.Memory[M, P, GLOBAL_ADDRESS_SPACE, tkl.f32], + d: tkl.Memory[P, N, GLOBAL_ADDRESS_SPACE, dtype], + storage: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, dtype], + ): + a_read = tkw.read(a) + b_read = tkw.read(b) + c_reg = tkl.Register[M, N, tkl.f32](0.0) + mma1 = tkw.mma(a_read, b_read, c_reg) + mma1_casted = tkw.cast(mma1, tkl.f16) + tkw.write(mma1_casted, storage) + reloaded = tkw.read(storage) + d_read = tkw.read(d) + c_reg2 = tkl.Register[M, P, tkl.f32](0.0) + mma2 = tkw.mma(reloaded, d_read, c_reg2) + tkw.write(mma2, c) + + hyperparams = { + M: 128, + N: 128, + K: 128, + P: 128, + BLOCK_M: 16, + BLOCK_N: 16, + } + return mma_chain, hyperparams + + +def testMatrixAdd(): + kernel, params = _get_matrix_add_kernel() + options = WaveCompileOptions( + subs=params, + run_bench=False, + check_water_analysis=True, + ) + compiled_kernel = wave_compile(options, kernel) + assert compiled_kernel is not None + + def testGemm(): relevant_hyperparams = [ tkl.sym.M, @@ -54,5 +153,18 @@ def testGemm(): assert compiled_gemm is not None +def testMmaChain(): + kernel, params = _get_mma_chain_kernel() + options = WaveCompileOptions( + subs=params, + run_bench=False, + check_water_analysis=True, + ) + compiled_kernel = wave_compile(options, kernel) + assert compiled_kernel is not None + + if __name__ == "__main__": + testMatrixAdd() + testMmaChain() testGemm() diff --git a/water/include/water/Dialect/Wave/IR/WaveInterfaces.h b/water/include/water/Dialect/Wave/IR/WaveInterfaces.h index 2d2373d12f..c7bf7092b5 100644 --- a/water/include/water/Dialect/Wave/IR/WaveInterfaces.h +++ b/water/include/water/Dialect/Wave/IR/WaveInterfaces.h @@ -18,6 +18,7 @@ #include "llvm/Support/raw_ostream.h" #include "water/Dialect/Wave/IR/WaveAttrs.h" +#include "water/Dialect/Wave/Transforms/Utils.h" namespace wave { @@ -563,9 +564,17 @@ class IndexExprsAnalysisInit { // expressions. class IndexExprsLatticeStorage { public: + // Priorities for specific operations that may be used. + static constexpr int32_t kHighestPriority = + std::numeric_limits::max(); + static constexpr int32_t kMmaPriority = 3; + static constexpr int32_t kWritePriority = 1; + static constexpr int32_t kLowestPriority = 0; + IndexExprsLatticeStorage(); IndexExprsLatticeStorage(const IndexExprsLatticeStorage &value) = default; - IndexExprsLatticeStorage(mlir::DictionaryAttr concreteValue); + IndexExprsLatticeStorage(mlir::DictionaryAttr concreteValue, + int32_t priority); IndexExprsLatticeStorage & operator=(const IndexExprsLatticeStorage &other) = default; @@ -583,6 +592,10 @@ class IndexExprsLatticeStorage { // specified or not, or null if the lattice instance is a top or a bottom. mlir::DictionaryAttr getConcreteValue() const; + // Return the priority of this lattice instance or -1 if it is not a concrete + // value. + int32_t getPriority() const { return getConcreteValue() ? priority : -1; } + // Return the top lattice instance. static IndexExprsLatticeStorage top(); @@ -628,6 +641,10 @@ class IndexExprsLatticeStorage { // symbol indexing the value or one of the top/bottom flags. llvm::PointerIntPair value; + // Priority of this value. Specific values with higher priority override + // values with lower priority in joins. + int32_t priority; + // State flags. constexpr static unsigned kUninitializedState = 0; constexpr static unsigned kSpecificTypeState = 1; diff --git a/water/include/water/Dialect/Wave/IR/WaveInterfaces.td b/water/include/water/Dialect/Wave/IR/WaveInterfaces.td index 1c7aaaff73..2f4c409008 100644 --- a/water/include/water/Dialect/Wave/IR/WaveInterfaces.td +++ b/water/include/water/Dialect/Wave/IR/WaveInterfaces.td @@ -264,7 +264,8 @@ def WaveInferIndexExprsOpInterface : OpInterface<"WaveInferIndexExprsOpInterface "initializeIndexExprsBackward", (ins "::llvm::MutableArrayRef<::wave::IndexExprsLatticeStorage>":$operandExprs, "const ::wave::IndexExprsAnalysisInit &":$initObject, - "::wave::EmitErrorFn":$emitError), + "::wave::EmitErrorFn":$emitError, + "::wave::EmitDelayedErrorFn &":$delayedErrorEmitter), /*methodBody=*/"", /*defaultImplementation=*/[{ return ::llvm::success(); diff --git a/water/include/water/Dialect/Wave/IR/WaveOps.td b/water/include/water/Dialect/Wave/IR/WaveOps.td index 75ca7002cc..3186fe05e6 100644 --- a/water/include/water/Dialect/Wave/IR/WaveOps.td +++ b/water/include/water/Dialect/Wave/IR/WaveOps.td @@ -574,7 +574,7 @@ def WriteOp : WaveOp<"write", [ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, + ["getIndexExprValuesAndDescriptions", "initializeIndexExprsBackward"]>, RequiresSidewaysBackwardPropagationOpTrait]> { let summary = "Writes into memory"; let description = [{ diff --git a/water/include/water/Dialect/Wave/Transforms/DataFlowAnalyses.h b/water/include/water/Dialect/Wave/Transforms/DataFlowAnalyses.h index aefb0a3a38..a73f0d89bf 100644 --- a/water/include/water/Dialect/Wave/Transforms/DataFlowAnalyses.h +++ b/water/include/water/Dialect/Wave/Transforms/DataFlowAnalyses.h @@ -7,6 +7,7 @@ #ifndef WATER_DIALECT_WAVE_TRANSFORMS_DATAFLOWANALYSES_H #define WATER_DIALECT_WAVE_TRANSFORMS_DATAFLOWANALYSES_H +#include "water/Dialect/Wave/Transforms/Utils.h" #include "llvm/ADT/FunctionExtras.h" namespace llvm { @@ -23,7 +24,7 @@ class DictionaryAttr; namespace wave { using SetIndexLatticeFn = - llvm::function_ref; + llvm::function_ref; using OverrideInitializationFn = llvm::function_ref; @@ -36,13 +37,18 @@ struct WaveIndexExprsAnalysisOptions { }; // Add analyses for index expression propagation to the solver. -void addWaveIndexExprsAnalyses(mlir::DataFlowSolver &solver, - mlir::SymbolTableCollection &symbolTable, - WaveIndexExprsAnalysisOptions options = {}); - -llvm::LogicalResult -setWaveIndexExprAnalysisResults(mlir::Operation *top, - const mlir::DataFlowSolver &solver); +wave::DelayedErrorEmitterInfo +addWaveIndexExprsAnalyses(mlir::DataFlowSolver &solver, + mlir::SymbolTableCollection &symbolTable, + WaveIndexExprsAnalysisOptions options = {}); + +// Set the index attribute attributes on operations nested under `top` using the +// lattices computed by the dataflow analyses in the given solver. Emit delayed +// errors if they are related to operations for which we failed to infer index +// expressions. +llvm::LogicalResult setWaveIndexExprAnalysisResults( + mlir::Operation *top, const mlir::DataFlowSolver &solver, + const wave::DelayedErrorEmitterInfo &delayedErrorInfo); // Run the dataflow analyses and capture whether some diagnostics were emitted. // Only emit a generic diagnostic if no more specific diagnostic was emitted. diff --git a/water/include/water/Dialect/Wave/Transforms/Utils.h b/water/include/water/Dialect/Wave/Transforms/Utils.h index 4d3cc3692f..022762278d 100644 --- a/water/include/water/Dialect/Wave/Transforms/Utils.h +++ b/water/include/water/Dialect/Wave/Transforms/Utils.h @@ -11,6 +11,19 @@ namespace wave { +// Callback to generate a delayed diagnostic. The diagnostic message should be +// attached to the argument. It may or may not be emitted. +using EmitDelayedErrorFn = std::function; + +// Information to emit delayed errors. +struct DelayedErrorEmitterInfo { + // Returns the delayed error for the given operation. + std::function getDelayedError; + + // Returns true if there are any delayed errors. + std::function hasDelayedErrors; +}; + /// Get the hyperparameters from an ancestor operation. /// Returns nullptr if no hyperparameters are found. WaveHyperparameterAttr getHyperparameters(mlir::Operation *op); diff --git a/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp b/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp index f96aadeb6f..0dcce4ce1b 100644 --- a/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp +++ b/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp @@ -729,11 +729,11 @@ llvm::LogicalResult wave::detail::verifyCompatibleOperandsAndResultsOpTrait( //----------------------------------------------------------------------------- wave::IndexExprsLatticeStorage::IndexExprsLatticeStorage() - : value(nullptr, kUninitializedState) {} + : value(nullptr, kUninitializedState), priority(kLowestPriority) {} wave::IndexExprsLatticeStorage::IndexExprsLatticeStorage( - DictionaryAttr concreteValue) - : value(concreteValue, kSpecificTypeState) {} + DictionaryAttr concreteValue, int32_t priority) + : value(concreteValue, kSpecificTypeState), priority(priority) {} bool wave::IndexExprsLatticeStorage::operator==( const IndexExprsLatticeStorage &other) const { @@ -1097,12 +1097,18 @@ static FailureOr getIndexExprStepStrideJoinedMap( return failure(); } -// Join two concrete index expressions mappings by joining their -// start/step/stride maps independently. See getIndexExprStartJoinedMap and -// getIndexExprStepStrideJoinedMap for more details. +// Join two concrete index expressions mappings either by picking the +// higher-priority one or by joining their start/step/stride maps independently. +// See getIndexExprStartJoinedMap and getIndexExprStepStrideJoinedMap for more +// details on independent joining. static wave::WaveIndexMappingAttr getIndexExprsJoinMappings(wave::WaveIndexMappingAttr lhs, - wave::WaveIndexMappingAttr rhs) { + wave::WaveIndexMappingAttr rhs, int32_t lhsPriority, + int32_t rhsPriority) { + if (lhsPriority > rhsPriority) + return lhs; + if (rhsPriority > lhsPriority) + return rhs; // Collect all unique symbol names from both index mappings in order. llvm::SmallVector allSymbols; @@ -1162,7 +1168,8 @@ wave::IndexExprsLatticeStorage wave::IndexExprsLatticeStorage::join( namedAttr.getName().getValue()); }); return IndexExprsLatticeStorage( - DictionaryAttr::get(rhs.getConcreteValue().getContext(), filtered)); + DictionaryAttr::get(rhs.getConcreteValue().getContext(), filtered), + rhs.getPriority()); } if (rhs.isBottom()) @@ -1195,17 +1202,21 @@ wave::IndexExprsLatticeStorage wave::IndexExprsLatticeStorage::join( if (lhsValue == rhsValue) continue; - wave::WaveIndexMappingAttr joinedMapping = - getIndexExprsJoinMappings(lhsValue, rhsValue); + wave::WaveIndexMappingAttr joinedMapping = getIndexExprsJoinMappings( + lhsValue, rhsValue, lhs.getPriority(), rhs.getPriority()); if (!joinedMapping) return IndexExprsLatticeStorage::top(); result[namedAttr.getName()] = joinedMapping; } return IndexExprsLatticeStorage( - DictionaryAttr::get(ctx, llvm::map_to_vector(result, [](auto &&pair) { - return NamedAttribute(pair.first, pair.second); - }))); + DictionaryAttr::get(ctx, llvm::map_to_vector(result, + [](auto &&pair) { + return NamedAttribute( + pair.first, + pair.second); + })), + std::max(lhs.getPriority(), rhs.getPriority())); } wave::IndexExprsLatticeStorage @@ -1237,7 +1248,8 @@ wave::IndexExprsLatticeStorage wave::IndexExprsLatticeStorage::keepOnlySymbols( return bottom(); return IndexExprsLatticeStorage( - DictionaryAttr::get(getConcreteValue().getContext(), filtered)); + DictionaryAttr::get(getConcreteValue().getContext(), filtered), + getPriority()); } wave::IndexExprsLatticeStorage @@ -1257,7 +1269,8 @@ wave::IndexExprsLatticeStorage::withoutIterSymbols( } return NamedAttribute(attr.getName(), value); }); - return IndexExprsLatticeStorage(DictionaryAttr::get(ctx, updated)); + return IndexExprsLatticeStorage(DictionaryAttr::get(ctx, updated), + getPriority()); } void wave::IndexExprsLatticeStorage::print(llvm::raw_ostream &os) const { @@ -1266,7 +1279,7 @@ void wave::IndexExprsLatticeStorage::print(llvm::raw_ostream &os) const { } else if (isTop()) { os << ""; } else { - os << getConcreteValue(); + os << "[pri: " << getPriority() << "] " << getConcreteValue(); } } diff --git a/water/lib/Dialect/Wave/IR/WaveOps.cpp b/water/lib/Dialect/Wave/IR/WaveOps.cpp index bce01c65a8..a8009899f0 100644 --- a/water/lib/Dialect/Wave/IR/WaveOps.cpp +++ b/water/lib/Dialect/Wave/IR/WaveOps.cpp @@ -1164,7 +1164,9 @@ LogicalResult MmaOp::initializeIndexExprsForward( mixInThreadIndependentConstraints( *this, initObject.hardwareConstraint.getThreadsPerWave(), indexingSymbols, initObject.symbolConstraints, symbolMappings); - resultExprs[0].unsafeSet(DictionaryAttr::get(getContext(), symbolMappings)); + resultExprs[0].unsafeSet(wave::IndexExprsLatticeStorage( + DictionaryAttr::get(getContext(), symbolMappings), + wave::IndexExprsLatticeStorage::kMmaPriority)); return llvm::success(); } @@ -1175,8 +1177,8 @@ LogicalResult MmaOp::initializeIndexExprsForward( // well as workgroup constraints (thread-independent). LogicalResult MmaOp::initializeIndexExprsBackward( llvm::MutableArrayRef operandExprs, - const wave::IndexExprsAnalysisInit &initObject, - wave::EmitErrorFn emitError) { + const wave::IndexExprsAnalysisInit &initObject, wave::EmitErrorFn emitError, + wave::EmitDelayedErrorFn &delayedErrorEmitter) { auto resultType = llvm::cast(getResult().getType()); auto lhsType = llvm::cast(getLhs().getType()); assert(resultType.getRank() == lhsType.getRank() && lhsType.getRank() >= 2 && @@ -1242,13 +1244,16 @@ LogicalResult MmaOp::initializeIndexExprsBackward( operandExprs[getLhsMutable().getOperandNumber()] = wave::IndexExprsLatticeStorage( - DictionaryAttr::get(getContext(), lhsSymbolMappings)); + DictionaryAttr::get(getContext(), lhsSymbolMappings), + wave::IndexExprsLatticeStorage::kMmaPriority); operandExprs[getRhsMutable().getOperandNumber()] = wave::IndexExprsLatticeStorage( - DictionaryAttr::get(getContext(), rhsSymbolMappings)); + DictionaryAttr::get(getContext(), rhsSymbolMappings), + wave::IndexExprsLatticeStorage::kMmaPriority); operandExprs[getAccumulatorMutable().getOperandNumber()] = wave::IndexExprsLatticeStorage( - DictionaryAttr::get(getContext(), accumulatorSymbolMappings)); + DictionaryAttr::get(getContext(), accumulatorSymbolMappings), + wave::IndexExprsLatticeStorage::kMmaPriority); return llvm::success(); } @@ -2233,6 +2238,174 @@ llvm::FailureOr wave::WriteOp::propagateIndexExprsForward( return ChangeResult::NoChange; } +/// Computes the vector stride for each dimension: stride[i] is the product of +/// vector shapes for dimensions i+1 .. rank-1 (so the last dimension has +/// stride 1). A dimension is contiguous iff its stride is 1. +static FailureOr> +getVectorStrides(wave::WaveTensorType tensorType, + HardwareConstraintAttr hardwareConstraint) { + assert(tensorType.getFullySpecified() && + "expected fully-specified tensor type"); + DictionaryAttr vectorShapes = hardwareConstraint.getVectorShapes(); + if (!vectorShapes) + return failure(); + int64_t rank = tensorType.getRank(); + SmallVector strides(rank); + strides[rank - 1] = 1; + for (int64_t i = rank - 2; i >= 0; --i) { + Attribute vectorShape = + vectorShapes.get(tensorType.getShape()[i + 1].getName()); + if (!vectorShape) + return failure(); + int64_t shape = cast(vectorShape).getValue().getSExtValue(); + strides[i] = shape * strides[i + 1]; + } + return strides; +} + +LogicalResult WriteOp::initializeIndexExprsBackward( + llvm::MutableArrayRef operandExprs, + const wave::IndexExprsAnalysisInit &initObject, wave::EmitErrorFn emitError, + wave::EmitDelayedErrorFn &delayedErrorEmitter) { + + // TODO: figure out how to propagate elements per threads from constraints to + // operations while avoiding the clash with index sequences. When propagating, + // we don't have sequences yet. + WaveTensorType tensorType = cast(getValueToStore().getType()); + HardwareConstraintAttr hardwareConstraint = initObject.hardwareConstraint; + + assert(tensorType.getFullySpecified()); + FailureOr> stridesOr = + getVectorStrides(tensorType, hardwareConstraint); + // XXX: don't report this error immediately since we may be able to proceed + // without it, e.g., when index expressions may be propagated from + // operations with higher priority operations to this one. This is a + // questionable design choice carried over from the initial Python + // prototype, but is needed for initial consistency. Consider revising. + if (failed(stridesOr)) { + delayedErrorEmitter = [](InFlightDiagnostic &diag) { + diag << "couldn't find vector shapes in the contiguity check"; + }; + return success(); + } + llvm::ArrayRef strides = *stridesOr; + // XXX: pywave confusingly calls this "contiguous" but it is actually the + // dimension along which SIMD vectorization is applied, i.e., the deepest + // dimension for which the per-thread vector shape is not 1 or, alternatively, + // the product of vector shapes for trailing dimensions remains 1. + int64_t vectorizedDimPos = -1; + for (int64_t i = 0, e = tensorType.getRank(); i < e; ++i) + if (strides[i] == 1) { + vectorizedDimPos = i; + break; + } + + SmallVector indexMappings; + for (int64_t i = 0, e = tensorType.getRank(); i < e; ++i) { + AffineExpr elementsPerThread = nullptr; + bool isVectorized = (i == vectorizedDimPos); + + // The absence of constraints for a dimension means it is not mapped to + // workgroups/wave/items, so there is nothing to do here. + // We expect it to be handled by thread-independent constraints setting + // the default (0, 1, 1) index expression or following the tiling + // constraint. + auto it = initObject.symbolConstraints.find(tensorType.getShape()[i]); + if (it == initObject.symbolConstraints.end()) + continue; + auto wgConstraintIt = + llvm::find_if(it->second, llvm::IsaPred); + if (wgConstraintIt == it->second.end()) + continue; + WorkgroupConstraintAttr wgConstraint = + cast(*wgConstraintIt); + + // The innermost dimension with vectorized size other than 1 is the one we + // want to vectorize along. + std::optional opElementsPerThread = getElementsPerThread(); + SmallVector symbols; + if (isVectorized) { + if (opElementsPerThread) { + elementsPerThread = + getAffineConstantExpr(*opElementsPerThread, getContext()); + } else { + AffineMap tileSizeMap = wgConstraint.getTileSize().getMap(); + assert(tileSizeMap.getNumResults() == 1 && + "expected a single expression in tile size affine map"); + unsigned numThreads = [&]() { + switch (wgConstraint.getWorkgroupDim().getValue()) { + case WaveWorkgroupDim::X: + return initObject.wavesPerBlock[0] * + initObject.hardwareConstraint.getThreadsPerWave(); + case WaveWorkgroupDim::Y: + return initObject.wavesPerBlock[1]; + case WaveWorkgroupDim::Z: + return initObject.wavesPerBlock[2]; + } + }(); + + elementsPerThread = tileSizeMap.getResult(0).ceilDiv(numThreads); + llvm::append_range(symbols, wgConstraint.getTileSize().getSymbols()); + } + } else { + elementsPerThread = getAffineConstantExpr(1, getContext()); + } + + int64_t stride = strides[i]; + assert(stride > 0 && "stride should be positive"); + + WaveIndexSymbol threadSymbol = [&]() { + switch (wgConstraint.getWorkgroupDim().getValue()) { + case WaveWorkgroupDim::X: + return WaveIndexSymbol::THREAD_0; + case WaveWorkgroupDim::Y: + return WaveIndexSymbol::THREAD_1; + case WaveWorkgroupDim::Z: + return WaveIndexSymbol::THREAD_2; + } + }(); + WaveIndexSymbolAttr threadSymbolAttr = + WaveIndexSymbolAttr::get(getContext(), threadSymbol); + symbols.push_back(threadSymbolAttr); + + AffineExpr startExpr = + getAffineSymbolExpr(symbols.size() - 1, getContext()); + if (wgConstraint.getWorkgroupDim().getValue() == WaveWorkgroupDim::X) { + startExpr = startExpr % hardwareConstraint.getThreadsPerWave(); + } else { + // TODO: in pywave, we always do `startExpr % threadsPerWave` where + // threadsPerWave == 1 for workgroup dims other than X, making it + // always zero. It mentions an assumption about the (64, 1, 1) thread + // shape, but it is unclear whether that assumption always holds. + // It looks like the intention for this was to express lane ID rather + // than thread ID, but it is unclear how it accounts for multiple + // wavefronts running in parallel. + startExpr = getAffineConstantExpr(0, getContext()); + } + + auto indexMapping = WaveIndexMappingAttr::get( + getContext(), symbols, + AffineMap::get(/*dimCount=*/0, symbols.size(), + startExpr * elementsPerThread), + AffineMap::get(/*dimCount=*/0, symbols.size(), elementsPerThread), + AffineMap::get(/*dimCount=*/0, symbols.size(), + getAffineConstantExpr(stride, getContext()))); + indexMappings.emplace_back(tensorType.getShape()[i].getName(), + indexMapping); + } + mixInThreadIndependentConstraints( + *this, initObject.hardwareConstraint.getThreadsPerWave(), + tensorType.getShape(), initObject.symbolConstraints, indexMappings); + operandExprs[getValueToStoreMutable().getOperandNumber()] = + IndexExprsLatticeStorage(DictionaryAttr::get(getContext(), indexMappings), + IndexExprsLatticeStorage::kWritePriority); + operandExprs[getMemoryMutable().getOperandNumber()] = + IndexExprsLatticeStorage(DictionaryAttr::get(getContext(), indexMappings), + IndexExprsLatticeStorage::kWritePriority); + + return success(); +} + // Propagating "sideways" between operands, but only if this would not result // in conflicts. llvm::FailureOr wave::WriteOp::propagateIndexExprsBackward( @@ -2761,7 +2934,8 @@ permuteIndexExprsStrides(const IndexExprsLatticeStorage &inputLattice, NamedAttribute(StringAttr::get(ctx, srcSymbol.getName()), newMapping)); } - return IndexExprsLatticeStorage(DictionaryAttr::get(ctx, permutedMappings)); + return IndexExprsLatticeStorage(DictionaryAttr::get(ctx, permutedMappings), + inputLattice.getPriority()); } llvm::FailureOr wave::PermuteOp::propagateIndexExprsForward( diff --git a/water/lib/Dialect/Wave/Transforms/InferTypes.cpp b/water/lib/Dialect/Wave/Transforms/InferTypes.cpp index cc8126c53d..2af50ef48b 100644 --- a/water/lib/Dialect/Wave/Transforms/InferTypes.cpp +++ b/water/lib/Dialect/Wave/Transforms/InferTypes.cpp @@ -33,6 +33,7 @@ using namespace mlir; #define DEBUG_TYPE "wave-infer-types" +using namespace mlir; using wave::ElementsPerThreadLatticeValue; using wave::IndexExprsLatticeStorage; @@ -1396,7 +1397,10 @@ class IndexExprsForwardAnalysis wave::applyConstraint(tilingConstraint)}}); LDBG() << "setting iterate block argument lattice " << capture << " from " << PrintNoRegions(iterateOp) << " to " << dict; - unsafeSet(getLatticeElement(capture), dict); + unsafeSet( + getLatticeElement(capture), + wave::IndexExprsLatticeStorage( + dict, wave::IndexExprsLatticeStorage::kLowestPriority)); } } } @@ -1408,11 +1412,12 @@ class IndexExprsForwardAnalysis if (overrideInitialization) { if (llvm::failed(overrideInitialization( - top, [&](Value value, DictionaryAttr dict) { + top, [&](Value value, DictionaryAttr dict, int32_t priority) { if (!dict) return unsafeSet(getLatticeElement(value), IndexExprsLatticeStorage::top()); - unsafeSet(getLatticeElement(value), dict); + unsafeSet(getLatticeElement(value), + wave::IndexExprsLatticeStorage(dict, priority)); }))) return llvm::failure(); } @@ -1608,9 +1613,18 @@ class IndexExprsForwardAnalysis } } + // Return true if there are pending error reports. + bool hasDelayedErrors() const { return !delayedErrors.empty(); } + + // Return the emitter of a pending error report for the given operation. + wave::EmitDelayedErrorFn getDelayedError(Operation *op) const { + return delayedErrors.lookup_or(op, wave::EmitDelayedErrorFn()); + } + private: bool initialized = false; wave::OverrideInitializationFn overrideInitialization; + llvm::SmallDenseMap delayedErrors; }; class IndexExprsBackwardAnalysis @@ -1666,9 +1680,14 @@ class IndexExprsBackwardAnalysis LDBG() << "initializing index expressions backward for " << PrintNoRegions(op); + wave::EmitDelayedErrorFn delayedErrorEmitter = nullptr; if (llvm::failed(iface.initializeIndexExprsBackward( - operandExprs, *initObject, emitError))) + operandExprs, *initObject, emitError, delayedErrorEmitter))) return WalkResult::interrupt(); + if (delayedErrorEmitter) { + LDBG() << "delayed error recorded\n"; + delayedErrors[op] = delayedErrorEmitter; + } for (auto &&[i, operand, lattice] : llvm::enumerate(op->getOperands(), operandExprs)) { IndexExprsLattice *latticeObject = getLatticeElement(operand); @@ -1694,11 +1713,12 @@ class IndexExprsBackwardAnalysis if (overrideInitialization) { if (llvm::failed(overrideInitialization( - top, [&](Value value, DictionaryAttr dict) { + top, [&](Value value, DictionaryAttr dict, int32_t priority) { if (!dict) return unsafeSet(getLatticeElement(value), IndexExprsLatticeStorage::top()); - unsafeSet(getLatticeElement(value), dict); + unsafeSet(getLatticeElement(value), + wave::IndexExprsLatticeStorage(dict, priority)); }))) return llvm::failure(); } @@ -1873,9 +1893,18 @@ class IndexExprsBackwardAnalysis // by the forward analysis. } + // Returns true if there are any delayed errors. + bool hasDelayedErrors() const { return !delayedErrors.empty(); } + + // Returns the delayed error emitter for the given operation. + wave::EmitDelayedErrorFn getDelayedError(Operation *op) const { + return delayedErrors.lookup_or(op, wave::EmitDelayedErrorFn()); + } + private: bool initialized = false; wave::OverrideInitializationFn overrideInitialization; + llvm::SmallDenseMap delayedErrors; }; namespace { @@ -1899,16 +1928,17 @@ class InferIndexExprsPass config.setInterprocedural(false); DataFlowSolver solver(config); - solver.load(); - solver.load(); - wave::addWaveIndexExprsAnalyses(solver, symbolTable); + solver.load(); + solver.load(); + wave::DelayedErrorEmitterInfo delayedErrorInfo = + wave::addWaveIndexExprsAnalyses(solver, symbolTable); if (llvm::failed( wave::runSolverAndCaptureErrors(solver, getOperation(), false))) return signalPassFailure(); - if (llvm::failed( - wave::setWaveIndexExprAnalysisResults(getOperation(), solver))) + if (llvm::failed(wave::setWaveIndexExprAnalysisResults( + getOperation(), solver, delayedErrorInfo))) return signalPassFailure(); getOperation()->walk([&](wave::IterateOp iterateOp) { @@ -1922,21 +1952,46 @@ class InferIndexExprsPass }; } // namespace -void wave::addWaveIndexExprsAnalyses( - DataFlowSolver &solver, SymbolTableCollection &symbolTable, - wave::WaveIndexExprsAnalysisOptions options) { +wave::DelayedErrorEmitterInfo +wave::addWaveIndexExprsAnalyses(DataFlowSolver &solver, + SymbolTableCollection &symbolTable, + wave::WaveIndexExprsAnalysisOptions options) { + IndexExprsForwardAnalysis *forward = nullptr; if (!options.disableForward) { - solver.load(options.overrideInitialization); + forward = + solver.load(options.overrideInitialization); } + IndexExprsBackwardAnalysis *backward = nullptr; if (!options.disableBackward) { - solver.load(symbolTable, - options.overrideInitialization); + backward = solver.load( + symbolTable, options.overrideInitialization); } + + // Note that these lambdas are stored and used later so they must not capture + // anything that has a function-level lifetime. + wave::DelayedErrorEmitterInfo delayedErrorEmitterInfo; + delayedErrorEmitterInfo.getDelayedError = + [forward, backward](Operation *op) -> wave::EmitDelayedErrorFn { + if (forward) { + if (wave::EmitDelayedErrorFn delayedError = forward->getDelayedError(op)) + return delayedError; + } + if (backward) { + return backward->getDelayedError(op); + } + return nullptr; + }; + delayedErrorEmitterInfo.hasDelayedErrors = [forward, backward]() { + return (forward && forward->hasDelayedErrors()) || + (backward && backward->hasDelayedErrors()); + }; + return delayedErrorEmitterInfo; } -LogicalResult -wave::setWaveIndexExprAnalysisResults(Operation *top, - const DataFlowSolver &solver) { +LogicalResult wave::setWaveIndexExprAnalysisResults( + Operation *top, const DataFlowSolver &solver, + const DelayedErrorEmitterInfo &delayedErrorInfo) { + bool hadFailures = false; WalkResult walkResult = top->walk([&](wave::WaveInferIndexExprsOpInterface iface) { auto getLatticeValue = [&](Value value) { @@ -1956,13 +2011,27 @@ wave::setWaveIndexExprAnalysisResults(Operation *top, descriptionGenerator(os, i); if (failed(detail::checkAndAppendIndexExpr(iface->getLoc(), getLatticeValue(value), - os.str(), indexExprs))) - return WalkResult::interrupt(); + os.str(), indexExprs))) { + // Don't stop on the first reported error if there are some delayed + // errors that would be useful to report here. We need to wait and + // see whether the operation they are attached to actually has had + // inference issues as some errors may be corrected. + if (!delayedErrorInfo.hasDelayedErrors()) + return WalkResult::interrupt(); + + hadFailures = true; + if (auto delayedError = delayedErrorInfo.getDelayedError(iface)) { + InFlightDiagnostic diag = + iface->emitError() + << "the error above may be caused by the following: "; + delayedError(diag); + } + } } iface->setAttr(wave::WaveDialect::kIndexWaveExprListAttrName, ArrayAttr::get(iface->getContext(), indexExprs)); return WalkResult::advance(); }); - return llvm::failure(walkResult.wasInterrupted()); + return llvm::failure(hadFailures || walkResult.wasInterrupted()); } diff --git a/water/test/Dialect/Wave/infer-index-exprs-lattice.mlir b/water/test/Dialect/Wave/infer-index-exprs-lattice.mlir index e8256c754f..926038a51b 100644 --- a/water/test/Dialect/Wave/infer-index-exprs-lattice.mlir +++ b/water/test/Dialect/Wave/infer-index-exprs-lattice.mlir @@ -18,9 +18,9 @@ normalform.module [#wave.normal_form] { #wave.hardware_constraint ] } { - %lhs_override = wave.read %lhs { wave_test.override_result_index = [{ + %lhs_override = wave.read %lhs { wave_test.override_result_index = [[3, { N = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 32, 1, 1)> - }]} : (!wave.tensor<[@M, @K] of f16>) -> !wave.tensor<[@M, @K] of f16> + }]]} : (!wave.tensor<[@M, @K] of f16>) -> !wave.tensor<[@M, @K] of f16> // expected-error @below {{conflict when propagating forward to the result lattice in MmaOp}} // expected-note @below {{Result lattice}} // expected-note @below {{LHS lattice}} @@ -49,9 +49,9 @@ normalform.module [#wave.normal_form] { // expected-note @below {{LHS lattice}} // expected-note @below {{result lattice}} %r = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind, - wave_test.override_result_index = [ + wave_test.override_result_index = [[3, {K = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 32, 1, 1)>} - ] + ]] } : (!wave.tensor<[@M, @K] of f16>, !wave.tensor<[@N, @K] of f16>, !wave.tensor<[@M, @N] of f32>) -> !wave.tensor<[@M, @N] of f32> @@ -77,7 +77,7 @@ normalform.module [#wave.normal_form] { %r = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind, wave_test.override_operand_index = [ unit, - {N = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 32, 1, 1)>} + [3, {N = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 32, 1, 1)>}] ] } : (!wave.tensor<[@M, @K] of f16>, !wave.tensor<[@N, @K] of f16>, !wave.tensor<[@M, @N] of f32>) @@ -105,7 +105,7 @@ normalform.module [#wave.normal_form] { wave_test.override_operand_index = [ unit, unit, - {M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 32, 1, 1)>} + [3, {M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 32, 1, 1)>}] ] } : (!wave.tensor<[@M, @K] of f16>, !wave.tensor<[@N, @K] of f16>, !wave.tensor<[@M, @N] of f32>) @@ -126,18 +126,18 @@ normalform.module [#wave.normal_form] { #wave.hardware_constraint ] } { - %add = wave.add %a, %b {wave_test.override_result_index = [{ + %add = wave.add %a, %b {wave_test.override_result_index = [[1,{ M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 32, 1, 1)>, K = #wave.index_mapping<[#wave.index_symbol] -> (T1 * 16, 1, 1)> - }]}: (!wave.tensor<[@M, @K] of f16>, !wave.tensor<[@M, @K] of f16>) -> !wave.tensor<[@M, @K] of f16> + }]]}: (!wave.tensor<[@M, @K] of f16>, !wave.tensor<[@M, @K] of f16>) -> !wave.tensor<[@M, @K] of f16> // expected-error @below {{conflict when propagating index expressions from result to operand #0}} // expected-note @below {{original operand lattice}} // expected-note @below {{result #0 lattice}} - %mul = wave.mul %add, %c {wave_test.override_result_index = [{ + %mul = wave.mul %add, %c {wave_test.override_result_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40, 1, 1)>, K = #wave.index_mapping<[#wave.index_symbol] -> (T1 * 16, 1, 1)> - }]}: (!wave.tensor<[@M, @K] of f16>, !wave.tensor<[@M, @K] of f16>) -> !wave.tensor<[@M, @K] of f16> + }]]}: (!wave.tensor<[@M, @K] of f16>, !wave.tensor<[@M, @K] of f16>) -> !wave.tensor<[@M, @K] of f16> return %mul : !wave.tensor<[@M, @K] of f16> } } @@ -154,19 +154,19 @@ normalform.module [#wave.normal_form] attributes { wave_test.disable #wave.hardware_constraint ] } { - %add = wave.add %a, %b {wave_test.override_result_index = [{ + %add = wave.add %a, %b {wave_test.override_result_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 32, 1, 1)>, K = #wave.index_mapping<[#wave.index_symbol] -> (T1 * 16, 1, 1)> - }]}: (!wave.tensor<[@M, @K] of f16>, !wave.tensor<[@M, @K] of f16>) -> !wave.tensor<[@M, @K] of f16> + }]]}: (!wave.tensor<[@M, @K] of f16>, !wave.tensor<[@M, @K] of f16>) -> !wave.tensor<[@M, @K] of f16> // expected-error @below {{conflict when propagating index expressions from operand to result #0}} // expected-note @below {{original result lattice}} // expected-note @below {{operand #0 lattice}} // expected-note @below {{operand #1 lattice}} - %mul = wave.mul %add, %c {wave_test.override_result_index = [{ + %mul = wave.mul %add, %c {wave_test.override_result_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40, 1, 1)>, K = #wave.index_mapping<[#wave.index_symbol] -> (T1 * 16, 1, 1)> - }]}: (!wave.tensor<[@M, @K] of f16>, !wave.tensor<[@M, @K] of f16>) -> !wave.tensor<[@M, @K] of f16> + }]]}: (!wave.tensor<[@M, @K] of f16>, !wave.tensor<[@M, @K] of f16>) -> !wave.tensor<[@M, @K] of f16> return %mul : !wave.tensor<[@M, @K] of f16> } } @@ -186,13 +186,13 @@ normalform.module [#wave.normal_form] attributes { wave_test.disable // expected-error @below {{incompatible operand lattices when propagating from those to result}} // expected-note @below {{operand #0 lattice}} // expected-note @below {{operand #1 lattice}} - %add = wave.add %a, %b {wave_test.override_operand_index = [{ + %add = wave.add %a, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 32, 1, 1)>, K = #wave.index_mapping<[#wave.index_symbol] -> (T1 * 16, 1, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 44, 1, 1)>, K = #wave.index_mapping<[#wave.index_symbol] -> (T1 * 16, 1, 1)> - }]}: (!wave.tensor<[@M, @K] of f16>, !wave.tensor<[@M, @K] of f16>) -> !wave.tensor<[@M, @K] of f16> + }]]}: (!wave.tensor<[@M, @K] of f16>, !wave.tensor<[@M, @K] of f16>) -> !wave.tensor<[@M, @K] of f16> return %add : !wave.tensor<[@M, @K] of f16> } @@ -235,12 +235,12 @@ normalform.module [#wave.normal_form] { } { // CHECK: wave.add // CHECK-SAME: index - // CHECK-SAME: M : <[#wave.index_symbol] -> (T0 * 40, 1, 1)> - %result = wave.add %a, %b {wave_test.override_operand_index = [{ + // CHECK-SAME: M : [#wave.index_symbol] -> (T0 * 40, 1, 1) + %result = wave.add %a, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40, 1, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40, 1, 1)> - }]} + }]]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> return %result : !wave.tensor<[@M] of f32> } @@ -262,12 +262,12 @@ normalform.module [#wave.normal_form] { } { // CHECK: wave.add // CHECK-SAME: index - // CHECK-SAME: M : <[#wave.index_symbol] -> (T0 * 40, 1, 1)> - %result = wave.add %a, %b {wave_test.override_operand_index = [{ + // CHECK-SAME: M : [#wave.index_symbol] -> (T0 * 40, 1, 1) + %result = wave.add %a, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40, 1, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.index_symbol] -> (, 1, 1)> - }]} + }]]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> return %result : !wave.tensor<[@M] of f32> } @@ -289,10 +289,10 @@ normalform.module [#wave.normal_form] { } { // CHECK: wave.add // CHECK-SAME: index - // CHECK-SAME: M : <[#wave.index_symbol] -> (T0 * 40, 1, 1)> - %result = wave.add %a, %b {wave_test.override_operand_index = [{ + // CHECK-SAME: M : [#wave.index_symbol] -> (T0 * 40, 1, 1) + %result = wave.add %a, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40, 1, 1)> - }, + }], unit // will default-initialize to bottom. ]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> @@ -316,12 +316,12 @@ normalform.module [#wave.normal_form] { } { // CHECK: wave.add // CHECK-SAME: index - // CHECK-SAME: M : <[#wave.index_symbol] -> (T0 * 40, 1, 1)> - %result = wave.add %a, %b {wave_test.override_operand_index = [{ + // CHECK-SAME: M : [#wave.index_symbol] -> (T0 * 40, 1, 1) + %result = wave.add %a, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40, 1, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.index_symbol] -> (0, 1, 1)> - }]} + }]]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> return %result : !wave.tensor<[@M] of f32> } @@ -343,11 +343,11 @@ normalform.module [#wave.normal_form] { // expected-error @below {{incompatible operand lattices when propagating from those to result}} // expected-note @below {{operand #0 lattice}} // expected-note @below {{operand #1 lattice}} - %result = wave.add %a, %b {wave_test.override_operand_index = [{ + %result = wave.add %a, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40, 1, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40 + 1, 1, 1)> - }]} + }]]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> return %result : !wave.tensor<[@M] of f32> } @@ -369,11 +369,11 @@ normalform.module [#wave.normal_form] { // expected-error @below {{incompatible operand lattices when propagating from those to result}} // expected-note @below {{operand #0 lattice}} // expected-note @below {{operand #1 lattice}} - %result = wave.add %a, %b {wave_test.override_operand_index = [{ + %result = wave.add %a, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40 + 2, 1, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40 + 1, 1, 1)> - }]} + }]]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> return %result : !wave.tensor<[@M] of f32> } @@ -396,11 +396,11 @@ normalform.module [#wave.normal_form] { // expected-error @below {{incompatible operand lattices when propagating from those to result}} // expected-note @below {{operand #0 lattice}} // expected-note @below {{operand #1 lattice}} - %result = wave.add %a, %b {wave_test.override_operand_index = [{ + %result = wave.add %a, %b {wave_test.override_operand_index = [[1,{ M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40, 3, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40, 2, 1)> - }]} + }]]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> return %result : !wave.tensor<[@M] of f32> } @@ -422,11 +422,11 @@ normalform.module [#wave.normal_form] { // expected-error @below {{incompatible operand lattices when propagating from those to result}} // expected-note @below {{operand #0 lattice}} // expected-note @below {{operand #1 lattice}} - %result = wave.add %a, %b {wave_test.override_operand_index = [{ + %result = wave.add %a, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40, 1, 2)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40, 1, 3)> - }]} + }]]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> return %result : !wave.tensor<[@M] of f32> } @@ -447,12 +447,12 @@ normalform.module [#wave.normal_form] { } { // CHECK: wave.add // CHECK-SAME: index - // CHECK-SAME: M : <[#wave.index_symbol] -> (T0 * 40, 1, 2)> - %result = wave.add %a, %b {wave_test.override_operand_index = [{ + // CHECK-SAME: M : [#wave.index_symbol] -> (T0 * 40, 1, 2) + %result = wave.add %a, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40, 1, 2)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40, 1, 1)> - }]} + }]]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> return %result : !wave.tensor<[@M] of f32> } @@ -473,12 +473,12 @@ normalform.module [#wave.normal_form] { } { // CHECK: wave.add // CHECK-SAME: index - // CHECK-SAME: M : <[#wave.index_symbol] -> (T0 * 40, T0, 1)> - %result = wave.add %a, %b {wave_test.override_operand_index = [{ + // CHECK-SAME: M : [#wave.index_symbol] -> (T0 * 40, T0, 1) + %result = wave.add %a, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40, 1, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40, T0, 1)> - }]} + }]]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> return %result : !wave.tensor<[@M] of f32> } @@ -500,11 +500,11 @@ normalform.module [#wave.normal_form] { // expected-error @below {{incompatible operand lattices when propagating from those to result}} // expected-note @below {{operand #0 lattice}} // expected-note @below {{operand #1 lattice}} - %result = wave.add %a, %b {wave_test.override_operand_index = [{ + %result = wave.add %a, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0, T0, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.index_symbol, #wave.index_symbol] -> (T0, WG0, 1)> - }]} + }]]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> return %result : !wave.tensor<[@M] of f32> } @@ -528,11 +528,11 @@ normalform.module [#wave.normal_form] { // expected-error @below {{incompatible operand lattices when propagating from those to result}} // expected-note @below {{operand #0 lattice}} // expected-note @below {{operand #1 lattice}} - %result = wave.add %a, %b {wave_test.override_operand_index = [{ + %result = wave.add %a, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 40, 1, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T1 * 40, 1, 1)> - }]} + }]]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> return %result : !wave.tensor<[@M] of f32> } @@ -555,11 +555,11 @@ normalform.module [#wave.normal_form] { // expected-error @below {{incompatible operand lattices when propagating from those to result}} // expected-note @below {{operand #0 lattice}} // expected-note @below {{operand #1 lattice}} - %result = wave.add %a, %b {wave_test.override_operand_index = [{ + %result = wave.add %a, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol, #wave.index_symbol] -> (WG0, 1, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.index_symbol, #wave.index_symbol] -> (WG1, 1, 1)> - }]} + }]]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> return %result : !wave.tensor<[@M] of f32> } @@ -582,12 +582,12 @@ normalform.module [#wave.normal_form] { } { // CHECK: wave.add // CHECK-SAME: index - // CHECK-SAME: {M : <[#wave.index_symbol, #wave.index_symbol] -> (WG0 + T0, 1, 1)> - %result = wave.add %a, %b {wave_test.override_operand_index = [{ + // CHECK-SAME: {M : [#wave.index_symbol, #wave.index_symbol] -> (WG0 + T0, 1, 1) + %result = wave.add %a, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol, #wave.index_symbol] -> (WG0, 1, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.index_symbol, #wave.index_symbol] -> (T0, 1, 1)> - }]} + }]]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> return %result : !wave.tensor<[@M] of f32> } @@ -609,12 +609,12 @@ normalform.module [#wave.normal_form] { } { // CHECK: wave.add // CHECK-SAME: index - // CHECK-SAME: {M : <[#wave.index_symbol, #wave.index_symbol] -> (WG0 + T0 + 2, 1, 1)> - %result = wave.add %a, %b {wave_test.override_operand_index = [{ + // CHECK-SAME: {M : [#wave.index_symbol, #wave.index_symbol] -> (WG0 + T0 + 2, 1, 1) + %result = wave.add %a, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (WG0 + 2, 1, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0 + 2 , 1, 1)> - }]} + }]]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> return %result : !wave.tensor<[@M] of f32> } @@ -636,12 +636,12 @@ normalform.module [#wave.normal_form] { } { // CHECK: wave.add // CHECK-SAME: index - // CHECK-SAME: {M : <[#wave.index_symbol, #wave.index_symbol] -> (WG0 + T0, 1, 1)> - %result = wave.add %a, %b {wave_test.override_operand_index = [{ + // CHECK-SAME: {M : [#wave.index_symbol, #wave.index_symbol] -> (WG0 + T0, 1, 1) + %result = wave.add %a, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (WG0, 1, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.index_symbol] -> (T0, 1, 1)> - }]} + }]]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> return %result : !wave.tensor<[@M] of f32> } @@ -666,12 +666,12 @@ normalform.module [#wave.normal_form] { ^bb0(%a_arg: !wave.tensor<[@M, @K] of f32>): // CHECK: wave.add // CHECK-SAME: index - // CHECK-SAME: M : <[#wave.index_symbol, #wave.iter<"K">] -> (WG0 + _Iter_K, 1, 1)> - %partial_result = wave.add %a_arg, %b {wave_test.override_operand_index = [{ + // CHECK-SAME: M : [#wave.index_symbol, #wave.iter<"K">] -> (WG0 + _Iter_K, 1, 1) + %partial_result = wave.add %a_arg, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.index_symbol] -> (WG0, 1, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.iter<"K">] -> (_Iter_K, 1, 1)> - }]} + }]]} : (!wave.tensor<[@M, @K] of f32>, !wave.tensor<[@M, @K] of f32>) -> !wave.tensor<[@M, @K] of f32> wave.yield %partial_result : !wave.tensor<[@M, @K] of f32> @@ -698,12 +698,12 @@ normalform.module [#wave.normal_form] { ^bb0(%a_arg: !wave.tensor<[@M, @K] of f32>): // CHECK: wave.add // CHECK-SAME: index - // CHECK-SAME: M : <[#wave.iter<"K">] -> (_Iter_K + 42, 1, 1)> - %partial_result = wave.add %a_arg, %b {wave_test.override_operand_index = [{ + // CHECK-SAME: M : [#wave.iter<"K">] -> (_Iter_K + 42, 1, 1) + %partial_result = wave.add %a_arg, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.iter<"K">] -> (_Iter_K + 42, 1, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.iter<"K">] -> (_Iter_K + 42, 1, 1)> - }]} + }]]} : (!wave.tensor<[@M, @K] of f32>, !wave.tensor<[@M, @K] of f32>) -> !wave.tensor<[@M, @K] of f32> wave.yield %partial_result : !wave.tensor<[@M, @K] of f32> @@ -740,12 +740,12 @@ normalform.module [#wave.normal_form] { ^bb1(%a_arg: !wave.tensor<[@M, @K] of f32>): // CHECK: wave.add // CHECK-SAME: index - // CHECK-SAME: M : <[#wave.iter<"K">, #wave.iter<"M">] -> (_Iter_K + _Iter_M, 1, 1)> - %partial_result = wave.add %a_arg, %b_arg {wave_test.override_operand_index = [{ + // CHECK-SAME: M : [#wave.iter<"K">, #wave.iter<"M">] -> (_Iter_K + _Iter_M, 1, 1) + %partial_result = wave.add %a_arg, %b_arg {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.iter<"M">] -> (_Iter_M, 1, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.iter<"K">] -> (_Iter_K, 1, 1)> - }]} + }]]} : (!wave.tensor<[@M, @K] of f32>, !wave.tensor<[@M, @K] of f32>) -> !wave.tensor<[@M, @K] of f32> wave.yield %partial_result : !wave.tensor<[@M, @K] of f32> @@ -776,11 +776,11 @@ normalform.module [#wave.normal_form] { // expected-error @below {{incompatible operand lattices when propagating from those to result}} // expected-note @below {{operand #0 lattice}} // expected-note @below {{operand #1 lattice}} - %partial_result = wave.add %a_arg, %b {wave_test.override_operand_index = [{ + %partial_result = wave.add %a_arg, %b {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.iter<"K">] -> (_Iter_K + 42, 1, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.iter<"K">] -> (_Iter_K * 2, 1, 1)> - }]} + }]]} : (!wave.tensor<[@M, @K] of f32>, !wave.tensor<[@M, @K] of f32>) -> !wave.tensor<[@M, @K] of f32> wave.yield %partial_result : !wave.tensor<[@M, @K] of f32> @@ -808,11 +808,11 @@ normalform.module [#wave.normal_form] { %b_reg = wave.read %b : (!wave.tensor<[@M, @K] of f32>) -> !wave.tensor<[@M, @K] of f32> %result = wave.iterate @K iter_args(%a) { ^bb0(%a_arg: !wave.tensor<[@M, @K] of f32>): - %partial_result = wave.add %a_arg, %b_reg {wave_test.override_operand_index = [{ + %partial_result = wave.add %a_arg, %b_reg {wave_test.override_operand_index = [[1, { M = #wave.index_mapping<[#wave.iter<"K">] -> (_Iter_K + 42, 1, 1)> - }, { + }], [1, { M = #wave.index_mapping<[#wave.iter<"K">] -> (_Iter_K + 42, 1, 1)> - }]} + }]]} : (!wave.tensor<[@M, @K] of f32>, !wave.tensor<[@M, @K] of f32>) -> !wave.tensor<[@M, @K] of f32> wave.yield %partial_result : !wave.tensor<[@M, @K] of f32> @@ -833,19 +833,21 @@ normalform.module [#wave.normal_form] { %c: !wave.tensor<[@M] of f32> ) attributes { wave.constraints = [ - #wave.hardware_constraint - ] + #wave.hardware_constraint, + #wave.workgroup_constraint, tile_size = <[] -> (42)>, workgroup_dim = > + ], + wave.hyperparameters = #wave.hyperparameters<{M = 1024}> } { // CHECK: index // CHECK: M : <[] -> (42, 1, )> wave.write %a, %b {wave_test.override_operand_index = [ - unit, { + unit, [1, { M = #wave.index_mapping<[#wave.iter<"K">] -> (, 1, )> - }] + }]] } : !wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32> - %c_reg = wave.read %c {wave_test.override_result_index = [{ + %c_reg = wave.read %c {wave_test.override_result_index = [[1, { M = #wave.index_mapping<[#wave.iter<"K">] -> (42, , )> - }]} : (!wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> + }]]} : (!wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> wave.write %c_reg, %b : !wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32> return } @@ -853,6 +855,109 @@ normalform.module [#wave.normal_form] { // ----- +normalform.module [#wave.normal_form] { + // CHECK-LABEL: @priority_join + func.func @priority_join( + %a: !wave.tensor<[@M] of f32>, + %b: !wave.tensor<[@M] of f32> + ) attributes { + wave.constraints = [ + #wave.hardware_constraint + ] + } { + // Low priority lattice. + %a_reg = wave.read %a {wave_test.override_result_index = [[0, { + M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 10, 1, 1)> + }]]} : (!wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> + + // Higher priority lattice on the other operand. Selected without causing a conflict. + // CHECK: wave.add + // CHECK-SAME: M : [#wave.index_symbol] -> (T0 * 20, 1, 1) + %sum = wave.add %a_reg, %b {wave_test.override_operand_index = [ + unit, + [3, { + M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 20, 1, 1)> + }] + ]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> + + return + } +} + +// ----- + +normalform.module [#wave.normal_form] attributes { wave_test.disable_backward } { + func.func @same_priority_conflict( + %a: !wave.tensor<[@M] of f32>, + %b: !wave.tensor<[@M] of f32> + ) attributes { + wave.constraints = [ + #wave.hardware_constraint + ] + } { + // Both operands have same priority but different expressions - should conflict + // expected-error @below {{incompatible operand lattices when propagating from those to result}} + // expected-note @below {{operand #0 lattice}} + // expected-note @below {{operand #1 lattice}} + %sum = wave.add %a, %b {wave_test.override_operand_index = [[1, { + M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 10, 1, 1)> + }], [1, { + M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 20, 1, 1)> + }]]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> + + return + } +} + +// ----- + +// Test that higher priority from write propagates backward through multiple operations. + +normalform.module [#wave.normal_form] { + // CHECK-LABEL: @priority_backward_through_chain + func.func @priority_backward_through_chain( + %a: !wave.tensor<[@M] of f32>, + %b: !wave.tensor<[@M] of f32>, + %output: !wave.tensor<[@M] of f32> + ) attributes { + wave.constraints = [ + #wave.hardware_constraint + ] + } { + // CHECK: wave.read + // CHECK-SAME: M : [#wave.index_symbol] -> (T0 * 30, 1, 1) + %a_reg = wave.read %a : (!wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> + + // The override has low priority, it only matters at initialization and is + // then itself overridden by the higher-priority lattice in backpropagation. + // CHECK: wave.add + // CHECK-SAME: M : [#wave.index_symbol] -> (T0 * 30, 1, 1) + %sum1 = wave.add %a_reg, %b {wave_test.override_operand_index = [ + unit, + [0, { + M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 5, 1, 1)> + }] + ]} : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> + + // CHECK: wave.add + // CHECK-SAME: M : [#wave.index_symbol] -> (T0 * 30, 1, 1) + %sum2 = wave.add %sum1, %a_reg : (!wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> + + // CHECK: wave.write + // CHECK-SAME: M : [#wave.index_symbol] -> (T0 * 30, 1, 1) + wave.write %sum2, %output {wave_test.override_operand_index = [ + unit, + [3, { + M = #wave.index_mapping<[#wave.index_symbol] -> (T0 * 30, 1, 1)> + }] + ]} : !wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32> + + return + } +} + +// ----- + // Check that sideways propagation between operands of a write that would // lead to a conflict is not happening. @@ -864,23 +969,26 @@ normalform.module [#wave.normal_form] { %c: !wave.tensor<[@M] of f32> ) attributes { wave.constraints = [ - #wave.hardware_constraint - ] + #wave.hardware_constraint, + #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, workgroup_dim = >, + #wave.wave_constraint, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M floordiv 4)>> + ], + wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 64 : i64, M = 128}> } { // CHECK: wave.write // CHECK: index // CHECK: M : <[] -> (1, , )> wave.write %a, %b {wave_test.override_operand_index = [ - unit, { + unit, [1, { M = #wave.index_mapping<[#wave.iter<"K">] -> (1, , )> - }] + }]] } : !wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32> // CHECK: wave.read // CHECK: index - // CHECK: M : <[] -> (42, , )> - %c_reg = wave.read %c {wave_test.override_result_index = [{ + // CHECK: M : [] -> (42, , ) + %c_reg = wave.read %c {wave_test.override_result_index = [[1, { M = #wave.index_mapping<[#wave.iter<"K">] -> (42, , )> - }]} : (!wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> + }]]} : (!wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32> wave.write %c_reg, %b : !wave.tensor<[@M] of f32>, !wave.tensor<[@M] of f32> return } diff --git a/water/test/Dialect/Wave/infer-index-exprs.mlir b/water/test/Dialect/Wave/infer-index-exprs.mlir index 6e1af3d240..eec60b399b 100644 --- a/water/test/Dialect/Wave/infer-index-exprs.mlir +++ b/water/test/Dialect/Wave/infer-index-exprs.mlir @@ -633,6 +633,7 @@ normalform.module [#wave.normal_form] { ]} { // expected-error @below {{failed to infer index expressions for value to store}} + // expected-error @below {{the error above may be caused by the following: couldn't find vector shapes in the contiguity check}} wave.write %src, %dst : !wave.tensor<[@M, @N] of f32>, !wave.tensor<[@M, @N] of f32, > return @@ -963,3 +964,248 @@ normalform.module [#wave.normal_form] { return } } + +// ----- + +normalform.module [#wave.normal_form] { + // CHECK-LABEL: @propagate_from_write + func.func @propagate_from_write( + %a: !wave.tensor<[@M, @N] of f32>, + %b: !wave.tensor<[@M, @N] of f32>, + %output: !wave.tensor<[@M, @N] of f32> + ) attributes { + wave.constraints = [ + #wave.hardware_constraint, vector_shapes = {M = 4 : i64, N = 1 : i64}>, + #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, workgroup_dim = >, + #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_N">] -> (BLOCK_N)>, workgroup_dim = > + ], + wave.hyperparameters = #wave.hyperparameters<{M = 128, N = 128, BLOCK_M = 64 : i64, BLOCK_N = 64 : i64}> + } { + // CHECK: wave.read + // CHECK-DAG: M : <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> ((T0 mod 64) * (BLOCK_M ceildiv 64) + WG0 * BLOCK_M, BLOCK_M ceildiv 64, 1)> + // CHECK-DAG: N : <[#wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N, 1, 1)> + %a_reg = wave.read %a : (!wave.tensor<[@M, @N] of f32>) -> !wave.tensor<[@M, @N] of f32, > + // CHECK: wave.read + // CHECK-DAG: M : <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> ((T0 mod 64) * (BLOCK_M ceildiv 64) + WG0 * BLOCK_M, BLOCK_M ceildiv 64, 1)> + // CHECK-DAG: N : <[#wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N, 1, 1)> + %b_reg = wave.read %b : (!wave.tensor<[@M, @N] of f32>) -> !wave.tensor<[@M, @N] of f32, > + + // CHECK: wave.add + // CHECK-DAG: M : <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> ((T0 mod 64) * (BLOCK_M ceildiv 64) + WG0 * BLOCK_M, BLOCK_M ceildiv 64, 1)> + // CHECK-DAG: N : <[#wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N, 1, 1)> + %sum = wave.add %a_reg, %b_reg : (!wave.tensor<[@M, @N] of f32, >, !wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + + // CHECK: wave.write + // CHECK-DAG: M : <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> ((T0 mod 64) * (BLOCK_M ceildiv 64) + WG0 * BLOCK_M, BLOCK_M ceildiv 64, 1)> + // CHECK-DAG: N : <[#wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N, 1, 1)> + wave.write %sum, %output : !wave.tensor<[@M, @N] of f32, >, !wave.tensor<[@M, @N] of f32> + + return + } +} + +// ----- + +// Elements per thread provided on the op used instead of the value inferred from workgroup constraints. +normalform.module [#wave.normal_form] { + // CHECK-LABEL: @propagate_from_write_explicit_ept + func.func @propagate_from_write_explicit_ept( + %output: !wave.tensor<[@M, @N] of f32> + ) attributes { + wave.constraints = [ + #wave.hardware_constraint, vector_shapes = {M = 4 : i64, N = 1 : i64}>, + #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, workgroup_dim = >, + #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_N">] -> (BLOCK_N)>, workgroup_dim = > + ], + wave.hyperparameters = #wave.hyperparameters<{M = 128, N = 128, BLOCK_M = 64 : i64, BLOCK_N = 64 : i64}> + } { + %cst = arith.constant 0.0 : f32 + %reg = wave.register %cst : !wave.tensor<[@M, @N] of f32, > + // CHECK: wave.write + // CHECK-DAG: M : <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> ((T0 mod 64) * 8 + WG0 * BLOCK_M, 8, 1)> + // CHECK-DAG: N : <[#wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N, 1, 1)> + wave.write %reg, %output {elements_per_thread = 8} : !wave.tensor<[@M, @N] of f32, >, !wave.tensor<[@M, @N] of f32> + + return + } +} + +// ----- + +// Elements per thread is used for the trailing dimension because its vector shape is no longer 1. +normalform.module [#wave.normal_form] { + // CHECK-LABEL: @propagate_from_write_vector_shape + func.func @propagate_from_write_vector_shape( + %output: !wave.tensor<[@M, @N] of f32> + ) attributes { + wave.constraints = [ + #wave.hardware_constraint, vector_shapes = {M = 4 : i64, N = 16 : i64}>, + #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, workgroup_dim = >, + #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_N">] -> (BLOCK_N)>, workgroup_dim = > + ], + wave.hyperparameters = #wave.hyperparameters<{M = 128, N = 128, BLOCK_M = 64 : i64, BLOCK_N = 64 : i64}> + } { + %cst = arith.constant 0.0 : f32 + %reg = wave.register %cst : !wave.tensor<[@M, @N] of f32, > + // CHECK: wave.write + // CHECK-DAG: M : <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (T0 mod 64 + WG0 * BLOCK_M, 1, 16)> + // CHECK-DAG: N : <[#wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N, 8, 1)> + wave.write %reg, %output {elements_per_thread = 8} : !wave.tensor<[@M, @N] of f32, >, !wave.tensor<[@M, @N] of f32> + + return + } +} + +// ----- + +// Test that unmapped dimensions get default (0, 1, 1) index expressions +// when there are no workgroup/wave/tiling constraints for them. + +normalform.module [#wave.normal_form] { + // CHECK-LABEL: @unmapped_dimension_default + func.func @unmapped_dimension_default( + %a: !wave.tensor<[@B, @M, @N] of f32>, + %output: !wave.tensor<[@B, @M, @N] of f32> + ) attributes { + wave.constraints = [ + #wave.hardware_constraint, vector_shapes = {M = 4 : i64, N = 1 : i64}>, + #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, workgroup_dim = >, + #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_N">] -> (BLOCK_N)>, workgroup_dim = > + ], + wave.hyperparameters = #wave.hyperparameters<{B = 8, M = 128, N = 128, BLOCK_M = 64 : i64, BLOCK_N = 64 : i64}> + } { + // Read should get default index expression for B dimension (no constraints) + // and computed expressions for M and N dimensions + // CHECK: wave.read + // CHECK-DAG: B : <[] -> (0, 1, 1)> + // CHECK-DAG: M : <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> ((T0 mod 64) * (BLOCK_M ceildiv 64) + WG0 * BLOCK_M, BLOCK_M ceildiv 64, 1)> + // CHECK-DAG: N : <[#wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N, 1, 1)> + %a_reg = wave.read %a : (!wave.tensor<[@B, @M, @N] of f32>) -> !wave.tensor<[@B, @M, @N] of f32, > + + // Write should preserve the same index expressions + // CHECK: wave.write + // CHECK-DAG: B : <[] -> (0, 1, 1)> + // CHECK-DAG: M : <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> ((T0 mod 64) * (BLOCK_M ceildiv 64) + WG0 * BLOCK_M, BLOCK_M ceildiv 64, 1)> + // CHECK-DAG: N : <[#wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N, 1, 1)> + wave.write %a_reg, %output : !wave.tensor<[@B, @M, @N] of f32, >, !wave.tensor<[@B, @M, @N] of f32> + + return + } +} + +// ----- + +// Test priority-based propagation with multiple write operations. +// All writes should establish index expressions with the same priority, +// and the join should succeed since they agree. + +normalform.module [#wave.normal_form] { + // CHECK-LABEL: @multiple_writes_consistent + func.func @multiple_writes_consistent( + %a: !wave.tensor<[@M, @N] of f32>, + %b: !wave.tensor<[@M, @N] of f32>, + %out1: !wave.tensor<[@M, @N] of f32>, + %out2: !wave.tensor<[@M, @N] of f32> + ) attributes { + wave.constraints = [ + #wave.hardware_constraint, vector_shapes = {M = 4 : i64, N = 1 : i64}>, + #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, workgroup_dim = >, + #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_N">] -> (BLOCK_N)>, workgroup_dim = > + ], + wave.hyperparameters = #wave.hyperparameters<{M = 128, N = 128, BLOCK_M = 64 : i64, BLOCK_N = 64 : i64}> + } { + // CHECK: wave.read + // CHECK-DAG: M : <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> ((T0 mod 64) * (BLOCK_M ceildiv 64) + WG0 * BLOCK_M, BLOCK_M ceildiv 64, 1)> + // CHECK-DAG: N : <[#wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N, 1, 1)> + %a_reg = wave.read %a : (!wave.tensor<[@M, @N] of f32>) -> !wave.tensor<[@M, @N] of f32, > + + // CHECK: wave.read + // CHECK-DAG: M : <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> ((T0 mod 64) * (BLOCK_M ceildiv 64) + WG0 * BLOCK_M, BLOCK_M ceildiv 64, 1)> + // CHECK-DAG: N : <[#wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N, 1, 1)> + %b_reg = wave.read %b : (!wave.tensor<[@M, @N] of f32>) -> !wave.tensor<[@M, @N] of f32, > + + // CHECK: wave.add + // CHECK-DAG: M : <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> ((T0 mod 64) * (BLOCK_M ceildiv 64) + WG0 * BLOCK_M, BLOCK_M ceildiv 64, 1)> + // CHECK-DAG: N : <[#wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N, 1, 1)> + %sum = wave.add %a_reg, %b_reg : (!wave.tensor<[@M, @N] of f32, >, !wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > + + // Both writes establish the same index expressions + // CHECK: wave.write + // CHECK-DAG: M : <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> ((T0 mod 64) * (BLOCK_M ceildiv 64) + WG0 * BLOCK_M, BLOCK_M ceildiv 64, 1)> + // CHECK-DAG: N : <[#wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N, 1, 1)> + wave.write %sum, %out1 : !wave.tensor<[@M, @N] of f32, >, !wave.tensor<[@M, @N] of f32> + + // CHECK: wave.write + // CHECK-DAG: M : <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> ((T0 mod 64) * (BLOCK_M ceildiv 64) + WG0 * BLOCK_M, BLOCK_M ceildiv 64, 1)> + // CHECK-DAG: N : <[#wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N, 1, 1)> + wave.write %a_reg, %out2 : !wave.tensor<[@M, @N] of f32, >, !wave.tensor<[@M, @N] of f32> + + return + } +} + +// ----- + +// Test write when all dimension symbols are absent from constraints. +normalform.module [#wave.normal_form] { + // CHECK-LABEL: @write_all_dimensions_unmapped + func.func @write_all_dimensions_unmapped( + %a: !wave.tensor<[@P, @Q] of f32>, + %output: !wave.tensor<[@P, @Q] of f32> + ) attributes { + wave.constraints = [ + #wave.hardware_constraint, vector_shapes = {P = 1 : i64, Q = 1 : i64}> + ], + wave.hyperparameters = #wave.hyperparameters<{P = 8 : i64, Q = 16 : i64}> + } { + // CHECK: wave.read + // CHECK-DAG: P : <[] -> (0, 1, 1)> + // CHECK-DAG: Q : <[] -> (0, 1, 1)> + %a_reg = wave.read %a : (!wave.tensor<[@P, @Q] of f32>) -> !wave.tensor<[@P, @Q] of f32, > + + // CHECK: wave.write + // CHECK-DAG: P : <[] -> (0, 1, 1)> + // CHECK-DAG: Q : <[] -> (0, 1, 1)> + wave.write %a_reg, %output : !wave.tensor<[@P, @Q] of f32, >, !wave.tensor<[@P, @Q] of f32> + + return + } +} + + +// ----- + +// MMa index expression has higher priority than write. + +normalform.module [#wave.normal_form] { + // CHECK-LABEL: @write_after_mma_priority + func.func @write_after_mma_priority( + %a: !wave.tensor<[@M, @K] of f16>, + %b: !wave.tensor<[@N, @K] of f16>, + %c: !wave.tensor<[@M, @N] of f32>, + %output: !wave.tensor<[@M, @N] of f32> + ) attributes { + wave.constraints = [ + #wave.hardware_constraint, vector_shapes = {M = 4 : i64, N = 1 : i64}>, + #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, workgroup_dim = >, + #wave.workgroup_constraint, tile_size = <[#wave.symbol<"BLOCK_N">] -> (BLOCK_N)>, workgroup_dim = >, + #wave.wave_constraint, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>>, + #wave.wave_constraint, tile_size = <[#wave.symbol<"BLOCK_N">] -> (BLOCK_N)>> + ], + wave.hyperparameters = #wave.hyperparameters<{M = 128, N = 128, K = 64, BLOCK_M = 64 : i64, BLOCK_N = 64 : i64}> + } { + %a_reg = wave.read %a : (!wave.tensor<[@M, @K] of f16>) -> !wave.tensor<[@M, @K] of f16, > + %b_reg = wave.read %b : (!wave.tensor<[@N, @K] of f16>) -> !wave.tensor<[@N, @K] of f16, > + %mma = wave.mma %a_reg, %b_reg, %c {kind = #wave.mma_kind} + : (!wave.tensor<[@M, @K] of f16, >, !wave.tensor<[@N, @K] of f16, >, !wave.tensor<[@M, @N] of f32>) -> !wave.tensor<[@M, @N] of f32, > + // Write in isolation would have inferred step=floordiv(BLOCK_M, 64) since M is mapped + // to workgroup X with 64 threads in it, but we obtain step=4 propagate from the mma + // above, because that has higher priority. + // CHECK: wave.write + // CHECK-DAG: M : <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (((T0 mod 64) floordiv 16) * 4 + WG0 * BLOCK_M + // CHECK-DAG: N : <[#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (T0 mod 16 + WG1 * BLOCK_N + wave.write %mma, %output : !wave.tensor<[@M, @N] of f32, >, !wave.tensor<[@M, @N] of f32> + + return + } +} diff --git a/water/test/lib/Transforms/TestWaveDialectInferIndexExprs.cpp b/water/test/lib/Transforms/TestWaveDialectInferIndexExprs.cpp index e46a4a0195..ff1131f718 100644 --- a/water/test/lib/Transforms/TestWaveDialectInferIndexExprs.cpp +++ b/water/test/lib/Transforms/TestWaveDialectInferIndexExprs.cpp @@ -38,21 +38,25 @@ overrideInitialization(Operation *top, continue; if (auto strAttr = llvm::dyn_cast(attr); strAttr && strAttr.getValue() == "") { - setIndexForValue(value, nullptr); + setIndexForValue(value, nullptr, + wave::IndexExprsLatticeStorage::kHighestPriority); continue; } - auto dict = llvm::dyn_cast(attr); - if (!dict || llvm::any_of(dict.getValue(), [](NamedAttribute attr) { - return !llvm::isa(attr.getValue()); - })) { - return op->emitError() - << "expected " << attributeName - << " to be an array of " - "dictionaries with WaveIndexMappingAttr or UnitAttr values"; + auto array = llvm::dyn_cast(attr); + auto priority = (array && !array.empty()) + ? llvm::dyn_cast(array.getValue()[0]) + : IntegerAttr(); + auto dict = (array && array.size() > 1) + ? llvm::dyn_cast(array.getValue()[1]) + : DictionaryAttr(); + if (!priority || !dict) { + return op->emitError() << "expected " << attributeName + << " to be an array containing an integer " + "priority and a dictionary mapping"; } - setIndexForValue(value, dict); + setIndexForValue(value, dict, priority.getInt()); } return success(); }; @@ -100,7 +104,8 @@ class TestWaveDialectInferIndexExprsPass options.disableForward = getOperation()->getAttrOfType( "wave_test.disable_forward") != nullptr; options.overrideInitialization = overrideInitialization; - addWaveIndexExprsAnalyses(solver, symbolTable, options); + wave::DelayedErrorEmitterInfo delayedErrorInfo = + wave::addWaveIndexExprsAnalyses(solver, symbolTable, options); IRRewriter rewriter(&getContext()); getOperation()->walk( @@ -109,7 +114,8 @@ class TestWaveDialectInferIndexExprsPass if (failed(wave::runSolverAndCaptureErrors(solver, getOperation(), false))) return signalPassFailure(); - if (failed(setWaveIndexExprAnalysisResults(getOperation(), solver))) + if (failed(setWaveIndexExprAnalysisResults(getOperation(), solver, + delayedErrorInfo))) return signalPassFailure(); getOperation()->walk([&](wave::IterateOp iterateOp) { diff --git a/wave_lang/kernel/wave/analysis/index_sequence_analysis.py b/wave_lang/kernel/wave/analysis/index_sequence_analysis.py index 4c72b29663..26a0961fd2 100644 --- a/wave_lang/kernel/wave/analysis/index_sequence_analysis.py +++ b/wave_lang/kernel/wave/analysis/index_sequence_analysis.py @@ -21,7 +21,7 @@ from wave_lang.kernel.wave.compile_options import WaveCompileOptions from wave_lang.support.logging import get_logger -from ..._support.indexing import IndexSequence, IndexSymbol +from ..._support.indexing import IndexSequence, IndexSymbol, IndexExpr from ..._support.tracing import CapturedTrace from ...lang.global_symbols import * from ...ops.wave_ops import ( @@ -328,18 +328,72 @@ def _check_water_indices(trace: CapturedTrace, inferred: dict[str, IndexSequence if isinstance(custom, GetResult): continue + # Assumptions on symbols may be insufficiently tight, in particular for + # non-counting symbols we can assume strict positivity, which allows us + # to get rid of Max(1, ...) expressions that otherwise appear on the + # python side. + def ensure_symbols_positive( + seqs: dict[IndexSymbol, IndexSequence], + ) -> dict[IndexSymbol, IndexSequence]: + all_symbols = set() + for _, seq in seqs.items(): + if isinstance(seq.start, sympy.Expr): + all_symbols.update(seq.start.free_symbols) + if isinstance(seq.size, sympy.Expr): + all_symbols.update(seq.size.free_symbols) + if isinstance(seq.stride, sympy.Expr): + all_symbols.update(seq.stride.free_symbols) + + symbol_remapping = { + symbol: ( + sympy.Symbol(symbol.name, nonnegative=True, integer=True) + if symbol.name.startswith("$") + else sympy.Symbol(symbol.name, positive=True, integer=True) + ) + for symbol in all_symbols + } + return { + dim: IndexSequence( + start=( + sympy.simplify( + seq.start.subs(symbol_remapping, simultaneous=True) + ) + if isinstance(seq.start, IndexExpr) + else seq.start + ), + size=( + sympy.simplify( + seq.size.subs(symbol_remapping, simultaneous=True) + ) + if isinstance(seq.size, IndexExpr) + else seq.size + ), + stride=( + sympy.simplify( + seq.stride.subs(symbol_remapping, simultaneous=True) + ) + if isinstance(seq.stride, IndexExpr) + else seq.stride + ), + ) + for dim, seq in node.index.items() + } + + node_index = ensure_symbols_positive(node.index) + inferred_index = ensure_symbols_positive(inferred_index) + # Check that that indices match, raise an error if they don't. Start by # a trivial direct comparison, fall back to computing and simplifying # the difference. The latter can raise with additional information, # which this wants to preserve. try: - if node.index != inferred_index and not _check_index_difference_is_zero( - node.index, inferred_index + if node_index != inferred_index and not _check_index_difference_is_zero( + node_index, inferred_index ): raise ValueError("mismatching indices") except ValueError as e: raise RuntimeError( - f"Index for node {get_custom(node)}, {get_custom(node).index} does not match inferred index {inferred_index}." + f"Index for node {get_custom(node)}, {node_index} does not match inferred index {inferred_index}." ) from e