diff --git a/wave_lang/kernel/wave/utils/mma_utils.py b/wave_lang/kernel/wave/utils/mma_utils.py index fa698b532..5fa4ef338 100644 --- a/wave_lang/kernel/wave/utils/mma_utils.py +++ b/wave_lang/kernel/wave/utils/mma_utils.py @@ -159,11 +159,11 @@ def add_reshape_if_needed(mma: MMABase, prev_mma: MMABase, arg_index: int): arg = mma.lhs if arg_index == 0 else mma.rhs arg = get_custom(arg) if is_reshape_needed(arg, mma.vector_shapes, prev_mma.vector_shapes): - reshape = Reshape(arg.fx_node, prev_mma.vector_shapes).add_to_graph( + reshape = Reshape(arg.fx_node, mma.vector_shapes).add_to_graph( mma.graph, loc=mma.location ) custom_reshape = get_custom(reshape) - custom_reshape.vector_shapes = mma.vector_shapes + custom_reshape.vector_shapes = prev_mma.vector_shapes propagate_tag(mma.fx_node, reshape) mma.update_arg(arg_index, reshape)