Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions lit_tests/kernel/wave/mlir_to_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
compare_hardware_constraints_for_mlir_roundtrip,
)
from wave_lang.kernel.ops.wave_ops import get_custom, Placeholder
from wave_lang.kernel.wave.compile import build_graph_passes
from wave_lang.kernel._support.indexing import IndexingContext

# Keep emitter subprocesses alive for the entire test file instead of
# spawning fresh ones per call (~2s import overhead each).
Expand Down Expand Up @@ -425,3 +427,88 @@ def cyclic_write(

# CHECK: OK: mapping roundtrip
print("OK: mapping roundtrip")


# CHECK-LABEL: mlir_to_fx_reduction_roundtrip
@run_test
def mlir_to_fx_reduction_roundtrip():
"""Test MLIR roundtrip for sum and max reductions.

Stops compilation before decompose_reduce_ops so the trace still
contains wave.sum / wave.max_element ops.
"""
constraints = [
wave.WorkgroupConstraint(M, BLOCK_M, 0),
wave.WorkgroupConstraint(N, BLOCK_N, 1),
wave.WaveConstraint(M, sympy.floor(BLOCK_M / 2)),
wave.WaveConstraint(N, sympy.floor(BLOCK_N / 2)),
wave.HardwareConstraint(
threads_per_wave=64, vector_shapes={M: BLOCK_M, N: BLOCK_N}
),
]

subs = {
BLOCK_M: 64,
BLOCK_N: 64,
M: 128,
N: 128,
}

def _assert_reduction_roundtrip(kernel, label):
options = WaveCompileOptions(subs=subs, compile_to_mlir=True)
with IndexingContext() as idxc:
idxc.set_subs(options.subs)
kernel.initialize_wave_constraints()
kernel.initialize_symbolic_constraints()
kernel.initialize_workgroup_constraints()
trace = kernel._trace(
location_capture_config=options.location_capture_config
)
graph_passes = build_graph_passes(kernel, trace, options)
for p in graph_passes:
name = getattr(p, "__name__", "") or getattr(
getattr(p, "func", None), "__name__", ""
)
if name == "decompose_reduce_ops":
break
p()

mlir_text, diagnostics, _ = emitter.emit_wave_dialect(
trace, kernel.constraints, options
)
errors = error_diagnostics(diagnostics)
assert errors == [], f"[{label}] unexpected emit errors: {errors}"

fx_trace, fx_constraints, fx_options, fx_diags = emitter.mlir_to_fx(mlir_text)
errors = error_diagnostics(fx_diags)
assert errors == [], f"[{label}] unexpected import errors: {errors}"

assert_traces_equivalent(trace, fx_trace, subs=options.subs)
print(f" {label}: OK")

@wave.wave(constraints)
def sum_kernel(
a: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[M, GLOBAL_ADDRESS_SPACE, tkl.f16],
):
res = wave.read(a)
init = wave.read(c)
res = wave.sum(res, init, dim=N)
wave.write(res, c)

_assert_reduction_roundtrip(sum_kernel, "sum roundtrip")

@wave.wave(constraints)
def max_kernel(
a: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[M, GLOBAL_ADDRESS_SPACE, tkl.f16],
):
res = wave.read(a)
init = wave.read(c)
res = wave.max(res, init, dim=N)
wave.write(res, c)

_assert_reduction_roundtrip(max_kernel, "max roundtrip")

# CHECK: sum roundtrip: OK
# CHECK: max roundtrip: OK
29 changes: 20 additions & 9 deletions water/include/water/Dialect/Wave/IR/WaveInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ class ReductionTypeInferenceOpTrait
auto concrete = llvm::cast<OpTy>(this->getOperation());
wave::WaveSymbolAttr axis = concrete.getReducedSymbol();
unsigned initOperandNum = concrete.getInitMutable().getOperandNumber();
unsigned inputOperandNum = concrete.getInputMutable().getOperandNumber();
// Use the first input for type propagation.
unsigned inputOperandNum = concrete.getInputs().getBeginOperandIndex();
return detail::propagateReductionTypesForward(
axis, initOperandNum, inputOperandNum, operandTypes, resultTypes, errs);
}
Expand All @@ -188,15 +189,17 @@ class ReductionTypeInferenceOpTrait
auto concrete = llvm::cast<OpTy>(this->getOperation());
wave::WaveSymbolAttr axis = concrete.getReducedSymbol();
unsigned initOperandNum = concrete.getInitMutable().getOperandNumber();
unsigned inputOperandNum = concrete.getInputMutable().getOperandNumber();
// Use the first input for type propagation.
unsigned inputOperandNum = concrete.getInputs().getBeginOperandIndex();
return detail::propagateReductionTypesBackward(
axis, initOperandNum, inputOperandNum, operandTypes, resultTypes, errs);
}

