From 5332b1ba761b7a7b01e7a8da297adfaf3a1a19fb Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Mon, 2 Mar 2026 20:33:00 -0600 Subject: [PATCH 01/14] prepare for JAX 0.9.1 release --- .github/workflows/build-wheels.yml | 2 +- .github/workflows/ci.yml | 4 ++-- .github/workflows/llama-perf.yml | 20 +++++++++---------- .github/workflows/rocm-perf.yml | 8 +++++--- .github/workflows/test-and-upload.yml | 2 +- BUILDING.md | 2 +- build/requirements.txt | 4 ++-- ci/Dockerfile.maxtext | 2 +- .../build/rocm/tools/build_wheels.py | 4 ++-- jax_rocm_plugin/pjrt/python/version.py | 4 ++-- jax_rocm_plugin/third_party/jax/workspace.bzl | 8 +++----- jax_rocm_plugin/third_party/xla/workspace.bzl | 4 ++-- stack.py | 4 ++-- 13 files changed, 34 insertions(+), 34 deletions(-) diff --git a/.github/workflows/build-wheels.yml b/.github/workflows/build-wheels.yml index 814bcc5681..8461b216dc 100644 --- a/.github/workflows/build-wheels.yml +++ b/.github/workflows/build-wheels.yml @@ -31,7 +31,7 @@ on: description: 'JAX git ref (branch, tag, or commit) to checkout for jaxlib build' required: false type: string - default: 'rocm-jaxlib-v0.8.2' + default: 'rocm-jaxlib-v0.9.1' builder-image: required: false type: string diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d9a058be28..fcdfa02c30 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -62,7 +62,7 @@ jobs: # manage tests clean: false repository: rocm/jax - ref: rocm-jaxlib-v0.8.2 + ref: rocm-jaxlib-v0.9.1 path: jax - name: Apply patches to rocm/jax test repo run: | @@ -105,5 +105,5 @@ jobs: run: | python3 build/ci_build test \ "ghcr.io/rocm/jax-ubu24.rocm${ROCM_VERSION//.}:${GITHUB_SHA}" \ - --test-cmd "bash ci/jax_rbe/pr_setup.sh && ci/jax_rbe/pr_test.sh 0.8.2 3.12" + --test-cmd "bash ci/jax_rbe/pr_setup.sh && ci/jax_rbe/pr_test.sh 0.9.1 3.12" diff --git a/.github/workflows/llama-perf.yml b/.github/workflows/llama-perf.yml index e6f727d7f2..1497b0bdbc 100644 --- a/.github/workflows/llama-perf.yml +++ b/.github/workflows/llama-perf.yml @@ -1,10 +1,10 @@ name: Llama Performance Benchmarks # This workflow runs Llama performance benchmarks -# using two different JAX versions: 0.6.0 & 0.8.2 +# using two different JAX versions: 0.6.0 & 0.9.1 # For JAX 0.6.0: uses official released wheels # and Docker image. -# For JAX 0.8.2: uses wheels and Docker image +# For JAX 0.9.1: uses wheels and Docker image # built from the latest nightly results. # PS: Ubuntu 24 & ROCm 7.2.0. @@ -42,17 +42,17 @@ jobs: strategy: fail-fast: false matrix: - jax-version: ["0.6.0", "0.8.2"] + jax-version: ["0.6.0", "0.9.1"] include: - jax-version: "0.6.0" jaxlib-version: "0.6.0" # There is no docker image with ROCm 7.2.0 and Jax 0.6.0 - # use a ROCm 7.2.0 / Jax 0.8.2 docker image and update Jax + # use a ROCm 7.2.0 / Jax 0.9.1 docker image and update Jax # before building TE. For the application workload Bazel installs # its own Jax docker-image: "ghcr.io/rocm/jax-ubu24.rocm720:nightly" - - jax-version: "0.8.2" - jaxlib-version: "0.8.2" + - jax-version: "0.9.1" + jaxlib-version: "0.9.1" docker-image: "ghcr.io/rocm/jax-ubu24.rocm720:nightly" env: NVTE_FRAMEWORK: jax @@ -148,16 +148,16 @@ jobs: fail-fast: false max-parallel: 1 matrix: - jax-version: ["0.6.0", "0.8.2"] + jax-version: ["0.6.0", "0.9.1"] model-name: ["train_dense"] include: - jax-version: "0.6.0" model-name: "train_dense" jaxlib-version: "0.6.0" docker-image: "ghcr.io/rocm/jax-ubu24.rocm720:nightly" - - jax-version: "0.8.2" + - jax-version: "0.9.1" model-name: "train_dense" - jaxlib-version: "0.8.2" + jaxlib-version: "0.9.1" docker-image: "ghcr.io/rocm/jax-ubu24.rocm720:nightly" steps: - name: Checkout source repo @@ -294,7 +294,7 @@ jobs: model-name: "train_dense" rocm-version: "7.2.0" python-version: "3.12" - - jax-version: "0.8.2" + - jax-version: "0.9.1" model-name: "train_dense" rocm-version: "7.2.0" python-version: "3.12" diff --git a/.github/workflows/rocm-perf.yml b/.github/workflows/rocm-perf.yml index ae6d5edea5..168f2883d0 100644 --- a/.github/workflows/rocm-perf.yml +++ b/.github/workflows/rocm-perf.yml @@ -56,7 +56,7 @@ jobs: uses: actions/checkout@v4 with: repository: ROCm/jax - ref: rocm-jaxlib-v0.8.2 + ref: rocm-jaxlib-v0.9.1 path: jax - name: Build plugin wheels @@ -128,7 +128,9 @@ jobs: - name: Analyze logs to compute median step time run: | - pip install numpy --break-system-packages + python3 -m venv venv + source venv/bin/activate + pip install numpy python3 build/analyze_maxtext_logs.py cat summary.json @@ -190,4 +192,4 @@ jobs: --python-version "$PYTHON_VERSION" \ --rocm-version "$ROCM_VERSION" \ --gfx-version gfx942 \ - --jax-version 0.8.2 + --jax-version 0.9.1 diff --git a/.github/workflows/test-and-upload.yml b/.github/workflows/test-and-upload.yml index dfd7951000..b292fd9937 100644 --- a/.github/workflows/test-and-upload.yml +++ b/.github/workflows/test-and-upload.yml @@ -59,7 +59,7 @@ jobs: # TODO: Change the repo and ref once we figure out how exactly we're going to # manage tests repository: rocm/jax - ref: rocm-jaxlib-v0.8.2 + ref: rocm-jaxlib-v0.9.1 path: jax - name: Apply patches to rocm/jax test repo run: | diff --git a/BUILDING.md b/BUILDING.md index f993def2f6..f2f411a8fb 100644 --- a/BUILDING.md +++ b/BUILDING.md @@ -136,7 +136,7 @@ python3 build/ci_build test $TEST_IMAGE --test-cmd "pytest jax_rocm_plugin/tests We keep unit tests in the `rocm/jax` repository, and you'll need to clone it to run the regular JAX unit tests with ROCm, ```shell -git clone --depth 1 --branch rocm-jaxlib-v0.8.2 git@github.com:ROCm/jax.git +git clone --depth 1 --branch rocm-jaxlib-v0.9.1 git@github.com:ROCm/jax.git # Each release of the ROCm plugin has a corresponding branch. You can find # more at https://github.com/ROCm/rocm-jax/branches/all?query=rocm-jaxlib diff --git a/build/requirements.txt b/build/requirements.txt index b6be3ecb53..40a09f3bec 100644 --- a/build/requirements.txt +++ b/build/requirements.txt @@ -1,2 +1,2 @@ -jaxlib==0.8.2 -jax==0.8.2 +jaxlib==0.9.1 +jax==0.9.1 diff --git a/ci/Dockerfile.maxtext b/ci/Dockerfile.maxtext index 58717c388f..206757b11e 100644 --- a/ci/Dockerfile.maxtext +++ b/ci/Dockerfile.maxtext @@ -15,4 +15,4 @@ WORKDIR /maxtext # Explicitly install jax,jaxlib to avoid pip pulling a newer version (e.g. 0.8.1) RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \ pip install -r requirements.txt && \ - pip install jax==0.8.2 jaxlib==0.8.2 && pip freeze + pip install jax==0.9.1 jaxlib==0.9.1 && pip freeze diff --git a/jax_rocm_plugin/build/rocm/tools/build_wheels.py b/jax_rocm_plugin/build/rocm/tools/build_wheels.py index 7211870b8e..5246b01d29 100644 --- a/jax_rocm_plugin/build/rocm/tools/build_wheels.py +++ b/jax_rocm_plugin/build/rocm/tools/build_wheels.py @@ -367,8 +367,8 @@ def fix_wheel(path, jax_path): def is_release_jaxlib(filename): """Check if wheel is a release jaxlib wheel (not selfbuilt).""" - # e.g. jaxlib-0.8.2-cp312-....whl (release) - # reject jaxlib-0.8.2.dev0+selfbuilt-cp312-....whl + # e.g. jaxlib-0.9.1-cp312-....whl (release) + # reject jaxlib-0.9.1.dev0+selfbuilt-cp312-....whl return filename.startswith("jaxlib-") and "+selfbuilt" not in filename diff --git a/jax_rocm_plugin/pjrt/python/version.py b/jax_rocm_plugin/pjrt/python/version.py index e4f15f46c6..3434a5e1ab 100644 --- a/jax_rocm_plugin/pjrt/python/version.py +++ b/jax_rocm_plugin/pjrt/python/version.py @@ -28,7 +28,7 @@ import subprocess # pylint: disable=invalid-name -_version = "0.8.2" +_version = "0.9.1" # The following line is overwritten by build scripts in distributions & # releases. Do not modify this manually, or jax/jaxlib build will fail. _release_version: str | None = None @@ -163,7 +163,7 @@ def make_release_tree(self, base_dir, files): __version__ = _get_version_string() -_minimum_jaxlib_version = "0.8.2" +_minimum_jaxlib_version = "0.9.1" def _version_as_tuple(version_str): diff --git a/jax_rocm_plugin/third_party/jax/workspace.bzl b/jax_rocm_plugin/third_party/jax/workspace.bzl index 3e2f5e57a8..b0b7433be7 100644 --- a/jax_rocm_plugin/third_party/jax/workspace.bzl +++ b/jax_rocm_plugin/third_party/jax/workspace.bzl @@ -1,10 +1,10 @@ load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") # To update JAX: -# 1. Find the commit hash you want to pin to (e.g., from rocm-jaxlib-v0.8.2 branch) +# 1. Find the commit hash you want to pin to (e.g., from rocm-jaxlib-v0.9.1 branch) # 2. Update JAX_COMMIT below -JAX_COMMIT = "fbfa695aea59ed578b81d8fc72ab23bba5d2cfaa" +JAX_COMMIT = "58cb6e556c996bf0361bca9e64890a551e513280" def repo(): git_repository( @@ -14,8 +14,6 @@ def repo(): patch_tool = "patch", patch_args = ["-p1"], patches = [ - "//third_party/jax:0005-Fix-HIP-availability-errors.patch", - "//third_party/jax:0006-Enable-testing-with-ROCm-plugin-wheels.patch", # TODO: remove due to: https://github.com/jax-ml/jax/pull/34641 - "//third_party/jax:0007-Fix-legacy-create-init.patch", # TODO: remove due to: https://github.com/jax-ml/jax/pull/34770 + "//third_party/jax:0005-Fix-HIP-availability-errors.patch", #TODO(gulsumgudukbay): check if this is still needed ], ) diff --git a/jax_rocm_plugin/third_party/xla/workspace.bzl b/jax_rocm_plugin/third_party/xla/workspace.bzl index 382b075c7a..b0e7c3d5bb 100644 --- a/jax_rocm_plugin/third_party/xla/workspace.bzl +++ b/jax_rocm_plugin/third_party/xla/workspace.bzl @@ -10,10 +10,10 @@ load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") # To update XLA: -# 1. Find the commit hash you want to pin to (e.g., from rocm-jaxlib-v0.8.2 branch) +# 1. Find the commit hash you want to pin to (e.g., from rocm-jaxlib-v0.9.1 branch) # 2. Update XLA_COMMIT below -XLA_COMMIT = "24c5f10ae8fc24aefd20b43c501ade7f66fd0cfd" +XLA_COMMIT = "3cc8846c10052cc1c32c4db87866eac4e4cdbccd" def repo(): git_repository( diff --git a/stack.py b/stack.py index 6f97915745..797c6480e2 100644 --- a/stack.py +++ b/stack.py @@ -5,8 +5,8 @@ import os import subprocess -TEST_JAX_REPO_REF = "rocm-jaxlib-v0.8.2" -XLA_REPO_REF = "rocm-jaxlib-v0.8.2" +TEST_JAX_REPO_REF = "rocm-jaxlib-v0.9.1" +XLA_REPO_REF = "rocm-jaxlib-v0.9.1" JAX_REPL_URL = "https://github.com/rocm/jax" From be917a50f2f647e69260e96a59d5ea443593f272 Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Tue, 3 Mar 2026 11:18:19 -0600 Subject: [PATCH 02/14] add visibility:public for plugin targets --- .../jax/0006-Expose-rocm-plugin-targets.patch | 39 +++++++++++++++++++ jax_rocm_plugin/third_party/jax/workspace.bzl | 1 + 2 files changed, 40 insertions(+) create mode 100644 jax_rocm_plugin/third_party/jax/0006-Expose-rocm-plugin-targets.patch diff --git a/jax_rocm_plugin/third_party/jax/0006-Expose-rocm-plugin-targets.patch b/jax_rocm_plugin/third_party/jax/0006-Expose-rocm-plugin-targets.patch new file mode 100644 index 0000000000..29edc41267 --- /dev/null +++ b/jax_rocm_plugin/third_party/jax/0006-Expose-rocm-plugin-targets.patch @@ -0,0 +1,39 @@ +diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD +index 67fc1893..5ee3ef4c 100644 +--- a/jaxlib/rocm/BUILD ++++ b/jaxlib/rocm/BUILD +@@ -476,12 +476,14 @@ nanobind_extension( + ) + + py_library( + name = "rocm_gpu_support", ++ visibility = ["//visibility:public"], + deps = [ + ":_hybrid", + ":_linalg", + ":_prng", + ":_rnn", + ":_solver", + ":_sparse", + ":_triton", + ], + ) +@@ -526,6 +528,7 @@ cc_library( + ) + + nanobind_extension( + name = "rocm_plugin_extension", ++ visibility = ["//visibility:public"], + srcs = ["rocm_plugin_extension.cc"], + module_name = "rocm_plugin_extension", + deps = [ + ":py_client_gpu", + "//jaxlib:kernel_nanobind_helpers", + "//jaxlib/gpu:gpu_plugin_extension", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + "@local_config_rocm//rocm:hip_runtime", + "@local_config_rocm//rocm:rocm_headers", + "@nanobind", + ], + ) diff --git a/jax_rocm_plugin/third_party/jax/workspace.bzl b/jax_rocm_plugin/third_party/jax/workspace.bzl index b0b7433be7..9a0e63d835 100644 --- a/jax_rocm_plugin/third_party/jax/workspace.bzl +++ b/jax_rocm_plugin/third_party/jax/workspace.bzl @@ -15,5 +15,6 @@ def repo(): patch_args = ["-p1"], patches = [ "//third_party/jax:0005-Fix-HIP-availability-errors.patch", #TODO(gulsumgudukbay): check if this is still needed + "//third_party/jax:0006-Expose-rocm-plugin-targets.patch", ], ) From ef0c436defbad4c77d84a9e7ceec3c57ec34f8d7 Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Tue, 3 Mar 2026 11:48:57 -0600 Subject: [PATCH 03/14] update patch --- ...rgets.patch => 0008-Expose-rocm-plugin-targets.patch} | 9 ++++----- jax_rocm_plugin/third_party/jax/workspace.bzl | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) rename jax_rocm_plugin/third_party/jax/{0006-Expose-rocm-plugin-targets.patch => 0008-Expose-rocm-plugin-targets.patch} (88%) diff --git a/jax_rocm_plugin/third_party/jax/0006-Expose-rocm-plugin-targets.patch b/jax_rocm_plugin/third_party/jax/0008-Expose-rocm-plugin-targets.patch similarity index 88% rename from jax_rocm_plugin/third_party/jax/0006-Expose-rocm-plugin-targets.patch rename to jax_rocm_plugin/third_party/jax/0008-Expose-rocm-plugin-targets.patch index 29edc41267..fd668216b9 100644 --- a/jax_rocm_plugin/third_party/jax/0006-Expose-rocm-plugin-targets.patch +++ b/jax_rocm_plugin/third_party/jax/0008-Expose-rocm-plugin-targets.patch @@ -1,10 +1,9 @@ diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD -index 67fc1893..5ee3ef4c 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD -@@ -476,12 +476,14 @@ nanobind_extension( +@@ -475,12 +475,13 @@ nanobind_extension( ) - + py_library( name = "rocm_gpu_support", + visibility = ["//visibility:public"], @@ -18,9 +17,9 @@ index 67fc1893..5ee3ef4c 100644 ":_triton", ], ) -@@ -526,6 +528,7 @@ cc_library( +@@ -528,6 +529,7 @@ cc_library( ) - + nanobind_extension( name = "rocm_plugin_extension", + visibility = ["//visibility:public"], diff --git a/jax_rocm_plugin/third_party/jax/workspace.bzl b/jax_rocm_plugin/third_party/jax/workspace.bzl index 9a0e63d835..fc565326f8 100644 --- a/jax_rocm_plugin/third_party/jax/workspace.bzl +++ b/jax_rocm_plugin/third_party/jax/workspace.bzl @@ -15,6 +15,6 @@ def repo(): patch_args = ["-p1"], patches = [ "//third_party/jax:0005-Fix-HIP-availability-errors.patch", #TODO(gulsumgudukbay): check if this is still needed - "//third_party/jax:0006-Expose-rocm-plugin-targets.patch", + "//third_party/jax:0008-Expose-rocm-plugin-targets.patch", ], ) From 646db9f292cbb3fdff28661a7cb5d8a9930bd58d Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Tue, 3 Mar 2026 12:50:16 -0600 Subject: [PATCH 04/14] fix patch --- .../jax/0008-Expose-rocm-plugin-targets.patch | 22 ++----------------- 1 file changed, 2 insertions(+), 20 deletions(-) diff --git a/jax_rocm_plugin/third_party/jax/0008-Expose-rocm-plugin-targets.patch b/jax_rocm_plugin/third_party/jax/0008-Expose-rocm-plugin-targets.patch index fd668216b9..0c488c2f2f 100644 --- a/jax_rocm_plugin/third_party/jax/0008-Expose-rocm-plugin-targets.patch +++ b/jax_rocm_plugin/third_party/jax/0008-Expose-rocm-plugin-targets.patch @@ -1,8 +1,8 @@ diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD +index 72f2ec6da..11f620658 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD -@@ -475,12 +475,13 @@ nanobind_extension( - ) +@@ -477,6 +477,7 @@ nanobind_extension( py_library( name = "rocm_gpu_support", @@ -10,15 +10,7 @@ diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD deps = [ ":_hybrid", ":_linalg", - ":_prng", - ":_rnn", - ":_solver", - ":_sparse", - ":_triton", - ], - ) @@ -528,6 +529,7 @@ cc_library( - ) nanobind_extension( name = "rocm_plugin_extension", @@ -26,13 +18,3 @@ diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD srcs = ["rocm_plugin_extension.cc"], module_name = "rocm_plugin_extension", deps = [ - ":py_client_gpu", - "//jaxlib:kernel_nanobind_helpers", - "//jaxlib/gpu:gpu_plugin_extension", - "@com_google_absl//absl/log", - "@com_google_absl//absl/strings", - "@local_config_rocm//rocm:hip_runtime", - "@local_config_rocm//rocm:rocm_headers", - "@nanobind", - ], - ) From 27545992dcc9572e4fffc0043ab8d98fcf7f1d39 Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Tue, 10 Mar 2026 18:14:56 +0000 Subject: [PATCH 05/14] update jaxlib resolution --- docker/Dockerfile.jax-ubu24 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/Dockerfile.jax-ubu24 b/docker/Dockerfile.jax-ubu24 index f5e742e1e2..357707f4d0 100644 --- a/docker/Dockerfile.jax-ubu24 +++ b/docker/Dockerfile.jax-ubu24 @@ -64,5 +64,5 @@ RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \ ls -lah /wheelhouse && \ for py in python3.11 python3.12 python3.13 python3.14; do \ $py -m pip install -f /wheelhouse --no-deps --no-index "jax_rocm${PLUGIN_NAMESPACE}_plugin" "jax_rocm${PLUGIN_NAMESPACE}_pjrt" && \ - $py -m pip install -f /wheelhouse --no-deps --no-index --force-reinstall "jaxlib"; \ + $py -m pip install --break-system-packages --no-deps --force-reinstall "jaxlib==0.9.1"; \ done From 621cba58d95e320cb5070002c6562d152860823f Mon Sep 17 00:00:00 2001 From: JeniferC99 <150404595+JeniferC99@users.noreply.github.com> Date: Tue, 3 Mar 2026 15:35:31 -0800 Subject: [PATCH 06/14] Cherry-pick PR #293 + #306: dev tarball Docker tag fix and double URL-encoding fix (#335) * Cherry-pick PR #293 + #306: dev tarball Docker tag fix and double URL-encoding fix - build/ci_build: replace '+' with '.' in rocm_ver_tag and rocm_version_tag for valid Docker tags - jax_rocm_plugin/build/rocm/ci_build: ROCM_VERSION_EXTRA use replace('+', '.') - tools/get_rocm.py: unquote then quote therock URL to avoid double-encoding; create amdgcn symlink only if missing Made-with: Cursor * Fix pylint C0301: shorten comment line (line-too-long 103/100) Made-with: Cursor * Use try/except FileExistsError for amdgcn symlink (PR #330) Made-with: Cursor --------- Co-authored-by: Jenkins (cherry picked from commit ea9a0213a8e8a72417fb87270ee186ee9d05c811) --- build/ci_build | 8 +++++--- jax_rocm_plugin/build/rocm/ci_build | 2 +- tools/get_rocm.py | 6 +++++- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/build/ci_build b/build/ci_build index a8b97db54b..bb9e0b55d9 100755 --- a/build/ci_build +++ b/build/ci_build @@ -378,8 +378,9 @@ def build_dockers( commit_info = _get_commit_info_from_wheel() dockerfiles = _apply_filters(docker_filters, "Dockerfile.jax") - rocm_ver_tag = "rocm%s" % "".join(rocm_version.split(".")) - rocm_version_tag = "".join(rocm_version.split(".")) + # Docker tags cannot contain '+', so replace it with '.' for consistency with wheel versions + rocm_ver_tag = "rocm%s" % "".join(rocm_version.split(".")).replace("+", ".") + rocm_version_tag = "".join(rocm_version.split(".")).replace("+", ".") plugin_namespace = rocm_version[0] if plugin_namespace == "6": plugin_namespace = "60" @@ -456,7 +457,8 @@ def build_base_dockers( push_latest=False, ): dockerfiles = _apply_filters(docker_filters, "Dockerfile.base") - rocm_ver_tag = "rocm%s" % "".join(rocm_version.split(".")) + # Docker tags cannot contain '+', so replace it with '.' for consistency with wheel versions + rocm_ver_tag = "rocm%s" % "".join(rocm_version.split(".")).replace("+", ".") extra_args = [] if add_llvm: diff --git a/jax_rocm_plugin/build/rocm/ci_build b/jax_rocm_plugin/build/rocm/ci_build index 7d4840fe29..f6d8b425a4 100755 --- a/jax_rocm_plugin/build/rocm/ci_build +++ b/jax_rocm_plugin/build/rocm/ci_build @@ -116,7 +116,7 @@ def dist_wheels( "--shm-size", "64G", "-e", - "ROCM_VERSION_EXTRA=" + rocm_version, + "ROCM_VERSION_EXTRA=" + rocm_version.replace("+", "."), "-e", "ROCM_JAX_COMMIT=" + rocm_jax_commit, builder_image, diff --git a/tools/get_rocm.py b/tools/get_rocm.py index c459e1d315..b3534e2299 100644 --- a/tools/get_rocm.py +++ b/tools/get_rocm.py @@ -29,6 +29,7 @@ import ssl import subprocess import sys +import urllib.parse import urllib.request # pylint: disable=unspecified-encoding @@ -222,7 +223,10 @@ def _install_therock(rocm_version, therock_path): else: os.makedirs(rocm_real_path) tar_path = "/tmp/therock.tar.gz" - with urllib.request.urlopen(therock_path) as response: + # Unquote first to avoid double-encoding if URL already encoded (e.g. '%2B' -> '%252B'). + decoded_url = urllib.parse.unquote(therock_path) + encoded_url = urllib.parse.quote(decoded_url, safe=":/?&=") + with urllib.request.urlopen(encoded_url) as response: if response.status == 200: with open(tar_path, "wb") as tar_file: tar_file.write(response.read()) From b1d232cc49eb56a9c050c1e87535cd730176b1fa Mon Sep 17 00:00:00 2001 From: Flora Cui Date: Fri, 6 Mar 2026 17:43:22 +0800 Subject: [PATCH 07/14] [WSL] pjrt/python: hardcode GPU count to 1 in WSL (#336) On WSL, ROCm GPU discovery via KFD topology is unavailable, which can cause plugin initialization to fail. This change adds a WSL-specific fast path in GPU counting: when /dev/dxg exists, it returns 1 directly for initialization gating. The existing KFD-based counting logic remains unchanged for other environments, and the docstring is updated to clearly document this intentional hardcoded behavior. Signed-off-by: Flora Cui (cherry picked from commit ab4424b24159edfaaedb8944a124c25aa48ff3e5) --- jax_rocm_plugin/pjrt/python/__init__.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/jax_rocm_plugin/pjrt/python/__init__.py b/jax_rocm_plugin/pjrt/python/__init__.py index 94e8b999da..0637acbc7d 100644 --- a/jax_rocm_plugin/pjrt/python/__init__.py +++ b/jax_rocm_plugin/pjrt/python/__init__.py @@ -134,13 +134,15 @@ def count_amd_gpus(stop_at: int = None) -> int: """Count AMD GPUs available via KFD kernel driver. This function checks for the presence of AMD GPUs by examining KFD kernel - driver entities as a proxy. This approach provides a good compromise between - performance, reliability and simplicity. Presence of such entities doesn't - guarantee that the GPUs are usable through HIP and PJRT, however, we can't - do much better without spawning an additional process with a potentially - complicated setup to run actual HIP code. And we don't want to initialize - HIP right now inside the current process, because doing so might spoil a - proper initialization of the rocprofiler-sdk later during PJRT startup. + driver entities as a proxy. In WSL setups, if /dev/dxg exists, this check + hardcodes the result to 1 GPU for initialization gating. This approach + provides a good compromise between performance, reliability and simplicity. + Presence of such entities doesn't guarantee that the GPUs are usable + through HIP and PJRT, however, we can't do much better without spawning an + additional process with a potentially complicated setup to run actual HIP + code. And we don't want to initialize HIP right now inside the current + process, because doing so might spoil a proper initialization of the + rocprofiler-sdk later during PJRT startup. Args: stop_at: If provided, stop counting once this many GPUs are found. @@ -150,6 +152,9 @@ def count_amd_gpus(stop_at: int = None) -> int: The number of AMD GPUs detected (up to stop_at if provided). """ try: + if os.path.exists("/dev/dxg"): + return 1 + kfd_nodes_path = "/sys/class/kfd/kfd/topology/nodes/" if not os.path.exists(kfd_nodes_path): return 0 From d032adbf4f74caba0b96048f1d504896334063c0 Mon Sep 17 00:00:00 2001 From: charleshofer Date: Fri, 6 Mar 2026 11:29:26 -0600 Subject: [PATCH 08/14] Shrink docker image size (#329) (cherry picked from commit 33fbb4bc8c3f59ba91b81a5a152fd17da1e09bf9) --- build/ci_build | 1 + docker/Dockerfile.base-ubu24 | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/build/ci_build b/build/ci_build index bb9e0b55d9..c772fcfa09 100755 --- a/build/ci_build +++ b/build/ci_build @@ -672,6 +672,7 @@ def parse_args(): "--llvm-version", help="Set the version number of LLVM to be installed.", type=int, + default="18", ) testp = subp.add_parser("test") diff --git a/docker/Dockerfile.base-ubu24 b/docker/Dockerfile.base-ubu24 index 8482a140bc..9251a5f52e 100644 --- a/docker/Dockerfile.base-ubu24 +++ b/docker/Dockerfile.base-ubu24 @@ -103,7 +103,7 @@ RUN --mount=type=bind,source=tools/get_rocm.py,target=get_rocm.py \ --mount=type=bind,from=therock,target=/tmp/therock/ \ python3 get_rocm.py --rocm-version=$ROCM_VERSION --job-name=$ROCM_BUILD_JOB --build-num=$ROCM_BUILD_NUM --therock-path=$THEROCK_PATH -# Install LLVM when enabled: +# Install LLVM's clang and lld when enabled: RUN --mount=type=cache,target=/var/cache/apt \ [ $INSTALL_LLVM -eq 0 ] || ( \ apt update && \ @@ -115,7 +115,7 @@ RUN --mount=type=cache,target=/var/cache/apt \ cd /tmp && \ wget https://apt.llvm.org/llvm.sh && \ chmod +x llvm.sh && \ - ./llvm.sh $LLVM_VERSION all && \ + ./llvm.sh $LLVM_VERSION clang lld && \ rm llvm.sh && \ apt-get clean && rm -rf /var/lib/apt/lists/* ) @@ -143,7 +143,7 @@ RUN [ "$ROCM_VERSION" = "7.1.1" ] || ( \ -DMIOPEN_USE_MLIR=OFF -DMIOPEN_ENABLE_AI_KERNEL_TUNING=OFF \ -DMIOPEN_ENABLE_AI_IMMED_MODE_FALLBACK=OFF .. && \ make -j10 && make install && \ - apt-get clean && rm -rf /var/lib/apt/lists/* ) + rm -rf /rocm-libraries && apt-get clean && rm -rf /var/lib/apt/lists/* ) # This mitigates crashes related to the kernel database. ENV MIOPEN_FIND_ENFORCE=SEARCH_DB_UPDATE From 3f4e3918ed92a42f54b77a43507b413964c47a88 Mon Sep 17 00:00:00 2001 From: Pakize Sanal Date: Mon, 9 Mar 2026 10:38:17 -0500 Subject: [PATCH 09/14] Ingest JAX Pytest results from S3 into DB (#294) (cherry picked from commit 7fe40553750a443f601d2d39d11facac8f353203) --- .github/workflows/pytest-results-to-db.yml | 43 +- ci/ingest_jax_ci_logs.sh | 72 ++ ci/upload_pytest_to_db.py | 753 +++++++++++++++++++++ 3 files changed, 864 insertions(+), 4 deletions(-) create mode 100755 ci/ingest_jax_ci_logs.sh create mode 100644 ci/upload_pytest_to_db.py diff --git a/.github/workflows/pytest-results-to-db.yml b/.github/workflows/pytest-results-to-db.yml index 8fca2be4b7..2a566c0a0b 100644 --- a/.github/workflows/pytest-results-to-db.yml +++ b/.github/workflows/pytest-results-to-db.yml @@ -1,10 +1,19 @@ -name: Pytest Results to DB +name: Jax CI Test Results to DB on: + schedule: + - cron: "10 * * * *" workflow_dispatch: inputs: - run-id: - required: true + source-repo: + description: "ORG/REPO" + required: false + default: "ROCm/jax" + type: string + source-github-run-id: + description: "Repo GitHub Run ID" + required: false + default: "" type: string secrets: ROCM_JAX_DB_HOSTNAME: @@ -16,9 +25,35 @@ on: ROCM_JAX_DB_NAME: required: true +concurrency: + group: jax-ci-test-results-to-db + cancel-in-progress: false + jobs: - upload-to-db: + process-test-results: runs-on: mysqldb steps: - name: Checkout source uses: actions/checkout@v4 + - name: Set up Python environment + run: | + python3 -m venv venv + source venv/bin/activate + pip install --upgrade pip + pip install mysql-connector-python + - name: Upload logs to MySQL database + env: + ROCM_JAX_DB_HOSTNAME: ${{ secrets.ROCM_JAX_DB_HOSTNAME }} + ROCM_JAX_DB_USERNAME: ${{ secrets.ROCM_JAX_DB_USERNAME }} + ROCM_JAX_DB_PASSWORD: ${{ secrets.ROCM_JAX_DB_PASSWORD }} + ROCM_JAX_DB_NAME: ${{ secrets.ROCM_JAX_DB_NAME }} + BUCKET: ${{ secrets.AMD_S3_BUCKET_NAME }} + AWS_ACCESS_KEY_ID: ${{ secrets.AMD_S3_BUCKET_ACCESS_KEY }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AMD_S3_BUCKET_SECRET_KEY }} + FILTER_REPO: ${{ inputs.source-repo }} + FILTER_RUN_ID: ${{ inputs.source-github-run-id }} + INPUT_GPU_ARCH: "MI350" + run: | + source venv/bin/activate + ./ci/ingest_jax_ci_logs.sh + diff --git a/ci/ingest_jax_ci_logs.sh b/ci/ingest_jax_ci_logs.sh new file mode 100755 index 0000000000..32e99eadae --- /dev/null +++ b/ci/ingest_jax_ci_logs.sh @@ -0,0 +1,72 @@ +#!/bin/bash +set -euo pipefail + +: "${INPUT_GPU_ARCH:?}" +: "${FILTER_REPO:=ROCm/jax}" +: "${FILTER_RUN_ID:=}" + +ROOT="jax-ci-test-logs/${FILTER_REPO}/" +CUTOFF="$(date -u -d '2 days ago' +%F)" + +echo "Scanning S3 prefix: s3://${BUCKET}/${ROOT}" +echo "Cutoff date (UTC): ${CUTOFF}" + +aws s3 ls "s3://${BUCKET}/${ROOT}" --recursive \ + | awk '/\/_SUCCESS$/ {print $4}' \ + | while read -r SUCCESS_KEY; do + + PREFIX="${SUCCESS_KEY%/_SUCCESS}" + + # Filter by run id (optional manual override) + if [[ -n "${FILTER_RUN_ID:-}" ]] && [[ "${PREFIX}" != *"_${FILTER_RUN_ID}_"* ]]; then + continue + fi + + # Skip already ingested + if aws s3 ls "s3://${BUCKET}/${PREFIX}/_INGESTED" >/dev/null 2>&1; then + continue + fi + + # Extract run_dir + RUN_DIR="$(basename "$(dirname "${PREFIX}")")" + RUN_DATE="${RUN_DIR:0:10}" + + # Skip older than cutoff + if [[ "${RUN_DATE}" < "${CUTOFF}" ]]; then + continue + fi + + # Skip if logs.tar.gz missing + if ! aws s3 ls "s3://${BUCKET}/${PREFIX}/logs.tar.gz" >/dev/null 2>&1; then + echo "Skipping ${PREFIX}: logs.tar.gz not found" + continue + fi + + echo "Ingesting: ${PREFIX}" + + WD="$(mktemp -d)" + + aws s3 cp "s3://${BUCKET}/${PREFIX}/run-manifest.json" "${WD}/run-manifest.json" + aws s3 cp "s3://${BUCKET}/${PREFIX}/logs.tar.gz" "${WD}/logs.tar.gz" + + mkdir -p "${WD}/logs_dir/extracted" + cp "${WD}/run-manifest.json" "${WD}/logs_dir/" + tar -xzf "${WD}/logs.tar.gz" -C "${WD}/logs_dir/extracted" + + if python3 ci/upload_pytest_to_db.py \ + --local_logs_dir "${WD}/logs_dir" \ + --run-tag "ci-run" \ + --gpu-tag "${INPUT_GPU_ARCH}" \ + --artifact_uri "s3://${BUCKET}/${PREFIX}" + then + printf '' | aws s3 cp - "s3://${BUCKET}/${PREFIX}/_INGESTED" + echo "Marked _INGESTED: ${PREFIX}" + else + echo "Skipping ${PREFIX}: ingest failed" + fi + + rm -rf "${WD}" + done + +echo "Done" + diff --git a/ci/upload_pytest_to_db.py b/ci/upload_pytest_to_db.py new file mode 100644 index 0000000000..624920c518 --- /dev/null +++ b/ci/upload_pytest_to_db.py @@ -0,0 +1,753 @@ +#!/usr/bin/env python3 +""" +Ingest JAX Pytest results from S3 into MySQL. + +Recursively locates one Pytest report under given log dir + +Tables: + - jax_ci_runs: one row per run + - jax_ci_tests: one row per unique test + - jax_ci_results: one row per test per run + +Run-level manifest (GitHub vars, etc.) is sourced from the CI. +""" + +from __future__ import annotations + +import argparse +import json +import os +import re +from datetime import datetime, timezone +from functools import lru_cache +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +# pylint: disable=import-error +import mysql.connector +from mysql.connector import Error as MySQLError + +# ----------------------------- +# Constants +# ----------------------------- +TEXT_LIMIT = 250 +BATCH_SIZE = 2000 +DEFAULT_LABEL = "Skipped Upstream" +MANIFEST_FILENAME = "run-manifest.json" + + +# ----------------------------- +# Helpers +# ----------------------------- +def extract_skip_reason(reason: str) -> str: + """Parse pytest skip longrepr tuple-string into its reason text. + + Example input: "('/path/test_x.py', 42, 'Skipped: some reason')" + Also works for xdist header: + "[gw0] ... \\n('/path/test_x.py', 42, 'Skipped: some reason')" + """ + + # strip outer parentheses, + # then split into 3 parts + parts = reason[1:-1].split(",", 2) + if len(parts) != 3: + return reason + + msg = parts[2].strip() + # drop matching quotes if any + if msg[:1] in {"'", '"'} and msg[-1:] == msg[:1]: + msg = msg[1:-1] + return msg + + +def nodeid_parts(nodeid: str) -> Tuple[str, str, str]: + """Split pytest nodeid into (filename, classname, test_name). + + Support both variants: "file::Class::test" and "file::test" in nodeids. + """ + parts = nodeid.split("::", 2) + f = parts[0] + if len(parts) == 2: + return f, "", parts[1] + c = parts[1] + t = parts[2] + return f, c, t + + +def pipe_split(raw: Optional[str]) -> List[str]: + """Split a pipe-separated manifest field into a list of values.""" + if not raw: + return [] + return [p for p in (x.strip() for x in raw.split("|")) if p] + + +def parse_iso_dt(s: Optional[str]) -> Optional[datetime]: + """Parse ISO8601 timestamp into UTC datetime.""" + if not s: + return None + return datetime.fromisoformat(s.replace("Z", "+00:00")).replace(tzinfo=None) + + +def parse_run_key_and_combo(artifact_uri: str) -> tuple[str, str]: + """Extract (run_key, combo) from the artifact URI path. + + Expected format: ...// + """ + parts = [p for p in artifact_uri.split("/") if p] + if len(parts) < 2: + raise SystemExit(f"Invalid artifact_uri format: {artifact_uri}") + return parts[-2], parts[-1] # run_key, combo + + +# ----------------------------- +# Input/File Loading +# ----------------------------- +def find_pytest_report_json(local_logs_dir: Path) -> Optional[Path]: + """Find the single Pytest JSON report under the given logs directory, if present""" + ignore = {MANIFEST_FILENAME, "last_running.json"} + reports = [ + p + for p in local_logs_dir.rglob("*.json") + if p.is_file() + and p.name not in ignore + and not p.name.endswith("last_running.json") + ] + if not reports: + return None + if len(reports) != 1: + listing = "\n - " + "\n - ".join( + str(p.relative_to(local_logs_dir)) for p in sorted(reports) + ) + raise SystemExit( + f"Expected exactly ONE pytest JSON report; found {len(reports)}:{listing}" + ) + return reports[0] + + +def load_from_pytest_json(path: Path) -> Tuple[Optional[datetime], List[dict]]: + """Load pytest JSON report and return (report_created_at, tests), if present""" + with path.open("r", encoding="utf-8") as fh: + data = json.load(fh) + if isinstance(data, dict): + tests = data.get("tests", []) + created = data.get("created") + report_created_at = ( + datetime.fromtimestamp(float(created), tz=timezone.utc).replace(tzinfo=None) + if created is not None + else None + ) + return report_created_at, tests + if isinstance(data, list): + return None, data + raise ValueError(f"Unexpected report JSON structure: {path}") + + +def load_manifest(local_logs_dir: Path) -> dict: + """Load CI run metadata from run-manifest.json. + + Schema Version_1 Fields: + run_started_at CI run start time + run_completed_at CI run completion time + github_run_url URL of the GitHub Actions run + github_repository Repository name (e.g. "ROCm/jax") + github_ref_name Branch or tag name + github_ref Full Git reference + github_sha Commit SHA for the run + github_event_name GitHub event type (push, pull_request, etc.) + github_run_id GitHub Actions run identifier + github_run_attempt Retry attempt number + github_run_number Sequential run number + github_workflow Workflow name + github_job Job name within the workflow + python_version Python version used in the run + rocm_version ROCm version used + rocm_tag ROCm container tag + is_nightly Whether it is a nightly or continuous run + gpu_count Number of GPUs used + runner CI runner label + base_image_name Base container image name + base_image_digest Container image digest + jax_packages_raw Installed JAX Python packages (pipe-separated) + wheels_sha_raw Wheel SHA256 list (pipe-separated) + + Some fields may be optional and missing.""" + p = local_logs_dir / MANIFEST_FILENAME + if not p.exists(): + raise FileNotFoundError(f"{MANIFEST_FILENAME} not found: {p}") + with p.open("r", encoding="utf-8") as fh: + return json.load(fh) + + +# ----------------------------- +# Run/Result Fields Preparation +# ----------------------------- +def require_field(m: dict, key: str): + """Return a required manifest field, failing early if it is missing.""" + v = m.get(key) + if v in (None, ""): + raise SystemExit(f"Manifest missing required field: {key}") + return v + + +def packages_json_and_jax_version( + raw: Optional[str], +) -> Tuple[Optional[str], Optional[str]]: + """Parse package list into JSON and extract the JAX version.""" + if not raw: + return None, None + pkgs = [] + jax_ver = None + for item in pipe_split(raw): + name, sep, ver = item.partition("==") + name = name.strip() + ver = ver.strip() if sep else None + pkgs.append({"name": name, "version": ver, "raw": item}) + if name == "jax" and ver: + jax_ver = ver + return json.dumps(pkgs), jax_ver + + +def wheels_json(raw: Optional[str]) -> Optional[str]: + """Parse wheels metadata into normalized JSON.""" + if not raw: + return None + wheels = [] + for line in pipe_split(raw): + m = re.match(r"^([0-9a-fA-F]{64})\s+(.+)$", line) + if m: + wheels.append({"sha256": m.group(1).lower(), "file": m.group(2).strip()}) + else: + wheels.append({"sha256": None, "file": line}) + return json.dumps(wheels) + + +def build_run_fields( # pylint: disable=too-many-locals + m: dict, + *, + artifact_uri: str, + run_tag: str, + gpu_tag: str, +) -> dict: + """Normalize manifest fields for jax_ci_runs insertion.""" + run_key, combo = parse_run_key_and_combo(artifact_uri) + github_repository = require_field(m, "github_repository") + github_ref_name = require_field(m, "github_ref_name") + github_run_id = int(require_field(m, "github_run_id")) + python_version = require_field(m, "python_version") + rocm_version = require_field(m, "rocm_version") + runner = require_field(m, "runner") + is_nightly = require_field(m, "is_nightly") + if is_nightly not in {"nightly", "continuous"}: + raise SystemExit(f"Invalid is_nightly value: {is_nightly}") + github_run_attempt = int(m.get("github_run_attempt") or 1) + github_run_number = ( + int(m["github_run_number"]) + if m.get("github_run_number") not in (None, "") + else None + ) + pkgs_json, jax_ver_guess = packages_json_and_jax_version(m.get("jax_packages_raw")) + whl_json = wheels_json(m.get("wheels_sha_raw")) + return { + "github_repository": github_repository, + "github_ref_name": github_ref_name, + "github_ref": m.get("github_ref"), + "github_event_name": m.get("github_event_name"), + "github_run_url": m.get("github_run_url"), + "github_sha": m.get("github_sha"), + "github_run_id": github_run_id, + "github_run_attempt": github_run_attempt, + "github_run_number": github_run_number, + "github_workflow": m.get("github_workflow"), + "github_job": m.get("github_job"), + "runner": runner, + "python_version": python_version, + "rocm_version": rocm_version, + "rocm_tag": m.get("rocm_tag"), + "gpu_count": m.get("gpu_count"), + "gpu_tag": gpu_tag, + "is_nightly": is_nightly, + "run_tag": run_tag, + "run_key": run_key, + "combo": combo, + "artifact_uri": artifact_uri, + "jax_version": m.get("jax_version") or jax_ver_guess, + "jax_commit": m.get("jax_commit"), + "xla_commit": m.get("xla_commit"), + "base_image_name": m.get("base_image_name"), + "base_image_digest": m.get("base_image_digest"), + "packages_json": pkgs_json, + "wheels_json": whl_json, + "run_started_at": parse_iso_dt(m.get("run_started_at")), + "run_completed_at": parse_iso_dt(m.get("run_completed_at")), + } + + +def extract_result_fields( + t: dict, +) -> Tuple[str, str, float, Optional[str], Optional[str]]: + """Extract core fields from a pytest test dict. + + Args: + t: Pytest test dict with keys 'nodeid', 'outcome', and optional 'call'. + + Returns: + Tuple (nodeid, outcome, duration, longrepr, message), where: + - longrepr: skip reason (tuple-string normalized if needed), truncated. + - message: crash message with normalized spaces, truncated (or None). + """ + nodeid = t["nodeid"] + outcome = t["outcome"] + call = t.get("call") or {} + duration = float(call.get("duration", 0.0)) + + longrepr_raw = call.get("longrepr") + if isinstance(longrepr_raw, str) and longrepr_raw: + longrepr_raw = extract_skip_reason(longrepr_raw) + longrepr = str(longrepr_raw)[:TEXT_LIMIT] if longrepr_raw is not None else None + + message = None + crash = call.get("crash") + if isinstance(crash, dict): + # Normalize excessive/irregular whitespace, then truncate. + raw_msg = crash.get("message", "") + msg = " ".join(str(raw_msg).split()) + message = msg[:TEXT_LIMIT] if msg else None + + return nodeid, outcome, duration, longrepr, message + + +# ----------------------------- +# Skip reason categorizer +# ----------------------------- +# Precompile skip categorization rules (regex etc.) once. +# categorize_reason() reuses them; lru_cache avoids recompute. +# Rules are evaluated in order - more specific rules should come before generic ones +_RULES_RAW = [ + # TPU-specific (checked first) + {"contains": "tpu", "label": "TPU-Only"}, + # Mosaic (check reason only, filename/testname checked separately) + {"contains": "mosaic", "label": "Mosaic"}, + # ROCm-specific checks + {"any": ["skip on rocm", "skip for rocm"], "label": "Not Supported on ROCm"}, + {"all": ["not supported on", "rocm"], "label": "Not Supported on ROCm"}, + {"contains": "is not available for rocm", "label": "Not Supported on ROCm"}, + # Multiple devices required (before generic "support" check) + {"all": [">=", "devices"], "label": "Multiple Devices Required"}, + {"all": ["test", "requires", "device"], "label": "Multiple Devices Required"}, + # NVIDIA-specific + { + "any": [ + "cuda", + "sm90", + "sm100a", + "sm80", + "cudnn", + "nvidia", + "cupy", + "capability", + ], + "label": "NVIDIA-Specific", + }, + {"contains": "at least", "label": "NVIDIA-Specific"}, + # Apple-specific + {"any": ["metal", "apple"], "label": "Apple-Specific"}, + # CPU-only tests + {"contains": "test enabled only for cpu", "label": "CPU-Only"}, + {"contains": "jax implements eig only on cpu", "label": "CPU-Only"}, + {"contains": "schur decomposition is only implemented on cpu", "label": "CPU-Only"}, + {"contains": "backend is not cpu", "label": "CPU-Only"}, + {"contains": "only for cpu", "label": "CPU-Only"}, + # Device inapplicability + {"contains": "x64", "label": "Device Inapplicability"}, + {"contains": "x32", "label": "Device Inapplicability"}, + { + "contains": "memories do not work on cpu and gpu backends yet", + "label": "Device Inapplicability", + }, + # Missing modules/plugins + {"contains": "magma is not installed", "label": "Missing Module/API/Plugin"}, + {"contains": "no module named", "label": "Missing Module/API/Plugin"}, + {"contains": "requires pytorch", "label": "Missing Module/API/Plugin"}, + {"contains": "requires tensorflow", "label": "Missing Module/API/Plugin"}, + { + "regex": re.compile(r"tests?\s+require?\s+(.+?)\s+plugin", re.I), + "label": "Missing Module/API/Plugin", + }, + # Memory Limit + {"contains": "memory size limit exceeded", "label": "Memory Limit Exceeded"}, + # Performance-related skips + {"contains": "too slow", "label": "Too Slow (Skipped Upstream)"}, + { + "contains": "skipping big tests under sanitizers due to slowdown", + "label": "Too Slow (Skipped Upstream)", + }, + # Maintenance-related skips + { + "any": ["unmaintained", "not maintained"], + "label": "Currently Unmaintained (Skipped Upstream)", + }, + # Generic "Skipped Upstream" checks (at the bottom) + {"contains": "dimension", "label": "Skipped Upstream"}, + {"contains": "not supported in interpret mode", "label": "Skipped Upstream"}, + {"contains": "not implemented", "label": "Skipped Upstream"}, + {"contains": "not relevant", "label": "Skipped Upstream"}, + {"contains": "support", "label": "Skipped Upstream"}, +] + +_CATEG_RULES = tuple(dict(r) for r in _RULES_RAW) + + +@lru_cache(maxsize=4096) +def categorize_reason(reason: Optional[str]) -> str: + """Map a skip reason to a category label. + + Matching is case- and whitespace-insensitive. Rules are evaluated in order; + the first match wins. Unknown/empty reasons fall back to DEFAULT_LABEL. + """ + if not reason: + return DEFAULT_LABEL + + s = " ".join(str(reason).split()).casefold() + + for rule in _CATEG_RULES: + if "contains" in rule and rule["contains"] in s: + return rule["label"] + if "any" in rule and any(k in s for k in rule["any"]): + return rule["label"] + if "all" in rule and all(k in s for k in rule["all"]): + return rule["label"] + if "regex" in rule and rule["regex"].search(s): + return rule["label"] + return DEFAULT_LABEL + + +# ----------------------------- +# DB Ops +# ----------------------------- +def connect(): + """Open a MySQL connection from environment variables.""" + return mysql.connector.connect( + host=os.environ["ROCM_JAX_DB_HOSTNAME"], + user=os.environ["ROCM_JAX_DB_USERNAME"], + password=os.environ["ROCM_JAX_DB_PASSWORD"], + database=os.environ["ROCM_JAX_DB_NAME"], + autocommit=False, + ) + + +def find_existing_run_id( # pylint: disable=too-many-arguments, too-many-positional-arguments + cur, + github_repository: str, + github_ref_name: str, + is_nightly: str, + run_key: str, + combo: str, +) -> Optional[int]: + """Return run id for an existing logical run, if present.""" + cur.execute( + """ + SELECT id + FROM jax_ci_runs + WHERE github_repository = %s + AND github_ref_name = %s + AND is_nightly = %s + AND run_key = %s + AND combo = %s + LIMIT 1 + """, + (github_repository, github_ref_name, is_nightly, run_key, combo), + ) + row = cur.fetchone() + return int(row[0]) if row else None + + +def insert_run(cur, report_created_at: Optional[datetime], fields: dict) -> int: + """Insert one row into jax_ci_runs and return run_id. + + Runs are treated as immutable; duplicate logical runs should be + detected before calling this function.""" + ingested_at = datetime.now(timezone.utc).replace(tzinfo=None) + if report_created_at is None: + report_created_at = ingested_at + else: + report_created_at = report_created_at.replace(tzinfo=None) + + cur.execute( + """ + INSERT INTO jax_ci_runs ( + report_created_at, run_started_at, run_completed_at, ingested_at, + github_repository, github_ref_name, github_ref, github_event_name, + github_run_url, github_sha, + github_run_id, github_run_attempt, github_run_number, + github_workflow, github_job, + runner, python_version, rocm_version, rocm_tag, + gpu_count, gpu_tag, is_nightly, + run_tag, run_key, combo, artifact_uri, + jax_version, jax_commit, xla_commit, + base_image_name, base_image_digest, + packages_json, wheels_json + ) VALUES ( + %s, %s, %s, %s, + %s, %s, %s, %s, + %s, %s, + %s, %s, %s, + %s, %s, + %s, %s, %s, %s, + %s, %s, %s, + %s, %s, %s, %s, + %s, %s, %s, + %s, %s, + %s, %s + ) + """, + ( + report_created_at, + fields["run_started_at"], + fields["run_completed_at"], + ingested_at, + fields["github_repository"], + fields["github_ref_name"], + fields["github_ref"], + fields["github_event_name"], + fields["github_run_url"], + fields["github_sha"], + fields["github_run_id"], + fields["github_run_attempt"], + fields["github_run_number"], + fields["github_workflow"], + fields["github_job"], + fields["runner"], + fields["python_version"], + fields["rocm_version"], + fields["rocm_tag"], + fields["gpu_count"], + fields["gpu_tag"], + fields["is_nightly"], + fields["run_tag"], + fields["run_key"], + fields["combo"], + fields["artifact_uri"], + fields["jax_version"], + fields["jax_commit"], + fields["xla_commit"], + fields["base_image_name"], + fields["base_image_digest"], + fields["packages_json"], + fields["wheels_json"], + ), + ) + return int(cur.lastrowid) + + +def sync_tests_and_get_ids(cur, tests: List[dict]) -> Dict[Tuple[str, str, str], int]: + """Ensure all tests exist in jax_ci_tests and return an ID mapping. + + Uses a TEMPORARY TABLE for efficiency with large runs: + 1) Bulk insert unique (filename, classname, test_name) into a temp table. + 2) INSERT any missing rows into jax_ci_tests in one set operation. + 3) SELECT back (file, class, test) -> id mapping in one query. + """ + uniq = {nodeid_parts(t["nodeid"]) for t in tests} + if not uniq: + return {} + + cur.execute("DROP TEMPORARY TABLE IF EXISTS tmp_tests_") + # fmt: off + cur.execute( + """ + CREATE TEMPORARY TABLE tmp_tests_ ( + filename VARCHAR(100) NOT NULL, + classname VARCHAR(100) NOT NULL, + test_name VARCHAR(500) NOT NULL, + PRIMARY KEY (filename, classname, test_name) + ) ENGINE=InnoDB + """ + ) + cur.executemany( + "INSERT IGNORE INTO tmp_tests_ (filename, classname, test_name) VALUES (%s,%s,%s)", + list(uniq), + ) + + cur.execute( + """ + INSERT INTO jax_ci_tests (filename, classname, test_name) + SELECT s.filename, s.classname, s.test_name + FROM tmp_tests_ s + LEFT JOIN jax_ci_tests t + ON t.filename = s.filename + AND t.classname = s.classname + AND t.test_name = s.test_name + WHERE t.id IS NULL + """ + ) + + cur.execute( + """ + SELECT t.id, s.filename, s.classname, s.test_name + FROM tmp_tests_ s + JOIN jax_ci_tests t + ON t.filename = s.filename + AND t.classname = s.classname + AND t.test_name = s.test_name + """ + ) + # fmt: on + return {(f, c, n): int(test_id) for (test_id, f, c, n) in cur.fetchall()} + + +def batch_insert_results(cur, rows) -> None: + """Bulk insert/update result rows in chunks. + + Uses ON DUPLICATE KEY UPDATE to keep results idempotent per (run_id, test_id). + """ + if not rows: + return + + sql = """ + INSERT INTO jax_ci_results + (run_id, test_id, outcome, duration, longrepr, message, skip_label) + VALUES (%s,%s,%s,%s,%s,%s,%s) + ON DUPLICATE KEY UPDATE + outcome=VALUES(outcome), + duration=VALUES(duration), + longrepr=VALUES(longrepr), + message=VALUES(message), + skip_label=VALUES(skip_label) + """ + for i in range(0, len(rows), BATCH_SIZE): + cur.executemany(sql, rows[i : i + BATCH_SIZE]) + + +# ----------------------------- +# Entry point +# ----------------------------- +def upload_pytest_results( # pylint: disable=too-many-locals + local_logs_dir: Path, + *, + run_tag: str, + gpu_tag: str, + artifact_uri: str, +) -> None: + """Load per-test JSON reports and upload results to MySQL. + + Flow: + 1) Parse JSONs and gather tests. + 2) Insert a jax_ci_runs row and get run_id. + 3) Ensure all tests exist; get (file,class,test) -> test_id map. + 4) Bulk insert/update jax_ci_results for this run. + """ + report = find_pytest_report_json(local_logs_dir) + if report is None: + report_created_at = None + tests = [] + else: + report_created_at, tests = load_from_pytest_json(report) + + manifest = load_manifest(local_logs_dir) + fields = build_run_fields( + manifest, + artifact_uri=artifact_uri, + run_tag=run_tag, + gpu_tag=gpu_tag, + ) + + conn = connect() + cur = conn.cursor() + try: + existing_run_id = find_existing_run_id( + cur, + fields["github_repository"], + fields["github_ref_name"], + fields["is_nightly"], + fields["run_key"], + fields["combo"], + ) + if existing_run_id is not None: + conn.rollback() + print( + "[DUPLICATE] run already exists: " + f"run_id={existing_run_id} " + f"repo={fields['github_repository']} " + f"ref={fields['github_ref_name']} " + f"is_nightly={fields['is_nightly']} " + f"run_key={fields['run_key']} " + f"combo={fields['combo']}" + ) + return + run_id = insert_run(cur, report_created_at, fields) + + rows = [] + test_id_map = {} + + if tests: + test_id_map = sync_tests_and_get_ids(cur, tests) + for t in tests: + nodeid, outcome, duration, longrepr, message = extract_result_fields(t) + f, c, n = nodeid_parts(nodeid) + test_id = test_id_map[(f, c, n)] + + # Categorize skip reason, with special check for Mosaic + # in filename/testname (including mgpu) + skip_label = None + if outcome == "skipped": + # Check if "mosaic" or "mgpu" is in filename or test name + if ( + "mosaic" in f.lower() + or "mosaic" in n.lower() + or "mgpu" in f.lower() + ): + skip_label = "Mosaic" + else: + skip_label = categorize_reason(longrepr) + + rows.append( + (run_id, test_id, outcome, duration, longrepr, message, skip_label) + ) + batch_insert_results(cur, rows) + conn.commit() + print( + f"[summary] run_id={run_id} total_results={len(rows)} unique_tests={len(test_id_map)}" + ) + # NOTE: optionally print Grafana dashboard URL, e.g. {URL}?var-run_id={id} + except MySQLError as e: + conn.rollback() + # INSERT may still hit a duplicate key (e.g. artifact_uri or logical identity) + # even if the earlier SELECT did not detect it. + if getattr(e, "errno", None) == 1062: + print( + "[DUPLICATE] insert hit unique constraint: " + f"repo={fields['github_repository']} " + f"ref={fields['github_ref_name']} " + f"is_nightly={fields['is_nightly']} " + f"run_key={fields['run_key']} " + f"combo={fields['combo']} " + f"artifact_uri={fields['artifact_uri']}" + ) + return + raise SystemExit(f"MySQL error: {e}") from e + except Exception: + conn.rollback() + raise + finally: + cur.close() + conn.close() + + +def parse_args() -> argparse.Namespace: + """Parse command-line arguments for pytest DB uploader.""" + p = argparse.ArgumentParser(description="Upload pytest report + manifest to MySQL") + p.add_argument("--local_logs_dir", required=True, help="Directory with JSON files") + p.add_argument("--run-tag", required=True, help="Run tag, e.g. ci-run") + p.add_argument("--gpu-tag", required=True, help="GPU architecture, e.g. MI350") + p.add_argument("--artifact_uri", required=True, help="Unique artifact path for CI") + return p.parse_args() + + +if __name__ == "__main__": + args = parse_args() + upload_pytest_results( + Path(args.local_logs_dir), + run_tag=args.run_tag, + gpu_tag=args.gpu_tag, + artifact_uri=args.artifact_uri, + ) From 71a78a10d959909a36335af98d8f0bb42a0b8e94 Mon Sep 17 00:00:00 2001 From: charleshofer Date: Tue, 10 Mar 2026 02:07:42 -0500 Subject: [PATCH 10/14] Add AWS CLI to manylinux image (#346) (cherry picked from commit fd00e03a69b5f3f0248ad33edfdd5f472dcccef9) --- docker/manylinux/Dockerfile.jax-manylinux_2_28-rocm | 11 +++++++++++ .../manylinux/Dockerfile.jax-manylinux_2_28-therock | 10 ++++++++++ 2 files changed, 21 insertions(+) diff --git a/docker/manylinux/Dockerfile.jax-manylinux_2_28-rocm b/docker/manylinux/Dockerfile.jax-manylinux_2_28-rocm index a1ba7f71e0..d983de57fc 100644 --- a/docker/manylinux/Dockerfile.jax-manylinux_2_28-rocm +++ b/docker/manylinux/Dockerfile.jax-manylinux_2_28-rocm @@ -28,3 +28,14 @@ COPY ./docker/manylinux/clang.cfg /usr/lib/llvm-18/bin/clang++.cfg COPY ./docker/manylinux/clang.cfg /usr/lib/llvm-18/bin/clang.cfg COPY ./docker/manylinux/clang.cfg /opt/rocm/llvm/bin/clang++.cfg COPY ./docker/manylinux/clang.cfg /opt/rocm/llvm/bin/clang.cfg + +# Install AWS CLI v2 +RUN --mount=type=cache,target=/var/cache/apt \ + apt-get update && \ + apt-get install -y --no-install-recommends curl unzip && \ + curl -fsSL "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o /tmp/awscliv2.zip && \ + unzip -q /tmp/awscliv2.zip -d /tmp && \ + /tmp/aws/install && \ + rm -rf /tmp/aws /tmp/awscliv2.zip && \ + apt-get clean && rm -rf /var/lib/apt/lists/* + diff --git a/docker/manylinux/Dockerfile.jax-manylinux_2_28-therock b/docker/manylinux/Dockerfile.jax-manylinux_2_28-therock index 6a3b5f2032..41b3c4487d 100644 --- a/docker/manylinux/Dockerfile.jax-manylinux_2_28-therock +++ b/docker/manylinux/Dockerfile.jax-manylinux_2_28-therock @@ -29,3 +29,13 @@ COPY ./docker/manylinux/clang.cfg /usr/lib/llvm-18/bin/clang.cfg COPY ./docker/manylinux/clang.cfg /opt/rocm/llvm/bin/clang++.cfg COPY ./docker/manylinux/clang.cfg /opt/rocm/llvm/bin/clang.cfg +# Install AWS CLI v2 +RUN --mount=type=cache,target=/var/cache/apt \ + apt-get update && \ + apt-get install -y --no-install-recommends curl unzip && \ + curl -fsSL "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o /tmp/awscliv2.zip && \ + unzip -q /tmp/awscliv2.zip -d /tmp && \ + /tmp/aws/install && \ + rm -rf /tmp/aws /tmp/awscliv2.zip && \ + apt-get clean && rm -rf /var/lib/apt/lists/* + From ec0c319ffa1b06679b0f71ed1db47651bf71ad40 Mon Sep 17 00:00:00 2001 From: Pakize Sanal Date: Tue, 10 Mar 2026 12:43:12 -0500 Subject: [PATCH 11/14] Update default FILTER_REPO for scheduled S3 search (#347) (cherry picked from commit 1c21e421fe79531f50c819e1e542a551ec7a3fa4) --- ci/ingest_jax_ci_logs.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/ingest_jax_ci_logs.sh b/ci/ingest_jax_ci_logs.sh index 32e99eadae..60f08f48ec 100755 --- a/ci/ingest_jax_ci_logs.sh +++ b/ci/ingest_jax_ci_logs.sh @@ -2,7 +2,7 @@ set -euo pipefail : "${INPUT_GPU_ARCH:?}" -: "${FILTER_REPO:=ROCm/jax}" +: "${FILTER_REPO:=jax-ml/jax}" : "${FILTER_RUN_ID:=}" ROOT="jax-ci-test-logs/${FILTER_REPO}/" From 3ed721791a4bca35c4c02c1a5c1efe7667d3a0c1 Mon Sep 17 00:00:00 2001 From: Alexandros Theodoridis Date: Tue, 10 Mar 2026 18:55:48 +0100 Subject: [PATCH 12/14] Fix manylinux base image builds (#348) (cherry picked from commit 1c21914d9d27440d608396bae2834cf10f7448bf) --- docker/manylinux/Dockerfile.jax-manylinux_2_28-rocm | 7 +++---- docker/manylinux/Dockerfile.jax-manylinux_2_28-therock | 7 +++---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/docker/manylinux/Dockerfile.jax-manylinux_2_28-rocm b/docker/manylinux/Dockerfile.jax-manylinux_2_28-rocm index d983de57fc..409374e280 100644 --- a/docker/manylinux/Dockerfile.jax-manylinux_2_28-rocm +++ b/docker/manylinux/Dockerfile.jax-manylinux_2_28-rocm @@ -30,12 +30,11 @@ COPY ./docker/manylinux/clang.cfg /opt/rocm/llvm/bin/clang++.cfg COPY ./docker/manylinux/clang.cfg /opt/rocm/llvm/bin/clang.cfg # Install AWS CLI v2 -RUN --mount=type=cache,target=/var/cache/apt \ - apt-get update && \ - apt-get install -y --no-install-recommends curl unzip && \ +RUN --mount=type=cache,target=/var/cache/dnf \ + dnf install -y curl unzip && \ curl -fsSL "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o /tmp/awscliv2.zip && \ unzip -q /tmp/awscliv2.zip -d /tmp && \ /tmp/aws/install && \ rm -rf /tmp/aws /tmp/awscliv2.zip && \ - apt-get clean && rm -rf /var/lib/apt/lists/* + dnf clean all diff --git a/docker/manylinux/Dockerfile.jax-manylinux_2_28-therock b/docker/manylinux/Dockerfile.jax-manylinux_2_28-therock index 41b3c4487d..cd25ab964c 100644 --- a/docker/manylinux/Dockerfile.jax-manylinux_2_28-therock +++ b/docker/manylinux/Dockerfile.jax-manylinux_2_28-therock @@ -30,12 +30,11 @@ COPY ./docker/manylinux/clang.cfg /opt/rocm/llvm/bin/clang++.cfg COPY ./docker/manylinux/clang.cfg /opt/rocm/llvm/bin/clang.cfg # Install AWS CLI v2 -RUN --mount=type=cache,target=/var/cache/apt \ - apt-get update && \ - apt-get install -y --no-install-recommends curl unzip && \ +RUN --mount=type=cache,target=/var/cache/dnf \ + dnf install -y curl unzip && \ curl -fsSL "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o /tmp/awscliv2.zip && \ unzip -q /tmp/awscliv2.zip -d /tmp && \ /tmp/aws/install && \ rm -rf /tmp/aws /tmp/awscliv2.zip && \ - apt-get clean && rm -rf /var/lib/apt/lists/* + dnf clean all From 7ba724bb5474c08c2f81aea01c56d44ec477fee5 Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Tue, 10 Mar 2026 15:45:58 -0500 Subject: [PATCH 13/14] reduce wheel size - use --offload-compress --- jax_rocm_plugin/.bazelrc | 1 + jax_rocm_plugin/build/rocm/jax.bazelrc | 1 + jax_rocm_plugin/build/rocm/tools/build_wheels.py | 1 + 3 files changed, 3 insertions(+) diff --git a/jax_rocm_plugin/.bazelrc b/jax_rocm_plugin/.bazelrc index d3bbea3341..bf353a5a04 100644 --- a/jax_rocm_plugin/.bazelrc +++ b/jax_rocm_plugin/.bazelrc @@ -99,6 +99,7 @@ build:rocm --action_env=TF_ROCM_CLANG="1" build:rocm --action_env=CLANG_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" build:rocm --copt=-Wno-gnu-offsetof-extensions build:rocm --copt=-Qunused-arguments +build:rocm --action_env=HIPCC_COMPILE_FLAGS_APPEND=--offload-compress build:rocm --action_env=TF_HIPCC_CLANG="1" build:asan --strip=never diff --git a/jax_rocm_plugin/build/rocm/jax.bazelrc b/jax_rocm_plugin/build/rocm/jax.bazelrc index 52e4abbdfd..92993337fb 100644 --- a/jax_rocm_plugin/build/rocm/jax.bazelrc +++ b/jax_rocm_plugin/build/rocm/jax.bazelrc @@ -10,6 +10,7 @@ build:rocm --config=native_arch_posix build:rocm --action_env=CLANG_COMPILER_PATH="/lib/llvm-18/bin/clang-18" build:rocm --action_env=ROCM_PATH="/opt/rocm" build:rocm --repo_env=TF_ROCM_AMDGPU_TARGETS="gfx908,gfx90a,gfx942,gfx950" +build:rocm --action_env=HIPCC_COMPILE_FLAGS_APPEND=--offload-compress build:rocm --repo_env=HERMETIC_PYTHON_VERSION=3.12 build:rocm_mgpu --test_tag_filters=jax_test_gpu,multiaccelerator,-config-cuda-only,-manual diff --git a/jax_rocm_plugin/build/rocm/tools/build_wheels.py b/jax_rocm_plugin/build/rocm/tools/build_wheels.py index 5246b01d29..30917e9223 100644 --- a/jax_rocm_plugin/build/rocm/tools/build_wheels.py +++ b/jax_rocm_plugin/build/rocm/tools/build_wheels.py @@ -180,6 +180,7 @@ def build_plugin_wheel( "--output_path=%s" % output_dir, # Use roctracer (v1) instead of rocprofiler-sdk (v3) for profiling. "--bazel_options=--define=xla_rocm_profiler=v1", + "--bazel_options=--action_env=HIPCC_COMPILE_FLAGS_APPEND=--offload-compress", ] # Add clang path if clang is used. From b93ccb82d10e050c5cfff68331884e0317f8484d Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Tue, 10 Mar 2026 15:46:35 -0500 Subject: [PATCH 14/14] reduce wheel size - use generic targets --- ci/jax_rbe/pr_test.sh | 2 +- docker/Dockerfile.base-ubu24 | 2 +- .../Dockerfile.jax-manylinux_2_28-rocm | 2 +- .../Dockerfile.jax-manylinux_2_28-therock | 2 +- jax_rocm_plugin/.bazelrc | 4 ++-- jax_rocm_plugin/build/build.py | 17 +---------------- jax_rocm_plugin/build/rocm/jax.bazelrc | 2 +- jax_rocm_plugin/build/rocm/setup.rocm.sh | 2 +- stack.py | 2 +- 9 files changed, 10 insertions(+), 25 deletions(-) diff --git a/ci/jax_rbe/pr_test.sh b/ci/jax_rbe/pr_test.sh index 2df75dca06..6e59a2e9fc 100755 --- a/ci/jax_rbe/pr_test.sh +++ b/ci/jax_rbe/pr_test.sh @@ -47,7 +47,7 @@ python3 build/build.py build --wheels=jax-rocm-plugin --configure_only --python_ --config=rocm_rbe \ --noremote_accept_cached \ --//jax:build_jaxlib=false \ - --action_env=TF_ROCM_AMDGPU_TARGETS="gfx908,gfx90a,gfx942,gfx950,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201" \ + --action_env=TF_ROCM_AMDGPU_TARGETS="gfx9-generic,gfx9-4-generic,gfx1030,gfx11-generic,gfx12-generic" \ --test_verbose_timeout_warnings \ --test_output=errors \ //tests:core_test_gpu \ diff --git a/docker/Dockerfile.base-ubu24 b/docker/Dockerfile.base-ubu24 index 9251a5f52e..f2a8cfd6f4 100644 --- a/docker/Dockerfile.base-ubu24 +++ b/docker/Dockerfile.base-ubu24 @@ -2,7 +2,7 @@ FROM ubuntu:24.04 ### Container Build Arguments: # The list of target devices to be supported by the JAX ROCm plugin and pjrt. -ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx942 gfx950 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201" +ARG GPU_DEVICE_TARGETS="gfx9-generic gfx9-4-generic gfx1030 gfx11-generic gfx12-generic" # The ROCm version to be used inside the container. ARG ROCM_VERSION # The installation path for ROCm. diff --git a/docker/manylinux/Dockerfile.jax-manylinux_2_28-rocm b/docker/manylinux/Dockerfile.jax-manylinux_2_28-rocm index 409374e280..fb644353d0 100644 --- a/docker/manylinux/Dockerfile.jax-manylinux_2_28-rocm +++ b/docker/manylinux/Dockerfile.jax-manylinux_2_28-rocm @@ -3,7 +3,7 @@ FROM quay.io/pypa/manylinux_2_28_x86_64 ARG ROCM_VERSION ARG ROCM_BUILD_JOB ARG ROCM_BUILD_NUM -ENV GPU_DEVICE_TARGETS="gfx906 gfx908 gfx90a gfx942 gfx950 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201" +ENV GPU_DEVICE_TARGETS="gfx9-generic gfx9-4-generic gfx1030 gfx11-generic gfx12-generic" # Install patchelf and headers for numactl RUN --mount=type=cache,target=/var/cache/dnf \ diff --git a/docker/manylinux/Dockerfile.jax-manylinux_2_28-therock b/docker/manylinux/Dockerfile.jax-manylinux_2_28-therock index cd25ab964c..328a53cf63 100644 --- a/docker/manylinux/Dockerfile.jax-manylinux_2_28-therock +++ b/docker/manylinux/Dockerfile.jax-manylinux_2_28-therock @@ -3,7 +3,7 @@ FROM quay.io/pypa/manylinux_2_28_x86_64 ARG ROCM_VERSION ARG THEROCK_PATH -ENV GPU_DEVICE_TARGETS="gfx906 gfx908 gfx90a gfx942 gfx950 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201" +ENV GPU_DEVICE_TARGETS="gfx9-generic gfx9-4-generic gfx1030 gfx11-generic gfx12-generic" # Install patchelf and headers for numactl RUN --mount=type=cache,target=/var/cache/dnf \ diff --git a/jax_rocm_plugin/.bazelrc b/jax_rocm_plugin/.bazelrc index bf353a5a04..fbae3912b3 100644 --- a/jax_rocm_plugin/.bazelrc +++ b/jax_rocm_plugin/.bazelrc @@ -91,7 +91,7 @@ build:rocm_base --config=clang_local build:rocm_base --crosstool_top=@local_config_rocm//crosstool:toolchain build:rocm_base --define=using_rocm=true --define=using_rocm_hipcc=true build:rocm_base --repo_env TF_NEED_ROCM=1 -build:rocm_base --action_env TF_ROCM_AMDGPU_TARGETS="gfx908,gfx90a,gfx942,gfx950,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201" +build:rocm_base --action_env TF_ROCM_AMDGPU_TARGETS="gfx9-generic,gfx9-4-generic,gfx1030,gfx11-generic,gfx12-generic" # Build with hipcc for ROCm and clang for the host. build:rocm --config=rocm_base @@ -119,7 +119,7 @@ build:asan --linkopt="-lclang_rt.asan-x86_64" build:asan --linkopt="-lclang_rt.asan_cxx-x86_64" build:asan --//build/rocm:sanitizer=asan build:asan --run_under=//build/rocm:sanitizer_wrapper -build:asan --action_env TF_ROCM_AMDGPU_TARGETS="gfx908,gfx90a,gfx942" +build:asan --action_env TF_ROCM_AMDGPU_TARGETS="gfx9-generic,gfx9-4-generic" ############################################################################# # Configuration for running RBE builds and tests diff --git a/jax_rocm_plugin/build/build.py b/jax_rocm_plugin/build/build.py index b8cb454b19..4dc03cd564 100755 --- a/jax_rocm_plugin/build/build.py +++ b/jax_rocm_plugin/build/build.py @@ -240,7 +240,7 @@ def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser): rocm_group.add_argument( "--rocm_amdgpu_targets", type=str, - default="gfx908,gfx90a,gfx942,gfx950,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201", + default="gfx9-generic,gfx9-4-generic,gfx1030,gfx11-generic,gfx12-generic", help="A comma-separated list of ROCm amdgpu targets to support.", ) @@ -622,21 +622,6 @@ async def main(): f'--action_env=ROCM_PATH="{args.rocm_path}"' ) if args.rocm_amdgpu_targets: - rocm_version_str = get_rocm_version(args.rocm_path) - rocm_version = ( - tuple(map(int, rocm_version_str.split("."))) - if rocm_version_str - else None - ) - - targets = args.rocm_amdgpu_targets.split(",") - if rocm_version and rocm_version < (7, 0, 0): - if "gfx950" in targets: - logging.debug("Removing gfx950 since ROCm version is < 7.0.0") - targets.remove("gfx950") - - args.rocm_amdgpu_targets = ",".join(targets) - logging.debug("ROCm AMD GPU targets: %s", args.rocm_amdgpu_targets) wheel_build_command_base.append( f"--action_env=TF_ROCM_AMDGPU_TARGETS={args.rocm_amdgpu_targets}" diff --git a/jax_rocm_plugin/build/rocm/jax.bazelrc b/jax_rocm_plugin/build/rocm/jax.bazelrc index 92993337fb..3e2832c550 100644 --- a/jax_rocm_plugin/build/rocm/jax.bazelrc +++ b/jax_rocm_plugin/build/rocm/jax.bazelrc @@ -9,7 +9,7 @@ build:rocm --config=mkl_open_source_only build:rocm --config=native_arch_posix build:rocm --action_env=CLANG_COMPILER_PATH="/lib/llvm-18/bin/clang-18" build:rocm --action_env=ROCM_PATH="/opt/rocm" -build:rocm --repo_env=TF_ROCM_AMDGPU_TARGETS="gfx908,gfx90a,gfx942,gfx950" +build:rocm --repo_env=TF_ROCM_AMDGPU_TARGETS="gfx9-generic,gfx9-4-generic" build:rocm --action_env=HIPCC_COMPILE_FLAGS_APPEND=--offload-compress build:rocm --repo_env=HERMETIC_PYTHON_VERSION=3.12 diff --git a/jax_rocm_plugin/build/rocm/setup.rocm.sh b/jax_rocm_plugin/build/rocm/setup.rocm.sh index 11fd5948d6..f980eed0df 100755 --- a/jax_rocm_plugin/build/rocm/setup.rocm.sh +++ b/jax_rocm_plugin/build/rocm/setup.rocm.sh @@ -94,6 +94,6 @@ echo "$ROCM_PATH" echo "$GPU_DEVICE_TARGETS" # Ensure the ROCm target list is set up -GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS:-"gfx908 gfx90a gfx942 gfx950 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201"} +GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS:-"gfx9-generic gfx9-4-generic gfx1030 gfx11-generic gfx12-generic"} printf '%s\n' "${GPU_DEVICE_TARGETS}" | tee -a "$ROCM_PATH/bin/target.lst" touch "${ROCM_PATH}/.info/version" diff --git a/stack.py b/stack.py index 797c6480e2..669587501f 100644 --- a/stack.py +++ b/stack.py @@ -17,7 +17,7 @@ MAKE_TEMPLATE = r""" # gfx targets for which XLA and jax custom call kernels are built for -# AMDGPU_TARGETS ?= "gfx908,gfx90a,gfx942,gfx950,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201" +# AMDGPU_TARGETS ?= "gfx9-generic,gfx9-4-generic,gfx1030,gfx11-generic,gfx12-generic" # customize to a single arch for local dev builds to reduce compile time AMDGPU_TARGETS ?= "$(shell rocminfo | grep -o -m 1 'gfx.*')"