Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
121 commits
Select commit Hold shift + click to select a range
067dd46
[hijax] optimize code, much better
mattjj Feb 26, 2026
dd11809
Even more Pyrefly fixes
superbobry Mar 2, 2026
3db21b7
[Mosaic GPU] Add support for s8 and u8 warp-level MMAs
apaszke Mar 2, 2026
c097ebe
Add TPU v7 to `test_vmem_oom_error_message_basics`.
yueshengys Mar 2, 2026
1bd107b
Bump actions/upload-artifact from 6.0.0 to 7.0.0
dependabot[bot] Mar 2, 2026
a4dd1dd
Bump actions/checkout from 6.0.0 to 6.0.2
dependabot[bot] Mar 2, 2026
c0c9843
[NFC] Remove and modify some out-of-date test skips.
yueshengys Mar 2, 2026
81bb26a
Merge pull request #35497 from superbobry:pyrefly
Google-ML-Automation Mar 2, 2026
6bbd83d
Skip lax_numpy_reducers_test reducer tests with where.
danielsuo Mar 2, 2026
2260303
Skip installing collecting profile requirements under 3.13-nogil.
danielsuo Mar 2, 2026
c9de85a
[pyrefly] fix typing errors in jax/_src/array.py
jakevdp Mar 2, 2026
1379250
[Mosaic:GPU] Fix the barrier value creation.
PatriosTheGreat Mar 2, 2026
1cc1755
Merge pull request #35518 from jax-ml:dependabot/github_actions/actio…
Google-ML-Automation Mar 2, 2026
5ef2c50
[pyrefly] fix errors in jax/_src/shard_map.py
jakevdp Mar 2, 2026
23d35a9
Merge pull request #35522 from jax-ml:dependabot/github_actions/actio…
Google-ML-Automation Mar 2, 2026
b80f8e8
[pyrefly] fix pyrefly errors in jax._src.core
jakevdp Mar 2, 2026
926bf97
[typing] fix pyi type signature for argsort
jakevdp Mar 2, 2026
9cff4d4
Merge pull request #35532 from jakevdp:argsort-type
Google-ML-Automation Mar 2, 2026
a39b86f
Merge pull request #35450 from jakevdp:pyrefly-core
Google-ML-Automation Mar 2, 2026
e3b1f06
Merge pull request #35442 from mattjj:fix-ugly-code
Google-ML-Automation Mar 2, 2026
153f214
Update XLA dependency to use revision http://github.com/openxla/xla/c…
Google-ML-Automation Mar 2, 2026
208b29e
Reverts 0b7d6672a7e69616e84048752b6d5e94ffeac8ce
Google-ML-Automation Mar 2, 2026
a162bb8
Merge pull request #35530 from jakevdp:pyrefly-array
Google-ML-Automation Mar 2, 2026
c318de0
[Pallas/SC] Fix bug where axis name for SC mesh is always "core"
sharadmv Mar 2, 2026
9e0dbb4
Use `itertools.chain` directly when the iterable is a literal
superbobry Mar 3, 2026
c8890f6
[pyrefly] fix errors in jax._src.interpreters.pxla
jakevdp Mar 3, 2026
ac89644
Make main process_*() arguments position-only.
jakevdp Mar 2, 2026
7fbcdbb
Merge pull request #35531 from jakevdp:pyrefly-shard-map
Google-ML-Automation Mar 3, 2026
c2aba31
[mosaic:gpu] Guard nvvm.elect_sync in blackwell examples.
danielsuo Mar 3, 2026
c2570db
Merge pull request #35527 from jakevdp:pyrefly-pxla
Google-ML-Automation Mar 3, 2026
82c848d
Merge pull request #35513 from jakevdp:process-args
Google-ML-Automation Mar 3, 2026
c5bf3ae
[pallas] Fix: Pallas dot_general on TPU gives wrong results for unsig…
rdyro Mar 3, 2026
69a609b
[Pallas/SC] Enable closing over scalars in SCS kernels
sharadmv Mar 3, 2026
5bb5f77
[mgpu] Fix race condition in `GetOrCreateKernel`.
chr1sj0nes Mar 3, 2026
ea4cf33
Update XLA dependency to use revision http://github.com/openxla/xla/c…
Google-ML-Automation Mar 3, 2026
dd5aeab
[mGPU] Fix race condition in CachedInit.
khasanovaa Mar 3, 2026
c1b231f
[pallas:triton] `debug_barrier` now has an effect to prevent it from …
superbobry Mar 3, 2026
ccdbf06
[Mosaic GPU] Ensure that WGStridedFragLayout picks a vector length th…
apaszke Mar 3, 2026
2bbaf4c
Fix a test.
yueshengys Mar 3, 2026
7a3fb9d
[pallas:sc] Fixed how `scheduler.grid_env` is used when `emit_pipelin…
superbobry Mar 3, 2026
5a3accc
[pallas:triton] Moved Triton specific primitives from Pallas core int…
superbobry Mar 3, 2026
8f56da1
[pallas:sc] Use `pltpu.CoreType` instead of the derepcated `pltpu.Ker…
superbobry Mar 3, 2026
5a07b03
[jaxlib] Force the use of `typing_extensions.CapsuleType` in the _jax…
superbobry Mar 3, 2026
ba16266
[jaxlib] Fixed the type of `Traceback.__add__`
superbobry Mar 3, 2026
b41f583
Fixed all remaining Pyrefly errors in jax/_src
superbobry Mar 3, 2026
108a11a
Run build_cleaner on jaxlib.
hawkinsp Mar 3, 2026
0b31d9c
[Mosaic:GPU] Enable cuda-graphs with collective metadata.
PatriosTheGreat Mar 3, 2026
5c3035c
Reverts c5bf3aea3148c25e753f20b512d2efb5a741ffe9
Google-ML-Automation Mar 3, 2026
1d8fba4
[pallas:mgpu] Use two barriers for try-cancel barriers in `dynamic_sc…
chr1sj0nes Mar 3, 2026
f45c878
[hijax] enable VJPHiPrimitive to define transpose rules
mattjj Feb 26, 2026
158c20a
Fix a pallas dot test on TPU using unsupported unsigned int dot.
rdyro Mar 3, 2026
b0efa20
Rely on IFRT types implementing AbslStringify rather than calling the…
ICGog Mar 3, 2026
f08bb99
Add support for multiple weak keys to WeakrefLRUCache.
hawkinsp Mar 3, 2026
79bdd4b
Merge pull request #35558 from superbobry:pyrefly
Google-ML-Automation Mar 3, 2026
ecc04a8
Merge pull request #35446 from mattjj:hi-davis
Google-ML-Automation Mar 3, 2026
3ec221c
PR #35567: Bumped Pyrefly to 0.55
superbobry Mar 3, 2026
031c1e9
Use driver linking as default provider.
ermilovmaxim Mar 3, 2026
b275aec
Use Pyrefly instead of mypy
superbobry Mar 3, 2026
d777b52
Integrate TorchTPU SDPA with optimized flash attention kernel.
jcc-google Mar 3, 2026
8a73c61
[Mosaic TPU] Support reshape which unfolds the minormost dim into two…
yueshengys Mar 3, 2026
0ef0cfa
Merge pull request #35569 from superbobry:maint
Google-ML-Automation Mar 3, 2026
844bdee
Re-apply pallas dot_general check for disallowed unsigned integer inp…
rdyro Mar 3, 2026
a1328df
[doc] update developer docs for pyrefly
jakevdp Mar 3, 2026
0711c08
[pyrefly] test with opt-einsum
jakevdp Mar 3, 2026
79d8807
Merge pull request #35575 from jakevdp:pyrefly-docs
Google-ML-Automation Mar 3, 2026
c0b2687
Migrate to SafeStatic instead of SafeStaticInit.
hawkinsp Mar 3, 2026
718ac01
Couple of changes in this PR:
yashk2810 Mar 4, 2026
165e54c
Fix shard_map_partial_eval and partial_eval_custom to use `nospec` in…
yashk2810 Mar 4, 2026
42d97c8
Merge pull request #35576 from jakevdp:pyrefly-opt-einsum
Google-ML-Automation Mar 4, 2026
a5ed9d2
[export] Add a test for shape polymorphism with invalid constraints
gnecula Mar 3, 2026
76cbef0
Merge pull request #35587 from gnecula:poly_bug1
Google-ML-Automation Mar 4, 2026
e05dc3c
Update XLA dependency to use revision http://github.com/openxla/xla/c…
Google-ML-Automation Mar 4, 2026
c307523
Removed unnecessary Pyrefly suppressions
superbobry Mar 4, 2026
1184dba
Fix TensorStore implementation in Jax to use typed initialization.
Google-ML-Automation Mar 4, 2026
7ad75ac
Merge pull request #35591 from superbobry:pyrefly
Google-ML-Automation Mar 4, 2026
41f6727
Removed a few `type: ignore`s which were only necessary for mypy
superbobry Mar 4, 2026
d36ae3b
[Mosaic GPU][NFC] Add a `is_wg_semantics` method to `mosaic_gpu_test`.
dimitar-asenov Mar 4, 2026
9ff8500
Merge pull request #35592 from superbobry:pyrefly
Google-ML-Automation Mar 4, 2026
a9a9ce4
[Mosaic GPU] Add basic support for atomic reductions while storing
apaszke Mar 4, 2026
1c48067
Add ROCm wheel build and test pipeline to continuous CI
alekstheod Mar 3, 2026
4ec06d4
[Pallas][TPU kernel interpreter] Add optional logging to memory opera…
Google-ML-Automation Mar 4, 2026
c732a77
Restore docker
alekstheod Mar 4, 2026
40e9111
[NFC] Use `assert_never` for unreachable case in layout inference.
allanrenucci Mar 4, 2026
aa979f9
Trigger CI/CD pipeline
alekstheod Mar 4, 2026
3562aa8
[Mosaic TPU] Stop using canonicalization rules for tiling propagation
apaszke Mar 4, 2026
9e6b25f
[Mosaic GPU] Refactor `FragmentedArray.broadcast` to use `broadcast_i…
allanrenucci Mar 4, 2026
f2b145c
Trigger CI/CD pipeline
alekstheod Mar 4, 2026
5555a76
Trigger CI/CD pipeline
alekstheod Mar 4, 2026
4b93a66
Trigger change
alekstheod Mar 4, 2026
f37f484
[pallas:mgpu] Removed deprecated `Layout` aliases
superbobry Mar 4, 2026
4a37416
Trigger CI/CD pipeline
alekstheod Mar 4, 2026
7746f50
[Mosaic GPU] Add support for f16/bf16 atomics
apaszke Mar 4, 2026
4762077
[0.9.1] Update jaxlib version guards post-release.
danielsuo Mar 4, 2026
250595b
Switch docker
alekstheod Mar 4, 2026
9fa9c3b
[Mosaic:GPU] Rename is_comm_used as is_nvshmem_used
PatriosTheGreat Mar 4, 2026
9b75f92
Add mising file
alekstheod Mar 4, 2026
76ac8c0
Add missing wildcard
alekstheod Mar 4, 2026
3965854
Remove audit step
alekstheod Mar 4, 2026
9aa8b23
[pallas-triton] Explicitly disallow unsigned int operands in dot.
rdyro Mar 4, 2026
9905f55
[Mosaic] Allow mask relayout with changed bitwidth and 1d tiling.
WindQAQ Mar 4, 2026
1cd3ad5
[Mosaic] Add a `full_range` option to reciprocal.
WindQAQ Mar 4, 2026
3564ddc
[Mosaic TPU] Disallow DMAs with regular semaphores when src or tgt is…
apaszke Mar 4, 2026
232efa2
Add step installing aws cli
alekstheod Mar 4, 2026
87db9ab
Integrate LLVM at llvm/llvm-project@5ff5a1f14761
Google-ML-Automation Mar 4, 2026
fb41d65
Jax updates for rules_python 1.8.4
ecalubaquib Mar 4, 2026
42673a9
Clean up some static initializer usage.
hawkinsp Mar 4, 2026
ca6fda2
Reverts 1cd3ad579a92f2c25f3502b9d0683d5903afce78
WindQAQ Mar 4, 2026
93f24e8
[refactor] directly define array methods
jakevdp Mar 4, 2026
c35c1ff
Add variable for rocm presubmit's to clone xla repo
ecalubaquib Mar 4, 2026
8b14bfd
Merge pull request #35581 from jakevdp:reduce-stack-frames
Google-ML-Automation Mar 4, 2026
46f9beb
Fix buggy unreduced add lowering.
mwhittaker Mar 4, 2026
bc94500
Reshape which unfolds the last dim into two will only be supported on…
yueshengys Mar 4, 2026
2709440
[hijax] generalize shmap closure handling for hitypes, add HipSpec.to…
mattjj Mar 4, 2026
0251953
Merge pull request #35615 from mattjj:shmap-no-p
Google-ML-Automation Mar 5, 2026
8d439f5
Rename to_cotangent_aval -> to_ct_aval. This is to match it to `to_ct…
yashk2810 Mar 5, 2026
1ab4f73
Automated Code Change
Google-ML-Automation Mar 5, 2026
e527f48
Properly mirror CompilationProviderOptions for jax
ermilovmaxim Mar 5, 2026
a4958bb
Merge branch 'main' of github.com:jax-ml/jax into implement_periodic_…
alekstheod Mar 5, 2026
23bf7d9
Trigger CI/CD pipeline
alekstheod Mar 5, 2026
9a3bb3b
Switch to oidc
alekstheod Mar 6, 2026
b4f4b37
Trigger CI/CD pipeline
alekstheod Mar 6, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,10 @@ common:ci_windows_amd64 --extra_toolchains="@xla//tools/toolchains/win2022/20241
common:ci_windows_amd64 --host_linkopt=/FORCE:MULTIPLE --linkopt=/FORCE:MULTIPLE
common:ci_windows_amd64 --color=yes

# Used for rules_python bootstrap 1.8.0+
build --@rules_python//python/config_settings:bootstrap_impl=script --repo_env=RULES_PYTHON_ENABLE_PIPSTAR=0
test --@rules_python//python/config_settings:bootstrap_impl=script --repo_env=RULES_PYTHON_ENABLE_PIPSTAR=0

# #############################################################################
# RBE config options below. These inherit the CI configs above and set the
# remote execution backend and authentication options required to run builds
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/bazel_rocm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,12 @@ jobs:
JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE: ${{ inputs.write_to_bazel_remote_cache }}
JAXCI_BUILD_JAX: ${{ inputs.build_jax }}
JAXCI_BUILD_JAXLIB: ${{ inputs.build_jaxlib }}
JAXCI_CLONE_MAIN_XLA: ${{ inputs.clone_main_xla }}
# Begin Presubmit Naming Check - name modification requires internal check to be updated
name: "linux x86, jaxlib=${{ inputs.jaxlib-version }}, ROCM=${{ inputs.rocm-version }}, Python=${{ inputs.python }}, x64=${{ inputs.enable-x64 }}, build_jax=${{ inputs.build_jax }}, build_jaxlib=${{ inputs.build_jaxlib }}"
# End Presubmit Naming Check github-rocm-presubmits
steps:
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: ROCm Info
Expand Down
180 changes: 180 additions & 0 deletions .github/workflows/build_rocm_artifacts.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# CI - Build ROCm Artifacts
# This workflow builds ROCm wheels (jax-rocm-plugin, jax-rocm-pjrt) in a ROCm container
# and uploads them to an S3 bucket. It can be triggered manually via workflow_dispatch or
# called by other workflows via workflow_call.
name: CI - Build ROCm Artifacts

on:
workflow_dispatch:
inputs:
runner:
description: "Which runner should the workflow run on?"
type: choice
default: "linux-x86-64-1gpu-amd"
options:
- "linux-x86-64-1gpu-amd"
artifact:
description: "Which ROCm artifact to build?"
type: choice
default: "jax-rocm-plugin"
options:
- "jax-rocm-plugin"
- "jax-rocm-pjrt"
python:
description: "Which python version should the artifact be built for?"
type: choice
default: "3.12"
options:
- "3.11"
- "3.12"
- "3.13"
- "3.14"
rocm-version:
description: "Which ROCm version to build for?"
type: string
default: "7"
clone_main_xla:
description: "Should latest XLA be used?"
type: choice
default: "1"
options:
- "1"
- "0"
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: choice
default: 'no'
options:
- 'yes'
- 'no'
workflow_call:
inputs:
runner:
description: "Which runner should the workflow run on?"
type: string
default: "linux-x86-64-1gpu-amd"
artifact:
description: "Which ROCm artifact to build?"
type: string
default: "jax-rocm-plugin"
python:
description: "Which python version should the artifact be built for?"
type: string
default: "3.12"
rocm-version:
description: "Which ROCm version to build for?"
type: string
default: "7"
clone_main_xla:
description: "Should latest XLA be used?"
type: string
default: "1"
upload_artifacts_to_s3:
description: "Should the artifacts be uploaded to S3?"
default: true
type: boolean
s3_upload_uri:
description: "S3 location prefix to where the artifacts should be uploaded"
default: 's3://jax-ci-amd/jax-rocm-plugin-wheels'
type: string
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: string
default: 'no'
secrets:
AWS_ACCESS_KEY_ID:
required: true
AWS_SECRET_ACCESS_KEY:
required: true
S3_BUCKET_NAME:
required: true
outputs:
s3_upload_uri:
description: "S3 location prefix to where the artifacts were uploaded"
value: ${{ jobs.build-artifacts.outputs.s3_upload_uri }}

permissions:
id-token: write
contents: read
actions: read

jobs:
build-artifacts:
defaults:
run:
shell: bash
runs-on: ${{ inputs.runner }}
container:
image: "ghcr.io/rocm/jax-manylinux_2_28-rocm-7.2.0:latest"
volumes:
- /data:/data
options: >-
--device=/dev/kfd
--device=/dev/dri
--security-opt seccomp=unconfined
--group-add video
--shm-size 64G

env:
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}"
JAXCI_OUTPUT_DIR: "${{ github.workspace }}/dist"

name: "${{ inputs.artifact }}, py ${{ inputs.python }}, ROCm ${{ inputs.rocm-version }}"

outputs:
s3_upload_uri: ${{ steps.store-s3-upload-uri.outputs.s3_upload_uri }}

steps:
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
with:
persist-credentials: false
- name: Install Bazelisk
run: |
curl -fSsL -o /usr/local/bin/bazel https://github.com/bazelbuild/bazelisk/releases/latest/download/bazelisk-linux-amd64
chmod +x /usr/local/bin/bazel
- name: Install AWS CLI
run: |
curl -fSsL "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o /tmp/awscliv2.zip
unzip -q /tmp/awscliv2.zip -d /tmp
/tmp/aws/install
rm -rf /tmp/awscliv2.zip /tmp/aws
- name: ROCm Info
run: rocminfo
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Create dist dir
run: mkdir -p $JAXCI_OUTPUT_DIR/
- name: Build ${{ inputs.artifact }}
timeout-minutes: 120
run: |
bazel --bazelrc=build/rocm/rocm.bazelrc run \
--config=rocm_release_wheel \
--config=rocm_rbe \
--repo_env=HERMETIC_PYTHON_VERSION="${{ inputs.python }}" \
$DEPLOY_TARGET -- $JAXCI_OUTPUT_DIR/
env:
DEPLOY_TARGET: ${{ inputs.artifact == 'jax-rocm-plugin' && '//jaxlib/tools:deploy_rocm_plugin_wheel' || '//jaxlib/tools:deploy_rocm_pjrt_wheel' }}
- name: Configure AWS Credentials
if: ${{ inputs.upload_artifacts_to_s3 }}
uses: aws-actions/configure-aws-credentials@v4
with:
role-to-assume: arn:aws:iam::661452401056:role/jax-ci-rocm-jax-s3-oidc
aws-region: us-east-1
- name: Upload artifacts to S3
if: ${{ inputs.upload_artifacts_to_s3 }}
run: |
echo "Uploading wheels to S3..."
ls -lh $JAXCI_OUTPUT_DIR/*.whl
aws s3 cp --only-show-errors --recursive $JAXCI_OUTPUT_DIR/ "${INPUTS_S3_UPLOAD_URI}"/
echo "Upload complete."
env:
INPUTS_S3_UPLOAD_URI: ${{ inputs.s3_upload_uri }}
- name: Store the S3 upload URI as an output
if: ${{ inputs.upload_artifacts_to_s3 }}
id: store-s3-upload-uri
run: echo "s3_upload_uri=${INPUTS_S3_UPLOAD_URI}" >> "$GITHUB_OUTPUT"
env:
INPUTS_S3_UPLOAD_URI: ${{ inputs.s3_upload_uri }}
33 changes: 0 additions & 33 deletions .github/workflows/pyrefly.yml

This file was deleted.

10 changes: 5 additions & 5 deletions .github/workflows/pytest_tpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,11 @@ jobs:
gcs_download_uri: ${{ inputs.gcs_download_uri }}
- name: Install Python dependencies
run: |
$JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt -r build/collect-profile-requirements.txt
$JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt
# xprof depends on cffi, which doesn't support Python 3.13 free-threaded.
if [[ "${JAXCI_HERMETIC_PYTHON_VERSION}" != "3.13-nogil" ]]; then
$JAXCI_PYTHON -m uv pip install -r build/collect-profile-requirements.txt
fi
- name: Set up libtpu wheels
run: |
if [[ "${INPUTS_LIBTPU_VERSION_TYPE}" == "nightly" ]]; then
Expand Down Expand Up @@ -124,10 +128,6 @@ jobs:
- name: Run Pytest TPU tests
timeout-minutes: ${{ github.event_name == 'pull_request' && 30 || 210 }}
run: |
if [[ ${INPUTS_PYTHON} == "3.13-nogil" ]]; then
echo "Uninstalling xprof as it is not compatible with python 3.13t."
$JAXCI_PYTHON -m uv pip uninstall xprof
fi
./ci/run_pytest_tpu.sh
env:
INPUTS_PYTHON: ${{ inputs.python }}
2 changes: 1 addition & 1 deletion .github/workflows/rocm-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:
dist_docker \
--image-tag $TEST_IMAGE
- name: Archive jax wheels
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: rocm_jax_r${{ env.ROCM_VERSION }}_py${{ env.PYTHON_VERSION }}_id${{ github.run_id }}
path: ${{ env.WORKSPACE_DIR }}/dist/*.whl
Expand Down
86 changes: 83 additions & 3 deletions .github/workflows/wheel_tests_continuous.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# CI - Wheel Tests (Continuous)
#
# This workflow builds JAX artifacts and runs CPU/TPU/CUDA tests.
# This workflow builds JAX artifacts and runs CPU/TPU/CUDA/ROCm tests.
#
# It orchestrates the following:
# 1. build-jaxlib-artifact: Calls the `build_artifacts.yml` workflow to build jaxlib and
Expand All @@ -22,15 +22,26 @@
# that was built in the previous step and runs TPU tests.
# 9. run-bazel-test-tpu: Calls the `bazel_test_tpu.yml` workflow which
# runs Bazel TPU tests with py_import.
# 10. build-rocm-artifacts: Calls the `build_rocm_artifacts.yml` workflow to build ROCm plugin/pjrt
# wheels and uploads them to an S3 bucket.
# 11. run-pytest-rocm: Calls the `pytest_rocm.yml` workflow which downloads the jaxlib and
# ROCm artifacts and runs the ROCm tests.
# 12. run-bazel-test-rocm: Calls the `bazel_rocm.yml` workflow which runs the ROCm Bazel tests.

name: CI - Wheel Tests (Continuous)
permissions:
contents: read
id-token: write
contents: read
actions: read

on:
schedule:
- cron: "0 */3 * * *" # Run once every 3 hours
workflow_dispatch: # allows triggering the workflow run manually
pull_request:
paths:
- '.github/workflows/wheel_tests_continuous.yml'
- '.github/workflows/build_rocm_artifacts.yml'

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
Expand Down Expand Up @@ -285,4 +296,73 @@ jobs:
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
build_jaxlib: "wheel"
build_jax: "wheel"
clone_main_xla: 1
clone_main_xla: 1

build-rocm-artifacts:
uses: ./.github/workflows/build_rocm_artifacts.yml
secrets: inherit
permissions:
id-token: write
contents: read
actions: read
strategy:
fail-fast: false
matrix:
runner: ["linux-x86-64-1gpu-amd"]
artifact: ["jax-rocm-plugin", "jax-rocm-pjrt"]
python: ["3.11"]
rocm-version: ["7"]
name: "Build ${{ format('{0}', 'ROCm') }} artifacts"
with:
runner: ${{ matrix.runner }}
artifact: ${{ matrix.artifact }}
python: ${{ matrix.python }}
rocm-version: ${{ matrix.rocm-version }}
clone_main_xla: 1
upload_artifacts_to_s3: true
s3_upload_uri: 's3://jax-ci-amd/rocm-wheels/wheel-tests-continuous/${{ github.run_number }}/${{ github.run_attempt }}'

run-pytest-rocm:
if: ${{ !cancelled() }}
needs: [build-jax-artifact, build-jaxlib-artifact, build-rocm-artifacts]
uses: ./.github/workflows/pytest_rocm.yml
strategy:
fail-fast: false
matrix:
runner: ["linux-x86-64-1gpu-amd", "linux-x86-64-4gpu-amd", "linux-x86-64-8gpu-amd"]
python: ["3.11"]
rocm: [
{version: "7.2.0", tag: "rocm720"},
]
name: "Pytest ROCm (JAX artifacts version = ${{ format('{0}', 'head') }})"
with:
runner: ${{ matrix.runner }}
python: ${{ matrix.python }}
rocm-version: ${{ matrix.rocm.version }}
rocm-tag: ${{ matrix.rocm.tag }}
jaxlib-version: "head"
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}

run-bazel-test-rocm:
if: ${{ !cancelled() }}
needs: [build-jax-artifact, build-jaxlib-artifact, build-rocm-artifacts]
uses: ./.github/workflows/bazel_rocm.yml
strategy:
fail-fast: false
matrix:
runner: ["linux-x86-64-4gpu-amd", "linux-x86-64-8gpu-amd"]
python: ["3.11"]
rocm-version: ["7"]
enable-x64: [0]
name: "Bazel ROCm tests (JAX artifacts version = ${{ format('{0}', 'head') }})"
with:
runner: ${{ matrix.runner }}
python: ${{ matrix.python }}
rocm-version: ${{ matrix.rocm-version }}
enable-x64: ${{ matrix.enable-x64 }}
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
build_jaxlib: "false"
build_jax: "false"
jaxlib-version: "head"
run_multiaccelerator_tests: "false"
clone_main_xla: 1
Loading
Loading