llvm::LogicalResult finalizeTypeInference() {
auto concrete = llvm::cast<OpTy>(this->getOperation());
if (detail::isReductionTypeInferenceComplete(
concrete.getInput(), concrete.getInit(), concrete.getResult()))
if (detail::isReductionTypeInferenceComplete(concrete.getInputs().front(),
concrete.getInit(),
concrete.getResult()))
concrete.removeAxisAttr();
return llvm::success();
}
Expand Down Expand Up @@ -732,16 +735,24 @@ llvm::LogicalResult verifyReductionOperation(mlir::Operation *op,
template <typename OpTy>
static inline WaveSymbolAttr getReducedSymbol(OpTy op) {
return wave::detail::getReducedSymbol(op, op.getAxisAttr(),
op.getInput().getType());
op.getInputs().front().getType());
}

// Common verification logic for reduction operations. We expect the input type
// to have one more dimension that precisely matches the reduction axis.
// Common verification logic for reduction operations. All inputs must have the
// same type; we verify against the first input.
template <typename OpTy>
static inline llvm::LogicalResult verifyReductionOperation(OpTy op) {
if (op.getInputs().empty())
return op.emitOpError("expected at least one input");
mlir::Type firstInputType = op.getInputs().front().getType();
for (mlir::Value input : op.getInputs().drop_front()) {
if (input.getType() != firstInputType)
return op.emitOpError() << "all inputs must have the same type, but got "
<< firstInputType << " and " << input.getType();
}
return wave::detail::verifyReductionOperation(
op, op.getInput().getType(), op.getInit().getType(),
op.getResult().getType(), op.getAxisAttr());
op, firstInputType, op.getInit().getType(), op.getResult().getType(),
op.getAxisAttr());
}
} // namespace detail

Expand Down
6 changes: 4 additions & 2 deletions water/include/water/Dialect/Wave/IR/WaveOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,10 @@ class ReductionWaveOp<string mnemonic>
WaveElementsPerThreadOpInterface, ReductionElementsPerThreadOpTrait,
RequiresSidewaysBackwardPropagationOpTrait, WaveReductionOpTrait]>,
WaveArithmeticOpDoc {
// We cannot use Variadic<WaveTensorInRegister> because Variadic requires
// a Type, not a TypeConstraint (same pattern as wave.iterate's iter_args).
Comment on lines +109 to +110
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a TODO(#889) here, I think I may have a solution

let arguments = !con((ins
Arg<WaveTensorInRegister, "Input tensor to reduce">:$input,
Arg<Variadic<WaveIterableType>, "Tensor(s) to reduce">:$inputs,
Arg<WaveTensorInRegister, "Initial value for the reduction">:$init,
Arg<WaveReductionScopeAttr, "Scope of the reduction">:$scope,
Arg<OptionalAttr<WaveSymbolAttr>, "Reduction axis">:$axis
Expand All @@ -118,7 +120,7 @@ class ReductionWaveOp<string mnemonic>
);

let assemblyFormat =
"$input `init` `(` $init `)` (`along` custom<SingleSymbol>($axis)^)? "
"$inputs `init` `(` $init `)` (`along` custom<SingleSymbol>($axis)^)? "
"$scope " # commonArgumentsSyntax # " attr-dict `:`"
"functional-type(operands, results)";
}
Expand Down
24 changes: 24 additions & 0 deletions water/include/water/Dialect/Wave/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -177,4 +177,28 @@ def WaterWaveResolveDistributedAllocationsPass
];
}

def WaterWaveExpandVariadicReductionsPass
: Pass<"water-wave-expand-variadic-reductions"> {
let summary = "Expand variadic reduction inputs into chained single-input "
"reductions";
let description = [{
Rewrites reduction operations (e.g. wave.sum, wave.max_element) that have
multiple inputs into chains of single-input reductions. For example:

%r = wave.sum %a, %b, %c init(%init) <scope>

becomes:

%0 = wave.sum %a init(%init) <scope>
%1 = wave.sum %b init(%0) <scope>
%r = wave.sum %c init(%1) <scope>

Variadic inputs arise from the Python-side graph expansion pass, which
tiles the reduction dimension into multiple slices. The Wave dialect
supports variadic inputs for faithful roundtripping with the Python
representation. This pass normalizes them before lowering, which
requires single-input reductions.
}];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you document the differences between this pass and its python counterpart?

}

