Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
7513 commits
Select commit Hold shift + click to select a range
523b667
Bump actions/checkout from 6.0.0 to 6.0.1
dependabot[bot] Feb 20, 2026
3b6fd04
Propagate input sharding to output in full_like if it's fully replica…
yashk2810 Feb 21, 2026
79f3519
[hijax] add is_vjp bit to LinearizeTrace
mattjj Feb 21, 2026
4181fe4
Merge pull request #35285 from mattjj:is-vjp-bit
Google-ML-Automation Feb 21, 2026
4f22f9d
Update XLA dependency to use revision http://github.com/openxla/xla/c…
Google-ML-Automation Feb 21, 2026
7fbbe01
Canonicalize indexer.slice_sharding. Fixes https://github.com/jax-ml/…
yashk2810 Feb 21, 2026
b1981cb
[hijax] improved dce interface
mattjj Feb 22, 2026
7f80f43
Merge pull request #35301 from mattjj:better-hijax-dce-interface
Google-ML-Automation Feb 22, 2026
acad300
[hijax] fix followup on #35301
mattjj Feb 22, 2026
0cff9fe
Merge pull request #35302 from mattjj:hijax-dce-fix
Google-ML-Automation Feb 22, 2026
5f43a72
[hijax][remat3] start implementing remat on hijax
mattjj Feb 21, 2026
53d6b35
Merge pull request #35304 from mattjj:remat-transform
Google-ML-Automation Feb 22, 2026
b17e41d
[hijax][remat3] add basic remat_jaxpr
mattjj Feb 22, 2026
dea31b6
Merge pull request #35306 from mattjj:remat-jaxpr
Google-ML-Automation Feb 22, 2026
8dac162
Update XLA dependency to use revision http://github.com/openxla/xla/c…
Google-ML-Automation Feb 22, 2026
9ce7295
[pyrefly] Fixed or suppressed type errors in Pallas backends
superbobry Feb 21, 2026
4d3395f
Merge pull request #35309 from superbobry:pyrefly
Google-ML-Automation Feb 23, 2026
f47f0c2
Update XLA dependency to use revision http://github.com/openxla/xla/c…
Google-ML-Automation Feb 23, 2026
cd01971
[pallas:mgpu] Fix handling of `AbstractRef` in `_commute_transform`.
chr1sj0nes Feb 23, 2026
a58b181
Merge pull request #33802 from jax-ml:dependabot/github_actions/actio…
Google-ML-Automation Feb 23, 2026
1155884
Skip scaled matmul tests on CPU
phambinhfin Feb 23, 2026
b5c9201
[test] fix deprecated mpmath import
jakevdp Feb 23, 2026
0beb837
More Pyrefly fixes in Pallas
superbobry Feb 23, 2026
adc7836
Merge pull request #35323 from jakevdp:mpmath-fix
Google-ML-Automation Feb 23, 2026
ad8c8c4
Merge pull request #35317 from superbobry:pyrefly
Google-ML-Automation Feb 23, 2026
b1c5e11
[typing] more pyrefly fixes
jakevdp Feb 23, 2026
e673505
Another batch of Pyrefly fixes
superbobry Feb 23, 2026
db34816
[Mosaic TPU] Give more time for the changes to propagate to libtpu
apaszke Feb 23, 2026
01c3083
Fix dangling std::string_view in keys of GetModuleImage() cache
Arech8 Feb 23, 2026
2287580
[jaxlib] Bundle _{tpu,triton,mosaic_gpu}_ext.pyi with jaxlib
superbobry Feb 23, 2026
8347183
Pin mpmath to v1.3.0.
danielsuo Feb 23, 2026
b486aac
Merge pull request #35332 from superbobry:pyrefly
Google-ML-Automation Feb 23, 2026
4dc41c3
Fixed `self.accum_ref` assertion in mosaic/pipeline.py
superbobry Feb 23, 2026
920f666
Roll back adding __eq__ and __hash__ to PyDevice.
danielsuo Feb 23, 2026
fc5ab2c
[pmap] Replace `functools.lru_cache` with `util.cache` to fix stale `…
danielsuo Feb 23, 2026
74b70ef
[lax_numpy] Adjust tolerance for np.float32 in trapezoid tests.
danielsuo Feb 23, 2026
ef24700
Bump actions/setup-python from 6.1.0 to 6.2.0
dependabot[bot] Feb 23, 2026
9c33cd1
Bump actions/cache from 4.3.0 to 5.0.3
dependabot[bot] Feb 23, 2026
45e29f6
Merge pull request #35344 from jax-ml:dependabot/github_actions/actio…
Google-ML-Automation Feb 23, 2026
5d01b1f
[typing] suppress mypy redundant-cast
jakevdp Feb 23, 2026
8ac9b74
Merge pull request #35347 from jax-ml:dependabot/github_actions/actio…
Google-ML-Automation Feb 23, 2026
cc7365f
Re-enable some disabled tests in pjit_test.py
yashk2810 Feb 23, 2026
677d7b0
Merge pull request #35343 from jakevdp:fix-mypy-lint
Google-ML-Automation Feb 23, 2026
e12cf3d
Merge pull request #35331 from jakevdp:more-pyrefly-fixes
Google-ML-Automation Feb 23, 2026
6dc4419
[mlir] Fixed the return type of `*{Attr,Type}.get` staticmethods
superbobry Feb 23, 2026
d6d975c
Update XLA dependency to use revision http://github.com/openxla/xla/c…
Google-ML-Automation Feb 24, 2026
b55efd8
[Mosaic GPU] Add support for swizzle=16 in async SMEM->TMEM copies
apaszke Feb 24, 2026
776d3c3
Reverts 42c458cc318706bbbefbb45a752da6615a7123d3
apaszke Feb 24, 2026
98f0ca2
More Pyrefly fixes
superbobry Feb 24, 2026
1807a49
[Pallas:MGPU] Allow SMEM->TMEM copies with unswizzled (but tiled) refs
apaszke Feb 24, 2026
192601c
[Mosaic GPU] Relax N multiplicity restrictions for tcgen05 block-scal…
apaszke Feb 24, 2026
d037727
[Mosaic GPU] Add support for fp4 tcgen05 MMAs that are both block-sca…
apaszke Feb 24, 2026
c139db7
Merge pull request #35133 from mwhittaker:fault_tolerance_blog_tweaks
Google-ML-Automation Feb 24, 2026
ab5ef25
[Mosaic TPU] Expose the latest version from the TPU MosaicSerdePass
apaszke Feb 24, 2026
f0f74df
Add a light, blocking Mosaic B200 presubmit.
belitskiy Feb 24, 2026
fc9061d
[Pallas:MGPU] Add support for writes to WGMMA accumulator refs
apaszke Feb 24, 2026
e16cf23
[Mosaic:GPU] Temporary fix of ASAN test.
PatriosTheGreat Feb 24, 2026
afa13c4
PR #35345: Bump actions/upload-artifact from 5.0.0 to 6.0.0
dependabot[bot] Feb 24, 2026
012d665
Merge pull request #35367 from superbobry:pyrefly
Google-ML-Automation Feb 24, 2026
f59d821
[mosaic] Verify that `tpu.memref_slice` always produces tile-aligned …
superbobry Feb 24, 2026
b69d1c0
[hijax] mlir.lower_fun should have lo-jaxprs
mattjj Feb 24, 2026
e76e398
[NFC] Remove some unnecessary skips to have better test coverage.
yueshengys Feb 25, 2026
6b89f7c
Add Pallas SparseCore guide
IvyZX Feb 13, 2026
2cbf65c
Merge pull request #35379 from mattjj:mlir-lower-fun-to-lojax
Google-ML-Automation Feb 25, 2026
7b56a86
Reverts f59d8213d42d124048de9ad93792bc17152134c7
Google-ML-Automation Feb 25, 2026
518b0d1
[Mosaic TPU] Add more complete support for 16-bit `cmpf` on TPU v5. P…
yueshengys Feb 25, 2026
0f31a1b
Update XLA dependency to use revision http://github.com/openxla/xla/c…
Google-ML-Automation Feb 25, 2026
2e3c840
Add SC support for pltpu.trace_value
brianwa84 Feb 25, 2026
f6af222
Merge pull request #35334 from ROCm:ci_arech_fix_dangling_string_view
Google-ML-Automation Feb 25, 2026
90fdba0
[Pallas:MGPU] Add a test for tcgen05 MMA that's both sparse and block…
apaszke Feb 25, 2026
eda5ec9
[windows] Match triton importing in pallas_shape_poly_test to how it'…
danielsuo Feb 25, 2026
9169891
[jaxlib] Every MLIR dialect now comes with a .pyi
superbobry Feb 23, 2026
ab5f197
[lax_scipy] Adjust tolerance for np.float32 in trapezoid tests.
danielsuo Feb 25, 2026
dc11248
[lax_numpy] Skip ReducerWhere tests temporarily.
danielsuo Feb 25, 2026
ace7a79
[Mosaic TPU] Allow non-DMA semaphores in EnqueueDMA
apaszke Feb 25, 2026
1ae8be1
Merge pull request #35357 from superbobry:more-piys
Google-ML-Automation Feb 25, 2026
122abbb
[lax_numpy] Skip AdvancedIndexing tests temporarily.
danielsuo Feb 25, 2026
c6d2109
Enable some tests with TC-tiling with newer libtpu.
brianwa84 Feb 25, 2026
9c9dd77
Fix a bug in how input/output memory spaces are detected and propagat…
brianwa84 Feb 25, 2026
709086f
[typing] add type annotation for shard_map
Feb 25, 2026
84a33c0
[Pallas:MGPU] Fix a race in dynamic_scheduling_loop
apaszke Feb 25, 2026
a211a31
[ci] Update pyrefly pre-commit to v0.54.0
jakevdp Feb 25, 2026
5c0ec3e
[windows] Skip ops_test.test_binary_scalar if no TPU.
danielsuo Feb 25, 2026
ac537c7
[lax_numpy] Skip setxor1 tests temporarily.
danielsuo Feb 25, 2026
f5c39fe
Merge pull request #35404 from jakevdp:pyrefly-054
Google-ML-Automation Feb 25, 2026
100b05c
Merge pull request #35015 from IvyZX:scguide
Google-ML-Automation Feb 25, 2026
19e71a5
[pyrefly] remove ignore statements for issues fixed in v0.54
jakevdp Feb 25, 2026
b8171e3
[typing] add type annotation for `jax.smap`.
Feb 25, 2026
cda2930
Merge pull request #35405 from jakevdp:pyrefly-ignore-statements
Google-ML-Automation Feb 25, 2026
1f3139e
Merge pull request #34601 from ROCm:phambinh/ci_rocm_nn_tests
Google-ML-Automation Feb 25, 2026
015f1cb
Merge pull request #35178 from ezhulenev:jax_sort_devices_by_process_…
Google-ML-Automation Feb 25, 2026
7780531
[Pallas][Mosaic TPU] Add flag to disable semaphore checks per kernel.
bythew3i Feb 25, 2026
0461a3b
[typing] more pyrefly fixes in jax/_src/lax/
jakevdp Feb 25, 2026
742ad02
Merge pull request #35325 from jakevdp:pyrefly-fixes
Google-ML-Automation Feb 25, 2026
480847f
Removed more dead code
superbobry Feb 25, 2026
1e102ae
[pyrefly] fix remaining jax/_src/lax errors
jakevdp Feb 25, 2026
7f250e0
[windows] Skip triton export tests in pallas_shape_poly_test if trito…
danielsuo Feb 25, 2026
6a635c9
Explicitly check for disallowing TransformedRefs in higher-order JAX …
rdyro Feb 25, 2026
b0d9134
[typing] pyrefly fixes in jax/_src/scipy/stats
levskaya Feb 25, 2026
3b399a8
[typing] pyrefly fixes in jax/_src/scipy/optimize
levskaya Feb 25, 2026
6869e28
[mosaic] Removed previously deprecated tpu.wait_dma
superbobry Feb 26, 2026
dc3f50c
Merge pull request #35414 from levskaya:pyrefly3
Google-ML-Automation Feb 26, 2026
efa6bd9
Reverts 6869e28a4bde0f6d3a06a7c39858fe40cc44ff60
Google-ML-Automation Feb 26, 2026
6f06ab8
[Mosaic:GPU] Simplify collective metadata initialization.
PatriosTheGreat Feb 26, 2026
d0d6f42
Update XLA dependency to use revision http://github.com/openxla/xla/c…
Google-ML-Automation Feb 26, 2026
ee4cea1
Upgrade Abseil to LTS 20260107.1
akuegel Feb 26, 2026
4826da4
Merge pull request #35412 from levskaya:pyrefly
Google-ML-Automation Feb 26, 2026
7192cf7
[Mosaic GPU] Add support for warp-level f8 MMAs
apaszke Feb 26, 2026
9e11f1c
[Pallas:TPU] Add tests for DMAs with regular semaphores
apaszke Feb 26, 2026
f2017f7
[Pallas:TPU] Add a test checking that dynamic shape exports require t…
apaszke Feb 26, 2026
af59a64
Set rpath based on the type (link_only)
alekstheod Feb 16, 2026
c141c0f
Add int1/uint1 dtypes.
WindQAQ Feb 26, 2026
0b7d667
[mosaic] Another go at removing tpu.wait_dma
superbobry Feb 26, 2026
3f667fd
[Mosaic GPU] Relax restrictions on N for sparse tcgen05 MMA
apaszke Feb 26, 2026
7af1783
Raise a better error if jit(shard_map) is bound with an AbstractMesh …
yashk2810 Feb 26, 2026
6df7333
Fix hijax `VmapOf.batch_dim_rule` with `None` map dims.
chr1sj0nes Feb 26, 2026
4b87f03
[hijax] cleaner scan inc_rank/dec_rank api
mattjj Feb 26, 2026
6f2f0f3
[JAX] Bound tracebacks up to the enclosing `jit`.
hawkinsp Feb 26, 2026
66b6141
Add a pre-commit hook that checks for Apache licence headers.
hawkinsp Feb 26, 2026
8f96afa
Suppress still-broken test.
brianwa84 Feb 26, 2026
873bebe
Roll back skipped tests that were failing on LLVM bit/byte-packing is…
danielsuo Feb 26, 2026
c3b53a2
Update XLA dependency to use revision http://github.com/openxla/xla/c…
danielsuo Feb 26, 2026
17cfbb8
Add jaxlib version guard for updated make_c_api_client.
danielsuo Feb 26, 2026
6428a91
[jaxlib] Ported normalize_stubs.sh to Python
danielsuo Feb 26, 2026
12344fc
Reverts 7af1783eb619f7e0236aa8c25fe1fc72881bb6b3
Google-ML-Automation Feb 26, 2026
b612292
Tie LinearizeTrace.tag to tangent_trace.tag
jakevdp Feb 26, 2026
3b48263
Merge pull request #35439 from mattjj:hijax-cleaner-scan-axis-spec
Google-ML-Automation Feb 26, 2026
ed62898
Merge pull request #35413 from jakevdp:pyrefly-lax-fixes
Google-ML-Automation Feb 26, 2026
fa36954
[hijax] lower jaxpr during discharge_state
jakevdp Feb 26, 2026
d237804
[Pallas/TPU] Add tpu_info.get_tpu_info_for_chip
Google-ML-Automation Feb 26, 2026
f82e281
Skip pallas lowering determinism testOrderAgnostic test if not on TPU.
danielsuo Feb 26, 2026
85f6ef5
[typing] miscellaneous pyrefly typing fixes
levskaya Feb 26, 2026
47291b9
[typing] fixed remaining scipy pyrefly typing issues
levskaya Feb 26, 2026
008467a
Generalize how op verifiers checking issuing core
naummo Feb 26, 2026
e81437f
Fix undefined name in ad_checkpoint
jakevdp Feb 26, 2026
6d14b20
Merge pull request #35219 from jakevdp:fix-tracer-tag
Google-ML-Automation Feb 26, 2026
c08a170
Merge pull request #35452 from jakevdp:fix-ad-checkpoint
Google-ML-Automation Feb 26, 2026
153937a
Merge pull request #35448 from jakevdp:discharge-hijax
Google-ML-Automation Feb 26, 2026
d873c96
[linux-arm64] Reduce parallelism to 32 for nogil pytest cpu to avoid …
danielsuo Feb 27, 2026
019a87b
Merge pull request #35421 from levskaya:pyrefly4
Google-ML-Automation Feb 27, 2026
d57e7c8
[pyrefly] fix errors in jax/interpreters/ad
jakevdp Feb 27, 2026
11af98b
[pmap] Accelerate deprecation of PmapSharding.
danielsuo Feb 27, 2026
e008d9d
Merge pull request #35422 from levskaya:pyrefly5
Google-ML-Automation Feb 27, 2026
a6df1f2
Add an experimental `top_level_all_gather` API (private for now).
yashk2810 Feb 27, 2026
d506abd
Add a test for replicated -> sharded+unreduced reshard
yashk2810 Feb 27, 2026
08a51ce
Update XLA dependency to use revision http://github.com/openxla/xla/c…
Google-ML-Automation Feb 27, 2026
717f99a
Merge pull request #35102 from ROCm:set_release_rpaths_to_rocm_so_tar…
Google-ML-Automation Feb 27, 2026
df6415b
[Pallas:MGPU] Allow inline_mgpu inside a warp context
apaszke Feb 27, 2026
caf1324
[Pallas][TPU kernel interpreter] Consistently use `pltu.HBM` to refer…
Google-ML-Automation Feb 27, 2026
0d37f59
Update XLA dependency to use revision http://github.com/openxla/xla/c…
danielsuo Feb 27, 2026
a573e86
Skip pallas lowering determinism testOrderAgnostic test if not on TPU.
danielsuo Feb 26, 2026
969da1b
[linux-arm64] Reduce parallelism to 32 for nogil pytest cpu to avoid …
danielsuo Feb 27, 2026
143c236
Reverts ee4cea11e08b92f23f464be2e7795f5eb9fe9557
akuegel Feb 27, 2026
8946986
Remove references to tsl/platform/logging.h header.
hawkinsp Feb 27, 2026
6400fa1
[pallas:sc] Moved `jaxpr_call` into core Pallas
superbobry Feb 27, 2026
fd7c2ac
Fix typo in function parameter documentation
joshuapjacob Feb 26, 2026
3ad9484
[mosaic] Added pattern canonicalizing tpu.memref_squeeze(memref.cast(…
superbobry Feb 27, 2026
10a59a5
[Mosaic GPU] fix annotation typo: s/pytype/pylint/
cota Feb 27, 2026
b9926c3
[Mosaic TPU] Make sure to take the export flag into account when not …
apaszke Feb 27, 2026
e3adc6a
[mosaic_gpu] Fixed mosaic-gpu-resolve-trivial-locations
superbobry Feb 27, 2026
1609c18
[Mosaic:GPU] Re-use the barrier buffers between executions to prevent…
PatriosTheGreat Feb 27, 2026
f3bf01a
[Mosaic:GPU] Move collective kernel loading to the prepare stage.
PatriosTheGreat Feb 27, 2026
a25a24d
[Mosaic:GPU] Add test-case for several mosaic ops.
PatriosTheGreat Feb 27, 2026
b2ccb32
Reimplement weakref_lru_cache on top of a custom hash map type.
hawkinsp Feb 27, 2026
0a513d3
[Mosaic TPU] Make sure to allow unregistered dialects before deserial…
apaszke Feb 27, 2026
a30d6bc
Integrate LLVM at llvm/llvm-project@1053047a4be7
Google-ML-Automation Feb 27, 2026
8350f36
Skip tests that are failing at libtpu=0.0.35, but passing at libtpu-0…
danielsuo Feb 27, 2026
438e694
Bump jaxlib_extension_version to 412 due to LLVM integrate.
danielsuo Feb 27, 2026
0eb2314
Skip tests that are failing at libtpu=0.0.35, but passing at libtpu-0…
danielsuo Feb 27, 2026
1899c32
[Mosaic] Remove patterns to fold memref.cast into tpu.enqueue_dma and…
tlongeri Feb 27, 2026
abd765d
Rename `pypi_latest` libtpu-version-type to `pinned` and use pinned l…
ybaturina Feb 27, 2026
bec3b79
Add workflow for build wheels and running bazel tests on ROCm
charleshofer Jan 21, 2026
85ec15d
Raise an error on `f32[4]{R:x} * f32[]`. For example `x: f32[4]{R:x} …
yashk2810 Feb 27, 2026
70a8149
Rename `pypi_latest` libtpu-version-type to `pinned` and use pinned l…
ybaturina Feb 27, 2026
41ae1eb
Reverts b2ccb32cabd9da9c99a4a9b7f163d4059ff448d5
hawkinsp Feb 27, 2026
f154eb0
[Mosaic] Simplify GetCoreTypeOfParentOp
naummo Feb 28, 2026
cfe83cb
Fix release workflow by using the folder with pre-downloaded jax whee…
ybaturina Feb 28, 2026
162b2ed
Fix release workflow by using the folder with pre-downloaded jax whee…
ybaturina Feb 28, 2026
5c96b9e
[Pallas][Mosaic TPU] More complete 1D block support in Pallas/Mosaic.
yueshengys Feb 28, 2026
0b56550
[pallas] rewrite pull_block_spec rule to work in terms of block_index…
levskaya Feb 28, 2026
b66a3c3
Update XLA dependency to use revision http://github.com/openxla/xla/c…
Google-ML-Automation Feb 28, 2026
e5ab805
Skip layout_test.test_layout_donation_mismatching_in_and_out_fails fo…
danielsuo Feb 28, 2026
7aa10a0
Skip layout_test.test_layout_donation_mismatching_in_and_out_fails fo…
danielsuo Feb 28, 2026
500bbc2
Skip sparse random and dot general sampled ad tests.
danielsuo Feb 28, 2026
bb3362e
Skip sparse random and dot general sampled ad tests.
danielsuo Feb 28, 2026
58cb6e5
Prepare for JAX release 0.9.1
danielsuo Feb 26, 2026
eba07b3
Update XLA dependency to use revision http://github.com/openxla/xla/c…
Google-ML-Automation Mar 1, 2026
b0d504a
Update XLA dependency to use revision http://github.com/openxla/xla/c…
Google-ML-Automation Mar 2, 2026
d4204d5
Add C addition to hopper matmul kernel.
Google-ML-Automation Mar 2, 2026
29ce624
Merge remote-tracking branch 'origin/release/0.9.1' into postrelease/…
danielsuo Mar 2, 2026
5b9d90a
Postrelease JAX v0.9.1.
danielsuo Mar 2, 2026
c5d5575
[jaxlib] :normalize_stubs now also fixes improper ``get`` classmethods
superbobry Mar 2, 2026
95d95e3
[mgpu] Minor cleanup in `try_cluster_cancel`.
chr1sj0nes Mar 2, 2026
0cd1592
[Mosaic GPU][NFC] Remove duplicate layout inference rule for `vector.…
allanrenucci Mar 2, 2026
b4e692e
Merge pull request #35502 from superbobry:stubgen
Google-ML-Automation Mar 2, 2026
e20119d
[Mosaic GPU][NFC] Use more precise type for `op` in `_optimization_ba…
allanrenucci Mar 2, 2026
41430d3
[pallas:sc] Switched to using `pltpu.emit_pipeline` in the lowering
superbobry Mar 2, 2026
31fc0bd
[Mosaic] Add assembly format to tpu.delay_op
vsytch Mar 2, 2026
44736b9
[Mosaic GPU] Fix for the latest LLVM version
apaszke Mar 2, 2026
a3ffef5
Add serialization support for Mosaic GPU kernels.
khasanovaa Mar 2, 2026
0ded2fe
Regenerate type stubs for nanobind v2.12.0.
hawkinsp Mar 2, 2026
ed6d58f
[Mosaic GPU][NFC] Remove dead `constraints.BroadcastInDim`.
allanrenucci Mar 2, 2026
edf82d1
Merge pull request #35501 from jax-ml:postrelease/0.9.1
Google-ML-Automation Mar 2, 2026
68b2565
Reverts 41ae1eb6ea50739feb91589705aa078b1495d40f
hawkinsp Mar 2, 2026
ba54b34
Merge pull request #34702 from ROCm:add-run-bazel-test-rocm-rbe
Google-ML-Automation Mar 2, 2026
a192b43
Merge pull request #35417 from joshuapjacob:patch-1
Google-ML-Automation Mar 2, 2026
a1b958b
[Mosaic GPU] Add support for untiled CP_ASYNC
apaszke Mar 2, 2026
4b423a6
Merge pull request #35453 from jakevdp:pyrefly-ad
Google-ML-Automation Mar 2, 2026
2fe5318
Remove nvidia_wheel_versions
charleshofer Nov 12, 2025
c7ec794
Make jaxlib targets visible
charleshofer Nov 12, 2025
789df45
hipblas typedef fix
charleshofer Nov 12, 2025
d0e3207
No GPU fail
charleshofer Nov 13, 2025
caf56b6
Wrap HIP inline functions in anonymous namespaces in vendor.h
mminutoli Feb 12, 2026
d8deb09
SWDEV-512768 - Replace hipGetLastError with hipExtGetLastError
dsicarov-amd Jun 10, 2025
6eee40d
Add shared utility function get_rocm_version to test_util.py
charleshofer Nov 14, 2025
3128e0e
Fix hipSparse CSR algorithm mappings for ROCm 7
phambinhfin Nov 17, 2025
cfddfa6
Fix v_pages quantization and adjust test params for ROCm compatibilit…
phambinhfin Nov 19, 2025
cb904b3
Address LLVM assertion failure due to a multithreaded use. Update .gi…
Arech8 Nov 26, 2025
b78fe9a
Add skip of test_is_finite() on Cuda (#565)
Arech8 Nov 26, 2025
514fec3
Add rocm test requirements file (#570)
AratiGanesh Dec 15, 2025
b64a5e4
Let the unit tests use build.py for setting up Bazel commands for uni…
charleshofer Dec 15, 2025
640e4a2
adding abort logic to rocm/jax (#590)
gulsumgudukbay Jan 13, 2026
6eb4fdf
Skip is_finite tests on ROCm (not in Triton lowering for jax 0.8.0) (…
phambinhfin Jan 14, 2026
14b0125
Fix shared memory limit check for ROCm in test_dot (#596)
phambinhfin Jan 14, 2026
92211a5
Fix Numpy signatures test (#598)
magaonka-amd Jan 14, 2026
b202e1f
fix merge arts
Ruturaj4 Jan 18, 2026
4c60b2b
Enable RngShardingTests (#644)
gulsumgudukbay Jan 22, 2026
22ac32a
Enable test_variadic_reduce_window on ROCm (#647)
mminutoli Feb 12, 2026
8ecdd8b
Skip sparse tests on ROCm due to hipSPARSE issue (#652)
magaonka-amd Jan 23, 2026
cc54366
Update sparse test skip messages in v0.8.2 (#653)
magaonka-amd Jan 23, 2026
1b6c9e7
Skip sparse tests on ROCm due to hipSPARSE issue (#652)
magaonka-amd Jan 23, 2026
f0e9039
Update sparse test skip messages in v0.8.2 (#653)
magaonka-amd Jan 23, 2026
0cc3b98
Skip sparse tests on ROCm due to hipSPARSE issue (#652)
magaonka-amd Jan 23, 2026
4732196
Update sparse test skip messages in v0.8.2 (#653)
magaonka-amd Jan 23, 2026
e571492
Enable testMultivariateNormalSingularCovariance on ROCm (#666)
AratiGanesh Jan 28, 2026
84ad452
Skip test_tridiagonal_solve on ROCm due to hipSPARSE numerical errors…
AratiGanesh Jan 28, 2026
3c29577
Update Skip Reason Outputs (#663)
gulsumgudukbay Jan 28, 2026
7648e11
Skip sparse tests on ROCm due to hipSPARSE issue (#652)
magaonka-amd Jan 23, 2026
e0d2d1e
Update sparse test skip messages in v0.8.2 (#653)
magaonka-amd Jan 23, 2026
a186d62
Skip testCudaArrayInterfaceOnNonCudaFails on ROCm platform (#677)
magaonka-amd Jan 29, 2026
9983dd0
Skip sparse tests on ROCm due to hipSPARSE issue (#652)
magaonka-amd Jan 23, 2026
c45c49a
Update sparse test skip messages in v0.8.2 (#653)
magaonka-amd Jan 23, 2026
14d1b17
Skip sparse tests on ROCm due to hipSPARSE issue (#652)
magaonka-amd Jan 23, 2026
ec39433
Update sparse test skip messages in v0.8.2 (#653)
magaonka-amd Jan 23, 2026
d18c757
Add ROCm encoding for test_struct_encoding_determinism (#683)
AratiGanesh Feb 5, 2026
717cd20
Remove 'mean' from unsupported params for jnp.var (#689)
magaonka-amd Feb 6, 2026
02ae6c5
Implement approx_tanh for ROCm using OCML tanh function (#691)
magaonka-amd Feb 6, 2026
f2f9dd1
Skipping testEighTinyNorm due to hipSolver issues (#697)
AratiGanesh Feb 9, 2026
6957e42
Abort detection CI workflow (#688)
gulsumgudukbay Feb 20, 2026
7d684aa
Abort-Detection: Fix halt-for-connection input (#712)
gulsumgudukbay Feb 24, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
518 changes: 319 additions & 199 deletions .bazelrc

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion .bazelversion
Original file line number Diff line number Diff line change
@@ -1 +1 @@
7.4.1
7.7.0
15 changes: 2 additions & 13 deletions .editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,7 @@ indent_style = space
end_of_line = lf
trim_trailing_whitespace = true
insert_final_newline = true

[*.py]
max_line_length = 79
indent_size = 2

[*.rst]
max_line_length = 79
indent_size = 2

[*.md]
max_line_length = 79
indent_size = 2

[*.yml]
indent_size = 2
[*.py]
max_line_length = 80
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/bug-report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ body:

[issue search]: https://github.com/jax-ml/jax/search?q=is%3Aissue&type=issues

[Raw report]: http://github.com/jax-ml/jax/issues/new
[Raw report]: https://github.com/jax-ml/jax/issues/new?template=none
- type: textarea
attributes:
label: Description
Expand Down
11 changes: 11 additions & 0 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
<!--

Thanks for taking the time to contribute to JAX! A couple things to keep in mind:

* Contributing to JAX requires signing the Contributor License Agreement: https://docs.jax.dev/en/latest/contributing.html#google-contributor-license-agreement

* Please run lint checks and tests locally: https://docs.jax.dev/en/latest/contributing.html#contributing-code-using-pull-requests

* If applicable, read our policy on AI generated code: https://docs.jax.dev/en/latest/contributing.html#can-i-contribute-ai-generated-code

-->
20 changes: 20 additions & 0 deletions .github/actionlint.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Configuration related to self-hosted runner.
self-hosted-runner:
labels:
- "linux-x86-n4-16" # Linux X86 runner using the 16 vcpu n4-standard-16 machine.
- "linux-x86-n4-32" # Linux X86 runner using the 32 vcpu n4-standard-32 machine.
- "linux-x86-n4-64" # Linux X86 runner using the 64 vcpu n2-standard-64 machine.
- "linux-x86-g2-16-l4-1gpu" # Linux X86 GPU runner using g2-standard-16 machine with 1 NVIDIA L4 GPU attached.
- "linux-x86-g2-48-l4-4gpu" # Linux X86 GPU runner using g2-standard-48 machine with 4 NVIDIA L4 GPUs attached.
- "linux-x86-ct5lp-224-8tpu" # Linux X86 TPU runner using ct5lp-hightpu-8t machine with 2x4 topology.
- "linux-arm64-c4a-16" # Linux ARM64 CPU Runner using the 16 vcpu c4a-standard-16 machine.
- "linux-arm64-c4a-64" # Linux ARM64 CPU Runner using the 64 vcpu c4a-standard-64 machine.
- "windows-x86-n2-16" # Windows X86 runner using n2-standard-16 machine.
- "windows-x86-n2-64" # Windows X86 runner using n2-standard-64 machine.
- "linux-x86-a4-224-b200-1gpu" # Linux X86 GPU runner using 1 B200 GPU and 1/8 the resources of a a4-highgpu-8g machine
- "linux-x86-a3-8g-h100-8gpu" # Linux X86 GPU runner using a3-highgpu-8g machine with 8 NVIDIA H100 GPUs attached.
- "linux-x86-ct6e-180-8tpu" # Linux X86 TPU runner using ct6e-hightpu-8t machine with 2x4 topology.
- "linux-x86-ct6e-180-4tpu" # Linux X86 TPU runner using ct6e-hightpu-4t machine with 2x2 topology.
- "linux-x86-ct4p-240-4tpu" # Linux X86 TPU runner using ct4p-hightpu-4t machine with 2x2x1 topology.
- "linux-x86-tpu7x-224-4tpu" # Linux X86 TPU runner using tpu7x-224 machine with 4 TPU chips (8 cores) and 2x2x1 topology.
- "linux-x86_64-cirrascale-64-8gpu-amd-mi250" # AMD runner
113 changes: 113 additions & 0 deletions .github/actions/download-jax-cpu-wheels/action.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Composite action to download the jax and jaxlib wheels
name: Download JAX CPU wheels

inputs:
runner:
description: "Which runner type should the wheels be downloaded for?"
type: string
default: "linux-x86-n4-16"
python:
description: "Which python version should the artifact be downloaded for?"
required: true
type: string
jaxlib-version:
description: "Which jaxlib version to download? (head/pypi_latest)"
type: string
default: "head"
skip-download-jaxlib-from-gcs:
description: "Whether to skip downloading the jaxlib artifact from GCS (e.g for testing a jax only release)"
default: '0'
type: string
gcs_download_uri:
description: "GCS location prefix from where the artifacts should be downloaded"
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
type: string
permissions: {}
runs:
using: "composite"

steps:
# Note that certain envs such as JAXCI_HERMETIC_PYTHON_VERSION are set by the calling workflow.
- name: Set env vars for use in artifact download URL
shell: bash
run: |
os=$(uname -s | awk '{print tolower($0)}')
arch=$(uname -m)

# Adjust os and arch for Windows
if [[ $os =~ "msys_nt" ]] && [[ $arch =~ "x86_64" ]]; then
os="win"
arch="amd64"
fi

# Get the major and minor version of Python.
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.13-nogil, then python_major_minor=313t
python_major_minor=$(echo "${JAXCI_HERMETIC_PYTHON_VERSION//-nogil/t}" | tr -d '.')

echo "OS=${os}" >> $GITHUB_ENV
echo "ARCH=${arch}" >> $GITHUB_ENV
# Python wheels follow a naming convention: standard wheels use the pattern
# `*-cp<py_version>-cp<py_version>-*`, while free-threaded wheels use
# `*-cp<py_version>-cp<py_version>t-*`.
echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV
- name: Download wheels from GCS (non-Windows runs)
shell: bash
id: download-wheel-artifacts-nw
# Set continue-on-error to true to prevent actions from failing the workflow if this step
# fails. Instead, we verify the outcome in the step below so that we can print a more
# informative error message.
continue-on-error: true
if: ${{ !contains(inputs.runner, 'windows-x86') }}
run: |
mkdir -p $(pwd)/dist
gcloud storage cp -r "${INPUTS_GCS_DOWNLOAD_URI}"/jax*py3*none*any.whl $(pwd)/dist/

if [[ "${INPUTS_SKIP_DOWNLOAD_JAXLIB_FROM_GCS}" == "1" ]]; then
echo "JAX only release. Only downloading the jax wheel from the release bucket."
else
if [[ ${INPUTS_JAXLIB_VERSION} == "head" ]]; then
gcloud storage cp -r "${INPUTS_GCS_DOWNLOAD_URI}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/
elif [[ ${INPUTS_JAXLIB_VERSION} == "pypi_latest" ]]; then
PYTHON=python${INPUTS_PYTHON}
$PYTHON -m pip download jaxlib --dest $(pwd)/dist/
else
echo "Invalid jaxlib version: ${INPUTS_JAXLIB_VERSION}"
exit 1
fi
fi
env:
INPUTS_GCS_DOWNLOAD_URI: ${{ inputs.gcs_download_uri }}
INPUTS_SKIP_DOWNLOAD_JAXLIB_FROM_GCS: ${{ inputs.skip-download-jaxlib-from-gcs }}
INPUTS_JAXLIB_VERSION: ${{ inputs.jaxlib-version }}
INPUTS_PYTHON: ${{ inputs.python }}
- name: Download wheels from GCS (Windows runs)
shell: cmd
id: download-wheel-artifacts-w
# Set continue-on-error to true to prevent actions from failing the workflow if this step
# fails. Instead, we verify the outcome in step below so that we can print a more
# informative error message.
continue-on-error: true
if: ${{ contains(inputs.runner, 'windows-x86') }}
run: |
mkdir dist
@REM Use `call` so that we can run sequential gcloud storage commands on Windows
@REM See https://github.com/GoogleCloudPlatform/gsutil/issues/233#issuecomment-196150652
call gcloud storage cp -r "%INPUTS_GCS_DOWNLOAD_URI%"/jax*py3*none*any.whl dist/

if "%INPUTS_SKIP_DOWNLOAD_JAXLIB_FROM_GCS%"=="1" (
echo "JAX only release. Only downloading the jax wheel from the release bucket."
) else (
call gcloud storage cp -r "%INPUTS_GCS_DOWNLOAD_URI%/jaxlib*%PYTHON_MAJOR_MINOR%*%OS%*%ARCH%*.whl" dist/
)
env:
INPUTS_GCS_DOWNLOAD_URI: ${{ inputs.gcs_download_uri }}
INPUTS_SKIP_DOWNLOAD_JAXLIB_FROM_GCS: ${{ inputs.skip-download-jaxlib-from-gcs }}
- name: Skip the test run if the wheel artifacts were not downloaded successfully
shell: bash
if: steps.download-wheel-artifacts-nw.outcome == 'failure' || steps.download-wheel-artifacts-w.outcome == 'failure'
run: |
echo "Failed to download wheel artifacts from GCS. Please check if the wheels were"
echo "built successfully by the artifact build jobs and are available in the GCS bucket."
echo "Skipping the test run."
exit 1
108 changes: 108 additions & 0 deletions .github/actions/download-jax-cuda-wheels/action.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Composite action to download the jax, jaxlib, and the CUDA plugin wheels
name: Download JAX CUDA wheels

inputs:
python:
description: "Which python version should the artifact be downloaded for?"
type: string
required: true
cuda-version:
description: "Which cuda version should the artifact be downloaded for?"
type: string
default: "12"
use-nvidia-pip-wheels:
description: "Whether to download Nvidia CUDA packages from PyPI?"
type: boolean
default: false
jaxlib-version:
description: "Which jaxlib version to download? (head/pypi_latest)"
type: string
default: "head"
download-jax-from-gcs:
description: "Whether to download the jax wheel from GCS"
default: '1'
type: string
skip-download-jaxlib-and-plugins-from-gcs:
description: "Whether to skip downloading the jaxlib and plugins from GCS (e.g for testing a jax only release)"
default: '0'
type: string
gcs_download_uri:
description: "GCS location prefix from where the artifacts should be downloaded"
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
type: string
permissions: {}
runs:
using: "composite"

steps:
# Note that certain envs such as JAXCI_HERMETIC_PYTHON_VERSION are set by the calling workflow.
- name: Set env vars for use in artifact download URL
shell: bash
run: |
os=$(uname -s | awk '{print tolower($0)}')
arch=$(uname -m)

# Get the major and minor version of Python.
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.11, then python_major_minor=311
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.13-nogil, then python_major_minor=313t
python_major_minor=$(echo "${JAXCI_HERMETIC_PYTHON_VERSION//-nogil/t}" | tr -d '.')

echo "OS=${os}" >> $GITHUB_ENV
echo "ARCH=${arch}" >> $GITHUB_ENV
# Python wheels follow a naming convention: standard wheels use the pattern
# `*-cp<py_version>-cp<py_version>-*`, while free-threaded wheels use
# `*-cp<py_version>-cp<py_version>t-*`.
echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV

# Get the CUDA major version only
full_cuda_version="${INPUTS_CUDA_VERSION}"
echo "JAXCI_CUDA_VERSION=${full_cuda_version%%.*}" >> $GITHUB_ENV
env:
INPUTS_CUDA_VERSION: ${{ inputs.cuda-version }}
- name: Download wheels
shell: bash
id: download-wheel-artifacts
# Set continue-on-error to true to prevent actions from failing the workflow if this step
# fails. Instead, we verify the outcome in the next step so that we can print a more
# informative error message.
continue-on-error: true
run: |
mkdir -p $(pwd)/dist
if [[ "${INPUTS_DOWNLOAD_JAX_FROM_GCS}" == "1" ]]; then
gcloud storage cp -r "${INPUTS_GCS_DOWNLOAD_URI}"/jax*py3*none*any.whl $(pwd)/dist/
else
echo "JAX wheel won't be downloaded, only jaxlib pre-built wheel is tested."
fi

# Do not download the jaxlib and CUDA plugin artifacts if we are testing a jax only
# release.
if [[ "${INPUTS_SKIP_DOWNLOAD_JAXLIB_AND_PLUGINS_FROM_GCS}" == "1" ]]; then
echo "JAX only release. Only downloading the jax wheel from the release bucket."
else
if [[ ${INPUTS_JAXLIB_VERSION} == "head" ]]; then
gcloud storage cp -r "${INPUTS_GCS_DOWNLOAD_URI}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/
gcloud storage cp -r "${INPUTS_GCS_DOWNLOAD_URI}/jax*cuda${JAXCI_CUDA_VERSION}*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/
gcloud storage cp -r "${INPUTS_GCS_DOWNLOAD_URI}/jax*cuda${JAXCI_CUDA_VERSION}*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/
elif [[ ${INPUTS_JAXLIB_VERSION} == "pypi_latest" ]]; then
PYTHON=python${INPUTS_PYTHON}
$PYTHON -m pip download jaxlib jax-cuda${JAXCI_CUDA_VERSION}-pjrt jax-cuda${JAXCI_CUDA_VERSION}-plugin --dest $(pwd)/dist/
else
echo "Invalid jaxlib version: ${INPUTS_JAXLIB_VERSION}"
exit 1
fi
fi
env:
INPUTS_DOWNLOAD_JAX_FROM_GCS: ${{ inputs.download-jax-from-gcs }}
INPUTS_GCS_DOWNLOAD_URI: ${{ inputs.gcs_download_uri }}
INPUTS_SKIP_DOWNLOAD_JAXLIB_AND_PLUGINS_FROM_GCS: ${{ inputs.skip-download-jaxlib-and-plugins-from-gcs }}
INPUTS_JAXLIB_VERSION: ${{ inputs.jaxlib-version }}
INPUTS_PYTHON: ${{ inputs.python }}
- name: Skip the test run if the wheel artifacts were not downloaded successfully
shell: bash
if: steps.download-wheel-artifacts.outcome == 'failure'
run: |
echo "Failed to download wheel artifacts. Please check if the wheels were"
echo "built successfully by the artifact build jobs and are available in the GCS bucket if
echo "downloading from GCS."
echo "Skipping the test run."
exit 1
Loading
Loading