diff --git a/water/include/water/Dialect/Wave/IR/WaveInterfaces.h b/water/include/water/Dialect/Wave/IR/WaveInterfaces.h index 2d2373d12f..256d7b2af2 100644 --- a/water/include/water/Dialect/Wave/IR/WaveInterfaces.h +++ b/water/include/water/Dialect/Wave/IR/WaveInterfaces.h @@ -26,6 +26,10 @@ using EmitErrorFn = llvm::function_ref; class WaveTensorType; +/// Get the hyperparameters from an ancestor operation. +/// Returns nullptr if no hyperparameters are found. +WaveHyperparameterAttr getHyperparameters(mlir::Operation *op); + //----------------------------------------------------------------------------- // HasWaveIndexMapping trait //----------------------------------------------------------------------------- diff --git a/water/include/water/Dialect/Wave/Transforms/Utils.h b/water/include/water/Dialect/Wave/Transforms/Utils.h index 4d3cc3692f..e57195160a 100644 --- a/water/include/water/Dialect/Wave/Transforms/Utils.h +++ b/water/include/water/Dialect/Wave/Transforms/Utils.h @@ -11,10 +11,6 @@ namespace wave { -/// Get the hyperparameters from an ancestor operation. -/// Returns nullptr if no hyperparameters are found. -WaveHyperparameterAttr getHyperparameters(mlir::Operation *op); - // Populates `constraints` with a mapping from an operation with a Wave // constraints attribute attached to that attribute. llvm::LogicalResult collectWaveConstraints( diff --git a/water/lib/Dialect/Wave/IR/WaveDialect.cpp b/water/lib/Dialect/Wave/IR/WaveDialect.cpp index 7a7e7218c0..2b525b222f 100644 --- a/water/lib/Dialect/Wave/IR/WaveDialect.cpp +++ b/water/lib/Dialect/Wave/IR/WaveDialect.cpp @@ -14,14 +14,13 @@ #include "mlir/IR/Dialect.h" #include "water/Dialect/Wave/IR/WaveDialect.cpp.inc" +#include "water/Dialect/Wave/IR/WaveInterfaces.h" #include "water/Dialect/Wave/IR/WaveUtils.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/LogicalResult.h" -#include #include using namespace mlir; diff --git a/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp b/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp index f96aadeb6f..7125e466fa 100644 --- a/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp +++ b/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp @@ -23,6 +23,19 @@ using namespace mlir; +//----------------------------------------------------------------------------- +// getHyperparameters +//----------------------------------------------------------------------------- + +wave::WaveHyperparameterAttr wave::getHyperparameters(Operation *op) { + for (Operation *current = op; current; current = current->getParentOp()) { + if (auto hyperparams = current->getAttrOfType( + WaveDialect::kHyperparameterAttrName)) + return hyperparams; + } + return nullptr; +} + //----------------------------------------------------------------------------- // Index attribute verification //----------------------------------------------------------------------------- @@ -86,6 +99,38 @@ LogicalResult wave::verifyWaveIndexMappings(Operation *op) { } } + // For ops with the index attribute, verify that each index expression has at + // most one dimension whose step evaluates to a static value different from 1 + // (with hyperparameters substituted). Structural checks stay in op verifiers. + wave::WaveHyperparameterAttr hyperparams = wave::getHyperparameters(op); + for (DictionaryAttr dictAttr : dicts) { + int nonUnitCount = 0; + for (const NamedAttribute &named : dictAttr) { + auto mapping = dyn_cast(named.getValue()); + if (!mapping || !mapping.getStep()) + continue; + + std::optional> stepValues = + wave::evaluateMapWithHyperparams(mapping.getStep(), + mapping.getSymbols(), hyperparams); + if (!stepValues || stepValues->size() != 1) + continue; + + int64_t step = (*stepValues)[0]; + if (step == 1 || step == ShapedType::kDynamic) + continue; + + if (++nonUnitCount > 1) { + InFlightDiagnostic diag = + op->emitOpError() << "'" << WaveDialect::kIndexWaveExprListAttrName + << "' has more than one entry with non-unit step"; + diag.attachNote() << "second non-unit step dimension: " + << named.getName(); + return failure(); + } + } + } + // When the operation implements WaveInferIndexExprsOpInterface, the index // attribute length must match the number of values from // getIndexExprValuesAndDescriptions. Otherwise, default to the number of op diff --git a/water/lib/Dialect/Wave/IR/WaveOps.cpp b/water/lib/Dialect/Wave/IR/WaveOps.cpp index aa5cbb5d70..d87ad5808e 100644 --- a/water/lib/Dialect/Wave/IR/WaveOps.cpp +++ b/water/lib/Dialect/Wave/IR/WaveOps.cpp @@ -1572,9 +1572,8 @@ ReadOp::propagateBackward(MutableArrayRef operandTypes, LogicalResult ReadOp::finalizeTypeInference() { return success(); } -// Check the well-formedness of the index attribute (must have at most one -// non-unit dimension) and its correspondence with the explicit elements per -// thread, if provided, and with the number of elements in the vector type. +// Check the correspondence of the index attribute with the explicit elements +// per thread, if provided, and with the number of elements in the vector type. static LogicalResult verifyIndexElementsPerThread(Operation *op, ArrayAttr indexAttr, std::optional elementsPerThread, @@ -1605,7 +1604,7 @@ verifyIndexElementsPerThread(Operation *op, ArrayAttr indexAttr, if (!indexDict) return success(); - wave::WaveHyperparameterAttr hyper = wave::WaveHyperparameterAttr(); + wave::WaveHyperparameterAttr hyper = nullptr; for (Operation *cur = op; cur != nullptr && !hyper; cur = cur->getParentOp()) { hyper = cur->getAttrOfType( @@ -1621,7 +1620,7 @@ verifyIndexElementsPerThread(Operation *op, ArrayAttr indexAttr, getUncollapsedVectorShape(tensorType.getShape(), indexDict, hyper); int64_t nonUnit = 1; bool hadDynamic = false; - for (auto [i, size] : llvm::enumerate(shape)) { + for (int64_t size : shape) { if (ShapedType::isDynamic(size)) { hadDynamic = true; continue; @@ -1632,13 +1631,7 @@ verifyIndexElementsPerThread(Operation *op, ArrayAttr indexAttr, } if (nonUnit == 1) { nonUnit = size; - continue; } - - InFlightDiagnostic diag = - op->emitError() << "'index' has more than one entry with non-unit step"; - diag.attachNote() << "second non-unit step dimension: " << i; - return diag; } // If there were unevaluated steps, they may end up matching later on. diff --git a/water/lib/Dialect/Wave/Transforms/InferTypes.cpp b/water/lib/Dialect/Wave/Transforms/InferTypes.cpp index cc8126c53d..61362d1a5e 100644 --- a/water/lib/Dialect/Wave/Transforms/InferTypes.cpp +++ b/water/lib/Dialect/Wave/Transforms/InferTypes.cpp @@ -829,18 +829,11 @@ class ElementsPerThreadForwardAnalysis // Elements per thread may be 1 if _all_ dimensions have a unit step, // otherwise it should be the one non-unit step. - // TODO(#1013): this logic can be reused in the verifier. - if (!elementsPerThread.has_value()) { - elementsPerThread = (*stepValues)[0]; - } else if (*elementsPerThread == 1) { - elementsPerThread = (*stepValues)[0]; - } else if (stepValue != 1) { - // TODO(#1013): turn this into an assertion when the verifier is - // implemented. - op->emitError() << "expected only one non-unit index step, found " - << (*stepValues)[0] << " and " << *elementsPerThread - << " (missing verifier)"; - return WalkResult::interrupt(); + assert((!elementsPerThread.has_value() || *elementsPerThread == 1 || + stepValue == 1) && + "expected only one non-unit index step"); + if (!elementsPerThread.has_value() || *elementsPerThread == 1) { + elementsPerThread = stepValue; } } diff --git a/water/lib/Dialect/Wave/Transforms/Utils.cpp b/water/lib/Dialect/Wave/Transforms/Utils.cpp index 25003e33e0..7234ff5405 100644 --- a/water/lib/Dialect/Wave/Transforms/Utils.cpp +++ b/water/lib/Dialect/Wave/Transforms/Utils.cpp @@ -21,15 +21,6 @@ using namespace mlir; -wave::WaveHyperparameterAttr wave::getHyperparameters(Operation *op) { - for (Operation *current = op; current; current = current->getParentOp()) { - if (auto hyperparams = current->getAttrOfType( - WaveDialect::kHyperparameterAttrName)) - return hyperparams; - } - return nullptr; -} - llvm::LogicalResult wave::collectWaveConstraints( Operation *top, llvm::DenseMap &constraints) { auto *waveDialect = top->getContext()->getLoadedDialect(); diff --git a/water/test/Dialect/Wave/ops-invalid.mlir b/water/test/Dialect/Wave/ops-invalid.mlir index da3ba333b1..0a21325099 100644 --- a/water/test/Dialect/Wave/ops-invalid.mlir +++ b/water/test/Dialect/Wave/ops-invalid.mlir @@ -602,9 +602,11 @@ func.func @bounds_wrong_rank(%mem: !wave.tensor<[@N] of f32>) { // ----- -func.func @read_index_multi_step(%mem: !wave.tensor<[@M, @N] of f32>) { +func.func @read_index_multi_step(%mem: !wave.tensor<[@M, @N] of f32>) attributes { + wave.hyperparameters = #wave.hyperparameters<{M = 1, N = 1}> +} { // expected-error @below {{'index' has more than one entry with non-unit step}} - // expected-note @below {{second non-unit step dimension: 1}} + // expected-note @below {{second non-unit step dimension: "N"}} wave.read %mem index [{ M : <[#wave.index_symbol] -> (T0, 2, 1)>, N : <[#wave.index_symbol] -> (T1, 2, 1)> @@ -643,7 +645,7 @@ func.func @read_index_multi_step_eval(%mem: !wave.tensor<[@M, @N] of f32>) attri wave.hyperparameters = #wave.hyperparameters<{X = 1, Y = 1, M = 100, N = 200}> } { // expected-error @below {{'index' has more than one entry with non-unit step}} - // expected-note @below {{second non-unit step dimension: 1}} + // expected-note @below {{second non-unit step dimension: "N"}} wave.read %mem index [{ M : <[#wave.index_symbol, #wave.symbol<"X">] -> (T0, 2 * X, 1)>, N : <[#wave.index_symbol, #wave.symbol<"X">, #wave.symbol<"Y">] -> (T1, X + Y, 1)> @@ -653,6 +655,20 @@ func.func @read_index_multi_step_eval(%mem: !wave.tensor<[@M, @N] of f32>) attri // ----- +func.func @write_index_multi_step_eval(%val: !wave.tensor<[@M, @N] of f32, >, %mem: !wave.tensor<[@M, @N] of f32, >) attributes { + wave.hyperparameters = #wave.hyperparameters<{X = 1, Y = 1, M = 100, N = 200}> +} { + // expected-error @below {{'index' has more than one entry with non-unit step}} + // expected-note @below {{second non-unit step dimension: "N"}} + wave.write %val, %mem index [{ + M : <[#wave.index_symbol, #wave.symbol<"X">] -> (T0, 2 * X, 1)>, + N : <[#wave.index_symbol, #wave.symbol<"X">, #wave.symbol<"Y">] -> (T1, X + Y, 1)> + }] : !wave.tensor<[@M, @N] of f32, >, !wave.tensor<[@M, @N] of f32, > + return +} + +// ----- + func.func @extract_invalid_position_rank(%src: !wave.tensor<[@M, @N] of f32>) { // expected-error @below {{position must contain exactly one expression, but got 2}} wave.extract %src[#wave.expr_list<[] -> (0, 1)>] : (!wave.tensor<[@M, @N] of f32>) -> !wave.tensor<[@M] of f32> diff --git a/water/test/Dialect/Wave/propagate-elements-per-thread.mlir b/water/test/Dialect/Wave/propagate-elements-per-thread.mlir index f3ab0fd78a..cc680e9280 100644 --- a/water/test/Dialect/Wave/propagate-elements-per-thread.mlir +++ b/water/test/Dialect/Wave/propagate-elements-per-thread.mlir @@ -148,21 +148,6 @@ normalform.module [#wave.normal_form] { // ----- -// Two dimensions with non-unit steps; pass must report "expected only one non-unit". -// Use only register (write has its own verifier for multi-step index). -normalform.module [#wave.normal_form] { - func.func @index_multi_non_unit_step(%mem: !wave.tensor<[@M, @N] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128, N = 64}>, wave.constraints = []} { - %cst = arith.constant 0.0 : f16 - // expected-error @below {{expected only one non-unit index step}} - %reg = wave.register %cst index [{M : <[] -> (, 4, )>, N : <[] -> (, 8, )>}] : !wave.tensor<[@M, @N] of f16, > - wave.write %reg, %mem index [{M : <[] -> (, 4, )>, N : <[] -> (, 1, )>}] - : !wave.tensor<[@M, @N] of f16, >, !wave.tensor<[@M, @N] of f16, > - return - } -} - -// ----- - // Index missing dimension N for result type [M, N]; pass must report missing dimensions. normalform.module [#wave.normal_form] { func.func @index_missing_dimension(%mem: !wave.tensor<[@M, @N] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128, N = 64}>, wave.constraints = []} {