From 672b879ef58e4fa299c836b7aaeeaa7a9acbe460 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Fri, 3 Apr 2026 13:09:50 -0700 Subject: [PATCH] PR #40376: [xla:gpu] TraceMe for while loop iterations Imported from GitHub PR https://github.com/openxla/xla/pull/40376 Add `TraceMe` for while loop iterations to make host profiles more informative Copybara import of the project: -- 94fe20bf68cee874695396461a095ddda9d44a1d by Eugene Zhulenev : [xla:gpu] TraceMe for while loop iterations Merging this change closes #40376 FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/40376 from ezhulenev:trace-while-iters 94fe20bf68cee874695396461a095ddda9d44a1d PiperOrigin-RevId: 894217452 --- xla/backends/gpu/runtime/while_thunk.cc | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/xla/backends/gpu/runtime/while_thunk.cc b/xla/backends/gpu/runtime/while_thunk.cc index 75febd8cd7f86..a8009456bdf24 100644 --- a/xla/backends/gpu/runtime/while_thunk.cc +++ b/xla/backends/gpu/runtime/while_thunk.cc @@ -86,6 +86,8 @@ absl::Status WhileThunk::ExecuteOnStream(const ExecuteParams& params) { XLA_VLOG_DEVICE(2, device_ordinal) << "Executing WhileThunk for " << *trip_count_ << " iterations"; for (size_t i = 0; i < trip_count_; loop.IncLoopIteration(), ++i) { + TraceMe trace( + [&] { return absl::StrFormat("[iter=%d]", loop.loop_iteration()); }); XLA_VLOG_DEVICE(3, device_ordinal) << "Executing iteration # " << i << " (Device: " << stream.parent()->device_ordinal() << ")"; @@ -106,9 +108,8 @@ absl::Status WhileThunk::ExecuteOnStream(const ExecuteParams& params) { condition_result_buffer_index_); for (;; loop.IncLoopIteration()) { - TraceMe trace([&] { - return TraceMeEncode("While", {{"iteration:", loop.loop_iteration()}}); - }); + TraceMe trace( + [&] { return absl::StrFormat("[iter=%d]", loop.loop_iteration()); }); XLA_VLOG_DEVICE(3, device_ordinal) << "Executing WhileThunk condition computation; iter=" @@ -120,9 +121,9 @@ absl::Status WhileThunk::ExecuteOnStream(const ExecuteParams& params) { stream.Memcpy(condition_result, condition_result_data, sizeof(bool))); if (absl::Status blocked = stream.BlockHostUntilDone(); !blocked.ok()) { - return absl::InternalError(absl::StrFormat( + return Internal( "Failed to complete all kernels launched on stream %p: %s", &stream, - blocked.message())); + blocked.message()); } XLA_VLOG_DEVICE(3, device_ordinal)