Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1131,7 +1131,6 @@ cc_library(
"//xla/hlo/pass:hlo_pass",
"//xla/hlo/transforms/simplifiers:hlo_dce",
"//xla/hlo/utils:hlo_query",
"//xla/service/spmd/shardy:constants",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/algorithm:container",
Expand Down
18 changes: 0 additions & 18 deletions xla/service/call_inliner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ limitations under the License.
#include "xla/hlo/utils/hlo_query.h"
#include "xla/service/call_graph.h"
#include "xla/service/hlo_domain_isolator.h"
#include "xla/service/spmd/shardy/constants.h"
#include "xla/status_macros.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
Expand Down Expand Up @@ -364,23 +363,6 @@ bool CallInliner::IsInlineableCallOp(HloInstruction* instruction) const {
if (!prerequisite) {
return false;
}
if (instruction->GetModule()->config().use_shardy_partitioner() &&
(absl::StrContains(instruction->to_apply()->name(), "shmap_body") ||
absl::StrContains(instruction->to_apply()->name(),
sdy::kManualComputationFuncName.str()))) {
// TODO(b/436603025). Remove this special handling by marking the
// instruction as uninlineable with the frontend attribute.
//
// Specific inlining rules when needing to round-trip from MLIR->HLO->MLIR
// when using Shardy (github.com/openxla/shardy).
//
// - shmap_body: We do not want to inline the bodies of JAX shard maps to
// import them into an `sdy.ManualComputationOp`. This is for the MHLO
// round-trip pipeline
// - kManualComputationFuncName: Same as shmap_body except for the SDY
// round-trip pipeline.
return false;
}
return InlineComposites(instruction, composites_to_preserve_);
}

Expand Down
88 changes: 0 additions & 88 deletions xla/service/call_inliner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -513,94 +513,6 @@ TEST_F(CallInlinerTest, InlineCallWithOverriddenAttributeInlineableFalse) {
EXPECT_EQ(call, nullptr);
}

TEST_F(CallInlinerTest, UseShardyMhloToHloShmapBodyNotInlined) {
const char* const hloString = R"(
HloModule jit_f, entry_computation_layout={(f32[8,8]{1,0})->f32[8,8]{1,0}}

%prefix_shmap_body_suffix.4 (Arg_0.5: f32[1,8]) -> f32[1,8] {
%Arg_0.5 = f32[1,8]{1,0} parameter(0)
ROOT %add.6 = f32[1,8]{1,0} add(f32[1,8]{1,0} %Arg_0.5, f32[1,8]{1,0} %Arg_0.5), metadata={source_file="-" source_line=11}
}

ENTRY %main.10 (Arg_0.1: f32[8,8]) -> f32[8,8] {
%Arg_0.1 = f32[8,8]{1,0} parameter(0)
%custom-call.2 = f32[8,8]{1,0} custom-call(f32[8,8]{1,0} %Arg_0.1), custom_call_target="Sharding", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=3}
%custom-call.3 = f32[1,8]{1,0} custom-call(f32[8,8]{1,0} %custom-call.2), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={source_file="-" source_line=4}
%call.7 = f32[1,8]{1,0} call(f32[1,8]{1,0} %custom-call.3), to_apply=%prefix_shmap_body_suffix.4
%custom-call.8 = f32[1,8]{1,0} custom-call(f32[1,8]{1,0} %call.7), custom_call_target="Sharding", sharding={manual}, metadata={source_file="-" source_line=6}
ROOT %custom-call.9 = f32[8,8]{1,0} custom-call(f32[1,8]{1,0} %custom-call.8), custom_call_target="SPMDShardToFullShape", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=7}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hloString));
module->mutable_config().set_use_shardy_partitioner(true);
TF_ASSERT_OK_AND_ASSIGN(bool changed, CallInliner().Run(module.get()));
VLOG(1) << module->ToString();
// The single call in the module is not inlined.
EXPECT_FALSE(changed);

HloInstruction* call = FindInstruction(module.get(), xla::HloOpcode::kCall);
EXPECT_NE(call, nullptr);
EXPECT_TRUE(call->has_to_apply());
EXPECT_EQ(call->to_apply()->name(), "prefix_shmap_body_suffix.4");
}

