Skip to content

[water] priority-based index expression propagation#734

Open
ftynse wants to merge 1 commit intousers/ftynse/default-index-exprfrom
users/ftynse/priority-index-expr
Open

[water] priority-based index expression propagation#734
ftynse wants to merge 1 commit intousers/ftynse/default-index-exprfrom
users/ftynse/priority-index-expr

Conversation

@ftynse
Copy link
Contributor

@ftynse ftynse commented Jan 14, 2026

This mimics the index exrepssion propagation behavior pywave has with
its multiple passes of propagation based on different "source" ops in a
single dataflow process, with convergence guarantees. Earlier "source"
operations assign higher-priority index expressions to values that
completely override lower-priority expressions on join. Only
equal-priority expressions perform actual per-dimension join.

Several design decisions are replicated for consistency purposes, but
should be revised in the future to provide a unified approach to
inference without surprises such as the "mode switch" that results in
completely different expressions when an mma is removed.

This is particularly true for the case of missing vector_shape part of
the hardware constraint that is only necessary when a "write" operation
isn't writing a value transitively obtained from an MMA. Supporting this
requires deferring diagnostic emission to account for the case where the
process would converge thanks to a higher-priority expression coming
from elsewhere.

Comment on lines +1757 to +1760
// TODO: pywave just ignores this not sure if we want to, including the
// case below where there may be zero constraints. Interestingly, it
// asserts if trailing dimensions are not found when computing the
// stride...
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it safe to simply ignore the symbols for which there are no constraints when setting index sequences from write?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If no constraints are specified or the vector shape is not set to 0 (dimensions we don't want to expand), then the symbol either corresponds to the actual tensor dimension or is set dynamically in the kernel. I don't think we should ignore the symbol because it could be meaningful in the analysis.

Comment on lines +1782 to +1786
emitError() << "expected a single workgroup constraint for dimension "
<< tensorType.getShape()[i]
<< " used in the write operation without explicit "
"`elements_per_thread`";
return failure();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto, but in absence of a workgroup constraint?

It feels like we need to set it to start=0, and likely size=1 and stride=1 but not sure

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the code does make some assumptions like that where it falls back to start = 0, size and stride of 1, but I think we shouldn't allow that and instead be more explicit.

Comment on lines +1861 to +1867
// TODO: in pywave, we always do `startExpr % threadsPerWave` where
// threadsPerWave == 1 for workgroup dims other than X, making it
// always zero. It mentions an assumption about the (64, 1, 1) thread
// shape, but it is unclear whether that assumption always holds.
// It looks like the intention for this was to express lane ID rather
// than thread ID, but it is unclear how it accounts for multiple
// wavefronts running in parallel.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment in the original source (

# We have an assumption that the thread dimensions in each wave is of shape (64,1,1).
# In cases other than dimension 0, we also calculate the modulus of thread_id with the
# number of threads in that dimension to prevent double counting of thread ID in thread
# independent index.
) says something about preventing double counting of thread id, but I can't infer where and why it would be counted twice. The support for it was added in a commit for atomics, 4eeee9a, which is doesn't provide an explanation either

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comes up in the SIMT context (no MMA, you can also see this in the example for the atomic case). If you look at the original code, what was happening was that because we dont have an MMA, the default pattern for SIMT is a thread linear pattern and so for the atomicAdd we were getting a dependence on x and y, even though that shouldn't be the case for the example. So this was a fix to handle that scenario. Will also tag @nithinsubbiah to add more context.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, in absence of further comments (would have been appreciated), my investigation turns up the following: this is indeed laneId as I suspected and was intentionally added in 2070bcf. However, this has an implicit assumption that a WaveConstraint is present on the same dimension and contributing a component that involves wave_id, which is floordiv(threadId, threadPerWave), to the start expression. In absence of a WaveConstraint, it appears that the start expression will simply be incorrect for the multi-wave-along-X case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For other dimensions, I suppose the intent is that they are going to be expanded at which point they should just start with 0 and then the expansion will update them. No need to obfuscate that via modulo operations.

pair.first,
pair.second);
})),
std::max(lhs.getPriority(), rhs.getPriority()));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need priority per-symbol?..

Comment on lines +1764 to +1765
<< " used in the write operation without explicit "
"`elements_per_thread`";
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: at this point, we haven't checked for elements_per_thread yet...

Comment on lines +1742 to +1745
// XXX: don't report this error immediately since we may be able to proceed
// without it, e.g., when index expressions can be propagate from non-write
// operations to this one. This may be a questionable design choice carried
// over from the initial Python prototype.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still use these in the stride computation below... But it may be wrong

