-
Notifications
You must be signed in to change notification settings - Fork 28
Description
After expand_graph (pass 14), reduction ops have their single input tiled into a list of slices along the reduction dimension. For example, a ReduceOp that originally took one tensor now has node.arg = [slice_0, slice_1, ..., slice_n]. The Water emitter in water_emitter.py:732-736 explicitly checks for this and raises NotImplementedError("Only single-operand reductions are currently supported.").
This intermediate state exists between expand_graph and decompose_reduce_ops (pass 35), which decomposes the multi-slice reductions into per-slice reduces with combining logic. The Water dialect's reduction ops (wave.sum, wave.max_element in WaveOps.td) also only accept a single $input, so even if the emitter didn't reject them, the dialect couldn't represent them.
This blocks passes 14-34 (the entire middle section of the pipeline) for any kernel with reductions - all attention variants, and any future kernel with tiled reduction dimensions.
Fix options:
(a) Extend the Water dialect reduction ops to accept variadic inputs and teach the emitter to serialize them.
(b) Have the emitter concatenate the slices back into a single tensor before emitting the reduce.
(c) Emit multiple single-operand reduces and combine results with the reduction's binary combiner (add for sum, max for max).