Skip to content

[Testing do not merge] Add ROCm wheel build and test pipeline to continuous CI#723

Draft
alekstheod wants to merge 121 commits intoamd-mainfrom
implement_periodic_build_wheels_pipeline_rocm_internal_v2
Draft

[Testing do not merge] Add ROCm wheel build and test pipeline to continuous CI#723
alekstheod wants to merge 121 commits intoamd-mainfrom
implement_periodic_build_wheels_pipeline_rocm_internal_v2

Conversation

@alekstheod
Copy link

  • Add jax-rocm-plugin and jax-rocm-pjrt to allowed artifacts in build_artifacts.sh with ROCm version flag passthrough.
  • Create build_rocm_artifacts.yml reusable workflow that builds ROCm wheels in a ROCm container and uploads them to S3 via OIDC.
  • Extend wheel_tests_continuous.yml with build-rocm-artifacts, run-pytest-rocm, and run-bazel-test-rocm jobs.

mattjj and others added 30 commits February 27, 2026 00:41
Co-authored-by: <yashkatariya@google.com>
Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 6.0.0 to 7.0.0.
- [Release notes](https://github.com/actions/upload-artifact/releases)
- [Commits](actions/upload-artifact@b7c566a...bbbca2d)

---
updated-dependencies:
- dependency-name: actions/upload-artifact
  dependency-version: 7.0.0
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Bumps [actions/checkout](https://github.com/actions/checkout) from 6.0.0 to 6.0.2.
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](actions/checkout@v6...de0fac2)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-version: 6.0.2
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Requires a new nightly libtpu (latest is libtpu-0.0.37.dev20260224).

PiperOrigin-RevId: 877504588
xprof pulls in cffi, which is not supported under 3.13-nogil.

PiperOrigin-RevId: 877504615
…/actions/upload-artifact-7.0.0

PiperOrigin-RevId: 877527030
…/actions/checkout-6.0.2

PiperOrigin-RevId: 877527224
PiperOrigin-RevId: 877558703
This matches how they are currently used, and allows child classes flexibility in choice of argument names.
Recent LLVM change didn't appear to update the examples.

PiperOrigin-RevId: 877630264
@alekstheod alekstheod force-pushed the implement_periodic_build_wheels_pipeline_rocm_internal_v2 branch from f172dc1 to 3965854 Compare March 4, 2026 17:57
rdyro and others added 19 commits March 4, 2026 10:06
Only setting it to true when inputs/results are known to be not weird (+-inf, NaN, +-1/flt_min). Otherwise, the results might not be IEEE compliant.

PiperOrigin-RevId: 878546673
… in SMEM

DMAs involving SMEM require us to use sfence, but tpu.sem_wait has no way of knowing
if it should emit one.

PiperOrigin-RevId: 878549889
Updates LLVM usage to match
[5ff5a1f14761](llvm/llvm-project@5ff5a1f14761)

PiperOrigin-RevId: 878586624
1. Update runfile imports and its related python files
2. Add flags to .bazelrc specific for flags used for rules_python
3. Update all .sh runners for ci test to use rules_python_bootstrap except for older python 3.11
4. Fix wheel_build_command to use wheel_build_command_based
5. Make numpy visible
6. Update all the patches for rules_python for scoping and freethreading

PiperOrigin-RevId: 878592042
PiperOrigin-RevId: 878594835
PiperOrigin-RevId: 878629269
Why? Previously these were defined by a function factory – while this leads
to shorter code, dynamically-generated functions are more difficult to
follow for developers, and lead to longer stack traces for users.
I noticed the following program was producing the wrong gradients:

```python
import jax
import jax.numpy as jnp

jax.config.update('jax_num_cpu_devices', 2)
mesh = jax.make_mesh((2,) , ("i",))
jax.set_mesh(mesh)
x = jax.reshard(jnp.ones((2, 2)), jax.P(None, None, reduced={'i'}))
def f(x):
  return jnp.sum(x + x)
y, vjp = jax.vjp(f, x)
g = vjp(1.)[0]
print(g.sharding)
for shard in g.addressable_shards:
  print(shard.data)
```

The gradient should be an unreduced array where each device-local shard is [[1,
1], [1, 1]], but we were seeing [[2, 2], [2, 2]] on each shard.

Yash found and fixed the bug.

PiperOrigin-RevId: 878663233
… TPU v5+ if the last dim is not divisible by 128.

PiperOrigin-RevId: 878667118
…_tangent_spec

Co-authored-by: Yash Katariya <yashkatariya@google.com>
PiperOrigin-RevId: 878799690
@alekstheod alekstheod force-pushed the implement_periodic_build_wheels_pipeline_rocm_internal_v2 branch from ab6bb69 to 3ea5e36 Compare March 5, 2026 09:49
@alekstheod alekstheod force-pushed the implement_periodic_build_wheels_pipeline_rocm_internal_v2 branch from 3ea5e36 to a4958bb Compare March 5, 2026 09:55
@alekstheod alekstheod force-pushed the implement_periodic_build_wheels_pipeline_rocm_internal_v2 branch from 6289d0f to bc61e1e Compare March 6, 2026 07:53
@alekstheod alekstheod force-pushed the implement_periodic_build_wheels_pipeline_rocm_internal_v2 branch from bc61e1e to 9a3bb3b Compare March 6, 2026 08:14
@alekstheod alekstheod force-pushed the implement_periodic_build_wheels_pipeline_rocm_internal_v2 branch from 714ba20 to b4f4b37 Compare March 6, 2026 08:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.