Skip to content

Conversation

@dependabot
Copy link
Contributor

@dependabot dependabot bot commented on behalf of github Oct 16, 2025

Bumps jax from 0.5.2 to 0.8.0.

Release notes

Sourced from jax's releases.

JAX v0.8.0

  • Breaking changes:

    • JAX is changing the default jax.pmap implementation to one implemented in terms of jax.jit and jax.shard_map. jax.pmap is in maintenance mode and we encourage all new code to use jax.shard_map directly. See the migration guide for more information.
    • The auto= parameter of jax.experimental.shard_map.shard_map has been removed. This means that jax.experimental.shard_map.shard_map no longer supports nesting. If you want to nest shard_map calls, please use jax.shard_map.
    • JAX no longer allows passing objects that support __jax_array__ directly to, e.g. jit-ed functions. Call jax.numpy.asarray on them first.
    • jax.numpy.cov is now returns NaN for empty arrays ({jax-issue}[#32305](https://github.com/jax-ml/jax/issues/32305)), and matches NumPy 2.2 behavior for single-row design matrices ({jax-issue}[#32308](https://github.com/jax-ml/jax/issues/32308)).
    • JAX no longer accepts Array values where a dtype value is expected. Call .dtype on these values first.
    • The deprecated function jax.interpreters.mlir.custom_call was removed.
    • The jax.util, jax.extend.ffi, and jax.experimental.host_callback modules have been removed. All public APIs within these modules were deprecated and removed in v0.7.0 or earlier.
    • The deprecated symbol jax.custom_derivatives.custom_jvp_call_jaxpr_p was removed.
    • jax.experimental.multihost_utils.process_allgather raises an error when the input is a jax.Array and not fully-addressable and tiled=False. To fix this, pass tiled=True to your process_allgather invocation.
    • from jax.experimental.compilation_cache, the deprecated symbols is_initialized and initialize_cache were removed.
    • The deprecated function jax.interpreters.xla.canonicalize_dtype was removed.
    • jaxlib.hlo_helpers has been removed. Use jax.ffi instead.
    • The option jax_cpu_enable_gloo_collectives has been removed. Use jax_cpu_collectives_implementation instead.
    • The previously-deprecated interpolation argument to jax.numpy.percentile and jax.numpy.quantile has been removed; use method instead.
    • The JAX-internal for_loop primitive was removed. Its functionality, reading from and writing to refs in the loop body, is now directly supported by jax.lax.fori_loop. If you need help updating your code, please file a bug.
    • jax.numpy.trimzeros now errors for non-1D input.
    • The where argument to jax.numpy.sum and other reductions is now required to be boolean. Non-boolean values have resulted in a DeprecationWarning since JAX v0.5.0.
    • The deprecated functions in jax.dlpack, jax.errors, jax.lib.xla_bridge, jax.lib.xla_client, and jax.lib.xla_extension were removed.
    • jax.interpreters.mlir.dense_bool_array was removed. Use MLIR APIs to

... (truncated)

Changelog

Sourced from jax's changelog.

JAX 0.8.0 (October 15, 2025)

  • Breaking changes:

    • JAX is changing the default jax.pmap implementation to one implemented in terms of jax.jit and jax.shard_map. jax.pmap is in maintenance mode and we encourage all new code to use jax.shard_map directly. See the migration guide for more information.
    • The auto= parameter of jax.experimental.shard_map.shard_map has been removed. This means that jax.experimental.shard_map.shard_map no longer supports nesting. If you want to nest shard_map calls, please use jax.shard_map.
    • JAX no longer allows passing objects that support __jax_array__ directly to, e.g. jit-ed functions. Call jax.numpy.asarray on them first.
    • {func}jax.numpy.cov is now returns NaN for empty arrays ({jax-issue}[#32305](https://github.com/jax-ml/jax/issues/32305)), and matches NumPy 2.2 behavior for single-row design matrices ({jax-issue}[#32308](https://github.com/jax-ml/jax/issues/32308)).
    • JAX no longer accepts Array values where a dtype value is expected. Call .dtype on these values first.
    • The deprecated function {func}jax.interpreters.mlir.custom_call was removed.
    • The jax.util, jax.extend.ffi, and jax.experimental.host_callback modules have been removed. All public APIs within these modules were deprecated and removed in v0.7.0 or earlier.
    • The deprecated symbol {obj}jax.custom_derivatives.custom_jvp_call_jaxpr_p was removed.
    • jax.experimental.multihost_utils.process_allgather raises an error when the input is a jax.Array and not fully-addressable and tiled=False. To fix this, pass tiled=True to your process_allgather invocation.
    • from {mod}jax.experimental.compilation_cache, the deprecated symbols is_initialized and initialize_cache were removed.
    • The deprecated function {func}jax.interpreters.xla.canonicalize_dtype was removed.
    • {mod}jaxlib.hlo_helpers has been removed. Use {mod}jax.ffi instead.
    • The option jax_cpu_enable_gloo_collectives has been removed. Use jax_cpu_collectives_implementation instead.
    • The previously-deprecated interpolation argument to {func}jax.numpy.percentile and {func}jax.numpy.quantile has been removed; use method instead.
    • The JAX-internal for_loop primitive was removed. Its functionality, reading from and writing to refs in the loop body, is now directly supported by {func}jax.lax.fori_loop. If you need help updating your code, please file a bug.
    • {func}jax.numpy.trimzeros now errors for non-1D input.
    • The where argument to {func}jax.numpy.sum and other reductions is now required to be boolean. Non-boolean values have resulted in a DeprecationWarning since JAX v0.5.0.
    • The deprecated functions in {mod} jax.dlpack, {mod} jax.errors, {mod} jax.lib.xla_bridge, {mod} jax.lib.xla_client, and {mod} jax.lib.xla_extension were removed.

... (truncated)

Commits
  • 403977d Reverts f44bc6e37f770a9d79d24061f7960235a2892407
  • b1ce4c5 Prepare for JAX release 0.8.0
  • 07dad8e [Pallas] Fix 64-bit multiplication emulation in Philox kernel
  • 4f06e7a Update XLA hash for JAX release
  • 1bd3402 Fix race when dialects were registered in an MLIR context when it was being a...
  • b942e89 Clean up dead code.
  • ac87cbe Remove cusolver_potrf_ffi from the list of covered backwards compatibility ta...
  • 69fd205 [Pallas:MGPU] Expose multicast TMA stores
  • 8eed7dc [Pallas:MGPU] Add an all-gather implementation
  • 6aae6f4 Add tpu_sc.subcore_barrier primitive.
  • Additional commits viewable in compare view

Dependabot compatibility score

Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting @dependabot rebase.


Dependabot commands and options

You can trigger Dependabot actions by commenting on this PR:

  • @dependabot rebase will rebase this PR
  • @dependabot recreate will recreate this PR, overwriting any edits that have been made to it
  • @dependabot merge will merge this PR after your CI passes on it
  • @dependabot squash and merge will squash and merge this PR after your CI passes on it
  • @dependabot cancel merge will cancel a previously requested merge and block automerging
  • @dependabot reopen will reopen this PR if it is closed
  • @dependabot close will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually
  • @dependabot show <dependency name> ignore conditions will show all of the ignore conditions of the specified dependency
  • @dependabot ignore this major version will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself)
  • @dependabot ignore this minor version will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself)
  • @dependabot ignore this dependency will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)

@dependabot dependabot bot added dependencies Pull requests that update a dependency file python Pull requests that update Python code labels Oct 16, 2025
@tushuhei
Copy link
Member

@dependabot rebase

Bumps [jax](https://github.com/jax-ml/jax) from 0.5.2 to 0.8.0.
- [Release notes](https://github.com/jax-ml/jax/releases)
- [Changelog](https://github.com/jax-ml/jax/blob/main/CHANGELOG.md)
- [Commits](jax-ml/jax@jax-v0.5.2...jax-v0.8.0)

---
updated-dependencies:
- dependency-name: jax
  dependency-version: 0.8.0
  dependency-type: direct:development
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
@dependabot dependabot bot force-pushed the dependabot/pip/jax-0.8.0 branch from 2518927 to 5905487 Compare October 20, 2025 02:34
@tushuhei
Copy link
Member

This update will be taken care of by #1062

@tushuhei tushuhei closed this Oct 20, 2025
@dependabot @github
Copy link
Contributor Author

dependabot bot commented on behalf of github Oct 20, 2025

OK, I won't notify you again about this release, but will get in touch when a new version is available. If you'd rather skip all updates until the next major or minor version, let me know by commenting @dependabot ignore this major version or @dependabot ignore this minor version. You can also ignore all major, minor, or patch releases for a dependency by adding an ignore condition with the desired update_types to your config file.

If you change your mind, just re-open this PR and I'll resolve any conflicts on it.

@dependabot dependabot bot deleted the dependabot/pip/jax-0.8.0 branch October 20, 2025 03:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

dependencies Pull requests that update a dependency file python Pull requests that update Python code

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant