Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")

flatbuffers()

load("//third_party/plugins:workspace.bzl", "plugin_wheel_deps_repository")

plugin_wheel_deps_repository(name = "plugin_wheel_deps")

load("//:test_shard_count.bzl", "test_shard_count_repository")

test_shard_count_repository(
Expand Down
1 change: 1 addition & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ string_flag(
"true",
"false",
"wheel",
"plugin_wheels",
],
)

Expand Down
27 changes: 19 additions & 8 deletions jaxlib/jax.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,13 @@ def _gpu_test_deps():
"//jaxlib/rocm:gpu_only_test_deps",
"//jax_plugins:gpu_plugin_only_test_deps",
],
"//jax:config_build_jaxlib_false": [
"//jax:config_build_jaxlib_false": if_rocm_is_configured([
"//jaxlib/tools:rocm_plugin_kernels_wheel",
"//jaxlib/tools:rocm_plugin_pjrt_wheel",
]) + if_cuda_is_configured([
"//jaxlib/tools:pypi_jax_cuda_plugin_with_cuda_deps",
"//jaxlib/tools:pypi_jax_cuda_pjrt_with_cuda_deps",
],
]),
"//jax:config_build_jaxlib_wheel": [
"//jaxlib/tools:jax_cuda_plugin_py_import",
"//jaxlib/tools:jax_cuda_pjrt_py_import",
Expand Down Expand Up @@ -304,6 +307,7 @@ def jax_multiplatform_test(
tags = test_tags,
main = main,
exec_properties = tf_exec_properties({"tags": test_tags}),
legacy_create_init = 0,
)

def jax_generate_backend_suites(backends = []):
Expand Down Expand Up @@ -438,6 +442,9 @@ def _jax_wheel_impl(ctx):
if ctx.attr.skip_gpu_kernels:
args.add("--skip_gpu_kernels")

for extra_arg in ctx.attr.extra_args:
args.add(extra_arg)

srcs = []
for src in ctx.attr.source_files:
for f in src.files.to_list():
Expand Down Expand Up @@ -483,6 +490,7 @@ _jax_wheel = rule(
"include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:include_cuda_libs")),
"override_include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:override_include_cuda_libs")),
"py_freethreaded": attr.label(default = Label("@rules_python//python/config_settings:py_freethreaded")),
"extra_args": attr.string_list(default = []),
},
implementation = _jax_wheel_impl,
executable = False,
Expand All @@ -498,7 +506,8 @@ def jax_wheel(
enable_cuda = False,
enable_rocm = False,
platform_version = "",
source_files = []):
source_files = [],
extra_args = []):
"""Create jax artifact wheels.

Common artifact attributes are grouped within a single macro.
Expand All @@ -514,6 +523,7 @@ def jax_wheel(
enable_rocm: whether to build a rocm wheel
platform_version: the cuda version to use for the wheel
source_files: the source files to include in the wheel
extra_args: additional arguments to pass to the wheel binary

Returns:
A wheel file or a wheel directory.
Expand All @@ -540,13 +550,14 @@ def jax_wheel(
}),
# TODO(kanglan) Add @platforms//cpu:ppc64le once JAX Bazel is upgraded > 6.5.0.
cpu = select({
"//jaxlib/tools:macos_arm64": "arm64",
"//jaxlib/tools:macos_x86_64": "x86_64",
"//jaxlib/tools:win_amd64": "AMD64",
"//jaxlib/tools:linux_aarch64": "aarch64",
"//jaxlib/tools:linux_x86_64": "x86_64",
"@jax//jaxlib/tools:macos_arm64": "arm64",
"@jax//jaxlib/tools:macos_x86_64": "x86_64",
"@jax//jaxlib/tools:win_amd64": "AMD64",
"@jax//jaxlib/tools:linux_aarch64": "aarch64",
"@jax//jaxlib/tools:linux_x86_64": "x86_64",
}),
source_files = source_files,
extra_args = extra_args,
)

def jax_source_package(
Expand Down
11 changes: 6 additions & 5 deletions jaxlib/rocm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

# AMD HIP kernels

load("@plugin_wheel_deps//:deps.bzl", "PLUGIN_WHEEL_DEPS")
load("@rules_cc//cc:cc_library.bzl", "cc_library")
load("@rules_python//python:defs.bzl", "py_library")
load(
Expand Down Expand Up @@ -391,6 +392,10 @@ nanobind_extension(
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
linkopts = [
"-L/opt/rocm/lib",
"-lamdhip64",
],
module_name = "_hybrid",
deps = [
":hip_gpu_kernel_helpers",
Expand All @@ -404,10 +409,6 @@ nanobind_extension(
"@nanobind",
"@xla//xla/ffi/api:ffi",
],
linkopts = [
"-L/opt/rocm/lib",
"-lamdhip64",
],
)

cc_library(
Expand Down Expand Up @@ -551,5 +552,5 @@ py_library(
deps = if_rocm_is_configured([
":rocm_gpu_support",
":rocm_plugin_extension",
]),
] + PLUGIN_WHEEL_DEPS),
)
11 changes: 11 additions & 0 deletions jaxlib/tools/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ py_binary(
main = "build_wheel.py",
deps = [
":build_utils",
"//jaxlib:jaxlib_files",
"@bazel_tools//tools/python/runfiles",
"@pypi//build",
"@pypi//setuptools",
Expand Down Expand Up @@ -451,6 +452,16 @@ py_import(
wheel = ":jaxlib_wheel",
)

py_import(
name = "rocm_plugin_kernels_wheel",
wheel = "@jax_rocm_plugin//:jax_rocm7_plugin_wheel",
)

py_import(
name = "rocm_plugin_pjrt_wheel",
wheel = "@jax_rocm_plugin//:jax_rocm7_pjrt_wheel",
)

py_import(
name = "jax_cuda_plugin_py_import",
wheel = ":jax_cuda{cuda}_plugin_wheel".format(cuda = cuda_major_version),
Expand Down
42 changes: 34 additions & 8 deletions jaxlib/tools/build_wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@
parser.add_argument(
"--srcs", help="source files for the wheel", action="append"
)
parser.add_argument(
"--build_from_external_workspace",
action="store_true",
help="Set when building from an external/umbrella workspace where jax is @jax. "
"When true, source prefix is 'jax/' instead of '__main__/'.",
)
args = parser.parse_args()

r = runfiles.Create()
Expand Down Expand Up @@ -116,7 +122,7 @@ def verify_mac_libraries_dont_reference_chkstack(
if not _is_mac():
return
file_path = _get_file_path(
f"__main__/jaxlib/_jax.{pyext}", runfiles, wheel_sources_map
f"{jax_source_prefix}jaxlib/_jax.{pyext}", runfiles, wheel_sources_map
)
nm = subprocess.run(
["nm", "-g", file_path],
Expand Down Expand Up @@ -147,8 +153,16 @@ def write_setup_cfg(sources_path, cpu):


def prepare_wheel(wheel_sources_path: pathlib.Path, *, cpu, wheel_sources):
"""Assembles a source tree for the wheel in `wheel_sources_path`."""
source_file_prefix = build_utils.get_source_file_prefix(wheel_sources)
"""Assembles a source tree for the wheel in `wheel_sources_path`. In case of
build under @jax strip the prefix"""
# wheel_sources is a list of file paths, not a prefix. If provided, use empty prefix.
# Otherwise, determine prefix based on build_from_external_workspace flag.
if wheel_sources:
source_file_prefix = ""
elif args.build_from_external_workspace:
source_file_prefix = "jax/"
else:
source_file_prefix = "__main__/"
# The wheel sources provided by the transitive rules might have different path
# prefixes, so we need to create a map of paths relative to the root package
# to the full paths.
Expand Down Expand Up @@ -389,13 +403,25 @@ def prepare_wheel(wheel_sources_path: pathlib.Path, *, cpu, wheel_sources):
wheel_sources_map=wheel_sources_map,
)

# XLA FFI headers have different paths depending on build method:
# - wheel_sources_map: "xla/ffi/api/..." (mapped from external/xla/xla/ffi/api/...)
# - runfiles: "{source_file_prefix}jaxlib/include/xla/ffi/api/..." (copied to jaxlib/include/)
if args.build_from_external_workspace:
xla_ffi_files = [
f"{source_file_prefix}jaxlib/include/xla/ffi/api/c_api.h",
f"{source_file_prefix}jaxlib/include/xla/ffi/api/api.h",
f"{source_file_prefix}jaxlib/include/xla/ffi/api/ffi.h",
]
else:
xla_ffi_files = [
"xla/ffi/api/c_api.h",
"xla/ffi/api/api.h",
"xla/ffi/api/ffi.h",
]

copy_files(
dst_dir=jaxlib_dir / "include" / "xla" / "ffi" / "api",
src_files=[
"xla/ffi/api/c_api.h",
"xla/ffi/api/api.h",
"xla/ffi/api/ffi.h",
],
src_files=xla_ffi_files,
)

tmpdir = None
Expand Down
15 changes: 15 additions & 0 deletions third_party/plugins/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

licenses(["notice"]) # Apache 2.0
113 changes: 113 additions & 0 deletions third_party/plugins/workspace.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Repository rule to parse plugin wheel dependencies from environment variable.

This module provides a repository rule that reads the PLUGIN_WHEEL_DEPS
environment variable and generates a .bzl file containing a list of targets
that can be used as dependencies.

The environment variable should contain a comma-separated list of targets, e.g.:
PLUGIN_WHEEL_DEPS=@jax_rocm_plugin//:plugin.whl,@jax_rocm_plugin//:pjrt.whl

Usage:
# In WORKSPACE:
load("//third_party/plugins:workspace.bzl", "plugin_wheel_deps_repository")
plugin_wheel_deps_repository(name = "plugin_wheel_deps")

# In BUILD files:
load("@plugin_wheel_deps//:deps.bzl", "PLUGIN_WHEEL_DEPS")
py_library(
name = "my_target",
deps = PLUGIN_WHEEL_DEPS,
)
"""

