[water] priority-based index expression propagation#734
[water] priority-based index expression propagation#734ftynse wants to merge 1 commit intousers/ftynse/default-index-exprfrom
Conversation
| // TODO: pywave just ignores this not sure if we want to, including the | ||
| // case below where there may be zero constraints. Interestingly, it | ||
| // asserts if trailing dimensions are not found when computing the | ||
| // stride... |
There was a problem hiding this comment.
Is it safe to simply ignore the symbols for which there are no constraints when setting index sequences from write?
There was a problem hiding this comment.
If no constraints are specified or the vector shape is not set to 0 (dimensions we don't want to expand), then the symbol either corresponds to the actual tensor dimension or is set dynamically in the kernel. I don't think we should ignore the symbol because it could be meaningful in the analysis.
| emitError() << "expected a single workgroup constraint for dimension " | ||
| << tensorType.getShape()[i] | ||
| << " used in the write operation without explicit " | ||
| "`elements_per_thread`"; | ||
| return failure(); |
There was a problem hiding this comment.
Ditto, but in absence of a workgroup constraint?
It feels like we need to set it to start=0, and likely size=1 and stride=1 but not sure
There was a problem hiding this comment.
I think the code does make some assumptions like that where it falls back to start = 0, size and stride of 1, but I think we shouldn't allow that and instead be more explicit.
| // TODO: in pywave, we always do `startExpr % threadsPerWave` where | ||
| // threadsPerWave == 1 for workgroup dims other than X, making it | ||
| // always zero. It mentions an assumption about the (64, 1, 1) thread | ||
| // shape, but it is unclear whether that assumption always holds. | ||
| // It looks like the intention for this was to express lane ID rather | ||
| // than thread ID, but it is unclear how it accounts for multiple | ||
| // wavefronts running in parallel. |
There was a problem hiding this comment.
The comment in the original source (
wave/wave_lang/kernel/wave/constraints.py
Lines 498 to 501 in 601ab68
There was a problem hiding this comment.
This comes up in the SIMT context (no MMA, you can also see this in the example for the atomic case). If you look at the original code, what was happening was that because we dont have an MMA, the default pattern for SIMT is a thread linear pattern and so for the atomicAdd we were getting a dependence on x and y, even though that shouldn't be the case for the example. So this was a fix to handle that scenario. Will also tag @nithinsubbiah to add more context.
There was a problem hiding this comment.
Okay, in absence of further comments (would have been appreciated), my investigation turns up the following: this is indeed laneId as I suspected and was intentionally added in 2070bcf. However, this has an implicit assumption that a WaveConstraint is present on the same dimension and contributing a component that involves wave_id, which is floordiv(threadId, threadPerWave), to the start expression. In absence of a WaveConstraint, it appears that the start expression will simply be incorrect for the multi-wave-along-X case.
There was a problem hiding this comment.
For other dimensions, I suppose the intent is that they are going to be expanded at which point they should just start with 0 and then the expansion will update them. No need to obfuscate that via modulo operations.
| pair.first, | ||
| pair.second); | ||
| })), | ||
| std::max(lhs.getPriority(), rhs.getPriority())); |
There was a problem hiding this comment.
do we need priority per-symbol?..
| << " used in the write operation without explicit " | ||
| "`elements_per_thread`"; |
There was a problem hiding this comment.
Nit: at this point, we haven't checked for elements_per_thread yet...
| // XXX: don't report this error immediately since we may be able to proceed | ||
| // without it, e.g., when index expressions can be propagate from non-write | ||
| // operations to this one. This may be a questionable design choice carried | ||
| // over from the initial Python prototype. |
There was a problem hiding this comment.
We still use these in the stride computation below... But it may be wrong
| // TODO: in pywave, we always do `startExpr % threadsPerWave` where | ||
| // threadsPerWave == 1 for workgroup dims other than X, making it | ||
| // always zero. It mentions an assumption about the (64, 1, 1) thread | ||
| // shape, but it is unclear whether that assumption always holds. | ||
| // It looks like the intention for this was to express lane ID rather | ||
| // than thread ID, but it is unclear how it accounts for multiple | ||
| // wavefronts running in parallel. |
There was a problem hiding this comment.
Okay, in absence of further comments (would have been appreciated), my investigation turns up the following: this is indeed laneId as I suspected and was intentionally added in 2070bcf. However, this has an implicit assumption that a WaveConstraint is present on the same dimension and contributing a component that involves wave_id, which is floordiv(threadId, threadPerWave), to the start expression. In absence of a WaveConstraint, it appears that the start expression will simply be incorrect for the multi-wave-along-X case.
| // TODO: in pywave, we always do `startExpr % threadsPerWave` where | ||
| // threadsPerWave == 1 for workgroup dims other than X, making it | ||
| // always zero. It mentions an assumption about the (64, 1, 1) thread | ||
| // shape, but it is unclear whether that assumption always holds. | ||
| // It looks like the intention for this was to express lane ID rather | ||
| // than thread ID, but it is unclear how it accounts for multiple | ||
| // wavefronts running in parallel. |
There was a problem hiding this comment.
For other dimensions, I suppose the intent is that they are going to be expanded at which point they should just start with 0 and then the expansion will update them. No need to obfuscate that via modulo operations.
| int64_t stride = 1; | ||
| for (int64_t j = i + 1; j < e; ++j) { | ||
| Attribute vectorShape = hardwareConstraint.getVectorShapes().get( | ||
| tensorType.getShape()[j].getName()); | ||
| if (!vectorShape) { | ||
| emitError() << "couldn't find vector shape for dimension " | ||
| << tensorType.getShape()[j]; | ||
| return failure(); | ||
| } | ||
| stride *= cast<IntegerAttr>(vectorShape).getValue().getSExtValue(); | ||
| } |
There was a problem hiding this comment.
I have a strong suspicion that usage of strides is inconsistent: here we use linear row-major stride (changed in 1ddf92d), but for MMAs these remain per-dimension strides. That being said, they don't seem to affect code generation at all, which makes me wonder why do we even use them (except for mma-style strides that may be involved in strided write splitting)....
This mimics the index exrepssion propagation behavior pywave has with its multiple passes of propagation based on different "source" ops in a single dataflow process, with convergence guarantees. Earlier "source" operations assign higher-priority index expressions to values that completely override lower-priority expressions on join. Only equal-priority expressions perform actual per-dimension join. Several design decisions are replicated for consistency purposes, but should be revised in the future to provide a unified approach to inference without surprises such as the "mode switch" that results in completely different expressions when an mma is removed. This is particularly true for the case of missing `vector_shape` part of the hardware constraint that is only necessary when a "write" operation isn't writing a value transitively obtained from an MMA. Supporting this requires deferring diagnostic emission to account for the case where the process would converge thanks to a higher-priority expression coming from elsewhere. Signed-off-by: Alex Zinenko <git@ozinenko.com>
f6202f7 to
be097bf
Compare
There was a problem hiding this comment.
Pull request overview
This PR introduces priority-based index-expression propagation in the Wave dialect dataflow analyses (so higher-confidence producers like MMA can override lower-confidence ones like writes), adds delayed diagnostics to improve error context, and expands MLIR + Python tests for index inference/propagation.
Changes:
- Add lattice priority to index-expression inference and use it during joins to prefer higher-priority mappings.
- Implement write-based backward initialization of index expressions and surface “delayed” diagnostic context when inference fails.
- Update/extend MLIR and Python tests to cover propagation from writes and priority conflict resolution.
Reviewed changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| wave_lang/kernel/wave/analysis/index_sequence_analysis.py | Adjusts Python-side Water-vs-Python index checking with stronger positivity assumptions for symbols. |
| water/test/lib/Transforms/TestWaveDialectInferIndexExprs.cpp | Updates test pass override format to include per-value lattice priority; threads delayed-error info through result setting. |
| water/test/Dialect/Wave/infer-index-exprs.mlir | Adds/updates MLIR tests for write-driven propagation and priority behavior. |
| water/test/Dialect/Wave/infer-index-exprs-lattice.mlir | Updates lattice override syntax to include priority and adds priority-join/conflict tests. |
| water/lib/Dialect/Wave/Transforms/InferTypes.cpp | Wires priority + delayed-error plumbing through the analysis pipeline and result materialization. |
| water/lib/Dialect/Wave/IR/WaveOps.cpp | Assigns priorities to MMA-derived mappings and adds WriteOp::initializeIndexExprsBackward. |
| water/lib/Dialect/Wave/IR/WaveInterfaces.cpp | Extends the lattice storage with priority and modifies join behavior to prefer higher-priority mappings. |
| water/include/water/Dialect/Wave/Transforms/Utils.h | Introduces delayed-error types (EmitDelayedErrorFn, DelayedErrorEmitterInfo). |
| water/include/water/Dialect/Wave/Transforms/DataFlowAnalyses.h | Updates analysis API to return delayed-error info and to accept priority during initialization overrides. |
| water/include/water/Dialect/Wave/IR/WaveOps.td | Declares WriteOp as implementing initializeIndexExprsBackward. |
| water/include/water/Dialect/Wave/IR/WaveInterfaces.td | Extends interface method signature to accept a delayed-error emitter out-param. |
| water/include/water/Dialect/Wave/IR/WaveInterfaces.h | Adds lattice priority constants and priority-carrying constructor. |
| lit_tests/kernel/wave/infer_index_exprs.py | Adds lit-level Python kernels covering matrix add and an MMA->write->read->MMA chain under Water analysis checking. |
Comments suppressed due to low confidence (1)
water/lib/Dialect/Wave/IR/WaveInterfaces.cpp:741
IndexExprsLatticeStoragenow has apriority, but equality/inequality still compares only thevaluepointer/int pair. This breaks priority-based propagation: lattices with the same concrete dict but different priorities will be treated as equal, so higher-priority information may not propagate. Related:joinhas fast-paths based onvalueequality that will also drop priority unless updated. Make equality (and anyvalue-equality early returns) account forprioritywhen the lattice is concrete.
wave::IndexExprsLatticeStorage::IndexExprsLatticeStorage()
: value(nullptr, kUninitializedState), priority(kLowestPriority) {}
wave::IndexExprsLatticeStorage::IndexExprsLatticeStorage(
DictionaryAttr concreteValue, int32_t priority)
: value(concreteValue, kSpecificTypeState), priority(priority) {}
bool wave::IndexExprsLatticeStorage::operator==(
const IndexExprsLatticeStorage &other) const {
return value == other.value;
}
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| #include "water/Dialect/Wave/IR/WaveAttrs.h" | ||
|
|
||
| namespace wave { | ||
|
|
||
| // Callback to generate a delayed diagnostic. The diagnostic message should be | ||
| // attached to the argument. It may or may not be emitted. | ||
| using EmitDelayedErrorFn = std::function<void(mlir::InFlightDiagnostic &)>; | ||
|
|
||
| // Information to emit delayed errors. | ||
| struct DelayedErrorEmitterInfo { | ||
| // Returns the delayed error for the given operation. | ||
| std::function<EmitDelayedErrorFn(mlir::Operation *)> getDelayedError; | ||
|
|
||
| // Returns true if there are any delayed errors. | ||
| std::function<bool()> hasDelayedErrors; | ||
| }; |
There was a problem hiding this comment.
Utils.h introduces std::function (and EmitDelayedErrorFn) but the header does not include <functional>. This can cause build failures depending on include order; add the standard header include so Utils.h is self-contained.
|
|
||
| #define DEBUG_TYPE "wave-infer-types" | ||
|
|
||
| using namespace mlir; |
There was a problem hiding this comment.
There are now two identical using namespace mlir; directives at the top of the file. Remove the duplicate to avoid redundancy and keep the header section clean.
| using namespace mlir; |
| if (failed(detail::checkAndAppendIndexExpr(iface->getLoc(), | ||
| getLatticeValue(value), | ||
| os.str(), indexExprs))) | ||
| return WalkResult::interrupt(); | ||
| os.str(), indexExprs))) { | ||
| // Don't stop on the first reported error if there are some delayed | ||
| // errors that would be useful to report here. We need to wait and | ||
| // see whether the operation they are attached to actually has had | ||
| // inference issues as some errors may be corrected. | ||
| if (!delayedErrorInfo.hasDelayedErrors()) | ||
| return WalkResult::interrupt(); | ||
|
|
||
| hadFailures = true; | ||
| if (auto delayedError = delayedErrorInfo.getDelayedError(iface)) { | ||
| InFlightDiagnostic diag = | ||
| iface->emitError() | ||
| << "the error above may be caused by the following: "; | ||
| delayedError(diag); | ||
| } | ||
| } | ||
| } | ||
| iface->setAttr(wave::WaveDialect::kIndexWaveExprListAttrName, | ||
| ArrayAttr::get(iface->getContext(), indexExprs)); |
There was a problem hiding this comment.
When checkAndAppendIndexExpr fails and delayed errors exist, the code continues but still unconditionally sets the index attribute using the partially-built indexExprs array. This can leave ops with an invalid/partial index attribute (e.g., wrong length), which may trigger verifier assertions or produce confusing downstream diagnostics even though the pass ultimately fails. Consider skipping setAttr (or erasing/restoring the attr) for ops where any index expr failed, while still emitting the delayed diagnostic context.
| else seq.stride | ||
| ), | ||
| ) | ||
| for dim, seq in node.index.items() |
There was a problem hiding this comment.
ensure_symbols_positive ignores its seqs argument and always iterates over node.index. As a result, when you call it for inferred_index, it will still rebuild sequences from node.index, so the comparison can incorrectly pass/fail (and substitutions may not be applied to the inferred sequences at all). Iterate over seqs.items() in the returned dict comprehension instead of node.index.items().
| for dim, seq in node.index.items() | |
| for dim, seq in seqs.items() |
This mimics the index exrepssion propagation behavior pywave has with
its multiple passes of propagation based on different "source" ops in a
single dataflow process, with convergence guarantees. Earlier "source"
operations assign higher-priority index expressions to values that
completely override lower-priority expressions on join. Only
equal-priority expressions perform actual per-dimension join.
Several design decisions are replicated for consistency purposes, but
should be revised in the future to provide a unified approach to
inference without surprises such as the "mode switch" that results in
completely different expressions when an mma is removed.
This is particularly true for the case of missing
vector_shapepart ofthe hardware constraint that is only necessary when a "write" operation
isn't writing a value transitively obtained from an MMA. Supporting this
requires deferring diagnostic emission to account for the case where the
process would converge thanks to a higher-priority expression coming
from elsewhere.