[Testing do not merge] Add ROCm wheel build and test pipeline to continuous CI#723
Draft
alekstheod wants to merge 121 commits intoamd-mainfrom
Draft
[Testing do not merge] Add ROCm wheel build and test pipeline to continuous CI#723alekstheod wants to merge 121 commits intoamd-mainfrom
alekstheod wants to merge 121 commits intoamd-mainfrom
Conversation
alekstheod
commented
Mar 4, 2026
- Add jax-rocm-plugin and jax-rocm-pjrt to allowed artifacts in build_artifacts.sh with ROCm version flag passthrough.
- Create build_rocm_artifacts.yml reusable workflow that builds ROCm wheels in a ROCm container and uploads them to S3 via OIDC.
- Extend wheel_tests_continuous.yml with build-rocm-artifacts, run-pytest-rocm, and run-bazel-test-rocm jobs.
Co-authored-by: <yashkatariya@google.com>
PiperOrigin-RevId: 877469331
PiperOrigin-RevId: 877475185
Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 6.0.0 to 7.0.0. - [Release notes](https://github.com/actions/upload-artifact/releases) - [Commits](actions/upload-artifact@b7c566a...bbbca2d) --- updated-dependencies: - dependency-name: actions/upload-artifact dependency-version: 7.0.0 dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] <support@github.com>
Bumps [actions/checkout](https://github.com/actions/checkout) from 6.0.0 to 6.0.2. - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](actions/checkout@v6...de0fac2) --- updated-dependencies: - dependency-name: actions/checkout dependency-version: 6.0.2 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] <support@github.com>
PiperOrigin-RevId: 877489290
PiperOrigin-RevId: 877503580
Requires a new nightly libtpu (latest is libtpu-0.0.37.dev20260224). PiperOrigin-RevId: 877504588
xprof pulls in cffi, which is not supported under 3.13-nogil. PiperOrigin-RevId: 877504615
PiperOrigin-RevId: 877509481
…/actions/upload-artifact-7.0.0 PiperOrigin-RevId: 877527030
…/actions/checkout-6.0.2 PiperOrigin-RevId: 877527224
PiperOrigin-RevId: 877544391
PiperOrigin-RevId: 877544489
PiperOrigin-RevId: 877548502
…ommit/ddf15ca00ef5693e02e2d870c8d720b7d8d060f6 PiperOrigin-RevId: 877558591
PiperOrigin-RevId: 877610348
PiperOrigin-RevId: 877616830
PiperOrigin-RevId: 877621017
This matches how they are currently used, and allows child classes flexibility in choice of argument names.
PiperOrigin-RevId: 877630232
Recent LLVM change didn't appear to update the examples. PiperOrigin-RevId: 877630264
PiperOrigin-RevId: 877633633
f172dc1 to
3965854
Compare
PiperOrigin-RevId: 878543875
PiperOrigin-RevId: 878543940
Only setting it to true when inputs/results are known to be not weird (+-inf, NaN, +-1/flt_min). Otherwise, the results might not be IEEE compliant. PiperOrigin-RevId: 878546673
… in SMEM DMAs involving SMEM require us to use sfence, but tpu.sem_wait has no way of knowing if it should emit one. PiperOrigin-RevId: 878549889
Updates LLVM usage to match [5ff5a1f14761](llvm/llvm-project@5ff5a1f14761) PiperOrigin-RevId: 878586624
1. Update runfile imports and its related python files 2. Add flags to .bazelrc specific for flags used for rules_python 3. Update all .sh runners for ci test to use rules_python_bootstrap except for older python 3.11 4. Fix wheel_build_command to use wheel_build_command_based 5. Make numpy visible 6. Update all the patches for rules_python for scoping and freethreading PiperOrigin-RevId: 878592042
PiperOrigin-RevId: 878594835
Why? Previously these were defined by a function factory – while this leads to shorter code, dynamically-generated functions are more difficult to follow for developers, and lead to longer stack traces for users.
PiperOrigin-RevId: 878636461
PiperOrigin-RevId: 878646845
I noticed the following program was producing the wrong gradients:
```python
import jax
import jax.numpy as jnp
jax.config.update('jax_num_cpu_devices', 2)
mesh = jax.make_mesh((2,) , ("i",))
jax.set_mesh(mesh)
x = jax.reshard(jnp.ones((2, 2)), jax.P(None, None, reduced={'i'}))
def f(x):
return jnp.sum(x + x)
y, vjp = jax.vjp(f, x)
g = vjp(1.)[0]
print(g.sharding)
for shard in g.addressable_shards:
print(shard.data)
```
The gradient should be an unreduced array where each device-local shard is [[1,
1], [1, 1]], but we were seeing [[2, 2], [2, 2]] on each shard.
Yash found and fixed the bug.
PiperOrigin-RevId: 878663233
… TPU v5+ if the last dim is not divisible by 128. PiperOrigin-RevId: 878667118
…_tangent_spec Co-authored-by: Yash Katariya <yashkatariya@google.com>
PiperOrigin-RevId: 878755710
…_spec`. PiperOrigin-RevId: 878763496
PiperOrigin-RevId: 878799690
PiperOrigin-RevId: 878805891
ab6bb69 to
3ea5e36
Compare
…build_wheels_pipeline_rocm_internal_v2
3ea5e36 to
a4958bb
Compare
6289d0f to
bc61e1e
Compare
bc61e1e to
9a3bb3b
Compare
714ba20 to
b4f4b37
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.