Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions wave_lang/kernel/wave/utils/mma_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,11 @@ def add_reshape_if_needed(mma: MMABase, prev_mma: MMABase, arg_index: int):
arg = mma.lhs if arg_index == 0 else mma.rhs
arg = get_custom(arg)
if is_reshape_needed(arg, mma.vector_shapes, prev_mma.vector_shapes):
reshape = Reshape(arg.fx_node, prev_mma.vector_shapes).add_to_graph(
reshape = Reshape(arg.fx_node, mma.vector_shapes).add_to_graph(
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reshape is being constructed with args=arg.fx_node (a single fx.Node). Downstream codegen (handle_reshape) treats args as a sequence (calls len(args) and indexes args[0]), so passing a single node here will raise at runtime. Wrap the argument in a 1-element sequence (e.g., [mma.lhs] / [mma.rhs] or [arg.fx_node]) so args is always a list/tuple of nodes.

Suggested change
reshape = Reshape(arg.fx_node, mma.vector_shapes).add_to_graph(
reshape = Reshape([arg.fx_node], mma.vector_shapes).add_to_graph(

Copilot uses AI. Check for mistakes.
mma.graph, loc=mma.location
)
custom_reshape = get_custom(reshape)
custom_reshape.vector_shapes = mma.vector_shapes
custom_reshape.vector_shapes = prev_mma.vector_shapes
Comment on lines +162 to +166
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Reshape construction appears to swap the meaning of target_vector_shape vs vector_shapes compared to other reshape sites in the codebase. Elsewhere (e.g. wave_lang/kernel/wave/decompose_vmma_ops.py:136-153), target_vector_shape is set to the source (pre-reshape) vector shape and reshape.vector_shapes is set to the destination (post-reshape) vector shape so expansion/fixup can derive num_slices correctly. Here target_vector_shape is set to mma.vector_shapes and custom_reshape.vector_shapes to prev_mma.vector_shapes, which would make the reshape compute slice/concat factors in the wrong direction when chaining MMAs. Consider restoring the prior direction: set target_vector_shape from prev_mma.vector_shapes and set custom_reshape.vector_shapes to mma.vector_shapes (the layout the current MMA consumes).

Copilot uses AI. Check for mistakes.
propagate_tag(mma.fx_node, reshape)
mma.update_arg(arg_index, reshape)

Expand Down
Loading