Skip to content

[water] Support variadic reduction ops in Water dialect and add corresponding simplification pass#1053

Open
martin-luecke wants to merge 1 commit intomainfrom
users/martin/multi_operand_reduce
Open

[water] Support variadic reduction ops in Water dialect and add corresponding simplification pass#1053
martin-luecke wants to merge 1 commit intomainfrom
users/martin/multi_operand_reduce

Conversation

@martin-luecke
Copy link
Contributor

Extends the Water dialect reduction ops (wave.sum, wave.max_element) to accept variadic inputs, matching the PyWave representation, where expand_graph tiles reduction inputs into a list of slices. This simplifies FX <-> MLIR roundtrips by allowing the dialect to directly represent the intermediate form, rather than requiring the Python side to decompose reductions before emission, track which reductions stem from this, and fuse them again for the roundtrip.

A new ExpandVariadicReductions pass chains N variadic inputs into N single-input reductions, each feeding its result as the next accumulator — a partial port of the logic in PyWave's decompose_reduce_ops pass. Both the Water emitter and FX importer have been updated to handle variadic forms in both directions.
A normal-form annotation for expanded reductions could be added to indicate where in the pipeline single-input reductions are expected, though currently this would only be relevant for codegen, I think.

Signed-off-by: Martin Lücke <martin.luecke@amd.com>
@martin-luecke martin-luecke requested a review from ftynse March 5, 2026 21:50
@martin-luecke martin-luecke changed the title [water] Support variadic reduction ops in Water dialect and corresponding simplification pass [water] Support variadic reduction ops in Water dialect and add corresponding simplification pass Mar 5, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant