From 76381664902493ff22a197856da09f9c5e24d4c3 Mon Sep 17 00:00:00 2001 From: Charles Hofer Date: Wed, 21 Jan 2026 19:06:24 +0000 Subject: [PATCH] Add script for running bazel tests on ROCm --- .bazelrc | 26 ++++ .../download-jax-rocm-wheels/action.yml | 81 ++++++++++++ .github/workflows/bazel_rocm.yml | 120 ++++++++++++++++++ .../workflows/wheel_tests_nightly_release.yml | 24 ++++ WORKSPACE | 2 + ci/envs/default.env | 2 +- ci/run_bazel_test_rocm_rbe.sh | 76 +++++++++++ jaxlib/jax.bzl | 9 +- jaxlib/tools/BUILD.bazel | 19 +++ 9 files changed, 355 insertions(+), 4 deletions(-) create mode 100644 .github/actions/download-jax-rocm-wheels/action.yml create mode 100644 .github/workflows/bazel_rocm.yml create mode 100755 ci/run_bazel_test_rocm_rbe.sh diff --git a/.bazelrc b/.bazelrc index 371f4b78a376..b33c6b4a5007 100644 --- a/.bazelrc +++ b/.bazelrc @@ -452,6 +452,32 @@ common:rbe_windows_amd64 --nobuild_python_zip common:rbe_windows_amd64 --config=ci_windows_amd64 +# RBE configs for ROCm +build:rocm_rbe --tls_client_certificate="ci-cert.crt" +build:rocm_rbe --tls_client_key="ci-cert.key" + +build:rocm_rbe --bes_backend="grpcs://wardite.cluster.engflow.com" +build:rocm_rbe --bes_results_url="https://wardite.cluster.engflow.com/invocation/" +build:rocm_rbe --remote_executor="grpcs://wardite.cluster.engflow.com" +build:rocm_rbe --remote_cache="grpcs://wardite.cluster.engflow.com" +build:rocm_rbe --host_platform="//platform/linux:manylinux" +build:rocm_rbe --extra_execution_platforms="//platform/linux:manylinux" +build:rocm_rbe --platforms="//platform/linux:manylinux" +build:rocm_rbe --bes_timeout=600s +build:rocm_rbe --spawn_strategy=local +build:rocm_rbe --grpc_keepalive_time=30s +build:rocm_rbe --repo_env=REMOTE_GPU_TESTING=1 + +test:rocm_rbe --host_platform="//platform/linux:ubuntu_gpu" +test:rocm_rbe --extra_execution_platforms="//platform/linux:ubuntu_gpu" +test:rocm_rbe --platforms="//platform/linux:ubuntu_gpu" +test:rocm_rbe --remote_timeout=3600 +test:rocm_rbe --jobs=200 +test:rocm_rbe --test_sharding_strategy=disabled +test:rocm_rbe --strategy=TestRunner=remote,local +test:rocm_rbe --worker_sandboxing=true +test:rocm_rbe --repo_env=REMOTE_GPU_TESTING=1 + # ############################################################################# # Cross-compile config options below. Native RBE support does not exist for # Linux Aarch64 and Mac x86. So, we use a cross-compile toolchain to build diff --git a/.github/actions/download-jax-rocm-wheels/action.yml b/.github/actions/download-jax-rocm-wheels/action.yml new file mode 100644 index 000000000000..4a9eb87c06b3 --- /dev/null +++ b/.github/actions/download-jax-rocm-wheels/action.yml @@ -0,0 +1,81 @@ +# Composite action to download the jax, jaxlib, and the ROCM plugin wheels +name: Download JAX ROCM wheels + +inputs: + python: + description: "Which python version should the artifact be downloaded for?" + type: string + required: true + rocm-version: + description: "Which rocm version should the artifact be downloaded for?" + type: string + default: "7" + skip-download-jaxlib-and-rocm-plugins-from-gh: + description: "Whether to skip downloading the jaxlib and rocm plugins from GCS (e.g for testing a jax only release)" + default: '0' + type: string + gh_download_uri: + description: "GCS location prefix from where the artifacts should be downloaded" + default: 'https://github.com/ROCm/rocm-jax/releases/download' + #default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + type: string +permissions: {} +runs: + using: "composite" + + steps: + # Note that certain envs such as JAXCI_HERMETIC_PYTHON_VERSION are set by the calling workflow. + - name: Set env vars for use in artifact download URL + shell: bash + run: | + os=$(uname -s | awk '{print tolower($0)}') + arch=$(uname -m) + + # Get the major and minor version of Python. + # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.11, then python_major_minor=311 + # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.13-nogil, then python_major_minor=313t + python_major_minor=$(echo "${JAXCI_HERMETIC_PYTHON_VERSION//-nogil/t}" | tr -d '.') + + echo "OS=${os}" >> $GITHUB_ENV + echo "ARCH=${arch}" >> $GITHUB_ENV + # Python wheels follow a naming convention: standard wheels use the pattern + # `*-cp-cp-*`, while free-threaded wheels use + # `*-cp-cpt-*`. + echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV + + # Get the ROCM major version only + full_rocm_version="${{ inputs.rocm-version }}" + echo "JAXCI_ROCM_VERSION=${full_rocm_version%%.*}" >> $GITHUB_ENV + - name: Download wheels + shell: bash + id: download-wheel-artifacts + # Set continue-on-error to true to prevent actions from failing the workflow if this step + # fails. Instead, we verify the outcome in the next step so that we can print a more + # informative error message. + continue-on-error: true + run: | + mkdir -p $(pwd)/dist + if [[ "${{ inputs.download-jax-from-gcs }}" == "1" ]]; then + gcloud storage cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/ + else + echo "JAX wheel won't be downloaded, only jaxlib pre-built wheel is tested." + fi + + # Do not download the jaxlib and ROCM plugin artifacts if we are testing a jax only + # release. + if [[ "${{ inputs.skip-download-jaxlib-and-rocm-plugins-from-gh }}" == "1" ]]; then + echo "JAX only release. Only downloading the jax wheel from the release bucket." + else + wget -P $(pwd)/dist/ "${{ inputs.gh_download_uri }}/jaxlib-0.8.2+rocm${{ inputs.rocm-version }}-cp3${PYTHON_MAJOR_VERSION}-cp3${PYTHON_MAJOR_VERSION}-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl" + wget -P $(pwd)/dist/ "${{ inputs.gh_download_uri }}/jax_rocm${JAXCI_ROCM_VERSION}_pjrt-0.8.2+rocm${{ inputs.rocm-version }}-py3${PYTHON_MAJOR_VERSION}-none-manylinux_2_28_x86_64.whl" + wget -P $(pwd)/dist/ "${{ inputs.gh_download_uri }}/jax_rocm${JAXCI_ROCM_VERSION}_plugin-0.8.2+rocm${{ inputs.rocm-version }}-cp3${PYTHON_MAJOR_VERSION}-cp3${PYTHON_MAJOR_VERSION}-manylinux_2_28_x86_64.whl" + fi + - name: Skip the test run if the wheel artifacts were not downloaded successfully + shell: bash + if: steps.download-wheel-artifacts.outcome == 'failure' + run: | + echo "Failed to download wheel artifacts. Please check if the wheels were" + echo "built successfully by the artifact build jobs and are available in the GCS bucket if + echo "downloading from GCS." + echo "Skipping the test run." + exit 1 diff --git a/.github/workflows/bazel_rocm.yml b/.github/workflows/bazel_rocm.yml new file mode 100644 index 000000000000..89f35cce6ce2 --- /dev/null +++ b/.github/workflows/bazel_rocm.yml @@ -0,0 +1,120 @@ +# CI - Bazel ROCM tests +# +# This workflow runs the CUDA tests with Bazel. It can only be triggered by other workflows via +# `workflow_call`. It is used by the `CI - Bazel ROCM tests (RBE)`,`CI - Wheel Tests (Continuous)` +# and `CI - Wheel Tests (Nightly/Release)` workflows to run the Bazel CUDA tests. +# +# It consists of the following job: +# run-tests: +# - Downloads the jaxlib and ROCM artifacts from GitHub if build_jaxlib is `false`. +# Otherwise, the artifacts are built from source. +# - Downloads the jax artifact from a GCS bucket if build_jax is `false`. +# Otherwise, the artifact is built from source. +# - If `run_multiaccelerator_tests` is `false`, executes the `run_bazel_test_rocm_rbe.sh` script, +# which performs the following actions: +# - `build_jaxlib=wheel`: Runs the Bazel CPU tests with py_import dependencies. +# - `build_jaxlib=false`: Runs the Bazel CPU tests with downloaded wheel dependencies. +# - `build_jaxlib=true`: Runs the Bazel CPU tests with individual Bazel target dependencies. +# - If `run_multiaccelerator_tests` is `true`, executes the `run_bazel_test_rocm_non_rbe.sh` +# script, which performs the following actions: +# - `build_jaxlib=wheel`: Runs the Bazel CPU tests with py_import dependencies. +# - `build_jaxlib=false`: Runs the Bazel CPU tests with downloaded wheel dependencies. + +name: CI - Bazel CUDA tests + +on: + workflow_call: + inputs: + runner: + description: "Which runner should the workflow run on?" + type: string + default: "linux-x86-n4-16" + python: + description: "Which python version to test?" + type: string + default: "3.12" + rocm-version: + description: "Which ROCM version to test?" + type: string + default: "7" + enable-x64: + description: "Should x64 mode be enabled?" + type: string + default: "0" + download-jax-from-gcs: + description: "Whether to download the jax wheel from GH" + default: '1' + type: string + skip-download-jaxlib-and-rocm-plugins-from-gh: + description: "Whether to skip downloading the jaxlib and rocm plugins from GH (e.g for testing a jax only release)" + default: '0' + type: string + gh_download_uri: + description: "GH location URI from where the artifacts should be downloaded" + default: 'https://github.com/ROCm/rocm-jax/releases/download' + type: string + build_jaxlib: + description: 'Should jaxlib be built from source?' + required: true + type: string + build_jax: + description: 'Should jax be built from source?' + required: true + type: string + write_to_bazel_remote_cache: + description: 'Whether to enable writing to the Bazel remote cache bucket' + required: false + default: '0' + type: string + run_multiaccelerator_tests: + description: 'Whether to run multi-accelerator tests' + required: false + default: 'false' + type: string + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: string + default: 'no' +permissions: {} +jobs: + run-tests: + defaults: + run: + # Explicitly set the shell to bash + shell: bash + runs-on: ${{ inputs.runner }} + container: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest" + + env: + JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }} + JAXCI_ENABLE_X64: ${{ inputs.enable-x64 }} + JAXCI_ROCM_VERSION: ${{ inputs.rocm-version }} + JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE: ${{ inputs.write_to_bazel_remote_cache }} + JAXCI_BUILD_JAX: ${{ inputs.build_jax }} + JAXCI_BUILD_JAXLIB: ${{ inputs.build_jaxlib }} +# Begin Presubmit Naming Check - name modification requires internal check to be updated + name: "${{ (contains(inputs.runner, 'linux-x86') && 'linux x86') || + (contains(inputs.runner, 'linux-arm64') && 'linux arm64') || + (contains(inputs.runner, 'windows-x86') && 'windows x86') }}, jaxlib=${{ inputs.jaxlib-version }}, CUDA=${{ inputs.rocm-version }}, Python=${{ inputs.python }}, x64=${{ inputs.enable-x64 }}, build_jax=${{ inputs.build_jax }}, build_jaxlib=${{ inputs.build_jaxlib }}" +# End Presubmit Naming Check github-cuda-presubmits + steps: + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + with: + persist-credentials: false + - name: Download JAX ROCM wheels + if: inputs.build_jaxlib == 'false' + uses: ./.github/actions/download-jax-rocm-wheels + with: + python: ${{ inputs.python }} + rocm-version: ${{ inputs.rocm-version }} + download-jax-from-gh: ${{ inputs.download-jax-from-gh }} + skip-download-jaxlib-and-rocm-plugins-from-gh: ${{ inputs.skip-download-jaxlib-and-rocm-plugins-from-gh }} + gh_download_uri: ${{ inputs.gh_download_uri }} + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: "Bazel ROCM tests with build_jax=${{ inputs.build_jax }}, build_jaxlib=${{ inputs.build_jaxlib }}" + timeout-minutes: 60 + run: ${{ ((inputs.run_multiaccelerator_tests == 'false') && './ci/run_bazel_test_rocm_rbe.sh') || './ci/run_bazel_test_rocm_non_rbe.sh' }} diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index 07592cd0e70c..d175c029147c 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -168,6 +168,30 @@ jobs: write_to_bazel_remote_cache: 1 run_multiaccelerator_tests: "true" + run-bazel-test-cuda: + uses: ./.github/workflows/bazel_rocm.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + # Runner OS and Python values need to match the matrix stategy of our internal CI jobs + # that build the wheels. + runner: ["linux-x86-g2-48-l4-4gpu"] + python: ["3.11", "3.12", "3.13", "3.13-nogil", "3.14", "3.14-nogil"] + rocm-version: [7] + enable-x64: [0] + name: "Bazel ROCM Non-RBE with ${{ format('{0}', 'build_jaxlib=false') }}" + with: + runner: ${{ matrix.runner }} + python: ${{ matrix.python }} + rocm-version: ${{ matrix.rocm-version }} + enable-x64: ${{ matrix.enable-x64 }} + halt-for-connection: ${{inputs.halt-for-connection}} + build_jaxlib: "false" + build_jax: "false" + jaxlib-version: "head" + write_to_bazel_remote_cache: 1 + run_multiaccelerator_tests: "true" + run-pytest-tpu: uses: ./.github/workflows/pytest_tpu.yml strategy: diff --git a/WORKSPACE b/WORKSPACE index 6b3d0e2aa010..7f2a35f25aba 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -59,6 +59,8 @@ python_init_repositories( "jaxlib*", "jax_cuda*", "jax-cuda*", + "jax_rocm*", + "jax-rocm*", ], local_wheel_workspaces = ["//jaxlib:jax.bzl"], requirements = { diff --git a/ci/envs/default.env b/ci/envs/default.env index c35ec991c276..87cdec2c63db 100644 --- a/ci/envs/default.env +++ b/ci/envs/default.env @@ -89,4 +89,4 @@ export JAXCI_BUILD_JAX=${JAXCI_BUILD_JAX:-true} export JAXCI_BAZEL_OUTPUT_BASE=${JAXCI_BAZEL_OUTPUT_BASE:-} # Controls whether to build or run CPU test targets. -export JAXCI_BAZEL_CPU_RBE_MODE=${JAXCI_BAZEL_CPU_RBE_MODE:-"test"} \ No newline at end of file +export JAXCI_BAZEL_CPU_RBE_MODE=${JAXCI_BAZEL_CPU_RBE_MODE:-"test"} diff --git a/ci/run_bazel_test_rocm_rbe.sh b/ci/run_bazel_test_rocm_rbe.sh new file mode 100755 index 000000000000..135a928823dd --- /dev/null +++ b/ci/run_bazel_test_rocm_rbe.sh @@ -0,0 +1,76 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Runs Bazel GPU tests with RBE. This runs single accelerator tests with one +# GPU apiece on RBE. +# +# -e: abort script if one command fails +# -u: error if undefined variable used +# -x: log all commands +# -o history: record shell history +# -o allexport: export all functions and variables to be available to subscripts +set -exu -o history -o allexport + +# Source default JAXCI environment variables. +source ci/envs/default.env + +# Clone XLA at HEAD if path to local XLA is not provided +if [[ -z "$JAXCI_XLA_GIT_DIR" ]]; then + export JAXCI_CLONE_MAIN_XLA=1 +fi + +# Set up the build environment. +source "ci/utilities/setup_build_environment.sh" + +if [[ "$JAXCI_BUILD_JAXLIB" == "false" ]]; then + WHEEL_SIZE_TESTS="" +else + WHEEL_SIZE_TESTS="//jaxlib/tools:jax_cuda_plugin_wheel_size_test \ + //jaxlib/tools:jax_cuda_pjrt_wheel_size_test \ + //jaxlib/tools:jaxlib_wheel_size_test" +fi + +if [[ "$JAXCI_BUILD_JAX" != "false" ]]; then + WHEEL_SIZE_TESTS="$WHEEL_SIZE_TESTS //:jax_wheel_size_test" +fi + +if [[ "$JAXCI_BUILD_JAXLIB" != "true" ]]; then + #cuda_libs_flag="--config=cuda_libraries_from_stubs" + cuda_libs_flag="" +else + cuda_libs_flag="--@local_config_cuda//cuda:override_include_cuda_libs=true" +fi + +# Run Bazel GPU tests with RBE (single accelerator tests with one GPU apiece). +echo "Running RBE GPU tests..." + +bazel test --config=rocm_rbe \ + --config=rocm \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + --override_repository=xla="${JAXCI_XLA_GIT_DIR}" \ + --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ + --test_output=errors \ + --test_env=TF_CPP_MIN_LOG_LEVEL=0 \ + --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \ + --test_tag_filters=-multiaccelerator \ + --test_env=JAX_SKIP_SLOW_TESTS=true \ + --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ + --color=yes \ + $cuda_libs_flag \ + --//jax:build_jaxlib=$JAXCI_BUILD_JAXLIB \ + --//jax:build_jax=$JAXCI_BUILD_JAX \ + //tests:gpu_tests //tests:backend_independent_tests \ + //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests \ + $WHEEL_SIZE_TESTS diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 28f2367831a7..7a875dcc578d 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -192,10 +192,13 @@ def _gpu_test_deps(): "//jaxlib/rocm:gpu_only_test_deps", "//jax_plugins:gpu_plugin_only_test_deps", ], - "//jax:config_build_jaxlib_false": [ + "//jax:config_build_jaxlib_false": if_cuda_is_configured([ "//jaxlib/tools:pypi_jax_cuda_plugin_with_cuda_deps", "//jaxlib/tools:pypi_jax_cuda_pjrt_with_cuda_deps", - ], + ]) + if_rocm_is_configured([ + "//jaxlib/tools:rocm_plugin_kernels_wheel", + "//jaxlib/tools:rocm_plugin_pjrt_wheel", + ]), "//jax:config_build_jaxlib_wheel": [ "//jaxlib/tools:jax_cuda_plugin_py_import", "//jaxlib/tools:jax_cuda_pjrt_py_import", @@ -303,7 +306,7 @@ def jax_multiplatform_test( shard_count = test_shards, tags = test_tags, main = main, - exec_properties = tf_exec_properties({"tags": test_tags}), + exec_properties = {} #tf_exec_properties({"tags": test_tags}), ) def jax_generate_backend_suites(backends = []): diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 3a1a48736e17..8ec4953f86bf 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -476,6 +476,25 @@ py_import( wheel_deps = if_pypi_cuda_wheel_deps([":nvidia_wheel_deps"]), ) +filegroup( + name = "rocm_wheel_deps", + srcs = [ + ], +) + +# Targets for importing ROCm plugin wheels +py_import( + name = "rocm_plugin_kernels_wheel", + wheel = "@pypi_jax_rocm7_plugin//:whl", + wheel_deps = [":rocm_wheel_deps"], +) + +py_import( + name = "rocm_plugin_pjrt_wheel", + wheel = "@pypi_jax_rocm7_pjrt//:whl", + wheel_deps = [":rocm_wheel_deps"], +) + # Mosaic GPU py_binary(