Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xla/backends/gpu/codegen/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ cc_library(
"//xla/backends/gpu/runtime:device_to_device_copy_thunk",
"//xla/backends/gpu/runtime:dynamic_slice_thunk",
"//xla/backends/gpu/runtime:gemm_thunk",
"//xla/backends/gpu/runtime:legacy_custom_call_thunk",
"//xla/backends/gpu/runtime:thunk",
"//xla/codegen/emitters:kernel_arguments",
"//xla/ffi:attribute_map",
Expand Down
19 changes: 9 additions & 10 deletions xla/backends/gpu/codegen/custom.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ limitations under the License.
#include "xla/backends/gpu/runtime/device_to_device_copy_thunk.h"
#include "xla/backends/gpu/runtime/dynamic_slice_thunk.h"
#include "xla/backends/gpu/runtime/gemm_thunk.h"
#include "xla/backends/gpu/runtime/legacy_custom_call_thunk.h"
#include "xla/backends/gpu/runtime/thunk.h"
#include "xla/codegen/emitters/kernel_arguments.h"
#include "xla/ffi/attribute_map.h"
Expand Down Expand Up @@ -878,7 +879,7 @@ absl::StatusOr<FusionEmissionResult> EmitCustomCall(

// For legacy custom calls we convert all API versions into the latest
// status-returning one and pass backend config as an opaque string.
CustomCallThunk::CustomCallTarget custom_call_target;
LegacyCustomCallThunk::CustomCallTarget custom_call_target;

// For XLA FFI handlers we decode opaque backend config into attributes map
// at IR emission time, so that we do not need to parse MLIR at run time.
Expand Down Expand Up @@ -926,9 +927,8 @@ absl::StatusOr<FusionEmissionResult> EmitCustomCall(
auto thunk_info = Thunk::ThunkInfo::WithProfileAnnotation(
&fusion, ir_emitter_context.GetNextThunkId());

auto ffi_thunk =
[&](Slices ops,
Slices res) -> absl::StatusOr<std::unique_ptr<CustomCallThunk>> {
auto ffi_thunk = [&](Slices ops,
Slices res) -> absl::StatusOr<std::unique_ptr<Thunk>> {
auto& called_computations = custom_call.called_computations();
auto& backend_config_str =
backend_config.ok()
Expand All @@ -954,16 +954,15 @@ absl::StatusOr<FusionEmissionResult> EmitCustomCall(
};

auto legacy_thunk =
[&](Slices ops,
Slices res) -> absl::StatusOr<std::unique_ptr<CustomCallThunk>> {
[&](Slices ops, Slices res) -> absl::StatusOr<std::unique_ptr<Thunk>> {
std::string opaque =
backend_config.ok()
? backend_config->custom_call_backend_config().opaque()
: custom_call.raw_backend_config_string();
return CustomCallThunk::Create(thunk_info, call_target_name, std::move(ops),
std::move(res), std::move(opaque),
custom_call.api_version(),
ir_emitter_context.platform_name());
return LegacyCustomCallThunk::Create(
thunk_info, call_target_name, std::move(ops), std::move(res),
std::move(opaque), custom_call.api_version(),
ir_emitter_context.platform_name());
};

std::vector<BufferAllocation> fake_allocations(num_args, {0, 0, 0});
Expand Down
79 changes: 71 additions & 8 deletions xla/backends/gpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ cc_library(
":dynamic_memcpy_thunk",
":dynamic_slice_thunk",
":gpublas_lt_matmul_thunk",
":legacy_custom_call_thunk",
":p2p_thunk_common",
":ragged_all_to_all_thunk",
":recv_thunk",
Expand Down Expand Up @@ -345,6 +346,7 @@ cc_library(
":gemm_thunk",
":gpublas_lt_matmul_thunk",
":kernel_thunk",
":legacy_custom_call_thunk",
":memset_thunk",
":ragged_all_to_all_thunk",
":recv_thunk",
Expand Down Expand Up @@ -1082,8 +1084,8 @@ cc_library(
":collective_cliques",
":collective_memory",
":collective_params",
":custom_call_target",
":thunk",
":thunk_proto_cc",
"//xla:executable_run_options",
"//xla:shape_util",
"//xla:util",
Expand All @@ -1101,16 +1103,13 @@ cc_library(
"//xla/runtime:buffer_use",
"//xla/runtime:object_pool",
"//xla/service:buffer_assignment",
"//xla/service:custom_call_status",
"//xla/service:custom_call_status_internal",
"//xla/service:custom_call_target_registry",
"//xla/service:hlo_proto_cc",
"//xla/service:shaped_slice",
"//xla/service/gpu:buffer_allocations",
"//xla/stream_executor:device_address",
"//xla/stream_executor:device_address_allocator",
"//xla/stream_executor:device_description",
"//xla/stream_executor:stream",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:status_macros",
"//xla/tsl/platform:statusor",
"//xla/tsl/util:unique_any",
Expand All @@ -1123,7 +1122,35 @@ cc_library(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
],
)

cc_library(
name = "legacy_custom_call_thunk",
srcs = ["legacy_custom_call_thunk.cc"],
hdrs = ["legacy_custom_call_thunk.h"],
deps = [
":custom_call_target",
":thunk",
":thunk_proto_cc",
"//xla:util",
"//xla/runtime:buffer_use",
"//xla/service:buffer_assignment",
"//xla/service:custom_call_status",
"//xla/service:custom_call_status_internal",
"//xla/service:custom_call_target_registry",
"//xla/service:hlo_proto_cc",
"//xla/service:shaped_slice",
"//xla/service/gpu:buffer_allocations",
"//xla/stream_executor:stream",
"//xla/tsl/platform:status_macros",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/log",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform",
Expand All @@ -1135,9 +1162,12 @@ xla_test(
srcs = ["custom_call_thunk_test.cc"],
backends = ["gpu"],
deps = [
":collective_clique_requests",
":collective_memory_requests",
":collective_params",
":custom_call_thunk",
":thunk",
":thunk_proto_cc",
"//xla:executable_run_options",
"//xla:shape_util",
"//xla/backends/cpu:target_machine_options",
Expand All @@ -1149,10 +1179,9 @@ xla_test(
"//xla/hlo/ir:hlo",
"//xla/runtime:device_id",
"//xla/service:buffer_assignment",
"//xla/service:custom_call_status_public_headers",
"//xla/service:custom_call_target_registry",
"//xla/service:executable",
"//xla/service:hlo_module_config",
"//xla/service:hlo_proto_cc",
"//xla/service:platform_util",
"//xla/service:shaped_slice",
"//xla/service/gpu:buffer_allocations",
Expand All @@ -1162,6 +1191,7 @@ xla_test(
"//xla/stream_executor:platform_manager",
"//xla/stream_executor:stream",
"//xla/stream_executor:stream_executor_address_allocator",
"//xla/tsl/platform:status_macros",
"//xla/tsl/platform:statusor",
"//xla/tsl/util/proto:parse_text_proto",
"@com_google_absl//absl/base",
Expand All @@ -1175,6 +1205,36 @@ xla_test(
],
)

xla_test(
name = "legacy_custom_call_thunk_test",
srcs = ["legacy_custom_call_thunk_test.cc"],
backends = ["gpu"],
deps = [
":legacy_custom_call_thunk",
":thunk",
":thunk_proto_cc",
"//xla/service:buffer_assignment",
"//xla/service:custom_call_status_public_headers",
"//xla/service:custom_call_target_registry",
"//xla/service:executable",
"//xla/service:hlo_proto_cc",
"//xla/service:platform_util",
"//xla/service/gpu:buffer_allocations",
"//xla/stream_executor:platform",
"//xla/stream_executor:platform_manager",
"//xla/stream_executor:stream",
"//xla/stream_executor:stream_executor_address_allocator",
"//xla/tsl/platform:status_macros",
"//xla/tsl/platform:statusor",
"//xla/tsl/util/proto:parse_text_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:status_matchers",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest_main",
],
)

cc_library(
name = "fft_thunk",
srcs = ["fft_thunk.cc"],
Expand Down Expand Up @@ -3606,6 +3666,7 @@ cc_library(
":host_to_device_copy_thunk",
":infeed_thunk",
":kernel_thunk",
":legacy_custom_call_thunk",
":memset_thunk",
":norm_thunk",
":nvshmem_all_reduce_thunk",
Expand All @@ -3627,6 +3688,7 @@ cc_library(
"//xla/backends/cpu:target_machine_options",
"//xla/hlo/ir:hlo",
"//xla/service:buffer_assignment",
"//xla/service:hlo_proto_cc",
"//xla/stream_executor:device_description",
"//xla/stream_executor:kernel_spec",
"//xla/stream_executor:stream_executor_h",
Expand All @@ -3639,6 +3701,7 @@ cc_library(
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@com_google_protobuf//:protobuf",
"@com_google_protobuf//:protobuf_lite",
],
)

Expand Down
66 changes: 34 additions & 32 deletions xla/backends/gpu/runtime/command_buffer_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ limitations under the License.
#include "xla/backends/gpu/runtime/dynamic_memcpy_thunk.h"
#include "xla/backends/gpu/runtime/dynamic_slice_thunk.h"
#include "xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.h"
#include "xla/backends/gpu/runtime/legacy_custom_call_thunk.h"
#include "xla/backends/gpu/runtime/p2p_thunk_common.h"
#include "xla/backends/gpu/runtime/ragged_all_to_all_thunk.h"
#include "xla/backends/gpu/runtime/recv_thunk.h"
Expand Down Expand Up @@ -1211,27 +1212,16 @@ Command::BufferUses CuDnnCmd::buffer_uses() const {
}

//===----------------------------------------------------------------------===//
// CustomCallCmd
// LegacyCustomCallCmd
//===----------------------------------------------------------------------===//

absl::StatusOr<const se::CommandBuffer::Command*> CustomCallCmd::Record(
const Thunk::ExecuteParams& execute_params,
const RecordParams& record_params, RecordAction record_action,
se::CommandBuffer* command_buffer) {
if (handler_ == nullptr) {
return RecordLegacyCustomCall(execute_params, record_params,
std::move(record_action), command_buffer);
}
return RecordXlaFfiCall(execute_params, record_params,
std::move(record_action), command_buffer);
}

namespace {
// Records each buffer associated with each slice into the provided vector.
// Returns an error if any of the slices is missing a buffer allocation.
absl::Status GetBuffers(const Thunk::ExecuteParams& execute_params,
absl::Span<const NullableShapedSlice> slices,
std::vector<void*>& buffers, absl::string_view label) {
static absl::Status GetBuffers(const Thunk::ExecuteParams& execute_params,
absl::Span<const NullableShapedSlice> slices,
std::vector<void*>& buffers,
absl::string_view label) {
for (int i = 0; i < slices.size(); ++i) {
if (!slices[i].has_value()) {
buffers.push_back(nullptr);
Expand All @@ -1253,21 +1243,18 @@ absl::Status GetBuffers(const Thunk::ExecuteParams& execute_params,
}
} // namespace

absl::StatusOr<const se::CommandBuffer::Command*>
CustomCallCmd::RecordLegacyCustomCall(
absl::StatusOr<const se::CommandBuffer::Command*> LegacyCustomCallCmd::Record(
const Thunk::ExecuteParams& execute_params,
const RecordParams& record_params, RecordAction record_action,
se::CommandBuffer* command_buffer) {
std::vector<void*> buffers;
buffers.reserve(operands_.size() + results_.size());

VLOG(5) << "CustomCallCmd: target_name=" << target_name_;
TF_RETURN_IF_ERROR(
GetBuffers(execute_params, operands_, buffers, " Operand "));
TF_RETURN_IF_ERROR(
GetBuffers(execute_params, results_, buffers, " Result "));
VLOG(5) << "LegacyCustomCallCmd: target_name=" << target_name_;
RETURN_IF_ERROR(GetBuffers(execute_params, operands_, buffers, " Operand "));
RETURN_IF_ERROR(GetBuffers(execute_params, results_, buffers, " Result "));

TF_ASSIGN_OR_RETURN(
ASSIGN_OR_RETURN(
auto nested_cmd,
se::TraceCommandBufferFactory::Create(
execute_params.stream->parent(),
Expand All @@ -1293,11 +1280,26 @@ CustomCallCmd::RecordLegacyCustomCall(
});
}

absl::StatusOr<const se::CommandBuffer::Command*>
CustomCallCmd::RecordXlaFfiCall(const Thunk::ExecuteParams& execute_params,
const RecordParams& record_params,
RecordAction record_action,
se::CommandBuffer* command_buffer) {
Command::BufferUses LegacyCustomCallCmd::buffer_uses() const {
Command::BufferUses buffer_usage;
for (auto& slices : {operands_, results_}) {
for (const std::optional<ShapedSlice>& slice : slices) {
if (slice.has_value()) {
buffer_usage.push_back(BufferUse::Write(slice->slice, slice->shape));
}
}
}
return buffer_usage;
}

//===----------------------------------------------------------------------===//
// CustomCallCmd (FFI)
//===----------------------------------------------------------------------===//

absl::StatusOr<const se::CommandBuffer::Command*> CustomCallCmd::Record(
const Thunk::ExecuteParams& execute_params,
const RecordParams& record_params, RecordAction record_action,
se::CommandBuffer* command_buffer) {
// TODO(ezhulenev): This is not the most optimal approach, as we'll be doing
// a lot of extra allocation on every call. We have to keep attributes
// separate from arguments, as they do not change after thunk is
Expand Down Expand Up @@ -1342,8 +1344,8 @@ CustomCallCmd::RecordXlaFfiCall(const Thunk::ExecuteParams& execute_params,

// Borrow the FFI call frame from the object pool and update with the actual
// device memory addresses.
TF_ASSIGN_OR_RETURN(auto call_frame, call_frames_->GetOrCreate());
TF_RETURN_IF_ERROR(call_frame->UpdateWithBuffers(arguments, results));
ASSIGN_OR_RETURN(auto call_frame, call_frames_->GetOrCreate());
RETURN_IF_ERROR(call_frame->UpdateWithBuffers(arguments, results));

RunId run_id = execute_params.collective_params->run_id;

Expand All @@ -1360,7 +1362,7 @@ CustomCallCmd::RecordXlaFfiCall(const Thunk::ExecuteParams& execute_params,
}
}

TF_ASSIGN_OR_RETURN(
ASSIGN_OR_RETURN(
auto nested_cmd,
se::TraceCommandBufferFactory::Create(
execute_params.stream->parent(),
Expand Down
Loading
Loading