Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
5aba9fd
fix the name of the container image (#283)
kliegeois Jan 26, 2026
6337f29
Fix rocm performance script and prepare for 0.9.0 jax plugin release
Ruturaj4 Jan 25, 2026
b6250a3
Update MaxText performance workload to support JAX 0.8.2+
psanal35 Jan 28, 2026
c7c6379
Update Llama custom performance workload to support JAX 0.8.2+
psanal35 Jan 30, 2026
e06da0a
Fix metrics parsing in Llama custom workload and DB upload
psanal35 Feb 2, 2026
2f888f8
Use wheel 0.46.3 in fixwheel.py (#290)
charleshofer Feb 4, 2026
d43caec
Add pytest-results-to-db workflow (#296)
psanal35 Feb 7, 2026
32deb51
Create jax plugin wheels on the fly (#275)
alekstheod Feb 9, 2026
455676c
Phase 1.1: Remove Ubuntu 22.04 base Docker image
mminutoli Jan 26, 2026
07e89fe
Phase 1.3: Update Ubuntu 24.04 base image for multi-Python support
mminutoli Jan 26, 2026
0804697
Phase 2.1: Remove Ubuntu 22.04 JAX Docker image
mminutoli Jan 26, 2026
7c806cb
Phase 2.2: Update Ubuntu 24.04 JAX image for multi-Python support
mminutoli Jan 26, 2026
6df948f
Phase 3: Update wheel building default Python versions
mminutoli Jan 27, 2026
18ebabe
Phase 4: Update CI workflows for Ubuntu 24.04 and multi-Python support
mminutoli Jan 30, 2026
4c143c5
Phase 5: Update build scripts for Ubuntu 24.04 and multi-Python support
mminutoli Jan 30, 2026
297e4cd
Add GitHub CLI and Google Cloud CLI to base container
mminutoli Jan 30, 2026
2b230fb
Configure gcloud to access public GCS buckets without authentication
mminutoli Jan 31, 2026
f84a969
Add Python development packages for all Python versions
mminutoli Jan 31, 2026
9ca620b
Introduce ci pr check tests pipeline (#297)
alekstheod Feb 11, 2026
ed4fa9b
Update runner labels in Llama perf workflow
psanal35 Feb 11, 2026
bbc226d
Fix test ignore list (#307)
alekstheod Feb 12, 2026
c95522a
Add flaky and timeouts handling (#308)
alekstheod Feb 12, 2026
9218490
Optimize build_wheels.py: build PJRT wheel only once
Ruturaj4 Feb 11, 2026
b0ca098
Introduce asan build (#303)
alekstheod Feb 16, 2026
4462a9c
Pre-build manylinux and devsetup docker images (#209)
charleshofer Feb 16, 2026
a7bb47c
Fix broken nightly workflow file (#313)
charleshofer Feb 17, 2026
17921d2
Add fake nvidia_versions repo and remove a patch (#314)
alekstheod Feb 18, 2026
d69a73f
Log in before wheel builds to avoid public GHCR rate limit (#316)
charleshofer Feb 18, 2026
bf1e6ed
Use git_repository for XLA and JAX dependencies
Ruturaj4 Feb 16, 2026
05732c7
Restore patches (#312)
alekstheod Feb 18, 2026
522e97e
fixes
Ruturaj4 Feb 18, 2026
35db335
reintroduce the env vars
Ruturaj4 Feb 18, 2026
e0861c0
env fixes
Ruturaj4 Feb 18, 2026
04f0dcd
Merge branch 'master' into rocm-jaxlib-v0.9.0
Ruturaj4 Feb 18, 2026
aa52c69
fix sha
Ruturaj4 Feb 18, 2026
8d95c2e
remove the plugin wheels patch
Ruturaj4 Feb 18, 2026
37c008a
fix build-and-test workflow
Ruturaj4 Feb 23, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions .dockerignore

This file was deleted.

82 changes: 77 additions & 5 deletions .github/workflows/build-base-docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,22 @@ on:
- 'build/ci_build'
- '.github/workflows/build-base-docker.yml'
- 'docker/Dockerfile*'
- 'docker/manylinux/*'
push:
branches:
- 'master'
paths:
- 'build/ci_build'
- '.github/workflows/build-base-docker.yml'
- 'docker/Dockerfile*'
- 'docker/manylinux/*'

jobs:
build-base-images:
runs-on: ${{ matrix.runner-label }}
strategy:
fail-fast: false
matrix:
ubuntu-version: ["22", "24"]
rocm-version: ["7.1.1", "7.2.0"]
install-llvm: [true, false]
include:
Expand Down Expand Up @@ -64,7 +72,7 @@ jobs:
--rocm-version="${{ matrix.rocm-version }}" \
$BUILD_ARGS \
build_base_dockers \
--filter="ubu${{ matrix.ubuntu-version }}" \
--filter="ubu24" \
${{ matrix.install-llvm && '--install-llvm --llvm-version 18' || '' }}
- name: Authenticate to GitHub Container Registry
run: |
Expand All @@ -73,16 +81,15 @@ jobs:
- name: Push docker images
env:
ROCM_VERSION: ${{ matrix.rocm-version }}
UBUNTU_VERSION: ${{ matrix.ubuntu-version }}
INSTALL_LLVM: ${{ matrix.install-llvm }}
run: |
# Construct image tag based on matrix values
# ROCm version tag removes dots (7.1.1 -> 711, 7.2.0 -> 720)
rocm_tag="rocm${ROCM_VERSION//.}"
if [ "$INSTALL_LLVM" = "true" ]; then
image_tag="jax-dev-ubu${UBUNTU_VERSION}.${rocm_tag}"
image_tag="jax-dev-ubu24.${rocm_tag}"
else
image_tag="jax-base-ubu${UBUNTU_VERSION}.${rocm_tag}"
image_tag="jax-base-ubu24.${rocm_tag}"
fi

# Push with commit SHA tag
Expand All @@ -98,3 +105,68 @@ jobs:
docker tag "${image_tag}" "${ghcr_image_latest}"
docker push "${ghcr_image_latest}"
fi

build-manylinux-builder-images:
runs-on: ${{ matrix.runner-label }}
strategy:
fail-fast: false
matrix:
rocm-version: ["7.1.1", "7.2.0"]
include:
- rocm-version: "7.1.1"
runner-label: "linux-x86-64-1gpu-amd"
- rocm-version: "7.2.0"
runner-label: "linux-x86-64-1gpu-amd"
steps:
- name: Clean up old runs
run: |
ls -lah
# Make sure that we own all of the files so that we have permissions to delete them
docker run --rm -v "./:/rocm-jax" ubuntu \
/bin/bash -c "shopt -s dotglob; chown -R $UID /rocm-jax/* || true"
# Remove any old work directories from this machine
rm -rf * || true
ls -lah
# Clean up any docker stuff that's more than a week old
docker system prune -a --filter "until=168h"
# Stop any containers running for more than 12 hours. No CI job should take this long.
docker ps --format="{{.RunningFor}} {{.Names}}" | grep hours \
| awk -F: '{if($1>12)print$1}' | awk ' {print $4} ' | xargs docker stop || true
- uses: actions/checkout@v4
- name: Build docker images
run: |
python3 build/ci_build \
--rocm-version="${{ matrix.rocm-version }}" \
--rocm-build-job="${{ matrix.rocm-build-job }}" \
--rocm-build-num="${{ matrix.rocm-build-num }}" \
build_manylinux_dockers
- name: Login to GitHub Container Registry
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Push docker images
env:
ROCM_VERSION: ${{ matrix.rocm-version }}
ROCM_BUILD_JOB: >-
${{ matrix.rocm-build-job && format('-{0}', inputs.rocm-build-job) || '' }}
ROCM_BUILD_NUM: >-
${{ matrix.rocm-build-num && format('-{0}', inputs.rocm-build-num) || '' }}
run: |
image_tag="ghcr.io/rocm/jax-manylinux_2_28-rocm-${ROCM_VERSION}${ROCM_BUILD_JOB}${ROCM_BUILD_NUM}"

# Push with commit SHA tag
sha_image_tag="${image_tag}:${GITHUB_SHA}"
echo "Image name (SHA): ${sha_image_tag}"
docker tag "${image_tag}" "${sha_image_tag}"
docker push "${sha_image_tag}"

# Push with latest tag (only for schedule and workflow_dispatch, not PRs)
if [ "${{ github.event_name }}" != "pull_request" ]; then
latest_image_tag="${image_tag}:latest"
echo "Image Name (latest): ${latest_image_tag}"
docker tag "${image_tag}" "${latest_image_tag}"
docker push "${latest_image_tag}"
fi

8 changes: 0 additions & 8 deletions .github/workflows/build-docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,26 +73,18 @@ jobs:
EXTRA_CR_TAG: ${{ inputs.extra-cr-tag }}
ROCM_VERSION: ${{ inputs.rocm-version }}
run: |
ubu22_img="ghcr.io/rocm/jax-ubu22.rocm${ROCM_VERSION//.}:${GITHUB_SHA}"
ubu24_img="ghcr.io/rocm/jax-ubu24.rocm${ROCM_VERSION//.}:${GITHUB_SHA}"
echo "Ubuntu 22 image name: ${ubu22_img}"
echo "Ubuntu 24 image name: ${ubu24_img}"
docker tag "jax-ubu22.rocm${ROCM_VERSION//.}" "${ubu22_img}"
docker tag "jax-ubu24.rocm${ROCM_VERSION//.}" "${ubu24_img}"
docker push "${ubu22_img}"
docker push "${ubu24_img}"
- name: Push extra tags
if: ${{ inputs.extra-cr-tag }}
env:
EXTRA_CR_TAG: ${{ inputs.extra-cr-tag }}
ROCM_VERSION: ${{ inputs.rocm-version }}
run: |
ubu22_img="ghcr.io/rocm/jax-ubu22.rocm${ROCM_VERSION//.}:${EXTRA_CR_TAG}"
ubu24_img="ghcr.io/rocm/jax-ubu24.rocm${ROCM_VERSION//.}:${EXTRA_CR_TAG}"
echo "Ubuntu 22 image name: ${ubu22_img}"
echo "Ubuntu 24 image name: ${ubu24_img}"
docker tag "jax-ubu22.rocm${ROCM_VERSION//.}" "${ubu22_img}"
docker tag "jax-ubu24.rocm${ROCM_VERSION//.}" "${ubu24_img}"
docker push "${ubu22_img}"
docker push "${ubu24_img}"

11 changes: 10 additions & 1 deletion .github/workflows/build-wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ on:
required: false
type: string
default: 'rocm-jaxlib-v0.9.0'
builder-image:
required: false
type: string
default: ''
secrets:
rbe_ci_key:
required: true
Expand Down Expand Up @@ -71,6 +75,10 @@ jobs:
repository: ${{ inputs.jax-repo }}
ref: ${{ inputs.jax-ref }}
path: jax
- name: Authenticate to GitHub Container Registry
run: |
echo "${{ secrets.GITHUB_TOKEN }}" \
| docker login ghcr.io -u ${{ github.actor }} --password-stdin
- name: Get RBE cluster keys
env:
RBE_CI_CERT: ${{ secrets.rbe_ci_cert }}
Expand All @@ -88,7 +96,8 @@ jobs:
--rocm-build-num="${{ inputs.rocm-build-num }}" \
--jax-source-dir="./jax" \
dist_wheels \
--rbe
--rbe \
--builder-image="${{ inputs.builder-image }}"
- name: Archive plugin wheels
uses: actions/upload-artifact@v4
with:
Expand Down
96 changes: 96 additions & 0 deletions .github/workflows/ci-ut.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
name: CI Unit Tests

on:
push:
branches:
- master
- 'rocm-jaxlib-v*'
pull_request:
branches:
- master
- 'rocm-jaxlib-v*'
workflow_dispatch:

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

permissions:
contents: read

jobs:
build-and-test:
name: build-and-test (${{ matrix.mode.name }})
runs-on: linux-x86-64-4gpu-amd-gfx942
strategy:
fail-fast: false
matrix:
mode:
- {name: "py3.11", python_version: "3.11", config: ""}
- {name: "py3.12", python_version: "3.12", config: ""}
- {name: "py3.13", python_version: "3.13", config: ""}
- {name: "py3.14", python_version: "3.14", config: ""}
- {name: "asan", python_version: "3.11", config: "--config=asan"}
container:
# note this image shall match the one defined in platform/linux:tf_linux_gpu
image: rocm/tensorflow-build@sha256:7fcfbd36b7ac8f6b0805b37c4248e929e31cf5ee3af766c8409dd70d5ab65faa
options: >-
-w ${{ github.workspace }}/jax_rocm_plugin
--device=/dev/kfd
--device=/dev/dri
--group-add video
--cap-add=SYS_PTRACE
--security-opt seccomp=unconfined
--shm-size 16G
defaults:
run:
working-directory: jax_rocm_plugin
steps:
- name: Checkout plugin repo
uses: actions/checkout@v4

- name: Get RBE cluster keys
env:
RBE_CI_CERT: ${{ secrets.RBE_CI_CERT }}
RBE_CI_KEY: ${{ secrets.RBE_CI_KEY }}
run: |
echo "$RBE_CI_CERT" >> ci-cert.crt
echo "$RBE_CI_KEY" >> ci-cert.key

- name: Run single-GPU unit tests
if: always()
run: |
bash build/rocm/ci_run_jax_ut.sh \
--config=rocm_sgpu \
--config=rocm_rbe \
--repo_env=HERMETIC_PYTHON_VERSION=${{ matrix.mode.python_version }} \
${{ matrix.mode.config }} \
--curses=no \
--color=yes \
-- \
@jax//tests:gpu_tests \
@jax//tests:backend_independent_tests \
$(build/rocm/targets_to_ignore.sh)

- name: Run multi-GPU unit tests
if: always()
run: |
bash build/rocm/ci_run_jax_ut.sh \
--config=rocm_mgpu \
--config=rocm_rbe \
--repo_env=HERMETIC_PYTHON_VERSION=${{ matrix.mode.python_version }} \
${{ matrix.mode.config }} \
--curses=no \
--color=yes \
--strategy=TestRunner=local \
-- \
@jax//tests:gpu_tests \
@jax//tests:backend_independent_tests \
$(build/rocm/targets_to_ignore.sh)

- name: Upload logs to artifact
if: always()
uses: actions/upload-artifact@v4
with:
name: logs-rbe-py${{ matrix.mode.name }}
path: jax_rocm_plugin/logs/
9 changes: 4 additions & 5 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@ jobs:
rocm-version: ["7.1.1"]
uses: ./.github/workflows/build-wheels.yml
with:
# TODO: Add back Python 3.13 when we're ready to move to a more recent version of XLA. 3.13
# fails with a complaint abou the pipes module.
python-versions: "3.11,3.12"
python-versions: "3.11,3.12,3.13,3.14"
rocm-version: ${{ matrix.rocm-version }}
runner-label: '["linux-x86-64-1gpu-amd"]'
builder-image: "search"
secrets:
rbe_ci_cert: ${{ secrets.RBE_CI_CERT }}
rbe_ci_key: ${{ secrets.RBE_CI_KEY }}
Expand Down Expand Up @@ -86,7 +85,7 @@ jobs:
ROCM_VERSION: ${{ matrix.rocm-version }}
run: |
docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video \
"ghcr.io/rocm/jax-ubu22.rocm${ROCM_VERSION//.}:${GITHUB_SHA}" \
"ghcr.io/rocm/jax-ubu24.rocm${ROCM_VERSION//.}:${GITHUB_SHA}" \
rocm-smi -a || true
- name: Download wheel artifacts
uses: actions/download-artifact@v4
Expand All @@ -105,6 +104,6 @@ jobs:
# TODO: Add the tests/linalg_test.py test back once we fix the XLAClient thing.
run: |
python3 build/ci_build test \
"ghcr.io/rocm/jax-ubu22.rocm${ROCM_VERSION//.}:${GITHUB_SHA}" \
"ghcr.io/rocm/jax-ubu24.rocm${ROCM_VERSION//.}:${GITHUB_SHA}" \
--test-cmd "bash ci/jax_rbe/pr_setup.sh && ci/jax_rbe/pr_test.sh 0.9.0 3.12"

Loading
Loading