Comment on lines +1861 to +1867
// TODO: in pywave, we always do `startExpr % threadsPerWave` where
// threadsPerWave == 1 for workgroup dims other than X, making it
// always zero. It mentions an assumption about the (64, 1, 1) thread
// shape, but it is unclear whether that assumption always holds.
// It looks like the intention for this was to express lane ID rather
// than thread ID, but it is unclear how it accounts for multiple
// wavefronts running in parallel.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, in absence of further comments (would have been appreciated), my investigation turns up the following: this is indeed laneId as I suspected and was intentionally added in 2070bcf. However, this has an implicit assumption that a WaveConstraint is present on the same dimension and contributing a component that involves wave_id, which is floordiv(threadId, threadPerWave), to the start expression. In absence of a WaveConstraint, it appears that the start expression will simply be incorrect for the multi-wave-along-X case.

Comment on lines +1861 to +1867
// TODO: in pywave, we always do `startExpr % threadsPerWave` where
// threadsPerWave == 1 for workgroup dims other than X, making it
// always zero. It mentions an assumption about the (64, 1, 1) thread
// shape, but it is unclear whether that assumption always holds.
// It looks like the intention for this was to express lane ID rather
// than thread ID, but it is unclear how it accounts for multiple
// wavefronts running in parallel.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For other dimensions, I suppose the intent is that they are going to be expanded at which point they should just start with 0 and then the expansion will update them. No need to obfuscate that via modulo operations.

