diff --git a/xla/hlo/utils/hlo_sharding_util.cc b/xla/hlo/utils/hlo_sharding_util.cc index f15d2bdbb92ec..6de57d6209781 100644 --- a/xla/hlo/utils/hlo_sharding_util.cc +++ b/xla/hlo/utils/hlo_sharding_util.cc @@ -277,7 +277,12 @@ bool IsSubTilingOrEqualNamedSharding(const Shape& potential_sharded_shape, !sharding.manual_axes().empty()) { return false; } - CHECK(sub_mesh.DeviceAssignmentEquals(mesh)); + if (!sub_mesh.DeviceAssignmentEquals(mesh)) { + return IsSubTilingOrEqualSharding( + potential_sharded_shape, + HloSharding::V3ToV2Sharding(potential_subsharding), + HloSharding::V3ToV2Sharding(sharding)); + } CHECK_EQ(potential_subsharding.num_dimensions(), sharding.num_dimensions()); diff --git a/xla/service/spmd/BUILD b/xla/service/spmd/BUILD index 45181fcd60a53..634101f5cbcf4 100644 --- a/xla/service/spmd/BUILD +++ b/xla/service/spmd/BUILD @@ -440,6 +440,7 @@ xla_cc_test( srcs = ["spmd_partitioner_util_test.cc"], deps = [ ":spmd_partitioner", + "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/hlo/ir:mesh_and_axis", "//xla/hlo/ir:named_sharding", diff --git a/xla/service/spmd/gather_scatter_handler.cc b/xla/service/spmd/gather_scatter_handler.cc index 9181db89c1a14..6da717268aca2 100644 --- a/xla/service/spmd/gather_scatter_handler.cc +++ b/xla/service/spmd/gather_scatter_handler.cc @@ -1544,8 +1544,12 @@ HloInstruction* SelectOperandForScatterIndexPassthroughDimensions( // Update partition_id for partial replicate. auto partition_id = indices.state().partition_id; if (indices.sharding().HasPartialReplication()) { + HloSharding sharding_v2 = + indices.sharding().UseNamedShardingLeaf() + ? HloSharding::V3ToV2Sharding(indices.sharding().named_sharding()) + : indices.sharding(); auto sharding_grouped = hlo_sharding_util::GroupShardingOnDims( - indices.sharding(), {indices.sharding().SubgroupReplicationDim()}); + sharding_v2, {sharding_v2.SubgroupReplicationDim()}); partition_id = GetInGroupPartitionId(partition_id, sharding_grouped.device_groups, b); } @@ -1965,25 +1969,29 @@ absl::Status SpmdPartitioningVisitor::HandleScatterWithoutConflicts( // guaranteed by the scatter semantics. HloSharding indices_sharding = indices.sharding(); for (int64_t i = 0; i < indices.num_dimensions(); ++i) { + int64_t dim_size = + indices_sharding.IsReplicated() ? 1 : indices_sharding.dimension(i); if (indices.base_shape().dimensions(i) != - indices_sharding.dimension(i) * indices.hlo()->shape().dimensions(i)) { + dim_size * indices.hlo()->shape().dimensions(i)) { indices = indices.Replicate().Reshard( indices_sharding, /*pad_value=*/LiteralUtil::CreateR0(-1)); break; } } - SpmdBuilder* b = builder(); - std::vector partitioned_inserted_window_dims; for (int64_t dim : dnums.inserted_window_dims()) { - if (operands[0].sharding().dimension(dim) > 1) { + int64_t dim_size = operands[0].sharding().IsReplicated() + ? 1 + : operands[0].sharding().dimension(dim); + if (dim_size > 1) { // TODO(b/496605332). inserted_window_dims may not be in // scatter_dims_to_operand_dims. CHECK(absl::c_linear_search(dnums.scatter_dims_to_operand_dims(), dim)); partitioned_inserted_window_dims.push_back(dim); } } + SpmdBuilder* b = builder(); if (!partitioned_inserted_window_dims.empty()) { HloInstruction* indices_min; std::tie(indices_min, std::ignore) = @@ -2125,8 +2133,10 @@ absl::Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) { // Reshard indices with -1 padding, which will have no effect on the result as // guaranteed by the scatter semantics. for (auto i = 0; i != indices.num_dimensions(); ++i) { + int64_t dim_size = + indices_sharding.IsReplicated() ? 1 : indices_sharding.dimension(i); if (indices.base_shape().dimensions(i) != - indices_sharding.dimension(i) * indices.hlo()->shape().dimensions(i)) { + dim_size * indices.hlo()->shape().dimensions(i)) { // Reshard only when we know that some dimension is padded. indices = indices.Replicate().Reshard( indices_sharding, /*pad_value=*/LiteralUtil::CreateR0(-1)); diff --git a/xla/service/spmd/spmd_partitioner.cc b/xla/service/spmd/spmd_partitioner.cc index 2a3fdb179c700..04d324eba497f 100644 --- a/xla/service/spmd/spmd_partitioner.cc +++ b/xla/service/spmd/spmd_partitioner.cc @@ -2167,11 +2167,23 @@ std::optional PartitionedHlo::TryComplexReshardHandling( const HloSharding& target) const { VLOG(5) << "Trying to split complicated reshard: " << sharding().ToString() << " to " << target.ToString(); + + // TODO(b/498118846): Remove V2 conversion and support V3 directly. + HloSharding source_v2 = + sharding().UseNamedShardingLeaf() + ? HloSharding::V3ToV2Sharding(sharding().named_sharding()) + : sharding(); + HloSharding target_v2 = + target.UseNamedShardingLeaf() + ? HloSharding::V3ToV2Sharding(target.named_sharding()) + : target; + const bool is_source_partially_replicated = - sharding().ReplicateOnLastTileDim(); - const bool is_target_partially_replicated = target.ReplicateOnLastTileDim(); + source_v2.ReplicateOnLastTileDim(); + const bool is_target_partially_replicated = + target_v2.ReplicateOnLastTileDim(); if (auto reshape = PatternMatchMergeOrSplitSharding(this->base_shape(), - sharding(), target)) { + source_v2, target_v2)) { auto& [before_sharding, new_reshaped_sharding, source_dim] = *reshape; PartitionedHlo reshaped = SplitReshapeHelper( *this, source_dim, this->hlo()->shape().dimensions(source_dim), @@ -2199,7 +2211,7 @@ std::optional PartitionedHlo::TryComplexReshardHandling( return reshaped; } if (auto intermediate_target = - PatternMatchPartiallyReplicateDim(sharding(), target)) { + PatternMatchPartiallyReplicateDim(source_v2, target_v2)) { VLOG(5) << "Matched \"pattern_match_partially_replicate_dim()\": " << intermediate_target->ToString(); auto intermediate_reshard = Reshard(*intermediate_target); @@ -2211,17 +2223,17 @@ std::optional PartitionedHlo::TryComplexReshardHandling( return final_reshard; } if (is_source_partially_replicated && !is_target_partially_replicated) { - const int64_t partial_repl_amount = sharding().dimensions().back(); + const int64_t partial_repl_amount = source_v2.dimensions().back(); int64_t first_different_dimension = -1; // Trying to match conditions like [..,X,..,Z,..,Y] last_tile_dim_replicate // to [..,Y,..,Z,..,X,..], where Y in the source is partially replicated, // but in the target it is not and some other dimension got moved or // modified. Try to remove the partial replication to simplify the step from // source to target sharding. - for (int64_t i = 0; i < target.num_dimensions(); ++i) { - if (target.dimension(i) != sharding().dimension(i) && - sharding().dimension(i) == 1 && - target.dimension(i) % partial_repl_amount == 0) { + for (int64_t i = 0; i < target_v2.num_dimensions(); ++i) { + if (target_v2.dimension(i) != source_v2.dimension(i) && + source_v2.dimension(i) == 1 && + target_v2.dimension(i) % partial_repl_amount == 0) { first_different_dimension = i; break; } @@ -2230,12 +2242,12 @@ std::optional PartitionedHlo::TryComplexReshardHandling( return std::nullopt; } VLOG(5) << "Matched partially replicated to non partially replicated: " - << sharding().ToString(); - std::vector transpose_dims(sharding().num_dimensions(), 0); + << source_v2.ToString(); + std::vector transpose_dims(source_v2.num_dimensions(), 0); absl::c_iota(transpose_dims, 0); std::swap(transpose_dims[first_different_dimension], transpose_dims.back()); auto intermediate_sharding = - hlo_sharding_util::TransposeSharding(sharding(), transpose_dims); + hlo_sharding_util::TransposeSharding(source_v2, transpose_dims); auto intermediate_reshard = Reshard(intermediate_sharding); auto reshard = intermediate_reshard.ReshardNoCache( target, /*pad_value=*/std::nullopt, /*allow_full_replication=*/false); diff --git a/xla/service/spmd/spmd_partitioner_test.cc b/xla/service/spmd/spmd_partitioner_test.cc index 500346b73e7d7..1de06a23cf7aa 100644 --- a/xla/service/spmd/spmd_partitioner_test.cc +++ b/xla/service/spmd/spmd_partitioner_test.cc @@ -1771,7 +1771,7 @@ ENTRY entry { EXPECT_EQ(root->operand(0)->window().dimensions(0).padding_high(), 0); } -TEST_P(SpmdPartitioningTest, SelectAndScatterNoOverlap) { +TEST_P(SpmdPartitioningAllShardingTest, SelectAndScatterNoOverlap) { absl::string_view hlo_string = R"( HloModule module @@ -1821,7 +1821,7 @@ ENTRY entry { EXPECT_EQ(root->window().dimensions(0).padding_high(), 0); } -TEST_P(SpmdPartitioningTest, SelectAndScatterNoOverlapReshard) { +TEST_P(SpmdPartitioningAllShardingTest, SelectAndScatterNoOverlapReshard) { absl::string_view hlo_string = R"( HloModule module @@ -1875,7 +1875,7 @@ ENTRY entry { EXPECT_EQ(root->window().dimensions(0).padding_high(), 0); } -TEST_P(SpmdPartitioningTest, SelectAndScatterWithOverlap) { +TEST_P(SpmdPartitioningAllShardingTest, SelectAndScatterWithOverlap) { absl::string_view hlo_string = R"( HloModule module @@ -5250,7 +5250,7 @@ ENTRY entry { EXPECT_THAT(root, AllOf(op::While(zero), op::Shape("s32[]"))); } -TEST_P(SpmdPartitioningTest, SelectAndScatter_RetinaNet) { +TEST_P(SpmdPartitioningAllShardingTest, SelectAndScatter_RetinaNet) { absl::string_view hlo_string = R"( HloModule module @@ -9028,7 +9028,7 @@ ENTRY entry { } } -TEST_P(SpmdPartitioningTest, UnpartitionedScatter) { +TEST_P(SpmdPartitioningAllShardingTest, UnpartitionedScatter) { absl::string_view hlo_string = R"( HloModule module @@ -9062,7 +9062,7 @@ ENTRY entry { op::Shape("f32[2,5]"))); } -TEST_P(SpmdPartitioningTest, VariadicScatter) { +TEST_P(SpmdPartitioningAllShardingTest, VariadicScatter) { absl::string_view hlo_string = R"( HloModule module @@ -9108,7 +9108,7 @@ ENTRY entry { op::Shape("(f32[2,3],f32[2,3])"))); } -TEST_P(SpmdPartitioningTest, VariadicScatterSharedOperands) { +TEST_P(SpmdPartitioningAllShardingTest, VariadicScatterSharedOperands) { absl::string_view hlo_string = R"( HloModule module @@ -9153,7 +9153,7 @@ ENTRY entry { } } -TEST_P(SpmdPartitioningTest, PassthroughScatter) { +TEST_P(SpmdPartitioningAllShardingTest, PassthroughScatter) { absl::string_view hlo_string = R"( HloModule module @@ -9188,7 +9188,7 @@ ENTRY entry { } } -TEST_P(SpmdPartitioningTest, PassthroughScatterVariadic) { +TEST_P(SpmdPartitioningAllShardingTest, PassthroughScatterVariadic) { absl::string_view hlo_string = R"( HloModule module @@ -9231,7 +9231,7 @@ ENTRY entry { } } -TEST_P(SpmdPartitioningTest, PassthroughScatter_PartialReplicate) { +TEST_P(SpmdPartitioningAllShardingTest, PassthroughScatter_PartialReplicate) { absl::string_view hlo_string = R"( HloModule module @@ -9269,7 +9269,8 @@ ENTRY entry { } } -TEST_P(SpmdPartitioningTest, PassthroughScatterVariadic_PartialReplicate) { +TEST_P(SpmdPartitioningAllShardingTest, + PassthroughScatterVariadic_PartialReplicate) { absl::string_view hlo_string = R"( HloModule module @@ -9316,7 +9317,7 @@ ENTRY entry { } } -TEST_P(SpmdPartitioningTest, IndexPassthroughScatter) { +TEST_P(SpmdPartitioningAllShardingTest, IndexPassthroughScatter) { absl::string_view hlo_string = R"( HloModule module @@ -9355,7 +9356,8 @@ ENTRY entry { } } -TEST_P(SpmdPartitioningTest, IndexPassthroughScatter_PartialReplicate) { +TEST_P(SpmdPartitioningAllShardingTest, + IndexPassthroughScatter_PartialReplicate) { absl::string_view hlo_string = R"( HloModule module @@ -9396,7 +9398,8 @@ ENTRY entry { } } -TEST_P(SpmdPartitioningTest, IndexPassthroughScatterPartitionedIndexVectorDim) { +TEST_P(SpmdPartitioningAllShardingTest, + IndexPassthroughScatterPartitionedIndexVectorDim) { absl::string_view hlo_string = R"( HloModule module @@ -9430,7 +9433,7 @@ ENTRY entry { EXPECT_THAT(root, op::AllReduce(op::AllReduce(scatter))); } -TEST_P(SpmdPartitioningTest, IndexPassthroughScatterReshardIndices) { +TEST_P(SpmdPartitioningAllShardingTest, IndexPassthroughScatterReshardIndices) { absl::string_view hlo_string = R"( HloModule module @@ -9463,7 +9466,7 @@ ENTRY entry { op::AllReduce(op::AllReduce(scatter))); } -TEST_P(SpmdPartitioningTest, IndexPassthroughScatter_Min) { +TEST_P(SpmdPartitioningAllShardingTest, IndexPassthroughScatter_Min) { absl::string_view hlo_string = R"( HloModule module @@ -9502,7 +9505,7 @@ ENTRY entry { } } -TEST_P(SpmdPartitioningTest, ScatterExplicitBatchDims) { +TEST_P(SpmdPartitioningAllShardingTest, ScatterExplicitBatchDims) { absl::string_view hlo_string = R"( HloModule module @@ -9538,7 +9541,8 @@ ENTRY entry { } } -TEST_P(SpmdPartitioningTest, ScatterExplicitBatchAndOperandPassthroughDims) { +TEST_P(SpmdPartitioningAllShardingTest, + ScatterExplicitBatchAndOperandPassthroughDims) { absl::string_view hlo_string = R"( HloModule module @@ -9574,7 +9578,8 @@ ENTRY entry { } } -TEST_P(SpmdPartitioningTest, ScatterExplicitBatchAndIndexPassthroughDims1) { +TEST_P(SpmdPartitioningAllShardingTest, + ScatterExplicitBatchAndIndexPassthroughDims1) { absl::string_view hlo_string = R"( HloModule module @@ -9611,7 +9616,8 @@ ENTRY entry { } } -TEST_P(SpmdPartitioningTest, ScatterExplicitBatchAndIndexPassthroughDims2) { +TEST_P(SpmdPartitioningAllShardingTest, + ScatterExplicitBatchAndIndexPassthroughDims2) { absl::string_view hlo_string = R"( HloModule module @@ -9650,7 +9656,7 @@ ENTRY entry { } } -TEST_P(SpmdPartitioningTest, ScatterPartitionedOnTrivialSliceDims) { +TEST_P(SpmdPartitioningAllShardingTest, ScatterPartitionedOnTrivialSliceDims) { absl::string_view hlo_string = R"( HloModule module @@ -9688,7 +9694,8 @@ ENTRY entry { } } -TEST_P(SpmdPartitioningTest, ScatterPartitionedOnTrivialSliceDimsVariadic) { +TEST_P(SpmdPartitioningAllShardingTest, + ScatterPartitionedOnTrivialSliceDimsVariadic) { absl::string_view hlo_string = R"( HloModule module @@ -9734,7 +9741,7 @@ ENTRY entry { } } -TEST_P(SpmdPartitioningTest, +TEST_P(SpmdPartitioningAllShardingTest, ScatterPartitionedOnTrivialSliceDims_PartialReplicate) { absl::string_view hlo_string = R"( HloModule module @@ -9776,7 +9783,7 @@ ENTRY entry { } } -TEST_P(SpmdPartitioningTest, +TEST_P(SpmdPartitioningAllShardingTest, ScatterPartitionedOnTrivialSliceDimsVariadic_PartialReplicate) { absl::string_view hlo_string = R"( HloModule module @@ -12392,7 +12399,8 @@ ENTRY entry { EXPECT_THAT(root, op::Tuple(op::AllGather(gather), _, _)); } -TEST_P(SpmdPartitioningTest, ScatterRepsOnLastTileDimDontDivideGroups) { +TEST_P(SpmdPartitioningAllShardingTest, + ScatterRepsOnLastTileDimDontDivideGroups) { absl::string_view hlo_string = R"( HloModule module @@ -12435,7 +12443,8 @@ ENTRY entry { EXPECT_THAT(module->entry_computation()->root_instruction(), scatter); } -TEST_P(SpmdPartitioningTest, ParallelDimFromOutsideConditionalPositive) { +TEST_P(SpmdPartitioningAllShardingTest, + ParallelDimFromOutsideConditionalPositive) { absl::string_view hlo_string = R"( HloModule module @@ -12695,7 +12704,8 @@ ENTRY %main.14 (Arg_0.1: s32[4,32], Arg_1.2: s32[4]) -> s32[4] { EXPECT_THAT(root, op::AllReduce(op::Select(_, _, gather))); } -TEST_P(SpmdPartitioningTest, GatherMergedIndexParallelAndIndexPassthrough) { +TEST_P(SpmdPartitioningAllShardingTest, + GatherMergedIndexParallelAndIndexPassthrough) { absl::string_view hlo_string = R"( HloModule module @@ -12976,7 +12986,7 @@ ENTRY %module { } // Tests for Gather partitioning with SPMD config option. -TEST_P(SpmdPartitioningTest, +TEST_P(SpmdPartitioningAllShardingTest, GatherPartitionedOnTrivialSliceDimsForceTrivialSlice) { absl::string_view hlo_string = R"( HloModule module @@ -13007,7 +13017,7 @@ ENTRY entry { EXPECT_THAT(collective_permute, nullptr); } -TEST_P(SpmdPartitioningTest, +TEST_P(SpmdPartitioningAllShardingTest, GatherPartitionedOnTrivialSliceDimsForceIndexParallel) { absl::string_view hlo_string = R"( HloModule module @@ -13085,7 +13095,7 @@ ENTRY %module { EXPECT_THAT(root, op::AllGather(scatter)); } -TEST_P(SpmdPartitioningTest, ScatterParallelDimReplicatedIndices) { +TEST_P(SpmdPartitioningAllShardingTest, ScatterParallelDimReplicatedIndices) { absl::string_view hlo_string = R"( HloModule module @@ -13129,7 +13139,7 @@ ENTRY %module { EXPECT_THAT(root, op::AllGather(scatter)); } -TEST_P(SpmdPartitioningTest, ScatterParallelDimReplicatedOperand) { +TEST_P(SpmdPartitioningAllShardingTest, ScatterParallelDimReplicatedOperand) { absl::string_view hlo_string = R"( HloModule module @@ -13172,7 +13182,7 @@ ENTRY %module { EXPECT_THAT(root, op::AllGather(scatter)); } -TEST_P(SpmdPartitioningTest, ScatterParallelDimReplicatedUpdate) { +TEST_P(SpmdPartitioningAllShardingTest, ScatterParallelDimReplicatedUpdate) { absl::string_view hlo_string = R"( HloModule module @@ -13215,7 +13225,8 @@ ENTRY %module { EXPECT_THAT(root, op::AllGather(scatter)); } -TEST_P(SpmdPartitioningTest, ScatterParallelDimPartialReplicatedIndices) { +TEST_P(SpmdPartitioningAllShardingTest, + ScatterParallelDimPartialReplicatedIndices) { absl::string_view hlo_string = R"( HloModule module @@ -13259,7 +13270,8 @@ ENTRY %module { EXPECT_THAT(root, op::AllGather(scatter)); } -TEST_P(SpmdPartitioningTest, ScatterParallelDimPartialReplicatedOperand) { +TEST_P(SpmdPartitioningAllShardingTest, + ScatterParallelDimPartialReplicatedOperand) { absl::string_view hlo_string = R"( HloModule module @@ -13303,7 +13315,8 @@ ENTRY %module { EXPECT_THAT(root, op::AllGather(scatter)); } -TEST_P(SpmdPartitioningTest, ScatterParallelDimPartialReplicatedUpdate) { +TEST_P(SpmdPartitioningAllShardingTest, + ScatterParallelDimPartialReplicatedUpdate) { absl::string_view hlo_string = R"( HloModule module @@ -13347,7 +13360,7 @@ ENTRY %module { EXPECT_THAT(root, op::AllGather(scatter)); } -TEST_P(SpmdPartitioningTest, ScatterParallelDimSwappedDimensions) { +TEST_P(SpmdPartitioningAllShardingTest, ScatterParallelDimSwappedDimensions) { absl::string_view hlo_string = R"( HloModule module @@ -13391,7 +13404,8 @@ ENTRY %module { EXPECT_THAT(root, op::AllGather(op::AllGather(scatter))); } -TEST_P(SpmdPartitioningTest, ScatterParallelDimFromOutsideWhilePositive) { +TEST_P(SpmdPartitioningAllShardingTest, + ScatterParallelDimFromOutsideWhilePositive) { absl::string_view hlo_string = R"( HloModule module @@ -13469,7 +13483,8 @@ ENTRY entry { EXPECT_THAT(root, op::Tuple(op::AllGather(scatter), _, _, _)); } -TEST_P(SpmdPartitioningTest, ScatterParallelDimAndNonParallelDimPartitioned) { +TEST_P(SpmdPartitioningAllShardingTest, + ScatterParallelDimAndNonParallelDimPartitioned) { absl::string_view hlo_string = R"( HloModule module @@ -13518,7 +13533,8 @@ ENTRY %module { op::DynamicSlice(op::AllReduce(scatter), _, _, _, _)))); } -TEST_P(SpmdPartitioningTest, GatherScatterPartitioningIndexParallelCase) { +TEST_P(SpmdPartitioningAllShardingTest, + GatherScatterPartitioningIndexParallelCase) { absl::string_view hlo_string = R"( HloModule jit__init @@ -13545,12 +13561,13 @@ ENTRY main.22 { const auto root = module->entry_computation()->root_instruction(); auto operand = AllOf(op::Shape("f32[16,2]"), op::Broadcast()); auto indices = AllOf(op::Shape("s32[8,2]"), op::Subtract()); - auto update = AllOf(op::Shape("f32[8]"), op::Broadcast(op::Constant())); + auto update = AllOf(op::Shape("f32[8]"), op::Broadcast()); EXPECT_THAT(root, AllOf(op::Shape("f32[16,2]"), op::Scatter(operand, indices, update))); } -TEST_P(SpmdPartitioningTest, ScatterMergedIndexParallelAndOperandPassthrough) { +TEST_P(SpmdPartitioningAllShardingTest, + ScatterMergedIndexParallelAndOperandPassthrough) { absl::string_view hlo_string = R"( HloModule module @@ -13594,7 +13611,7 @@ ENTRY %module { EXPECT_THAT(root, op::AllGather(op::AllGather(scatter))); } -TEST_P(SpmdPartitioningTest, +TEST_P(SpmdPartitioningAllShardingTest, ScatterMergedIndexParallelAndTrivialSlicedOperand) { absl::string_view hlo_string = R"( HloModule module @@ -13639,7 +13656,8 @@ ENTRY %module { EXPECT_THAT(root, op::AllGather(op::AllGather(scatter))); } -TEST_P(SpmdPartitioningTest, ScatterMergedIndexParallelAndIndexPassthrough) { +TEST_P(SpmdPartitioningAllShardingTest, + ScatterMergedIndexParallelAndIndexPassthrough) { absl::string_view hlo_string = R"( HloModule module @@ -13691,7 +13709,7 @@ ENTRY %module { } } -TEST_P(SpmdPartitioningTest, +TEST_P(SpmdPartitioningAllShardingTest, ScatterMergedOperandPassthroughAndTrivialSlicedOperand) { absl::string_view hlo_string = R"( HloModule module @@ -13731,7 +13749,7 @@ ENTRY %module { EXPECT_THAT(root, op::AllGather(op::AllGather(op::AllGather(scatter)))); } -TEST_P(SpmdPartitioningTest, +TEST_P(SpmdPartitioningAllShardingTest, ScatterMergedOperandPassthroughAndIndexPassthrough) { absl::string_view hlo_string = R"( HloModule module @@ -13771,7 +13789,7 @@ ENTRY %module { EXPECT_THAT(root, op::AllGather(op::AllReduce(scatter))); } -TEST_P(SpmdPartitioningTest, +TEST_P(SpmdPartitioningAllShardingTest, ScatterMergedOperandPassthroughAndIndexPassthrough_PartialGrouping) { absl::string_view hlo_string = R"( HloModule module @@ -13812,7 +13830,7 @@ ENTRY %module { _, _, _)); } -TEST_P(SpmdPartitioningTest, +TEST_P(SpmdPartitioningAllShardingTest, ScatterMergedTrivialSlicedOperandAndIndexPassthrough) { absl::string_view hlo_string = R"( HloModule module @@ -13852,7 +13870,7 @@ ENTRY %module { EXPECT_THAT(root, op::AllGather(op::AllGather(op::AllReduce(scatter)))); } -TEST_P(SpmdPartitioningTest, +TEST_P(SpmdPartitioningAllShardingTest, ScatterMergedTrivialSlicedOperandAndIndexPassthrough_PartialGrouping) { absl::string_view hlo_string = R"( HloModule module @@ -13892,7 +13910,7 @@ ENTRY %module { EXPECT_THAT(root, op::AllGather(op::AllReduce(op::AllReduce(scatter)))); } -TEST_P(SpmdPartitioningTest, ScatterTrivialSlicedOperandPartial) { +TEST_P(SpmdPartitioningAllShardingTest, ScatterTrivialSlicedOperandPartial) { absl::string_view hlo_string = R"( HloModule module @@ -13927,7 +13945,7 @@ ENTRY main.4 { } // Tests for scatter partitioning methods with SPMD config option. -TEST_P(SpmdPartitioningTest, +TEST_P(SpmdPartitioningAllShardingTest, ScatterPartitionedOnTrivialSliceDimsForceTrivialSlice) { absl::string_view hlo_string = R"( HloModule module @@ -13966,7 +13984,7 @@ ENTRY entry { EXPECT_THAT(collective_permute, nullptr); } -TEST_P(SpmdPartitioningTest, +TEST_P(SpmdPartitioningAllShardingTest, ScatterPartitionedOnTrivialSliceDimsForceIndexParallel) { absl::string_view hlo_string = R"( HloModule module @@ -14423,7 +14441,8 @@ ENTRY entry { op::Shape("f32[32,16,24,512]"))); } -TEST_P(SpmdPartitioningTest, PartitionPassthroughScatterCorrectOutputSharding) { +TEST_P(SpmdPartitioningAllShardingTest, + PartitionPassthroughScatterCorrectOutputSharding) { absl::string_view hlo_string = R"( HloModule module @@ -14865,7 +14884,7 @@ ENTRY entry { EXPECT_EQ(root->replica_groups().size(), 2); } -TEST_P(SpmdPartitioningTest, ScatterPreferUpdateIndexIfSmaller) { +TEST_P(SpmdPartitioningAllShardingTest, ScatterPreferUpdateIndexIfSmaller) { absl::string_view hlo_string = R"( HloModule module @@ -14905,7 +14924,8 @@ ENTRY entry { op::Shape("bf16[512,1024,1020]"))))))); } -TEST_P(SpmdPartitioningTest, ScatterPreferTrivialIfSmallerThanIndices) { +TEST_P(SpmdPartitioningAllShardingTest, + ScatterPreferTrivialIfSmallerThanIndices) { absl::string_view hlo_string = R"( HloModule module @@ -14944,6 +14964,8 @@ ENTRY entry { op::Shape("bf16[32,256]")))))); } +// TODO(b/498965034): When device assignment is supported for v3 conversion, +// enable this test for V3. TEST_P(SpmdPartitioningTest, GatherOperandPassthroughIndexPassthrough) { const char* const hlo_string = R"( HloModule module @@ -15802,7 +15824,7 @@ ENTRY %main.21 { op::GetTupleElement(op::Reduce(_, _, _, _))); } -TEST_P(SpmdPartitioningTest, CombiningScatterPartitiong) { +TEST_P(SpmdPartitioningAllShardingTest, CombiningScatterPartitioning) { const char* const hlo_string = R"( HloModule pjit @@ -15906,7 +15928,7 @@ ENTRY %main.21 { EXPECT_THAT(gather, op::Shape("bf16[4096,32]")); } -TEST_P(SpmdPartitioningTest, ScatterCostModelForUnmatchedSharding) { +TEST_P(SpmdPartitioningAllShardingTest, ScatterCostModelForUnmatchedSharding) { const char* const hlo_string = R"( HloModule pjit @@ -15935,7 +15957,7 @@ ENTRY %main.21 { EXPECT_THAT(updates, op::Shape("bf16[4096,64]")); } -TEST_P(SpmdPartitioningTest, ScatterAllOperandsAreSameInstruction) { +TEST_P(SpmdPartitioningAllShardingTest, ScatterAllOperandsAreSameInstruction) { const char* const hlo_string = R"( HloModule pjit diff --git a/xla/service/spmd/spmd_partitioner_util.cc b/xla/service/spmd/spmd_partitioner_util.cc index 83b11466aedf0..e8d4fda17a9fc 100644 --- a/xla/service/spmd/spmd_partitioner_util.cc +++ b/xla/service/spmd/spmd_partitioner_util.cc @@ -2661,6 +2661,208 @@ HloSharding CreateMatchingShardingOnDims( target_dims, source_sharding, source_dims); } +namespace { + +bool IsAxisSuperset(const AxisRef& available, const AxisRef& requested) { + // If the mesh axes are not the same, they are not comparable. + if (available.mesh_axis_index() != requested.mesh_axis_index()) { + return false; + } + // If the available axis does not have sub-axis info, it represents the full + // mesh axis and is therefore a superset of any requested sub-axis. + if (!available.sub_axis_info().has_value()) { + return true; + } + // If the requested axis does not have sub-axis info, it represents the full + // mesh axis and is therefore not a subset of any requested sub-axis. + if (!requested.sub_axis_info().has_value()) { + return false; + } + // The available axis must contain the requested sub-axis. + return available.sub_axis_info()->pre_size <= + requested.sub_axis_info()->pre_size && + requested.sub_axis_info()->next_pre_size() <= + available.sub_axis_info()->next_pre_size(); +} +bool IsAxisUsedInSharding( + const AxisRef& axis, int64_t target_dim, + absl::Span dim_shardings) { + for (int64_t i = 0; i < dim_shardings.size(); ++i) { + if (i == target_dim) { + continue; + } + for (const AxisRef& other_axis : dim_shardings[i].axes()) { + if (axis.Overlaps(other_axis)) { + return true; + } + } + } + return false; +} + +void PropagateDimensionSharding( + const NamedSharding::DimensionSharding& source_dim_sharding, + int64_t target_dim, const Mesh& mesh, + std::vector* target_dim_shardings, + std::vector* target_replicated_axes) { + std::vector to_remove; + std::vector to_add; + + for (const AxisRef& axis : source_dim_sharding.axes()) { + for (const AxisRef& replicated_axis : *target_replicated_axes) { + if (replicated_axis.mesh_axis_index() != axis.mesh_axis_index()) { + continue; + } + if (!replicated_axis.sub_axis_info().has_value() && + axis.sub_axis_info().has_value()) { + to_remove.push_back(replicated_axis); + int64_t mesh_size = mesh.axis_size(axis.mesh_axis_index()); + int64_t p = axis.sub_axis_info()->pre_size; + int64_t s = axis.sub_axis_info()->size; + + if (p > 1) { + to_add.push_back(AxisRef(axis.mesh_axis_index(), {1, p})); + } + int64_t next_p = p * s; + int64_t rem_size = mesh_size / next_p; + if (rem_size > 1) { + to_add.push_back(AxisRef(axis.mesh_axis_index(), {next_p, rem_size})); + } + } else if (IsAxisSuperset(axis, replicated_axis)) { + to_remove.push_back(replicated_axis); + } + } + } + + (*target_dim_shardings)[target_dim] = source_dim_sharding; + std::erase_if(*target_replicated_axes, + [&to_remove](const auto& replicated_axis) { + return absl::c_linear_search(to_remove, replicated_axis); + }); + target_replicated_axes->insert(target_replicated_axes->end(), to_add.begin(), + to_add.end()); + SortAndMergeAxes(*target_replicated_axes, mesh); +} + +} // namespace + +std::optional +GatherScatterOperandsShardedAcrossParallelDimsNamedSharding( + const HloInstruction& operand, const HloInstruction& indices, + const hlo_sharding_util::GatherScatterDims& parallel_dims) { + const hlo_sharding_util::GatherScatterDims& dims = parallel_dims; + const DimensionVector& indices_parallel_dims = dims.indices_dims; + const DimensionVector& operand_parallel_dims = dims.operand_dims; + + const HloSharding& idx_sharding = indices.sharding(); + const HloSharding& op_sharding = operand.sharding(); + + if (idx_sharding.named_sharding().mesh() != + op_sharding.named_sharding().mesh()) { + return std::nullopt; + } + + absl::Span idx_dim_shardings = + idx_sharding.named_sharding().dim_shardings(); + absl::Span op_dim_shardings = + op_sharding.named_sharding().dim_shardings(); + + std::vector new_indices_dims; + std::vector new_operand_dims; + std::vector indices_rep; + std::vector operand_rep; + + bool changed = false; + + // Lazily initialize the new shardings only when we find a dimension that + // needs propagation. + auto ensure_initialized = [&]() { + if (changed) { + return; + } + changed = true; + if (idx_sharding.named_sharding().IsReplicated()) { + new_indices_dims.resize(indices.shape().dimensions().size()); + } else { + new_indices_dims.assign(idx_dim_shardings.begin(), + idx_dim_shardings.end()); + } + if (op_sharding.named_sharding().IsReplicated()) { + new_operand_dims.resize(operand.shape().dimensions().size()); + } else { + new_operand_dims.assign(op_dim_shardings.begin(), op_dim_shardings.end()); + } + indices_rep.assign(idx_sharding.named_sharding().replicated_axes().begin(), + idx_sharding.named_sharding().replicated_axes().end()); + operand_rep.assign(op_sharding.named_sharding().replicated_axes().begin(), + op_sharding.named_sharding().replicated_axes().end()); + }; + + for (int64_t i = 0; i < indices_parallel_dims.size(); ++i) { + int64_t idx_dim = indices_parallel_dims[i]; + int64_t op_dim = operand_parallel_dims[i]; + + bool idx_replicated = idx_sharding.named_sharding().IsReplicated() || + idx_dim >= idx_dim_shardings.size() || + idx_dim_shardings[idx_dim].axes().empty(); + bool op_replicated = op_sharding.named_sharding().IsReplicated() || + op_dim >= op_dim_shardings.size() || + op_dim_shardings[op_dim].axes().empty(); + + if (idx_replicated) { + if (!op_replicated) { + bool conflict = false; + for (const AxisRef& axis : op_dim_shardings[op_dim].axes()) { + if (IsAxisUsedInSharding(axis, idx_dim, idx_dim_shardings)) { + conflict = true; + break; + } + } + if (conflict) { + return std::nullopt; + } + ensure_initialized(); + PropagateDimensionSharding(op_dim_shardings[op_dim], idx_dim, + idx_sharding.named_sharding().mesh(), + &new_indices_dims, &indices_rep); + } + } else if (op_replicated) { + bool conflict = false; + for (const AxisRef& axis : idx_dim_shardings[idx_dim].axes()) { + if (IsAxisUsedInSharding(axis, op_dim, op_dim_shardings)) { + conflict = true; + break; + } + } + if (conflict) { + return std::nullopt; + } + ensure_initialized(); + PropagateDimensionSharding(idx_dim_shardings[idx_dim], op_dim, + op_sharding.named_sharding().mesh(), + &new_operand_dims, &operand_rep); + } else { + absl::Span idx_axes = idx_dim_shardings[idx_dim].axes(); + absl::Span op_axes = op_dim_shardings[op_dim].axes(); + if (idx_axes != op_axes) { + return std::nullopt; + } + } + } + + if (!changed) { + return std::nullopt; + } + + return GatherScatterParallelDimSharding{ + HloSharding(NamedSharding( + idx_sharding.named_sharding().mesh(), std::move(new_indices_dims), + std::move(indices_rep), idx_sharding.named_sharding().manual_axes())), + HloSharding(NamedSharding( + op_sharding.named_sharding().mesh(), std::move(new_operand_dims), + std::move(operand_rep), op_sharding.named_sharding().manual_axes()))}; +} + std::optional GatherScatterOperandsShardedAcrossParallelDims( const HloInstruction& operand, const HloInstruction& indices, @@ -2670,8 +2872,27 @@ GatherScatterOperandsShardedAcrossParallelDims( if (indices_parallel_dims.size() != operand_parallel_dims.size()) { return std::nullopt; } - auto new_index_shard = indices.sharding(); - auto new_operand_shard = operand.sharding(); + const HloSharding& idx_sharding = indices.sharding(); + const HloSharding& op_sharding = operand.sharding(); + + bool indices_v3 = idx_sharding.UseNamedShardingLeaf(); + bool operand_v3 = op_sharding.UseNamedShardingLeaf(); + + if (indices_v3 && operand_v3) { + auto res = GatherScatterOperandsShardedAcrossParallelDimsNamedSharding( + operand, indices, parallel_dims); + if (res) { + return res; + } + } + + HloSharding new_index_shard = + indices_v3 ? HloSharding::V3ToV2Sharding(idx_sharding.named_sharding()) + : idx_sharding; + HloSharding new_operand_shard = + operand_v3 ? HloSharding::V3ToV2Sharding(op_sharding.named_sharding()) + : op_sharding; + int idx_parallel_tiles_num = new_index_shard.NumTiles(indices_parallel_dims); int op_parallel_tiles_num = new_operand_shard.NumTiles(operand_parallel_dims); if (idx_parallel_tiles_num == 1 && op_parallel_tiles_num == 1) { diff --git a/xla/service/spmd/spmd_partitioner_util_test.cc b/xla/service/spmd/spmd_partitioner_util_test.cc index c797a70a5772a..f3c5c71c58d69 100644 --- a/xla/service/spmd/spmd_partitioner_util_test.cc +++ b/xla/service/spmd/spmd_partitioner_util_test.cc @@ -16,22 +16,32 @@ limitations under the License. #include "xla/service/spmd/spmd_partitioner_util.h" #include +#include #include #include #include #include +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/ir/mesh_and_axis.h" #include "xla/hlo/ir/named_sharding.h" #include "xla/hlo/ir/replica_group.h" #include "xla/hlo/ir/tile_assignment.h" #include "xla/service/spmd/spmd_partitioner_util_internal.h" +#include "xla/shape.h" namespace xla { namespace spmd { namespace { +std::unique_ptr CreateParameterWithSharding( + int64_t index, const Shape& shape, const NamedSharding& sharding) { + auto inst = HloInstruction::CreateParameter(index, shape, "param"); + inst->set_sharding(HloSharding(sharding)); + return inst; +} + TEST(SPMDPartitionerUtilTest, PartialReplicateReshardCompatibleSharding1) { HloSharding partial_sharding = HloSharding::PartialTile(TileAssignment({1, 2, 2})); @@ -661,6 +671,91 @@ TEST(SPMDPartitionerUtilTest, } } +TEST(SPMDPartitionerUtilTest, + GatherScatterOperandsShardedAcrossParallelDimsNamedSharding) { + Mesh mesh({2, 2}, {"a", "b"}); + NamedSharding indices_named = test_utils::FromAxisNames(mesh, {{"a"}}); + NamedSharding operand_named = test_utils::FromAxisNames(mesh, {{}}); + + auto result = GatherScatterOperandsShardedAcrossParallelDims( + *CreateParameterWithSharding(1, Shape(xla::PrimitiveType::F32, {20}), + operand_named), + *CreateParameterWithSharding(0, Shape(xla::PrimitiveType::F32, {10}), + indices_named), + /*parallel_dims=*/{{0}, {0}}); + + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(result->indices_sharding.named_sharding().dim_sharding(0).axes(), + indices_named.dim_sharding(0).axes()); + EXPECT_EQ(result->operand_sharding.named_sharding().dim_sharding(0).axes(), + indices_named.dim_sharding(0).axes()); +} + +TEST(SPMDPartitionerUtilTest, + GatherScatterOperandsShardedAcrossParallelDimsNamedShardingSubAxes) { + Mesh mesh({4, 2}, {"a", "b"}); + NamedSharding indices_named = test_utils::FromAxisNames(mesh, {{"a:(1)2"}}); + NamedSharding operand_named = test_utils::FromAxisNames(mesh, {{}}, {"a"}); + + auto result = GatherScatterOperandsShardedAcrossParallelDims( + *CreateParameterWithSharding(1, Shape(xla::PrimitiveType::F32, {20}), + operand_named), + *CreateParameterWithSharding(0, Shape(xla::PrimitiveType::F32, {10}), + indices_named), + /*parallel_dims=*/{{0}, {0}}); + + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(result->indices_sharding.named_sharding().dim_sharding(0).axes(), + indices_named.dim_sharding(0).axes()); + EXPECT_EQ(result->operand_sharding.named_sharding().dim_sharding(0).axes(), + indices_named.dim_sharding(0).axes()); + EXPECT_EQ( + result->operand_sharding.named_sharding().replicated_axes(), + test_utils::FromAxisNames(mesh, {{}}, {"a:(2)2"}).replicated_axes()); +} + +TEST(SPMDPartitionerUtilTest, + GatherScatterOperandsShardedAcrossParallelDimsNamedShardingExplicit) { + Mesh mesh({2, 2}, {"a", "b"}); + NamedSharding indices_named = test_utils::FromAxisNames(mesh, {{"a"}}); + NamedSharding operand_named = test_utils::FromAxisNames(mesh, {{}}, {"a"}); + + auto result = GatherScatterOperandsShardedAcrossParallelDims( + *CreateParameterWithSharding(1, Shape(xla::PrimitiveType::F32, {20}), + operand_named), + *CreateParameterWithSharding(0, Shape(xla::PrimitiveType::F32, {10}), + indices_named), + /*parallel_dims=*/{{0}, {0}}); + + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(result->operand_sharding.named_sharding().dim_sharding(0).axes(), + indices_named.dim_sharding(0).axes()); + EXPECT_TRUE( + result->operand_sharding.named_sharding().replicated_axes().empty()); +} + +TEST(SPMDPartitionerUtilTest, + GatherScatterOperandsShardedAcrossParallelDimsNamedShardingMixed) { + Mesh mesh({2, 2}, {"a", "b"}); + NamedSharding indices_named = test_utils::FromAxisNames(mesh, {{"a"}, {}}); + NamedSharding operand_named = + test_utils::FromAxisNames(mesh, {{}, {"b"}}, {"a"}); + + auto result = GatherScatterOperandsShardedAcrossParallelDims( + *CreateParameterWithSharding(1, Shape(xla::PrimitiveType::F32, {20, 20}), + operand_named), + *CreateParameterWithSharding(0, Shape(xla::PrimitiveType::F32, {10, 10}), + indices_named), + /*parallel_dims=*/{{0, 1}, {0, 1}}); + + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(result->operand_sharding.named_sharding().dim_sharding(0).axes(), + test_utils::FromAxisNames(mesh, {{"a"}}).dim_sharding(0).axes()); + EXPECT_EQ( + result->indices_sharding.named_sharding().dim_sharding(1).axes(), + test_utils::FromAxisNames(mesh, {{}, {"b"}}).dim_sharding(1).axes()); +} + } // namespace } // namespace spmd } // namespace xla