From 2bcc7acb9a636e5ce2a3f7f14ad5705287a41cba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20L=C3=BCcke?= Date: Thu, 5 Mar 2026 22:23:01 +0100 Subject: [PATCH 1/2] [water] Support variadic reduction ops in Water dialect MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Martin Lücke --- lit_tests/kernel/wave/mlir_to_fx.py | 87 +++++++++++++++++++ .../water/Dialect/Wave/IR/WaveInterfaces.h | 29 +++++-- .../include/water/Dialect/Wave/IR/WaveOps.td | 6 +- .../water/Dialect/Wave/Transforms/Passes.td | 24 +++++ .../Dialect/Wave/Transforms/CMakeLists.txt | 1 + .../Transforms/ExpandVariadicReductions.cpp | 67 ++++++++++++++ .../Wave/Transforms/LoweringPatterns.cpp | 9 +- .../Wave/expand-variadic-reductions.mlir | 64 ++++++++++++++ water/test/Dialect/Wave/ops-invalid.mlir | 8 ++ water/test/Dialect/Wave/ops.mlir | 19 ++++ wave_lang/kernel/ops/wave_ops.py | 3 +- .../kernel/wave/mlir_converter/fx_emitter.py | 52 +++++++++++ .../wave/mlir_converter/water_emitter.py | 27 ++++-- 13 files changed, 378 insertions(+), 18 deletions(-) create mode 100644 water/lib/Dialect/Wave/Transforms/ExpandVariadicReductions.cpp create mode 100644 water/test/Dialect/Wave/expand-variadic-reductions.mlir diff --git a/lit_tests/kernel/wave/mlir_to_fx.py b/lit_tests/kernel/wave/mlir_to_fx.py index a234cb641b..9eb57f2c73 100644 --- a/lit_tests/kernel/wave/mlir_to_fx.py +++ b/lit_tests/kernel/wave/mlir_to_fx.py @@ -26,6 +26,8 @@ compare_hardware_constraints_for_mlir_roundtrip, ) from wave_lang.kernel.ops.wave_ops import get_custom, Placeholder +from wave_lang.kernel.wave.compile import build_graph_passes +from wave_lang.kernel._support.indexing import IndexingContext # Keep emitter subprocesses alive for the entire test file instead of # spawning fresh ones per call (~2s import overhead each). @@ -425,3 +427,88 @@ def cyclic_write( # CHECK: OK: mapping roundtrip print("OK: mapping roundtrip") + + +# CHECK-LABEL: mlir_to_fx_reduction_roundtrip +@run_test +def mlir_to_fx_reduction_roundtrip(): + """Test MLIR roundtrip for sum and max reductions. + + Stops compilation before decompose_reduce_ops so the trace still + contains wave.sum / wave.max_element ops. + """ + constraints = [ + wave.WorkgroupConstraint(M, BLOCK_M, 0), + wave.WorkgroupConstraint(N, BLOCK_N, 1), + wave.WaveConstraint(M, sympy.floor(BLOCK_M / 2)), + wave.WaveConstraint(N, sympy.floor(BLOCK_N / 2)), + wave.HardwareConstraint( + threads_per_wave=64, vector_shapes={M: BLOCK_M, N: BLOCK_N} + ), + ] + + subs = { + BLOCK_M: 64, + BLOCK_N: 64, + M: 128, + N: 128, + } + + def _assert_reduction_roundtrip(kernel, label): + options = WaveCompileOptions(subs=subs, compile_to_mlir=True) + with IndexingContext() as idxc: + idxc.set_subs(options.subs) + kernel.initialize_wave_constraints() + kernel.initialize_symbolic_constraints() + kernel.initialize_workgroup_constraints() + trace = kernel._trace( + location_capture_config=options.location_capture_config + ) + graph_passes = build_graph_passes(kernel, trace, options) + for p in graph_passes: + name = getattr(p, "__name__", "") or getattr( + getattr(p, "func", None), "__name__", "" + ) + if name == "decompose_reduce_ops": + break + p() + + mlir_text, diagnostics, _ = emitter.emit_wave_dialect( + trace, kernel.constraints, options + ) + errors = error_diagnostics(diagnostics) + assert errors == [], f"[{label}] unexpected emit errors: {errors}" + + fx_trace, fx_constraints, fx_options, fx_diags = emitter.mlir_to_fx(mlir_text) + errors = error_diagnostics(fx_diags) + assert errors == [], f"[{label}] unexpected import errors: {errors}" + + assert_traces_equivalent(trace, fx_trace, subs=options.subs) + print(f" {label}: OK") + + @wave.wave(constraints) + def sum_kernel( + a: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, GLOBAL_ADDRESS_SPACE, tkl.f16], + ): + res = wave.read(a) + init = wave.read(c) + res = wave.sum(res, init, dim=N) + wave.write(res, c) + + _assert_reduction_roundtrip(sum_kernel, "sum roundtrip") + + @wave.wave(constraints) + def max_kernel( + a: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, GLOBAL_ADDRESS_SPACE, tkl.f16], + ): + res = wave.read(a) + init = wave.read(c) + res = wave.max(res, init, dim=N) + wave.write(res, c) + + _assert_reduction_roundtrip(max_kernel, "max roundtrip") + + # CHECK: sum roundtrip: OK + # CHECK: max roundtrip: OK diff --git a/water/include/water/Dialect/Wave/IR/WaveInterfaces.h b/water/include/water/Dialect/Wave/IR/WaveInterfaces.h index 2d2373d12f..fe1ad1ab22 100644 --- a/water/include/water/Dialect/Wave/IR/WaveInterfaces.h +++ b/water/include/water/Dialect/Wave/IR/WaveInterfaces.h @@ -176,7 +176,8 @@ class ReductionTypeInferenceOpTrait auto concrete = llvm::cast(this->getOperation()); wave::WaveSymbolAttr axis = concrete.getReducedSymbol(); unsigned initOperandNum = concrete.getInitMutable().getOperandNumber(); - unsigned inputOperandNum = concrete.getInputMutable().getOperandNumber(); + // Use the first input for type propagation. + unsigned inputOperandNum = concrete.getInputs().getBeginOperandIndex(); return detail::propagateReductionTypesForward( axis, initOperandNum, inputOperandNum, operandTypes, resultTypes, errs); } @@ -188,15 +189,17 @@ class ReductionTypeInferenceOpTrait auto concrete = llvm::cast(this->getOperation()); wave::WaveSymbolAttr axis = concrete.getReducedSymbol(); unsigned initOperandNum = concrete.getInitMutable().getOperandNumber(); - unsigned inputOperandNum = concrete.getInputMutable().getOperandNumber(); + // Use the first input for type propagation. + unsigned inputOperandNum = concrete.getInputs().getBeginOperandIndex(); return detail::propagateReductionTypesBackward( axis, initOperandNum, inputOperandNum, operandTypes, resultTypes, errs); } llvm::LogicalResult finalizeTypeInference() { auto concrete = llvm::cast(this->getOperation()); - if (detail::isReductionTypeInferenceComplete( - concrete.getInput(), concrete.getInit(), concrete.getResult())) + if (detail::isReductionTypeInferenceComplete(concrete.getInputs().front(), + concrete.getInit(), + concrete.getResult())) concrete.removeAxisAttr(); return llvm::success(); } @@ -732,16 +735,24 @@ llvm::LogicalResult verifyReductionOperation(mlir::Operation *op, template static inline WaveSymbolAttr getReducedSymbol(OpTy op) { return wave::detail::getReducedSymbol(op, op.getAxisAttr(), - op.getInput().getType()); + op.getInputs().front().getType()); } -// Common verification logic for reduction operations. We expect the input type -// to have one more dimension that precisely matches the reduction axis. +// Common verification logic for reduction operations. All inputs must have the +// same type; we verify against the first input. template static inline llvm::LogicalResult verifyReductionOperation(OpTy op) { + if (op.getInputs().empty()) + return op.emitOpError("expected at least one input"); + mlir::Type firstInputType = op.getInputs().front().getType(); + for (mlir::Value input : op.getInputs().drop_front()) { + if (input.getType() != firstInputType) + return op.emitOpError() << "all inputs must have the same type, but got " + << firstInputType << " and " << input.getType(); + } return wave::detail::verifyReductionOperation( - op, op.getInput().getType(), op.getInit().getType(), - op.getResult().getType(), op.getAxisAttr()); + op, firstInputType, op.getInit().getType(), op.getResult().getType(), + op.getAxisAttr()); } } // namespace detail diff --git a/water/include/water/Dialect/Wave/IR/WaveOps.td b/water/include/water/Dialect/Wave/IR/WaveOps.td index 75ca7002cc..4ca32ae607 100644 --- a/water/include/water/Dialect/Wave/IR/WaveOps.td +++ b/water/include/water/Dialect/Wave/IR/WaveOps.td @@ -106,8 +106,10 @@ class ReductionWaveOp WaveElementsPerThreadOpInterface, ReductionElementsPerThreadOpTrait, RequiresSidewaysBackwardPropagationOpTrait, WaveReductionOpTrait]>, WaveArithmeticOpDoc { + // We cannot use Variadic because Variadic requires + // a Type, not a TypeConstraint (same pattern as wave.iterate's iter_args). let arguments = !con((ins - Arg:$input, + Arg, "Tensor(s) to reduce">:$inputs, Arg:$init, Arg:$scope, Arg, "Reduction axis">:$axis @@ -118,7 +120,7 @@ class ReductionWaveOp ); let assemblyFormat = - "$input `init` `(` $init `)` (`along` custom($axis)^)? " + "$inputs `init` `(` $init `)` (`along` custom($axis)^)? " "$scope " # commonArgumentsSyntax # " attr-dict `:`" "functional-type(operands, results)"; } diff --git a/water/include/water/Dialect/Wave/Transforms/Passes.td b/water/include/water/Dialect/Wave/Transforms/Passes.td index 01f955abb5..f92f607fad 100644 --- a/water/include/water/Dialect/Wave/Transforms/Passes.td +++ b/water/include/water/Dialect/Wave/Transforms/Passes.td @@ -177,4 +177,28 @@ def WaterWaveResolveDistributedAllocationsPass ]; } +def WaterWaveExpandVariadicReductionsPass + : Pass<"water-wave-expand-variadic-reductions"> { + let summary = "Expand variadic reduction inputs into chained single-input " + "reductions"; + let description = [{ + Rewrites reduction operations (e.g. wave.sum, wave.max_element) that have + multiple inputs into chains of single-input reductions. For example: + + %r = wave.sum %a, %b, %c init(%init) + + becomes: + + %0 = wave.sum %a init(%init) + %1 = wave.sum %b init(%0) + %r = wave.sum %c init(%1) + + Variadic inputs arise from the Python-side graph expansion pass, which + tiles the reduction dimension into multiple slices. The Wave dialect + supports variadic inputs for faithful roundtripping with the Python + representation. This pass normalizes them before lowering, which + requires single-input reductions. + }]; +} + #endif // WATER_DIALECT_WAVE_TRANSFORMS_PASSES diff --git a/water/lib/Dialect/Wave/Transforms/CMakeLists.txt b/water/lib/Dialect/Wave/Transforms/CMakeLists.txt index b0f4006358..2f5bbd4252 100644 --- a/water/lib/Dialect/Wave/Transforms/CMakeLists.txt +++ b/water/lib/Dialect/Wave/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRWaveTransforms DetectNormalForms.cpp + ExpandVariadicReductions.cpp InferTypes.cpp LoweringPatterns.cpp LowerReadWriteOps.cpp diff --git a/water/lib/Dialect/Wave/Transforms/ExpandVariadicReductions.cpp b/water/lib/Dialect/Wave/Transforms/ExpandVariadicReductions.cpp new file mode 100644 index 0000000000..65ccbbc041 --- /dev/null +++ b/water/lib/Dialect/Wave/Transforms/ExpandVariadicReductions.cpp @@ -0,0 +1,67 @@ +// Copyright 2026 The Wave Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "water/Dialect/Wave/IR/WaveOps.h" +#include "water/Dialect/Wave/Transforms/Passes.h" + +#define DEBUG_TYPE "wave-expand-variadic-reductions" + +namespace wave { +#define GEN_PASS_DEF_WATERWAVEEXPANDVARIADICREDUCTIONSPASS +#include "water/Dialect/Wave/Transforms/Passes.h.inc" +} // namespace wave + +using namespace mlir; +using namespace wave; + +namespace { + +/// Expand a variadic reduction with N inputs into N chained single-input +/// reductions: +/// %r = wave.sum %a, %b, %c init(%init) +/// becomes: +/// %0 = wave.sum %a init(%init) +/// %1 = wave.sum %b init(%0) +/// %r = wave.sum %c init(%1) +template +struct ExpandVariadicReduction : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ReductionOp op, + PatternRewriter &rewriter) const override { + OperandRange inputs = op.getInputs(); + if (inputs.size() <= 1) + return failure(); + + Value acc = op.getInit(); + Value result; + for (Value input : inputs) { + auto newOp = ReductionOp::create( + rewriter, op.getLoc(), op.getResult().getType(), input, acc, + op.getScopeAttr(), op.getAxisAttr(), op.getIndexAttr()); + result = newOp.getResult(); + acc = result; + } + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct ExpandVariadicReductions + : public wave::impl::WaterWaveExpandVariadicReductionsPassBase< + ExpandVariadicReductions> { + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + patterns.add, + ExpandVariadicReduction>(&getContext()); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace diff --git a/water/lib/Dialect/Wave/Transforms/LoweringPatterns.cpp b/water/lib/Dialect/Wave/Transforms/LoweringPatterns.cpp index ee1099b24b..7259622c33 100644 --- a/water/lib/Dialect/Wave/Transforms/LoweringPatterns.cpp +++ b/water/lib/Dialect/Wave/Transforms/LoweringPatterns.cpp @@ -1363,9 +1363,16 @@ class ReductionOpLoweringPattern : public OpConversionPattern { "unsupported reduction kind"); // Expect PropagateElementsPerThread pass to have run, converting // WaveTensorType results to VectorType. + + // Variadic reductions must be expanded to single-input form by the + // water-wave-expand-variadic-reductions pass before lowering. + if (adaptor.getInputs().size() != 1) + return op.emitOpError("expected single input, run " + "--water-wave-expand-variadic-reductions first"); + Location loc = op.getLoc(); - Value input = adaptor.getInput(); + Value input = adaptor.getInputs().front(); Value init = adaptor.getInit(); bool isBlockReduction = op.getScope() == wave::WaveReductionScope::Block; diff --git a/water/test/Dialect/Wave/expand-variadic-reductions.mlir b/water/test/Dialect/Wave/expand-variadic-reductions.mlir new file mode 100644 index 0000000000..644bb86fd4 --- /dev/null +++ b/water/test/Dialect/Wave/expand-variadic-reductions.mlir @@ -0,0 +1,64 @@ +// RUN: water-opt %s --water-wave-expand-variadic-reductions --split-input-file | FileCheck %s + +// Variadic reductions are expanded into chained single-input reductions. + +// CHECK-LABEL: @expand_variadic_sum +// CHECK-SAME: (%[[A:.*]]: {{.*}}, %[[B:.*]]: {{.*}}, %[[C:.*]]: {{.*}}, %[[INIT:.*]]: {{.*}}) +func.func @expand_variadic_sum(%a: !wave.tensor<[@N, @M] of f32>, %b: !wave.tensor<[@N, @M] of f32>, %c: !wave.tensor<[@N, @M] of f32>, %init: !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32> { + // CHECK: %[[R0:.*]] = wave.sum %[[A]] init(%[[INIT]]) + // CHECK: %[[R1:.*]] = wave.sum %[[B]] init(%[[R0]]) + // CHECK: %[[R2:.*]] = wave.sum %[[C]] init(%[[R1]]) + // CHECK: return %[[R2]] + %result = wave.sum %a, %b, %c init(%init) : (!wave.tensor<[@N, @M] of f32>, !wave.tensor<[@N, @M] of f32>, !wave.tensor<[@N, @M] of f32>, !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32> + return %result : !wave.tensor<[@N] of f32> +} + +// ----- + +// CHECK-LABEL: @expand_variadic_max_element +// CHECK-SAME: (%[[A:.*]]: {{.*}}, %[[B:.*]]: {{.*}}, %[[INIT:.*]]: {{.*}}) +func.func @expand_variadic_max_element(%a: !wave.tensor<[@N, @M] of f32>, %b: !wave.tensor<[@N, @M] of f32>, %init: !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32> { + // CHECK: %[[R0:.*]] = wave.max_element %[[A]] init(%[[INIT]]) + // CHECK: %[[R1:.*]] = wave.max_element %[[B]] init(%[[R0]]) + // CHECK: return %[[R1]] + %result = wave.max_element %a, %b init(%init) : (!wave.tensor<[@N, @M] of f32>, !wave.tensor<[@N, @M] of f32>, !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32> + return %result : !wave.tensor<[@N] of f32> +} + +// ----- + +// Axis attribute is preserved on each chained reduction. +// CHECK-LABEL: @expand_preserves_axis +// CHECK-SAME: (%[[A:.*]]: {{.*}}, %[[B:.*]]: {{.*}}, %[[INIT:.*]]: {{.*}}) +func.func @expand_preserves_axis(%a: !wave.tensor, %b: !wave.tensor, %init: !wave.tensor) -> !wave.tensor { + // CHECK: %[[R0:.*]] = wave.sum %[[A]] init(%[[INIT]]) along @K + // CHECK: %[[R1:.*]] = wave.sum %[[B]] init(%[[R0]]) along @K + // CHECK: return %[[R1]] + %result = wave.sum %a, %b init(%init) along @K : (!wave.tensor, !wave.tensor, !wave.tensor) -> !wave.tensor + return %result : !wave.tensor +} + +// ----- + +// Index attribute is preserved on each chained reduction. +// CHECK-LABEL: @expand_preserves_index +// CHECK-SAME: (%[[A:.*]]: {{.*}}, %[[B:.*]]: {{.*}}, %[[INIT:.*]]: {{.*}}) +func.func @expand_preserves_index(%a: !wave.tensor<[@N, @M] of f32>, %b: !wave.tensor<[@N, @M] of f32>, %init: !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32> { + // CHECK: %[[R0:.*]] = wave.sum %[[A]] init(%[[INIT]]) index [{N : <[] -> (42, 1, 1)>}] + // CHECK: %[[R1:.*]] = wave.sum %[[B]] init(%[[R0]]) index [{N : <[] -> (42, 1, 1)>}] + // CHECK: return %[[R1]] + %result = wave.sum %a, %b init(%init) index [{N : <[] -> (42, 1, 1)>}] : (!wave.tensor<[@N, @M] of f32>, !wave.tensor<[@N, @M] of f32>, !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32> + return %result : !wave.tensor<[@N] of f32> +} + +// ----- + +// Single-input reductions are left unchanged. +// CHECK-LABEL: @single_input_sum_unchanged +// CHECK-SAME: (%[[A:.*]]: {{.*}}, %[[INIT:.*]]: {{.*}}) +func.func @single_input_sum_unchanged(%a: !wave.tensor<[@N, @M] of f32>, %init: !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32> { + // CHECK: wave.sum %[[A]] init(%[[INIT]]) + // CHECK-NOT: wave.sum + %result = wave.sum %a init(%init) : (!wave.tensor<[@N, @M] of f32>, !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32> + return %result : !wave.tensor<[@N] of f32> +} diff --git a/water/test/Dialect/Wave/ops-invalid.mlir b/water/test/Dialect/Wave/ops-invalid.mlir index da3ba333b1..a839f46229 100644 --- a/water/test/Dialect/Wave/ops-invalid.mlir +++ b/water/test/Dialect/Wave/ops-invalid.mlir @@ -927,6 +927,14 @@ func.func @reduction_init_and_result_contain_axis(%input: !wave.tensor, %b: !wave.tensor<[@N, @K] of f32>, %init: !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32> { + // expected-error @below {{all inputs must have the same type, but got '!wave.tensor<[@N, @M] of f32>' and '!wave.tensor<[@N, @K] of f32>'}} + %result = wave.sum %a, %b init(%init) : (!wave.tensor<[@N, @M] of f32>, !wave.tensor<[@N, @K] of f32>, !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32> + return %result : !wave.tensor<[@N] of f32> +} + +// ----- + func.func @broadcast_source_dim_not_in_result(%arg0: !wave.tensor<[@M, @N] of f32, >) { // Source has [@M, @N], result has [@M, @P, @K] - N is missing (replaced by P). // expected-error @below {{source dimension 'N' not found in result shape}} diff --git a/water/test/Dialect/Wave/ops.mlir b/water/test/Dialect/Wave/ops.mlir index f319370b54..a8e97f0c57 100644 --- a/water/test/Dialect/Wave/ops.mlir +++ b/water/test/Dialect/Wave/ops.mlir @@ -690,6 +690,25 @@ func.func @underspecified_reduction(%input: !wave.tensor, %init: !wa // ----- +// Variadic reductions: multiple inputs are accepted. +// CHECK-LABEL: @variadic_sum +func.func @variadic_sum(%a: !wave.tensor<[@N, @M] of f32>, %b: !wave.tensor<[@N, @M] of f32>, %c: !wave.tensor<[@N, @M] of f32>, %init: !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32> { + // CHECK: wave.sum %{{.*}}, %{{.*}}, %{{.*}} init(%{{.*}}) + %result = wave.sum %a, %b, %c init(%init) : (!wave.tensor<[@N, @M] of f32>, !wave.tensor<[@N, @M] of f32>, !wave.tensor<[@N, @M] of f32>, !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32> + return %result : !wave.tensor<[@N] of f32> +} + +// ----- + +// CHECK-LABEL: @variadic_max_element +func.func @variadic_max_element(%a: !wave.tensor<[@N, @M] of f32>, %b: !wave.tensor<[@N, @M] of f32>, %init: !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32> { + // CHECK: wave.max_element %{{.*}}, %{{.*}} init(%{{.*}}) + %result = wave.max_element %a, %b init(%init) : (!wave.tensor<[@N, @M] of f32>, !wave.tensor<[@N, @M] of f32>, !wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32> + return %result : !wave.tensor<[@N] of f32> +} + +// ----- + // CHECK-LABEL: @broadcast_1d_to_2d func.func @broadcast_1d_to_2d(%arg0: !wave.tensor<[@M] of f32, >) -> !wave.tensor<[@M, @N] of f32, > { // CHECK: wave.broadcast %{{.*}} : (!wave.tensor<[@M] of f32, >) -> !wave.tensor<[@M, @N] of f32, > diff --git a/wave_lang/kernel/ops/wave_ops.py b/wave_lang/kernel/ops/wave_ops.py index d729b7fb27..41c2012381 100644 --- a/wave_lang/kernel/ops/wave_ops.py +++ b/wave_lang/kernel/ops/wave_ops.py @@ -570,7 +570,8 @@ class NewSubclass(cls): NewSubclass.tkw_op_name = op_name pascal_op_name = op_name.replace("_", " ").title().replace(" ", "") - NewSubclass.__name__ = f"{pascal_op_name}" + NewSubclass.__name__ = pascal_op_name + NewSubclass.__qualname__ = pascal_op_name NewSubclass.__module__ = cls.__module__ current_module = sys.modules[NewSubclass.__module__] setattr(current_module, NewSubclass.__name__, NewSubclass) diff --git a/wave_lang/kernel/wave/mlir_converter/fx_emitter.py b/wave_lang/kernel/wave/mlir_converter/fx_emitter.py index 198b5e0fdb..5886afd66d 100644 --- a/wave_lang/kernel/wave/mlir_converter/fx_emitter.py +++ b/wave_lang/kernel/wave/mlir_converter/fx_emitter.py @@ -36,16 +36,19 @@ from water_mlir.water_mlir.dialects.wave import ( AllocateOp, BroadcastOp, + MaxElementOp, ReadOp, WriteOp, MmaOp, RegisterOp, ExtractSliceOp, IterateOp, + SumOp, YieldOp, WaveAddressSpaceAttr, WaveExprListAttr, WaveMmaKindAttr, + WaveReductionScope, WaveSymbolMappingAttr, WaveWorkgroupDimAttr, WaveTensorType, @@ -90,6 +93,7 @@ Write, MMA, MMABase, + Max, NewRegister, ExtractSlice, Iterate, @@ -98,6 +102,7 @@ Output, GetResult, SharedMemoryBarrier, + Sum, get_custom, ) from attr_type_converter import ( @@ -788,6 +793,49 @@ def _handle_mma_op(op: MmaOp, parse_ctx: _OpParseContext) -> None: parse_ctx.add_mapping(op.result, mma_op.fx_node) +def _handle_reduction_op( + op: SumOp | MaxElementOp, + reduce_cls: type, + parse_ctx: _OpParseContext, +) -> None: + """Handle wave.sum / wave.max_element operations.""" + input_nodes = [parse_ctx.resolve_operand(v) for v in op.inputs] + init_node = parse_ctx.resolve_operand(op.init) + is_block = op.scope.value == WaveReductionScope.Block + + axis_attr = op.axis + if axis_attr is not None: + dim = index_symbol(symbol_attr_to_name(axis_attr)) + else: + # Infer the reduction dimension from the shape difference between + # input and result (the verifier guarantees fully-specified types + # when axis is absent). + input_shape = set( + index_symbol(symbol_attr_to_name(s)) for s in op.inputs[0].type.shape + ) + result_shape = set( + index_symbol(symbol_attr_to_name(s)) for s in op.result.type.shape + ) + reduced_dims = input_shape - result_shape + assert ( + len(reduced_dims) == 1 + ), f"Expected exactly one reduced dimension, got {reduced_dims}" + dim = reduced_dims.pop() + + converted_attrs = _convert_supported_attrs(op, ignore_attrs={"scope", "axis"}) + + reduce_op = reduce_cls.create( + parse_ctx.graph, + arg=input_nodes, + init=init_node, + dim=dim, + block=is_block, + type=_convert_wave_tensor_type(op.result.type, parse_ctx), + ) + _apply_mlir_attrs_to_fx_node(reduce_op.fx_node, converted_attrs) + parse_ctx.add_mapping(op.result, reduce_op.fx_node) + + def _handle_extract_slice_op(op: ExtractSliceOp, parse_ctx: _OpParseContext) -> None: """Handle wave.extract_slice operation.""" src_node = parse_ctx.resolve_operand(op.memory) @@ -1027,6 +1075,10 @@ def _convert_ops(ops: Sequence[ir.Operation], parse_ctx: _OpParseContext) -> Non _handle_write_op(op, parse_ctx) case MmaOp(): _handle_mma_op(op, parse_ctx) + case SumOp(): + _handle_reduction_op(op, Sum, parse_ctx) + case MaxElementOp(): + _handle_reduction_op(op, Max, parse_ctx) case ExtractSliceOp(): _handle_extract_slice_op(op, parse_ctx) case IterateOp(): diff --git a/wave_lang/kernel/wave/mlir_converter/water_emitter.py b/wave_lang/kernel/wave/mlir_converter/water_emitter.py index df5b1a4f4a..045f046a52 100644 --- a/wave_lang/kernel/wave/mlir_converter/water_emitter.py +++ b/wave_lang/kernel/wave/mlir_converter/water_emitter.py @@ -754,18 +754,35 @@ def create_mlir_operands(): result_type, *create_mlir_operands(), kind=mma_kind ) elif isinstance(node, Reduce): - if isinstance(node.arg, Sequence): - raise NotImplementedError( - "Only single-operand reductions are currently supported." - ) + args = node.arg if isinstance(node.arg, Sequence) else [node.arg] + inputs = [get_single_mapped_value(a) for a in args] + init = get_single_mapped_value(node.init) + # The axis attribute is only emitted when the input type + # is not fully specified (uses 'any' shapes). For + # fully-specified types, the reduction dimension is + # inferred from the shape difference between input and + # result, and emitting axis would be rejected by the + # verifier. + input_type = inputs[0].type + fully_specified = ( + hasattr(input_type, "fully_specified") + and input_type.fully_specified + ) + axis = ( + symbol_name_to_attribute(node.dim.name) + if node.dim is not None and not fully_specified + else None + ) mlir_op = op_builder( result_type, - *create_mlir_operands(), + inputs, + init, scope=wave.WaveReductionScopeAttr.get( wave.WaveReductionScope.Block if node.block else wave.WaveReductionScope.Warp ), + axis=axis, ) elif isinstance(node, Allocate): # Get parent value from value_map if it exists. From a1ef8147e420a84208f867631dbccbf3a2736311 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20L=C3=BCcke?= Date: Fri, 6 Mar 2026 16:58:13 +0100 Subject: [PATCH 2/2] address PR comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Martin Lücke --- water/include/water/Dialect/Wave/IR/WaveOps.td | 10 ++++++---- water/include/water/Dialect/Wave/Transforms/Passes.td | 4 ++++ wave_lang/kernel/wave/mlir_converter/fx_emitter.py | 2 +- wave_lang/kernel/wave/mlir_converter/water_emitter.py | 2 +- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/water/include/water/Dialect/Wave/IR/WaveOps.td b/water/include/water/Dialect/Wave/IR/WaveOps.td index 4ca32ae607..6aa7e4ad34 100644 --- a/water/include/water/Dialect/Wave/IR/WaveOps.td +++ b/water/include/water/Dialect/Wave/IR/WaveOps.td @@ -106,8 +106,9 @@ class ReductionWaveOp WaveElementsPerThreadOpInterface, ReductionElementsPerThreadOpTrait, RequiresSidewaysBackwardPropagationOpTrait, WaveReductionOpTrait]>, WaveArithmeticOpDoc { - // We cannot use Variadic because Variadic requires - // a Type, not a TypeConstraint (same pattern as wave.iterate's iter_args). + // TODO(#889): We cannot use Variadic because Variadic + // requires a Type, not a TypeConstraint (same pattern as wave.iterate's + // iter_args). let arguments = !con((ins Arg, "Tensor(s) to reduce">:$inputs, Arg:$init, @@ -477,7 +478,7 @@ def ReshapeOp : WaveOp<"reshape", [ }]; let arguments = !con((ins - // TODO: cannot use WaveTensorInRegister here because it's not a Type, requires + // TODO(#889): cannot use WaveTensorInRegister here because it's not a Type, requires // an upstream change to Variadic to allow TypeConstraint. Arg, "Tensor to reshape, potentially composed from a sequence " @@ -696,7 +697,8 @@ def ApplyExprOp : WaveOp<"apply_expr", }] # baseDescription; // Accept both WaveTensorType (before PropagateElementsPerThread) and - // 1D vectors (after). We cannot use Variadic because + // 1D vectors (after). + // TODO(#889): We cannot use Variadic because // Variadic requires a Type and not a TypeConstraint. let arguments = !con((ins Arg, "Input registers">:$arguments, diff --git a/water/include/water/Dialect/Wave/Transforms/Passes.td b/water/include/water/Dialect/Wave/Transforms/Passes.td index f92f607fad..12f13d0be4 100644 --- a/water/include/water/Dialect/Wave/Transforms/Passes.td +++ b/water/include/water/Dialect/Wave/Transforms/Passes.td @@ -198,6 +198,10 @@ def WaterWaveExpandVariadicReductionsPass supports variadic inputs for faithful roundtripping with the Python representation. This pass normalizes them before lowering, which requires single-input reductions. + + This corresponds to the "source reduce" step (step 1) of the Python-side + `decompose_reduce_ops` pass, which additionally performs local element + reduction, cross-thread butterfly shuffles, and accumulator combination. }]; } diff --git a/wave_lang/kernel/wave/mlir_converter/fx_emitter.py b/wave_lang/kernel/wave/mlir_converter/fx_emitter.py index 5886afd66d..838bf138c6 100644 --- a/wave_lang/kernel/wave/mlir_converter/fx_emitter.py +++ b/wave_lang/kernel/wave/mlir_converter/fx_emitter.py @@ -795,7 +795,7 @@ def _handle_mma_op(op: MmaOp, parse_ctx: _OpParseContext) -> None: def _handle_reduction_op( op: SumOp | MaxElementOp, - reduce_cls: type, + reduce_cls: type[Sum | Max], parse_ctx: _OpParseContext, ) -> None: """Handle wave.sum / wave.max_element operations.""" diff --git a/wave_lang/kernel/wave/mlir_converter/water_emitter.py b/wave_lang/kernel/wave/mlir_converter/water_emitter.py index 045f046a52..1cd444859c 100644 --- a/wave_lang/kernel/wave/mlir_converter/water_emitter.py +++ b/wave_lang/kernel/wave/mlir_converter/water_emitter.py @@ -765,7 +765,7 @@ def create_mlir_operands(): # verifier. input_type = inputs[0].type fully_specified = ( - hasattr(input_type, "fully_specified") + isinstance(input_type, WaveTensorType) and input_type.fully_specified ) axis = (