Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
6584 commits
Select commit Hold shift + click to select a range
54cf88d
Merge pull request #34204 from jakevdp:doc-sidebar
Google-ML-Automation Jan 7, 2026
227cb22
Merge pull request #32268 from samanklesaria:issues/32267
Google-ML-Automation Jan 7, 2026
27b3bff
Prefer isinstance(x, type) over type.isinstance
boomanaiden154 Jan 7, 2026
2486a5e
Add jax.experimental.random to the wheel build
jakevdp Jan 7, 2026
bf11cd8
fix shard_map transpose explicit sharding zero unsharding bug
mattjj Jan 7, 2026
f6b6edc
add optional `explain` callback for weakref_lru_cache misses
mattjj Jan 7, 2026
75fd86b
Merge pull request #34209 from jakevdp:fix-wheel
Google-ML-Automation Jan 7, 2026
75130d4
Merge pull request #34211 from mattjj:andy-customvjp-none
Google-ML-Automation Jan 8, 2026
caaad6e
Handle ad.Zero cotangents in _reshard_transpose_fancy.
yashk2810 Jan 8, 2026
fd57ef6
Update XLA dependency to use revision http://github.com/openxla/xla/c…
Google-ML-Automation Jan 8, 2026
1199fd7
[Pallas/Mosaic GPU] Enable more `WarpSpecializedPipelineWGTest`s.
bchetioui Jan 8, 2026
de96f7b
[Mosaic GPU] Add a `bitwidth` field to `Relayout` constraints in layo…
bchetioui Jan 8, 2026
2a1a0a9
[pallas:sc] Allowed specifying tiling in `pltpu.emit_pipeline`
superbobry Jan 8, 2026
60a32d3
[XLA:MGPU] Port Tiling to C++.
golechwierowicz Jan 8, 2026
8c2555a
deviceless aot test
keshavb96 Jan 8, 2026
07bb0f6
Add a thread guard config option.
emilyfertig Jan 8, 2026
1c3aa2f
[Mosaic] Move Float8EXMYType to tpu.td.
WindQAQ Jan 8, 2026
4b25896
Merge pull request #33542 from keshavb96:deviceless_aot_test
Google-ML-Automation Jan 8, 2026
4e89ce4
Remove redundant test targets that are already executed as a part of …
ybaturina Jan 8, 2026
9fad589
respect self.statics in FlatTree.__eq__
mattjj Jan 8, 2026
b02123b
Fix precommit breakage
boomanaiden154 Jan 8, 2026
5279000
sick
mattjj Jan 8, 2026
0636587
revive as many cache miss explanations as reasonably possible
mattjj Jan 9, 2026
12dce9c
[pallas:sc] Skip a few tests failing when the compiler uses tiled mem…
superbobry Jan 9, 2026
b10cee6
Colocated python: Use a wrapper when storing remote objects at the ba…
Google-ML-Automation Jan 9, 2026
d6101d8
skip on jaxlib version
mattjj Jan 9, 2026
909e638
Add `halt-for-connection` to `build_artifacts.yml` workflow call.
ybaturina Jan 9, 2026
a8ca82a
Merge pull request #33839 from jax-ml:pjit-without-linear-util
Google-ML-Automation Jan 9, 2026
774597a
Update XLA dependency to use revision http://github.com/openxla/xla/c…
Google-ML-Automation Jan 9, 2026
5a50361
[Mosaic GPU] Use `isinstance(x, mlir_ty)` instead of the deprecated `…
bchetioui Jan 9, 2026
568bca1
Remove unnecessary return from placement new in Mosaic GPU extension.
beckerhe Jan 9, 2026
35ffd40
[mosaic] Added a canonicalization rule for memref.dim of tpu.memref_s…
superbobry Jan 9, 2026
8947c9a
[Mosaic GPU] Add basic support for DSMEM
apaszke Jan 9, 2026
73ad4e5
[Mosaic GPU] Add support for using the redux instructions to speed up…
apaszke Jan 9, 2026
e27b89b
[Pallas/Mosaic GPU] Disable `test_ragged_dot_transposed` temporarily.
bchetioui Jan 9, 2026
741ecce
Remove some more stale version guards.
hawkinsp Jan 9, 2026
0b4a440
[bug] fix grad-of-vmap-of-dynamic_slice with out-of-bound indices
jakevdp Jan 8, 2026
c1f8ccf
Allow max_size to be None (infinite).
pschuh Jan 9, 2026
09c3f1c
Don't pass axis_index_groups to prim.bind in psum batching rule if pr…
yashk2810 Jan 9, 2026
d191861
Migrates `builder.create<Op>()` => `Op::create()`
Google-ML-Automation Jan 9, 2026
046b70c
Merge pull request #34229 from jakevdp:dynamic-slice-grad-fix
Google-ML-Automation Jan 9, 2026
9528be8
Fix Array `__format__` and `__str__` to handle McJAX arrays that are …
emilyfertig Jan 9, 2026
d557a06
[hijax] start implementing custom_vjp on top of hijax primitives
mattjj Jan 6, 2026
705171a
Merge pull request #34161 from mattjj:custom-vjp3-youandme
Google-ML-Automation Jan 9, 2026
077574f
Delete the jax_collectives_common_channel_id flag.
hawkinsp Jan 9, 2026
1447d89
[indexing] implement strategy='dynamic_slice'
jakevdp Jan 9, 2026
eab73db
Merge pull request #34225 from jakevdp:index-dynamic-slice
Google-ML-Automation Jan 9, 2026
6acf46c
[TPU] Reenable a disabled test that now passes.
hawkinsp Jan 9, 2026
2f3063c
[Pallas MGPU] Use `the cp.async.bulk` instruction for large contiguou…
Rifur13 Jan 9, 2026
d800c3d
[Mosaic GPU] Expose the `mbarrier.complete_tx` instruction to manuall…
Rifur13 Jan 9, 2026
22b3a0c
Shorten splash attention kernel name to address https://github.com/ja…
rdyro Jan 9, 2026
b2217c3
[shmap] fix shmap error logic when subclasses of pspec are used
mattjj Jan 9, 2026
6744656
Run the GetKeys inside GetOrCreate.
pschuh Jan 9, 2026
603cef6
Merge pull request #34262 from mattjj:shmap-error-isleaf
Google-ML-Automation Jan 9, 2026
e184e21
Run g4 fix on the weakref_lru_cache code.
pschuh Jan 9, 2026
c5e5736
Update rules_ml_toolchain version to accommodate custom redistributio…
ybaturina Jan 9, 2026
a930721
Fix deserialization with specified layouts via a ShapeDtypeStruct bei…
rdyro Jan 10, 2026
11abf3b
Remove obsolete filegroup
18praveenb Jan 10, 2026
b00dc36
Update XLA dependency to use revision http://github.com/openxla/xla/c…
Google-ML-Automation Jan 10, 2026
c0fcb0b
don't warn on complex->real cast in dot transpose
mattjj Dec 3, 2025
05d4abb
Merge pull request #33708 from mattjj:issue33521
Google-ML-Automation Jan 10, 2026
902cee5
[hijax] add vmap suport to CustomVJP hijax primitive
mattjj Jan 9, 2026
6f73035
Merge pull request #34268 from mattjj:custom-vjp3-youandme-2
Google-ML-Automation Jan 11, 2026
d40b4c9
Update XLA dependency to use revision http://github.com/openxla/xla/c…
Google-ML-Automation Jan 11, 2026
447b004
[Mosaic] Add abs, sign, erf, atan2, reduce_min, reduce_prod support t…
oulgen Dec 30, 2025
0974805
[JAX] Add IFRT SerDes to jaxlib deps
hyeontaek Jan 12, 2026
5815456
Update XLA dependency to use revision http://github.com/openxla/xla/c…
Google-ML-Automation Jan 12, 2026
1ef117f
Merge pull request #34120 from oulgen:mgpu-ops
Google-ML-Automation Jan 12, 2026
0ad1117
[Mosaic GPU] Test approximate math functions properly + eta reduce args
apaszke Jan 12, 2026
098e953
[Mosaic GPU] Enable redux.sync.f32 on Blackwell
apaszke Jan 12, 2026
658e650
[XLA:MGPU] Port Replicated wrapper to C++.
golechwierowicz Jan 12, 2026
dcb011a
[mosaic] infer-memref-layout now accepts target shape as a span
superbobry Jan 12, 2026
e4fa025
[Pallas:MGPU] Lower `lax.sign` consistently for LANE and WG semantics.
allanrenucci Jan 12, 2026
a070482
[Pallas:MGPU] Add Pallas lowering for `pl.debug_check` under WG seman…
allanrenucci Jan 12, 2026
1242ba3
Add `thread_guard` to the public API.
emilyfertig Jan 12, 2026
7cee4db
[Mosaic GPU] Add support for f8 types for WGMMA with lhs in registers…
allanrenucci Jan 12, 2026
27a568a
[Pallas MGPU] Clip the size of contiguous TMA transfers to the size o…
Rifur13 Jan 12, 2026
6ce4314
lax: ensure padtype_to_pads returns Python ints
Prakharprasun Dec 16, 2025
a4036e4
[jax.collect_profile] Allow arbitrary options to be passed to XProf
Matt-Hurd Jan 12, 2026
aad0e01
[indexing] support newaxis in static/dynamic slice strategies
jakevdp Jan 12, 2026
96787dd
Plumb prim params for call discharge rule (to handle named_computation_p
sharadmv Jan 12, 2026
06d7542
[sc] Generalized infer-memref-layout to support SC tiling
superbobry Jan 12, 2026
9e8fae6
[Pallas/TPU] Don't lower eqns that have all dropvar outputs (aka DCE at
sharadmv Jan 12, 2026
c984def
[sc] Removed `infer_kernel_arguments` from infer-memref-layout
superbobry Jan 12, 2026
a10149b
Merge pull request #33974 from Prakharprasun:fix-padtype-to-pads
Google-ML-Automation Jan 13, 2026
c5e4240
[Pallas] Allowlist semaphore/prng effects under remat and custom
sharadmv Jan 13, 2026
c396a36
Add lax.tile_p
levskaya Jan 13, 2026
bd07769
Automated Code Change
Google-ML-Automation Jan 13, 2026
4691dfe
[Mosaic:TPU] Clean up tpu.memref_slice verifier
tlongeri Jan 13, 2026
9a243d7
Update XLA dependency to use revision http://github.com/openxla/xla/c…
Google-ML-Automation Jan 13, 2026
f0ca449
Add pallas lowering for jnp.tile using tpu.repeat.
levskaya Jan 13, 2026
96ecec4
[export] Add support for explicit sharding.
gnecula Nov 28, 2025
b7700e4
Handle the case when shardings are GSPMDSharding.
gnecula Dec 11, 2025
e5faf1d
[Mosaic GPU] Support more reduction kinds and layouts in `MultiDimRed…
allanrenucci Jan 13, 2026
6eb71e9
[Pallas/interpreter] Add a prototype for a GPU kernel interpreter.
Google-ML-Automation Jan 13, 2026
f25011f
[Pallas:MGPU][NFC] Update outdated docstring for `scratch_view`.
allanrenucci Jan 13, 2026
678adf5
[Mosaic GPU] Fix issue in `vector_dim` reduction with `vec_len > 2`.
bchetioui Jan 13, 2026
35a14b8
[Mosaic GPU][NFC] Replace usages of `_gpu_ops_gen` with `gpu` dialect…
allanrenucci Jan 13, 2026
ed12030
[Mosaic GPU] Fix the return type of `_lift_fast_packed_instr` to matc…
bchetioui Jan 13, 2026
4b76bd8
Set the proper memory kinds for output in call_exported
gnecula Jan 10, 2026
fe56c71
Merge pull request #33597 from gnecula:export_shit
Google-ML-Automation Jan 13, 2026
4fdb82c
Merge pull request #34290 from jakevdp:indexing-none
Google-ML-Automation Jan 13, 2026
3e7d4dd
[Pallas:MGPU] Remove broadcast in scalar reduce test.
allanrenucci Jan 13, 2026
1a74daa
Change configurations for nightly TPU tests:
ybaturina Jan 13, 2026
4732c73
[hijax] more custom_vjp support: remat and optimize_remat
mattjj Jan 12, 2026
5267d8a
Update the correct core count for TPU v7 runners
quoctruong Jan 13, 2026
1bbbe8c
Fix numpy dependency.
ybaturina Jan 13, 2026
940d241
Merge pull request #34320 from mattjj:custom-vjp3-youandme-3
Google-ML-Automation Jan 13, 2026
99f6558
Add "jax/_src/pallas/mosaic_gpu/interpret:interpret_pallas_call" to j…
ybaturina Jan 13, 2026
fcbffe9
Revert switch to using `cp.async.bulk`.
Rifur13 Jan 13, 2026
0b7fc6a
Breaking existing tests
JW1992 Jan 13, 2026
2a5f637
[Pallas] Add delay effect
sharadmv Jan 13, 2026
4d76b71
Add "jax_force_dcn_cross_host_transfers" flag to use DCN instead of t…
emilyfertig Jan 14, 2026
5f9c795
[Pallas TPU] Fix bug where in kernels that only contain inter-core
sharadmv Jan 14, 2026
4f642fc
Use pe._default_dce_rule if a prim doesn't have a `dce` method on it …
yashk2810 Jan 14, 2026
b66cb45
Update XLA dependency to use revision http://github.com/openxla/xla/c…
Google-ML-Automation Jan 14, 2026
ab6c95c
[Mosaic GPU] Support signed/unsigned min reductions in `MultiDimReduc…
allanrenucci Jan 14, 2026
7980031
[Pallas:MGPU] Implement min/max scalar reductions under WG semantics.
allanrenucci Jan 14, 2026
974972b
[pallas:mosaic] tpu.memref_{slice,squeeze} no longer support strided …
superbobry Jan 14, 2026
e9d2dee
Update XLA commit to https://github.com/openxla/xla/commit/c358f68c2a…
danielsuo Jan 14, 2026
12fb9d2
[pmap] Suppress PmapSharding deprecation warning in multihost_utils_t…
danielsuo Jan 14, 2026
b8c00a8
Don't use convert_element_type to cast in dot_general bwd pass. Inste…
yashk2810 Jan 14, 2026
4b9bbb5
[JAX] Populate Send/Recv frontend attributes.
georgepaw Jan 14, 2026
de8e4e8
[pmap] Remove the `jax_pmap_no_rank_reduction` config state.
danielsuo Jan 14, 2026
34a4cbc
replace uses of pltpu.repeat with jnp.tile in pallas.
levskaya Jan 14, 2026
9efab3c
[pmap] Expand docs for migrating from pmap to shard_map.
danielsuo Jan 14, 2026
697b18d
Remove problematic warnings filter
jakevdp Jan 14, 2026
ae58223
Add mapped_aval and unmapped_aval to jax.extend.core
jakevdp Jan 13, 2026
09684bc
[dep] Deprecate jax.numpy.fix
jakevdp Jan 14, 2026
69e2fa5
Merge pull request #34310 from jakevdp:mapped-aval
Google-ML-Automation Jan 14, 2026
52ea5e8
Merge pull request #34293 from jakevdp:dep-fix
Google-ML-Automation Jan 14, 2026
8218324
[Pallas:MGPU] Enable `test_load_store_wgmma_transposed` under WG sema…
allanrenucci Jan 14, 2026
540671e
Reverts 0b7fc6a58ec36929ff5e98b46a359551839604eb
gnecula Jan 14, 2026
e4e22fc
[indexing] fully support indexing & normalization modes within static…
jakevdp Jan 14, 2026
20627f6
[Pallas:MGPU] Simplify Mosaic GPU tests by removing intermediate SMEM…
allanrenucci Jan 14, 2026
2e29fd9
Merge pull request #34327 from jakevdp:warning-filter
Google-ML-Automation Jan 14, 2026
ddddd11
Merge pull request #34326 from jakevdp:indexing-normalize-indices
Google-ML-Automation Jan 14, 2026
30e528a
relax fp8 sdpa test tolerance
Cjkkkk Jan 14, 2026
68e2033
[Pallas/interpreter] Add support for `run_scoped` in the GPU kernel i…
Google-ML-Automation Jan 14, 2026
a4befcd
[Pallas] Fix broken pallas distributed test
sharadmv Jan 14, 2026
6bb594a
Remove support for some old format in `jax/tests/export_back_compat_t…
ZixuanJiang Jan 14, 2026
3422bfc
Reverts b8c00a84ddc1c463fcfe7b7bbdaf2eaa670886df
Google-ML-Automation Jan 15, 2026
7cd6b1a
Only reset preferred_element_type = x.aval.dtype if explicitly set
pschuh Jan 15, 2026
267a9a5
Fix backwards compatibility with jaxlib.
emilyfertig Jan 15, 2026
a13f53e
[NFC] Use getDefiningOp<Op>
tlongeri Jan 15, 2026
643d2e4
Update XLA dependency to use revision http://github.com/openxla/xla/c…
Google-ML-Automation Jan 15, 2026
7a7fd1e
[export] Add backwards compatibility tests for v6.
gnecula Dec 10, 2025
bcc959b
Merge pull request #33853 from gnecula:export_v6_tests
Google-ML-Automation Jan 15, 2026
9ac60a3
[XLA:MGPU] Port TiledLayout's construction logic to C++.
golechwierowicz Jan 15, 2026
ef84d38
[Mosaic GPU] Disable redux ops until properly benchmarked and optimized
Google-ML-Automation Jan 15, 2026
92f0810
[XLA:MGPU] Port Partitioned.*Dims and VectorLength methods to C++.
golechwierowicz Jan 15, 2026
15a3e1e
Reverts 7cd6b1a9c711068f4f05312d24ef598fe40d4e50
Google-ML-Automation Jan 15, 2026
602f212
[pallas] Make the LoweringDynamicShapeEnv use local mappings
gnecula Jan 15, 2026
2df10cb
[export] Cleanup the handling of has_named_sharding
gnecula Jan 15, 2026
31ef45e
Merge pull request #34403 from gnecula:export_fix1
Google-ML-Automation Jan 15, 2026
d66c2a6
Merge pull request #34400 from gnecula:shape_poly_pallas_cache_reuse
Google-ML-Automation Jan 15, 2026
0c80468
[pallas:sc] The lowering can now use an empty grid directly
superbobry Jan 15, 2026
bb1017b
[pmap] Deprecate setting `jax_pmap_shmap_merge` config state.
danielsuo Jan 15, 2026
a892789
Update `rules_ml_toolchain` version to incorporate new GPU folders st…
ybaturina Jan 15, 2026
e16ca65
Add a version guard around TPU lowering rule changes.
hawkinsp Jan 15, 2026
8d97179
Add the test rule `compare_srcs_and_test_deps_test` that compares the…
ybaturina Jan 15, 2026
695eac8
[Pallas:MGPU] Fix scratch size for cross-warp reductions.
allanrenucci Jan 15, 2026
c5831d1
Merge pull request #34375 from Cjkkkk:relax_fp8_sdpa_tolerance
Google-ML-Automation Jan 15, 2026
002c3dc
[hijax] fix up custom_vjp3 error messages
mattjj Jan 14, 2026
7f6fe29
[Pallas] Turn off tpu_7x configs for distributed test to avoid hangs
sharadmv Jan 15, 2026
ee85b82
[Pallas] Remove PRNG/Semaphore effects from basic JAX allowlist since
sharadmv Jan 15, 2026
9007630
[indexing] use strides=None for all unit strides
jakevdp Jan 15, 2026
e64dccf
[typ] annotate lax.padtype_to_pads
jakevdp Jan 15, 2026
429fcfc
Merge pull request #34394 from mattjj:custom-vjp3-youandme-4
Google-ML-Automation Jan 15, 2026
fe7676d
Uninstall xprof on python3.13-nogil always.
mwhittaker Jan 15, 2026
9e48147
[Pallas] Disable non-sharded multi-chip splash attention test config,…
rdyro Jan 15, 2026
2ecafee
[hijax] fix shard_map of hijax primitive
jakevdp Jan 15, 2026
3d40094
[pmap] Suppress PmapSharding deprecation warning in (more) multihost_…
danielsuo Jan 15, 2026
1143012
Merge pull request #34414 from jakevdp:padtype-to-pads
Google-ML-Automation Jan 15, 2026
f0ff1b8
Merge pull request #34410 from jakevdp:indexing-strides
Google-ML-Automation Jan 15, 2026
acdd012
Migrate TSAN workflow to use RBE.
belitskiy Jan 15, 2026
91223b9
Follow up pr after triton integration cl and tokamax pr.
loislo Jan 15, 2026
6a9f3d9
Merge pull request #34418 from jakevdp:hijax-shard-map
Google-ML-Automation Jan 15, 2026
02939fe
[Pallas] Create trace_value primitive for dynamic value logging
Google-ML-Automation Jan 15, 2026
a84ca0c
Add `testonly=True` to py_import targets that depend on testonly `whe…
ybaturina Jan 15, 2026
cdfe904
Bump XLA version.
mwhittaker Jan 16, 2026
dd745d7
Prepare for JAX release 0.9.0
mwhittaker Jan 15, 2026
8f4389a
Increase shard count for TPU and GPU tests back to 5 for api_test.py.
belitskiy Jan 16, 2026
fae5def
Fix wheel sources tests for Windows platform.
ybaturina Jan 16, 2026
d97eed6
Disable `tests/multiprocess:socket_transfer_test` since it's failing …
emilyfertig Jan 16, 2026
18b34a1
Re-enable `socket_transfer_test` internally.
emilyfertig Jan 16, 2026
9e7b005
Add libtpu date guard for failing test.
mwhittaker Jan 16, 2026
77d9ffb
Skip tpu_pallas_distributed_test on 7x.
mwhittaker Jan 16, 2026
80b1ef6
Add libtpu guard to failing tpu_trace_value_test.
mwhittaker Jan 16, 2026
28799d5
Disable `tpu_splash_attention_kernel_test` on TPU v7x.
mwhittaker Jan 16, 2026
98d6b4b
Remove failing test_itof_dot_canonicalization_fails_without_compat_mo…
18praveenb Jan 16, 2026
2de5b8b
Use maxsize=None with trace_to_jaxpr's weakref_lru_cache to get more …
yashk2810 Jan 19, 2026
c706c2f
Remove nvidia_wheel_versions
charleshofer Nov 12, 2025
de91c59
Make jaxlib targets visible
charleshofer Nov 12, 2025
6e0eb3e
hipblas typedef fix
charleshofer Nov 12, 2025
4f3741b
No GPU fail
charleshofer Nov 13, 2025
2b08a20
Wrap HIP inline functions in anonymous namespaces in vendor.h
AratiGanesh Oct 13, 2025
224b5ba
SWDEV-512768 - Replace hipGetLastError with hipExtGetLastError
dsicarov-amd Jun 10, 2025
1bac5e7
Add shared utility function get_rocm_version to test_util.py
charleshofer Nov 14, 2025
81cdeb4
Fix hipSparse CSR algorithm mappings for ROCm 7
phambinhfin Nov 17, 2025
d88ca19
Make nvidia version data optional for ROCm builds
phambinhfin Oct 31, 2025
61d2832
Fix v_pages quantization and adjust test params for ROCm compatibilit…
phambinhfin Nov 19, 2025
8b0b3d1
Address LLVM assertion failure due to a multithreaded use. Update .gi…
Arech8 Nov 26, 2025
6ab3502
Add skip of test_is_finite() on Cuda (#565)
Arech8 Nov 26, 2025
835856e
Add rocm test requirements file (#570)
AratiGanesh Dec 15, 2025
6bd1a15
Let the unit tests use build.py for setting up Bazel commands for uni…
charleshofer Dec 15, 2025
a83cc9e
adding abort logic to rocm/jax (#590)
gulsumgudukbay Jan 13, 2026
17c7560
Skip is_finite tests on ROCm (not in Triton lowering for jax 0.8.0) (…
phambinhfin Jan 14, 2026
e5b6613
Fix shared memory limit check for ROCm in test_dot (#596)
phambinhfin Jan 14, 2026
1a823af
Fix Numpy signatures test (#598)
magaonka-amd Jan 14, 2026
fd0e42c
Fix GPU lowering rule for SVD on ROCm devices (#600)
tsrw2048 Jan 14, 2026
aaac58c
Fixed merge conflicts when moving GESVDJ commit from JAX 0.8.0 to 0.8…
tsrw2048 Jan 16, 2026
e8ac89b
fix merge arts
Ruturaj4 Jan 18, 2026
ee2c254
Enabled testTridiagonal to run on ROCm devices (#607)
tsrw2048 Jan 16, 2026
455f6b1
Enabled ToeplitzSymmetricConstruction unit tests for ROCm devices (#608)
tsrw2048 Jan 17, 2026
fc3bc80
Enable testSvdSubsetByIndex for ROCm with subset_by_index skip (#603)
phambinhfin Jan 17, 2026
0956a23
Fix KeyError for bytes_reservable_limit on ROCm
magaonka-amd Jan 6, 2026
ca93234
Backport cuda_array_interface , testDotAlgorithm, Optimizer test fixe…
magaonka-amd Jan 22, 2026
4a075b9
Enable RngShardingTests (#644)
gulsumgudukbay Jan 22, 2026
2557ff0
Enable reduce_window tests on ROCm (#643)
magaonka-amd Jan 22, 2026
134825e
Enable test_variadic_reduce_window on ROCm (#647)
magaonka-amd Jan 22, 2026
ee5581c
Unskip supported dtypes for testConvolutionsPreferredElementType (#649)
gulsumgudukbay Jan 23, 2026
10c7e75
Enabled test for condition number on ROCm devices. (#613)
tsrw2048 Jan 20, 2026
708d732
Enabled RNN unit test: test_no_workspace_overflow for ROCm devices (#…
tsrw2048 Jan 20, 2026
98cc66c
Added changes from PR #626 and PR #645. This also fixes merge conflic…
tsrw2048 Jan 23, 2026
b77014f
[Pallas] Fix ROCm GPU architecture detection and route to Triton backend
Ruturaj4 Jan 28, 2026
91be368
Enable array interoperability tests on ROCm platform (#660)
magaonka-amd Jan 28, 2026
95e1df0
Skip sparse tests on ROCm due to hipSPARSE issue (#652)
magaonka-amd Jan 23, 2026
e0739e6
Update sparse test skip messages in v0.8.2 (#653)
magaonka-amd Jan 23, 2026
a988aeb
Port skip for "test_prim_tridiagonal_solve" tests from JAX v0.8.0 to …
tsrw2048 Jan 26, 2026
3467620
Fix test_cuda_array_interface test skip condition (#657)
magaonka-amd Jan 26, 2026
ed14fc6
Enable testMultivariateNormalSingularCovariance on ROCm (#666)
AratiGanesh Jan 28, 2026
fc10735
Skip test_batch_axis_sharding_jvp because of hipSPARSE issue (#667)
AratiGanesh Jan 28, 2026
93f5fdc
Skip test_tridiagonal_solve on ROCm due to hipSPARSE numerical errors…
AratiGanesh Jan 28, 2026
3747d88
Update Skip Reason Outputs (#663)
gulsumgudukbay Jan 28, 2026
103da59
Skip testCudaArrayInterfaceOnNonCudaFails on ROCm platform (#677)
magaonka-amd Jan 29, 2026
d3cce97
Enable lobpcg tests on ROCm platform (#681)
magaonka-amd Feb 3, 2026
498f735
Enable lax backend scipy tests on ROCm GPUs (#687)
magaonka-amd Feb 4, 2026
6701f5d
Add ROCm encoding for test_struct_encoding_determinism (#683)
AratiGanesh Feb 5, 2026
4803ba8
Remove 'mean' from unsupported params for jnp.var (#689)
magaonka-amd Feb 6, 2026
2c7ef61
Enable memory space export tests on ROCm GPUs (#690)
magaonka-amd Feb 6, 2026
7071f71
Implement approx_tanh for ROCm using OCML tanh function (#691)
magaonka-amd Feb 6, 2026
76c5165
Enable test deviceless aot compile test on ROCm (#694)
magaonka-amd Feb 6, 2026
6df5f17
Skipping testEighTinyNorm due to hipSolver issues (#697)
AratiGanesh Feb 9, 2026
ac4ba1c
Modified memory space export test to run on ROCm (for some tests). (#…
tsrw2048 Feb 9, 2026
0988e6d
Add `device test` unit tests for ROCm (JAX v0.9.0) (#705)
tsrw2048 Feb 10, 2026
57c0080
Skip test_tridiagonal_solve_grad test 0.9.0 (#703)
AratiGanesh Feb 11, 2026
8cb19cf
Skip test_batch_axis_sharding_jvp13 test 0.9.0 (#709)
AratiGanesh Feb 11, 2026
32ba3e9
Update skip message version from 0.8.0 to 0.9.0 for test_is_finite on…
phambinhfin Feb 12, 2026
48d2ef1
Fix HIP memory leaks in RNN kernels (#726)
magaonka-amd Mar 5, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
506 changes: 310 additions & 196 deletions .bazelrc

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion .bazelversion
Original file line number Diff line number Diff line change
@@ -1 +1 @@
7.4.1
7.7.0
15 changes: 2 additions & 13 deletions .editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,7 @@ indent_style = space
end_of_line = lf
trim_trailing_whitespace = true
insert_final_newline = true

[*.py]
max_line_length = 79
indent_size = 2

[*.rst]
max_line_length = 79
indent_size = 2

[*.md]
max_line_length = 79
indent_size = 2

[*.yml]
indent_size = 2
[*.py]
max_line_length = 80
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/bug-report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ body:

[issue search]: https://github.com/jax-ml/jax/search?q=is%3Aissue&type=issues

[Raw report]: http://github.com/jax-ml/jax/issues/new
[Raw report]: https://github.com/jax-ml/jax/issues/new?template=none
- type: textarea
attributes:
label: Description
Expand Down
11 changes: 11 additions & 0 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
<!--

Thanks for taking the time to contribute to JAX! A couple things to keep in mind:

* Contributing to JAX requires signing the Contributor License Agreement: https://docs.jax.dev/en/latest/contributing.html#google-contributor-license-agreement

* Please run lint checks and tests locally: https://docs.jax.dev/en/latest/contributing.html#contributing-code-using-pull-requests

* If applicable, read our policy on AI generated code: https://docs.jax.dev/en/latest/contributing.html#can-i-contribute-ai-generated-code

-->
20 changes: 20 additions & 0 deletions .github/actionlint.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Configuration related to self-hosted runner.
self-hosted-runner:
labels:
- "linux-x86-n4-16" # Linux X86 runner using the 16 vcpu n4-standard-16 machine.
- "linux-x86-n4-32" # Linux X86 runner using the 32 vcpu n4-standard-32 machine.
- "linux-x86-n4-64" # Linux X86 runner using the 64 vcpu n2-standard-64 machine.
- "linux-x86-g2-16-l4-1gpu" # Linux X86 GPU runner using g2-standard-16 machine with 1 NVIDIA L4 GPU attached.
- "linux-x86-g2-48-l4-4gpu" # Linux X86 GPU runner using g2-standard-48 machine with 4 NVIDIA L4 GPUs attached.
- "linux-x86-ct5lp-224-8tpu" # Linux X86 TPU runner using ct5lp-hightpu-8t machine with 2x4 topology.
- "linux-arm64-c4a-16" # Linux ARM64 CPU Runner using the 16 vcpu c4a-standard-16 machine.
- "linux-arm64-c4a-64" # Linux ARM64 CPU Runner using the 64 vcpu c4a-standard-64 machine.
- "windows-x86-n2-16" # Windows X86 runner using n2-standard-16 machine.
- "windows-x86-n2-64" # Windows X86 runner using n2-standard-64 machine.
- "linux-x86-a4-224-b200-1gpu" # Linux X86 GPU runner using 1 B200 GPU and 1/8 the resources of a a4-highgpu-8g machine
- "linux-x86-a3-8g-h100-8gpu" # Linux X86 GPU runner using a3-highgpu-8g machine with 8 NVIDIA H100 GPUs attached.
- "linux-x86-ct6e-180-8tpu" # Linux X86 TPU runner using ct6e-hightpu-8t machine with 2x4 topology.
- "linux-x86-ct6e-180-4tpu" # Linux X86 TPU runner using ct6e-hightpu-4t machine with 2x2 topology.
- "linux-x86-ct4p-240-4tpu" # Linux X86 TPU runner using ct4p-hightpu-4t machine with 2x2x1 topology.
- "linux-x86-tpu7x-224-4tpu" # Linux X86 TPU runner using tpu7x-224 machine with 4 TPU chips (8 cores) and 2x2x1 topology.
- "linux-x86_64-cirrascale-64-8gpu-amd-mi250" # AMD runner
105 changes: 105 additions & 0 deletions .github/actions/download-jax-cpu-wheels/action.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Composite action to download the jax and jaxlib wheels
name: Download JAX CPU wheels

inputs:
runner:
description: "Which runner type should the wheels be downloaded for?"
type: string
default: "linux-x86-n4-16"
python:
description: "Which python version should the artifact be downloaded for?"
required: true
type: string
jaxlib-version:
description: "Which jaxlib version to download? (head/pypi_latest)"
type: string
default: "head"
skip-download-jaxlib-from-gcs:
description: "Whether to skip downloading the jaxlib artifact from GCS (e.g for testing a jax only release)"
default: '0'
type: string
gcs_download_uri:
description: "GCS location prefix from where the artifacts should be downloaded"
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
type: string
permissions: {}
runs:
using: "composite"

steps:
# Note that certain envs such as JAXCI_HERMETIC_PYTHON_VERSION are set by the calling workflow.
- name: Set env vars for use in artifact download URL
shell: bash
run: |
os=$(uname -s | awk '{print tolower($0)}')
arch=$(uname -m)

# Adjust os and arch for Windows
if [[ $os =~ "msys_nt" ]] && [[ $arch =~ "x86_64" ]]; then
os="win"
arch="amd64"
fi

# Get the major and minor version of Python.
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.13-nogil, then python_major_minor=313t
python_major_minor=$(echo "${JAXCI_HERMETIC_PYTHON_VERSION//-nogil/t}" | tr -d '.')

echo "OS=${os}" >> $GITHUB_ENV
echo "ARCH=${arch}" >> $GITHUB_ENV
# Python wheels follow a naming convention: standard wheels use the pattern
# `*-cp<py_version>-cp<py_version>-*`, while free-threaded wheels use
# `*-cp<py_version>-cp<py_version>t-*`.
echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV
- name: Download wheels from GCS (non-Windows runs)
shell: bash
id: download-wheel-artifacts-nw
# Set continue-on-error to true to prevent actions from failing the workflow if this step
# fails. Instead, we verify the outcome in the step below so that we can print a more
# informative error message.
continue-on-error: true
if: ${{ !contains(inputs.runner, 'windows-x86') }}
run: |
mkdir -p $(pwd)/dist
gcloud storage cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/

if [[ "${{ inputs.skip-download-jaxlib-from-gcs }}" == "1" ]]; then
echo "JAX only release. Only downloading the jax wheel from the release bucket."
else
if [[ ${{ inputs.jaxlib-version }} == "head" ]]; then
gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/
elif [[ ${{ inputs.jaxlib-version }} == "pypi_latest" ]]; then
PYTHON=python${{ inputs.python }}
$PYTHON -m pip download jaxlib --dest $(pwd)/dist/
else
echo "Invalid jaxlib version: ${{ inputs.jaxlib-version }}"
exit 1
fi
fi
- name: Download wheels from GCS (Windows runs)
shell: cmd
id: download-wheel-artifacts-w
# Set continue-on-error to true to prevent actions from failing the workflow if this step
# fails. Instead, we verify the outcome in step below so that we can print a more
# informative error message.
continue-on-error: true
if: ${{ contains(inputs.runner, 'windows-x86') }}
run: |
mkdir dist
@REM Use `call` so that we can run sequential gcloud storage commands on Windows
@REM See https://github.com/GoogleCloudPlatform/gsutil/issues/233#issuecomment-196150652
call gcloud storage cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl dist/

if "${{ inputs.skip-download-jaxlib-from-gcs }}"=="1" (
echo "JAX only release. Only downloading the jax wheel from the release bucket."
) else (
call gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jaxlib*%PYTHON_MAJOR_MINOR%*%OS%*%ARCH%*.whl" dist/
)
- name: Skip the test run if the wheel artifacts were not downloaded successfully
shell: bash
if: steps.download-wheel-artifacts-nw.outcome == 'failure' || steps.download-wheel-artifacts-w.outcome == 'failure'
run: |
echo "Failed to download wheel artifacts from GCS. Please check if the wheels were"
echo "built successfully by the artifact build jobs and are available in the GCS bucket."
echo "Skipping the test run."
exit 1
100 changes: 100 additions & 0 deletions .github/actions/download-jax-cuda-wheels/action.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Composite action to download the jax, jaxlib, and the CUDA plugin wheels
name: Download JAX CUDA wheels

inputs:
python:
description: "Which python version should the artifact be downloaded for?"
type: string
required: true
cuda-version:
description: "Which cuda version should the artifact be downloaded for?"
type: string
default: "12"
use-nvidia-pip-wheels:
description: "Whether to download Nvidia CUDA packages from PyPI?"
type: boolean
default: false
jaxlib-version:
description: "Which jaxlib version to download? (head/pypi_latest)"
type: string
default: "head"
download-jax-from-gcs:
description: "Whether to download the jax wheel from GCS"
default: '1'
type: string
skip-download-jaxlib-and-cuda-plugins-from-gcs:
description: "Whether to skip downloading the jaxlib and cuda plugins from GCS (e.g for testing a jax only release)"
default: '0'
type: string
gcs_download_uri:
description: "GCS location prefix from where the artifacts should be downloaded"
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
type: string
permissions: {}
runs:
using: "composite"

steps:
# Note that certain envs such as JAXCI_HERMETIC_PYTHON_VERSION are set by the calling workflow.
- name: Set env vars for use in artifact download URL
shell: bash
run: |
os=$(uname -s | awk '{print tolower($0)}')
arch=$(uname -m)

# Get the major and minor version of Python.
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.11, then python_major_minor=311
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.13-nogil, then python_major_minor=313t
python_major_minor=$(echo "${JAXCI_HERMETIC_PYTHON_VERSION//-nogil/t}" | tr -d '.')

echo "OS=${os}" >> $GITHUB_ENV
echo "ARCH=${arch}" >> $GITHUB_ENV
# Python wheels follow a naming convention: standard wheels use the pattern
# `*-cp<py_version>-cp<py_version>-*`, while free-threaded wheels use
# `*-cp<py_version>-cp<py_version>t-*`.
echo "PYTHON_MAJOR_MINOR=cp${python_major_minor%t}-cp${python_major_minor}-" >> $GITHUB_ENV

# Get the CUDA major version only
full_cuda_version="${{ inputs.cuda-version }}"
echo "JAXCI_CUDA_VERSION=${full_cuda_version%%.*}" >> $GITHUB_ENV
- name: Download wheels
shell: bash
id: download-wheel-artifacts
# Set continue-on-error to true to prevent actions from failing the workflow if this step
# fails. Instead, we verify the outcome in the next step so that we can print a more
# informative error message.
continue-on-error: true
run: |
mkdir -p $(pwd)/dist
if [[ "${{ inputs.download-jax-from-gcs }}" == "1" ]]; then
gcloud storage cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any.whl $(pwd)/dist/
else
echo "JAX wheel won't be downloaded, only jaxlib pre-built wheel is tested."
fi

# Do not download the jaxlib and CUDA plugin artifacts if we are testing a jax only
# release.
if [[ "${{ inputs.skip-download-jaxlib-and-cuda-plugins-from-gcs }}" == "1" ]]; then
echo "JAX only release. Only downloading the jax wheel from the release bucket."
else
if [[ ${{ inputs.jaxlib-version }} == "head" ]]; then
gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/
gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jax*cuda${JAXCI_CUDA_VERSION}*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/
gcloud storage cp -r "${{ inputs.gcs_download_uri }}/jax*cuda${JAXCI_CUDA_VERSION}*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/
elif [[ ${{ inputs.jaxlib-version }} == "pypi_latest" ]]; then
PYTHON=python${{ inputs.python }}
$PYTHON -m pip download jaxlib jax-cuda${JAXCI_CUDA_VERSION}-pjrt jax-cuda${JAXCI_CUDA_VERSION}-plugin --dest $(pwd)/dist/
else
echo "Invalid jaxlib version: ${{ inputs.jaxlib-version }}"
exit 1
fi
fi
- name: Skip the test run if the wheel artifacts were not downloaded successfully
shell: bash
if: steps.download-wheel-artifacts.outcome == 'failure'
run: |
echo "Failed to download wheel artifacts. Please check if the wheels were"
echo "built successfully by the artifact build jobs and are available in the GCS bucket if
echo "downloading from GCS."
echo "Skipping the test run."
exit 1
17 changes: 12 additions & 5 deletions .github/workflows/asan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,17 @@ on:
- main
paths:
- '**/workflows/asan.yaml'
permissions: {}

env:
UV_DEFAULT_INDEX: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple"
PIP_INDEX_URL: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple"

jobs:
asan:
# Don't execute in fork due to runner type
if: github.repository == 'jax-ml/jax'
runs-on: linux-x86-n2-64
runs-on: linux-x86-n4-64
container:
image: index.docker.io/library/ubuntu@sha256:b359f1067efa76f37863778f7b6d0e8d911e3ee8efa807ad01fbf5dc1ef9006b # ratchet:ubuntu:24.04
strategy:
Expand All @@ -38,14 +43,16 @@ jobs:
zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev curl git \
libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \
libffi-dev liblzma-dev
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
with:
path: jax
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
persist-credentials: false
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
with:
repository: python/cpython
path: cpython
ref: v3.13.0
persist-credentials: false
- name: Build CPython with ASAN enabled
env:
ASAN_OPTIONS: detect_leaks=0
Expand All @@ -60,6 +67,7 @@ jobs:
env:
ASAN_OPTIONS: detect_leaks=0
run: |
apt install -y xxd
source ${GITHUB_WORKSPACE}/venv/bin/activate
cd jax
pip install uv~=0.5.30
Expand All @@ -72,8 +80,7 @@ jobs:
cd jax
python build/build.py build --wheels=jaxlib --verbose \
--bazel_options=--color=yes \
--bazel_options=--copt=-fsanitize=address \
--clang_path=/usr/bin/clang-18
--bazel_options=--copt=-fsanitize=address
uv pip install dist/jaxlib-*.whl \
-e .
- name: Run tests
Expand Down
Loading