Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions xla/hlo/ir/hlo_original_value_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -462,5 +462,27 @@ ENTRY main {
EXPECT_EQ(gte->original_value()->ToString(), R"({"p0"})");
}

TEST_F(OriginalValueHloTest, CopyOriginalValueWithMap) {
auto src_original_value = std::make_shared<OriginalValue>(Node::Tuple({
Node::Leaf(OriginalArray{"instA", {0}}),
Node::Leaf(OriginalArray{"instB", {1}}),
Node::Leaf(OriginalArray{"instC", {2}}),
}));

auto dest_original_value =
std::make_shared<OriginalValue>(ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}));

absl::flat_hash_map<int64_t, int64_t> old_to_new_tuple_idx = {{2, 0}, {0, 1}};

CopyOriginalValue(src_original_value, dest_original_value,
old_to_new_tuple_idx);

EXPECT_THAT(dest_original_value->original_array({0}),
Optional(Eq(OriginalArray{"instC", {2}})));
EXPECT_THAT(dest_original_value->original_array({1}),
Optional(Eq(OriginalArray{"instA", {0}})));
}

} // namespace
} // namespace xla
51 changes: 37 additions & 14 deletions xla/hlo/ir/hlo_original_value_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,30 +26,53 @@ limitations under the License.

namespace xla {

// Copies the original value of the source to the destination instruction.
// Original arrays in the source original value are rearranged in the new
// original value according to the given mapping of old to new tuple indices.
// Checks if the type of the map is a matching integer map.
template <typename T>
std::enable_if_t<std::is_integral_v<T>> CopyOriginalValue(
const HloInstruction* src_instruction, HloInstruction* dest_instruction,
const absl::flat_hash_map<T, T>& old_to_new_tuple_idx) {
std::shared_ptr<OriginalValue> old_original_value =
src_instruction->original_value();
if (!old_original_value) {
struct is_matching_integer_map {
static constexpr bool value =
std::is_integral<typename T::key_type>::value &&
std::is_same<typename T::key_type, typename T::mapped_type>::value;
};

// Copies original arrays in the source original value to the destination
// original value according to the given mapping of old to new tuple indices.
template <typename MapType>
typename std::enable_if<is_matching_integer_map<MapType>::value>::type
CopyOriginalValue(const std::shared_ptr<OriginalValue>& src_original_value,
const std::shared_ptr<OriginalValue>& dest_original_value,
const MapType& old_to_new_tuple_idx) {
if (!src_original_value || !dest_original_value) {
return;
}
const int64_t src_tuple_size = old_original_value->tree().num_leaves();
const int64_t src_tuple_size = src_original_value->tree().num_leaves();
const int64_t dest_tuple_size = old_to_new_tuple_idx.size();
std::shared_ptr<xla::OriginalValue> new_original_value =
std::make_shared<xla::OriginalValue>(dest_instruction->shape());
for (const auto& [old_idx, new_idx] : old_to_new_tuple_idx) {
if (old_idx < 0 || old_idx >= src_tuple_size || new_idx < 0 ||
new_idx >= dest_tuple_size) {
return;
}
new_original_value->mutable_tree()->CopySubtreeFrom(
old_original_value->tree(), {old_idx}, {new_idx});
dest_original_value->mutable_tree()->CopySubtreeFrom(
src_original_value->tree(), {old_idx}, {new_idx});
}
}

// Copies the original value of the source to the destination instruction.
// Original arrays in the source original value are rearranged in the new
// original value according to the given mapping of old to new tuple indices.
template <typename MapType>
typename std::enable_if<is_matching_integer_map<MapType>::value>::type
CopyOriginalValue(const HloInstruction* src_instruction,
HloInstruction* dest_instruction,
const MapType& old_to_new_tuple_idx) {
const std::shared_ptr<OriginalValue> old_original_value =
src_instruction->original_value();
if (!old_original_value) {
return;
}
auto new_original_value =
std::make_shared<xla::OriginalValue>(dest_instruction->shape());
CopyOriginalValue(old_original_value, new_original_value,
old_to_new_tuple_idx);
dest_instruction->set_original_value(new_original_value);
}

Expand Down
Loading