[water] give unmapped dimensions a default index expression#1066
[water] give unmapped dimensions a default index expression#1066
Conversation
When no constraintsa are provided for a symbol indexing a value and it is not entangled with hardware cosntraints via, e.g, MmaOp mapping, it needs to have some index expression. Give it the default index expression with start=0, step=1 and stride=1, assuming the dimension will be fully unrolled by the expansion pass later on. This is consistent with pywave behavior, though inconsistencies may be discovered later depending on how the factors are computed for expansion. In general, the expansion process likely shouldn't be dissociated from index expression propagation. Signed-off-by: Alex Zinenko <git@ozinenko.com>
There was a problem hiding this comment.
Pull request overview
Adds a default (start=0, step=1, stride=1) index expression for dimensions that otherwise end up without any mapping/constraints during Wave index-expr inference (notably for MMA), aligning behavior with pywave.
Changes:
- Update Wave index-expression propagation to synthesize a default mapping for symbols with no constraints and no existing mapping.
- Add an MLIR regression test covering an unmapped/batch dimension in a 3D MMA case.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
water/lib/Dialect/Wave/IR/WaveOps.cpp |
Ensures unconstrained + unmapped symbols get a default (0,1,1) index mapping during constraint mixing. |
water/test/Dialect/Wave/infer-index-exprs.mlir |
Adds a test asserting default index exprs are present for an unmapped dimension (B) across MMA operands/results. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // If no other mapping is present, default to (start=0, step=1, stride=1), | ||
| // assuming this will be replicated enough times by the expansion pass. | ||
| // When some mapping is present, we don't need to do anything as it would | ||
| // be equivalent to adding 0 to the start of the mapping and taking its | ||
| // step and stride. | ||
| if (!mapping) { | ||
| symbolMappings.emplace_back( | ||
| symbol.getName(), | ||
| wave::WaveIndexMappingAttr::get( | ||
| where->getContext(), /*symbols=*/{}, | ||
| AffineMap::get(/*dimCount=*/0, /*numSymbols=*/0, | ||
| getAffineConstantExpr(0, where->getContext())), | ||
| AffineMap::get(/*dimCount=*/0, /*numSymbols=*/0, | ||
| getAffineConstantExpr(1, where->getContext())), | ||
| AffineMap::get(/*dimCount=*/0, /*numSymbols=*/0, | ||
| getAffineConstantExpr(1, where->getContext())))); |
There was a problem hiding this comment.
This re-creates three constant AffineMaps (0/1/1) inside the per-symbol loop. Since the file already has appendDefaultIndexMapping() that builds reusable zero/one maps once, consider reusing that helper (or factoring out a shared default-mapping builder) to avoid duplication and repeated AffineMap construction in this hot-ish path.
martin-luecke
left a comment
There was a problem hiding this comment.
LGTM, but I think there is opportunity to reuse existing code instead
| appendDefaultIndexMapping(MLIRContext *context, | ||
| llvm::SmallVectorImpl<NamedAttribute> &symbolMappings, | ||
| ArrayRef<wave::WaveSymbolAttr> indexingSymbols) { | ||
|
|
||
| auto zero = AffineMap::get(/*dimCount=*/0, /*numSymbols=*/0, | ||
| getAffineConstantExpr(0, context)); | ||
| auto one = AffineMap::get(/*dimCount=*/0, /*numSymbols=*/0, | ||
| getAffineConstantExpr(1, context)); | ||
|
|
||
| for (wave::WaveSymbolAttr symbol : indexingSymbols) { | ||
| symbolMappings.emplace_back( | ||
| symbol.getName(), | ||
| wave::WaveIndexMappingAttr::get(context, {}, zero, one, one)); | ||
| } |
There was a problem hiding this comment.
We could reuse this bit of code above instead of the newly added code.
| } | ||
| } | ||
|
|
||
| // Append index mappings with offset=0, size=1 and stride=1 to the |
There was a problem hiding this comment.
While we are touching this, we should update the comment:
| // Append index mappings with start=0, step=1, stride=1 to the |
When no constraintsa are provided for a symbol indexing a value and it is not entangled with hardware cosntraints via, e.g, MmaOp mapping, it needs to have some index expression. Give it the default index expression with start=0, step=1 and stride=1, assuming the dimension will be fully unrolled by the expansion pass later on. This is consistent with pywave behavior, though inconsistencies may be discovered later depending on how the factors are computed for expansion.
In general, the expansion process likely shouldn't be dissociated from index expression propagation.