#endif // WATER_DIALECT_WAVE_TRANSFORMS_PASSES
1 change: 1 addition & 0 deletions water/lib/Dialect/Wave/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRWaveTransforms
DetectNormalForms.cpp
ExpandVariadicReductions.cpp
InferTypes.cpp
LoweringPatterns.cpp
LowerReadWriteOps.cpp
Expand Down
67 changes: 67 additions & 0 deletions water/lib/Dialect/Wave/Transforms/ExpandVariadicReductions.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright 2026 The Wave Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "water/Dialect/Wave/IR/WaveOps.h"
#include "water/Dialect/Wave/Transforms/Passes.h"

#define DEBUG_TYPE "wave-expand-variadic-reductions"

namespace wave {
#define GEN_PASS_DEF_WATERWAVEEXPANDVARIADICREDUCTIONSPASS
#include "water/Dialect/Wave/Transforms/Passes.h.inc"
} // namespace wave

using namespace mlir;
using namespace wave;

namespace {

/// Expand a variadic reduction with N inputs into N chained single-input
/// reductions:
/// %r = wave.sum %a, %b, %c init(%init) <scope>
/// becomes:
/// %0 = wave.sum %a init(%init) <scope>
/// %1 = wave.sum %b init(%0) <scope>
/// %r = wave.sum %c init(%1) <scope>
template <typename ReductionOp>
struct ExpandVariadicReduction : public OpRewritePattern<ReductionOp> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a trait for reductions, would it make sense to make this OpTraitRewritePattern? Very open for arguments here since traits don't provide named accessors... Related discussion here #992 (comment).

using OpRewritePattern<ReductionOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ReductionOp op,
PatternRewriter &rewriter) const override {
OperandRange inputs = op.getInputs();
if (inputs.size() <= 1)
return failure();

Value acc = op.getInit();
Value result;
for (Value input : inputs) {
auto newOp = ReductionOp::create(
rewriter, op.getLoc(), op.getResult().getType(), input, acc,
op.getScopeAttr(), op.getAxisAttr(), op.getIndexAttr());
result = newOp.getResult();
acc = result;
}
rewriter.replaceOp(op, result);
return success();
}
};

struct ExpandVariadicReductions
: public wave::impl::WaterWaveExpandVariadicReductionsPassBase<
ExpandVariadicReductions> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
patterns.add<ExpandVariadicReduction<SumOp>,
ExpandVariadicReduction<MaxElementOp>>(&getContext());
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
};

} // namespace
9 changes: 8 additions & 1 deletion water/lib/Dialect/Wave/Transforms/LoweringPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1363,9 +1363,16 @@ class ReductionOpLoweringPattern : public OpConversionPattern<WaveOp> {
"unsupported reduction kind");
// Expect PropagateElementsPerThread pass to have run, converting
// WaveTensorType results to VectorType.

// Variadic reductions must be expanded to single-input form by the
// water-wave-expand-variadic-reductions pass before lowering.
if (adaptor.getInputs().size() != 1)
return op.emitOpError("expected single input, run "
"--water-wave-expand-variadic-reductions first");

Location loc = op.getLoc();

Value input = adaptor.getInput();
Value input = adaptor.getInputs().front();
Value init = adaptor.getInit();
bool isBlockReduction = op.getScope() == wave::WaveReductionScope::Block;

Expand Down
64 changes: 64 additions & 0 deletions water/test/Dialect/Wave/expand-variadic-reductions.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// RUN: water-opt %s --water-wave-expand-variadic-reductions --split-input-file | FileCheck %s

// Variadic reductions are expanded into chained single-input reductions.

