Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
097ff78
Remove nvidia_wheel_versions
charleshofer Nov 12, 2025
19a87c5
Make jaxlib targets visible
charleshofer Nov 12, 2025
35b2368
hipblas typedef fix
charleshofer Nov 12, 2025
9bf2dbf
No GPU fail
charleshofer Nov 13, 2025
7d1708e
Wrap HIP inline functions in anonymous namespaces in vendor.h
mminutoli Feb 12, 2026
30d7f94
SWDEV-512768 - Replace hipGetLastError with hipExtGetLastError
dsicarov-amd Jun 10, 2025
a5377e5
Add shared utility function get_rocm_version to test_util.py
charleshofer Nov 14, 2025
db30afa
Fix hipSparse CSR algorithm mappings for ROCm 7
phambinhfin Nov 17, 2025
a44f942
Fix v_pages quantization and adjust test params for ROCm compatibilit…
phambinhfin Nov 19, 2025
01746ea
Address LLVM assertion failure due to a multithreaded use. Update .gi…
Arech8 Nov 26, 2025
f555563
Add skip of test_is_finite() on Cuda (#565)
Arech8 Nov 26, 2025
8cf787a
Add rocm test requirements file (#570)
AratiGanesh Dec 15, 2025
17e6022
Let the unit tests use build.py for setting up Bazel commands for uni…
charleshofer Dec 15, 2025
b600136
adding abort logic to rocm/jax (#590)
gulsumgudukbay Jan 13, 2026
02399d0
Skip is_finite tests on ROCm (not in Triton lowering for jax 0.8.0) (…
phambinhfin Jan 14, 2026
0959b0f
Fix shared memory limit check for ROCm in test_dot (#596)
phambinhfin Jan 14, 2026
b43ca18
Fix Numpy signatures test (#598)
magaonka-amd Jan 14, 2026
cdb5bcb
fix merge arts
Ruturaj4 Jan 18, 2026
de1ef41
Enable RngShardingTests (#644)
gulsumgudukbay Jan 22, 2026
d8179cd
Enable test_variadic_reduce_window on ROCm (#647)
mminutoli Feb 12, 2026
c5016ef
Skip sparse tests on ROCm due to hipSPARSE issue (#652)
magaonka-amd Jan 23, 2026
4e6626e
Update sparse test skip messages in v0.8.2 (#653)
magaonka-amd Jan 23, 2026
694e861
Skip sparse tests on ROCm due to hipSPARSE issue (#652)
magaonka-amd Jan 23, 2026
12e07fb
Update sparse test skip messages in v0.8.2 (#653)
magaonka-amd Jan 23, 2026
76e576f
Skip sparse tests on ROCm due to hipSPARSE issue (#652)
magaonka-amd Jan 23, 2026
237e5ad
Update sparse test skip messages in v0.8.2 (#653)
magaonka-amd Jan 23, 2026
da3a3cc
Enable testMultivariateNormalSingularCovariance on ROCm (#666)
AratiGanesh Jan 28, 2026
06d459e
Skip test_tridiagonal_solve on ROCm due to hipSPARSE numerical errors…
AratiGanesh Jan 28, 2026
c30a449
Update Skip Reason Outputs (#663)
gulsumgudukbay Jan 28, 2026
58ce4e1
Skip sparse tests on ROCm due to hipSPARSE issue (#652)
magaonka-amd Jan 23, 2026
e8307d2
Update sparse test skip messages in v0.8.2 (#653)
magaonka-amd Jan 23, 2026
64ee74e
Skip testCudaArrayInterfaceOnNonCudaFails on ROCm platform (#677)
magaonka-amd Jan 29, 2026
f144132
Skip sparse tests on ROCm due to hipSPARSE issue (#652)
magaonka-amd Jan 23, 2026
9dd1698
Update sparse test skip messages in v0.8.2 (#653)
magaonka-amd Jan 23, 2026
fd1195e
Skip sparse tests on ROCm due to hipSPARSE issue (#652)
magaonka-amd Jan 23, 2026
4af5327
Update sparse test skip messages in v0.8.2 (#653)
magaonka-amd Jan 23, 2026
bfe0208
Add ROCm encoding for test_struct_encoding_determinism (#683)
AratiGanesh Feb 5, 2026
44b8a6c
Remove 'mean' from unsupported params for jnp.var (#689)
magaonka-amd Feb 6, 2026
7bd4a13
Implement approx_tanh for ROCm using OCML tanh function (#691)
magaonka-amd Feb 6, 2026
95ae9fa
Skipping testEighTinyNorm due to hipSolver issues (#697)
AratiGanesh Feb 9, 2026
e355fcd
Abort detection CI workflow (#688)
gulsumgudukbay Feb 20, 2026
d36ebc2
Abort-Detection: Fix halt-for-connection input (#712)
gulsumgudukbay Feb 24, 2026
e793527
Drop call-id input from reusable workflow
psanal35 Feb 25, 2026
c58f3e9
Add upload_rocm_logs.sh to push CI logs and manifest to S3
psanal35 Feb 28, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion .github/workflows/pytest_rocm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,16 @@ on:
description: "GCS location prefix from where the artifacts should be downloaded"
default: 'gs://jax-nightly-artifacts/latest'
type: string
permissions: {}
secrets:
AWS_ACCESS_KEY_ID:
required: true
AWS_SECRET_ACCESS_KEY:
required: true
S3_BUCKET_NAME:
required: true
permissions:
actions: read
contents: read

env:
UV_DEFAULT_INDEX: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple"
Expand Down Expand Up @@ -148,3 +157,25 @@ jobs:
- name: Run Pytest ROCm tests
timeout-minutes: 120
run: ./ci/run_pytest_rocm.sh
- name: Archive test logs
if: always()
run: |
set -euo pipefail
tar -czf logs.tar.gz logs
- name: Configure AWS Credentials
run: |
echo "config AWS cred."
- name: Upload test-artifacts to AMD S3
if: always()
env:
S3_BUCKET_NAME: ${{ secrets.S3_BUCKET_NAME }}
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
GITHUB_TOKEN: ${{ github.token }}
INPUT_PYTHON: ${{ inputs.python }}
INPUT_ROCM_VERSION: ${{ inputs.rocm-version }}
INPUT_RUNNER: ${{ inputs.runner }}
INPUT_ROCM_TAG: ${{ inputs.rocm-tag }}
IS_NIGHTLY: ${{ contains(github.workflow, 'Nightly/Release') && 'nightly' || 'continuous' }}
run: |
./ci/upload_rocm_logs.sh
175 changes: 175 additions & 0 deletions .github/workflows/pytest_rocm_abort.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# CI - Pytest ROCm (Abort Support)
#
# This workflow runs the ROCm tests with Pytest in ROCm GHCR containers,
# using the ROCm `pytest-abort` retry wrapper to detect/retry aborts/crashes.
#
# It can be triggered manually via workflow_dispatch or called by other workflows
# via workflow_call.
#
# It consists of the following job:
# run-tests:
# - Runs in ROCm container (ghcr.io/rocm/jax-base-ubu24-rocm*:latest)
# - Downloads the JAX and jaxlib wheels from GCS, and ROCm plugins from latest release.
# - Executes the `run_pytest_rocm_abort.sh` script, which installs wheel artifacts and
# runs the ROCm tests with Pytest under `pytest-abort-retry`.
name: CI - Pytest ROCm (Abort Support)

on:
workflow_dispatch:
inputs:
runner:
description: "Which runner should the workflow run on?"
type: choice
default: "linux-x86-64-4gpu-amd"
options:
- "linux-x86-64-1gpu-amd"
- "linux-x86-64-4gpu-amd"
- "linux-x86-64-8gpu-amd"
python:
description: "Which Python version to use?"
type: choice
default: "3.11"
options:
- "3.11"
- "3.12"
rocm-version:
description: "Which ROCm version to test?"
type: choice
default: "7.2.0"
options:
- "7.2.0"
rocm-tag:
description: "ROCm tag for container image (e.g., rocm720)"
type: string
default: "rocm720"
jaxlib-version:
description: "Which jaxlib version to use? (head/pypi_latest)"
type: choice
default: "head"
options:
- "head"
- "pypi_latest"
skip-download-jaxlib-and-plugins-from-gcs:
description: "Whether to skip downloading the jaxlib and plugins from GCS (e.g for testing a jax only release)"
type: choice
default: '0'
options:
- '0'
- '1'
gcs_download_uri:
description: "GCS location prefix from where the artifacts should be downloaded"
type: string
default: 'gs://jax-nightly-artifacts/latest'
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: string
default: 'no'
max-worker-restart:
description: "Max xdist worker restarts (passed to pytest --max-worker-restart)"
type: string
default: '50'
workflow_call:
inputs:
runner:
description: "Which runner should the workflow run on?"
type: string
default: "linux-x86-64-4gpu-amd"
python:
description: "Which Python version to use?"
type: string
default: "3.11"
rocm-version:
description: "Which ROCm version to test?"
type: string
default: "7.2.0"
rocm-tag:
description: "ROCm tag for container image (e.g., rocm720)"
type: string
default: "rocm720"
jaxlib-version:
description: "Which jaxlib version to use? (head/pypi_latest)"
type: string
default: "head"
skip-download-jaxlib-and-plugins-from-gcs:
description: "Whether to skip downloading the jaxlib and plugins from GCS (e.g for testing a jax only release)"
default: '0'
type: string
gcs_download_uri:
description: "GCS location prefix from where the artifacts should be downloaded"
default: 'gs://jax-nightly-artifacts/latest'
type: string
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: string
default: 'no'
max-worker-restart:
description: "Max xdist worker restarts (passed to pytest --max-worker-restart)"
type: string
default: '50'

permissions: {}

env:
UV_DEFAULT_INDEX: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple"

jobs:
run-tests:
defaults:
run:
# Set the shell to bash as GitHub actions run with /bin/sh by default
shell: bash
runs-on: ${{ inputs.runner }}
continue-on-error: true
# Run in ROCm GHCR container with GPU access
container:
image: ghcr.io/rocm/jax-base-ubu24.${{ inputs.rocm-tag }}:latest
credentials:
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --shm-size 64G --env-file /etc/podinfo/gha-gpu-isolation-settings
name: "${{ (contains(inputs.runner, '1gpu') && '1gpu') ||
(contains(inputs.runner, '4gpu') && '4gpu') ||
(contains(inputs.runner, '8gpu') && '8gpu') }}, ROCm ${{ inputs.rocm-version }}, py${{ inputs.python }}"

env:
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}"
JAXCI_PYTHON: "python${{ inputs.python }}"
JAXCI_ENABLE_X64: "0"
MAX_WORKER_RESTART: "${{ inputs['max-worker-restart'] }}"

steps:
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
with:
persist-credentials: false
- name: Download JAX ROCm wheels
uses: ./.github/actions/download-jax-rocm-wheels
with:
python: ${{ inputs.python }}
rocm-version: ${{ inputs.rocm-version }}
jaxlib-version: ${{ inputs.jaxlib-version }}
skip-download-jaxlib-and-plugins-from-gcs: ${{ inputs.skip-download-jaxlib-and-plugins-from-gcs }}
gcs_download_uri: ${{ inputs.gcs_download_uri }}
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Install Python dependencies
run: |
$JAXCI_PYTHON -m pip install uv~=0.5.30
$JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Run Pytest ROCm tests (abort support)
timeout-minutes: 180
run: ./ci/run_pytest_rocm_abort.sh
- name: Upload pytest results to artifact
if: always()
uses: actions/upload-artifact@v4
with:
name: logs_abort
path: |
logs_abort/
if-no-files-found: warn
retention-days: 2
overwrite: true
4 changes: 4 additions & 0 deletions .github/workflows/wheel_tests_nightly_release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,11 @@ jobs:
clone_main_xla: 0

run-pytest-rocm:
permissions:
contents: read
actions: read
uses: ./.github/workflows/pytest_rocm.yml
secrets: inherit
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
Expand Down
53 changes: 53 additions & 0 deletions .github/workflows/wheel_tests_nightly_release_abort.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# CI - Wheel Tests (Nightly/Release) (ROCm abort only)
#
# This workflow runs only the ROCm wheel tests using the abort/retry wrapper workflow.
name: CI - Wheel Tests (Nightly/Release) (ROCm abort only)

on:
workflow_dispatch:
inputs:
gcs_download_uri:
description: "GCS location URI from where the artifacts should be downloaded"
required: true
default: 'gs://jax-nightly-artifacts/latest'
type: string
skip-download-jaxlib-and-plugins-from-gcs:
description: "Whether to download only the jax wheel from GCS (e.g for testing a jax only release)"
required: true
default: '0'
type: string
halt-for-connection:
description: 'Should this workflow run wait for a remote connection? (yes/no)'
required: false
default: 'no'
type: string

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true
permissions: {}

env:
UV_DEFAULT_INDEX: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple"

jobs:
run-pytest-rocm:
uses: ./.github/workflows/pytest_rocm_abort.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
runner: ["linux-x86-64-1gpu-amd", "linux-x86-64-4gpu-amd", "linux-x86-64-8gpu-amd"]
python: ["3.11", "3.12", "3.13", "3.14"]
rocm: [
{version: "7.2.0", tag: "rocm720"},
]
name: "Pytest ROCm abort (JAX artifacts version = ${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})"
with:
runner: ${{ matrix.runner }}
python: ${{ matrix.python }}
rocm-version: ${{ matrix.rocm.version }}
rocm-tag: ${{ matrix.rocm.tag }}
jaxlib-version: "head"
skip-download-jaxlib-and-plugins-from-gcs: ${{inputs.skip-download-jaxlib-and-plugins-from-gcs}}
gcs_download_uri: ${{inputs.gcs_download_uri}}
halt-for-connection: ${{inputs.halt-for-connection}}
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,7 @@ jax.iml
/include/
/lib/
/share/

/compile_commands.json
/strace.txt
/external
4 changes: 4 additions & 0 deletions build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,10 @@ async def main():
)

if "rocm" in args.wheels:
if not args.configure_only:
print("ERROR: This repo is not used for building the ROCm JAX plugins. Please use the new plugin repo: https://github.com/ROCm/rocm-jax")
exit(1)

wheel_build_command_base.append("--config=rocm_base")
wheel_build_command_base.append("--config=rocm")
if clang_local:
Expand Down
24 changes: 24 additions & 0 deletions build/rocm-test-requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
absl-py
build
cloudpickle
colorama>=0.4.4
filelock
flatbuffers
hypothesis
mpmath>=1.3
pillow>=10.4.0
# TODO(kanglan): Remove once psutil from portpicker supports python 3.13t
portpicker; python_version<"3.13"
pytest-xdist
pytest-json-report
pytest-html
pytest-csv
pytest-rerunfailures
pytest-html-merger
pytest-reportlog
wheel
rich
setuptools
matplotlib
opt-einsum
auditwheel
1 change: 1 addition & 0 deletions build/test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pillow>=11.3
portpicker
pytest<9.0 # Works around https://github.com/pytest-dev/pytest/issues/13895
pytest-xdist
pytest-json-report
rich
matplotlib
auditwheel
Expand Down
4 changes: 4 additions & 0 deletions ci/run_pytest_rocm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,13 @@ echo "Running ROCm tests..."
# TODO: Add examples directory to test suite (CUDA tests both: tests examples)
# TODO: Verify if CSV/HTML report generation should be kept (unique to ROCm)
# TODO: Verify if log file output should be kept (unique to ROCm)
LOGS_DIR="logs"
mkdir -p "${LOGS_DIR}"
export NPROC=32
"$JAXCI_PYTHON" -m pytest -n $num_processes --tb=short \
--json-report --json-report-file=${LOGS_DIR}/pytest_results.json \
tests \
--deselect=tests/multi_device_test.py::MultiDeviceTest::test_computation_follows_data \
--deselect=tests/multiprocess_gpu_test.py::MultiProcessGpuTest::test_distributed_jax_visible_devices \
--deselect=tests/compilation_cache_test.py::CompilationCacheTest::test_task_using_cache_metric
#TODO: --log-file=pytest_output.log
Loading
Loading