From 041e70377a33b3e7681da31284fb46046280955e Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Fri, 20 Feb 2026 15:28:01 +0100 Subject: [PATCH] fix reshape construction Signed-off-by: Alex Zinenko --- wave_lang/kernel/wave/utils/mma_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wave_lang/kernel/wave/utils/mma_utils.py b/wave_lang/kernel/wave/utils/mma_utils.py index fa698b532a..5fa4ef3382 100644 --- a/wave_lang/kernel/wave/utils/mma_utils.py +++ b/wave_lang/kernel/wave/utils/mma_utils.py @@ -159,11 +159,11 @@ def add_reshape_if_needed(mma: MMABase, prev_mma: MMABase, arg_index: int): arg = mma.lhs if arg_index == 0 else mma.rhs arg = get_custom(arg) if is_reshape_needed(arg, mma.vector_shapes, prev_mma.vector_shapes): - reshape = Reshape(arg.fx_node, prev_mma.vector_shapes).add_to_graph( + reshape = Reshape(arg.fx_node, mma.vector_shapes).add_to_graph( mma.graph, loc=mma.location ) custom_reshape = get_custom(reshape) - custom_reshape.vector_shapes = mma.vector_shapes + custom_reshape.vector_shapes = prev_mma.vector_shapes propagate_tag(mma.fx_node, reshape) mma.update_arg(arg_index, reshape)