// Don't inline when the name starts with "xla.sdy.manual_computation_body".
TEST_F(CallInlinerTest, UseShardManualComputationBodyNotInlined) {
const char* const hloString = R"(
HloModule jit_f, entry_computation_layout={(f32[8,8]{1,0})->f32[8,8]{1,0}}

%xla.sdy.manual_computation_body.4 (Arg_0.5: f32[1,8]) -> f32[1,8] {
%Arg_0.5 = f32[1,8]{1,0} parameter(0)
ROOT %add.6 = f32[1,8]{1,0} add(f32[1,8]{1,0} %Arg_0.5, f32[1,8]{1,0} %Arg_0.5), metadata={source_file="-" source_line=11}
}

ENTRY %main.10 (Arg_0.1: f32[8,8]) -> f32[8,8] {
%Arg_0.1 = f32[8,8]{1,0} parameter(0)
%custom-call.3 = f32[1,8]{1,0} custom-call(f32[8,8]{1,0} %Arg_0.1), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={source_file="-" source_line=4}
%call.7 = f32[1,8]{1,0} call(f32[1,8]{1,0} %custom-call.3), to_apply=%xla.sdy.manual_computation_body.4
ROOT %custom-call.9 = f32[8,8]{1,0} custom-call(f32[1,8]{1,0} %call.7), custom_call_target="SPMDShardToFullShape", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=7}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hloString));
module->mutable_config().set_use_shardy_partitioner(true);
TF_ASSERT_OK_AND_ASSIGN(bool changed, CallInliner().Run(module.get()));
// The single call in the module is not inlined.
EXPECT_FALSE(changed);

HloInstruction* call = FindInstruction(module.get(), xla::HloOpcode::kCall);
EXPECT_NE(call, nullptr);
EXPECT_TRUE(call->has_to_apply());
EXPECT_EQ(call->to_apply()->name(), "xla.sdy.manual_computation_body.4");
}

// Make sure we check the name of the called function contains the string, not
// just the prefix/suffix.
TEST_F(CallInlinerTest, UseShardManualComputationBodySurroundedNotInlined) {
const char* const hloString = R"(
HloModule jit_f, entry_computation_layout={(f32[8,8]{1,0})->f32[8,8]{1,0}}

%my_model.___call__.fwd.xla.sdy.manual_computation_body_14.1234 (Arg_0.5: f32[1,8]) -> f32[1,8] {
%Arg_0.5 = f32[1,8]{1,0} parameter(0)
ROOT %add.6 = f32[1,8]{1,0} add(f32[1,8]{1,0} %Arg_0.5, f32[1,8]{1,0} %Arg_0.5), metadata={source_file="-" source_line=11}
}

ENTRY %main.10 (Arg_0.1: f32[8,8]) -> f32[8,8] {
%Arg_0.1 = f32[8,8]{1,0} parameter(0)
%custom-call.3 = f32[1,8]{1,0} custom-call(f32[8,8]{1,0} %Arg_0.1), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={source_file="-" source_line=4}
%call.7 = f32[1,8]{1,0} call(f32[1,8]{1,0} %custom-call.3), to_apply=%my_model.___call__.fwd.xla.sdy.manual_computation_body_14.1234
ROOT %custom-call.9 = f32[8,8]{1,0} custom-call(f32[1,8]{1,0} %call.7), custom_call_target="SPMDShardToFullShape", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=7}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hloString));
module->mutable_config().set_use_shardy_partitioner(true);
TF_ASSERT_OK_AND_ASSIGN(bool changed, CallInliner().Run(module.get()));
// The single call in the module is not inlined.
EXPECT_FALSE(changed);

HloInstruction* call = FindInstruction(module.get(), xla::HloOpcode::kCall);
EXPECT_NE(call, nullptr);
EXPECT_TRUE(call->has_to_apply());
EXPECT_EQ(call->to_apply()->name(),
"my_model.___call__.fwd.xla.sdy.manual_computation_body_14.1234");
}

TEST_F(CallInlinerTest, ControlDepsPropagateToRootOfInlinedInstructions) {
const char* hlo = R"(
HloModule test
Expand Down
6 changes: 0 additions & 6 deletions xla/service/spmd/shardy/shardy_xla_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -413,12 +413,6 @@ bool eraseInlineableAttrForShardyManualComputations(HloModule* module) {
absl::StrContains(instruction->to_apply()->name(),
sdy::kManualComputationFuncName.str())) {
instruction->erase_frontend_attribute(kXlaInlineableAttr);
// TODO(b/436603025). CallInliner do not inline the Shardy related
// manual computations based on the callee name. We have to rename the
// callee to a name such that it can be inlined. If we can remove the
// special handling in CallInliner, we can remove this renaming.
module->SetAndUniquifyComputationName(instruction->to_apply(),
"inlineable_callee");
changed = true;
}
}
Expand Down
1 change: 0 additions & 1 deletion xla/service/spmd/shardy/shardy_xla_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1114,7 +1114,6 @@ TEST_F(ShardyXLATest, UpdateInlineableAttr) {
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kCall);
EXPECT_FALSE(root->has_frontend_attributes());
EXPECT_EQ(root->to_apply()->name(), "inlineable_callee");
}

TEST_F(ShardyXLATest, ManualComputationCallOpWithToken) {
Expand Down
Loading