diff --git a/xla/backends/gpu/codegen/BUILD b/xla/backends/gpu/codegen/BUILD index faa434da3e745..35fb29e3e9c92 100644 --- a/xla/backends/gpu/codegen/BUILD +++ b/xla/backends/gpu/codegen/BUILD @@ -144,6 +144,7 @@ cc_library( "//xla/backends/gpu/runtime:device_to_device_copy_thunk", "//xla/backends/gpu/runtime:dynamic_slice_thunk", "//xla/backends/gpu/runtime:gemm_thunk", + "//xla/backends/gpu/runtime:legacy_custom_call_thunk", "//xla/backends/gpu/runtime:thunk", "//xla/codegen/emitters:kernel_arguments", "//xla/ffi:attribute_map", diff --git a/xla/backends/gpu/codegen/custom.cc b/xla/backends/gpu/codegen/custom.cc index f4d8e7ee99ddf..d1260cbca8d6a 100644 --- a/xla/backends/gpu/codegen/custom.cc +++ b/xla/backends/gpu/codegen/custom.cc @@ -46,6 +46,7 @@ limitations under the License. #include "xla/backends/gpu/runtime/device_to_device_copy_thunk.h" #include "xla/backends/gpu/runtime/dynamic_slice_thunk.h" #include "xla/backends/gpu/runtime/gemm_thunk.h" +#include "xla/backends/gpu/runtime/legacy_custom_call_thunk.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/codegen/emitters/kernel_arguments.h" #include "xla/ffi/attribute_map.h" @@ -878,7 +879,7 @@ absl::StatusOr EmitCustomCall( // For legacy custom calls we convert all API versions into the latest // status-returning one and pass backend config as an opaque string. - CustomCallThunk::CustomCallTarget custom_call_target; + LegacyCustomCallThunk::CustomCallTarget custom_call_target; // For XLA FFI handlers we decode opaque backend config into attributes map // at IR emission time, so that we do not need to parse MLIR at run time. @@ -926,9 +927,8 @@ absl::StatusOr EmitCustomCall( auto thunk_info = Thunk::ThunkInfo::WithProfileAnnotation( &fusion, ir_emitter_context.GetNextThunkId()); - auto ffi_thunk = - [&](Slices ops, - Slices res) -> absl::StatusOr> { + auto ffi_thunk = [&](Slices ops, + Slices res) -> absl::StatusOr> { auto& called_computations = custom_call.called_computations(); auto& backend_config_str = backend_config.ok() @@ -954,16 +954,15 @@ absl::StatusOr EmitCustomCall( }; auto legacy_thunk = - [&](Slices ops, - Slices res) -> absl::StatusOr> { + [&](Slices ops, Slices res) -> absl::StatusOr> { std::string opaque = backend_config.ok() ? backend_config->custom_call_backend_config().opaque() : custom_call.raw_backend_config_string(); - return CustomCallThunk::Create(thunk_info, call_target_name, std::move(ops), - std::move(res), std::move(opaque), - custom_call.api_version(), - ir_emitter_context.platform_name()); + return LegacyCustomCallThunk::Create( + thunk_info, call_target_name, std::move(ops), std::move(res), + std::move(opaque), custom_call.api_version(), + ir_emitter_context.platform_name()); }; std::vector fake_allocations(num_args, {0, 0, 0}); diff --git a/xla/backends/gpu/runtime/BUILD b/xla/backends/gpu/runtime/BUILD index 3c366022aadf2..50e5d52b4b14f 100644 --- a/xla/backends/gpu/runtime/BUILD +++ b/xla/backends/gpu/runtime/BUILD @@ -188,6 +188,7 @@ cc_library( ":dynamic_memcpy_thunk", ":dynamic_slice_thunk", ":gpublas_lt_matmul_thunk", + ":legacy_custom_call_thunk", ":p2p_thunk_common", ":ragged_all_to_all_thunk", ":recv_thunk", @@ -345,6 +346,7 @@ cc_library( ":gemm_thunk", ":gpublas_lt_matmul_thunk", ":kernel_thunk", + ":legacy_custom_call_thunk", ":memset_thunk", ":ragged_all_to_all_thunk", ":recv_thunk", @@ -1082,8 +1084,8 @@ cc_library( ":collective_cliques", ":collective_memory", ":collective_params", - ":custom_call_target", ":thunk", + ":thunk_proto_cc", "//xla:executable_run_options", "//xla:shape_util", "//xla:util", @@ -1101,16 +1103,13 @@ cc_library( "//xla/runtime:buffer_use", "//xla/runtime:object_pool", "//xla/service:buffer_assignment", - "//xla/service:custom_call_status", - "//xla/service:custom_call_status_internal", - "//xla/service:custom_call_target_registry", + "//xla/service:hlo_proto_cc", "//xla/service:shaped_slice", "//xla/service/gpu:buffer_allocations", "//xla/stream_executor:device_address", "//xla/stream_executor:device_address_allocator", "//xla/stream_executor:device_description", "//xla/stream_executor:stream", - "//xla/tsl/platform:errors", "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/util:unique_any", @@ -1123,7 +1122,35 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "legacy_custom_call_thunk", + srcs = ["legacy_custom_call_thunk.cc"], + hdrs = ["legacy_custom_call_thunk.h"], + deps = [ + ":custom_call_target", + ":thunk", + ":thunk_proto_cc", + "//xla:util", + "//xla/runtime:buffer_use", + "//xla/service:buffer_assignment", + "//xla/service:custom_call_status", + "//xla/service:custom_call_status_internal", + "//xla/service:custom_call_target_registry", + "//xla/service:hlo_proto_cc", + "//xla/service:shaped_slice", + "//xla/service/gpu:buffer_allocations", + "//xla/stream_executor:stream", + "//xla/tsl/platform:status_macros", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/log", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@tsl//tsl/platform", @@ -1135,9 +1162,12 @@ xla_test( srcs = ["custom_call_thunk_test.cc"], backends = ["gpu"], deps = [ + ":collective_clique_requests", ":collective_memory_requests", + ":collective_params", ":custom_call_thunk", ":thunk", + ":thunk_proto_cc", "//xla:executable_run_options", "//xla:shape_util", "//xla/backends/cpu:target_machine_options", @@ -1149,10 +1179,9 @@ xla_test( "//xla/hlo/ir:hlo", "//xla/runtime:device_id", "//xla/service:buffer_assignment", - "//xla/service:custom_call_status_public_headers", - "//xla/service:custom_call_target_registry", "//xla/service:executable", "//xla/service:hlo_module_config", + "//xla/service:hlo_proto_cc", "//xla/service:platform_util", "//xla/service:shaped_slice", "//xla/service/gpu:buffer_allocations", @@ -1162,6 +1191,7 @@ xla_test( "//xla/stream_executor:platform_manager", "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_address_allocator", + "//xla/tsl/platform:status_macros", "//xla/tsl/platform:statusor", "//xla/tsl/util/proto:parse_text_proto", "@com_google_absl//absl/base", @@ -1175,6 +1205,36 @@ xla_test( ], ) +xla_test( + name = "legacy_custom_call_thunk_test", + srcs = ["legacy_custom_call_thunk_test.cc"], + backends = ["gpu"], + deps = [ + ":legacy_custom_call_thunk", + ":thunk", + ":thunk_proto_cc", + "//xla/service:buffer_assignment", + "//xla/service:custom_call_status_public_headers", + "//xla/service:custom_call_target_registry", + "//xla/service:executable", + "//xla/service:hlo_proto_cc", + "//xla/service:platform_util", + "//xla/service/gpu:buffer_allocations", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_address_allocator", + "//xla/tsl/platform:status_macros", + "//xla/tsl/platform:statusor", + "//xla/tsl/util/proto:parse_text_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "fft_thunk", srcs = ["fft_thunk.cc"], @@ -3606,6 +3666,7 @@ cc_library( ":host_to_device_copy_thunk", ":infeed_thunk", ":kernel_thunk", + ":legacy_custom_call_thunk", ":memset_thunk", ":norm_thunk", ":nvshmem_all_reduce_thunk", @@ -3627,6 +3688,7 @@ cc_library( "//xla/backends/cpu:target_machine_options", "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", + "//xla/service:hlo_proto_cc", "//xla/stream_executor:device_description", "//xla/stream_executor:kernel_spec", "//xla/stream_executor:stream_executor_h", @@ -3639,6 +3701,7 @@ cc_library( "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:protobuf_lite", ], ) diff --git a/xla/backends/gpu/runtime/command_buffer_cmd.cc b/xla/backends/gpu/runtime/command_buffer_cmd.cc index eb41cb8bb362f..ec99954882444 100644 --- a/xla/backends/gpu/runtime/command_buffer_cmd.cc +++ b/xla/backends/gpu/runtime/command_buffer_cmd.cc @@ -63,6 +63,7 @@ limitations under the License. #include "xla/backends/gpu/runtime/dynamic_memcpy_thunk.h" #include "xla/backends/gpu/runtime/dynamic_slice_thunk.h" #include "xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.h" +#include "xla/backends/gpu/runtime/legacy_custom_call_thunk.h" #include "xla/backends/gpu/runtime/p2p_thunk_common.h" #include "xla/backends/gpu/runtime/ragged_all_to_all_thunk.h" #include "xla/backends/gpu/runtime/recv_thunk.h" @@ -1211,27 +1212,16 @@ Command::BufferUses CuDnnCmd::buffer_uses() const { } //===----------------------------------------------------------------------===// -// CustomCallCmd +// LegacyCustomCallCmd //===----------------------------------------------------------------------===// -absl::StatusOr CustomCallCmd::Record( - const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, RecordAction record_action, - se::CommandBuffer* command_buffer) { - if (handler_ == nullptr) { - return RecordLegacyCustomCall(execute_params, record_params, - std::move(record_action), command_buffer); - } - return RecordXlaFfiCall(execute_params, record_params, - std::move(record_action), command_buffer); -} - namespace { // Records each buffer associated with each slice into the provided vector. // Returns an error if any of the slices is missing a buffer allocation. -absl::Status GetBuffers(const Thunk::ExecuteParams& execute_params, - absl::Span slices, - std::vector& buffers, absl::string_view label) { +static absl::Status GetBuffers(const Thunk::ExecuteParams& execute_params, + absl::Span slices, + std::vector& buffers, + absl::string_view label) { for (int i = 0; i < slices.size(); ++i) { if (!slices[i].has_value()) { buffers.push_back(nullptr); @@ -1253,21 +1243,18 @@ absl::Status GetBuffers(const Thunk::ExecuteParams& execute_params, } } // namespace -absl::StatusOr -CustomCallCmd::RecordLegacyCustomCall( +absl::StatusOr LegacyCustomCallCmd::Record( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) { std::vector buffers; buffers.reserve(operands_.size() + results_.size()); - VLOG(5) << "CustomCallCmd: target_name=" << target_name_; - TF_RETURN_IF_ERROR( - GetBuffers(execute_params, operands_, buffers, " Operand ")); - TF_RETURN_IF_ERROR( - GetBuffers(execute_params, results_, buffers, " Result ")); + VLOG(5) << "LegacyCustomCallCmd: target_name=" << target_name_; + RETURN_IF_ERROR(GetBuffers(execute_params, operands_, buffers, " Operand ")); + RETURN_IF_ERROR(GetBuffers(execute_params, results_, buffers, " Result ")); - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto nested_cmd, se::TraceCommandBufferFactory::Create( execute_params.stream->parent(), @@ -1293,11 +1280,26 @@ CustomCallCmd::RecordLegacyCustomCall( }); } -absl::StatusOr -CustomCallCmd::RecordXlaFfiCall(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - RecordAction record_action, - se::CommandBuffer* command_buffer) { +Command::BufferUses LegacyCustomCallCmd::buffer_uses() const { + Command::BufferUses buffer_usage; + for (auto& slices : {operands_, results_}) { + for (const std::optional& slice : slices) { + if (slice.has_value()) { + buffer_usage.push_back(BufferUse::Write(slice->slice, slice->shape)); + } + } + } + return buffer_usage; +} + +//===----------------------------------------------------------------------===// +// CustomCallCmd (FFI) +//===----------------------------------------------------------------------===// + +absl::StatusOr CustomCallCmd::Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, RecordAction record_action, + se::CommandBuffer* command_buffer) { // TODO(ezhulenev): This is not the most optimal approach, as we'll be doing // a lot of extra allocation on every call. We have to keep attributes // separate from arguments, as they do not change after thunk is @@ -1342,8 +1344,8 @@ CustomCallCmd::RecordXlaFfiCall(const Thunk::ExecuteParams& execute_params, // Borrow the FFI call frame from the object pool and update with the actual // device memory addresses. - TF_ASSIGN_OR_RETURN(auto call_frame, call_frames_->GetOrCreate()); - TF_RETURN_IF_ERROR(call_frame->UpdateWithBuffers(arguments, results)); + ASSIGN_OR_RETURN(auto call_frame, call_frames_->GetOrCreate()); + RETURN_IF_ERROR(call_frame->UpdateWithBuffers(arguments, results)); RunId run_id = execute_params.collective_params->run_id; @@ -1360,7 +1362,7 @@ CustomCallCmd::RecordXlaFfiCall(const Thunk::ExecuteParams& execute_params, } } - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( auto nested_cmd, se::TraceCommandBufferFactory::Create( execute_params.stream->parent(), diff --git a/xla/backends/gpu/runtime/command_buffer_cmd.h b/xla/backends/gpu/runtime/command_buffer_cmd.h index ee072b630ea83..263e04c54e0f3 100644 --- a/xla/backends/gpu/runtime/command_buffer_cmd.h +++ b/xla/backends/gpu/runtime/command_buffer_cmd.h @@ -37,10 +37,10 @@ limitations under the License. #include "xla/backends/gpu/runtime/command.h" #include "xla/backends/gpu/runtime/command_executor.h" #include "xla/backends/gpu/runtime/command_state.h" -#include "xla/backends/gpu/runtime/custom_call_thunk.h" #include "xla/backends/gpu/runtime/dynamic_memcpy_thunk.h" #include "xla/backends/gpu/runtime/dynamic_slice_thunk.h" #include "xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.h" +#include "xla/backends/gpu/runtime/legacy_custom_call_thunk.h" #include "xla/backends/gpu/runtime/p2p_thunk_common.h" #include "xla/backends/gpu/runtime/ragged_all_to_all_thunk.h" #include "xla/backends/gpu/runtime/thunk.h" @@ -451,27 +451,11 @@ class CuDnnCmd : public TracedCommandBufferCmd { }; //===----------------------------------------------------------------------===// -// CustomCallCmd +// CustomCallCmd (FFI) //===----------------------------------------------------------------------===// class CustomCallCmd : public Command { public: - using CustomCallTarget = CustomCallThunk::CustomCallTarget; - using AttributesMap = ffi::AttributesMap; - - // This is a legacy custom call API that is discouraged, and will be - // deprecated once XLA:FFI mechanism is ready. - CustomCallCmd(std::string target_name, CustomCallTarget call_target, - std::vector operands, - std::vector results, - absl::string_view opaque) - : Command(CommandType::kCustomCallCmd), - target_name_(std::move(target_name)), - call_target_(std::move(call_target)), - opaque_(opaque), - operands_(std::move(operands)), - results_(std::move(results)) {} - CustomCallCmd(std::string target_name, XLA_FFI_Handler* handler, std::vector operands, std::vector results, @@ -498,26 +482,8 @@ class CustomCallCmd : public Command { bool IsNestedCommandBuffer() const final { return true; } private: - absl::StatusOr RecordLegacyCustomCall( - const Thunk::ExecuteParams& execute_param, - const RecordParams& record_params, RecordAction record_action, - se::CommandBuffer* command_buffer); - - absl::StatusOr RecordXlaFfiCall( - const Thunk::ExecuteParams& execute_param, - const RecordParams& record_params, RecordAction record_action, - se::CommandBuffer* command_buffer); - std::string target_name_; - // This is a legacy custom call API that is discouraged, and will be - // deprecated once XLA:FFI mechanism is ready. - CustomCallTarget call_target_; - std::string opaque_; - - // XLA FFI provides a right type safe mechanism for registering external - // functions with XLA runtime. It's under construction, and still misses - // a lot of features. Long term it will replace legacy custom calls. XLA_FFI_Handler* handler_ = nullptr; // Reference call frame pre-initialized at construction time. @@ -540,6 +506,43 @@ class CustomCallCmd : public Command { std::vector results_; }; +//===----------------------------------------------------------------------===// +// LegacyCustomCallCmd +//===----------------------------------------------------------------------===// + +class LegacyCustomCallCmd : public Command { + public: + using CustomCallTarget = LegacyCustomCallThunk::CustomCallTarget; + + LegacyCustomCallCmd(std::string target_name, CustomCallTarget call_target, + std::vector operands, + std::vector results, + absl::string_view opaque) + : Command(CommandType::kCustomCallCmd), + target_name_(std::move(target_name)), + call_target_(std::move(call_target)), + opaque_(opaque), + operands_(std::move(operands)), + results_(std::move(results)) {} + + absl::StatusOr Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, RecordAction record_action, + se::CommandBuffer* command_buffer) override; + + BufferUses buffer_uses() const override; + bool IsNestedCommandBuffer() const final { return true; } + + private: + std::string target_name_; + + CustomCallTarget call_target_; + std::string opaque_; + + std::vector operands_; + std::vector results_; +}; + //===----------------------------------------------------------------------===// // CollectiveCmd //===----------------------------------------------------------------------===// diff --git a/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc b/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc index 338353bd071c1..37201e0d1edf5 100644 --- a/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc +++ b/xla/backends/gpu/runtime/command_buffer_cmd_emitter.cc @@ -49,6 +49,7 @@ limitations under the License. #include "xla/backends/gpu/runtime/gemm_thunk.h" #include "xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.h" #include "xla/backends/gpu/runtime/kernel_thunk.h" +#include "xla/backends/gpu/runtime/legacy_custom_call_thunk.h" #include "xla/backends/gpu/runtime/memset_thunk.h" #include "xla/backends/gpu/runtime/ragged_all_to_all_thunk.h" #include "xla/backends/gpu/runtime/recv_thunk.h" @@ -277,9 +278,16 @@ static absl::StatusOr> Convert( thunk.execution_state(), /*called_computation=*/nullptr); // TODO(b/342285364) } - return std::make_unique(thunk.target_name(), - thunk.call_target(), thunk.operands(), - thunk.results(), thunk.opaque()); + return absl::InternalError( + "CustomCallThunk without FFI handler bundle cannot be converted to a " + "command buffer command"); +} + +static absl::StatusOr> Convert( + const LegacyCustomCallThunk& thunk) { + return std::make_unique( + thunk.target_name(), thunk.call_target(), thunk.operands(), + thunk.results(), thunk.opaque()); } static absl::StatusOr> Convert( @@ -329,7 +337,14 @@ static absl::Status AppendCommands(ConversionContext& ctx, return append(Convert(thunk)); } case Thunk::Kind::kCustomCall: - return append(Convert(thunk)); + if (auto* ffi_thunk = dynamic_cast(&thunk)) { + return append(Convert(*ffi_thunk)); + } + if (auto* legacy_thunk = + dynamic_cast(&thunk)) { + return append(Convert(*legacy_thunk)); + } + return absl::InternalError("Unknown custom call thunk type"); case Thunk::Kind::kCustomKernel: return append(Convert(thunk)); case Thunk::Kind::kKernel: diff --git a/xla/backends/gpu/runtime/command_buffer_conversion_pass.cc b/xla/backends/gpu/runtime/command_buffer_conversion_pass.cc index 928ef38f9f521..6de74dcb533a4 100644 --- a/xla/backends/gpu/runtime/command_buffer_conversion_pass.cc +++ b/xla/backends/gpu/runtime/command_buffer_conversion_pass.cc @@ -204,8 +204,8 @@ bool IsConvertible(const ConditionalThunk& conditional_thunk, } // Returns true if the CustomCallThunk is convertible to a command buffer -// operation. Checks if the custom call target is in the legacy allowlist or if -// the registered FFI handler is compatible with command buffers. +// operation. Checks if the registered FFI handler is compatible with command +// buffers. bool IsConvertible(const CustomCallThunk& custom_call_thunk, const CommandBufferConfig& config) { const std::string& target_name = custom_call_thunk.target_name(); @@ -312,7 +312,11 @@ bool IsConvertible(const Thunk& thunk, const CommandBufferConfig& config) { } if (thunk.kind() == Thunk::kCustomCall) { - return IsConvertible(static_cast(thunk), config); + if (auto* ffi_thunk = dynamic_cast(&thunk)) { + return IsConvertible(*ffi_thunk, config); + } + // Legacy custom calls are not command-buffer compatible. + return false; } if (thunk.kind() == Thunk::kDynamicSlice) { diff --git a/xla/backends/gpu/runtime/custom_call_thunk.cc b/xla/backends/gpu/runtime/custom_call_thunk.cc index 4788b4279384f..f7267d79189b7 100644 --- a/xla/backends/gpu/runtime/custom_call_thunk.cc +++ b/xla/backends/gpu/runtime/custom_call_thunk.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/backends/gpu/runtime/custom_call_thunk.h" +#include #include #include #include @@ -34,15 +35,14 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/backends/cpu/target_machine_options.h" #include "xla/backends/gpu/runtime/collective_clique_requests.h" #include "xla/backends/gpu/runtime/collective_cliques.h" #include "xla/backends/gpu/runtime/collective_params.h" -#include "xla/backends/gpu/runtime/custom_call_target.h" #include "xla/backends/gpu/runtime/thunk.h" +#include "xla/backends/gpu/runtime/thunk.pb.h" #include "xla/executable_run_options.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/attribute_map.h" @@ -55,24 +55,18 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/runtime/object_pool.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/custom_call_status.h" -#include "xla/service/custom_call_status_internal.h" -#include "xla/service/custom_call_target_registry.h" #include "xla/service/gpu/buffer_allocations.h" +#include "xla/service/hlo.pb.h" #include "xla/service/shaped_slice.h" #include "xla/stream_executor/device_address.h" #include "xla/stream_executor/device_address_allocator.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/stream.h" -#include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/statusor.h" #include "xla/tsl/util/unique_any.h" -#include "xla/util.h" -#include "tsl/platform/platform.h" #include "xla/tsl/platform/status_macros.h" -namespace xla { -namespace gpu { +namespace xla::gpu { using xla::ffi::CallFrame; using xla::ffi::CallFrameBuilder; @@ -89,8 +83,6 @@ static absl::StatusOr BuildCallFramePrototype( /*num_args=*/operands.size(), /*num_rets=*/results.size()); - // Add prototype input buffers with actual data types and shapes. Device - // memory addresses will be updated at runtime. for (int i = 0; i < operands.size(); ++i) { auto& operand = operands[i]; @@ -107,8 +99,6 @@ static absl::StatusOr BuildCallFramePrototype( operand->shape.dimensions()); } - // Add prototype output buffers with actual data types and shapes. Device - // memory addresses will be updated at runtime. for (int i = 0; i < results.size(); ++i) { auto& result = results[i]; @@ -125,7 +115,6 @@ static absl::StatusOr BuildCallFramePrototype( result->shape.dimensions()); } - // Add attributes if any. if (!attributes.empty()) { ffi::CallFrameBuilder::AttributesBuilder attrs; attrs.Append(std::move(attributes)); @@ -135,111 +124,6 @@ static absl::StatusOr BuildCallFramePrototype( return builder.Build(); } -static absl::StatusOr -ResolveLegacyCustomCall(const CustomCallTargetRegistry& registry, - absl::string_view target_name, - absl::string_view platform_name, - CustomCallApiVersion api_version) { - void* call_target = - registry.Lookup(std::string(target_name), std::string(platform_name)); - - if (call_target == nullptr) { - return NotFound( - "No registered implementation for custom call to %s for platform %s", - target_name, platform_name); - } - - // For information about this calling convention, see - // xla/g3doc/custom_call.md. - switch (api_version) { - case CustomCallApiVersion::API_VERSION_ORIGINAL: { - constexpr absl::string_view kErrorMessage = - "Custom call API version `API_VERSION_ORIGINAL` is not supported by " - "XLA:GPU. Prefer https://docs.jax.dev/en/latest/ffi.html. It will be " - "fully removed in November 2025."; - if constexpr (tsl::kIsOpenSource) { - LOG(ERROR) << kErrorMessage; - } else { - LOG(FATAL) << kErrorMessage; - } - - return [call_target](stream_executor::Stream* stream, void** buffers, - const char* opaque, size_t opaque_len, - XlaCustomCallStatus*) { - reinterpret_cast(call_target)( - stream->platform_specific_handle().stream, buffers, opaque, - opaque_len); - }; - break; - } - case CustomCallApiVersion::API_VERSION_STATUS_RETURNING: - case CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED: - return [call_target](stream_executor::Stream* stream, void** buffers, - const char* opaque, size_t opaque_len, - XlaCustomCallStatus* status) { - reinterpret_cast( - call_target)(stream->platform_specific_handle().stream, buffers, - opaque, opaque_len, status); - }; - break; - case CustomCallApiVersion::API_VERSION_TYPED_FFI: - return absl::InvalidArgumentError( - "Called ResolveLegacyCustomCall with API_VERSION_TYPED_FFI"); - default: - return Internal("Unknown custom-call API version enum value: %d", - api_version); - } -} - -absl::StatusOr> CustomCallThunk::Create( - ThunkInfo thunk_info, std::string target_name, CustomCallTarget call_target, - std::vector operands, - std::vector results, std::string opaque) { - return absl::WrapUnique(new CustomCallThunk( - thunk_info, std::move(target_name), std::move(operands), - std::move(results), std::move(opaque), std::move(call_target), - /*api_version=*/std::nullopt)); -} - -absl::StatusOr> CustomCallThunk::Create( - ThunkInfo thunk_info, std::string target_name, - std::vector operands, - std::vector results, std::string opaque, - CustomCallApiVersion api_version, absl::string_view platform_name) { - if (api_version == CustomCallApiVersion::API_VERSION_TYPED_FFI) { - return absl::InvalidArgumentError( - "Called overload of CustomCallThunk::Create that is intended for " - "legacy custom calls with api_version=API_VERSION_TYPED_FFI"); - } - - TF_ASSIGN_OR_RETURN( - CustomCallTarget call_target, - ResolveLegacyCustomCall(*CustomCallTargetRegistry::Global(), target_name, - platform_name, api_version)); - - return absl::WrapUnique(new CustomCallThunk( - thunk_info, std::move(target_name), std::move(operands), - std::move(results), std::move(opaque), call_target, api_version)); -} - -absl::StatusOr> CustomCallThunk::Create( - ThunkInfo thunk_info, std::string target_name, - std::vector operands, - std::vector results, ffi::AttributesMap attributes, - const HloComputation* called_computation, absl::string_view platform_name, - const se::GpuComputeCapability& gpu_compute_capability, - std::unique_ptr execution_state, - std::optional cpu_target_machine_options) { - TF_ASSIGN_OR_RETURN(ffi::HandlerRegistration registration, - ffi::FindHandler(target_name, platform_name)); - - return Create(thunk_info, std::move(target_name), - std::move(registration.bundle), std::move(operands), - std::move(results), std::move(attributes), called_computation, - gpu_compute_capability, std::move(execution_state), - std::move(cpu_target_machine_options)); -} - static InvokeContext BuildInstantiateInvokeContext( ffi::ExecutionState* execution_state, const se::GpuComputeCapability* gpu_compute_capability, @@ -260,6 +144,24 @@ static InvokeContext BuildInstantiateInvokeContext( return context; } +absl::StatusOr> CustomCallThunk::Create( + ThunkInfo thunk_info, std::string target_name, + std::vector operands, + std::vector results, ffi::AttributesMap attributes, + const HloComputation* called_computation, absl::string_view platform_name, + const se::GpuComputeCapability& gpu_compute_capability, + std::unique_ptr execution_state, + std::optional cpu_target_machine_options) { + ASSIGN_OR_RETURN(ffi::HandlerRegistration registration, + ffi::FindHandler(target_name, platform_name)); + + return Create(thunk_info, std::move(target_name), + std::move(registration.bundle), std::move(operands), + std::move(results), std::move(attributes), called_computation, + gpu_compute_capability, std::move(execution_state), + std::move(cpu_target_machine_options)); +} + absl::StatusOr> CustomCallThunk::Create( ThunkInfo thunk_info, std::string target_name, XLA_FFI_Handler_Bundle bundle, std::vector operands, @@ -289,8 +191,8 @@ absl::StatusOr> CustomCallThunk::Create( } } - TF_ASSIGN_OR_RETURN(CallFrame call_frame, - BuildCallFramePrototype(operands, results, attributes)); + ASSIGN_OR_RETURN(CallFrame call_frame, + BuildCallFramePrototype(operands, results, attributes)); return absl::WrapUnique(new CustomCallThunk( thunk_info, std::move(target_name), std::move(bundle), std::move(operands), std::move(results), std::move(call_frame), @@ -313,12 +215,11 @@ absl::StatusOr> CustomCallThunk::Create( auto execution_state = std::make_unique(); - // Initialize FFI handler state if it has an instantiate callback. if (bundle.instantiate) { // Build a call frame with placeholder buffers so the instantiate handler // can read operand/result types and shapes. Data pointers are nullptr. - TF_ASSIGN_OR_RETURN(CallFrame call_frame, - BuildCallFramePrototype(operands, results, attributes)); + ASSIGN_OR_RETURN(CallFrame call_frame, + BuildCallFramePrototype(operands, results, attributes)); if (!cpu_target_machine_options.has_value()) { cpu_target_machine_options = xla::cpu::TargetMachineOptions(); @@ -326,13 +227,12 @@ absl::StatusOr> CustomCallThunk::Create( InvokeContext context = BuildInstantiateInvokeContext( execution_state.get(), &gpu_compute_capability, &*cpu_target_machine_options); - TF_RETURN_IF_ERROR(Invoke(ffi::GetXlaFfiApi(), *bundle.instantiate, - call_frame, context, - xla::ffi::ExecutionStage::kInstantiate)); + RETURN_IF_ERROR(Invoke(ffi::GetXlaFfiApi(), *bundle.instantiate, call_frame, + context, xla::ffi::ExecutionStage::kInstantiate)); } - TF_ASSIGN_OR_RETURN(CallFrame call_frame, - BuildCallFramePrototype(operands, results, attributes)); + ASSIGN_OR_RETURN(CallFrame call_frame, + BuildCallFramePrototype(operands, results, attributes)); return absl::WrapUnique(new CustomCallThunk( thunk_info, std::move(target_name), std::move(bundle), std::move(operands), std::move(results), std::move(call_frame), @@ -340,20 +240,6 @@ absl::StatusOr> CustomCallThunk::Create( cpu_target_machine_options)); } -CustomCallThunk::CustomCallThunk( - ThunkInfo thunk_info, std::string target_name, - std::vector operands, - std::vector results, std::string opaque, - CustomCallTarget call_target, - const std::optional& api_version) - : Thunk(Thunk::kCustomCall, thunk_info), - api_version_(api_version), - target_name_(std::move(target_name)), - operands_(std::move(operands)), - results_(std::move(results)), - call_target_(std::move(call_target)), - opaque_(std::move(opaque)) {} - CustomCallThunk::CustomCallThunk( ThunkInfo thunk_info, std::string target_name, std::variant bundle, @@ -364,7 +250,6 @@ CustomCallThunk::CustomCallThunk( const HloComputation* called_computation, std::optional cpu_target_machine_options) : Thunk(Thunk::kCustomCall, thunk_info), - api_version_(CustomCallApiVersion::API_VERSION_TYPED_FFI), target_name_(std::move(target_name)), operands_(std::move(operands)), results_(std::move(results)), @@ -376,42 +261,6 @@ CustomCallThunk::CustomCallThunk( called_computation_(called_computation), cpu_target_machine_options_(std::move(cpu_target_machine_options)) {} -absl::Status CustomCallThunk::ExecuteCustomCall(const ExecuteParams& params) { - // gpu_stream is CUstream or e.g. the equivalent type in ROCm. - std::vector buffers; - buffers.reserve(operands_.size() + results_.size()); - for (auto& slices : {operands_, results_}) { - for (const std::optional& slice : slices) { - if (!slice.has_value()) { - buffers.push_back(nullptr); - continue; - } - - if (!slice->slice.allocation()) { - return Internal("custom call input missing buffer allocation"); - } - - buffers.push_back( - params.buffer_allocations->GetDeviceAddress(slice->slice).opaque()); - } - } - - XlaCustomCallStatus custom_call_status; - call_target_(params.stream, buffers.data(), opaque_.data(), opaque_.size(), - &custom_call_status); - auto message = CustomCallStatusGetMessage(&custom_call_status); - if (message) { - return Internal("CustomCall failed: %s", *message); - } - return absl::OkStatus(); -} - -// Builds a call frame for the custom call. -// -// If `buffer_allocations` is provided, the call frame will contain the actual -// device memory addresses of the buffers. Otherwise, the call frame will -// contain placeholders - this should only be the case when calling Prepare() -// stage handler. absl::StatusOr::BorrowedObject> CustomCallThunk::BuildCallFrame( const BufferAllocations* absl_nullable buffer_allocations) { @@ -420,7 +269,6 @@ CustomCallThunk::BuildCallFrame( : se::DeviceAddressBase{}; }; - // Collect arguments buffers. absl::InlinedVector arguments; arguments.reserve(operands_.size()); for (auto& operand : operands_) { @@ -431,7 +279,6 @@ CustomCallThunk::BuildCallFrame( } } - // Collect results buffers. absl::InlinedVector results; results.reserve(results_.size()); for (auto& result : results_) { @@ -442,17 +289,11 @@ CustomCallThunk::BuildCallFrame( } } - // Borrow the FFI call frame from the object pool and update with the actual - // device memory addresses. - TF_ASSIGN_OR_RETURN(auto call_frame, call_frames_->GetOrCreate()); - TF_RETURN_IF_ERROR(call_frame->UpdateWithBuffers(arguments, results)); + ASSIGN_OR_RETURN(auto call_frame, call_frames_->GetOrCreate()); + RETURN_IF_ERROR(call_frame->UpdateWithBuffers(arguments, results)); return call_frame; } -// Builds call options object for the custom call. -// -// `stream` and `buffer_allocations may only be non-null for options passed to -// Prepare()_stage handler. InvokeContext CustomCallThunk::BuildInvokeContext( RunId run_id, se::Stream* absl_nullable stream, Thunk::ExecutionScopedState* absl_nullable execution_scoped_state, @@ -476,7 +317,6 @@ InvokeContext CustomCallThunk::BuildInvokeContext( &stream->parent()->GetDeviceDescription().gpu_compute_capability(); } - // Lookup per-execution state for prepare and init stages. ffi::ExecutionState* prepare_state = nullptr; ffi::ExecutionState* initialize_state = nullptr; @@ -522,7 +362,7 @@ absl::Status CustomCallThunk::ExecuteFfiHandler( return absl::InternalError("buffer allocations and stream are required"); } - TF_ASSIGN_OR_RETURN(auto call_frame, BuildCallFrame(buffer_allocations)); + ASSIGN_OR_RETURN(auto call_frame, BuildCallFrame(buffer_allocations)); InvokeContext context = BuildInvokeContext( run_id, stream, execution_scoped_state, buffer_allocations, collective_params, collective_clique_requests, collective_memory_requests, @@ -545,7 +385,7 @@ absl::Status CustomCallThunk::ExecuteFfiHandler( return absl::InternalError("buffer allocations and stream are required"); } - TF_ASSIGN_OR_RETURN(auto call_frame, BuildCallFrame(buffer_allocations)); + ASSIGN_OR_RETURN(auto call_frame, BuildCallFrame(buffer_allocations)); InvokeContext context = BuildInvokeContext( run_id, stream, execution_scoped_state, buffer_allocations, collective_params, collective_clique_requests, collective_memory_requests, @@ -554,74 +394,66 @@ absl::Status CustomCallThunk::ExecuteFfiHandler( } absl::Status CustomCallThunk::Prepare(const PrepareParams& params) { - if (bundle_.has_value()) { - const RunId run_id = - params.collective_params ? params.collective_params->run_id : RunId{-1}; - - if (const auto* c_bundle = - std::get_if(&bundle_.value()); - c_bundle && c_bundle->prepare) { - return ExecuteFfiHandler( - run_id, c_bundle->prepare, XLA_FFI_ExecutionStage_PREPARE, - /*stream=*/nullptr, - /*execution_scoped_state=*/params.execution_scoped_state, - /*execution_context=*/nullptr, - /*buffer_allocations=*/params.buffer_allocations, - /*collective_params=*/params.collective_params, - /*collective_clique_requests=*/params.collective_clique_requests, - /*collective_memory_requests=*/params.collective_memory_requests, - /*collective_cliques=*/nullptr, - /*collective_memory=*/nullptr); - } - if (const auto* owned_bundle = - std::get_if(&bundle_.value()); - owned_bundle && owned_bundle->prepare) { - return ExecuteFfiHandler( - run_id, *owned_bundle->prepare, xla::ffi::ExecutionStage::kPrepare, - /*stream=*/nullptr, - /*execution_scoped_state=*/params.execution_scoped_state, - /*execution_context=*/nullptr, - /*buffer_allocations=*/params.buffer_allocations, - /*collective_params=*/params.collective_params, - /*collective_clique_requests=*/params.collective_clique_requests, - /*collective_memory_requests=*/params.collective_memory_requests, - /*collective_cliques=*/nullptr, - /*collective_memory=*/nullptr); - } + const RunId run_id = + params.collective_params ? params.collective_params->run_id : RunId{-1}; + + if (const auto* c_bundle = std::get_if(&bundle_); + c_bundle && c_bundle->prepare) { + return ExecuteFfiHandler( + run_id, c_bundle->prepare, XLA_FFI_ExecutionStage_PREPARE, + /*stream=*/nullptr, + /*execution_scoped_state=*/params.execution_scoped_state, + /*execution_context=*/nullptr, + /*buffer_allocations=*/params.buffer_allocations, + /*collective_params=*/params.collective_params, + /*collective_clique_requests=*/params.collective_clique_requests, + /*collective_memory_requests=*/params.collective_memory_requests, + /*collective_cliques=*/nullptr, + /*collective_memory=*/nullptr); + } + if (const auto* owned_bundle = std::get_if(&bundle_); + owned_bundle && owned_bundle->prepare) { + return ExecuteFfiHandler( + run_id, *owned_bundle->prepare, xla::ffi::ExecutionStage::kPrepare, + /*stream=*/nullptr, + /*execution_scoped_state=*/params.execution_scoped_state, + /*execution_context=*/nullptr, + /*buffer_allocations=*/params.buffer_allocations, + /*collective_params=*/params.collective_params, + /*collective_clique_requests=*/params.collective_clique_requests, + /*collective_memory_requests=*/params.collective_memory_requests, + /*collective_cliques=*/nullptr, + /*collective_memory=*/nullptr); } return absl::OkStatus(); } absl::Status CustomCallThunk::Initialize(const InitializeParams& params) { - if (bundle_.has_value()) { - const RunId run_id = - params.collective_params ? params.collective_params->run_id : RunId{-1}; - - if (const auto* c_bundle = - std::get_if(&bundle_.value()); - c_bundle && c_bundle->initialize) { - return ExecuteFfiHandler( - run_id, *c_bundle->initialize, XLA_FFI_ExecutionStage_INITIALIZE, - params.stream, params.execution_scoped_state, - params.ffi_execution_context, params.buffer_allocations, - params.collective_params, - /*collective_clique_requests=*/nullptr, - /*collective_memory_requests=*/nullptr, params.collective_cliques, - params.collective_memory); - } - if (const auto* owned_bundle = - std::get_if(&bundle_.value()); - owned_bundle && owned_bundle->initialize) { - return ExecuteFfiHandler( - run_id, *owned_bundle->initialize, - xla::ffi::ExecutionStage::kInitialize, params.stream, - params.execution_scoped_state, params.ffi_execution_context, - params.buffer_allocations, params.collective_params, - /*collective_clique_requests=*/nullptr, - /*collective_memory_requests=*/nullptr, params.collective_cliques, - params.collective_memory); - } + const RunId run_id = + params.collective_params ? params.collective_params->run_id : RunId{-1}; + + if (const auto* c_bundle = std::get_if(&bundle_); + c_bundle && c_bundle->initialize) { + return ExecuteFfiHandler( + run_id, *c_bundle->initialize, XLA_FFI_ExecutionStage_INITIALIZE, + params.stream, params.execution_scoped_state, + params.ffi_execution_context, params.buffer_allocations, + params.collective_params, + /*collective_clique_requests=*/nullptr, + /*collective_memory_requests=*/nullptr, params.collective_cliques, + params.collective_memory); + } + if (const auto* owned_bundle = std::get_if(&bundle_); + owned_bundle && owned_bundle->initialize) { + return ExecuteFfiHandler( + run_id, *owned_bundle->initialize, + xla::ffi::ExecutionStage::kInitialize, params.stream, + params.execution_scoped_state, params.ffi_execution_context, + params.buffer_allocations, params.collective_params, + /*collective_clique_requests=*/nullptr, + /*collective_memory_requests=*/nullptr, params.collective_cliques, + params.collective_memory); } return absl::OkStatus(); } @@ -629,64 +461,53 @@ absl::Status CustomCallThunk::Initialize(const InitializeParams& params) { absl::Status CustomCallThunk::ExecuteOnStream(const ExecuteParams& params) { se::Stream* stream = params.stream; - if (bundle_.has_value()) { - const RunId run_id = - params.collective_params ? params.collective_params->run_id : RunId{-1}; - if (const auto* c_bundle = - std::get_if(&bundle_.value()); - c_bundle) { - return ExecuteFfiHandler( - run_id, c_bundle->execute, XLA_FFI_ExecutionStage_EXECUTE, stream, - params.execution_scoped_state, params.ffi_execution_context, - params.buffer_allocations, params.collective_params, - /*collective_clique_requests=*/nullptr, - /*collective_memory_requests=*/nullptr, params.collective_cliques, - params.collective_memory); - } - if (const auto* owned_bundle = - std::get_if(&bundle_.value()); - owned_bundle) { - if (!owned_bundle->execute) { - return absl::InternalError("FFI execute handler is not set"); - } - return ExecuteFfiHandler( - run_id, *owned_bundle->execute, xla::ffi::ExecutionStage::kExecute, - stream, params.execution_scoped_state, params.ffi_execution_context, - params.buffer_allocations, params.collective_params, - /*collective_clique_requests=*/nullptr, - /*collective_memory_requests=*/nullptr, params.collective_cliques, - params.collective_memory); + const RunId run_id = + params.collective_params ? params.collective_params->run_id : RunId{-1}; + + if (const auto* c_bundle = std::get_if(&bundle_)) { + return ExecuteFfiHandler( + run_id, c_bundle->execute, XLA_FFI_ExecutionStage_EXECUTE, stream, + params.execution_scoped_state, params.ffi_execution_context, + params.buffer_allocations, params.collective_params, + /*collective_clique_requests=*/nullptr, + /*collective_memory_requests=*/nullptr, params.collective_cliques, + params.collective_memory); + } + if (const auto* owned_bundle = std::get_if(&bundle_)) { + if (!owned_bundle->execute) { + return absl::InternalError("FFI execute handler is not set"); } + return ExecuteFfiHandler( + run_id, *owned_bundle->execute, xla::ffi::ExecutionStage::kExecute, + stream, params.execution_scoped_state, params.ffi_execution_context, + params.buffer_allocations, params.collective_params, + /*collective_clique_requests=*/nullptr, + /*collective_memory_requests=*/nullptr, params.collective_cliques, + params.collective_memory); } - return ExecuteCustomCall(params); + return absl::InternalError("No FFI handler bundle set"); } absl::StatusOr CustomCallThunk::ToProto() const { - if (!api_version_.has_value()) { - return absl::FailedPreconditionError( - "CustomCallThunk was created from a non-registered target and cannot " - "be serialized to a proto"); - } - ThunkProto proto; *proto.mutable_thunk_info() = thunk_info().ToProto(); proto.mutable_custom_call_thunk()->set_target_name(target_name_); - proto.mutable_custom_call_thunk()->set_opaque(opaque_); - proto.mutable_custom_call_thunk()->set_api_version(api_version_.value()); + proto.mutable_custom_call_thunk()->set_api_version( + CustomCallApiVersion::API_VERSION_TYPED_FFI); if (called_computation_ != nullptr) { proto.mutable_custom_call_thunk()->set_called_computation( called_computation_->name()); } for (const NullableShapedSlice& operand : operands_) { - TF_ASSIGN_OR_RETURN(*proto.mutable_custom_call_thunk()->add_operands(), - operand.ToProto()); + ASSIGN_OR_RETURN(*proto.mutable_custom_call_thunk()->add_operands(), + operand.ToProto()); } for (const NullableShapedSlice& result : results_) { - TF_ASSIGN_OR_RETURN(*proto.mutable_custom_call_thunk()->add_results(), - result.ToProto()); + ASSIGN_OR_RETURN(*proto.mutable_custom_call_thunk()->add_results(), + result.ToProto()); } if (attributes_.has_value()) { @@ -695,7 +516,7 @@ absl::StatusOr CustomCallThunk::ToProto() const { } if (execution_state_ && execution_state_->IsSerializable()) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( *proto.mutable_custom_call_thunk()->mutable_execution_state(), execution_state_->ToProto()); } @@ -716,31 +537,24 @@ absl::StatusOr> CustomCallThunk::FromProto( std::vector operands, results; for (const auto& operand_proto : proto.operands()) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( NullableShapedSlice operand, NullableShapedSlice::FromProto(operand_proto, buffer_allocations)); operands.push_back(std::move(operand)); } for (const auto& result_proto : proto.results()) { - TF_ASSIGN_OR_RETURN( + ASSIGN_OR_RETURN( NullableShapedSlice result, NullableShapedSlice::FromProto(result_proto, buffer_allocations)); results.push_back(std::move(result)); } - if (proto.api_version() != CustomCallApiVersion::API_VERSION_TYPED_FFI) { - // Create a thunk that uses the legacy custom call registry. - return CustomCallThunk::Create( - std::move(thunk_info), proto.target_name(), std::move(operands), - std::move(results), proto.opaque(), proto.api_version(), platform_name); - } - - TF_ASSIGN_OR_RETURN(ffi::AttributesMap attributes, - ffi::AttributesMap::FromProto(proto.attributes())); + ASSIGN_OR_RETURN(ffi::AttributesMap attributes, + ffi::AttributesMap::FromProto(proto.attributes())); HloComputation* called_computation = nullptr; if (proto.has_called_computation()) { - CHECK(hlo_module != nullptr); // This check is needed for static analysis. + CHECK(hlo_module != nullptr); called_computation = hlo_module->GetComputationWithName(proto.called_computation()); if (called_computation == nullptr) { @@ -758,7 +572,7 @@ absl::StatusOr> CustomCallThunk::FromProto( } else { LOG(WARNING) << "Failed to deserialize the custom call execution state. Falling " - "back to runtime instantiaton of the execution state. Reason: " + "back to runtime instantiation of the execution state. Reason: " << state.status(); } } @@ -770,5 +584,4 @@ absl::StatusOr> CustomCallThunk::FromProto( std::move(cpu_target_machine_options)); } -} // namespace gpu -} // namespace xla +} // namespace xla::gpu diff --git a/xla/backends/gpu/runtime/custom_call_thunk.h b/xla/backends/gpu/runtime/custom_call_thunk.h index 810d2c002777d..cd077bb5fb617 100644 --- a/xla/backends/gpu/runtime/custom_call_thunk.h +++ b/xla/backends/gpu/runtime/custom_call_thunk.h @@ -16,8 +16,6 @@ limitations under the License. #ifndef XLA_BACKENDS_GPU_RUNTIME_CUSTOM_CALL_THUNK_H_ #define XLA_BACKENDS_GPU_RUNTIME_CUSTOM_CALL_THUNK_H_ -#include -#include #include #include #include @@ -33,6 +31,7 @@ limitations under the License. #include "xla/backends/gpu/runtime/collective_cliques.h" #include "xla/backends/gpu/runtime/collective_memory.h" #include "xla/backends/gpu/runtime/thunk.h" +#include "xla/backends/gpu/runtime/thunk.pb.h" #include "xla/executable_run_options.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/attribute_map.h" @@ -45,20 +44,20 @@ limitations under the License. #include "xla/runtime/buffer_use.h" #include "xla/runtime/object_pool.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/custom_call_status.h" #include "xla/service/gpu/buffer_allocations.h" #include "xla/service/shaped_slice.h" #include "xla/stream_executor/device_address_allocator.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/stream.h" -namespace xla { -namespace gpu { +namespace xla::gpu { -// Thunk to run a GPU custom call. +// Thunk to run an XLA FFI custom call on a GPU. // -// This thunk's `ExecuteOnStream` implementation executes a host function -// `call_target` which is expected to enqueue operations onto the GPU. +// This thunk handles custom calls registered via the XLA FFI mechanism, which +// provides a type-safe API for registering external functions with XLA runtime. +// +// For legacy (non-FFI) custom calls, see LegacyCustomCallThunk. // // Note that not all kCustomCall HLOs in XLA:GPU end up being run by this thunk. // XLA itself creates kCustomCall instructions when lowering kConvolution HLOs @@ -85,26 +84,6 @@ class CustomCallThunk : public Thunk { ffi::ExecutionState init; }; - using CustomCallTarget = - std::function; - - // Creates a serializable custom call thunk. The callback is resolved using - // the legacy CustomCall registry. For new code please use XLA FFI instead. - static absl::StatusOr> Create( - ThunkInfo thunk_info, std::string target_name, - std::vector operands, - std::vector results, std::string opaque, - CustomCallApiVersion api_version, absl::string_view platform_name); - - // Creates a custom call thunk from the given legacy custom call target. - // Note that a thunk created this way can't be serialized to a proto. - // This function is only permitted for unit testing code. - static absl::StatusOr> Create( - ThunkInfo thunk_info, std::string target_name, - CustomCallTarget call_target, std::vector operands, - std::vector results, std::string opaque); - // Creates a serializable custom call thunk. The callback is resolved using // XLA FFI. static absl::StatusOr> Create( @@ -150,14 +129,10 @@ class CustomCallThunk : public Thunk { absl::Status ExecuteOnStream(const ExecuteParams& params) override; const std::string& target_name() const { return target_name_; } - CustomCallTarget call_target() const { return call_target_; } std::optional bundle() const { - if (!bundle_.has_value()) { - return std::nullopt; - } const XLA_FFI_Handler_Bundle* c_bundle = - std::get_if(&bundle_.value()); + std::get_if(&bundle_); return c_bundle ? std::make_optional(*c_bundle) : std::nullopt; } @@ -172,8 +147,6 @@ class CustomCallThunk : public Thunk { const std::vector& operands() const { return operands_; } const std::vector& results() const { return results_; } - absl::string_view opaque() const { return opaque_; } - BufferUses buffer_uses() const override { BufferUses res; res.reserve(operands_.size() + results_.size()); @@ -197,12 +170,6 @@ class CustomCallThunk : public Thunk { std::optional cpu_target_machine_options); private: - CustomCallThunk(ThunkInfo thunk_info, std::string target_name, - std::vector operands, - std::vector results, std::string opaque, - CustomCallTarget call_target, - const std::optional& api_version); - CustomCallThunk( ThunkInfo thunk_info, std::string target_name, std::variant bundle, @@ -213,8 +180,6 @@ class CustomCallThunk : public Thunk { const HloComputation* called_computation, std::optional cpu_target_machine_options); - absl::Status ExecuteCustomCall(const ExecuteParams& params); - absl::StatusOr::BorrowedObject> BuildCallFrame(const BufferAllocations* absl_nullable buffer_allocations); @@ -251,26 +216,15 @@ class CustomCallThunk : public Thunk { const CollectiveCliques* absl_nullable collective_cliques, const CollectiveMemory* absl_nullable collective_memory); - // API version of the custom call. If not set, it means the custom call thunk - // was initialized from a non-registered function pointer and can't be - // serialized to a proto. - std::optional api_version_; std::string target_name_; // Nulled shape slices represent null pointer arguments to the thunk. std::vector operands_; std::vector results_; - // This is a legacy custom call API that is discouraged, and will be - // deprecated once XLA:FFI mechanism is ready. - CustomCallTarget call_target_; - std::string opaque_; - - // XLA FFI provides a right type safe mechanism for registering external - // functions with XLA runtime. It's under construction, and still misses - // a lot of features. Long term it will replace legacy custom calls. - std::optional> - bundle_; + // XLA FFI handler bundle: either a C API bundle (from the global FFI + // registry) or an owned bundle (from xla::ffi::Bind()). + std::variant bundle_; std::optional attributes_; // Reference call frame pre-initialized at construction time. @@ -296,7 +250,6 @@ class CustomCallThunk : public Thunk { std::optional cpu_target_machine_options_; }; -} // namespace gpu -} // namespace xla +} // namespace xla::gpu #endif // XLA_BACKENDS_GPU_RUNTIME_CUSTOM_CALL_THUNK_H_ diff --git a/xla/backends/gpu/runtime/custom_call_thunk_test.cc b/xla/backends/gpu/runtime/custom_call_thunk_test.cc index 4a37af9e61a85..14d12660bb09e 100644 --- a/xla/backends/gpu/runtime/custom_call_thunk_test.cc +++ b/xla/backends/gpu/runtime/custom_call_thunk_test.cc @@ -35,8 +35,11 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/backends/cpu/target_machine_options.h" #include "xla/backends/gpu/ffi.h" +#include "xla/backends/gpu/runtime/collective_clique_requests.h" #include "xla/backends/gpu/runtime/collective_memory_requests.h" +#include "xla/backends/gpu/runtime/collective_params.h" #include "xla/backends/gpu/runtime/thunk.h" +#include "xla/backends/gpu/runtime/thunk.pb.h" #include "xla/executable_run_options.h" #include "xla/ffi/attribute_map.h" #include "xla/ffi/execution_state.h" @@ -47,9 +50,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/runtime/device_id.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/custom_call_status.h" -#include "xla/service/custom_call_target_registry.h" #include "xla/service/gpu/buffer_allocations.h" +#include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" #include "xla/service/platform_util.h" #include "xla/service/service_executable_run_options.h" @@ -63,6 +65,8 @@ limitations under the License. #include "xla/stream_executor/stream_executor_address_allocator.h" #include "xla/tsl/platform/statusor.h" #include "xla/tsl/util/proto/parse_text_proto.h" +#include "xla/xla_data.pb.h" +#include "xla/tsl/platform/status_macros.h" namespace xla::gpu { struct TestState { @@ -132,38 +136,11 @@ using absl_testing::StatusIs; using ::testing::HasSubstr; static absl::StatusOr GpuExecutor() { - TF_ASSIGN_OR_RETURN(auto name, PlatformUtil::CanonicalPlatformName("gpu")); - TF_ASSIGN_OR_RETURN(auto* platform, - se::PlatformManager::PlatformWithName(name)); + ASSIGN_OR_RETURN(auto name, PlatformUtil::CanonicalPlatformName("gpu")); + ASSIGN_OR_RETURN(auto* platform, se::PlatformManager::PlatformWithName(name)); return platform->ExecutorForDevice(0); } -TEST(CustomCallThunkTest, SimpleCustomCall) { - TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, - executor->CreateStream()); - - bool was_called = false; - - CustomCallThunk::CustomCallTarget target = - [&](se::Stream* stream_in_callback, void** args, const char* target_name, - size_t num_args, XlaCustomCallStatus* status) { - was_called = true; - EXPECT_THAT(stream_in_callback, ::testing::Eq(stream.get())); - }; - - TF_ASSERT_OK_AND_ASSIGN( - auto thunk, CustomCallThunk::Create(Thunk::ThunkInfo(), "target_name", - target, {}, {}, "")); - stream_executor::StreamExecutorAddressAllocator allocator(executor); - Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( - ServiceExecutableRunOptions(), BufferAllocations({}, 0, &allocator), - stream.get(), stream.get(), nullptr, nullptr, nullptr); - EXPECT_THAT(thunk->ExecuteOnStream(Thunk::ExecuteParams(params)), - absl_testing::IsOk()); - EXPECT_TRUE(was_called); -} - // A simple callback function that always returns an error. absl::Status ReturnError() { return absl::UnknownError("Custom call was executed!"); @@ -181,11 +158,11 @@ XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), kReturnErrorCustomCallName, "ROCM", kReturnError); TEST(CustomCallThunkTest, ResolvesFFICustomCall) { - TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, - executor->CreateStream()); + ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + executor->CreateStream()); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( std::unique_ptr thunk, CustomCallThunk::Create( Thunk::ThunkInfo(), @@ -210,52 +187,10 @@ TEST(CustomCallThunkTest, ResolvesFFICustomCall) { HasSubstr("Custom call was executed!"))); } -// A simple callback function that always returns an error and has the function -// signature for a legacy custom call. -void Callback_WithStatusFailed(void* /*stream*/, void** /*buffers*/, - const char* /*opaque*/, size_t /*opaque_len*/, - XlaCustomCallStatus* status) { - constexpr absl::string_view kErrorMessage = - "Legacy Custom call was executed!"; - XlaCustomCallStatusSetFailure(status, kErrorMessage.data(), - kErrorMessage.size()); -} - -XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_WithStatusFailed, "CUDA"); -XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_WithStatusFailed, "ROCM"); - -TEST(CustomCallThunkTest, ResolvesLegacyCustomCall) { - TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, - executor->CreateStream()); - - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr thunk, - CustomCallThunk::Create( - Thunk::ThunkInfo(), - /*target_name=*/"Callback_WithStatusFailed", - /*operands=*/{}, - /*results=*/{}, /*opaque=*/"", - CustomCallApiVersion::API_VERSION_STATUS_RETURNING, - /*platform_name=*/executor->GetPlatform()->Name())); - - stream_executor::StreamExecutorAddressAllocator allocator(executor); - BufferAllocations empty_unused_allocations({}, 0, &allocator); - Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( - ServiceExecutableRunOptions(), empty_unused_allocations, - /*stream=*/stream.get(), - /*command_buffer_trace_stream=*/stream.get(), - /*collective_params=*/nullptr, - /*collective_cliques=*/nullptr, /*collective_memory=*/nullptr); - EXPECT_THAT(thunk->ExecuteOnStream(params), - StatusIs(absl::StatusCode::kInternal, - HasSubstr("Legacy Custom call was executed!"))); -} - TEST(CustomCallThunkTest, CustomCallWithOwnedHandlers) { - TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, - executor->CreateStream()); + ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + executor->CreateStream()); int instantiate_calls = 0; int prepare_calls = 0; int initialize_calls = 0; @@ -304,7 +239,7 @@ TEST(CustomCallThunkTest, CustomCallWithOwnedHandlers) { ServiceExecutableRunOptions(), buffer_allocations, stream.get(), stream.get(), nullptr, nullptr, nullptr); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( std::unique_ptr thunk, CustomCallThunk::Create( Thunk::ThunkInfo(), "target_name", std::move(bundle), @@ -336,9 +271,9 @@ TEST(CustomCallThunkTest, CustomCallWithOwnedHandlers) { } TEST(CustomCallThunkTest, CustomCallWithOwnedHandlersWithoutOptionalOnes) { - TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, - executor->CreateStream()); + ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + executor->CreateStream()); int execute_calls = 0; CustomCallThunk::OwnedHandlerBundle bundle; bundle.execute = ffi::Ffi::Bind().To([&]() { @@ -369,7 +304,7 @@ TEST(CustomCallThunkTest, CustomCallWithOwnedHandlersWithoutOptionalOnes) { stream.get(), nullptr, nullptr, nullptr); // Optional handlers are null and shouldn't be invoked. - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( std::unique_ptr thunk, CustomCallThunk::Create( Thunk::ThunkInfo(), "target_name", std::move(bundle), @@ -383,9 +318,9 @@ TEST(CustomCallThunkTest, CustomCallWithOwnedHandlersWithoutOptionalOnes) { } TEST(CustomCallThunkTest, CustomCallWithOwnedHandlersWithoutExecute) { - TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, - executor->CreateStream()); + ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + executor->CreateStream()); CustomCallThunk::OwnedHandlerBundle bundle; // all handlers null stream_executor::StreamExecutorAddressAllocator allocator(executor); Thunk::ExecuteParams execute_params = Thunk::ExecuteParams::Create( @@ -436,9 +371,9 @@ XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), kTestPlatformName, kVerifyCallbackArguments); TEST(CustomCallThunkTest, ProtoConversion) { - TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, - executor->CreateStream()); + ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + executor->CreateStream()); HloModuleConfig config; HloModule hlo_module("test_module", config); @@ -460,7 +395,7 @@ TEST(CustomCallThunkTest, ProtoConversion) { std::make_unique(TestState{"some state"})), IsOk()); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( std::unique_ptr original_thunk, CustomCallThunk::Create( Thunk::ThunkInfo(), @@ -470,14 +405,14 @@ TEST(CustomCallThunkTest, ProtoConversion) { hlo_module.entry_computation(), /*platform_name=*/kTestPlatformName, /*gpu_compute_capability=*/{}, std::move(execution_state))); - TF_ASSERT_OK_AND_ASSIGN(ThunkProto proto, original_thunk->ToProto()); + ASSERT_OK_AND_ASSIGN(ThunkProto proto, original_thunk->ToProto()); ASSERT_TRUE(proto.has_custom_call_thunk()); ASSERT_TRUE(proto.custom_call_thunk().has_execution_state()); original_thunk.reset(); std::array allocations = {alloc0, alloc1}; - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( std::unique_ptr new_thunk, CustomCallThunk::FromProto(Thunk::ThunkInfo(), proto.custom_call_thunk(), allocations, &hlo_module, kTestPlatformName, @@ -549,7 +484,7 @@ TEST(CustomCallThunkTest, DeserializationFailsGracefully) { 0, ShapeUtil::MakeShape(U32, {42}), "parameter")); hlo_module.AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( std::unique_ptr thunk, CustomCallThunk::FromProto(Thunk::ThunkInfo(), proto, /*buffer_allocations=*/{}, &hlo_module, @@ -562,9 +497,9 @@ TEST(CustomCallThunkTest, DeserializationFailsGracefully) { } TEST(CustomCallThunkTest, RoundtripWithNonSerializableExecutionState) { - TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, - executor->CreateStream()); + ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + executor->CreateStream()); HloModuleConfig config; HloModule hlo_module("test_module", config); @@ -579,7 +514,7 @@ TEST(CustomCallThunkTest, RoundtripWithNonSerializableExecutionState) { IsOk()); EXPECT_FALSE(execution_state->IsSerializable()); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( std::unique_ptr original_thunk, CustomCallThunk::Create( Thunk::ThunkInfo(), @@ -589,13 +524,13 @@ TEST(CustomCallThunkTest, RoundtripWithNonSerializableExecutionState) { /*platform_name=*/kTestPlatformName, /*gpu_compute_capability=*/{}, std::move(execution_state))); - TF_ASSERT_OK_AND_ASSIGN(ThunkProto proto, original_thunk->ToProto()); + ASSERT_OK_AND_ASSIGN(ThunkProto proto, original_thunk->ToProto()); ASSERT_TRUE(proto.has_custom_call_thunk()); EXPECT_FALSE(proto.custom_call_thunk().has_execution_state()); original_thunk.reset(); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( std::unique_ptr new_thunk, CustomCallThunk::FromProto( Thunk::ThunkInfo(), proto.custom_call_thunk(), @@ -620,7 +555,7 @@ TEST(CustomCallThunkTest, SerializationFails) { FailingSerializableTestState{42}))); EXPECT_TRUE(execution_state->IsSerializable()); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( std::unique_ptr thunk, CustomCallThunk::Create( Thunk::ThunkInfo(), @@ -658,68 +593,6 @@ TEST(CustomCallThunkTest, ParseFFIProtoWithNonUtf8Attribute) { EXPECT_TRUE(reconstructed_proto.ParseFromString(serialized_to_wire_format)); } -TEST(CustomCallThunkTest, ParseLegacyProtoWithNonUtf8Opaque) { - // This test ensures that legacy custom calls can contain non-UTF-8 opaque - // data, and these will be correctly parsed (and not fail). - - CustomCallThunkProto proto = - tsl::proto_testing::ParseTextProtoOrDie( - R"pb( - target_name: "Callback_WithStatusFailed" - api_version: API_VERSION_STATUS_RETURNING - opaque: "\xfe" - )pb"); - - std::string serialized_to_wire_format; - proto.SerializeToString(&serialized_to_wire_format); - - CustomCallThunkProto reconstructed_proto; - EXPECT_TRUE(reconstructed_proto.ParseFromString(serialized_to_wire_format)); -} - -TEST(CustomCallThunkTest, LegacyCustomCallRoundTrip) { - TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, - executor->CreateStream()); - - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr original_thunk, - CustomCallThunk::Create( - Thunk::ThunkInfo(), - /*target_name=*/"Callback_WithStatusFailed", - /*operands=*/{}, - /*results=*/{}, /*opaque=*/"opaque", - CustomCallApiVersion::API_VERSION_STATUS_RETURNING, - /*platform_name=*/executor->GetPlatform()->Name())); - - TF_ASSERT_OK_AND_ASSIGN(ThunkProto proto, original_thunk->ToProto()); - original_thunk.reset(); - - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr new_thunk, - CustomCallThunk::FromProto( - Thunk::ThunkInfo(), proto.custom_call_thunk(), - /*buffer_allocations=*/{}, - /*hlo_module=*/nullptr, executor->GetPlatform()->Name(), - executor->GetDeviceDescription().gpu_compute_capability(), - /*cpu_target_machine_options=*/std::nullopt)); - - stream_executor::StreamExecutorAddressAllocator allocator(executor); - BufferAllocations empty_unused_allocations({}, 0, &allocator); - Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( - ServiceExecutableRunOptions(), empty_unused_allocations, - /*stream=*/stream.get(), - /*command_buffer_trace_stream=*/stream.get(), - /*collective_params=*/nullptr, /*collective_cliques=*/nullptr, - /*collective_memory=*/nullptr); - - // We check that the new thunk behaves like the original one (returning - // internal error with specific message). - EXPECT_THAT(new_thunk->ExecuteOnStream(params), - StatusIs(absl::StatusCode::kInternal, - HasSubstr("Legacy Custom call was executed!"))); -} - static bool passes_cpu_target_machine_options_instantiate_called = false; absl::Status VerifyCpuTargetMachineOptionsInstantiate( @@ -759,14 +632,14 @@ XLA_FFI_REGISTER_HANDLER( static_cast(ffi::Traits::kCmdBufferCompatible)); TEST(CustomCallThunkTest, PassesCpuTargetMachineOptionsToInstantiate) { - TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, - executor->CreateStream()); + ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + executor->CreateStream()); xla::cpu::TargetMachineOptions options("test-triple", "test-cpu", ""); passes_cpu_target_machine_options_instantiate_called = false; - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( std::unique_ptr thunk, CustomCallThunk::Create( Thunk::ThunkInfo(), @@ -784,10 +657,10 @@ TEST(CustomCallThunkTest, PassesCpuTargetMachineOptionsToInstantiate) { // Also check that FromProto restores the CPU target machine options. // We clear the execution state from the proto to force a re-instantiation. passes_cpu_target_machine_options_instantiate_called = false; - TF_ASSERT_OK_AND_ASSIGN(ThunkProto proto, thunk->ToProto()); + ASSERT_OK_AND_ASSIGN(ThunkProto proto, thunk->ToProto()); proto.mutable_custom_call_thunk()->clear_execution_state(); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN( std::unique_ptr new_thunk, CustomCallThunk::FromProto( Thunk::ThunkInfo(), proto.custom_call_thunk(), diff --git a/xla/backends/gpu/runtime/legacy_custom_call_thunk.cc b/xla/backends/gpu/runtime/legacy_custom_call_thunk.cc new file mode 100644 index 0000000000000..3cc38c4a09e65 --- /dev/null +++ b/xla/backends/gpu/runtime/legacy_custom_call_thunk.cc @@ -0,0 +1,224 @@ +/* Copyright 2026 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/gpu/runtime/legacy_custom_call_thunk.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/backends/gpu/runtime/custom_call_target.h" +#include "xla/backends/gpu/runtime/thunk.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/custom_call_status.h" +#include "xla/service/custom_call_status_internal.h" +#include "xla/service/custom_call_target_registry.h" +#include "xla/service/gpu/buffer_allocations.h" +#include "xla/service/shaped_slice.h" +#include "xla/stream_executor/stream.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "tsl/platform/platform.h" +#include "xla/tsl/platform/status_macros.h" + +namespace xla::gpu { + +static absl::StatusOr +ResolveLegacyCustomCall(const CustomCallTargetRegistry& registry, + absl::string_view target_name, + absl::string_view platform_name, + CustomCallApiVersion api_version) { + void* call_target = + registry.Lookup(std::string(target_name), std::string(platform_name)); + + if (call_target == nullptr) { + return NotFound( + "No registered implementation for custom call to %s for platform %s", + target_name, platform_name); + } + + switch (api_version) { + case CustomCallApiVersion::API_VERSION_ORIGINAL: { + constexpr absl::string_view kErrorMessage = + "Custom call API version `API_VERSION_ORIGINAL` is not supported by " + "XLA:GPU. Prefer https://docs.jax.dev/en/latest/ffi.html. It will be " + "fully removed in November 2025."; + if constexpr (tsl::kIsOpenSource) { + LOG(ERROR) << kErrorMessage; + } else { + LOG(FATAL) << kErrorMessage; + } + + return [call_target](stream_executor::Stream* stream, void** buffers, + const char* opaque, size_t opaque_len, + XlaCustomCallStatus*) { + reinterpret_cast(call_target)( + stream->platform_specific_handle().stream, buffers, opaque, + opaque_len); + }; + break; + } + case CustomCallApiVersion::API_VERSION_STATUS_RETURNING: + case CustomCallApiVersion::API_VERSION_STATUS_RETURNING_UNIFIED: + return [call_target](stream_executor::Stream* stream, void** buffers, + const char* opaque, size_t opaque_len, + XlaCustomCallStatus* status) { + reinterpret_cast( + call_target)(stream->platform_specific_handle().stream, buffers, + opaque, opaque_len, status); + }; + break; + case CustomCallApiVersion::API_VERSION_TYPED_FFI: + return absl::InvalidArgumentError( + "Called ResolveLegacyCustomCall with API_VERSION_TYPED_FFI"); + default: + return Internal("Unknown custom-call API version enum value: %d", + api_version); + } +} + +absl::StatusOr> +LegacyCustomCallThunk::Create(ThunkInfo thunk_info, std::string target_name, + CustomCallTarget call_target, + std::vector operands, + std::vector results, + std::string opaque) { + return absl::WrapUnique(new LegacyCustomCallThunk( + thunk_info, std::move(target_name), std::move(operands), + std::move(results), std::move(opaque), std::move(call_target), + /*api_version=*/std::nullopt)); +} + +absl::StatusOr> +LegacyCustomCallThunk::Create(ThunkInfo thunk_info, std::string target_name, + std::vector operands, + std::vector results, + std::string opaque, + CustomCallApiVersion api_version, + absl::string_view platform_name) { + ASSIGN_OR_RETURN( + CustomCallTarget call_target, + ResolveLegacyCustomCall(*CustomCallTargetRegistry::Global(), target_name, + platform_name, api_version)); + + return absl::WrapUnique(new LegacyCustomCallThunk( + thunk_info, std::move(target_name), std::move(operands), + std::move(results), std::move(opaque), call_target, api_version)); +} + +LegacyCustomCallThunk::LegacyCustomCallThunk( + ThunkInfo thunk_info, std::string target_name, + std::vector operands, + std::vector results, std::string opaque, + CustomCallTarget call_target, + const std::optional& api_version) + : Thunk(Thunk::kCustomCall, thunk_info), + api_version_(api_version), + target_name_(std::move(target_name)), + operands_(std::move(operands)), + results_(std::move(results)), + call_target_(std::move(call_target)), + opaque_(std::move(opaque)) {} + +absl::Status LegacyCustomCallThunk::ExecuteOnStream( + const ExecuteParams& params) { + std::vector buffers; + buffers.reserve(operands_.size() + results_.size()); + for (auto& slices : {operands_, results_}) { + for (const std::optional& slice : slices) { + if (!slice.has_value()) { + buffers.push_back(nullptr); + continue; + } + + if (!slice->slice.allocation()) { + return Internal("custom call input missing buffer allocation"); + } + + buffers.push_back( + params.buffer_allocations->GetDeviceAddress(slice->slice).opaque()); + } + } + + XlaCustomCallStatus custom_call_status; + call_target_(params.stream, buffers.data(), opaque_.data(), opaque_.size(), + &custom_call_status); + auto message = CustomCallStatusGetMessage(&custom_call_status); + if (message) { + return Internal("CustomCall failed: %s", *message); + } + return absl::OkStatus(); +} + +absl::StatusOr LegacyCustomCallThunk::ToProto() const { + if (!api_version_.has_value()) { + return absl::FailedPreconditionError( + "LegacyCustomCallThunk was created from a non-registered target and " + "cannot be serialized to a proto"); + } + + ThunkProto proto; + *proto.mutable_thunk_info() = thunk_info().ToProto(); + proto.mutable_custom_call_thunk()->set_target_name(target_name_); + proto.mutable_custom_call_thunk()->set_opaque(opaque_); + proto.mutable_custom_call_thunk()->set_api_version(api_version_.value()); + + for (const NullableShapedSlice& operand : operands_) { + ASSIGN_OR_RETURN(*proto.mutable_custom_call_thunk()->add_operands(), + operand.ToProto()); + } + + for (const NullableShapedSlice& result : results_) { + ASSIGN_OR_RETURN(*proto.mutable_custom_call_thunk()->add_results(), + result.ToProto()); + } + + return proto; +} + +absl::StatusOr> +LegacyCustomCallThunk::FromProto( + ThunkInfo thunk_info, const CustomCallThunkProto& proto, + absl::Span buffer_allocations, + absl::string_view platform_name) { + std::vector operands, results; + for (const auto& operand_proto : proto.operands()) { + ASSIGN_OR_RETURN( + NullableShapedSlice operand, + NullableShapedSlice::FromProto(operand_proto, buffer_allocations)); + operands.push_back(std::move(operand)); + } + for (const auto& result_proto : proto.results()) { + ASSIGN_OR_RETURN( + NullableShapedSlice result, + NullableShapedSlice::FromProto(result_proto, buffer_allocations)); + results.push_back(std::move(result)); + } + + return LegacyCustomCallThunk::Create( + std::move(thunk_info), proto.target_name(), std::move(operands), + std::move(results), proto.opaque(), proto.api_version(), platform_name); +} + +} // namespace xla::gpu diff --git a/xla/backends/gpu/runtime/legacy_custom_call_thunk.h b/xla/backends/gpu/runtime/legacy_custom_call_thunk.h new file mode 100644 index 0000000000000..eecdc65cc8510 --- /dev/null +++ b/xla/backends/gpu/runtime/legacy_custom_call_thunk.h @@ -0,0 +1,120 @@ +/* Copyright 2026 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_GPU_RUNTIME_LEGACY_CUSTOM_CALL_THUNK_H_ +#define XLA_BACKENDS_GPU_RUNTIME_LEGACY_CUSTOM_CALL_THUNK_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/backends/gpu/runtime/thunk.h" +#include "xla/backends/gpu/runtime/thunk.pb.h" +#include "xla/runtime/buffer_use.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/custom_call_status.h" +#include "xla/service/hlo.pb.h" +#include "xla/service/shaped_slice.h" +#include "xla/stream_executor/stream.h" + +namespace xla::gpu { + +// Thunk to run a legacy (non-FFI) GPU custom call. +// +// This thunk is DEPRECATED. All new custom calls should use the XLA FFI +// mechanism via CustomCallThunk instead. This class exists only to support +// legacy custom calls that have not yet migrated to FFI. +// +// For the FFI-based custom call thunk, see custom_call_thunk.h. +class LegacyCustomCallThunk : public Thunk { + public: + using CustomCallTarget = + std::function; + + // Creates a serializable legacy custom call thunk. The callback is resolved + // using the legacy CustomCallTargetRegistry. + static absl::StatusOr> Create( + ThunkInfo thunk_info, std::string target_name, + std::vector operands, + std::vector results, std::string opaque, + CustomCallApiVersion api_version, absl::string_view platform_name); + + // Creates a legacy custom call thunk from a given call target. A thunk + // created this way cannot be serialized to a proto. This overload is only + // permitted for unit testing code. + static absl::StatusOr> Create( + ThunkInfo thunk_info, std::string target_name, + CustomCallTarget call_target, std::vector operands, + std::vector results, std::string opaque); + + absl::Status ExecuteOnStream(const ExecuteParams& params) override; + + const std::string& target_name() const { return target_name_; } + CustomCallTarget call_target() const { return call_target_; } + + const std::vector& operands() const { return operands_; } + const std::vector& results() const { return results_; } + + absl::string_view opaque() const { return opaque_; } + + BufferUses buffer_uses() const override { + BufferUses res; + res.reserve(operands_.size() + results_.size()); + for (const NullableShapedSlice& shaped_slice : operands_) { + if (!shaped_slice.has_value()) { + continue; + } + res.push_back(BufferUse::Read(shaped_slice->slice, shaped_slice->shape)); + } + return res; + } + + absl::StatusOr ToProto() const override; + + static absl::StatusOr> FromProto( + ThunkInfo thunk_info, const CustomCallThunkProto& proto, + absl::Span buffer_allocations, + absl::string_view platform_name); + + private: + LegacyCustomCallThunk(ThunkInfo thunk_info, std::string target_name, + std::vector operands, + std::vector results, + std::string opaque, CustomCallTarget call_target, + const std::optional& api_version); + + // API version of the custom call. If not set, the thunk was created from a + // non-registered function pointer and cannot be serialized. + std::optional api_version_; + std::string target_name_; + + std::vector operands_; + std::vector results_; + + CustomCallTarget call_target_; + std::string opaque_; +}; + +} // namespace xla::gpu + +#endif // XLA_BACKENDS_GPU_RUNTIME_LEGACY_CUSTOM_CALL_THUNK_H_ diff --git a/xla/backends/gpu/runtime/legacy_custom_call_thunk_test.cc b/xla/backends/gpu/runtime/legacy_custom_call_thunk_test.cc new file mode 100644 index 0000000000000..72abb72b1035d --- /dev/null +++ b/xla/backends/gpu/runtime/legacy_custom_call_thunk_test.cc @@ -0,0 +1,181 @@ +/* Copyright 2026 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/gpu/runtime/legacy_custom_call_thunk.h" + +#include +#include +#include + +#include +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/backends/gpu/runtime/thunk.h" +#include "xla/backends/gpu/runtime/thunk.pb.h" +#include "xla/service/custom_call_status.h" +#include "xla/service/custom_call_target_registry.h" +#include "xla/service/gpu/buffer_allocations.h" +#include "xla/service/hlo.pb.h" +#include "xla/service/platform_util.h" +#include "xla/service/service_executable_run_options.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor_address_allocator.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/util/proto/parse_text_proto.h" +#include "xla/tsl/platform/status_macros.h" + +namespace xla::gpu { +namespace { +using absl_testing::IsOk; +using absl_testing::StatusIs; +using ::testing::HasSubstr; + +static absl::StatusOr GpuExecutor() { + ASSIGN_OR_RETURN(auto name, PlatformUtil::CanonicalPlatformName("gpu")); + ASSIGN_OR_RETURN(auto* platform, se::PlatformManager::PlatformWithName(name)); + return platform->ExecutorForDevice(0); +} + +TEST(LegacyCustomCallThunkTest, SimpleCustomCall) { + ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + executor->CreateStream()); + + bool was_called = false; + + LegacyCustomCallThunk::CustomCallTarget target = + [&](se::Stream* stream_in_callback, void** args, const char* target_name, + size_t num_args, XlaCustomCallStatus* status) { + was_called = true; + EXPECT_THAT(stream_in_callback, ::testing::Eq(stream.get())); + }; + + ASSERT_OK_AND_ASSIGN( + auto thunk, LegacyCustomCallThunk::Create( + Thunk::ThunkInfo(), "target_name", target, {}, {}, "")); + stream_executor::StreamExecutorAddressAllocator allocator(executor); + Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( + ServiceExecutableRunOptions(), BufferAllocations({}, 0, &allocator), + stream.get(), stream.get(), nullptr, nullptr, nullptr); + EXPECT_THAT(thunk->ExecuteOnStream(Thunk::ExecuteParams(params)), + absl_testing::IsOk()); + EXPECT_TRUE(was_called); +} + +// A simple callback function that always returns an error and has the function +// signature for a legacy custom call. +void Callback_WithStatusFailed(void* /*stream*/, void** /*buffers*/, + const char* /*opaque*/, size_t /*opaque_len*/, + XlaCustomCallStatus* status) { + constexpr absl::string_view kErrorMessage = + "Legacy Custom call was executed!"; + XlaCustomCallStatusSetFailure(status, kErrorMessage.data(), + kErrorMessage.size()); +} + +XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_WithStatusFailed, "CUDA"); +XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_WithStatusFailed, "ROCM"); + +TEST(LegacyCustomCallThunkTest, ResolvesLegacyCustomCall) { + ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + executor->CreateStream()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr thunk, + LegacyCustomCallThunk::Create( + Thunk::ThunkInfo(), + /*target_name=*/"Callback_WithStatusFailed", + /*operands=*/{}, + /*results=*/{}, /*opaque=*/"", + CustomCallApiVersion::API_VERSION_STATUS_RETURNING, + /*platform_name=*/executor->GetPlatform()->Name())); + + stream_executor::StreamExecutorAddressAllocator allocator(executor); + BufferAllocations empty_unused_allocations({}, 0, &allocator); + Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( + ServiceExecutableRunOptions(), empty_unused_allocations, + /*stream=*/stream.get(), + /*command_buffer_trace_stream=*/stream.get(), + /*collective_params=*/nullptr, + /*collective_cliques=*/nullptr, /*collective_memory=*/nullptr); + EXPECT_THAT(thunk->ExecuteOnStream(params), + StatusIs(absl::StatusCode::kInternal, + HasSubstr("Legacy Custom call was executed!"))); +} + +TEST(LegacyCustomCallThunkTest, ParseLegacyProtoWithNonUtf8Opaque) { + // This test ensures that legacy custom calls can contain non-UTF-8 opaque + // data, and these will be correctly parsed (and not fail). + + CustomCallThunkProto proto = + tsl::proto_testing::ParseTextProtoOrDie( + R"pb( + target_name: "Callback_WithStatusFailed" + api_version: API_VERSION_STATUS_RETURNING + opaque: "\xfe" + )pb"); + + std::string serialized_to_wire_format; + proto.SerializeToString(&serialized_to_wire_format); + + CustomCallThunkProto reconstructed_proto; + EXPECT_TRUE(reconstructed_proto.ParseFromString(serialized_to_wire_format)); +} + +TEST(LegacyCustomCallThunkTest, LegacyCustomCallRoundTrip) { + ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + executor->CreateStream()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr original_thunk, + LegacyCustomCallThunk::Create( + Thunk::ThunkInfo(), + /*target_name=*/"Callback_WithStatusFailed", + /*operands=*/{}, + /*results=*/{}, /*opaque=*/"opaque", + CustomCallApiVersion::API_VERSION_STATUS_RETURNING, + /*platform_name=*/executor->GetPlatform()->Name())); + + ASSERT_OK_AND_ASSIGN(ThunkProto proto, original_thunk->ToProto()); + original_thunk.reset(); + + ASSERT_OK_AND_ASSIGN( + std::unique_ptr new_thunk, + LegacyCustomCallThunk::FromProto( + Thunk::ThunkInfo(), proto.custom_call_thunk(), + /*buffer_allocations=*/{}, executor->GetPlatform()->Name())); + + stream_executor::StreamExecutorAddressAllocator allocator(executor); + BufferAllocations empty_unused_allocations({}, 0, &allocator); + Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( + ServiceExecutableRunOptions(), empty_unused_allocations, + /*stream=*/stream.get(), + /*command_buffer_trace_stream=*/stream.get(), + /*collective_params=*/nullptr, /*collective_cliques=*/nullptr, + /*collective_memory=*/nullptr); + + // We check that the new thunk behaves like the original one (returning + // internal error with specific message). + EXPECT_THAT(new_thunk->ExecuteOnStream(params), + StatusIs(absl::StatusCode::kInternal, + HasSubstr("Legacy Custom call was executed!"))); +} + +} // namespace +} // namespace xla::gpu diff --git a/xla/backends/gpu/runtime/thunk_proto_deserialization.cc b/xla/backends/gpu/runtime/thunk_proto_deserialization.cc index 6731666119bc5..5cf90b6e0cf46 100644 --- a/xla/backends/gpu/runtime/thunk_proto_deserialization.cc +++ b/xla/backends/gpu/runtime/thunk_proto_deserialization.cc @@ -56,6 +56,7 @@ limitations under the License. #include "xla/backends/gpu/runtime/host_to_device_copy_thunk.h" #include "xla/backends/gpu/runtime/infeed_thunk.h" #include "xla/backends/gpu/runtime/kernel_thunk.h" +#include "xla/backends/gpu/runtime/legacy_custom_call_thunk.h" #include "xla/backends/gpu/runtime/memset_thunk.h" #include "xla/backends/gpu/runtime/norm_thunk.h" #include "xla/backends/gpu/runtime/nvshmem_all_reduce_thunk.h" @@ -75,6 +76,7 @@ limitations under the License. #include "xla/backends/gpu/runtime/while_thunk.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/buffer_assignment.h" +#include "xla/service/hlo.pb.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/platform/statusor.h" @@ -225,11 +227,17 @@ absl::StatusOr> DeserializeThunkProtoImpl( thunk_proto.dynamic_slice_thunk(), buffer_allocations, deserializer); } - case ThunkProto::kCustomCallThunk: + case ThunkProto::kCustomCallThunk: { + const auto& cc_proto = thunk_proto.custom_call_thunk(); + if (cc_proto.api_version() != + CustomCallApiVersion::API_VERSION_TYPED_FFI) { + return LegacyCustomCallThunk::FromProto( + std::move(thunk_info), cc_proto, buffer_allocations, platform_name); + } return CustomCallThunk::FromProto( - std::move(thunk_info), thunk_proto.custom_call_thunk(), - buffer_allocations, hlo_module, platform_name, gpu_compute_capability, - cpu_target_machine_options); + std::move(thunk_info), cc_proto, buffer_allocations, hlo_module, + platform_name, gpu_compute_capability, cpu_target_machine_options); + } case ThunkProto::kHostExecuteStartThunk: return HostExecuteStartThunk::FromProto( std::move(thunk_info), thunk_proto.host_execute_start_thunk(), diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index d4b2fb6991732..02099a5ef7915 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -530,6 +530,7 @@ cc_library( "//xla/backends/gpu/runtime:host_to_device_copy_thunk", "//xla/backends/gpu/runtime:infeed_thunk", "//xla/backends/gpu/runtime:kernel_thunk", + "//xla/backends/gpu/runtime:legacy_custom_call_thunk", "//xla/backends/gpu/runtime:norm_thunk", "//xla/backends/gpu/runtime:nvshmem_all_reduce_thunk", "//xla/backends/gpu/runtime:nvshmem_collective_permute_thunk", diff --git a/xla/service/gpu/thunk_emitter.cc b/xla/service/gpu/thunk_emitter.cc index fdbb44f735996..e6560b9e1d53c 100644 --- a/xla/service/gpu/thunk_emitter.cc +++ b/xla/service/gpu/thunk_emitter.cc @@ -94,6 +94,7 @@ limitations under the License. #include "xla/backends/gpu/runtime/host_to_device_copy_thunk.h" #include "xla/backends/gpu/runtime/infeed_thunk.h" #include "xla/backends/gpu/runtime/kernel_thunk.h" +#include "xla/backends/gpu/runtime/legacy_custom_call_thunk.h" #include "xla/backends/gpu/runtime/norm_thunk.h" #include "xla/backends/gpu/runtime/nvshmem_all_reduce_thunk.h" #include "xla/backends/gpu/runtime/nvshmem_collective_permute_thunk.h" @@ -973,7 +974,7 @@ absl::StatusOr ThunkEmitter::EmitCustomCallThunk( << "Fall back to parse the raw backend config str."; } - auto ffi_thunk = [&]() -> absl::StatusOr> { + auto ffi_thunk = [&]() -> absl::StatusOr> { auto& called_computations = instr->called_computations(); auto& backend_config_str = backend_config.ok() @@ -1004,13 +1005,12 @@ absl::StatusOr ThunkEmitter::EmitCustomCallThunk( ir_emitter_context_->cpu_target_machine_options()); }; - auto legacy_thunk = - [&]() -> absl::StatusOr> { + auto legacy_thunk = [&]() -> absl::StatusOr> { std::string opaque = backend_config.ok() ? backend_config->custom_call_backend_config().opaque() : instr->raw_backend_config_string(); - return CustomCallThunk::Create( + return LegacyCustomCallThunk::Create( Thunk::ThunkInfo::WithProfileAnnotation( instr, ir_emitter_context_->GetNextThunkId()), call_target_name, std::move(operands), std::move(results), @@ -1018,7 +1018,7 @@ absl::StatusOr ThunkEmitter::EmitCustomCallThunk( ir_emitter_context_->platform_name()); }; - absl::StatusOr> custom_call_thunk = + absl::StatusOr> custom_call_thunk = is_ffi_custom_call ? ffi_thunk() : legacy_thunk(); ThunkSequence thunks;