Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
2faed86
Remove nvidia_wheel_versions
charleshofer Nov 12, 2025
6ba78d2
Make jaxlib targets visible
charleshofer Nov 12, 2025
d3f32d9
hipblas typedef fix
charleshofer Nov 12, 2025
e874827
No GPU fail
charleshofer Nov 13, 2025
a3338bb
Wrap HIP inline functions in anonymous namespaces in vendor.h
mminutoli Feb 12, 2026
ce43299
SWDEV-512768 - Replace hipGetLastError with hipExtGetLastError
dsicarov-amd Jun 10, 2025
c3e2357
Add shared utility function get_rocm_version to test_util.py
charleshofer Nov 14, 2025
d6acb72
Fix hipSparse CSR algorithm mappings for ROCm 7
phambinhfin Nov 17, 2025
c3cf5d3
Fix v_pages quantization and adjust test params for ROCm compatibilit…
phambinhfin Nov 19, 2025
04b3d82
Address LLVM assertion failure due to a multithreaded use. Update .gi…
Arech8 Nov 26, 2025
d9103b3
Add skip of test_is_finite() on Cuda (#565)
Arech8 Nov 26, 2025
3a63ef3
Add rocm test requirements file (#570)
AratiGanesh Dec 15, 2025
11a084b
Let the unit tests use build.py for setting up Bazel commands for uni…
charleshofer Dec 15, 2025
5a2d899
adding abort logic to rocm/jax (#590)
gulsumgudukbay Jan 13, 2026
becc59e
Skip is_finite tests on ROCm (not in Triton lowering for jax 0.8.0) (…
phambinhfin Jan 14, 2026
cbfb842
Fix shared memory limit check for ROCm in test_dot (#596)
phambinhfin Jan 14, 2026
c979baf
Fix Numpy signatures test (#598)
magaonka-amd Jan 14, 2026
e5fddf8
fix merge arts
Ruturaj4 Jan 18, 2026
ce0783a
Enable RngShardingTests (#644)
gulsumgudukbay Jan 22, 2026
472b227
Enable test_variadic_reduce_window on ROCm (#647)
mminutoli Feb 12, 2026
69381b1
Skip sparse tests on ROCm due to hipSPARSE issue (#652)
magaonka-amd Jan 23, 2026
2988d2a
Update sparse test skip messages in v0.8.2 (#653)
magaonka-amd Jan 23, 2026
3f3518f
Skip sparse tests on ROCm due to hipSPARSE issue (#652)
magaonka-amd Jan 23, 2026
2045fd0
Update sparse test skip messages in v0.8.2 (#653)
magaonka-amd Jan 23, 2026
b798b3b
Skip sparse tests on ROCm due to hipSPARSE issue (#652)
magaonka-amd Jan 23, 2026
c7ccce2
Update sparse test skip messages in v0.8.2 (#653)
magaonka-amd Jan 23, 2026
2e76cae
Enable testMultivariateNormalSingularCovariance on ROCm (#666)
AratiGanesh Jan 28, 2026
25bc14d
Skip test_tridiagonal_solve on ROCm due to hipSPARSE numerical errors…
AratiGanesh Jan 28, 2026
55c86bc
Update Skip Reason Outputs (#663)
gulsumgudukbay Jan 28, 2026
87d7111
Skip sparse tests on ROCm due to hipSPARSE issue (#652)
magaonka-amd Jan 23, 2026
3f26dfe
Update sparse test skip messages in v0.8.2 (#653)
magaonka-amd Jan 23, 2026
9df214c
Skip testCudaArrayInterfaceOnNonCudaFails on ROCm platform (#677)
magaonka-amd Jan 29, 2026
7382469
Skip sparse tests on ROCm due to hipSPARSE issue (#652)
magaonka-amd Jan 23, 2026
f3efc4d
Update sparse test skip messages in v0.8.2 (#653)
magaonka-amd Jan 23, 2026
991bc2f
Skip sparse tests on ROCm due to hipSPARSE issue (#652)
magaonka-amd Jan 23, 2026
abfda2d
Update sparse test skip messages in v0.8.2 (#653)
magaonka-amd Jan 23, 2026
30f61f4
Add ROCm encoding for test_struct_encoding_determinism (#683)
AratiGanesh Feb 5, 2026
c632678
Remove 'mean' from unsupported params for jnp.var (#689)
magaonka-amd Feb 6, 2026
cd072cd
Implement approx_tanh for ROCm using OCML tanh function (#691)
magaonka-amd Feb 6, 2026
90a3578
Skipping testEighTinyNorm due to hipSolver issues (#697)
AratiGanesh Feb 9, 2026
f7a6407
Abort detection CI workflow (#688)
gulsumgudukbay Feb 20, 2026
1604453
Implement Mosaic GPU detection and Auto-Skips
gulsumgudukbay Feb 21, 2026
95a4d4a
Fix indentation for adding pytest marker
gulsumgudukbay Feb 21, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 162 additions & 0 deletions .github/workflows/pytest_rocm_abort.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# CI - Pytest ROCm (Abort Support)
#
# This workflow runs the ROCm tests with Pytest in ROCm GHCR containers,
# using the ROCm `pytest-abort` retry wrapper to detect/retry aborts/crashes.
#
# It can be triggered manually via workflow_dispatch or called by other workflows
# via workflow_call.
#
# It consists of the following job:
# run-tests:
# - Runs in ROCm container (ghcr.io/rocm/jax-base-ubu24-rocm*:latest)
# - Downloads the JAX and jaxlib wheels from GCS, and ROCm plugins from latest release.
# - Executes the `run_pytest_rocm_abort.sh` script, which installs wheel artifacts and
# runs the ROCm tests with Pytest under `pytest-abort-retry`.
name: CI - Pytest ROCm (Abort Support)

on:
workflow_dispatch:
inputs:
runner:
description: "Which runner should the workflow run on?"
type: choice
default: "linux-x86-64-4gpu-amd"
options:
- "linux-x86-64-1gpu-amd"
- "linux-x86-64-4gpu-amd"
- "linux-x86-64-8gpu-amd"
python:
description: "Which Python version to use?"
type: choice
default: "3.11"
options:
- "3.11"
- "3.12"
rocm-version:
description: "Which ROCm version to test?"
type: choice
default: "7.2.0"
options:
- "7.2.0"
rocm-tag:
description: "ROCm tag for container image (e.g., rocm720)"
type: string
default: "rocm720"
jaxlib-version:
description: "Which jaxlib version to use? (head/pypi_latest)"
type: choice
default: "head"
options:
- "head"
- "pypi_latest"
skip-download-jaxlib-and-plugins-from-gcs:
description: "Whether to skip downloading the jaxlib and plugins from GCS (e.g for testing a jax only release)"
type: choice
default: '0'
options:
- '0'
- '1'
gcs_download_uri:
description: "GCS location prefix from where the artifacts should be downloaded"
type: string
default: 'gs://jax-nightly-artifacts/latest'
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: string
default: 'no'
workflow_call:
inputs:
runner:
description: "Which runner should the workflow run on?"
type: string
default: "linux-x86-64-4gpu-amd"
python:
description: "Which Python version to use?"
type: string
default: "3.11"
rocm-version:
description: "Which ROCm version to test?"
type: string
default: "7.2.0"
rocm-tag:
description: "ROCm tag for container image (e.g., rocm720)"
type: string
default: "rocm720"
jaxlib-version:
description: "Which jaxlib version to use? (head/pypi_latest)"
type: string
default: "head"
skip-download-jaxlib-and-plugins-from-gcs:
description: "Whether to skip downloading the jaxlib and plugins from GCS (e.g for testing a jax only release)"
default: '0'
type: string
gcs_download_uri:
description: "GCS location prefix from where the artifacts should be downloaded"
default: 'gs://jax-nightly-artifacts/latest'
type: string

permissions: {}

env:
UV_DEFAULT_INDEX: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple"

jobs:
run-tests:
defaults:
run:
# Set the shell to bash as GitHub actions run with /bin/sh by default
shell: bash
runs-on: ${{ inputs.runner }}
continue-on-error: true
# Run in ROCm GHCR container with GPU access
container:
image: ghcr.io/rocm/jax-base-ubu24.${{ inputs.rocm-tag }}:latest
credentials:
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --shm-size 64G --env-file /etc/podinfo/gha-gpu-isolation-settings
name: "${{ (contains(inputs.runner, '1gpu') && '1gpu') ||
(contains(inputs.runner, '4gpu') && '4gpu') ||
(contains(inputs.runner, '8gpu') && '8gpu') }}, ROCm ${{ inputs.rocm-version }}, py${{ inputs.python }}"

env:
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}"
JAXCI_PYTHON: "python${{ inputs.python }}"
JAXCI_ENABLE_X64: "0"

steps:
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
with:
persist-credentials: false
- name: Download JAX ROCm wheels
uses: ./.github/actions/download-jax-rocm-wheels
with:
python: ${{ inputs.python }}
rocm-version: ${{ inputs.rocm-version }}
jaxlib-version: ${{ inputs.jaxlib-version }}
skip-download-jaxlib-and-plugins-from-gcs: ${{ inputs.skip-download-jaxlib-and-plugins-from-gcs }}
gcs_download_uri: ${{ inputs.gcs_download_uri }}
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Install Python dependencies
run: |
$JAXCI_PYTHON -m pip install uv~=0.5.30
$JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Run Pytest ROCm tests (abort support)
timeout-minutes: 180
run: ./ci/run_pytest_rocm_abort.sh
- name: Upload pytest results to artifact
if: always()
uses: actions/upload-artifact@v4
with:
name: logs_abort
path: |
logs_abort/
if-no-files-found: warn
retention-days: 2
overwrite: true
53 changes: 53 additions & 0 deletions .github/workflows/wheel_tests_nightly_release_abort.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# CI - Wheel Tests (Nightly/Release) (ROCm abort only)
#
# This workflow runs only the ROCm wheel tests using the abort/retry wrapper workflow.
name: CI - Wheel Tests (Nightly/Release) (ROCm abort only)

on:
workflow_dispatch:
inputs:
gcs_download_uri:
description: "GCS location URI from where the artifacts should be downloaded"
required: true
default: 'gs://jax-nightly-artifacts/latest'
type: string
skip-download-jaxlib-and-plugins-from-gcs:
description: "Whether to download only the jax wheel from GCS (e.g for testing a jax only release)"
required: true
default: '0'
type: string
halt-for-connection:
description: 'Should this workflow run wait for a remote connection? (yes/no)'
required: false
default: 'no'
type: string

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true
permissions: {}

env:
UV_DEFAULT_INDEX: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple"

jobs:
run-pytest-rocm:
uses: ./.github/workflows/pytest_rocm_abort.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
runner: ["linux-x86-64-1gpu-amd", "linux-x86-64-4gpu-amd", "linux-x86-64-8gpu-amd"]
python: ["3.11", "3.12", "3.13", "3.14"]
rocm: [
{version: "7.2.0", tag: "rocm720"},
]
name: "Pytest ROCm abort (JAX artifacts version = ${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})"
with:
runner: ${{ matrix.runner }}
python: ${{ matrix.python }}
rocm-version: ${{ matrix.rocm.version }}
rocm-tag: ${{ matrix.rocm.tag }}
jaxlib-version: "head"
skip-download-jaxlib-and-plugins-from-gcs: ${{inputs.skip-download-jaxlib-and-plugins-from-gcs}}
gcs_download_uri: ${{inputs.gcs_download_uri}}
halt-for-connection: ${{inputs.halt-for-connection}}
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,7 @@ jax.iml
/include/
/lib/
/share/

/compile_commands.json
/strace.txt
/external
4 changes: 4 additions & 0 deletions build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,10 @@ async def main():
)

if "rocm" in args.wheels:
if not args.configure_only:
print("ERROR: This repo is not used for building the ROCm JAX plugins. Please use the new plugin repo: https://github.com/ROCm/rocm-jax")
exit(1)

wheel_build_command_base.append("--config=rocm_base")
wheel_build_command_base.append("--config=rocm")
if clang_local:
Expand Down
24 changes: 24 additions & 0 deletions build/rocm-test-requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
absl-py
build
cloudpickle
colorama>=0.4.4
filelock
flatbuffers
hypothesis
mpmath>=1.3
pillow>=10.4.0
# TODO(kanglan): Remove once psutil from portpicker supports python 3.13t
portpicker; python_version<"3.13"
pytest-xdist
pytest-json-report
pytest-html
pytest-csv
pytest-rerunfailures
pytest-html-merger
pytest-reportlog
wheel
rich
setuptools
matplotlib
opt-einsum
auditwheel
Loading
Loading