def _plugin_wheel_deps_repository_impl(repository_ctx):
"""Implementation of the plugin_wheel_deps_repository rule.

Reads the PLUGIN_WHEEL_DEPS environment variable and generates a deps.bzl
file containing a list of targets.

Args:
repository_ctx: The repository context.
"""
env_var_name = repository_ctx.attr.env_var
deps_env = repository_ctx.os.environ.get(env_var_name, "")

# Parse the comma-separated list of targets
deps_list = []
if deps_env:
for dep in deps_env.split(","):
dep = dep.strip()
if dep:
deps_list.append(dep)

# Generate the deps.bzl file with the list of targets
deps_content = """\
# Auto-generated file. Do not edit.
# Generated from environment variable: {env_var}

# List of plugin wheel dependency targets.
# These targets can be used as dependencies in BUILD files.
PLUGIN_WHEEL_DEPS = [
{deps}
]
""".format(
env_var = env_var_name,
deps = "\n".join([' "{}",'.format(dep) for dep in deps_list]),
)

repository_ctx.file("deps.bzl", deps_content)
repository_ctx.file("BUILD.bazel", "# Auto-generated BUILD file\n")

plugin_wheel_deps_repository = repository_rule(
implementation = _plugin_wheel_deps_repository_impl,
attrs = {
"env_var": attr.string(
default = "PLUGIN_WHEEL_DEPS",
doc = "The name of the environment variable containing the comma-separated list of targets.",
),
},
environ = ["PLUGIN_WHEEL_DEPS"],
doc = """Repository rule to parse plugin wheel dependencies from an environment variable.

Reads a comma-separated list of Bazel targets from the specified environment
variable (default: PLUGIN_WHEEL_DEPS) and generates a deps.bzl file that
exports a PLUGIN_WHEEL_DEPS list containing these targets.

Example:
# Set the environment variable:
export PLUGIN_WHEEL_DEPS="@jax_rocm_plugin//:plugin.whl,@jax_rocm_plugin//:pjrt.whl"

# In WORKSPACE:
load("//third_party/plugins:workspace.bzl", "plugin_wheel_deps_repository")
plugin_wheel_deps_repository(name = "plugin_wheel_deps")

# In BUILD files:
load("@plugin_wheel_deps//:deps.bzl", "PLUGIN_WHEEL_DEPS")
py_library(
name = "my_target",
deps = PLUGIN_WHEEL_DEPS,
)
""",
)

def repo(name = "plugin_wheel_deps"):
"""Convenience function to create the plugin wheel deps repository.

Args:
name: The name of the repository (default: "plugin_wheel_deps").
"""
plugin_wheel_deps_repository(name = name)