// CHECK-LABEL: @expand_variadic_sum
// CHECK-SAME: (%[[A:.*]]: {{.*}}, %[[B:.*]]: {{.*}}, %[[C:.*]]: {{.*}}, %[[INIT:.*]]: {{.*}})
func.func @expand_variadic_sum(%a: !wave.tensor<[@N, @M] of f32>, %b: !wave.tensor<[@N, @M] of f32>, %c: !wave.tensor<[@N, @M] of f32>, %init: !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32> {
// CHECK: %[[R0:.*]] = wave.sum %[[A]] init(%[[INIT]]) <warp>
// CHECK: %[[R1:.*]] = wave.sum %[[B]] init(%[[R0]]) <warp>
// CHECK: %[[R2:.*]] = wave.sum %[[C]] init(%[[R1]]) <warp>
// CHECK: return %[[R2]]
%result = wave.sum %a, %b, %c init(%init) <warp> : (!wave.tensor<[@N, @M] of f32>, !wave.tensor<[@N, @M] of f32>, !wave.tensor<[@N, @M] of f32>, !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32>
return %result : !wave.tensor<[@N] of f32>
}

// -----

// CHECK-LABEL: @expand_variadic_max_element
// CHECK-SAME: (%[[A:.*]]: {{.*}}, %[[B:.*]]: {{.*}}, %[[INIT:.*]]: {{.*}})
func.func @expand_variadic_max_element(%a: !wave.tensor<[@N, @M] of f32>, %b: !wave.tensor<[@N, @M] of f32>, %init: !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32> {
// CHECK: %[[R0:.*]] = wave.max_element %[[A]] init(%[[INIT]]) <warp>
// CHECK: %[[R1:.*]] = wave.max_element %[[B]] init(%[[R0]]) <warp>
// CHECK: return %[[R1]]
%result = wave.max_element %a, %b init(%init) <warp> : (!wave.tensor<[@N, @M] of f32>, !wave.tensor<[@N, @M] of f32>, !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32>
return %result : !wave.tensor<[@N] of f32>
}

// -----

// Axis attribute is preserved on each chained reduction.
// CHECK-LABEL: @expand_preserves_axis
// CHECK-SAME: (%[[A:.*]]: {{.*}}, %[[B:.*]]: {{.*}}, %[[INIT:.*]]: {{.*}})
func.func @expand_preserves_axis(%a: !wave.tensor<any of f32>, %b: !wave.tensor<any of f32>, %init: !wave.tensor<any of f32>) -> !wave.tensor<any of f32> {
// CHECK: %[[R0:.*]] = wave.sum %[[A]] init(%[[INIT]]) along @K <warp>
// CHECK: %[[R1:.*]] = wave.sum %[[B]] init(%[[R0]]) along @K <warp>
// CHECK: return %[[R1]]
%result = wave.sum %a, %b init(%init) along @K <warp> : (!wave.tensor<any of f32>, !wave.tensor<any of f32>, !wave.tensor<any of f32>) -> !wave.tensor<any of f32>
return %result : !wave.tensor<any of f32>
}

// -----

// Index attribute is preserved on each chained reduction.
// CHECK-LABEL: @expand_preserves_index
// CHECK-SAME: (%[[A:.*]]: {{.*}}, %[[B:.*]]: {{.*}}, %[[INIT:.*]]: {{.*}})
func.func @expand_preserves_index(%a: !wave.tensor<[@N, @M] of f32>, %b: !wave.tensor<[@N, @M] of f32>, %init: !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32> {
// CHECK: %[[R0:.*]] = wave.sum %[[A]] init(%[[INIT]]) <warp> index [{N : <[] -> (42, 1, 1)>}]
// CHECK: %[[R1:.*]] = wave.sum %[[B]] init(%[[R0]]) <warp> index [{N : <[] -> (42, 1, 1)>}]
// CHECK: return %[[R1]]
%result = wave.sum %a, %b init(%init) <warp> index [{N : <[] -> (42, 1, 1)>}] : (!wave.tensor<[@N, @M] of f32>, !wave.tensor<[@N, @M] of f32>, !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32>
return %result : !wave.tensor<[@N] of f32>
}

// -----

// Single-input reductions are left unchanged.
// CHECK-LABEL: @single_input_sum_unchanged
// CHECK-SAME: (%[[A:.*]]: {{.*}}, %[[INIT:.*]]: {{.*}})
func.func @single_input_sum_unchanged(%a: !wave.tensor<[@N, @M] of f32>, %init: !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32> {
// CHECK: wave.sum %[[A]] init(%[[INIT]]) <warp>
// CHECK-NOT: wave.sum
%result = wave.sum %a init(%init) <warp> : (!wave.tensor<[@N, @M] of f32>, !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32>
return %result : !wave.tensor<[@N] of f32>
}
8 changes: 8 additions & 0 deletions water/test/Dialect/Wave/ops-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,14 @@ func.func @reduction_init_and_result_contain_axis(%input: !wave.tensor<any of f3

// -----

func.func @variadic_reduction_mismatched_input_types(%a: !wave.tensor<[@N, @M] of f32>, %b: !wave.tensor<[@N, @K] of f32>, %init: !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32> {
// expected-error @below {{all inputs must have the same type, but got '!wave.tensor<[@N, @M] of f32>' and '!wave.tensor<[@N, @K] of f32>'}}
%result = wave.sum %a, %b init(%init) <warp> : (!wave.tensor<[@N, @M] of f32>, !wave.tensor<[@N, @K] of f32>, !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32>
return %result : !wave.tensor<[@N] of f32>
}

// -----

func.func @broadcast_source_dim_not_in_result(%arg0: !wave.tensor<[@M, @N] of f32, <register>>) {
// Source has [@M, @N], result has [@M, @P, @K] - N is missing (replaced by P).
// expected-error @below {{source dimension 'N' not found in result shape}}
Expand Down
Loading