Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion xla/hlo/utils/hlo_sharding_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,12 @@ bool IsSubTilingOrEqualNamedSharding(const Shape& potential_sharded_shape,
!sharding.manual_axes().empty()) {
return false;
}
CHECK(sub_mesh.DeviceAssignmentEquals(mesh));
if (!sub_mesh.DeviceAssignmentEquals(mesh)) {
return IsSubTilingOrEqualSharding(
potential_sharded_shape,
HloSharding::V3ToV2Sharding(potential_subsharding),
HloSharding::V3ToV2Sharding(sharding));
}

CHECK_EQ(potential_subsharding.num_dimensions(), sharding.num_dimensions());

Expand Down
1 change: 1 addition & 0 deletions xla/service/spmd/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ xla_cc_test(
srcs = ["spmd_partitioner_util_test.cc"],
deps = [
":spmd_partitioner",
"//xla:shape_util",
"//xla/hlo/ir:hlo",
"//xla/hlo/ir:mesh_and_axis",
"//xla/hlo/ir:named_sharding",
Expand Down
22 changes: 16 additions & 6 deletions xla/service/spmd/gather_scatter_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1544,8 +1544,12 @@ HloInstruction* SelectOperandForScatterIndexPassthroughDimensions(
// Update partition_id for partial replicate.
auto partition_id = indices.state().partition_id;
if (indices.sharding().HasPartialReplication()) {
HloSharding sharding_v2 =
indices.sharding().UseNamedShardingLeaf()
? HloSharding::V3ToV2Sharding(indices.sharding().named_sharding())
: indices.sharding();
auto sharding_grouped = hlo_sharding_util::GroupShardingOnDims(
indices.sharding(), {indices.sharding().SubgroupReplicationDim()});
sharding_v2, {sharding_v2.SubgroupReplicationDim()});
partition_id =
GetInGroupPartitionId(partition_id, sharding_grouped.device_groups, b);
}
Expand Down Expand Up @@ -1965,25 +1969,29 @@ absl::Status SpmdPartitioningVisitor::HandleScatterWithoutConflicts(
// guaranteed by the scatter semantics.
HloSharding indices_sharding = indices.sharding();
for (int64_t i = 0; i < indices.num_dimensions(); ++i) {
int64_t dim_size =
indices_sharding.IsReplicated() ? 1 : indices_sharding.dimension(i);
if (indices.base_shape().dimensions(i) !=
indices_sharding.dimension(i) * indices.hlo()->shape().dimensions(i)) {
dim_size * indices.hlo()->shape().dimensions(i)) {
indices = indices.Replicate().Reshard(
indices_sharding, /*pad_value=*/LiteralUtil::CreateR0<int32_t>(-1));
break;
}
}

SpmdBuilder* b = builder();

std::vector<int64_t> partitioned_inserted_window_dims;
for (int64_t dim : dnums.inserted_window_dims()) {
if (operands[0].sharding().dimension(dim) > 1) {
int64_t dim_size = operands[0].sharding().IsReplicated()
? 1
: operands[0].sharding().dimension(dim);
if (dim_size > 1) {
// TODO(b/496605332). inserted_window_dims may not be in
// scatter_dims_to_operand_dims.
CHECK(absl::c_linear_search(dnums.scatter_dims_to_operand_dims(), dim));
partitioned_inserted_window_dims.push_back(dim);
}
}
SpmdBuilder* b = builder();
if (!partitioned_inserted_window_dims.empty()) {
HloInstruction* indices_min;
std::tie(indices_min, std::ignore) =
Expand Down Expand Up @@ -2125,8 +2133,10 @@ absl::Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) {
// Reshard indices with -1 padding, which will have no effect on the result as
// guaranteed by the scatter semantics.
for (auto i = 0; i != indices.num_dimensions(); ++i) {
int64_t dim_size =
indices_sharding.IsReplicated() ? 1 : indices_sharding.dimension(i);
if (indices.base_shape().dimensions(i) !=
indices_sharding.dimension(i) * indices.hlo()->shape().dimensions(i)) {
dim_size * indices.hlo()->shape().dimensions(i)) {
// Reshard only when we know that some dimension is padded.
indices = indices.Replicate().Reshard(
indices_sharding, /*pad_value=*/LiteralUtil::CreateR0<int32_t>(-1));
Expand Down
36 changes: 24 additions & 12 deletions xla/service/spmd/spmd_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2167,11 +2167,23 @@ std::optional<PartitionedHlo> PartitionedHlo::TryComplexReshardHandling(
const HloSharding& target) const {
VLOG(5) << "Trying to split complicated reshard: " << sharding().ToString()
<< " to " << target.ToString();

// TODO(b/498118846): Remove V2 conversion and support V3 directly.
HloSharding source_v2 =
sharding().UseNamedShardingLeaf()
? HloSharding::V3ToV2Sharding(sharding().named_sharding())
: sharding();
HloSharding target_v2 =
target.UseNamedShardingLeaf()
? HloSharding::V3ToV2Sharding(target.named_sharding())
: target;

const bool is_source_partially_replicated =
sharding().ReplicateOnLastTileDim();
const bool is_target_partially_replicated = target.ReplicateOnLastTileDim();
source_v2.ReplicateOnLastTileDim();
const bool is_target_partially_replicated =
target_v2.ReplicateOnLastTileDim();
if (auto reshape = PatternMatchMergeOrSplitSharding(this->base_shape(),
sharding(), target)) {
source_v2, target_v2)) {
auto& [before_sharding, new_reshaped_sharding, source_dim] = *reshape;
PartitionedHlo reshaped = SplitReshapeHelper(
*this, source_dim, this->hlo()->shape().dimensions(source_dim),
Expand Down Expand Up @@ -2199,7 +2211,7 @@ std::optional<PartitionedHlo> PartitionedHlo::TryComplexReshardHandling(
return reshaped;
}
if (auto intermediate_target =
PatternMatchPartiallyReplicateDim(sharding(), target)) {
PatternMatchPartiallyReplicateDim(source_v2, target_v2)) {
VLOG(5) << "Matched \"pattern_match_partially_replicate_dim()\": "
<< intermediate_target->ToString();
auto intermediate_reshard = Reshard(*intermediate_target);
Expand All @@ -2211,17 +2223,17 @@ std::optional<PartitionedHlo> PartitionedHlo::TryComplexReshardHandling(
return final_reshard;
}
if (is_source_partially_replicated && !is_target_partially_replicated) {
const int64_t partial_repl_amount = sharding().dimensions().back();
const int64_t partial_repl_amount = source_v2.dimensions().back();
int64_t first_different_dimension = -1;
// Trying to match conditions like [..,X,..,Z,..,Y] last_tile_dim_replicate
// to [..,Y,..,Z,..,X,..], where Y in the source is partially replicated,
// but in the target it is not and some other dimension got moved or
// modified. Try to remove the partial replication to simplify the step from
// source to target sharding.
for (int64_t i = 0; i < target.num_dimensions(); ++i) {
if (target.dimension(i) != sharding().dimension(i) &&
sharding().dimension(i) == 1 &&
target.dimension(i) % partial_repl_amount == 0) {
for (int64_t i = 0; i < target_v2.num_dimensions(); ++i) {
if (target_v2.dimension(i) != source_v2.dimension(i) &&
source_v2.dimension(i) == 1 &&
target_v2.dimension(i) % partial_repl_amount == 0) {
first_different_dimension = i;
break;
}
Expand All @@ -2230,12 +2242,12 @@ std::optional<PartitionedHlo> PartitionedHlo::TryComplexReshardHandling(
return std::nullopt;
}
VLOG(5) << "Matched partially replicated to non partially replicated: "
<< sharding().ToString();
std::vector<int64_t> transpose_dims(sharding().num_dimensions(), 0);
<< source_v2.ToString();
std::vector<int64_t> transpose_dims(source_v2.num_dimensions(), 0);
absl::c_iota(transpose_dims, 0);
std::swap(transpose_dims[first_different_dimension], transpose_dims.back());
auto intermediate_sharding =
hlo_sharding_util::TransposeSharding(sharding(), transpose_dims);
hlo_sharding_util::TransposeSharding(source_v2, transpose_dims);
auto intermediate_reshard = Reshard(intermediate_sharding);
auto reshard = intermediate_reshard.ReshardNoCache(
target, /*pad_value=*/std::nullopt, /*allow_full_replication=*/false);
Expand Down
Loading
Loading