diff --git a/xla/hlo/ir/hlo_original_value_test.cc b/xla/hlo/ir/hlo_original_value_test.cc index b56966848f17f..8a84d37c7f5f1 100644 --- a/xla/hlo/ir/hlo_original_value_test.cc +++ b/xla/hlo/ir/hlo_original_value_test.cc @@ -462,5 +462,27 @@ ENTRY main { EXPECT_EQ(gte->original_value()->ToString(), R"({"p0"})"); } +TEST_F(OriginalValueHloTest, CopyOriginalValueWithMap) { + auto src_original_value = std::make_shared(Node::Tuple({ + Node::Leaf(OriginalArray{"instA", {0}}), + Node::Leaf(OriginalArray{"instB", {1}}), + Node::Leaf(OriginalArray{"instC", {2}}), + })); + + auto dest_original_value = + std::make_shared(ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})})); + + absl::flat_hash_map old_to_new_tuple_idx = {{2, 0}, {0, 1}}; + + CopyOriginalValue(src_original_value, dest_original_value, + old_to_new_tuple_idx); + + EXPECT_THAT(dest_original_value->original_array({0}), + Optional(Eq(OriginalArray{"instC", {2}}))); + EXPECT_THAT(dest_original_value->original_array({1}), + Optional(Eq(OriginalArray{"instA", {0}}))); +} + } // namespace } // namespace xla diff --git a/xla/hlo/ir/hlo_original_value_util.h b/xla/hlo/ir/hlo_original_value_util.h index 94ca6f5debdc1..16c8c850e77af 100644 --- a/xla/hlo/ir/hlo_original_value_util.h +++ b/xla/hlo/ir/hlo_original_value_util.h @@ -26,30 +26,53 @@ limitations under the License. namespace xla { -// Copies the original value of the source to the destination instruction. -// Original arrays in the source original value are rearranged in the new -// original value according to the given mapping of old to new tuple indices. +// Checks if the type of the map is a matching integer map. template -std::enable_if_t> CopyOriginalValue( - const HloInstruction* src_instruction, HloInstruction* dest_instruction, - const absl::flat_hash_map& old_to_new_tuple_idx) { - std::shared_ptr old_original_value = - src_instruction->original_value(); - if (!old_original_value) { +struct is_matching_integer_map { + static constexpr bool value = + std::is_integral::value && + std::is_same::value; +}; + +// Copies original arrays in the source original value to the destination +// original value according to the given mapping of old to new tuple indices. +template +typename std::enable_if::value>::type +CopyOriginalValue(const std::shared_ptr& src_original_value, + const std::shared_ptr& dest_original_value, + const MapType& old_to_new_tuple_idx) { + if (!src_original_value || !dest_original_value) { return; } - const int64_t src_tuple_size = old_original_value->tree().num_leaves(); + const int64_t src_tuple_size = src_original_value->tree().num_leaves(); const int64_t dest_tuple_size = old_to_new_tuple_idx.size(); - std::shared_ptr new_original_value = - std::make_shared(dest_instruction->shape()); for (const auto& [old_idx, new_idx] : old_to_new_tuple_idx) { if (old_idx < 0 || old_idx >= src_tuple_size || new_idx < 0 || new_idx >= dest_tuple_size) { return; } - new_original_value->mutable_tree()->CopySubtreeFrom( - old_original_value->tree(), {old_idx}, {new_idx}); + dest_original_value->mutable_tree()->CopySubtreeFrom( + src_original_value->tree(), {old_idx}, {new_idx}); + } +} + +// Copies the original value of the source to the destination instruction. +// Original arrays in the source original value are rearranged in the new +// original value according to the given mapping of old to new tuple indices. +template +typename std::enable_if::value>::type +CopyOriginalValue(const HloInstruction* src_instruction, + HloInstruction* dest_instruction, + const MapType& old_to_new_tuple_idx) { + const std::shared_ptr old_original_value = + src_instruction->original_value(); + if (!old_original_value) { + return; } + auto new_original_value = + std::make_shared(dest_instruction->shape()); + CopyOriginalValue(old_original_value, new_original_value, + old_to_new_tuple_idx); dest_instruction->set_original_value(new_original_value); }