From b24c0ca09757a808ae976aa2f6c393e6bb3a0ead Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 5 Mar 2026 13:29:45 +0100 Subject: [PATCH] [water] verify index expr step/strire are positive Add a verifier that index expression step and stride are strictly positive to avoid semantically undefined behavior (what does it mean to have a 0-element piece of the tensor? stride zero is hidden broadcasting, negative stride is hidden reversal, both are deterimental to dependence analysis). Closes #1012. Signed-off-by: Alex Zinenko --- water/lib/Dialect/Wave/IR/WaveInterfaces.cpp | 63 +++++++++++++------ .../Dialect/Wave/Transforms/InferTypes.cpp | 10 +-- water/test/Dialect/Wave/ops-invalid.mlir | 40 ++++++++++++ .../Wave/propagate-elements-per-thread.mlir | 14 ----- 4 files changed, 86 insertions(+), 41 deletions(-) diff --git a/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp b/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp index 7125e466f..eb6b459a5 100644 --- a/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp +++ b/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp @@ -99,34 +99,57 @@ LogicalResult wave::verifyWaveIndexMappings(Operation *op) { } } - // For ops with the index attribute, verify that each index expression has at - // most one dimension whose step evaluates to a static value different from 1 - // (with hyperparameters substituted). Structural checks stay in op verifiers. + // For ops with the index attribute, verify that (1) each index expression has + // at most one dimension whose step evaluates to a static value different from + // 1 (with hyperparameters substituted), and (2) when step or stride can be + // evaluated to a concrete value, that value is strictly positive. Be + // defensive because we may not have verified anything but the basic + // well-formedness yet, e.g., the op verifier checking for single-result + // affine expressions in mappings did not run yet. wave::WaveHyperparameterAttr hyperparams = wave::getHyperparameters(op); for (DictionaryAttr dictAttr : dicts) { int nonUnitCount = 0; for (const NamedAttribute &named : dictAttr) { auto mapping = dyn_cast(named.getValue()); - if (!mapping || !mapping.getStep()) + if (!mapping) continue; - std::optional> stepValues = - wave::evaluateMapWithHyperparams(mapping.getStep(), - mapping.getSymbols(), hyperparams); - if (!stepValues || stepValues->size() != 1) - continue; - - int64_t step = (*stepValues)[0]; - if (step == 1 || step == ShapedType::kDynamic) - continue; + if (AffineMap stepMap = mapping.getStep()) { + std::optional> stepValues = + wave::evaluateMapWithHyperparams(stepMap, mapping.getSymbols(), + hyperparams); + if (stepValues && stepValues->size() == 1) { + int64_t step = (*stepValues)[0]; + if (step != ShapedType::kDynamic && step <= 0) { + return op->emitOpError() + << "step in index expression must be strictly positive, got " + << step << " for dimension " << named.getName(); + } + if (step != 1 && step != ShapedType::kDynamic && ++nonUnitCount > 1) { + InFlightDiagnostic diag = + op->emitOpError() + << "'" << WaveDialect::kIndexWaveExprListAttrName + << "' has more than one entry with non-unit step"; + diag.attachNote() + << "second non-unit step dimension: " << named.getName(); + return failure(); + } + } + } - if (++nonUnitCount > 1) { - InFlightDiagnostic diag = - op->emitOpError() << "'" << WaveDialect::kIndexWaveExprListAttrName - << "' has more than one entry with non-unit step"; - diag.attachNote() << "second non-unit step dimension: " - << named.getName(); - return failure(); + if (AffineMap strideMap = mapping.getStride()) { + std::optional> strideValues = + wave::evaluateMapWithHyperparams(strideMap, mapping.getSymbols(), + hyperparams); + if (strideValues && strideValues->size() == 1) { + int64_t stride = (*strideValues)[0]; + if (stride != ShapedType::kDynamic && stride <= 0) { + return op->emitOpError() + << "stride in index expression must be strictly positive, " + "got " + << stride << " for dimension " << named.getName(); + } + } } } } diff --git a/water/lib/Dialect/Wave/Transforms/InferTypes.cpp b/water/lib/Dialect/Wave/Transforms/InferTypes.cpp index 61362d1a5..3d0a8895a 100644 --- a/water/lib/Dialect/Wave/Transforms/InferTypes.cpp +++ b/water/lib/Dialect/Wave/Transforms/InferTypes.cpp @@ -818,14 +818,10 @@ class ElementsPerThreadForwardAnalysis wave::evaluateMapWithHyperparams(step, symbols, init.hyperparams); if (!stepValues) continue; - // TODO(#1012): turn this into an assertion when the verifier is - // implemented. + int64_t stepValue = (*stepValues)[0]; - if (stepValue <= 0) { - op->emitError() << "expected positive step in index expressions " - "(missing verifier)"; - return WalkResult::interrupt(); - } + assert(stepValue > 0 && + "expected positive step in index expressions"); // Elements per thread may be 1 if _all_ dimensions have a unit step, // otherwise it should be the one non-unit step. diff --git a/water/test/Dialect/Wave/ops-invalid.mlir b/water/test/Dialect/Wave/ops-invalid.mlir index 0a2132509..d49725bff 100644 --- a/water/test/Dialect/Wave/ops-invalid.mlir +++ b/water/test/Dialect/Wave/ops-invalid.mlir @@ -669,6 +669,46 @@ func.func @write_index_multi_step_eval(%val: !wave.tensor<[@M, @N] of f32, ) attributes { + wave.hyperparameters = #wave.hyperparameters<{M = 128}> +} { + // expected-error @below {{step in index expression must be strictly positive, got 0 for dimension "M"}} + wave.read %mem index [{M : <[] -> (0, 0, 1)>}] : (!wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32, > + return +} + +// ----- + +func.func @read_index_negative_step(%mem: !wave.tensor<[@M] of f32>) attributes { + wave.hyperparameters = #wave.hyperparameters<{M = 128}> +} { + // expected-error @below {{step in index expression must be strictly positive, got -1 for dimension "M"}} + wave.read %mem index [{M : <[] -> (0, -1, 1)>}] : (!wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32, > + return +} + +// ----- + +func.func @read_index_zero_stride(%mem: !wave.tensor<[@M] of f32>) attributes { + wave.hyperparameters = #wave.hyperparameters<{M = 128}> +} { + // expected-error @below {{stride in index expression must be strictly positive, got 0 for dimension "M"}} + wave.read %mem index [{M : <[] -> (0, 1, 0)>}] : (!wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32, > + return +} + +// ----- + +func.func @read_index_negative_stride(%mem: !wave.tensor<[@M] of f32>) attributes { + wave.hyperparameters = #wave.hyperparameters<{M = 128}> +} { + // expected-error @below {{stride in index expression must be strictly positive, got -1 for dimension "M"}} + wave.read %mem index [{M : <[] -> (0, 1, -1)>}] : (!wave.tensor<[@M] of f32>) -> !wave.tensor<[@M] of f32, > + return +} + +// ----- + func.func @extract_invalid_position_rank(%src: !wave.tensor<[@M, @N] of f32>) { // expected-error @below {{position must contain exactly one expression, but got 2}} wave.extract %src[#wave.expr_list<[] -> (0, 1)>] : (!wave.tensor<[@M, @N] of f32>) -> !wave.tensor<[@M] of f32> diff --git a/water/test/Dialect/Wave/propagate-elements-per-thread.mlir b/water/test/Dialect/Wave/propagate-elements-per-thread.mlir index cc680e928..0d92ab29c 100644 --- a/water/test/Dialect/Wave/propagate-elements-per-thread.mlir +++ b/water/test/Dialect/Wave/propagate-elements-per-thread.mlir @@ -134,20 +134,6 @@ normalform.module [#wave.normal_form] { // ----- -// Step is zero; pass must report "expected positive step". -normalform.module [#wave.normal_form] { - func.func @index_step_zero(%mem: !wave.tensor<[@M] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>, wave.constraints = []} { - %cst = arith.constant 0.0 : f16 - // expected-error @below {{expected positive step in index expressions}} - %reg = wave.register %cst index [{M : <[] -> (, 0, )>}] : !wave.tensor<[@M] of f16, > - wave.write %reg, %mem index [{M : <[] -> (, 0, )>}] - : !wave.tensor<[@M] of f16, >, !wave.tensor<[@M] of f16, > - return - } -} - -// ----- - // Index missing dimension N for result type [M, N]; pass must report missing dimensions. normalform.module [#wave.normal_form] { func.func @index_missing_dimension(%mem: !wave.tensor<[@M, @N] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128, N = 64}>, wave.constraints = []} {