Comment on lines +1823 to +1833
int64_t stride = 1;
for (int64_t j = i + 1; j < e; ++j) {
Attribute vectorShape = hardwareConstraint.getVectorShapes().get(
tensorType.getShape()[j].getName());
if (!vectorShape) {
emitError() << "couldn't find vector shape for dimension "
<< tensorType.getShape()[j];
return failure();
}
stride *= cast<IntegerAttr>(vectorShape).getValue().getSExtValue();
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a strong suspicion that usage of strides is inconsistent: here we use linear row-major stride (changed in 1ddf92d), but for MMAs these remain per-dimension strides. That being said, they don't seem to affect code generation at all, which makes me wonder why do we even use them (except for mma-style strides that may be involved in strided write splitting)....

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed in #956

This mimics the index exrepssion propagation behavior pywave has with
its multiple passes of propagation based on different "source" ops in a
single dataflow process, with convergence guarantees. Earlier "source"
operations assign higher-priority index expressions to values that
completely override lower-priority expressions on join. Only
equal-priority expressions perform actual per-dimension join.

Several design decisions are replicated for consistency purposes, but
should be revised in the future to provide a unified approach to
inference without surprises such as the "mode switch" that results in
completely different expressions when an mma is removed.

This is particularly true for the case of missing `vector_shape` part of
the hardware constraint that is only necessary when a "write" operation
isn't writing a value transitively obtained from an MMA. Supporting this
requires deferring diagnostic emission to account for the case where the
process would converge thanks to a higher-priority expression coming
from elsewhere.

Signed-off-by: Alex Zinenko <git@ozinenko.com>
@ftynse ftynse force-pushed the users/ftynse/priority-index-expr branch from f6202f7 to be097bf Compare March 6, 2026 15:01
@ftynse ftynse changed the base branch from main to users/ftynse/default-index-expr March 6, 2026 15:02
@ftynse ftynse requested review from Copilot and martin-luecke March 6, 2026 15:03
@ftynse ftynse marked this pull request as ready for review March 6, 2026 15:03
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces priority-based index-expression propagation in the Wave dialect dataflow analyses (so higher-confidence producers like MMA can override lower-confidence ones like writes), adds delayed diagnostics to improve error context, and expands MLIR + Python tests for index inference/propagation.

Changes:

  • Add lattice priority to index-expression inference and use it during joins to prefer higher-priority mappings.
  • Implement write-based backward initialization of index expressions and surface “delayed” diagnostic context when inference fails.
  • Update/extend MLIR and Python tests to cover propagation from writes and priority conflict resolution.

Reviewed changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
wave_lang/kernel/wave/analysis/index_sequence_analysis.py Adjusts Python-side Water-vs-Python index checking with stronger positivity assumptions for symbols.
water/test/lib/Transforms/TestWaveDialectInferIndexExprs.cpp Updates test pass override format to include per-value lattice priority; threads delayed-error info through result setting.
water/test/Dialect/Wave/infer-index-exprs.mlir Adds/updates MLIR tests for write-driven propagation and priority behavior.
water/test/Dialect/Wave/infer-index-exprs-lattice.mlir Updates lattice override syntax to include priority and adds priority-join/conflict tests.
water/lib/Dialect/Wave/Transforms/InferTypes.cpp Wires priority + delayed-error plumbing through the analysis pipeline and result materialization.
water/lib/Dialect/Wave/IR/WaveOps.cpp Assigns priorities to MMA-derived mappings and adds WriteOp::initializeIndexExprsBackward.
water/lib/Dialect/Wave/IR/WaveInterfaces.cpp Extends the lattice storage with priority and modifies join behavior to prefer higher-priority mappings.
water/include/water/Dialect/Wave/Transforms/Utils.h Introduces delayed-error types (EmitDelayedErrorFn, DelayedErrorEmitterInfo).
water/include/water/Dialect/Wave/Transforms/DataFlowAnalyses.h Updates analysis API to return delayed-error info and to accept priority during initialization overrides.
water/include/water/Dialect/Wave/IR/WaveOps.td Declares WriteOp as implementing initializeIndexExprsBackward.
water/include/water/Dialect/Wave/IR/WaveInterfaces.td Extends interface method signature to accept a delayed-error emitter out-param.
water/include/water/Dialect/Wave/IR/WaveInterfaces.h Adds lattice priority constants and priority-carrying constructor.
lit_tests/kernel/wave/infer_index_exprs.py Adds lit-level Python kernels covering matrix add and an MMA->write->read->MMA chain under Water analysis checking.
Comments suppressed due to low confidence (1)

water/lib/Dialect/Wave/IR/WaveInterfaces.cpp:741

  • IndexExprsLatticeStorage now has a priority, but equality/inequality still compares only the value pointer/int pair. This breaks priority-based propagation: lattices with the same concrete dict but different priorities will be treated as equal, so higher-priority information may not propagate. Related: join has fast-paths based on value equality that will also drop priority unless updated. Make equality (and any value-equality early returns) account for priority when the lattice is concrete.
wave::IndexExprsLatticeStorage::IndexExprsLatticeStorage()
    : value(nullptr, kUninitializedState), priority(kLowestPriority) {}

wave::IndexExprsLatticeStorage::IndexExprsLatticeStorage(
    DictionaryAttr concreteValue, int32_t priority)
    : value(concreteValue, kSpecificTypeState), priority(priority) {}

bool wave::IndexExprsLatticeStorage::operator==(
    const IndexExprsLatticeStorage &other) const {
  return value == other.value;
}

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 10 to +25
#include "water/Dialect/Wave/IR/WaveAttrs.h"

namespace wave {

// Callback to generate a delayed diagnostic. The diagnostic message should be
// attached to the argument. It may or may not be emitted.
using EmitDelayedErrorFn = std::function<void(mlir::InFlightDiagnostic &)>;

// Information to emit delayed errors.
struct DelayedErrorEmitterInfo {
// Returns the delayed error for the given operation.
std::function<EmitDelayedErrorFn(mlir::Operation *)> getDelayedError;

// Returns true if there are any delayed errors.
std::function<bool()> hasDelayedErrors;
};
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Utils.h introduces std::function (and EmitDelayedErrorFn) but the header does not include <functional>. This can cause build failures depending on include order; add the standard header include so Utils.h is self-contained.

Copilot uses AI. Check for mistakes.

#define DEBUG_TYPE "wave-infer-types"

using namespace mlir;
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are now two identical using namespace mlir; directives at the top of the file. Remove the duplicate to avoid redundancy and keep the header section clean.

Suggested change
using namespace mlir;

Copilot uses AI. Check for mistakes.
Comment on lines 2012 to 2032
if (failed(detail::checkAndAppendIndexExpr(iface->getLoc(),
getLatticeValue(value),
os.str(), indexExprs)))
return WalkResult::interrupt();
os.str(), indexExprs))) {
// Don't stop on the first reported error if there are some delayed
// errors that would be useful to report here. We need to wait and
// see whether the operation they are attached to actually has had
// inference issues as some errors may be corrected.
if (!delayedErrorInfo.hasDelayedErrors())
return WalkResult::interrupt();

hadFailures = true;
if (auto delayedError = delayedErrorInfo.getDelayedError(iface)) {
InFlightDiagnostic diag =
iface->emitError()
<< "the error above may be caused by the following: ";
delayedError(diag);
}
}
}
iface->setAttr(wave::WaveDialect::kIndexWaveExprListAttrName,
ArrayAttr::get(iface->getContext(), indexExprs));
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When checkAndAppendIndexExpr fails and delayed errors exist, the code continues but still unconditionally sets the index attribute using the partially-built indexExprs array. This can leave ops with an invalid/partial index attribute (e.g., wrong length), which may trigger verifier assertions or produce confusing downstream diagnostics even though the pass ultimately fails. Consider skipping setAttr (or erasing/restoring the attr) for ops where any index expr failed, while still emitting the delayed diagnostic context.

Copilot uses AI. Check for mistakes.
else seq.stride
),
)
for dim, seq in node.index.items()
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ensure_symbols_positive ignores its seqs argument and always iterates over node.index. As a result, when you call it for inferred_index, it will still rebuild sequences from node.index, so the comparison can incorrectly pass/fail (and substitutions may not be applied to the inferred sequences at all). Iterate over seqs.items() in the returned dict comprehension instead of node.index.items().

Suggested change
for dim, seq in node.index.items()
for dim, seq in seqs.items()

Copilot uses AI. Check for mistakes.
@ftynse ftynse changed the title WIP: priority-based index propagation [water] priority-based index expression propagation Mar 6, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants