diff --git a/jax_rocm_plugin/.bazelrc b/jax_rocm_plugin/.bazelrc index e3ee8e806b..f71573e88d 100644 --- a/jax_rocm_plugin/.bazelrc +++ b/jax_rocm_plugin/.bazelrc @@ -99,6 +99,10 @@ build:rocm --copt=-Wno-gnu-offsetof-extensions build:rocm --copt=-Qunused-arguments build:rocm --action_env=TF_HIPCC_CLANG="1" +# Strip Bazel sandbox rpaths and embed wheel-specific rpaths instead. +# Use this config when building wheels to avoid post-build patchelf usage. +build:rocm_wheel --@local_config_rocm//rocm:rocm_path_type=link_only + ############################################################################# # Configuration for running RBE builds and tests diff --git a/jax_rocm_plugin/build/build.py b/jax_rocm_plugin/build/build.py index b8cb454b19..6d440b5ad2 100755 --- a/jax_rocm_plugin/build/build.py +++ b/jax_rocm_plugin/build/build.py @@ -611,6 +611,7 @@ async def main(): if "rocm" in args.wheels: wheel_build_command_base.append("--config=rocm_base") + wheel_build_command_base.append("--config=rocm_wheel") if args.use_clang: wheel_build_command_base.append("--config=rocm") wheel_build_command_base.append( diff --git a/jax_rocm_plugin/build/rpath.bzl b/jax_rocm_plugin/build/rpath.bzl new file mode 100644 index 0000000000..d9b881e30e --- /dev/null +++ b/jax_rocm_plugin/build/rpath.bzl @@ -0,0 +1,53 @@ +load("@jax//jaxlib:jax.bzl", "nanobind_extension") + +_ROCM_LINK_ONLY = "@local_config_rocm//rocm:link_only" + +_WHEEL_RPATHS = [ + "-Wl,-rpath,$$ORIGIN/../rocm/lib", + "-Wl,-rpath,$$ORIGIN/../../rocm/lib", + "-Wl,-rpath,/opt/rocm/lib", +] + +def _wheel_features(): + return select({ + _ROCM_LINK_ONLY: ["no_solib_rpaths"], + "//conditions:default": [], + }) + +def _wheel_linkopts(): + return select({ + _ROCM_LINK_ONLY: _WHEEL_RPATHS, + "//conditions:default": [], + }) + +def rocm_cc_binary(name, features = [], linkopts = [], **kwargs): + """cc_binary that automatically strips solib rpaths and embeds wheel RPATHs. + + Args: + name: Target name. + features: Additional features (rpath features are appended automatically). + linkopts: Additional linkopts (wheel RPATHs are appended automatically). + **kwargs: Passed through to native.cc_binary. + """ + native.cc_binary( + name = name, + features = features + _wheel_features(), + linkopts = linkopts + _wheel_linkopts(), + **kwargs + ) + +def rocm_nanobind_extension(name, features = [], linkopts = [], **kwargs): + """nanobind_extension that automatically strips solib rpaths and embeds wheel RPATHs. + + Args: + name: Target name. + features: Additional features (rpath features are appended automatically). + linkopts: Additional linkopts (wheel RPATHs are appended automatically). + **kwargs: Passed through to nanobind_extension. + """ + nanobind_extension( + name = name, + features = features + _wheel_features(), + linkopts = linkopts + _wheel_linkopts(), + **kwargs + ) diff --git a/jax_rocm_plugin/jaxlib_ext/rocm/BUILD b/jax_rocm_plugin/jaxlib_ext/rocm/BUILD new file mode 100644 index 0000000000..6a4955025c --- /dev/null +++ b/jax_rocm_plugin/jaxlib_ext/rocm/BUILD @@ -0,0 +1,221 @@ +# Copyright 2026 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Local nanobind_extension wrappers for JAX ROCm GPU kernels. +# +# These targets mirror @jax//jaxlib/rocm but allow us to control +# features and linkopts via select() on rocm_path_type: +# - link_only → strip solib rpaths, embed wheel-specific rpaths +# - (default) → preserve original Bazel rpath behavior + +load("@rules_python//python:defs.bzl", "py_library") +load("//build:rpath.bzl", "rocm_nanobind_extension") + +licenses(["notice"]) # Apache 2 + +package(default_visibility = ["//visibility:public"]) + +rocm_nanobind_extension( + name = "_rnn", + srcs = ["@jax//jaxlib/gpu:rnn.cc"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + module_name = "_rnn", + deps = [ + "@jax//jaxlib/rocm:hip_vendor", + "@jax//jaxlib/rocm:miopen_rnn_kernels", + "@jax//jaxlib:absl_status_casters", + "@jax//jaxlib:kernel_nanobind_helpers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings:str_format", + "@nanobind", + ], +) + +rocm_nanobind_extension( + name = "_solver", + srcs = ["@jax//jaxlib/gpu:solver.cc"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + module_name = "_solver", + deps = [ + "@jax//jaxlib/rocm:hip_vendor", + "@jax//jaxlib/rocm:hipsolver_kernels_ffi", + "@jax//jaxlib:kernel_nanobind_helpers", + "@local_config_rocm//rocm:hipblas", + "@local_config_rocm//rocm:hipsolver", + "@local_config_rocm//rocm:rocm_headers", + "@nanobind", + ], +) + +rocm_nanobind_extension( + name = "_sparse", + srcs = ["@jax//jaxlib/gpu:sparse.cc"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + module_name = "_sparse", + deps = [ + "@jax//jaxlib/rocm:hip_gpu_kernel_helpers", + "@jax//jaxlib/rocm:hip_vendor", + "@jax//jaxlib/rocm:hipsparse_kernels", + "@jax//jaxlib:absl_status_casters", + "@jax//jaxlib:kernel_nanobind_helpers", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@local_config_rocm//rocm:hipsparse", + "@local_config_rocm//rocm:rocm_headers", + "@nanobind", + "@xla//xla/tsl/python/lib/core:numpy", + ], +) + +rocm_nanobind_extension( + name = "_linalg", + srcs = ["@jax//jaxlib/gpu:linalg.cc"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + module_name = "_linalg", + deps = [ + "@jax//jaxlib/rocm:hip_gpu_kernel_helpers", + "@jax//jaxlib/rocm:hip_linalg_kernels", + "@jax//jaxlib/rocm:hip_vendor", + "@jax//jaxlib:kernel_nanobind_helpers", + "@local_config_rocm//rocm:rocm_headers", + "@nanobind", + ], +) + +rocm_nanobind_extension( + name = "_prng", + srcs = ["@jax//jaxlib/gpu:prng.cc"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + module_name = "_prng", + deps = [ + "@jax//jaxlib/rocm:hip_gpu_kernel_helpers", + "@jax//jaxlib/rocm:hip_prng_kernels", + "@jax//jaxlib/rocm:hip_vendor", + "@jax//jaxlib:kernel_nanobind_helpers", + "@local_config_rocm//rocm:hip", + "@local_config_rocm//rocm:hip_runtime", + "@local_config_rocm//rocm:rocm_headers", + "@nanobind", + ], +) + +rocm_nanobind_extension( + name = "_hybrid", + srcs = ["@jax//jaxlib/gpu:hybrid.cc"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + linkopts = [ + "-L/opt/rocm/lib", + "-lamdhip64", + ], + module_name = "_hybrid", + deps = [ + "@jax//jaxlib/rocm:hip_gpu_kernel_helpers", + "@jax//jaxlib/rocm:hip_hybrid_kernels", + "@jax//jaxlib/rocm:hip_vendor", + "@jax//jaxlib:kernel_nanobind_helpers", + "@jax//jaxlib/cpu:lapack_kernels", + "@com_google_absl//absl/base", + "@local_config_rocm//rocm:hip", + "@local_config_rocm//rocm:hip_runtime", + "@local_config_rocm//rocm:rocm_headers", + "@nanobind", + "@xla//xla/ffi/api:ffi", + ], +) + +rocm_nanobind_extension( + name = "_triton", + srcs = ["@jax//jaxlib/gpu:triton.cc"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + module_name = "_triton", + deps = [ + "@jax//jaxlib/rocm:hip_gpu_kernel_helpers", + "@jax//jaxlib/rocm:hip_vendor", + "@jax//jaxlib/rocm:triton_kernels", + "@jax//jaxlib/rocm:triton_utils", + "@jax//jaxlib:absl_status_casters", + "@jax//jaxlib:kernel_nanobind_helpers", + "@jax//jaxlib/gpu:triton_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@local_config_rocm//rocm:hip", + "@local_config_rocm//rocm:hip_runtime", + "@nanobind", + ], +) + +rocm_nanobind_extension( + name = "rocm_plugin_extension", + srcs = ["rocm_plugin_extension.cc"], + module_name = "rocm_plugin_extension", + deps = [ + "@jax//jaxlib/rocm:hip_gpu_kernel_helpers", + "@jax//jaxlib/rocm:py_client_gpu", + "@jax//jaxlib:kernel_nanobind_helpers", + "@jax//jaxlib/gpu:gpu_plugin_extension", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + "@local_config_rocm//rocm:hip_runtime", + "@local_config_rocm//rocm:rocm_headers", + "@nanobind", + ], +) + +py_library( + name = "rocm_gpu_support", + deps = [ + ":_hybrid", + ":_linalg", + ":_prng", + ":_rnn", + ":_solver", + ":_sparse", + ":_triton", + ], +) diff --git a/jax_rocm_plugin/jaxlib_ext/rocm/rocm_plugin_extension.cc b/jax_rocm_plugin/jaxlib_ext/rocm/rocm_plugin_extension.cc new file mode 100644 index 0000000000..05be1e81c8 --- /dev/null +++ b/jax_rocm_plugin/jaxlib_ext/rocm/rocm_plugin_extension.cc @@ -0,0 +1,126 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "rocm/include/hip/hip_runtime.h" +#include "nanobind/nanobind.h" +#include "jaxlib/gpu/gpu_plugin_extension.h" +#include "jaxlib/gpu/py_client_gpu.h" +#include "jaxlib/kernel_nanobind_helpers.h" + +namespace nb = nanobind; + +namespace jax { +namespace { + +std::string ToString(hipError_t result) { +#define OSTREAM_ROCM_ERROR(__name) \ + case hipError##__name: \ + return "HIP_ERROR_" #__name; + + switch (result) { + OSTREAM_ROCM_ERROR(InvalidValue) + OSTREAM_ROCM_ERROR(OutOfMemory) + OSTREAM_ROCM_ERROR(NotInitialized) + OSTREAM_ROCM_ERROR(Deinitialized) + OSTREAM_ROCM_ERROR(NoDevice) + OSTREAM_ROCM_ERROR(InvalidDevice) + OSTREAM_ROCM_ERROR(InvalidImage) + OSTREAM_ROCM_ERROR(InvalidContext) + OSTREAM_ROCM_ERROR(InvalidHandle) + OSTREAM_ROCM_ERROR(NotFound) + OSTREAM_ROCM_ERROR(NotReady) + OSTREAM_ROCM_ERROR(NoBinaryForGpu) + + // Encountered an uncorrectable ECC error during execution. + OSTREAM_ROCM_ERROR(ECCNotCorrectable) + + // Load/store on an invalid address. Must reboot all context. + case 700: + return "ROCM_ERROR_ILLEGAL_ADDRESS"; + // Passed too many / wrong arguments, too many threads for register count. + case 701: + return "ROCM_ERROR_LAUNCH_OUT_OF_RESOURCES"; + + OSTREAM_ROCM_ERROR(ContextAlreadyInUse) + OSTREAM_ROCM_ERROR(PeerAccessUnsupported) + OSTREAM_ROCM_ERROR(Unknown) // Unknown internal error to ROCM. + default: + return absl::StrCat("hipError_t(", static_cast(result), ")"); + } +} + +static nb::dict GpuTransposePlanCacheType() { + auto [type_id, type_info] = hip::GpuTransposePlanCacheTypeInfo(); + nb::dict d; + d["type_id"] = nb::capsule(type_id); + d["type_info"] = nb::capsule(type_info); + return d; +} + +nb::dict FfiTypes() { + nb::dict dict; + dict["GpuTransposePlanCache"] = GpuTransposePlanCacheType(); + return dict; +} + +nb::dict FfiHandlers() { + nb::dict dict; + nb::dict gpu_callback_dict; + gpu_callback_dict["instantiate"] = + EncapsulateFfiHandler(hip::kGpuTransposePlanCacheInstantiate); + gpu_callback_dict["execute"] = + EncapsulateFfiHandler(hip::kXlaFfiPythonGpuCallback); + dict["xla_ffi_python_gpu_callback"] = gpu_callback_dict; + dict["xla_ffi_partitioned_python_gpu_callback"] = gpu_callback_dict; + dict["xla_buffer_python_gpu_callback"] = + EncapsulateFfiHandler(hip::kXlaBufferPythonGpuCallback); + dict["xla_buffer_python_gpu_callback_cmd_buffer"] = + EncapsulateFfiHandler(hip::kXlaBufferPythonGpuCallbackCmdBuffer); + return dict; +} + +} // namespace + +NB_MODULE(rocm_plugin_extension, m) { + BuildGpuPluginExtension(m); + m.def("ffi_types", &FfiTypes); + m.def("ffi_handlers", &FfiHandlers); + + m.def( + "get_device_ordinal", + [](std::intptr_t data_value) { + if (data_value == 0) { + return 0; + } + int device_ordinal; + void* data_ptr = reinterpret_cast(data_value); + hipError_t result = + hipPointerGetAttribute(static_cast(&device_ordinal), + HIP_POINTER_ATTRIBUTE_DEVICE_ORDINAL, + reinterpret_cast(data_ptr)); + if (result != hipSuccess) { + LOG(FATAL) << "Not able to get the device_ordinal for ptr: " + << data_ptr << ". Error: " << ToString(result); + } + return device_ordinal; + }, + nb::arg("data_value")); +} +} // namespace jax diff --git a/jax_rocm_plugin/jaxlib_ext/tools/BUILD.bazel b/jax_rocm_plugin/jaxlib_ext/tools/BUILD.bazel index 82a36bd5cd..9d8a392e89 100644 --- a/jax_rocm_plugin/jaxlib_ext/tools/BUILD.bazel +++ b/jax_rocm_plugin/jaxlib_ext/tools/BUILD.bazel @@ -26,8 +26,8 @@ ROCM_PLUGIN_SOURCES = [ "//jax_plugins/rocm:plugin_pyproject.toml", "//jax_plugins/rocm:plugin_setup.py", "//pjrt/python:version.py", - "@jax//jaxlib/rocm:rocm_gpu_support", - "@jax//jaxlib/rocm:rocm_plugin_extension", + "//jaxlib_ext/rocm:rocm_gpu_support", + "//jaxlib_ext/rocm:rocm_plugin_extension", ] py_binary( diff --git a/jax_rocm_plugin/jaxlib_ext/tools/build_gpu_kernels_wheel.py b/jax_rocm_plugin/jaxlib_ext/tools/build_gpu_kernels_wheel.py index 383186804b..64aa160138 100644 --- a/jax_rocm_plugin/jaxlib_ext/tools/build_gpu_kernels_wheel.py +++ b/jax_rocm_plugin/jaxlib_ext/tools/build_gpu_kernels_wheel.py @@ -22,8 +22,6 @@ import os import pathlib import shutil -import stat -import subprocess import tempfile # pylint: disable=import-error,invalid-name,consider-using-with @@ -196,7 +194,9 @@ def prepare_wheel_rocm(wheel_sources_path: pathlib.Path, *, cpu, rocm_version, s plugin_dir, xla_commit_hash, jax_commit_hash, get_rocm_jax_git_hash() ) - # Copy .so files: always from jax runfiles + # Copy .so files from local wrapper targets (//jaxlib_ext/rocm). + # RPATHs are set at build time via Bazel features/linkopts when + # --config=rocm_wheel is used (rocm_path_type=link_only). for so_file in [ f"_linalg.{pyext}", f"_prng.{pyext}", @@ -207,47 +207,7 @@ def prepare_wheel_rocm(wheel_sources_path: pathlib.Path, *, cpu, rocm_version, s f"_triton.{pyext}", f"rocm_plugin_extension.{pyext}", ]: - shutil.copy(r.Rlocation(f"jax/jaxlib/rocm/{so_file}"), plugin_dir) - - # NOTE(mrodden): this is a hack to change/set rpath values - # in the shared objects that are produced by the bazel build - # before they get pulled into the wheel build process. - # we have to do this change here because setting rpath - # using bazel requires the rpath to be valid during the build - # which won't be correct until we make changes to - # the xla/tsl/jax plugin build - - try: - subprocess.check_output(["which", "patchelf"]) - except subprocess.CalledProcessError as ex: - mesg = ( - "rocm plugin and kernel wheel builds require patchelf. " - "please install 'patchelf' and run again" - ) - raise RuntimeError(mesg) from ex - - files = [ - f"_linalg.{pyext}", - f"_prng.{pyext}", - f"_solver.{pyext}", - f"_sparse.{pyext}", - f"_hybrid.{pyext}", - f"_rnn.{pyext}", - f"_triton.{pyext}", - f"rocm_plugin_extension.{pyext}", - ] - runpath = "$ORIGIN/../rocm/lib:$ORIGIN/../../rocm/lib:/opt/rocm/lib" - # patchelf --set-rpath $RUNPATH $so - for f in files: - so_path = os.path.join(plugin_dir, f) - fix_perms = False - perms = os.stat(so_path).st_mode - if not perms & stat.S_IWUSR: - fix_perms = True - os.chmod(so_path, perms | stat.S_IWUSR) - subprocess.check_call(["patchelf", "--set-rpath", runpath, so_path]) - if fix_perms: - os.chmod(so_path, perms) + shutil.copy(rloc(f"jaxlib_ext/rocm/{so_file}"), plugin_dir) tmpdir = tempfile.TemporaryDirectory(prefix="jax_rocm_plugin") diff --git a/jax_rocm_plugin/pjrt/BUILD b/jax_rocm_plugin/pjrt/BUILD index 30b793257a..23e072db29 100644 --- a/jax_rocm_plugin/pjrt/BUILD +++ b/jax_rocm_plugin/pjrt/BUILD @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("//build:rpath.bzl", "rocm_cc_binary") + licenses(["notice"]) # Apache 2 package(default_visibility = ["//visibility:public"]) -cc_binary( +rocm_cc_binary( name = "pjrt_c_api_gpu_plugin.so", linkopts = [ "-Wl,--version-script,$(location :gpu_version_script.lds)", diff --git a/jax_rocm_plugin/pjrt/tools/build_gpu_plugin_wheel.py b/jax_rocm_plugin/pjrt/tools/build_gpu_plugin_wheel.py index 1d9c2c335c..f269b26c6f 100644 --- a/jax_rocm_plugin/pjrt/tools/build_gpu_plugin_wheel.py +++ b/jax_rocm_plugin/pjrt/tools/build_gpu_plugin_wheel.py @@ -22,8 +22,6 @@ import os import pathlib import shutil -import stat -import subprocess import tempfile # pylint: disable=import-error,invalid-name,consider-using-with @@ -193,35 +191,6 @@ def prepare_rocm_plugin_wheel( plugin_dir, xla_commit_hash, jax_commit_hash, get_rocm_jax_git_hash() ) - # NOTE(mrodden): this is a hack to change/set rpath values - # in the shared objects that are produced by the bazel build - # before they get pulled into the wheel build process. - # we have to do this change here because setting rpath - # using bazel requires the rpath to be valid during the build - # which won't be correct until we make changes to - # the xla/tsl/jax plugin build - - try: - subprocess.check_output(["which", "patchelf"]) - except subprocess.CalledProcessError as ex: - mesg = ( - "rocm plugin and kernel wheel builds require patchelf. " - "please install 'patchelf' and run again" - ) - raise RuntimeError(mesg) from ex - - shared_obj_path = os.path.join(plugin_dir, "xla_rocm_plugin.so") - runpath = "$ORIGIN/../rocm/lib:$ORIGIN/../../rocm/lib:/opt/rocm/lib" - # patchelf --set-rpath $RUNPATH $so - fix_perms = False - perms = os.stat(shared_obj_path).st_mode - if not perms & stat.S_IWUSR: - fix_perms = True - os.chmod(shared_obj_path, perms | stat.S_IWUSR) - subprocess.check_call(["patchelf", "--set-rpath", runpath, shared_obj_path]) - if fix_perms: - os.chmod(shared_obj_path, perms) - tmpdir = None sources_path = args.sources_path diff --git a/jax_rocm_plugin/third_party/jax/0008-Add-rpath-control-for-wheel-builds.patch b/jax_rocm_plugin/third_party/jax/0008-Add-rpath-control-for-wheel-builds.patch new file mode 100644 index 0000000000..55ac4a52d5 --- /dev/null +++ b/jax_rocm_plugin/third_party/jax/0008-Add-rpath-control-for-wheel-builds.patch @@ -0,0 +1,128 @@ +diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD +--- a/jaxlib/gpu/BUILD ++++ b/jaxlib/gpu/BUILD +@@ -31,48 +31,51 @@ + default_visibility = ["//jax:internal"], + ) + +-exports_files(srcs = [ +- "blas_handle_pool.cc", +- "blas_handle_pool.h", +- "ffi_wrapper.h", +- "gpu_kernel_helpers.cc", +- "gpu_kernel_helpers.h", +- "gpu_kernels.cc", +- "hybrid.cc", +- "hybrid_kernels.cc", +- "hybrid_kernels.h", +- "linalg.cc", +- "linalg_kernels.cc", +- "linalg_kernels.cu.cc", +- "linalg_kernels.h", +- "make_batch_pointers.cu.cc", +- "make_batch_pointers.h", +- "prng.cc", +- "prng_kernels.cc", +- "prng_kernels.cu.cc", +- "prng_kernels.h", +- "py_client_gpu.cc", +- "py_client_gpu.h", +- "rnn.cc", +- "rnn_kernels.cc", +- "rnn_kernels.h", +- "solver.cc", +- "solver_handle_pool.cc", +- "solver_handle_pool.h", +- "solver_interface.cc", +- "solver_interface.h", +- "solver_kernels_ffi.cc", +- "solver_kernels_ffi.h", +- "sparse.cc", +- "sparse_kernels.cc", +- "sparse_kernels.h", +- "triton.cc", +- "triton_kernels.cc", +- "triton_kernels.h", +- "triton_utils.cc", +- "triton_utils.h", +- "vendor.h", +-]) ++exports_files( ++ srcs = [ ++ "blas_handle_pool.cc", ++ "blas_handle_pool.h", ++ "ffi_wrapper.h", ++ "gpu_kernel_helpers.cc", ++ "gpu_kernel_helpers.h", ++ "gpu_kernels.cc", ++ "hybrid.cc", ++ "hybrid_kernels.cc", ++ "hybrid_kernels.h", ++ "linalg.cc", ++ "linalg_kernels.cc", ++ "linalg_kernels.cu.cc", ++ "linalg_kernels.h", ++ "make_batch_pointers.cu.cc", ++ "make_batch_pointers.h", ++ "prng.cc", ++ "prng_kernels.cc", ++ "prng_kernels.cu.cc", ++ "prng_kernels.h", ++ "py_client_gpu.cc", ++ "py_client_gpu.h", ++ "rnn.cc", ++ "rnn_kernels.cc", ++ "rnn_kernels.h", ++ "solver.cc", ++ "solver_handle_pool.cc", ++ "solver_handle_pool.h", ++ "solver_interface.cc", ++ "solver_interface.h", ++ "solver_kernels_ffi.cc", ++ "solver_kernels_ffi.h", ++ "sparse.cc", ++ "sparse_kernels.cc", ++ "sparse_kernels.h", ++ "triton.cc", ++ "triton_kernels.cc", ++ "triton_kernels.h", ++ "triton_utils.cc", ++ "triton_utils.h", ++ "vendor.h", ++ ], ++ visibility = ["//visibility:public"], ++) + + proto_library( + name = "triton_proto", +@@ -82,10 +85,7 @@ + cc_proto_library( + name = "triton_cc_proto", + compatible_with = None, +- visibility = [ +- "//jax:internal", +- "//third_party/py/enzyme_ad:__subpackages__", +- ], ++ visibility = ["//visibility:public"], + deps = [":triton_proto"], + ) + +@@ -114,6 +114,7 @@ + name = "gpu_plugin_extension", + srcs = ["gpu_plugin_extension.cc"], + hdrs = ["gpu_plugin_extension.h"], ++ visibility = ["//visibility:public"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", +diff --git a/jaxlib/cpu/BUILD b/jaxlib/cpu/BUILD +--- a/jaxlib/cpu/BUILD ++++ b/jaxlib/cpu/BUILD +@@ -33,6 +33,7 @@ + name = "lapack_kernels", + srcs = ["lapack_kernels.cc"], + hdrs = ["lapack_kernels.h"], ++ visibility = ["//visibility:public"], + # compatible with libtpu + copts = ["-fexceptions"], + features = ["-use_header_modules"], diff --git a/jax_rocm_plugin/third_party/jax/workspace.bzl b/jax_rocm_plugin/third_party/jax/workspace.bzl index ccc5352339..7432ba2f3b 100644 --- a/jax_rocm_plugin/third_party/jax/workspace.bzl +++ b/jax_rocm_plugin/third_party/jax/workspace.bzl @@ -13,5 +13,6 @@ def repo(): "//third_party/jax:0005-Fix-HIP-availability-errors.patch", "//third_party/jax:0006-Enable-testing-with-ROCm-plugin-wheels.patch", # TODO: remove due to: https://github.com/jax-ml/jax/pull/34641 "//third_party/jax:0007-Fix-legacy-create-init.patch", # TODO: remove due to: https://github.com/jax-ml/jax/pull/34770 + "//third_party/jax:0008-Add-rpath-control-for-wheel-builds.patch", ], ) diff --git a/jax_rocm_plugin/third_party/xla/0001-Add-support-of-empty-rpaths.patch b/jax_rocm_plugin/third_party/xla/0001-Add-support-of-empty-rpaths.patch new file mode 100644 index 0000000000..71dccf6a1e --- /dev/null +++ b/jax_rocm_plugin/third_party/xla/0001-Add-support-of-empty-rpaths.patch @@ -0,0 +1,125 @@ +diff --git a/third_party/gpus/crosstool/hipcc_cc_toolchain_config.bzl.tpl b/third_party/gpus/crosstool/hipcc_cc_toolchain_config.bzl.tpl +index a97202d8e9..26c6bcdff5 100644 +--- a/third_party/gpus/crosstool/hipcc_cc_toolchain_config.bzl.tpl ++++ b/third_party/gpus/crosstool/hipcc_cc_toolchain_config.bzl.tpl +@@ -382,6 +382,10 @@ def _impl(ctx): + provides = ["profile"], + ) + ++ # Targets can set features = ["no_solib_rpaths"] to suppress the ++ # automatic solib RPATH entries and rely on their own linkopts instead. ++ no_solib_rpaths_feature = feature(name = "no_solib_rpaths") ++ + runtime_library_search_directories_feature = feature( + name = "runtime_library_search_directories", + flag_sets = [ +@@ -409,7 +413,10 @@ def _impl(ctx): + ), + ], + with_features = [ +- with_feature_set(features = ["static_link_cpp_runtimes"]), ++ with_feature_set( ++ features = ["static_link_cpp_runtimes"], ++ not_features = ["no_solib_rpaths"], ++ ), + ], + ), + flag_set( +@@ -430,7 +437,7 @@ def _impl(ctx): + ], + with_features = [ + with_feature_set( +- not_features = ["static_link_cpp_runtimes"], ++ not_features = ["static_link_cpp_runtimes", "no_solib_rpaths"], + ), + ], + ), +@@ -1084,6 +1091,7 @@ def _impl(ctx): + shared_flag_feature, + linkstamps_feature, + output_execpath_flags_feature, ++ no_solib_rpaths_feature, + runtime_library_search_directories_feature, + library_search_directories_feature, + archiver_flags_feature, +diff --git a/third_party/gpus/rocm/BUILD.tpl b/third_party/gpus/rocm/BUILD.tpl +index b0f476883c..f1b4f36a84 100644 +--- a/third_party/gpus/rocm/BUILD.tpl ++++ b/third_party/gpus/rocm/BUILD.tpl +@@ -13,6 +13,7 @@ string_flag( + "hermetic", + "multiple", + "system", ++ "link_only", + ], + ) + +@@ -30,6 +31,13 @@ config_setting( + }, + ) + ++config_setting( ++ name = "link_only", ++ flag_values = { ++ ":rocm_path_type": "link_only", ++ }, ++) ++ + config_setting( + name = "using_hipcc", + values = { +@@ -133,9 +141,10 @@ cc_library( + deps = [":rocm_config"], + ) + +-# workaround to bring data to the same fs layout as expected in the rocm libs +-# rocblas assumes that miopen db files are located in ../share/miopen/db directory +-# hibplatslt assumes that tensile files are located in ../hipblaslt/library directory ++# Provides -L and -Wl,-rpath flags for ROCm libraries. ++# These must live in a cc_library (not a toolchain feature) because ++# cc_library linkopts propagate transitively through CcInfo to the ++# final linking target, whereas toolchain features do not. + cc_library( + name = "rocm_rpath", + linkopts = select({ +@@ -144,6 +153,11 @@ cc_library( + "-Wl,-rpath,external/local_config_rocm/rocm/%{rocm_root}/lib/llvm/lib", + "-Lexternal/local_config_rocm/rocm/%{rocm_root}/lib", + ], ++ ":link_only": [ ++ "-Wl,-rpath-link,external/local_config_rocm/rocm/%{rocm_root}/lib", ++ "-Wl,-rpath-link,external/local_config_rocm/rocm/%{rocm_root}/lib/llvm/lib", ++ "-Lexternal/local_config_rocm/rocm/%{rocm_root}/lib", ++ ], + ":multiple_rocm_paths": [ + "-Wl,-rpath=%{rocm_lib_paths}", + "-Lexternal/local_config_rocm/rocm/%{rocm_root}/lib", +@@ -254,7 +268,6 @@ cc_library( + includes = [ + "%{rocm_root}/include", + ], +- linkopts = ["-Wl,-rpath,external/local_config_rocm/rocm/%{rocm_root}/lib"], + linkstatic = 1, + visibility = ["//visibility:public"], + deps = [ +@@ -270,7 +283,6 @@ cc_library( + includes = [ + "%{rocm_root}/include", + ], +- linkopts = ["-Wl,-rpath,external/local_config_rocm/rocm/%{rocm_root}/lib"], + linkstatic = 1, + visibility = ["//visibility:public"], + deps = [ +@@ -301,11 +313,11 @@ cc_library( + "%{rocm_root}/lib/libMIOpen*.so*", + "%{rocm_root}/share/miopen/**", + ]), +- linkopts = ["-lMIOpen"], + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include", + ], ++ linkopts = ["-lMIOpen"], + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [ diff --git a/jax_rocm_plugin/third_party/xla/workspace.bzl b/jax_rocm_plugin/third_party/xla/workspace.bzl index 10645fe5b6..9126b2fac6 100644 --- a/jax_rocm_plugin/third_party/xla/workspace.bzl +++ b/jax_rocm_plugin/third_party/xla/workspace.bzl @@ -25,7 +25,7 @@ def repo(): sha256 = XLA_SHA256, strip_prefix = "xla-{commit}".format(commit = XLA_COMMIT), urls = ["https://github.com/ROCm/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)], - patch_file = [], + patch_file = ["//third_party/xla:0001-Add-support-of-empty-rpaths.patch"], ) # For development, one often wants to make changes to the TF repository as well