Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions jax_rocm_plugin/.bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ build:rocm --copt=-Wno-gnu-offsetof-extensions
build:rocm --copt=-Qunused-arguments
build:rocm --action_env=TF_HIPCC_CLANG="1"

# Strip Bazel sandbox rpaths and embed wheel-specific rpaths instead.
# Use this config when building wheels to avoid post-build patchelf usage.
build:rocm_wheel --@local_config_rocm//rocm:rocm_path_type=link_only


#############################################################################
# Configuration for running RBE builds and tests
Expand Down
1 change: 1 addition & 0 deletions jax_rocm_plugin/build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,7 @@ async def main():

if "rocm" in args.wheels:
wheel_build_command_base.append("--config=rocm_base")
wheel_build_command_base.append("--config=rocm_wheel")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This config just turns the -rpath settings on and off. Could we call this config something like no_rocm_rpath?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think release_wheel would be more appropriate here. no_rocm_rpath seems to be very specific. It is in fact replace rpaths with wheel based rpaths.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is in fact replace rpaths with wheel based rpaths.

Which wheels are you referring to here? The JAX plugin wheels? Or the ROCm whels?

if args.use_clang:
wheel_build_command_base.append("--config=rocm")
wheel_build_command_base.append(
Expand Down
53 changes: 53 additions & 0 deletions jax_rocm_plugin/build/rpath.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
load("@jax//jaxlib:jax.bzl", "nanobind_extension")

_ROCM_LINK_ONLY = "@local_config_rocm//rocm:link_only"

_WHEEL_RPATHS = [
"-Wl,-rpath,$$ORIGIN/../rocm/lib",
"-Wl,-rpath,$$ORIGIN/../../rocm/lib",
"-Wl,-rpath,/opt/rocm/lib",
]

def _wheel_features():
return select({
_ROCM_LINK_ONLY: ["no_solib_rpaths"],
"//conditions:default": [],
})

def _wheel_linkopts():
return select({
_ROCM_LINK_ONLY: _WHEEL_RPATHS,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do these need to be conditional? Don't we always need the -rpath options?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. For tests we don't need rpath, tests shall be looking inside the sandbox

"//conditions:default": [],
})

def rocm_cc_binary(name, features = [], linkopts = [], **kwargs):
"""cc_binary that automatically strips solib rpaths and embeds wheel RPATHs.

Args:
name: Target name.
features: Additional features (rpath features are appended automatically).
linkopts: Additional linkopts (wheel RPATHs are appended automatically).
**kwargs: Passed through to native.cc_binary.
"""
native.cc_binary(
name = name,
features = features + _wheel_features(),
linkopts = linkopts + _wheel_linkopts(),
**kwargs
)

def rocm_nanobind_extension(name, features = [], linkopts = [], **kwargs):
"""nanobind_extension that automatically strips solib rpaths and embeds wheel RPATHs.

Args:
name: Target name.
features: Additional features (rpath features are appended automatically).
linkopts: Additional linkopts (wheel RPATHs are appended automatically).
**kwargs: Passed through to nanobind_extension.
"""
nanobind_extension(
name = name,
features = features + _wheel_features(),
linkopts = linkopts + _wheel_linkopts(),
**kwargs
)
221 changes: 221 additions & 0 deletions jax_rocm_plugin/jaxlib_ext/rocm/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
# Copyright 2026 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Local nanobind_extension wrappers for JAX ROCm GPU kernels.
#
# These targets mirror @jax//jaxlib/rocm but allow us to control
# features and linkopts via select() on rocm_path_type:
# - link_only → strip solib rpaths, embed wheel-specific rpaths
# - (default) → preserve original Bazel rpath behavior

load("@rules_python//python:defs.bzl", "py_library")
load("//build:rpath.bzl", "rocm_nanobind_extension")

licenses(["notice"]) # Apache 2

package(default_visibility = ["//visibility:public"])

rocm_nanobind_extension(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless we intend to completely remove this file from upstream, this file is going to be a pain to maintain, as we'll be occasionally porting upstream changes into the plugin repo. Do you intend to remove the upstream one?

It looks like all this is doing is using rocm_nanobind_extension instead of nanobind_extension so that we can pick up the rpath configuration. Is that correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes there is no difference except that rocm_nanobind_extension handles the rpaths. The question is if we want to implement it upstream (jax repo) or not....

name = "_rnn",
srcs = ["@jax//jaxlib/gpu:rnn.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
module_name = "_rnn",
deps = [
"@jax//jaxlib/rocm:hip_vendor",
"@jax//jaxlib/rocm:miopen_rnn_kernels",
"@jax//jaxlib:absl_status_casters",
"@jax//jaxlib:kernel_nanobind_helpers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings:str_format",
"@nanobind",
],
)

rocm_nanobind_extension(
name = "_solver",
srcs = ["@jax//jaxlib/gpu:solver.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
module_name = "_solver",
deps = [
"@jax//jaxlib/rocm:hip_vendor",
"@jax//jaxlib/rocm:hipsolver_kernels_ffi",
"@jax//jaxlib:kernel_nanobind_helpers",
"@local_config_rocm//rocm:hipblas",
"@local_config_rocm//rocm:hipsolver",
"@local_config_rocm//rocm:rocm_headers",
"@nanobind",
],
)

rocm_nanobind_extension(
name = "_sparse",
srcs = ["@jax//jaxlib/gpu:sparse.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
module_name = "_sparse",
deps = [
"@jax//jaxlib/rocm:hip_gpu_kernel_helpers",
"@jax//jaxlib/rocm:hip_vendor",
"@jax//jaxlib/rocm:hipsparse_kernels",
"@jax//jaxlib:absl_status_casters",
"@jax//jaxlib:kernel_nanobind_helpers",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@local_config_rocm//rocm:hipsparse",
"@local_config_rocm//rocm:rocm_headers",
"@nanobind",
"@xla//xla/tsl/python/lib/core:numpy",
],
)

rocm_nanobind_extension(
name = "_linalg",
srcs = ["@jax//jaxlib/gpu:linalg.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
module_name = "_linalg",
deps = [
"@jax//jaxlib/rocm:hip_gpu_kernel_helpers",
"@jax//jaxlib/rocm:hip_linalg_kernels",
"@jax//jaxlib/rocm:hip_vendor",
"@jax//jaxlib:kernel_nanobind_helpers",
"@local_config_rocm//rocm:rocm_headers",
"@nanobind",
],
)

rocm_nanobind_extension(
name = "_prng",
srcs = ["@jax//jaxlib/gpu:prng.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
module_name = "_prng",
deps = [
"@jax//jaxlib/rocm:hip_gpu_kernel_helpers",
"@jax//jaxlib/rocm:hip_prng_kernels",
"@jax//jaxlib/rocm:hip_vendor",
"@jax//jaxlib:kernel_nanobind_helpers",
"@local_config_rocm//rocm:hip",
"@local_config_rocm//rocm:hip_runtime",
"@local_config_rocm//rocm:rocm_headers",
"@nanobind",
],
)

rocm_nanobind_extension(
name = "_hybrid",
srcs = ["@jax//jaxlib/gpu:hybrid.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
linkopts = [
"-L/opt/rocm/lib",
"-lamdhip64",
],
module_name = "_hybrid",
deps = [
"@jax//jaxlib/rocm:hip_gpu_kernel_helpers",
"@jax//jaxlib/rocm:hip_hybrid_kernels",
"@jax//jaxlib/rocm:hip_vendor",
"@jax//jaxlib:kernel_nanobind_helpers",
"@jax//jaxlib/cpu:lapack_kernels",
"@com_google_absl//absl/base",
"@local_config_rocm//rocm:hip",
"@local_config_rocm//rocm:hip_runtime",
"@local_config_rocm//rocm:rocm_headers",
"@nanobind",
"@xla//xla/ffi/api:ffi",
],
)

rocm_nanobind_extension(
name = "_triton",
srcs = ["@jax//jaxlib/gpu:triton.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
module_name = "_triton",
deps = [
"@jax//jaxlib/rocm:hip_gpu_kernel_helpers",
"@jax//jaxlib/rocm:hip_vendor",
"@jax//jaxlib/rocm:triton_kernels",
"@jax//jaxlib/rocm:triton_utils",
"@jax//jaxlib:absl_status_casters",
"@jax//jaxlib:kernel_nanobind_helpers",
"@jax//jaxlib/gpu:triton_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@local_config_rocm//rocm:hip",
"@local_config_rocm//rocm:hip_runtime",
"@nanobind",
],
)

rocm_nanobind_extension(
name = "rocm_plugin_extension",
srcs = ["rocm_plugin_extension.cc"],
module_name = "rocm_plugin_extension",
deps = [
"@jax//jaxlib/rocm:hip_gpu_kernel_helpers",
"@jax//jaxlib/rocm:py_client_gpu",
"@jax//jaxlib:kernel_nanobind_helpers",
"@jax//jaxlib/gpu:gpu_plugin_extension",
"@com_google_absl//absl/log",
"@com_google_absl//absl/strings",
"@local_config_rocm//rocm:hip_runtime",
"@local_config_rocm//rocm:rocm_headers",
"@nanobind",
],
)

py_library(
name = "rocm_gpu_support",
deps = [
":_hybrid",
":_linalg",
":_prng",
":_rnn",
":_solver",
":_sparse",
":_triton",
],
)
Loading