Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions jax_rocm_plugin/.bazelrc
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
# #############################################################################
# All default build options below. These apply to all build commands.
# #############################################################################
# TODO: Enable Bzlmod
common --noenable_bzlmod
# Bzlmod enabled - JAX and XLA from upstream via MODULE.bazel
common --enable_bzlmod

common:bzlmod --enable_bzlmod
common:bzlmod --noenable_workspace
common:bzlmod --check_direct_dependencies=error

# Make Bazel print out all options from rc files.
common --announce_rc
Expand Down Expand Up @@ -89,6 +93,7 @@ build:rocm_base --config=clang_local
build:rocm_base --crosstool_top=@local_config_rocm//crosstool:toolchain
build:rocm_base --define=using_rocm=true --define=using_rocm_hipcc=true
build:rocm_base --repo_env TF_NEED_ROCM=1
build:rocm_base --repo_env ROCM_PATH="/opt/rocm"
build:rocm_base --action_env TF_ROCM_AMDGPU_TARGETS="gfx908,gfx90a,gfx942,gfx950,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201"

# Build with hipcc for ROCm and clang for the host.
Expand All @@ -99,7 +104,6 @@ build:rocm --copt=-Wno-gnu-offsetof-extensions
build:rocm --copt=-Qunused-arguments
build:rocm --action_env=TF_HIPCC_CLANG="1"


#############################################################################
# Configuration for running RBE builds and tests
#############################################################################
Expand Down
113 changes: 113 additions & 0 deletions jax_rocm_plugin/MODULE.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Bazel module definition for JAX ROCm Plugin.

Uses upstream JAX and XLA via bzlmod.
"""

module(
name = "jax_rocm_plugin",
version = "0.1.0",
)

##############################################################
# Core Bazel dependencies
##############################################################

bazel_dep(name = "bazel_skylib", version = "1.8.1")
bazel_dep(name = "rules_python", version = "1.6.1")

##############################################################
# JAX - upstream (from jax-ml/jax main)
# XLA and rules_ml_toolchain versions are synced from JAX's MODULE.bazel
##############################################################

bazel_dep(name = "jax")
archive_override(
module_name = "jax",
integrity = "sha256-vd8B9HSmEF2KZLbJhv5iINEikZAa6rnGD2vxXxMRGMM=",
strip_prefix = "jax-b19be1aa34969e312b7ec30abbae828ec35e1d12",
urls = ["https://github.com/jax-ml/jax/archive/b19be1aa34969e312b7ec30abbae828ec35e1d12.tar.gz"],
)

# XLA - synced from JAX's MODULE.bazel
bazel_dep(name = "xla")
archive_override(
module_name = "xla",
integrity = "sha256-FmmcrZgng83LaOHnuYpTmouqcQdCh5jbt0tnUQQdj9g=",
strip_prefix = "xla-ed953c01bb51f95a36abd907d1a64295feef16fc",
urls = ["https://github.com/openxla/xla/archive/ed953c01bb51f95a36abd907d1a64295feef16fc.tar.gz"],
)

# rules_ml_toolchain - synced from JAX's MODULE.bazel
bazel_dep(name = "rules_ml_toolchain")
archive_override(
module_name = "rules_ml_toolchain",
integrity = "sha256-s6lSdgWx8xRrtDEin6H4ui5EFI7S653N6rgAy+SnA/Y=",
strip_prefix = "rules_ml_toolchain-469be4eea388140207e2a31b8f4d4d612532fcde",
urls = ["https://github.com/google-ml-infra/rules_ml_toolchain/archive/469be4eea388140207e2a31b8f4d4d612532fcde.tar.gz"],
)

##############################################################
# Patches for transitive dependencies (from JAX's MODULE.bazel)
##############################################################

single_version_override(
module_name = "grpc",
patch_strip = 1,
patches = ["//third_party/grpc:grpc.patch"],
)

##############################################################
# Use extensions from XLA (via JAX) for ROCm configuration
##############################################################

rocm = use_extension("@xla//third_party/extensions:rocm_configure.bzl", "rocm_configure_ext")
use_repo(rocm, "local_config_rocm")

##############################################################
# Toolchain registration
##############################################################

register_toolchains("@rules_ml_toolchain//cc:linux_x86_64_linux_x86_64")

##############################################################
# Python dependencies
##############################################################

python = use_extension("@rules_python//python/extensions:python.bzl", "python")
python.defaults(
python_version = "3.11",
python_version_env = "HERMETIC_PYTHON_VERSION",
)
python.toolchain(python_version = "3.11")
python.toolchain(python_version = "3.12")
python.toolchain(python_version = "3.13")
python.toolchain(python_version = "3.14")

pip = use_extension("@rules_python//python/extensions:pip.bzl", "pip")

[pip.parse(
hub_name = "rocm_plugin_pypi",
python_version = python_version,
requirements_lock = "//build:requirements_lock_{}.txt".format(python_version.replace(".", "_")),
) for python_version in [
"3.11",
"3.12",
"3.13",
"3.14",
]]

use_repo(pip, pypi = "rocm_plugin_pypi")
Loading
Loading