-
Notifications
You must be signed in to change notification settings - Fork 28
[water] Support variadic reduction ops in Water dialect and add corresponding simplification pass #1053
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[water] Support variadic reduction ops in Water dialect and add corresponding simplification pass #1053
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -177,4 +177,28 @@ def WaterWaveResolveDistributedAllocationsPass | |
| ]; | ||
| } | ||
|
|
||
| def WaterWaveExpandVariadicReductionsPass | ||
| : Pass<"water-wave-expand-variadic-reductions"> { | ||
| let summary = "Expand variadic reduction inputs into chained single-input " | ||
| "reductions"; | ||
| let description = [{ | ||
| Rewrites reduction operations (e.g. wave.sum, wave.max_element) that have | ||
| multiple inputs into chains of single-input reductions. For example: | ||
|
|
||
| %r = wave.sum %a, %b, %c init(%init) <scope> | ||
|
|
||
| becomes: | ||
|
|
||
| %0 = wave.sum %a init(%init) <scope> | ||
| %1 = wave.sum %b init(%0) <scope> | ||
| %r = wave.sum %c init(%1) <scope> | ||
|
|
||
| Variadic inputs arise from the Python-side graph expansion pass, which | ||
| tiles the reduction dimension into multiple slices. The Wave dialect | ||
| supports variadic inputs for faithful roundtripping with the Python | ||
| representation. This pass normalizes them before lowering, which | ||
| requires single-input reductions. | ||
| }]; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you document the differences between this pass and its python counterpart? |
||
| } | ||
|
|
||
| #endif // WATER_DIALECT_WAVE_TRANSFORMS_PASSES | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| // Copyright 2026 The Wave Authors | ||
| // | ||
| // Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
|
||
| #include "mlir/IR/PatternMatch.h" | ||
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
| #include "water/Dialect/Wave/IR/WaveOps.h" | ||
| #include "water/Dialect/Wave/Transforms/Passes.h" | ||
|
|
||
| #define DEBUG_TYPE "wave-expand-variadic-reductions" | ||
|
|
||
| namespace wave { | ||
| #define GEN_PASS_DEF_WATERWAVEEXPANDVARIADICREDUCTIONSPASS | ||
| #include "water/Dialect/Wave/Transforms/Passes.h.inc" | ||
| } // namespace wave | ||
|
|
||
| using namespace mlir; | ||
| using namespace wave; | ||
|
|
||
| namespace { | ||
|
|
||
| /// Expand a variadic reduction with N inputs into N chained single-input | ||
| /// reductions: | ||
| /// %r = wave.sum %a, %b, %c init(%init) <scope> | ||
| /// becomes: | ||
| /// %0 = wave.sum %a init(%init) <scope> | ||
| /// %1 = wave.sum %b init(%0) <scope> | ||
| /// %r = wave.sum %c init(%1) <scope> | ||
| template <typename ReductionOp> | ||
| struct ExpandVariadicReduction : public OpRewritePattern<ReductionOp> { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have a trait for reductions, would it make sense to make this |
||
| using OpRewritePattern<ReductionOp>::OpRewritePattern; | ||
|
|
||
| LogicalResult matchAndRewrite(ReductionOp op, | ||
| PatternRewriter &rewriter) const override { | ||
| OperandRange inputs = op.getInputs(); | ||
| if (inputs.size() <= 1) | ||
| return failure(); | ||
|
|
||
| Value acc = op.getInit(); | ||
| Value result; | ||
| for (Value input : inputs) { | ||
| auto newOp = ReductionOp::create( | ||
| rewriter, op.getLoc(), op.getResult().getType(), input, acc, | ||
| op.getScopeAttr(), op.getAxisAttr(), op.getIndexAttr()); | ||
| result = newOp.getResult(); | ||
| acc = result; | ||
| } | ||
| rewriter.replaceOp(op, result); | ||
| return success(); | ||
| } | ||
| }; | ||
|
|
||
| struct ExpandVariadicReductions | ||
| : public wave::impl::WaterWaveExpandVariadicReductionsPassBase< | ||
| ExpandVariadicReductions> { | ||
| void runOnOperation() override { | ||
| RewritePatternSet patterns(&getContext()); | ||
| patterns.add<ExpandVariadicReduction<SumOp>, | ||
| ExpandVariadicReduction<MaxElementOp>>(&getContext()); | ||
| if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) | ||
| return signalPassFailure(); | ||
| } | ||
| }; | ||
|
|
||
| } // namespace | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| // RUN: water-opt %s --water-wave-expand-variadic-reductions --split-input-file | FileCheck %s | ||
|
|
||
| // Variadic reductions are expanded into chained single-input reductions. | ||
|
|
||
| // CHECK-LABEL: @expand_variadic_sum | ||
| // CHECK-SAME: (%[[A:.*]]: {{.*}}, %[[B:.*]]: {{.*}}, %[[C:.*]]: {{.*}}, %[[INIT:.*]]: {{.*}}) | ||
| func.func @expand_variadic_sum(%a: !wave.tensor<[@N, @M] of f32>, %b: !wave.tensor<[@N, @M] of f32>, %c: !wave.tensor<[@N, @M] of f32>, %init: !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32> { | ||
| // CHECK: %[[R0:.*]] = wave.sum %[[A]] init(%[[INIT]]) <warp> | ||
| // CHECK: %[[R1:.*]] = wave.sum %[[B]] init(%[[R0]]) <warp> | ||
| // CHECK: %[[R2:.*]] = wave.sum %[[C]] init(%[[R1]]) <warp> | ||
| // CHECK: return %[[R2]] | ||
| %result = wave.sum %a, %b, %c init(%init) <warp> : (!wave.tensor<[@N, @M] of f32>, !wave.tensor<[@N, @M] of f32>, !wave.tensor<[@N, @M] of f32>, !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32> | ||
| return %result : !wave.tensor<[@N] of f32> | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: @expand_variadic_max_element | ||
| // CHECK-SAME: (%[[A:.*]]: {{.*}}, %[[B:.*]]: {{.*}}, %[[INIT:.*]]: {{.*}}) | ||
| func.func @expand_variadic_max_element(%a: !wave.tensor<[@N, @M] of f32>, %b: !wave.tensor<[@N, @M] of f32>, %init: !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32> { | ||
| // CHECK: %[[R0:.*]] = wave.max_element %[[A]] init(%[[INIT]]) <warp> | ||
| // CHECK: %[[R1:.*]] = wave.max_element %[[B]] init(%[[R0]]) <warp> | ||
| // CHECK: return %[[R1]] | ||
| %result = wave.max_element %a, %b init(%init) <warp> : (!wave.tensor<[@N, @M] of f32>, !wave.tensor<[@N, @M] of f32>, !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32> | ||
| return %result : !wave.tensor<[@N] of f32> | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // Axis attribute is preserved on each chained reduction. | ||
| // CHECK-LABEL: @expand_preserves_axis | ||
| // CHECK-SAME: (%[[A:.*]]: {{.*}}, %[[B:.*]]: {{.*}}, %[[INIT:.*]]: {{.*}}) | ||
| func.func @expand_preserves_axis(%a: !wave.tensor<any of f32>, %b: !wave.tensor<any of f32>, %init: !wave.tensor<any of f32>) -> !wave.tensor<any of f32> { | ||
| // CHECK: %[[R0:.*]] = wave.sum %[[A]] init(%[[INIT]]) along @K <warp> | ||
| // CHECK: %[[R1:.*]] = wave.sum %[[B]] init(%[[R0]]) along @K <warp> | ||
| // CHECK: return %[[R1]] | ||
| %result = wave.sum %a, %b init(%init) along @K <warp> : (!wave.tensor<any of f32>, !wave.tensor<any of f32>, !wave.tensor<any of f32>) -> !wave.tensor<any of f32> | ||
| return %result : !wave.tensor<any of f32> | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // Index attribute is preserved on each chained reduction. | ||
| // CHECK-LABEL: @expand_preserves_index | ||
| // CHECK-SAME: (%[[A:.*]]: {{.*}}, %[[B:.*]]: {{.*}}, %[[INIT:.*]]: {{.*}}) | ||
| func.func @expand_preserves_index(%a: !wave.tensor<[@N, @M] of f32>, %b: !wave.tensor<[@N, @M] of f32>, %init: !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32> { | ||
| // CHECK: %[[R0:.*]] = wave.sum %[[A]] init(%[[INIT]]) <warp> index [{N : <[] -> (42, 1, 1)>}] | ||
| // CHECK: %[[R1:.*]] = wave.sum %[[B]] init(%[[R0]]) <warp> index [{N : <[] -> (42, 1, 1)>}] | ||
| // CHECK: return %[[R1]] | ||
| %result = wave.sum %a, %b init(%init) <warp> index [{N : <[] -> (42, 1, 1)>}] : (!wave.tensor<[@N, @M] of f32>, !wave.tensor<[@N, @M] of f32>, !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32> | ||
| return %result : !wave.tensor<[@N] of f32> | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // Single-input reductions are left unchanged. | ||
| // CHECK-LABEL: @single_input_sum_unchanged | ||
| // CHECK-SAME: (%[[A:.*]]: {{.*}}, %[[INIT:.*]]: {{.*}}) | ||
| func.func @single_input_sum_unchanged(%a: !wave.tensor<[@N, @M] of f32>, %init: !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32> { | ||
| // CHECK: wave.sum %[[A]] init(%[[INIT]]) <warp> | ||
| // CHECK-NOT: wave.sum | ||
| %result = wave.sum %a init(%init) <warp> : (!wave.tensor<[@N, @M] of f32>, !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32> | ||
| return %result : !wave.tensor<[@N] of f32> | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a
TODO(#889)here, I think I may have a solution