From a5d48b4f8bbc1d2ce1a7790cfc7d7ba5cb3ad5b1 Mon Sep 17 00:00:00 2001 From: Alexandros Theodoridis Date: Mon, 19 Jan 2026 16:49:44 +0000 Subject: [PATCH 1/4] My patch --- jaxlib/jax.bzl | 7 +++++-- jaxlib/tools/BUILD.bazel | 26 +++++++++++++++++++++++ jaxlib/tools/build_wheel.py | 42 ++++++++++++++++++++++++++++++------- 3 files changed, 65 insertions(+), 10 deletions(-) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 28f2367831a7..87dff699d2ba 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -192,10 +192,13 @@ def _gpu_test_deps(): "//jaxlib/rocm:gpu_only_test_deps", "//jax_plugins:gpu_plugin_only_test_deps", ], - "//jax:config_build_jaxlib_false": [ + "//jax:config_build_jaxlib_false": if_rocm_is_configured([ + "//jaxlib/tools:rocm_plugin_kernels_wheel", + "//jaxlib/tools:rocm_plugin_pjrt_wheel", + ]) + if_cuda_is_configured([ "//jaxlib/tools:pypi_jax_cuda_plugin_with_cuda_deps", "//jaxlib/tools:pypi_jax_cuda_pjrt_with_cuda_deps", - ], + ]), "//jax:config_build_jaxlib_wheel": [ "//jaxlib/tools:jax_cuda_plugin_py_import", "//jaxlib/tools:jax_cuda_pjrt_py_import", diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 3a1a48736e17..8618f5e74af3 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -137,9 +137,25 @@ string_flag( py_binary( name = "build_wheel_tool", srcs = ["build_wheel.py"], + data = [ + "LICENSE.txt", + "//jaxlib", + "//jaxlib:README.md", + "//jaxlib:_ifrt_proxy.pyi", + "//jaxlib:_pathways.pyi", + "//jaxlib:init.py", + "//jaxlib:jaxlib_binaries", + "//jaxlib:setup.py", + "//jaxlib:weakref_lru_cache.pyi", + "//jaxlib/mlir/_mlir_libs:_triton_ext.pyi", + "//jaxlib/mosaic/dialect/gpu:_mosaic_gpu_gen_enums.py", + "//jaxlib/mosaic/dialect/gpu:_mosaic_gpu_gen_ops.py", + "@xla//xla/ffi/api:ffi", + ], main = "build_wheel.py", deps = [ ":build_utils", + "//jaxlib:jaxlib_files", "@bazel_tools//tools/python/runfiles", "@pypi//build", "@pypi//setuptools", @@ -451,6 +467,16 @@ py_import( wheel = ":jaxlib_wheel", ) +py_import( + name = "rocm_plugin_kernels_wheel", + wheel = "@pypi_jax_rocm7_plugin//:whl", +) + +py_import( + name = "rocm_plugin_pjrt_wheel", + wheel = "@pypi_jax_rocm7_pjrt//:whl", +) + py_import( name = "jax_cuda_plugin_py_import", wheel = ":jax_cuda{cuda}_plugin_wheel".format(cuda = cuda_major_version), diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 593dfcdaa708..a0d24d4f503f 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -59,6 +59,12 @@ parser.add_argument( "--srcs", help="source files for the wheel", action="append" ) +parser.add_argument( + "--build_from_external_workspace", + action="store_true", + help="Set when building from an external/umbrella workspace where jax is @jax. " + "When true, source prefix is 'jax/' instead of '__main__/'.", +) args = parser.parse_args() r = runfiles.Create() @@ -116,7 +122,7 @@ def verify_mac_libraries_dont_reference_chkstack( if not _is_mac(): return file_path = _get_file_path( - f"__main__/jaxlib/_jax.{pyext}", runfiles, wheel_sources_map + f"{jax_source_prefix}jaxlib/_jax.{pyext}", runfiles, wheel_sources_map ) nm = subprocess.run( ["nm", "-g", file_path], @@ -147,8 +153,16 @@ def write_setup_cfg(sources_path, cpu): def prepare_wheel(wheel_sources_path: pathlib.Path, *, cpu, wheel_sources): - """Assembles a source tree for the wheel in `wheel_sources_path`.""" - source_file_prefix = build_utils.get_source_file_prefix(wheel_sources) + """Assembles a source tree for the wheel in `wheel_sources_path`. In case of + build under @jax strip the prefix""" + # wheel_sources is a list of file paths, not a prefix. If provided, use empty prefix. + # Otherwise, determine prefix based on build_from_external_workspace flag. + if wheel_sources: + source_file_prefix = "" + elif args.build_from_external_workspace: + source_file_prefix = "jax/" + else: + source_file_prefix = "__main__/" # The wheel sources provided by the transitive rules might have different path # prefixes, so we need to create a map of paths relative to the root package # to the full paths. @@ -389,13 +403,25 @@ def prepare_wheel(wheel_sources_path: pathlib.Path, *, cpu, wheel_sources): wheel_sources_map=wheel_sources_map, ) + # XLA FFI headers have different paths depending on build method: + # - wheel_sources_map: "xla/ffi/api/..." (mapped from external/xla/xla/ffi/api/...) + # - runfiles: "{source_file_prefix}jaxlib/include/xla/ffi/api/..." (copied to jaxlib/include/) + if args.build_from_external_workspace: + xla_ffi_files = [ + f"{source_file_prefix}jaxlib/include/xla/ffi/api/c_api.h", + f"{source_file_prefix}jaxlib/include/xla/ffi/api/api.h", + f"{source_file_prefix}jaxlib/include/xla/ffi/api/ffi.h", + ] + else: + xla_ffi_files = [ + "xla/ffi/api/c_api.h", + "xla/ffi/api/api.h", + "xla/ffi/api/ffi.h", + ] + copy_files( dst_dir=jaxlib_dir / "include" / "xla" / "ffi" / "api", - src_files=[ - "xla/ffi/api/c_api.h", - "xla/ffi/api/api.h", - "xla/ffi/api/ffi.h", - ], + src_files=xla_ffi_files, ) tmpdir = None From 49cb507ad844c43d54718981c81ae93cf9193cb2 Mon Sep 17 00:00:00 2001 From: Alexandros Theodoridis Date: Thu, 22 Jan 2026 16:39:19 +0000 Subject: [PATCH 2/4] Support wheels as build deps --- jaxlib/jax.bzl | 24 ++++++++++++++++-------- jaxlib/tools/BUILD.bazel | 4 ++-- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 87dff699d2ba..fac64974a231 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -193,8 +193,8 @@ def _gpu_test_deps(): "//jax_plugins:gpu_plugin_only_test_deps", ], "//jax:config_build_jaxlib_false": if_rocm_is_configured([ - "//jaxlib/tools:rocm_plugin_kernels_wheel", - "//jaxlib/tools:rocm_plugin_pjrt_wheel", + "@jax_rocm_plugin//:pjrt.whl", + "@jax_rocm_plugin//:plugin.whl", ]) + if_cuda_is_configured([ "//jaxlib/tools:pypi_jax_cuda_plugin_with_cuda_deps", "//jaxlib/tools:pypi_jax_cuda_pjrt_with_cuda_deps", @@ -307,6 +307,7 @@ def jax_multiplatform_test( tags = test_tags, main = main, exec_properties = tf_exec_properties({"tags": test_tags}), + legacy_create_init = 0, ) def jax_generate_backend_suites(backends = []): @@ -441,6 +442,9 @@ def _jax_wheel_impl(ctx): if ctx.attr.skip_gpu_kernels: args.add("--skip_gpu_kernels") + for extra_arg in ctx.attr.extra_args: + args.add(extra_arg) + srcs = [] for src in ctx.attr.source_files: for f in src.files.to_list(): @@ -486,6 +490,7 @@ _jax_wheel = rule( "include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:include_cuda_libs")), "override_include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:override_include_cuda_libs")), "py_freethreaded": attr.label(default = Label("@rules_python//python/config_settings:py_freethreaded")), + "extra_args": attr.string_list(default = []), }, implementation = _jax_wheel_impl, executable = False, @@ -501,7 +506,8 @@ def jax_wheel( enable_cuda = False, enable_rocm = False, platform_version = "", - source_files = []): + source_files = [], + extra_args = []): """Create jax artifact wheels. Common artifact attributes are grouped within a single macro. @@ -517,6 +523,7 @@ def jax_wheel( enable_rocm: whether to build a rocm wheel platform_version: the cuda version to use for the wheel source_files: the source files to include in the wheel + extra_args: additional arguments to pass to the wheel binary Returns: A wheel file or a wheel directory. @@ -543,13 +550,14 @@ def jax_wheel( }), # TODO(kanglan) Add @platforms//cpu:ppc64le once JAX Bazel is upgraded > 6.5.0. cpu = select({ - "//jaxlib/tools:macos_arm64": "arm64", - "//jaxlib/tools:macos_x86_64": "x86_64", - "//jaxlib/tools:win_amd64": "AMD64", - "//jaxlib/tools:linux_aarch64": "aarch64", - "//jaxlib/tools:linux_x86_64": "x86_64", + "@jax//jaxlib/tools:macos_arm64": "arm64", + "@jax//jaxlib/tools:macos_x86_64": "x86_64", + "@jax//jaxlib/tools:win_amd64": "AMD64", + "@jax//jaxlib/tools:linux_aarch64": "aarch64", + "@jax//jaxlib/tools:linux_x86_64": "x86_64", }), source_files = source_files, + extra_args = extra_args, ) def jax_source_package( diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 8618f5e74af3..5da6c7edb7d7 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -469,12 +469,12 @@ py_import( py_import( name = "rocm_plugin_kernels_wheel", - wheel = "@pypi_jax_rocm7_plugin//:whl", + wheel = "@jax_rocm_plugin//:jax_rocm7_plugin_wheel", ) py_import( name = "rocm_plugin_pjrt_wheel", - wheel = "@pypi_jax_rocm7_pjrt//:whl", + wheel = "@jax_rocm_plugin//:jax_rocm7_pjrt_wheel", ) py_import( From 79f790d036f71ce0d8677e5158e2206929933f80 Mon Sep 17 00:00:00 2001 From: Alexandros Theodoridis Date: Fri, 23 Jan 2026 11:10:35 +0000 Subject: [PATCH 3/4] Make deps parametrized --- WORKSPACE | 4 ++ jax/BUILD | 8 +++ jaxlib/jax.bzl | 8 ++- third_party/plugins/BUILD.bazel | 15 ++++ third_party/plugins/workspace.bzl | 113 ++++++++++++++++++++++++++++++ 5 files changed, 146 insertions(+), 2 deletions(-) create mode 100644 third_party/plugins/BUILD.bazel create mode 100644 third_party/plugins/workspace.bzl diff --git a/WORKSPACE b/WORKSPACE index 6b3d0e2aa010..dc695b2ee8f7 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -99,6 +99,10 @@ load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo") flatbuffers() +load("//third_party/plugins:workspace.bzl", "plugin_wheel_deps_repository") + +plugin_wheel_deps_repository(name = "plugin_wheel_deps") + load("//:test_shard_count.bzl", "test_shard_count_repository") test_shard_count_repository( diff --git a/jax/BUILD b/jax/BUILD index 0aaed7b61d2c..2a7a368ed130 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -47,6 +47,7 @@ string_flag( "true", "false", "wheel", + "plugin_wheels", ], ) @@ -71,6 +72,13 @@ config_setting( }, ) +config_setting( + name = "config_build_jaxlib_plugin_wheels", + flag_values = { + ":build_jaxlib": "plugin_wheels", + }, +) + # The flag controls whether jax should be built by Bazel. # If ":build_jax=true", then jax will be built. # If ":build_jax=false", then jax is not built. It is assumed that the pre-built jax wheel diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index fac64974a231..5a7f27fb6246 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -20,6 +20,7 @@ load("@jax_wheel//:wheel.bzl", "WHEEL_VERSION") load("@jax_wheel_version_suffix//:wheel_version_suffix.bzl", "WHEEL_VERSION_SUFFIX") load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", _if_cuda_is_configured = "if_cuda_is_configured") load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library") +load("@plugin_wheel_deps//:deps.bzl", "PLUGIN_WHEEL_DEPS") load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION", "HERMETIC_PYTHON_VERSION_KIND") load("@rules_cc//cc:defs.bzl", _cc_proto_library = "cc_proto_library") load("@rules_python//python:defs.bzl", "py_library", "py_test") @@ -173,6 +174,7 @@ def if_building_jaxlib( return select({ "//jax:config_build_jaxlib_true": if_building, "//jax:config_build_jaxlib_false": if_not_building, + "//jax:config_build_jaxlib_plugin_wheels": [], "//jax:config_build_jaxlib_wheel": [], }) @@ -181,6 +183,7 @@ def _cpu_test_deps(): return select({ "//jax:config_build_jaxlib_true": [], "//jax:config_build_jaxlib_false": ["@pypi//jaxlib"], + "//jax:config_build_jaxlib_plugin_wheels": [], "//jax:config_build_jaxlib_wheel": ["//jaxlib/tools:jaxlib_py_import"], }) @@ -192,9 +195,10 @@ def _gpu_test_deps(): "//jaxlib/rocm:gpu_only_test_deps", "//jax_plugins:gpu_plugin_only_test_deps", ], + "//jax:config_build_jaxlib_plugin_wheels": PLUGIN_WHEEL_DEPS, "//jax:config_build_jaxlib_false": if_rocm_is_configured([ - "@jax_rocm_plugin//:pjrt.whl", - "@jax_rocm_plugin//:plugin.whl", + "//jaxlib/tools:rocm_plugin_kernels_wheel", + "//jaxlib/tools:rocm_plugin_pjrt_wheel", ]) + if_cuda_is_configured([ "//jaxlib/tools:pypi_jax_cuda_plugin_with_cuda_deps", "//jaxlib/tools:pypi_jax_cuda_pjrt_with_cuda_deps", diff --git a/third_party/plugins/BUILD.bazel b/third_party/plugins/BUILD.bazel new file mode 100644 index 000000000000..c200f4efb42b --- /dev/null +++ b/third_party/plugins/BUILD.bazel @@ -0,0 +1,15 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) # Apache 2.0 diff --git a/third_party/plugins/workspace.bzl b/third_party/plugins/workspace.bzl new file mode 100644 index 000000000000..a3e10521eb49 --- /dev/null +++ b/third_party/plugins/workspace.bzl @@ -0,0 +1,113 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Repository rule to parse plugin wheel dependencies from environment variable. + +This module provides a repository rule that reads the PLUGIN_WHEEL_DEPS +environment variable and generates a .bzl file containing a list of targets +that can be used as dependencies. + +The environment variable should contain a comma-separated list of targets, e.g.: + PLUGIN_WHEEL_DEPS=@jax_rocm_plugin//:plugin.whl,@jax_rocm_plugin//:pjrt.whl + +Usage: + # In WORKSPACE: + load("//third_party/plugins:workspace.bzl", "plugin_wheel_deps_repository") + plugin_wheel_deps_repository(name = "plugin_wheel_deps") + + # In BUILD files: + load("@plugin_wheel_deps//:deps.bzl", "PLUGIN_WHEEL_DEPS") + py_library( + name = "my_target", + deps = PLUGIN_WHEEL_DEPS, + ) +""" + +def _plugin_wheel_deps_repository_impl(repository_ctx): + """Implementation of the plugin_wheel_deps_repository rule. + + Reads the PLUGIN_WHEEL_DEPS environment variable and generates a deps.bzl + file containing a list of targets. + + Args: + repository_ctx: The repository context. + """ + env_var_name = repository_ctx.attr.env_var + deps_env = repository_ctx.os.environ.get(env_var_name, "") + + # Parse the comma-separated list of targets + deps_list = [] + if deps_env: + for dep in deps_env.split(","): + dep = dep.strip() + if dep: + deps_list.append(dep) + + # Generate the deps.bzl file with the list of targets + deps_content = """\ +# Auto-generated file. Do not edit. +# Generated from environment variable: {env_var} + +# List of plugin wheel dependency targets. +# These targets can be used as dependencies in BUILD files. +PLUGIN_WHEEL_DEPS = [ +{deps} +] +""".format( + env_var = env_var_name, + deps = "\n".join([' "{}",'.format(dep) for dep in deps_list]), + ) + + repository_ctx.file("deps.bzl", deps_content) + repository_ctx.file("BUILD.bazel", "# Auto-generated BUILD file\n") + +plugin_wheel_deps_repository = repository_rule( + implementation = _plugin_wheel_deps_repository_impl, + attrs = { + "env_var": attr.string( + default = "PLUGIN_WHEEL_DEPS", + doc = "The name of the environment variable containing the comma-separated list of targets.", + ), + }, + environ = ["PLUGIN_WHEEL_DEPS"], + doc = """Repository rule to parse plugin wheel dependencies from an environment variable. + +Reads a comma-separated list of Bazel targets from the specified environment +variable (default: PLUGIN_WHEEL_DEPS) and generates a deps.bzl file that +exports a PLUGIN_WHEEL_DEPS list containing these targets. + +Example: + # Set the environment variable: + export PLUGIN_WHEEL_DEPS="@jax_rocm_plugin//:plugin.whl,@jax_rocm_plugin//:pjrt.whl" + + # In WORKSPACE: + load("//third_party/plugins:workspace.bzl", "plugin_wheel_deps_repository") + plugin_wheel_deps_repository(name = "plugin_wheel_deps") + + # In BUILD files: + load("@plugin_wheel_deps//:deps.bzl", "PLUGIN_WHEEL_DEPS") + py_library( + name = "my_target", + deps = PLUGIN_WHEEL_DEPS, + ) +""", +) + +def repo(name = "plugin_wheel_deps"): + """Convenience function to create the plugin wheel deps repository. + + Args: + name: The name of the repository (default: "plugin_wheel_deps"). + """ + plugin_wheel_deps_repository(name = name) From 42a02dc0db18cbed651b955d32b8c2bd9274cdfb Mon Sep 17 00:00:00 2001 From: Alexandros Theodoridis Date: Wed, 4 Feb 2026 15:23:59 +0000 Subject: [PATCH 4/4] Support --- jax/BUILD | 7 ------- jaxlib/jax.bzl | 4 ---- jaxlib/rocm/BUILD | 11 ++++++----- jaxlib/tools/BUILD.bazel | 15 --------------- 4 files changed, 6 insertions(+), 31 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 2a7a368ed130..4630f36c02c1 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -72,13 +72,6 @@ config_setting( }, ) -config_setting( - name = "config_build_jaxlib_plugin_wheels", - flag_values = { - ":build_jaxlib": "plugin_wheels", - }, -) - # The flag controls whether jax should be built by Bazel. # If ":build_jax=true", then jax will be built. # If ":build_jax=false", then jax is not built. It is assumed that the pre-built jax wheel diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 5a7f27fb6246..55b97480822d 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -20,7 +20,6 @@ load("@jax_wheel//:wheel.bzl", "WHEEL_VERSION") load("@jax_wheel_version_suffix//:wheel_version_suffix.bzl", "WHEEL_VERSION_SUFFIX") load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", _if_cuda_is_configured = "if_cuda_is_configured") load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library") -load("@plugin_wheel_deps//:deps.bzl", "PLUGIN_WHEEL_DEPS") load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION", "HERMETIC_PYTHON_VERSION_KIND") load("@rules_cc//cc:defs.bzl", _cc_proto_library = "cc_proto_library") load("@rules_python//python:defs.bzl", "py_library", "py_test") @@ -174,7 +173,6 @@ def if_building_jaxlib( return select({ "//jax:config_build_jaxlib_true": if_building, "//jax:config_build_jaxlib_false": if_not_building, - "//jax:config_build_jaxlib_plugin_wheels": [], "//jax:config_build_jaxlib_wheel": [], }) @@ -183,7 +181,6 @@ def _cpu_test_deps(): return select({ "//jax:config_build_jaxlib_true": [], "//jax:config_build_jaxlib_false": ["@pypi//jaxlib"], - "//jax:config_build_jaxlib_plugin_wheels": [], "//jax:config_build_jaxlib_wheel": ["//jaxlib/tools:jaxlib_py_import"], }) @@ -195,7 +192,6 @@ def _gpu_test_deps(): "//jaxlib/rocm:gpu_only_test_deps", "//jax_plugins:gpu_plugin_only_test_deps", ], - "//jax:config_build_jaxlib_plugin_wheels": PLUGIN_WHEEL_DEPS, "//jax:config_build_jaxlib_false": if_rocm_is_configured([ "//jaxlib/tools:rocm_plugin_kernels_wheel", "//jaxlib/tools:rocm_plugin_pjrt_wheel", diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index 1b8e8dd1e64b..1a4b436ae25c 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -14,6 +14,7 @@ # AMD HIP kernels +load("@plugin_wheel_deps//:deps.bzl", "PLUGIN_WHEEL_DEPS") load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_python//python:defs.bzl", "py_library") load( @@ -391,6 +392,10 @@ nanobind_extension( "-fno-strict-aliasing", ], features = ["-use_header_modules"], + linkopts = [ + "-L/opt/rocm/lib", + "-lamdhip64", + ], module_name = "_hybrid", deps = [ ":hip_gpu_kernel_helpers", @@ -404,10 +409,6 @@ nanobind_extension( "@nanobind", "@xla//xla/ffi/api:ffi", ], - linkopts = [ - "-L/opt/rocm/lib", - "-lamdhip64", - ], ) cc_library( @@ -551,5 +552,5 @@ py_library( deps = if_rocm_is_configured([ ":rocm_gpu_support", ":rocm_plugin_extension", - ]), + ] + PLUGIN_WHEEL_DEPS), ) diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 5da6c7edb7d7..f8347b64fbe8 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -137,21 +137,6 @@ string_flag( py_binary( name = "build_wheel_tool", srcs = ["build_wheel.py"], - data = [ - "LICENSE.txt", - "//jaxlib", - "//jaxlib:README.md", - "//jaxlib:_ifrt_proxy.pyi", - "//jaxlib:_pathways.pyi", - "//jaxlib:init.py", - "//jaxlib:jaxlib_binaries", - "//jaxlib:setup.py", - "//jaxlib:weakref_lru_cache.pyi", - "//jaxlib/mlir/_mlir_libs:_triton_ext.pyi", - "//jaxlib/mosaic/dialect/gpu:_mosaic_gpu_gen_enums.py", - "//jaxlib/mosaic/dialect/gpu:_mosaic_gpu_gen_ops.py", - "@xla//xla/ffi/api:ffi", - ], main = "build_wheel.py", deps = [ ":build_utils",