From d4fb498b240640f0cd29a130bc2c973e5444fa7c Mon Sep 17 00:00:00 2001 From: Kingston Mandisodza Date: Wed, 1 Apr 2026 04:11:08 -0700 Subject: [PATCH] Add `NamedSharding` support to `GatherScatterOperandsShardedAcrossParallelDims`. This change introduces a new helper function to handle `NamedSharding` in `GatherScatterOperandsShardedAcrossParallelDims`. When both the operand and indices use `NamedSharding`, this function attempts to align the sharding of parallel dimensions. If one dimension is sharded and the corresponding parallel dimension is replicated, it will propagate the sharding to the replicated dimension if the sharding axes are available. PiperOrigin-RevId: 892847147 --- xla/hlo/utils/hlo_sharding_util.cc | 7 +- xla/service/spmd/BUILD | 1 + xla/service/spmd/gather_scatter_handler.cc | 22 +- xla/service/spmd/spmd_partitioner.cc | 36 ++- xla/service/spmd/spmd_partitioner_test.cc | 136 ++++++----- xla/service/spmd/spmd_partitioner_util.cc | 225 +++++++++++++++++- .../spmd/spmd_partitioner_util_test.cc | 95 ++++++++ 7 files changed, 444 insertions(+), 78 deletions(-) diff --git a/xla/hlo/utils/hlo_sharding_util.cc b/xla/hlo/utils/hlo_sharding_util.cc index f15d2bdbb92ec..6de57d6209781 100644 --- a/xla/hlo/utils/hlo_sharding_util.cc +++ b/xla/hlo/utils/hlo_sharding_util.cc @@ -277,7 +277,12 @@ bool IsSubTilingOrEqualNamedSharding(const Shape& potential_sharded_shape, !sharding.manual_axes().empty()) { return false; } - CHECK(sub_mesh.DeviceAssignmentEquals(mesh)); + if (!sub_mesh.DeviceAssignmentEquals(mesh)) { + return IsSubTilingOrEqualSharding( + potential_sharded_shape, + HloSharding::V3ToV2Sharding(potential_subsharding), + HloSharding::V3ToV2Sharding(sharding)); + } CHECK_EQ(potential_subsharding.num_dimensions(), sharding.num_dimensions()); diff --git a/xla/service/spmd/BUILD b/xla/service/spmd/BUILD index 45181fcd60a53..634101f5cbcf4 100644 --- a/xla/service/spmd/BUILD +++ b/xla/service/spmd/BUILD @@ -440,6 +440,7 @@ xla_cc_test( srcs = ["spmd_partitioner_util_test.cc"], deps = [ ":spmd_partitioner", + "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/hlo/ir:mesh_and_axis", "//xla/hlo/ir:named_sharding", diff --git a/xla/service/spmd/gather_scatter_handler.cc b/xla/service/spmd/gather_scatter_handler.cc index 9181db89c1a14..6da717268aca2 100644 --- a/xla/service/spmd/gather_scatter_handler.cc +++ b/xla/service/spmd/gather_scatter_handler.cc @@ -1544,8 +1544,12 @@ HloInstruction* SelectOperandForScatterIndexPassthroughDimensions( // Update partition_id for partial replicate. auto partition_id = indices.state().partition_id; if (indices.sharding().HasPartialReplication()) { + HloSharding sharding_v2 = + indices.sharding().UseNamedShardingLeaf() + ? HloSharding::V3ToV2Sharding(indices.sharding().named_sharding()) + : indices.sharding(); auto sharding_grouped = hlo_sharding_util::GroupShardingOnDims( - indices.sharding(), {indices.sharding().SubgroupReplicationDim()}); + sharding_v2, {sharding_v2.SubgroupReplicationDim()}); partition_id = GetInGroupPartitionId(partition_id, sharding_grouped.device_groups, b); } @@ -1965,25 +1969,29 @@ absl::Status SpmdPartitioningVisitor::HandleScatterWithoutConflicts( // guaranteed by the scatter semantics. HloSharding indices_sharding = indices.sharding(); for (int64_t i = 0; i < indices.num_dimensions(); ++i) { + int64_t dim_size = + indices_sharding.IsReplicated() ? 1 : indices_sharding.dimension(i); if (indices.base_shape().dimensions(i) != - indices_sharding.dimension(i) * indices.hlo()->shape().dimensions(i)) { + dim_size * indices.hlo()->shape().dimensions(i)) { indices = indices.Replicate().Reshard( indices_sharding, /*pad_value=*/LiteralUtil::CreateR0(-1)); break; } } - SpmdBuilder* b = builder(); - std::vector partitioned_inserted_window_dims; for (int64_t dim : dnums.inserted_window_dims()) { - if (operands[0].sharding().dimension(dim) > 1) { + int64_t dim_size = operands[0].sharding().IsReplicated() + ? 1 + : operands[0].sharding().dimension(dim); + if (dim_size > 1) { // TODO(b/496605332). inserted_window_dims may not be in // scatter_dims_to_operand_dims. CHECK(absl::c_linear_search(dnums.scatter_dims_to_operand_dims(), dim)); partitioned_inserted_window_dims.push_back(dim); } } + SpmdBuilder* b = builder(); if (!partitioned_inserted_window_dims.empty()) { HloInstruction* indices_min; std::tie(indices_min, std::ignore) = @@ -2125,8 +2133,10 @@ absl::Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) { // Reshard indices with -1 padding, which will have no effect on the result as // guaranteed by the scatter semantics. for (auto i = 0; i != indices.num_dimensions(); ++i) { + int64_t dim_size = + indices_sharding.IsReplicated() ? 1 : indices_sharding.dimension(i); if (indices.base_shape().dimensions(i) != - indices_sharding.dimension(i) * indices.hlo()->shape().dimensions(i)) { + dim_size * indices.hlo()->shape().dimensions(i)) { // Reshard only when we know that some dimension is padded. indices = indices.Replicate().Reshard( indices_sharding, /*pad_value=*/LiteralUtil::CreateR0(-1)); diff --git a/xla/service/spmd/spmd_partitioner.cc b/xla/service/spmd/spmd_partitioner.cc index 2a3fdb179c700..04d324eba497f 100644 --- a/xla/service/spmd/spmd_partitioner.cc +++ b/xla/service/spmd/spmd_partitioner.cc @@ -2167,11 +2167,23 @@ std::optional PartitionedHlo::TryComplexReshardHandling( const HloSharding& target) const { VLOG(5) << "Trying to split complicated reshard: " << sharding().ToString() << " to " << target.ToString(); + + // TODO(b/498118846): Remove V2 conversion and support V3 directly. + HloSharding source_v2 = + sharding().UseNamedShardingLeaf() + ? HloSharding::V3ToV2Sharding(sharding().named_sharding()) + : sharding(); + HloSharding target_v2 = + target.UseNamedShardingLeaf() + ? HloSharding::V3ToV2Sharding(target.named_sharding()) + : target; + const bool is_source_partially_replicated = - sharding().ReplicateOnLastTileDim(); - const bool is_target_partially_replicated = target.ReplicateOnLastTileDim(); + source_v2.ReplicateOnLastTileDim(); + const bool is_target_partially_replicated = + target_v2.ReplicateOnLastTileDim(); if (auto reshape = PatternMatchMergeOrSplitSharding(this->base_shape(), - sharding(), target)) { + source_v2, target_v2)) { auto& [before_sharding, new_reshaped_sharding, source_dim] = *reshape; PartitionedHlo reshaped = SplitReshapeHelper( *this, source_dim, this->hlo()->shape().dimensions(source_dim), @@ -2199,7 +2211,7 @@ std::optional PartitionedHlo::TryComplexReshardHandling( return reshaped; } if (auto intermediate_target = - PatternMatchPartiallyReplicateDim(sharding(), target)) { + PatternMatchPartiallyReplicateDim(source_v2, target_v2)) { VLOG(5) << "Matched \"pattern_match_partially_replicate_dim()\": " << intermediate_target->ToString(); auto intermediate_reshard = Reshard(*intermediate_target); @@ -2211,17 +2223,17 @@ std::optional PartitionedHlo::TryComplexReshardHandling( return final_reshard; } if (is_source_partially_replicated && !is_target_partially_replicated) { - const int64_t partial_repl_amount = sharding().dimensions().back(); + const int64_t partial_repl_amount = source_v2.dimensions().back(); int64_t first_different_dimension = -1; // Trying to match conditions like [..,X,..,Z,..,Y] last_tile_dim_replicate // to [..,Y,..,Z,..,X,..], where Y in the source is partially replicated, // but in the target it is not and some other dimension got moved or // modified. Try to remove the partial replication to simplify the step from // source to target sharding. - for (int64_t i = 0; i < target.num_dimensions(); ++i) { - if (target.dimension(i) != sharding().dimension(i) && - sharding().dimension(i) == 1 && - target.dimension(i) % partial_repl_amount == 0) { + for (int64_t i = 0; i < target_v2.num_dimensions(); ++i) { + if (target_v2.dimension(i) != source_v2.dimension(i) && + source_v2.dimension(i) == 1 && + target_v2.dimension(i) % partial_repl_amount == 0) { first_different_dimension = i; break; } @@ -2230,12 +2242,12 @@ std::optional PartitionedHlo::TryComplexReshardHandling( return std::nullopt; } VLOG(5) << "Matched partially replicated to non partially replicated: " - << sharding().ToString(); - std::vector transpose_dims(sharding().num_dimensions(), 0); + << source_v2.ToString(); + std::vector transpose_dims(source_v2.num_dimensions(), 0); absl::c_iota(transpose_dims, 0); std::swap(transpose_dims[first_different_dimension], transpose_dims.back()); auto intermediate_sharding = - hlo_sharding_util::TransposeSharding(sharding(), transpose_dims); + hlo_sharding_util::TransposeSharding(source_v2, transpose_dims); auto intermediate_reshard = Reshard(intermediate_sharding); auto reshard = intermediate_reshard.ReshardNoCache( target, /*pad_value=*/std::nullopt, /*allow_full_replication=*/false); diff --git a/xla/service/spmd/spmd_partitioner_test.cc b/xla/service/spmd/spmd_partitioner_test.cc index 500346b73e7d7..1de06a23cf7aa 100644 --- a/xla/service/spmd/spmd_partitioner_test.cc +++ b/xla/service/spmd/spmd_partitioner_test.cc @@ -1771,7 +1771,7 @@ ENTRY entry { EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_high(), 0); } -TEST_P(SpmdPartitioningTest, SelectAndScatterNoOverlap) { +TEST_P(SpmdPartitioningAllShardingTest, SelectAndScatterNoOverlap) { absl::string_view hlo_string = R"( HloModule module @@ -1821,7 +1821,7 @@ ENTRY entry { EXPECT_EQ(root->window().dimensions(0).padding_high(), 0); } -TEST_P(SpmdPartitioningTest, SelectAndScatterNoOverlapReshard) { +TEST_P(SpmdPartitioningAllShardingTest, SelectAndScatterNoOverlapReshard) { absl::string_view hlo_string = R"( HloModule module @@ -1875,7 +1875,7 @@ ENTRY entry { EXPECT_EQ(root->window().dimensions(0).padding_high(), 0); } -TEST_P(SpmdPartitioningTest, SelectAndScatterWithOverlap) { +TEST_P(SpmdPartitioningAllShardingTest, SelectAndScatterWithOverlap) { absl::string_view hlo_string = R"( HloModule module @@ -5250,7 +5250,7 @@ ENTRY entry { EXPECT_THAT(root, AllOf(op::While(zero), op::Shape("s32[]"))); } -TEST_P(SpmdPartitioningTest, SelectAndScatter_RetinaNet) { +TEST_P(SpmdPartitioningAllShardingTest, SelectAndScatter_RetinaNet) { absl::string_view hlo_string = R"( HloModule module @@ -9028,7 +9028,7 @@ ENTRY entry { } } -TEST_P(SpmdPartitioningTest, UnpartitionedScatter) { +TEST_P(SpmdPartitioningAllShardingTest, UnpartitionedScatter) { absl::string_view hlo_string = R"( HloModule module @@ -9062,7 +9062,7 @@ ENTRY entry { op::Shape("f32[2,5]"))); } -TEST_P(SpmdPartitioningTest, VariadicScatter) { +TEST_P(SpmdPartitioningAllShardingTest, VariadicScatter) { absl::string_view hlo_string = R"( HloModule module @@ -9108,7 +9108,7 @@ ENTRY entry { op::Shape("(f32[2,3],f32[2,3])"))); } -TEST_P(SpmdPartitioningTest, VariadicScatterSharedOperands) { +TEST_P(SpmdPartitioningAllShardingTest, VariadicScatterSharedOperands) { absl::string_view hlo_string = R"( HloModule module @@ -9153,7 +9153,7 @@ ENTRY entry { } } -TEST_P(SpmdPartitioningTest, PassthroughScatter) { +TEST_P(SpmdPartitioningAllShardingTest, PassthroughScatter) { absl::string_view hlo_string = R"( HloModule module @@ -9188,7 +9188,7 @@ ENTRY entry { } } -TEST_P(SpmdPartitioningTest, PassthroughScatterVariadic) { +TEST_P(SpmdPartitioningAllShardingTest, PassthroughScatterVariadic) { absl::string_view hlo_string = R"( HloModule module @@ -9231,7 +9231,7 @@ ENTRY entry { } } -TEST_P(SpmdPartitioningTest, PassthroughScatter_PartialReplicate) { +TEST_P(SpmdPartitioningAllShardingTest, PassthroughScatter_PartialReplicate) { absl::string_view hlo_string = R"( HloModule module @@ -9269,7 +9269,8 @@ ENTRY entry { } } -TEST_P(SpmdPartitioningTest, PassthroughScatterVariadic_PartialReplicate) { +TEST_P(SpmdPartitioningAllShardingTest, + PassthroughScatterVariadic_PartialReplicate) { absl::string_view hlo_string = R"( HloModule module @@ -9316,7 +9317,7 @@ ENTRY entry { } } -TEST_P(SpmdPartitioningTest, IndexPassthroughScatter) { +TEST_P(SpmdPartitioningAllShardingTest, IndexPassthroughScatter) { absl::string_view hlo_string = R"( HloModule module @@ -9355,7 +9356,8 @@ ENTRY entry { } } -TEST_P(SpmdPartitioningTest, IndexPassthroughScatter_PartialReplicate) { +TEST_P(SpmdPartitioningAllShardingTest, + IndexPassthroughScatter_PartialReplicate) { absl::string_view hlo_string = R"( HloModule module @@ -9396,7 +9398,8 @@ ENTRY entry { } } -TEST_P(SpmdPartitioningTest, IndexPassthroughScatterPartitionedIndexVectorDim) { +TEST_P(SpmdPartitioningAllShardingTest, + IndexPassthroughScatterPartitionedIndexVectorDim) { absl::string_view hlo_string = R"( HloModule module @@ -9430,7 +9433,7 @@ ENTRY entry { EXPECT_THAT(root, op::AllReduce(op::AllReduce(scatter))); } -TEST_P(SpmdPartitioningTest, IndexPassthroughScatterReshardIndices) { +TEST_P(SpmdPartitioningAllShardingTest, IndexPassthroughScatterReshardIndices) { absl::string_view hlo_string = R"( HloModule module @@ -9463,7 +9466,7 @@ ENTRY entry { op::AllReduce(op::AllReduce(scatter))); } -TEST_P(SpmdPartitioningTest, IndexPassthroughScatter_Min) { +TEST_P(SpmdPartitioningAllShardingTest, IndexPassthroughScatter_Min) { absl::string_view hlo_string = R"( HloModule module @@ -9502,7 +9505,7 @@ ENTRY entry { } } -TEST_P(SpmdPartitioningTest, ScatterExplicitBatchDims) { +TEST_P(SpmdPartitioningAllShardingTest, ScatterExplicitBatchDims) { absl::string_view hlo_string = R"( HloModule module @@ -9538,7 +9541,8 @@ ENTRY entry { } } -TEST_P(SpmdPartitioningTest, ScatterExplicitBatchAndOperandPassthroughDims) { +TEST_P(SpmdPartitioningAllShardingTest, + ScatterExplicitBatchAndOperandPassthroughDims) { absl::string_view hlo_string = R"( HloModule module @@ -9574,7 +9578,8 @@ ENTRY entry { } } -TEST_P(SpmdPartitioningTest, ScatterExplicitBatchAndIndexPassthroughDims1) { +TEST_P(SpmdPartitioningAllShardingTest, + ScatterExplicitBatchAndIndexPassthroughDims1) { absl::string_view hlo_string = R"( HloModule module @@ -9611,7 +9616,8 @@ ENTRY entry { } } -TEST_P(SpmdPartitioningTest, ScatterExplicitBatchAndIndexPassthroughDims2) { +TEST_P(SpmdPartitioningAllShardingTest, + ScatterExplicitBatchAndIndexPassthroughDims2) { absl::string_view hlo_string = R"( HloModule module @@ -9650,7 +9656,7 @@ ENTRY entry { } } -TEST_P(SpmdPartitioningTest, ScatterPartitionedOnTrivialSliceDims) { +TEST_P(SpmdPartitioningAllShardingTest, ScatterPartitionedOnTrivialSliceDims) { absl::string_view hlo_string = R"( HloModule module @@ -9688,7 +9694,8 @@ ENTRY entry { } } -TEST_P(SpmdPartitioningTest, ScatterPartitionedOnTrivialSliceDimsVariadic) { +TEST_P(SpmdPartitioningAllShardingTest, + ScatterPartitionedOnTrivialSliceDimsVariadic) { absl::string_view hlo_string = R"( HloModule module @@ -9734,7 +9741,7 @@ ENTRY entry { } } -TEST_P(SpmdPartitioningTest, +TEST_P(SpmdPartitioningAllShardingTest, ScatterPartitionedOnTrivialSliceDims_PartialReplicate) { absl::string_view hlo_string = R"( HloModule module @@ -9776,7 +9783,7 @@ ENTRY entry { } } -TEST_P(SpmdPartitioningTest, +TEST_P(SpmdPartitioningAllShardingTest, ScatterPartitionedOnTrivialSliceDimsVariadic_PartialReplicate) { absl::string_view hlo_string = R"( HloModule module @@ -12392,7 +12399,8 @@ ENTRY entry { EXPECT_THAT(root, op::Tuple(op::AllGather(gather), _, _)); } -TEST_P(SpmdPartitioningTest, ScatterRepsOnLastTileDimDontDivideGroups) { +TEST_P(SpmdPartitioningAllShardingTest, + ScatterRepsOnLastTileDimDontDivideGroups) { absl::string_view hlo_string = R"( HloModule module @@ -12435,7 +12443,8 @@ ENTRY entry { EXPECT_THAT(module->entry_computation()->root_instruction(), scatter); } -TEST_P(SpmdPartitioningTest, ParallelDimFromOutsideConditionalPositive) { +TEST_P(SpmdPartitioningAllShardingTest, + ParallelDimFromOutsideConditionalPositive) { absl::string_view hlo_string = R"( HloModule module @@ -12695,7 +12704,8 @@ ENTRY %main.14 (Arg_0.1: s32[4,32], Arg_1.2: s32[4]) -> s32[4] { EXPECT_THAT(root, op::AllReduce(op::Select(_, _, gather))); } -TEST_P(SpmdPartitioningTest, GatherMergedIndexParallelAndIndexPassthrough) { +TEST_P(SpmdPartitioningAllShardingTest, + GatherMergedIndexParallelAndIndexPassthrough) { absl::string_view hlo_string = R"( HloModule module @@ -12976,7 +12986,7 @@ ENTRY %module { } // Tests for Gather partitioning with SPMD config option. -TEST_P(SpmdPartitioningTest, +TEST_P(SpmdPartitioningAllShardingTest, GatherPartitionedOnTrivialSliceDimsForceTrivialSlice) { absl::string_view hlo_string = R"( HloModule module @@ -13007,7 +13017,7 @@ ENTRY entry { EXPECT_THAT(collective_permute, nullptr); } -TEST_P(SpmdPartitioningTest, +TEST_P(SpmdPartitioningAllShardingTest, GatherPartitionedOnTrivialSliceDimsForceIndexParallel) { absl::string_view hlo_string = R"( HloModule module @@ -13085,7 +13095,7 @@ ENTRY %module { EXPECT_THAT(root, op::AllGather(scatter)); } -TEST_P(SpmdPartitioningTest, ScatterParallelDimReplicatedIndices) { +TEST_P(SpmdPartitioningAllShardingTest, ScatterParallelDimReplicatedIndices) { absl::string_view hlo_string = R"( HloModule module @@ -13129,7 +13139,7 @@ ENTRY %module { EXPECT_THAT(root, op::AllGather(scatter)); } -TEST_P(SpmdPartitioningTest, ScatterParallelDimReplicatedOperand) { +TEST_P(SpmdPartitioningAllShardingTest, ScatterParallelDimReplicatedOperand) { absl::string_view hlo_string = R"( HloModule module @@ -13172,7 +13182,7 @@ ENTRY %module { EXPECT_THAT(root, op::AllGather(scatter)); } -TEST_P(SpmdPartitioningTest, ScatterParallelDimReplicatedUpdate) { +TEST_P(SpmdPartitioningAllShardingTest, ScatterParallelDimReplicatedUpdate) { absl::string_view hlo_string = R"( HloModule module @@ -13215,7 +13225,8 @@ ENTRY %module { EXPECT_THAT(root, op::AllGather(scatter)); } -TEST_P(SpmdPartitioningTest, ScatterParallelDimPartialReplicatedIndices) { +TEST_P(SpmdPartitioningAllShardingTest, + ScatterParallelDimPartialReplicatedIndices) { absl::string_view hlo_string = R"( HloModule module @@ -13259,7 +13270,8 @@ ENTRY %module { EXPECT_THAT(root, op::AllGather(scatter)); } -TEST_P(SpmdPartitioningTest, ScatterParallelDimPartialReplicatedOperand) { +TEST_P(SpmdPartitioningAllShardingTest, + ScatterParallelDimPartialReplicatedOperand) { absl::string_view hlo_string = R"( HloModule module @@ -13303,7 +13315,8 @@ ENTRY %module { EXPECT_THAT(root, op::AllGather(scatter)); } -TEST_P(SpmdPartitioningTest, ScatterParallelDimPartialReplicatedUpdate) { +TEST_P(SpmdPartitioningAllShardingTest, + ScatterParallelDimPartialReplicatedUpdate) { absl::string_view hlo_string = R"( HloModule module @@ -13347,7 +13360,7 @@ ENTRY %module { EXPECT_THAT(root, op::AllGather(scatter)); } -TEST_P(SpmdPartitioningTest, ScatterParallelDimSwappedDimensions) { +TEST_P(SpmdPartitioningAllShardingTest, ScatterParallelDimSwappedDimensions) { absl::string_view hlo_string = R"( HloModule module @@ -13391,7 +13404,8 @@ ENTRY %module { EXPECT_THAT(root, op::AllGather(op::AllGather(scatter))); } -TEST_P(SpmdPartitioningTest, ScatterParallelDimFromOutsideWhilePositive) { +TEST_P(SpmdPartitioningAllShardingTest, + ScatterParallelDimFromOutsideWhilePositive) { absl::string_view hlo_string = R"( HloModule module @@ -13469,7 +13483,8 @@ ENTRY entry { EXPECT_THAT(root, op::Tuple(op::AllGather(scatter), _, _, _)); } -TEST_P(SpmdPartitioningTest, ScatterParallelDimAndNonParallelDimPartitioned) { +TEST_P(SpmdPartitioningAllShardingTest, + ScatterParallelDimAndNonParallelDimPartitioned) { absl::string_view hlo_string = R"( HloModule module @@ -13518,7 +13533,8 @@ ENTRY %module { op::DynamicSlice(op::AllReduce(scatter), _, _, _, _)))); } -TEST_P(SpmdPartitioningTest, GatherScatterPartitioningIndexParallelCase) { +TEST_P(SpmdPartitioningAllShardingTest, + GatherScatterPartitioningIndexParallelCase) { absl::string_view hlo_string = R"( HloModule jit__init @@ -13545,12 +13561,13 @@ ENTRY main.22 { const auto root = module->entry_computation()->root_instruction(); auto operand = AllOf(op::Shape("f32[16,2]"), op::Broadcast()); auto indices = AllOf(op::Shape("s32[8,2]"), op::Subtract()); - auto update = AllOf(op::Shape("f32[8]"), op::Broadcast(op::Constant())); + auto update = AllOf(op::Shape("f32[8]"), op::Broadcast()); EXPECT_THAT(root, AllOf(op::Shape("f32[16,2]"), op::Scatter(operand, indices, update))); } -TEST_P(SpmdPartitioningTest, ScatterMergedIndexParallelAndOperandPassthrough) { +TEST_P(SpmdPartitioningAllShardingTest, + ScatterMergedIndexParallelAndOperandPassthrough) { absl::string_view hlo_string = R"( HloModule module @@ -13594,7 +13611,7 @@ ENTRY %module { EXPECT_THAT(root, op::AllGather(op::AllGather(scatter))); } -TEST_P(SpmdPartitioningTest, +TEST_P(SpmdPartitioningAllShardingTest, ScatterMergedIndexParallelAndTrivialSlicedOperand) { absl::string_view hlo_string = R"( HloModule module @@ -13639,7 +13656,8 @@ ENTRY %module { EXPECT_THAT(root, op::AllGather(op::AllGather(scatter))); } -TEST_P(SpmdPartitioningTest, ScatterMergedIndexParallelAndIndexPassthrough) { +TEST_P(SpmdPartitioningAllShardingTest, + ScatterMergedIndexParallelAndIndexPassthrough) { absl::string_view hlo_string = R"( HloModule module @@ -13691,7 +13709,7 @@ ENTRY %module { } } -TEST_P(SpmdPartitioningTest, +TEST_P(SpmdPartitioningAllShardingTest, ScatterMergedOperandPassthroughAndTrivialSlicedOperand) { absl::string_view hlo_string = R"( HloModule module @@ -13731,7 +13749,7 @@ ENTRY %module { EXPECT_THAT(root, op::AllGather(op::AllGather(op::AllGather(scatter)))); } -TEST_P(SpmdPartitioningTest, +TEST_P(SpmdPartitioningAllShardingTest, ScatterMergedOperandPassthroughAndIndexPassthrough) { absl::string_view hlo_string = R"( HloModule module @@ -13771,7 +13789,7 @@ ENTRY %module { EXPECT_THAT(root, op::AllGather(op::AllReduce(scatter))); } -TEST_P(SpmdPartitioningTest, +TEST_P(SpmdPartitioningAllShardingTest, ScatterMergedOperandPassthroughAndIndexPassthrough_PartialGrouping) { absl::string_view hlo_string = R"( HloModule module @@ -13812,7 +13830,7 @@ ENTRY %module { _, _, _)); } -TEST_P(SpmdPartitioningTest, +TEST_P(SpmdPartitioningAllShardingTest, ScatterMergedTrivialSlicedOperandAndIndexPassthrough) { absl::string_view hlo_string = R"( HloModule module @@ -13852,7 +13870,7 @@ ENTRY %module { EXPECT_THAT(root, op::AllGather(op::AllGather(op::AllReduce(scatter)))); } -TEST_P(SpmdPartitioningTest, +TEST_P(SpmdPartitioningAllShardingTest, ScatterMergedTrivialSlicedOperandAndIndexPassthrough_PartialGrouping) { absl::string_view hlo_string = R"( HloModule module @@ -13892,7 +13910,7 @@ ENTRY %module { EXPECT_THAT(root, op::AllGather(op::AllReduce(op::AllReduce(scatter)))); } -TEST_P(SpmdPartitioningTest, ScatterTrivialSlicedOperandPartial) { +TEST_P(SpmdPartitioningAllShardingTest, ScatterTrivialSlicedOperandPartial) { absl::string_view hlo_string = R"( HloModule module @@ -13927,7 +13945,7 @@ ENTRY main.4 { } // Tests for scatter partitioning methods with SPMD config option. -TEST_P(SpmdPartitioningTest, +TEST_P(SpmdPartitioningAllShardingTest, ScatterPartitionedOnTrivialSliceDimsForceTrivialSlice) { absl::string_view hlo_string = R"( HloModule module @@ -13966,7 +13984,7 @@ ENTRY entry { EXPECT_THAT(collective_permute, nullptr); } -TEST_P(SpmdPartitioningTest, +TEST_P(SpmdPartitioningAllShardingTest, ScatterPartitionedOnTrivialSliceDimsForceIndexParallel) { absl::string_view hlo_string = R"( HloModule module @@ -14423,7 +14441,8 @@ ENTRY entry { op::Shape("f32[32,16,24,512]"))); } -TEST_P(SpmdPartitioningTest, PartitionPassthroughScatterCorrectOutputSharding) { +TEST_P(SpmdPartitioningAllShardingTest, + PartitionPassthroughScatterCorrectOutputSharding) { absl::string_view hlo_string = R"( HloModule module @@ -14865,7 +14884,7 @@ ENTRY entry { EXPECT_EQ(root->replica_groups().size(), 2); } -TEST_P(SpmdPartitioningTest, ScatterPreferUpdateIndexIfSmaller) { +TEST_P(SpmdPartitioningAllShardingTest, ScatterPreferUpdateIndexIfSmaller) { absl::string_view hlo_string = R"( HloModule module @@ -14905,7 +14924,8 @@ ENTRY entry { op::Shape("bf16[512,1024,1020]"))))))); } -TEST_P(SpmdPartitioningTest, ScatterPreferTrivialIfSmallerThanIndices) { +TEST_P(SpmdPartitioningAllShardingTest, + ScatterPreferTrivialIfSmallerThanIndices) { absl::string_view hlo_string = R"( HloModule module @@ -14944,6 +14964,8 @@ ENTRY entry { op::Shape("bf16[32,256]")))))); } +// TODO(b/498965034): When device assignment is supported for v3 conversion, +// enable this test for V3. TEST_P(SpmdPartitioningTest, GatherOperandPassthroughIndexPassthrough) { const char* const hlo_string = R"( HloModule module @@ -15802,7 +15824,7 @@ ENTRY %main.21 { op::GetTupleElement(op::Reduce(_, _, _, _))); } -TEST_P(SpmdPartitioningTest, CombiningScatterPartitiong) { +TEST_P(SpmdPartitioningAllShardingTest, CombiningScatterPartitioning) { const char* const hlo_string = R"( HloModule pjit @@ -15906,7 +15928,7 @@ ENTRY %main.21 { EXPECT_THAT(gather, op::Shape("bf16[4096,32]")); } -TEST_P(SpmdPartitioningTest, ScatterCostModelForUnmatchedSharding) { +TEST_P(SpmdPartitioningAllShardingTest, ScatterCostModelForUnmatchedSharding) { const char* const hlo_string = R"( HloModule pjit @@ -15935,7 +15957,7 @@ ENTRY %main.21 { EXPECT_THAT(updates, op::Shape("bf16[4096,64]")); } -TEST_P(SpmdPartitioningTest, ScatterAllOperandsAreSameInstruction) { +TEST_P(SpmdPartitioningAllShardingTest, ScatterAllOperandsAreSameInstruction) { const char* const hlo_string = R"( HloModule pjit diff --git a/xla/service/spmd/spmd_partitioner_util.cc b/xla/service/spmd/spmd_partitioner_util.cc index 83b11466aedf0..e8d4fda17a9fc 100644 --- a/xla/service/spmd/spmd_partitioner_util.cc +++ b/xla/service/spmd/spmd_partitioner_util.cc @@ -2661,6 +2661,208 @@ HloSharding CreateMatchingShardingOnDims( target_dims, source_sharding, source_dims); } +namespace { + +bool IsAxisSuperset(const AxisRef& available, const AxisRef& requested) { + // If the mesh axes are not the same, they are not comparable. + if (available.mesh_axis_index() != requested.mesh_axis_index()) { + return false; + } + // If the available axis does not have sub-axis info, it represents the full + // mesh axis and is therefore a superset of any requested sub-axis. + if (!available.sub_axis_info().has_value()) { + return true; + } + // If the requested axis does not have sub-axis info, it represents the full + // mesh axis and is therefore not a subset of any requested sub-axis. + if (!requested.sub_axis_info().has_value()) { + return false; + } + // The available axis must contain the requested sub-axis. + return available.sub_axis_info()->pre_size <= + requested.sub_axis_info()->pre_size && + requested.sub_axis_info()->next_pre_size() <= + available.sub_axis_info()->next_pre_size(); +} +bool IsAxisUsedInSharding( + const AxisRef& axis, int64_t target_dim, + absl::Span dim_shardings) { + for (int64_t i = 0; i < dim_shardings.size(); ++i) { + if (i == target_dim) { + continue; + } + for (const AxisRef& other_axis : dim_shardings[i].axes()) { + if (axis.Overlaps(other_axis)) { + return true; + } + } + } + return false; +} + +void PropagateDimensionSharding( + const NamedSharding::DimensionSharding& source_dim_sharding, + int64_t target_dim, const Mesh& mesh, + std::vector* target_dim_shardings, + std::vector* target_replicated_axes) { + std::vector to_remove; + std::vector to_add; + + for (const AxisRef& axis : source_dim_sharding.axes()) { + for (const AxisRef& replicated_axis : *target_replicated_axes) { + if (replicated_axis.mesh_axis_index() != axis.mesh_axis_index()) { + continue; + } + if (!replicated_axis.sub_axis_info().has_value() && + axis.sub_axis_info().has_value()) { + to_remove.push_back(replicated_axis); + int64_t mesh_size = mesh.axis_size(axis.mesh_axis_index()); + int64_t p = axis.sub_axis_info()->pre_size; + int64_t s = axis.sub_axis_info()->size; + + if (p > 1) { + to_add.push_back(AxisRef(axis.mesh_axis_index(), {1, p})); + } + int64_t next_p = p * s; + int64_t rem_size = mesh_size / next_p; + if (rem_size > 1) { + to_add.push_back(AxisRef(axis.mesh_axis_index(), {next_p, rem_size})); + } + } else if (IsAxisSuperset(axis, replicated_axis)) { + to_remove.push_back(replicated_axis); + } + } + } + + (*target_dim_shardings)[target_dim] = source_dim_sharding; + std::erase_if(*target_replicated_axes, + [&to_remove](const auto& replicated_axis) { + return absl::c_linear_search(to_remove, replicated_axis); + }); + target_replicated_axes->insert(target_replicated_axes->end(), to_add.begin(), + to_add.end()); + SortAndMergeAxes(*target_replicated_axes, mesh); +} + +} // namespace + +std::optional +GatherScatterOperandsShardedAcrossParallelDimsNamedSharding( + const HloInstruction& operand, const HloInstruction& indices, + const hlo_sharding_util::GatherScatterDims& parallel_dims) { + const hlo_sharding_util::GatherScatterDims& dims = parallel_dims; + const DimensionVector& indices_parallel_dims = dims.indices_dims; + const DimensionVector& operand_parallel_dims = dims.operand_dims; + + const HloSharding& idx_sharding = indices.sharding(); + const HloSharding& op_sharding = operand.sharding(); + + if (idx_sharding.named_sharding().mesh() != + op_sharding.named_sharding().mesh()) { + return std::nullopt; + } + + absl::Span idx_dim_shardings = + idx_sharding.named_sharding().dim_shardings(); + absl::Span op_dim_shardings = + op_sharding.named_sharding().dim_shardings(); + + std::vector new_indices_dims; + std::vector new_operand_dims; + std::vector indices_rep; + std::vector operand_rep; + + bool changed = false; + + // Lazily initialize the new shardings only when we find a dimension that + // needs propagation. + auto ensure_initialized = [&]() { + if (changed) { + return; + } + changed = true; + if (idx_sharding.named_sharding().IsReplicated()) { + new_indices_dims.resize(indices.shape().dimensions().size()); + } else { + new_indices_dims.assign(idx_dim_shardings.begin(), + idx_dim_shardings.end()); + } + if (op_sharding.named_sharding().IsReplicated()) { + new_operand_dims.resize(operand.shape().dimensions().size()); + } else { + new_operand_dims.assign(op_dim_shardings.begin(), op_dim_shardings.end()); + } + indices_rep.assign(idx_sharding.named_sharding().replicated_axes().begin(), + idx_sharding.named_sharding().replicated_axes().end()); + operand_rep.assign(op_sharding.named_sharding().replicated_axes().begin(), + op_sharding.named_sharding().replicated_axes().end()); + }; + + for (int64_t i = 0; i < indices_parallel_dims.size(); ++i) { + int64_t idx_dim = indices_parallel_dims[i]; + int64_t op_dim = operand_parallel_dims[i]; + + bool idx_replicated = idx_sharding.named_sharding().IsReplicated() || + idx_dim >= idx_dim_shardings.size() || + idx_dim_shardings[idx_dim].axes().empty(); + bool op_replicated = op_sharding.named_sharding().IsReplicated() || + op_dim >= op_dim_shardings.size() || + op_dim_shardings[op_dim].axes().empty(); + + if (idx_replicated) { + if (!op_replicated) { + bool conflict = false; + for (const AxisRef& axis : op_dim_shardings[op_dim].axes()) { + if (IsAxisUsedInSharding(axis, idx_dim, idx_dim_shardings)) { + conflict = true; + break; + } + } + if (conflict) { + return std::nullopt; + } + ensure_initialized(); + PropagateDimensionSharding(op_dim_shardings[op_dim], idx_dim, + idx_sharding.named_sharding().mesh(), + &new_indices_dims, &indices_rep); + } + } else if (op_replicated) { + bool conflict = false; + for (const AxisRef& axis : idx_dim_shardings[idx_dim].axes()) { + if (IsAxisUsedInSharding(axis, op_dim, op_dim_shardings)) { + conflict = true; + break; + } + } + if (conflict) { + return std::nullopt; + } + ensure_initialized(); + PropagateDimensionSharding(idx_dim_shardings[idx_dim], op_dim, + op_sharding.named_sharding().mesh(), + &new_operand_dims, &operand_rep); + } else { + absl::Span idx_axes = idx_dim_shardings[idx_dim].axes(); + absl::Span op_axes = op_dim_shardings[op_dim].axes(); + if (idx_axes != op_axes) { + return std::nullopt; + } + } + } + + if (!changed) { + return std::nullopt; + } + + return GatherScatterParallelDimSharding{ + HloSharding(NamedSharding( + idx_sharding.named_sharding().mesh(), std::move(new_indices_dims), + std::move(indices_rep), idx_sharding.named_sharding().manual_axes())), + HloSharding(NamedSharding( + op_sharding.named_sharding().mesh(), std::move(new_operand_dims), + std::move(operand_rep), op_sharding.named_sharding().manual_axes()))}; +} + std::optional GatherScatterOperandsShardedAcrossParallelDims( const HloInstruction& operand, const HloInstruction& indices, @@ -2670,8 +2872,27 @@ GatherScatterOperandsShardedAcrossParallelDims( if (indices_parallel_dims.size() != operand_parallel_dims.size()) { return std::nullopt; } - auto new_index_shard = indices.sharding(); - auto new_operand_shard = operand.sharding(); + const HloSharding& idx_sharding = indices.sharding(); + const HloSharding& op_sharding = operand.sharding(); + + bool indices_v3 = idx_sharding.UseNamedShardingLeaf(); + bool operand_v3 = op_sharding.UseNamedShardingLeaf(); + + if (indices_v3 && operand_v3) { + auto res = GatherScatterOperandsShardedAcrossParallelDimsNamedSharding( + operand, indices, parallel_dims); + if (res) { + return res; + } + } + + HloSharding new_index_shard = + indices_v3 ? HloSharding::V3ToV2Sharding(idx_sharding.named_sharding()) + : idx_sharding; + HloSharding new_operand_shard = + operand_v3 ? HloSharding::V3ToV2Sharding(op_sharding.named_sharding()) + : op_sharding; + int idx_parallel_tiles_num = new_index_shard.NumTiles(indices_parallel_dims); int op_parallel_tiles_num = new_operand_shard.NumTiles(operand_parallel_dims); if (idx_parallel_tiles_num == 1 && op_parallel_tiles_num == 1) { diff --git a/xla/service/spmd/spmd_partitioner_util_test.cc b/xla/service/spmd/spmd_partitioner_util_test.cc index c797a70a5772a..f3c5c71c58d69 100644 --- a/xla/service/spmd/spmd_partitioner_util_test.cc +++ b/xla/service/spmd/spmd_partitioner_util_test.cc @@ -16,22 +16,32 @@ limitations under the License. #include "xla/service/spmd/spmd_partitioner_util.h" #include +#include #include #include #include #include +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/ir/mesh_and_axis.h" #include "xla/hlo/ir/named_sharding.h" #include "xla/hlo/ir/replica_group.h" #include "xla/hlo/ir/tile_assignment.h" #include "xla/service/spmd/spmd_partitioner_util_internal.h" +#include "xla/shape.h" namespace xla { namespace spmd { namespace { +std::unique_ptr CreateParameterWithSharding( + int64_t index, const Shape& shape, const NamedSharding& sharding) { + auto inst = HloInstruction::CreateParameter(index, shape, "param"); + inst->set_sharding(HloSharding(sharding)); + return inst; +} + TEST(SPMDPartitionerUtilTest, PartialReplicateReshardCompatibleSharding1) { HloSharding partial_sharding = HloSharding::PartialTile(TileAssignment({1, 2, 2})); @@ -661,6 +671,91 @@ TEST(SPMDPartitionerUtilTest, } } +TEST(SPMDPartitionerUtilTest, + GatherScatterOperandsShardedAcrossParallelDimsNamedSharding) { + Mesh mesh({2, 2}, {"a", "b"}); + NamedSharding indices_named = test_utils::FromAxisNames(mesh, {{"a"}}); + NamedSharding operand_named = test_utils::FromAxisNames(mesh, {{}}); + + auto result = GatherScatterOperandsShardedAcrossParallelDims( + *CreateParameterWithSharding(1, Shape(xla::PrimitiveType::F32, {20}), + operand_named), + *CreateParameterWithSharding(0, Shape(xla::PrimitiveType::F32, {10}), + indices_named), + /*parallel_dims=*/{{0}, {0}}); + + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(result->indices_sharding.named_sharding().dim_sharding(0).axes(), + indices_named.dim_sharding(0).axes()); + EXPECT_EQ(result->operand_sharding.named_sharding().dim_sharding(0).axes(), + indices_named.dim_sharding(0).axes()); +} + +TEST(SPMDPartitionerUtilTest, + GatherScatterOperandsShardedAcrossParallelDimsNamedShardingSubAxes) { + Mesh mesh({4, 2}, {"a", "b"}); + NamedSharding indices_named = test_utils::FromAxisNames(mesh, {{"a:(1)2"}}); + NamedSharding operand_named = test_utils::FromAxisNames(mesh, {{}}, {"a"}); + + auto result = GatherScatterOperandsShardedAcrossParallelDims( + *CreateParameterWithSharding(1, Shape(xla::PrimitiveType::F32, {20}), + operand_named), + *CreateParameterWithSharding(0, Shape(xla::PrimitiveType::F32, {10}), + indices_named), + /*parallel_dims=*/{{0}, {0}}); + + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(result->indices_sharding.named_sharding().dim_sharding(0).axes(), + indices_named.dim_sharding(0).axes()); + EXPECT_EQ(result->operand_sharding.named_sharding().dim_sharding(0).axes(), + indices_named.dim_sharding(0).axes()); + EXPECT_EQ( + result->operand_sharding.named_sharding().replicated_axes(), + test_utils::FromAxisNames(mesh, {{}}, {"a:(2)2"}).replicated_axes()); +} + +TEST(SPMDPartitionerUtilTest, + GatherScatterOperandsShardedAcrossParallelDimsNamedShardingExplicit) { + Mesh mesh({2, 2}, {"a", "b"}); + NamedSharding indices_named = test_utils::FromAxisNames(mesh, {{"a"}}); + NamedSharding operand_named = test_utils::FromAxisNames(mesh, {{}}, {"a"}); + + auto result = GatherScatterOperandsShardedAcrossParallelDims( + *CreateParameterWithSharding(1, Shape(xla::PrimitiveType::F32, {20}), + operand_named), + *CreateParameterWithSharding(0, Shape(xla::PrimitiveType::F32, {10}), + indices_named), + /*parallel_dims=*/{{0}, {0}}); + + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(result->operand_sharding.named_sharding().dim_sharding(0).axes(), + indices_named.dim_sharding(0).axes()); + EXPECT_TRUE( + result->operand_sharding.named_sharding().replicated_axes().empty()); +} + +TEST(SPMDPartitionerUtilTest, + GatherScatterOperandsShardedAcrossParallelDimsNamedShardingMixed) { + Mesh mesh({2, 2}, {"a", "b"}); + NamedSharding indices_named = test_utils::FromAxisNames(mesh, {{"a"}, {}}); + NamedSharding operand_named = + test_utils::FromAxisNames(mesh, {{}, {"b"}}, {"a"}); + + auto result = GatherScatterOperandsShardedAcrossParallelDims( + *CreateParameterWithSharding(1, Shape(xla::PrimitiveType::F32, {20, 20}), + operand_named), + *CreateParameterWithSharding(0, Shape(xla::PrimitiveType::F32, {10, 10}), + indices_named), + /*parallel_dims=*/{{0, 1}, {0, 1}}); + + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(result->operand_sharding.named_sharding().dim_sharding(0).axes(), + test_utils::FromAxisNames(mesh, {{"a"}}).dim_sharding(0).axes()); + EXPECT_EQ( + result->indices_sharding.named_sharding().dim_sharding(1).axes(), + test_utils::FromAxisNames(mesh, {{}, {"b"}}).dim_sharding(1).axes()); +} + } // namespace } // namespace spmd } // namespace xla