From 75411b4e4beb536c183f66e6b15e2b826465db4e Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Fri, 3 Apr 2026 09:43:15 -0700 Subject: [PATCH] PR #39201: [xla:gpu] Split legacy custom calls into LegacyCustomCallThunk MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Imported from GitHub PR https://github.com/openxla/xla/pull/39201 Split `CustomCallThunk` into two separate classes to cleanly separate FFI from legacy custom call APIs (to be deleted). This change is **NFC**, simply splitting existing thunk/command into two classes to clearly separate legacy calling conventions from FFI. - **`CustomCallThunk`** — now exclusively handles XLA FFI custom calls. Simplified by removing all legacy code paths, optional bundle checks, and the `ExecuteCustomCall()` fallback. The `bundle_` field is now a plain `std::variant` (always present) instead of `std::optional`. - **`LegacyCustomCallThunk`** (new) — handles deprecated legacy custom calls (`API_VERSION_STATUS_RETURNING`, etc.) using `CustomCallTargetRegistry`. Marked as deprecated in its class comment, directing users to FFI via `CustomCallThunk`. Also split `CustomCallCmd` in the command buffer layer into: - **`CustomCallCmd`** — FFI only, builds `ffi::CallFrame` and invokes FFI handlers directly. - **`LegacyCustomCallCmd`** — legacy only, traces custom call execution via `TraceCommandBufferFactory`. Unified `CustomCallThunk` accumulated too much complexity related to dispatching between two unrelated custom all ABIs, by splitting them into separate files we can keep on improving FFI without having to deal with legacy custom calls. Copybara import of the project: -- 2d33e5be1a6cbe6f2dfe3dfba5e961592d5074c8 by Eugene Zhulenev : [xla:gpu] Split legacy custom calls into LegacyCustomCallThunk Merging this change closes #39201 FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/39201 from ezhulenev:split-custom-call-thunk 2d33e5be1a6cbe6f2dfe3dfba5e961592d5074c8 PiperOrigin-RevId: 894119940 --- xla/backends/gpu/codegen/BUILD | 1 + xla/backends/gpu/codegen/custom.cc | 19 +- xla/backends/gpu/runtime/BUILD | 79 +++- .../gpu/runtime/command_buffer_cmd.cc | 66 +-- xla/backends/gpu/runtime/command_buffer_cmd.h | 75 +-- .../gpu/runtime/command_buffer_cmd_emitter.cc | 23 +- .../runtime/command_buffer_conversion_pass.cc | 10 +- xla/backends/gpu/runtime/custom_call_thunk.cc | 437 +++++------------- xla/backends/gpu/runtime/custom_call_thunk.h | 71 +-- .../gpu/runtime/custom_call_thunk_test.cc | 213 ++------- .../gpu/runtime/legacy_custom_call_thunk.cc | 224 +++++++++ .../gpu/runtime/legacy_custom_call_thunk.h | 120 +++++ .../runtime/legacy_custom_call_thunk_test.cc | 181 ++++++++ .../runtime/thunk_proto_deserialization.cc | 16 +- xla/service/gpu/BUILD | 1 + xla/service/gpu/thunk_emitter.cc | 10 +- 16 files changed, 903 insertions(+), 643 deletions(-) create mode 100644 xla/backends/gpu/runtime/legacy_custom_call_thunk.cc create mode 100644 xla/backends/gpu/runtime/legacy_custom_call_thunk.h create mode 100644 xla/backends/gpu/runtime/legacy_custom_call_thunk_test.cc diff --git a/xla/backends/gpu/codegen/BUILD b/xla/backends/gpu/codegen/BUILD index faa434da3e745..35fb29e3e9c92 100644 --- a/xla/backends/gpu/codegen/BUILD +++ b/xla/backends/gpu/codegen/BUILD @@ -144,6 +144,7 @@ cc_library( "//xla/backends/gpu/runtime:device_to_device_copy_thunk", "//xla/backends/gpu/runtime:dynamic_slice_thunk", "//xla/backends/gpu/runtime:gemm_thunk", + "//xla/backends/gpu/runtime:legacy_custom_call_thunk", "//xla/backends/gpu/runtime:thunk", "//xla/codegen/emitters:kernel_arguments", "//xla/ffi:attribute_map", diff --git a/xla/backends/gpu/codegen/custom.cc b/xla/backends/gpu/codegen/custom.cc index f4d8e7ee99ddf..d1260cbca8d6a 100644 --- a/xla/backends/gpu/codegen/custom.cc +++ b/xla/backends/gpu/codegen/custom.cc @@ -46,6 +46,7 @@ limitations under the License. #include "xla/backends/gpu/runtime/device_to_device_copy_thunk.h" #include "xla/backends/gpu/runtime/dynamic_slice_thunk.h" #include "xla/backends/gpu/runtime/gemm_thunk.h" +#include "xla/backends/gpu/runtime/legacy_custom_call_thunk.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/codegen/emitters/kernel_arguments.h" #include "xla/ffi/attribute_map.h" @@ -878,7 +879,7 @@ absl::StatusOr EmitCustomCall( // For legacy custom calls we convert all API versions into the latest // status-returning one and pass backend config as an opaque string. - CustomCallThunk::CustomCallTarget custom_call_target; + LegacyCustomCallThunk::CustomCallTarget custom_call_target; // For XLA FFI handlers we decode opaque backend config into attributes map // at IR emission time, so that we do not need to parse MLIR at run time. @@ -926,9 +927,8 @@ absl::StatusOr EmitCustomCall( auto thunk_info = Thunk::ThunkInfo::WithProfileAnnotation( &fusion, ir_emitter_context.GetNextThunkId()); - auto ffi_thunk = - [&](Slices ops, - Slices res) -> absl::StatusOr> { + auto ffi_thunk = [&](Slices ops, + Slices res) -> absl::StatusOr> { auto& called_computations = custom_call.called_computations(); auto& backend_config_str = backend_config.ok() @@ -954,16 +954,15 @@ absl::StatusOr EmitCustomCall( }; auto legacy_thunk = - [&](Slices ops, - Slices res) -> absl::StatusOr> { + [&](Slices ops, Slices res) -> absl::StatusOr> { std::string opaque = backend_config.ok() ? backend_config->custom_call_backend_config().opaque() : custom_call.raw_backend_config_string(); - return CustomCallThunk::Create(thunk_info, call_target_name, std::move(ops), - std::move(res), std::move(opaque), - custom_call.api_version(), - ir_emitter_context.platform_name()); + return LegacyCustomCallThunk::Create( + thunk_info, call_target_name, std::move(ops), std::move(res), + std::move(opaque), custom_call.api_version(), + ir_emitter_context.platform_name()); }; std::vector fake_allocations(num_args, {0, 0, 0}); diff --git a/xla/backends/gpu/runtime/BUILD b/xla/backends/gpu/runtime/BUILD index 3c366022aadf2..50e5d52b4b14f 100644 --- a/xla/backends/gpu/runtime/BUILD +++ b/xla/backends/gpu/runtime/BUILD @@ -188,6 +188,7 @@ cc_library( ":dynamic_memcpy_thunk", ":dynamic_slice_thunk", ":gpublas_lt_matmul_thunk", + ":legacy_custom_call_thunk", ":p2p_thunk_common", ":ragged_all_to_all_thunk", ":recv_thunk", @@ -345,6 +346,7 @@ cc_library( ":gemm_thunk", ":gpublas_lt_matmul_thunk", ":kernel_thunk", + ":legacy_custom_call_thunk", ":memset_thunk", ":ragged_all_to_all_thunk", ":recv_thunk", @@ -1082,8 +1084,8 @@ cc_library( ":collective_cliques", ":collective_memory", ":collective_params", - ":custom_call_target", ":thunk", + ":thunk_proto_cc", "//xla:executable_run_options", "//xla:shape_util", "//xla:util", @@ -1101,16 +1103,13 @@ cc_library( "//xla/runtime:buffer_use", "//xla/runtime:object_pool", "//xla/service:buffer_assignment", - "//xla/service:custom_call_status", - "//xla/service:custom_call_status_internal", - "//xla/service:custom_call_target_registry", + "//xla/service:hlo_proto_cc", "//xla/service:shaped_slice", "//xla/service/gpu:buffer_allocations", "//xla/stream_executor:device_address", "//xla/stream_executor:device_address_allocator", "//xla/stream_executor:device_description", "//xla/stream_executor:stream", - "//xla/tsl/platform:errors", "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/util:unique_any", @@ -1123,7 +1122,35 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "legacy_custom_call_thunk", + srcs = ["legacy_custom_call_thunk.cc"], + hdrs = ["legacy_custom_call_thunk.h"], + deps = [ + ":custom_call_target", + ":thunk", + ":thunk_proto_cc", + "//xla:util", + "//xla/runtime:buffer_use", + "//xla/service:buffer_assignment", + "//xla/service:custom_call_status", + "//xla/service:custom_call_status_internal", + "//xla/service:custom_call_target_registry", + "//xla/service:hlo_proto_cc", + "//xla/service:shaped_slice", + "//xla/service/gpu:buffer_allocations", + "//xla/stream_executor:stream", + "//xla/tsl/platform:status_macros", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/log", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@tsl//tsl/platform", @@ -1135,9 +1162,12 @@ xla_test( srcs = ["custom_call_thunk_test.cc"], backends = ["gpu"], deps = [ + ":collective_clique_requests", ":collective_memory_requests", + ":collective_params", ":custom_call_thunk", ":thunk", + ":thunk_proto_cc", "//xla:executable_run_options", "//xla:shape_util", "//xla/backends/cpu:target_machine_options", @@ -1149,10 +1179,9 @@ xla_test( "//xla/hlo/ir:hlo", "//xla/runtime:device_id", "//xla/service:buffer_assignment", - "//xla/service:custom_call_status_public_headers", - "//xla/service:custom_call_target_registry", "//xla/service:executable", "//xla/service:hlo_module_config", + "//xla/service:hlo_proto_cc", "//xla/service:platform_util", "//xla/service:shaped_slice", "//xla/service/gpu:buffer_allocations", @@ -1162,6 +1191,7 @@ xla_test( "//xla/stream_executor:platform_manager", "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_address_allocator", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/util/proto:parse_text_proto", "@com_google_absl//absl/base", @@ -1175,6 +1205,36 @@ xla_test( ], ) +xla_test( + name = "legacy_custom_call_thunk_test", + srcs = ["legacy_custom_call_thunk_test.cc"], + backends = ["gpu"], + deps = [ + ":legacy_custom_call_thunk", + ":thunk", + ":thunk_proto_cc", + "//xla/service:buffer_assignment", + "//xla/service:custom_call_status_public_headers", + "//xla/service:custom_call_target_registry", + "//xla/service:executable", + "//xla/service:hlo_proto_cc", + "//xla/service:platform_util", + "//xla/service/gpu:buffer_allocations", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_address_allocator", + "//xla/tsl/platform:status_macros", + "//xla/tsl/platform:statusor", + "//xla/tsl/util/proto:parse_text_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "fft_thunk", srcs = ["fft_thunk.cc"], @@ -3606,6 +3666,7 @@ cc_library( ":host_to_device_copy_thunk", ":infeed_thunk", ":kernel_thunk", + ":legacy_custom_call_thunk", ":memset_thunk", ":norm_thunk", ":nvshmem_all_reduce_thunk", @@ -3627,6 +3688,7 @@ cc_library( "//xla/backends/cpu:target_machine_options", "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", + "//xla/service:hlo_proto_cc", "//xla/stream_executor:device_description", "//xla/stream_executor:kernel_spec", "//xla/stream_executor:stream_executor_h", @@ -3639,6 +3701,7 @@ cc_library( "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:protobuf_lite", ], ) diff --git a/xla/backends/gpu/runtime/command_buffer_cmd.cc b/xla/backends/gpu/runtime/command_buffer_cmd.cc index eb41cb8bb362f..ec99954882444 100644 --- a/xla/backends/gpu/runtime/command_buffer_cmd.cc +++ b/xla/backends/gpu/runtime/command_buffer_cmd.cc @@ -63,6 +63,7 @@ limitations under the License. #include "xla/backends/gpu/runtime/dynamic_memcpy_thunk.h" #include "xla/backends/gpu/runtime/dynamic_slice_thunk.h" #include "xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.h" +#include "xla/backends/gpu/runtime/legacy_custom_call_thunk.h" #include "xla/backends/gpu/runtime/p2p_thunk_common.h" #include "xla/backends/gpu/runtime/ragged_all_to_all_thunk.h" #include "xla/backends/gpu/runtime/recv_thunk.h" @@ -1211,27 +1212,16 @@ Command::BufferUses CuDnnCmd::buffer_uses() const { } //===----------------------------------------------------------------------===// -// CustomCallCmd +// LegacyCustomCallCmd //===----------------------------------------------------------------------===// -absl::StatusOr CustomCallCmd::Record( - const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, RecordAction record_action, - se::CommandBuffer* command_buffer) { - if (handler_ == nullptr) { - return RecordLegacyCustomCall(execute_params, record_params, - std::move(record_action), command_buffer); - } - return RecordXlaFfiCall(execute_params, record_params, - std::move(record_action), command_buffer); -} - namespace { // Records each buffer associated with each slice into the provided vector. // Returns an error if any of the slices is missing a buffer allocation. -absl::Status GetBuffers(const Thunk::ExecuteParams& execute_params, - absl::Span slices, - std::vector& buffers, absl::string_view label) { +static absl::Status GetBuffers(const Thunk::ExecuteParams& execute_params, + absl::Span slices, + std::vector& buffers, + absl::string_view label) { for (int i = 0; i < slices.size(); ++i) { if (!slices[i].has_value()) { buffers.push_back(nullptr); @@ -1253,21 +1243,18 @@ absl::Status GetBuffers(const Thunk::ExecuteParams& execute_params, } } // namespace -absl::StatusOr -CustomCallCmd::RecordLegacyCustomCall( +absl::StatusOr LegacyCustomCallCmd::Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) { std::vector buffers; buffers.reserve(operands_.size() + results_.size()); - VLOG(5) << "CustomCallCmd: target_name=" << target_name_; - TF_RETURN_IF_ERROR( - GetBuffers(execute_params, operands_, buffers, " Operand ")); - TF_RETURN_IF_ERROR( - GetBuffers(execute_params, results_, buffers, " Result ")); + VLOG(5) << "LegacyCustomCallCmd: target_name=" << target_name_; + RETURN_IF_ERROR(GetBuffers(execute_params, operands_, buffers, " Operand ")); + RETURN_IF_ERROR(GetBuffers(execute_params, results_, buffers, " Result ")); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto nested_cmd, se::TraceCommandBufferFactory::Create( execute_params.stream->parent(), @@ -1293,11 +1280,26 @@ CustomCallCmd::RecordLegacyCustomCall( }); } -absl::StatusOr -CustomCallCmd::RecordXlaFfiCall(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - RecordAction record_action, - se::CommandBuffer* command_buffer) { +Command::BufferUses LegacyCustomCallCmd::buffer_uses() const { + Command::BufferUses buffer_usage; + for (auto& slices : {operands_, results_}) { + for (const std::optional& slice : slices) { + if (slice.has_value()) { + buffer_usage.push_back(BufferUse::Write(slice->slice, slice->shape)); + } + } + } + return buffer_usage; +} + +//===----------------------------------------------------------------------===// +// CustomCallCmd (FFI) +//===----------------------------------------------------------------------===// + +absl::StatusOr CustomCallCmd::Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, RecordAction record_action, + se::CommandBuffer* command_buffer) { // TODO(ezhulenev): This is not the most optimal approach, as we'll be doing // a lot of extra allocation on every call. We have to keep attributes // separate from arguments, as they do not change after thunk is @@ -1342,8 +1344,8 @@ CustomCallCmd::RecordXlaFfiCall(const Thunk::ExecuteParams& execute_params, // Borrow the FFI call frame from the object pool and update with the actual // device memory addresses. - TF_ASSIGN_OR_RETURN(auto call_frame, call_frames_->GetOrCreate()); - TF_RETURN_IF_ERROR(call_frame->UpdateWithBuffers(arguments, results)); + ASSIGN_OR_RETURN(auto call_frame, call_frames_->GetOrCreate()); + RETURN_IF_ERROR(call_frame->UpdateWithBuffers(arguments, results)); RunId run_id = execute_params.collective_params->run_id; @@ -1360,7 +1362,7 @@ CustomCallCmd::RecordXlaFfiCall(const Thunk::ExecuteParams& execute_params, } } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto nested_cmd, se::TraceCommandBufferFactory::Create( execute_params.stream->parent(), diff --git a/xla/backends/gpu/runtime/command_buffer_cmd.h b/xla/backends/gpu/runtime/command_buffer_cmd.h index ee072b630ea83..263e04c54e0f3 100644 --- a/xla/backends/gpu/runtime/command_buffer_cmd.h +++ b/xla/backends/gpu/runtime/command_buffer_cmd.h @@ -37,10 +37,10 @@ limitations under the License. #include "xla/backends/gpu/runtime/command.h" #include "xla/backends/gpu/runtime/command_executor.h" #include "xla/backends/gpu/runtime/command_state.h" -#include "xla/backends/gpu/runtime/custom_call_thunk.h" #include "xla/backends/gpu/runtime/dynamic_memcpy_thunk.h" #include "xla/backends/gpu/runtime/dynamic_slice_thunk.h" #include "xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.h" +#include "xla/backends/gpu/runtime/legacy_custom_call_thunk.h" #include "xla/backends/gpu/runtime/p2p_thunk_common.h" #include "xla/backends/gpu/runtime/ragged_all_to_all_thunk.h" #include "xla/backends/gpu/runtime/thunk.h" @@ -451,27 +451,11 @@ class CuDnnCmd : public TracedCommandBufferCmd { }; //===----------------------------------------------------------------------===// -// CustomCallCmd +// CustomCallCmd (FFI) //===----------------------------------------------------------------------===// class CustomCallCmd : public Command { public: - using CustomCallTarget = CustomCallThunk::CustomCallTarget; - using AttributesMap = ffi::AttributesMap; - - // This is a legacy custom call API that is discouraged, and will be - // deprecated once XLA:FFI mechanism is ready. - CustomCallCmd(std::string target_name, CustomCallTarget call_target, - std::vector operands, - std::vector results, - absl::string_view opaque) - : Command(CommandType::kCustomCallCmd), - target_name_(std::move(target_name)), - call_target_(std::move(call_target)), - opaque_(opaque), - operands_(std::move(operands)), - results_(std::move(results)) {} - CustomCallCmd(std::string target_name, XLA_FFI_Handler* handler, std::vector operands, std::vector results, @@ -498,26 +482,8 @@ class CustomCallCmd : public Command { bool IsNestedCommandBuffer() const final { return true; } private: - absl::StatusOr RecordLegacyCustomCall( - const Thunk::ExecuteParams& execute_param, - const RecordParams& record_params, RecordAction record_action, - se::CommandBuffer* command_buffer); - - absl::StatusOr RecordXlaFfiCall( - const Thunk::ExecuteParams& execute_param, - const RecordParams& record_params, RecordAction record_action, - se::CommandBuffer* command_buffer); - std::string target_name_; - // This is a legacy custom call API that is discouraged, and will be - // deprecated once XLA:FFI mechanism is ready. - CustomCallTarget call_target_; - std::string opaque_; - - // XLA FFI provides a right type safe mechanism for registering external - // functions with XLA runtime. It's under construction, and still misses - // a lot of features. Long term it will replace legacy custom calls. XLA_FFI_Handler* handler_ = nullptr; // Reference call frame pre-initialized at construction time. @@ -540,6 +506,43 @@ class CustomCallCmd : public Command { std::vector results_; }; +//===----------------------------------------------------------------------===// +// LegacyCustomCallCmd +//===----------------------------------------------------------------------===// + +class LegacyCustomCallCmd : public Command { + public: + using CustomCallTarget = LegacyCustomCallThunk::CustomCallTarget; + + LegacyCustomCallCmd(std::string target_name, CustomCallTarget call_target, + std::vector operands, + std::vector results, + absl::string_view opaque) + : Command(CommandType::kCustomCallCmd), + target_name_(std::move(target_name)), + call_target_(std::move(call_target)), + opaque_(opaque), + operands_(std::move(operands)), + results_(std::move(results)) {} + + absl::StatusOr Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, RecordAction record_action, + se::CommandBuffer* command_buffer) override; + + BufferUses buffer_uses() const override; + bool IsNestedCommandBuffer() const final { return true; } + + private: + std::string target_name_; + + CustomCallTarget call_target_; + std::string opaque_; + + std::vector operands_; + std::vector results_; +}; + //===----------------------------------------------------------------------===// // CollectiveCmd //===----------------------------------------------------------------------===// diff --git a/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc b/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc index 338353bd071c1..37201e0d1edf5 100644 --- a/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc +++ b/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc @@ -49,6 +49,7 @@ limitations under the License. #include "xla/backends/gpu/runtime/gemm_thunk.h" #include "xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.h" #include "xla/backends/gpu/runtime/kernel_thunk.h" +#include "xla/backends/gpu/runtime/legacy_custom_call_thunk.h" #include "xla/backends/gpu/runtime/memset_thunk.h" #include "xla/backends/gpu/runtime/ragged_all_to_all_thunk.h" #include "xla/backends/gpu/runtime/recv_thunk.h" @@ -277,9 +278,16 @@ static absl::StatusOr> Convert( thunk.execution_state(), /*called_computation=*/nullptr); // TODO(b/342285364) } - return std::make_unique(thunk.target_name(), - thunk.call_target(), thunk.operands(), - thunk.results(), thunk.opaque()); + return absl::InternalError( + "CustomCallThunk without FFI handler bundle cannot be converted to a " + "command buffer command"); +} + +static absl::StatusOr> Convert( + const LegacyCustomCallThunk& thunk) { + return std::make_unique( + thunk.target_name(), thunk.call_target(), thunk.operands(), + thunk.results(), thunk.opaque()); } static absl::StatusOr> Convert( @@ -329,7 +337,14 @@ static absl::Status AppendCommands(ConversionContext& ctx, return append(Convert(thunk)); } case Thunk::Kind::kCustomCall: - return append(Convert(thunk)); + if (auto* ffi_thunk = dynamic_cast(&thunk)) { + return append(Convert(*ffi_thunk)); + } + if (auto* legacy_thunk = + dynamic_cast(&thunk)) { + return append(Convert(*legacy_thunk)); + } + return absl::InternalError("Unknown custom call thunk type"); case Thunk::Kind::kCustomKernel: return append(Convert(thunk)); case Thunk::Kind::kKernel: diff --git a/xla/backends/gpu/runtime/command_buffer_conversion_pass.cc b/xla/backends/gpu/runtime/command_buffer_conversion_pass.cc index 928ef38f9f521..6de74dcb533a4 100644 --- a/xla/backends/gpu/runtime/command_buffer_conversion_pass.cc +++ b/xla/backends/gpu/runtime/command_buffer_conversion_pass.cc @@ -204,8 +204,8 @@ bool IsConvertible(const ConditionalThunk& conditional_thunk, } // Returns true if the CustomCallThunk is convertible to a command buffer -// operation. Checks if the custom call target is in the legacy allowlist or if -// the registered FFI handler is compatible with command buffers. +// operation. Checks if the registered FFI handler is compatible with command +// buffers. bool IsConvertible(const CustomCallThunk& custom_call_thunk, const CommandBufferConfig& config) { const std::string& target_name = custom_call_thunk.target_name(); @@ -312,7 +312,11 @@ bool IsConvertible(const Thunk& thunk, const CommandBufferConfig& config) { } if (thunk.kind() == Thunk::kCustomCall) { - return IsConvertible(static_cast(thunk), config); + if (auto* ffi_thunk = dynamic_cast(&thunk)) { + return IsConvertible(*ffi_thunk, config); + } + // Legacy custom calls are not command-buffer compatible. + return false; } if (thunk.kind() == Thunk::kDynamicSlice) { diff --git a/xla/backends/gpu/runtime/custom_call_thunk.cc b/xla/backends/gpu/runtime/custom_call_thunk.cc index 4788b4279384f..f7267d79189b7 100644 --- a/xla/backends/gpu/runtime/custom_call_thunk.cc +++ b/xla/backends/gpu/runtime/custom_call_thunk.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/backends/gpu/runtime/custom_call_thunk.h" +#include #include #include #include @@ -34,15 +35,14 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/backends/cpu/target_machine_options.h" #include "xla/backends/gpu/runtime/collective_clique_requests.h" #include "xla/backends/gpu/runtime/collective_cliques.h" #include "xla/backends/gpu/runtime/collective_params.h" -#include "xla/backends/gpu/runtime/custom_call_target.h" #include "xla/backends/gpu/runtime/thunk.h" +#include "xla/backends/gpu/runtime/thunk.pb.h" #include "xla/executable_run_options.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/attribute_map.h" @@ -55,24 +55,18 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/runtime/object_pool.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/custom_call_status.h" -#include "xla/service/custom_call_status_internal.h" -#include "xla/service/custom_call_target_registry.h" #include "xla/service/gpu/buffer_allocations.h" +#include "xla/service/hlo.pb.h" #include "xla/service/shaped_slice.h" #include "xla/stream_executor/device_address.h" #include "xla/stream_executor/device_address_allocator.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/stream.h" -#include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/statusor.h" #include "xla/tsl/util/unique_any.h" -#include "xla/util.h" -#include "tsl/platform/platform.h" #include "xla/tsl/platform/status_macros.h" -namespace xla { -namespace gpu { +namespace xla::gpu { using xla::ffi::CallFrame; using xla::ffi::CallFrameBuilder; @@ -89,8 +83,6 @@ static absl::StatusOr BuildCallFramePrototype( /*num_args=*/operands.size(), /*num_rets=*/results.size()); - // Add prototype input buffers with actual data types and shapes. Device - // memory addresses will be updated at runtime. for (int i = 0; i < operands.size(); ++i) { auto& operand = operands[i]; @@ -107,8 +99,6 @@ static absl::StatusOr BuildCallFramePrototype( operand->shape.dimensions()); } - // Add prototype output buffers with actual data types and shapes. Device - // memory addresses will be updated at runtime. for (int i = 0; i < results.size(); ++i) { auto& result = results[i]; @@ -125,7 +115,6 @@ static absl::StatusOr BuildCallFramePrototype( result->shape.dimensions()); } - // Add attributes if any. if (!attributes.empty()) { ffi::CallFrameBuilder::AttributesBuilder attrs; attrs.Append(std::move(attributes)); @@ -135,111 +124,6 @@ static absl::StatusOr BuildCallFramePrototype( return builder.Build(); } -static absl::StatusOr -ResolveLegacyCustomCall(const CustomCallTargetRegistry& registry, - absl::string_view target_name, - absl::string_view platform_name, - CustomCallApiVersion api_version) { - void* call_target = - registry.Lookup(std::string(target_name), std::string(platform_name)); - - if (call_target == nullptr) { - return NotFound( - "No registered implementation for custom call to %s for platform %s", - target_name, platform_name); - } - - // For information about this calling convention, see - // xla/g3doc/custom_call.md. - switch (api_version) { - case CustomCallApiVersion::API_VERSION_ORIGINAL: { - constexpr absl::string_view kErrorMessage = - "Custom call API version `API_VERSION_ORIGINAL` is not supported by " - "XLA:GPU. Prefer https://docs.jax.dev/en/latest/ffi.html. It will be " - "fully removed in November 2025."; - if constexpr (tsl::kIsOpenSource) { - LOG(ERROR) << kErrorMessage; - } else { - LOG(FATAL) << kErrorMessage; - } - - return [call_target](stream_executor::Stream* stream, void** buffers, - const char* opaque, size_t opaque_len, - XlaCustomCallStatus*) { - reinterpret_cast(call_target)( - stream->platform_specific_handle().stream, buffers, opaque, - opaque_len); - }; - break; - } - case CustomCallApiVersion::API_VERSION_STATUS_RETURNING: - case CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED: - return [call_target](stream_executor::Stream* stream, void** buffers, - const char* opaque, size_t opaque_len, - XlaCustomCallStatus* status) { - reinterpret_cast( - call_target)(stream->platform_specific_handle().stream, buffers, - opaque, opaque_len, status); - }; - break; - case CustomCallApiVersion::API_VERSION_TYPED_FFI: - return absl::InvalidArgumentError( - "Called ResolveLegacyCustomCall with API_VERSION_TYPED_FFI"); - default: - return Internal("Unknown custom-call API version enum value: %d", - api_version); - } -} - -absl::StatusOr> CustomCallThunk::Create( - ThunkInfo thunk_info, std::string target_name, CustomCallTarget call_target, - std::vector operands, - std::vector results, std::string opaque) { - return absl::WrapUnique(new CustomCallThunk( - thunk_info, std::move(target_name), std::move(operands), - std::move(results), std::move(opaque), std::move(call_target), - /*api_version=*/std::nullopt)); -} - -absl::StatusOr> CustomCallThunk::Create( - ThunkInfo thunk_info, std::string target_name, - std::vector operands, - std::vector results, std::string opaque, - CustomCallApiVersion api_version, absl::string_view platform_name) { - if (api_version == CustomCallApiVersion::API_VERSION_TYPED_FFI) { - return absl::InvalidArgumentError( - "Called overload of CustomCallThunk::Create that is intended for " - "legacy custom calls with api_version=API_VERSION_TYPED_FFI"); - } - - TF_ASSIGN_OR_RETURN( - CustomCallTarget call_target, - ResolveLegacyCustomCall(*CustomCallTargetRegistry::Global(), target_name, - platform_name, api_version)); - - return absl::WrapUnique(new CustomCallThunk( - thunk_info, std::move(target_name), std::move(operands), - std::move(results), std::move(opaque), call_target, api_version)); -} - -absl::StatusOr> CustomCallThunk::Create( - ThunkInfo thunk_info, std::string target_name, - std::vector operands, - std::vector results, ffi::AttributesMap attributes, - const HloComputation* called_computation, absl::string_view platform_name, - const se::GpuComputeCapability& gpu_compute_capability, - std::unique_ptr execution_state, - std::optional cpu_target_machine_options) { - TF_ASSIGN_OR_RETURN(ffi::HandlerRegistration registration, - ffi::FindHandler(target_name, platform_name)); - - return Create(thunk_info, std::move(target_name), - std::move(registration.bundle), std::move(operands), - std::move(results), std::move(attributes), called_computation, - gpu_compute_capability, std::move(execution_state), - std::move(cpu_target_machine_options)); -} - static InvokeContext BuildInstantiateInvokeContext( ffi::ExecutionState* execution_state, const se::GpuComputeCapability* gpu_compute_capability, @@ -260,6 +144,24 @@ static InvokeContext BuildInstantiateInvokeContext( return context; } +absl::StatusOr> CustomCallThunk::Create( + ThunkInfo thunk_info, std::string target_name, + std::vector operands, + std::vector results, ffi::AttributesMap attributes, + const HloComputation* called_computation, absl::string_view platform_name, + const se::GpuComputeCapability& gpu_compute_capability, + std::unique_ptr execution_state, + std::optional cpu_target_machine_options) { + ASSIGN_OR_RETURN(ffi::HandlerRegistration registration, + ffi::FindHandler(target_name, platform_name)); + + return Create(thunk_info, std::move(target_name), + std::move(registration.bundle), std::move(operands), + std::move(results), std::move(attributes), called_computation, + gpu_compute_capability, std::move(execution_state), + std::move(cpu_target_machine_options)); +} + absl::StatusOr> CustomCallThunk::Create( ThunkInfo thunk_info, std::string target_name, XLA_FFI_Handler_Bundle bundle, std::vector operands, @@ -289,8 +191,8 @@ absl::StatusOr> CustomCallThunk::Create( } } - TF_ASSIGN_OR_RETURN(CallFrame call_frame, - BuildCallFramePrototype(operands, results, attributes)); + ASSIGN_OR_RETURN(CallFrame call_frame, + BuildCallFramePrototype(operands, results, attributes)); return absl::WrapUnique(new CustomCallThunk( thunk_info, std::move(target_name), std::move(bundle), std::move(operands), std::move(results), std::move(call_frame), @@ -313,12 +215,11 @@ absl::StatusOr> CustomCallThunk::Create( auto execution_state = std::make_unique(); - // Initialize FFI handler state if it has an instantiate callback. if (bundle.instantiate) { // Build a call frame with placeholder buffers so the instantiate handler // can read operand/result types and shapes. Data pointers are nullptr. - TF_ASSIGN_OR_RETURN(CallFrame call_frame, - BuildCallFramePrototype(operands, results, attributes)); + ASSIGN_OR_RETURN(CallFrame call_frame, + BuildCallFramePrototype(operands, results, attributes)); if (!cpu_target_machine_options.has_value()) { cpu_target_machine_options = xla::cpu::TargetMachineOptions(); @@ -326,13 +227,12 @@ absl::StatusOr> CustomCallThunk::Create( InvokeContext context = BuildInstantiateInvokeContext( execution_state.get(), &gpu_compute_capability, &*cpu_target_machine_options); - TF_RETURN_IF_ERROR(Invoke(ffi::GetXlaFfiApi(), *bundle.instantiate, - call_frame, context, - xla::ffi::ExecutionStage::kInstantiate)); + RETURN_IF_ERROR(Invoke(ffi::GetXlaFfiApi(), *bundle.instantiate, call_frame, + context, xla::ffi::ExecutionStage::kInstantiate)); } - TF_ASSIGN_OR_RETURN(CallFrame call_frame, - BuildCallFramePrototype(operands, results, attributes)); + ASSIGN_OR_RETURN(CallFrame call_frame, + BuildCallFramePrototype(operands, results, attributes)); return absl::WrapUnique(new CustomCallThunk( thunk_info, std::move(target_name), std::move(bundle), std::move(operands), std::move(results), std::move(call_frame), @@ -340,20 +240,6 @@ absl::StatusOr> CustomCallThunk::Create( cpu_target_machine_options)); } -CustomCallThunk::CustomCallThunk( - ThunkInfo thunk_info, std::string target_name, - std::vector operands, - std::vector results, std::string opaque, - CustomCallTarget call_target, - const std::optional& api_version) - : Thunk(Thunk::kCustomCall, thunk_info), - api_version_(api_version), - target_name_(std::move(target_name)), - operands_(std::move(operands)), - results_(std::move(results)), - call_target_(std::move(call_target)), - opaque_(std::move(opaque)) {} - CustomCallThunk::CustomCallThunk( ThunkInfo thunk_info, std::string target_name, std::variant bundle, @@ -364,7 +250,6 @@ CustomCallThunk::CustomCallThunk( const HloComputation* called_computation, std::optional cpu_target_machine_options) : Thunk(Thunk::kCustomCall, thunk_info), - api_version_(CustomCallApiVersion::API_VERSION_TYPED_FFI), target_name_(std::move(target_name)), operands_(std::move(operands)), results_(std::move(results)), @@ -376,42 +261,6 @@ CustomCallThunk::CustomCallThunk( called_computation_(called_computation), cpu_target_machine_options_(std::move(cpu_target_machine_options)) {} -absl::Status CustomCallThunk::ExecuteCustomCall(const ExecuteParams& params) { - // gpu_stream is CUstream or e.g. the equivalent type in ROCm. - std::vector buffers; - buffers.reserve(operands_.size() + results_.size()); - for (auto& slices : {operands_, results_}) { - for (const std::optional& slice : slices) { - if (!slice.has_value()) { - buffers.push_back(nullptr); - continue; - } - - if (!slice->slice.allocation()) { - return Internal("custom call input missing buffer allocation"); - } - - buffers.push_back( - params.buffer_allocations->GetDeviceAddress(slice->slice).opaque()); - } - } - - XlaCustomCallStatus custom_call_status; - call_target_(params.stream, buffers.data(), opaque_.data(), opaque_.size(), - &custom_call_status); - auto message = CustomCallStatusGetMessage(&custom_call_status); - if (message) { - return Internal("CustomCall failed: %s", *message); - } - return absl::OkStatus(); -} - -// Builds a call frame for the custom call. -// -// If `buffer_allocations` is provided, the call frame will contain the actual -// device memory addresses of the buffers. Otherwise, the call frame will -// contain placeholders - this should only be the case when calling Prepare() -// stage handler. absl::StatusOr::BorrowedObject> CustomCallThunk::BuildCallFrame( const BufferAllocations* absl_nullable buffer_allocations) { @@ -420,7 +269,6 @@ CustomCallThunk::BuildCallFrame( : se::DeviceAddressBase{}; }; - // Collect arguments buffers. absl::InlinedVector arguments; arguments.reserve(operands_.size()); for (auto& operand : operands_) { @@ -431,7 +279,6 @@ CustomCallThunk::BuildCallFrame( } } - // Collect results buffers. absl::InlinedVector results; results.reserve(results_.size()); for (auto& result : results_) { @@ -442,17 +289,11 @@ CustomCallThunk::BuildCallFrame( } } - // Borrow the FFI call frame from the object pool and update with the actual - // device memory addresses. - TF_ASSIGN_OR_RETURN(auto call_frame, call_frames_->GetOrCreate()); - TF_RETURN_IF_ERROR(call_frame->UpdateWithBuffers(arguments, results)); + ASSIGN_OR_RETURN(auto call_frame, call_frames_->GetOrCreate()); + RETURN_IF_ERROR(call_frame->UpdateWithBuffers(arguments, results)); return call_frame; } -// Builds call options object for the custom call. -// -// `stream` and `buffer_allocations may only be non-null for options passed to -// Prepare()_stage handler. InvokeContext CustomCallThunk::BuildInvokeContext( RunId run_id, se::Stream* absl_nullable stream, Thunk::ExecutionScopedState* absl_nullable execution_scoped_state, @@ -476,7 +317,6 @@ InvokeContext CustomCallThunk::BuildInvokeContext( &stream->parent()->GetDeviceDescription().gpu_compute_capability(); } - // Lookup per-execution state for prepare and init stages. ffi::ExecutionState* prepare_state = nullptr; ffi::ExecutionState* initialize_state = nullptr; @@ -522,7 +362,7 @@ absl::Status CustomCallThunk::ExecuteFfiHandler( return absl::InternalError("buffer allocations and stream are required"); } - TF_ASSIGN_OR_RETURN(auto call_frame, BuildCallFrame(buffer_allocations)); + ASSIGN_OR_RETURN(auto call_frame, BuildCallFrame(buffer_allocations)); InvokeContext context = BuildInvokeContext( run_id, stream, execution_scoped_state, buffer_allocations, collective_params, collective_clique_requests, collective_memory_requests, @@ -545,7 +385,7 @@ absl::Status CustomCallThunk::ExecuteFfiHandler( return absl::InternalError("buffer allocations and stream are required"); } - TF_ASSIGN_OR_RETURN(auto call_frame, BuildCallFrame(buffer_allocations)); + ASSIGN_OR_RETURN(auto call_frame, BuildCallFrame(buffer_allocations)); InvokeContext context = BuildInvokeContext( run_id, stream, execution_scoped_state, buffer_allocations, collective_params, collective_clique_requests, collective_memory_requests, @@ -554,74 +394,66 @@ absl::Status CustomCallThunk::ExecuteFfiHandler( } absl::Status CustomCallThunk::Prepare(const PrepareParams& params) { - if (bundle_.has_value()) { - const RunId run_id = - params.collective_params ? params.collective_params->run_id : RunId{-1}; - - if (const auto* c_bundle = - std::get_if(&bundle_.value()); - c_bundle && c_bundle->prepare) { - return ExecuteFfiHandler( - run_id, c_bundle->prepare, XLA_FFI_ExecutionStage_PREPARE, - /*stream=*/nullptr, - /*execution_scoped_state=*/params.execution_scoped_state, - /*execution_context=*/nullptr, - /*buffer_allocations=*/params.buffer_allocations, - /*collective_params=*/params.collective_params, - /*collective_clique_requests=*/params.collective_clique_requests, - /*collective_memory_requests=*/params.collective_memory_requests, - /*collective_cliques=*/nullptr, - /*collective_memory=*/nullptr); - } - if (const auto* owned_bundle = - std::get_if(&bundle_.value()); - owned_bundle && owned_bundle->prepare) { - return ExecuteFfiHandler( - run_id, *owned_bundle->prepare, xla::ffi::ExecutionStage::kPrepare, - /*stream=*/nullptr, - /*execution_scoped_state=*/params.execution_scoped_state, - /*execution_context=*/nullptr, - /*buffer_allocations=*/params.buffer_allocations, - /*collective_params=*/params.collective_params, - /*collective_clique_requests=*/params.collective_clique_requests, - /*collective_memory_requests=*/params.collective_memory_requests, - /*collective_cliques=*/nullptr, - /*collective_memory=*/nullptr); - } + const RunId run_id = + params.collective_params ? params.collective_params->run_id : RunId{-1}; + + if (const auto* c_bundle = std::get_if(&bundle_); + c_bundle && c_bundle->prepare) { + return ExecuteFfiHandler( + run_id, c_bundle->prepare, XLA_FFI_ExecutionStage_PREPARE, + /*stream=*/nullptr, + /*execution_scoped_state=*/params.execution_scoped_state, + /*execution_context=*/nullptr, + /*buffer_allocations=*/params.buffer_allocations, + /*collective_params=*/params.collective_params, + /*collective_clique_requests=*/params.collective_clique_requests, + /*collective_memory_requests=*/params.collective_memory_requests, + /*collective_cliques=*/nullptr, + /*collective_memory=*/nullptr); + } + if (const auto* owned_bundle = std::get_if(&bundle_); + owned_bundle && owned_bundle->prepare) { + return ExecuteFfiHandler( + run_id, *owned_bundle->prepare, xla::ffi::ExecutionStage::kPrepare, + /*stream=*/nullptr, + /*execution_scoped_state=*/params.execution_scoped_state, + /*execution_context=*/nullptr, + /*buffer_allocations=*/params.buffer_allocations, + /*collective_params=*/params.collective_params, + /*collective_clique_requests=*/params.collective_clique_requests, + /*collective_memory_requests=*/params.collective_memory_requests, + /*collective_cliques=*/nullptr, + /*collective_memory=*/nullptr); } return absl::OkStatus(); } absl::Status CustomCallThunk::Initialize(const InitializeParams& params) { - if (bundle_.has_value()) { - const RunId run_id = - params.collective_params ? params.collective_params->run_id : RunId{-1}; - - if (const auto* c_bundle = - std::get_if(&bundle_.value()); - c_bundle && c_bundle->initialize) { - return ExecuteFfiHandler( - run_id, *c_bundle->initialize, XLA_FFI_ExecutionStage_INITIALIZE, - params.stream, params.execution_scoped_state, - params.ffi_execution_context, params.buffer_allocations, - params.collective_params, - /*collective_clique_requests=*/nullptr, - /*collective_memory_requests=*/nullptr, params.collective_cliques, - params.collective_memory); - } - if (const auto* owned_bundle = - std::get_if(&bundle_.value()); - owned_bundle && owned_bundle->initialize) { - return ExecuteFfiHandler( - run_id, *owned_bundle->initialize, - xla::ffi::ExecutionStage::kInitialize, params.stream, - params.execution_scoped_state, params.ffi_execution_context, - params.buffer_allocations, params.collective_params, - /*collective_clique_requests=*/nullptr, - /*collective_memory_requests=*/nullptr, params.collective_cliques, - params.collective_memory); - } + const RunId run_id = + params.collective_params ? params.collective_params->run_id : RunId{-1}; + + if (const auto* c_bundle = std::get_if(&bundle_); + c_bundle && c_bundle->initialize) { + return ExecuteFfiHandler( + run_id, *c_bundle->initialize, XLA_FFI_ExecutionStage_INITIALIZE, + params.stream, params.execution_scoped_state, + params.ffi_execution_context, params.buffer_allocations, + params.collective_params, + /*collective_clique_requests=*/nullptr, + /*collective_memory_requests=*/nullptr, params.collective_cliques, + params.collective_memory); + } + if (const auto* owned_bundle = std::get_if(&bundle_); + owned_bundle && owned_bundle->initialize) { + return ExecuteFfiHandler( + run_id, *owned_bundle->initialize, + xla::ffi::ExecutionStage::kInitialize, params.stream, + params.execution_scoped_state, params.ffi_execution_context, + params.buffer_allocations, params.collective_params, + /*collective_clique_requests=*/nullptr, + /*collective_memory_requests=*/nullptr, params.collective_cliques, + params.collective_memory); } return absl::OkStatus(); } @@ -629,64 +461,53 @@ absl::Status CustomCallThunk::Initialize(const InitializeParams& params) { absl::Status CustomCallThunk::ExecuteOnStream(const ExecuteParams& params) { se::Stream* stream = params.stream; - if (bundle_.has_value()) { - const RunId run_id = - params.collective_params ? params.collective_params->run_id : RunId{-1}; - if (const auto* c_bundle = - std::get_if(&bundle_.value()); - c_bundle) { - return ExecuteFfiHandler( - run_id, c_bundle->execute, XLA_FFI_ExecutionStage_EXECUTE, stream, - params.execution_scoped_state, params.ffi_execution_context, - params.buffer_allocations, params.collective_params, - /*collective_clique_requests=*/nullptr, - /*collective_memory_requests=*/nullptr, params.collective_cliques, - params.collective_memory); - } - if (const auto* owned_bundle = - std::get_if(&bundle_.value()); - owned_bundle) { - if (!owned_bundle->execute) { - return absl::InternalError("FFI execute handler is not set"); - } - return ExecuteFfiHandler( - run_id, *owned_bundle->execute, xla::ffi::ExecutionStage::kExecute, - stream, params.execution_scoped_state, params.ffi_execution_context, - params.buffer_allocations, params.collective_params, - /*collective_clique_requests=*/nullptr, - /*collective_memory_requests=*/nullptr, params.collective_cliques, - params.collective_memory); + const RunId run_id = + params.collective_params ? params.collective_params->run_id : RunId{-1}; + + if (const auto* c_bundle = std::get_if(&bundle_)) { + return ExecuteFfiHandler( + run_id, c_bundle->execute, XLA_FFI_ExecutionStage_EXECUTE, stream, + params.execution_scoped_state, params.ffi_execution_context, + params.buffer_allocations, params.collective_params, + /*collective_clique_requests=*/nullptr, + /*collective_memory_requests=*/nullptr, params.collective_cliques, + params.collective_memory); + } + if (const auto* owned_bundle = std::get_if(&bundle_)) { + if (!owned_bundle->execute) { + return absl::InternalError("FFI execute handler is not set"); } + return ExecuteFfiHandler( + run_id, *owned_bundle->execute, xla::ffi::ExecutionStage::kExecute, + stream, params.execution_scoped_state, params.ffi_execution_context, + params.buffer_allocations, params.collective_params, + /*collective_clique_requests=*/nullptr, + /*collective_memory_requests=*/nullptr, params.collective_cliques, + params.collective_memory); } - return ExecuteCustomCall(params); + return absl::InternalError("No FFI handler bundle set"); } absl::StatusOr CustomCallThunk::ToProto() const { - if (!api_version_.has_value()) { - return absl::FailedPreconditionError( - "CustomCallThunk was created from a non-registered target and cannot " - "be serialized to a proto"); - } - ThunkProto proto; *proto.mutable_thunk_info() = thunk_info().ToProto(); proto.mutable_custom_call_thunk()->set_target_name(target_name_); - proto.mutable_custom_call_thunk()->set_opaque(opaque_); - proto.mutable_custom_call_thunk()->set_api_version(api_version_.value()); + proto.mutable_custom_call_thunk()->set_api_version( + CustomCallApiVersion::API_VERSION_TYPED_FFI); if (called_computation_ != nullptr) { proto.mutable_custom_call_thunk()->set_called_computation( called_computation_->name()); } for (const NullableShapedSlice& operand : operands_) { - TF_ASSIGN_OR_RETURN(*proto.mutable_custom_call_thunk()->add_operands(), - operand.ToProto()); + ASSIGN_OR_RETURN(*proto.mutable_custom_call_thunk()->add_operands(), + operand.ToProto()); } for (const NullableShapedSlice& result : results_) { - TF_ASSIGN_OR_RETURN(*proto.mutable_custom_call_thunk()->add_results(), - result.ToProto()); + ASSIGN_OR_RETURN(*proto.mutable_custom_call_thunk()->add_results(), + result.ToProto()); } if (attributes_.has_value()) { @@ -695,7 +516,7 @@ absl::StatusOr CustomCallThunk::ToProto() const { } if (execution_state_ && execution_state_->IsSerializable()) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( *proto.mutable_custom_call_thunk()->mutable_execution_state(), execution_state_->ToProto()); } @@ -716,31 +537,24 @@ absl::StatusOr> CustomCallThunk::FromProto( std::vector operands, results; for (const auto& operand_proto : proto.operands()) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( NullableShapedSlice operand, NullableShapedSlice::FromProto(operand_proto, buffer_allocations)); operands.push_back(std::move(operand)); } for (const auto& result_proto : proto.results()) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( NullableShapedSlice result, NullableShapedSlice::FromProto(result_proto, buffer_allocations)); results.push_back(std::move(result)); } - if (proto.api_version() != CustomCallApiVersion::API_VERSION_TYPED_FFI) { - // Create a thunk that uses the legacy custom call registry. - return CustomCallThunk::Create( - std::move(thunk_info), proto.target_name(), std::move(operands), - std::move(results), proto.opaque(), proto.api_version(), platform_name); - } - - TF_ASSIGN_OR_RETURN(ffi::AttributesMap attributes, - ffi::AttributesMap::FromProto(proto.attributes())); + ASSIGN_OR_RETURN(ffi::AttributesMap attributes, + ffi::AttributesMap::FromProto(proto.attributes())); HloComputation* called_computation = nullptr; if (proto.has_called_computation()) { - CHECK(hlo_module != nullptr); // This check is needed for static analysis. + CHECK(hlo_module != nullptr); called_computation = hlo_module->GetComputationWithName(proto.called_computation()); if (called_computation == nullptr) { @@ -758,7 +572,7 @@ absl::StatusOr> CustomCallThunk::FromProto( } else { LOG(WARNING) << "Failed to deserialize the custom call execution state. Falling " - "back to runtime instantiaton of the execution state. Reason: " + "back to runtime instantiation of the execution state. Reason: " << state.status(); } } @@ -770,5 +584,4 @@ absl::StatusOr> CustomCallThunk::FromProto( std::move(cpu_target_machine_options)); } -} // namespace gpu -} // namespace xla +} // namespace xla::gpu diff --git a/xla/backends/gpu/runtime/custom_call_thunk.h b/xla/backends/gpu/runtime/custom_call_thunk.h index 810d2c002777d..cd077bb5fb617 100644 --- a/xla/backends/gpu/runtime/custom_call_thunk.h +++ b/xla/backends/gpu/runtime/custom_call_thunk.h @@ -16,8 +16,6 @@ limitations under the License. #ifndef XLA_BACKENDS_GPU_RUNTIME_CUSTOM_CALL_THUNK_H_ #define XLA_BACKENDS_GPU_RUNTIME_CUSTOM_CALL_THUNK_H_ -#include -#include #include #include #include @@ -33,6 +31,7 @@ limitations under the License. #include "xla/backends/gpu/runtime/collective_cliques.h" #include "xla/backends/gpu/runtime/collective_memory.h" #include "xla/backends/gpu/runtime/thunk.h" +#include "xla/backends/gpu/runtime/thunk.pb.h" #include "xla/executable_run_options.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/attribute_map.h" @@ -45,20 +44,20 @@ limitations under the License. #include "xla/runtime/buffer_use.h" #include "xla/runtime/object_pool.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/custom_call_status.h" #include "xla/service/gpu/buffer_allocations.h" #include "xla/service/shaped_slice.h" #include "xla/stream_executor/device_address_allocator.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/stream.h" -namespace xla { -namespace gpu { +namespace xla::gpu { -// Thunk to run a GPU custom call. +// Thunk to run an XLA FFI custom call on a GPU. // -// This thunk's `ExecuteOnStream` implementation executes a host function -// `call_target` which is expected to enqueue operations onto the GPU. +// This thunk handles custom calls registered via the XLA FFI mechanism, which +// provides a type-safe API for registering external functions with XLA runtime. +// +// For legacy (non-FFI) custom calls, see LegacyCustomCallThunk. // // Note that not all kCustomCall HLOs in XLA:GPU end up being run by this thunk. // XLA itself creates kCustomCall instructions when lowering kConvolution HLOs @@ -85,26 +84,6 @@ class CustomCallThunk : public Thunk { ffi::ExecutionState init; }; - using CustomCallTarget = - std::function; - - // Creates a serializable custom call thunk. The callback is resolved using - // the legacy CustomCall registry. For new code please use XLA FFI instead. - static absl::StatusOr> Create( - ThunkInfo thunk_info, std::string target_name, - std::vector operands, - std::vector results, std::string opaque, - CustomCallApiVersion api_version, absl::string_view platform_name); - - // Creates a custom call thunk from the given legacy custom call target. - // Note that a thunk created this way can't be serialized to a proto. - // This function is only permitted for unit testing code. - static absl::StatusOr> Create( - ThunkInfo thunk_info, std::string target_name, - CustomCallTarget call_target, std::vector operands, - std::vector results, std::string opaque); - // Creates a serializable custom call thunk. The callback is resolved using // XLA FFI. static absl::StatusOr> Create( @@ -150,14 +129,10 @@ class CustomCallThunk : public Thunk { absl::Status ExecuteOnStream(const ExecuteParams& params) override; const std::string& target_name() const { return target_name_; } - CustomCallTarget call_target() const { return call_target_; } std::optional bundle() const { - if (!bundle_.has_value()) { - return std::nullopt; - } const XLA_FFI_Handler_Bundle* c_bundle = - std::get_if(&bundle_.value()); + std::get_if(&bundle_); return c_bundle ? std::make_optional(*c_bundle) : std::nullopt; } @@ -172,8 +147,6 @@ class CustomCallThunk : public Thunk { const std::vector& operands() const { return operands_; } const std::vector& results() const { return results_; } - absl::string_view opaque() const { return opaque_; } - BufferUses buffer_uses() const override { BufferUses res; res.reserve(operands_.size() + results_.size()); @@ -197,12 +170,6 @@ class CustomCallThunk : public Thunk { std::optional cpu_target_machine_options); private: - CustomCallThunk(ThunkInfo thunk_info, std::string target_name, - std::vector operands, - std::vector results, std::string opaque, - CustomCallTarget call_target, - const std::optional& api_version); - CustomCallThunk( ThunkInfo thunk_info, std::string target_name, std::variant bundle, @@ -213,8 +180,6 @@ class CustomCallThunk : public Thunk { const HloComputation* called_computation, std::optional cpu_target_machine_options); - absl::Status ExecuteCustomCall(const ExecuteParams& params); - absl::StatusOr::BorrowedObject> BuildCallFrame(const BufferAllocations* absl_nullable buffer_allocations); @@ -251,26 +216,15 @@ class CustomCallThunk : public Thunk { const CollectiveCliques* absl_nullable collective_cliques, const CollectiveMemory* absl_nullable collective_memory); - // API version of the custom call. If not set, it means the custom call thunk - // was initialized from a non-registered function pointer and can't be - // serialized to a proto. - std::optional api_version_; std::string target_name_; // Nulled shape slices represent null pointer arguments to the thunk. std::vector operands_; std::vector results_; - // This is a legacy custom call API that is discouraged, and will be - // deprecated once XLA:FFI mechanism is ready. - CustomCallTarget call_target_; - std::string opaque_; - - // XLA FFI provides a right type safe mechanism for registering external - // functions with XLA runtime. It's under construction, and still misses - // a lot of features. Long term it will replace legacy custom calls. - std::optional> - bundle_; + // XLA FFI handler bundle: either a C API bundle (from the global FFI + // registry) or an owned bundle (from xla::ffi::Bind()). + std::variant bundle_; std::optional attributes_; // Reference call frame pre-initialized at construction time. @@ -296,7 +250,6 @@ class CustomCallThunk : public Thunk { std::optional cpu_target_machine_options_; }; -} // namespace gpu -} // namespace xla +} // namespace xla::gpu #endif // XLA_BACKENDS_GPU_RUNTIME_CUSTOM_CALL_THUNK_H_ diff --git a/xla/backends/gpu/runtime/custom_call_thunk_test.cc b/xla/backends/gpu/runtime/custom_call_thunk_test.cc index 4a37af9e61a85..14d12660bb09e 100644 --- a/xla/backends/gpu/runtime/custom_call_thunk_test.cc +++ b/xla/backends/gpu/runtime/custom_call_thunk_test.cc @@ -35,8 +35,11 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/backends/cpu/target_machine_options.h" #include "xla/backends/gpu/ffi.h" +#include "xla/backends/gpu/runtime/collective_clique_requests.h" #include "xla/backends/gpu/runtime/collective_memory_requests.h" +#include "xla/backends/gpu/runtime/collective_params.h" #include "xla/backends/gpu/runtime/thunk.h" +#include "xla/backends/gpu/runtime/thunk.pb.h" #include "xla/executable_run_options.h" #include "xla/ffi/attribute_map.h" #include "xla/ffi/execution_state.h" @@ -47,9 +50,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/runtime/device_id.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/custom_call_status.h" -#include "xla/service/custom_call_target_registry.h" #include "xla/service/gpu/buffer_allocations.h" +#include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" #include "xla/service/platform_util.h" #include "xla/service/service_executable_run_options.h" @@ -63,6 +65,8 @@ limitations under the License. #include "xla/stream_executor/stream_executor_address_allocator.h" #include "xla/tsl/platform/statusor.h" #include "xla/tsl/util/proto/parse_text_proto.h" +#include "xla/xla_data.pb.h" +#include "xla/tsl/platform/status_macros.h" namespace xla::gpu { struct TestState { @@ -132,38 +136,11 @@ using absl_testing::StatusIs; using ::testing::HasSubstr; static absl::StatusOr GpuExecutor() { - TF_ASSIGN_OR_RETURN(auto name, PlatformUtil::CanonicalPlatformName("gpu")); - TF_ASSIGN_OR_RETURN(auto* platform, - se::PlatformManager::PlatformWithName(name)); + ASSIGN_OR_RETURN(auto name, PlatformUtil::CanonicalPlatformName("gpu")); + ASSIGN_OR_RETURN(auto* platform, se::PlatformManager::PlatformWithName(name)); return platform->ExecutorForDevice(0); } -TEST(CustomCallThunkTest, SimpleCustomCall) { - TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, - executor->CreateStream()); - - bool was_called = false; - - CustomCallThunk::CustomCallTarget target = - [&](se::Stream* stream_in_callback, void** args, const char* target_name, - size_t num_args, XlaCustomCallStatus* status) { - was_called = true; - EXPECT_THAT(stream_in_callback, ::testing::Eq(stream.get())); - }; - - TF_ASSERT_OK_AND_ASSIGN( - auto thunk, CustomCallThunk::Create(Thunk::ThunkInfo(), "target_name", - target, {}, {}, "")); - stream_executor::StreamExecutorAddressAllocator allocator(executor); - Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( - ServiceExecutableRunOptions(), BufferAllocations({}, 0, &allocator), - stream.get(), stream.get(), nullptr, nullptr, nullptr); - EXPECT_THAT(thunk->ExecuteOnStream(Thunk::ExecuteParams(params)), - absl_testing::IsOk()); - EXPECT_TRUE(was_called); -} - // A simple callback function that always returns an error. absl::Status ReturnError() { return absl::UnknownError("Custom call was executed!"); @@ -181,11 +158,11 @@ XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), kReturnErrorCustomCallName, "ROCM", kReturnError); TEST(CustomCallThunkTest, ResolvesFFICustomCall) { - TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, - executor->CreateStream()); + ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + executor->CreateStream()); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( std::unique_ptr thunk, CustomCallThunk::Create( Thunk::ThunkInfo(), @@ -210,52 +187,10 @@ TEST(CustomCallThunkTest, ResolvesFFICustomCall) { HasSubstr("Custom call was executed!"))); } -// A simple callback function that always returns an error and has the function -// signature for a legacy custom call. -void Callback_WithStatusFailed(void* /*stream*/, void** /*buffers*/, - const char* /*opaque*/, size_t /*opaque_len*/, - XlaCustomCallStatus* status) { - constexpr absl::string_view kErrorMessage = - "Legacy Custom call was executed!"; - XlaCustomCallStatusSetFailure(status, kErrorMessage.data(), - kErrorMessage.size()); -} - -XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_WithStatusFailed, "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_WithStatusFailed, "ROCM"); - -TEST(CustomCallThunkTest, ResolvesLegacyCustomCall) { - TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, - executor->CreateStream()); - - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr thunk, - CustomCallThunk::Create( - Thunk::ThunkInfo(), - /*target_name=*/"Callback_WithStatusFailed", - /*operands=*/{}, - /*results=*/{}, /*opaque=*/"", - CustomCallApiVersion::API_VERSION_STATUS_RETURNING, - /*platform_name=*/executor->GetPlatform()->Name())); - - stream_executor::StreamExecutorAddressAllocator allocator(executor); - BufferAllocations empty_unused_allocations({}, 0, &allocator); - Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( - ServiceExecutableRunOptions(), empty_unused_allocations, - /*stream=*/stream.get(), - /*command_buffer_trace_stream=*/stream.get(), - /*collective_params=*/nullptr, - /*collective_cliques=*/nullptr, /*collective_memory=*/nullptr); - EXPECT_THAT(thunk->ExecuteOnStream(params), - StatusIs(absl::StatusCode::kInternal, - HasSubstr("Legacy Custom call was executed!"))); -} - TEST(CustomCallThunkTest, CustomCallWithOwnedHandlers) { - TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, - executor->CreateStream()); + ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + executor->CreateStream()); int instantiate_calls = 0; int prepare_calls = 0; int initialize_calls = 0; @@ -304,7 +239,7 @@ TEST(CustomCallThunkTest, CustomCallWithOwnedHandlers) { ServiceExecutableRunOptions(), buffer_allocations, stream.get(), stream.get(), nullptr, nullptr, nullptr); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( std::unique_ptr thunk, CustomCallThunk::Create( Thunk::ThunkInfo(), "target_name", std::move(bundle), @@ -336,9 +271,9 @@ TEST(CustomCallThunkTest, CustomCallWithOwnedHandlers) { } TEST(CustomCallThunkTest, CustomCallWithOwnedHandlersWithoutOptionalOnes) { - TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, - executor->CreateStream()); + ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + executor->CreateStream()); int execute_calls = 0; CustomCallThunk::OwnedHandlerBundle bundle; bundle.execute = ffi::Ffi::Bind().To([&]() { @@ -369,7 +304,7 @@ TEST(CustomCallThunkTest, CustomCallWithOwnedHandlersWithoutOptionalOnes) { stream.get(), nullptr, nullptr, nullptr); // Optional handlers are null and shouldn't be invoked. - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( std::unique_ptr thunk, CustomCallThunk::Create( Thunk::ThunkInfo(), "target_name", std::move(bundle), @@ -383,9 +318,9 @@ TEST(CustomCallThunkTest, CustomCallWithOwnedHandlersWithoutOptionalOnes) { } TEST(CustomCallThunkTest, CustomCallWithOwnedHandlersWithoutExecute) { - TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, - executor->CreateStream()); + ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + executor->CreateStream()); CustomCallThunk::OwnedHandlerBundle bundle; // all handlers null stream_executor::StreamExecutorAddressAllocator allocator(executor); Thunk::ExecuteParams execute_params = Thunk::ExecuteParams::Create( @@ -436,9 +371,9 @@ XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), kTestPlatformName, kVerifyCallbackArguments); TEST(CustomCallThunkTest, ProtoConversion) { - TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, - executor->CreateStream()); + ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + executor->CreateStream()); HloModuleConfig config; HloModule hlo_module("test_module", config); @@ -460,7 +395,7 @@ TEST(CustomCallThunkTest, ProtoConversion) { std::make_unique(TestState{"some state"})), IsOk()); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( std::unique_ptr original_thunk, CustomCallThunk::Create( Thunk::ThunkInfo(), @@ -470,14 +405,14 @@ TEST(CustomCallThunkTest, ProtoConversion) { hlo_module.entry_computation(), /*platform_name=*/kTestPlatformName, /*gpu_compute_capability=*/{}, std::move(execution_state))); - TF_ASSERT_OK_AND_ASSIGN(ThunkProto proto, original_thunk->ToProto()); + ASSERT_OK_AND_ASSIGN(ThunkProto proto, original_thunk->ToProto()); ASSERT_TRUE(proto.has_custom_call_thunk()); ASSERT_TRUE(proto.custom_call_thunk().has_execution_state()); original_thunk.reset(); std::array allocations = {alloc0, alloc1}; - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( std::unique_ptr new_thunk, CustomCallThunk::FromProto(Thunk::ThunkInfo(), proto.custom_call_thunk(), allocations, &hlo_module, kTestPlatformName, @@ -549,7 +484,7 @@ TEST(CustomCallThunkTest, DeserializationFailsGracefully) { 0, ShapeUtil::MakeShape(U32, {42}), "parameter")); hlo_module.AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( std::unique_ptr thunk, CustomCallThunk::FromProto(Thunk::ThunkInfo(), proto, /*buffer_allocations=*/{}, &hlo_module, @@ -562,9 +497,9 @@ TEST(CustomCallThunkTest, DeserializationFailsGracefully) { } TEST(CustomCallThunkTest, RoundtripWithNonSerializableExecutionState) { - TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, - executor->CreateStream()); + ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + executor->CreateStream()); HloModuleConfig config; HloModule hlo_module("test_module", config); @@ -579,7 +514,7 @@ TEST(CustomCallThunkTest, RoundtripWithNonSerializableExecutionState) { IsOk()); EXPECT_FALSE(execution_state->IsSerializable()); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( std::unique_ptr original_thunk, CustomCallThunk::Create( Thunk::ThunkInfo(), @@ -589,13 +524,13 @@ TEST(CustomCallThunkTest, RoundtripWithNonSerializableExecutionState) { /*platform_name=*/kTestPlatformName, /*gpu_compute_capability=*/{}, std::move(execution_state))); - TF_ASSERT_OK_AND_ASSIGN(ThunkProto proto, original_thunk->ToProto()); + ASSERT_OK_AND_ASSIGN(ThunkProto proto, original_thunk->ToProto()); ASSERT_TRUE(proto.has_custom_call_thunk()); EXPECT_FALSE(proto.custom_call_thunk().has_execution_state()); original_thunk.reset(); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( std::unique_ptr new_thunk, CustomCallThunk::FromProto( Thunk::ThunkInfo(), proto.custom_call_thunk(), @@ -620,7 +555,7 @@ TEST(CustomCallThunkTest, SerializationFails) { FailingSerializableTestState{42}))); EXPECT_TRUE(execution_state->IsSerializable()); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( std::unique_ptr thunk, CustomCallThunk::Create( Thunk::ThunkInfo(), @@ -658,68 +593,6 @@ TEST(CustomCallThunkTest, ParseFFIProtoWithNonUtf8Attribute) { EXPECT_TRUE(reconstructed_proto.ParseFromString(serialized_to_wire_format)); } -TEST(CustomCallThunkTest, ParseLegacyProtoWithNonUtf8Opaque) { - // This test ensures that legacy custom calls can contain non-UTF-8 opaque - // data, and these will be correctly parsed (and not fail). - - CustomCallThunkProto proto = - tsl::proto_testing::ParseTextProtoOrDie( - R"pb( - target_name: "Callback_WithStatusFailed" - api_version: API_VERSION_STATUS_RETURNING - opaque: "\xfe" - )pb"); - - std::string serialized_to_wire_format; - proto.SerializeToString(&serialized_to_wire_format); - - CustomCallThunkProto reconstructed_proto; - EXPECT_TRUE(reconstructed_proto.ParseFromString(serialized_to_wire_format)); -} - -TEST(CustomCallThunkTest, LegacyCustomCallRoundTrip) { - TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, - executor->CreateStream()); - - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr original_thunk, - CustomCallThunk::Create( - Thunk::ThunkInfo(), - /*target_name=*/"Callback_WithStatusFailed", - /*operands=*/{}, - /*results=*/{}, /*opaque=*/"opaque", - CustomCallApiVersion::API_VERSION_STATUS_RETURNING, - /*platform_name=*/executor->GetPlatform()->Name())); - - TF_ASSERT_OK_AND_ASSIGN(ThunkProto proto, original_thunk->ToProto()); - original_thunk.reset(); - - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr new_thunk, - CustomCallThunk::FromProto( - Thunk::ThunkInfo(), proto.custom_call_thunk(), - /*buffer_allocations=*/{}, - /*hlo_module=*/nullptr, executor->GetPlatform()->Name(), - executor->GetDeviceDescription().gpu_compute_capability(), - /*cpu_target_machine_options=*/std::nullopt)); - - stream_executor::StreamExecutorAddressAllocator allocator(executor); - BufferAllocations empty_unused_allocations({}, 0, &allocator); - Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( - ServiceExecutableRunOptions(), empty_unused_allocations, - /*stream=*/stream.get(), - /*command_buffer_trace_stream=*/stream.get(), - /*collective_params=*/nullptr, /*collective_cliques=*/nullptr, - /*collective_memory=*/nullptr); - - // We check that the new thunk behaves like the original one (returning - // internal error with specific message). - EXPECT_THAT(new_thunk->ExecuteOnStream(params), - StatusIs(absl::StatusCode::kInternal, - HasSubstr("Legacy Custom call was executed!"))); -} - static bool passes_cpu_target_machine_options_instantiate_called = false; absl::Status VerifyCpuTargetMachineOptionsInstantiate( @@ -759,14 +632,14 @@ XLA_FFI_REGISTER_HANDLER( static_cast(ffi::Traits::kCmdBufferCompatible)); TEST(CustomCallThunkTest, PassesCpuTargetMachineOptionsToInstantiate) { - TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, - executor->CreateStream()); + ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + executor->CreateStream()); xla::cpu::TargetMachineOptions options("test-triple", "test-cpu", ""); passes_cpu_target_machine_options_instantiate_called = false; - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( std::unique_ptr thunk, CustomCallThunk::Create( Thunk::ThunkInfo(), @@ -784,10 +657,10 @@ TEST(CustomCallThunkTest, PassesCpuTargetMachineOptionsToInstantiate) { // Also check that FromProto restores the CPU target machine options. // We clear the execution state from the proto to force a re-instantiation. passes_cpu_target_machine_options_instantiate_called = false; - TF_ASSERT_OK_AND_ASSIGN(ThunkProto proto, thunk->ToProto()); + ASSERT_OK_AND_ASSIGN(ThunkProto proto, thunk->ToProto()); proto.mutable_custom_call_thunk()->clear_execution_state(); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( std::unique_ptr new_thunk, CustomCallThunk::FromProto( Thunk::ThunkInfo(), proto.custom_call_thunk(), diff --git a/xla/backends/gpu/runtime/legacy_custom_call_thunk.cc b/xla/backends/gpu/runtime/legacy_custom_call_thunk.cc new file mode 100644 index 0000000000000..3cc38c4a09e65 --- /dev/null +++ b/xla/backends/gpu/runtime/legacy_custom_call_thunk.cc @@ -0,0 +1,224 @@ +/* Copyright 2026 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/gpu/runtime/legacy_custom_call_thunk.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/backends/gpu/runtime/custom_call_target.h" +#include "xla/backends/gpu/runtime/thunk.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/custom_call_status.h" +#include "xla/service/custom_call_status_internal.h" +#include "xla/service/custom_call_target_registry.h" +#include "xla/service/gpu/buffer_allocations.h" +#include "xla/service/shaped_slice.h" +#include "xla/stream_executor/stream.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "tsl/platform/platform.h" +#include "xla/tsl/platform/status_macros.h" + +namespace xla::gpu { + +static absl::StatusOr +ResolveLegacyCustomCall(const CustomCallTargetRegistry& registry, + absl::string_view target_name, + absl::string_view platform_name, + CustomCallApiVersion api_version) { + void* call_target = + registry.Lookup(std::string(target_name), std::string(platform_name)); + + if (call_target == nullptr) { + return NotFound( + "No registered implementation for custom call to %s for platform %s", + target_name, platform_name); + } + + switch (api_version) { + case CustomCallApiVersion::API_VERSION_ORIGINAL: { + constexpr absl::string_view kErrorMessage = + "Custom call API version `API_VERSION_ORIGINAL` is not supported by " + "XLA:GPU. Prefer https://docs.jax.dev/en/latest/ffi.html. It will be " + "fully removed in November 2025."; + if constexpr (tsl::kIsOpenSource) { + LOG(ERROR) << kErrorMessage; + } else { + LOG(FATAL) << kErrorMessage; + } + + return [call_target](stream_executor::Stream* stream, void** buffers, + const char* opaque, size_t opaque_len, + XlaCustomCallStatus*) { + reinterpret_cast(call_target)( + stream->platform_specific_handle().stream, buffers, opaque, + opaque_len); + }; + break; + } + case CustomCallApiVersion::API_VERSION_STATUS_RETURNING: + case CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED: + return [call_target](stream_executor::Stream* stream, void** buffers, + const char* opaque, size_t opaque_len, + XlaCustomCallStatus* status) { + reinterpret_cast( + call_target)(stream->platform_specific_handle().stream, buffers, + opaque, opaque_len, status); + }; + break; + case CustomCallApiVersion::API_VERSION_TYPED_FFI: + return absl::InvalidArgumentError( + "Called ResolveLegacyCustomCall with API_VERSION_TYPED_FFI"); + default: + return Internal("Unknown custom-call API version enum value: %d", + api_version); + } +} + +absl::StatusOr> +LegacyCustomCallThunk::Create(ThunkInfo thunk_info, std::string target_name, + CustomCallTarget call_target, + std::vector operands, + std::vector results, + std::string opaque) { + return absl::WrapUnique(new LegacyCustomCallThunk( + thunk_info, std::move(target_name), std::move(operands), + std::move(results), std::move(opaque), std::move(call_target), + /*api_version=*/std::nullopt)); +} + +absl::StatusOr> +LegacyCustomCallThunk::Create(ThunkInfo thunk_info, std::string target_name, + std::vector operands, + std::vector results, + std::string opaque, + CustomCallApiVersion api_version, + absl::string_view platform_name) { + ASSIGN_OR_RETURN( + CustomCallTarget call_target, + ResolveLegacyCustomCall(*CustomCallTargetRegistry::Global(), target_name, + platform_name, api_version)); + + return absl::WrapUnique(new LegacyCustomCallThunk( + thunk_info, std::move(target_name), std::move(operands), + std::move(results), std::move(opaque), call_target, api_version)); +} + +LegacyCustomCallThunk::LegacyCustomCallThunk( + ThunkInfo thunk_info, std::string target_name, + std::vector operands, + std::vector results, std::string opaque, + CustomCallTarget call_target, + const std::optional& api_version) + : Thunk(Thunk::kCustomCall, thunk_info), + api_version_(api_version), + target_name_(std::move(target_name)), + operands_(std::move(operands)), + results_(std::move(results)), + call_target_(std::move(call_target)), + opaque_(std::move(opaque)) {} + +absl::Status LegacyCustomCallThunk::ExecuteOnStream( + const ExecuteParams& params) { + std::vector buffers; + buffers.reserve(operands_.size() + results_.size()); + for (auto& slices : {operands_, results_}) { + for (const std::optional& slice : slices) { + if (!slice.has_value()) { + buffers.push_back(nullptr); + continue; + } + + if (!slice->slice.allocation()) { + return Internal("custom call input missing buffer allocation"); + } + + buffers.push_back( + params.buffer_allocations->GetDeviceAddress(slice->slice).opaque()); + } + } + + XlaCustomCallStatus custom_call_status; + call_target_(params.stream, buffers.data(), opaque_.data(), opaque_.size(), + &custom_call_status); + auto message = CustomCallStatusGetMessage(&custom_call_status); + if (message) { + return Internal("CustomCall failed: %s", *message); + } + return absl::OkStatus(); +} + +absl::StatusOr LegacyCustomCallThunk::ToProto() const { + if (!api_version_.has_value()) { + return absl::FailedPreconditionError( + "LegacyCustomCallThunk was created from a non-registered target and " + "cannot be serialized to a proto"); + } + + ThunkProto proto; + *proto.mutable_thunk_info() = thunk_info().ToProto(); + proto.mutable_custom_call_thunk()->set_target_name(target_name_); + proto.mutable_custom_call_thunk()->set_opaque(opaque_); + proto.mutable_custom_call_thunk()->set_api_version(api_version_.value()); + + for (const NullableShapedSlice& operand : operands_) { + ASSIGN_OR_RETURN(*proto.mutable_custom_call_thunk()->add_operands(), + operand.ToProto()); + } + + for (const NullableShapedSlice& result : results_) { + ASSIGN_OR_RETURN(*proto.mutable_custom_call_thunk()->add_results(), + result.ToProto()); + } + + return proto; +} + +absl::StatusOr> +LegacyCustomCallThunk::FromProto( + ThunkInfo thunk_info, const CustomCallThunkProto& proto, + absl::Span buffer_allocations, + absl::string_view platform_name) { + std::vector operands, results; + for (const auto& operand_proto : proto.operands()) { + ASSIGN_OR_RETURN( + NullableShapedSlice operand, + NullableShapedSlice::FromProto(operand_proto, buffer_allocations)); + operands.push_back(std::move(operand)); + } + for (const auto& result_proto : proto.results()) { + ASSIGN_OR_RETURN( + NullableShapedSlice result, + NullableShapedSlice::FromProto(result_proto, buffer_allocations)); + results.push_back(std::move(result)); + } + + return LegacyCustomCallThunk::Create( + std::move(thunk_info), proto.target_name(), std::move(operands), + std::move(results), proto.opaque(), proto.api_version(), platform_name); +} + +} // namespace xla::gpu diff --git a/xla/backends/gpu/runtime/legacy_custom_call_thunk.h b/xla/backends/gpu/runtime/legacy_custom_call_thunk.h new file mode 100644 index 0000000000000..eecdc65cc8510 --- /dev/null +++ b/xla/backends/gpu/runtime/legacy_custom_call_thunk.h @@ -0,0 +1,120 @@ +/* Copyright 2026 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_GPU_RUNTIME_LEGACY_CUSTOM_CALL_THUNK_H_ +#define XLA_BACKENDS_GPU_RUNTIME_LEGACY_CUSTOM_CALL_THUNK_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/backends/gpu/runtime/thunk.h" +#include "xla/backends/gpu/runtime/thunk.pb.h" +#include "xla/runtime/buffer_use.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/custom_call_status.h" +#include "xla/service/hlo.pb.h" +#include "xla/service/shaped_slice.h" +#include "xla/stream_executor/stream.h" + +namespace xla::gpu { + +// Thunk to run a legacy (non-FFI) GPU custom call. +// +// This thunk is DEPRECATED. All new custom calls should use the XLA FFI +// mechanism via CustomCallThunk instead. This class exists only to support +// legacy custom calls that have not yet migrated to FFI. +// +// For the FFI-based custom call thunk, see custom_call_thunk.h. +class LegacyCustomCallThunk : public Thunk { + public: + using CustomCallTarget = + std::function; + + // Creates a serializable legacy custom call thunk. The callback is resolved + // using the legacy CustomCallTargetRegistry. + static absl::StatusOr> Create( + ThunkInfo thunk_info, std::string target_name, + std::vector operands, + std::vector results, std::string opaque, + CustomCallApiVersion api_version, absl::string_view platform_name); + + // Creates a legacy custom call thunk from a given call target. A thunk + // created this way cannot be serialized to a proto. This overload is only + // permitted for unit testing code. + static absl::StatusOr> Create( + ThunkInfo thunk_info, std::string target_name, + CustomCallTarget call_target, std::vector operands, + std::vector results, std::string opaque); + + absl::Status ExecuteOnStream(const ExecuteParams& params) override; + + const std::string& target_name() const { return target_name_; } + CustomCallTarget call_target() const { return call_target_; } + + const std::vector& operands() const { return operands_; } + const std::vector& results() const { return results_; } + + absl::string_view opaque() const { return opaque_; } + + BufferUses buffer_uses() const override { + BufferUses res; + res.reserve(operands_.size() + results_.size()); + for (const NullableShapedSlice& shaped_slice : operands_) { + if (!shaped_slice.has_value()) { + continue; + } + res.push_back(BufferUse::Read(shaped_slice->slice, shaped_slice->shape)); + } + return res; + } + + absl::StatusOr ToProto() const override; + + static absl::StatusOr> FromProto( + ThunkInfo thunk_info, const CustomCallThunkProto& proto, + absl::Span buffer_allocations, + absl::string_view platform_name); + + private: + LegacyCustomCallThunk(ThunkInfo thunk_info, std::string target_name, + std::vector operands, + std::vector results, + std::string opaque, CustomCallTarget call_target, + const std::optional& api_version); + + // API version of the custom call. If not set, the thunk was created from a + // non-registered function pointer and cannot be serialized. + std::optional api_version_; + std::string target_name_; + + std::vector operands_; + std::vector results_; + + CustomCallTarget call_target_; + std::string opaque_; +}; + +} // namespace xla::gpu + +#endif // XLA_BACKENDS_GPU_RUNTIME_LEGACY_CUSTOM_CALL_THUNK_H_ diff --git a/xla/backends/gpu/runtime/legacy_custom_call_thunk_test.cc b/xla/backends/gpu/runtime/legacy_custom_call_thunk_test.cc new file mode 100644 index 0000000000000..72abb72b1035d --- /dev/null +++ b/xla/backends/gpu/runtime/legacy_custom_call_thunk_test.cc @@ -0,0 +1,181 @@ +/* Copyright 2026 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/gpu/runtime/legacy_custom_call_thunk.h" + +#include +#include +#include + +#include +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/backends/gpu/runtime/thunk.h" +#include "xla/backends/gpu/runtime/thunk.pb.h" +#include "xla/service/custom_call_status.h" +#include "xla/service/custom_call_target_registry.h" +#include "xla/service/gpu/buffer_allocations.h" +#include "xla/service/hlo.pb.h" +#include "xla/service/platform_util.h" +#include "xla/service/service_executable_run_options.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor_address_allocator.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/util/proto/parse_text_proto.h" +#include "xla/tsl/platform/status_macros.h" + +namespace xla::gpu { +namespace { +using absl_testing::IsOk; +using absl_testing::StatusIs; +using ::testing::HasSubstr; + +static absl::StatusOr GpuExecutor() { + ASSIGN_OR_RETURN(auto name, PlatformUtil::CanonicalPlatformName("gpu")); + ASSIGN_OR_RETURN(auto* platform, se::PlatformManager::PlatformWithName(name)); + return platform->ExecutorForDevice(0); +} + +TEST(LegacyCustomCallThunkTest, SimpleCustomCall) { + ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + executor->CreateStream()); + + bool was_called = false; + + LegacyCustomCallThunk::CustomCallTarget target = + [&](se::Stream* stream_in_callback, void** args, const char* target_name, + size_t num_args, XlaCustomCallStatus* status) { + was_called = true; + EXPECT_THAT(stream_in_callback, ::testing::Eq(stream.get())); + }; + + ASSERT_OK_AND_ASSIGN( + auto thunk, LegacyCustomCallThunk::Create( + Thunk::ThunkInfo(), "target_name", target, {}, {}, "")); + stream_executor::StreamExecutorAddressAllocator allocator(executor); + Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( + ServiceExecutableRunOptions(), BufferAllocations({}, 0, &allocator), + stream.get(), stream.get(), nullptr, nullptr, nullptr); + EXPECT_THAT(thunk->ExecuteOnStream(Thunk::ExecuteParams(params)), + absl_testing::IsOk()); + EXPECT_TRUE(was_called); +} + +// A simple callback function that always returns an error and has the function +// signature for a legacy custom call. +void Callback_WithStatusFailed(void* /*stream*/, void** /*buffers*/, + const char* /*opaque*/, size_t /*opaque_len*/, + XlaCustomCallStatus* status) { + constexpr absl::string_view kErrorMessage = + "Legacy Custom call was executed!"; + XlaCustomCallStatusSetFailure(status, kErrorMessage.data(), + kErrorMessage.size()); +} + +XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_WithStatusFailed, "CUDA"); +XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_WithStatusFailed, "ROCM"); + +TEST(LegacyCustomCallThunkTest, ResolvesLegacyCustomCall) { + ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + executor->CreateStream()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr thunk, + LegacyCustomCallThunk::Create( + Thunk::ThunkInfo(), + /*target_name=*/"Callback_WithStatusFailed", + /*operands=*/{}, + /*results=*/{}, /*opaque=*/"", + CustomCallApiVersion::API_VERSION_STATUS_RETURNING, + /*platform_name=*/executor->GetPlatform()->Name())); + + stream_executor::StreamExecutorAddressAllocator allocator(executor); + BufferAllocations empty_unused_allocations({}, 0, &allocator); + Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( + ServiceExecutableRunOptions(), empty_unused_allocations, + /*stream=*/stream.get(), + /*command_buffer_trace_stream=*/stream.get(), + /*collective_params=*/nullptr, + /*collective_cliques=*/nullptr, /*collective_memory=*/nullptr); + EXPECT_THAT(thunk->ExecuteOnStream(params), + StatusIs(absl::StatusCode::kInternal, + HasSubstr("Legacy Custom call was executed!"))); +} + +TEST(LegacyCustomCallThunkTest, ParseLegacyProtoWithNonUtf8Opaque) { + // This test ensures that legacy custom calls can contain non-UTF-8 opaque + // data, and these will be correctly parsed (and not fail). + + CustomCallThunkProto proto = + tsl::proto_testing::ParseTextProtoOrDie( + R"pb( + target_name: "Callback_WithStatusFailed" + api_version: API_VERSION_STATUS_RETURNING + opaque: "\xfe" + )pb"); + + std::string serialized_to_wire_format; + proto.SerializeToString(&serialized_to_wire_format); + + CustomCallThunkProto reconstructed_proto; + EXPECT_TRUE(reconstructed_proto.ParseFromString(serialized_to_wire_format)); +} + +TEST(LegacyCustomCallThunkTest, LegacyCustomCallRoundTrip) { + ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + executor->CreateStream()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr original_thunk, + LegacyCustomCallThunk::Create( + Thunk::ThunkInfo(), + /*target_name=*/"Callback_WithStatusFailed", + /*operands=*/{}, + /*results=*/{}, /*opaque=*/"opaque", + CustomCallApiVersion::API_VERSION_STATUS_RETURNING, + /*platform_name=*/executor->GetPlatform()->Name())); + + ASSERT_OK_AND_ASSIGN(ThunkProto proto, original_thunk->ToProto()); + original_thunk.reset(); + + ASSERT_OK_AND_ASSIGN( + std::unique_ptr new_thunk, + LegacyCustomCallThunk::FromProto( + Thunk::ThunkInfo(), proto.custom_call_thunk(), + /*buffer_allocations=*/{}, executor->GetPlatform()->Name())); + + stream_executor::StreamExecutorAddressAllocator allocator(executor); + BufferAllocations empty_unused_allocations({}, 0, &allocator); + Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( + ServiceExecutableRunOptions(), empty_unused_allocations, + /*stream=*/stream.get(), + /*command_buffer_trace_stream=*/stream.get(), + /*collective_params=*/nullptr, /*collective_cliques=*/nullptr, + /*collective_memory=*/nullptr); + + // We check that the new thunk behaves like the original one (returning + // internal error with specific message). + EXPECT_THAT(new_thunk->ExecuteOnStream(params), + StatusIs(absl::StatusCode::kInternal, + HasSubstr("Legacy Custom call was executed!"))); +} + +} // namespace +} // namespace xla::gpu diff --git a/xla/backends/gpu/runtime/thunk_proto_deserialization.cc b/xla/backends/gpu/runtime/thunk_proto_deserialization.cc index 6731666119bc5..5cf90b6e0cf46 100644 --- a/xla/backends/gpu/runtime/thunk_proto_deserialization.cc +++ b/xla/backends/gpu/runtime/thunk_proto_deserialization.cc @@ -56,6 +56,7 @@ limitations under the License. #include "xla/backends/gpu/runtime/host_to_device_copy_thunk.h" #include "xla/backends/gpu/runtime/infeed_thunk.h" #include "xla/backends/gpu/runtime/kernel_thunk.h" +#include "xla/backends/gpu/runtime/legacy_custom_call_thunk.h" #include "xla/backends/gpu/runtime/memset_thunk.h" #include "xla/backends/gpu/runtime/norm_thunk.h" #include "xla/backends/gpu/runtime/nvshmem_all_reduce_thunk.h" @@ -75,6 +76,7 @@ limitations under the License. #include "xla/backends/gpu/runtime/while_thunk.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/buffer_assignment.h" +#include "xla/service/hlo.pb.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/platform/statusor.h" @@ -225,11 +227,17 @@ absl::StatusOr> DeserializeThunkProtoImpl( thunk_proto.dynamic_slice_thunk(), buffer_allocations, deserializer); } - case ThunkProto::kCustomCallThunk: + case ThunkProto::kCustomCallThunk: { + const auto& cc_proto = thunk_proto.custom_call_thunk(); + if (cc_proto.api_version() != + CustomCallApiVersion::API_VERSION_TYPED_FFI) { + return LegacyCustomCallThunk::FromProto( + std::move(thunk_info), cc_proto, buffer_allocations, platform_name); + } return CustomCallThunk::FromProto( - std::move(thunk_info), thunk_proto.custom_call_thunk(), - buffer_allocations, hlo_module, platform_name, gpu_compute_capability, - cpu_target_machine_options); + std::move(thunk_info), cc_proto, buffer_allocations, hlo_module, + platform_name, gpu_compute_capability, cpu_target_machine_options); + } case ThunkProto::kHostExecuteStartThunk: return HostExecuteStartThunk::FromProto( std::move(thunk_info), thunk_proto.host_execute_start_thunk(), diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index d4b2fb6991732..02099a5ef7915 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -530,6 +530,7 @@ cc_library( "//xla/backends/gpu/runtime:host_to_device_copy_thunk", "//xla/backends/gpu/runtime:infeed_thunk", "//xla/backends/gpu/runtime:kernel_thunk", + "//xla/backends/gpu/runtime:legacy_custom_call_thunk", "//xla/backends/gpu/runtime:norm_thunk", "//xla/backends/gpu/runtime:nvshmem_all_reduce_thunk", "//xla/backends/gpu/runtime:nvshmem_collective_permute_thunk", diff --git a/xla/service/gpu/thunk_emitter.cc b/xla/service/gpu/thunk_emitter.cc index fdbb44f735996..e6560b9e1d53c 100644 --- a/xla/service/gpu/thunk_emitter.cc +++ b/xla/service/gpu/thunk_emitter.cc @@ -94,6 +94,7 @@ limitations under the License. #include "xla/backends/gpu/runtime/host_to_device_copy_thunk.h" #include "xla/backends/gpu/runtime/infeed_thunk.h" #include "xla/backends/gpu/runtime/kernel_thunk.h" +#include "xla/backends/gpu/runtime/legacy_custom_call_thunk.h" #include "xla/backends/gpu/runtime/norm_thunk.h" #include "xla/backends/gpu/runtime/nvshmem_all_reduce_thunk.h" #include "xla/backends/gpu/runtime/nvshmem_collective_permute_thunk.h" @@ -973,7 +974,7 @@ absl::StatusOr ThunkEmitter::EmitCustomCallThunk( << "Fall back to parse the raw backend config str."; } - auto ffi_thunk = [&]() -> absl::StatusOr> { + auto ffi_thunk = [&]() -> absl::StatusOr> { auto& called_computations = instr->called_computations(); auto& backend_config_str = backend_config.ok() @@ -1004,13 +1005,12 @@ absl::StatusOr ThunkEmitter::EmitCustomCallThunk( ir_emitter_context_->cpu_target_machine_options()); }; - auto legacy_thunk = - [&]() -> absl::StatusOr> { + auto legacy_thunk = [&]() -> absl::StatusOr> { std::string opaque = backend_config.ok() ? backend_config->custom_call_backend_config().opaque() : instr->raw_backend_config_string(); - return CustomCallThunk::Create( + return LegacyCustomCallThunk::Create( Thunk::ThunkInfo::WithProfileAnnotation( instr, ir_emitter_context_->GetNextThunkId()), call_target_name, std::move(operands), std::move(results), @@ -1018,7 +1018,7 @@ absl::StatusOr ThunkEmitter::EmitCustomCallThunk( ir_emitter_context_->platform_name()); }; - absl::StatusOr> custom_call_thunk = + absl::StatusOr> custom_call_thunk = is_ffi_custom_call ? ffi_thunk() : legacy_thunk(); ThunkSequence thunks;