diff --git a/xprof/utils/hlo_module_map.h b/xprof/utils/hlo_module_map.h index 01309e09..3aea3ec4 100644 --- a/xprof/utils/hlo_module_map.h +++ b/xprof/utils/hlo_module_map.h @@ -30,6 +30,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_module_metadata.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_cost_analysis.h" @@ -62,7 +63,7 @@ class HloInstructionInterface { virtual void ProcessXlaCostAnalysis( const xla::HloCostAnalysis* cost_analysis) = 0; - virtual std::string OpLocationStack(int32_t frame_id) const = 0; + virtual std::string OpLocationStack(xla::StackFrameId frame_id) const = 0; virtual tsl::profiler::OpSourceInfo SourceInfo() const = 0; virtual const ::tensorflow::profiler::PerformanceInfoWrapper* GetPerformanceInfoWrapper() const = 0; @@ -128,7 +129,7 @@ class HloInstructionWrapper : public HloInstructionInterface { return fused_children_; } - std::string OpLocationStack(int32_t frame_id) const override { + std::string OpLocationStack(xla::StackFrameId frame_id) const override { return GetOpLocationStack(frame_id, *instr_); } diff --git a/xprof/utils/hlo_module_utils.cc b/xprof/utils/hlo_module_utils.cc index af686f22..dce32e68 100644 --- a/xprof/utils/hlo_module_utils.cc +++ b/xprof/utils/hlo_module_utils.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module_metadata.h" #include "xla/service/hlo.pb.h" #include "xla/tsl/profiler/convert/xla_op_utils.h" @@ -159,7 +160,9 @@ OpSourceInfo GetSourceInfo(const xla::HloInstructionProto& instr, OpSourceInfo GetSourceInfo(const xla::HloInstruction& instr) { const auto stack_frame_id = instr.metadata().stack_frame_id(); const std::string stack_frame = - stack_frame_id != 0 ? GetOpLocationStack(stack_frame_id, instr) : ""; + stack_frame_id != 0 + ? GetOpLocationStack(xla::StackFrameId{stack_frame_id}, instr) + : ""; return GetSourceInfo(instr.metadata().source_file(), instr.metadata().source_line(), stack_frame); } diff --git a/xprof/utils/hlo_module_utils.h b/xprof/utils/hlo_module_utils.h index 58573ef2..c8ecd776 100644 --- a/xprof/utils/hlo_module_utils.h +++ b/xprof/utils/hlo_module_utils.h @@ -27,9 +27,19 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_module_metadata.h" #include "xla/hlo/ir/hlo_print_options.h" #include "xla/tsl/profiler/convert/xla_op_utils.h" +// This is a temporary hack to work around version skew between XLA and xprof. +// This may be removed after XLA is updated past Feb 16, 2026. +#ifndef XLA_HAVE_STACK_FRAME_ID +namespace xla { +using StackFrameId = int; +} // namespace xla +#endif + + namespace tensorflow { namespace profiler { @@ -85,11 +95,15 @@ inline std::string UncachedExpression(const xla::HloInstruction& instr, return expression; } -inline std::string GetOpLocationStack(int32_t frame_id, +inline std::string GetOpLocationStack(xla::StackFrameId frame_id, const xla::HloInstruction& instr) { std::string stack_lines; xla::HloModule* hlo_module = instr.GetModule(); +#ifdef XLA_HAVE_STACK_FRAME_ID + while (frame_id.valid()) { +#else while (frame_id != 0) { +#endif xla::HloModule::StackFrame frame = hlo_module->get_stack_frame(frame_id); if (frame.empty()) { break; diff --git a/xprof/utils/hlo_module_utils_test.cc b/xprof/utils/hlo_module_utils_test.cc index a822d971..310abe69 100644 --- a/xprof/utils/hlo_module_utils_test.cc +++ b/xprof/utils/hlo_module_utils_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_module_metadata.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/platform/statusor.h" @@ -129,7 +130,7 @@ TEST_F(HloModuleUtilsTest, TestGetLocationStack) { GetModuleWithStackFrames()); const auto* root_instruction = module_with_stack_frames->entry_computation()->root_instruction(); - EXPECT_EQ(GetOpLocationStack(2, *root_instruction), + EXPECT_EQ(GetOpLocationStack(xla::StackFrameId{2}, *root_instruction), "main.py:20:1\nmain.py:10:5\n"); }