+ `;
+
+ // Create the cluster.
+ let network = [];
+ let server = new CoordinationService(network, options);
+ const processes = [
+ new Process(network, options, 0), new Process(network, options, 1),
+ new Process(network, options, 2)
+ ];
+
+ // Set up the live_devices button.
+ for (let i = 0; i < 3; ++i) {
+ const button = container.getElementsByClassName(`p${i}-ld-button`)[0];
+ if (options.live_devices) {
+ button.addEventListener('click', () => processes[i].send_live_devices());
+ } else {
+ button.remove();
+ }
+ }
+
+ // Set up the fail button.
+ const button = container.querySelectorAll('.p2-fail-button')[0];
+ button.addEventListener('click', () => processes[2].fail());
+
+ // Remove live_devices display if needed.
+ if (!options.live_devices) {
+ for (let i = 0; i < 3; ++i) {
+ container.getElementsByClassName(`p${i}-live-devices`)[0].remove();
+ }
+ }
+ if (!options.barrier) {
+ container.getElementsByClassName('in-barrier')[0].remove();
+ }
+
+ // Periodically process network messages.
+ setInterval(() => {
+ while (network.length > 0) {
+ const msg = network.shift();
+ const tall = options.live_devices;
+ send(container, tall, msg.payload, `p${msg.src}`, `p${msg.dst}`, () => {
+ if (msg.dst == 'server') {
+ server.receive(msg);
+ } else {
+ processes[msg.dst].receive(msg);
+ }
+ });
+ }
+ }, 10)
+
+ // Periodically update HTML.
+ setInterval(() => {
+ server.update_html(container);
+ for (let proc of processes) {
+ proc.update_html(container);
+ }
+ }, 50);
+}
diff --git a/docs/_static/fault_tolerance/live_devices.py b/docs/_static/fault_tolerance/live_devices.py
new file mode 100644
index 000000000000..9f41a2bdac6a
--- /dev/null
+++ b/docs/_static/fault_tolerance/live_devices.py
@@ -0,0 +1,64 @@
+# Copyright 2025 The JAX Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+os.environ['XLA_FLAGS'] = ' '.join([
+ '--xla_gpu_nccl_terminate_on_error=false',
+ '--xla_gpu_nccl_async_execution=true',
+ '--xla_gpu_nccl_blocking_communicators=false',
+])
+os.environ['XLA_PYTHON_CLIENT_ABORT_COLLECTIVES_ON_FAILURE'] = '1'
+os.environ['XLA_PYTHON_CLIENT_USE_TFRT_GPU_CLIENT'] = '1'
+
+from absl import app
+from absl import flags
+from collections.abc import Sequence
+from jax.experimental.multihost_utils import live_devices
+import jax
+import jax.numpy as jnp
+import time
+
+_PROCESS_ID = flags.DEFINE_integer("i", -1, "Process id")
+_NUM_PROCESSES = flags.DEFINE_integer("n", -1, "Number of processes")
+
+
+def main(_: Sequence[str]) -> None:
+ jax.config.update("jax_enable_recoverability", True)
+ jax.distributed.initialize(
+ coordinator_address="localhost:9000",
+ num_processes=_NUM_PROCESSES.value,
+ process_id=_PROCESS_ID.value,
+ local_device_ids=[_PROCESS_ID.value],
+ heartbeat_timeout_seconds=10,
+ )
+ print(f'{jax.devices()=}')
+ print(f'{jax.local_devices()=}')
+
+ while True:
+ try:
+ with live_devices(jax.devices()) as devices:
+ print(f'{devices=}')
+ n = len(devices)
+ jax.set_mesh(jax.make_mesh((n,), ("i",), devices=devices))
+ x = jax.device_put(jnp.arange(n), jax.P("i"))
+ print(jnp.sum(x))
+ except Exception as e:
+ print('FAIL:', e)
+ else:
+ print('PASS')
+ time.sleep(1)
+
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/docs/_static/fault_tolerance/while_loop.py b/docs/_static/fault_tolerance/while_loop.py
new file mode 100644
index 000000000000..0dbac58b528d
--- /dev/null
+++ b/docs/_static/fault_tolerance/while_loop.py
@@ -0,0 +1,41 @@
+# Copyright 2025 The JAX Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from absl import app
+from absl import flags
+from collections.abc import Sequence
+import jax
+import time
+
+_PROCESS_ID = flags.DEFINE_integer("i", -1, "Process id")
+_NUM_PROCESSES = flags.DEFINE_integer("n", -1, "Number of processes")
+
+
+def main(_: Sequence[str]) -> None:
+ jax.distributed.initialize(
+ coordinator_address="localhost:9000",
+ num_processes=_NUM_PROCESSES.value,
+ process_id=_PROCESS_ID.value,
+ local_device_ids=[_PROCESS_ID.value],
+ heartbeat_timeout_seconds=10,
+ )
+ print(f'{jax.devices()=}')
+ print(f'{jax.local_devices()=}')
+ while True:
+ print(time.time())
+ time.sleep(1)
+
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/docs/advanced-autodiff.md b/docs/advanced-autodiff.md
deleted file mode 100644
index f8b5000c2b47..000000000000
--- a/docs/advanced-autodiff.md
+++ /dev/null
@@ -1,1777 +0,0 @@
----
-jupytext:
- formats: md:myst
- text_representation:
- extension: .md
- format_name: myst
- format_version: 0.13
- jupytext_version: 1.16.4
-kernelspec:
- display_name: Python 3
- language: python
- name: python3
----
-
-(advanced-autodiff)=
-# Advanced automatic differentiation
-
-
-
-In this tutorial, you will learn about complex applications of automatic differentiation (autodiff) in JAX and gain a better understanding of how taking derivatives in JAX can be both easy and powerful.
-
-Make sure to check out the {ref}`automatic-differentiation` tutorial to go over the JAX autodiff basics, if you haven't already.
-
-## Setup
-
-```{code-cell}
-import jax
-import jax.numpy as jnp
-from jax import grad, jit, vmap
-from jax import random
-
-key = random.key(0)
-```
-
-## Taking gradients (part 2)
-
-### Higher-order derivatives
-
-JAX's autodiff makes it easy to compute higher-order derivatives, because the functions that compute derivatives are themselves differentiable. Thus, higher-order derivatives are as easy as stacking transformations.
-
-The single-variable case was covered in the {ref}`automatic-differentiation` tutorial, where the example showed how to use {func}`jax.grad` to compute the derivative of $f(x) = x^3 + 2x^2 - 3x + 1$.
-
-In the multivariable case, higher-order derivatives are more complicated. The second-order derivative of a function is represented by its [Hessian matrix](https://en.wikipedia.org/wiki/Hessian_matrix), defined according to:
-
-$$(\mathbf{H}f)_{i,j} = \frac{\partial^2 f}{\partial_i\partial_j}.$$
-
-The Hessian of a real-valued function of several variables, $f: \mathbb R^n\to\mathbb R$, can be identified with the [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant) of its gradient.
-
-JAX provides two transformations for computing the Jacobian of a function, {func}`jax.jacfwd` and {func}`jax.jacrev`, corresponding to forward- and reverse-mode autodiff. They give the same answer, but one can be more efficient than the other in different circumstances – refer to the [video about autodiff](https://www.youtube.com/watch?v=wG_nF1awSSY).
-
-```{code-cell}
-def hessian(f):
- return jax.jacfwd(jax.grad(f))
-```
-
-Let's double check this is correct on the dot-product $f: \mathbf{x} \mapsto \mathbf{x} ^\top \mathbf{x}$.
-
-if $i=j$, $\frac{\partial^2 f}{\partial_i\partial_j}(\mathbf{x}) = 2$. Otherwise, $\frac{\partial^2 f}{\partial_i\partial_j}(\mathbf{x}) = 0$.
-
-```{code-cell}
-def f(x):
- return jnp.dot(x, x)
-
-hessian(f)(jnp.array([1., 2., 3.]))
-```
-
-## Higher-order optimization
-
-Some meta-learning techniques, such as Model-Agnostic Meta-Learning ([MAML](https://arxiv.org/abs/1703.03400)), require differentiating through gradient updates. In other frameworks this can be quite cumbersome, but in JAX it's much easier:
-
-```python
-def meta_loss_fn(params, data):
- """Computes the loss after one step of SGD."""
- grads = jax.grad(loss_fn)(params, data)
- return loss_fn(params - lr * grads, data)
-
-meta_grads = jax.grad(meta_loss_fn)(params, data)
-```
-
-(stopping-gradients)=
-### Stopping gradients
-
-Autodiff enables automatic computation of the gradient of a function with respect to its inputs. Sometimes, however, you might want some additional control: for instance, you might want to avoid backpropagating gradients through some subset of the computational graph.
-
-Consider for instance the TD(0) ([temporal difference](https://en.wikipedia.org/wiki/Temporal_difference_learning)) reinforcement learning update. This is used to learn to estimate the *value* of a state in an environment from experience of interacting with the environment. Let's assume the value estimate $v_{\theta}(s_{t-1}$) in a state $s_{t-1}$ is parameterised by a linear function.
-
-```{code-cell}
-# Value function and initial parameters
-value_fn = lambda theta, state: jnp.dot(theta, state)
-theta = jnp.array([0.1, -0.1, 0.])
-```
-
-Consider a transition from a state $s_{t-1}$ to a state $s_t$ during which you observed the reward $r_t$
-
-```{code-cell}
-# An example transition.
-s_tm1 = jnp.array([1., 2., -1.])
-r_t = jnp.array(1.)
-s_t = jnp.array([2., 1., 0.])
-```
-
-The TD(0) update to the network parameters is:
-
-$$
-\Delta \theta = (r_t + v_{\theta}(s_t) - v_{\theta}(s_{t-1})) \nabla v_{\theta}(s_{t-1})
-$$
-
-This update is not the gradient of any loss function.
-
-However, it can be **written** as the gradient of the pseudo loss function
-
-$$
-L(\theta) = - \frac{1}{2} [r_t + v_{\theta}(s_t) - v_{\theta}(s_{t-1})]^2
-$$
-
-if the dependency of the target $r_t + v_{\theta}(s_t)$ on the parameter $\theta$ is ignored.
-
-How can you implement this in JAX? If you write the pseudo loss naively, you get:
-
-```{code-cell}
-def td_loss(theta, s_tm1, r_t, s_t):
- v_tm1 = value_fn(theta, s_tm1)
- target = r_t + value_fn(theta, s_t)
- return -0.5 * ((target - v_tm1) ** 2)
-
-td_update = jax.grad(td_loss)
-delta_theta = td_update(theta, s_tm1, r_t, s_t)
-
-delta_theta
-```
-
-But `td_update` will **not** compute a TD(0) update, because the gradient computation will include the dependency of `target` on $\theta$.
-
-You can use {func}`jax.lax.stop_gradient` to force JAX to ignore the dependency of the target on $\theta$:
-
-```{code-cell}
-def td_loss(theta, s_tm1, r_t, s_t):
- v_tm1 = value_fn(theta, s_tm1)
- target = r_t + value_fn(theta, s_t)
- return -0.5 * ((jax.lax.stop_gradient(target) - v_tm1) ** 2)
-
-td_update = jax.grad(td_loss)
-delta_theta = td_update(theta, s_tm1, r_t, s_t)
-
-delta_theta
-```
-
-This will treat `target` as if it did **not** depend on the parameters $\theta$ and compute the correct update to the parameters.
-
-Now, let's also calculate $\Delta \theta$ using the original TD(0) update expression, to cross-check our work. You may wish to try and implement this yourself using {func}`jax.grad` and your knowledge so far. Here's our solution:
-
-```{code-cell}
-s_grad = jax.grad(value_fn)(theta, s_tm1)
-delta_theta_original_calculation = (r_t + value_fn(theta, s_t) - value_fn(theta, s_tm1)) * s_grad
-
-delta_theta_original_calculation # [1.2, 2.4, -1.2], same as `delta_theta`
-```
-
-`jax.lax.stop_gradient` may also be useful in other settings, for instance if you want the gradient from some loss to only affect a subset of the parameters of the neural network (because, for instance, the other parameters are trained using a different loss).
-
-
-### Straight-through estimator using `stop_gradient`
-
-The straight-through estimator is a trick for defining a 'gradient' of a function that is otherwise non-differentiable. Given a non-differentiable function $f : \mathbb{R}^n \to \mathbb{R}^n$ that is used as part of a larger function that we wish to find a gradient of, we simply pretend during the backward pass that $f$ is the identity function. This can be implemented neatly using `jax.lax.stop_gradient`:
-
-```{code-cell}
-def f(x):
- return jnp.round(x) # non-differentiable
-
-def straight_through_f(x):
- # Create an exactly-zero expression with Sterbenz lemma that has
- # an exactly-one gradient.
- zero = x - jax.lax.stop_gradient(x)
- return zero + jax.lax.stop_gradient(f(x))
-
-print("f(x): ", f(3.2))
-print("straight_through_f(x):", straight_through_f(3.2))
-
-print("grad(f)(x):", jax.grad(f)(3.2))
-print("grad(straight_through_f)(x):", jax.grad(straight_through_f)(3.2))
-```
-
-### Per-example gradients
-
-While most ML systems compute gradients and updates from batches of data, for reasons of computational efficiency and/or variance reduction, it is sometimes necessary to have access to the gradient/update associated with each specific sample in the batch.
-
-For instance, this is needed to prioritize data based on gradient magnitude, or to apply clipping / normalisations on a sample by sample basis.
-
-In many frameworks (PyTorch, TF, Theano) it is often not trivial to compute per-example gradients, because the library directly accumulates the gradient over the batch. Naive workarounds, such as computing a separate loss per example and then aggregating the resulting gradients are typically very inefficient.
-
-In JAX, you can define the code to compute the gradient per-sample in an easy but efficient way.
-
-Just combine the {func}`jax.jit`, {func}`jax.vmap` and {func}`jax.grad` transformations together:
-
-```{code-cell}
-perex_grads = jax.jit(jax.vmap(jax.grad(td_loss), in_axes=(None, 0, 0, 0)))
-
-# Test it:
-batched_s_tm1 = jnp.stack([s_tm1, s_tm1])
-batched_r_t = jnp.stack([r_t, r_t])
-batched_s_t = jnp.stack([s_t, s_t])
-
-perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
-```
-
-Let's go through this one transformation at a time.
-
-First, you apply {func}`jax.grad` to `td_loss` to obtain a function that computes the gradient of the loss w.r.t. the parameters on single (unbatched) inputs:
-
-```{code-cell}
-dtdloss_dtheta = jax.grad(td_loss)
-
-dtdloss_dtheta(theta, s_tm1, r_t, s_t)
-```
-
-This function computes one row of the array above.
-
-Then, you vectorise this function using {func}`jax.vmap`. This adds a batch dimension to all inputs and outputs. Now, given a batch of inputs, you produce a batch of outputs — each output in the batch corresponds to the gradient for the corresponding member of the input batch.
-
-```{code-cell}
-almost_perex_grads = jax.vmap(dtdloss_dtheta)
-
-batched_theta = jnp.stack([theta, theta])
-almost_perex_grads(batched_theta, batched_s_tm1, batched_r_t, batched_s_t)
-```
-
-This isn't quite what we want, because we have to manually feed this function a batch of `theta`s, whereas we actually want to use a single `theta`. We fix this by adding `in_axes` to the {func}`jax.vmap`, specifying theta as `None`, and the other args as `0`. This makes the resulting function add an extra axis only to the other arguments, leaving `theta` unbatched, as we want:
-
-```{code-cell}
-inefficient_perex_grads = jax.vmap(dtdloss_dtheta, in_axes=(None, 0, 0, 0))
-
-inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
-```
-
-This does what we want, but is slower than it has to be. Now, you wrap the whole thing in a {func}`jax.jit` to get the compiled, efficient version of the same function:
-
-```{code-cell}
-perex_grads = jax.jit(inefficient_perex_grads)
-
-perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
-```
-
-```{code-cell}
-%timeit inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()
-%timeit perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()
-```
-
-### Hessian-vector products with `jax.grad`-of-`jax.grad`
-
-One thing you can do with higher-order {func}`jax.grad` is build a Hessian-vector product function. (Later on you'll write an even more efficient implementation that mixes both forward- and reverse-mode, but this one will use pure reverse-mode.)
-
-A Hessian-vector product function can be useful in a [truncated Newton Conjugate-Gradient algorithm](https://en.wikipedia.org/wiki/Truncated_Newton_method) for minimizing smooth convex functions, or for studying the curvature of neural network training objectives (e.g. [1](https://arxiv.org/abs/1406.2572), [2](https://arxiv.org/abs/1811.07062), [3](https://arxiv.org/abs/1706.04454), [4](https://arxiv.org/abs/1802.03451)).
-
-For a scalar-valued function $f : \mathbb{R}^n \to \mathbb{R}$ with continuous second derivatives (so that the Hessian matrix is symmetric), the Hessian at a point $x \in \mathbb{R}^n$ is written as $\partial^2 f(x)$. A Hessian-vector product function is then able to evaluate
-
-$\qquad v \mapsto \partial^2 f(x) \cdot v$
-
-for any $v \in \mathbb{R}^n$.
-
-The trick is not to instantiate the full Hessian matrix: if $n$ is large, perhaps in the millions or billions in the context of neural networks, then that might be impossible to store.
-
-Luckily, {func}`jax.grad` already gives us a way to write an efficient Hessian-vector product function. You just have to use the identity:
-
-$\qquad \partial^2 f (x) v = \partial [x \mapsto \partial f(x) \cdot v] = \partial g(x)$,
-
-where $g(x) = \partial f(x) \cdot v$ is a new scalar-valued function that dots the gradient of $f$ at $x$ with the vector $v$. Notice that you're only ever differentiating scalar-valued functions of vector-valued arguments, which is exactly where you know {func}`jax.grad` is efficient.
-
-In JAX code, you can just write this:
-
-```{code-cell}
-def hvp(f, x, v):
- return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
-```
-
-This example shows that you can freely use lexical closure, and JAX will never get perturbed or confused.
-
-You will check this implementation a few cells down, once you learn how to compute dense Hessian matrices. You'll also write an even better version that uses both forward-mode and reverse-mode.
-
-
-### Jacobians and Hessians using `jax.jacfwd` and `jax.jacrev`
-
-You can compute full Jacobian matrices using the {func}`jax.jacfwd` and {func}`jax.jacrev` functions:
-
-```{code-cell}
-from jax import jacfwd, jacrev
-
-# Define a sigmoid function.
-def sigmoid(x):
- return 0.5 * (jnp.tanh(x / 2) + 1)
-
-# Outputs probability of a label being true.
-def predict(W, b, inputs):
- return sigmoid(jnp.dot(inputs, W) + b)
-
-# Build a toy dataset.
-inputs = jnp.array([[0.52, 1.12, 0.77],
- [0.88, -1.08, 0.15],
- [0.52, 0.06, -1.30],
- [0.74, -2.49, 1.39]])
-
-# Initialize random model coefficients
-key, W_key, b_key = random.split(key, 3)
-W = random.normal(W_key, (3,))
-b = random.normal(b_key, ())
-
-# Isolate the function from the weight matrix to the predictions
-f = lambda W: predict(W, b, inputs)
-
-J = jacfwd(f)(W)
-print("jacfwd result, with shape", J.shape)
-print(J)
-
-J = jacrev(f)(W)
-print("jacrev result, with shape", J.shape)
-print(J)
-```
-
-These two functions compute the same values (up to machine numerics), but differ in their implementation: {func}`jax.jacfwd` uses forward-mode automatic differentiation, which is more efficient for "tall" Jacobian matrices (more outputs than inputs), while {func}`jax.jacrev` uses reverse-mode, which is more efficient for "wide" Jacobian matrices (more inputs than outputs). For matrices that are near-square, {func}`jax.jacfwd` probably has an edge over {func}`jax.jacrev`.
-
-You can also use {func}`jax.jacfwd` and {func}`jax.jacrev` with container types:
-
-```{code-cell}
-def predict_dict(params, inputs):
- return predict(params['W'], params['b'], inputs)
-
-J_dict = jacrev(predict_dict)({'W': W, 'b': b}, inputs)
-for k, v in J_dict.items():
- print("Jacobian from {} to logits is".format(k))
- print(v)
-```
-
-For more details on forward- and reverse-mode, as well as how to implement {func}`jax.jacfwd` and {func}`jax.jacrev` as efficiently as possible, read on!
-
-Using a composition of two of these functions gives us a way to compute dense Hessian matrices:
-
-```{code-cell}
-def hessian(f):
- return jacfwd(jacrev(f))
-
-H = hessian(f)(W)
-print("hessian, with shape", H.shape)
-print(H)
-```
-
-This shape makes sense: if you start with a function $f : \mathbb{R}^n \to \mathbb{R}^m$, then at a point $x \in \mathbb{R}^n$ you expect to get the shapes:
-
-* $f(x) \in \mathbb{R}^m$, the value of $f$ at $x$,
-* $\partial f(x) \in \mathbb{R}^{m \times n}$, the Jacobian matrix at $x$,
-* $\partial^2 f(x) \in \mathbb{R}^{m \times n \times n}$, the Hessian at $x$,
-
-and so on.
-
-To implement `hessian`, you could have used `jacfwd(jacrev(f))` or `jacrev(jacfwd(f))` or any other composition of these two. But forward-over-reverse is typically the most efficient. That's because in the inner Jacobian computation we're often differentiating a function wide Jacobian (maybe like a loss function $f : \mathbb{R}^n \to \mathbb{R}$), while in the outer Jacobian computation we're differentiating a function with a square Jacobian (since $\nabla f : \mathbb{R}^n \to \mathbb{R}^n$), which is where forward-mode wins out.
-
-
-## How it's made: Two foundational autodiff functions
-
-### Jacobian-Vector products (JVPs, a.k.a. forward-mode autodiff)
-
-JAX includes efficient and general implementations of both forward- and reverse-mode automatic differentiation. The familiar {func}`jax.grad` function is built on reverse-mode, but to explain the difference between the two modes, and when each can be useful, you need a bit of math background.
-
-
-#### JVPs in math
-
-Mathematically, given a function $f : \mathbb{R}^n \to \mathbb{R}^m$, the Jacobian of $f$ evaluated at an input point $x \in \mathbb{R}^n$, denoted $\partial f(x)$, is often thought of as a matrix in $\mathbb{R}^m \times \mathbb{R}^n$:
-
-$\qquad \partial f(x) \in \mathbb{R}^{m \times n}$.
-
-But you can also think of $\partial f(x)$ as a linear map, which maps the tangent space of the domain of $f$ at the point $x$ (which is just another copy of $\mathbb{R}^n$) to the tangent space of the codomain of $f$ at the point $f(x)$ (a copy of $\mathbb{R}^m$):
-
-$\qquad \partial f(x) : \mathbb{R}^n \to \mathbb{R}^m$.
-
-This map is called the [pushforward map](https://en.wikipedia.org/wiki/Pushforward_(differential)) of $f$ at $x$. The Jacobian matrix is just the matrix for this linear map on a standard basis.
-
-If you don't commit to one specific input point $x$, then you can think of the function $\partial f$ as first taking an input point and returning the Jacobian linear map at that input point:
-
-$\qquad \partial f : \mathbb{R}^n \to \mathbb{R}^n \to \mathbb{R}^m$.
-
-In particular, you can uncurry things so that given input point $x \in \mathbb{R}^n$ and a tangent vector $v \in \mathbb{R}^n$, you get back an output tangent vector in $\mathbb{R}^m$. We call that mapping, from $(x, v)$ pairs to output tangent vectors, the *Jacobian-vector product*, and write it as:
-
-$\qquad (x, v) \mapsto \partial f(x) v$
-
-
-#### JVPs in JAX code
-
-Back in Python code, JAX's {func}`jax.jvp` function models this transformation. Given a Python function that evaluates $f$, JAX's {func}`jax.jvp` is a way to get a Python function for evaluating $(x, v) \mapsto (f(x), \partial f(x) v)$.
-
-```{code-cell}
-from jax import jvp
-
-# Isolate the function from the weight matrix to the predictions
-f = lambda W: predict(W, b, inputs)
-
-key, subkey = random.split(key)
-v = random.normal(subkey, W.shape)
-
-# Push forward the vector `v` along `f` evaluated at `W`
-y, u = jvp(f, (W,), (v,))
-```
-
-In terms of [Haskell-like type signatures](https://wiki.haskell.org/Type_signature), you could write:
-
-```haskell
-jvp :: (a -> b) -> a -> T a -> (b, T b)
-```
-
-where `T a` is used to denote the type of the tangent space for `a`.
-
-In other words, `jvp` takes as arguments a function of type `a -> b`, a value of type `a`, and a tangent vector value of type `T a`. It gives back a pair consisting of a value of type `b` and an output tangent vector of type `T b`.
-
-The `jvp`-transformed function is evaluated much like the original function, but paired up with each primal value of type `a` it pushes along tangent values of type `T a`. For each primitive numerical operation that the original function would have applied, the `jvp`-transformed function executes a "JVP rule" for that primitive that both evaluates the primitive on the primals and applies the primitive's JVP at those primal values.
-
-That evaluation strategy has some immediate implications about computational complexity. Since we evaluate JVPs as we go, we don't need to store anything for later, and so the memory cost is independent of the depth of the computation. In addition, the FLOP cost of the `jvp`-transformed function is about 3x the cost of just evaluating the function (one unit of work for evaluating the original function, for example `sin(x)`; one unit for linearizing, like `cos(x)`; and one unit for applying the linearized function to a vector, like `cos_x * v`). Put another way, for a fixed primal point $x$, we can evaluate $v \mapsto \partial f(x) \cdot v$ for about the same marginal cost as evaluating $f$.
-
-That memory complexity sounds pretty compelling! So why don't we see forward-mode very often in machine learning?
-
-To answer that, first think about how you could use a JVP to build a full Jacobian matrix. If we apply a JVP to a one-hot tangent vector, it reveals one column of the Jacobian matrix, corresponding to the nonzero entry we fed in. So we can build a full Jacobian one column at a time, and to get each column costs about the same as one function evaluation. That will be efficient for functions with "tall" Jacobians, but inefficient for "wide" Jacobians.
-
-If you're doing gradient-based optimization in machine learning, you probably want to minimize a loss function from parameters in $\mathbb{R}^n$ to a scalar loss value in $\mathbb{R}$. That means the Jacobian of this function is a very wide matrix: $\partial f(x) \in \mathbb{R}^{1 \times n}$, which we often identify with the Gradient vector $\nabla f(x) \in \mathbb{R}^n$. Building that matrix one column at a time, with each call taking a similar number of FLOPs to evaluate the original function, sure seems inefficient! In particular, for training neural networks, where $f$ is a training loss function and $n$ can be in the millions or billions, this approach just won't scale.
-
-To do better for functions like this, you just need to use reverse-mode.
-
-
-### Vector-Jacobian products (VJPs, a.k.a. reverse-mode autodiff)
-
-Where forward-mode gives us back a function for evaluating Jacobian-vector products, which we can then use to build Jacobian matrices one column at a time, reverse-mode is a way to get back a function for evaluating vector-Jacobian products (equivalently Jacobian-transpose-vector products), which we can use to build Jacobian matrices one row at a time.
-
-
-#### VJPs in math
-
-Let's again consider a function $f : \mathbb{R}^n \to \mathbb{R}^m$.
-Starting from our notation for JVPs, the notation for VJPs is pretty simple:
-
-$\qquad (x, v) \mapsto v \partial f(x)$,
-
-where $v$ is an element of the cotangent space of $f$ at $x$ (isomorphic to another copy of $\mathbb{R}^m$). When being rigorous, we should think of $v$ as a linear map $v : \mathbb{R}^m \to \mathbb{R}$, and when we write $v \partial f(x)$ we mean function composition $v \circ \partial f(x)$, where the types work out because $\partial f(x) : \mathbb{R}^n \to \mathbb{R}^m$. But in the common case we can identify $v$ with a vector in $\mathbb{R}^m$ and use the two almost interchangeably, just like we might sometimes flip between "column vectors" and "row vectors" without much comment.
-
-With that identification, we can alternatively think of the linear part of a VJP as the transpose (or adjoint conjugate) of the linear part of a JVP:
-
-$\qquad (x, v) \mapsto \partial f(x)^\mathsf{T} v$.
-
-For a given point $x$, we can write the signature as
-
-$\qquad \partial f(x)^\mathsf{T} : \mathbb{R}^m \to \mathbb{R}^n$.
-
-The corresponding map on cotangent spaces is often called the [pullback](https://en.wikipedia.org/wiki/Pullback_(differential_geometry))
-of $f$ at $x$. The key for our purposes is that it goes from something that looks like the output of $f$ to something that looks like the input of $f$, just like we might expect from a transposed linear function.
-
-#### VJPs in JAX code
-
-Switching from math back to Python, the JAX function `vjp` can take a Python function for evaluating $f$ and give us back a Python function for evaluating the VJP $(x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))$.
-
-```{code-cell}
-from jax import vjp
-
-# Isolate the function from the weight matrix to the predictions
-f = lambda W: predict(W, b, inputs)
-
-y, vjp_fun = vjp(f, W)
-
-key, subkey = random.split(key)
-u = random.normal(subkey, y.shape)
-
-# Pull back the covector `u` along `f` evaluated at `W`
-v = vjp_fun(u)
-```
-
-In terms of [Haskell-like type signatures](https://wiki.haskell.org/Type_signature), we could write
-
-```haskell
-vjp :: (a -> b) -> a -> (b, CT b -> CT a)
-```
-
-where we use `CT a` to denote the type for the cotangent space for `a`. In words, `vjp` takes as arguments a function of type `a -> b` and a point of type `a`, and gives back a pair consisting of a value of type `b` and a linear map of type `CT b -> CT a`.
-
-This is great because it lets us build Jacobian matrices one row at a time, and the FLOP cost for evaluating $(x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))$ is only about three times the cost of evaluating $f$. In particular, if we want the gradient of a function $f : \mathbb{R}^n \to \mathbb{R}$, we can do it in just one call. That's how {func}`jax.grad` is efficient for gradient-based optimization, even for objectives like neural network training loss functions on millions or billions of parameters.
-
-There's a cost, though the FLOPs are friendly, memory scales with the depth of the computation. Also, the implementation is traditionally more complex than that of forward-mode, though JAX has some tricks up its sleeve (that's a story for a future notebook!).
-
-For more on how reverse-mode works, check out [this tutorial video from the Deep Learning Summer School in 2017](http://videolectures.net/deeplearning2017_johnson_automatic_differentiation/).
-
-
-### Vector-valued gradients with VJPs
-
-If you're interested in taking vector-valued gradients (like `tf.gradients`):
-
-```{code-cell}
-def vgrad(f, x):
- y, vjp_fn = vjp(f, x)
- return vjp_fn(jnp.ones(y.shape))[0]
-
-print(vgrad(lambda x: 3*x**2, jnp.ones((2, 2))))
-```
-
-### Hessian-vector products using both forward- and reverse-mode
-
-In a previous section, you implemented a Hessian-vector product function just using reverse-mode (assuming continuous second derivatives):
-
-```{code-cell}
-def hvp(f, x, v):
- return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
-```
-
-That's efficient, but you can do even better and save some memory by using forward-mode together with reverse-mode.
-
-Mathematically, given a function $f : \mathbb{R}^n \to \mathbb{R}$ to differentiate, a point $x \in \mathbb{R}^n$ at which to linearize the function, and a vector $v \in \mathbb{R}^n$, the Hessian-vector product function we want is:
-
-$(x, v) \mapsto \partial^2 f(x) v$
-
-Consider the helper function $g : \mathbb{R}^n \to \mathbb{R}^n$ defined to be the derivative (or gradient) of $f$, namely $g(x) = \partial f(x)$. All you need is its JVP, since that will give us:
-
-$(x, v) \mapsto \partial g(x) v = \partial^2 f(x) v$.
-
-We can translate that almost directly into code:
-
-```{code-cell}
-# forward-over-reverse
-def hvp(f, primals, tangents):
- return jvp(grad(f), primals, tangents)[1]
-```
-
-Even better, since you didn't have to call {func}`jnp.dot` directly, this `hvp` function works with arrays of any shape and with arbitrary container types (like vectors stored as nested lists/dicts/tuples), and doesn't even have a dependence on {mod}`jax.numpy`.
-
-Here's an example of how to use it:
-
-```{code-cell}
-def f(X):
- return jnp.sum(jnp.tanh(X)**2)
-
-key, subkey1, subkey2 = random.split(key, 3)
-X = random.normal(subkey1, (30, 40))
-V = random.normal(subkey2, (30, 40))
-
-ans1 = hvp(f, (X,), (V,))
-ans2 = jnp.tensordot(hessian(f)(X), V, 2)
-
-print(jnp.allclose(ans1, ans2, 1e-4, 1e-4))
-```
-
-Another way you might consider writing this is using reverse-over-forward:
-
-```{code-cell}
-# Reverse-over-forward
-def hvp_revfwd(f, primals, tangents):
- g = lambda primals: jvp(f, primals, tangents)[1]
- return grad(g)(primals)
-```
-
-That's not quite as good, though, because forward-mode has less overhead than reverse-mode, and since the outer differentiation operator here has to differentiate a larger computation than the inner one, keeping forward-mode on the outside works best:
-
-```{code-cell}
-# Reverse-over-reverse, only works for single arguments
-def hvp_revrev(f, primals, tangents):
- x, = primals
- v, = tangents
- return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
-
-
-print("Forward over reverse")
-%timeit -n10 -r3 hvp(f, (X,), (V,))
-print("Reverse over forward")
-%timeit -n10 -r3 hvp_revfwd(f, (X,), (V,))
-print("Reverse over reverse")
-%timeit -n10 -r3 hvp_revrev(f, (X,), (V,))
-
-print("Naive full Hessian materialization")
-%timeit -n10 -r3 jnp.tensordot(hessian(f)(X), V, 2)
-```
-
-## Composing VJPs, JVPs, and `jax.vmap`
-
-### Jacobian-Matrix and Matrix-Jacobian products
-
-Now that you have {func}`jax.jvp` and {func}`jax.vjp` transformations that give you functions to push-forward or pull-back single vectors at a time, you can use JAX's {func}`jax.vmap` [transformation](https://github.com/jax-ml/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, you can use that to write fast matrix-Jacobian and Jacobian-matrix products:
-
-```{code-cell}
-# Isolate the function from the weight matrix to the predictions
-f = lambda W: predict(W, b, inputs)
-
-# Pull back the covectors `m_i` along `f`, evaluated at `W`, for all `i`.
-# First, use a list comprehension to loop over rows in the matrix M.
-def loop_mjp(f, x, M):
- y, vjp_fun = vjp(f, x)
- return jnp.vstack([vjp_fun(mi) for mi in M])
-
-# Now, use vmap to build a computation that does a single fast matrix-matrix
-# multiply, rather than an outer loop over vector-matrix multiplies.
-def vmap_mjp(f, x, M):
- y, vjp_fun = vjp(f, x)
- outs, = vmap(vjp_fun)(M)
- return outs
-
-key = random.key(0)
-num_covecs = 128
-U = random.normal(key, (num_covecs,) + y.shape)
-
-loop_vs = loop_mjp(f, W, M=U)
-print('Non-vmapped Matrix-Jacobian product')
-%timeit -n10 -r3 loop_mjp(f, W, M=U)
-
-print('\nVmapped Matrix-Jacobian product')
-vmap_vs = vmap_mjp(f, W, M=U)
-%timeit -n10 -r3 vmap_mjp(f, W, M=U)
-
-assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Matrix-Jacobian Products should be identical'
-```
-
-```{code-cell}
-def loop_jmp(f, W, M):
- # jvp immediately returns the primal and tangent values as a tuple,
- # so we'll compute and select the tangents in a list comprehension
- return jnp.vstack([jvp(f, (W,), (mi,))[1] for mi in M])
-
-def vmap_jmp(f, W, M):
- _jvp = lambda s: jvp(f, (W,), (s,))[1]
- return vmap(_jvp)(M)
-
-num_vecs = 128
-S = random.normal(key, (num_vecs,) + W.shape)
-
-loop_vs = loop_jmp(f, W, M=S)
-print('Non-vmapped Jacobian-Matrix product')
-%timeit -n10 -r3 loop_jmp(f, W, M=S)
-vmap_vs = vmap_jmp(f, W, M=S)
-print('\nVmapped Jacobian-Matrix product')
-%timeit -n10 -r3 vmap_jmp(f, W, M=S)
-
-assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Jacobian-Matrix products should be identical'
-```
-
-### The implementation of `jax.jacfwd` and `jax.jacrev`
-
-Now that we've seen fast Jacobian-matrix and matrix-Jacobian products, it's not hard to guess how to write {func}`jax.jacfwd` and {func}`jax.jacrev`. We just use the same technique to push-forward or pull-back an entire standard basis (isomorphic to an identity matrix) at once.
-
-```{code-cell}
-from jax import jacrev as builtin_jacrev
-
-def our_jacrev(f):
- def jacfun(x):
- y, vjp_fun = vjp(f, x)
- # Use vmap to do a matrix-Jacobian product.
- # Here, the matrix is the Euclidean basis, so we get all
- # entries in the Jacobian at once.
- J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))
- return J
- return jacfun
-
-assert jnp.allclose(builtin_jacrev(f)(W), our_jacrev(f)(W)), 'Incorrect reverse-mode Jacobian results!'
-```
-
-```{code-cell}
-from jax import jacfwd as builtin_jacfwd
-
-def our_jacfwd(f):
- def jacfun(x):
- _jvp = lambda s: jvp(f, (x,), (s,))[1]
- Jt = vmap(_jvp, in_axes=1)(jnp.eye(len(x)))
- return jnp.transpose(Jt)
- return jacfun
-
-assert jnp.allclose(builtin_jacfwd(f)(W), our_jacfwd(f)(W)), 'Incorrect forward-mode Jacobian results!'
-```
-
-Interestingly, the [Autograd](https://github.com/hips/autograd) library couldn't do this. The [implementation](https://github.com/HIPS/autograd/blob/96a03f44da43cd7044c61ac945c483955deba957/autograd/differential_operators.py#L60) of reverse-mode `jacobian` in Autograd had to pull back one vector at a time with an outer-loop `map`. Pushing one vector at a time through the computation is much less efficient than batching it all together with {func}`jax.vmap`.
-
-Another thing that Autograd couldn't do is {func}`jax.jit`. Interestingly, no matter how much Python dynamism you use in your function to be differentiated, we could always use {func}`jax.jit` on the linear part of the computation. For example:
-
-```{code-cell}
-def f(x):
- try:
- if x < 3:
- return 2 * x ** 3
- else:
- raise ValueError
- except ValueError:
- return jnp.pi * x
-
-y, f_vjp = vjp(f, 4.)
-print(jit(f_vjp)(1.))
-```
-
-## Complex numbers and differentiation
-
-JAX is great at complex numbers and differentiation. To support both [holomorphic and non-holomorphic differentiation](https://en.wikipedia.org/wiki/Holomorphic_function), it helps to think in terms of JVPs and VJPs.
-
-Consider a complex-to-complex function $f: \mathbb{C} \to \mathbb{C}$ and identify it with a corresponding function $g: \mathbb{R}^2 \to \mathbb{R}^2$,
-
-```{code-cell}
-def f(z):
- x, y = jnp.real(z), jnp.imag(z)
- return u(x, y) + v(x, y) * 1j
-
-def g(x, y):
- return (u(x, y), v(x, y))
-```
-
-That is, we've decomposed $f(z) = u(x, y) + v(x, y) i$ where $z = x + y i$, and identified $\mathbb{C}$ with $\mathbb{R}^2$ to get $g$.
-
-Since $g$ only involves real inputs and outputs, we already know how to write a Jacobian-vector product for it, say given a tangent vector $(c, d) \in \mathbb{R}^2$, namely:
-
-$\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix}
-\begin{bmatrix} c \\ d \end{bmatrix}$.
-
-To get a JVP for the original function $f$ applied to a tangent vector $c + di \in \mathbb{C}$, we just use the same definition and identify the result as another complex number,
-
-$\partial f(x + y i)(c + d i) =
-\begin{matrix} \begin{bmatrix} 1 & i \end{bmatrix} \\ ~ \end{matrix}
-\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix}
-\begin{bmatrix} c \\ d \end{bmatrix}$.
-
-That's our definition of the JVP of a $\mathbb{C} \to \mathbb{C}$ function! Notice it doesn't matter whether or not $f$ is holomorphic: the JVP is unambiguous.
-
-Here's a check:
-
-```{code-cell}
-def check(seed):
- key = random.key(seed)
-
- # random coeffs for u and v
- key, subkey = random.split(key)
- a, b, c, d = random.uniform(subkey, (4,))
-
- def fun(z):
- x, y = jnp.real(z), jnp.imag(z)
- return u(x, y) + v(x, y) * 1j
-
- def u(x, y):
- return a * x + b * y
-
- def v(x, y):
- return c * x + d * y
-
- # primal point
- key, subkey = random.split(key)
- x, y = random.uniform(subkey, (2,))
- z = x + y * 1j
-
- # tangent vector
- key, subkey = random.split(key)
- c, d = random.uniform(subkey, (2,))
- z_dot = c + d * 1j
-
- # check jvp
- _, ans = jvp(fun, (z,), (z_dot,))
- expected = (grad(u, 0)(x, y) * c +
- grad(u, 1)(x, y) * d +
- grad(v, 0)(x, y) * c * 1j+
- grad(v, 1)(x, y) * d * 1j)
- print(jnp.allclose(ans, expected))
-```
-
-```{code-cell}
-check(0)
-check(1)
-check(2)
-```
-
-What about VJPs? We do something pretty similar: for a cotangent vector $c + di \in \mathbb{C}$ we define the VJP of $f$ as
-
-$(c + di)^* \; \partial f(x + y i) =
-\begin{matrix} \begin{bmatrix} c & -d \end{bmatrix} \\ ~ \end{matrix}
-\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix}
-\begin{bmatrix} 1 \\ -i \end{bmatrix}$.
-
-What's with the negatives? They're just to take care of complex conjugation, and the fact that we're working with covectors.
-
-Here's a check of the VJP rules:
-
-```{code-cell}
-def check(seed):
- key = random.key(seed)
-
- # random coeffs for u and v
- key, subkey = random.split(key)
- a, b, c, d = random.uniform(subkey, (4,))
-
- def fun(z):
- x, y = jnp.real(z), jnp.imag(z)
- return u(x, y) + v(x, y) * 1j
-
- def u(x, y):
- return a * x + b * y
-
- def v(x, y):
- return c * x + d * y
-
- # primal point
- key, subkey = random.split(key)
- x, y = random.uniform(subkey, (2,))
- z = x + y * 1j
-
- # cotangent vector
- key, subkey = random.split(key)
- c, d = random.uniform(subkey, (2,))
- z_bar = jnp.array(c + d * 1j) # for dtype control
-
- # check vjp
- _, fun_vjp = vjp(fun, z)
- ans, = fun_vjp(z_bar)
- expected = (grad(u, 0)(x, y) * c +
- grad(v, 0)(x, y) * (-d) +
- grad(u, 1)(x, y) * c * (-1j) +
- grad(v, 1)(x, y) * (-d) * (-1j))
- assert jnp.allclose(ans, expected, atol=1e-5, rtol=1e-5)
-```
-
-```{code-cell}
-check(0)
-check(1)
-check(2)
-```
-
-What about convenience wrappers like {func}`jax.grad`, {func}`jax.jacfwd`, and {func}`jax.jacrev`?
-
-For $\mathbb{R} \to \mathbb{R}$ functions, recall we defined `grad(f)(x)` as being `vjp(f, x)[1](1.0)`, which works because applying a VJP to a `1.0` value reveals the gradient (i.e. Jacobian, or derivative). We can do the same thing for $\mathbb{C} \to \mathbb{R}$ functions: we can still use `1.0` as the cotangent vector, and we just get out a complex number result summarizing the full Jacobian:
-
-```{code-cell}
-def f(z):
- x, y = jnp.real(z), jnp.imag(z)
- return x**2 + y**2
-
-z = 3. + 4j
-grad(f)(z)
-```
-
-For general $\mathbb{C} \to \mathbb{C}$ functions, the Jacobian has 4 real-valued degrees of freedom (as in the 2x2 Jacobian matrices above), so we can't hope to represent all of them within a complex number. But we can for holomorphic functions! A holomorphic function is precisely a $\mathbb{C} \to \mathbb{C}$ function with the special property that its derivative can be represented as a single complex number. (The [Cauchy-Riemann equations](https://en.wikipedia.org/wiki/Cauchy%E2%80%93Riemann_equations) ensure that the above 2x2 Jacobians have the special form of a scale-and-rotate matrix in the complex plane, i.e. the action of a single complex number under multiplication.) And we can reveal that one complex number using a single call to `vjp` with a covector of `1.0`.
-
-Because this only works for holomorphic functions, to use this trick we need to promise JAX that our function is holomorphic; otherwise, JAX will raise an error when {func}`jax.grad` is used for a complex-output function:
-
-```{code-cell}
-def f(z):
- return jnp.sin(z)
-
-z = 3. + 4j
-grad(f, holomorphic=True)(z)
-```
-
-All the `holomorphic=True` promise does is disable the error when the output is complex-valued. We can still write `holomorphic=True` when the function isn't holomorphic, but the answer we get out won't represent the full Jacobian. Instead, it'll be the Jacobian of the function where we just discard the imaginary part of the output:
-
-```{code-cell}
-def f(z):
- return jnp.conjugate(z)
-
-z = 3. + 4j
-grad(f, holomorphic=True)(z) # f is not actually holomorphic!
-```
-
-There are some useful upshots for how {func}`jax.grad` works here:
-
-1. We can use {func}`jax.grad` on holomorphic $\mathbb{C} \to \mathbb{C}$ functions.
-2. We can use {func}`jax.grad` to optimize $f : \mathbb{C} \to \mathbb{R}$ functions, like real-valued loss functions of complex parameters `x`, by taking steps in the direction of the conjugate of `grad(f)(x)`.
-3. If we have an $\mathbb{R} \to \mathbb{R}$ function that just happens to use some complex-valued operations internally (some of which must be non-holomorphic, e.g. FFTs used in convolutions) then {func}`jax.grad` still works and we get the same result that an implementation using only real values would have given.
-
-In any case, JVPs and VJPs are always unambiguous. And if we wanted to compute the full Jacobian matrix of a non-holomorphic $\mathbb{C} \to \mathbb{C}$ function, we can do it with JVPs or VJPs!
-
-
-You should expect complex numbers to work everywhere in JAX. Here's differentiating through a Cholesky decomposition of a complex matrix:
-
-```{code-cell}
-A = jnp.array([[5., 2.+3j, 5j],
- [2.-3j, 7., 1.+7j],
- [-5j, 1.-7j, 12.]])
-
-def f(X):
- L = jnp.linalg.cholesky(X)
- return jnp.sum((L - jnp.sin(L))**2)
-
-grad(f, holomorphic=True)(A)
-```
-
-(advanced-autodiff-custom-derivative-rules)=
-## Custom derivative rules for JAX-transformable Python functions
-
-There are two ways to define differentiation rules in JAX:
-
-1. Using {func}`jax.custom_jvp` and {func}`jax.custom_vjp` to define custom differentiation rules for Python functions that are already JAX-transformable; and
-2. Defining new `core.Primitive` instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems.
-
-This notebook is about #1. To read instead about #2, refer to the [notebook on adding primitives](https://docs.jax.dev/en/latest/notebooks/How_JAX_primitives_work.html).
-
-
-### TL;DR: Custom JVPs with {func}`jax.custom_jvp`
-
-```{code-cell}
-from jax import custom_jvp
-
-@custom_jvp
-def f(x, y):
- return jnp.sin(x) * y
-
-@f.defjvp
-def f_jvp(primals, tangents):
- x, y = primals
- x_dot, y_dot = tangents
- primal_out = f(x, y)
- tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot
- return primal_out, tangent_out
-```
-
-```{code-cell}
-print(f(2., 3.))
-y, y_dot = jvp(f, (2., 3.), (1., 0.))
-print(y)
-print(y_dot)
-print(grad(f)(2., 3.))
-```
-
-```{code-cell}
-# Equivalent alternative using the `defjvps` convenience wrapper
-
-@custom_jvp
-def f(x, y):
- return jnp.sin(x) * y
-
-f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y,
- lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot)
-```
-
-```{code-cell}
-print(f(2., 3.))
-y, y_dot = jvp(f, (2., 3.), (1., 0.))
-print(y)
-print(y_dot)
-print(grad(f)(2., 3.))
-```
-
-### TL;DR: Custom VJPs with `jax.custom_vjp`
-
-```{code-cell}
-from jax import custom_vjp
-
-@custom_vjp
-def f(x, y):
- return jnp.sin(x) * y
-
-def f_fwd(x, y):
-# Returns primal output and residuals to be used in backward pass by `f_bwd`.
- return f(x, y), (jnp.cos(x), jnp.sin(x), y)
-
-def f_bwd(res, g):
- cos_x, sin_x, y = res # Gets residuals computed in `f_fwd`
- return (cos_x * g * y, sin_x * g)
-
-f.defvjp(f_fwd, f_bwd)
-```
-
-```{code-cell}
-print(grad(f)(2., 3.))
-```
-
-### Example problems
-
-To get an idea of what problems {func}`jax.custom_jvp` and {func}`jax.custom_vjp` are meant to solve, let's go over a few examples. A more thorough introduction to the {func}`jax.custom_jvp` and {func}`jax.custom_vjp` APIs is in the next section.
-
-
-#### Example: Numerical stability
-
-One application of {func}`jax.custom_jvp` is to improve the numerical stability of differentiation.
-
-Say we want to write a function called `log1pexp`, which computes $x \mapsto \log ( 1 + e^x )$. We can write that using `jax.numpy`:
-
-```{code-cell}
-def log1pexp(x):
- return jnp.log(1. + jnp.exp(x))
-
-log1pexp(3.)
-```
-
-Since it's written in terms of `jax.numpy`, it's JAX-transformable:
-
-```{code-cell}
-print(jit(log1pexp)(3.))
-print(jit(grad(log1pexp))(3.))
-print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))
-```
-
-But there's a numerical stability problem lurking here:
-
-```{code-cell}
-print(grad(log1pexp)(100.))
-```
-
-That doesn't seem right! After all, the derivative of $x \mapsto \log (1 + e^x)$ is $x \mapsto \frac{e^x}{1 + e^x}$, and so for large values of $x$ we'd expect the value to be about 1.
-
-We can get a bit more insight into what's going on by looking at the jaxpr for the gradient computation:
-
-```{code-cell}
-from jax import make_jaxpr
-
-make_jaxpr(grad(log1pexp))(100.)
-```
-
-Stepping through how the jaxpr would be evaluated, notice that the last line would involve multiplying values that floating point math will round to 0 and $\infty$, respectively, which is never a good idea. That is, we're effectively evaluating `lambda x: (1 / (1 + jnp.exp(x))) * jnp.exp(x)` for large `x`, which effectively turns into `0. * jnp.inf`.
-
-Instead of generating such large and small values, hoping for a cancellation that floats can't always provide, we'd rather just express the derivative function as a more numerically stable program. In particular, we can write a program that more closely evaluates the equal mathematical expression $1 - \frac{1}{1 + e^x}$, with no cancellation in sight.
-
-This problem is interesting because even though our definition of `log1pexp` could already be JAX-differentiated (and transformed with {func}`jax.jit`, {func}`jax.vmap`, ...), we're not happy with the result of applying standard autodiff rules to the primitives comprising `log1pexp` and composing the result. Instead, we'd like to specify how the whole function `log1pexp` should be differentiated, as a unit, and thus arrange those exponentials better.
-
-This is one application of custom derivative rules for Python functions that are already JAX transformable: specifying how a composite function should be differentiated, while still using its original Python definition for other transformations (like {func}`jax.jit`, {func}`jax.vmap`, ...).
-
-Here's a solution using {func}`jax.custom_jvp`:
-
-```{code-cell}
-@custom_jvp
-def log1pexp(x):
- return jnp.log(1. + jnp.exp(x))
-
-@log1pexp.defjvp
-def log1pexp_jvp(primals, tangents):
- x, = primals
- x_dot, = tangents
- ans = log1pexp(x)
- ans_dot = (1 - 1/(1 + jnp.exp(x))) * x_dot
- return ans, ans_dot
-```
-
-```{code-cell}
-print(grad(log1pexp)(100.))
-```
-
-```{code-cell}
-print(jit(log1pexp)(3.))
-print(jit(grad(log1pexp))(3.))
-print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))
-```
-
-Here's a `defjvps` convenience wrapper to express the same thing:
-
-```{code-cell}
-@custom_jvp
-def log1pexp(x):
- return jnp.log(1. + jnp.exp(x))
-
-log1pexp.defjvps(lambda t, ans, x: (1 - 1/(1 + jnp.exp(x))) * t)
-```
-
-```{code-cell}
-print(grad(log1pexp)(100.))
-print(jit(log1pexp)(3.))
-print(jit(grad(log1pexp))(3.))
-print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))
-```
-
-#### Example: Enforcing a differentiation convention
-
-A related application is to enforce a differentiation convention, perhaps at a boundary.
-
-Consider the function $f : \mathbb{R}_+ \to \mathbb{R}_+$ with $f(x) = \frac{x}{1 + \sqrt{x}}$, where we take $\mathbb{R}_+ = [0, \infty)$. We might implement $f$ as a program like this:
-
-```{code-cell}
-def f(x):
- return x / (1 + jnp.sqrt(x))
-```
-
-As a mathematical function on $\mathbb{R}$ (the full real line), $f$ is not differentiable at zero (because the limit defining the derivative doesn't exist from the left). Correspondingly, autodiff produces a `nan` value:
-
-```{code-cell}
-print(grad(f)(0.))
-```
-
-But mathematically if we think of $f$ as a function on $\mathbb{R}_+$ then it is differentiable at 0 [Rudin's Principles of Mathematical Analysis Definition 5.1, or Tao's Analysis I 3rd ed. Definition 10.1.1 and Example 10.1.6]. Alternatively, we might say as a convention we want to consider the directional derivative from the right. So there is a sensible value for the Python function `grad(f)` to return at `0.0`, namely `1.0`. By default, JAX's machinery for differentiation assumes all functions are defined over $\mathbb{R}$ and thus doesn't produce `1.0` here.
-
-We can use a custom JVP rule! In particular, we can define the JVP rule in terms of the derivative function $x \mapsto \frac{\sqrt{x} + 2}{2(\sqrt{x} + 1)^2}$ on $\mathbb{R}_+$,
-
-```{code-cell}
-@custom_jvp
-def f(x):
- return x / (1 + jnp.sqrt(x))
-
-@f.defjvp
-def f_jvp(primals, tangents):
- x, = primals
- x_dot, = tangents
- ans = f(x)
- ans_dot = ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * x_dot
- return ans, ans_dot
-```
-
-```{code-cell}
-print(grad(f)(0.))
-```
-
-Here's the convenience wrapper version:
-
-```{code-cell}
-@custom_jvp
-def f(x):
- return x / (1 + jnp.sqrt(x))
-
-f.defjvps(lambda t, ans, x: ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * t)
-```
-
-```{code-cell}
-print(grad(f)(0.))
-```
-
-#### Example: Gradient clipping
-
-While in some cases we want to express a mathematical differentiation computation, in other cases we may even want to take a step away from mathematics to adjust the computation autodiff performs. One canonical example is reverse-mode gradient clipping.
-
-For gradient clipping, we can use {func}`jnp.clip` together with a {func}`jax.custom_vjp` reverse-mode-only rule:
-
-```{code-cell}
-from functools import partial
-
-@custom_vjp
-def clip_gradient(lo, hi, x):
- return x # identity function
-
-def clip_gradient_fwd(lo, hi, x):
- return x, (lo, hi) # save bounds as residuals
-
-def clip_gradient_bwd(res, g):
- lo, hi = res
- return (None, None, jnp.clip(g, lo, hi)) # use None to indicate zero cotangents for lo and hi
-
-clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)
-```
-
-```{code-cell}
-import matplotlib.pyplot as plt
-
-t = jnp.linspace(0, 10, 1000)
-
-plt.plot(jnp.sin(t))
-plt.plot(vmap(grad(jnp.sin))(t))
-```
-
-```{code-cell}
-def clip_sin(x):
- x = clip_gradient(-0.75, 0.75, x)
- return jnp.sin(x)
-
-plt.plot(clip_sin(t))
-plt.plot(vmap(grad(clip_sin))(t))
-```
-
-#### Example: Python debugging
-
-Another application that is motivated by development workflow rather than numerics is to set a `pdb` debugger trace in the backward pass of reverse-mode autodiff.
-
-When trying to track down the source of a `nan` runtime error, or just examine carefully the cotangent (gradient) values being propagated, it can be useful to insert a debugger at a point in the backward pass that corresponds to a specific point in the primal computation. You can do that with {func}`jax.custom_vjp`.
-
-We'll defer an example until the next section.
-
-
-
-#### Example: Implicit function differentiation of iterative implementations
-
-This example gets pretty deep in the mathematical weeds!
-
-Another application for {func}`jax.custom_vjp` is reverse-mode differentiation of functions that are JAX-transformable (by {func}`jax.jit`, {func}`jax.vmap`, ...) but not efficiently JAX-differentiable for some reason, perhaps because they involve {func}`jax.lax.while_loop`. (It's not possible to produce an XLA HLO program that efficiently computes the reverse-mode derivative of an XLA HLO While loop because that would require a program with unbounded memory use, which isn't possible to express in XLA HLO, at least without "side-effecting" interactions through infeed/outfeed.)
-
-For example, consider this `fixed_point` routine which computes a fixed point by iteratively applying a function in a `while_loop`:
-
-```{code-cell}
-from jax.lax import while_loop
-
-def fixed_point(f, a, x_guess):
- def cond_fun(carry):
- x_prev, x = carry
- return jnp.abs(x_prev - x) > 1e-6
-
- def body_fun(carry):
- _, x = carry
- return x, f(a, x)
-
- _, x_star = while_loop(cond_fun, body_fun, (x_guess, f(a, x_guess)))
- return x_star
-```
-
-This is an iterative procedure for numerically solving the equation $x = f(a, x)$ for $x$, by iterating $x_{t+1} = f(a, x_t)$ until $x_{t+1}$ is sufficiently close to $x_t$. The result $x^*$ depends on the parameters $a$, and so we can think of there being a function $a \mapsto x^*(a)$ that is implicitly defined by equation $x = f(a, x)$.
-
-We can use `fixed_point` to run iterative procedures to convergence, for example running Newton's method to calculate square roots while only executing adds, multiplies, and divides:
-
-```{code-cell}
-def newton_sqrt(a):
- update = lambda a, x: 0.5 * (x + a / x)
- return fixed_point(update, a, a)
-```
-
-```{code-cell}
-print(newton_sqrt(2.))
-```
-
-We can {func}`jax.vmap` or {func}`jax.jit` the function as well:
-
-```{code-cell}
-print(jit(vmap(newton_sqrt))(jnp.array([1., 2., 3., 4.])))
-```
-
-We can't apply reverse-mode automatic differentiation because of the `while_loop`, but it turns out we wouldn't want to anyway: instead of differentiating through the implementation of `fixed_point` and all its iterations, we can exploit the mathematical structure to do something that is much more memory-efficient (and FLOP-efficient in this case, too!). We can instead use the implicit function theorem [Prop A.25 of Bertsekas's Nonlinear Programming, 2nd ed.], which guarantees (under some conditions) the existence of the mathematical objects we're about to use. In essence, we linearize the solution and solve those linear equations iteratively to compute the derivatives we want.
-
-Consider again the equation $x = f(a, x)$ and the function $x^*$. We want to evaluate vector-Jacobian products like $v^\mathsf{T} \mapsto v^\mathsf{T} \partial x^*(a_0)$.
-
-At least in an open neighborhood around the point $a_0$ at which we want to differentiate, let's assume that the equation $x^*(a) = f(a, x^*(a))$ holds for all $a$. Since the two sides are equal as functions of $a$, their derivatives must be equal as well, so let's differentiate both sides:
-
-$\qquad \partial x^*(a) = \partial_0 f(a, x^*(a)) + \partial_1 f(a, x^*(a)) \partial x^*(a)$.
-
-Setting $A = \partial_1 f(a_0, x^*(a_0))$ and $B = \partial_0 f(a_0, x^*(a_0))$, we can write the quantity we're after more simply as:
-
-$\qquad \partial x^*(a_0) = B + A \partial x^*(a_0)$,
-
-or, by rearranging,
-
-$\qquad \partial x^*(a_0) = (I - A)^{-1} B$.
-
-That means we can evaluate vector-Jacobian products, such as:
-
-$\qquad v^\mathsf{T} \partial x^*(a_0) = v^\mathsf{T} (I - A)^{-1} B = w^\mathsf{T} B$,
-
-where $w^\mathsf{T} = v^\mathsf{T} (I - A)^{-1}$, or equivalently $w^\mathsf{T} = v^\mathsf{T} + w^\mathsf{T} A$, or equivalently $w^\mathsf{T}$ is the fixed point of the map $u^\mathsf{T} \mapsto v^\mathsf{T} + u^\mathsf{T} A$. That last characterization gives us a way to write the VJP for `fixed_point` in terms of a call to `fixed_point`! Moreover, after expanding $A$ and $B$ back out, you can conclude you need only to evaluate VJPs of $f$ at $(a_0, x^*(a_0))$.
-
-Here's the upshot:
-
-```{code-cell}
-@partial(custom_vjp, nondiff_argnums=(0,))
-def fixed_point(f, a, x_guess):
- def cond_fun(carry):
- x_prev, x = carry
- return jnp.abs(x_prev - x) > 1e-6
-
- def body_fun(carry):
- _, x = carry
- return x, f(a, x)
-
- _, x_star = while_loop(cond_fun, body_fun, (x_guess, f(a, x_guess)))
- return x_star
-
-def fixed_point_fwd(f, a, x_init):
- x_star = fixed_point(f, a, x_init)
- return x_star, (a, x_star)
-
-def fixed_point_rev(f, res, x_star_bar):
- a, x_star = res
- _, vjp_a = vjp(lambda a: f(a, x_star), a)
- a_bar, = vjp_a(fixed_point(partial(rev_iter, f),
- (a, x_star, x_star_bar),
- x_star_bar))
- return a_bar, jnp.zeros_like(x_star)
-
-def rev_iter(f, packed, u):
- a, x_star, x_star_bar = packed
- _, vjp_x = vjp(lambda x: f(a, x), x_star)
- return x_star_bar + vjp_x(u)[0]
-
-fixed_point.defvjp(fixed_point_fwd, fixed_point_rev)
-```
-
-```{code-cell}
-print(newton_sqrt(2.))
-```
-
-```{code-cell}
-print(grad(newton_sqrt)(2.))
-print(grad(grad(newton_sqrt))(2.))
-```
-
-We can check our answers by differentiating {func}`jnp.sqrt`, which uses a totally different implementation:
-
-```{code-cell}
-print(grad(jnp.sqrt)(2.))
-print(grad(grad(jnp.sqrt))(2.))
-```
-
-A limitation to this approach is that the argument `f` can't close over any values involved in differentiation. That is, you might notice that we kept the parameter `a` explicit in the argument list of `fixed_point`. For this use case, consider using the low-level primitive `lax.custom_root`, which allows for derivatives in closed-over variables with custom root-finding functions.
-
-
-### Basic usage of `jax.custom_jvp` and `jax.custom_vjp` APIs
-
-#### Use `jax.custom_jvp` to define forward-mode (and, indirectly, reverse-mode) rules
-
-Here's a canonical basic example of using {func}`jax.custom_jvp`, where the comments use
-[Haskell-like type signatures](https://wiki.haskell.org/Type_signature):
-
-```{code-cell}
-# f :: a -> b
-@custom_jvp
-def f(x):
- return jnp.sin(x)
-
-# f_jvp :: (a, T a) -> (b, T b)
-def f_jvp(primals, tangents):
- x, = primals
- t, = tangents
- return f(x), jnp.cos(x) * t
-
-f.defjvp(f_jvp)
-```
-
-```{code-cell}
-print(f(3.))
-
-y, y_dot = jvp(f, (3.,), (1.,))
-print(y)
-print(y_dot)
-```
-
-In other words, we start with a primal function `f` that takes inputs of type `a` and produces outputs of type `b`. We associate with it a JVP rule function `f_jvp` that takes a pair of inputs representing the primal inputs of type `a` and the corresponding tangent inputs of type `T a`, and produces a pair of outputs representing the primal outputs of type `b` and tangent outputs of type `T b`. The tangent outputs should be a linear function of the tangent inputs.
-
-You can also use `f.defjvp` as a decorator, as in
-
-```python
-@custom_jvp
-def f(x):
- ...
-
-@f.defjvp
-def f_jvp(primals, tangents):
- ...
-```
-
-Even though we defined only a JVP rule and no VJP rule, we can use both forward- and reverse-mode differentiation on `f`. JAX will automatically transpose the linear computation on tangent values from our custom JVP rule, computing the VJP as efficiently as if we had written the rule by hand:
-
-```{code-cell}
-print(grad(f)(3.))
-print(grad(grad(f))(3.))
-```
-
-For automatic transposition to work, the JVP rule's output tangents must be linear as a function of the input tangents. Otherwise a transposition error is raised.
-
-Multiple arguments work like this:
-
-```{code-cell}
-@custom_jvp
-def f(x, y):
- return x ** 2 * y
-
-@f.defjvp
-def f_jvp(primals, tangents):
- x, y = primals
- x_dot, y_dot = tangents
- primal_out = f(x, y)
- tangent_out = 2 * x * y * x_dot + x ** 2 * y_dot
- return primal_out, tangent_out
-```
-
-```{code-cell}
-print(grad(f)(2., 3.))
-```
-
-The `defjvps` convenience wrapper lets us define a JVP for each argument separately, and the results are computed separately then summed:
-
-```{code-cell}
-@custom_jvp
-def f(x):
- return jnp.sin(x)
-
-f.defjvps(lambda t, ans, x: jnp.cos(x) * t)
-```
-
-```{code-cell}
-print(grad(f)(3.))
-```
-
-Here's a `defjvps` example with multiple arguments:
-
-```{code-cell}
-@custom_jvp
-def f(x, y):
- return x ** 2 * y
-
-f.defjvps(lambda x_dot, primal_out, x, y: 2 * x * y * x_dot,
- lambda y_dot, primal_out, x, y: x ** 2 * y_dot)
-```
-
-```{code-cell}
-print(grad(f)(2., 3.))
-print(grad(f, 0)(2., 3.)) # same as above
-print(grad(f, 1)(2., 3.))
-```
-
-As a shorthand, with `defjvps` you can pass a `None` value to indicate that the JVP for a particular argument is zero:
-
-```{code-cell}
-@custom_jvp
-def f(x, y):
- return x ** 2 * y
-
-f.defjvps(lambda x_dot, primal_out, x, y: 2 * x * y * x_dot,
- None)
-```
-
-```{code-cell}
-print(grad(f)(2., 3.))
-print(grad(f, 0)(2., 3.)) # same as above
-print(grad(f, 1)(2., 3.))
-```
-
-Calling a {func}`jax.custom_jvp` function with keyword arguments, or writing a {func}`jax.custom_jvp` function definition with default arguments, are both allowed so long as they can be unambiguously mapped to positional arguments based on the function signature retrieved by the standard library `inspect.signature` mechanism.
-
-When you're not performing differentiation, the function `f` is called just as if it weren't decorated by {func}`jax.custom_jvp`:
-
-```{code-cell}
-@custom_jvp
-def f(x):
- print('called f!') # a harmless side-effect
- return jnp.sin(x)
-
-@f.defjvp
-def f_jvp(primals, tangents):
- print('called f_jvp!') # a harmless side-effect
- x, = primals
- t, = tangents
- return f(x), jnp.cos(x) * t
-```
-
-```{code-cell}
-print(f(3.))
-```
-
-```{code-cell}
-print(vmap(f)(jnp.arange(3.)))
-print(jit(f)(3.))
-```
-
-The custom JVP rule is invoked during differentiation, whether forward or reverse:
-
-```{code-cell}
-y, y_dot = jvp(f, (3.,), (1.,))
-print(y_dot)
-```
-
-```{code-cell}
-print(grad(f)(3.))
-```
-
-Notice that `f_jvp` calls `f` to compute the primal outputs. In the context of higher-order differentiation, each application of a differentiation transform will use the custom JVP rule if and only if the rule calls the original `f` to compute the primal outputs. (This represents a kind of fundamental tradeoff, where we can't make use of intermediate values from the evaluation of `f` in our rule _and also_ have the rule apply in all orders of higher-order differentiation.)
-
-```{code-cell}
-grad(grad(f))(3.)
-```
-
-You can use Python control flow with {func}`jax.custom_jvp`:
-
-```{code-cell}
-@custom_jvp
-def f(x):
- if x > 0:
- return jnp.sin(x)
- else:
- return jnp.cos(x)
-
-@f.defjvp
-def f_jvp(primals, tangents):
- x, = primals
- x_dot, = tangents
- ans = f(x)
- if x > 0:
- return ans, 2 * x_dot
- else:
- return ans, 3 * x_dot
-```
-
-```{code-cell}
-print(grad(f)(1.))
-print(grad(f)(-1.))
-```
-
-#### Use `jax.custom_vjp` to define custom reverse-mode-only rules
-
-While {func}`jax.custom_jvp` suffices for controlling both forward- and, via JAX's automatic transposition, reverse-mode differentiation behavior, in some cases we may want to directly control a VJP rule, for example in the latter two example problems presented above. We can do that with {func}`jax.custom_vjp`:
-
-```{code-cell}
-from jax import custom_vjp
-
-# f :: a -> b
-@custom_vjp
-def f(x):
- return jnp.sin(x)
-
-# f_fwd :: a -> (b, c)
-def f_fwd(x):
- return f(x), jnp.cos(x)
-
-# f_bwd :: (c, CT b) -> CT a
-def f_bwd(cos_x, y_bar):
- return (cos_x * y_bar,)
-
-f.defvjp(f_fwd, f_bwd)
-```
-
-```{code-cell}
-print(f(3.))
-print(grad(f)(3.))
-```
-
-In other words, we again start with a primal function `f` that takes inputs of type `a` and produces outputs of type `b`. We associate with it two functions, `f_fwd` and `f_bwd`, which describe how to perform the forward- and backward-passes of reverse-mode autodiff, respectively.
-
-The function `f_fwd` describes the forward pass, not only the primal computation but also what values to save for use on the backward pass. Its input signature is just like that of the primal function `f`, in that it takes a primal input of type `a`. But as output it produces a pair, where the first element is the primal output `b` and the second element is any "residual" data of type `c` to be stored for use by the backward pass. (This second output is analogous to [PyTorch's save_for_backward mechanism](https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html).)
-
-The function `f_bwd` describes the backward pass. It takes two inputs, where the first is the residual data of type `c` produced by `f_fwd` and the second is the output cotangents of type `CT b` corresponding to the output of the primal function. It produces an output of type `CT a` representing the cotangents corresponding to the input of the primal function. In particular, the output of `f_bwd` must be a sequence (e.g. a tuple) of length equal to the number of arguments to the primal function.
-
-So multiple arguments work like this:
-
-```{code-cell}
-@custom_vjp
-def f(x, y):
- return jnp.sin(x) * y
-
-def f_fwd(x, y):
- return f(x, y), (jnp.cos(x), jnp.sin(x), y)
-
-def f_bwd(res, g):
- cos_x, sin_x, y = res
- return (cos_x * g * y, sin_x * g)
-
-f.defvjp(f_fwd, f_bwd)
-```
-
-```{code-cell}
-print(grad(f)(2., 3.))
-```
-
-Calling a {func}`jax.custom_vjp` function with keyword arguments, or writing a {func}`jax.custom_vjp` function definition with default arguments, are both allowed so long as they can be unambiguously mapped to positional arguments based on the function signature retrieved by the standard library `inspect.signature` mechanism.
-
-As with {func}`jax.custom_jvp`, the custom VJP rule composed of `f_fwd` and `f_bwd` is not invoked if differentiation is not applied. If the function is evaluated, or transformed with {func}`jax.jit`, {func}`jax.vmap`, or other non-differentiation transformations, then only `f` is called.
-
-```{code-cell}
-@custom_vjp
-def f(x):
- print("called f!")
- return jnp.sin(x)
-
-def f_fwd(x):
- print("called f_fwd!")
- return f(x), jnp.cos(x)
-
-def f_bwd(cos_x, y_bar):
- print("called f_bwd!")
- return (cos_x * y_bar,)
-
-f.defvjp(f_fwd, f_bwd)
-```
-
-```{code-cell}
-print(f(3.))
-```
-
-```{code-cell}
-print(grad(f)(3.))
-```
-
-```{code-cell}
-y, f_vjp = vjp(f, 3.)
-print(y)
-```
-
-```{code-cell}
-print(f_vjp(1.))
-```
-
-**Forward-mode autodiff cannot be used on the** {func}`jax.custom_vjp` **function** and will raise an error:
-
-```{code-cell}
-:tags: [raises-exception]
-
-from jax import jvp
-
-try:
- jvp(f, (3.,), (1.,))
-except TypeError as e:
- print('ERROR! {}'.format(e))
-```
-
-If you want to use both forward- and reverse-mode, use {func}`jax.custom_jvp` instead.
-
-We can use {func}`jax.custom_vjp` together with `pdb` to insert a debugger trace in the backward pass:
-
-```{code-cell}
-import pdb
-
-@custom_vjp
-def debug(x):
- return x # acts like identity
-
-def debug_fwd(x):
- return x, x
-
-def debug_bwd(x, g):
- import pdb; pdb.set_trace()
- return g
-
-debug.defvjp(debug_fwd, debug_bwd)
-```
-
-```{code-cell}
-def foo(x):
- y = x ** 2
- y = debug(y) # insert pdb in corresponding backward pass step
- return jnp.sin(y)
-```
-
-```python
-jax.grad(foo)(3.)
-
-> (12)debug_bwd()
--> return g
-(Pdb) p x
-Array(9., dtype=float32)
-(Pdb) p g
-Array(-0.91113025, dtype=float32)
-(Pdb) q
-```
-
-
-### More features and details
-
-#### Working with `list` / `tuple` / `dict` containers (and other pytrees)
-
-You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://docs.jax.dev/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints.
-
-Here's a contrived example with {func}`jax.custom_jvp`:
-
-```{code-cell}
-from collections import namedtuple
-Point = namedtuple("Point", ["x", "y"])
-
-@custom_jvp
-def f(pt):
- x, y = pt.x, pt.y
- return {'a': x ** 2,
- 'b': (jnp.sin(x), jnp.cos(y))}
-
-@f.defjvp
-def f_jvp(primals, tangents):
- pt, = primals
- pt_dot, = tangents
- ans = f(pt)
- ans_dot = {'a': 2 * pt.x * pt_dot.x,
- 'b': (jnp.cos(pt.x) * pt_dot.x, -jnp.sin(pt.y) * pt_dot.y)}
- return ans, ans_dot
-
-def fun(pt):
- dct = f(pt)
- return dct['a'] + dct['b'][0]
-```
-
-```{code-cell}
-pt = Point(1., 2.)
-
-print(f(pt))
-```
-
-```{code-cell}
-print(grad(fun)(pt))
-```
-
-And an analogous contrived example with {func}`jax.custom_vjp`:
-
-```{code-cell}
-@custom_vjp
-def f(pt):
- x, y = pt.x, pt.y
- return {'a': x ** 2,
- 'b': (jnp.sin(x), jnp.cos(y))}
-
-def f_fwd(pt):
- return f(pt), pt
-
-def f_bwd(pt, g):
- a_bar, (b0_bar, b1_bar) = g['a'], g['b']
- x_bar = 2 * pt.x * a_bar + jnp.cos(pt.x) * b0_bar
- y_bar = -jnp.sin(pt.y) * b1_bar
- return (Point(x_bar, y_bar),)
-
-f.defvjp(f_fwd, f_bwd)
-
-def fun(pt):
- dct = f(pt)
- return dct['a'] + dct['b'][0]
-```
-
-```{code-cell}
-pt = Point(1., 2.)
-
-print(f(pt))
-```
-
-```{code-cell}
-print(grad(fun)(pt))
-```
-
-#### Handling non-differentiable arguments
-
-Some use cases, like the final example problem, call for non-differentiable arguments like function-valued arguments to be passed to functions with custom differentiation rules, and for those arguments to also be passed to the rules themselves. In the case of `fixed_point`, the function argument `f` was such a non-differentiable argument. A similar situation arises with `jax.experimental.odeint`.
-
-##### `jax.custom_jvp` with `nondiff_argnums`
-
-Use the optional `nondiff_argnums` parameter to {func}`jax.custom_jvp` to indicate arguments like these. Here's an example with {func}`jax.custom_jvp`:
-
-```{code-cell}
-from functools import partial
-
-@partial(custom_jvp, nondiff_argnums=(0,))
-def app(f, x):
- return f(x)
-
-@app.defjvp
-def app_jvp(f, primals, tangents):
- x, = primals
- x_dot, = tangents
- return f(x), 2. * x_dot
-```
-
-```{code-cell}
-print(app(lambda x: x ** 3, 3.))
-```
-
-```{code-cell}
-print(grad(app, 1)(lambda x: x ** 3, 3.))
-```
-
-Notice the gotcha here: no matter where in the argument list these parameters appear, they're placed at the *start* of the signature of the corresponding JVP rule. Here's another example:
-
-```{code-cell}
-@partial(custom_jvp, nondiff_argnums=(0, 2))
-def app2(f, x, g):
- return f(g((x)))
-
-@app2.defjvp
-def app2_jvp(f, g, primals, tangents):
- x, = primals
- x_dot, = tangents
- return f(g(x)), 3. * x_dot
-```
-
-```{code-cell}
-print(app2(lambda x: x ** 3, 3., lambda y: 5 * y))
-```
-
-```{code-cell}
-print(grad(app2, 1)(lambda x: x ** 3, 3., lambda y: 5 * y))
-```
-
-##### `jax.custom_vjp` with `nondiff_argnums`
-
-A similar option exists for {func}`jax.custom_vjp`, and, similarly, the convention is that the non-differentiable arguments are passed as the first arguments to the `_bwd` rule, no matter where they appear in the signature of the original function. The signature of the `_fwd` rule remains unchanged - it is the same as the signature of the primal function. Here's an example:
-
-```{code-cell}
-@partial(custom_vjp, nondiff_argnums=(0,))
-def app(f, x):
- return f(x)
-
-def app_fwd(f, x):
- return f(x), x
-
-def app_bwd(f, x, g):
- return (5 * g,)
-
-app.defvjp(app_fwd, app_bwd)
-```
-
-```{code-cell}
-print(app(lambda x: x ** 2, 4.))
-```
-
-```{code-cell}
-print(grad(app, 1)(lambda x: x ** 2, 4.))
-```
-
-Refer to `fixed_point` above for another usage example.
-
-**You don't need to use** `nondiff_argnums` **with array-valued arguments**, such as, for example, ones with the integer dtype. Instead, `nondiff_argnums` should only be used for argument values that don't correspond to JAX types (essentially don't correspond to array types), like Python callables or strings. If JAX detects that an argument indicated by `nondiff_argnums` contains a JAX Tracer, then an error is raised. The `clip_gradient` function above is a good example of not using `nondiff_argnums` for integer-dtype array arguments.
-
-## Next steps
-
-There's a whole world of other autodiff tricks and functionality out there. Topics that weren't covered in this tutorial but can be worth pursuing include:
-
- - Gauss-Newton Vector Products, linearizing once
- - Custom VJPs and JVPs
- - Efficient derivatives at fixed-points
- - Estimating the trace of a Hessian using random Hessian-vector products
- - Forward-mode autodiff using only reverse-mode autodiff
- - Taking derivatives with respect to custom data types
- - Checkpointing (binomial checkpointing for efficient reverse-mode, not model snapshotting)
- - Optimizing VJPs with Jacobian pre-accumulation
diff --git a/docs/advanced_autodiff.md b/docs/advanced_autodiff.md
new file mode 100644
index 000000000000..43b62b0e0d5c
--- /dev/null
+++ b/docs/advanced_autodiff.md
@@ -0,0 +1,11 @@
+# Advanced Automatic Differentiation
+
+```{toctree}
+:caption: Advanced automatic differentiation
+:maxdepth: 1
+
+higher-order
+jacobian-vector-products
+complex-differentiation
+notebooks/Custom_derivative_rules_for_Python_code
+```
diff --git a/docs/advanced_guides.rst b/docs/advanced_guides.rst
index 7a9d3d95a58d..4a7624e08262 100644
--- a/docs/advanced_guides.rst
+++ b/docs/advanced_guides.rst
@@ -17,6 +17,7 @@ operations.
notebooks/layout
notebooks/host-offloading
multi_process
+ fault_tolerance
distributed_data_loading
notebooks/colocated-python
@@ -31,9 +32,8 @@ operations.
:maxdepth: 1
notebooks/autodiff_cookbook
- notebooks/Custom_derivative_rules_for_Python_code
notebooks/autodiff_remat
- advanced-autodiff
+ advanced_autodiff
.. toctree::
:maxdepth: 1
@@ -42,7 +42,6 @@ operations.
errors
debugging
debugging/index
- debugging/flags
transfer_guard
.. toctree::
diff --git a/docs/automatic-differentiation.md b/docs/automatic-differentiation.md
index 07af05e3d973..221dd19c5121 100644
--- a/docs/automatic-differentiation.md
+++ b/docs/automatic-differentiation.md
@@ -26,7 +26,7 @@ Computing gradients is a critical part of modern machine learning methods, and t
- {ref}`automatic-differentiation-evaluating-using-jax-value_and_grad`
- {ref}`automatic-differentiation-checking-against-numerical-differences`
-Make sure to also check out the {ref}`advanced-autodiff` tutorial for more advanced topics.
+Make sure to also check out the {ref}`"Advanced automatic differentiation" guides ` for more advanced topics.
While understanding how automatic differentiation works "under the hood" isn't crucial for using JAX in most contexts, you are encouraged to check out this quite accessible [video](https://www.youtube.com/watch?v=wG_nF1awSSY) to get a deeper sense of what's going on.
@@ -230,4 +230,4 @@ check_grads(loss, (W, b), order=2) # check up to 2nd order derivatives
## Next steps
-The {ref}`advanced-autodiff` tutorial provides more advanced and detailed explanations of how the ideas covered in this document are implemented in the JAX backend. Some features, such as {ref}`advanced-autodiff-custom-derivative-rules`, depend on understanding advanced automatic differentiation, so do check out that section in the {ref}`advanced-autodiff` tutorial if you are interested.
+The {ref}`"Advanced automatic differentiation" guides ` provide more advanced and detailed explanations of how the ideas covered in this document are implemented in the JAX backend. Some features, such as {ref}`advanced-autodiff-custom-derivative-rules`, depend on understanding advanced automatic differentiation, so do check out that section if you are interested.
diff --git a/docs/complex-differentiation.md b/docs/complex-differentiation.md
new file mode 100644
index 000000000000..cf31b90a45ef
--- /dev/null
+++ b/docs/complex-differentiation.md
@@ -0,0 +1,207 @@
+---
+jupytext:
+ formats: md:myst
+ text_representation:
+ extension: .md
+ format_name: myst
+ format_version: 0.13
+ jupytext_version: 1.16.4
+kernelspec:
+ display_name: Python 3
+ name: python3
+---
+
+# Complex numbers and differentiation
+
+JAX is great at complex numbers and differentiation. To support both [holomorphic and non-holomorphic differentiation](https://en.wikipedia.org/wiki/Holomorphic_function), it helps to think in terms of JVPs and VJPs.
+
+Consider a complex-to-complex function $f: \mathbb{C} \to \mathbb{C}$ and identify it with a corresponding function $g: \mathbb{R}^2 \to \mathbb{R}^2$,
+
+```{code-cell}
+import jax.numpy as jnp
+
+def f(z):
+ x, y = jnp.real(z), jnp.imag(z)
+ return u(x, y) + v(x, y) * 1j
+
+def g(x, y):
+ return (u(x, y), v(x, y))
+```
+
+That is, we've decomposed $f(z) = u(x, y) + v(x, y) i$ where $z = x + y i$, and identified $\mathbb{C}$ with $\mathbb{R}^2$ to get $g$.
+
+Since $g$ only involves real inputs and outputs, we already know how to write a Jacobian-vector product for it, say given a tangent vector $(c, d) \in \mathbb{R}^2$, namely:
+
+$\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix}
+\begin{bmatrix} c \\ d \end{bmatrix}$.
+
+To get a JVP for the original function $f$ applied to a tangent vector $c + di \in \mathbb{C}$, we just use the same definition and identify the result as another complex number,
+
+$\partial f(x + y i)(c + d i) =
+\begin{matrix} \begin{bmatrix} 1 & i \end{bmatrix} \\ ~ \end{matrix}
+\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix}
+\begin{bmatrix} c \\ d \end{bmatrix}$.
+
+That's our definition of the JVP of a $\mathbb{C} \to \mathbb{C}$ function! Notice it doesn't matter whether or not $f$ is holomorphic: the JVP is unambiguous.
+
+Here's a check:
+
+```{code-cell}
+from jax import random, grad, jvp
+
+def check(seed):
+ key = random.key(seed)
+
+ # random coeffs for u and v
+ key, subkey = random.split(key)
+ a, b, c, d = random.uniform(subkey, (4,))
+
+ def fun(z):
+ x, y = jnp.real(z), jnp.imag(z)
+ return u(x, y) + v(x, y) * 1j
+
+ def u(x, y):
+ return a * x + b * y
+
+ def v(x, y):
+ return c * x + d * y
+
+ # primal point
+ key, subkey = random.split(key)
+ x, y = random.uniform(subkey, (2,))
+ z = x + y * 1j
+
+ # tangent vector
+ key, subkey = random.split(key)
+ c, d = random.uniform(subkey, (2,))
+ z_dot = c + d * 1j
+
+ # check jvp
+ _, ans = jvp(fun, (z,), (z_dot,))
+ expected = (grad(u, 0)(x, y) * c +
+ grad(u, 1)(x, y) * d +
+ grad(v, 0)(x, y) * c * 1j+
+ grad(v, 1)(x, y) * d * 1j)
+ print(jnp.allclose(ans, expected))
+```
+
+```{code-cell}
+check(0)
+check(1)
+check(2)
+```
+
+What about VJPs? We do something pretty similar: for a cotangent vector $c + di \in \mathbb{C}$ we define the VJP of $f$ as
+
+$(c + di)^* \; \partial f(x + y i) =
+\begin{matrix} \begin{bmatrix} c & -d \end{bmatrix} \\ ~ \end{matrix}
+\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix}
+\begin{bmatrix} 1 \\ -i \end{bmatrix}$.
+
+What's with the negatives? They're just to take care of complex conjugation, and the fact that we're working with covectors.
+
+Here's a check of the VJP rules:
+
+```{code-cell}
+from jax import vjp
+
+def check(seed):
+ key = random.key(seed)
+
+ # random coeffs for u and v
+ key, subkey = random.split(key)
+ a, b, c, d = random.uniform(subkey, (4,))
+
+ def fun(z):
+ x, y = jnp.real(z), jnp.imag(z)
+ return u(x, y) + v(x, y) * 1j
+
+ def u(x, y):
+ return a * x + b * y
+
+ def v(x, y):
+ return c * x + d * y
+
+ # primal point
+ key, subkey = random.split(key)
+ x, y = random.uniform(subkey, (2,))
+ z = x + y * 1j
+
+ # cotangent vector
+ key, subkey = random.split(key)
+ c, d = random.uniform(subkey, (2,))
+ z_bar = jnp.array(c + d * 1j) # for dtype control
+
+ # check vjp
+ _, fun_vjp = vjp(fun, z)
+ ans, = fun_vjp(z_bar)
+ expected = (grad(u, 0)(x, y) * c +
+ grad(v, 0)(x, y) * (-d) +
+ grad(u, 1)(x, y) * c * (-1j) +
+ grad(v, 1)(x, y) * (-d) * (-1j))
+ assert jnp.allclose(ans, expected, atol=1e-5, rtol=1e-5)
+```
+
+```{code-cell}
+check(0)
+check(1)
+check(2)
+```
+
+What about convenience wrappers like {func}`jax.grad`, {func}`jax.jacfwd`, and {func}`jax.jacrev`?
+
+For $\mathbb{R} \to \mathbb{R}$ functions, recall we defined `grad(f)(x)` as being `vjp(f, x)[1](1.0)`, which works because applying a VJP to a `1.0` value reveals the gradient (i.e. Jacobian, or derivative). We can do the same thing for $\mathbb{C} \to \mathbb{R}$ functions: we can still use `1.0` as the cotangent vector, and we just get out a complex number result summarizing the full Jacobian:
+
+```{code-cell}
+def f(z):
+ x, y = jnp.real(z), jnp.imag(z)
+ return x**2 + y**2
+
+z = 3. + 4j
+grad(f)(z)
+```
+
+For general $\mathbb{C} \to \mathbb{C}$ functions, the Jacobian has 4 real-valued degrees of freedom (as in the 2x2 Jacobian matrices above), so we can't hope to represent all of them within a complex number. But we can for holomorphic functions! A holomorphic function is precisely a $\mathbb{C} \to \mathbb{C}$ function with the special property that its derivative can be represented as a single complex number. (The [Cauchy-Riemann equations](https://en.wikipedia.org/wiki/Cauchy%E2%80%93Riemann_equations) ensure that the above 2x2 Jacobians have the special form of a scale-and-rotate matrix in the complex plane, i.e. the action of a single complex number under multiplication.) And we can reveal that one complex number using a single call to `vjp` with a covector of `1.0`.
+
+Because this only works for holomorphic functions, to use this trick we need to promise JAX that our function is holomorphic; otherwise, JAX will raise an error when {func}`jax.grad` is used for a complex-output function:
+
+```{code-cell}
+def f(z):
+ return jnp.sin(z)
+
+z = 3. + 4j
+grad(f, holomorphic=True)(z)
+```
+
+All the `holomorphic=True` promise does is disable the error when the output is complex-valued. We can still write `holomorphic=True` when the function isn't holomorphic, but the answer we get out won't represent the full Jacobian. Instead, it'll be the Jacobian of the function where we just discard the imaginary part of the output:
+
+```{code-cell}
+def f(z):
+ return jnp.conjugate(z)
+
+z = 3. + 4j
+grad(f, holomorphic=True)(z) # f is not actually holomorphic!
+```
+
+There are some useful upshots for how {func}`jax.grad` works here:
+
+1. We can use {func}`jax.grad` on holomorphic $\mathbb{C} \to \mathbb{C}$ functions.
+2. We can use {func}`jax.grad` to optimize $f : \mathbb{C} \to \mathbb{R}$ functions, like real-valued loss functions of complex parameters `x`, by taking steps in the direction of the conjugate of `grad(f)(x)`.
+3. If we have an $\mathbb{R} \to \mathbb{R}$ function that just happens to use some complex-valued operations internally (some of which must be non-holomorphic, e.g. FFTs used in convolutions) then {func}`jax.grad` still works and we get the same result that an implementation using only real values would have given.
+
+In any case, JVPs and VJPs are always unambiguous. And if we wanted to compute the full Jacobian matrix of a non-holomorphic $\mathbb{C} \to \mathbb{C}$ function, we can do it with JVPs or VJPs!
+
+
+You should expect complex numbers to work everywhere in JAX. Here's differentiating through a Cholesky decomposition of a complex matrix:
+
+```{code-cell}
+A = jnp.array([[5., 2.+3j, 5j],
+ [2.-3j, 7., 1.+7j],
+ [-5j, 1.-7j, 12.]])
+
+def f(X):
+ L = jnp.linalg.cholesky(X)
+ return jnp.sum((L - jnp.sin(L))**2)
+
+grad(f, holomorphic=True)(A)
+```
diff --git a/docs/conf.py b/docs/conf.py
index 9c3845800bac..ee7cafdf2aaa 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -143,6 +143,7 @@ def _do_not_evaluate_in_jax(
'pallas/tpu/distributed.md',
'pallas/tpu/sparse.md',
'pallas/tpu/matmul.md',
+ 'pallas/tpu/core_map.md',
'jep/9407-type-promotion.md',
'autodidax.md',
'autodidax2_part1.md',
@@ -245,6 +246,7 @@ def _do_not_evaluate_in_jax(
'pallas/tpu/distributed.*',
'pallas/tpu/sparse.*',
'pallas/tpu/matmul.*',
+ 'pallas/tpu/core_map.*',
'distributed_data_loading.*',
'notebooks/host-offloading.*',
]
diff --git a/docs/contributing.md b/docs/contributing.md
index 0c85f83d8b80..40334bb9599a 100644
--- a/docs/contributing.md
+++ b/docs/contributing.md
@@ -6,8 +6,8 @@ Everyone can contribute to JAX, and we value everyone's contributions. There are
ways to contribute, including:
- Answering questions on JAX's [discussions page](https://github.com/jax-ml/jax/discussions)
-- Improving or expanding JAX's [documentation](http://docs.jax.dev/)
-- Contributing to JAX's [code-base](http://github.com/jax-ml/jax/)
+- Improving or expanding JAX's [documentation](https://docs.jax.dev)
+- Contributing to JAX's [code-base](https://github.com/jax-ml/jax)
- Contributing in any of the above ways to the broader ecosystem of [libraries built on JAX](https://github.com/jax-ml/jax#neural-network-libraries)
The JAX project follows [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
@@ -49,7 +49,7 @@ Follow these steps to contribute code:
For more information, see the {ref}`pr-checklist` below.
2. Fork the JAX repository by clicking the **Fork** button on the
- [repository page](http://www.github.com/jax-ml/jax). This creates
+ [repository page](https://github.com/jax-ml/jax). This creates
a copy of the JAX repository in your own account.
3. Install Python >= 3.11 locally in order to run tests.
@@ -68,7 +68,7 @@ Follow these steps to contribute code:
changes.
```bash
- git remote add upstream https://www.github.com/jax-ml/jax
+ git remote add upstream https://github.com/jax-ml/jax.git
```
6. Create a branch where you will develop from:
diff --git a/docs/debugging.md b/docs/debugging.md
index b86e32cb6522..9aa646f1ecb3 100644
--- a/docs/debugging.md
+++ b/docs/debugging.md
@@ -275,11 +275,3 @@ Read more in [](debugging/flags).
## Next steps
Check out the {ref}`advanced-debugging` to learn more about debugging in JAX.
-
-```{toctree}
-:hidden:
-
-debugging/print_breakpoint
-debugging/checkify_guide
-debugging/flags
-```
diff --git a/docs/debugging/flags.md b/docs/debugging/flags.md
index 53009500f8fe..a879fb69e16e 100644
--- a/docs/debugging/flags.md
+++ b/docs/debugging/flags.md
@@ -1,3 +1,17 @@
+---
+jupytext:
+ formats: md:myst
+ text_representation:
+ extension: .md
+ format_name: myst
+ format_version: 0.13
+ jupytext_version: 1.16.4
+kernelspec:
+ display_name: Python 3
+ language: python
+ name: python3
+---
+
(debugging-flags)=
# JAX debugging flags
@@ -9,13 +23,13 @@ JAX offers flags and context managers that enable catching errors more easily.
**Summary:** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code.
-`jax_debug_nans` is a JAX flag that when enabled, will cause computations to error-out immediately on production of a NaN. Switching this option on adds a NaN check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an `@jax.jit`.
+`jax_debug_nans` is a JAX flag that when enabled, will cause computations to error-out immediately on production of a NaN. Switching this option on adds a NaN check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an `@jax.jit`.
For code under an `@jax.jit`, the output of every `@jax.jit` function is checked and if a NaN is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of `@jax.jit` at a time.
There could be tricky situations that arise, like NaNs that only occur under a `@jax.jit` but don't get produced in de-optimized mode. In that case you'll see a warning message print out but your code will continue to execute.
-If the NaNs are being produced in the backward pass of a gradient evaluation, when an exception is raised several frames up in the stack trace you will be in the backward_pass function, which is essentially a simple jaxpr interpreter that walks the sequence of primitive operations in reverse.
+If the NaNs are being produced in the backward pass of a gradient evaluation, when an exception is raised several frames up in the stack trace you will be in the backward_pass function, which is essentially a simple jaxpr interpreter that walks the sequence of primitive operations in reverse.
### Usage
@@ -27,7 +41,7 @@ If you want to trace where NaNs are occurring in your functions or gradients, yo
### Example(s)
-```python
+```{code-cell}
import jax
import jax.numpy as jnp
import traceback
@@ -46,7 +60,7 @@ except FloatingPointError as e:
The NaN generated was caught. By running `%debug`, we can get a post-mortem debugger. This also works with functions under `@jax.jit`, as the example below shows.
-```python
+```{code-cell}
:tags: [raises-exception]
jax.jit(f)(5.)
@@ -56,7 +70,7 @@ When this code sees a NaN in the output of an `@jax.jit` function, it calls into
The `jax.debug_nans` context manager can be used to activate/deactivate NaN debugging. Since we activated it above with `jax.config.update`, let's deactivate it:
-```python
+```{code-cell}
with jax.debug_nans(False):
print(jax.jit(f)(5.))
```
diff --git a/docs/debugging/index.md b/docs/debugging/index.md
index 936c701f0c00..724d29af34b6 100644
--- a/docs/debugging/index.md
+++ b/docs/debugging/index.md
@@ -138,9 +138,8 @@ ENTRY main.5 {
:caption: Read more
:maxdepth: 1
+flags
print_breakpoint
checkify_guide
-./flags
xla_metadata
```
-
diff --git a/docs/faq.rst b/docs/faq.rst
index 2d3c920498f6..5653ff1cbb26 100644
--- a/docs/faq.rst
+++ b/docs/faq.rst
@@ -137,195 +137,14 @@ on GitHub.
How to use ``jit`` with methods?
--------------------------------
-Most examples of :func:`jax.jit` concern decorating stand-alone Python functions,
-but decorating a method within a class introduces some complication. For example,
-consider the following simple class, where we've used a standard :func:`~jax.jit`
-annotation on a method::
-
- >>> import jax.numpy as jnp
- >>> from jax import jit
-
- >>> class CustomClass:
- ... def __init__(self, x: jnp.ndarray, mul: bool):
- ... self.x = x
- ... self.mul = mul
- ...
- ... @jit # <---- How to do this correctly?
- ... def calc(self, y):
- ... if self.mul:
- ... return self.x * y
- ... return y
-
-However, this approach will result in an error when you attempt to call this method::
-
- >>> c = CustomClass(2, True)
- >>> c.calc(3) # doctest: +SKIP
- ---------------------------------------------------------------------------
- TypeError Traceback (most recent call last)
- File "", line 1, in ' of type is not a valid JAX type.
-
-The problem is that the first argument to the function is ``self``, which has type
-``CustomClass``, and JAX does not know how to handle this type.
-There are three basic strategies we might use in this case, and we'll discuss
-them below.
-
-Strategy 1: JIT-compiled helper function
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-The most straightforward approach is to create a helper function external to the class
-that can be JIT-decorated in the normal way. For example::
-
- >>> from functools import partial
-
- >>> class CustomClass:
- ... def __init__(self, x: jnp.ndarray, mul: bool):
- ... self.x = x
- ... self.mul = mul
- ...
- ... def calc(self, y):
- ... return _calc(self.mul, self.x, y)
-
- >>> @partial(jit, static_argnums=0)
- ... def _calc(mul, x, y):
- ... if mul:
- ... return x * y
- ... return y
-
-The result will work as expected::
-
- >>> c = CustomClass(2, True)
- >>> print(c.calc(3))
- 6
-
-The benefit of such an approach is that it is simple, explicit, and it avoids the need
-to teach JAX how to handle objects of type ``CustomClass``. However, you may wish to
-keep all the method logic in the same place.
-
-Strategy 2: Marking ``self`` as static
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-Another common pattern is to use ``static_argnums`` to mark the ``self`` argument as static.
-But this must be done with care to avoid unexpected results.
-You may be tempted to simply do this::
-
- >>> class CustomClass:
- ... def __init__(self, x: jnp.ndarray, mul: bool):
- ... self.x = x
- ... self.mul = mul
- ...
- ... # WARNING: this example is broken, as we'll see below. Don't copy & paste!
- ... @partial(jit, static_argnums=0)
- ... def calc(self, y):
- ... if self.mul:
- ... return self.x * y
- ... return y
-
-If you call the method, it will no longer raise an error::
-
- >>> c = CustomClass(2, True)
- >>> print(c.calc(3))
- 6
-
-However, there is a catch: if you mutate the object after the first method call, the
-subsequent method call may return an incorrect result::
-
- >>> c.mul = False
- >>> print(c.calc(3)) # Should print 3
- 6
-
-Why is this? When you mark an object as static, it will effectively be used as a dictionary
-key in JIT's internal compilation cache, meaning its hash (i.e. ``hash(obj)``) equality
-(i.e. ``obj1 == obj2``) and object identity (i.e. ``obj1 is obj2``) will be assumed to have
-consistent behavior. The default ``__hash__`` for a custom object is its object ID, and so
-JAX has no way of knowing that a mutated object should trigger a re-compilation.
-
-You can partially address this by defining an appropriate ``__hash__`` and ``__eq__`` methods
-for your object; for example::
-
- >>> class CustomClass:
- ... def __init__(self, x: jnp.ndarray, mul: bool):
- ... self.x = x
- ... self.mul = mul
- ...
- ... @partial(jit, static_argnums=0)
- ... def calc(self, y):
- ... if self.mul:
- ... return self.x * y
- ... return y
- ...
- ... def __hash__(self):
- ... return hash((self.x, self.mul))
- ...
- ... def __eq__(self, other):
- ... return (isinstance(other, CustomClass) and
- ... (self.x, self.mul) == (other.x, other.mul))
-
-(see the :meth:`object.__hash__` documentation for more discussion of the requirements
-when overriding ``__hash__``).
-
-This should work correctly with JIT and other transforms **so long as you never mutate
-your object**. Mutations of objects used as hash keys lead to several subtle problems,
-which is why for example mutable Python containers (e.g. :class:`dict`, :class:`list`)
-don't define ``__hash__``, while their immutable counterparts (e.g. :class:`tuple`) do.
-
-If your class relies on in-place mutations (such as setting ``self.attr = ...`` within its
-methods), then your object is not really "static" and marking it as such may lead to problems.
-Fortunately, there's another option for this case.
-
-Strategy 3: Making ``CustomClass`` a PyTree
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-The most flexible approach to correctly JIT-compiling a class method is to register the
-type as a custom PyTree object; see :ref:`pytrees-custom-pytree-nodes`. This lets you specify
-exactly which components of the class should be treated as static and which should be
-treated as dynamic. Here's how it might look::
-
- >>> class CustomClass:
- ... def __init__(self, x: jnp.ndarray, mul: bool):
- ... self.x = x
- ... self.mul = mul
- ...
- ... @jit
- ... def calc(self, y):
- ... if self.mul:
- ... return self.x * y
- ... return y
- ...
- ... def _tree_flatten(self):
- ... children = (self.x,) # arrays / dynamic values
- ... aux_data = {'mul': self.mul} # static values
- ... return (children, aux_data)
- ...
- ... @classmethod
- ... def _tree_unflatten(cls, aux_data, children):
- ... return cls(*children, **aux_data)
-
- >>> from jax import tree_util
- >>> tree_util.register_pytree_node(CustomClass,
- ... CustomClass._tree_flatten,
- ... CustomClass._tree_unflatten)
-
-This is certainly more involved, but it solves all the issues associated with the simpler
-approaches used above::
-
- >>> c = CustomClass(2, True)
- >>> print(c.calc(3))
- 6
-
- >>> c.mul = False # mutation is detected
- >>> print(c.calc(3))
- 3
-
- >>> c = CustomClass(jnp.array(2), True) # non-hashable x is supported
- >>> print(c.calc(3))
- 6
-
-So long as your ``tree_flatten`` and ``tree_unflatten`` functions correctly handle all
-relevant attributes in the class, you should be able to use objects of this type directly
-as arguments to JIT-compiled functions, without any special annotations.
+
+Moved to :ref:`jax-jit-class-methods`.
.. _faq-jax-vs-numpy:
Is JAX faster than NumPy?
-~~~~~~~~~~~~~~~~~~~~~~~~~
+-------------------------
+
One question users frequently attempt to answer with such benchmarks is whether JAX
is faster than NumPy; due to the difference in the two packages, there is not a
simple answer.
diff --git a/docs/fault_tolerance.rst b/docs/fault_tolerance.rst
new file mode 100644
index 000000000000..153b3c159399
--- /dev/null
+++ b/docs/fault_tolerance.rst
@@ -0,0 +1,1524 @@
+.. raw:: html
+
+
+
+
+
+Fault Tolerant Distributed JAX
+==============================
+
+Recall that `multi-controller JAX`_ allows you to run a JAX program distributed
+across multiple machines. By default, if *any* of these machines fail, then
+*every* machine will fail. That is, multi-controller JAX is not
+**fault-tolerant** by default.
+
+This article has three parts. In the first part, we'll explain the basics of
+how to write fault tolerant multi-controller JAX programs. In the second part,
+we'll show some example fault-tolerant multi-controller JAX programs. In the
+third part, we'll take a look under the covers at how multi-controller JAX
+implements fault tolerance.
+
+.. warning::
+
+ JAX's support for fault tolerance is still experimental. It currently only
+ works fully on GPUs. It has rough edges, is probably buggy, and is subject
+ to change. Use at your own risk.
+
+
+.. _part1:
+
+Part 1: Fault Tolerance Basics
+------------------------------
+
+Fault Intolerant By Default
+^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+By default, multi-controller JAX programs are not fault tolerant. If *any*
+process crashes, then *all* other processes will also intentionally crash. To
+make this concrete, consider the following trivial script, ``example.py``, that
+initializes multi-controller JAX by calling ``jax.distributed.initialize`` and
+then enters an infinite loop.
+
+.. literalinclude:: _static/fault_tolerance/while_loop.py
+ :language: python
+ :emphasize-lines: 12-18
+ :lines: 15-
+ :linenos:
+ :caption: ``example.py``
+
+Run ``example.py`` across four processes on a VM with four GPUs by running
+the following four commands, each in a different terminal. The
+``local_device_ids`` argument to ``jax.distributed.initialize`` ensures each
+process is assigned only one of the four GPUs. We'll explain the
+``heartbeat_timeout_seconds`` argument in just a second.
+
+.. code-block:: shell
+
+ python example.py --i=0 --n=4 # in terminal 1
+ python example.py --i=1 --n=4 # in terminal 2
+ python example.py --i=2 --n=4 # in terminal 3
+ python example.py --i=3 --n=4 # in terminal 4
+
+When you run these commands, you'll see the processes dutifully printing out
+the current time every second. Next, fail the fourth process: ``pkill -9 -f
+'python example.py --i=3 --n=4'``. After about ten seconds, the other
+processes will also terminate and spit out error messages that look something
+like this:
+
+.. code-block::
+
+ E0926 17:26:32.075402 157988 coordination_service_agent.cc:332] Polled an error from coordination service (this can be an error from this or another task).
+ F0926 17:26:32.075587 157988 client.h:77] Terminating process because the JAX distributed service detected fatal errors. This most likely indicates that another task died; see the other task logs for more details. Disable Python buffering, i.e. `python -u`, to be sure to see all the previous output. absl::Status: UNAVAILABLE: The following tasks are unhealthy (stopped sending heartbeats):
+ /job:jax_worker/replica:0/task:3
+ The tasks have crashed. Check the task logs for an earlier error, or scheduler events (e.g. preemption, eviction) to debug further.
+
+ RPC: /tensorflow.CoordinationService/PollForError [type.googleapis.com/tensorflow.CoordinationServiceError='']
+
+When a process in a multi-controller JAX program notices that a peer process
+has crashed, it decides to crash as well. The processes `share fate`_. The
+``heartbeat_timeout_seconds`` argument to ``jax.distributed.initialize``
+determines how long a process waits before concluding a peer process has died.
+The first three processes crash about ten seconds after you kill the fourth
+because we passed ``heartbeat_timeout_seconds=10`` as an argument to
+``jax.distributed.initialize``.
+
+Surviving Faults
+^^^^^^^^^^^^^^^^
+
+We can disable fate-sharing by adding the
+``--xla_gpu_nccl_terminate_on_error=false`` flag and the
+``jax_enable_recoverability`` configuration option to ``example.py``, as shown
+below:
+
+.. literalinclude:: _static/fault_tolerance/dont_fail.py
+ :language: python
+ :emphasize-lines: 1-2,15
+ :linenos:
+ :lines: 15-
+
+Again run the script across four processes and then kill the fourth. Notice
+that now, the other three processes happily continue executing.
+
+Next try failing process 0. Notice that all four processes terminate with
+error messages that look something like the following:
+
+.. code-block::
+
+ E0929 17:42:48.594192 1044529 coordination_service_agent.cc:332] Polled an error from coordination service (this can be an error from this or another task).
+ F0929 17:42:48.594200 1044529 client.h:77] Terminating process because the JAX distributed service detected fatal errors. This most likely indicates that another task died; see the other task logs for more details. Disable Python buffering, i.e. `python -u`, to be sure to see all the previous output. absl::Status: UNAVAILABLE: Failed to send RPC to coordination service. Either the leader task was preempted/died/restarted unexpectedly or this task is experiencing network issues. Check earlier logs from 1) this task, 2) the leader (usually slice 0 task 0), and 3) cluster scheduler to debug further.
+ Additional GRPC error information from remote target coordination_service while calling /tensorflow.CoordinationService/PollForError:
+ :UNKNOWN:Error received from peer {grpc_message:"Socket closed", grpc_status:14}
+
+Process 0 is special. If process 0 fails, every process will fail, even with
+fate-sharing disabled. Why? Process 0 runs an RPC service called the
+coordination service that all processes use to coordination with each other. If
+the coordination service fails, all other processes have no choice but to fail.
+See :ref:`part3` for more details.
+
+Getting Stuck in Collectives
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+``example.py`` is now able to survive faults, but the processes do not
+communicate with each other at all. Any realistic multi-controller JAX program
+would involve communication between the processes (otherwise, what's the point
+of using multi-controller JAX?). Let's edit ``example.py`` so that the
+processes perform a collective ``jnp.sum`` in every iteration of the loop.
+
+.. literalinclude:: _static/fault_tolerance/collectives.py
+ :language: python
+ :emphasize-lines: 27-32
+ :linenos:
+ :lines: 15-
+
+In the highlighted code above, the processes create an array ``x`` sharded
+across the four processes and then perform a distributed ``jnp.sum``. Again run
+the program and fail the fourth process. You'll notice that the first three
+process do not crash, but they do get *stuck*. By default, if a process fails
+while participating in a distributed computation (like ``jnp.sum``), then the
+rest of the processes participating in the computation will get stuck
+*forever*.
+
+.. _`canceling_collectives`:
+
+Cancelling Collectives
+^^^^^^^^^^^^^^^^^^^^^^
+
+We can avoid getting stuck by cancelling collectives with a failed participant.
+We can enable collective cancelling by providing a few more flags and
+environment variables, highlighted below.
+
+.. literalinclude:: _static/fault_tolerance/cancel_collectives.py
+ :language: python
+ :emphasize-lines: 1-8,22,33-35
+ :linenos:
+ :lines: 15-
+
+We also need to insert a call to
+``jax.experimental.multihost_utils._live_devices`` to make the script work. You
+should normally not do this. You should instead use the ``live_devices`` API
+that we'll introduce momentarily. For now, ``_live_devices`` is a hack to get
+the script working before we explain the proper API.
+
+Again run the script and fail the fourth process. The first three processes
+will be stuck in their call to ``jnp.sum``, but after about ten seconds, the
+call will be cancelled and ``jnp.sum`` will raise an exception that looks
+something like this:
+
+.. code-block::
+
+ jaxlib._jax.XlaRuntimeError: FAILED_PRECONDITION: Task with incarnation id 3446767950926952685 is not connected
+
+
+Knowing Who's Alive
+^^^^^^^^^^^^^^^^^^^
+
+After a process dies, the remaining *alive* procesess need to learn who is dead
+and who is alive. For this, we can use the core JAX fault tolerance API:
+``live_devices``. ``live_devices`` is a context manager that takes a list of
+devices as an argument and returns the subset of these devices that are alive.
+Below, we edit ``example.py`` to call ``live_devices``.
+
+.. literalinclude:: _static/fault_tolerance/live_devices.py
+ :language: python
+ :emphasize-lines: 34-46
+ :linenos:
+ :lines: 15-
+
+In the highlighted code above, we call ``live_devices`` with all devices
+(``jax.devices()``) to get the set ``devices`` of live devices. We then shard
+array ``x`` over these devices and perform a ``jnp.sum``. If a process fails
+while executing the ``jnp.sum``, then ``jnp.sum`` will be cancelled and raise
+an exception on the remaining live devices. Technically, the collective is not
+guaranteed to fail. We'll revisit this in :ref:`atomicity`. For now, assume it
+will fail.
+
+.. note::
+
+ ``jax.devices()`` always returns the set of *all* devices, even if some of
+ these devices are on failed processes. Use
+ ``jax.experimental.multihost_utils.live_devices`` to learn which of these
+ devices are live.
+
+Again run the script and fail the fourth process. Notice that the remaining
+three alive processes catch the exception raised by ``jnp.sum`` and continue to
+the next iteration of the while loop. In this next iteration, ``devices`` does
+not include the device on the failed fourth process. The three alive processes
+continue to execute correctly even though the fourth process is dead.
+
+Next, restart the fourth process. Notice that after the fourth process
+restarts, its device is again included in the set of alive devices returned by
+``live_devices``. All four processes then continue executing normally.
+
+At first blush, ``live_devices`` seems trivial. You give it a list of devices,
+and it returns the ones that are alive. How complicated can that be?
+Unfortunately, as with `many things in distributed systems`_, there are a lot
+subtleties to iron out. Next, we explain the **barrier** semantics and
+**atomicity** properties of ``live_devices``.
+
+Barrier Semantics
+^^^^^^^^^^^^^^^^^
+
+Recall that every process in a `multi-controller JAX`_ program should run in
+lockstep. The processes should execute the same instructions in the same order.
+Failing to do so will *almost certainly* lead to deadlocks, crashes, or
+anomalous behavior.
+
+In the context of ``live_devices``, we need to ensure that every process agrees
+on which processes are currently alive. This is difficult to ensure because
+every process is executing independently at potentially different speeds and
+processes can fail at any time. Consider again the ``example.py`` script from
+above running on four processes. Imagine process 1 and 2 call ``live_devices``,
+then process 4 fails, and then process 3 calls ``live_devices``. Process 1 and
+2 might think process 4 is alive while process 3 thinks it is dead.
+
+To avoid situations like these, ``live_devices`` guarantees that it returns the
+same set of live devices to every process. It accomplishes this using a
+barrier. A call to ``live_devicess(devices)`` blocks until every live process
+hosting a device in ``devices`` has also called ``live_devices``. Once every
+live process is in the ``live_devices`` barrier, ``live_devices`` returns the
+same set of live devices to every process.
+
+.. important::
+
+ ``live_devices`` uses a barrier to ensure that it will *always* return the
+ same set of live devices to every live process.
+
+Because ``live_devices`` implements a barrier it is susceptible to deadlock if
+used improperly. We recommend only having a single ``with live_devices`` block
+in a program. Multiple calls to ``live_devices`` is hard to reason about and
+can lead to deadlock.
+
+See :ref:`part3` for details on how the ``live_devices`` barrier is implemented
+as well as a formal semantics based on `linearizability`_.
+
+.. _atomicity:
+
+Atomicity
+^^^^^^^^^
+
+A distributed computation is **atomic** if every participant in the computation
+agrees on whether the operation succeeds or fails. In the ``example.py`` script
+above, we saw that when a process failed during the execution of a ``jnp.sum``,
+then ``jnp.sum`` would abort and raise an exception on the remaining live
+processes. So ``jnp.sum`` is atomic?
+
+Unfortunately, it's not.
+
+When a process fails during the execution of a collective operation (like
+``jnp.sum``), the remaining processes may cancel the operation and raise an
+exception or they may complete the operation successfully. Collective
+operations in JAX do not have any inherent atomicity properties.
+
+If collective operations are not atomic, however, then multi-controller JAX
+processes might diverge. For example, if a process fails during a training step
+of a machine learning model, some processes might detect the failure and roll
+the model back to a checkpoint while other processes might think the step
+succeeded and keep training.
+
+To avoid the complexities of non-atomic execution, ``live_devices`` provides
+its own atomicity guarantees despite the fact that collectives are not atomic.
+Specifically, the body of a ``with live_devices`` block is guaranteed to either
+complete successfully on all processes or raise an exception on all processes.
+More concretely, if we consider the code snippet below, either every process
+executes branch A or every process executes branch B. It is impossible for some
+processes to execute A while others execute B.
+
+.. code-block:: python
+
+ try:
+ with live_devices(jax.live_devices()) as devices:
+ ...
+ except Exception as e:
+ ... # Branch A
+ else:
+ ... # Branch B
+
+.. warning::
+
+ A ``with live_devices`` block does not guarantee atomicity if the code
+ block non-deterministically raises exceptions for reasons other than
+ collectives that fail because of a crashed process. For example, if one
+ process raises an exception because it runs out of memory, this exception
+ will not be propagated to the other processes.
+
+Recall that JAX uses `asynchronous dispatch`_. Operations like ``jnp.sum`` do
+not block until the operation is complete. Instead, they return ``jax.Arrays``
+that act as futures. This asynchrony can interact with ``live_devices`` in
+unexpected ways. For example, consider the following code that performs a
+``jnp.sum``, assigns the result to ``y``, and then prints ``y``:
+
+.. code-block:: python
+
+ x = ...
+ y = ...
+ try:
+ with live_devices(jax.live_devices()) as devices:
+ y = jnp.sum(x)
+ except Exception as e:
+ ... # Branch A
+ else:
+ ... # Branch B
+ print(y)
+
+Imagine that the ``with live_devices`` block executes successfully on all
+processes. That is, all processes execute branch B. This only guarantees that
+every process successfully created a future and assigned it to ``y``. The
+actual computation of the ``jnp.sum`` may be delayed until outside the block.
+Thus, some processes might successfully complete the ``jnp.sum`` and print the
+value of ``y`` while other processes fail to complete the ``jnp.sum`` and raise
+an exception when trying to print ``y``.
+
+To avoid this, use ``jax.block_until_ready`` to ensure that computations are
+performed within the ``with live_devices`` block. The code snippet below, which
+now calls ``jax.block_until_ready`` when assigning to ``y``, guarantees that
+every process will successfully execute the ``jnp.sum`` or every process will
+raise an exception.
+
+.. code-block:: python
+
+ x = ...
+ y = ...
+ try:
+ with live_devices(jax.live_devices()) as devices:
+ y = jax.block_until_ready(jnp.sum(x))
+ except Exception as e:
+ ... # Branch A
+ else:
+ ... # Branch B
+ print(y)
+
+See :ref:`part3` for details on how atomicity is implemented.
+
+Part 2: Examples
+----------------
+
+``live_devices`` is not a panacea; it is a tool. It does not magically make
+multi-controller JAX programs fault tolerant. Rather, it allows you to
+implement fault tolerance yourself in the way that is best for your
+application.
+
+The exact details of how you implement fault-tolerance will vary greatly based
+on the nature of your application. In this section, we present some examples of
+how to use ``live_devices``. The examples are meant to be illustrative but not
+prescriptive. There are many other ways to implement fault tolerance.
+
+Example 1: Fault Tolerant Data Parallel Training
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+In this example, we train a trivial single-parameter linear model (:math:`y =
+\alpha x`) with data parallelism across four processes. The example is
+contrived---you would never train a model with a single parameter across four
+machines---but we intentionally keep the model simple to focus on fault
+tolerance.
+
+Data parallelism makes implementing fault tolerance relatively straightforward.
+Because every process has a full copy of the model weights, if a process fails,
+we can simply ignore it and continue training. This example tolerates an
+arbitrary number of process failures (excluding process 0), but once a process
+fails, we assume it does not recover. The next example shows how to handle
+process recovery.
+
+First, we set some flags to disable fate-sharing and enable collective
+cancelling. We also make the necessary imports and define some flags.
+
+.. literalinclude:: _static/fault_tolerance/data_parallelism.py
+ :language: python
+ :lines: 15-33
+ :lineno-start: 1
+
+Next, we define a ``replicated`` function that returns an array replicated
+across a set of devices. Note that ``replicated`` doesn't actually move any
+data. It assumes the argument ``x`` already has equal value across all
+processes. It simply returns a new view of that data, in a process-spanning
+`jax.Array` with a replicated sharding.
+
+.. literalinclude:: _static/fault_tolerance/data_parallelism.py
+ :language: python
+ :lines: 35-49
+ :lineno-start: 21
+
+We define a similar ``sharded`` function that returns an array sharded across a
+set of devices. Again, ``sharded`` is not actually moving any data between
+processes.
+
+.. literalinclude:: _static/fault_tolerance/data_parallelism.py
+ :language: python
+ :lines: 52-64
+ :lineno-start: 38
+
+Now, we're ready to start writing our training loop. We begin by initializing
+multi-controller JAX by calling ``jax.distributed.initialize``.
+
+.. literalinclude:: _static/fault_tolerance/data_parallelism.py
+ :language: python
+ :lines: 67-76
+ :lineno-start: 53
+
+Then, we define our simple linear model, generate some random training data,
+and initialize some basic hyperparameters.
+
+.. literalinclude:: _static/fault_tolerance/data_parallelism.py
+ :language: python
+ :lines: 78-97
+ :lineno-start: 64
+
+Finally, we enter the main training loop.
+
+.. literalinclude:: _static/fault_tolerance/data_parallelism.py
+ :language: python
+ :lines: 99-125
+ :lineno-start: 85
+
+- Every iteration of the loop, we call ``live_devices`` to learn which devices
+ are currently alive.
+- We then ensure that the model weights are replicated across these devices and
+ ensure that the training data is sharded across these devices. Note that this
+ doesn't actually move any data between the devices; it simply creates JAX
+ arrays with the appropriate replication and sharding metadata.
+- We call ``loss_and_grad`` to compute the gradient of the weights with respect
+ to the current batch of data and then compute the new weights. Notice that we
+ assign the new weights to ``new_weights`` rather than assigning to
+ ``weights`` in case the training step fails. We also call
+ ``jax.block_until_ready`` to ensure that every process has computed the new
+ weights when we exit the ``live_devices`` block.
+- If no processes failed during the execution of the training step, then the
+ ``else`` branch is taken. The step is incremented, and ``weights`` is
+ updated. Otherwise, an exception will be raised and the ``except`` branch is
+ taken. In this case, we do not update ``step`` or ``weights`` and retry the
+ step on the next iteration with the new set of live devices.
+
+Here is the full example:
+
+.. literalinclude:: _static/fault_tolerance/data_parallelism.py
+ :language: python
+ :linenos:
+ :lines: 15-
+
+Example 2: Fault Tolerant Data Parallel Training With Recovery
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Now, we modify the example above to allow failed processes to recover. When a
+process recovers, it needs to receive the current step and model weights.
+Because we assume process 0 never fails---recall that if process 0 fails, every
+process will fail---we have process 0 send the current step and weights to
+recovering processes.
+
+First, we define ``send`` and ``recv`` functions that use a ``shard_map`` to
+send data from one device to another. The sender calls ``send``, and the
+receiver calls ``recv``.
+
+.. literalinclude:: _static/fault_tolerance/data_parallelism_with_recovery.py
+ :language: python
+ :lines: 69-90
+ :lineno-start: 55
+
+``allgather`` performs an AllGather of a single float across a set of devices.
+
+.. literalinclude:: _static/fault_tolerance/data_parallelism_with_recovery.py
+ :language: python
+ :lines: 93-100
+ :lineno-start: 79
+
+Finally, we modify the training loop to handle recovering processes, as shown
+in the highlighted code below.
+
+.. literalinclude:: _static/fault_tolerance/data_parallelism_with_recovery.py
+ :language: python
+ :lines: 135-178
+ :lineno-start: 121
+ :emphasize-lines: 7-22
+
+Recovery is a two-step process. First, we need to detect which processes are
+recovering. Second, we need process 0 to send the step and weights to the
+recovering processes.
+
+1. To detect which processes are recovering, we perform an AllGather on all
+ live processes' steps. When a failed process recovers, its ``step`` will be
+ ``0``, while the ``step`` on process ``0`` will be some positive number, so
+ if a process' step is not equal to process 0's step, then it is recovering.
+2. Then, we call the ``send`` and ``recv`` functions we defined above to
+ transfer the current step and model weights from process 0 to the recovering
+ processes.
+
+Here is the full example:
+
+.. literalinclude:: _static/fault_tolerance/data_parallelism_with_recovery.py
+ :language: python
+ :linenos:
+ :lines: 15-
+
+.. _part3:
+
+
+Part 3: Implementation Details
+------------------------------
+
+We now take a deep dive into the architecture of multi-controller JAX and the
+semantics and implementation of ``live_devices``. If you're only interested in
+writing fault-tolerant multi-controller JAX programs, the first two parts of
+this article suffice.
+
+The Coordination Service
+^^^^^^^^^^^^^^^^^^^^^^^^
+
+When you launch a multi-controller JAX program, the first process (i.e. process
+0) runs a standalone RPC server called the **coordination service**. Moreover,
+all processes (including process 0) create an RPC client to the coordination
+service. Concretely, the ``coordinator_address`` argument of
+:func:`jax.distributed.initialize` is the address of the coordination service.
+This argument lets process 0 know on what address to run the server, and it
+lets all processes know which address to connect to.
+
+The coordination service implements the multi-controller JAX **control plane**.
+For example, it can perform a distributed barrier across all processes, and it
+implements a key-value store that processes can use to exchange small amounts
+of metadata. Note, however, that the **data plane** (e.g., all collective
+operations on program data) is implemented directly between the processes and
+does not involve the coordination service.
+
+One of the most important functionalities of the coordination service is health
+checking. Every process periodically sends a heartbeat to the coordination
+service. If a process fails, it stops sending heartbeats. If the coordination
+service hasn't received a heartbeat from a process for a while, it assumes the
+process has failed.
+
+This is shown in the interactive visualization below. The coordination service
+is shown at the top and three multi-controller JAX processes are shown at the
+bottom. Note how the processes periodically send heartbeats to the controller,
+and the controller keeps track of the health of each process based on when it
+last received a heartbeat. Try failing process 2 by clicking the "Fail" button.
+Observe how the process stops sending heartbeats and the coordination service
+eventually considers the process dead.
+
+.. raw:: html
+
+
+
+
+By default, when the coordination service detects that a process has failed, it
+sends a message to all other processes requesting that they self-terminate. In
+other words, all processes in a multi-controller JAX program `share fate`_.
+Again fail process 2 in the visualization below by clicking the "Fail" button
+and observe how the coordination service notifies the other processes to fail.
+
+.. raw:: html
+
+
+
+
+This fate sharing means that multi-controller JAX programs are not at all
+fault-tolerant. They are fault-*intolerant*. To enable fault-tolerance, we
+need to do two things:
+
+- First, we need to remove fate sharing and allow processes to continue
+ executing even when a peer process has died. This can be enabled using the
+ ``jax_enable_recoverability`` option, as described in :ref:`part1`. We'll
+ assume that this option is set.
+- Second, we need to provide an API that processes can use to learn which
+ processes are alive and which have failed. This is the ``live_devices`` API
+ introduced in :ref:`part1`.
+
+There is a surprising amount of technical depth and subtlety in implementing
+the ``live_devices`` API. We'll walk through the design and implementation of
+the API step-by-step. We'll begin by introducing a simpler ``live_processes``
+API and slowly improve it until we arrive at the ``live_devices`` API.
+
+Live Processes
+^^^^^^^^^^^^^^
+
+Let's try to design a new hypothetical JAX API: ``jax.live_processes``. As the
+name suggests, we want ``jax.live_processes()`` to return the set of all
+currently alive processes. Here is a naive but (as we'll see momentarily)
+incorrect implementation. When a process calls ``jax.live_processes()``, it
+sends an RPC request to the coordination service. Remember that the
+coordination service already uses heartbeats to keep track of which processes
+are dead and which are alive, so when it receives a ``jax.live_processes``
+request, it responds with the set of processes it thinks are alive.
+
+This is illustrated below. Below each process is a "Call live_processes"
+button. You can click this button to make the process call
+``jax.live_processes``. Note how the coordination service replies to a
+``live_processess`` request with the set of alive processes. Fail process 2 by
+clicking the "Fail" button and see how it affects later calls to
+``jax.live_processes``.
+
+.. raw:: html
+
+
+
+
+This naive implementation is simple but incorrect. It is crucial that all
+processes in a multi-controller JAX job execute the same instructions in the
+same order. If the processes start to diverge, by executing different code
+paths in the JAX program, the job will behave erratically. Most likely, it will
+crash or hang or produce garbage values, and most certainly it will be very
+hard to reason about.
+
+Our naive implementation of ``jax.live_processes`` can very easily lead to
+divergence. For example, consider a multi-controller JAX job with three
+processes. If process 0 and 1 both call ``jax.live_processes`` around the same
+time that process 2 fails, the coordination service might report to process 0
+that all processes are alive but report to process 1 that only processes 0 and
+1 are alive. Try to produce this scenario in the visualization below:
+
+.. raw:: html
+
+
+
+
+If processes disagree on which processes are alive, they will almost certainly
+diverge. Thankfully, we can avoid this divergence by augmenting
+``jax.live_processes`` with barrier semantics.
+
+Barrier Semantics
+^^^^^^^^^^^^^^^^^
+
+Let's change the implementation of ``jax.live_processes`` so that when the
+coordination service receives a ``jax.live_processes()`` request, it does not
+reply right away. Instead, the coordination service only replies once *every*
+live process has called ``jax.live_processes()``. Once every alive process has
+entered the ``jax.live_processess()`` barrier, the coordination service returns
+the set of live processes. Crucially, the coordination service returns the
+*same* set of live processes to all processes, which prevents the processes
+from diverging.
+
+This is illustrated below. Note that coordination server now keeps track of
+which devices are in the ``live_processes`` barrier. Try calling
+``live_processes`` from every process. Notice how the coordination service
+doesn't respond until every process has entered the barrier. Then fail process
+2 and call ``live_processes`` from process 0 and process 1.
+
+.. raw:: html
+
+
+
+
+Formal Semantics
+^^^^^^^^^^^^^^^^
+
+Distributed systems are notoriously complex. Machines can fail at arbitrary
+times, and network messages can be dropped, delayed, and reordered. In this
+section, we introduce a formal semantics of the ``jax.live_processes`` API to
+help tame this complexity. Thinking rigorously about the semantics of
+``jax.live_processes`` will help us understand the behavior of the API even in
+pathological executions.
+
+We'll base the formal semantics of ``jax.live_processes`` on
+`linearizability`_: a popular formalism used to define the semantics of many
+distributed APIs. Concretely, we model our distributed system as a number of
+processes. Each process serially performs a number of events. There are four
+types of events:
+
+1. A process can **start** (👶). We'll assume that when a process starts, it
+ connects to the coordination service, so the coordination service is aware
+ that is has started.
+2. A process can **fail** (💀). Unlike starting, the coordination service may
+ not immediately be aware that a process has failed.
+3. A process can **send** a ``jax.live_processes`` request to the coordination
+ service.
+4. A process can **receive** a reply to a ``jax.live_processes`` request from
+ the coordination service.
+
+Below is a diagram of an execution of three processes: 0, 1, and 2. Time
+progresses from left to right. First, all three processes start. This is shown
+with the baby emojis. Then all three processes send ``jax.live_processes``
+requests to the coordination service. This is shown as the start of the thick
+colored regions. Later, all three processes receive a reply from the
+coordination service with ``0,1,2`` as the set of live devices.
+
+.. raw:: html
+
+
+
+
+
+In this simple execution, it is clear that ``jax.live_processes`` is behaving
+correctly. We can formalize this intuition with the following formal semantics.
+
+.. attention::
+
+ An execution is valid if whenever ``jax.live_processes`` returns a set ``P``
+ of live processes, there exists an instantaneous moment in time at which
+ every process in ``P`` was in the ``live_processes`` barrier and every other
+ process was dead. An implementation of ``live_processes`` is correct if
+ it only allows for valid executions.
+
+Later, we will amend these formal semantics to cover some subtle corner cases,
+but assume this simplified semantics for now.
+
+In the example above, ``live_processes`` returns ``0,1,2``. In the
+visualization below, we show that there does exist an instantaneous moment of
+time in which processes 0, 1, and 2 are all in the barrier and all other
+processes (there are none) are dead. The moment in time is drawn as a vertical
+red bar.
+
+.. raw:: html
+
+
+
+
+
+There is nothing special about the specific moment in time we chose in the
+visualization above. All that's important is that *there exists some* moment in
+time where all processes in `P` are in the barrier and all other processes are
+dead. There are many moments in time that satisfy this property, as shown
+below.
+
+.. raw:: html
+
+
+
+
+
+In the next example, processes 0 and 1 start, call ``jax.live_devices``, and
+receive ``0,1`` as a reply. Process 2 is dead throughout the execution.
+
+.. raw:: html
+
+
+
+
+
+This is a valid execution under our formal semantics because there exists a
+moment a time in which processes 0 and 1 are in the barrier and process 2 is
+dead.
+
+.. raw:: html
+
+
+
+
+
+In the following execution, process 0 calls ``jax.live_processes`` and receives
+a reply of ``0``. Process 1 calls ``jax.live_processes``, but dies before
+receiving a reply.
+
+.. raw:: html
+
+
+
+
+
+Is this a valid execution? Yes. There exists a moment in time at which process
+0 is in the barrier and process 1 is dead, as shown below. Even though process
+1 called ``jax.live_processes``, it is not guaranteed that process 1 will be
+included in the coordination service's response.
+
+For example, process 1's ``jax.live_processes`` request may have been dropped
+by the network and never received by the coordination service. So from the
+coordination service's perspective, process 1 is thoroughly dead and never even
+entered the ``live_processes`` barrier.
+
+.. raw:: html
+
+
+
+
+
+What about the same exact execution, except that process 0 now receives the
+reply ``0,1`` from the coordination service?
+
+.. raw:: html
+
+
+
+
+
+Again, this is a valid execution, as witnessed below. Intuitively, the
+coordination service could have received ``jax.live_processes`` requests from
+both processes 0 and 1 and sent the reply ``0,1`` to both. While this reply was
+in the network, process 1 failed. Thus, even though process 1 is dead when
+process 0 receives a reply, the execution is still valid.
+
+.. raw:: html
+
+
+
+
+
+This point bears repeating. If ``jax.live_processes`` returns a set ``P`` of
+processes, it does not mean that all processes in ``P`` are *currently* alive
+and all other processes are *currently* dead. It only means that *there existed
+a point in time* when this was true.
+
+In the following execution, process 1 calls ``jax.live_processes`` and fails.
+Later, process 0 starts, calls ``jax.live_processes``, and receives ``0,1`` as
+a reply.
+
+.. raw:: html
+
+
+
+
+
+Using the formal semantics described thus far, this is *not* a valid execution.
+There is never a point in time where process 0 and 1 are both alive. However,
+this *should* be a valid execution.
+
+The reason has to do with the unavoidable fact that in a distributed system, it
+is impossible to detect failures with 100% accuracy. If the coordination
+service hasn't received heartbeats from a process in a while, it considers the
+process dead. But, the coordination service cannot determine with 100%
+certainty when the process died or if the process is actually dead at all.
+Maybe the process died a long time ago, or maybe it died very recently, or
+maybe it is alive but on the other side of a network partition.
+
+Let's return to the execution above for a concrete example. Imagine the
+coordination service successfully received process 1's ``live_processes``
+request. Then, process 1 failed but the coordination service didn't detect the
+failure immediately. In the meantime, the coordination service received process
+0's ``live_processes`` request. At this point, the coordination service thought
+both processes were alive and saw that both processes were in the barrier, so
+it naturally returned ``0,1`` to both processes (though only process 0 received
+the reply because process 1 was dead).
+
+The coordination service thought process 1 was alive when it was dead. And
+sometimes the coordination service might think a process is dead when it is
+alive. Though not ideal, we need to accommodate executions like this because
+they are unavoidable.
+
+We amend our formal semantics and allow ourselves to move a failure either
+earlier or later in time, though we cannot move a failure past a different
+event from the same process. Intuitively, we can move a failure from when it
+actually happened to the point in time when the coordination service thought it
+happened. Continuing the example above, we can delay the failure of process 1
+to create a moment in time in which both processes 0 and 1 are in the barrier,
+witnessing the fact that the execution is valid.
+
+.. raw:: html
+
+
+
+
+
+Consider a similar execution below.
+
+.. raw:: html
+
+
+
+
+
+As is, there is no moment in time in which process 0 is alive and process 1 is
+dead. However, if we move the failure of process 1 leftwards, there is. How
+might such an execution arise? Imagine process 1 is partitioned from the
+coordination service. The coordination service doesn't receive any messages
+from process 1, including its heartbeats. This leads the coordination service
+to conclude that process 1 is dead, even though it isn't. Then, the
+coordination service receives process 0's ``live_processes`` request and
+responds with ``0``.
+
+.. raw:: html
+
+
+
+
+
+We cannot move a process failure past the process' other events, however. For
+example, the following execution is *invalid* because no matter where we move
+the failure of process 1, there is never a moment in time where both processes
+are in the barrier.
+
+.. raw:: html
+
+
+
+
+
+With these formal semantics, we can make sense of even complex executions. For
+example, consider the following execution.
+
+.. raw:: html
+
+
+
+
+
+
+After moving some process failures, we see the execution is valid.
+
+.. raw:: html
+
+
+
+
+
+The following execution, on the other hand, is invalid.
+
+.. raw:: html
+
+
+
+
+
+
+Atomicity
+^^^^^^^^^
+
+Equipped with ``jax.live_processes``, let's try to write some fault-tolerant
+multi-controller JAX code.
+
+.. code-block:: python
+
+ step = 0
+ while True:
+ # Get the devices on all live processes.
+ procs = jax.live_processes()
+ devices = [d for d in jax.devices() if d.process_index in procs]
+
+ # Shard array x over these devices.
+ mesh = jax.make_mesh((len(devices),), ("i",), devices=devices)
+ spec = jax.sharding.PartitionSpec("i")
+ sharding = jax.sharding.NamedSharding(mesh, spec)
+ x = jax.make_array_from_process_local_data(sharding, np.ones(1))
+
+ # Try to perform a jnp.sum.
+ try:
+ print(jnp.sum(x))
+ except:
+ # jnp.sum failed.
+ pass
+ else:
+ # jnp.sum succeeded.
+ step += 1
+
+The code repeatedly
+
+- calls ``jax.live_processes`` to learn which processes are alive,
+- computes the set of devices on the healthy processes,
+- shards an array across these healthy devices,
+- performs a ``jnp.sum`` (i.e. AllReduce) on the array, and
+- increments ``step`` if the ``jnp.sum`` succeeds.
+
+This code *looks* correct, but it has a very subtle bug. Assume the ``jnp.sum``
+is being performed across a set of processes ``P``. If one (or more) of the
+processes in ``P`` fails during the execution of the ``jnp.sum``, then
+``jnp.sum`` can behave differently on different processes. Some processes in
+``P`` might see ``jnp.sum`` return the correct result. Other processes might
+see ``jnp.sum`` raise an exception. Others might see ``jnp.sum`` return an
+incorrect result.
+
+.. warning::
+
+ If a process fails during a collective operation, the operation may behave
+ differently on different processes.
+
+This means that the processes executing the code example above might diverge.
+Some might increment ``step``, and some might not. In the trivial code example
+above, this divergence is benign, but in a real program, the divergence would
+likely lead to a crash, a deadlock, or garbage outputs. For example, if a
+multi-controller JAX program is training a model with data parallelism and
+starts to diverge, some processes might roll back their model weights to a
+previous checkpoint while others continue training, leading to a
+"franken-model" where nobody agrees on what the model weights are supposed to
+be.
+
+To write fault-tolerant code that does not diverge, we want **atomicity**. When
+executing a block of code (like the ``jnp.sum`` above), we either want *every*
+process to run the code successfully, or *every* process to learn that the code
+failed to execute successfully. We don't want some processes succeeding and
+others failing.
+
+Thankfully, we can achieve atomicity with a very simple trick: call
+``live_processes`` twice, once before a code block and once after. If all the
+processes that were alive before the block are also alive after the block, then
+the code block executed successfully on all live processes. On the other hand,
+if any process died, then all remaining processes can agree the code block
+failed to execute properly. Here's a sketch of what that might look like:
+
+.. code-block:: python
+
+ # Get the set of live processes before the code block.
+ procs_before = jax.live_processes()
+
+ # Execute the code block.
+ ...
+
+ # Get the set of live processes after the code block
+ procs_after = jax.live_processes()
+ if procs_before == procs_after:
+ # The code block executed successfully on all processes in
+ # procs_before.
+ pass
+ else:
+ # The code block did not execute successfully. All processes will
+ # agree it failed.
+ pass
+
+The code above should give you a rough idea of how to use two calls to
+``live_processes`` to achieve atomicity, but there are still a handful of small
+issues we need to address before it is fully correct. For example,
+
+- What if the code block throws an exception? We need to catch the exception
+ and still call ``live_processess`` the second time and then re-raise the
+ exception.
+- What if a process fails after the first call to ``live_processes`` and
+ recovers before the second call? Wouldn't the code block fail but the
+ processes before and after be the same? Every time a process starts, it
+ generates a random **incarnation id**. In addition to checking that the set
+ of processes hasn't changed, we also check that their incarnation ids haven't
+ changed.
+- What if a process recovers and its first call to ``live_processes`` matches
+ up with a different process' second call to ``live_processes``? Couldn't this
+ lead to a deadlock? Yes. We can avoid the problem by only calling
+ ``live_processes`` at a single program point. We can be clever and use a
+ single call to ``live_processes`` for two purposes. It can be used to check
+ that the set of processes hasn't changed since the previous call to
+ ``live_processes``, and it can be used to generate the set of live processes
+ that should be used the next time the atomic code block is executed.
+
+All these details are handled and abstracted away by the ``jax.live_devices``
+API introduced in :ref:`part1`. ``jax.live_devices`` is a context manager that
+guarantees the atomic execution of a block of code. In the code snippet below,
+``devices`` is a list of the devices on all live processes. The code block
+``A`` will execute atomically across these processes. That is, either every
+process will see the code raise an exception (branch ``B``) or every process
+will see the code succeed (branch ``C``).
+
+.. code-block:: python
+
+ try:
+ with live_devices() as devices:
+ pass # A
+ except Exception as e:
+ pass # B
+ else:
+ pass # C
+
+Cancelling Collectives
+^^^^^^^^^^^^^^^^^^^^^^
+
+As mentioned in :ref:`canceling_collectives`, if a process participating in a
+collective fails, then the other participating processes get stuck forever. We
+need to explicitly cancel these collectives to allow the alive participants to
+make progress. While the ``live_devices`` API is supported on all JAX backends
+(i.e. CPU, GPU, TPU), cancelling collectives is only supported by the GPU
+backend. Here, we briefly explain some of the implementation details behind
+collective cancelling.
+
+The GPU backend implements collectives using `NCCL`_, NVIDIA's collective
+communication library. When a set of processes wants to perform a collective,
+they form a **NCCL communicator**. Processes can then repeatedly perform
+collectives using this communicator. Creating a communicator is expensive---it
+requires network communication---so the JAX backend caches communicators keyed
+by the set of participating processes and their incarnation ids.
+
+Internally, a JAX client polls the coordination service for the current status
+of every process. If a client ever detects that a process is dead or has
+restarted with a new incarnation id, then the client aborts all communicators
+with the failed incarnation id in its cache key.
+
+.. _asynchronous dispatch: https://docs.jax.dev/en/latest/async_dispatch.html
+.. _linearizability: https://cs.brown.edu/~mph/HerlihyW90/p463-herlihy.pdf
+.. _many things in distributed systems: https://en.wikipedia.org/wiki/Fallacies_of_distributed_computing
+.. _multi-controller JAX: https://docs.jax.dev/en/latest/multi_process.html
+.. _NCCL: https://developer.nvidia.com/nccl
+.. _reference: https://docs.jax.dev/en/latest/config_options.html#jax_enable_recoverability
+.. _share fate: https://en.wikipedia.org/wiki/Fate-sharing
diff --git a/docs/gradient-checkpointing.md b/docs/gradient-checkpointing.md
index 929956ad6e31..1b6463f65024 100644
--- a/docs/gradient-checkpointing.md
+++ b/docs/gradient-checkpointing.md
@@ -19,7 +19,7 @@ kernelspec:
In this tutorial, you will learn how to control JAX automatic differentiation's saved values using {func}`jax.checkpoint` (also known as {func}`jax.remat`), which can be particularly helpful in machine learning.
-If you are new to automatic differentiation (autodiff) or need to refresh your memory, JAX has {ref}`automatic-differentiation` and {ref}`advanced-autodiff` tutorials.
+If you are new to automatic differentiation (autodiff) or need to refresh your memory, JAX has an {ref}`automatic-differentiation` tutorial and several {ref}`Advanced automatic differentiation guides `.
**TL;DR** Use the {func}`jax.checkpoint` decorator (aliased as {func}`jax.remat`) with {func}`jax.grad` to control which intermediates are saved on the forward pass versus the recomputed intermediates on the backward pass, trading off memory and FLOPs.
@@ -144,7 +144,7 @@ print_fwd_bwd(f3, W1, W2, W3, x)
### Let's think step by step
-**Note:** It may help to check out the {ref}`advanced-autodiff` tutorial prior to continuing here.
+**Note:** It may help to check out the {ref}`"Advanced automatic differentiation" guides ` prior to continuing here.
#### `jax.checkpoint` fundamentals
diff --git a/docs/higher-order.md b/docs/higher-order.md
new file mode 100644
index 000000000000..e835d3af82cc
--- /dev/null
+++ b/docs/higher-order.md
@@ -0,0 +1,336 @@
+---
+jupytext:
+ formats: md:myst
+ text_representation:
+ extension: .md
+ format_name: myst
+ format_version: 0.13
+ jupytext_version: 1.16.4
+kernelspec:
+ display_name: Python 3
+ language: python
+ name: python3
+---
+
+# Higher-order derivatives
+
+## Taking gradients (part 2)
+
+JAX's autodiff makes it easy to compute higher-order derivatives, because the functions that compute derivatives are themselves differentiable. Thus, higher-order derivatives are as easy as stacking transformations.
+
+The single-variable case was covered in the {ref}`automatic-differentiation` tutorial, where the example showed how to use {func}`jax.grad` to compute the derivative of $f(x) = x^3 + 2x^2 - 3x + 1$.
+
+In the multivariable case, higher-order derivatives are more complicated. The second-order derivative of a function is represented by its [Hessian matrix](https://en.wikipedia.org/wiki/Hessian_matrix), defined according to:
+
+$$(\mathbf{H}f)_{i,j} = \frac{\partial^2 f}{\partial_i\partial_j}.$$
+
+The Hessian of a real-valued function of several variables, $f: \mathbb R^n\to\mathbb R$, can be identified with the [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant) of its gradient.
+
+JAX provides two transformations for computing the Jacobian of a function, {func}`jax.jacfwd` and {func}`jax.jacrev`, corresponding to forward- and reverse-mode autodiff. They give the same answer, but one can be more efficient than the other in different circumstances – refer to the [video about autodiff](https://www.youtube.com/watch?v=wG_nF1awSSY).
+
+```{code-cell}
+import jax
+
+def hessian(f):
+ return jax.jacfwd(jax.grad(f))
+```
+
+Let's double check this is correct on the dot-product $f: \mathbf{x} \mapsto \mathbf{x} ^\top \mathbf{x}$.
+
+if $i=j$, $\frac{\partial^2 f}{\partial_i\partial_j}(\mathbf{x}) = 2$. Otherwise, $\frac{\partial^2 f}{\partial_i\partial_j}(\mathbf{x}) = 0$.
+
+```{code-cell}
+import jax.numpy as jnp
+
+def f(x):
+ return jnp.dot(x, x)
+
+hessian(f)(jnp.array([1., 2., 3.]))
+```
+
+## Higher-order derivative applications
+
+Some meta-learning techniques, such as Model-Agnostic Meta-Learning ([MAML](https://arxiv.org/abs/1703.03400)), require differentiating through gradient updates. In other frameworks this can be quite cumbersome, but in JAX it's much easier:
+
+```python
+def meta_loss_fn(params, data):
+ """Computes the loss after one step of SGD."""
+ grads = jax.grad(loss_fn)(params, data)
+ return loss_fn(params - lr * grads, data)
+
+meta_grads = jax.grad(meta_loss_fn)(params, data)
+```
+
+(stopping-gradients)=
+### Stopping gradients
+
+Autodiff enables automatic computation of the gradient of a function with respect to its inputs. Sometimes, however, you might want some additional control: for instance, you might want to avoid backpropagating gradients through some subset of the computational graph.
+
+Consider for instance the TD(0) ([temporal difference](https://en.wikipedia.org/wiki/Temporal_difference_learning)) reinforcement learning update. This is used to learn to estimate the *value* of a state in an environment from experience of interacting with the environment. Let's assume the value estimate $v_{\theta}(s_{t-1}$) in a state $s_{t-1}$ is parameterised by a linear function.
+
+```{code-cell}
+# Value function and initial parameters
+value_fn = lambda theta, state: jnp.dot(theta, state)
+theta = jnp.array([0.1, -0.1, 0.])
+```
+
+Consider a transition from a state $s_{t-1}$ to a state $s_t$ during which you observed the reward $r_t$
+
+```{code-cell}
+# An example transition.
+s_tm1 = jnp.array([1., 2., -1.])
+r_t = jnp.array(1.)
+s_t = jnp.array([2., 1., 0.])
+```
+
+The TD(0) update to the network parameters is:
+
+$$
+\Delta \theta = (r_t + v_{\theta}(s_t) - v_{\theta}(s_{t-1})) \nabla v_{\theta}(s_{t-1})
+$$
+
+This update is not the gradient of any loss function.
+
+However, it can be **written** as the gradient of the pseudo loss function
+
+$$
+L(\theta) = - \frac{1}{2} [r_t + v_{\theta}(s_t) - v_{\theta}(s_{t-1})]^2
+$$
+
+if the dependency of the target $r_t + v_{\theta}(s_t)$ on the parameter $\theta$ is ignored.
+
+How can you implement this in JAX? If you write the pseudo loss naively, you get:
+
+```{code-cell}
+def td_loss(theta, s_tm1, r_t, s_t):
+ v_tm1 = value_fn(theta, s_tm1)
+ target = r_t + value_fn(theta, s_t)
+ return -0.5 * ((target - v_tm1) ** 2)
+
+td_update = jax.grad(td_loss)
+delta_theta = td_update(theta, s_tm1, r_t, s_t)
+
+delta_theta
+```
+
+But `td_update` will **not** compute a TD(0) update, because the gradient computation will include the dependency of `target` on $\theta$.
+
+You can use {func}`jax.lax.stop_gradient` to force JAX to ignore the dependency of the target on $\theta$:
+
+```{code-cell}
+def td_loss(theta, s_tm1, r_t, s_t):
+ v_tm1 = value_fn(theta, s_tm1)
+ target = r_t + value_fn(theta, s_t)
+ return -0.5 * ((jax.lax.stop_gradient(target) - v_tm1) ** 2)
+
+td_update = jax.grad(td_loss)
+delta_theta = td_update(theta, s_tm1, r_t, s_t)
+
+delta_theta
+```
+
+This will treat `target` as if it did **not** depend on the parameters $\theta$ and compute the correct update to the parameters.
+
+Now, let's also calculate $\Delta \theta$ using the original TD(0) update expression, to cross-check our work. You may wish to try and implement this yourself using {func}`jax.grad` and your knowledge so far. Here's our solution:
+
+```{code-cell}
+s_grad = jax.grad(value_fn)(theta, s_tm1)
+delta_theta_original_calculation = (r_t + value_fn(theta, s_t) - value_fn(theta, s_tm1)) * s_grad
+
+delta_theta_original_calculation # [1.2, 2.4, -1.2], same as `delta_theta`
+```
+
+`jax.lax.stop_gradient` may also be useful in other settings, for instance if you want the gradient from some loss to only affect a subset of the parameters of the neural network (because, for instance, the other parameters are trained using a different loss).
+
+### Straight-through estimator using `stop_gradient`
+
+The straight-through estimator is a trick for defining a 'gradient' of a function that is otherwise non-differentiable. Given a non-differentiable function $f : \mathbb{R}^n \to \mathbb{R}^n$ that is used as part of a larger function that we wish to find a gradient of, we simply pretend during the backward pass that $f$ is the identity function. This can be implemented neatly using `jax.lax.stop_gradient`:
+
+```{code-cell}
+def f(x):
+ return jnp.round(x) # non-differentiable
+
+def straight_through_f(x):
+ # Create an exactly-zero expression with Sterbenz lemma that has
+ # an exactly-one gradient.
+ zero = x - jax.lax.stop_gradient(x)
+ return zero + jax.lax.stop_gradient(f(x))
+
+print("f(x): ", f(3.2))
+print("straight_through_f(x):", straight_through_f(3.2))
+
+print("grad(f)(x):", jax.grad(f)(3.2))
+print("grad(straight_through_f)(x):", jax.grad(straight_through_f)(3.2))
+```
+
+### Per-example gradients
+
+While most ML systems compute gradients and updates from batches of data, for reasons of computational efficiency and/or variance reduction, it is sometimes necessary to have access to the gradient/update associated with each specific sample in the batch.
+
+For instance, this is needed to prioritize data based on gradient magnitude, or to apply clipping / normalisations on a sample by sample basis.
+
+In many frameworks (PyTorch, TF, Theano) it is often not trivial to compute per-example gradients, because the library directly accumulates the gradient over the batch. Naive workarounds, such as computing a separate loss per example and then aggregating the resulting gradients are typically very inefficient.
+
+In JAX, you can define the code to compute the gradient per-sample in an easy but efficient way.
+
+Just combine the {func}`jax.jit`, {func}`jax.vmap` and {func}`jax.grad` transformations together:
+
+```{code-cell}
+perex_grads = jax.jit(jax.vmap(jax.grad(td_loss), in_axes=(None, 0, 0, 0)))
+
+# Test it:
+batched_s_tm1 = jnp.stack([s_tm1, s_tm1])
+batched_r_t = jnp.stack([r_t, r_t])
+batched_s_t = jnp.stack([s_t, s_t])
+
+perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
+```
+
+Let's go through this one transformation at a time.
+
+First, you apply {func}`jax.grad` to `td_loss` to obtain a function that computes the gradient of the loss w.r.t. the parameters on single (unbatched) inputs:
+
+```{code-cell}
+dtdloss_dtheta = jax.grad(td_loss)
+
+dtdloss_dtheta(theta, s_tm1, r_t, s_t)
+```
+
+This function computes one row of the array above.
+
+Then, you vectorise this function using {func}`jax.vmap`. This adds a batch dimension to all inputs and outputs. Now, given a batch of inputs, you produce a batch of outputs — each output in the batch corresponds to the gradient for the corresponding member of the input batch.
+
+```{code-cell}
+almost_perex_grads = jax.vmap(dtdloss_dtheta)
+
+batched_theta = jnp.stack([theta, theta])
+almost_perex_grads(batched_theta, batched_s_tm1, batched_r_t, batched_s_t)
+```
+
+This isn't quite what we want, because we have to manually feed this function a batch of `theta`s, whereas we actually want to use a single `theta`. We fix this by adding `in_axes` to the {func}`jax.vmap`, specifying theta as `None`, and the other args as `0`. This makes the resulting function add an extra axis only to the other arguments, leaving `theta` unbatched, as we want:
+
+```{code-cell}
+inefficient_perex_grads = jax.vmap(dtdloss_dtheta, in_axes=(None, 0, 0, 0))
+
+inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
+```
+
+This does what we want, but is slower than it has to be. Now, you wrap the whole thing in a {func}`jax.jit` to get the compiled, efficient version of the same function:
+
+```{code-cell}
+perex_grads = jax.jit(inefficient_perex_grads)
+
+perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
+```
+
+```{code-cell}
+%timeit inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()
+%timeit perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()
+```
+
+### Hessian-vector products with `jax.grad`-of-`jax.grad`
+
+One thing you can do with higher-order {func}`jax.grad` is build a Hessian-vector product function. (Later on you'll write an even more efficient implementation that mixes both forward- and reverse-mode, but this one will use pure reverse-mode.)
+
+A Hessian-vector product function can be useful in a [truncated Newton Conjugate-Gradient algorithm](https://en.wikipedia.org/wiki/Truncated_Newton_method) for minimizing smooth convex functions, or for studying the curvature of neural network training objectives (e.g. [1](https://arxiv.org/abs/1406.2572), [2](https://arxiv.org/abs/1811.07062), [3](https://arxiv.org/abs/1706.04454), [4](https://arxiv.org/abs/1802.03451)).
+
+For a scalar-valued function $f : \mathbb{R}^n \to \mathbb{R}$ with continuous second derivatives (so that the Hessian matrix is symmetric), the Hessian at a point $x \in \mathbb{R}^n$ is written as $\partial^2 f(x)$. A Hessian-vector product function is then able to evaluate
+
+$\qquad v \mapsto \partial^2 f(x) \cdot v$
+
+for any $v \in \mathbb{R}^n$.
+
+The trick is not to instantiate the full Hessian matrix: if $n$ is large, perhaps in the millions or billions in the context of neural networks, then that might be impossible to store.
+
+Luckily, {func}`jax.grad` already gives us a way to write an efficient Hessian-vector product function. You just have to use the identity:
+
+$\qquad \partial^2 f (x) v = \partial [x \mapsto \partial f(x) \cdot v] = \partial g(x)$,
+
+where $g(x) = \partial f(x) \cdot v$ is a new scalar-valued function that dots the gradient of $f$ at $x$ with the vector $v$. Notice that you're only ever differentiating scalar-valued functions of vector-valued arguments, which is exactly where you know {func}`jax.grad` is efficient.
+
+In JAX code, you can just write this:
+
+```{code-cell}
+def hvp(f, x, v):
+ return jax.grad(lambda x: jnp.vdot(jax.grad(f)(x), v))(x)
+```
+
+This example shows that you can freely use lexical closure, and JAX will never get perturbed or confused.
+
+You will check this implementation a few cells down, once you learn how to compute dense Hessian matrices. You'll also write an even better version that uses both forward-mode and reverse-mode.
+
+### Jacobians and Hessians using `jax.jacfwd` and `jax.jacrev`
+
+You can compute full Jacobian matrices using the {func}`jax.jacfwd` and {func}`jax.jacrev` functions:
+
+```{code-cell}
+from jax import jacfwd, jacrev
+
+# Define a sigmoid function.
+def sigmoid(x):
+ return 0.5 * (jnp.tanh(x / 2) + 1)
+
+# Outputs probability of a label being true.
+def predict(W, b, inputs):
+ return sigmoid(jnp.dot(inputs, W) + b)
+
+# Build a toy dataset.
+inputs = jnp.array([[0.52, 1.12, 0.77],
+ [0.88, -1.08, 0.15],
+ [0.52, 0.06, -1.30],
+ [0.74, -2.49, 1.39]])
+
+# Initialize random model coefficients
+key = jax.random.key(0)
+key, W_key, b_key = jax.random.split(key, 3)
+W = jax.random.normal(W_key, (3,))
+b = jax.random.normal(b_key, ())
+
+# Isolate the function from the weight matrix to the predictions
+f = lambda W: predict(W, b, inputs)
+
+J = jacfwd(f)(W)
+print("jacfwd result, with shape", J.shape)
+print(J)
+
+J = jacrev(f)(W)
+print("jacrev result, with shape", J.shape)
+print(J)
+```
+
+These two functions compute the same values (up to machine numerics), but differ in their implementation: {func}`jax.jacfwd` uses forward-mode automatic differentiation, which is more efficient for "tall" Jacobian matrices (more outputs than inputs), while {func}`jax.jacrev` uses reverse-mode, which is more efficient for "wide" Jacobian matrices (more inputs than outputs). For matrices that are near-square, {func}`jax.jacfwd` probably has an edge over {func}`jax.jacrev`.
+
+You can also use {func}`jax.jacfwd` and {func}`jax.jacrev` with container types:
+
+```{code-cell}
+def predict_dict(params, inputs):
+ return predict(params['W'], params['b'], inputs)
+
+J_dict = jax.jacrev(predict_dict)({'W': W, 'b': b}, inputs)
+for k, v in J_dict.items():
+ print("Jacobian from {} to logits is".format(k))
+ print(v)
+```
+
+For more details on forward- and reverse-mode, as well as how to implement {func}`jax.jacfwd` and {func}`jax.jacrev` as efficiently as possible, read on!
+
+Using a composition of two of these functions gives us a way to compute dense Hessian matrices:
+
+```{code-cell}
+def hessian(f):
+ return jax.jacfwd(jax.jacrev(f))
+
+H = hessian(f)(W)
+print("hessian, with shape", H.shape)
+print(H)
+```
+
+This shape makes sense: if you start with a function $f : \mathbb{R}^n \to \mathbb{R}^m$, then at a point $x \in \mathbb{R}^n$ you expect to get the shapes:
+
+* $f(x) \in \mathbb{R}^m$, the value of $f$ at $x$,
+* $\partial f(x) \in \mathbb{R}^{m \times n}$, the Jacobian matrix at $x$,
+* $\partial^2 f(x) \in \mathbb{R}^{m \times n \times n}$, the Hessian at $x$,
+
+and so on.
+
+To implement `hessian`, you could have used `jacfwd(jacrev(f))` or `jacrev(jacfwd(f))` or any other composition of these two. But forward-over-reverse is typically the most efficient. That's because in the inner Jacobian computation we're often differentiating a function wide Jacobian (maybe like a loss function $f : \mathbb{R}^n \to \mathbb{R}$), while in the outer Jacobian computation we're differentiating a function with a square Jacobian (since $\nabla f : \mathbb{R}^n \to \mathbb{R}^n$), which is where forward-mode wins out.
diff --git a/docs/jacobian-vector-products.md b/docs/jacobian-vector-products.md
new file mode 100644
index 000000000000..bbc678d02d1c
--- /dev/null
+++ b/docs/jacobian-vector-products.md
@@ -0,0 +1,358 @@
+---
+jupytext:
+ formats: md:myst
+ text_representation:
+ extension: .md
+ format_name: myst
+ format_version: 0.13
+ jupytext_version: 1.16.4
+kernelspec:
+ display_name: Python 3
+ name: python3
+---
+
+(advanced-guides-jvp-vjp)=
+# Forward- and reverse-mode autodiff in JAX
+
+## Jacobian-Vector products (JVPs, a.k.a. forward-mode autodiff)
+
+JAX includes efficient and general implementations of both forward- and reverse-mode automatic differentiation. The familiar {func}`jax.grad` function is built on reverse-mode, but to explain the difference between the two modes, and when each can be useful, you need a bit of math background.
+
+### JVPs in math
+
+Mathematically, given a function $f : \mathbb{R}^n \to \mathbb{R}^m$, the Jacobian of $f$ evaluated at an input point $x \in \mathbb{R}^n$, denoted $\partial f(x)$, is often thought of as a matrix in $\mathbb{R}^m \times \mathbb{R}^n$:
+
+$\qquad \partial f(x) \in \mathbb{R}^{m \times n}$.
+
+But you can also think of $\partial f(x)$ as a linear map, which maps the tangent space of the domain of $f$ at the point $x$ (which is just another copy of $\mathbb{R}^n$) to the tangent space of the codomain of $f$ at the point $f(x)$ (a copy of $\mathbb{R}^m$):
+
+$\qquad \partial f(x) : \mathbb{R}^n \to \mathbb{R}^m$.
+
+This map is called the [pushforward map](https://en.wikipedia.org/wiki/Pushforward_(differential)) of $f$ at $x$. The Jacobian matrix is just the matrix for this linear map on a standard basis.
+
+If you don't commit to one specific input point $x$, then you can think of the function $\partial f$ as first taking an input point and returning the Jacobian linear map at that input point:
+
+$\qquad \partial f : \mathbb{R}^n \to \mathbb{R}^n \to \mathbb{R}^m$.
+
+In particular, you can uncurry things so that given input point $x \in \mathbb{R}^n$ and a tangent vector $v \in \mathbb{R}^n$, you get back an output tangent vector in $\mathbb{R}^m$. We call that mapping, from $(x, v)$ pairs to output tangent vectors, the *Jacobian-vector product*, and write it as:
+
+$\qquad (x, v) \mapsto \partial f(x) v$
+
+### JVPs in JAX code
+
+Back in Python code, JAX's {func}`jax.jvp` function models this transformation. Given a Python function that evaluates $f$, JAX's {func}`jax.jvp` is a way to get a Python function for evaluating $(x, v) \mapsto (f(x), \partial f(x) v)$.
+
+```{code-cell}
+import jax
+import jax.numpy as jnp
+
+key = jax.random.key(0)
+
+# Initialize random model coefficients
+key, W_key, b_key = jax.random.split(key, 3)
+W = jax.random.normal(W_key, (3,))
+b = jax.random.normal(b_key, ())
+
+# Define a sigmoid function.
+def sigmoid(x):
+ return 0.5 * (jnp.tanh(x / 2) + 1)
+
+# Outputs probability of a label being true.
+def predict(W, b, inputs):
+ return sigmoid(jnp.dot(inputs, W) + b)
+
+# Build a toy dataset.
+inputs = jnp.array([[0.52, 1.12, 0.77],
+ [0.88, -1.08, 0.15],
+ [0.52, 0.06, -1.30],
+ [0.74, -2.49, 1.39]])
+
+# Isolate the function from the weight matrix to the predictions
+f = lambda W: predict(W, b, inputs)
+
+key, subkey = jax.random.split(key)
+v = jax.random.normal(subkey, W.shape)
+
+# Push forward the vector `v` along `f` evaluated at `W`
+y, u = jax.jvp(f, (W,), (v,))
+```
+
+In terms of [Haskell-like type signatures](https://wiki.haskell.org/Type_signature), you could write:
+
+```haskell
+jvp :: (a -> b) -> a -> T a -> (b, T b)
+```
+
+where `T a` is used to denote the type of the tangent space for `a`.
+
+In other words, `jvp` takes as arguments a function of type `a -> b`, a value of type `a`, and a tangent vector value of type `T a`. It gives back a pair consisting of a value of type `b` and an output tangent vector of type `T b`.
+
+The `jvp`-transformed function is evaluated much like the original function, but paired up with each primal value of type `a` it pushes along tangent values of type `T a`. For each primitive numerical operation that the original function would have applied, the `jvp`-transformed function executes a "JVP rule" for that primitive that both evaluates the primitive on the primals and applies the primitive's JVP at those primal values.
+
+That evaluation strategy has some immediate implications about computational complexity. Since we evaluate JVPs as we go, we don't need to store anything for later, and so the memory cost is independent of the depth of the computation. In addition, the FLOP cost of the `jvp`-transformed function is about 3x the cost of just evaluating the function (one unit of work for evaluating the original function, for example `sin(x)`; one unit for linearizing, like `cos(x)`; and one unit for applying the linearized function to a vector, like `cos_x * v`). Put another way, for a fixed primal point $x$, we can evaluate $v \mapsto \partial f(x) \cdot v$ for about the same marginal cost as evaluating $f$.
+
+That memory complexity sounds pretty compelling! So why don't we see forward-mode very often in machine learning?
+
+To answer that, first think about how you could use a JVP to build a full Jacobian matrix. If we apply a JVP to a one-hot tangent vector, it reveals one column of the Jacobian matrix, corresponding to the nonzero entry we fed in. So we can build a full Jacobian one column at a time, and to get each column costs about the same as one function evaluation. That will be efficient for functions with "tall" Jacobians, but inefficient for "wide" Jacobians.
+
+If you're doing gradient-based optimization in machine learning, you probably want to minimize a loss function from parameters in $\mathbb{R}^n$ to a scalar loss value in $\mathbb{R}$. That means the Jacobian of this function is a very wide matrix: $\partial f(x) \in \mathbb{R}^{1 \times n}$, which we often identify with the Gradient vector $\nabla f(x) \in \mathbb{R}^n$. Building that matrix one column at a time, with each call taking a similar number of FLOPs to evaluate the original function, sure seems inefficient! In particular, for training neural networks, where $f$ is a training loss function and $n$ can be in the millions or billions, this approach just won't scale.
+
+To do better for functions like this, you just need to use reverse-mode.
+
+## Vector-Jacobian products (VJPs, a.k.a. reverse-mode autodiff)
+
+Where forward-mode gives us back a function for evaluating Jacobian-vector products, which we can then use to build Jacobian matrices one column at a time, reverse-mode is a way to get back a function for evaluating vector-Jacobian products (equivalently Jacobian-transpose-vector products), which we can use to build Jacobian matrices one row at a time.
+
+### VJPs in math
+
+Let's again consider a function $f : \mathbb{R}^n \to \mathbb{R}^m$.
+Starting from our notation for JVPs, the notation for VJPs is pretty simple:
+
+$\qquad (x, v) \mapsto v \partial f(x)$,
+
+where $v$ is an element of the cotangent space of $f$ at $x$ (isomorphic to another copy of $\mathbb{R}^m$). When being rigorous, we should think of $v$ as a linear map $v : \mathbb{R}^m \to \mathbb{R}$, and when we write $v \partial f(x)$ we mean function composition $v \circ \partial f(x)$, where the types work out because $\partial f(x) : \mathbb{R}^n \to \mathbb{R}^m$. But in the common case we can identify $v$ with a vector in $\mathbb{R}^m$ and use the two almost interchangeably, just like we might sometimes flip between "column vectors" and "row vectors" without much comment.
+
+With that identification, we can alternatively think of the linear part of a VJP as the transpose (or adjoint conjugate) of the linear part of a JVP:
+
+$\qquad (x, v) \mapsto \partial f(x)^\mathsf{T} v$.
+
+For a given point $x$, we can write the signature as
+
+$\qquad \partial f(x)^\mathsf{T} : \mathbb{R}^m \to \mathbb{R}^n$.
+
+The corresponding map on cotangent spaces is often called the [pullback](https://en.wikipedia.org/wiki/Pullback_(differential_geometry))
+of $f$ at $x$. The key for our purposes is that it goes from something that looks like the output of $f$ to something that looks like the input of $f$, just like we might expect from a transposed linear function.
+
+### VJPs in JAX code
+
+Switching from math back to Python, the JAX function `vjp` can take a Python function for evaluating $f$ and give us back a Python function for evaluating the VJP $(x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))$.
+
+```{code-cell}
+from jax import vjp
+
+# Isolate the function from the weight matrix to the predictions
+f = lambda W: predict(W, b, inputs)
+
+y, vjp_fun = vjp(f, W)
+
+key, subkey = jax.random.split(key)
+u = jax.random.normal(subkey, y.shape)
+
+# Pull back the covector `u` along `f` evaluated at `W`
+v = vjp_fun(u)
+```
+
+In terms of [Haskell-like type signatures](https://wiki.haskell.org/Type_signature), we could write
+
+```haskell
+vjp :: (a -> b) -> a -> (b, CT b -> CT a)
+```
+
+where we use `CT a` to denote the type for the cotangent space for `a`. In words, `vjp` takes as arguments a function of type `a -> b` and a point of type `a`, and gives back a pair consisting of a value of type `b` and a linear map of type `CT b -> CT a`.
+
+This is great because it lets us build Jacobian matrices one row at a time, and the FLOP cost for evaluating $(x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))$ is only about three times the cost of evaluating $f$. In particular, if we want the gradient of a function $f : \mathbb{R}^n \to \mathbb{R}$, we can do it in just one call. That's how {func}`jax.grad` is efficient for gradient-based optimization, even for objectives like neural network training loss functions on millions or billions of parameters.
+
+There's a cost, though the FLOPs are friendly, memory scales with the depth of the computation. Also, the implementation is traditionally more complex than that of forward-mode, though JAX has some tricks up its sleeve (that's a story for a future notebook!).
+
+For more on how reverse-mode works, check out [this tutorial video from the Deep Learning Summer School in 2017](http://videolectures.net/deeplearning2017_johnson_automatic_differentiation/).
+
+## Vector-valued gradients with VJPs
+
+If you're interested in taking vector-valued gradients (like `tf.gradients`):
+
+```{code-cell}
+def vgrad(f, x):
+ y, vjp_fn = jax.vjp(f, x)
+ return vjp_fn(jnp.ones(y.shape))[0]
+
+print(vgrad(lambda x: 3*x**2, jnp.ones((2, 2))))
+```
+
+## Hessian-vector products using both forward- and reverse-mode
+
+In a previous section, you implemented a Hessian-vector product function just using reverse-mode (assuming continuous second derivatives):
+
+```{code-cell}
+def hvp(f, x, v):
+ return jax.grad(lambda x: jnp.vdot(jax.grad(f)(x), v))(x)
+```
+
+That's efficient, but you can do even better and save some memory by using forward-mode together with reverse-mode.
+
+Mathematically, given a function $f : \mathbb{R}^n \to \mathbb{R}$ to differentiate, a point $x \in \mathbb{R}^n$ at which to linearize the function, and a vector $v \in \mathbb{R}^n$, the Hessian-vector product function we want is:
+
+$(x, v) \mapsto \partial^2 f(x) v$
+
+Consider the helper function $g : \mathbb{R}^n \to \mathbb{R}^n$ defined to be the derivative (or gradient) of $f$, namely $g(x) = \partial f(x)$. All you need is its JVP, since that will give us:
+
+$(x, v) \mapsto \partial g(x) v = \partial^2 f(x) v$.
+
+We can translate that almost directly into code:
+
+```{code-cell}
+# forward-over-reverse
+def hvp(f, primals, tangents):
+ return jax.jvp(jax.grad(f), primals, tangents)[1]
+```
+
+Even better, since you didn't have to call {func}`jnp.dot` directly, this `hvp` function works with arrays of any shape and with arbitrary container types (like vectors stored as nested lists/dicts/tuples), and doesn't even have a dependence on {mod}`jax.numpy`.
+
+Here's an example of how to use it:
+
+```{code-cell}
+def f(X):
+ return jnp.sum(jnp.tanh(X)**2)
+
+key, subkey1, subkey2 = jax.random.split(key, 3)
+X = jax.random.normal(subkey1, (30, 40))
+V = jax.random.normal(subkey2, (30, 40))
+
+def hessian(f):
+ return jax.jacfwd(jax.jacrev(f))
+
+ans1 = hvp(f, (X,), (V,))
+ans2 = jnp.tensordot(hessian(f)(X), V, 2)
+
+print(jnp.allclose(ans1, ans2, 1e-4, 1e-4))
+```
+
+Another way you might consider writing this is using reverse-over-forward:
+
+```{code-cell}
+# Reverse-over-forward
+def hvp_revfwd(f, primals, tangents):
+ g = lambda primals: jax.jvp(f, primals, tangents)[1]
+ return jax.grad(g)(primals)
+```
+
+That's not quite as good, though, because forward-mode has less overhead than reverse-mode, and since the outer differentiation operator here has to differentiate a larger computation than the inner one, keeping forward-mode on the outside works best:
+
+```{code-cell}
+# Reverse-over-reverse, only works for single arguments
+def hvp_revrev(f, primals, tangents):
+ x, = primals
+ v, = tangents
+ return jax.grad(lambda x: jnp.vdot(jax.grad(f)(x), v))(x)
+
+
+print("Forward over reverse")
+%timeit -n10 -r3 hvp(f, (X,), (V,))
+print("Reverse over forward")
+%timeit -n10 -r3 hvp_revfwd(f, (X,), (V,))
+print("Reverse over reverse")
+%timeit -n10 -r3 hvp_revrev(f, (X,), (V,))
+
+print("Naive full Hessian materialization")
+%timeit -n10 -r3 jnp.tensordot(jax.hessian(f)(X), V, 2)
+```
+
+## Composing VJPs, JVPs, and `jax.vmap`
+
+## Jacobian-Matrix and Matrix-Jacobian products
+
+Now that you have {func}`jax.jvp` and {func}`jax.vjp` transformations that give you functions to push-forward or pull-back single vectors at a time, you can use JAX's {func}`jax.vmap` [transformation](https://github.com/jax-ml/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, you can use that to write fast matrix-Jacobian and Jacobian-matrix products:
+
+```{code-cell}
+# Isolate the function from the weight matrix to the predictions
+f = lambda W: predict(W, b, inputs)
+
+# Pull back the covectors `m_i` along `f`, evaluated at `W`, for all `i`.
+# First, use a list comprehension to loop over rows in the matrix M.
+def loop_mjp(f, x, M):
+ y, vjp_fun = jax.vjp(f, x)
+ return jnp.vstack([vjp_fun(mi) for mi in M])
+
+# Now, use vmap to build a computation that does a single fast matrix-matrix
+# multiply, rather than an outer loop over vector-matrix multiplies.
+def vmap_mjp(f, x, M):
+ y, vjp_fun = jax.vjp(f, x)
+ outs, = jax.vmap(vjp_fun)(M)
+ return outs
+
+key = jax.random.key(0)
+num_covecs = 128
+U = jax.random.normal(key, (num_covecs,) + y.shape)
+
+loop_vs = loop_mjp(f, W, M=U)
+print('Non-vmapped Matrix-Jacobian product')
+%timeit -n10 -r3 loop_mjp(f, W, M=U)
+
+print('\nVmapped Matrix-Jacobian product')
+vmap_vs = vmap_mjp(f, W, M=U)
+%timeit -n10 -r3 vmap_mjp(f, W, M=U)
+
+assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Matrix-Jacobian Products should be identical'
+```
+
+```{code-cell}
+def loop_jmp(f, W, M):
+ # jvp immediately returns the primal and tangent values as a tuple,
+ # so we'll compute and select the tangents in a list comprehension
+ return jnp.vstack([jax.jvp(f, (W,), (mi,))[1] for mi in M])
+
+def vmap_jmp(f, W, M):
+ _jvp = lambda s: jax.jvp(f, (W,), (s,))[1]
+ return jax.vmap(_jvp)(M)
+num_vecs = 128
+S = jax.random.normal(key, (num_vecs,) + W.shape)
+
+loop_vs = loop_jmp(f, W, M=S)
+print('Non-vmapped Jacobian-Matrix product')
+%timeit -n10 -r3 loop_jmp(f, W, M=S)
+vmap_vs = vmap_jmp(f, W, M=S)
+print('\nVmapped Jacobian-Matrix product')
+%timeit -n10 -r3 vmap_jmp(f, W, M=S)
+
+assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Jacobian-Matrix products should be identical'
+```
+
+## The implementation of `jax.jacfwd` and `jax.jacrev`
+
+Now that we've seen fast Jacobian-matrix and matrix-Jacobian products, it's not hard to guess how to write {func}`jax.jacfwd` and {func}`jax.jacrev`. We just use the same technique to push-forward or pull-back an entire standard basis (isomorphic to an identity matrix) at once.
+
+```{code-cell}
+from jax import jacrev as builtin_jacrev
+
+def our_jacrev(f):
+ def jacfun(x):
+ y, vjp_fun = jax.vjp(f, x)
+ # Use vmap to do a matrix-Jacobian product.
+ # Here, the matrix is the Euclidean basis, so we get all
+ # entries in the Jacobian at once.
+ J, = jax.vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))
+ return J
+ return jacfun
+
+assert jnp.allclose(builtin_jacrev(f)(W), our_jacrev(f)(W)), 'Incorrect reverse-mode Jacobian results!'
+```
+
+```{code-cell}
+from jax import jacfwd as builtin_jacfwd
+
+def our_jacfwd(f):
+ def jacfun(x):
+ _jvp = lambda s: jax.jvp(f, (x,), (s,))[1]
+ Jt = jax.vmap(_jvp, in_axes=1)(jnp.eye(len(x)))
+ return jnp.transpose(Jt)
+ return jacfun
+
+assert jnp.allclose(builtin_jacfwd(f)(W), our_jacfwd(f)(W)), 'Incorrect forward-mode Jacobian results!'
+```
+
+Interestingly, the [Autograd](https://github.com/hips/autograd) library couldn't do this. The [implementation](https://github.com/HIPS/autograd/blob/96a03f44da43cd7044c61ac945c483955deba957/autograd/differential_operators.py#L60) of reverse-mode `jacobian` in Autograd had to pull back one vector at a time with an outer-loop `map`. Pushing one vector at a time through the computation is much less efficient than batching it all together with {func}`jax.vmap`.
+
+Another thing that Autograd couldn't do is {func}`jax.jit`. Interestingly, no matter how much Python dynamism you use in your function to be differentiated, we could always use {func}`jax.jit` on the linear part of the computation. For example:
+
+```{code-cell}
+def f(x):
+ try:
+ if x < 3:
+ return 2 * x ** 3
+ else:
+ raise ValueError
+ except ValueError:
+ return jnp.pi * x
+
+y, f_vjp = jax.vjp(f, 4.)
+print(jax.jit(f_vjp)(1.))
+```
diff --git a/docs/jax-primitives.md b/docs/jax-primitives.md
index 43dea9bced1a..92579e06ab16 100644
--- a/docs/jax-primitives.md
+++ b/docs/jax-primitives.md
@@ -321,7 +321,7 @@ assert api.jit(lambda x, y: square_add_prim(x, y),
### Forward differentiation
-JAX implements forward differentiation in the form of a Jacobian-Vector Product (JVP) (you can learn more about it in {ref}`advanced-autodiff`).
+JAX implements forward differentiation in the form of a Jacobian-Vector Product (JVP) (you can learn more about it in {ref}`advanced-guides-jvp-vjp`).
If you attempt to compute the `jvp` function, you'll get an error because you have not yet told JAX how to differentiate the `multiply_add` primitive.
diff --git a/docs/jax.experimental.pallas.mosaic_gpu.rst b/docs/jax.experimental.pallas.mosaic_gpu.rst
index 0a639f1e9d07..bb3340f248ee 100644
--- a/docs/jax.experimental.pallas.mosaic_gpu.rst
+++ b/docs/jax.experimental.pallas.mosaic_gpu.rst
@@ -14,6 +14,7 @@ Classes
CompilerParams
MemorySpace
Layout
+ SemaphoreType
SwizzleTransform
TilingTransform
TransposeTransform
diff --git a/docs/jax.experimental.pallas.rst b/docs/jax.experimental.pallas.rst
index 12d5129a9d34..17242e1551f5 100644
--- a/docs/jax.experimental.pallas.rst
+++ b/docs/jax.experimental.pallas.rst
@@ -47,6 +47,7 @@ Functions
debug_check
debug_print
dot
+ get_global
loop
max_contiguous
multiple_of
diff --git a/docs/jax.experimental.sparse.rst b/docs/jax.experimental.sparse.rst
index 37ef8ae43d67..f021cce3452f 100644
--- a/docs/jax.experimental.sparse.rst
+++ b/docs/jax.experimental.sparse.rst
@@ -1,12 +1,6 @@
``jax.experimental.sparse`` module
==================================
-.. note::
-
- The methods in ``jax.experimental.sparse`` are experimental reference
- implementations, and not recommended for use in performance-critical
- applications.
-
.. automodule:: jax.experimental.sparse
.. currentmodule:: jax.experimental.sparse
diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst
index aad146eb71c5..6cf14389adcd 100644
--- a/docs/jax.scipy.rst
+++ b/docs/jax.scipy.rst
@@ -463,6 +463,7 @@ jax.scipy.stats.poisson
logpmf
pmf
cdf
+ entropy
jax.scipy.stats.t
~~~~~~~~~~~~~~~~~
diff --git a/docs/jax.sharding.rst b/docs/jax.sharding.rst
index 12760d62ddb3..0146398cb12d 100644
--- a/docs/jax.sharding.rst
+++ b/docs/jax.sharding.rst
@@ -16,9 +16,6 @@ Classes
.. autoclass:: NamedSharding
:members:
:show-inheritance:
-.. autoclass:: PmapSharding
- :members:
- :show-inheritance:
.. autoclass:: PartitionSpec
:members:
.. autoclass:: Mesh
diff --git a/docs/jep/263-prng.md b/docs/jep/263-prng.md
index 7ef10ae0e9c4..ff58d1b7b94e 100644
--- a/docs/jep/263-prng.md
+++ b/docs/jep/263-prng.md
@@ -12,7 +12,7 @@ We want a PRNG design that
As a corollary of these we believe the design should be functional. Another corollary is that, at least given current hardware constraints, we’re going to do the PRNG in software.
> TLDR
-> **JAX PRNG = [Threefry counter PRNG](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf) + a functional array-oriented [splitting model](https://dl.acm.org/citation.cfm?id=2503784)**
+> **JAX PRNG = [Threefry counter PRNG](https://thesalmons.org/john/random123/papers/random123sc11.pdf) + a functional array-oriented [splitting model](https://dl.acm.org/doi/10.1145/2503778.2503784)**
## Contents
* [Three programming models and toy example programs](#three-programming-models-and-toy-example-programs)
@@ -79,7 +79,7 @@ Explicit threading is inconvenient for the programmer. But worse, it hasn’t ac
In short, making the code functional by explicitly threading state isn’t enough to achieve our expressiveness (#1) and performance (#5, #6) goals.
-The key problem in both the previous models is that there’s too much sequencing. To reduce the amount of sequential dependence we use **functional [splittable](https://dl.acm.org/citation.cfm?id=2503784) PRNGs**. Splitting is a mechanism to ‘fork’ a new PRNG state into two PRNG states while maintaining the usual desirable PRNG properties (the two new streams are computationally parallelizable and produce independent random values, i.e. they behave like [multistreams](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf)).
+The key problem in both the previous models is that there’s too much sequencing. To reduce the amount of sequential dependence we use **functional [splittable](https://dl.acm.org/doi/10.1145/2503778.2503784) PRNGs**. Splitting is a mechanism to ‘fork’ a new PRNG state into two PRNG states while maintaining the usual desirable PRNG properties (the two new streams are computationally parallelizable and produce independent random values, i.e. they behave like [multistreams](https://thesalmons.org/john/random123/papers/random123sc11.pdf)).
```python
def foo(rng_1):
@@ -105,7 +105,7 @@ The example doesn’t show it, but as a consequence of the choice (2) the only w
## Design
-We can use the *counter-based PRNG* design, and in particular the Threefry hash function, as described in [Parallel random numbers: as easy as 1, 2, 3](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf). We use the counter to achieve efficient vectorization: for a given key we can generate an array of values in a vectorized fashion by mapping the hash function over a range of integers [k + 1, …, k + sample_size]. We use the key together with the hash function to implement [splittable PRNGs](https://dl.acm.org/citation.cfm?id=2503784): that is, splitting is a way to generate two new keys from an existing one.
+We can use the *counter-based PRNG* design, and in particular the Threefry hash function, as described in [Parallel random numbers: as easy as 1, 2, 3](https://thesalmons.org/john/random123/papers/random123sc11.pdf). We use the counter to achieve efficient vectorization: for a given key we can generate an array of values in a vectorized fashion by mapping the hash function over a range of integers [k + 1, …, k + sample_size]. We use the key together with the hash function to implement [splittable PRNGs](https://dl.acm.org/doi/10.1145/2503778.2503784): that is, splitting is a way to generate two new keys from an existing one.
```haskell
type Sample = Int256
diff --git a/docs/migrate_pmap.md b/docs/migrate_pmap.md
index d48aa8fb28cc..be0080577c7e 100644
--- a/docs/migrate_pmap.md
+++ b/docs/migrate_pmap.md
@@ -92,6 +92,49 @@ Mesh('y': 4, axis_types=(Auto,))
## Performance implications
+### `int` indexing into sharded arrays
+
+The new implementation of `jax.pmap` uses `NamedSharding` instead of the legacy
+`PmapSharding`. We've observe a common pattern with the old `jax.pmap` where
+users shard stacked copies of an array to replicate (e.g., via
+`jax.device_put_replicated`). These "sharded-but-really-replicated" arrays
+suffer unnecessary communication overhead when `int` indexing (e.g., `x[0]`)
+because JAX does not know the arrays are actually replicated. For a more
+thorough discussion, please see [Appendix A](#appendix-a).
+
+#### Option 1: Prevent unintended sharding (recommended)
+Avoid creating the leading sharded dimension entirely.
+
+- Use `jax.pmap`'s `out_axes=None` for arguments that should remain replicated.
+The output will be fully replicated (e.g., `P(None, None)`), making access
+cheap.
+- For inputs: When using `jax.device_put`, specify `jax.P()` (fully replicated)
+in the partition spec rather than relying on utilities that stack and shard.
+(Note: `jax.device_put_replicated` and `jax.device_put_sharded` are deprecated
+because they confusingly produce sharded arrays rather than replicated ones).
+
+#### Option 2: Access local data directly
+If you must work with a sharded array (or want potentially fewer changes to
+code), you can access the local data shard directly without triggering JAX's
+distributed consistency checks. Note that this is only recommended when bringing
+data back to host (e.g., for logging, checkpointing). Instead of `x[0]`, use
+`addressable_shards`:
+
+```python
+# Old slow way:
+# result = x[0]
+
+# New fast way:
+# x.addressable_shards is a list of shards on the current process.
+# We grab the first one, extract the data, and remove the leading dimension.
+result = x.addressable_shards[0].data.squeeze(0)
+```
+
+In the example of `x` with shape `(8, 3, 4)`, `x.addressable_shards[0].data`
+returns the local chunk of shape `(1, 3, 4)`. Calling `.squeeze(0)` results in
+the desired `(3, 4)` shape without any cross-device communication. Both
+solutions will eliminate the `_gather` operations seen in profiling.
+
### Host local array to global array round-trip conversion
In multi-process JAX programs (i.e., `jax.process_count() > 1`), arrays might be
@@ -104,23 +147,6 @@ host-local array when returning to user code.
This round-trip conversion cannot be avoided, so if the performance penalty is
too great, we recommend migrating your code to `jax.shard_map`.
-### `int` array indexing
-
-Indexing into a sharded array with an int (e.g., `arr[0]`) may now execute a
-rank reduction computation. Depending on your use case, there may be
-workarounds:
-
-1. In a typical training loop, we might use a `jax.pmap`ed update function to
- operate on / carry training state and grab resulting metrics from the first
- `jax.pmap`'ed device for logging. In this case, it may be possible to
- use `None` for the relevant `in_axes` and `out_axes` passed to `jax.pmap`.
- This lets `jax.pmap` handle replication and will return an
- appropriately-shaped result that looks like it's from a single device for,
- say, logging metrics.
-2. More generally, you can get the first shard of data without a reshape via
- `arr[0:1]` or `arr.addressable_shards[0].data`. Note that this will have a
- leading `(1,)` dimension that your code will need to handle.
-
## Migrating to `jax.shard_map`
In many cases, users can migrate from `jax.pmap` to `jax.jit(jax.shard_map)` by
@@ -132,4 +158,111 @@ dispatch path as in the `jax.shard_map` implementation of `jax.pmap` and can
often be overlapped with compute or be called infrequently (i.e., before a train
loop and for occasionally grabbing metrics).
+(appendix-a)=
+## Appendix A: More details about `int` indexing into sharded arrays.
+
+### What should `x[0]` return?
+
+In **NumPy**, `x[0]` returns a rank-reduced array representing the first slice
+along the first dimension. For example, if `x = np.ones((8, 3, 4))`, then `x[0]`
+returns an array of shape `(3, 4)`.
+
+In **JAX** (`jax.numpy`), `x[0]` semantically works the same way: it returns the
+rank-reduced slice of the logical array `x`. However, performance depends on how
+`x` is sharded or replicated across devices. Consider an array `x` with shape
+`(8, 3, 4)` distributed across 8 devices (using `jax.P` as the short name for
+`jax.sharding.PartitionSpec`P):
+
+1. **Fully Replicated:** `jax.P(None, None, None)`
+ If `x` is fully replicated, every device holds a complete copy of the `(8,
+ 3, 4)` array. `x[0]` will have the shape `(3, 4)` and a partition spec
+ `jax.P(None, None)`. Since every device already has `x`, this operation will
+ slice on each device independently and requires **no communication**.
+
+2. **Sharded on Non-Leading Dimension:** `jax.P(None, 'x', None)`
+ If `x` is sharded along the second dimension, `x[0]` results in shape `(3,
+ 4)` with partition spec `jax.P('x', None)`. Since the first dimension (the
+ one being sliced) is unsharded, this operation also requires **no
+ communication**.
+
+3. **Sharded on Leading Dimension:** `jax.P('x', None, None)`
+ If `x` is sharded along the first dimension, `x[0]` results in shape `(3,
+ 4)` with partition spec `jax.P(None, None)`.
+ * **The Issue:** Because the first dimension is sharded, the data for
+ `x[0]` physically resides *only* on the first device. To satisfy the
+ output sharding `jax.P(None, None)` (which implies replication), JAX
+ must broadcast the data from the first device to all other devices. This
+ requires **communication**; JAX will gather the *entire* array of shape
+ `(8, 3, 4)` to each device and then take a slice.
+
+### The Common Performance Pitfall
+
+A common pattern among `jax.pmap` users involves arrays that are **semantically
+replicated** (the user intends for them to be identical everywhere) but are
+**physically sharded** (stacked along the leading dimension).
+
+This happens implicitly (e.g., via `jax.pmap(..., out_axes=0)`) or explicitly
+(e.g., via `jax.device_put_replicated`). Users often try to retrieve metrics or
+checkpoints by calling `unreplicate` or `x[0]`, assuming it is a cheap
+operation.
+
+#### Example: The "Unreplicate" Anti-Pattern
+
+```python
+from flax import jax_utils
+import jax.numpy as jnp
+import jax
+
+# jax_utils.replicate calls jax.device_put_replicated.
+# This stacks num_devices copies and SHARDS them over the stacked dimension.
+# Logical Shape: (8, 3, 4) | Sharding: P('x', None, None)
+train_state = jax_utils.replicate({'params': jnp.zeros((3, 4))})
+
+# out_axes=0 by default, so the output remains sharded along dim 0.
+train_step_pmapped = jax.pmap(lambda x: x)
+
+# jax_utils.unreplicate performs a jax.tree_map(lambda x: x[0], tree).
+# Users do this to grab metrics, log param statistics, checkpoint, etc.
+train_state = jax_utils.unreplicate(train_step_pmapped(train_state))
+```
+
+#### The Consequence
+Even though the user knows `train_state` contains identical data on every
+device, JAX sees an array with `shape (8, 3, 4)` and spec `jax.P('x', None,
+None)` i.e., an array that is sharded along its leading dimension. JAX cannot
+safely assume the data is identical on each device. Therefore, `x[0]` triggers a
+gather of the entire array to all devices before slicing to ensure correctness.
+This unnecessary communication causes performance degradation (visible as
+_gather operations in a stack trace).
+
+```
+train
+ └─ jax_utils.py:48 unreplicate
+ └─ tree_util.py:354 tree_map
+ └─ jax_utils.py:50 (performing x[0])
+ └─ array.py:335 __getitem__
+ └─ indexing.py:734 rewriting_take
+ │
+ ▼
+ └─ indexing.py:784 _gather
+ └─ slicing.py:324 gather
+ └─ PjitFunction(gather)
+```
+
+### Why was "Old Pmap" Fast?
+Historically, `pmap` used `PmapSharding`, which had a fast-path optimization in
+`jax.Array`'s `__getitem__` allowing it to return an array with a
+`SingleDeviceSharding` (data residing on only one device).
+
+However, current JAX uses `NamedSharding`. We do not strictly replicate the
+legacy behavior because it breaks the semantics of array indexing. If we allowed
+`x[0]` to return a `SingleDeviceSharding` array in a general context (e.g., in
+the middle of a train step instead of when trying to bring data back to host for
+reporting), only one device would have data while others would have nothing.
+This is computationally problematic for subsequent operations.
+
+The slowdown users experience now is JAX enforcing correct semantics: if you ask
+for `x[0]` from an array sharded along its leading dimension, you get a fully
+replicated result available on all devices, which requires communication.
+
diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb
index bbf8a29fa286..38d2a0e84383 100644
--- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb
+++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb
@@ -564,6 +564,377 @@
"For more details on indexed array updates, see the [documentation for the `.at` property](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at)."
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "(jax-jit-class-methods)=\n",
+ "## 🔪 Using `jax.jit` with class methods\n",
+ "\n",
+ "Most examples of [`jax.jit`](https://docs.jax.dev/en/latest/_autosummary/jax.jit.html) concern decorating stand-alone Python functions, but decorating a method within a class introduces some complication. For example, consider the following simple class, where we've used a standard `jax.jit` annotation on a method:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import jax.numpy as jnp\n",
+ "from jax import jit\n",
+ "\n",
+ "class CustomClass:\n",
+ " def __init__(self, x: jnp.ndarray, mul: bool):\n",
+ " self.x = x\n",
+ " self.mul = mul\n",
+ "\n",
+ " @jit # <---- How to do this correctly?\n",
+ " def calc(self, y):\n",
+ " if self.mul:\n",
+ " return self.x * y\n",
+ " return y"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "However, this approach will result in an error when you attempt to call this method:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {
+ "tags": [
+ "raises-exception"
+ ]
+ },
+ "outputs": [
+ {
+ "ename": "TypeError",
+ "evalue": "Error interpreting argument to as an abstract array. The problematic value is of type and was passed to the function at path self.\nThis typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
+ "\u001b[31mTypeError\u001b[39m Traceback (most recent call last)",
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[8]\u001b[39m\u001b[32m, line 2\u001b[39m\n\u001b[32m 1\u001b[39m c = CustomClass(\u001b[32m2\u001b[39m, \u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[32m----> \u001b[39m\u001b[32m2\u001b[39m \u001b[43mc\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcalc\u001b[49m\u001b[43m(\u001b[49m\u001b[32;43m3\u001b[39;49m\u001b[43m)\u001b[49m\n",
+ " \u001b[31m[... skipping hidden 5 frame]\u001b[39m\n",
+ "\u001b[36mFile \u001b[39m\u001b[32m~/.local/share/mamba/envs/jax-dev/lib/python3.12/site-packages/jax/_src/pjit.py:659\u001b[39m, in \u001b[36m_infer_input_type\u001b[39m\u001b[34m(fun, dbg_fn, explicit_args)\u001b[39m\n\u001b[32m 657\u001b[39m dbg = dbg_fn()\n\u001b[32m 658\u001b[39m arg_description = \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mpath \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdbg.arg_names[i]\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mif\u001b[39;00m\u001b[38;5;250m \u001b[39mdbg.arg_names\u001b[38;5;250m \u001b[39m\u001b[38;5;129;01mis\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;129;01mnot\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01melse\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[33m'\u001b[39m\u001b[33munknown\u001b[39m\u001b[33m'\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m \u001b[38;5;66;03m# pytype: disable=name-error\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m659\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\n\u001b[32m 660\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mError interpreting argument to \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfun\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m as an abstract array.\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 661\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m The problematic value is of type \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(x)\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m and was passed to\u001b[39m\u001b[33m\"\u001b[39m \u001b[38;5;66;03m# pytype: disable=name-error\u001b[39;00m\n\u001b[32m 662\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m the function at \u001b[39m\u001b[38;5;132;01m{\u001b[39;00marg_description\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[33m\"\u001b[39m\n\u001b[32m 663\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mThis typically means that a jit-wrapped function was called with a non-array\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 664\u001b[39m \u001b[33m\"\u001b[39m\u001b[33m argument, and this argument was not marked as static using the\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 665\u001b[39m \u001b[33m\"\u001b[39m\u001b[33m static_argnums or static_argnames parameters of jax.jit.\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 666\u001b[39m ) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 667\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m config.mutable_array_checks.value:\n\u001b[32m 668\u001b[39m check_no_aliased_ref_args(dbg_fn, avals, explicit_args)\n",
+ "\u001b[31mTypeError\u001b[39m: Error interpreting argument to as an abstract array. The problematic value is of type and was passed to the function at path self.\nThis typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit."
+ ]
+ }
+ ],
+ "source": [
+ "c = CustomClass(2, True)\n",
+ "c.calc(3)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The problem is that the first argument to the function is `self`, which has type `CustomClass`, and JAX does not know how to handle this type. There are three basic strategies we might use in this case, and we'll discuss them below."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Strategy 1: JIT-compiled helper function\n",
+ "\n",
+ "The most straightforward approach is to create a helper function external to the class that can be JIT-decorated in the normal way. For example:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from functools import partial\n",
+ "\n",
+ "class CustomClass:\n",
+ " def __init__(self, x: jnp.ndarray, mul: bool):\n",
+ " self.x = x\n",
+ " self.mul = mul\n",
+ "\n",
+ " def calc(self, y):\n",
+ " return _calc(self.mul, self.x, y)\n",
+ "\n",
+ "@partial(jit, static_argnums=0)\n",
+ "def _calc(mul, x, y):\n",
+ " if mul:\n",
+ " return x * y\n",
+ " return y"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The result will work as expected:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "6\n"
+ ]
+ }
+ ],
+ "source": [
+ "c = CustomClass(2, True)\n",
+ "print(c.calc(3))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The benefit of such an approach is that it is simple, explicit, and it avoids the need to teach JAX how to handle objects of type `CustomClass`. However, you may wish to keep all the method logic in the same place."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Strategy 2: Marking `self` as static\n",
+ "\n",
+ "Another common pattern is to use `static_argnums` to mark the `self` argument as static. But this must be done with care to avoid unexpected results. You may be tempted to simply do this:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class CustomClass:\n",
+ " def __init__(self, x: jnp.ndarray, mul: bool):\n",
+ " self.x = x\n",
+ " self.mul = mul\n",
+ "\n",
+ " # WARNING: this example is broken, as we'll see below. Don't copy & paste!\n",
+ " @partial(jit, static_argnums=0)\n",
+ " def calc(self, y):\n",
+ " if self.mul:\n",
+ " return self.x * y\n",
+ " return y"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "If you call the method, it will no longer raise an error:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "6\n"
+ ]
+ }
+ ],
+ "source": [
+ "c = CustomClass(2, True)\n",
+ "print(c.calc(3))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "However, there is a catch: if you mutate the object after the first method call, the subsequent method call may return an incorrect result:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "6\n"
+ ]
+ }
+ ],
+ "source": [
+ "c.mul = False\n",
+ "print(c.calc(3)) # Should print 3"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Why is this? When you mark an object as static, it will effectively be used as a dictionary key in JIT's internal compilation cache, meaning its hash (i.e. `hash(obj)`) equality (i.e. `obj1 == obj2`) and object identity (i.e. `obj1 is obj2`) will be assumed to have consistent behavior. The default `__hash__` for a custom object is its object ID, and so JAX has no way of knowing that a mutated object should trigger a re-compilation.\n",
+ "\n",
+ "You can partially address this by defining an appropriate `__hash__` and `__eq__` methods for your object; for example:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class CustomClass:\n",
+ " def __init__(self, x: jnp.ndarray, mul: bool):\n",
+ " self.x = x\n",
+ " self.mul = mul\n",
+ "\n",
+ " @partial(jit, static_argnums=0)\n",
+ " def calc(self, y):\n",
+ " if self.mul:\n",
+ " return self.x * y\n",
+ " return y\n",
+ "\n",
+ " def __hash__(self):\n",
+ " return hash((self.x, self.mul))\n",
+ "\n",
+ " def __eq__(self, other):\n",
+ " return (isinstance(other, CustomClass) and\n",
+ " (self.x, self.mul) == (other.x, other.mul))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "(see the [`object.__hash__`](https://docs.python.org/3/reference/datamodel.html#object.__hash__) documentation for more discussion of the requirements\n",
+ "when overriding `__hash__`).\n",
+ "\n",
+ "This should work correctly with JIT and other transforms **so long as you never mutate your object**. Mutations of objects used as hash keys lead to several subtle problems, which is why for example mutable Python containers (e.g. [`dict`](https://docs.python.org/3/library/stdtypes.html#dict), [`list`](https://docs.python.org/3/library/stdtypes.html#list)) don't define `__hash__`, while their immutable counterparts (e.g. [`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple)) do.\n",
+ "\n",
+ "If your class relies on in-place mutations (such as setting `self.attr = ...` within its methods), then your object is not really \"static\" and marking it as such may lead to problems. Fortunately, there's another option for this case."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Strategy 3: Making `CustomClass` a PyTree\n",
+ "\n",
+ "The most flexible approach to correctly JIT-compiling a class method is to register the type as a custom PyTree object; see [Custom pytree nodes](https://docs.jax.dev/en/latest/custom_pytrees.html#pytrees-custom-pytree-nodes). This lets you specify exactly which components of the class should be treated as static and which should be\n",
+ "treated as dynamic. Here's how it might look:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class CustomClass:\n",
+ " def __init__(self, x: jnp.ndarray, mul: bool):\n",
+ " self.x = x\n",
+ " self.mul = mul\n",
+ "\n",
+ " @jit\n",
+ " def calc(self, y):\n",
+ " if self.mul:\n",
+ " return self.x * y\n",
+ " return y\n",
+ "\n",
+ " def _tree_flatten(self):\n",
+ " children = (self.x,) # arrays / dynamic values\n",
+ " aux_data = {'mul': self.mul} # static values\n",
+ " return (children, aux_data)\n",
+ "\n",
+ " @classmethod\n",
+ " def _tree_unflatten(cls, aux_data, children):\n",
+ " return cls(*children, **aux_data)\n",
+ "\n",
+ "from jax import tree_util\n",
+ "tree_util.register_pytree_node(CustomClass,\n",
+ " CustomClass._tree_flatten,\n",
+ " CustomClass._tree_unflatten)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This is certainly more involved, but it solves all the issues associated with the simpler approaches used above:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "6\n"
+ ]
+ }
+ ],
+ "source": [
+ "c = CustomClass(2, True)\n",
+ "print(c.calc(3))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "3\n"
+ ]
+ }
+ ],
+ "source": [
+ "c.mul = False # mutation is detected\n",
+ "print(c.calc(3))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "6\n"
+ ]
+ }
+ ],
+ "source": [
+ "c = CustomClass(jnp.array(2), True) # non-hashable x is supported\n",
+ "print(c.calc(3))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "So long as your `tree_flatten` and `tree_unflatten` functions correctly handle all relevant attributes in the class, you should be able to use objects of this type directly as arguments to JIT-compiled functions, without any special annotations."
+ ]
+ },
{
"cell_type": "markdown",
"metadata": {
@@ -1231,7 +1602,7 @@
"formats": "ipynb,md:myst"
},
"kernelspec": {
- "display_name": "Python 3",
+ "display_name": "jax-dev",
"language": "python",
"name": "python3"
},
@@ -1245,15 +1616,10 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.8.2 (v3.8.2:7b3ab5921f, Feb 24 2020, 17:52:18) \n[Clang 6.0 (clang-600.0.57)]"
+ "version": "3.12.12"
},
"mystnb": {
"render_error_lexer": "none"
- },
- "vscode": {
- "interpreter": {
- "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
- }
}
},
"nbformat": 4,
diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md
index f675b65d9f45..40a40640d8ef 100644
--- a/docs/notebooks/Common_Gotchas_in_JAX.md
+++ b/docs/notebooks/Common_Gotchas_in_JAX.md
@@ -7,7 +7,7 @@ jupytext:
format_version: 0.13
jupytext_version: 1.16.4
kernelspec:
- display_name: Python 3
+ display_name: jax-dev
language: python
name: python3
---
@@ -285,6 +285,191 @@ print(new_jax_array)
For more details on indexed array updates, see the [documentation for the `.at` property](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at).
++++
+
+(jax-jit-class-methods)=
+## 🔪 Using `jax.jit` with class methods
+
+Most examples of [`jax.jit`](https://docs.jax.dev/en/latest/_autosummary/jax.jit.html) concern decorating stand-alone Python functions, but decorating a method within a class introduces some complication. For example, consider the following simple class, where we've used a standard `jax.jit` annotation on a method:
+
+```{code-cell} ipython3
+import jax.numpy as jnp
+from jax import jit
+
+class CustomClass:
+ def __init__(self, x: jnp.ndarray, mul: bool):
+ self.x = x
+ self.mul = mul
+
+ @jit # <---- How to do this correctly?
+ def calc(self, y):
+ if self.mul:
+ return self.x * y
+ return y
+```
+
+However, this approach will result in an error when you attempt to call this method:
+
+```{code-cell} ipython3
+:tags: [raises-exception]
+
+c = CustomClass(2, True)
+c.calc(3)
+```
+
+The problem is that the first argument to the function is `self`, which has type `CustomClass`, and JAX does not know how to handle this type. There are three basic strategies we might use in this case, and we'll discuss them below.
+
++++
+
+### Strategy 1: JIT-compiled helper function
+
+The most straightforward approach is to create a helper function external to the class that can be JIT-decorated in the normal way. For example:
+
+```{code-cell} ipython3
+from functools import partial
+
+class CustomClass:
+ def __init__(self, x: jnp.ndarray, mul: bool):
+ self.x = x
+ self.mul = mul
+
+ def calc(self, y):
+ return _calc(self.mul, self.x, y)
+
+@partial(jit, static_argnums=0)
+def _calc(mul, x, y):
+ if mul:
+ return x * y
+ return y
+```
+
+The result will work as expected:
+
+```{code-cell} ipython3
+c = CustomClass(2, True)
+print(c.calc(3))
+```
+
+The benefit of such an approach is that it is simple, explicit, and it avoids the need to teach JAX how to handle objects of type `CustomClass`. However, you may wish to keep all the method logic in the same place.
+
++++
+
+### Strategy 2: Marking `self` as static
+
+Another common pattern is to use `static_argnums` to mark the `self` argument as static. But this must be done with care to avoid unexpected results. You may be tempted to simply do this:
+
+```{code-cell} ipython3
+class CustomClass:
+ def __init__(self, x: jnp.ndarray, mul: bool):
+ self.x = x
+ self.mul = mul
+
+ # WARNING: this example is broken, as we'll see below. Don't copy & paste!
+ @partial(jit, static_argnums=0)
+ def calc(self, y):
+ if self.mul:
+ return self.x * y
+ return y
+```
+
+If you call the method, it will no longer raise an error:
+
+```{code-cell} ipython3
+c = CustomClass(2, True)
+print(c.calc(3))
+```
+
+However, there is a catch: if you mutate the object after the first method call, the subsequent method call may return an incorrect result:
+
+```{code-cell} ipython3
+c.mul = False
+print(c.calc(3)) # Should print 3
+```
+
+Why is this? When you mark an object as static, it will effectively be used as a dictionary key in JIT's internal compilation cache, meaning its hash (i.e. `hash(obj)`) equality (i.e. `obj1 == obj2`) and object identity (i.e. `obj1 is obj2`) will be assumed to have consistent behavior. The default `__hash__` for a custom object is its object ID, and so JAX has no way of knowing that a mutated object should trigger a re-compilation.
+
+You can partially address this by defining an appropriate `__hash__` and `__eq__` methods for your object; for example:
+
+```{code-cell} ipython3
+class CustomClass:
+ def __init__(self, x: jnp.ndarray, mul: bool):
+ self.x = x
+ self.mul = mul
+
+ @partial(jit, static_argnums=0)
+ def calc(self, y):
+ if self.mul:
+ return self.x * y
+ return y
+
+ def __hash__(self):
+ return hash((self.x, self.mul))
+
+ def __eq__(self, other):
+ return (isinstance(other, CustomClass) and
+ (self.x, self.mul) == (other.x, other.mul))
+```
+
+(see the [`object.__hash__`](https://docs.python.org/3/reference/datamodel.html#object.__hash__) documentation for more discussion of the requirements
+when overriding `__hash__`).
+
+This should work correctly with JIT and other transforms **so long as you never mutate your object**. Mutations of objects used as hash keys lead to several subtle problems, which is why for example mutable Python containers (e.g. [`dict`](https://docs.python.org/3/library/stdtypes.html#dict), [`list`](https://docs.python.org/3/library/stdtypes.html#list)) don't define `__hash__`, while their immutable counterparts (e.g. [`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple)) do.
+
+If your class relies on in-place mutations (such as setting `self.attr = ...` within its methods), then your object is not really "static" and marking it as such may lead to problems. Fortunately, there's another option for this case.
+
++++
+
+### Strategy 3: Making `CustomClass` a PyTree
+
+The most flexible approach to correctly JIT-compiling a class method is to register the type as a custom PyTree object; see [Custom pytree nodes](https://docs.jax.dev/en/latest/custom_pytrees.html#pytrees-custom-pytree-nodes). This lets you specify exactly which components of the class should be treated as static and which should be
+treated as dynamic. Here's how it might look:
+
+```{code-cell} ipython3
+class CustomClass:
+ def __init__(self, x: jnp.ndarray, mul: bool):
+ self.x = x
+ self.mul = mul
+
+ @jit
+ def calc(self, y):
+ if self.mul:
+ return self.x * y
+ return y
+
+ def _tree_flatten(self):
+ children = (self.x,) # arrays / dynamic values
+ aux_data = {'mul': self.mul} # static values
+ return (children, aux_data)
+
+ @classmethod
+ def _tree_unflatten(cls, aux_data, children):
+ return cls(*children, **aux_data)
+
+from jax import tree_util
+tree_util.register_pytree_node(CustomClass,
+ CustomClass._tree_flatten,
+ CustomClass._tree_unflatten)
+```
+
+This is certainly more involved, but it solves all the issues associated with the simpler approaches used above:
+
+```{code-cell} ipython3
+c = CustomClass(2, True)
+print(c.calc(3))
+```
+
+```{code-cell} ipython3
+c.mul = False # mutation is detected
+print(c.calc(3))
+```
+
+```{code-cell} ipython3
+c = CustomClass(jnp.array(2), True) # non-hashable x is supported
+print(c.calc(3))
+```
+
+So long as your `tree_flatten` and `tree_unflatten` functions correctly handle all relevant attributes in the class, you should be able to use objects of this type directly as arguments to JIT-compiled functions, without any special annotations.
+
+++ {"id": "oZ_jE2WAypdL"}
## 🔪 Out-of-bounds indexing
diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb
index e80c7ae94687..27f53cf32778 100644
--- a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb
+++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb
@@ -6,20 +6,21 @@
"id": "LqiaKasFjH82"
},
"source": [
- "# Custom derivative rules\n",
+ "(advanced-autodiff-custom-derivative-rules)=\n",
+ "# Custom derivative rules for JAX-transformable Python functions\n",
"\n",
- "\n",
+ "\n",
"\n",
"[](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)\n",
"\n",
"There are two ways to define differentiation rules in JAX:\n",
"\n",
- "1. using `jax.custom_jvp` and `jax.custom_vjp` to define custom differentiation rules for Python functions that are already JAX-transformable; and\n",
+ "1. using [`jax.custom_jvp`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.html) and [`jax.custom_vjp`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_vjp.html) to define custom differentiation rules for Python functions that are already JAX-transformable; and\n",
"2. defining new `core.Primitive` instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems.\n",
"\n",
"This notebook is about #1. To read instead about #2, see the [notebook on adding primitives](https://docs.jax.dev/en/latest/notebooks/How_JAX_primitives_work.html).\n",
"\n",
- "For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [jax.jvp](https://docs.jax.dev/en/latest/jax.html#jax.jvp) and [jax.grad](https://docs.jax.dev/en/latest/jax.html#jax.grad), and the mathematical meaning of JVPs and VJPs."
+ "For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [jax.jvp](https://docs.jax.dev/en/latest/_autosummary/jax.jvp.html) and [jax.grad](https://docs.jax.dev/en/latest/_autosummary/jax.grad.html), and the mathematical meaning of JVPs and VJPs."
]
},
{
@@ -28,16 +29,7 @@
"id": "9Fg3NFNY-2RY"
},
"source": [
- "## Summary"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "ZgMNRtXyWIW8"
- },
- "source": [
- "### Custom JVPs with `jax.custom_jvp`"
+ "### TL;DR: Custom JVPs with `jax.custom_jvp`"
]
},
{
@@ -144,7 +136,7 @@
"id": "N2DOGCREWXFj"
},
"source": [
- "### Custom VJPs with `jax.custom_vjp`"
+ "### TL;DR: Custom VJPs with `jax.custom_vjp`"
]
},
{
@@ -209,7 +201,7 @@
"id": "AR02eyd1GQhC"
},
"source": [
- "### Numerical stability\n",
+ "### Example: Numerical stability\n",
"\n",
"One application of `jax.custom_jvp` is to improve the numerical stability of differentiation."
]
@@ -370,7 +362,7 @@
"\n",
"Instead of generating such large and small values, hoping for a cancellation that floats can't always provide, we'd rather just express the derivative function as a more numerically stable program. In particular, we can write a program that more closely evaluates the equal mathematical expression $1 - \\frac{1}{1 + e^x}$, with no cancellation in sight.\n",
"\n",
- "This problem is interesting because even though our definition of `log1pexp` could already be JAX-differentiated (and transformed with `jit`, `vmap`, ...), we're not happy with the result of applying standard autodiff rules to the primitives comprising `log1pexp` and composing the result. Instead, we'd like to specify how the whole function `log1pexp` should be differentiated, as a unit, and thus arrange those exponentials better.\n",
+ "This problem is interesting because even though our definition of `log1pexp` could already be JAX-differentiated (and transformed with [`jit`](https://docs.jax.dev/en/latest/_autosummary/jax.jit.html), [`vmap`](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html), ...), we're not happy with the result of applying standard autodiff rules to the primitives comprising `log1pexp` and composing the result. Instead, we'd like to specify how the whole function `log1pexp` should be differentiated, as a unit, and thus arrange those exponentials better.\n",
"\n",
"This is one application of custom derivative rules for Python functions that are already JAX transformable: specifying how a composite function should be differentiated, while still using its original Python definition for other transformations (like `jit`, `vmap`, ...).\n",
"\n",
@@ -450,7 +442,7 @@
"id": "9sVUGbGkUOqO"
},
"source": [
- "Here's a `defjvps` convenience wrapper to express the same thing:"
+ "Here's a [`defjvps`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.defjvps.html) convenience wrapper to express the same thing:"
]
},
{
@@ -500,7 +492,7 @@
"id": "V9tHAfrSF1N-"
},
"source": [
- "### Enforcing a differentiation convention\n",
+ "### Example: Enforcing a differentiation convention\n",
"\n",
"A related application is to enforce a differentiation convention, perhaps at a boundary."
]
@@ -657,11 +649,11 @@
"id": "7J2A85wbSAmF"
},
"source": [
- "### Gradient clipping\n",
+ "### Example: Gradient clipping\n",
"\n",
"While in some cases we want to express a mathematical differentiation computation, in other cases we may even want to take a step away from mathematics to adjust the computation autodiff performs. One canonical example is reverse-mode gradient clipping.\n",
"\n",
- "For gradient clipping, we can use `jnp.clip` together with a `jax.custom_vjp` reverse-mode-only rule:"
+ "For gradient clipping, we can use [`jnp.clip`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.clip.html) together with a [`jax.custom_vjp`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_vjp.html) reverse-mode-only rule:"
]
},
{
@@ -782,7 +774,7 @@
"id": "CICQuI86WK4_"
},
"source": [
- "### Python debugging\n",
+ "### Example: Python debugging\n",
"\n",
"Another application that is motivated by development workflow rather than numerics is to set a `pdb` debugger trace in the backward pass of reverse-mode autodiff."
]
@@ -804,7 +796,7 @@
"id": "IC7tEcr1-Fc5"
},
"source": [
- "### Implicit function differentiation of iterative implementations\n",
+ "### Example: Implicit function differentiation of iterative implementations\n",
"\n",
"This example gets pretty deep in the mathematical weeds!"
]
@@ -815,7 +807,7 @@
"id": "szAt97t80hew"
},
"source": [
- "Another application for `jax.custom_vjp` is reverse-mode differentiation of functions that are JAX-transformable (by `jit`, `vmap`, ...) but not efficiently JAX-differentiable for some reason, perhaps because they involve `lax.while_loop`. (It's not possible to produce an XLA HLO program that efficiently computes the reverse-mode derivative of an XLA HLO While loop because that would require a program with unbounded memory use, which isn't possible to express in XLA HLO, at least without side-effecting interactions through infeed/outfeed.)\n",
+ "Another application for `jax.custom_vjp` is reverse-mode differentiation of functions that are JAX-transformable (by `jit`, `vmap`, ...) but not efficiently JAX-differentiable for some reason, perhaps because they involve [`lax.while_loop`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.while_loop.html). (It's not possible to produce an XLA HLO program that efficiently computes the reverse-mode derivative of an XLA HLO While loop because that would require a program with unbounded memory use, which isn't possible to express in XLA HLO, at least without side-effecting interactions through infeed/outfeed.)\n",
"\n",
"For example, consider this `fixed_point` routine which computes a fixed point by iteratively applying a function in a `while_loop`:"
]
@@ -1069,7 +1061,7 @@
"id": "HowvqayEuy-H"
},
"source": [
- "A limitation to this approach is that the argument `f` can't close over any values involved in differentiation. That is, you might notice that we kept the parameter `a` explicit in the argument list of `fixed_point`. For this use case, consider using the low-level primitive `lax.custom_root`, which allows for deriviatives in closed-over variables with custom root-finding functions."
+ "A limitation to this approach is that the argument `f` can't close over any values involved in differentiation. That is, you might notice that we kept the parameter `a` explicit in the argument list of `fixed_point`. For this use case, consider using the low-level primitive `lax.custom_root`, which allows for derivatives in closed-over variables with custom root-finding functions."
]
},
{
@@ -1089,7 +1081,7 @@
"source": [
"### Use `jax.custom_jvp` to define forward-mode (and, indirectly, reverse-mode) rules\n",
"\n",
- "Here's a canonical basic example of using `jax.custom_jvp`, where the comments use\n",
+ "Here's a canonical basic example of using [`jax.custom_jvp`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.html), where the comments use\n",
"[Haskell-like type signatures](https://wiki.haskell.org/Type_signature):"
]
},
@@ -1272,7 +1264,7 @@
"id": "YPsPS3rdaGo2"
},
"source": [
- "The `defjvps` convenience wrapper lets us define a JVP for each argument separately, and the results are computed separately then summed:"
+ "The [`defjvps`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.defjvps.html) convenience wrapper lets us define a JVP for each argument separately, and the results are computed separately then summed:"
]
},
{
@@ -1656,7 +1648,7 @@
"source": [
"### Use `jax.custom_vjp` to define custom reverse-mode-only rules\n",
"\n",
- "While `jax.custom_jvp` suffices for controlling both forward- and, via JAX's automatic transposition, reverse-mode differentiation behavior, in some cases we may want to directly control a VJP rule, for example in the latter two example problems presented above. We can do that with `jax.custom_vjp`:"
+ "While `jax.custom_jvp` suffices for controlling both forward- and, via JAX's automatic transposition, reverse-mode differentiation behavior, in some cases we may want to directly control a VJP rule, for example in the latter two example problems presented above. We can do that with [`jax.custom_vjp`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_vjp.html):"
]
},
{
@@ -2200,7 +2192,7 @@
"id": "JKTNivxbmKWO"
},
"source": [
- "### Handling non-differentiable arguments"
+ "### Handling non-differentiable arguments"
]
},
{
diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.md b/docs/notebooks/Custom_derivative_rules_for_Python_code.md
index 82b97e195bd9..ccdc709bd48b 100644
--- a/docs/notebooks/Custom_derivative_rules_for_Python_code.md
+++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.md
@@ -13,28 +13,25 @@ kernelspec:
+++ {"id": "LqiaKasFjH82"}
-# Custom derivative rules
+(advanced-autodiff-custom-derivative-rules)=
+# Custom derivative rules for JAX-transformable Python functions
-
+
[](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)
There are two ways to define differentiation rules in JAX:
-1. using `jax.custom_jvp` and `jax.custom_vjp` to define custom differentiation rules for Python functions that are already JAX-transformable; and
+1. using [`jax.custom_jvp`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.html) and [`jax.custom_vjp`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_vjp.html) to define custom differentiation rules for Python functions that are already JAX-transformable; and
2. defining new `core.Primitive` instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems.
This notebook is about #1. To read instead about #2, see the [notebook on adding primitives](https://docs.jax.dev/en/latest/notebooks/How_JAX_primitives_work.html).
-For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [jax.jvp](https://docs.jax.dev/en/latest/jax.html#jax.jvp) and [jax.grad](https://docs.jax.dev/en/latest/jax.html#jax.grad), and the mathematical meaning of JVPs and VJPs.
+For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [jax.jvp](https://docs.jax.dev/en/latest/_autosummary/jax.jvp.html) and [jax.grad](https://docs.jax.dev/en/latest/_autosummary/jax.grad.html), and the mathematical meaning of JVPs and VJPs.
+++ {"id": "9Fg3NFNY-2RY"}
-## Summary
-
-+++ {"id": "ZgMNRtXyWIW8"}
-
-### Custom JVPs with `jax.custom_jvp`
+### TL;DR: Custom JVPs with `jax.custom_jvp`
```{code-cell} ipython3
:id: zXic8tr--1PK
@@ -94,7 +91,7 @@ print(grad(f)(2., 3.))
+++ {"id": "N2DOGCREWXFj"}
-### Custom VJPs with `jax.custom_vjp`
+### TL;DR: Custom VJPs with `jax.custom_vjp`
```{code-cell} ipython3
:id: 35ScHqhrBwPh
@@ -131,7 +128,7 @@ To get an idea of what problems `jax.custom_jvp` and `jax.custom_vjp` are meant
+++ {"id": "AR02eyd1GQhC"}
-### Numerical stability
+### Example: Numerical stability
One application of `jax.custom_jvp` is to improve the numerical stability of differentiation.
@@ -197,7 +194,7 @@ Stepping through how the jaxpr would be evaluated, we can see that the last line
Instead of generating such large and small values, hoping for a cancellation that floats can't always provide, we'd rather just express the derivative function as a more numerically stable program. In particular, we can write a program that more closely evaluates the equal mathematical expression $1 - \frac{1}{1 + e^x}$, with no cancellation in sight.
-This problem is interesting because even though our definition of `log1pexp` could already be JAX-differentiated (and transformed with `jit`, `vmap`, ...), we're not happy with the result of applying standard autodiff rules to the primitives comprising `log1pexp` and composing the result. Instead, we'd like to specify how the whole function `log1pexp` should be differentiated, as a unit, and thus arrange those exponentials better.
+This problem is interesting because even though our definition of `log1pexp` could already be JAX-differentiated (and transformed with [`jit`](https://docs.jax.dev/en/latest/_autosummary/jax.jit.html), [`vmap`](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html), ...), we're not happy with the result of applying standard autodiff rules to the primitives comprising `log1pexp` and composing the result. Instead, we'd like to specify how the whole function `log1pexp` should be differentiated, as a unit, and thus arrange those exponentials better.
This is one application of custom derivative rules for Python functions that are already JAX transformable: specifying how a composite function should be differentiated, while still using its original Python definition for other transformations (like `jit`, `vmap`, ...).
@@ -239,7 +236,7 @@ print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))
+++ {"id": "9sVUGbGkUOqO"}
-Here's a `defjvps` convenience wrapper to express the same thing:
+Here's a [`defjvps`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.defjvps.html) convenience wrapper to express the same thing:
```{code-cell} ipython3
:id: xfQTp8F7USEM
@@ -263,7 +260,7 @@ print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))
+++ {"id": "V9tHAfrSF1N-"}
-### Enforcing a differentiation convention
+### Example: Enforcing a differentiation convention
A related application is to enforce a differentiation convention, perhaps at a boundary.
@@ -341,11 +338,11 @@ print(grad(f)(0.))
+++ {"id": "7J2A85wbSAmF"}
-### Gradient clipping
+### Example: Gradient clipping
While in some cases we want to express a mathematical differentiation computation, in other cases we may even want to take a step away from mathematics to adjust the computation autodiff performs. One canonical example is reverse-mode gradient clipping.
-For gradient clipping, we can use `jnp.clip` together with a `jax.custom_vjp` reverse-mode-only rule:
+For gradient clipping, we can use [`jnp.clip`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.clip.html) together with a [`jax.custom_vjp`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_vjp.html) reverse-mode-only rule:
```{code-cell} ipython3
:id: 8jfjSanIW_tJ
@@ -394,7 +391,7 @@ plt.plot(vmap(grad(clip_sin))(t))
+++ {"id": "CICQuI86WK4_"}
-### Python debugging
+### Example: Python debugging
Another application that is motivated by development workflow rather than numerics is to set a `pdb` debugger trace in the backward pass of reverse-mode autodiff.
@@ -406,13 +403,13 @@ We'll defer an example until the next section.
+++ {"id": "IC7tEcr1-Fc5"}
-### Implicit function differentiation of iterative implementations
+### Example: Implicit function differentiation of iterative implementations
This example gets pretty deep in the mathematical weeds!
+++ {"id": "szAt97t80hew"}
-Another application for `jax.custom_vjp` is reverse-mode differentiation of functions that are JAX-transformable (by `jit`, `vmap`, ...) but not efficiently JAX-differentiable for some reason, perhaps because they involve `lax.while_loop`. (It's not possible to produce an XLA HLO program that efficiently computes the reverse-mode derivative of an XLA HLO While loop because that would require a program with unbounded memory use, which isn't possible to express in XLA HLO, at least without side-effecting interactions through infeed/outfeed.)
+Another application for `jax.custom_vjp` is reverse-mode differentiation of functions that are JAX-transformable (by `jit`, `vmap`, ...) but not efficiently JAX-differentiable for some reason, perhaps because they involve [`lax.while_loop`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.while_loop.html). (It's not possible to produce an XLA HLO program that efficiently computes the reverse-mode derivative of an XLA HLO While loop because that would require a program with unbounded memory use, which isn't possible to express in XLA HLO, at least without side-effecting interactions through infeed/outfeed.)
For example, consider this `fixed_point` routine which computes a fixed point by iteratively applying a function in a `while_loop`:
@@ -559,7 +556,7 @@ print(grad(grad(jnp.sqrt))(2.))
+++ {"id": "HowvqayEuy-H"}
-A limitation to this approach is that the argument `f` can't close over any values involved in differentiation. That is, you might notice that we kept the parameter `a` explicit in the argument list of `fixed_point`. For this use case, consider using the low-level primitive `lax.custom_root`, which allows for deriviatives in closed-over variables with custom root-finding functions.
+A limitation to this approach is that the argument `f` can't close over any values involved in differentiation. That is, you might notice that we kept the parameter `a` explicit in the argument list of `fixed_point`. For this use case, consider using the low-level primitive `lax.custom_root`, which allows for derivatives in closed-over variables with custom root-finding functions.
+++ {"id": "Dr0aNkBslfQf"}
@@ -569,7 +566,7 @@ A limitation to this approach is that the argument `f` can't close over any valu
### Use `jax.custom_jvp` to define forward-mode (and, indirectly, reverse-mode) rules
-Here's a canonical basic example of using `jax.custom_jvp`, where the comments use
+Here's a canonical basic example of using [`jax.custom_jvp`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.html), where the comments use
[Haskell-like type signatures](https://wiki.haskell.org/Type_signature):
```{code-cell} ipython3
@@ -670,7 +667,7 @@ print(grad(f)(2., 3.))
+++ {"id": "YPsPS3rdaGo2"}
-The `defjvps` convenience wrapper lets us define a JVP for each argument separately, and the results are computed separately then summed:
+The [`defjvps`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.defjvps.html) convenience wrapper lets us define a JVP for each argument separately, and the results are computed separately then summed:
```{code-cell} ipython3
:id: CsQIUhUkajua
@@ -845,7 +842,7 @@ print(grad(f)(-1.))
### Use `jax.custom_vjp` to define custom reverse-mode-only rules
-While `jax.custom_jvp` suffices for controlling both forward- and, via JAX's automatic transposition, reverse-mode differentiation behavior, in some cases we may want to directly control a VJP rule, for example in the latter two example problems presented above. We can do that with `jax.custom_vjp`:
+While `jax.custom_jvp` suffices for controlling both forward- and, via JAX's automatic transposition, reverse-mode differentiation behavior, in some cases we may want to directly control a VJP rule, for example in the latter two example problems presented above. We can do that with [`jax.custom_vjp`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_vjp.html):
```{code-cell} ipython3
:id: zAZk1n3dUw76
@@ -1141,7 +1138,7 @@ print(grad(fun)(pt))
+++ {"id": "JKTNivxbmKWO"}
-### Handling non-differentiable arguments
+### Handling non-differentiable arguments
+++ {"id": "7g9sXSp_uc36"}
diff --git a/docs/notebooks/autodiff_cookbook.ipynb b/docs/notebooks/autodiff_cookbook.ipynb
index 5538b70dac93..46f887f8986f 100644
--- a/docs/notebooks/autodiff_cookbook.ipynb
+++ b/docs/notebooks/autodiff_cookbook.ipynb
@@ -1637,7 +1637,7 @@
"source": [
"## More advanced autodiff\n",
"\n",
- "In this notebook, we worked through some easy, and then progressively more complicated, applications of automatic differentiation in JAX. We hope you now feel that taking derivatives in JAX is easy and powerful. \n",
+ "In this notebook, we worked through some easy, and then progressively more complicated, applications of automatic differentiation in JAX. We hope you now feel that taking derivatives in JAX is easy and powerful. For more details, check out the [\"Advanced automatic differentiation\" section in the JAX advanced guides](https://jax.readthedocs.io/en/latest/advanced_guides.html).\n",
"\n",
"There's a whole world of other autodiff tricks and functionality out there. Topics we didn't cover, but hope to in an \"Advanced Autodiff Cookbook\" include:\n",
"\n",
diff --git a/docs/notebooks/autodiff_cookbook.md b/docs/notebooks/autodiff_cookbook.md
index db6fde8051d1..d2cb091bc0e8 100644
--- a/docs/notebooks/autodiff_cookbook.md
+++ b/docs/notebooks/autodiff_cookbook.md
@@ -960,7 +960,7 @@ grad(f, holomorphic=True)(A)
## More advanced autodiff
-In this notebook, we worked through some easy, and then progressively more complicated, applications of automatic differentiation in JAX. We hope you now feel that taking derivatives in JAX is easy and powerful.
+In this notebook, we worked through some easy, and then progressively more complicated, applications of automatic differentiation in JAX. We hope you now feel that taking derivatives in JAX is easy and powerful. For more details, check out the ["Advanced automatic differentiation" section in the JAX advanced guides](https://jax.readthedocs.io/en/latest/advanced_guides.html).
There's a whole world of other autodiff tricks and functionality out there. Topics we didn't cover, but hope to in an "Advanced Autodiff Cookbook" include:
diff --git a/docs/notebooks/thinking_in_jax.ipynb b/docs/notebooks/thinking_in_jax.ipynb
index 3fd8913459ea..26809769c981 100644
--- a/docs/notebooks/thinking_in_jax.ipynb
+++ b/docs/notebooks/thinking_in_jax.ipynb
@@ -797,8 +797,76 @@
},
{
"cell_type": "markdown",
+ "id": "b79e0c62",
"metadata": {},
"source": [
+ "## Debugging\n",
+ "\n",
+ "Debugging JAX code can be challenging due to its functional programming model and the fact that JAX code is often transformed via JIT compilation or vectorization. However, JAX provides several tools to help with debugging.\n",
+ "\n",
+ "### `jax.debug.print`\n",
+ "\n",
+ "For simple inspection, use [`jax.debug.print`](https://docs.jax.dev/en/latest/_autosummary/jax.debug.print.html).\n",
+ "\n",
+ "Python's built-in `print` executes at trace-time, before the runtime values exist. Because of this, `print` will only show tracer values within `jax.jit`-decorated code."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "61675ec9",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import jax\n",
+ "import jax.numpy as jnp\n",
+ "\n",
+ "@jax.jit\n",
+ "def f(x):\n",
+ " print(\"print(x) ->\", x)\n",
+ " y = jnp.sin(x)\n",
+ " print(\"print(y) ->\", y)\n",
+ " return y\n",
+ "\n",
+ "result = f(2.)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a34c34bb",
+ "metadata": {},
+ "source": [
+ "If you want to print the actual runtime values, you can use `jax.debug.print`:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "49b5cb05",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@jax.jit\n",
+ "def f(x):\n",
+ " jax.debug.print(\"jax.debug.print(x) -> {x}\", x=x)\n",
+ " y = jnp.sin(x)\n",
+ " jax.debug.print(\"jax.debug.print(y) -> {y}\", y=y)\n",
+ " return y\n",
+ "\n",
+ "result = f(2.)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "515495d4",
+ "metadata": {},
+ "source": [
+ "### Debugging flags\n",
+ "\n",
+ "JAX offers flags and context managers that enable catching errors more easily. For example, you can enable the `jax.debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code. You can also enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`.\n",
+ "\n",
+ "For more details, see [Introduction to debugging](https://docs.jax.dev/en/latest/debugging.html).\n",
+ "\n",
"---\n",
"\n",
"This is just a taste of what JAX can do. We're really excited to see what you do with it!"
diff --git a/docs/notebooks/thinking_in_jax.md b/docs/notebooks/thinking_in_jax.md
index 71cbe8a58e54..77f8797c4f5f 100644
--- a/docs/notebooks/thinking_in_jax.md
+++ b/docs/notebooks/thinking_in_jax.md
@@ -474,6 +474,49 @@ For more on pseudo random numbers in JAX, see the [Pseudorandom numbers tutorial
+++
+## Debugging
+
+Debugging JAX code can be challenging due to its functional programming model and the fact that JAX code is often transformed via JIT compilation or vectorization. However, JAX provides several tools to help with debugging.
+
+### `jax.debug.print`
+
+For simple inspection, use [`jax.debug.print`](https://docs.jax.dev/en/latest/_autosummary/jax.debug.print.html).
+
+Python's built-in `print` executes at trace-time, before the runtime values exist. Because of this, `print` will only show tracer values within `jax.jit`-decorated code.
+
+```{code-cell} ipython3
+import jax
+import jax.numpy as jnp
+
+@jax.jit
+def f(x):
+ print("print(x) ->", x)
+ y = jnp.sin(x)
+ print("print(y) ->", y)
+ return y
+
+result = f(2.)
+```
+
+If you want to print the actual runtime values, you can use `jax.debug.print`:
+
+```{code-cell} ipython3
+@jax.jit
+def f(x):
+ jax.debug.print("jax.debug.print(x) -> {x}", x=x)
+ y = jnp.sin(x)
+ jax.debug.print("jax.debug.print(y) -> {y}", y=y)
+ return y
+
+result = f(2.)
+```
+
+### Debugging flags
+
+JAX offers flags and context managers that enable catching errors more easily. For example, you can enable the `jax.debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code. You can also enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`.
+
+For more details, see [Introduction to debugging](https://docs.jax.dev/en/latest/debugging.html).
+
---
This is just a taste of what JAX can do. We're really excited to see what you do with it!
diff --git a/docs/pallas/design/async_note.md b/docs/pallas/design/async_note.md
index b255a91d3ec8..b21725f7a29e 100644
--- a/docs/pallas/design/async_note.md
+++ b/docs/pallas/design/async_note.md
@@ -18,7 +18,7 @@ def f(x):
In this function, we could perform the `ppermute` at the same time as the `x + 1`. This is an optimization XLA does automatically by:
-1. decomposing `ppermute` into a `ppermute_start` and `ppermute_done` op, which are connected via a future.
+1. decomposing `ppermute` into a `ppermute_start` and `ppermute_done` op, which are connected via a future.
2. scheduling the `x + 1` between the `ppermute_start` and `ppermute_done`,
resulting in the following program:
@@ -107,12 +107,12 @@ def ppermute_start(x, *, axis_name) -> tuple[Semaphore, Semaphore, Array]:
),
),
in_specs=[
- pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=pl.ANY),
],
out_specs=(
pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
- pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=pl.ANY),
),
)(x)
return send_sem, recv_sem, out
@@ -139,11 +139,11 @@ def ppermute_done(send_sem, recv_sem, out) ->Array:
),
),
in_specs=[
- pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=pl.ANY),
pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
],
- out_specs=pl.BlockSpec(memory_space=pltpu.ANY),
+ out_specs=pl.BlockSpec(memory_space=pl.ANY),
input_output_aliases={0:0}
)(out, send_sem, recv_sem)
return out
@@ -167,9 +167,9 @@ def f(x):
There are three remaining issues with this, each of which exists outside of Pallas to some degree. Here they are at a high level.
-1. Scheduling \- just because we write `ppermute_start`, then `x + 1`, then `ppermute_done` doesn’t guarantee that they will happen in that order. XLA is responsible for scheduling, so when we write JAX programs, we are setting up data dependencies that XLA will respect but XLA will not respect the specific order of operations written in JAX.
-2. Lifetimes \- XLA assumes that once a value is out of scope in the dependency graph, its memory can be freed for use by other values. If we have an op that asynchronously copies x \-\> y, we need to ensure that x is alive until the copy is complete, otherwise we will be copying from garbage memory.
-3. Defensive copies \- XLA reserves the right to create copies of values. We need to make sure we don’t introduce unnecessary copies to a) avoid unnecessary runtime overhead and b) ensure correctness.
+1. Scheduling \- just because we write `ppermute_start`, then `x + 1`, then `ppermute_done` doesn’t guarantee that they will happen in that order. XLA is responsible for scheduling, so when we write JAX programs, we are setting up data dependencies that XLA will respect but XLA will not respect the specific order of operations written in JAX.
+2. Lifetimes \- XLA assumes that once a value is out of scope in the dependency graph, its memory can be freed for use by other values. If we have an op that asynchronously copies x \-\> y, we need to ensure that x is alive until the copy is complete, otherwise we will be copying from garbage memory.
+3. Defensive copies \- XLA reserves the right to create copies of values. We need to make sure we don’t introduce unnecessary copies to a) avoid unnecessary runtime overhead and b) ensure correctness.
We will go over these issues one by one and suggest fixes.
@@ -292,13 +292,13 @@ def ppermute_start(x, *, axis_name) -> tuple[Semaphore, Semaphore, Array, Array]
),
),
in_specs=[
- pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=pl.ANY),
],
out_specs=(
pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
- pl.BlockSpec(memory_space=pltpu.ANY),
- pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=pl.ANY),
+ pl.BlockSpec(memory_space=pl.ANY),
),
input_output_aliases={0:2}
)(x)
@@ -322,12 +322,12 @@ def ppermute_done(send_sem, recv_sem, x, out) ->Array:
),
),
in_specs=[
- pl.BlockSpec(memory_space=pltpu.ANY),
- pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=pl.ANY),
+ pl.BlockSpec(memory_space=pl.ANY),
pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
],
- out_specs=pl.BlockSpec(memory_space=pltpu.ANY),
+ out_specs=pl.BlockSpec(memory_space=pl.ANY),
input_output_aliases={1:0}
)(x, out, send_sem, recv_sem)
return out
@@ -485,7 +485,7 @@ def f(x):
def body(i, x):
*sems, x, x2 = ppermute_start(x)
x2 = ppermute_done((*sems, x, x2))
-
+
*sems, x2, y = ppermute_start(x2)
y = ppermute_done((*sems, x2, y))
return y
@@ -574,10 +574,10 @@ our program should now be correct.
So we’ve come up with some rules of thumb:
-1. If we have operations dependent on the input value to the `ppermute`, unpack the future to use the aliased value instead of the original value.
+1. If we have operations dependent on the input value to the `ppermute`, unpack the future to use the aliased value instead of the original value.
2. Use `unroll >= 2` when doing `ppermute`s in a loop body.
-Let’s combine everything into one function that does `ppermute`s in a loop and accumulates the result.
+Let’s combine everything into one function that does `ppermute`s in a loop and accumulates the result.
```py
def f(x):
@@ -641,7 +641,7 @@ def f(x):
return y_ref[...]
```
-Before, we ran into scheduling ambiguity, where XLA could re-order the add w.r.t. the `ppermute`. With stateful semantics, we actually add in an ordering constraint\! `x_ref[...] += 1` mutates `x_ref` so it can’t be moved wrt to `ppermute_done_stateful`. JAX can inject these scheduling constraints as part of the lowering to HLO.
+Before, we ran into scheduling ambiguity, where XLA could re-order the add w.r.t. the `ppermute`. With stateful semantics, we actually add in an ordering constraint\! `x_ref[...] += 1` mutates `x_ref` so it can’t be moved wrt to `ppermute_done_stateful`. JAX can inject these scheduling constraints as part of the lowering to HLO.
The final key difference is evident when we try our loop examples.
@@ -665,8 +665,8 @@ To handle this without the manual unrolling, we’d create a scratch buffer with
The realization here is that being stateful forces us to deal with a lot of the issues that pop up with value semantics earlier on. We define them away\!
-1. Scheduling \- stateful ops that have `Ref`s as inputs force an ordering of our program. Note that this will schedule operations on the same `Ref` wrt to each other. We might also need an `opt_barrier_stateful` to enforce more ordering constraints.
-2. Lifetimes \- `Ref` lifetimes can be scoped via `run_state` or could be inputs to stateful ops.
+1. Scheduling \- stateful ops that have `Ref`s as inputs force an ordering of our program. Note that this will schedule operations on the same `Ref` wrt to each other. We might also need an `opt_barrier_stateful` to enforce more ordering constraints.
+2. Lifetimes \- `Ref` lifetimes can be scoped via `run_state` or could be inputs to stateful ops.
3. Defensive copies \- Using `Ref`s forces us to handle buffer assignment “manually” and the lowering can ensure the aliasing works out to avoid any copies.
Another important fundamental limitation is that we eventually stage out an HLO program where the live buffers and semaphores are represented as array value types. XLA does not provide guarantees about buffer lifetimes or which memory spaces they live in for these intermediate values. *Therefore, it is possible XLA can copy array values even if they are actively being copied into by Pallas kernels.* This is easy to verify in HLO but it is a sharp edge of using custom calls to represent asynchronous operations in HLO.
diff --git a/docs/pallas/quickstart.ipynb b/docs/pallas/quickstart.ipynb
index 8ed5cac076d3..3fefc2cbc157 100644
--- a/docs/pallas/quickstart.ipynb
+++ b/docs/pallas/quickstart.ipynb
@@ -343,7 +343,7 @@
"\n",
"def iota(size: int):\n",
" return pl.pallas_call(iota_kernel,\n",
- " out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.SMEM),\n",
+ " out_specs=pl.BlockSpec(memory_space=pltpu.SMEM),\n",
" out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),\n",
" grid=(size,))()\n",
"iota(8)"
diff --git a/docs/pallas/quickstart.md b/docs/pallas/quickstart.md
index 3ff0801db965..f18225a589d5 100644
--- a/docs/pallas/quickstart.md
+++ b/docs/pallas/quickstart.md
@@ -230,7 +230,7 @@ from jax.experimental.pallas import tpu as pltpu
def iota(size: int):
return pl.pallas_call(iota_kernel,
- out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.SMEM),
+ out_specs=pl.BlockSpec(memory_space=pltpu.SMEM),
out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),
grid=(size,))()
iota(8)
diff --git a/docs/pallas/tpu/core_map.ipynb b/docs/pallas/tpu/core_map.ipynb
new file mode 100644
index 000000000000..38be63d61cb4
--- /dev/null
+++ b/docs/pallas/tpu/core_map.ipynb
@@ -0,0 +1,628 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "last_runtime": {
+ "build_target": "//third_party/py/jax_triton/google/pallas_tpu:notebook",
+ "kind": "private"
+ }
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Pallas Core-specific Programming"
+ ],
+ "metadata": {
+ "id": "YIt0Za36LYg9"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "In this guide, we explore using `pl.core_map` to write Pallas kernels. Compared with `pallas_call`, `core_map` offers a few key characteristics:\n",
+ "\n",
+ "* **Per-core level programming**: You write code for an TPU/GPU core, not for a JAX device. This gives you full control over what runs on every core, or how cores communicate and distribute work among one another.\n",
+ "\n",
+ "* **Collectives**: `core_map` explicitly models physical cores, so inter-core communication can be expressed safely.\n",
+ "\n",
+ "* **Platform generic**: `core_map` programming model works for TPU (TensorCore and SparseCore) and GPU with minimal boilerplate changes."
+ ],
+ "metadata": {
+ "id": "khDWSc7aOVts"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "This guide focuses on TPU. For how to use `core_map` on GPU to achieve higher thread flexibility, check out our [Pallas GPU `core_map` tutorial](https://docs.jax.dev/en/latest/pallas/gpu/reference.html#using-core-map)."
+ ],
+ "metadata": {
+ "id": "i8pl0CLqTVvL"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Environment setup\n",
+ "\n",
+ "Modern accelerators often have multiple cores under a device. For recent TPU chips (v4, v5p), every JAX device may contains 2 TensorCores (aka. a [Megacore](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#chips)). Some TPUs (v5p, v6e, 7x) also contain [SparseCores](https://openxla.org/xla/sparsecore#specifications_at_a_glance), each of which consists of many subcores.\n",
+ "\n",
+ "This guide was written on a v5p chip, which contains 4 devices (2 TensorCores each) and 4 SparseCores, each with 16 subcores."
+ ],
+ "metadata": {
+ "id": "bsOPXdJkzC-x"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "id": "14PNaMVsLUur",
+ "executionInfo": {
+ "status": "ok",
+ "timestamp": 1764795546418,
+ "user_tz": 480,
+ "elapsed": 2087,
+ "user": {
+ "displayName": "Ivy Zheng",
+ "userId": "15297372265856137303"
+ }
+ },
+ "outputId": "01976bb1-2f2f-40e9-ca23-f0e480a82ab3"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Running on 4 TPU v5p devices.\n"
+ ]
+ }
+ ],
+ "source": [
+ "from functools import partial\n",
+ "\n",
+ "import jax\n",
+ "from jax.sharding import NamedSharding\n",
+ "from jax.experimental import pallas as pl\n",
+ "from jax.experimental.pallas import tpu as pltpu\n",
+ "from jax.experimental.pallas import tpu_sc as plsc\n",
+ "import jax.numpy as jnp\n",
+ "import numpy as np\n",
+ "\n",
+ "\n",
+ "num_devices = jax.local_device_count()\n",
+ "assert num_devices > 1, \"Please run this notebook with more than one device.\"\n",
+ "\n",
+ "tpu_info = pltpu.get_tpu_info() # This notebook only runs on TPU.\n",
+ "print(f\"Running on {num_devices} TPU {tpu_info.chip_version} devices.\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "In addition to the typical TPU device mesh, you need to make a mesh of cores. Consider this as an addition dimension called `core`, with length 2, in addition to the 4-device mesh you work with. That is 8 cores in total."
+ ],
+ "metadata": {
+ "id": "3f0XEhaYnGyk"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Mesh of devices\n",
+ "mesh = jax.make_mesh((jax.device_count(),), ('device',))\n",
+ "print(mesh)\n",
+ "\n",
+ "# Mesh of cores, within a JAX device\n",
+ "tc_mesh = pltpu.create_tensorcore_mesh('core')\n",
+ "print(tc_mesh)\n",
+ "\n",
+ "num_devices = mesh.size\n",
+ "num_cores = len(tc_mesh.devices)\n",
+ "print(f\"There are {num_devices} devices, and {num_cores} cores each.\")"
+ ],
+ "metadata": {
+ "id": "jr5MARD-mIlC",
+ "executionInfo": {
+ "status": "ok",
+ "timestamp": 1764795546665,
+ "user_tz": 480,
+ "elapsed": 57,
+ "user": {
+ "displayName": "Ivy Zheng",
+ "userId": "15297372265856137303"
+ }
+ },
+ "outputId": "1ea63c2f-3aec-4cdd-9674-d0e2df32460c"
+ },
+ "execution_count": 2,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Mesh('device': 4, axis_types=(Explicit,))\n",
+ "TensorCoreMesh(devices=array([TensorCore(id=0), TensorCore(id=1)], dtype=object), axis_names=('core',))\n",
+ "There are 4 devices, and 2 cores each.\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## A simple per-core kernel\n",
+ "\n",
+ "`pl.core_map` allows you to write per-core local code, just as `jax.shard_map` allows you to write per-device code.\n",
+ "\n",
+ "In the example kernel below, each core has its own VMEM and semaphore allocations. As with normal kernel, you can initiate copies between HBM and VMEM refs using `pltpu.async_copy`.\n",
+ "\n",
+ "**Communication between cores**\n",
+ "\n",
+ "Before communicating between cores, it is good practice to perform a barrier (using `pltpu.semaphore_signal`) to ensure resources have been allocated and both cores are at the same point during the program.\n",
+ "\n",
+ "Once the cores are synchronized, use `pltpu.make_async_remote_copy` to send data between them. The `device_id` keyword argument generically allows sending to any core on any device, but if you just pass in `{'core': other_core_id}`, it will perform a intra-device inter-core copy (the other axis names are held constant).\n"
+ ],
+ "metadata": {
+ "id": "CYxwiULfndlh"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# This runs on every core\n",
+ "def swap_cores_kernel(in_hbm, out_hbm,\n",
+ " in_vmem, scratch_vmem, out_vmem,\n",
+ " sem, send_sem, recv_sem):\n",
+ " core_index = jax.lax.axis_index('core')\n",
+ " num_cores = jax.lax.axis_size('core')\n",
+ " slc_size = in_hbm.shape[-1] // num_cores\n",
+ " slc = pl.ds(core_index * slc_size, slc_size)\n",
+ "\n",
+ " # Copy in a core-dependent slice of the input\n",
+ " pltpu.async_copy(in_hbm.at[:, slc], in_vmem, sem).wait()\n",
+ "\n",
+ " # A barrier to make sure all cores have entered run_scoped.\n",
+ " # You won't need this if not doing inter-core communications.\n",
+ " dst_core = (core_index + 1) % num_cores\n",
+ " sem0 = pltpu.get_barrier_semaphore()\n",
+ " pltpu.semaphore_signal(sem0, 1, device_id={'core': dst_core})\n",
+ " pltpu.semaphore_wait(sem0, 1)\n",
+ "\n",
+ " # Swap data between core 0 and core 1\n",
+ " the_copy = pltpu.make_async_remote_copy(\n",
+ " in_vmem, scratch_vmem, send_sem, recv_sem, device_id={'core': dst_core},\n",
+ " )\n",
+ " the_copy.start()\n",
+ " the_copy.wait()\n",
+ "\n",
+ " # Core-local compute\n",
+ " out_vmem[...] = scratch_vmem[...] * 2\n",
+ "\n",
+ " # Copy out the output\n",
+ " pltpu.async_copy(out_vmem, out_hbm.at[:, slc], sem).wait()\n"
+ ],
+ "metadata": {
+ "id": "GkGRT2HRJOUU",
+ "executionInfo": {
+ "status": "ok",
+ "timestamp": 1764795546946,
+ "user_tz": 480,
+ "elapsed": 53,
+ "user": {
+ "displayName": "Ivy Zheng",
+ "userId": "15297372265856137303"
+ }
+ }
+ },
+ "execution_count": 3,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Once you have the local kernel:\n",
+ "\n",
+ " * Start your top-level JAX code with HBM refs, and allocate output refs if needed.\n",
+ "\n",
+ " * Use `pl.core_map`, which takes the TensorCore mesh, to start per-core programming.\n",
+ "\n",
+ " * You will need `collective_id` for the barrier semaphore.\n",
+ "\n",
+ " * Inside `pl.core_map`, invoke `pl.run_scoped` to allocate per-core scratch spaces (VMEM and semaphores) and run the local kernel."
+ ],
+ "metadata": {
+ "id": "2T0tSkFmoFLI"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "input_shape = (32, 256)\n",
+ "local_vmem_shape = (32 // num_devices, 256 // num_cores)\n",
+ "in_spec = jax.P('device', None)\n",
+ "sharding = NamedSharding(mesh, in_spec)\n",
+ "\n",
+ "@jax.jit\n",
+ "@partial(jax.shard_map, mesh=mesh, in_specs=in_spec, out_specs=in_spec,\n",
+ " check_vma=False)\n",
+ "def swap_cores(x):\n",
+ " # Get buffers out of the input and output\n",
+ " x_hbm_ref = jax.new_ref(x)\n",
+ " o_hbm_ref = jax.new_ref(jax.lax.empty(x.shape, x.dtype))\n",
+ "\n",
+ " @pl.core_map(tc_mesh, compiler_params=pltpu.CompilerParams(collective_id=0))\n",
+ " def _():\n",
+ " pl.run_scoped(\n",
+ " partial(swap_cores_kernel, x_hbm_ref, o_hbm_ref),\n",
+ " *([pltpu.VMEM(local_vmem_shape, x.dtype)] * 3), # VMEM allocations\n",
+ " *([pltpu.SemaphoreType.DMA] * 3), # semaphores\n",
+ " )\n",
+ " return o_hbm_ref[...]\n",
+ "\n",
+ "\n",
+ "x = jax.random.normal(jax.random.key(0), input_shape, jnp.float32)\n",
+ "x = jax.device_put(x, sharding)\n",
+ "y = swap_cores(x)\n",
+ "\n",
+ "np.testing.assert_array_equal(y[:, 128:], x[:, :128] * 2)\n",
+ "np.testing.assert_array_equal(y[:, :128], x[:, 128:] * 2)"
+ ],
+ "metadata": {
+ "id": "KT6zkEKi1Sbc",
+ "executionInfo": {
+ "status": "ok",
+ "timestamp": 1764795548996,
+ "user_tz": 480,
+ "elapsed": 1800,
+ "user": {
+ "displayName": "Ivy Zheng",
+ "userId": "15297372265856137303"
+ }
+ }
+ },
+ "execution_count": 4,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Save the boilerplate\n",
+ "\n",
+ "You can use the `pl.kernel` decorator to wrap boilerplate such as `core_map`, `run_scoped`, and output buffer allocation.\n",
+ "\n",
+ "Note that this should run inside any `jax.shard_map` you may have at the top level."
+ ],
+ "metadata": {
+ "id": "dLV8sKa4HuSX"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "@jax.jit\n",
+ "@partial(jax.shard_map, mesh=mesh, in_specs=in_spec, out_specs=in_spec, check_vma=False)\n",
+ "def swap_cores(x):\n",
+ " scratch_shapes = [pltpu.VMEM(local_vmem_shape, x.dtype)] * 3 + [pltpu.SemaphoreType.DMA] * 3\n",
+ " return pl.kernel(swap_cores_kernel, out_shape=x, mesh=tc_mesh,\n",
+ " scratch_shapes=scratch_shapes,\n",
+ " compiler_params=pltpu.CompilerParams(collective_id=0))(x)\n",
+ "\n",
+ "y = swap_cores(x)\n",
+ "np.testing.assert_array_equal(y[:, 128:], x[:, :128] * 2)\n",
+ "np.testing.assert_array_equal(y[:, :128], x[:, 128:] * 2)"
+ ],
+ "metadata": {
+ "id": "7cHnsRHPHyfH",
+ "executionInfo": {
+ "status": "ok",
+ "timestamp": 1764795549347,
+ "user_tz": 480,
+ "elapsed": 106,
+ "user": {
+ "displayName": "Ivy Zheng",
+ "userId": "15297372265856137303"
+ }
+ }
+ },
+ "execution_count": 5,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Pipelining with `core_map`\n",
+ "\n",
+ "Note that the kernel above only does simple copies and compute, without automatic pipelining via Pallas `grid` and `BlockSpec`. To do pipelining inside `core_map`, use `pltpu.emit_pipeline` inside the core-local kernel.\n",
+ "\n",
+ "**Automatically parallelize work amongst cores**\n",
+ "\n",
+ "The simple way is to annotate a block axis as `pltpu.PARALLEL`, and Pallas will automatically parallelize work along this axis. Both `pl.pallas_call` and `pltpu.emit_pipeline` supports this, via arguments `core_axis` and `dimension_semantics`. The `pallas_call` example is [in another guide](https://docs.jax.dev/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration), and the `emit_pipeline` case is shown below.\n",
+ "\n",
+ "When the `PARALLEL` annotation is provided, the corresponding grid dimension will be logically split and executed on separate cores. (The exact semantics of which grid dimensions are executed on which core is guaranteed).\n",
+ "\n",
+ "**Scratch shapes allocation**\n",
+ "\n",
+ "Note that in the example below, the top level `pl.run_scoped` (wrapped inside `kernel`) did not allocate any VMEM scratch buffers. Instead, `pltpu.emit_pipeline` allocates its own scratch buffers in VMEM and use them for its multiple buffering.\n"
+ ],
+ "metadata": {
+ "id": "4-G--Wnysdjs"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def add_one_body(in_vmem, out_vmem):\n",
+ " out_vmem[...] = in_vmem[...] + 1\n",
+ "\n",
+ "input_shape = (1024, 1024)\n",
+ "in_spec = jax.P('device', None)\n",
+ "\n",
+ "def add_one_kernel(x_hbm_ref, o_hbm_ref):\n",
+ " in_shape = x_hbm_ref.shape\n",
+ " pltpu.emit_pipeline(\n",
+ " add_one_body,\n",
+ " grid=(in_shape[0] // 8, in_shape[1] // 128),\n",
+ " in_specs=[pl.BlockSpec(\n",
+ " block_shape=(8, 128), index_map=lambda i, j: (i, j),\n",
+ " )],\n",
+ " out_specs=[pl.BlockSpec(\n",
+ " block_shape=(8, 128), index_map=lambda i, j: (i, j),\n",
+ " )],\n",
+ " core_axis_name='core',\n",
+ " dimension_semantics=(pltpu.PARALLEL, pltpu.ARBITRARY),\n",
+ " )(x_hbm_ref, o_hbm_ref)\n",
+ "\n",
+ "\n",
+ "@jax.jit\n",
+ "@partial(jax.shard_map, mesh=mesh, in_specs=in_spec, out_specs=in_spec, check_vma=False)\n",
+ "def add_one(x):\n",
+ " return pl.kernel(add_one_kernel, out_shape=x, mesh=tc_mesh, scratch_shapes=[])(x)\n",
+ "\n",
+ "\n",
+ "x = jax.random.normal(jax.random.key(0), input_shape, jnp.float32)\n",
+ "x = jax.device_put(x, NamedSharding(mesh, in_spec))\n",
+ "y = add_one(x)\n",
+ "\n",
+ "np.testing.assert_array_equal(y, x + 1)"
+ ],
+ "metadata": {
+ "id": "xUMRPLxb1rEH",
+ "executionInfo": {
+ "status": "ok",
+ "timestamp": 1764795550106,
+ "user_tz": 480,
+ "elapsed": 518,
+ "user": {
+ "displayName": "Ivy Zheng",
+ "userId": "15297372265856137303"
+ }
+ }
+ },
+ "execution_count": 6,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Scalar prefetch\n",
+ "\n",
+ "The code below extends the kernel above but uses [scalar prefetch and dynamic block indexing](https://docs.jax.dev/en/latest/pallas/tpu/sparse.html) to select a specific sub-slice of the input.\n",
+ "\n",
+ "This involves pre-allocating an SMEM buffer (via the `pl.run_scoped` call inside `kernel`) and populating the buffer using a `sync_copy` before the pipeline starts. Close over the dynamic index value inside the `index_map` to use it.\n",
+ "\n",
+ "**Manually delegate work amongst cores**\n",
+ "\n",
+ "The code example below also shows how `core_map` allows you to customize exactly how the work is split between cores, without relying on the automatic API shown above.\n",
+ "\n",
+ "To achieve that, customize your `index_map` to use the core index to work on different slices on different cores.\n"
+ ],
+ "metadata": {
+ "id": "Cq5rYyvL2Tte"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "input_shape = (1024, 1024)\n",
+ "in_spec = jax.P('device', None)\n",
+ "output_shape = (1024, 512)\n",
+ "\n",
+ "def indexed_add_one_kernel(in_refs, out_refs, i_smem_ref):\n",
+ " (x_hbm_ref, i_hbm_ref), o_hbm_ref = in_refs, out_refs\n",
+ " in_shape = x_hbm_ref.shape\n",
+ " pltpu.sync_copy(i_hbm_ref, i_smem_ref)\n",
+ "\n",
+ " core_idx = jax.lax.axis_index('core')\n",
+ " core_slc_size = in_shape[0] // num_cores\n",
+ " i_map = lambda i: core_idx * core_slc_size // 8 + i # split work among cores\n",
+ " j_map = lambda j: i_smem_ref[0] // 128 + j # use the prefetched offset\n",
+ "\n",
+ " pltpu.emit_pipeline(\n",
+ " add_one_body,\n",
+ " grid=(core_slc_size // 8, output_shape[1] // 128),\n",
+ " in_specs=[pl.BlockSpec(\n",
+ " block_shape=(8, 128), index_map=lambda i, j: (i_map(i), j_map(j)),\n",
+ " )],\n",
+ " out_specs=[pl.BlockSpec(\n",
+ " block_shape=(8, 128), index_map=lambda i, j: (i_map(i), j),\n",
+ " )]\n",
+ " )(x_hbm_ref, o_hbm_ref)\n",
+ "\n",
+ "\n",
+ "@jax.jit\n",
+ "@partial(jax.shard_map, mesh=mesh,\n",
+ " in_specs=(in_spec, jax.P()), out_specs=in_spec, check_vma=False)\n",
+ "def indexed_add_one(x, index):\n",
+ " out_shape = jax.ShapeDtypeStruct((x.shape[0], x.shape[1] // 2), x.dtype)\n",
+ " return pl.kernel(indexed_add_one_kernel,\n",
+ " out_shape=out_shape, mesh=tc_mesh,\n",
+ " scratch_shapes=[pltpu.SMEM((1,), jnp.int32)])((x, index))\n",
+ "\n",
+ "\n",
+ "xs = jax.random.normal(jax.random.key(0), input_shape, jnp.float32)\n",
+ "xs = jax.device_put(xs, NamedSharding(mesh, in_spec))\n",
+ "idx = 256\n",
+ "y = indexed_add_one(xs, jnp.array([idx]))\n",
+ "\n",
+ "np.testing.assert_array_equal(y, xs[:, idx:(idx+512)] + 1)"
+ ],
+ "metadata": {
+ "id": "SE8pTStHeSWB",
+ "executionInfo": {
+ "status": "ok",
+ "timestamp": 1764795550778,
+ "user_tz": 480,
+ "elapsed": 378,
+ "user": {
+ "displayName": "Ivy Zheng",
+ "userId": "15297372265856137303"
+ }
+ }
+ },
+ "execution_count": 7,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Mapping over SparseCores\n",
+ "\n",
+ "TPU v5p contains 4 [SparseCores](https://openxla.org/xla/sparsecore), which are specialized for sparse memory access and operations. This guide will not dive into the full capabilities of SparseCore, but rather show how to run a program on SparseCore with the same semantics and minimal changes from the TensorCore code.\n",
+ "\n",
+ "Start with knowing the basic SparseCore specs of your chip, and create a `VectorSubcoreMesh` for vector operations. Note that each SparseCore has 16 (or other number) subcores on TPU v5p, and `core_map` will run your code SPMD on each of them."
+ ],
+ "metadata": {
+ "id": "B8qeo-4A2KRm"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "sc_info = pltpu.get_tpu_info().sparse_core\n",
+ "assert sc_info is not None\n",
+ "print(sc_info)\n",
+ "\n",
+ "sc_mesh = plsc.VectorSubcoreMesh(\n",
+ " core_axis_name=\"core\", subcore_axis_name=\"subcore\",\n",
+ " num_cores=sc_info.num_cores\n",
+ ")\n",
+ "sc_num_cores = sc_info.num_cores\n",
+ "sc_num_subcores = sc_info.num_subcores"
+ ],
+ "metadata": {
+ "id": "AHurx-yyYVvs",
+ "executionInfo": {
+ "status": "ok",
+ "timestamp": 1764795551102,
+ "user_tz": 480,
+ "elapsed": 55,
+ "user": {
+ "displayName": "Ivy Zheng",
+ "userId": "15297372265856137303"
+ }
+ },
+ "outputId": "aa4a45da-dd9a-4f57-de1a-bc9b5872b2df"
+ },
+ "execution_count": 8,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "SparseCoreInfo(num_cores=4, num_subcores=16, num_lanes=8)\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "The code below is very similar to the `add_one_kernel` we wrote earlier, except for a few differences:\n",
+ "\n",
+ "1. You need to split the work amongst all subcores, so a few lines to compute the specific slice for each subcore.\n",
+ "\n",
+ "1. SparseCore register computation allows smaller slices (`4x16` max for int32), so you need nested loops to iterate the slice during computation phase."
+ ],
+ "metadata": {
+ "id": "n2_dfsUWFgwU"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "input_shape = (4096, 128)\n",
+ "SC_REG_OP_SHAPE = (4, 16)\n",
+ "\n",
+ "def sc_add_one_body(in_vmem, out_vmem):\n",
+ " @pl.loop(0, in_vmem.shape[0], step=SC_REG_OP_SHAPE[0])\n",
+ " def _reg_loop_0(c0):\n",
+ " @pl.loop(0, in_vmem.shape[1], step=SC_REG_OP_SHAPE[1])\n",
+ " def _reg_loop_1(c1):\n",
+ " slc = (pl.ds(c0, SC_REG_OP_SHAPE[0]), pl.ds(c1, SC_REG_OP_SHAPE[1]))\n",
+ " out_vmem[slc] = in_vmem[slc] + 1\n",
+ "\n",
+ "\n",
+ "def sc_add_one_kernel(x_hbm_ref, o_hbm_ref):\n",
+ " in_shape = x_hbm_ref.shape\n",
+ " core_idx = jax.lax.axis_index('core')\n",
+ " subcore_idx = jax.lax.axis_index(\"subcore\")\n",
+ " cm_idx = core_idx * sc_num_subcores + subcore_idx # index on the core_map\n",
+ " slc_size = in_shape[0] // (sc_num_subcores * sc_num_cores)\n",
+ " index_map = lambda i, j: (\n",
+ " pl.ds(pl.multiple_of(cm_idx * slc_size + i * 8, 8), 8), j)\n",
+ "\n",
+ " pltpu.emit_pipeline(\n",
+ " sc_add_one_body,\n",
+ " grid=(slc_size // 8, in_shape[1] // 128),\n",
+ " in_specs=[pl.BlockSpec(\n",
+ " block_shape=(pl.BoundedSlice(8), 128), index_map=index_map,\n",
+ " )],\n",
+ " out_specs=[pl.BlockSpec(\n",
+ " block_shape=(pl.BoundedSlice(8), 128), index_map=index_map,\n",
+ " )]\n",
+ " )(x_hbm_ref, o_hbm_ref)\n",
+ "\n",
+ "\n",
+ "@jax.jit\n",
+ "@partial(jax.shard_map, mesh=mesh, in_specs=in_spec, out_specs=in_spec, check_vma=False)\n",
+ "def sc_add_one(x):\n",
+ " return pl.kernel(sc_add_one_kernel, out_shape=x, mesh=sc_mesh, scratch_shapes=[])(x)\n",
+ "\n",
+ "\n",
+ "x = jax.random.randint(jax.random.key(0), input_shape, 0, 64, jnp.int32)\n",
+ "x = jax.device_put(x, NamedSharding(mesh, in_spec))\n",
+ "y = sc_add_one(x)\n",
+ "\n",
+ "np.testing.assert_array_equal(y, x + 1)"
+ ],
+ "metadata": {
+ "id": "6fNShx6k2kxi",
+ "executionInfo": {
+ "status": "ok",
+ "timestamp": 1764795552411,
+ "user_tz": 480,
+ "elapsed": 1117,
+ "user": {
+ "displayName": "Ivy Zheng",
+ "userId": "15297372265856137303"
+ }
+ }
+ },
+ "execution_count": 9,
+ "outputs": []
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/pallas/tpu/core_map.md b/docs/pallas/tpu/core_map.md
new file mode 100644
index 000000000000..4e00399b6fe8
--- /dev/null
+++ b/docs/pallas/tpu/core_map.md
@@ -0,0 +1,363 @@
+# Pallas Core-specific Programming
+
+In this guide, we explore using `pl.core_map` to write Pallas kernels. Compared with `pallas_call`, `core_map` offers a few key characteristics:
+
+* **Per-core level programming**: You write code for an TPU/GPU core, not for a JAX device. This gives you full control over what runs on every core, or how cores communicate and distribute work among one another.
+
+* **Collectives**: `core_map` explicitly models physical cores, so inter-core communication can be expressed safely.
+
+* **Platform generic**: `core_map` programming model works for TPU (TensorCore and SparseCore) and GPU with minimal boilerplate changes.
+
+This guide focuses on TPU. For how to use `core_map` on GPU to achieve higher thread flexibility, check out our [Pallas GPU `core_map` tutorial](https://docs.jax.dev/en/latest/pallas/gpu/reference.html#using-core-map).
+
+## Environment setup
+
+Modern accelerators often have multiple cores under a device. For recent TPU chips (v4, v5p), every JAX device may contains 2 TensorCores (aka. a [Megacore](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#chips)). Some TPUs (v5p, v6e, 7x) also contain [SparseCores](https://openxla.org/xla/sparsecore#specifications_at_a_glance), each of which consists of many subcores.
+
+This guide was written on a v5p chip, which contains 4 devices (2 TensorCores each) and 4 SparseCores, each with 16 subcores.
+
+
+```python
+from functools import partial
+
+import jax
+from jax.sharding import NamedSharding
+from jax.experimental import pallas as pl
+from jax.experimental.pallas import tpu as pltpu
+from jax.experimental.pallas import tpu_sc as plsc
+import jax.numpy as jnp
+import numpy as np
+
+
+num_devices = jax.local_device_count()
+assert num_devices > 1, "Please run this notebook with more than one device."
+
+tpu_info = pltpu.get_tpu_info() # This notebook only runs on TPU.
+print(f"Running on {num_devices} TPU {tpu_info.chip_version} devices.")
+```
+
+ Running on 4 TPU v5p devices.
+
+
+In addition to the typical TPU device mesh, you need to make a mesh of cores. Consider this as an addition dimension called `core`, with length 2, in addition to the 4-device mesh you work with. That is 8 cores in total.
+
+
+```python
+# Mesh of devices
+mesh = jax.make_mesh((jax.device_count(),), ('device',))
+print(mesh)
+
+# Mesh of cores, within a JAX device
+tc_mesh = pltpu.create_tensorcore_mesh('core')
+print(tc_mesh)
+
+num_devices = mesh.size
+num_cores = len(tc_mesh.devices)
+print(f"There are {num_devices} devices, and {num_cores} cores each.")
+```
+
+ Mesh('device': 4, axis_types=(Explicit,))
+ TensorCoreMesh(devices=array([TensorCore(id=0), TensorCore(id=1)], dtype=object), axis_names=('core',))
+ There are 4 devices, and 2 cores each.
+
+
+## A simple per-core kernel
+
+`pl.core_map` allows you to write per-core local code, just as `jax.shard_map` allows you to write per-device code.
+
+In the example kernel below, each core has its own VMEM and semaphore allocations. As with normal kernel, you can initiate copies between HBM and VMEM refs using `pltpu.async_copy`.
+
+**Communication between cores**
+
+Before communicating between cores, it is good practice to perform a barrier (using `pltpu.semaphore_signal`) to ensure resources have been allocated and both cores are at the same point during the program.
+
+Once the cores are synchronized, use `pltpu.make_async_remote_copy` to send data between them. The `device_id` keyword argument generically allows sending to any core on any device, but if you just pass in `{'core': other_core_id}`, it will perform a intra-device inter-core copy (the other axis names are held constant).
+
+
+
+```python
+# This runs on every core
+def swap_cores_kernel(in_hbm, out_hbm,
+ in_vmem, scratch_vmem, out_vmem,
+ sem, send_sem, recv_sem):
+ core_index = jax.lax.axis_index('core')
+ num_cores = jax.lax.axis_size('core')
+ slc_size = in_hbm.shape[-1] // num_cores
+ slc = pl.ds(core_index * slc_size, slc_size)
+
+ # Copy in a core-dependent slice of the input
+ pltpu.async_copy(in_hbm.at[:, slc], in_vmem, sem).wait()
+
+ # A barrier to make sure all cores have entered run_scoped.
+ # You won't need this if not doing inter-core communications.
+ dst_core = (core_index + 1) % num_cores
+ sem0 = pltpu.get_barrier_semaphore()
+ pltpu.semaphore_signal(sem0, 1, device_id={'core': dst_core})
+ pltpu.semaphore_wait(sem0, 1)
+
+ # Swap data between core 0 and core 1
+ the_copy = pltpu.make_async_remote_copy(
+ in_vmem, scratch_vmem, send_sem, recv_sem, device_id={'core': dst_core},
+ )
+ the_copy.start()
+ the_copy.wait()
+
+ # Core-local compute
+ out_vmem[...] = scratch_vmem[...] * 2
+
+ # Copy out the output
+ pltpu.async_copy(out_vmem, out_hbm.at[:, slc], sem).wait()
+
+```
+
+Once you have the local kernel:
+
+ * Start your top-level JAX code with HBM refs, and allocate output refs if needed.
+
+ * Use `pl.core_map`, which takes the TensorCore mesh, to start per-core programming.
+
+ * You will need `collective_id` for the barrier semaphore.
+
+ * Inside `pl.core_map`, invoke `pl.run_scoped` to allocate per-core scratch spaces (VMEM and semaphores) and run the local kernel.
+
+
+```python
+input_shape = (32, 256)
+local_vmem_shape = (32 // num_devices, 256 // num_cores)
+in_spec = jax.P('device', None)
+sharding = NamedSharding(mesh, in_spec)
+
+@jax.jit
+@partial(jax.shard_map, mesh=mesh, in_specs=in_spec, out_specs=in_spec,
+ check_vma=False)
+def swap_cores(x):
+ # Get buffers out of the input and output
+ x_hbm_ref = jax.new_ref(x)
+ o_hbm_ref = jax.new_ref(jax.lax.empty(x.shape, x.dtype))
+
+ @pl.core_map(tc_mesh, compiler_params=pltpu.CompilerParams(collective_id=0))
+ def _():
+ pl.run_scoped(
+ partial(swap_cores_kernel, x_hbm_ref, o_hbm_ref),
+ *([pltpu.VMEM(local_vmem_shape, x.dtype)] * 3), # VMEM allocations
+ *([pltpu.SemaphoreType.DMA] * 3), # semaphores
+ )
+ return o_hbm_ref[...]
+
+
+x = jax.random.normal(jax.random.key(0), input_shape, jnp.float32)
+x = jax.device_put(x, sharding)
+y = swap_cores(x)
+
+np.testing.assert_array_equal(y[:, 128:], x[:, :128] * 2)
+np.testing.assert_array_equal(y[:, :128], x[:, 128:] * 2)
+```
+
+### Save the boilerplate
+
+You can use the `pl.kernel` decorator to wrap boilerplate such as `core_map`, `run_scoped`, and output buffer allocation.
+
+Note that this should run inside any `jax.shard_map` you may have at the top level.
+
+
+```python
+@jax.jit
+@partial(jax.shard_map, mesh=mesh, in_specs=in_spec, out_specs=in_spec, check_vma=False)
+def swap_cores(x):
+ scratch_shapes = [pltpu.VMEM(local_vmem_shape, x.dtype)] * 3 + [pltpu.SemaphoreType.DMA] * 3
+ return pl.kernel(swap_cores_kernel, out_shape=x, mesh=tc_mesh,
+ scratch_shapes=scratch_shapes,
+ compiler_params=pltpu.CompilerParams(collective_id=0))(x)
+
+y = swap_cores(x)
+np.testing.assert_array_equal(y[:, 128:], x[:, :128] * 2)
+np.testing.assert_array_equal(y[:, :128], x[:, 128:] * 2)
+```
+
+## Pipelining with `core_map`
+
+Note that the kernel above only does simple copies and compute, without automatic pipelining via Pallas `grid` and `BlockSpec`. To do pipelining inside `core_map`, use `pltpu.emit_pipeline` inside the core-local kernel.
+
+**Automatically parallelize work amongst cores**
+
+The simple way is to annotate a block axis as `pltpu.PARALLEL`, and Pallas will automatically parallelize work along this axis. Both `pl.pallas_call` and `pltpu.emit_pipeline` supports this, via arguments `core_axis` and `dimension_semantics`. The `pallas_call` example is [in another guide](https://docs.jax.dev/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration), and the `emit_pipeline` case is shown below.
+
+When the `PARALLEL` annotation is provided, the corresponding grid dimension will be logically split and executed on separate cores. (The exact semantics of which grid dimensions are executed on which core is guaranteed).
+
+**Scratch shapes allocation**
+
+Note that in the example below, the top level `pl.run_scoped` (wrapped inside `kernel`) did not allocate any VMEM scratch buffers. Instead, `pltpu.emit_pipeline` allocates its own scratch buffers in VMEM and use them for its multiple buffering.
+
+
+
+```python
+def add_one_body(in_vmem, out_vmem):
+ out_vmem[...] = in_vmem[...] + 1
+
+input_shape = (1024, 1024)
+in_spec = jax.P('device', None)
+
+def add_one_kernel(x_hbm_ref, o_hbm_ref):
+ in_shape = x_hbm_ref.shape
+ pltpu.emit_pipeline(
+ add_one_body,
+ grid=(in_shape[0] // 8, in_shape[1] // 128),
+ in_specs=[pl.BlockSpec(
+ block_shape=(8, 128), index_map=lambda i, j: (i, j),
+ )],
+ out_specs=[pl.BlockSpec(
+ block_shape=(8, 128), index_map=lambda i, j: (i, j),
+ )],
+ core_axis_name='core',
+ dimension_semantics=(pltpu.PARALLEL, pltpu.ARBITRARY),
+ )(x_hbm_ref, o_hbm_ref)
+
+
+@jax.jit
+@partial(jax.shard_map, mesh=mesh, in_specs=in_spec, out_specs=in_spec, check_vma=False)
+def add_one(x):
+ return pl.kernel(add_one_kernel, out_shape=x, mesh=tc_mesh, scratch_shapes=[])(x)
+
+
+x = jax.random.normal(jax.random.key(0), input_shape, jnp.float32)
+x = jax.device_put(x, NamedSharding(mesh, in_spec))
+y = add_one(x)
+
+np.testing.assert_array_equal(y, x + 1)
+```
+
+## Scalar prefetch
+
+The code below extends the kernel above but uses [scalar prefetch and dynamic block indexing](https://docs.jax.dev/en/latest/pallas/tpu/sparse.html) to select a specific sub-slice of the input.
+
+This involves pre-allocating an SMEM buffer (via the `pl.run_scoped` call inside `kernel`) and populating the buffer using a `sync_copy` before the pipeline starts. Close over the dynamic index value inside the `index_map` to use it.
+
+**Manually delegate work amongst cores**
+
+The code example below also shows how `core_map` allows you to customize exactly how the work is split between cores, without relying on the automatic API shown above.
+
+To achieve that, customize your `index_map` to use the core index to work on different slices on different cores.
+
+
+
+```python
+input_shape = (1024, 1024)
+in_spec = jax.P('device', None)
+output_shape = (1024, 512)
+
+def indexed_add_one_kernel(in_refs, out_refs, i_smem_ref):
+ (x_hbm_ref, i_hbm_ref), o_hbm_ref = in_refs, out_refs
+ in_shape = x_hbm_ref.shape
+ pltpu.sync_copy(i_hbm_ref, i_smem_ref)
+
+ core_idx = jax.lax.axis_index('core')
+ core_slc_size = in_shape[0] // num_cores
+ i_map = lambda i: core_idx * core_slc_size // 8 + i # split work among cores
+ j_map = lambda j: i_smem_ref[0] // 128 + j # use the prefetched offset
+
+ pltpu.emit_pipeline(
+ add_one_body,
+ grid=(core_slc_size // 8, output_shape[1] // 128),
+ in_specs=[pl.BlockSpec(
+ block_shape=(8, 128), index_map=lambda i, j: (i_map(i), j_map(j)),
+ )],
+ out_specs=[pl.BlockSpec(
+ block_shape=(8, 128), index_map=lambda i, j: (i_map(i), j),
+ )]
+ )(x_hbm_ref, o_hbm_ref)
+
+
+@jax.jit
+@partial(jax.shard_map, mesh=mesh,
+ in_specs=(in_spec, jax.P()), out_specs=in_spec, check_vma=False)
+def indexed_add_one(x, index):
+ out_shape = jax.ShapeDtypeStruct((x.shape[0], x.shape[1] // 2), x.dtype)
+ return pl.kernel(indexed_add_one_kernel,
+ out_shape=out_shape, mesh=tc_mesh,
+ scratch_shapes=[pltpu.SMEM((1,), jnp.int32)])((x, index))
+
+
+xs = jax.random.normal(jax.random.key(0), input_shape, jnp.float32)
+xs = jax.device_put(xs, NamedSharding(mesh, in_spec))
+idx = 256
+y = indexed_add_one(xs, jnp.array([idx]))
+
+np.testing.assert_array_equal(y, xs[:, idx:(idx+512)] + 1)
+```
+
+## Mapping over SparseCores
+
+TPU v5p contains 4 [SparseCores](https://openxla.org/xla/sparsecore), which are specialized for sparse memory access and operations. This guide will not dive into the full capabilities of SparseCore, but rather show how to run a program on SparseCore with the same semantics and minimal changes from the TensorCore code.
+
+Start with knowing the basic SparseCore specs of your chip, and create a `VectorSubcoreMesh` for vector operations. Note that each SparseCore has 16 (or other number) subcores on TPU v5p, and `core_map` will run your code SPMD on each of them.
+
+
+```python
+sc_info = pltpu.get_tpu_info().sparse_core
+assert sc_info is not None
+print(sc_info)
+
+sc_mesh = plsc.VectorSubcoreMesh(
+ core_axis_name="core", subcore_axis_name="subcore",
+ num_cores=sc_info.num_cores
+)
+sc_num_cores = sc_info.num_cores
+sc_num_subcores = sc_info.num_subcores
+```
+
+ SparseCoreInfo(num_cores=4, num_subcores=16, num_lanes=8)
+
+
+The code below is very similar to the `add_one_kernel` we wrote earlier, except for a few differences:
+
+1. You need to split the work amongst all subcores, so a few lines to compute the specific slice for each subcore.
+
+1. SparseCore register computation allows smaller slices (`4x16` max for int32), so you need nested loops to iterate the slice during computation phase.
+
+
+```python
+input_shape = (4096, 128)
+SC_REG_OP_SHAPE = (4, 16)
+
+def sc_add_one_body(in_vmem, out_vmem):
+ @pl.loop(0, in_vmem.shape[0], step=SC_REG_OP_SHAPE[0])
+ def _reg_loop_0(c0):
+ @pl.loop(0, in_vmem.shape[1], step=SC_REG_OP_SHAPE[1])
+ def _reg_loop_1(c1):
+ slc = (pl.ds(c0, SC_REG_OP_SHAPE[0]), pl.ds(c1, SC_REG_OP_SHAPE[1]))
+ out_vmem[slc] = in_vmem[slc] + 1
+
+
+def sc_add_one_kernel(x_hbm_ref, o_hbm_ref):
+ in_shape = x_hbm_ref.shape
+ core_idx = jax.lax.axis_index('core')
+ subcore_idx = jax.lax.axis_index("subcore")
+ cm_idx = core_idx * sc_num_subcores + subcore_idx # index on the core_map
+ slc_size = in_shape[0] // (sc_num_subcores * sc_num_cores)
+ index_map = lambda i, j: (
+ pl.ds(pl.multiple_of(cm_idx * slc_size + i * 8, 8), 8), j)
+
+ pltpu.emit_pipeline(
+ sc_add_one_body,
+ grid=(slc_size // 8, in_shape[1] // 128),
+ in_specs=[pl.BlockSpec(
+ block_shape=(pl.BoundedSlice(8), 128), index_map=index_map,
+ )],
+ out_specs=[pl.BlockSpec(
+ block_shape=(pl.BoundedSlice(8), 128), index_map=index_map,
+ )]
+ )(x_hbm_ref, o_hbm_ref)
+
+
+@jax.jit
+@partial(jax.shard_map, mesh=mesh, in_specs=in_spec, out_specs=in_spec, check_vma=False)
+def sc_add_one(x):
+ return pl.kernel(sc_add_one_kernel, out_shape=x, mesh=sc_mesh, scratch_shapes=[])(x)
+
+
+x = jax.random.randint(jax.random.key(0), input_shape, 0, 64, jnp.int32)
+x = jax.device_put(x, NamedSharding(mesh, in_spec))
+y = sc_add_one(x)
+
+np.testing.assert_array_equal(y, x + 1)
+```
diff --git a/docs/pallas/tpu/distributed.ipynb b/docs/pallas/tpu/distributed.ipynb
index 434f610a0a79..feebb7c2f8e7 100644
--- a/docs/pallas/tpu/distributed.ipynb
+++ b/docs/pallas/tpu/distributed.ipynb
@@ -273,9 +273,9 @@
" num_scalar_prefetch=0,\n",
" # MemorySpace.ANY will (usually) place the tensor in HBM.\n",
" in_specs=[\n",
- " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n",
+ " pl.BlockSpec(memory_space=pl.ANY),\n",
" ],\n",
- " out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n",
+ " out_specs=pl.BlockSpec(memory_space=pl.ANY),\n",
" scratch_shapes=(\n",
" # We allocate DMA semaphores in scratch memory.\n",
" [pltpu.SemaphoreType.DMA] * 2\n",
@@ -421,9 +421,9 @@
" num_scalar_prefetch=0,\n",
" in_specs=[\n",
" # MemorySpace.ANY will (usually) place the tensor in HBM.\n",
- " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n",
+ " pl.BlockSpec(memory_space=pl.ANY),\n",
" ],\n",
- " out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n",
+ " out_specs=pl.BlockSpec(memory_space=pl.ANY),\n",
" scratch_shapes=(\n",
" # DMA semaphores are allocated in scratch memory.\n",
" # We allocated one semaphore for a local HBM-VMEM copy,\n",
@@ -644,7 +644,7 @@
"\n",
"The main body assumes that a value has already been copied into our local working slot, either from the previous iteration or from the prologue. A complicating factor is that our destination buffers live in HBM, but we need to load values to VMEM before we perform arithmetic. Therefore, we simultaneously copy the working slot value into our VMEM (`receive_scratch`) and pass the value on to our right neighbor's receiving slot. Once the value has been copied into our VMEM, we can accumulate it into our result (contained in `o_ref`).\n",
"\n",
- "A subtle race condition can occur if one device runs one loop ahead of it's right neighbor. In this case, it could copy into the receiver's `working_slot` at the same time the receiver is reading from it. In order to avoid this, each device will block on a `REGULAR` semaphore before copying into the right neighbor's `dst_ref` until it has signaled that it is done reading from its `working_slot`. This race condition is rarely triggered for a small kernel such as this example, but can it can be explicitly triggered if for example using a `pltpu.delay` instruction to artificially hang a device.\n",
+ "A subtle race condition can occur if one device runs one loop ahead of it's right neighbor. In this case, it could copy into the receiver's `working_slot` at the same time the receiver is reading from it. In order to avoid this, each device will block on a `REGULAR` semaphore before copying into the right neighbor's `dst_ref` until it has signaled that it is done reading from its `working_slot`. This race condition is rarely triggered for a small kernel such as this example, but can it can be explicitly triggered if for example using a `pl.delay` instruction to artificially hang a device.\n",
"\n",
"Note that this is not an optimal or fully general kernel, as the block sizes must entirely fit in VMEM and we could better interleave communication and accumulation. We will discuss these optimizations in later sections."
]
@@ -809,13 +809,13 @@
" num_scalar_prefetch=0,\n",
" in_specs=[\n",
" # Our input lives in VMEM\n",
- " pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),\n",
+ " pl.BlockSpec(memory_space=pltpu.VMEM),\n",
" ],\n",
" out_specs=[\n",
" # Our output lives in VMEM\n",
- " pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),\n",
+ " pl.BlockSpec(memory_space=pltpu.VMEM),\n",
" # Our double-buffer lives in HBM\n",
- " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n",
+ " pl.BlockSpec(memory_space=pl.ANY),\n",
" ],\n",
" grid=(num_devices,),\n",
" scratch_shapes=(\n",
@@ -1146,11 +1146,11 @@
"grid_spec = pltpu.PrefetchScalarGridSpec(\n",
" num_scalar_prefetch=0,\n",
" in_specs=[\n",
- " pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),\n",
+ " pl.BlockSpec(memory_space=pltpu.VMEM),\n",
" ],\n",
" out_specs=[\n",
- " pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),\n",
- " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n",
+ " pl.BlockSpec(memory_space=pltpu.VMEM),\n",
+ " pl.BlockSpec(memory_space=pl.ANY),\n",
" ],\n",
" grid=(num_devices, 2),\n",
" scratch_shapes=(\n",
@@ -1576,11 +1576,11 @@
"grid_spec = pltpu.PrefetchScalarGridSpec(\n",
" num_scalar_prefetch=0,\n",
" in_specs=[\n",
- " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n",
+ " pl.BlockSpec(memory_space=pl.ANY),\n",
" ],\n",
" out_specs=[\n",
- " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n",
- " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n",
+ " pl.BlockSpec(memory_space=pl.ANY),\n",
+ " pl.BlockSpec(memory_space=pl.ANY),\n",
" ],\n",
" grid=(num_devices, 2),\n",
" scratch_shapes=(\n",
diff --git a/docs/pallas/tpu/distributed.md b/docs/pallas/tpu/distributed.md
index 678bd98f4470..e8bbdb3089cc 100644
--- a/docs/pallas/tpu/distributed.md
+++ b/docs/pallas/tpu/distributed.md
@@ -235,9 +235,9 @@ grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
# MemorySpace.ANY will (usually) place the tensor in HBM.
in_specs=[
- pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
+ pl.BlockSpec(memory_space=pl.ANY),
],
- out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
+ out_specs=pl.BlockSpec(memory_space=pl.ANY),
scratch_shapes=(
# We allocate DMA semaphores in scratch memory.
[pltpu.SemaphoreType.DMA] * 2
@@ -357,9 +357,9 @@ grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
# MemorySpace.ANY will (usually) place the tensor in HBM.
- pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
+ pl.BlockSpec(memory_space=pl.ANY),
],
- out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
+ out_specs=pl.BlockSpec(memory_space=pl.ANY),
scratch_shapes=(
# DMA semaphores are allocated in scratch memory.
# We allocated one semaphore for a local HBM-VMEM copy,
@@ -556,7 +556,7 @@ The prologue (executed when `outer_step==0`) first initiates a barrier with both
The main body assumes that a value has already been copied into our local working slot, either from the previous iteration or from the prologue. A complicating factor is that our destination buffers live in HBM, but we need to load values to VMEM before we perform arithmetic. Therefore, we simultaneously copy the working slot value into our VMEM (`receive_scratch`) and pass the value on to our right neighbor's receiving slot. Once the value has been copied into our VMEM, we can accumulate it into our result (contained in `o_ref`).
-A subtle race condition can occur if one device runs one loop ahead of it's right neighbor. In this case, it could copy into the receiver's `working_slot` at the same time the receiver is reading from it. In order to avoid this, each device will block on a `REGULAR` semaphore before copying into the right neighbor's `dst_ref` until it has signaled that it is done reading from its `working_slot`. This race condition is rarely triggered for a small kernel such as this example, but can it can be explicitly triggered if for example using a `pltpu.delay` instruction to artificially hang a device.
+A subtle race condition can occur if one device runs one loop ahead of it's right neighbor. In this case, it could copy into the receiver's `working_slot` at the same time the receiver is reading from it. In order to avoid this, each device will block on a `REGULAR` semaphore before copying into the right neighbor's `dst_ref` until it has signaled that it is done reading from its `working_slot`. This race condition is rarely triggered for a small kernel such as this example, but can it can be explicitly triggered if for example using a `pl.delay` instruction to artificially hang a device.
Note that this is not an optimal or fully general kernel, as the block sizes must entirely fit in VMEM and we could better interleave communication and accumulation. We will discuss these optimizations in later sections.
@@ -703,13 +703,13 @@ grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
# Our input lives in VMEM
- pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),
+ pl.BlockSpec(memory_space=pltpu.VMEM),
],
out_specs=[
# Our output lives in VMEM
- pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),
+ pl.BlockSpec(memory_space=pltpu.VMEM),
# Our double-buffer lives in HBM
- pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
+ pl.BlockSpec(memory_space=pl.ANY),
],
grid=(num_devices,),
scratch_shapes=(
@@ -1019,11 +1019,11 @@ out_shape = (
grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
- pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),
+ pl.BlockSpec(memory_space=pltpu.VMEM),
],
out_specs=[
- pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),
- pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
+ pl.BlockSpec(memory_space=pltpu.VMEM),
+ pl.BlockSpec(memory_space=pl.ANY),
],
grid=(num_devices, 2),
scratch_shapes=(
@@ -1410,11 +1410,11 @@ out_shape = (
grid_spec = pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
- pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
+ pl.BlockSpec(memory_space=pl.ANY),
],
out_specs=[
- pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
- pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
+ pl.BlockSpec(memory_space=pl.ANY),
+ pl.BlockSpec(memory_space=pl.ANY),
],
grid=(num_devices, 2),
scratch_shapes=(
diff --git a/docs/pallas/tpu/index.rst b/docs/pallas/tpu/index.rst
index 1aca99f0dda9..784e037bca06 100644
--- a/docs/pallas/tpu/index.rst
+++ b/docs/pallas/tpu/index.rst
@@ -11,5 +11,6 @@ TPU specific documentation.
matmul
sparse
distributed
+ core_map
prng
diff --git a/docs/pallas/tpu/pipelining.ipynb b/docs/pallas/tpu/pipelining.ipynb
index 64f095a4cb89..12a4b852e84a 100644
--- a/docs/pallas/tpu/pipelining.ipynb
+++ b/docs/pallas/tpu/pipelining.ipynb
@@ -123,10 +123,10 @@
"\n",
"| Pallas Enum | TPU Memory Space | Type (DRAM/SRAM) |\n",
"| --- | --- | --- |\n",
- "| `pltpu.MemorySpace.ANY` | HBM (usually) or VMEM | DRAM |\n",
- "| `pltpu.MemorySpace.VMEM` | VMEM | SRAM |\n",
- "| `pltpu.MemorySpace.SMEM` | SMEM | SRAM |\n",
- "| `pltpu.MemorySpace.SEMAPHORE` | Semaphore | SRAM |\n",
+ "| `pl.ANY` | HBM (usually) or VMEM | DRAM |\n",
+ "| `pltpu.VMEM` | VMEM | SRAM |\n",
+ "| `pltpu.SMEM` | SMEM | SRAM |\n",
+ "| `pltpu.SEMAPHORE` | Semaphore | SRAM |\n",
"\n",
"- `MemorySpace.VMEM` denotes vector SRAM. It is the default memory space if nothing is specified.\n",
"- `MemorySpace.SMEM` denotes scalar SRAM. Only scalar loads and stores can be performed to/from SMEM.\n",
@@ -164,9 +164,9 @@
"\n",
"x = jax.random.uniform(jax.random.key(0), (8, 128), jnp.float32)\n",
"out = pl.pallas_call(hbm_vmem_kernel,\n",
- " in_specs=[pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY)],\n",
+ " in_specs=[pl.BlockSpec(memory_space=pl.ANY)],\n",
" out_shape=jax.ShapeDtypeStruct((1, 128), jnp.float32),\n",
- " scratch_shapes=(pltpu.MemorySpace.VMEM(shape=(1, 128), dtype=jnp.float32),)\n",
+ " scratch_shapes=(pltpu.VMEM(shape=(1, 128), dtype=jnp.float32),)\n",
")(x)\n",
"\n",
"np.testing.assert_allclose(out, x[0:1] + 1)"
@@ -283,12 +283,12 @@
"x = jax.random.uniform(jax.random.key(0), (8, 128), jnp.float32)\n",
"slices = jnp.array([[0, 2], [2, 3], [3, 5], [5, 8]], dtype=jnp.int32)\n",
"\n",
- "hbm_block_spec = pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY)\n",
+ "hbm_block_spec = pl.BlockSpec(memory_space=pl.ANY)\n",
"out = pl.pallas_call(dynamic_block_example_kernel,\n",
" in_specs=[hbm_block_spec, hbm_block_spec],\n",
" out_specs=hbm_block_spec,\n",
" out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),\n",
- " scratch_shapes=(pltpu.MemorySpace.SMEM(slices.shape, jnp.int32),)\n",
+ " scratch_shapes=(pltpu.SMEM(slices.shape, jnp.int32),)\n",
" )(x, slices)\n",
"\n",
"np.testing.assert_allclose(x, out)"
diff --git a/docs/pallas/tpu/pipelining.md b/docs/pallas/tpu/pipelining.md
index d91ebdd63f65..02c9187edd2e 100644
--- a/docs/pallas/tpu/pipelining.md
+++ b/docs/pallas/tpu/pipelining.md
@@ -95,10 +95,10 @@ Pallas exposes all levels of the TPU memory hierarchy to users. The following ta
| Pallas Enum | TPU Memory Space | Type (DRAM/SRAM) |
| --- | --- | --- |
-| `pltpu.MemorySpace.ANY` | HBM (usually) or VMEM | DRAM |
-| `pltpu.MemorySpace.VMEM` | VMEM | SRAM |
-| `pltpu.MemorySpace.SMEM` | SMEM | SRAM |
-| `pltpu.MemorySpace.SEMAPHORE` | Semaphore | SRAM |
+| `pl.ANY` | HBM (usually) or VMEM | DRAM |
+| `pltpu.VMEM` | VMEM | SRAM |
+| `pltpu.SMEM` | SMEM | SRAM |
+| `pltpu.SEMAPHORE` | Semaphore | SRAM |
- `MemorySpace.VMEM` denotes vector SRAM. It is the default memory space if nothing is specified.
- `MemorySpace.SMEM` denotes scalar SRAM. Only scalar loads and stores can be performed to/from SMEM.
@@ -129,9 +129,9 @@ def hbm_vmem_kernel(x_hbm_ref, out_vmem_ref, scratch_vmem_ref):
x = jax.random.uniform(jax.random.key(0), (8, 128), jnp.float32)
out = pl.pallas_call(hbm_vmem_kernel,
- in_specs=[pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY)],
+ in_specs=[pl.BlockSpec(memory_space=pl.ANY)],
out_shape=jax.ShapeDtypeStruct((1, 128), jnp.float32),
- scratch_shapes=(pltpu.MemorySpace.VMEM(shape=(1, 128), dtype=jnp.float32),)
+ scratch_shapes=(pltpu.VMEM(shape=(1, 128), dtype=jnp.float32),)
)(x)
np.testing.assert_allclose(out, x[0:1] + 1)
@@ -229,12 +229,12 @@ def dynamic_block_example_kernel(x_hbm, slices_hbm, o_hbm, slices_smem):
x = jax.random.uniform(jax.random.key(0), (8, 128), jnp.float32)
slices = jnp.array([[0, 2], [2, 3], [3, 5], [5, 8]], dtype=jnp.int32)
-hbm_block_spec = pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY)
+hbm_block_spec = pl.BlockSpec(memory_space=pl.ANY)
out = pl.pallas_call(dynamic_block_example_kernel,
in_specs=[hbm_block_spec, hbm_block_spec],
out_specs=hbm_block_spec,
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
- scratch_shapes=(pltpu.MemorySpace.SMEM(slices.shape, jnp.int32),)
+ scratch_shapes=(pltpu.SMEM(slices.shape, jnp.int32),)
)(x, slices)
np.testing.assert_allclose(x, out)
diff --git a/docs/persistent_compilation_cache.md b/docs/persistent_compilation_cache.md
index ae3d6ddfcad0..6e82d995b782 100644
--- a/docs/persistent_compilation_cache.md
+++ b/docs/persistent_compilation_cache.md
@@ -132,6 +132,34 @@ Cloud Storage (GCS) bucket. We recommend the following configuration:
* All encryption policies are supported.
+It is **recommended** to use
+[Google Cloud Storage Fuse](https://cloud.google.com/storage/docs/cloud-storage-fuse)
+to mount the GCS bucket as a local directory. This is because when running JAX
+in a multi-node setup, multiple nodes might try to write to the cache
+simultaneously, leading to GCS rate-limit errors. GCSFuse handles this by
+ensuring that only one process can write to a file at a time, preventing these
+errors.
+
+To set up GCSFuse, follow instructions for
+[GCE](https://cloud.google.com/storage/docs/cloud-storage-fuse/mount-bucket) or
+[GKE](https://cloud.google.com/kubernetes-engine/docs/how-to/cloud-storage-fuse-csi-driver-setup).
+For better performance, enable file caching
+([GCE](https://cloud.google.com/storage/docs/cloud-storage-fuse/file-caching) and
+[GKE](https://cloud.google.com/kubernetes-engine/docs/how-to/cloud-storage-fuse-csi-driver-perf#enable-and-use-file-caching)).
+
+Once GCSFuse is configured, set the JAX cache directory to the GCSFuse mount
+point:
+
+```python
+# Example assuming the GCS bucket is mounted at /gcs/my-bucket
+jax.config.update("jax_compilation_cache_dir", "/gcs/my-bucket/jax-cache")
+```
+
+**Direct GCS access :**
+
+If you choose not to use GCSFuse, you can point the cache directly to a GCS
+bucket.
+
Assuming that `gs://jax-cache` is the GCS bucket, set cache location as
follows:
diff --git a/docs/the-training-cookbook.rst b/docs/the-training-cookbook.rst
index 3f8018f41490..4059ec5f8ed2 100644
--- a/docs/the-training-cookbook.rst
+++ b/docs/the-training-cookbook.rst
@@ -109,7 +109,7 @@ Examining the call signature of the function ``adam_apply`` gives us a hint:
.. tagged-block:: the-training-cookbook.py adam-apply
-Because ``train_state.params`` is the first argument, :func:`jax.tree.map` uses its tree structure to guide the mapping process.[#prefix_tree]_ This means that ``train_state.opt`` is traversed only as deep as the leaves of ``train_state.params``. The optimizer state for each parameter is therefore passed in as a complete subtree, which allows us to easily access all relevant states (like ``mu`` and ``nu``) for a given ``param`` inside ``adam_apply``.
+Because ``train_state.params`` is the first argument, :func:`jax.tree.map` uses its tree structure to guide the mapping process. [#prefix_tree]_ This means that ``train_state.opt`` is traversed only as deep as the leaves of ``train_state.params``. The optimizer state for each parameter is therefore passed in as a complete subtree, which allows us to easily access all relevant states (like ``mu`` and ``nu``) for a given ``param`` inside ``adam_apply``.
.. tip::
@@ -286,7 +286,7 @@ The drawback of data-parallel sharding is that we have to keep multiple, full, r
.. code-block:: python
- mesh = jax.sharding.Mesh(jax.devices(), ('fsdp',))
+ mesh = jax.make_mesh((128*4,), ("fsdp",))
*Parameter Shardings:*
@@ -323,8 +323,8 @@ If our model is large enough and structured appropriately, it becomes beneficial
*Mesh:*
.. code-block:: python
-
- mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(128, 4), ("fsdp", "tensor"))
+
+ mesh = jax.make_mesh((128,4), ("fsdp", "tensor"))
*Parameter Shardings:*
diff --git a/examples/jax_cpp/main.cc b/examples/jax_cpp/main.cc
index 27681e41fdad..b911711ad53f 100644
--- a/examples/jax_cpp/main.cc
+++ b/examples/jax_cpp/main.cc
@@ -106,7 +106,7 @@ int main(int argc, char** argv) {
// Get result.
std::shared_ptr result_literal =
- results[0][0]->ToLiteralSync().value();
+ results[0][0]->ToLiteral().Await().value();
LOG(INFO) << "result = " << *result_literal;
return 0;
}
diff --git a/jax/BUILD b/jax/BUILD
index ed8f35c60529..6431c08656d1 100644
--- a/jax/BUILD
+++ b/jax/BUILD
@@ -20,6 +20,7 @@ load(
"jax_extend_internal_users",
"jax_extra_deps",
"jax_internal_packages",
+ "jax_visibility",
"py_deps",
"py_library_providing_imports_info",
"pytype_library",
@@ -111,6 +112,7 @@ exports_files([
"LICENSE",
"version.py",
"py.typed",
+ "oss/pyproject.toml",
])
# Packages that have access to JAX-internal implementation details.
@@ -152,6 +154,7 @@ py_library_providing_imports_info(
],
# TODO(dsuo): Consider moving these files out of experimental if they're in the public API.
) + ["//jax/experimental:jax_public"],
+ lazy_imports = True,
lib_rule = pytype_library,
pytype_srcs = glob(
[
@@ -254,220 +257,36 @@ pytype_strict_library(
)
# Public JAX libraries below this point.
-# TODO(phawkins): remove this target in favor of the finer-grained targets in jax/extend/...
-pytype_strict_library(
- name = "extend",
- visibility = [":jax_extend_users"],
- deps = [
- "//jax/extend",
- "//jax/extend:backend",
- "//jax/extend:core",
- "//jax/extend:linear_util",
- "//jax/extend:random",
- "//jax/extend:source_info_util",
- ],
-)
# Aliases of experimental targets.
# TODO(dsuo): remove these aliases/targets.
-py_library_providing_imports_info(
+pytype_strict_library(
name = "experimental",
- srcs = [
- "//jax/example_libraries:jax_example_libraries",
- "//jax/experimental:jax_experimental",
- ],
- visibility = ["//visibility:public"],
- # NOTE: Exclude mosaic_gpu, serialize_executable, and buffer_callback.
+ visibility = jax_visibility("experimental_deprecated_alias"),
deps = [
":jax",
- "//jax/_src:buffer_callback",
- ] + py_deps("absl/logging") + py_deps("numpy"),
-)
-
-alias(
- name = "experimental_buffer_callback",
- actual = "//jax/experimental:buffer_callback",
- visibility = ["//jax/experimental:buffer_callback_users"],
-)
-
-alias(
- name = "experimental_colocated_python",
- actual = "//jax/experimental:colocated_python",
- visibility = ["//visibility:public"],
-)
-
-alias(
- name = "experimental_compute_on",
- actual = "//jax/experimental:compute_on",
- visibility = ["//visibility:public"],
-)
-
-alias(
- name = "compilation_cache",
- actual = "//jax/experimental:compilation_cache",
- visibility = ["//visibility:public"],
-)
-
-alias(
- name = "jet",
- actual = "//jax/experimental:jet",
- visibility = ["//visibility:public"],
-)
-
-alias(
- name = "mesh_utils",
- actual = "//jax/experimental:mesh_utils",
- visibility = ["//visibility:public"],
-)
-
-alias(
- name = "experimental_mesh_utils",
- actual = "//jax/experimental:mesh_utils",
- visibility = ["//visibility:public"],
-)
-
-alias(
- name = "mosaic",
- actual = "//jax/experimental:mosaic",
- visibility = ["//jax/experimental:mosaic_users"],
-)
-
-alias(
- name = "mosaic_gpu",
- actual = "//jax/experimental:mosaic_gpu",
- visibility = ["//jax/experimental:mosaic_gpu_users"],
-)
-
-alias(
- name = "experimental_multihost_utils",
- actual = "//jax/experimental:multihost_utils",
- visibility = ["//visibility:public"],
-)
-
-alias(
- name = "ode",
- actual = "//jax/experimental:ode",
- visibility = ["//visibility:public"],
-)
-
-alias(
- name = "pallas",
- actual = "//jax/experimental:pallas",
- visibility = ["//visibility:public"],
-)
-
-alias(
- name = "pallas_fuser",
- actual = "//jax/experimental:pallas_fuser",
- visibility = ["//jax/experimental:pallas_fuser_users"],
-)
-
-alias(
- name = "pallas_gpu",
- actual = "//jax/experimental:pallas_gpu",
- visibility = ["//jax/experimental:pallas_gpu_users"],
-)
-
-alias(
- name = "pallas_gpu_ops",
- actual = "//jax/experimental:pallas_gpu_ops",
- visibility = ["//jax/experimental:pallas_gpu_users"],
-)
-
-alias(
- name = "pallas_mosaic_gpu",
- actual = "//jax/experimental:pallas_mosaic_gpu",
- visibility = ["//jax/experimental:mosaic_gpu_users"],
-)
-
-alias(
- name = "pallas_tpu",
- actual = "//jax/experimental:pallas_tpu",
- visibility = ["//visibility:public"],
-)
-
-alias(
- name = "pallas_tpu_ops",
- actual = "//jax/experimental:pallas_tpu_ops",
- visibility = ["//visibility:public"],
-)
-
-alias(
- name = "pallas_triton",
- actual = "//jax/experimental:pallas_triton",
- visibility = ["//jax/experimental:pallas_gpu_users"],
-)
-
-alias(
- name = "pallas_experimental_gpu_ops",
- actual = "//jax/experimental:pallas_experimental_gpu_ops",
- visibility = ["//jax/experimental:mosaic_gpu_users"],
-)
-
-alias(
- name = "experimental_profiler",
- actual = "//jax/experimental:profiler",
- visibility = ["//visibility:public"],
-)
-
-alias(
- name = "experimental_pjit",
- actual = "//jax/experimental:pjit",
- visibility = ["//visibility:public"],
-)
-
-alias(
- name = "rnn",
- actual = "//jax/experimental:rnn",
- visibility = ["//visibility:public"],
-)
-
-alias(
- name = "experimental_serialize_executable",
- actual = "//jax/experimental:serialize_executable",
- visibility = ["//jax/experimental:serialize_executable_users"],
-)
-
-alias(
- name = "source_mapper",
- actual = "//jax/experimental:source_mapper",
- visibility = ["//visibility:public"],
-)
-
-alias(
- name = "experimental_sparse",
- actual = "//jax/experimental:sparse",
- visibility = ["//visibility:public"],
-)
-
-alias(
- name = "sparse_test_util",
- actual = "//jax/experimental:sparse_test_util",
- visibility = [":internal"],
-)
-
-alias(
- name = "experimental_topologies",
- actual = "//jax/experimental:topologies",
- visibility = ["//visibility:public"],
-)
-
-alias(
- name = "experimental_transfer",
- actual = "//jax/experimental:transfer",
- visibility = [":internal"],
-)
-
-# Aliases of example_library targets.
-# TODO(dsuo): remove these aliases.
-alias(
- name = "optimizers",
- actual = "//jax/example_libraries:optimizers",
- visibility = ["//visibility:public"],
-)
-
-alias(
- name = "stax",
- actual = "//jax/example_libraries:stax",
- visibility = ["//visibility:public"],
+ "//jax/example_libraries:optimizers",
+ "//jax/example_libraries:stax",
+ "//jax/experimental",
+ "//jax/experimental:checkify",
+ "//jax/experimental:compute_on",
+ "//jax/experimental:custom_dce",
+ "//jax/experimental:custom_partitioning",
+ "//jax/experimental:fused",
+ "//jax/experimental:hijax",
+ "//jax/experimental:jet",
+ "//jax/experimental:layout",
+ "//jax/experimental:mesh_utils",
+ "//jax/experimental:multihost_utils",
+ "//jax/experimental:ode",
+ "//jax/experimental:pjit",
+ "//jax/experimental:profiler",
+ "//jax/experimental:rnn",
+ "//jax/experimental:scheduling_groups",
+ "//jax/experimental:shard_alike",
+ "//jax/experimental:shard_map",
+ "//jax/experimental:topologies",
+ "//jax/experimental:transfer",
+ "//jax/experimental:xla_metadata",
+ ],
)
diff --git a/jax/__init__.py b/jax/__init__.py
index 945fb9f46374..874c5f119fdf 100644
--- a/jax/__init__.py
+++ b/jax/__init__.py
@@ -136,6 +136,7 @@
from jax._src.sharding_impls import make_mesh as make_mesh
from jax._src.sharding_impls import set_mesh as set_mesh
from jax._src.partition_spec import P as P
+from jax._src.pjit import reshard as reshard
from jax._src.shard_map import shard_map as shard_map
from jax._src.shard_map import smap as smap
@@ -207,41 +208,6 @@
"jax.device_put_sharded is deprecated; use jax.device_put instead.",
_deprecated_device_put_sharded
),
- # Finalized 2025-03-25; remove after 2025-06-25
- "treedef_is_leaf": (
- "jax.treedef_is_leaf was removed in JAX v0.6.0: use jax.tree_util.treedef_is_leaf.",
- None
- ),
- "tree_flatten": (
- "jax.tree_flatten was removed in JAX v0.6.0: use jax.tree.flatten (jax v0.4.25 or newer) "
- "or jax.tree_util.tree_flatten (any JAX version).",
- None
- ),
- "tree_leaves": (
- "jax.tree_leaves was removed in JAX v0.6.0: use jax.tree.leaves (jax v0.4.25 or newer) "
- "or jax.tree_util.tree_leaves (any JAX version).",
- None
- ),
- "tree_structure": (
- "jax.tree_structure was removed in JAX v0.6.0: use jax.tree.structure (jax v0.4.25 or newer) "
- "or jax.tree_util.tree_structure (any JAX version).",
- None
- ),
- "tree_transpose": (
- "jax.tree_transpose was removed in JAX v0.6.0: use jax.tree.transpose (jax v0.4.25 or newer) "
- "or jax.tree_util.tree_transpose (any JAX version).",
- None
- ),
- "tree_unflatten": (
- "jax.tree_unflatten was removed in JAX v0.6.0: use jax.tree.unflatten (jax v0.4.25 or newer) "
- "or jax.tree_util.tree_unflatten (any JAX version).",
- None
- ),
- "tree_map": (
- "jax.tree_map was removed in JAX v0.6.0: use jax.tree.map (jax v0.4.25 or newer) "
- "or jax.tree_util.tree_map (any JAX version).",
- None
- ),
}
import typing as _typing
diff --git a/jax/_src/BUILD b/jax/_src/BUILD
index f514c931c3bd..bb815a0c45a2 100644
--- a/jax/_src/BUILD
+++ b/jax/_src/BUILD
@@ -437,7 +437,6 @@ py_library_providing_imports_info(
":core",
":custom_derivatives",
":custom_partitioning_sharding_rule",
- ":deprecations",
":dtypes",
":effects",
":ffi",
@@ -1187,6 +1186,7 @@ pytype_strict_library(
":core",
":dtypes",
":effects",
+ ":partial_eval",
":tree_util",
":util",
],
diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py
index a74bb7b7828e..62002cd8eb51 100644
--- a/jax/_src/ad_checkpoint.py
+++ b/jax/_src/ad_checkpoint.py
@@ -731,27 +731,30 @@ def remat_partial_eval_custom_params_updater(*args):
partial(pe.call_partial_eval_custom_rule, 'jaxpr',
remat_partial_eval_custom_params_updater)
-def remat_transpose(out_cts, *in_primals, jaxpr, prevent_cse, **params):
+def remat_transpose(out_cts, *args, jaxpr, prevent_cse, **params):
+ # TODO(mattjj): avoid round-tripping into UndefinedPrimals
+ args_ = [ad.UndefinedPrimal(x.aval) if isinstance(x, ad.ValAccum) else x
+ for x in args]
+ if any(isinstance(x, ad.GradAccum) for x in args_): raise NotImplementedError
+
assert not jaxpr.constvars
- in_linear = [ad.is_undefined_primal(x) for x in in_primals]
+ in_linear = [ad.is_undefined_primal(x) for x in args_]
out_zeros = [type(ct) is ad_util.Zero for ct in out_cts]
transposed_jaxpr_, in_zeros = transpose_jaxpr(
pe.close_jaxpr(jaxpr), in_linear, out_zeros)
transposed_jaxpr, consts = transposed_jaxpr_.jaxpr, transposed_jaxpr_.consts
transposed_jaxpr = pe.convert_constvars_jaxpr(transposed_jaxpr)
- args, _ = tree_flatten((in_primals, out_cts))
+ flat_args, _ = tree_flatten((args_, out_cts))
if isinstance(prevent_cse, tuple):
prevent_cse_, _ = partition_list(in_linear, prevent_cse)
prevent_cse = tuple(prevent_cse_) + (True,) * (len(out_zeros) - sum(out_zeros))
- in_cts_nz = remat_p.bind(*consts, *args, jaxpr=transposed_jaxpr,
+ in_cts_nz = remat_p.bind(*consts, *flat_args, jaxpr=transposed_jaxpr,
prevent_cse=prevent_cse, **params)
in_cts_nz_, in_zeros_ = iter(in_cts_nz), iter(in_zeros)
- in_cts = [None if not ad.is_undefined_primal(x) else
- ad_util.Zero(x.aval) if next(in_zeros_) else next(in_cts_nz_)
- for x in in_primals]
- assert next(in_cts_nz_, None) is next(in_zeros_, None) is None
- return in_cts
-ad.primitive_transposes[remat_p] = remat_transpose
+ for x in args:
+ if isinstance(x, ad.ValAccum) and not next(in_zeros_):
+ x.accum(next(in_cts_nz_))
+ad.fancy_transposes[remat_p] = remat_transpose
# TODO(mattjj): move this to ad.py
def transpose_jaxpr(jaxpr: core.ClosedJaxpr, in_linear: bool | Sequence[bool],
@@ -845,7 +848,8 @@ def remat_dce(used_outputs: list[bool], eqn: core.JaxprEqn
pe.dce_rules[remat_p] = remat_dce
def _has_effects(effects) -> bool:
- return bool({e for e in effects if not isinstance(e, core.NamedAxisEffect)})
+ not_really_effects = (core.NamedAxisEffect, core.InternalMutableArrayEffect)
+ return any(not isinstance(e, not_really_effects) for e in effects)
def remat_expansion(
diff --git a/jax/_src/api.py b/jax/_src/api.py
index 15577426e2d7..01926d919860 100644
--- a/jax/_src/api.py
+++ b/jax/_src/api.py
@@ -68,7 +68,7 @@
from jax._src.lib import xla_client as xc
from jax._src.lib import pmap_lib
from jax._src.sharding import Sharding
-from jax._src.mesh import get_concrete_mesh
+from jax._src.mesh import get_concrete_mesh, get_abstract_mesh, Mesh
from jax._src.sharding_impls import (PmapSharding, PartitionSpec as P,
NamedSharding)
from jax._src.layout import Format
@@ -172,7 +172,6 @@ def jit(
device: xc.Device | None = ...,
backend: str | None = ...,
inline: bool = ...,
- abstracted_axes: Any | None = ...,
compiler_options: dict[str, Any] | None = ...,
) -> pjit.JitWrapped: ...
@@ -189,7 +188,6 @@ def jit(
device: xc.Device | None = ...,
backend: str | None = ...,
inline: bool = ...,
- abstracted_axes: Any | None = ...,
compiler_options: dict[str, Any] | None = ...,
) -> Callable[[Callable], pjit.JitWrapped]: ...
@@ -205,7 +203,6 @@ def jit(
device: xc.Device | None = None,
backend: str | None = None,
inline: bool = False,
- abstracted_axes: Any | None = None,
compiler_options: dict[str, Any] | None = None,
) -> pjit.JitWrapped | Callable[[Callable], pjit.JitWrapped]:
"""Sets up ``fun`` for just-in-time compilation with XLA.
@@ -350,8 +347,7 @@ def jit(
static_argnums=static_argnums, static_argnames=static_argnames,
donate_argnums=donate_argnums, donate_argnames=donate_argnames,
keep_unused=keep_unused, device=device, backend=backend, inline=inline,
- abstracted_axes=abstracted_axes, compiler_options=compiler_options,
- use_resource_env=False)
+ compiler_options=compiler_options, use_resource_env=False)
if isinstance(fun, NotSpecified):
return lambda fun: pjit.make_jit(fun, **kwds)
else:
@@ -393,7 +389,7 @@ def disable_jit(disable: bool = True):
... return y + 3
...
>>> print(f(jax.numpy.array([1, 2, 3])))
- Value of y is JitTracer
+ Value of y is JitTracer(int32[3])
[5 7 9]
Here ``y`` has been abstracted by :py:func:`jit` to a :py:class:`ShapedArray`,
@@ -1195,6 +1191,9 @@ def vmap_f(*args, **kwargs):
_mapped_axis_size(fun, in_tree, args_flat, in_axes_flat, "vmap"))
explicit_mesh_axis = _mapped_axis_spec(args_flat, in_axes_flat)
if spmd_axis_name is not None and explicit_mesh_axis is not None:
+ if config.remove_size_one_mesh_axis_from_type.value:
+ mesh = get_abstract_mesh()
+ spmd_axis_name = tuple(i for i in spmd_axis_name if mesh.shape[i] != 1)
if spmd_axis_name == explicit_mesh_axis:
spmd_axis_name = None
else:
@@ -1345,7 +1344,13 @@ def pmap(
donate_argnums: int | Iterable[int] = (),
global_arg_shapes: tuple[tuple[int, ...], ...] | None = None,
) -> Any:
- """Parallel map with support for collective operations.
+ """Old way of doing parallel map. Use :py:func:`jax.shard_map` instead.
+
+ .. note::
+ While :py:func:`jax.pmap` works, you should probably use
+ :py:func:`jax.shard_map` or ``jax.smap`` instead. shard_map supports more
+ efficient autodiff, and is more composable in the multi-controller setting.
+ See https://docs.jax.dev/en/latest/notebooks/shard_map.html for examples.
.. note::
:py:func:`pmap` is now implemented in terms of :py:func:`jit` and
@@ -1510,26 +1515,6 @@ def pmap(
are important particularly in the case of nested :py:func:`pmap` functions,
where collective operations can operate over distinct axes:
- >>> from functools import partial
- >>> import jax
- >>>
- >>> @partial(pmap, axis_name='rows')
- ... @partial(pmap, axis_name='cols')
- ... def normalize(x):
- ... row_normed = x / jax.lax.psum(x, 'rows')
- ... col_normed = x / jax.lax.psum(x, 'cols')
- ... doubly_normed = x / jax.lax.psum(x, ('rows', 'cols'))
- ... return row_normed, col_normed, doubly_normed
- >>>
- >>> x = jnp.arange(8.).reshape((4, 2))
- >>> row_normed, col_normed, doubly_normed = normalize(x) # doctest: +SKIP
- >>> print(row_normed.sum(0)) # doctest: +SKIP
- [ 1. 1.]
- >>> print(col_normed.sum(1)) # doctest: +SKIP
- [ 1. 1. 1. 1.]
- >>> print(doubly_normed.sum((0, 1))) # doctest: +SKIP
- 1.0
-
On multi-process platforms, collective operations operate over all devices,
including those on other processes. For example, assuming the following code
runs on two processes with 4 XLA devices each:
@@ -2227,120 +2212,12 @@ def vjp(
fun, debug_info=debug_info("vjp", fun, primals, {}))
return _vjp(wrapped_fun, *primals, has_aux=has_aux)
-def _vjp(fun: lu.WrappedFun, *primals, has_aux=False):
- """Variant of vjp() that takes an lu.WrappedFun."""
- if config.vjp3.value:
- return _vjp3(fun, *primals, has_aux=has_aux)
- primals_flat, in_tree = tree_flatten(primals)
- primals_flat = [canonicalize_value(v) if not isinstance(v, core.Tracer) else v
- for v in primals_flat]
- for arg in primals_flat: dispatch.check_arg(arg)
- if not has_aux:
- flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
- out_primals, vjp = ad.vjp(flat_fun, primals_flat)
- out_tree = out_tree()
- else:
- flat_fun, out_aux_trees = flatten_fun_nokwargs2(fun, in_tree)
- out_primals, vjp, aux = ad.vjp(flat_fun, primals_flat, has_aux=True)
- out_tree, aux_tree = out_aux_trees()
- out_primal_avals = map(shaped_abstractify, out_primals)
- out_primal_py = tree_unflatten(out_tree, out_primals)
- vjp_py = Partial(partial(_vjp_pullback_wrapper, fun.__name__,
- out_primal_avals, (out_tree, in_tree)), vjp)
- if not has_aux:
- return out_primal_py, vjp_py
- else:
- return out_primal_py, vjp_py, tree_unflatten(aux_tree, aux)
-
-@partial(api_boundary, repro_api_name="jax.experimental.saved_input_vjp")
-def saved_input_vjp(f: Callable, which: Sequence[bool], *primals,
- allow_unused: bool = True, allow_opaque: bool = True):
- if len(which) != len(primals):
- raise ValueError(
- "length of 'which' argument must equal the number of primal input values, "
- f"but got {len(which)=} and {len(primals)=}")
-
- dbg = debug_info("saved_input_vjp", f, primals, {})
- fun = lu.wrap_init(f, debug_info=dbg)
- primals_flat, in_tree = tree_flatten(primals)
- fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
- out_primals_flat, out_pvals, jaxpr, residuals = ad.linearize(fun, *primals_flat)
- out_known = [pval.is_known() for pval in out_pvals]
- primals_filt, filt_tree = tree_flatten(tuple(p for w, p in zip(which, primals) if w))
- id_map = {id(x): i for i, x in enumerate(primals_filt)}
- opaque_residuals = []
- res_spec = [RSpec(id_map[id(r)], True) if id(r) in id_map else
- RSpec(opaque_residuals.append(r) or (len(opaque_residuals) - 1), False) # type: ignore
- for r in residuals]
- f_vjp = Partial(partial(_saved_input_vjpfun, res_spec, filt_tree, in_tree,
- out_tree(), out_known, jaxpr), opaque_residuals)
-
- if not allow_unused and not set(id_map).issubset(res_ids := {id(r) for r in residuals}):
- unused = [(i, core.get_aval(x)) for i, (x, w) in enumerate(zip(primals, which))
- if w and id(x) not in res_ids]
- assert unused
- if len(unused) == 1:
- (i, a), = unused
- start, was = "an input value", "was"
- msg = f" {dbg.arg_names[i] if dbg.arg_names is not None else 'unknown'} of type {a.str_short()}"
- else:
- start, was = "multiple input values", "were"
- msg = "\n" + "\n".join(f" * {dbg.arg_names[i] if dbg.arg_names is not None else 'unknown'} of type {a.str_short()}"
- for i, a in unused)
- raise Exception(f"with {allow_unused=}, {start} marked to be saved {was} "
- f"not used by the backward pass:{msg}")
-
- if not allow_opaque and opaque_residuals:
- msg = ", ".join(core.get_aval(x).str_short() for x in opaque_residuals)
- raise Exception(f"with {allow_opaque=}, the backward pass requires opaque "
- f"(non-input) residuals: {msg}")
-
- out_primals = tree_unflatten(out_tree(), out_primals_flat)
- return out_primals, f_vjp
-
-def _saved_input_vjpfun(res_spec, filtered_tree, in_tree, out_tree, out_known,
- jaxpr, opaque_residuals, ct, *saved_primals):
- primals_filtered, filtered_tree_ = tree_flatten(saved_primals)
- if filtered_tree != filtered_tree_:
- raise ValueError(
- "inputs passed to f_vjp must be a tuple of (pytrees of) "
- "arrays with the same structure as\n"
- " tuple(x for x, w in zip(inputs, which) if w)\n"
- "given the original call\n"
- " _, f_vjp = saved_input_vjp(f, which, *inputs, ...)\n"
- "but the structures differ:\n" +
- "\n".join(f" * inputs{keystr(path)} was a {thing1} in the original "
- f"call, but a {thing2} here, so {explanation}"
- for path, thing1, thing2, explanation
- in equality_errors_pytreedef(filtered_tree, filtered_tree_)))
-
- residuals = [primals_filtered[i.idx] if i.primal else opaque_residuals[i.idx]
- for i in res_spec]
- dummy_args = [ad.UndefinedPrimal(v.aval) for v in jaxpr.invars]
- cts_flat, out_tree_ = tree_flatten(ct)
- assert out_tree_ == out_tree
- cts_flat = [ct for ct, k in zip(cts_flat, out_known) if not k]
- arg_cts = ad.backward_pass(jaxpr, True, residuals, dummy_args, cts_flat)
- return tree_unflatten(in_tree, map(ad.instantiate_zeros, arg_cts))
-
-@dataclasses.dataclass(frozen=True)
-class RSpec:
- idx: int
- primal: bool
-
-si_vjp = saved_input_vjp
-
-
-def vjp3(f, *primals, has_aux=False):
- dbg = debug_info("vjp", f, primals, {})
- fun = lu.wrap_init(f, debug_info=dbg)
- return _vjp3(fun, *primals, has_aux=has_aux)
-
-def _vjp3(fun, *primals, has_aux=False):
+def _vjp(fun, *primals, has_aux=False):
canon = lambda x: x if isinstance(x, core.Tracer) else canonicalize_value(x)
primals = tree_map(canon, primals)
primals_flat, in_tree = tree_flatten(primals)
- for arg in primals_flat: dispatch.check_arg(arg)
+ for arg in primals_flat:
+ dispatch.check_arg(arg)
if not has_aux:
flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
out_primals_flat, out_pvals, jaxpr, residuals = ad.linearize(
@@ -2369,22 +2246,14 @@ def _vjp3(fun, *primals, has_aux=False):
else:
return out_primals, f_vjp, tree_unflatten(aux_tree, aux)
-def tuptree_map(f, treedef, x):
- return treedef.walk(lambda xs, _: tuple(xs), f, x)
-
-
-def _is_ref(x):
- from jax._src.state.types import AbstractRef
- try: return isinstance(typeof(x), AbstractRef)
- except: return False
-
def _vjp3_callable(spec, out_known, jaxpr, out_primal_avals, in_tree, out_tree,
args_res, opaque_res, *maybe_ct_refs):
if not maybe_ct_refs:
maybe_ct_refs_flat = [GradValue()] * in_tree.num_leaves
else:
maybe_ct_refs_flat, in_tree_ = tree_flatten(maybe_ct_refs)
- if in_tree != in_tree_: raise Exception # TODO accept isomorph tuple tree
+ if in_tree != in_tree_:
+ raise Exception # TODO accept isomorph tuple tree
args_res_ = tree_leaves(args_res, is_leaf=lambda x: isinstance(x, NotNeeded))
residuals = [args_res_[i.idx] if i.primal else opaque_res[i.idx] for i in spec]
maybe_refs = [ad.RefAccum(v.aval, x) if _is_ref(x) else ad.ValAccum(v.aval)
@@ -2395,7 +2264,8 @@ def _vjp3_callable(spec, out_known, jaxpr, out_primal_avals, in_tree, out_tree,
def _vjp3_bwd(in_tree, out_tree, out_known, jaxpr, out_primal_avals, residuals,
maybe_refs, out_ct):
cts_flat, out_tree_ = tree_flatten(out_ct)
- if out_tree != out_tree_: _vjp_ct_tree_error(jaxpr, out_tree, out_tree_)
+ if out_tree != out_tree_:
+ _vjp_ct_tree_error(jaxpr, out_tree, out_tree_)
_vjp_check_ct_avals(cts_flat, out_primal_avals)
cts_flat = [ct for ct, k in zip(cts_flat, out_known) if not k]
ad.backward_pass3(jaxpr, True, residuals, maybe_refs, cts_flat)
@@ -2404,6 +2274,23 @@ def _vjp3_bwd(in_tree, out_tree, out_known, jaxpr, out_primal_avals, residuals,
arg_cts = map(ad.instantiate_zeros, arg_cts)
return tree_unflatten(in_tree, arg_cts)
+
+@dataclasses.dataclass(frozen=True)
+class RSpec:
+ idx: int
+ primal: bool
+
+def tuptree_map(f, treedef, x):
+ return treedef.walk(lambda xs, _: tuple(xs), f, x)
+
+def _is_ref(x):
+ from jax._src.state.types import AbstractRef
+ try:
+ return isinstance(typeof(x), AbstractRef)
+ except:
+ return False
+
+
_vjp_too_many_args = """
The function returned by `jax.vjp` applied to {} was called with {} arguments,
but functions returned by `jax.vjp` must be called with a single argument
@@ -2425,6 +2312,7 @@ def f(x):
arguments rather than in a tuple, this error can arise.
""".format
+
def _vjp_ct_tree_error(jaxpr, out_tree, ct_tree):
msg = f"""unexpected tree structure.
@@ -2439,13 +2327,12 @@ def _vjp_ct_tree_error(jaxpr, out_tree, ct_tree):
in equality_errors_pytreedef(out_tree, ct_tree))
raise ValueError(msg)
+
def _vjp_check_ct_avals(cts, primal_avals):
# TODO(mattjj): improve this error by flattening with keys in the first place
for ct, aval in zip(cts, primal_avals):
ct_aval = typeof(ct)
- ct_aval_expected = (
- aval.to_cotangent_aval() if hasattr(aval, 'to_cotangent_aval') else
- aval.to_tangent_aval())
+ ct_aval_expected = aval.to_cotangent_aval()
if (not core.typecompat(ct_aval, ct_aval_expected) and
not _temporary_dtype_exception(ct_aval, ct_aval_expected)):
raise ValueError(
@@ -2454,6 +2341,7 @@ def _vjp_check_ct_avals(cts, primal_avals):
"because the corresponding output of the differentiated function had JAX type "
f"{aval.str_short()}")
+
@register_dataclass
@dataclasses.dataclass(frozen=True)
class NotNeeded:
@@ -2577,13 +2465,13 @@ def transposed_fun(const, out_cotangent):
return Partial(transposed_fun, const)
-def _flat_axes_specs(abstracted_axes, *args, **kwargs
+def _flat_axes_specs(*args, **kwargs
) -> list[pe.AbstractedAxesSpec]:
if kwargs: raise NotImplementedError
def ax_leaf(l):
return (isinstance(l, dict) and all_leaves(l.values()) or
isinstance(l, tuple) and all_leaves(l, lambda x: x is None))
- return broadcast_prefix(abstracted_axes, args, ax_leaf)
+ return broadcast_prefix(args, ax_leaf)
@overload
@@ -2592,7 +2480,6 @@ def make_jaxpr(
static_argnums: int | Iterable[int] = (),
axis_env: Sequence[tuple[AxisName, int]] | None = None,
return_shape: Literal[False] = ...,
- abstracted_axes: Any | None = None,
) -> Callable[..., core.ClosedJaxpr]:
...
@@ -2602,7 +2489,6 @@ def make_jaxpr(
static_argnums: int | Iterable[int] = (),
axis_env: Sequence[tuple[AxisName, int]] | None = None,
return_shape: Literal[True] = ...,
- abstracted_axes: Any | None = None,
) -> Callable[..., tuple[core.ClosedJaxpr, Any]]:
...
@@ -2612,7 +2498,6 @@ def make_jaxpr(
static_argnums: int | Iterable[int] = (),
axis_env: Sequence[tuple[AxisName, int]] | None = None,
return_shape: bool = False,
- abstracted_axes: Any | None = None,
) -> Callable[..., core.ClosedJaxpr | tuple[core.ClosedJaxpr, Any]]:
"""Create a function that returns the jaxpr of ``fun`` given example args.
@@ -2680,8 +2565,7 @@ def make_jaxpr(
@api_boundary
def make_jaxpr_f(*args, **kwargs):
with core.extend_axis_env_nd(axis_env or []):
- traced = jit(fun, static_argnums=static_argnums,
- abstracted_axes=abstracted_axes).trace(*args, **kwargs)
+ traced = jit(fun, static_argnums=static_argnums).trace(*args, **kwargs)
# `jit` converts tracers in consts to args but `make_jaxpr` callers expect
# consts not to be converted.
num_consts = traced._num_consts
@@ -2916,8 +2800,12 @@ def _device_put_sharded(*xs):
raise ValueError("the shards passed to device_put_sharded must have "
f"consistent shape and dtype, but got {a1} and {a2}.")
stacked_aval = avals[0].update(shape=(len(devices),) + avals[0].shape)
- sharding_spec = sharding_specs.create_pmap_sharding_spec(stacked_aval.shape)
- sharding = PmapSharding(np.array(devices), sharding_spec)
+ if config.pmap_shmap_merge.value:
+ mesh = Mesh(np.array(devices), ('_device_put_sharded',))
+ sharding = NamedSharding(mesh, P('_device_put_sharded'))
+ else:
+ sharding_spec = sharding_specs.create_pmap_sharding_spec(stacked_aval.shape)
+ sharding = PmapSharding(np.array(devices), sharding_spec)
if dtypes.issubdtype(stacked_aval.dtype, dtypes.extended):
return stacked_aval.dtype._rules.device_put_sharded(xs, stacked_aval, sharding, devices)
if config.pmap_no_rank_reduction.value:
@@ -2972,7 +2860,6 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
def _device_put_replicated(x):
aval = core.unmapped_aval(len(devices), 0, core.get_aval(x))
assert isinstance(aval, ShapedArray)
- sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape)
if config.pmap_no_rank_reduction.value:
if isinstance(x, (np.ndarray, basearray.Array)):
buf = device_put(x[None], devices[0])
@@ -2980,7 +2867,12 @@ def _device_put_replicated(x):
buf = device_put(x, devices[0])[None]
else:
buf = device_put(x, devices[0])
- sharding = PmapSharding(np.array(devices), sharding_spec)
+ if config.pmap_shmap_merge.value:
+ mesh = Mesh(np.array(devices), ('_device_put_replicated',))
+ sharding = NamedSharding(mesh, P('_device_put_replicated'))
+ else:
+ sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape)
+ sharding = PmapSharding(np.array(devices), sharding_spec)
if dtypes.issubdtype(aval.dtype, dtypes.extended):
return aval.dtype._rules.device_put_replicated(buf, aval, sharding, devices)
return pxla.batched_device_put(aval, sharding, [buf] * len(devices), devices)
diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py
index 77e9e1c11d72..4f8a378a14f2 100644
--- a/jax/_src/api_util.py
+++ b/jax/_src/api_util.py
@@ -27,16 +27,17 @@
from jax._src.state.types import AbstractRef
from jax._src.tree_util import (
PyTreeDef, tree_flatten, tree_unflatten, treedef_children,
- generate_key_paths, broadcast_prefix, prefix_errors, none_leaf_registry,
- broadcast_flattened_prefix_with_treedef)
+ tree_flatten_with_path, generate_key_paths, broadcast_prefix, prefix_errors,
+ none_leaf_registry, broadcast_flattened_prefix_with_treedef)
from jax._src import linear_util as lu
from jax._src.util import (safe_map, WrapKwArgs, Hashable, HashableFunction,
- Unhashable, safe_zip as zip)
+ Unhashable, safe_zip, unzip2)
from jax._src import traceback_util
traceback_util.register_exclusion(__file__)
-map = safe_map
+map, unsafe_map = safe_map, map
+zip, unsafe_zip = safe_zip, zip
def _ensure_index(x: Any) -> int | tuple[int, ...]:
"""Ensure x is either an index or a tuple of indices."""
@@ -75,6 +76,16 @@ def flatten_fun(f: Callable, store: lu.Store,
store.store(out_tree)
return ans
+@lu.transformation_with_aux2
+def flatten_fun3(f: Callable, store: lu.Store,
+ in_tree: PyTreeDef, *args_flat):
+ py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
+ ans = f(*py_args, **py_kwargs)
+ paths_and_ans, out_tree = tree_flatten_with_path(ans)
+ paths, ans = unzip2(paths_and_ans)
+ store.store((out_tree, paths))
+ return ans
+
def apply_flat_fun(fun, io_tree, *py_args):
in_tree_expected, out_tree = io_tree
args, in_tree = tree_flatten((py_args, {}))
diff --git a/jax/_src/array.py b/jax/_src/array.py
index 13ca89fb25d7..50185186306f 100644
--- a/jax/_src/array.py
+++ b/jax/_src/array.py
@@ -39,12 +39,13 @@
from jax._src.interpreters import pxla
from jax._src.layout import AutoLayout, Format, Layout
from jax._src.lib import _jax
+from jax._src.lib import jaxlib_extension_version
from jax._src.lib import xla_client as xc
from jax._src.mesh import empty_concrete_mesh
from jax._src.sharding import Sharding
from jax._src.tree_util import broadcast_prefix, tree_flatten, tree_unflatten
from jax._src.sharding_impls import (
- PmapSharding, SingleDeviceSharding,
+ PmapSharding, SingleDeviceSharding, NamedSharding,
device_replica_id_map, hashed_index, num_addressable_indices,
local_to_global_shape, _internal_use_concrete_mesh) # pyformat: disable
from jax._src.typing import ArrayLike, DLDeviceType, DTypeLike, ExtendedDType
@@ -284,9 +285,6 @@ def weak_type(self):
def committed(self) -> bool:
return self._committed
- def __str__(self):
- return str(self._value)
-
def __len__(self):
try:
return self.shape[0]
@@ -394,11 +392,13 @@ def is_fully_replicated(self) -> bool:
def __repr__(self):
prefix = 'Array('
if self.aval is not None and self.aval.weak_type:
- dtype_str = f'dtype={self.dtype.name}, weak_type=True)'
+ dtype_str = f'dtype={self.dtype.name}, weak_type=True'
else:
- dtype_str = f'dtype={self.dtype.name})'
+ dtype_str = f'dtype={self.dtype.name}'
- if self.is_fully_addressable or self.is_fully_replicated:
+ if isinstance(self.sharding, NamedSharding) and self.sharding.spec.unreduced:
+ return f"Array(shape={self.shape}, {dtype_str}, sharding={self.sharding})"
+ elif self.is_fully_addressable or self.is_fully_replicated:
line_width = np.get_printoptions()["linewidth"]
if self.size == 0:
s = f"[], shape={self.shape}"
@@ -409,11 +409,19 @@ def __repr__(self):
separator=', ', max_line_width=line_width)
last_line_len = len(s) - s.rfind('\n') + 1
sep = ' '
- if last_line_len + len(dtype_str) + 1 > line_width:
+ if last_line_len + len(dtype_str) + 2 > line_width:
sep = ' ' * len(prefix)
- return f"{prefix}{s},{sep}{dtype_str}"
+ return f"{prefix}{s},{sep}{dtype_str})"
+ else:
+ return f"{prefix}shape={self.shape}, {dtype_str})"
+
+ def __str__(self):
+ if isinstance(self.sharding, NamedSharding) and self.sharding.spec.unreduced:
+ return repr(self)
+ elif self.is_fully_addressable or self.is_fully_replicated:
+ return str(self._value) # doesn't print Array(...)
else:
- return f"{prefix}shape={self.shape}, {dtype_str}"
+ return repr(self)
@property
def is_fully_addressable(self) -> bool:
@@ -903,14 +911,14 @@ def make_array_from_process_local_data(
>>> assert output_global_array.addressable_data(0).shape == per_device_shape
>>> assert output_global_array.shape == global_shape
- NB: While most shardings are uniform, It is possible to design am exotic
+ NB: While most shardings are uniform, It is possible to design an exotic
sharding mesh where each process's devices will be arranged in a non-grid
like pattern in some dimensions, or for indices to overlap non-trivially.
Such sharding is called "non-uniform" in those dimensions. In that case,
the global shape along those directions must match local shape as there is
no meaningful way to represent all needed
per-process data in non-overlapping fashion. For example for global_shape 4x4
- if sharding looks like this:
+ if sharding looks like this::
0123
2103
@@ -918,7 +926,7 @@ def make_array_from_process_local_data(
4567
with 4 processes, containing devices (0,1), (2, 3), (4, 5), (6, 7) respectively.
- Then the data for each host look like
+ Then the data for each host look like::
xx.. ..xx .... ....
.xx. x..x .... ....
@@ -932,7 +940,7 @@ def make_array_from_process_local_data(
In this case user must provide global_shape explicitly and for
local_shape=(2, 4), potentially valid global shapes are (2, 4) and (4, 4).
- On the other hand for sharding:
+ On the other hand for sharding::
0213 x.x. .x.x. .... ....
0213 x.x. .x.x. .... ....
@@ -1277,9 +1285,15 @@ def _array_shard_arg(xs, shardings, layouts, copy_semantics):
def _array_global_result_handler(global_aval, out_sharding, committed):
- global_aval = core.update_aval_with_sharding(global_aval, out_sharding)
if global_aval.dtype == dtypes.float0:
- return lambda _: np.zeros(global_aval.shape, dtypes.float0)
+ def handler(xs):
+ return np.zeros(global_aval.shape, dtypes.float0)
+ if jaxlib_extension_version >= 390:
+ phys_aval = core.physical_aval(global_aval)
+ return xc.array_result_handler(phys_aval, out_sharding, committed=committed,
+ _skip_checks=True).wrap(handler)
+ else:
+ return handler
if dtypes.issubdtype(global_aval.dtype, dtypes.extended):
return global_aval.dtype._rules.global_sharded_result_handler(
global_aval, out_sharding, committed)
@@ -1291,7 +1305,14 @@ def _array_global_result_handler(global_aval, out_sharding, committed):
# Only used for Arrays that come out of pmap.
def _array_local_result_handler(aval, sharding, indices):
if aval.dtype == dtypes.float0:
- return lambda _: np.zeros(aval.shape, dtypes.float0)
+ def handler(xs):
+ return np.zeros(aval.shape, dtypes.float0)
+ if jaxlib_extension_version >= 390:
+ phys_aval = core.physical_aval(aval)
+ return xc.array_result_handler(phys_aval, sharding, committed=True,
+ _skip_checks=True).wrap(handler)
+ else:
+ return handler
if dtypes.issubdtype(aval.dtype, dtypes.extended):
return aval.dtype._rules.local_sharded_result_handler(
aval, sharding, indices)
@@ -1320,9 +1341,13 @@ def _token_shard_arg(xs, shardings, layouts, copy_semantics):
def _token_global_result_handler(global_aval, out_sharding, committed):
array_handler = _array_global_result_handler(
core.get_token_aval(), out_sharding, committed)
-
- def wrapper(*args, **kwargs):
- out_buf = array_handler(*args, **kwargs)
- return core.Token(out_buf)
- return wrapper
+ if jaxlib_extension_version >= 390:
+ def wrapper(array):
+ return core.Token(array)
+ return array_handler.wrap(wrapper) # type: ignore
+ else:
+ def old_wrapper(*args, **kwargs):
+ out_buf = array_handler(*args, **kwargs)
+ return core.Token(out_buf)
+ return old_wrapper
pxla.global_result_handlers[core.AbstractToken] = _token_global_result_handler
diff --git a/jax/_src/cache_key.py b/jax/_src/cache_key.py
index 493e49a2e086..296f65b5ed3f 100644
--- a/jax/_src/cache_key.py
+++ b/jax/_src/cache_key.py
@@ -334,7 +334,6 @@ def _hash_serialized_compile_options(hash_obj, compile_options_obj,
def _hash_platform(hash_obj, backend):
_hash_string(hash_obj, backend.platform)
_hash_string(hash_obj, backend.platform_version)
- _hash_string(hash_obj, backend.runtime_type)
def _hash_xla_flags(hash_obj, extra_flag_prefixes: list[str]):
diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py
index 1fc01e00cbe6..62fa882dd006 100644
--- a/jax/_src/checkify.py
+++ b/jax/_src/checkify.py
@@ -47,7 +47,7 @@
from jax._src.interpreters import partial_eval as pe
from jax._src.partition_spec import PartitionSpec as P
from jax._src.tree_util import tree_flatten
-from jax._src.tree_util import tree_map
+from jax._src.tree_util import tree_map, FlatTree
from jax._src.tree_util import tree_unflatten
from jax._src.typing import Array
from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip,
@@ -753,21 +753,28 @@ def scatter_error_check(prim, error, enabled_errors, operand, indices, updates,
# HOP error check rules
+@jtu.register_static
+class ErrorEffects:
+ def __init__(self, val):
+ self.val = val
+
@weakref_lru_cache
def jaxpr_to_checkify_jaxpr(
jaxpr: core.ClosedJaxpr, enabled_errors, err_tree: PyTreeDef,
*flat_err_and_in_vals) -> tuple[core.ClosedJaxpr, PyTreeDef, set[ErrorEffect]]:
- checkify_jaxpr_partial = functools.partial(checkify_jaxpr_flat, jaxpr.jaxpr,
- jaxpr.consts, enabled_errors,
- err_tree)
- fun = lu.wrap_init(checkify_jaxpr_partial,
- debug_info=jaxpr.jaxpr.debug_info.with_unknown_names())
- fun, metadata = _flatten_and_get_error_metadata_thunk(fun)
-
- new_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, flat_err_and_in_vals)
- checked_jaxpr = core.ClosedJaxpr(new_jaxpr, consts)
- out_tree, error_effects = metadata()
- return checked_jaxpr, out_tree, error_effects
+
+ def fun_wrapped(*invals):
+ error, out = checkify_jaxpr_flat(
+ jaxpr.jaxpr, jaxpr.consts, enabled_errors, err_tree, *invals)
+ error_effects = ErrorEffects(set(error._pred.keys()))
+ return (error, out), error_effects
+
+ debug_info = jaxpr.jaxpr.debug_info.with_unknown_names()
+ args_avals = FlatTree.flatten((flat_err_and_in_vals, {}))
+ checked_jaxpr, full_out_avals = pe.trace_to_jaxpr(fun_wrapped, args_avals, debug_info)
+ out_avals, error_effects = full_out_avals.unpack()
+ error_effects = error_effects.unflatten().val
+ return checked_jaxpr, out_avals.tree, error_effects
def cond_error_check(error: Error, enabled_errors, index, *ops,
branches, **params):
@@ -848,18 +855,17 @@ def new_body_f(*c_consts_and_vals):
# This checks if the next cond application will error
lax.dce_sink(cond_f(*c_consts, *out))
return out
- new_body_f_ = lu.wrap_init(
+ c_consts_avals = cond_jaxpr.in_avals[:c_consts_num]
+
+ jaxpr, _ = pe.trace_to_jaxpr(
new_body_f,
+ FlatTree.flatten(((*c_consts_avals, *body_jaxpr.in_avals), {})),
debug_info=body_jaxpr.jaxpr.debug_info.with_unknown_names())
- c_consts_avals = cond_jaxpr.in_avals[:c_consts_num]
- jaxpr, _, () = pe.trace_to_jaxpr_dynamic(
- new_body_f_, [*c_consts_avals, *body_jaxpr.in_avals])
- closed_jaxpr = pe.close_jaxpr(jaxpr)
err_vals, err_tree = jtu.tree_flatten(error)
err_vals = map(core.get_aval, err_vals)
flat_err_and_in_vals = [*err_vals, *c_consts_avals, *body_jaxpr.in_avals]
jaxpr, out_tree, error_effects = jaxpr_to_checkify_jaxpr(
- closed_jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals)
+ jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals)
return jaxpr, out_tree, error_effects
@@ -1004,12 +1010,10 @@ def expand_errors_leading_dim(*xs):
return *errs, *outs
with core.extend_axis_env_nd(mesh.shape.items()), config._check_vma(check_vma):
- jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
- lu.wrap_init(expand_errors_leading_dim,
- debug_info=checked_jaxpr.jaxpr.debug_info),
- checked_jaxpr.in_avals
- )
- checked_jaxpr = core.ClosedJaxpr(jaxpr, consts)
+ checked_jaxpr, _ = pe.trace_to_jaxpr(
+ expand_errors_leading_dim,
+ FlatTree.flatten((tuple(checked_jaxpr.in_avals), {})),
+ debug_info=checked_jaxpr.jaxpr.debug_info)
# Update shard_map params to account for extra error values.
# Use fully sharded partitioning for out errors.
@@ -1235,17 +1239,15 @@ def checkify(f: Callable[..., Out],
@traceback_util.api_boundary
def checked_fun(*args, **kwargs):
# close over all arguments so they're not turned into abstract values.
- in_tree = jtu.tree_structure(((), {}))
+ in_avals = FlatTree.flatten(((), {}))
closed_f = lambda: f(*args, **kwargs)
# stage:
- debug = api_util.debug_info("checkify", f, args, kwargs)
- fun_, out_tree = api_util.flatten_fun(
- lu.wrap_init(closed_f, debug_info=debug.with_unknown_names()), in_tree)
- jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(fun_, ())
- jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_))
+ debug_info = api_util.debug_info("checkify", f, args, kwargs).with_unknown_names()
+ jaxpr_, out_avals = pe.trace_to_jaxpr(closed_f, in_avals, debug_info)
+ jaxpr, consts = pe.separate_consts(jaxpr_)
# checkify:
error, out_flat = checkify_jaxpr(jaxpr, errors, init_error, *consts)
- return error, jtu.tree_unflatten(out_tree(), out_flat)
+ return error, out_avals.update_from_list(out_flat).unflatten()
return checked_fun
def check(pred: Bool, msg: str,
diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py
index 59d11bc95b06..bb416fd1c633 100644
--- a/jax/_src/compilation_cache.py
+++ b/jax/_src/compilation_cache.py
@@ -261,7 +261,10 @@ def put_executable_and_time(
" since cache is disabled/not initialized", cache_key)
return
- serialized_executable = backend.serialize_executable(executable)
+ if hasattr(executable, "serialize") or xla_client._version >= 389:
+ serialized_executable = executable.serialize()
+ else:
+ serialized_executable = backend.serialize_executable(executable)
executable_and_time = combine_executable_and_time(
serialized_executable, compile_time)
executable_and_time = compress_executable(executable_and_time)
diff --git a/jax/_src/config.py b/jax/_src/config.py
index 212902268d78..70826d6bf82a 100644
--- a/jax/_src/config.py
+++ b/jax/_src/config.py
@@ -23,6 +23,7 @@
import os
import sys
from typing import Any, Generic, NoReturn, Optional, Protocol, Type, TypeVar, cast
+import warnings
from jax._src import deprecations
from jax._src.lib import _jax
@@ -1157,6 +1158,15 @@ def _validate_jax_pjrt_client_create_options(new_val):
'to disable any debuggers while leak checking is enabled.'))
checking_leaks = functools.partial(check_tracer_leaks, True)
+check_static_indices = bool_state(
+ name='jax_check_static_indices',
+ default=False,
+ help=('Turn on bounds checks for static indices during array indexing operations.'
+ ' These will only be checked when indexing mode is PROMISE_IN_BOUNDS, which'
+ ' is the default for gather-type operations.'),
+ include_in_jit_key=True,
+ include_in_trace_context=True,
+)
captured_constants_warn_bytes = int_state(
name='jax_captured_constants_warn_bytes',
@@ -1240,27 +1250,6 @@ def _validate_jax_pjrt_client_create_options(new_val):
include_in_trace_context=True,
)
-def _safer_randint_deprecation(new_val):
- if not new_val:
- deprecations.warn(
- 'safer-randint-config',
- (
- 'The jax_safer_randint configuration is deprecated in JAX v0.7.2'
- ' and will be removed in JAX v0.9.0.'
- ),
- stacklevel=4
- )
-
-# TODO(jakevdp): remove this flag.
-safer_randint = bool_state(
- name='jax_safer_randint',
- default=True,
- help='Use a safer randint algorithm for 8-bit and 16-bit dtypes.',
- include_in_jit_key=True,
- upgrade=True,
- validator=_safer_randint_deprecation
-)
-
class LegacyPrngKeyState(enum.StrEnum):
ALLOW = 'allow'
WARN = 'warn'
@@ -1539,25 +1528,23 @@ class LegacyPrngKeyState(enum.StrEnum):
'what they are trying to achieve should set it.'),
)
-def _default_dtype_bits_deprecation(new_val):
- if new_val != '64':
- deprecations.warn(
- 'default-dtype-bits-config',
- (
- 'The jax_default_dtype_bits configuration is deprecated in JAX v0.7.1'
- ' and will be removed in JAX v0.9.0.'
- ),
- stacklevel=4
- )
+def _default_dtype_bits_deprecation(val):
+ if val != '_default':
+ warnings.warn(
+ (
+ 'The jax_default_dtype_bits configuration is deprecated in JAX v0.7.1'
+ ' and has no effect as of JAX v0.9.0. It will be removed in JAX v0.10.0.'
+ ),
+ category=DeprecationWarning,
+ stacklevel=4)
default_dtype_bits = enum_state(
name='jax_default_dtype_bits',
- enum_values=['32', '64'],
- default='64',
- help=('[deprecated]. This flag was an experiment in allowing users to specify the'
- ' default bit width. It was never fully supported or tested. It will '
- ' have no effect after JAX v0.9.0, and be removed entirely in JAX v0.10.0.'),
+ enum_values=['_default', '32', '64'],
+ default='_default',
+ help=('[deprecated]. This has no effect starting with JAX v0.9.0, and'
+ ' will be removed in JAX v0.10.0.'),
extra_validator=_default_dtype_bits_deprecation)
@@ -1776,16 +1763,6 @@ def _validate_default_device(val):
default=False,
help=('Enables lowering BCOO ops to cuSparse.'))
-# TODO(mattjj): remove this flag when we ensure we only succeed at trace-staging
-# if the intended backend can handle lowering the result
-dynamic_shapes = bool_state(
- name='jax_dynamic_shapes',
- default=False,
- help=('Enables experimental features for staging out computations with '
- 'dynamic shapes.'),
- include_in_jit_key=True,
- include_in_trace_context=True)
-
# This is for stackless backward compat with e.g. equinox
eager_constant_folding = bool_state(
name='eager_constant_folding',
@@ -1820,13 +1797,6 @@ def _validate_default_device(val):
upgrade=False,
help='Temporary workaround to disable an error check in vmap-of-shmap.')
-# TODO(mattjj): remove once we land mutable array plumbing, or face great shame
-custom_vjp_disable_shape_check = bool_state(
- name='jax_custom_vjp_disable_shape_check',
- default=False,
- upgrade=True,
- help='Disable the check from #19009 to enable some custom_vjp hacks.')
-
mutable_array_checks = bool_state(
name='jax_mutable_array_checks',
default=True,
@@ -1834,18 +1804,19 @@ def _validate_default_device(val):
help='Enable error checks for mutable arrays that rule out aliasing.',
include_in_trace_context=True)
-vjp3 = bool_state(
- name='jax_vjp3',
- default=True,
- upgrade=True,
- help='Use new backward-pass code in jax.vjp')
-
refs_to_pins = bool_state(
name='jax_refs_to_pins',
default=False,
upgrade=True,
help='Lower refs to pinned buffers in HLO.')
+# TODO(mattjj, yashkatariya): remove once we land box plumbing
+disable_bwd_checks = bool_state(
+ name='jax_disable_bwd_checks',
+ default=False,
+ upgrade=True,
+ help='Disables all bwd pass checks')
+
xla_runtime_errors = bool_state(
name='jax_experimental_unsafe_xla_runtime_errors',
default=False,
diff --git a/jax/_src/core.py b/jax/_src/core.py
index 7950f281f33a..d86dc6f7ce08 100644
--- a/jax/_src/core.py
+++ b/jax/_src/core.py
@@ -731,7 +731,7 @@ def read(v: Atom) -> Any:
return v.val if isinstance(v, Literal) else env[v]
def write(v: Var, val: Any) -> None:
- if config.enable_checks.value and not config.dynamic_shapes.value:
+ if config.enable_checks.value:
assert typecheck(v.aval, val), (v.aval, get_aval(val), val)
env[v] = val
@@ -895,7 +895,7 @@ def _aval_property(name):
return property(lambda self: getattr(self.aval, name))
-if TYPE_CHECKING or jaxlib_extension_version < 388:
+if TYPE_CHECKING:
# We want Python type checkers to accept `some_tracer: jax.Array`, even though
# tracers can represent non-arrays. That is, ideally we would only accept that
# annotation when the Tracer instance has a ShapedArray aval, but we can't
@@ -1064,8 +1064,8 @@ def __getattr__(self, name):
if name == 'sharding':
raise AttributeError(
- f"The 'sharding' attribute is not available on {self._error_repr()}."
- f"{self._origin_msg()}")
+ f"The 'sharding' attribute is not available on {self._error_repr()}. "
+ "To query sharding information on tracers, use `jax.typeof(x)`.")
try:
attr = getattr(self.aval, name)
@@ -1654,6 +1654,9 @@ class AbstractValue:
def to_tangent_aval(self):
raise NotImplementedError("must override")
+ def to_cotangent_aval(self):
+ raise NotImplementedError("must override")
+
# TODO(dougalm): deprecate this alias
def at_least_vspace(self):
return self.to_tangent_aval()
@@ -1689,38 +1692,8 @@ def lo_ty_qdd(self, qdd):
def str_short(self, short_dtypes=False, mesh_axis_types=False):
return str(self)
-# For type signatures involving dynamic shapes, we use lists of abstract values
-# which may contain (reverse) de Bruijn indices in their shapes.
-class DBIdx(NamedTuple):
- val: int
-
-@dataclass(frozen=True)
-class InDBIdx:
- val: int
-
-@dataclass(frozen=True)
-class OutDBIdx:
- val: int
-
-# For annotating input types of callables (i.e. linear_util.WrappedFuns), we use
-# a sequence of pairs where the first element of each pair is an AbstractValue
-# (possibly containing DBIdx instances in its shape) and the second is a boolean
-# indicating whether that argument is explicit (i.e. passed to the callable).
-InputType = tuple[tuple[AbstractValue, bool], ...] # DBIdx in shapes
-
-# For annotating jaxpr output types, we use a sequence of pairs where the first
-# element of each pair is an AbstractValue (possibly containing InDBIdx and/or
-# OutDBIdx instances in its shape) and the second is a boolean indicating
-# whether that argument is explicit (i.e. returned by the callable).
-OutputType = tuple[tuple[AbstractValue, bool], ...] # InDBIdx / OutDBIdx shapes
-
-
-def _jaxpr_type_to_callable_annotation(jaxpr: Jaxpr) -> InputType:
- idxs = {v: DBIdx(i) for i, v in enumerate((*jaxpr.constvars, *jaxpr.invars))}
- out = [(v.aval.update(shape=tuple(idxs.get(d, d) for d in v.aval.shape)) # type: ignore
- if type(v.aval) is DShapedArray else v.aval, True)
- for v in jaxpr.invars]
- return tuple(out)
+InputType = tuple[AbstractValue]
+OutputType = tuple[AbstractValue]
# For use in typing annotations to denote either a Tracer or a `valid_jaxtype`.
Value = Any
@@ -1759,15 +1732,12 @@ def mem_space_to_kind(mem_space: MemorySpace) -> str:
@cache(max_size=4096,
trace_context_in_key=lambda: config.remove_size_one_mesh_axis_from_type.value)
def update_aval_with_sharding(aval, sharding, vma=None):
- if vma is None:
- vma = aval.vma
if isinstance(sharding, NamedSharding):
- return aval.update(
- sharding=NamedSharding(
- sharding.mesh.abstract_mesh,
- sharding.spec._normalized_spec_for_aval(aval.ndim)),
- vma=vma, memory_space=mem_kind_to_space(sharding.memory_kind))
- return aval.update(vma=vma)
+ s = NamedSharding(sharding.mesh.abstract_mesh,
+ sharding.spec._normalized_spec_for_aval(aval.ndim))
+ return aval.update(sharding=s, vma=aval.vma if vma is None else vma,
+ memory_space=mem_kind_to_space(sharding.memory_kind))
+ return aval if vma is None else aval.update(vma=vma)
# We have three flavors of abstractification APIs here which each used to have
@@ -1960,21 +1930,17 @@ def cur_aval_qdd(x):
@overload
def physical_aval(aval: ShapedArray) -> ShapedArray: ...
-@overload
-def physical_aval(aval: DShapedArray) -> DShapedArray: ...
@overload # TODO(frostig): remove this case
def physical_aval(aval: AbstractValue) -> AbstractValue: ...
def physical_aval(aval):
- if (isinstance(aval, (ShapedArray, DShapedArray)) and
+ if (isinstance(aval, ShapedArray) and
isinstance(aval.dtype, dtypes.ExtendedDType)):
elt_aval = physical_element_aval(aval.dtype)
- if isinstance(aval, ShapedArray):
- from jax._src.sharding_impls import physical_sharding # type: ignore
- return ShapedArray((*aval.shape, *elt_aval.shape), elt_aval.dtype,
- sharding=physical_sharding(aval, aval.sharding),
- vma=aval.vma)
- return DShapedArray((*aval.shape, *elt_aval.shape), elt_aval.dtype)
+ from jax._src.sharding_impls import physical_sharding # type: ignore
+ return ShapedArray((*aval.shape, *elt_aval.shape), elt_aval.dtype,
+ sharding=physical_sharding(aval, aval.sharding),
+ vma=aval.vma)
return aval
def physical_shape(logical_shape, dtype):
@@ -1996,15 +1962,7 @@ def _canonicalize_dimension(dim: DimSize) -> DimSize:
return operator.index(dim)
except TypeError as e:
type_error = e
- if isinstance(dim, Tracer) and config.dynamic_shapes.value:
- if not (dim.ndim == 0 and (dtypes.issubdtype(dim.dtype, np.integer)
- or isinstance(dim.dtype, bint))):
- raise TypeError(f"Dimensions must be integer scalars; got {dim.ndim=} {dim.dtype=}")
- return dim
- elif (config.dynamic_shapes.value and isinstance(dim, DArray) and
- type(dim._aval.dtype) is bint and not dim._aval.shape):
- return dim
- elif is_dim(dim):
+ if is_dim(dim):
return dim
else:
raise type_error
@@ -2038,16 +1996,11 @@ def canonicalize_dim(d: DimSize, context: str="") -> DimSize:
return canonicalize_shape((d,), context)[0]
def _invalid_shape_error(shape: Shape, context: str=""):
- if config.dynamic_shapes.value:
- msg = ("Shapes must be 1D sequences of integer scalars, "
- f"got {shape}")
- else:
- msg = ("Shapes must be 1D sequences of concrete values of integer type, "
- f"got {shape}.")
+ msg = ("Shapes must be 1D sequences of concrete values of integer type, "
+ f"got {shape}.")
if context:
msg += f" {context}."
- if not config.dynamic_shapes.value and any(
- isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray)
+ if any(isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray)
and not is_concrete(x) for x in shape):
msg += ("\nIf using `jit`, try using `static_argnums` or applying `jit` to "
"smaller subfunctions.")
@@ -2488,149 +2441,6 @@ def standard_vma_rule(prim_name, *avals, **kwargs) -> frozenset[AxisName]:
'workaround pass the check_vma=False argument to `jax.shard_map`')
return vma
-# Dynamic shape stuff below here! We keep the abstract values distinct just so
-# as not to interfere with any static shape machinery.
-
-# We have a convention of reusing AbsractValues as types, even though we could
-# make a distinction and use abstract values during tracing only. This reuse
-# becomes a bit more extreme with DShapedArrays. A DShapedArray's shape
-# attribute is a tuple which can contain several different types: int, DArray
-# (scalar and with dtype of bint type), Tracer (while tracing), Var (when used
-# as jaxpr type annotations), or DBIdx/InDBIdx/OutDBIdx (when used in InputType
-# or OutputType). We could reduce this polymorphism if it seems cleaner, though
-# it's kind of convenient!
-class DShapedArray(AbstractValue):
- __slots__ = ['shape', 'dtype', 'weak_type']
- shape: tuple[AxisSize, ...] # noqa: F821
- array_abstraction_level: int = 3
-
- def __init__(self, shape, dtype, weak_type=False):
- assert not any(isinstance(d, Literal) for d in shape)
- self.shape = shape
- self.dtype = dtype
- self.weak_type = weak_type
-
- ndim = property(lambda self: len(self.shape))
- size = property(lambda self:
- 0 if any(type(d) is int and d == 0 for d in self.shape)
- else math.prod(self.shape))
-
- def str_short(self, short_dtypes=False, mesh_axis_types=False) -> str:
- del short_dtypes # ignored
- shape = f'{",".join(str(d) for d in self.shape)}' if self.shape else ''
- dtype = dtypes.short_dtype_name(self.dtype)
- return f'{dtype}[{shape}]'
- __str__ = __repr__ = str_short
-
- def update(self, shape=None, dtype=None, weak_type=None):
- if shape is None:
- shape = self.shape
- if dtype is None:
- dtype = self.dtype
- if weak_type is None:
- weak_type = self.weak_type
- return DShapedArray(shape, dtype, weak_type)
-
- @property
- def sharding(self):
- return NamedSharding(mesh_lib.empty_abstract_mesh, P())
-
- @property
- def vma(self):
- return frozenset()
-
- def _len(self, tracer):
- return self.shape[0]
-
- def __eq__(self, other):
- return (type(self) is type(other)
- and self.dtype == other.dtype and self.shape == other.shape
- and self.weak_type == other.weak_type)
-
- def __hash__(self):
- # We don't hash the contents of the shape because it may contain tracers.
- return hash((len(self.shape), self.dtype, self.weak_type))
-
- def __ne__(self, other):
- return not self == other
-
- def to_tangent_aval(self):
- return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
- self.weak_type)
-
- def update_vma(self, vma):
- return self
-
- def update_weak_type(self, weak_type):
- return self.update(weak_type=weak_type)
-
- _bool = concretization_function_error(bool)
- _int = concretization_function_error(int, True)
- _float = concretization_function_error(float, True)
- _complex = concretization_function_error(complex, True)
- _hex = concretization_function_error(hex)
- _oct = concretization_function_error(oct)
- _index = concretization_function_error(operator.index)
-
-
-class DArray:
- _aval: DShapedArray
- _data: Any # standard array type
- def __init__(self, aval, data):
- pad_shape = tuple(d.dtype.bound if type(d) is DArray and
- type(d.dtype) is bint else d for d in aval.shape)
- assert data.shape == pad_shape
- self._aval = aval
- self._data = data
-
- shape = property(lambda self: self._aval.shape)
- dtype = property(lambda self: self._aval.dtype)
- aval = property(lambda self: self._aval)
- def __repr__(self) -> str:
- if not self.shape and type(self.dtype) is bint:
- # special-case scalar bints
- return f'{int(self._data)}{{≤{self.dtype.bound}}}'
-
- dtypestr = dtypes.short_dtype_name(self._aval.dtype)
- shapestr = ','.join(map(str, self.shape))
- data = self.data
- return f'{dtypestr}[{shapestr}] with value: {data}'
-
- def __hash__(self) -> int:
- if not self.shape:
- return hash((self._aval, int(self._data)))
- raise TypeError("unhashable type: DArray")
-
- def __eq__(self, other):
- if isinstance(other, DArray) and self._aval == other._aval:
- return self._data == other._data
- return False
-
- def __len__(self):
- return self.shape[0]
-
- @property
- def data(self):
- if not self.shape and type(self.dtype) is bint:
- # special-case scalar bints
- return self._data
-
- slices = tuple(
- slice(int(d._data))
- if type(d) is DArray and type(d.dtype) is bint
- else slice(None)
- for d in self.shape
- )
- data = self._data[slices]
- return data
-
-def _darray_aval(x):
- return DShapedArray(x._aval.shape, x._aval.dtype, x._aval.weak_type)
-
-pytype_aval_mappings[DArray] = _darray_aval
-dtypes.canonicalize_value_handlers[DArray] = lambda x: x
-
-
@dataclass(frozen=True)
class bint(dtypes.ExtendedDType):
bound: int
@@ -2646,7 +2456,7 @@ def name(self) -> str:
def __str__(self) -> str:
return self.name
-AxisSize = Union[int, DArray, Tracer, Var, DBIdx, InDBIdx, OutDBIdx]
+AxisSize = Union[int, Tracer, Var]
class RefMeta(type):
@@ -2812,6 +2622,7 @@ def accum_grad_in_ref(x):
class AbstractToken(AbstractValue):
def str_short(self, short_dtypes=False, mesh_axis_types=False): return 'Tok'
def to_tangent_aval(self): return self
+ def to_cotangent_aval(self): return self
abstract_token: AbstractToken = AbstractToken()
# Singleton shaped array used by all abstract tokens when shape/dtype is needed.
@@ -3100,8 +2911,6 @@ def get_bind_params(self, params):
jaxpr = new_params.pop('call_jaxpr')
subfun = lu.hashable_partial(
lu.wrap_init(eval_jaxpr, debug_info=jaxpr.debug_info), jaxpr, ())
- if config.dynamic_shapes.value:
- subfun = lu.annotate(subfun, _jaxpr_type_to_callable_annotation(jaxpr))
return [subfun], new_params
def call_impl(f: lu.WrappedFun, *args, **params):
@@ -3197,25 +3006,8 @@ def _unmap_shaped_array(
else:
raise TypeError(axis)
-def _map_dshaped_array(
- size: AxisSize, axis: int | None, aval: DShapedArray) -> DShapedArray:
- if axis is None: return aval
- return DShapedArray(tuple_delete(aval.shape, axis), aval.dtype,
- aval.weak_type)
-
-def _unmap_dshaped_array(
- size: AxisSize, axis: int | None, explicit_mesh_axis, aval: DShapedArray
- ) -> DShapedArray:
- if axis is None: return aval
- elif type(axis) is int:
- return DShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype,
- weak_type=aval.weak_type)
- else:
- raise TypeError(axis)
-
AvalMapHandlerPair = tuple[Callable, Callable]
aval_mapping_handlers: dict[type, AvalMapHandlerPair] = {
- DShapedArray: (_map_dshaped_array, _unmap_dshaped_array),
ShapedArray: (_map_shaped_array, _unmap_shaped_array),
AbstractToken: (lambda _, __, a: a, lambda _, __, ____, a: a)
}
@@ -3292,36 +3084,39 @@ def typecompat(aval_ref: AbstractValue, aval: AbstractValue) -> bool:
except TypeError:
return False
-def typematch(t1: AbstractValue, t2: AbstractValue) -> bool:
+def typematch(t1: AbstractValue, t2: AbstractValue,
+ only_shape_shd_check: bool = False) -> bool:
"""Determine whether `t1` and `t2` are equivalent. Ignores weak_type."""
t1 = t1.normalize()
t2 = t2.normalize()
from jax._src.state.types import AbstractRef # pytype: disable=import-error
if t1 == t2:
return True
- elif (isinstance(t1, (ShapedArray, DShapedArray)) and
- isinstance(t2, (ShapedArray, DShapedArray))):
- # This case handles DShapedArray and shape polynomials. Alternatively we
- # could try normalizing first and then doing simple equality.
- cmp = (t1.dtype == t2.dtype and definitely_equal_shape(t1.shape, t2.shape)
- and t1.vma == t2.vma and t1.memory_space == t2.memory_space) # type: ignore
- # TODO(yashkatariya): Expand this to Manual and Auto mode.
- # See https://github.com/jax-ml/jax/issues/26474
- if (not t1.sharding.mesh.empty and not t2.sharding.mesh.empty and
- (t1.sharding.mesh._any_axis_explicit or
- t2.sharding.mesh._any_axis_explicit)):
- sh_eq = t1.sharding == t2.sharding
- else:
- sh_eq = True
- return cmp and sh_eq
+ elif isinstance(t1, ShapedArray) and isinstance(t2, ShapedArray):
+ if only_shape_shd_check:
+ return cmp_shape_sharding_vma(t1, t2)
+ return (t1.dtype == t2.dtype and cmp_shape_sharding_vma(t1, t2) and
+ t1.memory_space == t2.memory_space)
elif isinstance(t1, AbstractRef) and isinstance(t2, AbstractRef):
# We want to use the regular typecheck for ShapedArray here.
- return (typematch(t1.inner_aval, t2.inner_aval) and # type: ignore
+ return (typematch(t1.inner_aval, t2.inner_aval, only_shape_shd_check) and # type: ignore
(t1.memory_space is None or t2.memory_space is None or # type: ignore
t1.memory_space == t2.memory_space)) # type: ignore
else:
return False
+def cmp_shape_sharding_vma(t1, t2):
+ # TODO(yashkatariya): Expand this to Manual and Auto mode.
+ # See https://github.com/jax-ml/jax/issues/26474
+ if (not t1.sharding.mesh.empty and not t2.sharding.mesh.empty and
+ (t1.sharding.mesh._any_axis_explicit or
+ t2.sharding.mesh._any_axis_explicit)):
+ shd_eq = t1.sharding == t2.sharding
+ else:
+ shd_eq = True
+ return (shd_eq and definitely_equal_shape(t1.shape, t2.shape) and
+ t1.vma == t2.vma)
+
def aval_mismatch_extra(a1: AbstractValue, a2: AbstractValue) -> str:
assert not typematch(a1, a2)
if isinstance(a1, ShapedArray) and isinstance(a2, ShapedArray):
@@ -3521,7 +3316,6 @@ def write(v: Var, a: AvalQDD) -> None:
f"Jaxpr effects: {jaxpr.effects}")
# Check out_type matches the let-binders' annotation (after substitution).
- out_type = substitute_vars_in_output_ty(out_type, eqn.invars, eqn.outvars)
out_type = [t if isinstance(t, AvalQDD) else AvalQDD(t, None) for t in out_type]
foreach(write, eqn.outvars, out_type)
@@ -3548,51 +3342,7 @@ def check_type(
env: dict[Var, Atom | MutableTypecheckVal],
ty: AbstractValue,
) -> None:
- if isinstance(ty, DShapedArray):
- # Check all elements in the shape tuple are well-typed.
- for d in ty.shape:
- if (isinstance(d, int) or
- isinstance(d, DArray) and not d.shape and type(d.dtype) == bint):
- continue
- elif isinstance(d, Var):
- if d not in env:
- ctx, _ = ctx_factory()
- raise JaxprTypeError(f"unbound axis size: '{pp_var(d, ctx)}'")
- if not isinstance(d.aval, (ShapedArray, DShapedArray)):
- raise JaxprTypeError(f"axis size with unexpected type annotation: "
- f"{d.aval} of type {type(d.aval)}")
- if isinstance(d.aval, ShapedArray):
- shape, dtype = d.aval.shape, d.aval.dtype
- if shape: raise JaxprTypeError(f"axis size nonscalar: {d.aval}")
- if not dtypes.issubdtype(dtype, np.integer):
- raise JaxprTypeError(f"axis size with non-integer dtype: {d.aval}")
- else:
- assert isinstance(d.aval, DShapedArray)
- shape, dtype = d.aval.shape, d.aval.dtype
- if shape: raise JaxprTypeError(f"axis size nonscalar: {d.aval}")
- if type(dtype) is not bint:
- raise JaxprTypeError(
- f"DArray axis size with non-bint dtype: {d.aval}")
- else:
- raise JaxprTypeError(f"unexpected type in shape: {type(d)}")
- else:
- return # Except in above case(s), all syntactic forms are valid
-
-def substitute_vars_in_output_ty(
- out_type: Sequence[AbstractValue], # shapes may contain InDBIdx / OutDBIdx
- in_atoms: Sequence[Atom],
- out_binders: Sequence[Var],
- ) -> list[AbstractValue]: # shapes may contain Vars
- in_atoms = [x.val if type(x) is Literal else x for x in in_atoms]
- result = []
- for aval in out_type:
- if type(aval) is DShapedArray:
- shape = [in_atoms[d.val] if type(d) is InDBIdx else
- out_binders[d.val] if type(d) is OutDBIdx else
- d for d in aval.shape]
- aval = aval.update(shape=tuple(shape))
- result.append(aval)
- return result
+ return # Except in above case(s), all syntactic forms are valid
def check_eqn(prim, in_avals, params):
for jaxpr in jaxprs_in_params(params):
@@ -3619,29 +3369,19 @@ def _check_call(ctx_factory, prim, in_atoms, params):
# Check `call_jaxpr` can be applied to in_atoms.
env: dict[Var, Atom | MutableTypecheckVal] = {}
- def substitute(aval: AbstractValue):
- if isinstance(aval, DShapedArray):
- aval = aval.update(shape=tuple(env.get(d, d) for d in aval.shape)) # type: ignore
- return aval
for v, x in zip(call_jaxpr.invars, in_atoms):
- if not typecompat(substitute(v.aval), x.aval):
+ if not typecompat(v.aval, x.aval):
# TODO(mattjj): vars in error message are confusing b/c of Var.__repr__
raise JaxprTypeError(f"Call primitive {prim} passes operand {x} of type "
f"{x.aval} to jaxpr expecting type "
- f"{substitute(v.aval)}")
+ f"{v.aval}")
env[v] = x.val if type(x) is Literal else x
check_jaxpr(call_jaxpr)
invars, outvars = call_jaxpr.invars, call_jaxpr.outvars
- in_map : dict[Var, InDBIdx] = {v: InDBIdx(i) for i, v in enumerate( invars)}
- out_map: dict[Var, OutDBIdx] = {x: OutDBIdx(i) for i, x in enumerate(outvars)
- if type(x) is Var}
out_avals = [x.aval for x in call_jaxpr.outvars]
- out_type = [a.update(shape=tuple(in_map.get(d, out_map.get(d))
- if type(d) is Var else d for d in a.shape))
- if type(a) is DShapedArray else a for a in out_avals]
-
+ out_type = out_avals
# jaxpr input effects are indexed to include jaxpr.constvars, but the eqn
# should have effects indexed only on its explicit arguments
effs = {e.replace(input_index=e.input_index - len(call_jaxpr.constvars))
@@ -3698,6 +3438,14 @@ def eqn_effects(jaxpr):
# ------------------- ShapeDtypeStruct -------------------
+def _check_sharding(sharding, shape):
+ if sharding is None:
+ return
+ if isinstance(sharding, P):
+ sharding._check_compatible_wrt_shape(shape)
+ else:
+ sharding.check_compatible_aval(shape)
+
@set_module("jax")
class ShapeDtypeStruct:
"""A container for the shape, dtype, and other static attributes of an array.
@@ -3731,6 +3479,7 @@ def __init__(self, shape, dtype, *, sharding=None, weak_type=False,
f" layout in a `ShapeDtypeStruct`. Got {sharding}")
self._sharding = (sharding.sharding if isinstance(sharding, Format)
else sharding)
+ _check_sharding(self._sharding, self.shape)
self._dll = sharding.layout if isinstance(sharding, Format) else None
self.weak_type = weak_type
if vma is not None and not isinstance(vma, (set, frozenset)):
@@ -3955,12 +3704,7 @@ def pp_var(v: Var | Literal, context: JaxprPpContext, *,
return v.pretty_print(context, print_dtype=print_literal_dtype)
def pp_aval(a: AbstractValue, context: JaxprPpContext) -> str:
- if isinstance(a, DShapedArray):
- shape = [pp_var(d, context) if type(d) is Var else str(d) for d in a.shape]
- dtype = dtypes.short_dtype_name(a.dtype)
- return f'{dtype}[{",".join(shape)}]'
- else:
- return a.str_short(short_dtypes=True)
+ return a.str_short(short_dtypes=True)
def pp_vars(vs: Sequence[Atom], context: JaxprPpContext,
*, separator="", print_shapes: bool = False) -> pp.Doc:
diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py
index 9583b0fce0f9..43599c264cbc 100644
--- a/jax/_src/cudnn/fused_attention_stablehlo.py
+++ b/jax/_src/cudnn/fused_attention_stablehlo.py
@@ -379,8 +379,8 @@ def check_is_flash_attention(
else:
# bf16/fp16 attention conditions
# Check the head dim.
- is_on_hopper = is_cuda_compute_capability_equal("9.0")
- H_max = 256 if is_on_hopper else 128
+ is_hopper_or_later = check_compute_capability("9.0")
+ H_max = 256 if is_hopper_or_later else 128
# check if multi-head latent attention is needed
is_mla = qH != vH
if not (qH <= H_max and qH % 8 == 0):
diff --git a/jax/_src/cudnn/scaled_matmul_stablehlo.py b/jax/_src/cudnn/scaled_matmul_stablehlo.py
index 48c5d27c2678..b6b80748fb24 100644
--- a/jax/_src/cudnn/scaled_matmul_stablehlo.py
+++ b/jax/_src/cudnn/scaled_matmul_stablehlo.py
@@ -66,7 +66,7 @@ def _scaled_matmul_impl(a, b, a_scale, b_scale, preferred_element_type):
)
-def _scaled_matmul_cuda_lowering(
+def _scaled_matmul_gpu_lowering(
ctx, a, b, a_scales, b_scales, preferred_element_type
):
lhs_type = ir.RankedTensorType(a.type)
@@ -119,9 +119,14 @@ def _scaled_matmul_abstract(a, b, a_scale, b_scale, *, preferred_element_type):
mlir.register_lowering(
_scaled_matmul_p,
- _scaled_matmul_cuda_lowering,
+ _scaled_matmul_gpu_lowering,
platform="cuda",
)
+mlir.register_lowering(
+ _scaled_matmul_p,
+ _scaled_matmul_gpu_lowering,
+ platform="rocm",
+)
_scaled_matmul_p_wrapper = core.Primitive("scaled_matmul_wrapper")
_scaled_matmul_p_wrapper.multiple_results = True
diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py
index 8c66d3da942c..e394e8273260 100644
--- a/jax/_src/custom_batching.py
+++ b/jax/_src/custom_batching.py
@@ -260,7 +260,8 @@ def custom_vmap_batching(args_flat, dims, *, call, rule, in_tree, out_tree):
def custom_vmap_abstract_eval(*in_avals, call, **_):
- return call.out_avals
+ del in_avals
+ return call.out_avals, call.effects
def custom_vmap_jvp(primals, tangents, *,
@@ -347,7 +348,7 @@ def to_vmap_over_extra_batched_dims(primals, tangents):
custom_vmap_p = core.Primitive('custom_vmap_call')
custom_vmap_p.multiple_results = True
custom_vmap_p.def_impl(custom_vmap_impl)
-custom_vmap_p.def_abstract_eval(custom_vmap_abstract_eval)
+custom_vmap_p.def_effectful_abstract_eval(custom_vmap_abstract_eval)
batching.primitive_batchers[custom_vmap_p] = custom_vmap_batching
ad.primitive_jvps[custom_vmap_p] = custom_vmap_jvp
pxla.register_initial_style_primitive(custom_vmap_p)
diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py
index ed9b051d5777..94528250d67d 100644
--- a/jax/_src/custom_derivatives.py
+++ b/jax/_src/custom_derivatives.py
@@ -449,14 +449,6 @@ def _custom_jvp_vjp_call_lowering(ctx: mlir.LoweringRuleContext, *args,
return out
mlir.register_lowering(custom_jvp_call_p, _custom_jvp_vjp_call_lowering)
-# If a (multi)linear function is defined with a custom jvp, then
-# custom_jvp_call_ can appear in jaxprs to be transposed. Since it's already
-# been linearized, we can drop the jvp rule.
-def _custom_jvp_call_transpose(params, jaxpr, args, ct, _):
- del params
- return ad.backward_pass(jaxpr.jaxpr, False, jaxpr.consts, args, ct)
-ad.primitive_transposes[custom_jvp_call_p] = _custom_jvp_call_transpose
-
def _custom_jvp_call_transpose_fancy(params, jaxpr, args, ct, _):
del params
return ad.backward_pass3(jaxpr.jaxpr, False, jaxpr.consts, args, ct)
@@ -953,7 +945,7 @@ def append(x, d):
if ct is zero or getattr(a.to_tangent_aval(), 'dtype') == dtypes.float0:
results.append(Zero(a.to_tangent_aval()))
elif type(ct) is SymbolicZero:
- if not core.typecompat(a.to_tangent_aval(), a_ := ct.aval):
+ if not core.typecompat(a.to_cotangent_aval(), a_ := ct.aval):
msg = ("Custom VJP bwd rule produced a SymbolicZero with a shape/dtype "
"that does not match the corresponding input tangent shape/dtype: "
f"at output{keystr(kp)} the SymbolicZero had shape/dtype "
@@ -964,15 +956,15 @@ def append(x, d):
raise ValueError(msg)
results.append(Zero(ct.aval))
else:
- if (not core.typecompat(a.to_tangent_aval(), a_ := core.get_aval(ct)) and
- not _ref_typecompat(a.to_tangent_aval(), a_) and
- not (_temporary_dtype_exception(a, a_) or
- _temporary_shape_exception(a, a_))):
+ if (not config.disable_bwd_checks.value and
+ not core.typecompat(a.to_cotangent_aval(), a_ := core.get_aval(ct))
+ and not _ref_typecompat(a.to_cotangent_aval(), a_)
+ and not _temporary_dtype_exception(a.to_cotangent_aval(), a_)):
msg = ("Custom VJP bwd rule must produce an output with the same "
- "shape/dtypes as the args tuple of the primal function, but at "
+ "type as the args tuple of the primal function, but at "
f"output{keystr(kp)} the bwd rule produced an output of "
- f"shape/dtype {a_.str_short()} corresponding "
- f"to an input of shape/dtype {a.str_short()}"
+ f"type {a_.str_short()} corresponding "
+ f"to an input of type {a.str_short()}"
f"{core.aval_mismatch_extra(a, a_)}")
raise ValueError(msg)
results.append(ct)
@@ -980,19 +972,17 @@ def append(x, d):
def _ref_typecompat(a, a_):
return (isinstance(a, AbstractRef) and
- core.typecompat(a.to_tangent_aval().inner_aval, a_))
+ core.typecompat(a.to_cotangent_aval().inner_aval, a_))
# TODO(mattjj): remove both these exceptions to cotangent compatibility check
def _temporary_dtype_exception(a, a_) -> bool:
if isinstance(a, core.ShapedArray) and isinstance(a_, core.ShapedArray):
return (a.shape == a_.shape and
+ core.typematch(a, a_, only_shape_shd_check=True) and
(dtypes.issubdtype(a_.dtype, dtypes.extended) or
dtypes.issubdtype(a.dtype, dtypes.np.inexact)))
return False
-# TODO(mattjj): remove both these exceptions to cotangent compatibility check
-def _temporary_shape_exception(a, a_) -> bool:
- return config.custom_vjp_disable_shape_check.value
class CustomVJPCallPrimitive(core.Primitive):
multiple_results = True
diff --git a/jax/_src/deprecations.py b/jax/_src/deprecations.py
index 3ae3e19cb8c1..39894348e557 100644
--- a/jax/_src/deprecations.py
+++ b/jax/_src/deprecations.py
@@ -125,12 +125,10 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None:
# always registered by the time `accelerate` and `is_acelerated` are called.
register('default-dtype-bits-config')
register('jax-checkpoint-concrete')
-register('jax-lax-dot-positional-args')
-register('jax-lib-module')
register('jax-nn-one-hot-float-input')
-register("jax-numpy-astype-complex-to-real")
+register('jax-numpy-arange-complex')
+register('jax-numpy-astype-complex-to-real')
register('jax-numpy-clip-args')
register('jax-scipy-special-sph-harm')
-register('safer-randint-config')
register('jax-pmap-no-rank-reduction')
-register('jax-make-mesh-default-explicit')
+register('pltpu-memory-space-any')
diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py
index cdd9d3462027..d3fca5107b01 100644
--- a/jax/_src/dispatch.py
+++ b/jax/_src/dispatch.py
@@ -19,7 +19,6 @@
from collections.abc import Sequence
import dataclasses
from functools import partial
-import itertools
import logging
import threading
import time
@@ -286,24 +285,6 @@ def get_intermediate_shardings(
return out
-def jaxpr_has_bints(jaxpr: core.Jaxpr) -> bool:
- return (any(type(v.aval.dtype) is core.bint for v in jaxpr.invars
- if isinstance(v.aval, (core.ShapedArray, core.DShapedArray))) or
- any(_is_bint_axis_size(d)
- for j in itertools.chain([jaxpr], core.subjaxprs(jaxpr))
- for e in j.eqns for v in e.outvars
- if isinstance(v.aval, core.DShapedArray) for d in v.aval.shape))
-
-def _is_bint_axis_size(d: core.AxisSize) -> bool:
- if isinstance(d, core.DArray):
- assert not d.shape
- return type(d.dtype) is core.bint
- elif isinstance(d, core.Var):
- return (isinstance(d.aval, core.DShapedArray) and
- type(d.aval.dtype) is core.bint)
- return False
-
-
def check_arg(arg: Any):
if not (isinstance(arg, core.Tracer) or core.valid_jaxtype(arg)):
raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid "
@@ -318,6 +299,15 @@ def check_special(name: str, bufs: Sequence[basearray.Array]) -> None:
for buf in bufs:
_check_special(name, buf.dtype, buf)
+
+def check_special_array(name: str, arr: array.ArrayImpl) -> array.ArrayImpl:
+ if needs_check_special():
+ if dtypes.issubdtype(arr.dtype, np.inexact):
+ for buf in arr._arrays:
+ _check_special(name, buf.dtype, buf)
+ return arr
+
+
def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None:
if dtypes.issubdtype(dtype, np.inexact):
if config.debug_nans.value and np.any(np.isnan(np.asarray(buf))):
@@ -406,6 +396,25 @@ def result_handler(self, shard_arg_result):
return pxla.global_aval_to_result_handler(
self.aval, self.s, self.committed)(shard_arg_result)
+@dataclasses.dataclass(frozen=True)
+class _DeferredCrossHostTransferArg:
+ """Deferred call to `xc.batched_copy_array_to_devices_with_sharding` for
+ cross-host data transfers.
+
+ Per-array impls return this object instead of a result array to indicate a
+ deferred `batched_copy_array_to_devices_with_sharding` call for a cross-host
+ data transfer. `_batched_device_put_impl` then batches all
+ `_DeferredCrossHostTransferArg` objects into a single
+ `_batched_device_put_impl` call.
+
+ For any _DeferredCrossHostTransferArg, _is_supported_cross_host_transfer(
+ x.ndim, x.sharding, dst_sharding) == True.
+ """
+
+ x: array.ArrayImpl
+ dst_sharding: Sharding
+ copy_semantics: ArrayCopySemantics
+
def _device_put_sharding_impl(
x: Any,
@@ -444,9 +453,7 @@ def _device_put_sharding_impl(
if (x_is_jax_array and x._committed and xla_bridge.process_count() > 1
and _is_supported_cross_host_transfer(x.ndim, x_sharding, s)):
- return xc.batched_copy_array_to_devices_with_sharding(
- [x], [s._internal_device_list], [s], # pytype: disable=attribute-error
- [copy])[0]
+ return _DeferredCrossHostTransferArg(x, s, copy)
if not s_is_fully_addressable:
# If both the source and target shardings are not fully addressable and
@@ -562,7 +569,14 @@ def _batched_device_put_impl(
copy_semantics: Sequence[ArrayCopySemantics],
dst_avals: Sequence[core.ShapedArray | None]):
ys = []
+
+ # Used to batch transfers when _device_put_impl returns a _DeferredShardArg.
dsa_indices, dsa_xs, dsa_shardings, dsa_copy_semantics = [], [], [], []
+ # Used to batch transfers when _device_put_impl returns a
+ # _DeferredCrossHostTransferArg.
+ dca_indices, dca_xs, dca_shardings, dca_device_lists, dca_copy_semantics = \
+ [], [], [], [], []
+
for i, (x, device, src, cp, aval) in enumerate(
zip(xs, devices, srcs, copy_semantics, dst_avals)):
y = _device_put_impl(x, device=device, src=src, copy=cp, aval=aval)
@@ -571,11 +585,17 @@ def _batched_device_put_impl(
dsa_xs.append(y.x)
dsa_shardings.append(y.s)
dsa_copy_semantics.append(y.copy_semantics)
+ elif isinstance(y, _DeferredCrossHostTransferArg):
+ dca_indices.append(i)
+ dca_xs.append(y.x)
+ dca_shardings.append(y.dst_sharding)
+ dca_device_lists.append(y.dst_sharding._internal_device_list) # pytype: disable=attribute-error
+ dca_copy_semantics.append(y.copy_semantics)
ys.append(y)
+ # Batch shard_arg / batched_copy_array_to_devices_with_sharding calls. Helps
+ # improve efficiency for backends that support efficient batch transfer.
if dsa_xs:
- # Batch shard_arg calls. Helps improve efficiency for backends that support
- # efficient batch transfer.
# device_put handles `Format` via a different path, so just pass `None` as
# the layout here.
shard_arg_results = pxla.shard_args(dsa_shardings, [None] * len(dsa_xs),
@@ -583,6 +603,13 @@ def _batched_device_put_impl(
for i, shard_arg_result in zip(dsa_indices, shard_arg_results):
assert isinstance(ys[i], _DeferredShardArg)
ys[i] = ys[i].result_handler(shard_arg_result)
+ if dca_xs:
+ copy_array_results = xc.batched_copy_array_to_devices_with_sharding(
+ dca_xs, dca_device_lists, dca_shardings, dca_copy_semantics)
+ for i, copy_array_result in zip(dca_indices, copy_array_results):
+ assert isinstance(ys[i], _DeferredCrossHostTransferArg)
+ ys[i] = copy_array_result
+
return ys
def batched_device_put_impl(
diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py
index 715d831e4e9c..b6f1ab7443d6 100644
--- a/jax/_src/distributed.py
+++ b/jax/_src/distributed.py
@@ -44,6 +44,15 @@
),
)
+_ENABLE_PREEMPTION_SERVICE = config.bool_state(
+ name='jax_enable_preemption_service',
+ default=True,
+ help=(
+ "Enables the preemption service. See"
+ " multihost_utils.reached_preemption_sync_point for details."
+ ),
+)
+
class State:
process_id: int = 0
num_processes: int = 1
@@ -188,6 +197,12 @@ def shutdown(self):
self.service = None
def initialize_preemption_sync_manager(self):
+ if not _ENABLE_PREEMPTION_SERVICE.value:
+ logger.info(
+ 'The JAX preemption service is disabled. You can enable it using the'
+ ' jax_enable_preemption_service configuration option.'
+ )
+ return
if self.preemption_sync_manager is not None:
raise RuntimeError(
'Preemption sync manager should only be initialized once.')
diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py
index fbf4c2fd86e1..03fbfe91fd1b 100644
--- a/jax/_src/dtypes.py
+++ b/jax/_src/dtypes.py
@@ -184,20 +184,10 @@ def supports_inf(dtype: DTypeLike) -> bool:
# Default types.
bool_ = np.bool_
-int_: type[Any]
-uint: type[Any]
-float_: type[Any]
-complex_: type[Any]
-if config.default_dtype_bits.value == '32':
- int_ = np.int32
- uint = np.uint32
- float_ = np.float32
- complex_ = np.complex64
-else:
- int_ = np.int64
- uint = np.uint64
- float_ = np.float64
- complex_ = np.complex128
+int_: type[Any] = np.int64
+uint: type[Any] = np.uint64
+float_: type[Any] = np.float64
+complex_: type[Any] = np.complex128
# Default dtypes. These are intended to have the same semantics as, say,
@@ -206,33 +196,23 @@ def supports_inf(dtype: DTypeLike) -> bool:
def default_int_dtype() -> DType:
- return (
- np.dtype(np.int64)
- if config.enable_x64.value and config.default_dtype_bits.value == '64'
- else np.dtype(np.int32)
- )
+ return np.dtype(np.int64) if config.enable_x64.value else np.dtype(np.int32)
def default_uint_dtype() -> DType:
- return (
- np.dtype(np.uint64)
- if config.enable_x64.value and config.default_dtype_bits.value == '64'
- else np.dtype(np.uint32)
- )
+ return np.dtype(np.uint64) if config.enable_x64.value else np.dtype(np.uint32)
def default_float_dtype() -> DType:
return (
- np.dtype(np.float64)
- if config.enable_x64.value and config.default_dtype_bits.value == '64'
- else np.dtype(np.float32)
+ np.dtype(np.float64) if config.enable_x64.value else np.dtype(np.float32)
)
def default_complex_dtype() -> DType:
return (
np.dtype(np.complex128)
- if config.enable_x64.value and config.default_dtype_bits.value == '64'
+ if config.enable_x64.value
else np.dtype(np.complex64)
)
@@ -995,7 +975,11 @@ def dtype(x: Any) -> DType:
dt = x.dtype
else:
try:
- dt = np.result_type(x)
+ with warnings.catch_warnings():
+ # Ignore warning associated with __numpy_dtype__ change in NumPy 2.4.
+ # TODO(jakevdp): remove this warning context after change is finalized.
+ warnings.simplefilter("ignore", DeprecationWarning)
+ dt = np.result_type(x)
except TypeError as err:
raise TypeError(f"Cannot determine dtype of {x}") from err
if dt not in _jax_dtype_set and not issubdtype(dt, extended):
diff --git a/jax/_src/export/serialization.py b/jax/_src/export/serialization.py
index 4dd4e5755ee4..4c8a65b89856 100644
--- a/jax/_src/export/serialization.py
+++ b/jax/_src/export/serialization.py
@@ -53,12 +53,7 @@
# Version 5, November 23rd, 2025, adds serialization for aval memory_space,
# upgrade num_devices to a 32 bit value.
# This version is backwards compatible with Version 2 to 4.
-# TODO(necula): we cannot really store the actual serialization_version
-# in the flatbuffer because prior to 11/25/2025 deserializers checked
-# if the version is 2 or 3. I have now removed that check, but for the
-# sake of old deserializers we can only store version 3. Starting
-# on January 2026 we can store the actual version.
-_SERIALIZATION_VERSION = 3
+_SERIALIZATION_VERSION = 5
def serialize(exp: _export.Exported, vjp_order: int = 0) -> bytearray:
"""Serializes an Exported.
@@ -125,13 +120,19 @@ def _serialize_exported(
vjp = _serialize_exported(builder, exp.vjp(), vjp_order - 1)
ser_flatbuf.ExportedStart(builder)
- ser_flatbuf.ExportedAddSerializationVersion(builder, _SERIALIZATION_VERSION)
+ # TODO(necula): we cannot really store the actual serialization_version
+ # in the flatbuffer because prior to 11/25/2025 deserializers checked
+ # if the version is 2 or 3. I have now removed that check, but for the
+ # sake of old deserializers we can only store version 3. Starting
+ # on January 2026 we can store the actual version.
+ ser_flatbuf.ExportedAddSerializationVersion(builder, 3)
ser_flatbuf.ExportedAddFunctionName(builder, fun_name)
ser_flatbuf.ExportedAddInTree(builder, in_tree)
ser_flatbuf.ExportedAddInAvals(builder, in_avals)
ser_flatbuf.ExportedAddOutTree(builder, out_tree)
ser_flatbuf.ExportedAddOutAvals(builder, out_avals)
ser_flatbuf.ExportedAddNrDevices(builder, exp.nr_devices)
+ ser_flatbuf.ExportedAddNrDevicesShort(builder, exp.nr_devices) # For forward compatibility, can remove after January 2026
ser_flatbuf.ExportedAddInShardings(builder, in_shardings)
ser_flatbuf.ExportedAddOutShardings(builder, out_shardings)
ser_flatbuf.ExportedAddPlatforms(builder, platforms)
diff --git a/jax/_src/hijax.py b/jax/_src/hijax.py
index 80729e9b53d2..4eaca75131a1 100644
--- a/jax/_src/hijax.py
+++ b/jax/_src/hijax.py
@@ -24,9 +24,12 @@
from jax._src import effects
from jax._src.interpreters import ad
from jax._src.interpreters import batching
+from jax._src.interpreters import partial_eval as pe
from jax._src import ad_util
from jax._src.util import safe_zip, safe_map, split_list
-from jax._src.tree_util import tree_flatten, tree_unflatten, tree_leaves, tree_map
+from jax._src.tree_util import (
+ tree_map, tree_flatten, tree_unflatten, tree_leaves, tree_leaves_checked,
+ broadcast_prefix)
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
@@ -92,9 +95,16 @@ def raise_val(self, *lo_vals: LoVal) -> HiVal:
# autodiff interface
def to_tangent_aval(self) -> HiType:
assert False, "must override"
+
+ # Subclasses should override if the cotangent type is a function of primal
+ # type. For example, CT unreduced = reduced and vice-versa.
+ def to_cotangent_aval(self) -> HiType:
+ return self.to_tangent_aval()
+
# the next two are required if this type is itself a tangent type
def vspace_zero(self) -> HiVal:
assert False, "must override"
+
def vspace_add(self, x: HiVal, y: HiVal) -> HiVal:
assert False, "must override"
@@ -127,6 +137,11 @@ def update_from_loval(self, state: QDD, val: HiVal, *lo_vals: LoVal) -> None:
def to_tangent_aval(self) -> HiType:
assert False, "must override"
+ # Subclasses should override if the cotangent type is a function of primal
+ # type. For example, CT unreduced = reduced and vice-versa.
+ def to_cotangent_aval(self) -> HiType:
+ return self.to_tangent_aval()
+
def register_hitype(val_cls, typeof_fn) -> None:
core.pytype_aval_mappings[val_cls] = typeof_fn
dtypes.canonicalize_value_handlers[val_cls] = lambda x: x
@@ -349,7 +364,7 @@ def __init__(self):
def expand(self, *args):
raise NotImplementedError(f"subclass {type(self)} must implement `expand`")
- def vjp_fwd(self, *args):
+ def vjp_fwd(self, nzs_in, *args):
raise NotImplementedError(f"for grad support, subclass {type(self)} must "
"implement `vjp_fwd`")
@@ -376,8 +391,13 @@ def __call__(self, *args):
return tree_unflatten(self.out_tree, ans_flat)
def check(self, *arg_tys):
- # subclass can optionally override this to add checking logic
- return
+ return # subclass can optionally override this to add checking logic
+
+ def staging(self, trace, source_info, *args):
+ args_flat = tree_leaves_checked(self.in_tree, args)
+ ans_flat = trace.default_process_primitive(
+ call_hi_primitive_p, args_flat, dict(prim=self), source_info)
+ return tree_unflatten(self.out_tree, ans_flat)
def __repr__(self):
return f"{self.__class__.__name__}[{self.params}]"
@@ -388,11 +408,6 @@ def __hash__(self):
def __eq__(self, other):
return type(self) is type(other) and self.params == other.params
-def tree_leaves_checked(treedef_expected, tree):
- flat_vals, treedef_actual = tree_flatten(tree)
- assert treedef_actual == treedef_expected
- return flat_vals
-
call_hi_primitive_p = core.Primitive("call_hi_primitive")
call_hi_primitive_p.multiple_results = True
call_hi_primitive_p.is_high = lambda *args, prim: True # type: ignore
@@ -400,9 +415,17 @@ def tree_leaves_checked(treedef_expected, tree):
def _call_hi_primitive_abstract_eval(*_args, prim):
return prim.out_avals_flat
+def _call_hi_primitive_staging(trace, source_info, *args_flat, prim):
+ trace.frame.is_high = True
+ args = tree_unflatten(prim.in_tree, args_flat)
+ ans = prim.staging(trace, source_info, *args)
+ return tree_leaves_checked(prim.out_tree, ans)
+pe.custom_staging_rules[call_hi_primitive_p] = _call_hi_primitive_staging
+
def _call_hi_primitive_to_lojax(*args_flat, prim):
args = tree_unflatten(prim.in_tree, args_flat)
- return tree_leaves_checked(prim.out_tree, prim.expand(*args))
+ ans = prim.expand(*args)
+ return tree_leaves_checked(prim.out_tree, ans)
call_hi_primitive_p.to_lojax = _call_hi_primitive_to_lojax
def _call_hi_primitive_batcher(axis_data, args_flat, dims_flat, prim):
@@ -414,34 +437,39 @@ def _call_hi_primitive_batcher(axis_data, args_flat, dims_flat, prim):
return ans_flat, dims_flat
batching.fancy_primitive_batchers[call_hi_primitive_p] = _call_hi_primitive_batcher
-def _call_hi_primitive_linearize(nz_in, *args_flat, prim):
- assert all(nz_in)
+def _call_hi_primitive_linearize(nz_in_flat, *args_flat, prim):
args = tree_unflatten(prim.in_tree, args_flat)
- ans, residuals = prim.vjp_fwd(*args)
- # TODO(dougalm): does the fwd/bwd API force us to assume the nzs_out are all False
- # (except in the case that all the nzs_in are True, which is handled in
- # LinearizeTrace.ProcessPrimitive)?
+ nzs_in = tree_unflatten(prim.in_tree, nz_in_flat)
+ ans, residuals, *maybe_nzs_out = prim.vjp_fwd(nzs_in, *args)
ans_flat = tree_leaves_checked(prim.out_tree, ans)
- nzs_out = [True for _ in ans_flat]
- return (ans_flat, nzs_out, residuals, partial(fake_linear_op, prim))
+ nzs_out = True if maybe_nzs_out == [] else maybe_nzs_out[0]
+ nzs_out_flat = broadcast_prefix(nzs_out, ans)
+ linearized = partial(fake_linear_op, prim, nz_in_flat)
+ return (ans_flat, nzs_out_flat, residuals, linearized)
-def fake_linear_op(prim, rs, *tangents):
+def fake_linear_op(prim, nz_in_flat, rs, *tangents):
residuals_flat, residuals_tree = tree_flatten(rs)
- return call_hi_primitive_linearized_p.bind(*residuals_flat, *tangents,
- residuals_tree=residuals_tree, prim=prim)
+ assert nz_in_flat == [not isinstance(t, ad_util.Zero) for t in tangents]
+ nz_tangents = tree_leaves(tangents)
+ return call_hi_primitive_linearized_p.bind(
+ *residuals_flat, *nz_tangents, residuals_tree=residuals_tree, prim=prim,
+ nz_in_flat=tuple(nz_in_flat))
ad.primitive_linearizations[call_hi_primitive_p] = _call_hi_primitive_linearize
call_hi_primitive_linearized_p = core.Primitive("call_hi_primitive_linearized")
call_hi_primitive_linearized_p.multiple_results = True
-call_hi_primitive_linearized_p.is_high = lambda *args, prim, residuals_tree: True # type: ignore
+call_hi_primitive_linearized_p.is_high = lambda *args, prim, **_: True # type: ignore
@call_hi_primitive_linearized_p.def_abstract_eval
-def _call_hi_primitive_linearized_abstract_eval(*_args, prim, residuals_tree):
+def _call_hi_primitive_linearized_abstract_eval(*_args, prim, residuals_tree, nz_in_flat):
return [t.to_tangent_aval() for t in prim.out_avals_flat] # TODO(dougalm): handle nonzeros
-def _call_hi_primitive_linearized_transpose(cts_flat, *args, prim, residuals_tree):
+def _call_hi_primitive_linearized_transpose(cts_flat, *args, prim, residuals_tree, nz_in_flat):
residuals_flat, accums_flat = split_list(args, [residuals_tree.num_leaves])
residuals = tree_unflatten(residuals_tree, residuals_flat)
+ accums_flat_ = iter(accums_flat)
+ accums_flat = [next(accums_flat_) if nz else ad.NullAccum() for nz in nz_in_flat]
+ assert next(accums_flat_, None) is None
accums = tree_unflatten(prim.in_tree, accums_flat)
cts = tree_unflatten(prim.out_tree, cts_flat)
none = prim.vjp_bwd(residuals, cts, *accums)
diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/export_with_memory_space.py b/jax/_src/internal_test_util/export_back_compat_test_data/export_with_memory_space.py
new file mode 100644
index 000000000000..3d89168e0004
--- /dev/null
+++ b/jax/_src/internal_test_util/export_back_compat_test_data/export_with_memory_space.py
@@ -0,0 +1,23 @@
+# Copyright 2025 The JAX Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# ruff: noqa
+
+# Pasted from the test output (see export_serialization_back_compat_test.py module docstring)
+serializations = [
+ dict(
+ serialization_version=5,
+ exported_serialized=bytearray(b"0\x00\x00\x00\x00\x00*\x00L\x00J\x00D\x00@\x00<\x008\x004\x00.\x00(\x00$\x00 \x00\x1c\x00\x18\x00\x14\x00\x10\x00\x0e\x00\x08\x00\x07\x00\x00\x000\x00*\x00\x00\x00\x00\x00\x00\x01D\x00\x00\x00\x00\x00\n\x00D\x00\x00\x00\x84\x02\x00\x00\x84\x02\x00\x00\x84\x02\x00\x00X\x02\x00\x00\x80\x02\x00\x00\x88\x02\x00\x00\x00\x00\x02\x00\x02\x00\x00\x00\xa0\x02\x00\x00\xcc\x02\x00\x00\xcc\x02\x00\x00\x04\x03\x00\x00X\x03\x00\x00\x00\x00\x03\x00\x01\x00\x00\x00\x00\x00\x00\x00 \x02\x00\x00ML\xefR\rStableHLO_v1.13.0\x00\x01\x1f\x07\x01\x05\t\t\x01\x03\x0f\x03\x03\x13\x05\x05\x17\x1b\x03kE\x0f\x01\x1b\x07\x0b#\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x03\r\x13\x0f\x1b\x17\x0f\x13\x05\x1f\x0b\x0b\x13\x13\x0b\x0b\x1b\x0b\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x01\x05\x0f\x0b\x05\x0b\x17\x0f\x1b\x07\x07\x02\xf9\x1f\x05\t\x03\x07\x07\t\x0b\r\x0f\x11\x05\x0f\x11\x03\x01\x05\x11\x11\x01\t\x05\x13\x11\x01\x05\x05\x15\t\x03\x1d\x19\x01\x05\x17\x05\x03\x1d\x01\x03\x17\t\r\x15\x05!%\x01\x0b\x03#\x01\x01\t\x17\x01\x0b\x01\x01\x01\x1d\x19\x1d\x1b\x03\x05-3\r\x03/1\x1d\x1d\x1d\x1f\r\x05\')5\x1f\x1d!#\t\x03\x03;\r\x05=?\')\x1d#\x1d%\x1d\'\x1d)\x01\x02\x02\x01\t)\x05\t\r\r)\x01\x0b\x11\x05\x07\x05\x03\x05\x1b\t\x04I\x05\x01Q\x01\x05\x01\x07\x047\x03\x01\t\x03@\x01\x03\x05P\x01\x05\x07\x04\x1b\x03\x05\x07\x05\r\x0b\x17\x00\x07\x04\x01\x03\x03\x06\x03\x01\x05\x01\x00\x1a\x04+\x0f\x0b\x0f!\x1b!)\x19#\x05\x19%)9\x15\x11\x0b\x0f\x0b\t\x11builtin\x00sdy\x00vhlo\x00module\x00mesh\x00func_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda\x00x\x00mhlo.memory_kind\x00pinned_host\x00jax.global_constant\x00_platform_index\x00sdy.sharding\x00jax.result_info\x00result\x00main\x00public\x00\x08\x1b\x07\x05\'\x01\x05\x1b\x03\x0b+79AC\x02\x00\x00\x00\x14\x00\x00\x00\x04\x00\x00\x00\x04\x00\x00\x00cuda\x00\x00\x00\x00\x03\x00\x00\x00tpu\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x18\xff\xff\xff\x01\x00\x00\x00\x04\x00\x00\x00@\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x01\x0c\x00\x00\x00\x08\x03\x1a\x02\x02\x01J\x01\x02R\x01\x00\x01\x00\x00\x00\x04\x00\x00\x00\xcc\xff\xff\xff\x00\x00\x02\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x003\x00\x00\x00\x01\x00\x00\x002\x00\x00\x00p\xff\xff\xff\x01\x00\x00\x00\x10\x00\x00\x00\x0c\x00\x0c\x00\x00\x00\x08\x00\x07\x00\x06\x00\x0c\x00\x00\x00\x00\x00\x02\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x003\x00\x00\x00\x01\x00\x00\x002\x00\x00\x00\xcc\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x02\x02\x00\x00\x00,\x00\x00\x00\x10\x00\x00\x00\x00\x00\n\x00\x0c\x00\x0b\x00\x00\x00\x04\x00\n\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x08\x00\x0c\x00\x0b\x00\x04\x00\x08\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x02\x01\x00\x00\x00\x08\x00\x00\x00\x04\x00\x04\x00\x04\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x00"),
+ ),
+]
diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/export_with_specified_sharding.py b/jax/_src/internal_test_util/export_back_compat_test_data/export_with_specified_sharding.py
new file mode 100644
index 000000000000..4cb08b42f268
--- /dev/null
+++ b/jax/_src/internal_test_util/export_back_compat_test_data/export_with_specified_sharding.py
@@ -0,0 +1,29 @@
+# Copyright 2025 The JAX Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# ruff: noqa
+
+# Pasted from the test output (see export_serialization_back_compat_test.py module docstring)
+serializations = [
+ dict(
+ serialization_version=4,
+ exported_serialized=bytearray(b"(\x00\x00\x00$\x00D\x00B\x00<\x008\x004\x000\x00,\x00*\x00$\x00 \x00\x1c\x00\x18\x00\x14\x00\x10\x00\x0c\x00\n\x00\x04\x00$\x00\x00\x00@\x00\x00\x00\x00\x00\n\x00@\x00\x00\x00H\x06\x00\x00H\x06\x00\x00H\x06\x00\x00,\x06\x00\x00D\x06\x00\x00d\x06\x00\x00\x00\x00\x02\x00\x80\x06\x00\x00\xac\x06\x00\x00\xac\x06\x00\x00\xe4\x06\x00\x008\x07\x00\x00\x00\x00\x02\x00\x01\x00\x00\x00\x00\x00\x00\x00\xf6\x05\x00\x00ML\xefR\rStableHLO_v1.13.0\x00\x01%\x07\x01\x05\t\x0f\x01\x03\x0f\x03\x03\x13\x05\x0b\x17\x1b\x1f#\'\x03\xc7\x9f\x11\x01y\x07\x0b\x0b\x0b\x0b\x0f#\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1f\x0b\x0f\x0f\x0b\x1f\x0f\x0f\x0b\x1b\x0b\x0f\x0f\x0b\x1b\x0b\x0f\x0f\x0b\x1f\x0b\x0f\x0f\x0b\x1f\x0f\x0b\x1f\x03\x0f\x17\x13\x13\x0f\x1b\x0f\x1b\x05\x19\x0b\x0f\x13\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x1f\x0f\x01\x05\x0f\x0b\x05\r\x17\x07\x0f\x17\x13\x07\x02~\x04\x1f\x05\x15\x05\x17\x05\t\t\x07\x1d!#\x03\x07\x0f\x11\x13\x15\x17\x19\x05\x19\x11\x03\x00\x05\x1b\x11\x01\t\x05\x1d\x11\x01\x05\x05\x1f\x1d\x1f\x01\x05!\x05#\x15%+\x1d\')\x05%-\x03\x07\xb1\x1f+\x15-3\x1d/1\x05\'-\x03\x07w\x15]\x155;\x1d79\x05)-\x03\x07\xb7!_\x15=E\x1d?A\x05+-C\x07~\x05!K\x05-\x15GM\x1dIK\x05/-\x05\x07\xca\x02\x11-\x15OW\x1dQS\x051-U\x07\xf35g\x053\x15Ya\x1d[]\x055-_\x07\xf1\x1f\x99\x057\x15ck\x1deg\x059-i\x07\x02\x08\x1f\xab\x05;\x15ms\x1doq\x05=-\x05\x07\xda\x03!_\x1duw\x05?-\x05\x07b\x05KW\x0b\x03\x83\x01\x01\x0b\x01\x01\x01\x05\x03\x7f\x01\x03A\t\r\t\x05y{\x01\tA\x01\r\t\x05{y\x01\x1dC\x03\x03\x8b\r\x03\x87\x81#\x0b\x03\x03\x91\r\x05\x93\x95\x87\x85\x1dE\x1dG\x1dI\x1dK\x1f\t\t\x00\x00\x00@\x1f\r\x01\x01\x02\x02\x01\t)\x05A\x11\x07\t)\x01\x07\x11\x03\x05\x03\x05)\x03\x01\x0f\x1d\x04s\x05\x01Q\x01\r\x01\x07\x04a\x03\x01\t\x03@\x01\x03\x05P\x01\x05\x07\x04E\x03\t\x13\x03\x0b\x1d\x00\x07B\x01\x07\x03\t\tF\x0b\t\x03\x05\x03\x03\x0b\x06\x0b\x03\x05\x05\x01\x05\r\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00\xba\x0fM\x0f\x0b\x0f!\x1b\x05\'E\x9b)\x9f1\x9f\x17)\xa13QAg\x17\x05\r%)9\x9d\x91\x15\x19)\x19\x11\x0b\x0f\x0b\t\x11builtin\x00sdy\x00vhlo\x00module\x00mesh\x00func_v1\x00constant_v1\x00broadcast_in_dim_v1\x00multiply_v1\x00return_v1\x00/Users/necula/Source/jax/tests/export_serialization_back_compat_test.py\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/_pytest/runner.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_f\x00b\x00jit(f)/mul\x00CompatTest.test_with_specified_sharding..f\x00CompatTest.export_and_serialize\x00CompatTest.test_with_specified_sharding\x00TestCaseFunction.runtest\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/_pytest/unittest.py\x00pytest_runtest_call\x00_multicall\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_callers.py\x00PluginManager._hookexec\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_manager.py\x00HookCaller.__call__\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_hooks.py\x00call_and_report..\x00CallInfo.from_call\x00x\x00sdy.sharding\x00jax.result_info\x00result\x00main\x00public\x00\x08#\x0b\x057\x01\x05}\x07\x0b\x89\x8d\x8f\x97\x99\x03\x9b\x03\x9d\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x03\x00\x00\x00cpu\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x1c\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x01\x0c\x00\x00\x00\x08\x03\x1a\x02\x01\x02J\x01\x02R\x01\x00\x01\x00\x00\x00\x04\x00\x00\x00@\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x01\x0c\x00\x00\x00\x08\x03\x1a\x02\x02\x01J\x01\x02R\x01\x00\x01\x00\x00\x00\x04\x00\x00\x00\xca\xff\xff\xff\x00\x00\x00\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x004\x00\x00\x00\x02\x00\x00\x0016\x00\x00p\xff\xff\xff\x01\x00\x00\x00\x10\x00\x00\x00\x00\x00\n\x00\x0c\x00\x00\x00\x08\x00\x07\x00\n\x00\x00\x00\x00\x00\x00\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x004\x00\x00\x00\x02\x00\x00\x0016\x00\x00\xcc\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x02\x02\x00\x00\x00,\x00\x00\x00\x10\x00\x00\x00\x00\x00\n\x00\x0c\x00\x0b\x00\x00\x00\x04\x00\n\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x08\x00\x0c\x00\x0b\x00\x04\x00\x08\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x02\x01\x00\x00\x00\x08\x00\x00\x00\x04\x00\x04\x00\x04\x00\x00\x00\x01\x00\x00\x00f\x00\x00\x00"),
+ ), # End paste
+
+
+ dict(
+ serialization_version=5,
+ exported_serialized=bytearray(b"0\x00\x00\x00\x00\x00*\x00H\x00F\x00@\x00<\x008\x004\x000\x00*\x00$\x00 \x00\x1c\x00\x18\x00\x14\x00\x10\x00\x0c\x00\n\x00\x04\x00\x00\x00\x00\x00,\x00*\x00\x00\x00D\x00\x00\x00\x00\x00\n\x00D\x00\x00\x00L\x06\x00\x00L\x06\x00\x00L\x06\x00\x000\x06\x00\x00H\x06\x00\x00h\x06\x00\x00\x00\x00\x02\x00\x02\x00\x00\x00\x80\x06\x00\x00\xac\x06\x00\x00\xac\x06\x00\x00\xe4\x06\x00\x008\x07\x00\x00\x00\x00\x03\x00\x01\x00\x00\x00\x00\x00\x00\x00\xf6\x05\x00\x00ML\xefR\rStableHLO_v1.13.0\x00\x01%\x07\x01\x05\t\x0f\x01\x03\x0f\x03\x03\x13\x05\x0b\x17\x1b\x1f#\'\x03\xc7\x9f\x11\x01y\x07\x0b\x0b\x0b\x0b\x0f#\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1f\x0b\x0f\x0f\x0b\x1f\x0f\x0f\x0b\x1b\x0b\x0f\x0f\x0b\x1b\x0b\x0f\x0f\x0b\x1f\x0b\x0f\x0f\x0b\x1f\x0f\x0b\x1f\x03\x0f\x17\x13\x13\x0f\x1b\x0f\x1b\x05\x19\x0b\x0f\x13\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x1f\x0f\x01\x05\x0f\x0b\x05\r\x17\x07\x0f\x17\x13\x07\x02~\x04\x1f\x05\x15\x05\x17\x05\t\t\x07\x1d!#\x03\x07\x0f\x11\x13\x15\x17\x19\x05\x19\x11\x03\x00\x05\x1b\x11\x01\t\x05\x1d\x11\x01\x05\x05\x1f\x1d\x1f\x01\x05!\x05#\x15%+\x1d\')\x05%-\x03\x07\xa5\x1f+\x15-3\x1d/1\x05\'-\x03\x07k\x15]\x155;\x1d79\x05)-\x03\x07\xa9\x1d[\x15=E\x1d?A\x05+-C\x07~\x05!K\x05-\x15GM\x1dIK\x05/-\x05\x07\xca\x02\x11-\x15OW\x1dQS\x051-U\x07\xf35g\x053\x15Ya\x1d[]\x055-_\x07\xf1\x1f\x99\x057\x15ck\x1deg\x059-i\x07\x02\x08\x1f\xab\x05;\x15ms\x1doq\x05=-\x05\x07\xda\x03!_\x1duw\x05?-\x05\x07b\x05KW\x0b\x03\x83\x01\x01\x0b\x01\x01\x01\x05\x03\x7f\x01\x03A\t\r\t\x05y{\x01\tA\x01\r\t\x05{y\x01\x1dC\x03\x03\x8b\r\x03\x87\x81#\x0b\x03\x03\x91\r\x05\x93\x95\x87\x85\x1dE\x1dG\x1dI\x1dK\x1f\t\t\x00\x00\x00@\x1f\r\x01\x01\x02\x02\x01\t)\x05A\x11\x07\t)\x01\x07\x11\x03\x05\x03\x05)\x03\x01\x0f\x1d\x04s\x05\x01Q\x01\r\x01\x07\x04a\x03\x01\t\x03@\x01\x03\x05P\x01\x05\x07\x04E\x03\t\x13\x03\x0b\x1d\x00\x07B\x01\x07\x03\t\tF\x0b\t\x03\x05\x03\x03\x0b\x06\x0b\x03\x05\x05\x01\x05\r\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00\xba\x0fM\x0f\x0b\x0f!\x1b\x05\'E\x9b)\x9f1\x9f\x17)\xa13QAg\x17\x05\r%)9\x9d\x91\x15\x19)\x19\x11\x0b\x0f\x0b\t\x11builtin\x00sdy\x00vhlo\x00module\x00mesh\x00func_v1\x00constant_v1\x00broadcast_in_dim_v1\x00multiply_v1\x00return_v1\x00/Users/necula/Source/jax/tests/export_serialization_back_compat_test.py\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/_pytest/runner.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_f\x00b\x00jit(f)/mul\x00CompatTest.test_with_specified_sharding..f\x00CompatTest.export_and_serialize\x00CompatTest.test_with_specified_sharding\x00TestCaseFunction.runtest\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/_pytest/unittest.py\x00pytest_runtest_call\x00_multicall\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_callers.py\x00PluginManager._hookexec\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_manager.py\x00HookCaller.__call__\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_hooks.py\x00call_and_report..\x00CallInfo.from_call\x00x\x00sdy.sharding\x00jax.result_info\x00result\x00main\x00public\x00\x08#\x0b\x057\x01\x05}\x07\x0b\x89\x8d\x8f\x97\x99\x03\x9b\x03\x9d\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x03\x00\x00\x00cpu\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x1c\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x01\x0c\x00\x00\x00\x08\x03\x1a\x02\x01\x02J\x01\x02R\x01\x00\x01\x00\x00\x00\x04\x00\x00\x00@\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x01\x0c\x00\x00\x00\x08\x03\x1a\x02\x02\x01J\x01\x02R\x01\x00\x01\x00\x00\x00\x04\x00\x00\x00\xcc\xff\xff\xff\x00\x00\x01\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x004\x00\x00\x00\x02\x00\x00\x0016\x00\x00p\xff\xff\xff\x01\x00\x00\x00\x10\x00\x00\x00\x0c\x00\x0c\x00\x00\x00\x08\x00\x07\x00\x06\x00\x0c\x00\x00\x00\x00\x00\x01\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x004\x00\x00\x00\x02\x00\x00\x0016\x00\x00\xcc\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x02\x02\x00\x00\x00,\x00\x00\x00\x10\x00\x00\x00\x00\x00\n\x00\x0c\x00\x0b\x00\x00\x00\x04\x00\n\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x08\x00\x0c\x00\x0b\x00\x04\x00\x08\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x02\x01\x00\x00\x00\x08\x00\x00\x00\x04\x00\x04\x00\x04\x00\x00\x00\x01\x00\x00\x00f\x00\x00\x00"),
+ ), # End paste
+]
diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/export_with_unspecified_sharding.py b/jax/_src/internal_test_util/export_back_compat_test_data/export_with_unspecified_sharding.py
new file mode 100644
index 000000000000..4a0da0b85ee9
--- /dev/null
+++ b/jax/_src/internal_test_util/export_back_compat_test_data/export_with_unspecified_sharding.py
@@ -0,0 +1,29 @@
+# Copyright 2025 The JAX Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# ruff: noqa
+
+
+# Pasted from the test output (see export_serialization_back_compat_test.py module docstring)
+serializations = [
+ dict(
+ serialization_version=4,
+ exported_serialized=bytearray(b"(\x00\x00\x00$\x00D\x00B\x00<\x008\x004\x000\x00,\x00*\x00$\x00 \x00\x1c\x00\x18\x00\x14\x00\x10\x00\x0c\x00\n\x00\x04\x00$\x00\x00\x00@\x00\x00\x00\x00\x00\n\x00@\x00\x00\x00D\x06\x00\x00D\x06\x00\x00D\x06\x00\x00(\x06\x00\x00@\x06\x00\x00H\x06\x00\x00\x00\x00\x02\x00d\x06\x00\x00\x90\x06\x00\x00\x90\x06\x00\x00\xc8\x06\x00\x00\x1c\x07\x00\x00\x00\x00\x02\x00\x01\x00\x00\x00\x00\x00\x00\x00\xf1\x05\x00\x00ML\xefR\rStableHLO_v1.13.0\x00\x01%\x07\x01\x05\t\x0f\x01\x03\x0f\x03\x03\x13\x05\x0b\x17\x1b\x1f#\'\x03\xc5\x9d\x11\x01y\x07\x0b\x0b\x0b\x0f#\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x0b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1f\x0b\x0f\x0f\x0b\x1f\x0f\x0f\x0b\x1b\x0b\x0f\x0f\x0b\x1b\x0b\x0f\x0f\x0b\x1f\x0b\x0f\x0f\x0b\x1f\x0f\x0b\x1f\x03\r\x13\x0f\x1b\x17\x0f\x13\x05\x19\x0f\x13\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x1f\x0f\x01\x05\x0f\x0b\x05\r\x17\x07\x0f\x17\x13\x07\x02^\x04\x1f\x05\x15\x05\x17\x05\t\x1d!#\x03\x07\r\x0f\x11\x13\x15\x17\x05\x19\x11\x03\x00\x05\x1b\x11\x01\t\x05\x1d\x11\x01\x05\x05\x1f\t\x07\x1d\x1f\x01\x05!\x05#\x15%+\x1d\')\x05%-\x03\x07\xed\x1f+\x15-3\x1d/1\x05\'-\x03\x07w\x15]\x155;\x1d79\x05)-\x03\x07\xf5!_\x15=E\x1d?A\x05+-C\x07~\x05!K\x05-\x15GM\x1dIK\x05/-\x05\x07\xca\x02\x11-\x15OW\x1dQS\x051-U\x07\xf35g\x053\x15Ya\x1d[]\x055-_\x07\xf1\x1f\x99\x057\x15ck\x1deg\x059-i\x07\x02\x08\x1f\xab\x05;\x15ms\x1doq\x05=-\x05\x07\xda\x03!_\x1duw\x05?-\x05\x07b\x05KW\x05\x03{\x01\x03A\t\r\x1b\x05\x7f\x83\x01\x0b\x03\x81\x01\x01\tA\x01\x0b\x01\x01\x01\x03\x03\x87\r\x03\x89}\x1dC#\x0b\x03\x03\x8f\r\x03\x91\x93\x1dE\x1dG\x1dI\x1dK\x1f\t\t\x00\x00\x00@\x1f\r\x01\x01\x02\x02\x01\t)\x05A\x11\x07\t)\x01\x07\x11\x03\x05\x03\x05)\x03\x01\x0f\x1d\x04s\x05\x01Q\x01\x0b\x01\x07\x04a\x03\x01\t\x03@\x01\x03\x05P\x01\x05\x07\x04E\x03\t\x13\x03\x0b\x1d\x00\x07B\x01\x07\x03\t\tF\t\t\x03\x05\x03\x03\x0b\x06\t\x03\x05\x05\x01\x05\r\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00\xca\x0fM\x0f\x0b\x0f!\x1b\x05\'E\x9b)\x9f1\x9f\x17)\xa13UAk\x17\x05\r%)9\x9d\x91\x15\x19)\x19\x11\x0b\x0f\x0b\t\x11builtin\x00sdy\x00vhlo\x00module\x00mesh\x00func_v1\x00constant_v1\x00broadcast_in_dim_v1\x00multiply_v1\x00return_v1\x00/Users/necula/Source/jax/tests/export_serialization_back_compat_test.py\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/_pytest/runner.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_f\x00b\x00jit(f)/mul\x00CompatTest.test_with_unspecified_sharding..f\x00CompatTest.export_and_serialize\x00CompatTest.test_with_unspecified_sharding\x00TestCaseFunction.runtest\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/_pytest/unittest.py\x00pytest_runtest_call\x00_multicall\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_callers.py\x00PluginManager._hookexec\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_manager.py\x00HookCaller.__call__\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_hooks.py\x00call_and_report..\x00CallInfo.from_call\x00x\x00sdy.sharding\x00jax.result_info\x00result\x00main\x00public\x00\x08#\x0b\x053\x01\x05y\x07\x0b\x85\x8b\x8d\x95\x97\x03\x99\x03\x9b\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x03\x00\x00\x00cpu\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x18\xff\xff\xff\x01\x00\x00\x00\x04\x00\x00\x00@\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x01\x0c\x00\x00\x00\x08\x03\x1a\x02\x02\x01J\x01\x02R\x01\x00\x01\x00\x00\x00\x04\x00\x00\x00\xca\xff\xff\xff\x00\x00\x00\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x004\x00\x00\x00\x02\x00\x00\x0016\x00\x00p\xff\xff\xff\x01\x00\x00\x00\x10\x00\x00\x00\x00\x00\n\x00\x0c\x00\x00\x00\x08\x00\x07\x00\n\x00\x00\x00\x00\x00\x00\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x004\x00\x00\x00\x02\x00\x00\x0016\x00\x00\xcc\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x02\x02\x00\x00\x00,\x00\x00\x00\x10\x00\x00\x00\x00\x00\n\x00\x0c\x00\x0b\x00\x00\x00\x04\x00\n\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x08\x00\x0c\x00\x0b\x00\x04\x00\x08\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x02\x01\x00\x00\x00\x08\x00\x00\x00\x04\x00\x04\x00\x04\x00\x00\x00\x01\x00\x00\x00f\x00\x00\x00"),
+ ),
+
+ dict(
+ serialization_version=5,
+ exported_serialized=bytearray(b"0\x00\x00\x00\x00\x00*\x00H\x00F\x00@\x00<\x008\x004\x000\x00*\x00$\x00 \x00\x1c\x00\x18\x00\x14\x00\x10\x00\x0c\x00\n\x00\x04\x00\x00\x00\x00\x00,\x00*\x00\x00\x00D\x00\x00\x00\x00\x00\n\x00D\x00\x00\x00H\x06\x00\x00H\x06\x00\x00H\x06\x00\x00,\x06\x00\x00D\x06\x00\x00L\x06\x00\x00\x00\x00\x02\x00\x02\x00\x00\x00d\x06\x00\x00\x90\x06\x00\x00\x90\x06\x00\x00\xc8\x06\x00\x00\x1c\x07\x00\x00\x00\x00\x03\x00\x01\x00\x00\x00\x00\x00\x00\x00\xf1\x05\x00\x00ML\xefR\rStableHLO_v1.13.0\x00\x01%\x07\x01\x05\t\x0f\x01\x03\x0f\x03\x03\x13\x05\x0b\x17\x1b\x1f#\'\x03\xc5\x9d\x11\x01y\x07\x0b\x0b\x0b\x0f#\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x0b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1f\x0b\x0f\x0f\x0b\x1f\x0f\x0f\x0b\x1b\x0b\x0f\x0f\x0b\x1b\x0b\x0f\x0f\x0b\x1f\x0b\x0f\x0f\x0b\x1f\x0f\x0b\x1f\x03\r\x13\x0f\x1b\x17\x0f\x13\x05\x19\x0f\x13\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x1f\x0f\x01\x05\x0f\x0b\x05\r\x17\x07\x0f\x17\x13\x07\x02^\x04\x1f\x05\x15\x05\x17\x05\t\x1d!#\x03\x07\r\x0f\x11\x13\x15\x17\x05\x19\x11\x03\x00\x05\x1b\x11\x01\t\x05\x1d\x11\x01\x05\x05\x1f\t\x07\x1d\x1f\x01\x05!\x05#\x15%+\x1d\')\x05%-\x03\x07\xdf\x1f+\x15-3\x1d/1\x05\'-\x03\x07w\x15]\x155;\x1d79\x05)-\x03\x07\xe7!_\x15=E\x1d?A\x05+-C\x07~\x05!K\x05-\x15GM\x1dIK\x05/-\x05\x07\xca\x02\x11-\x15OW\x1dQS\x051-U\x07\xf35g\x053\x15Ya\x1d[]\x055-_\x07\xf1\x1f\x99\x057\x15ck\x1deg\x059-i\x07\x02\x08\x1f\xab\x05;\x15ms\x1doq\x05=-\x05\x07\xda\x03!_\x1duw\x05?-\x05\x07b\x05KW\x05\x03{\x01\x03A\t\r\x1b\x05\x7f\x83\x01\x0b\x03\x81\x01\x01\tA\x01\x0b\x01\x01\x01\x03\x03\x87\r\x03\x89}\x1dC#\x0b\x03\x03\x8f\r\x03\x91\x93\x1dE\x1dG\x1dI\x1dK\x1f\t\t\x00\x00\x00@\x1f\r\x01\x01\x02\x02\x01\t)\x05A\x11\x07\t)\x01\x07\x11\x03\x05\x03\x05)\x03\x01\x0f\x1d\x04s\x05\x01Q\x01\x0b\x01\x07\x04a\x03\x01\t\x03@\x01\x03\x05P\x01\x05\x07\x04E\x03\t\x13\x03\x0b\x1d\x00\x07B\x01\x07\x03\t\tF\t\t\x03\x05\x03\x03\x0b\x06\t\x03\x05\x05\x01\x05\r\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00\xca\x0fM\x0f\x0b\x0f!\x1b\x05\'E\x9b)\x9f1\x9f\x17)\xa13UAk\x17\x05\r%)9\x9d\x91\x15\x19)\x19\x11\x0b\x0f\x0b\t\x11builtin\x00sdy\x00vhlo\x00module\x00mesh\x00func_v1\x00constant_v1\x00broadcast_in_dim_v1\x00multiply_v1\x00return_v1\x00/Users/necula/Source/jax/tests/export_serialization_back_compat_test.py\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/_pytest/runner.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_f\x00b\x00jit(f)/mul\x00CompatTest.test_with_unspecified_sharding..f\x00CompatTest.export_and_serialize\x00CompatTest.test_with_unspecified_sharding\x00TestCaseFunction.runtest\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/_pytest/unittest.py\x00pytest_runtest_call\x00_multicall\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_callers.py\x00PluginManager._hookexec\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_manager.py\x00HookCaller.__call__\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_hooks.py\x00call_and_report..\x00CallInfo.from_call\x00x\x00sdy.sharding\x00jax.result_info\x00result\x00main\x00public\x00\x08#\x0b\x053\x01\x05y\x07\x0b\x85\x8b\x8d\x95\x97\x03\x99\x03\x9b\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x03\x00\x00\x00cpu\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x18\xff\xff\xff\x01\x00\x00\x00\x04\x00\x00\x00@\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x01\x0c\x00\x00\x00\x08\x03\x1a\x02\x02\x01J\x01\x02R\x01\x00\x01\x00\x00\x00\x04\x00\x00\x00\xcc\xff\xff\xff\x00\x00\x01\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x004\x00\x00\x00\x02\x00\x00\x0016\x00\x00p\xff\xff\xff\x01\x00\x00\x00\x10\x00\x00\x00\x0c\x00\x0c\x00\x00\x00\x08\x00\x07\x00\x06\x00\x0c\x00\x00\x00\x00\x00\x01\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x004\x00\x00\x00\x02\x00\x00\x0016\x00\x00\xcc\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x02\x02\x00\x00\x00,\x00\x00\x00\x10\x00\x00\x00\x00\x00\n\x00\x0c\x00\x0b\x00\x00\x00\x04\x00\n\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x08\x00\x0c\x00\x0b\x00\x04\x00\x08\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x02\x01\x00\x00\x00\x08\x00\x00\x00\x04\x00\x04\x00\x04\x00\x00\x00\x01\x00\x00\x00f\x00\x00\x00"),
+ ),
+]
diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py
index 7e5aacba3d80..5d3af4d9ced8 100644
--- a/jax/_src/interpreters/ad.py
+++ b/jax/_src/interpreters/ad.py
@@ -26,7 +26,7 @@
from jax._src import linear_util as lu
from jax._src.interpreters import partial_eval as pe
from jax._src.tree_util import (tree_flatten, tree_unflatten,
- register_pytree_node, Partial, PyTreeDef)
+ register_pytree_node, PyTreeDef)
from jax._src import mesh as mesh_lib
from jax._src import core
from jax._src import source_info_util
@@ -51,18 +51,12 @@ def identity(x): return x
def _update_annotation(
f: lu.WrappedFun,
- orig_type: tuple[tuple[core.AbstractValue, bool], ...] | None,
- explicit_nonzeros: list[bool]
+ orig_type: tuple[core.AbstractValue, ...] | None,
+ nonzeros: list[bool]
) -> lu.WrappedFun:
if orig_type is None:
return f
- # By convention, `explicit_nonzeros` only accounts for explicit arguments.
- assert len(explicit_nonzeros) == sum(explicit for _, explicit in orig_type)
- # Implicit arguments never have tangents, so generate the tangent part of the
- # type annotation from explicit arguments only.
- explicit_avals = [aval for aval, explicit in orig_type if explicit]
- tan_types = [(aval.to_tangent_aval(), True)
- for nz, aval in zip(explicit_nonzeros, explicit_avals) if nz]
+ tan_types = [aval.to_tangent_aval() for nz, aval in zip(nonzeros, orig_type) if nz]
return lu.annotate(f, (*orig_type, *tan_types))
def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True,
@@ -77,7 +71,7 @@ def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True,
def jvpfun(f: Callable, instantiate, transform_stack, primals, tangents):
tag = core.TraceTag()
tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero)
- and isinstance(core.typeof(t), core.ShapedArray)
+ and isinstance(typeof(t), core.ShapedArray)
and dtype(t) == float0 else t for t in tangents]
ctx = (source_info_util.transform_name_stack('jvp') if transform_stack
else contextlib.nullcontext())
@@ -243,7 +237,7 @@ def direct_linearize(traceable: lu.WrappedFun, primals, kwargs, *,
tangent_trace = pe.DynamicJaxprTrace(dbg, auto_dce=True)
tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval(), source_info) for p in primals]
tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero)
- and isinstance(core.typeof(t), core.ShapedArray)
+ and isinstance(typeof(t), core.ShapedArray)
and dtype(t) == float0 else t for t in tangents]
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=tag)
tangent_trace.tag = linearize_trace.tag
@@ -308,158 +302,6 @@ def linearize(traceable: lu.WrappedFun, *primals, **kwargs):
else:
return out_primals_consts, out_tangents_pvals, jaxpr, consts, aux()
-def vjp(traceable: lu.WrappedFun, primals, has_aux=False):
- if not has_aux:
- out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
- else:
- out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
-
- def unbound_vjp(pvals, jaxpr, consts, *cts):
- cts = tuple(ct for ct, pval in zip(cts, pvals) if not pval.is_known())
- dummy_args = [UndefinedPrimal(v.aval) for v in jaxpr.invars]
- arg_cts = backward_pass(jaxpr, True, consts, dummy_args, cts)
- return map(instantiate_zeros, arg_cts)
-
- vjp_ = Partial(partial(unbound_vjp, pvals, jaxpr), consts)
- if not has_aux:
- return out_primals, vjp_
- else:
- return out_primals, vjp_, aux
-
-# NOTE: The FIXMEs below are caused by primal/tangent mixups (type
-# errors if you will)
-def backward_pass(jaxpr: core.Jaxpr, transform_stack,
- consts, primals_in, cotangents_in):
- if all(type(ct) is Zero for ct in cotangents_in) and not jaxpr.effects:
- return map(lambda v: Zero(v.aval), jaxpr.invars)
-
- def write_cotangent(prim, v, ct):
- # assert v not in primal_env
- assert ct is not Zero, (prim, v.aval) # check for an old harmless type error
- if ct is None or type(v) is Literal:
- return
- if type(ct) is Zero:
- # FIXME: This triggers a lot of failures!
- # assert v.aval == ct.aval, (prim, v.aval, ct.aval)
- return
- ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct
- # TODO(mattjj): add back these checks for dynamic shapes
- # if config.enable_checks.value:
- # ct_aval = core.get_aval(ct_env[v])
- # joined_aval = core.lattice_join(v.aval, ct_aval).strip_weak_type()
- # assert v.aval.strip_weak_type() == joined_aval, (prim, v.aval, ct_aval)
-
- def read_cotangent(v):
- return ct_env.pop(v, Zero(v.aval.to_tangent_aval()))
-
- def read_primal(v):
- if type(v) is Literal:
- return v.val
- else:
- a = v.aval
- if type(a) is core.DShapedArray:
- shape = [primal_env[d] if type(d) is core.Var else d for d in a.shape]
- a = a.update(shape=tuple(shape))
- return primal_env.get(v, UndefinedPrimal(a))
-
- def write_primal(v, val):
- if not is_undefined_primal(val):
- primal_env[v] = val
-
- primal_env: dict[Any, Any] = {}
- foreach(write_primal, jaxpr.constvars, consts)
- foreach(write_primal, jaxpr.invars, primals_in)
-
- # Start with a forward pass to evaluate any side-effect-free JaxprEqns that
- # only operate on primals. This is required to support primitives with
- # linearization rules that include computations on the residuals.
- lin_eqns = []
- dangling_refs = set()
- for eqn in jaxpr.eqns:
- if eqn.primitive is core.ref_p:
- dangling_refs.add(eqn.outvars[0])
- if eqn.primitive is core.freeze_p:
- dangling_refs.remove(eqn.invars[0]) # type: ignore
- # TODO (dfm): The effects check is probably stricter than necessary.
- # Consider adding an allowlist of effects here.
- if jaxpr.effects or any(
- type(x) is not Literal and x not in primal_env for x in eqn.invars):
- lin_eqns.append(eqn)
- continue
- subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
- name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack
- with source_info_util.user_context(
- eqn.source_info.traceback, name_stack=name_stack), eqn.ctx.manager:
- ans = eqn.primitive.bind(*subfuns, *map(read_primal, eqn.invars), **bind_params)
- if eqn.primitive.multiple_results:
- foreach(write_primal, eqn.outvars, ans)
- else:
- write_primal(eqn.outvars[0], ans)
-
- for v in dangling_refs:
- write_primal(v, core.new_ref(zeros_like_aval(v.aval.inner_aval))) # type: ignore
-
- ct_env: dict[Any, Any] = {}
- ctx = (source_info_util.transform_name_stack('transpose') if transform_stack
- else contextlib.nullcontext())
- with ctx:
- foreach(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in)
- for eqn in lin_eqns[::-1]:
- if eqn.primitive.ref_primitive:
- if eqn.primitive is core.ref_p:
- val_var, = eqn.invars
- ref_var, = eqn.outvars
- ref = read_primal(ref_var)
- ct_out = core.freeze(ref)
- write_cotangent(eqn.primitive, val_var, ct_out)
- elif eqn.primitive is core.freeze_p:
- val_var, = eqn.outvars
- ref_var, = eqn.invars # type: ignore
- ct_in = instantiate_zeros(read_cotangent(val_var))
- write_primal(ref_var, core.new_ref(ct_in))
- continue
-
- invals = map(read_primal, eqn.invars)
- if eqn.primitive.multiple_results:
- cts_in = map(read_cotangent, eqn.outvars)
- else:
- cts_in, = map(read_cotangent, eqn.outvars)
- name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack
- with source_info_util.user_context(
- eqn.source_info.traceback, name_stack=name_stack), eqn.ctx.manager:
- if eqn.primitive.call_primitive or eqn.primitive.map_primitive:
- cts_in_avals = [v.aval for v in eqn.outvars]
- params = dict(eqn.params)
- call_jaxpr = params.pop('call_jaxpr')
- cts_out = get_primitive_transpose(eqn.primitive)(
- params, call_jaxpr, invals, cts_in, cts_in_avals)
- else:
- try:
- cts_out = get_primitive_transpose(eqn.primitive)(
- cts_in, *invals, **eqn.params)
- except core.ShardingTypeError as e:
- extra_msg = ("This is a potential JAX bug. Please file an issue at"
- " https://github.com/jax-ml/jax/issues")
- if extra_msg in str(e):
- raise
- raise core.ShardingTypeError(f"{str(e)}\n{extra_msg}") from e
- except (FloatingPointError, ZeroDivisionError) as e:
- msg = "When differentiating the code at the top of the callstack:"
- if msg not in e.args[0]:
- e.args = e.args[0] + f'\n{msg}',
- e.args = e.args[0] + f'\n{source_info_util.summarize(eqn.source_info)}',
- raise e from None
- cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out
- # FIXME: Some invars correspond to primals!
- foreach(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out)
-
- cotangents_out = map(read_cotangent, jaxpr.invars)
- return cotangents_out
-
-def closed_backward_pass(jaxpr: core.ClosedJaxpr, transform_stack,
- primals_in, cotangents_in):
- return backward_pass(jaxpr.jaxpr, transform_stack, jaxpr.consts,
- primals_in, cotangents_in)
class UndefinedPrimal:
__slots__ = ['aval']
@@ -581,9 +423,10 @@ def accum(self, x):
assert x is not Zero
if isinstance(x, Zero) or x is None:
return
- elif self.ref is None:
+ if self.ref is None:
self.ref = core.new_ref(x)
else:
+ ct_check(self, x)
self.ref.addupdate(x)
def freeze(self):
@@ -607,16 +450,30 @@ def __init__(self, aval, val=None):
def accum(self, x):
if x is not None:
+ ct_check(self, x)
self.val = add_tangents(self.val, x)
def freeze(self):
return self.val
-# class NullAccum(GradAccum):
-# aval: core.AbstractValue
-# def __init__(self, aval): self.aval = aval
-# def accum(self, x): return
-# def freeze(self): assert False
+def ct_check(primal, ct):
+ if config.disable_bwd_checks.value:
+ return
+ ct_aval = ct.aval if type(ct) is Zero else typeof(ct)
+ ct_aval_expected = primal.aval.to_cotangent_aval() # type: ignore
+ if not core.typematch(ct_aval, ct_aval_expected, only_shape_shd_check=True):
+ # TODO(yashkatariya, mattjj): Add primitive name here for
+ # better error message?
+ raise ValueError(
+ f"Input primal JAX type to VJP function is"
+ f" {primal.aval.str_short()}. Hence the expected"
+ f" cotangent type is {ct_aval_expected.str_short()} but"
+ f" got {ct_aval.str_short()}")
+
+class NullAccum(GradAccum):
+ def __init__(self): pass
+ def accum(self, x): return
+ def freeze(self): assert False
fancy_transposes: dict[core.Primitive, Callable] = {}
@@ -651,7 +508,20 @@ def accum_typeof(x):
if isinstance(x, GradAccum):
return x.aval
else:
- return core.typeof(x)
+ return typeof(x)
+
+# TODO(mattjj): this is for for backward (get it?) compatibility. Remove, maybe.
+def backward_pass(jaxpr, transform_stack: bool, consts, primals_in, cts_in):
+ primals_in = [ValAccum(x.aval) if isinstance(x, UndefinedPrimal) else x
+ for x in primals_in]
+ backward_pass3(jaxpr, transform_stack, consts, primals_in, cts_in)
+ return [x.freeze() if isinstance(x, ValAccum) else None
+ for x in primals_in]
+
+def closed_backward_pass(jaxpr: core.ClosedJaxpr, transform_stack,
+ primals_in, cotangents_in):
+ return backward_pass(jaxpr.jaxpr, transform_stack, jaxpr.consts,
+ primals_in, cotangents_in)
@lu.transformation_with_aux2
@@ -679,7 +549,7 @@ def process_primitive(self, primitive, tracers, params):
primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers))
if (all(type(t) is Zero for t in tangents_in) and
primitive is not core.ref_p and
- not any(isinstance(core.typeof(x), AbstractRef) for x in primals_in)):
+ not any(isinstance(typeof(x), AbstractRef) for x in primals_in)):
return primitive.bind_with_trace(self.parent_trace, primals_in, params)
jvp = primitive_jvps.get(primitive)
if not jvp:
@@ -811,7 +681,7 @@ def process_custom_transpose(self, prim, call, tracers, **params):
def maybe_jvp_tracer(trace, primal, tangent):
if (type(tangent) is Zero or
- isinstance(core.typeof(tangent), core.ShapedArray)
+ isinstance(typeof(tangent), core.ShapedArray)
and dtype(tangent) == float0):
return primal
else:
@@ -828,7 +698,9 @@ def __init__(self, trace, primal, tangent):
self.tangent = tangent
def _short_repr(self):
- return f"GradTracer<{self.aval}>"
+ pp = lambda x: x._short_repr() if isinstance(x, Tracer) else str(x)
+ primal, tangent = pp(self.primal), pp(self.tangent)
+ return f'JVPTracer({primal=!s}, {tangent=!s})'
@property
def aval(self):
@@ -901,7 +773,7 @@ def process_primitive(self, primitive, args, params):
tangent_nzs = [type(t) is not Zero for t in tangents_in]
if (all(type(t) is Zero for t in tangents_in) and
primitive is not core.ref_p and
- not any(isinstance(core.typeof(x), AbstractRef) for x in primals_in)):
+ not any(isinstance(typeof(x), AbstractRef) for x in primals_in)):
return primitive.bind_with_trace(self.parent_trace, primals_in, params)
fallback = partial(fallback_linearize_rule, primitive)
lin = primitive_linearizations.get(primitive, fallback)
@@ -1173,6 +1045,11 @@ def __init__(self, trace, primal, tangent):
self.primal = primal
self.tangent = tangent
+ def _short_repr(self):
+ pp = lambda x: x._short_repr() if isinstance(x, Tracer) else str(x)
+ primal, tangent = pp(self.primal), typeof(self.tangent).str_short(True)
+ return f"GradTracer({primal=!s}, typeof(tangent)={tangent!s})"
+
@property
def aval(self):
return get_aval(self.primal)
@@ -1319,51 +1196,53 @@ def traceable(f, store, in_tree, *primals_and_tangents):
store.store(out_tree)
return out_flat
+def call_transpose_fancy(primitive, cts, *args, call_jaxpr, **params):
+ if call_jaxpr.constvars: raise NotImplementedError
+ primals_ctrefs, specs = project_accums(args)
+ flat_args, treedef = tree_flatten((primals_ctrefs, cts))
+ cell = lambda: None
+
+ @partial(lu.wrap_init, debug_info=call_jaxpr.debug_info.with_unknown_names())
+ def transposed(*flat_args):
+ primals_ctrefs, cts = tree_unflatten(treedef, flat_args)
+ args = unproject_accums(specs, primals_ctrefs)
+ backward_pass3(call_jaxpr, False, (), args, cts)
+ cts_out = [x.freeze() if isinstance(x, ValAccum) else None for x in args]
+ cts_out, cell.out_tree = tree_flatten(cts_out) # type: ignore
+ return cts_out
-def call_transpose(primitive, params, call_jaxpr: core.Jaxpr, args, ct, _):
- if isinstance(call_jaxpr, core.ClosedJaxpr):
- call_jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts
- else:
- consts = ()
- all_args, in_treedef = tree_flatten((consts, args, ct))
- fun = lu.hashable_partial(
- lu.wrap_init(backward_pass, debug_info=call_jaxpr.debug_info),
- call_jaxpr, False)
- fun, out_tree = flatten_fun_nokwargs(fun, in_treedef)
update_params = call_transpose_param_updaters.get(primitive)
if update_params:
- params = update_params(params, map(is_undefined_primal, args),
- [type(x) is not Zero for x in ct])
- if config.dynamic_shapes.value:
- # TODO(mattjj,dougalm): handle consts, for now assume just args
- which_lin = [is_undefined_primal(x) for x in args]
- res_invars, _ = partition_list(which_lin, call_jaxpr.invars)
- new_invars = [*res_invars, *call_jaxpr.outvars]
- dbidx_map = {v: core.DBIdx(i) for i, v in enumerate(new_invars)}
- in_type = [(v.aval.update(shape=tuple(dbidx_map.get(d, d) for d in v.aval.shape)) # type: ignore[arg-type]
- if type(v.aval) is core.DShapedArray else v.aval, True) for v in new_invars]
- fun = lu.annotate(fun, tuple(in_type))
- out_flat = primitive.bind(fun, *all_args, **params)
- return tree_unflatten(out_tree(), out_flat)
-primitive_transposes[core.call_p] = partial(call_transpose, call_p)
+ params = update_params(params, [isinstance(x, GradAccum) for x in args],
+ [type(x) is not Zero for x in cts])
+ out_flat = primitive.bind(transposed, *flat_args, **params)
+ for x, ct in zip(args, tree_unflatten(cell.out_tree, out_flat)): # type: ignore
+ if isinstance(x, ValAccum): x.accum(ct)
+fancy_transposes[core.call_p] = partial(call_transpose_fancy, call_p)
-def _closed_call_transpose(params, jaxpr, args, ct, cts_in_avals):
- jaxpr_, consts = jaxpr.jaxpr, jaxpr.consts
+def _closed_call_transpose(ct, *args, call_jaxpr, **params):
+ jaxpr_, consts = call_jaxpr.jaxpr, call_jaxpr.consts
jaxpr_ = pe.convert_constvars_jaxpr(jaxpr_)
- return call_transpose(core.closed_call_p, params, jaxpr_, (*consts, *args),
- ct, cts_in_avals)
-primitive_transposes[core.closed_call_p] = _closed_call_transpose
+ call_transpose_fancy(core.closed_call_p, ct, *consts, *args,
+ call_jaxpr=jaxpr_, **params)
+fancy_transposes[core.closed_call_p] = _closed_call_transpose
@lu.transformation_with_aux2
def nonzero_outputs(f, store, *args, **kwargs):
results = f(*args, **kwargs)
- store.store([type(r) is not Zero for r in results])
+ store.store([not isinstance(r, (Zero, type(None))) for r in results])
return results
+# TODO(mattjj): delete this when the original pmap implementation is removed
def map_transpose(primitive: core.Primitive, params,
call_jaxpr: core.Jaxpr, args, ct, _):
+ # TODO(mattjj): we should unmap any Zeros in ct according to out_axes, but
+ # this code path is not long for this world...
+ args = [x if type(x) is not UndefinedPrimal else
+ UndefinedPrimal(core.mapped_aval(params['axis_size'], ax, x.aval))
+ for x, ax in zip(args, params['in_axes'])]
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
# TODO(necula): use the right debug_info for the backwards pass
fun = lu.hashable_partial(lu.wrap_init(
@@ -1400,7 +1279,7 @@ def out_axes_thunk():
print("Invalid nan value encountered in the backward pass of a jax.jit "
"function. Calling the de-optimized backward pass.")
try:
- _ = backward_pass(call_jaxpr, False, {}, args, ct)
+ _ = backward_pass(call_jaxpr, False, (), args, ct)
except (FloatingPointError, ZeroDivisionError) as e2:
raise e2 from None
else:
@@ -1416,7 +1295,7 @@ def unmap_zero(zero, in_axis):
return (zero if in_axis is None else
Zero(core.unmapped_aval(params['axis_size'], in_axis, zero.aval)))
arg_cts = (unmap_zero(arg_ct, in_axis) if type(arg_ct) is Zero else
- arg_ct if in_axis is not None else
+ arg_ct if in_axis is not None or arg_ct is None else
arg_ct.sum(0)
for arg_ct, in_axis in zip(arg_cts, in_axes))
return tuple(arg_cts)
@@ -1549,3 +1428,21 @@ def __init__(self):
# TODO(mattjj): remove this vestigial dict
reducing_transposes: dict[core.Primitive, Callable] = {}
+
+# TODO(mattjj): remove this old code, used by something downstream
+def call_transpose(primitive, params, call_jaxpr: core.Jaxpr, args, ct, _):
+ if isinstance(call_jaxpr, core.ClosedJaxpr):
+ call_jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts
+ else:
+ consts = ()
+ all_args, in_treedef = tree_flatten((consts, args, ct))
+ fun = lu.hashable_partial(
+ lu.wrap_init(backward_pass, debug_info=call_jaxpr.debug_info),
+ call_jaxpr, False)
+ fun, out_tree = flatten_fun_nokwargs(fun, in_treedef)
+ update_params = call_transpose_param_updaters.get(primitive)
+ if update_params:
+ params = update_params(params, map(is_undefined_primal, args),
+ [type(x) is not Zero for x in ct])
+ out_flat = primitive.bind(fun, *all_args, **params)
+ return tree_unflatten(out_tree(), out_flat)
diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py
index 260a43988018..69f26c14eeac 100644
--- a/jax/_src/interpreters/batching.py
+++ b/jax/_src/interpreters/batching.py
@@ -13,7 +13,6 @@
# limitations under the License.
from __future__ import annotations
-import collections
from collections.abc import Callable, Sequence
import dataclasses
from functools import partial
@@ -23,6 +22,7 @@
from jax._src import config
from jax._src import core
+from jax._src.core import typeof
from jax._src import source_info_util
from jax._src import linear_util as lu
from jax._src.partition_spec import PartitionSpec as P
@@ -31,8 +31,7 @@
from jax._src.ad_util import Zero, SymbolicZero, add_jaxvals, add_jaxvals_p
from jax._src.core import Trace, Tracer, TraceTag, AxisName
from jax._src.interpreters import partial_eval as pe
-from jax._src.tree_util import (tree_unflatten, tree_flatten,
- register_pytree_node, PyTreeDef)
+from jax._src.tree_util import (tree_unflatten, tree_flatten, PyTreeDef)
from jax._src.typing import Array
from jax._src.util import (unzip2, safe_map, safe_zip, split_list,
canonicalize_axis, moveaxis, as_hashable_function,
@@ -43,198 +42,6 @@
zip, unsafe_zip = safe_zip, zip
-# Jumbles
-
-# i:(Fin 3) => f32[[3, 1, 4].i]
-@dataclasses.dataclass(frozen=True)
-class JumbleTy:
- binder: core.Var
- length: int | Tracer | core.Var
- elt_ty: core.DShapedArray
- def __repr__(self) -> str:
- return f'Var{id(self.binder)}:{self.length} => {self.elt_ty}'
- replace = dataclasses.replace
-
-# [3, 1, 4].i
-@dataclasses.dataclass(frozen=True)
-class IndexedAxisSize:
- idx: core.Var
- lengths: Array | core.Var | Tracer
- def __repr__(self) -> str:
- return f'{self.lengths}.Var{id(self.idx)}'
- replace = dataclasses.replace
-
-# Jumble(aval=a:3 => f32[[3 1 4].a],
-# data=Array([0., 1., 2., 0., 0., 1., 2., 3.], dtype=float32))
-@dataclasses.dataclass(frozen=True)
-class Jumble:
- aval: JumbleTy
- data: Array
-
-# To vmap over a jumble, one must specify the axis as JumbleAxis.
-class JumbleAxis: pass
-jumble_axis = JumbleAxis()
-
-# As a temporary measure before we have more general JITable / ADable interfaces
-# (analogues to vmappable), to enable Jumbles to be used with other
-# transformations and higher-order primitives (primarily jit, though also grad
-# with allow_int=True) we register them as pytrees.
-# TODO(mattjj): add JITable / ADable interfaces, remove this pytree registration
-def _jumble_flatten(jumble):
- lengths = []
- new_shape = [lengths.append(d.lengths) or d.replace(lengths=len(lengths))
- if type(d) is IndexedAxisSize else d
- for d in jumble.aval.elt_ty.shape]
- elt_ty = jumble.aval.elt_ty.update(shape=tuple(new_shape))
- aval = jumble.aval.replace(elt_ty=elt_ty)
- return (lengths, jumble.data), aval
-
-
-def _ragged_axis_parts(dim: RaggedAxis) -> tuple[int, int, int]:
- stacked_axis = dim.stacked_axis
- ragged_axes = dim.ragged_axes
- if len(ragged_axes) != 1:
- raise ValueError('Multiple ragged axes not yet implemented.')
- ragged_axis_dim = ragged_axes[0][0]
- ragged_axis_length = ragged_axes[0][1]
- return stacked_axis, ragged_axis_dim, ragged_axis_length
-
-
-def _jumble_unflatten(aval, x):
- lengths, data = x
- new_shape = [d.replace(lengths=lengths[d.lengths - 1])
- if type(d) is IndexedAxisSize else d
- for d in aval.elt_ty.shape]
- elt_ty = aval.elt_ty.update(shape=tuple(new_shape))
- aval = aval.replace(elt_ty=elt_ty)
- return Jumble(aval, data)
-register_pytree_node(Jumble, _jumble_flatten, _jumble_unflatten)
-
-def _jumble_result(axis_size, stacked_axis, ragged_axes, x):
- binder = core.Var(core.ShapedArray((), np.dtype('int32')))
- if stacked_axis != 0:
- raise NotImplementedError # TODO Transpose x so the stacked axis is axis 0
- shape = list(x.shape)
- del shape[0]
- for ragged_axis, segment_lens in ragged_axes:
- shape[ragged_axis-1] = IndexedAxisSize(binder, segment_lens)
- elt_ty = core.DShapedArray(tuple(shape), x.dtype, x.weak_type)
- return Jumble(JumbleTy(binder, axis_size, elt_ty), x)
-
-
-@dataclasses.dataclass(frozen=True)
-class RaggedAxis:
- stacked_axis: int
- # For each axis, we store its index and the corresponding segment lengths.
- # For example, the jumble i:(Fin 3) => f32[lens1.i, 7, lens2.i]
- # would be represented with ragged_axes = [(1, lens1), (3, lens2)]
- ragged_axes: tuple[tuple[int, Any], ...]
-
- @property
- def size(self):
- # TODO(mattjj, axch): All the segment lengths arrays better be the
- # same length!
- return len(self.ragged_axes[0][1])
-
- def move_stacked_axis(self: RaggedAxis, dst: int) -> RaggedAxis:
- # Assumes that all stored and incoming axes are already canonicalized
- def move_axis(ax):
- if self.stacked_axis > ax and ax >= dst:
- return ax + 1
- if self.stacked_axis < ax and ax <= dst:
- return ax - 1
- return ax
- new_axes = tuple((move_axis(ax), sizes) for ax, sizes in self.ragged_axes)
- return RaggedAxis(dst, new_axes)
-
-
-def transpose_ragged_axes(dim: RaggedAxis, perm: tuple[int, ...]) -> RaggedAxis:
- new_ragged_axes = []
- for idx, old_idx in enumerate(perm):
- for ax, size in dim.ragged_axes:
- if old_idx == ax:
- new_ragged_axes.append((idx, size))
- break
- return _sorted_ragged_axis(dim.stacked_axis, new_ragged_axes)
-
-def _sorted_ragged_axis(stacked_axis, ragged_axes):
- return RaggedAxis(stacked_axis, tuple(sorted(ragged_axes, key=lambda p: p[0])))
-
-def make_batch_axis(
- ndim: int,
- stacked_axis: int,
- ragged_axes: list[tuple[int, Array | core.Var]],
-) -> int | RaggedAxis:
- if ragged_axes:
- canonical = [(canonicalize_axis(ax, ndim), sz) for ax, sz in ragged_axes]
- return _sorted_ragged_axis(canonicalize_axis(stacked_axis, ndim), canonical)
- else:
- return canonicalize_axis(stacked_axis, ndim)
-
-def bdim_as_shape(
- bdim: int | RaggedAxis, data_shape: core.Shape) -> core.Shape:
- if isinstance(bdim, RaggedAxis):
- result = list(data_shape)
- binder = core.Var(core.ShapedArray((), np.dtype('int32')))
- for ragged_axis, segment_lens in bdim.ragged_axes:
- result[ragged_axis] = IndexedAxisSize(binder, segment_lens)
- return tuple(result)
- else:
- return data_shape
-
-def shape_as_bdim(
- stacked_axis: int, data_shape: core.Shape) -> int | RaggedAxis:
- # This assumes that there is only one binder in the data_shape.
- ragged_axes = [(i, size.lengths) for i, size in enumerate(data_shape)
- if isinstance(size, IndexedAxisSize)]
- return make_batch_axis(len(data_shape), stacked_axis, ragged_axes)
-
-
-def _update_annotation(
- f: lu.WrappedFun, orig_type: core.InputType | None,
- axis_size: core.AxisSize, axis_name: AxisName,
- explicit_in_dims: Sequence[int | RaggedAxis | None],
- segment_lens: Sequence[Array],
- ) -> lu.WrappedFun:
- if orig_type is None: return f
- # By convention, `explicit_in_dims` only accounts for explicit arguments.
- assert len(explicit_in_dims) == sum(explicit for _, explicit in orig_type)
- # We need to:
- # * if `axis_size` is dynamic, add a new implicit binder (type) for it;
- # * for each element of `segment_lengths`, add a new explicit binder for it;
- # * drop other implicit binders, replacing DBIdx which refer to them with
- # Name objects;
- # * for each (aval, in_dim) pair: if int-valued in_dim, add batch axis (int
- # size if `axis_size` is int, otherwise Name); if RaggedAxis-valued in_dim,
- # add batch axis (int if corresponding segment_lengths is concrete, Name if
- # not);
- # * generate full in_type with implicit args too.
-
- class Name:
- def __init__(self, a): self.a = a
- names = [Name(a) for a, _ in orig_type]
- avals = [a.update(shape=tuple(names[d.val] if type(d) is pe.DBIdx else d
- for d in a.shape))
- if type(a) is core.DShapedArray else a for a, e in orig_type if e]
-
- new_avals = [core.get_aval(s) for s in segment_lens]
- sz = Name(axis_size.aval) if isinstance(axis_size, Tracer) else axis_size
- for a, d in zip(avals, explicit_in_dims):
- if isinstance(d, RaggedAxis):
- raise NotImplementedError
- else:
- new_avals.append(core.unmapped_aval(sz, d, a)) # type: ignore
-
- mentioned = {d for a in new_avals if type(a) is core.DShapedArray
- for d in a.shape if type(d) is Name}
- expl_names = set(map(Name, new_avals))
- impl_names = mentioned - expl_names # type: ignore
- impl_part = [(n.a, False) for n in impl_names] # type: ignore
- name_map = {n: pe.DBIdx(i) for i, n in enumerate((*impl_names, *expl_names))}
- expl_part = [(a.update(shape=tuple(name_map.get(d, d) for d in a.shape))
- if type(a) is core.DShapedArray else a, True) for a in new_avals]
- return lu.annotate(f, (*impl_part, *expl_part))
-
### vmappable typeclass
Vmappable = Any
@@ -251,26 +58,11 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt:
handler = to_elt_handlers.get(type(x))
if handler:
return handler(partial(to_elt, trace, get_idx), get_idx, x, spec)
- elif type(x) is Jumble:
- if spec is not jumble_axis:
- raise TypeError("jumble input without using jumble_axis in_axes spec")
- ias: IndexedAxisSize # Not present in the AxisSize union in core.py
- (d, ias), = ((i, sz) # type: ignore
- for i, sz in enumerate(x.aval.elt_ty.shape)
- if type(sz) is IndexedAxisSize)
- batch_axis = make_batch_axis(x.data.ndim, 0, [(d+1, ias.lengths)])
- return BatchTracer(trace, x.data, batch_axis)
elif isinstance(spec, int) or spec is None:
spec = spec and canonicalize_axis(spec, len(np.shape(x)))
return (BatchTracer(trace, x, spec, source_info_util.current())
if spec is not None else x)
else:
- if isinstance(trace, BatchTrace) and isinstance(spec, JumbleAxis):
- # TODO(mvoz): A vaguely questionable assumption that it is always
- # sound to have a 0 axis here. This is true for the current use cases
- # and comes from how we handle intermediary products of jumbles in
- # vmap.
- return BatchTracer(trace, x, 0, source_info_util.current())
# TODO(mvoz): This is a terrible place to fall into if you pass
# a non jumble type in, make it clearer what went wrong.
assert False, f'Unexpected type in ELT? {type(x)}'
@@ -286,17 +78,11 @@ def _cont(axis_size, elt, axis):
return from_elt(trace, axis_size, mesh_axis, i, elt, axis)
return handler(_cont, axis_size, x, spec)
val, bdim = trace.to_batch_info(x)
- if type(bdim) is RaggedAxis:
- if spec is not jumble_axis:
- # TODO(mattjj): improve this error message
- raise TypeError("ragged output without using jumble_axis out_axes spec")
- return _jumble_result(axis_size, bdim.stacked_axis, bdim.ragged_axes, val)
- else:
- try:
- return matchaxis(trace.axis_data.name, axis_size, mesh_axis,
- bdim, spec, val)
- except SpecMatchError:
- raise SpecMatchError(i, x.batch_dim, spec) from None
+ try:
+ return matchaxis(trace.axis_data.name, axis_size, mesh_axis,
+ bdim, spec, val)
+ except SpecMatchError:
+ raise SpecMatchError(i, x.batch_dim, spec) from None
from_elt_handlers: dict[type, FromEltHandler] = {}
def make_iota(axis_size: AxisSize) -> Array:
@@ -319,7 +105,7 @@ def register_vmappable(data_type: type, spec_type: type, axis_size_type: type,
from_elt_handlers[data_type] = from_elt
if make_iota: make_iota_handlers[axis_size_type] = make_iota
vmappables: dict[type, tuple[type, type]] = {}
-spec_types: set[type] = {JumbleAxis}
+spec_types: set[type] = set()
def unregister_vmappable(data_type: type) -> None:
_, axis_size_type = vmappables.pop(data_type)
@@ -329,11 +115,11 @@ def unregister_vmappable(data_type: type) -> None:
del make_iota_handlers[axis_size_type]
global spec_types
spec_types = (
- {JumbleAxis} | {spec_type for spec_type, _ in vmappables.values()}
+ set() | {spec_type for spec_type, _ in vmappables.values()}
)
def is_vmappable(x: Any) -> bool:
- return type(x) is Jumble or type(x) in vmappables
+ return type(x) in vmappables
@lu.transformation_with_aux2
def flatten_fun_for_vmap(f: Callable,
@@ -344,44 +130,6 @@ def flatten_fun_for_vmap(f: Callable,
store.store(out_tree)
return ans
-# Propagate ragged masking rules from invars to outvars
-# rule([params], [raggedness_per_invar], outvars) ->
-# [raggedness_per_invar, raggedness_per_outvar]
-RaggedMaskingRule = Callable[
- [list[Any], list[Any], list[Any]], tuple[list[Any], list[Any]]
-]
-
-ragged_prop_rules: dict[core.Primitive, RaggedMaskingRule] = {}
-
-
-def ragged_mask_elementwise_rule(eqn_params, invar_raggedness, outvars):
- # TODO(mvoz): A util for getting the ragged representations
- first_invar_raggedness = invar_raggedness[0]
- for other_invar_raggedness in invar_raggedness[1:]:
- if other_invar_raggedness != first_invar_raggedness:
- raise ValueError(f'{other_invar_raggedness} != {first_invar_raggedness}')
-
- outvar_raggedness = [first_invar_raggedness] * len(outvars)
- return invar_raggedness, outvar_raggedness
-
-
-def ragged_mask_assert_no_op_rule(eqn_params, invar_raggedness, outvars):
- if any(invar_raggedness):
- raise ValueError(f'unexpected invar_raggedness: {invar_raggedness}')
- return invar_raggedness, [None] * len(outvars)
-
-
-def ragged_mask_no_op_rule(eqn_params, invar_raggedness, outvars):
- return invar_raggedness, [None] * len(outvars)
-
-
-def ragged_mask_transfer_identity(
- eqn_params, invar_raggedness, outvar_raggedness
-):
- assert len(invar_raggedness) == 1, invar_raggedness
- outvar_raggedness = invar_raggedness
- return invar_raggedness, outvar_raggedness
-
### tracer
@@ -393,10 +141,10 @@ def ragged_mask_transfer_identity(
class BatchTracer(Tracer):
__slots__ = ['val', 'batch_dim', 'source_info']
- def __init__(self, trace, val, batch_dim: NotMapped | int | RaggedAxis,
+ def __init__(self, trace, val, batch_dim: NotMapped | int,
source_info: source_info_util.SourceInfo | None = None):
if config.enable_checks.value:
- assert type(batch_dim) in (NotMapped, int, RaggedAxis)
+ assert type(batch_dim) in (NotMapped, int)
if type(batch_dim) is int:
aval = core.get_aval(val)
assert 0 <= batch_dim < len(aval.shape)
@@ -406,7 +154,7 @@ def __init__(self, trace, val, batch_dim: NotMapped | int | RaggedAxis,
self.source_info = source_info
def _short_repr(self):
- return f"VmapTracer<{self.aval}>"
+ return f"VmapTracer(aval={self.aval}, batched={typeof(self.val)})"
@property
def aval(self):
@@ -419,17 +167,8 @@ def aval(self):
return aval
elif type(self.batch_dim) is int:
return core.mapped_aval(aval.shape[self.batch_dim], self.batch_dim, aval)
- elif type(self.batch_dim) is RaggedAxis:
- new_aval = core.mapped_aval(
- aval.shape[self.batch_dim.stacked_axis], self.batch_dim.stacked_axis, aval)
- shape = list(new_aval.shape) # pytype: disable=attribute-error
- for ragged_axis, segment_lengths in self.batch_dim.ragged_axes:
- size_tracer = BatchTracer(self._trace, segment_lengths, 0)
- if self.batch_dim.stacked_axis < ragged_axis:
- ragged_axis -= 1
- shape[ragged_axis] = size_tracer
- return core.DShapedArray(shape=tuple(shape), dtype=aval.dtype,
- weak_type=aval.weak_type)
+ else:
+ raise Exception("batch dim should be int or `not_mapped`")
def full_lower(self):
if self.batch_dim is not_mapped:
@@ -449,7 +188,7 @@ def _contents(self):
def get_referent(self):
if self.batch_dim is None or type(self.batch_dim) is int:
return core.get_referent(self.val)
- else: # TODO(mattjj): could handle the RaggedAxis case?
+ else:
return self
@dataclasses.dataclass(frozen=True)
@@ -509,8 +248,6 @@ def to_batch_info(self, val):
return val, not_mapped
def process_primitive(self, p, tracers, params):
- if config.dynamic_shapes.value:
- p.abstract_eval(*(map(core.get_aval, tracers)), **params)
vals_in, dims_in = unzip2(map(self.to_batch_info, tracers))
args_not_mapped = all(bdim is not_mapped for bdim in dims_in)
if p in fancy_primitive_batchers:
@@ -518,18 +255,13 @@ def process_primitive(self, p, tracers, params):
and p in skippable_batchers
and not any(self.axis_data.name == axis_name
for axis_name in skippable_batchers[p](params))):
- # no-op shortcut
return p.bind_with_trace(self.parent_trace, vals_in, params)
else:
with core.set_current_trace(self.parent_trace):
val_out, dim_out = fancy_primitive_batchers[p](
self.axis_data, vals_in, dims_in, **params)
elif args_not_mapped:
- # no-op shortcut
return p.bind_with_trace(self.parent_trace, vals_in, params)
- elif p in primitive_batchers:
- with core.set_current_trace(self.parent_trace):
- val_out, dim_out = primitive_batchers[p](vals_in, dims_in, **params)
else:
raise NotImplementedError(f"Batching rule for '{p}' not implemented")
src = source_info_util.current()
@@ -545,16 +277,12 @@ def process_call(self, call_primitive, f, tracers, params):
assert call_primitive.multiple_results
params = dict(params, name=params.get('name', f.__name__))
vals, dims = unzip2(map(self.to_batch_info, tracers))
- segment_lens, dims = indirectify_ragged_axes(dims)
f_, dims_out = batch_subtrace(f, self.tag, self.axis_data, tuple(dims))
- f_ = _update_annotation(
- f_, f.in_type, self.axis_data.size, self.axis_data.name, dims, segment_lens)
with core.set_current_trace(self.parent_trace):
- vals_out = call_primitive.bind(f_, *segment_lens, *vals, **params)
- vals_out, dims_out = resolve_ragged_axes(vals_out, dims_out())
+ vals_out = call_primitive.bind(f_, *vals, **params)
src = source_info_util.current()
- return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out)]
+ return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out())]
def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
vals, dims = unzip2(map(self.to_batch_info, tracers))
@@ -707,43 +435,12 @@ def batch_subtrace(f, store, tag, axis_data, in_dims, *in_vals):
trace = BatchTrace(parent_trace, tag, axis_data)
with core.set_current_trace(trace):
in_dims = in_dims() if callable(in_dims) else in_dims
- in_vals, in_dims = resolve_ragged_axes(in_vals, in_dims)
in_tracers = [BatchTracer(trace, x, dim, source_info_util.current())
if dim is not None else x for x, dim in zip(in_vals, in_dims)]
outs = f(*in_tracers)
out_vals, out_dims = unzip2(map(trace.to_batch_info, outs))
- segment_lens, out_dims = indirectify_ragged_axes(out_dims)
store.store(out_dims)
- return (*segment_lens, *out_vals)
-
-def indirectify_ragged_axes(dims):
- if not any(type(d) is RaggedAxis for d in dims):
- return [], dims
- axis_map : dict[int, tuple[Array, pe.DBIdx]] = collections.OrderedDict()
- def canonicalize_segment_lengths(d: RaggedAxis) -> RaggedAxis:
- new_ragged_axes = []
- for ragged_axis, segment_lengths in d.ragged_axes:
- _, dbidx = axis_map.setdefault(
- id(core.get_referent(segment_lengths)),
- (segment_lengths, pe.DBIdx(len(axis_map))))
- new_ragged_axes.append((ragged_axis, dbidx))
- return RaggedAxis(d.stacked_axis, tuple(new_ragged_axes))
- new_dims = [canonicalize_segment_lengths(d)
- if isinstance(d, RaggedAxis) else d for d in dims]
- segment_lens = [s for s, _ in axis_map.values()]
- return segment_lens, new_dims
-
-def indirectify_ragged_axes_against_inputs_outputs(dims, in_vals, out_vals):
- def canonicalize_segment_lengths(d: RaggedAxis) -> RaggedAxis:
- new_ragged_axes = []
- for ragged_axis, segment_lengths in d.ragged_axes:
- key = id(core.get_referent(segment_lengths))
- value = _locate_value(key, in_vals, out_vals)
- new_ragged_axes.append((ragged_axis, value))
- return RaggedAxis(d.stacked_axis, tuple(new_ragged_axes))
- new_dims = [canonicalize_segment_lengths(d)
- if isinstance(d, RaggedAxis) else d for d in dims]
- return new_dims
+ return out_vals
def _locate_value(key, in_vals, out_vals):
for ix, candidate in enumerate(in_vals):
@@ -754,58 +451,27 @@ def _locate_value(key, in_vals, out_vals):
return pe.OutDBIdx(ix)
assert False, "Could not find segment lengths"
-def resolve_ragged_axes(vals, dims):
- idxs = {lengths_idx.val for d in dims if isinstance(d, RaggedAxis)
- for (_, lengths_idx) in d.ragged_axes}
- dims = [RaggedAxis(d.stacked_axis,
- tuple((ragged_axis, vals[lengths_idx.val])
- for ragged_axis, lengths_idx in d.ragged_axes))
- if isinstance(d, RaggedAxis) else d for d in dims]
- vals = [x for i, x in enumerate(vals) if i not in idxs]
- return vals, dims
-
-def resolve_ragged_axes_against_inputs_outputs(in_vals, out_vals, dims):
- def fetch(idx):
- if isinstance(idx, pe.InDBIdx):
- return in_vals[idx.val]
- else:
- assert isinstance(idx, pe.OutDBIdx)
- return out_vals[idx.val]
-
- dims = [RaggedAxis(d.stacked_axis,
- tuple((ragged_axis, fetch(lengths_idx))
- for ragged_axis, lengths_idx in d.ragged_axes))
- if isinstance(d, RaggedAxis) else d for d in dims]
- return dims
-
### API for batching jaxprs
-# TODO(axch): parameterize RaggedAxis annotations by a type parameter so as to
-# indicate whether we're dealing with instances that contain Arrays or DBIdx.
-# Can reuse same pattern for all dynamic shape stuff.
def batch_jaxpr2(
closed_jaxpr: core.ClosedJaxpr,
axis_data,
- in_axes: tuple[int | NotMapped | RaggedAxis, ...],
- ) -> tuple[core.ClosedJaxpr, tuple[int | NotMapped | RaggedAxis, ...]]:
+ in_axes: tuple[int | NotMapped, ...],
+ ) -> tuple[core.ClosedJaxpr, tuple[int | NotMapped ]]:
return _batch_jaxpr2(closed_jaxpr, axis_data, tuple(in_axes))
@weakref_lru_cache
def _batch_jaxpr2(
closed_jaxpr: core.ClosedJaxpr,
axis_data,
- in_axes: tuple[int | NotMapped | RaggedAxis, ...],
+ in_axes: tuple[int | NotMapped ],
) -> tuple[core.ClosedJaxpr, tuple[int | NotMapped, ...]]:
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr),
debug_info=closed_jaxpr.jaxpr.debug_info)
f, out_axes = _batch_jaxpr_inner(f, axis_data)
f = _batch_jaxpr_outer(f, axis_data, in_axes)
- in_axes2, avals_in = unzip2([
- handle_ragged(closed_jaxpr.in_avals, dim, aval)
- if isinstance(dim, RaggedAxis) else (dim, aval)
- for dim, aval in zip(in_axes, closed_jaxpr.in_avals)])
avals_in2 = []
- for aval, b in unsafe_zip(avals_in, in_axes2):
+ for aval, b in unsafe_zip(closed_jaxpr.in_avals, in_axes):
if b is not_mapped:
avals_in2.append(aval)
else:
@@ -818,14 +484,6 @@ def _batch_jaxpr2(
jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in2)
return core.ClosedJaxpr(jaxpr_out, consts), out_axes()
-def handle_ragged(in_avals: list[core.AbstractValue], dim: RaggedAxis,
- aval: core.ShapedArray) -> tuple[int, core.ShapedArray]:
- new_shape = list(aval.shape)
- for i, dbi in dim.ragged_axes:
- new_shape[i - (dim.stacked_axis < i)] = in_avals[dbi.val].dtype.bound
- new_aval = aval.update(shape=tuple(new_shape))
- return dim.stacked_axis, new_aval
-
def batch_jaxpr(closed_jaxpr, axis_data, in_batched, instantiate):
inst = tuple(instantiate) if isinstance(instantiate, list) else instantiate
return _batch_jaxpr(closed_jaxpr, axis_data, tuple(in_batched), inst)
@@ -863,7 +521,6 @@ def _batch_jaxpr_axes(closed_jaxpr: core.ClosedJaxpr,
def _batch_jaxpr_inner(f, store, axis_data, tag, in_axes, *in_vals):
with core.take_current_trace() as parent_trace:
trace = BatchTrace(parent_trace, tag, axis_data)
- _, in_axes = resolve_ragged_axes(in_vals, in_axes)
in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
for val, dim in zip(in_vals, in_axes)]
# TODO(yashkatariya): Instead of `add_explicit_mesh_axis_names`, we should
@@ -874,9 +531,7 @@ def _batch_jaxpr_inner(f, store, axis_data, tag, in_axes, *in_vals):
core.add_explicit_mesh_axis_names(axis_data.explicit_mesh_axis)):
outs = f(*in_tracers)
out_vals, out_axes = unzip2(map(trace.to_batch_info, outs))
- new_out_axes = indirectify_ragged_axes_against_inputs_outputs(
- out_axes, in_vals, out_vals)
- store.store(new_out_axes)
+ store.store(out_axes)
return out_vals
@lu.transformation_with_aux2
@@ -997,8 +652,6 @@ def _matchaxis_symzeros(axis_name, sz, mesh_axis, src, dst, x, sum_match=False):
...,
tuple[Any, Union[int, None, tuple[Union[int, None], ...]]]
]
-primitive_batchers : dict[core.Primitive, BatchingRule] = {}
-# "fancy" primitive batchers just take a extra leading `AxisData` and "trace type" args
fancy_primitive_batchers: dict[core.Primitive, Callable] = {}
# backwards compat shim. TODO: delete
@@ -1007,36 +660,47 @@ def __setitem__(self, prim, batcher):
def wrapped(axis_data, vals, dims, **params):
return batcher(axis_data.size, axis_data.name, None, vals, dims, **params)
fancy_primitive_batchers[prim] = wrapped
-
axis_primitive_batchers = AxisPrimitiveBatchersProxy()
+# backwards compat shim. TODO: delete
+class PrimitiveBatchersProxy:
+ def __setitem__(self, prim, batcher):
+ def wrapped(axis_data, vals, dims, **params):
+ del axis_data
+ if all(d is None for d in dims):
+ o = prim.bind(*vals, **params)
+ return (o, [None] * len(o)) if prim.multiple_results else (o, None)
+ return batcher(vals, dims, **params)
+ fancy_primitive_batchers[prim] = wrapped
+
+ def __delitem__(self, prim):
+ del fancy_primitive_batchers[prim]
+primitive_batchers = PrimitiveBatchersProxy()
+
# Presence in this table allows fancy batchers to be skipped by batch traces for
-# irrelevant axes. The Callable takes the params and returns a list of relevant
-# axes.
+# irrelevant axes. The Callable takes params and returns a list of relevant axes
+# TODO(yashkatariya): remove this
skippable_batchers : dict[core.Primitive, Callable] = {}
def defvectorized(prim):
- primitive_batchers[prim] = partial(vectorized_batcher, prim)
+ fancy_primitive_batchers[prim] = partial(vectorized_batcher, prim)
-def vectorized_batcher(prim, batched_args, batch_dims, **params):
+def vectorized_batcher(prim, axis_data, batched_args, batch_dims, **params):
+ assert not prim.multiple_results
+ if all(d is None for d in batch_dims):
+ return prim.bind(*batched_args, **params), None
assert all(batch_dims[0] == bd for bd in batch_dims[1:]), batch_dims
return prim.bind(*batched_args, **params), batch_dims[0]
def defbroadcasting(prim):
- primitive_batchers[prim] = partial(broadcast_batcher, prim)
-
-def broadcast_batcher(prim, args, dims, **params):
- """Process a primitive with built-in broadcasting.
+ fancy_primitive_batchers[prim] = partial(broadcast_batcher, prim)
- Args:
- args: the possibly-batched arguments
- dims: list or tuple of the same length as `args`, where each
- entry indicates the batching state of the corresponding entry to `args`:
- either an int indicating the batch dimension, or else `not_mapped`
- indicating no batching.
- """
+def broadcast_batcher(prim, axis_data, args, dims, **params):
assert len(args) > 1
+ if all(d is None for d in dims):
+ o = prim.bind(*args, **params)
+ return (o, [None] * len(o)) if prim.multiple_results else (o, None)
shape, dim = next((x.shape, d) for x, d in zip(args, dims)
if d is not not_mapped)
if all(core.definitely_equal_shape(shape, x.shape) and d == dim
@@ -1065,9 +729,12 @@ def _handle_scalar_broadcasting(nd, x, d):
return lax.expand_dims(x, tuple(range(np.ndim(x), nd)))
def defreducer(prim, ident):
- primitive_batchers[prim] = partial(reducer_batcher, prim, ident)
+ fancy_primitive_batchers[prim] = partial(reducer_batcher, prim, ident)
-def reducer_batcher(prim, ident, batched_args, batch_dims, axes, **params):
+def reducer_batcher(prim, ident, axis_data, batched_args, batch_dims, axes,
+ **params):
+ if all(d is None for d in batch_dims):
+ return prim.bind(*batched_args, axes=axes, **params), None
def out_axis(axes, axis):
return int(list(np.delete(np.arange(operand.ndim), axes)).index(axis))
operand, = batched_args
@@ -1078,23 +745,6 @@ def out_axis(axes, axis):
if 'input_shape' in params:
params = dict(params, input_shape=operand.shape)
return prim.bind(operand, axes=axes, **params), bdim_out
- elif isinstance(bdim, RaggedAxis):
- assert ident is not None, "TODO Ragged batching a reduction requires an identity"
- axes = tuple(np.where(np.less(axes, bdim.stacked_axis), axes, np.add(axes, 1)))
- bdim_out = out_axis(axes, bdim.stacked_axis)
- # For each ragged_axis, we either mask the operand there or append
- # it to the set of axes that will be ragged in the result.
- axes_to_mask = []
- ragged_axes_out = []
- for ragged_axis, segment_lengths in bdim.ragged_axes:
- if ragged_axis in axes:
- axes_to_mask.append((ragged_axis, segment_lengths))
- else:
- ragged_axes_out.append((out_axis(axes, ragged_axis), segment_lengths))
- operand = mask_ragged_axes(
- operand, ident, RaggedAxis(bdim.stacked_axis, tuple(axes_to_mask)))
- result = prim.bind(operand, axes=axes, **params)
- return result, make_batch_axis(operand.ndim, bdim_out, ragged_axes_out)
else:
assert False
@@ -1107,42 +757,6 @@ def expand_dims_batcher(prim, args, dims, **params):
out = prim.bind(*args, **params)
return (out, (0,) * len(out)) if prim.multiple_results else (out, 0)
-def mask_ragged_axes(operand: Array, ident, axis_spec: RaggedAxis) -> Array:
- # TODO(mattjj, axch) Can we mask multiple axes more efficiently at
- # once, rather than one at a time?
- for ragged_axis, segment_lengths in axis_spec.ragged_axes:
- this_axis_spec = RaggedAxis(
- axis_spec.stacked_axis, ((ragged_axis, segment_lengths),))
- operand = _mask_one_ragged_axis(operand, ident, this_axis_spec)
- return operand
-
-def _mask_one_ragged_axis(
- operand: Array, ident, axis_spec: RaggedAxis) -> Array:
- # Callers of this utility, via reducer_batcher() or defreducer(),
- # must be in a context where lax is importable.
- from jax import lax # pytype: disable=import-error
- assert len(axis_spec.ragged_axes) == 1, "Mask just one ragged axis at a time"
- ragged_axis, segment_lengths = axis_spec.ragged_axes[0]
- value = ident(operand.dtype)
- positions = lax.broadcasted_iota('int32', operand.shape, ragged_axis)
- # TODO(mattjj, axch) can't get ._data, need to convert it
- # lengths = lax.convert_element_type(segment_lengths._data, 'int32')
- lengths = lax.convert_element_type(segment_lengths, 'int32')
- limits = lax.broadcast_in_dim(
- lengths, operand.shape, [axis_spec.stacked_axis])
- mask = positions < limits
- return lax.select(mask, operand, lax.broadcast(value, operand.shape))
-
-def move_stacked_axis(operand, bdim, dst):
- dst = canonicalize_axis(dst, operand.ndim)
- if isinstance(bdim, int):
- return moveaxis(operand, bdim, dst), dst
- elif isinstance(bdim, RaggedAxis):
- result = moveaxis(operand, bdim.stacked_axis, dst)
- return result, bdim.move_stacked_axis(dst)
- else:
- raise TypeError(f"Unrecognized batch dimension type {bdim}")
-
### general utilities for manipulating axes on jaxpr types (not vmappables)
def broadcast(x, sz, axis, mesh_axis):
@@ -1174,12 +788,6 @@ def matchaxis2(axis_data, src, dst, x, sum_match=False):
src, dst, x, sum_match)
def matchaxis(axis_name, sz, mesh_axis, src, dst, x, sum_match=False):
- if dst == jumble_axis:
- x = bdim_at_front(x, src, sz)
- elt_ty = x.aval.update(shape=x.shape[1:])
- aval = JumbleTy(core.Var(core.ShapedArray((), np.dtype('int32'))),
- x.shape[0], elt_ty)
- return Jumble(aval, x)
try:
_ = core.get_aval(x)
except TypeError as e:
diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py
index f9fca20bfa65..69c50c07bad1 100644
--- a/jax/_src/interpreters/mlir.py
+++ b/jax/_src/interpreters/mlir.py
@@ -184,7 +184,7 @@ def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type:
f"No dtype_to_ir_type handler for dtype: {dtype}") from err
return ir_type_factory()
-def _array_ir_types(aval: core.ShapedArray | core.DShapedArray) -> ir.Type:
+def _array_ir_types(aval: core.ShapedArray) -> ir.Type:
aval = core.physical_aval(aval) # type: ignore
if not core.is_constant_shape(aval.shape):
return _dynamic_array_ir_types(aval) # type: ignore
@@ -209,7 +209,6 @@ def aval_to_ir_type(aval: core.AbstractValue) -> IrTypes:
ir_type_handlers[core.ShapedArray] = _array_ir_types
ir_type_handlers[core.AbstractToken] = lambda _: hlo.TokenType.get()
-ir_type_handlers[core.DShapedArray] = _dynamic_array_ir_types
# This is a backwards compatibility shim for external users of jax.mlir apis.
def aval_to_ir_types(aval: core.AbstractValue) -> tuple[ir.Type, ...]:
@@ -974,27 +973,23 @@ def sharded_aval(aval: core.AbstractValue,
return aval
if isinstance(aval, core.AbstractToken):
return aval
- if not isinstance(aval, (core.ShapedArray, core.DShapedArray)):
+ if not isinstance(aval, core.ShapedArray):
raise NotImplementedError
return aval.update(sharding.shard_shape(aval.shape), sharding=None) # type: ignore
def eval_dynamic_shape(ctx: LoweringRuleContext,
shape: core.Shape) -> tuple[int | Value, ...]:
- if config.dynamic_shapes.value:
- assert ctx.axis_size_env is not None
- return tuple(ctx.axis_size_env.get(d, d) for d in shape) # type: ignore
- else:
- ctx = ctx.replace(
- primitive="eval_dynamic_shape",
- avals_in=[core.dim_value_aval()] * len(ctx.module_context.shape_poly_state.dim_vars),
- tokens_out=None)
+ ctx = ctx.replace(
+ primitive="eval_dynamic_shape",
+ avals_in=[core.dim_value_aval()] * len(ctx.module_context.shape_poly_state.dim_vars),
+ tokens_out=None)
- res = lower_fun(
- partial(core.evaluate_shape, shape, ctx.module_context.shape_poly_state.dim_vars),
- multiple_results=True)(ctx, *ctx.dim_var_values)
- return tuple(operator.index(d) if core.is_constant_dim(d) else d_ir
- for d, d_ir in zip(shape, flatten_ir_values(res)))
+ res = lower_fun(
+ partial(core.evaluate_shape, shape, ctx.module_context.shape_poly_state.dim_vars),
+ multiple_results=True)(ctx, *ctx.dim_var_values)
+ return tuple(operator.index(d) if core.is_constant_dim(d) else d_ir
+ for d, d_ir in zip(shape, flatten_ir_values(res)))
# TODO: replace usage of eval_dynamic_shape_as_vals with eval_dynamic_shape_as_ivals
def eval_dynamic_shape_as_vals(ctx: LoweringRuleContext,
@@ -1039,7 +1034,7 @@ class LoweringResult(NamedTuple):
shape_poly_state: ShapePolyLoweringState
-_platforms_with_donation = ["cpu", "cuda", "rocm", "tpu", "neuron"]
+_platforms_with_donation = ["cpu", "cuda", "rocm", "tpu", "neuron", "iree_metal"]
def add_manual_axes(axis_ctx: sharding_impls.SPMDAxisContext, sharding, ndim):
@@ -1083,7 +1078,7 @@ def _to_physical_op_sharding(
assert isinstance(sharding, JSharding)
if isinstance(aval, AbstractRef):
return _to_physical_op_sharding(ctx, aval.inner_aval, sharding)
- assert isinstance(aval, (core.ShapedArray, core.DShapedArray))
+ assert isinstance(aval, core.ShapedArray)
if dtypes.issubdtype(aval.dtype, dtypes.extended):
sharding = sharding_impls.physical_sharding(aval, sharding)
aval = core.physical_aval(aval)
@@ -1288,16 +1283,12 @@ def lower_jaxpr_to_module(
# Create a keepalives list that will be mutated during the lowering.
keepalives: list[Any] = []
host_callbacks: list[Any] = []
+ # Find the dimension variables
+ all_dim_poly = [d for aval in sharded_in_avals if hasattr(aval, "shape")
+ for d in aval.shape if not core.is_constant_dim(d)]
+ dim_vars = tuple(sorted(functools.reduce(lambda acc, new: acc.union(new._get_vars()),
+ all_dim_poly, set())))
- dim_vars: Sequence[str]
- if not config.dynamic_shapes.value:
- # Find the dimension variables
- all_dim_poly = [d for aval in sharded_in_avals if hasattr(aval, "shape")
- for d in aval.shape if not core.is_constant_dim(d)]
- dim_vars = tuple(sorted(functools.reduce(lambda acc, new: acc.union(new._get_vars()),
- all_dim_poly, set())))
- else:
- dim_vars = ()
ctx = ModuleContext(backend=backend,
platforms=platforms, axis_context=axis_context,
@@ -1974,7 +1965,7 @@ def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value:
# For example: if the key.shape is (8, 2) and key_data(key).shape is (8, 2, 2),
# then the sharding will be P(P.UNCONSTRAINED, P.UNCONSTRAINED, None).
# The below custom call achieves the sharding like above example.
- assert isinstance(aval, (core.ShapedArray, core.DShapedArray))
+ assert isinstance(aval, core.ShapedArray)
if config.use_shardy_partitioner.value:
physical_ndim = core.physical_aval(aval).ndim
s = SdyArray(
@@ -2065,8 +2056,7 @@ def write(v: core.Var, node: IrValues):
eqn.ctx.manager):
# TODO(mattjj, phawkins): support caching for dynamic shapes.
can_cache_lowering = (
- eqn.primitive not in _uncacheable_primitives and
- not config.dynamic_shapes.value)
+ eqn.primitive not in _uncacheable_primitives)
if can_cache_lowering:
loc = source_info_to_location(ctx, None, eqn_name_stack,
eqn.source_info.traceback)
@@ -2077,10 +2067,6 @@ def write(v: core.Var, node: IrValues):
else:
# If we cannot cache the lowering, lower inline.
axis_size_env = None
- if config.dynamic_shapes.value:
- axis_size_env = {d: read(d)
- for a in avals_in if type(a) is core.DShapedArray
- for d in a.shape if type(d) is core.Var}
rule_ctx = LoweringRuleContext(
module_context=ctx, primitive=eqn.primitive,
name_stack=eqn_name_stack,
@@ -2305,6 +2291,9 @@ def _platforms_for_eqn(ctx: LoweringRuleContext) -> tuple[str, ...]:
return tuple(_platforms_for_eqn_ctx(ctx.jaxpr_eqn_ctx) or
ctx.platforms or ctx.module_context.platforms)
+def _get_owner(v: ir.Value):
+ owner = v.owner
+ return owner.operation if isinstance(owner, ir.OpView) else owner
def lower_per_platform(ctx: LoweringRuleContext,
description: str,
@@ -2381,11 +2370,11 @@ def lower_per_platform(ctx: LoweringRuleContext,
if len(kept_rules) == 1:
output = kept_rules[0](ctx, *rule_args, **rule_kwargs)
foreach(
- lambda o: wrap_compute_type_in_place(ctx, o.owner),
+ lambda o: wrap_compute_type_in_place(ctx, _get_owner(o)),
filter(_is_not_block_argument, flatten_ir_values(output)),
)
foreach(
- lambda o: wrap_xla_metadata_in_place(ctx, o.owner),
+ lambda o: wrap_xla_metadata_in_place(ctx, _get_owner(o)),
flatten_ir_values(output),
)
return output
@@ -2426,11 +2415,11 @@ def lower_per_platform(ctx: LoweringRuleContext,
raise ValueError("Output of translation rule must be iterable: "
f"{description}, got output {output}") from e
foreach(
- lambda o: wrap_compute_type_in_place(ctx, o.owner),
+ lambda o: wrap_compute_type_in_place(ctx, _get_owner(o)),
filter(_is_not_block_argument, out_nodes),
)
foreach(
- lambda o: wrap_xla_metadata_in_place(ctx, o.owner),
+ lambda o: wrap_xla_metadata_in_place(ctx, _get_owner(o)),
out_nodes,
)
if inner_ctx.tokens_out is not None:
@@ -2465,26 +2454,16 @@ def f_lowered(ctx: LoweringRuleContext, *args, **params):
wrapped_fun = lu.wrap_init(f, params,
debug_info=api_util.debug_info("lower_fun", fun, args, {}))
- if config.dynamic_shapes.value:
- # We might be applying this function to arguments with dynamic shapes,
- # i.e. there might be Vars in the shape tuples of ctx.avals_in. In that
- # case, we need to form a jaxpr with leading binders for those axis size
- # arguments (by computing an InputType and using trace_to_jaxpr_dynamic2),
- # and we need to call jaxpr_subcomp with these arguments made explicit.
- assert ctx.axis_size_env is not None
- args = (*ctx.axis_size_env.values(), *args)
- idx = {d: core.DBIdx(i) for i, d in enumerate(ctx.axis_size_env)}
- i32_aval = core.ShapedArray((), np.dtype('int32'))
- implicit_args = [(i32_aval, False)] * len(ctx.axis_size_env)
- explicit_args = [(a.update(shape=tuple(idx.get(d, d) for d in a.shape)) # type: ignore
- if type(a) is core.DShapedArray else a, True)
- for a in ctx.avals_in]
- wrapped_fun = lu.annotate(wrapped_fun, (*implicit_args, *explicit_args))
- jaxpr, _, consts_for_constvars = pe.trace_to_jaxpr_dynamic2(wrapped_fun)
- else:
- jaxpr, _, consts_for_constvars = pe.trace_to_jaxpr_dynamic(wrapped_fun,
- ctx.avals_in)
- # TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out?
+ jaxpr, _, consts_for_constvars = pe.trace_to_jaxpr_dynamic(
+ wrapped_fun, ctx.avals_in)
+
+ if any(isinstance(e, core.InternalMutableArrayEffect) for e in jaxpr.effects):
+ from jax._src.interpreters import pxla # type: ignore
+ closed_jaxpr = core.ClosedJaxpr(jaxpr, consts_for_constvars)
+ closed_jaxpr = pxla._discharge_internal_refs(closed_jaxpr)
+ jaxpr, consts_for_constvars = closed_jaxpr.jaxpr, closed_jaxpr.consts
+
+ # TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out?
if ctx.platforms is not None:
sub_context = ctx.module_context.replace(platforms=ctx.platforms)
@@ -2634,11 +2613,11 @@ def wrap_compute_type_in_place(ctx: LoweringRuleContext, op: ir.Operation) -> No
"_xla_stream_annotation": ir.StringAttr.get(stream),
"inlineable": ir.StringAttr.get("false"),
}
- op.operation.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr)
+ op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr)
else:
dict_attr = {"_xla_compute_type": ir.StringAttr.get(
map_compute_type(ctx.jaxpr_eqn_ctx.compute_type))}
- op.operation.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr)
+ op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr)
def wrap_xla_metadata_in_place(ctx: LoweringRuleContext, op: ir.Operation) -> None:
@@ -2683,7 +2662,7 @@ def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue,
out = hlo.broadcast_in_dim(
aval_to_ir_type(aval_out), op,
dense_int_array(broadcast_dimensions))
- wrap_compute_type_in_place(ctx, out.owner)
+ wrap_compute_type_in_place(ctx, _get_owner(out))
return out
def multi_broadcast_in_dim(ctx: LoweringRuleContext,
@@ -2699,9 +2678,15 @@ def multi_broadcast_in_dim(ctx: LoweringRuleContext,
out_aval = core.ShapedArray(
out_shape, op_aval.dtype, sharding=out_sharding) # type: ignore
if core.definitely_equal_shape(op_aval_shape, out_shape):
- out.append(op if op_aval_sharding == out_sharding else
- lower_with_sharding_in_types(ctx, op, out_aval))
+ if op_aval_sharding.spec.unreduced or op_aval_sharding.spec.reduced:
+ out.append(op)
+ elif op_aval_sharding == out_sharding:
+ out.append(op)
+ else:
+ out.append(lower_with_sharding_in_types(ctx, op, out_aval))
else:
+ if op_aval_sharding.spec.unreduced or op_aval_sharding.spec.reduced:
+ raise NotImplementedError()
assert len(op_aval_shape) <= len(out_shape), (op_aval_shape, out_shape)
broadcast_dimensions = list(range(len(out_shape) - len(op_aval_shape), len(out_shape)))
b_out = broadcast_in_dim(
diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py
index a3dbfc55714f..7fcd9f3645d8 100644
--- a/jax/_src/interpreters/partial_eval.py
+++ b/jax/_src/interpreters/partial_eval.py
@@ -16,7 +16,7 @@
from __future__ import annotations
from collections import namedtuple
-from collections.abc import Callable, Sequence, Hashable
+from collections.abc import Callable, Sequence
import contextlib
from dataclasses import dataclass
from functools import partial
@@ -40,12 +40,10 @@
from jax._src.core import (
Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval, AbstractValue,
ClosedJaxpr, new_jaxpr_eqn, Var, DropVar, Atom, JaxprEqn, Primitive,
- ShapedArray, DShapedArray, mapped_aval, unmapped_aval, DBIdx, InDBIdx,
- OutDBIdx, InputType, OutputType, get_referent, JaxprEqnContext, typeof)
+ mapped_aval, unmapped_aval, get_referent, JaxprEqnContext, typeof)
from jax._src.source_info_util import SourceInfo
from jax._src.state.types import AbstractRef, ReadEffect
-from jax._src.tree_util import (PyTreeDef, treedef_tuple, register_static,
- tree_flatten, tree_unflatten)
+from jax._src.tree_util import PyTreeDef, treedef_tuple, FlatTree
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
merge_lists, partition_list, OrderedSet,
as_hashable_function, weakref_lru_cache,
@@ -64,42 +62,6 @@ def identity(x): return x
AttrKind = Any
PyTree = Any
-def _update_annotation_known(
- f: lu.WrappedFun,
- orig_type: InputType | None,
- in_knowns: list[bool]
- ) -> lu.WrappedFun:
- if orig_type is None: return f
- # orig_type might contain DBIdx, but we're tossing out some args so we have to
- # re-index. moreover some of the implicit args may not be needed anymore.
- # so we basically just re-infer the lambda input type
- if (all(e for _, e in orig_type) and
- not any(type(d) is DBIdx for a, _ in orig_type for d in a.shape
- if type(a) is DShapedArray)):
- new_type = [ty for ty, known in zip(orig_type, in_knowns) if known]
- return lu.annotate(f, tuple(new_type))
-
- # Replace DBIdx with names, prune down to explicit only.
- class Name:
- def __init__(self, a): self.a = a
- names = [Name(a) for a, _ in orig_type]
- avals = [a.update(shape=tuple(names[d.val] if type(d) is DBIdx else d
- for d in a.shape))
- if type(a) is DShapedArray else a for a, e in orig_type if e]
- avals = [a for a, known in zip(avals, in_knowns) if known]
- # Figure out the implicit part: names which aren't explicit and known.
- expl_names = [o for o, (_, e) in zip(names, orig_type) if e]
- expl_names = [o for o, k in zip(expl_names, in_knowns) if k]
- expl_names_ = set(expl_names)
- impl_names = {d for a in avals if type(a) is DShapedArray for d in a.shape
- if type(d) is Name and d not in expl_names_}
- impl_part = [(n.a, False) for n in impl_names] # type: ignore
- # Figure out the explicit part: known explicit avals, replacing names w/ dbidx
- name_map = {n: DBIdx(i) for i, n in enumerate((*impl_names, *expl_names))}
- expl_part = [(a.update(shape=tuple(name_map.get(d, d) for d in a.shape))
- if type(a) is DShapedArray else a, True) for a in avals]
- return lu.annotate(f, (*impl_part, *expl_part))
-
class PartialVal(tuple):
"""Partial value: either a known value or an unknown (abstract) value.
@@ -187,14 +149,6 @@ def new_arg(self, pval: PartialVal) -> JaxprTracer:
# known inputs (if it needs them, then they get passed through residuals).
if const is None:
aval = pval.get_aval()
- if type(aval) is DShapedArray:
- # TODO(dougalm): Fix the type error and remove the pytype pragmas.
- # pytype: disable=attribute-error
- shape = [self.new_instantiated_const(d)
- if isinstance(d, Tracer) and d._trace.level < self.level else d
- for d in aval.shape]
- # pytype: enable=attribute-error
- aval = aval.update(shape=tuple(shape))
return JaxprTracer(self, PartialVal.unknown(aval), LambdaBinding())
else:
return self.new_const(const)
@@ -282,27 +236,12 @@ def process_call(self, primitive, f: lu.WrappedFun, tracers, params):
const_params = update_params(params, in_knowns, 0)
# Run the call, getting known out vals and aux data used for staged-out call
- fun_and_args = (_update_annotation_known(f_, f.in_type, in_knowns),) + tuple(in_consts)
+ fun_and_args = (f_,) + tuple(in_consts)
out = primitive.bind_with_trace(self.parent_trace, fun_and_args, const_params)
fwds, out_knowns, out_type, jaxpr, env = aux()
# Split apart known outputs from the original call and non-fwded residuals.
out_consts, non_fwd_res = split_list(out, [sum(out_knowns)])
-
- # Form the complete list of residuals by forwarding some inputs.
- if config.dynamic_shapes.value:
- # With dynamic shapes, we may need to forward implicit arguments.
- assert f.in_type is not None, "f must be annotated with lu.annotate()"
- in_consts_, in_knowns_ = iter(in_consts), iter(in_knowns)
- in_consts_full = [None] * len(f.in_type)
- for idx, (aval, explicit) in enumerate(f.in_type):
- if explicit and next(in_knowns_):
- c = in_consts_full[idx] = next(in_consts_)
- if aval.shape:
- for d1, d2 in zip(aval.shape, c.shape):
- if type(d1) is DBIdx:
- in_consts_full[d1.val] = d2
- else:
- in_consts_full = in_consts
+ in_consts_full = in_consts
res = subs_list(fwds, in_consts_full, non_fwd_res)
# Create the input tracers for the staged-out (unknown-value) call.
@@ -317,19 +256,8 @@ def process_call(self, primitive, f: lu.WrappedFun, tracers, params):
staged_params = dict(params, call_jaxpr=new_jaxpr)
staged_params = update_params(staged_params, map(op.not_, in_knowns),
num_new_args)
- # The outputs of the staged-out call are Tracers with the new eqn as recipe.
- if config.dynamic_shapes.value:
- # With dynamic shapes, we may need to substitute Tracers into avals.
- out_tracers = []
- for aval, _ in out_type:
- if type(aval) is DShapedArray:
- shape = [[*res_tracers, *env_tracers, *unknown_arg_tracers][d.val]
- if type(d) is InDBIdx else d for d in aval.shape]
- aval = aval.update(shape=tuple(shape))
- out_tracers.append(JaxprTracer(self, PartialVal.unknown(aval), None))
- else:
- out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None)
- for a in out_type]
+ out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None)
+ for a in out_type]
name_stack = self._current_truncated_name_stack()
source = source_info_util.current().replace(name_stack=name_stack)
eqn = new_eqn_recipe(self, (*res_tracers, *env_tracers, *unknown_arg_tracers),
@@ -568,8 +496,6 @@ def parents(self) -> Sequence[JaxprTracer]:
if isinstance(self.recipe, JaxprEqnRecipe):
# TODO broadcast_in_dim can create a new tracer...
return self.recipe.in_tracers
- elif isinstance(self.aval, DShapedArray):
- return [d for d in self.aval.shape if isinstance(d, JaxprTracer)]
else:
return []
@@ -814,19 +740,11 @@ def get_atom(t: JaxprTracer) -> Atom:
def newvar(t: JaxprTracer | None) -> Var:
assert t is not None
- var = gensym(type_substitute(t.aval))
+ var = gensym(t.aval)
var_ = t_to_var.setdefault(id(t), var)
assert var is var_
return var
- def type_substitute(aval: AbstractValue) -> AbstractValue:
- if isinstance(aval, DShapedArray):
- # Replace any Tracers in aval.shape with Vars or Literal values
- shape = [get_atom(d) if type(d) is JaxprTracer else d for d in aval.shape]
- shape = [d.val if type(d) is Literal else d for d in shape]
- aval = aval.update(shape=tuple(shape))
- return aval
-
processed_eqn_ids = set()
eqns: list[core.JaxprEqn] = []
@@ -843,7 +761,7 @@ def sort_key(t):
# TODO broadcast_in_dim can create a new tracer, not present in parents
if r.eqn_id not in processed_eqn_ids:
in_atoms = map(get_atom, r.in_tracers)
- outvars = [DropVar(type_substitute(a)) if rf() is None else newvar(rf())
+ outvars = [DropVar(a) if rf() is None else newvar(rf())
for a, rf in zip(r.out_avals, r.out_tracer_refs)]
eqns.append(new_jaxpr_eqn(in_atoms, outvars, r.primitive, r.params,
r.effects, r.source_info, r.ctx))
@@ -884,6 +802,11 @@ def move_envvars(jaxpr: Jaxpr, which: tuple[bool, ...]) -> Jaxpr:
constvars, envvars = partition_list(which, jaxpr.constvars)
return jaxpr.replace(constvars=constvars, invars=[*envvars, *jaxpr.invars])
+@weakref_lru_cache
+def separate_consts(jaxpr: ClosedJaxpr) -> tuple[ClosedJaxpr, list[Any]]:
+ """Moves the constvars to the start of invars and returns the consts explicitly."""
+ return ClosedJaxpr(convert_constvars_jaxpr(jaxpr.jaxpr), []), jaxpr.consts
+
@weakref_lru_cache
def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr:
"""Moves the constvars to the start of invars."""
@@ -1127,7 +1050,8 @@ def ensure_instantiated(inst: bool, x: Atom) -> Atom:
return x
def has_effects(effects) -> bool:
- return bool({e for e in effects if not isinstance(e, core.NamedAxisEffect)})
+ not_really_effects = (core.NamedAxisEffect, core.InternalMutableArrayEffect)
+ return any(not isinstance(e, not_really_effects) for e in effects)
known_eqns, staged_eqns = [], []
foreach(write, in_unknowns, in_inst, jaxpr.invars)
@@ -1682,7 +1606,7 @@ def __init__(self, trace: DynamicJaxprTrace,
self.parent = parent
def _short_repr(self):
- return f"JitTracer<{self.aval}>"
+ return f"JitTracer({self.aval})"
def cur_qdd(self):
return self.mutable_qdd.cur_val
@@ -1869,25 +1793,7 @@ def to_jaxpr(
jaxpr = Jaxpr(constvars, self.invars, outvars, eqns, effs, debug_info, is_high)
return jaxpr, list(constvals)
- def to_jaxpr2(self, out_tracers: Sequence[core.Tracer],
- debug_info: core.DebugInfo):
- eqns = self.get_eqns()
- outvars = [t.val for t in out_tracers]
- constvars, constvals = unzip2(self.constvar_to_val.copy().items())
- constvals = [c.canonical for c in constvals]
- constvars, constvals = _drop_unused_vars(constvars, constvals, eqns, outvars)
- effs = make_jaxpr_effects(constvars, self.invars, outvars, eqns)
- jaxpr = Jaxpr(constvars, self.invars, outvars, eqns, effs, debug_info)
- jaxpr, out_type = _add_implicit_outputs(jaxpr)
- config.enable_checks.value and core.check_jaxpr(jaxpr)
- return jaxpr, out_type, constvals
-
def newvar(self, aval):
- if isinstance(aval, DShapedArray):
- # this aval may have tracers in it, so we replace those with variables
- new_shape = [d.val if isinstance(d, Tracer) else d for d in aval.shape]
- new_shape = [d.val if isinstance(d, Literal) else d for d in new_shape]
- aval = aval.update(shape=tuple(new_shape))
if isinstance(aval, core.AvalQDD):
return self.gensym(aval.aval, initial_qdd=aval.qdd)
else:
@@ -1932,8 +1838,6 @@ def vars(atom: Atom) -> list[Var]:
if isinstance(atom, Literal):
return []
aval = atom.aval
- if isinstance(aval, DShapedArray):
- return [atom] + [d for d in aval.shape if isinstance(d, Var)]
return [atom]
used: set[Var] = {v for atom in outvars for v in vars(atom)}
for eqn in eqns[::-1]:
@@ -2067,7 +1971,6 @@ def new_const(self, c, source_info: SourceInfo,
if aval.has_qdd:
with core.set_current_trace(self.parent_trace or core.eval_trace):
aval = core.AvalQDD(aval, core.cur_qdd(c)) # type: ignore
- aval = self._lift_tracers_in_aval(aval, source_info)
tracer = self._new_const(aval, c, source_info)
return tracer
@@ -2104,14 +2007,6 @@ def get_const(self, tracer) -> Any:
const = const.canonical
return const
- def _lift_tracers_in_aval(self, aval, source_info: SourceInfo):
- if (not isinstance(aval, DShapedArray) or
- not any(isinstance(d, Tracer) for d in aval.shape)):
- return aval
- shape = [self.to_jaxpr_tracer(d, source_info) if isinstance(d, Tracer) else d
- for d in aval.shape]
- return aval.update(shape=tuple(shape))
-
def cur_qdd(self, x):
source_info = source_info_util.current()
return self.to_jaxpr_tracer(x, source_info=source_info).mutable_qdd.cur_val
@@ -2139,8 +2034,8 @@ def default_process_primitive(self, primitive, tracers, params,
# TODO(mattjj,dougalm): clean up how we check for new-style hi primitives
if primitive is call_hi_primitive_p:
out_avals, effs = params['prim'].out_avals_flat, set() # TODO effs
- elif (primitive.name == "custom_lin" or config.dynamic_shapes.value or
- primitive.is_effectful and primitive.is_effectful(params)):
+ elif (primitive.name in ("custom_lin", "call_hi_primitive_linearized") or
+ primitive.is_effectful and primitive.is_effectful(params)):
out_avals, effs = primitive.abstract_eval(*aval_qdds, **params)
else:
try:
@@ -2176,37 +2071,33 @@ def default_process_primitive(self, primitive, tracers, params,
self.frame.add_eqn(eqn)
return out_tracers if primitive.multiple_results else out_tracers.pop()
- def process_call(self, call_primitive, f: lu.WrappedFun, explicit_tracers,
+ def process_call(self, call_primitive, f: lu.WrappedFun, in_tracers,
params):
source_info = source_info_util.current()
to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info)
- in_type = (tuple((get_aval(t), True) for t in explicit_tracers)
- if f.in_type is None else f.in_type)
+ in_type = (tuple(get_aval(t) for t in in_tracers) if f.in_type is None
+ else f.in_type)
f.in_type = None
assert in_type is not None
- implicit_tracers = _extract_implicit_args(self, in_type, explicit_tracers,
- source_info)
- in_tracers = map(to_jaxpr_tracer, [*implicit_tracers, *explicit_tracers])
+ in_tracers = map(to_jaxpr_tracer, in_tracers)
# TODO(mattjj): check in_tracers are consistent with f.in_type annotation
- jaxpr, out_type, consts = _cached_trace_to_jaxpr(f, in_type)
+ jaxpr, out_avals, consts = _cached_trace_to_jaxpr(f, in_type)
if params.get('inline', False):
return core.eval_jaxpr(jaxpr, consts, *in_tracers,
propagate_source_info=False)
- out_avals = [aval for aval, _ in out_type]
new_jaxpr = convert_constvars_jaxpr(jaxpr)
if isinstance(call_primitive, core.ClosedCallPrimitive):
new_jaxpr = close_jaxpr(new_jaxpr) # type: ignore
new_params = dict(params, call_jaxpr=new_jaxpr)
update_params = call_param_updaters.get(call_primitive)
if update_params:
- new_params = update_params(new_params, [True] * len(explicit_tracers),
- len(consts) + len(implicit_tracers))
+ new_params = update_params(new_params, [True] * len(in_tracers),
+ len(consts))
const_tracers = map(to_jaxpr_tracer, consts)
- out_tracers = self.emit_eqn(
+ return self.emit_eqn(
[*const_tracers, *in_tracers], out_avals, call_primitive,
new_params, new_params['call_jaxpr'].effects, source_info=source_info)
- return [t for t, (_, keep) in zip(out_tracers, out_type) if keep]
def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
source_info = source_info_util.current()
@@ -2245,6 +2136,9 @@ def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
def process_custom_jvp_call(self, prim, fun: lu.WrappedFun,
jvp: lu.WrappedFun, tracers,
symbolic_zeros: bool):
+ if config.eager_constant_folding.value and not any(isinstance(x, Tracer) for x in tracers):
+ return prim.bind_with_trace(core.eval_trace, (fun, jvp, *tracers),
+ dict(symbolic_zeros=symbolic_zeros))
source_info = source_info_util.current()
to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info)
tracers = map(to_jaxpr_tracer, tracers)
@@ -2279,6 +2173,9 @@ def process_custom_vjp_call(self, prim: core.Primitive,
fwd: lu.WrappedFun, bwd: lu.WrappedFun, tracers,
out_trees: Callable[[], tuple[PyTreeDef, PyTreeDef, list[int | None]]],
symbolic_zeros: bool):
+ if config.eager_constant_folding.value and not any(isinstance(x, Tracer) for x in tracers):
+ return prim.bind_with_trace(core.eval_trace, (fun, fwd, bwd, *tracers),
+ dict(out_trees=out_trees, symbolic_zeros=symbolic_zeros))
source_info = source_info_util.current()
to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info)
tracers = map(to_jaxpr_tracer, tracers)
@@ -2354,10 +2251,9 @@ def to_jaxpr(self, out_tracers: Sequence[Tracer],
return self.frame.to_jaxpr(self, out_tracers, debug_info, source_info)
-
@lu.cache
def _cached_trace_to_jaxpr(f, in_type):
- jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(lu.annotate(f, in_type))
+ jaxpr, out_type, consts = trace_to_jaxpr_dynamic(lu.annotate(f, in_type), in_type)
return jaxpr, out_type, consts
@@ -2397,35 +2293,34 @@ def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals):
@weakref_lru_cache
def trace_to_jaxpr(
fun: Callable,
- in_tree: PyTreeDef,
- in_avals_flat: Sequence[AbstractValue | core.AvalQDD],
+ in_avals: FlatTree, # (args, kwargs) pair
debug_info: core.DebugInfo
-) -> tuple[Jaxpr, PyTreeDef, list[Any]]:
- config.enable_checks.value and debug_info.assert_arg_names(len(in_avals_flat))
+) -> tuple[ClosedJaxpr, FlatTree]:
+ config.enable_checks.value and debug_info.assert_arg_names(len(in_avals))
parent_trace = core.trace_ctx.trace
trace = DynamicJaxprTrace(debug_info, parent_trace=parent_trace)
# Name stacks are reset because the name stacks on jaxpr equations should be
# rooted at the enclosing jaxpr.
with core.ensure_no_leaks(trace), source_info_util.reset_name_stack():
source_info = source_info_util.current()
- in_tracers_flat = map(partial(trace.new_arg, source_info=source_info),
- in_avals_flat)
+ in_tracers = in_avals.map(partial(trace.new_arg, source_info=source_info))
with core.set_current_trace(trace):
- in_tracers = tree_unflatten(in_tree, in_tracers_flat)
- ans = fun(*in_tracers)
- debug_info = debug_info.set_result_paths(ans)
- ans_flat, out_tree = tree_flatten(ans)
-
- _check_returned_jaxtypes(debug_info, ans_flat)
- out_tracers = map(partial(trace.to_jaxpr_tracer, source_info=source_info), ans_flat)
- _check_no_returned_refs(debug_info, out_tracers)
- jaxpr, consts = trace.frame.to_jaxpr(trace, out_tracers, debug_info,
+ args, kwargs = in_tracers.unflatten()
+ ans_pytree = fun(*args, **kwargs)
+ debug_info = debug_info.set_result_paths(ans_pytree)
+ ans = FlatTree.flatten(ans_pytree)
+ del ans_pytree, args, kwargs
+
+ _check_returned_jaxtypes(debug_info, list(ans))
+ out_tracers = ans.map(partial(trace.to_jaxpr_tracer, source_info=source_info))
+ out_avals = out_tracers.map(lambda t: t.aval)
+ _check_no_returned_refs(debug_info, list(out_tracers))
+ jaxpr, consts = trace.frame.to_jaxpr(trace, list(out_tracers), debug_info,
source_info)
- del trace, fun, in_tracers_flat, in_tracers, out_tracers, ans, ans_flat
+ del trace, fun, in_tracers, out_tracers, ans
config.enable_checks.value and core.check_jaxpr(jaxpr)
- return jaxpr, out_tree, consts
-
+ return ClosedJaxpr(jaxpr, consts), out_avals
# TODO(dougalm): remove in favor of `trace_to_jaxpr`
@profiler.annotate_function
@@ -2446,8 +2341,7 @@ def trace_to_jaxpr_dynamic(
# rooted at the enclosing jaxpr.
with core.ensure_no_leaks(trace), source_info_util.reset_name_stack():
source_info = source_info_util.current()
- in_tracers = _input_type_to_tracers(
- partial(trace.new_arg, source_info=source_info), in_avals)
+ in_tracers = map(partial(trace.new_arg, source_info=source_info), in_avals)
in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
with core.set_current_trace(trace):
@@ -2505,176 +2399,6 @@ def _check_no_returned_refs(
f"a mutable array reference of type {a.str_short()}{loc}, but "
f"mutable array references cannot be returned.{origin_info}")
-@profiler.annotate_function
-def trace_to_jaxpr_dynamic2(
- fun: lu.WrappedFun,
- ) -> tuple[Jaxpr, OutputType, list[Any]]:
- assert fun.in_type is not None, "fun must be annotated with lu.annotate()"
- config.enable_checks.value and fun.debug_info.assert_arg_names(len(fun.in_type))
-
- parent_trace = core.trace_ctx.trace
- trace = DynamicJaxprTrace(fun.debug_info, parent_trace=parent_trace)
- with core.ensure_no_leaks(trace), source_info_util.reset_name_stack():
- source_info = source_info_util.current()
- in_avals, keep_inputs = unzip2(fun.in_type)
- in_tracers = _input_type_to_tracers(
- partial(trace.new_arg, source_info=source_info), in_avals)
- in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
- with core.set_current_trace(trace):
- ans = fun.call_wrapped(*in_tracers)
- out_tracers = map(partial(trace.to_jaxpr_tracer, source_info=source_info), ans)
- jaxpr = trace.frame.to_jaxpr2(out_tracers, fun.debug_info)
- del trace, in_tracers, out_tracers, ans
- return jaxpr
-
-AbstractedAxisName = Hashable
-AbstractedAxesSpec = Union[
- dict[int, AbstractedAxisName],
- tuple[AbstractedAxisName, ...],
-]
-
-@register_static
-class DoesNotExist: ...
-dne_sentinel = DoesNotExist()
-
-
-def infer_lambda_input_type(
- axes_specs: Sequence[AbstractedAxesSpec] | None,
- args: Sequence[Any]
- ) -> InputType:
- ndims = [getattr(get_aval(x), 'ndim', 0) for x in args]
- partial_specs = _canonicalize_specs(ndims, axes_specs)
- specs = _complete_specs(args, partial_specs)
- idxs, implicit_types = _collect_implicit(args, specs)
- implicit_sig = [(ty, False) for ty in implicit_types]
- explicit_sig = [(_arg_type(idxs, x, s), True) for x, s in zip(args, specs)]
- input_type = (*implicit_sig, *explicit_sig)
- lu._check_input_type(input_type)
- return input_type
-
-def _spec_to_dict(spec: AbstractedAxesSpec) -> dict[int, AbstractedAxisName]:
- if isinstance(spec, tuple):
- return {i: d for i, d in enumerate(spec) if d is not None}
- else:
- return spec
-
-def _canonicalize_specs(
- ndims: Sequence[int], specs: Sequence[AbstractedAxesSpec] | None
- ) -> list[dict[int, AbstractedAxisName]]:
- if specs is None:
- return [{}] * len(ndims)
- else:
- return [_spec_to_dict(s) for n, s in zip(ndims, specs)]
-
-def _complete_specs(
- args: Sequence[Any], partial_specs: list[dict[int, AbstractedAxisName]]
- ) -> list[dict[int, AbstractedAxisName]]:
- # The abstracted axes specification in `partial_specs` is partial in the sense
- # that there could be additional axis abstraction represented in `args` due to
- # Tracers existing in the shapes of elements of `args`. The purpose of this
- # function is to produce a full specification, for each argument mapping any
- # abstracted axis positions to a name, introducing new names as needed for
- # Tracers in axis sizes which don't already correspond to abstracted axis
- # names (with one new name per unique Tracer object id).
-
- # Identify each user-supplied name in partial_specs with a size.
- sizes: dict[AbstractedAxisName, int | DynamicJaxprTracer] = {}
- for x, spec in zip(args, partial_specs):
- for i, name in spec.items():
- d = sizes.setdefault(name, x.shape[i])
- if d is not x.shape[i] and d != x.shape[i]:
- raise TypeError(f"Provided size {d} for {name} does not match prior associated name for {name} : {x.shape[i]}")
-
- # Introduce new names as needed for Tracers in shapes.
- named_tracers: dict[TracerId, AbstractedAxisName] = {
- id(d): name for name, d in sizes.items() if isinstance(d, Tracer)}
- specs: list[dict[int, AbstractedAxisName]] = []
- for x, spec in zip(args, partial_specs):
- if isinstance(get_aval(x), DShapedArray):
- spec = dict(spec)
- for i, d in enumerate(x.shape):
- if isinstance(d, Tracer):
- spec[i] = named_tracers.get(id(d), TracerAsName(d))
- specs.append(spec)
-
- # Assert that `specs` is now complete in the sense that there are no Tracers
- # which don't correspond to an AbstractedAxisName.
- assert all(not spec or not any(isinstance(d, Tracer) and i not in spec
- for i, d in enumerate(x.shape))
- for x, spec in zip(args, specs))
- return specs
-
-
-def _collect_implicit(
- args: Sequence[Any], specs: list[dict[int, AbstractedAxisName]]
- ) -> tuple[dict[AbstractedAxisName, DBIdx], list[AbstractValue]]:
- # Given an explicit argument list and a specification of abstracted axes, we
- # want to produce an InputType by identifying AbstractedAxisNames with DBIdxs
- # and figuring out which AbstractedAxisNames correspond to implicit arguments.
-
- idxs: dict[AbstractedAxisName, DBIdx] = {}
- implicit_types: list[AbstractValue] = []
- explicit_tracers: dict[TracerId, int] = {}
- counter = it.count()
-
- # Add implicit arguments to idxs.
- for explicit_idx, (x, spec) in enumerate(zip(args, specs)):
- for i, name in spec.items():
- if name not in idxs and id(x.shape[i]) not in explicit_tracers:
- idxs[name] = DBIdx(next(counter))
- implicit_types.append(get_aval(x.shape[i]))
- if isinstance(x, Tracer):
- explicit_tracers.setdefault(id(x), explicit_idx) # use the first
-
- # Now that we know the implicit args, add explicit args to idxs.
- offset = len(implicit_types)
- for x, spec in zip(args, specs):
- for i, name in spec.items():
- if id(x.shape[i]) in explicit_tracers:
- idxs.setdefault(name, DBIdx(offset + explicit_tracers[id(x.shape[i])]))
-
- return idxs, implicit_types
-
-def _arg_type(
- idxs: dict[AbstractedAxisName, DBIdx], x: Any,
- spec: dict[int, AbstractedAxisName]
- ) -> AbstractValue:
- # Produce an AbstractValue by substituting DBIdxs for AbstractedAxisNames.
- aval = get_aval(x) # aval.shape could contain Tracers
- if not spec: return aval
- shape: list[int | DBIdx] = [idxs[spec[i]] if i in spec else d
- for i, d in enumerate(aval.shape)]
- assert not any(isinstance(d, Tracer) for d in shape)
- return DShapedArray(tuple(shape), aval.dtype, False)
-
-def _add_implicit_outputs(jaxpr: Jaxpr) -> tuple[Jaxpr, OutputType]:
- invars = [*jaxpr.constvars, *jaxpr.invars]
- expl_outvars = jaxpr.outvars
-
- # First do a pass to collect implicit outputs, meaning variables which occur
- # in explicit_outvars types but not in invars or to the left in outvars.
- seen: set[Var] = set(invars)
- impl_outvars = [seen.add(d) or d for x in expl_outvars if type(x) is Var and # type: ignore
- (seen.add(x) or type(x.aval) is DShapedArray) # type: ignore
- for d in x.aval.shape if type(d) is Var and d not in seen]
- outvars = [*impl_outvars, *expl_outvars]
-
- # Now assemble an OutputType by mapping vars in shapes to InDBIdx/OutDBIdx.
- in_map : dict[Var, InDBIdx] = {v: InDBIdx(i) for i, v in enumerate( invars)}
- out_map: dict[Var, OutDBIdx] = {x: OutDBIdx(i) for i, x in enumerate(outvars)
- if type(x) is Var}
- out_avals_ = (x.aval for x in outvars)
- out_avals = [a.update(shape=tuple(in_map.get(d, out_map.get(d))
- if type(d) is Var else d for d in a.shape))
- if type(a) is DShapedArray else a for a in out_avals_]
- kept_outs = [False] * len(impl_outvars) + [True] * len(expl_outvars)
- out_type = tuple(zip(out_avals, kept_outs))
-
- new_jaxpr = jaxpr.replace(outvars=outvars)
- config.enable_checks.value and core.check_jaxpr(jaxpr)
- return new_jaxpr, out_type
-
-
class TracerAsName:
ref: Any
def __init__(self, tracer):
@@ -2684,155 +2408,9 @@ def __eq__(self, other):
def __hash__(self):
return id(self.ref)
-def _extract_implicit_args(
- trace: DynamicJaxprTrace, in_type: Sequence[tuple[AbstractValue, bool]],
- explicit_tracers: Sequence[DynamicJaxprTracer], source_info: SourceInfo,
- ) -> Sequence[DynamicJaxprTracer]:
- # First, construct a list to represent the full argument list, leaving the
- # implicit arguments as Nones for now.
- explicit_tracers_ = iter(explicit_tracers)
- tracers = [next(explicit_tracers_) if expl else None for _, expl in in_type]
- assert next(explicit_tracers_, None) is None
- del explicit_tracers_
-
- # Next, populate the implicit arguments using DBIdxs in in_type.
- for i, (aval, explicit) in enumerate(in_type):
- if not explicit or not isinstance(aval, DShapedArray):
- continue # can't populate an implicit argument
- tracer = tracers[i]
- assert tracer is not None
- for d1, d2 in zip(aval.shape, tracer.aval.shape):
- if isinstance(d1, DBIdx):
- if tracers[d1.val] is None:
- tracers[d1.val] = trace.to_jaxpr_tracer(d2, source_info)
- assert tracers[d1.val] is trace.to_jaxpr_tracer(d2, source_info)
- assert all(t is not None for t in tracers)
- return [t for t, (_, e) in zip(tracers, in_type) if not e] # type: ignore
-
-def _input_type_to_tracers(
- new_arg: Callable[[AbstractValue | core.AvalQDD], Tracer],
- in_avals: Sequence[AbstractValue | core.AvalQDD]
- ) -> Sequence[Tracer]:
- # Create input Tracers given input AbstractValues, each of which can contain
- # DeBruijn indices which refer to positions in the input argument list. That
- # is, each element `a` of `in_avals` can have DBIdx instances in its shape,
- # which must refer to positions left of `a`'s.
- in_tracers: list[Tracer] = []
-
- def _substitute_tracers_in_aval(a):
- if isinstance(a, DShapedArray) and any(type(d) is DBIdx for d in a.shape):
- shape = [in_tracers[d.val] if type(d) is DBIdx else d for d in a.shape]
- return a.update(shape=tuple(shape))
- return a
-
- for a in in_avals:
- in_tracers.append(new_arg(_substitute_tracers_in_aval(a)))
- return in_tracers
-
Const = Any
Val = Any
-def pad_jaxpr(jaxpr: Jaxpr, consts: Sequence[Const]
- ) -> tuple[Jaxpr, list[Const]]:
- bounds = {v: v.aval.dtype.bound for v in jaxpr.invars
- if isinstance(v.aval, (core.ShapedArray, core.DShapedArray)) and
- type(v.aval.dtype) is core.bint and not v.aval.shape}
- idxs = {v: DBIdx(i) for i, v in enumerate(jaxpr.invars)}
-
- def substitute(aval: AbstractValue) -> AbstractValue:
- if (isinstance(aval, (core.ShapedArray, core.DShapedArray))
- and type(aval.dtype) is core.bint and not aval.shape):
- return ShapedArray((), dtypes.scalar_type_to_dtype(int))
- elif isinstance(aval, DShapedArray):
- shape = [bounds.get(d, idxs.get(d, d)) for d in aval.shape] # type: ignore
- typ = ShapedArray if all(type(d) is int for d in shape) else DShapedArray
- return typ(tuple(shape), aval.dtype, aval.weak_type)
- else:
- return aval
-
- in_avals = [substitute(v.aval) for v in jaxpr.invars]
- eval_padded = lu.wrap_init(partial(_eval_jaxpr_padded, jaxpr, consts),
- debug_info=jaxpr.debug_info)
- padded_jaxpr, _, padded_consts = trace_to_jaxpr_dynamic(eval_padded, in_avals)
- return padded_jaxpr, padded_consts
-
-class BoundedAxisSize(NamedTuple):
- val: int | DynamicJaxprTracer
- bound: int
-
-def _eval_jaxpr_padded(
- jaxpr: Jaxpr, consts: Sequence[Const], *args: DynamicJaxprTracer
- ) -> list[Const | DynamicJaxprTracer]:
- env: dict[Var, Val] = {}
-
- def read(x):
- return x.val if type(x) is Literal else env[x]
-
- def write(v, val) -> None:
- env[v] = val
-
- foreach(write, jaxpr.constvars, consts)
- foreach(write, jaxpr.invars, args)
- last_used = core.last_used(jaxpr)
- for eqn in jaxpr.eqns:
- in_avals = [_substitute_axis_sizes(env, v.aval) for v in eqn.invars]
- out_avals = [_substitute_axis_sizes(env, v.aval) for v in eqn.outvars]
- rule = padding_rules[eqn.primitive]
- outs = rule(in_avals, out_avals, *map(read, eqn.invars), **eqn.params)
- foreach(write, eqn.outvars, outs)
- core.clean_up_dead_vars(eqn, env, last_used)
- return map(read, jaxpr.outvars)
-
-def _substitute_axis_sizes(env: dict, aval: AbstractValue) -> AbstractValue:
- if isinstance(aval, DShapedArray):
- shp = []
- for d in aval.shape:
- if isinstance(d, core.DArray):
- assert not d.shape and type(d.dtype) is core.bint
- shp.append(BoundedAxisSize(int(d._data), int(d.dtype.bound)))
- elif (type(d) is core.Var and isinstance(d.aval, core.DShapedArray) and
- type(d.aval.dtype) is core.bint):
- assert not d.aval.shape
- shp.append(BoundedAxisSize(env[d], d.aval.dtype.bound))
- else:
- shp.append(env.get(d, d))
- return DShapedArray(tuple(shp), aval.dtype, aval.weak_type)
- else:
- return aval
-
-def _is_bint_axis_size(d: int | core.DArray | core.Var) -> bool:
- if isinstance(d, core.DArray):
- assert not d.shape # pytype: disable=attribute-error
- return type(d.dtype) is core.bint # pytype: disable=attribute-error
- elif isinstance(d, core.Var):
- return (isinstance(d.aval, core.DShapedArray) and # pytype: disable=attribute-error
- type(d.aval.dtype) is core.bint) # pytype: disable=attribute-error
- return False
-
-
-padding_rules: dict[Primitive, Callable] = {}
-
-def def_trivial_padding(prim: Primitive) -> None:
- if prim.multiple_results:
- padding_rules[prim] = partial(_trivial_padding_rule_multi, prim)
- else:
- padding_rules[prim] = partial(_trivial_padding_rule, prim)
-
-def _trivial_padding_rule(prim, _, __, *args, **params):
- return [prim.bind(*args, **params)]
-
-def _trivial_padding_rule_multi(prim, _, __, *args, **params):
- return prim.bind(*args, **params)
-
-def call_padding_rule(prim, in_avals, out_avals, *args, call_jaxpr, **params):
- if call_jaxpr.constvars: raise NotImplementedError
- padded_jaxpr, padded_consts = pad_jaxpr(call_jaxpr, ())
- if padded_consts: raise NotImplementedError
- new_params = dict(params, call_jaxpr=padded_jaxpr)
- subfuns, bind_params = prim.get_bind_params(new_params)
- return prim.bind(*subfuns, *args, **bind_params)
-
-
def instantiate_const_at(trace: JaxprTrace, instantiate: bool, tracer):
if instantiate:
return trace.instantiate_const(tracer)
diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py
index f12a4b518d0b..911d1c22af96 100644
--- a/jax/_src/interpreters/pxla.py
+++ b/jax/_src/interpreters/pxla.py
@@ -51,7 +51,6 @@
from jax._src import util
from jax._src import xla_bridge as xb
from jax._src.abstract_arrays import array_types
-from jax._src.core import DShapedArray
from jax._src.core import ShapedArray
from jax._src.interpreters import ad
from jax._src.interpreters import batching
@@ -59,6 +58,7 @@
from jax._src.interpreters import mlir
from jax._src.layout import Layout, AutoLayout, Format
from jax._src.lib import _jax
+from jax._src.lib import jaxlib_extension_version
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
@@ -67,10 +67,8 @@
from jax._src.mesh import (AbstractMesh, Mesh, get_abstract_mesh,
get_concrete_mesh)
from jax._src.sharding_impls import (
- ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UnspecifiedValue,
- get_array_mapping as _get_array_mapping, array_mapping_to_axis_resources,
- SingleDeviceSharding, GSPMDSharding, NamedSharding,
- PartitionSpec as P)
+ ArrayMapping, AUTO, UnspecifiedValue, array_mapping_to_axis_resources,
+ SingleDeviceSharding, GSPMDSharding, NamedSharding, PartitionSpec as P)
from jax._src.util import (safe_map, safe_zip, partition_list, wrap_name,
tuple_update, tuple_delete, distributed_debug_log,
unzip2, HashableFunction, weakref_lru_cache,
@@ -230,11 +228,6 @@ def _shard_typed_scalar(xs, shardings, layouts, copy_semantics):
for _t in literals.typed_scalar_types:
shard_arg_handlers[_t] = _shard_typed_scalar
-def _shard_darray(xs, shardings, layouts, copy_semantics):
- bufs = [x._data for x in xs]
- return shard_args(shardings, layouts, copy_semantics, bufs)
-shard_arg_handlers[core.DArray] = _shard_darray
-
def _shard_mutable_array(xs, shardings, layouts, copy_semantics):
bufs = [x._refs._buf for x in xs]
return shard_args(shardings, layouts, copy_semantics, bufs)
@@ -304,7 +297,7 @@ def local_aval_to_result_handler(
raise TypeError(
f"No pxla_result_handler for type: {type(aval)}") from err
-PxlaResultHandler = Callable[..., Callable[[Any], Any]]
+PxlaResultHandler = Callable[..., xc._xla.ResultHandler]
local_result_handlers: dict[type[core.AbstractValue], PxlaResultHandler] = {}
@@ -420,6 +413,7 @@ def _emap_impl(fun: lu.WrappedFun, *args,
platform = xb.get_backend(backend).platform
donate_argnums = (1,) if platform in {"cuda", "rocm", "tpu"} else ()
new_outvals = []
+ assert len(out_axes_src) == len(out_axes)
for out_axis_src, out_axis, outval in zip(out_axes_src, out_axes, outvals):
with api.disable_jit(False):
donate_argnums_ = donate_argnums
@@ -1371,20 +1365,29 @@ def __call__(self, *args):
input_bufs = self._add_tokens_to_inputs(input_bufs)
results = self.xla_executable.execute_sharded(input_bufs, with_tokens=True)
- result_token_bufs = results.disassemble_prefix_into_single_device_arrays(
- len(self.ordered_effects))
+ if jaxlib_extension_version >= 391:
+ result_token_bufs = results.consume_with_handlers(
+ [lambda xs: xs] * len(self.ordered_effects), strict=False)
+ else:
+ result_token_bufs = results.disassemble_prefix_into_single_device_arrays(
+ len(self.ordered_effects))
sharded_runtime_token = results.consume_token()
self._handle_token_bufs(result_token_bufs, sharded_runtime_token)
else:
results = self.xla_executable.execute_sharded(input_bufs)
- if dispatch.needs_check_special():
+ if jaxlib_extension_version >= 391 or not dispatch.needs_check_special():
+ handlers = self.out_handler.handlers
+ if dispatch.needs_check_special():
+ special_check = functools.partial(
+ dispatch.check_special_array, self.name)
+ handlers = [h.pre_wrap(special_check) for h in handlers]
+ out = results.consume_with_handlers(handlers)
+ else:
out_arrays = results.disassemble_into_single_device_arrays()
for arrays in out_arrays:
dispatch.check_special(self.name, arrays)
out = self.out_handler(out_arrays)
- else:
- out = results.consume_with_handlers(self.out_handler.handlers)
if (self.pgle_profiler is not None and self.pgle_profiler.is_running()
and len(out) > 0):
@@ -1886,7 +1889,7 @@ def _discharge_internal_refs(jaxpr: core.ClosedJaxpr) -> core.ClosedJaxpr:
class SemanticallyEqualShardings:
def __init__(self, shardings: tuple[GSPMDSharding | UnspecifiedValue, ...],
- avals: tuple[core.AbstractValue]):
+ avals: Sequence[core.AbstractValue]):
gspmd_shardings = [
s if (isinstance(s, (UnspecifiedValue, AUTO)) or
(isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh)))
@@ -1894,7 +1897,6 @@ def __init__(self, shardings: tuple[GSPMDSharding | UnspecifiedValue, ...],
for s, a in zip(shardings, avals)]
self._gspmd_shardings = gspmd_shardings
self.shardings = shardings
- self.avals = avals
def __hash__(self):
return hash(tuple(
@@ -2374,7 +2376,14 @@ def lower_sharding_computation(
out_shardings, global_out_avals, device_assignment,
propagated_out_mem_kinds)
- # 2. Build up the HLO
+ global_in_avals = [core.update_aval_with_sharding(a, sh)
+ if isinstance(a, core.ShapedArray) else a
+ for a, sh in zip(global_in_avals, in_shardings)]
+ global_out_avals = [core.update_aval_with_sharding(a, sh)
+ if isinstance(a, core.ShapedArray) else a
+ for a, sh in zip(global_out_avals, out_shardings)]
+
+ ############################ Build up the stableHLO ######################
abstract_mesh = None
if prim_requires_devices:
@@ -2456,7 +2465,7 @@ def _to_logical_sharding(
return None
if isinstance(sharding, AUTO):
return sharding
- elif isinstance(aval, (ShapedArray, DShapedArray, AbstractRef)):
+ elif isinstance(aval, (ShapedArray, AbstractRef)):
assert isinstance(sharding, JSharding)
return sharding
elif isinstance(aval, core.AbstractToken):
@@ -3414,7 +3423,3 @@ def batch_spec(spec, dim, val):
spec += (None,) * too_short
new_partitions = tuple_insert(spec, dim, val) # type: ignore
return PartitionSpec(*new_partitions)
-
-def get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified:
- pspec = sharding_impls.prepare_axis_resources(pspec, "pspec to array_mapping")
- return _get_array_mapping(pspec)
diff --git a/jax/_src/lax/control_flow/__init__.py b/jax/_src/lax/control_flow/__init__.py
index 44ee94e14ca2..79488d89e3b7 100644
--- a/jax/_src/lax/control_flow/__init__.py
+++ b/jax/_src/lax/control_flow/__init__.py
@@ -50,11 +50,8 @@
# Private utilities used elsewhere in JAX
# TODO(sharadmv): lift them into a more common place
from jax._src.lax.control_flow.common import (
- _initial_style_open_jaxpr as _initial_style_open_jaxpr,
- _initial_style_jaxpr as _initial_style_jaxpr,
- _initial_style_jaxprs_with_common_consts as _initial_style_jaxprs_with_common_consts,
_check_tree_and_avals as _check_tree_and_avals,
-
+ _merge_common_consts as _merge_common_consts,
)
# TODO(mattjj): fix dependent library which expects optimization_barrier_p here
from jax._src.lax.lax import optimization_barrier_p as optimization_barrier_p
diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py
index d29746560b46..9518b4484bd9 100644
--- a/jax/_src/lax/control_flow/common.py
+++ b/jax/_src/lax/control_flow/common.py
@@ -15,7 +15,7 @@
from __future__ import annotations
-from collections.abc import Callable, Sequence
+from collections.abc import Sequence
import os
from functools import partial
from typing import Any
@@ -27,7 +27,7 @@
from jax._src.util import weakref_lru_cache, safe_map
from jax._src.interpreters import partial_eval as pe
from jax._src.tree_util import (equality_errors_pytreedef, tree_map,
- tree_unflatten, keystr, PyTreeDef)
+ tree_unflatten, keystr)
map, unsafe_map = safe_map, map
@@ -43,78 +43,54 @@ def _typecheck_param(prim, param, name, msg_required, pred):
msg = sep.join([msg, param_str])
raise core.JaxprTypeError(msg)
-# TODO(dougalm): this is a silly wrapper now. Delete it.
-@weakref_lru_cache
-def _initial_style_open_jaxpr(fun: Callable,
- in_tree: PyTreeDef,
- in_avals: Sequence[core.AbstractValue | core.AvalQDD],
- debug_info: core.DebugInfo):
- jaxpr, out_tree, consts = pe.trace_to_jaxpr(fun, in_tree, in_avals, debug_info)
- return jaxpr, consts, out_tree
-
-# TODO(dougalm): Delete. Make `trace_to_jaxpr` do the jaxpr-closing thing instead.
-@weakref_lru_cache
-def _initial_style_jaxpr(fun: Callable,
- in_tree: PyTreeDef,
- in_avals: Sequence[core.AbstractValue],
- debug_info: core.DebugInfo) -> tuple[core.ClosedJaxpr, Sequence[Any], PyTreeDef]:
- jaxpr, consts, out_tree = _initial_style_open_jaxpr(
- fun, in_tree, in_avals, debug_info)
- closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
- return closed_jaxpr, consts, out_tree
-
-def _initial_style_jaxprs_with_common_consts(
- funs: Sequence[Callable],
- in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue | core.AvalQDD],
- debug_infos: Sequence[core.DebugInfo]):
- jaxpr_data = [_initial_style_open_jaxpr(fn, in_tree, in_avals, debug_info)
- for fn, debug_info in zip(funs, debug_infos)]
- if not jaxpr_data: return [], [], []
- jaxprs, all_consts, all_out_trees = zip(*jaxpr_data)
-
+# TODO(dougalm): this seems way too complicated. Why not allow different consts for each
+# branch of a switch?
+def _merge_common_consts(
+ jaxprs: Sequence[core.ClosedJaxpr],
+ all_consts: Sequence[Sequence[Any]]
+ ) -> tuple[Sequence[core.ClosedJaxpr], Sequence[Any]]:
# Jaxprs must share consts, so we concat consts and pad the jaxprs' constvars.
lens = map(len, all_consts)
consts = [c for cs in all_consts for c in cs]
avalqdds = tuple(map(core.cur_aval_qdd, consts))
- jaxprs = [_pad_constvars(jaxpr, avalqdds[:sum(lens[:i])], avalqdds[sum(lens[:i+1]):])
- for i, jaxpr in enumerate(jaxprs)]
+ num_constss = [len(cs) for cs in all_consts]
+ jaxprs = [_pad_constvars(jaxpr, num_consts, avalqdds[:sum(lens[:i])], avalqdds[sum(lens[:i+1]):])
+ for i, (jaxpr, num_consts) in enumerate(zip(jaxprs, num_constss))]
# De-duplicate shared constants.
const_ids = tuple(id(c) for c in consts)
seen = set()
- consts = [c for c in consts if id(c) not in seen and not seen.add(id(c))] # type: ignore
- jaxprs = [_dedup_consts(jaxpr, const_ids) for jaxpr in jaxprs]
-
- closed_jaxprs = [pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
- for jaxpr in jaxprs]
- return closed_jaxprs, consts, all_out_trees
+ dd_consts = [c for c in consts if id(c) not in seen and not seen.add(id(c))] # type: ignore
+ jaxprs = [_dedup_consts(jaxpr, len(consts), const_ids) for jaxpr in jaxprs]
+ return jaxprs, dd_consts
@weakref_lru_cache
-def _pad_constvars(jaxpr: core.Jaxpr, left: tuple[core.AvalQDD, ...],
- right: tuple[core.AbstractValue, ...]) -> core.Jaxpr:
+def _pad_constvars(jaxpr: core.ClosedJaxpr, num_consts: int,
+ left: tuple[core.AvalQDD, ...],
+ right: tuple[core.AbstractValue, ...]) -> core.ClosedJaxpr:
def make_var(aq):
return core.Var(aq.aval, initial_qdd=aq.qdd, final_qdd=aq.qdd)
- constvars = [*map(make_var, left), *jaxpr.constvars, *map(make_var, right)]
- effs = pe._renumber_effects([*constvars, *jaxpr.invars],
- [*jaxpr.constvars, *jaxpr.invars], jaxpr.effects)
- jaxpr = jaxpr.replace(constvars=constvars, effects=effs)
- config.enable_checks.value and core.check_jaxpr(jaxpr)
+ invars = [*map(make_var, left), *jaxpr.invars[:num_consts],
+ *map(make_var, right), *jaxpr.invars[num_consts:]]
+ effs = pe._renumber_effects(invars, jaxpr.invars, jaxpr.effects)
+ jaxpr = jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(invars=invars, effects=effs))
+ config.enable_checks.value and core.check_jaxpr(jaxpr.jaxpr)
return jaxpr
@weakref_lru_cache
-def _dedup_consts(jaxpr, const_ids):
+def _dedup_consts(jaxpr, num_consts, const_ids):
newvars = {}
canonicalize = {v: newvars.setdefault(constid, v)
- for constid, v in zip(const_ids, jaxpr.constvars)}
+ for constid, v in zip(const_ids, jaxpr.invars[:num_consts])}
eqns = [e.replace(invars=[canonicalize.get(x, x) if isinstance(x, core.Var)
else x for x in e.invars]) for e in jaxpr.eqns]
outvars = [canonicalize.get(x, x) if isinstance(x, core.Var) else x
for x in jaxpr.outvars]
- constvars = list(newvars.values())
- effs = pe._renumber_effects(
- [*constvars, *jaxpr.invars],
- [*map(canonicalize.get, jaxpr.constvars), *jaxpr.invars], jaxpr.effects)
- jaxpr = jaxpr.replace(constvars=constvars, eqns=eqns, outvars=outvars,
- effects=effs)
+ invars = [*list(newvars.values()), *jaxpr.invars[num_consts:]]
+ effs = pe._renumber_effects(invars,
+ [*map(canonicalize.get, jaxpr.invars[:num_consts]), *jaxpr.invars[num_consts:]],
+ jaxpr.effects)
+ jaxpr = jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(invars=invars, eqns=eqns, outvars=outvars,
+ effects=effs))
config.enable_checks.value and core.check_jaxpr(jaxpr)
return jaxpr
diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py
index 668b880ab4dd..a5fc5e0e6232 100644
--- a/jax/_src/lax/control_flow/conditionals.py
+++ b/jax/_src/lax/control_flow/conditionals.py
@@ -18,14 +18,13 @@
from collections.abc import Callable, Sequence
import functools
from functools import partial
-import inspect
import itertools
import operator
from typing import Any, TypeVar
from jax._src.tree_util import (
tree_flatten, tree_unflatten, tree_flatten_with_path, keystr,
- equality_errors_pytreedef)
+ equality_errors_pytreedef, FlatTree)
from jax._src import ad_util
from jax._src import api_util
from jax._src import config
@@ -53,7 +52,7 @@
import numpy as np
from jax._src.lax.control_flow.common import (
- _avals_short, _typecheck_param, _initial_style_jaxprs_with_common_consts,
+ _avals_short, _typecheck_param, _merge_common_consts,
_make_closed_jaxpr, _prune_zeros)
map, unsafe_map = safe_map, map
@@ -143,20 +142,23 @@ def _switch_internal(
dbgs = [api_util.debug_info("switch", branch, operands, {})
for branch in branches]
- ops, ops_tree = tree_flatten(operands)
- ops_avals = tuple(map(core.get_aval, ops))
+ args = FlatTree.flatten((operands, {}))
+ avals = args.map(core.get_aval)
if config.mutable_array_checks.value:
- api_util.check_no_aliased_ref_args(lambda: dbgs[0], ops_avals, ops)
+ api_util.check_no_aliased_ref_args(lambda: dbgs[0], list(avals), list(args))
+
+ jaxprs_, out_avalss = zip(*[pe.trace_to_jaxpr(branch, avals, dbg)
+ for branch, dbg in zip(branches, dbgs)])
+ jaxprs_, all_consts = zip(*[pe.separate_consts(j) for j in jaxprs_])
+ jaxprs, consts = _merge_common_consts(jaxprs_, all_consts)
- jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
- branches, ops_tree, ops_avals, dbgs)
if config.mutable_array_checks.value:
- api_util._check_no_aliased_closed_over_refs(dbgs[0], (*jaxprs[0].consts, *consts), ops)
- for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
+ api_util._check_no_aliased_closed_over_refs(dbgs[0], (*jaxprs[0].consts, *consts), list(args))
+ for i, (out_avals, jaxpr) in enumerate(zip(out_avalss[1:], jaxprs[1:])):
_check_branch_outputs(
"switch", "branch 0", f"branch{i+1}", branches[0], branches[i+1],
- out_trees[0], out_tree, jaxprs[0].out_avals, jaxpr.out_avals)
+ out_avalss[0], out_avals)
# prune passthrough outputs
fwds = [pe._jaxpr_forwarding(jaxpr.jaxpr) for jaxpr in jaxprs]
in_fwd = [xs[0] if len(set(xs)) == 1 else None for xs in zip(*fwds)]
@@ -172,19 +174,19 @@ def _switch_internal(
params = dict(branches=tuple(jaxprs))
if branches_platforms is not None:
params["branches_platforms"] = branches_platforms
- out = cond_p.bind(index, *consts, *ops, **params)
+ out = cond_p.bind(index, *consts, *args, **params)
out_ = iter(out)
- all_inputs = [*consts, *ops]
+ all_inputs = [*consts, *args]
out = [
next(out_) if fwd is None else lax.asarray(all_inputs[fwd])
for fwd in in_fwd
]
assert next(out_, None) is None
- return tree_unflatten(out_trees[0], out)
+ return out_avalss[0].update_from_list(out).unflatten()
@partial(api_boundary, repro_api_name="jax_cond")
-def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
+def cond(pred, true_fun: Callable, false_fun: Callable, *operands,
operand=_no_operand_sentinel):
"""Conditionally apply ``true_fun`` or ``false_fun``.
@@ -224,7 +226,12 @@ def cond(pred, true_fun, false_fun, *operands):
pytree (nested Python tuple/list/dict) thereof.
"""
if not (callable(true_fun) and callable(false_fun)):
- raise TypeError("lax.cond: true_fun and false_fun arguments should be callable.")
+ # try falling back to the old, deprecated version of `cond`
+ if callable(false_fun) and len(operands) == 2 and callable(operands[1]):
+ x_true, f_true, x_false, f_false = true_fun, false_fun, *operands
+ return cond(pred, lambda x, _: f_true(x), lambda _, x: f_false(x), x_true, x_false)
+ else:
+ raise TypeError("lax.cond: true_fun and false_fun arguments should be callable.")
if operand is not _no_operand_sentinel:
if operands:
raise TypeError("if 'operand' keyword is passed then no positional "
@@ -260,31 +267,33 @@ def cond(pred, true_fun, false_fun, *operands):
else:
return false_fun(*operands)
- ops, ops_tree = tree_flatten(operands)
- ops_avals = tuple(map(core.get_aval, ops))
- ops_avals = tuple(core.AvalQDD(a, cur_qdd(x)) if a.has_qdd # type: ignore
- else a for a, x in zip(ops_avals, ops))
-
-
- dbg_true_fun = api_util.debug_info("cond", true_fun, operands, {})
+ args = FlatTree.flatten((operands, {}))
+ avals = args.map(core.get_aval)
+ avals = avals.map2(
+ lambda a, x: core.AvalQDD(a, cur_qdd(x)) if a.has_qdd else a,
+ args)
+ dbg_true = api_util.debug_info("cond", true_fun, operands, {})
if config.mutable_array_checks.value:
- api_util.check_no_aliased_ref_args(lambda: dbg_true_fun, ops_avals, ops)
- dbg_false_fun = api_util.debug_info("cond", false_fun, operands, {})
- jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
- (true_fun, false_fun), ops_tree, ops_avals,
- [dbg_true_fun, dbg_false_fun])
- true_jaxpr, false_jaxpr = jaxprs
+ api_util.check_no_aliased_ref_args(lambda: dbg_true, list(avals), list(args))
+ dbg_false = api_util.debug_info("cond", false_fun, operands, {})
+
+ true_jaxpr_, out_avals = pe.trace_to_jaxpr(true_fun, avals, dbg_true)
+ true_jaxpr_, true_consts = pe.separate_consts(true_jaxpr_)
+ false_jaxpr_, false_out_avals = pe.trace_to_jaxpr(false_fun, avals, dbg_false)
+ false_jaxpr_, false_consts = pe.separate_consts(false_jaxpr_)
+ (true_jaxpr, false_jaxpr), consts = _merge_common_consts(
+ (true_jaxpr_, false_jaxpr_), (true_consts, false_consts))
if config.mutable_array_checks.value:
- api_util._check_no_aliased_closed_over_refs(dbg_true_fun, (*true_jaxpr.consts, *consts), ops)
+ api_util._check_no_aliased_closed_over_refs(
+ dbg_true, (*true_jaxpr.consts, *consts), list(args))
- out_tree, false_out_tree = out_trees
if any(isinstance(out_aval, AbstractRef) for out_aval in
true_jaxpr.out_avals + false_jaxpr.out_avals):
raise ValueError("Cannot return `Ref`s from `cond`.")
_check_branch_outputs(
- 'cond', 'true_fun', 'false_fun', true_fun, false_fun, out_tree,
- false_out_tree, true_jaxpr.out_avals, false_jaxpr.out_avals)
+ 'cond', 'true_fun', 'false_fun',
+ true_fun, false_fun, out_avals, false_out_avals)
# prune passthrough outputs
true_fwds = pe._jaxpr_forwarding(true_jaxpr.jaxpr)
@@ -304,24 +313,23 @@ def cond(pred, true_fun, false_fun, *operands):
false_jaxpr = replace_jaxpr_effects(false_jaxpr, joined_effects)
true_jaxpr = replace_jaxpr_effects(true_jaxpr, joined_effects)
- out = cond_p.bind(index, *consts, *ops, branches=(false_jaxpr, true_jaxpr))
+ out = cond_p.bind(index, *consts, *args, branches=(false_jaxpr, true_jaxpr))
out_ = iter(out)
- all_inputs = [*consts, *ops]
+ all_inputs = [*consts, *args]
out = [
next(out_) if fwd is None else lax.asarray(all_inputs[fwd])
for fwd in in_fwd
]
assert next(out_, None) is None
- return tree_unflatten(out_tree, out)
+ return out_avals.update_from_list(out).unflatten()
def _check_branch_outputs(
- api_name, name1, name2, f1, f2, out_tree1, out_tree2, out_avals1,
- out_avals2) -> None:
+ api_name, name1, name2, f1, f2, out_avals1, out_avals2) -> None:
info1 = api_util.fun_sourceinfo(f1)
info2 = api_util.fun_sourceinfo(f2)
try:
- outs1 = tree_unflatten(out_tree1, out_avals1)
+ outs1 = out_avals1.unflatten()
except:
paths = [None] * len(out_avals1)
component = lambda _: ''
@@ -330,11 +338,11 @@ def _check_branch_outputs(
paths, _ = unzip2(leaves_and_paths) # type: ignore
component = lambda p: f' at path {keystr(p)}' if p else ''
- if out_tree1 != out_tree2:
+ if out_avals1.tree != out_avals2.tree:
diffs = [f'{name1} output{component(p)} is a {thing1} but '
f'{name2} output{component(p)} is a {thing2}, so {expl}'
for p, thing1, thing2, expl
- in equality_errors_pytreedef(out_tree1, out_tree2)]
+ in equality_errors_pytreedef(out_avals1.tree, out_avals2.tree)]
if len(diffs) == 0:
return # the trees may have different aux data, but structures are same
@@ -399,48 +407,6 @@ def _capitalize(s):
# s.capitalize() converts s[1:] to lowercase which we don't want.
return s[0].capitalize() + s[1:]
-@api_boundary
-@functools.wraps(_cond)
-def cond(*args, **kwargs):
- # detect an attempt to call the former, deprecated cond
- try:
- ba = inspect.signature(_cond_with_per_branch_args).bind(*args, **kwargs)
- except TypeError:
- pass
- else:
- assert not ba.kwargs # no catch-all **kwargs in _cond_with_per_branch
- _, true_operand, true_fun, false_operand, false_fun = ba.args
- if callable(true_operand) and callable(true_fun):
- # treat this as modern cond (with two operands)
- return _cond(*args, **kwargs)
- if callable(true_fun) and callable(false_fun):
- return _cond_with_per_branch_args(*ba.args)
-
- return _cond(*args, **kwargs)
-
-@partial(api_boundary, repro_api_name="jax_cond_with_per_branch_args")
-def _cond_with_per_branch_args(pred,
- true_operand, true_fun: Callable,
- false_operand, false_fun: Callable):
- """Conditionally apply ``true_fun`` or ``false_fun``.
-
- Has equivalent semantics to this Python implementation::
-
- def cond(pred, true_operand, true_fun, false_operand, false_fun):
- if pred:
- return true_fun(true_operand)
- else:
- return false_fun(false_operand)
-
- Pred has to be a scalar type, collection types (list, tuple) are not supported
- """
- if not (callable(true_fun) and callable(false_fun)):
- raise TypeError("lax.cond: true_fun and false_fun arguments should be callable.")
- return _cond(pred,
- lambda op: true_fun(op[0]),
- lambda op: false_fun(op[1]),
- (true_operand, false_operand))
-
def _join_cond_effects(branches: Sequence[core.ClosedJaxpr]) -> effects.Effects:
joined_effects = set()
for b in branches:
@@ -870,49 +836,6 @@ def _cond_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn,
return [True, *used_inputs], new_eqn
-def _transpose_cond_jaxpr(jaxpr: core.ClosedJaxpr,
- num_res: int):
- res_avals, primal_avals = split_list(jaxpr.in_avals, [num_res])
-
- def transposed(*args):
- res, cts_out = split_list(args, [num_res])
- primals = res + [ad.UndefinedPrimal(aval) for aval in primal_avals]
- cts_in = ad.backward_pass(
- jaxpr.jaxpr, False, jaxpr.consts, primals, cts_out)
- _, cts_in = split_list(cts_in, [num_res])
- return map(ad.instantiate_zeros, cts_in)
-
- return _make_closed_jaxpr(lu.wrap_init(transposed,
- debug_info=jaxpr.jaxpr.debug_info),
- res_avals + jaxpr.out_avals)
-
-def _cond_transpose(cts, *args, branches, **params):
- index, *ops = args
- assert type(index) is not ad.UndefinedPrimal
- linear = [type(x) is ad.UndefinedPrimal for x in ops]
- in_avals = branches[0].in_avals
- num_res = len(ops) - sum(linear)
- if any(isinstance(eff, RefEffect) for branch in branches for eff in
- branch.jaxpr.effects):
- raise NotImplementedError("State effect not supported in cond transpose.")
-
- branches_trans = [_transpose_cond_jaxpr(jaxpr, num_res) for jaxpr in branches]
- lin_in_avals = [a.strip_weak_type() for a, l in zip(in_avals, linear) if l]
- assert all(core.typematch(out_aval, lin_in_aval)
- for jaxpr in branches_trans
- for out_aval, lin_in_aval in zip(jaxpr.out_avals, lin_in_avals))
-
- res = ops[:num_res]
- cts = map(ad.instantiate_zeros, cts)
-
- out = cond_p.bind(index, *res, *cts, branches=tuple(branches_trans), **params)
- assert all(map(core.typecheck, lin_in_avals, out))
-
- out_iter = iter(out)
- out = [next(out_iter) if l else None for l in linear]
- assert next(out_iter, None) is None
- return [None] + out
-
def _cond_transpose_fancy(cts_in, index, *args, branches, **params):
assert not isinstance(index, ad.GradAccum)
primals_ctrefs, specs = ad.project_accums(args)
@@ -1021,7 +944,6 @@ def _cond_typecheck(bind_time, *in_atoms, branches, **params):
cond_p.def_impl(partial(dispatch.apply_primitive, cond_p))
cond_p.def_effectful_abstract_eval(_cond_abstract_eval)
ad.primitive_jvps[cond_p] = _cond_jvp
-ad.primitive_transposes[cond_p] = _cond_transpose
ad.primitive_linearizations[cond_p] = _cond_linearize
ad.fancy_transposes[cond_p] = _cond_transpose_fancy
pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval
@@ -1030,7 +952,6 @@ def _cond_typecheck(bind_time, *in_atoms, branches, **params):
core.custom_typechecks[cond_p] = partial(_cond_typecheck, False)
pe.partial_eval_jaxpr_custom_rules[cond_p] = _cond_partial_eval_custom
pe.dce_rules[cond_p] = _cond_dce_rule
-batching.ragged_prop_rules[cond_p] = batching.ragged_mask_assert_no_op_rule
def _cond_is_high(*_, branches, **__) -> bool:
return any(j.jaxpr.is_high for j in branches)
diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py
index d5e31b0ca0a2..1b0b9f6874be 100644
--- a/jax/_src/lax/control_flow/loops.py
+++ b/jax/_src/lax/control_flow/loops.py
@@ -38,7 +38,8 @@
from jax._src import util
from jax._src.api_util import (
check_no_aliased_ref_args, _check_no_aliased_closed_over_refs)
-from jax._src.core import ShapedArray, typeof, cur_qdd, ClosedJaxpr
+from jax._src.core import (
+ ShapedArray, typeof, cur_qdd, ClosedJaxpr, AbstractValue)
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
@@ -50,7 +51,7 @@
from jax._src.lax import slicing
from jax._src.lax import windowed_reductions
from jax._src.lax.control_flow.common import (
- _avals_short, _initial_style_jaxpr, _prune_zeros, _typecheck_param,
+ _avals_short, _prune_zeros, _typecheck_param,
_make_closed_jaxpr)
from jax._src.lax.other import logaddexp
from jax._src.pjit import auto_axes, PartitionSpec as P, reshard
@@ -66,8 +67,8 @@
split_list_checked, unzip2, weakref_lru_cache, subs_list)
from jax._src import xla_bridge as xb
from jax._src.tree_util import (
- keystr, tree_flatten, tree_flatten_with_path, tree_map, tree_unflatten,
- treedef_is_leaf)
+ keystr, tree_flatten, tree_map, tree_unflatten,
+ treedef_is_leaf, FlatTree, tree_leaves_with_path)
import numpy as np
_map = safe_map
@@ -81,29 +82,14 @@
def _stack(arrs: Sequence[Array], axis: int=0) -> Array:
return lax.concatenate([lax.expand_dims(arr, (axis,)) for arr in arrs], dimension=axis)
-def _promote_weak_typed_inputs(in_vals, in_avals, out_avals):
- """Promote weakly-typed in_vals to be compatible with out_avals.
-
- Args:
- in_vals : flattened list of input values.
- in_avals : corresponding list of avals.
- out_avals : list of target output avals.
- Returns:
- in_vals_new : flattened list of modified in_vals with no weak types.
- changed : bool; true if in_vals required modification.
- """
- if len(in_vals) != len(in_avals) or len(in_avals) != len(out_avals):
- # Calling function is responsible for catching this.
- return in_vals, False
- weak_mismatches = [i for i, (a1, a2) in enumerate(zip(in_avals, out_avals))
- if getattr(a1, 'weak_type', False) and not core.typematch(a1, a2)]
- if not weak_mismatches:
- return in_vals, False
- for i in weak_mismatches:
- new_dtype = dtypes.result_type(in_vals[i], out_avals[i])
- in_vals[i] = lax.convert_element_type(in_vals[i], new_dtype)
- return in_vals, True
-
+def _promote_weak_typed_input(
+ in_val:Any, in_aval:AbstractValue, out_aval:AbstractValue
+ ) -> tuple[Any, bool]:
+ if getattr(in_aval, 'weak_type', False) and not core.typematch(in_aval, out_aval):
+ new_dtype = dtypes.result_type(in_val, out_aval)
+ return lax.convert_element_type(in_val, new_dtype), True
+ else:
+ return in_val, False
### scan
@@ -215,85 +201,44 @@ def scan(f, init, xs, length=None):
"""
if not callable(f):
raise TypeError("lax.scan: f argument should be a callable.")
- xs_flat, xs_tree = tree_flatten(xs)
- try:
- lengths = [x.shape[0] for x in xs_flat]
- except AttributeError as err:
- msg = "scan got value with no leading axis to scan over: {}."
- raise ValueError(
- msg.format(', '.join(str(x) for x in xs_flat
- if not hasattr(x, 'shape')))) from err
+ dbg_body = api_util.debug_info("scan", f, (init, xs), {})
+ init = FlatTree.flatten(init)
+ xs = FlatTree.flatten(xs)
+ args = FlatTree.pack((init, xs))
- xs_avals = [core.get_aval(x) for x in xs_flat]
+ args_avals = args.map(core.get_aval)
+ init_avals, xs_avals = args_avals.unpack()
- if not all(a.sharding.spec[0] is None for a in xs_avals):
- raise ValueError('0th dimension of all xs should be replicated. Got '
- f'{", ".join(str(a.sharding.spec) for a in xs_avals)}')
-
- if length is not None:
- try:
- length = int(length)
- except core.ConcretizationTypeError as err:
- msg = ('The `length` argument to `scan` expects a concrete `int` value.'
- ' For scan-like iteration with a dynamic length, use `while_loop`'
- ' or `fori_loop`.')
- raise core.ConcretizationTypeError(length, msg) from None # type: ignore[arg-type]
- if not all(length == l for l in lengths):
- msg = ("scan got `length` argument of {} which disagrees with "
- "leading axis sizes {}.")
- raise ValueError(msg.format(length, [x.shape[0] for x in xs_flat]))
- else:
- unique_lengths = set(lengths)
- if len(unique_lengths) > 1:
- msg = "scan got values with different leading axis sizes: {}."
- raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat)))
- elif len(unique_lengths) == 0:
- msg = "scan got no values to scan over and `length` not provided."
- raise ValueError(msg)
- else:
- length, = unique_lengths
+ length = _infer_scan_length(list(xs), list(xs_avals), length)
if config.disable_jit.value:
if length == 0:
raise ValueError("zero-length scan is not supported in disable_jit() "
"mode because the output type is unknown.")
- carry = init
+ carry = init.unflatten()
ys = []
maybe_reversed = reversed if reverse else lambda x: x
for i in maybe_reversed(range(length)):
- xs_slice = [slicing.index_in_dim(x, i, keepdims=False) for x in xs_flat]
- carry, y = f(carry, tree_unflatten(xs_tree, xs_slice))
+ xs_slice = xs.map(lambda x: slicing.index_in_dim(x, i, keepdims=False))
+ carry, y = f(carry, xs_slice.unflatten())
ys.append(y)
stack = lambda *ys: _stack(ys)
stacked_y = tree_map(stack, *maybe_reversed(ys))
return carry, stacked_y
- x_avals = [core.mapped_aval(length, 0, aval) for aval in xs_avals]
- dbg_body = api_util.debug_info("scan", f, (init, xs), {})
-
if config.mutable_array_checks.value:
- in_flat, in_tree = tree_flatten((init, xs))
- in_avals = tuple(_map(core.get_aval, in_flat))
- check_no_aliased_ref_args(lambda: dbg_body, in_avals, in_flat)
-
- def _create_jaxpr(init):
- init_flat, init_tree = tree_flatten(init)
- in_flat, in_tree = tree_flatten((init, xs))
- carry_avals = tuple(_map(core.get_aval, init_flat))
- open_jaxpr, out_tree, consts = pe.trace_to_jaxpr(
- f, in_tree, (*carry_avals, *x_avals), debug_info=dbg_body)
- jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(open_jaxpr))
- if config.mutable_array_checks.value:
- _check_no_aliased_closed_over_refs(dbg_body, (*jaxpr.consts, *consts), in_flat)
- out_tree_children = out_tree.children()
- if len(out_tree_children) != 2:
+ check_no_aliased_ref_args(lambda: dbg_body, list(args_avals), list(args))
+
+ x_avals = xs_avals.map(lambda aval: core.mapped_aval(length, 0, aval))
+ def _create_jaxpr(carry_avals):
+ new_arg_avals = FlatTree.pack(((carry_avals, x_avals), {}))
+ jaxpr, out_avals = pe.trace_to_jaxpr(f, new_arg_avals, dbg_body)
+ jaxpr, consts = pe.separate_consts(jaxpr)
+ if len(out_avals.unpack()) != 2:
msg = "scan body output must be a pair, got {}."
- raise TypeError(msg.format(tree_unflatten(out_tree, jaxpr.out_avals)))
-
- carry_avals_out, _ = split_list(jaxpr.out_avals, [out_tree_children[0].num_leaves])
- return (init_flat, carry_avals, carry_avals_out, init_tree, in_flat, jaxpr,
- consts, out_tree, out_tree_children)
+ raise TypeError(msg.format(out_avals.unflatten()))
+ return jaxpr, out_avals, consts
# The carry input and output avals must match exactly. However, we want to account for
# the case when init contains weakly-typed values (e.g. Python scalars), with avals that
@@ -303,18 +248,23 @@ def _create_jaxpr(init):
# TODO(dougalm): this two-pass stuff is expensive (exponential in scan nesting
# depth) and incomplete (because in the general case it takes more than two passes).
# Let's get rid of it, perhaps after getting rid of weak types altogether.
- init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
- new_init_flat, changed = _promote_weak_typed_inputs(init_flat, carry_avals, carry_avals_out)
- if changed:
- init = tree_unflatten(init_tree, new_init_flat)
- init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
- in_flat, jaxpr, consts, out_tree, out_tree_children = rest
- num_carry = len(init_flat)
- num_xs = len(x_avals)
- num_ys = len(jaxpr.out_avals) - num_carry
- del init_flat
+ jaxpr, out_avals, consts = _create_jaxpr(init_avals)
+ if config.mutable_array_checks.value:
+ _check_no_aliased_closed_over_refs(dbg_body, consts, list(args))
+ carry_out_avals, ys_avals = out_avals.unpack()
+ if len(carry_out_avals) != len(init_avals):
+ _check_carry_type('scan body', f, init_avals, carry_out_avals)
+ init, changed = init.map3(
+ _promote_weak_typed_input,
+ init_avals, carry_out_avals).unzip2()
+ num_carry, num_xs, num_ys = len(init), len(xs), len(ys_avals)
+ if any(changed):
+ init_avals = init.map(core.get_aval)
+ jaxpr, out_avals, consts = _create_jaxpr(init_avals)
+ carry_out_avals, ys_avals = out_avals.unpack()
+
+ _check_carry_type('scan body', f, init_avals, carry_out_avals)
- _check_carry_type('scan body', f, init, out_tree_children[0], carry_avals_out)
disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(jaxpr.effects)
if disallowed_effects:
raise NotImplementedError(
@@ -329,9 +279,12 @@ def _create_jaxpr(init):
if unroll < 0:
raise ValueError("`unroll` must be a `bool` or a non-negative `int`.")
+ args_flat = [*init.vals, *xs.vals]
+
# If the body forwards an input carry to an output carry, that input is
# read-only and can be moved to be a const. Doing so can lead to efficiency
# wins, e.g. if the scan is inside a cond with a batched predicate.
+ num_ys = len(jaxpr.out_avals) - num_carry
carry_fwd, ext_fwd = split_list(pe._jaxpr_forwarding(jaxpr.jaxpr), [num_carry])
move_to_const = [len(consts) + i == f for i, f in enumerate(carry_fwd)]
if any(move_to_const):
@@ -339,7 +292,7 @@ def _create_jaxpr(init):
jaxpr, [not m for m in move_to_const] + [True] * num_ys)
jaxpr = pe.move_binders_to_front(
jaxpr, [False] * len(consts) + move_to_const + [False] * num_xs)
- in_flat, new_consts = partition_list(move_to_const + [False] * num_xs, in_flat)
+ args_flat, new_consts = partition_list(move_to_const + [False] * num_xs, args_flat)
consts = [*new_consts, *consts]
num_carry -= len(new_consts)
@@ -356,30 +309,67 @@ def _create_jaxpr(init):
jaxpr = pe.prune_closed_jaxpr_outputs(
jaxpr, [True] * num_carry + [i is None for i in ext_to_ext_fwd])
- out = scan_p.bind(*consts, *in_flat,
+ out = scan_p.bind(*consts, *args_flat,
reverse=reverse, length=length, jaxpr=jaxpr,
num_consts=len(consts), num_carry=num_carry,
- linear=(False,) * (len(consts) + len(in_flat)),
+ linear=(False,) * (len(consts) + len(args_flat)),
unroll=unroll, _split_transpose=_split_transpose)
# Apply input to output forwarding that was computed above.
carry_out, out = split_list(out, [num_carry])
out_ = iter(out)
- out = [next(out_) if f is None else _maybe_put(in_flat[f]) for f in ext_to_ext_fwd]
+ out = [next(out_) if f is None else _maybe_put(args_flat[f]) for f in ext_to_ext_fwd]
assert next(out_, None) is None
out = [*carry_out, *out]
if any(move_to_const):
out = pe.merge_lists(move_to_const + [False] * num_ys, out, new_consts)
- return tree_unflatten(out_tree, out)
+ return out_avals.update_from_list(out).unflatten()
+def _infer_scan_length(
+ xs_flat: list[Any], xs_avals: list[AbstractValue],
+ length: int | None) -> int:
+ try:
+ lengths = [x.shape[0] for x in xs_flat]
+ except AttributeError as err:
+ msg = "scan got value with no leading axis to scan over: {}."
+ raise ValueError(
+ msg.format(', '.join(str(x) for x in xs_flat
+ if not hasattr(x, 'shape')))) from err
+
+ if not all(a.sharding.spec[0] is None for a in xs_avals):
+ raise ValueError('0th dimension of all xs should be replicated. Got '
+ f'{", ".join(str(a.sharding.spec) for a in xs_avals)}')
+
+ if length is not None:
+ try:
+ return int(length)
+ except core.ConcretizationTypeError as err:
+ msg = ('The `length` argument to `scan` expects a concrete `int` value.'
+ ' For scan-like iteration with a dynamic length, use `while_loop`'
+ ' or `fori_loop`.')
+ raise core.ConcretizationTypeError(length, msg) from None # type: ignore[arg-type]
+ if not all(length == l for l in lengths):
+ msg = ("scan got `length` argument of {} which disagrees with "
+ "leading axis sizes {}.")
+ raise ValueError(msg.format(length, [x.shape[0] for x in xs_flat]))
+ else:
+ unique_lengths = set(lengths)
+ if len(unique_lengths) > 1:
+ msg = "scan got values with different leading axis sizes: {}."
+ raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat)))
+ elif len(unique_lengths) == 0:
+ msg = "scan got no values to scan over and `length` not provided."
+ raise ValueError(msg)
+ else:
+ return list(unique_lengths)[0]
def _capitalize(s):
# s.capitalize() converts s[1:] to lowercase which we don't want.
return s[0].capitalize() + s[1:]
-def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals):
+def _check_carry_type(name, body_fun, in_carry, out_carry):
try:
sig = inspect.signature(body_fun)
except (ValueError, TypeError):
@@ -391,23 +381,20 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals):
else:
component = lambda p: (f'the input carry at path {keystr(p)}'
if p else 'the input carry')
- leaves_and_paths, in_carry_tree = tree_flatten_with_path(in_carry)
- paths, in_carry_flat = unzip2(leaves_and_paths)
- in_avals = _map(core.get_aval, in_carry_flat)
- if in_carry_tree != out_carry_tree:
+ if in_carry.tree != out_carry.tree:
try:
- out_carry = tree_unflatten(out_carry_tree, out_avals)
+ out_carry_unflat = out_carry.unflatten()
except:
- out_carry = None
+ out_carry_unflat = None
- if out_carry is None:
- differences = (f'the input tree structure is:\n{in_carry_tree}\n' +
- f'the output tree structure is:\n{out_carry_tree}\n')
+ if out_carry_unflat is None:
+ differences = (f'the input tree structure is:\n{in_carry.tree}\n' +
+ f'the output tree structure is:\n{out_carry.tree}\n')
else:
diffs = [f'{component(path)} is a {thing1} but the corresponding component '
f'of the carry output is a {thing2}, so {explanation}'
for path, thing1, thing2, explanation
- in equality_errors(in_carry, out_carry)]
+ in equality_errors(in_carry.unflatten(), out_carry.unflatten())]
if len(diffs) == 0:
return # the trees may have different aux data, but structures are same
elif len(diffs) == 1:
@@ -421,12 +408,14 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals):
f"{differences}\n"
"Revise the function so that the carry output has the same pytree "
"structure as the carry input.")
- if not all(_map(core.typematch, in_avals, out_avals)):
+ if not all(_map(core.typematch, in_carry, out_carry)):
+ # TODO(dougalm): add a way to get paths paths without roundtripping
+ paths, _ = unzip2(tree_leaves_with_path(in_carry.unflatten()))
diffs = [f'{component(path)} has type {in_aval.str_short()}'
' but the corresponding output carry component has type '
f'{out_aval.str_short()}'
f'{core.aval_mismatch_extra(in_aval, out_aval)}'
- for path, in_aval, out_aval in zip(paths, in_avals, out_avals)
+ for path, in_aval, out_aval in zip(paths, in_carry, out_carry)
if not core.typematch(in_aval, out_aval)]
if len(diffs) == 0:
@@ -441,7 +430,7 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals):
f'applying `jax.lax.pcast(..., {tuple(out_aval.vma - in_aval.vma)},'
" to='varying')` to the initial carry value corresponding to"
f' {component(path)}'
- for path, in_aval, out_aval in zip(paths, in_avals, out_avals)
+ for path, in_aval, out_aval in zip(paths, in_carry, out_carry)
if not core.typematch(in_aval, out_aval) and
isinstance(in_aval, ShapedArray) and isinstance(out_aval, ShapedArray)
and in_aval.vma != out_aval.vma and out_aval.vma - in_aval.vma]
@@ -934,225 +923,6 @@ def _rearrange_mutable_binders(
if config.enable_checks.value: core.check_jaxpr(new_jaxpr)
return ClosedJaxpr(new_jaxpr, jaxpr.consts)
-def _scan_transpose(cts, *args, reverse, length, num_consts,
- num_carry, jaxpr, linear, unroll, _split_transpose):
- # we've only implemented transposing scans with specific lin/nonlin patterns
- consts_lin, init_lin, xs_lin = split_list(linear, [num_consts, num_carry])
- num_ires = len(consts_lin) - sum(consts_lin)
- num_eres = len(xs_lin) - sum(xs_lin)
- if consts_lin != [False] * num_ires + [True] * (len(consts_lin) - num_ires):
- raise NotImplementedError
- if xs_lin != [True] * (len(xs_lin) - num_eres) + [False] * num_eres:
- raise NotImplementedError
- if not all(init_lin):
- pass # TODO(mattjj): error check https://github.com/jax-ml/jax/issues/1963
-
- # We follow a funny convention of passing cotangent refs like primals, so they
- # appear in `args` mixed in with the UndefinedPrimals of `T d` and `T a`.
- # Rearrange jaxpr binders and arguments to put cotangent mutable arrays first:
- # Before: [ires, T d, T c, T a, eres] -> [T c, T b]
- # After: [ires, T d_mut, T d_pure, T c, T a_mut, T a_pure, eres] -> [T c, T b]
- # where
- # * `ires` means intensive (not scanned over / const) residuals
- # * `T d` means the intensive tangents
- # * `T c` means the tangent carry
- # * `T a` means the extensive (scanned over) tangent inputs
- # * `eres` means the extensive residuals
- # * `T b` means the extensive tangent outputs
- ires, consts_dot, carry_dot, xs_dot, eres = split_list(
- args, [num_ires, num_consts - num_ires, num_carry, sum(xs_lin)])
- _, const_avals, _, xs_avals, _ = split_list(
- jaxpr.in_avals, [num_ires, num_consts - num_ires, num_carry, sum(xs_lin)])
- is_mutable = [isinstance(a, AbstractRef) for a in const_avals]
- immut_consts_dot, mut_consts_bar = partition_list(is_mutable, consts_dot)
- jaxpr = _rearrange_mutable_binders(jaxpr, num_ires, num_consts - num_ires)
- del const_avals, consts_dot
- is_mutable_ = [isinstance(a, AbstractRef) for a in xs_avals]
- immut_xs_dot, mut_xs_bar = partition_list(is_mutable_, xs_dot)
- jaxpr = _rearrange_mutable_binders(jaxpr, num_consts + num_carry, sum(xs_lin))
- del xs_avals, xs_dot
- # Check that pure tangent values are all UndefinedPrimals, and mutable
- # 'tangent values' are not (since we actually put cotangent refs there).
- assert not any(ad.is_undefined_primal(r) for r in ires)
- assert not any(ad.is_undefined_primal(x) for x in mut_consts_bar)
- # TODO(mattjj): re-enable these asserts
- # assert all(ad.is_undefined_primal(x) for x in immut_consts_dot)
- # assert all(ad.is_undefined_primal(x) for x in carry_dot)
- # assert all(ad.is_undefined_primal(x) for x in immut_xs_dot)
- assert not any(ad.is_undefined_primal(r) for r in eres)
- del args
-
- # Take apart passed-in cotangents to identify which are sym zeros.
- ct_carry, ct_ys = split_list(cts, [num_carry])
- ct_carry = _map(ad.instantiate_zeros, ct_carry)
- ct_ys_is_zeros = [type(ct_y) is ad.Zero for ct_y in ct_ys]
- ct_ys_nz = [x for x in ct_ys if type(x) is not ad.Zero]
- ct_immut_consts = _map(ad_util.zeros_like_aval,
- jaxpr.in_avals[num_ires+len(mut_consts_bar):num_consts])
-
- jaxpr_trans = _transpose_scan_jaxpr(
- jaxpr, num_ires, len(mut_consts_bar), len(immut_consts_dot),
- len(mut_xs_bar), len(immut_xs_dot), num_eres, tuple(ct_ys_is_zeros))
-
- linear_trans = ([False] * num_ires +
- [True] * (len(mut_consts_bar) + len(immut_consts_dot) +
- len(carry_dot) + len(mut_xs_bar) + len(ct_ys_nz)) +
- [False] * num_eres)
- transpose_inputs = [*ires, *mut_consts_bar, *ct_immut_consts, *ct_carry,
- *mut_xs_bar, *ct_ys_nz, *eres]
-
- if not _split_transpose:
- outs = scan_p.bind(
- *transpose_inputs,
- reverse=not reverse, length=length, jaxpr=jaxpr_trans,
- num_consts=num_ires + len(mut_consts_bar),
- num_carry=len(immut_consts_dot) + len(carry_dot),
- linear=tuple(linear_trans), unroll=unroll,
- _split_transpose=False)
- else:
- if len(mut_consts_bar): raise NotImplementedError
- transpose_num_out_carry = num_consts-num_ires+num_carry
- inst_mask = [False] * transpose_num_out_carry + [True] * (
- len(jaxpr_trans.out_avals) - transpose_num_out_carry)
-
- unknowns_mask = [False] * (len(transpose_inputs) - len(eres)) + [
- True
- ] * len(eres)
-
- # The residuals may contain original parameters (e.g. forwarded extensive
- # array arguments) and residuals from the primal. Hence we iterate and
- # update all values of the mask that we've set to True (i.e. 'unknown') to
- # see if we should actually push them to the known computation in order to
- # perform the scan (known) - map (unknown) split. The test effectively is
- # done by comparing the output masks.
- #
- # TODO(dvytin): improve performance by doing backwards abstract eval.
- #
- # For example, a mask arising from a relu() is an extensive residual, yet
- # only really used in the backpropagation scan, not in the unknown map. But
- # an intermediate activation of a matmul will be used only in the map part.
- # If we were to erroneously push the relu mask to the unknown part, then,
- # in the output, the partial evaluator will also pull the loop-carried state
- # to the unknown, and that is something we can test by comparing the output
- # mask of pe against our intended inst mask.
- for index in range(len(jaxpr_trans.in_avals)):
- if unknowns_mask[index]:
- mask_for_dependence = [False]*len(jaxpr_trans.in_avals)
- mask_for_dependence[index] = True # try moving this to unknown
- _, _, outs_for_dependence, _ = pe.partial_eval_jaxpr_nounits(
- jaxpr_trans, mask_for_dependence, inst_mask)
- if inst_mask != outs_for_dependence:
- unknowns_mask[index] = False
-
- jaxpr_known_body, jaxpr_unknown_body, outs_mask, res_avals = (
- pe.partial_eval_jaxpr_nounits(jaxpr_trans, unknowns_mask, inst_mask)
- )
-
- num_knowns = len(outs_mask) - sum(outs_mask)
-
- linear_list = list(linear_trans)
- known_linear = [
- l for mask, l in zip(unknowns_mask, linear_list) if not mask
- ]
- unknown_linear = [l for mask, l in zip(unknowns_mask, linear_list) if mask]
- unknown_linear = [False] * len(res_avals) + unknown_linear
-
- known_args = [
- arg for mask, arg in zip(unknowns_mask, transpose_inputs) if not mask
- ]
- unknown_args = [
- arg for mask, arg in zip(unknowns_mask, transpose_inputs) if mask
- ]
- # 1. Apply the known scan.
- knowns_and_residual = scan_p.bind(
- *known_args,
- reverse=not reverse,
- length=length,
- num_consts=num_ires,
- num_carry=transpose_num_out_carry,
- jaxpr=jaxpr_known_body,
- linear=tuple(known_linear),
- unroll=unroll,
- _split_transpose=False, # Just generate the loop now.
- )
- known_results, residuals = split_list(knowns_and_residual, [num_knowns])
-
- # 2. Apply the unknown map to residuals and unknown arguments.
- unknown_results = scan_p.bind(
- *residuals, *unknown_args,
- reverse=reverse, # Keep reverse as is for better scheduling.
- length=length,
- num_consts=0,
- num_carry=0,
- jaxpr=jaxpr_unknown_body,
- linear=tuple(unknown_linear),
- unroll=unroll,
- _split_transpose=False, # Just generate the loop now.
- )
- known_results_iter = iter(known_results)
- unknown_results_iter = iter(unknown_results)
- outs = [
- next(known_results_iter) if not mask else next(unknown_results_iter)
- for mask in outs_mask
- ]
-
- ct_immut_consts, ct_init, ct_immut_xs = split_list(outs, [len(immut_consts_dot), len(carry_dot)])
- ct_consts = merge_lists(is_mutable, ct_immut_consts, [None] * len(mut_consts_bar))
- ct_xs = merge_lists(is_mutable_, ct_immut_xs, [None] * len(mut_xs_bar))
- return [None] * num_ires + ct_consts + ct_init + ct_xs + [None] * num_eres
-
-# transpose_scan_jaxpr converts the jaxpr signature:
-# Before: [(ires, T d_mut T d_pure), T c, (CT a_mut, T a, eres)] -> [T c, T b]
-# ---------- consts ----------- --------- ext -------
-#
-# After: [(ires, CT d_mut), (CT d_pure, CT c), (CT a_mut, CT b, eres)] -> [(CT d_pure, CT c), CT a]
-# --- consts ---- ----- carry ------ --------- ext --------
-@weakref_lru_cache
-def _transpose_scan_jaxpr(
- jaxpr: ClosedJaxpr,
- num_ires: int,
- num_d_mut: int,
- num_d_pure: int,
- num_a_mut: int,
- num_a_pure: int,
- num_eres: int,
- ct_b_is_zeros: Sequence[bool]):
- num_d = num_d_mut + num_d_pure
- num_a = num_a_mut + num_a_pure
- num_b_nz = len(ct_b_is_zeros) - sum(ct_b_is_zeros)
- num_c = len(jaxpr.out_avals) - len(ct_b_is_zeros)
- assert num_a == len(jaxpr.in_avals) - num_ires - num_d - num_c - num_eres
-
- ires_avals, d_mut_avals, d_pure_avals, c_avals, a_mut_avals, a_pure_avals, eres_avals = split_list(
- jaxpr.in_avals, [num_ires, num_d_mut, num_d_pure, num_c, num_a_mut, num_a_pure])
- _, b_avals = split_list(jaxpr.out_avals, [num_c])
- b_avals_nz = [a for a, z in zip(b_avals, ct_b_is_zeros) if not z]
-
- # TODO(mattjj,dougalm): map to cotangent types...
- def transposed(*ct_args):
- ires, d_mut_bar, d_pure, c_bar, a_mut_bar, b_bar, eres = split_list(
- ct_args, [num_ires, num_d_mut, num_d_pure, num_c, num_a_mut, num_b_nz])
- b_bar_ = iter(b_bar)
- b_bar = [ad.Zero(a) if z else next(b_bar_) for a, z in zip(b_avals, ct_b_is_zeros)]
- assert next(b_bar_, None) is None
- primals = (
- ires + d_mut_bar +
- [ad.UndefinedPrimal(aval) for aval in [*d_pure_avals, *c_avals]] +
- a_mut_bar + [ad.UndefinedPrimal(aval) for aval in a_pure_avals] + eres)
- cts_out = ad.backward_pass(
- jaxpr.jaxpr, False, jaxpr.consts, primals, c_bar + b_bar)
- _, new_d_pure, new_c_bar, _, a_bar, _ = split_list(
- cts_out, [num_ires + num_d_mut, num_d_pure, num_c, num_a_mut, num_a_pure])
- d_pure = _map(ad.instantiate_zeros, _map(ad.add_tangents, d_pure, new_d_pure))
- new_c_bar = _map(ad.instantiate_zeros, new_c_bar)
- a_bar = _map(ad.instantiate_zeros, a_bar)
- return [*d_pure, *new_c_bar, *a_bar]
-
- transposed_wrapped = lu.wrap_init(transposed, debug_info=jaxpr.jaxpr.debug_info)
- trans_avals = *ires_avals, *d_mut_avals, *d_pure_avals, *c_avals, *a_mut_avals, *b_avals_nz, *eres_avals
- trans_jaxpr = _make_closed_jaxpr(transposed_wrapped, trans_avals)
- return trans_jaxpr
-
def _scan_transpose_fancy(cts, *args, reverse, length, num_consts,
num_carry, jaxpr, linear, unroll, _split_transpose):
consts_lin, init_lin, xs_lin = split_list(linear, [num_consts, num_carry])
@@ -1299,9 +1069,6 @@ def _scan_batching_rule(axis_data, args,
def _cached_scan_pad_jaxpr(jaxpr):
return ClosedJaxpr(*pe.pad_jaxpr(jaxpr.jaxpr, jaxpr.consts))
-def _scan_padding_rule(in_avals, out_avals, *args, jaxpr, **params):
- return scan_p.bind(*args, jaxpr=_cached_scan_pad_jaxpr(jaxpr), **params)
-
def _scan_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn
) -> tuple[list[bool], core.JaxprEqn | None]:
if not any(used_outputs) and not pe.has_effects(eqn):
@@ -1588,7 +1355,6 @@ def rearrange(lst):
scan_p.def_impl(partial(dispatch.apply_primitive, scan_p))
scan_p.def_effectful_abstract_eval(_scan_abstract_eval)
ad.primitive_jvps[scan_p] = _scan_jvp
-ad.primitive_transposes[scan_p] = _scan_transpose
ad.fancy_transposes[scan_p] = _scan_transpose_fancy
ad.primitive_linearizations[scan_p] = _scan_linearize
pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval
@@ -1598,7 +1364,6 @@ def rearrange(lst):
batching.fancy_primitive_batchers[scan_p] = _scan_batching_rule
core.custom_typechecks[scan_p] = partial(_scan_typecheck, False)
pe.partial_eval_jaxpr_custom_rules[scan_p] = _scan_partial_eval_custom
-pe.padding_rules[scan_p] = _scan_padding_rule
pe.dce_rules[scan_p] = _scan_dce_rule
state_discharge.register_partial_discharge_rule(scan_p)(_scan_state_partial_discharge_rule)
@@ -1708,42 +1473,48 @@ def while_loop(cond_fun, body_fun, init_val):
# transformation on it), so we fall back to the primitive version.
pass
- def _create_jaxpr(init_val):
- init_vals, in_tree = tree_flatten((init_val,))
- init_avals = tuple(_map(core.get_aval, init_vals))
- cond_dbg = api_util.debug_info("while_cond", cond_fun, (init_val,), {})
- cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(
- cond_fun, in_tree, init_avals, cond_dbg)
- body_dbg = api_util.debug_info("while_body", body_fun, (init_val,), {})
- body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(
- body_fun, in_tree, init_avals, body_dbg)
- if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1:
+ def _create_jaxpr(init_avals):
+ args_avals = FlatTree.pack(((init_avals,), {}))
+ cond_jaxpr, cond_out_avals = pe.trace_to_jaxpr(cond_fun, args_avals, cond_dbg)
+ body_jaxpr, body_out_avals = pe.trace_to_jaxpr(body_fun, args_avals, body_dbg)
+ if not treedef_is_leaf(cond_out_avals.tree) or len(cond_jaxpr.out_avals) != 1:
msg = "cond_fun must return a boolean scalar, but got pytree {}."
- raise TypeError(msg.format(cond_tree))
+ raise TypeError(msg.format(cond_out_avals.tree))
+
pred_aval = cond_jaxpr.out_avals[0]
if (not isinstance(pred_aval, ShapedArray)
or ShapedArray(pred_aval.shape, pred_aval.dtype) != ShapedArray((), np.bool_)):
msg = "cond_fun must return a boolean scalar, but got output type(s) {}."
raise TypeError(msg.format(cond_jaxpr.out_avals))
- return init_vals, init_avals, body_jaxpr, in_tree, cond_jaxpr, cond_consts, body_consts, body_tree
+
+ return cond_jaxpr, body_jaxpr, body_out_avals
+
+ cond_dbg = api_util.debug_info("while_cond", cond_fun, (init_val,), {})
+ body_dbg = api_util.debug_info("while_body", body_fun, (init_val,), {})
+ init_val = FlatTree.flatten(init_val) # type: ignore
+ init_aval = init_val.map(core.get_aval)
# The body input and output avals must match exactly. However, we want to account for
# the case when init contains weakly-typed values (e.g. Python scalars), with avals that
# may not match the output despite being compatible by virtue of their weak type.
# To do this, we compute the jaxpr in two passes: first with the raw inputs, and if
# necessary, a second time with modified init values.
- init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(init_val)
- new_init_vals, changed = _promote_weak_typed_inputs(
- init_vals, init_avals, body_jaxpr.out_avals)
- new_init_val, = tree_unflatten(in_tree, new_init_vals)
- if changed:
- init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(new_init_val)
- cond_jaxpr, cond_consts, body_consts, body_tree = rest
-
- in_tree_children = in_tree.children()
- assert len(in_tree_children) == 1
- _check_carry_type('while_loop body', body_fun, new_init_val, body_tree,
- body_jaxpr.out_avals)
+ cond_jaxpr, body_jaxpr, body_out_avals = _create_jaxpr(init_aval)
+ if len(body_out_avals) != len(init_aval):
+ _check_carry_type('while_loop body', body_fun, init_aval, body_out_avals)
+ assert False, "shouldn't get here"
+
+ init_val, changed = init_val.map3(
+ _promote_weak_typed_input,
+ init_aval, body_out_avals).unzip2()
+ if any(changed):
+ init_aval = init_val.map(core.get_aval)
+ cond_jaxpr, body_jaxpr, body_out_avals = _create_jaxpr(init_aval)
+
+ cond_jaxpr, cond_consts = pe.separate_consts(cond_jaxpr)
+ body_jaxpr, body_consts = pe.separate_consts(body_jaxpr)
+ _check_carry_type('while_loop body', body_fun, init_aval, body_out_avals)
+
if not all(not v.aval.has_qdd or v.initial_qdd == v.final_qdd for v in
body_jaxpr.jaxpr.invars):
raise TypeError("type-changing mutations not allowed in while_loop body")
@@ -1764,6 +1535,7 @@ def _create_jaxpr(init_val):
_, keep_cond_carry = split_list(keep_cond, [len(cond_consts)])
move_to_const = _map(operator.not_, keep_cond_carry)
+ init_vals = list(init_val) # type: ignore
if any(move_to_const):
cond_jaxpr = pe.close_jaxpr(cond_jaxpr_)
body_jaxpr = pe.prune_closed_jaxpr_outputs(
@@ -1780,7 +1552,7 @@ def _create_jaxpr(init_val):
if any(move_to_const):
outs = pe.merge_lists(move_to_const, outs, new_body_consts)
- return tree_unflatten(body_tree, outs)
+ return body_out_avals.update_from_list(outs).unflatten()
def _join_while_effects(body_jaxpr, cond_jaxpr, body_nconsts, cond_nconsts
@@ -2516,7 +2288,7 @@ def _pred_bcast_select_hlo(ctx,
pred_aval.shape, x_y_aval)
x_y_aval = core.physical_aval(x_y_aval)
bcast_pred = mlir.broadcast_in_dim(
- ctx, pred, core.DShapedArray(x_y_aval.shape, np.dtype(np.bool_)),
+ ctx, pred, core.ShapedArray(x_y_aval.shape, np.dtype(np.bool_)),
broadcast_dimensions=list(range(len(pred_aval.shape))))
return hlo.SelectOp(bcast_pred, x, y).results
@@ -2716,7 +2488,10 @@ def _batch_and_remainder(x, batch_size: int):
leaves, treedef = tree_flatten(x)
if not leaves:
return x, None
- num_batches, remainder = divmod(leaves[0].shape[0], batch_size)
+ if batch_size == 0:
+ num_batches, remainder = 0, leaves[0].shape[0]
+ else:
+ num_batches, remainder = divmod(leaves[0].shape[0], batch_size)
batch_elems = num_batches * batch_size
if num_batches == 0:
remainder_leaves = [_remainder_leaf(leaf, batch_elems) for leaf in leaves]
@@ -2757,6 +2532,8 @@ def map(f, xs):
divisible by the batch size, the remainder is processed in a separate ``vmap`` and
concatenated to the result.
+ ``batch_size=0`` is equivalent to applying a ``vmap``. That is, it uses a full batch.
+
>>> x = jnp.ones((10, 3, 4))
>>> def f(x):
... print('inner shape:', x.shape)
diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py
index 17ed44b69991..e38f51447c8c 100644
--- a/jax/_src/lax/control_flow/solves.py
+++ b/jax/_src/lax/control_flow/solves.py
@@ -27,16 +27,15 @@
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
+from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import pxla
from jax._src.traceback_util import api_boundary
-from jax._src.tree_util import (tree_flatten, treedef_children, tree_leaves,
- tree_unflatten, treedef_tuple)
+from jax._src.tree_util import tree_leaves, FlatTree
from jax._src.util import split_list, safe_map
import numpy as np
from jax._src.lax.control_flow.common import (
_check_tree,
- _initial_style_jaxpr,
)
_map = safe_map
@@ -92,21 +91,22 @@ def custom_root(f: Callable,
The result of calling solve(f, initial_guess) with gradients defined via
implicit differentiation assuming ``f(solve(f, initial_guess)) == 0``.
"""
- guess_flat, in_args_tree = tree_flatten((initial_guess,))
- guess_avals = tuple(_map(core.get_aval, guess_flat))
+ guess_flat = FlatTree.flatten(initial_guess)
+ guess_avals = guess_flat.map(core.get_aval)
f_debug = api_util.debug_info("custom_root", f, (initial_guess,), {})
- f_jaxpr, f_consts, out_tree = _initial_style_jaxpr(
- f, in_args_tree, guess_avals, f_debug)
+ args_avals = FlatTree.pack(((guess_avals,),{}))
+ f_jaxpr, out_avals = pe.trace_to_jaxpr(f, args_avals, f_debug)
+ f_jaxpr, f_consts = pe.separate_consts(f_jaxpr)
- in_tree, = treedef_children(in_args_tree)
- _check_tree("f", "initial_guess", out_tree, in_tree, False)
+ _check_tree("f", "initial_guess", out_avals.tree, guess_avals.tree, False)
solve_debug = api_util.debug_info("custom_root solve", solve,
(f, initial_guess), {},
static_argnums=(0,))
- solve_jaxpr, solve_consts, solution_tree = _initial_style_jaxpr(
- partial(solve, f), in_args_tree, guess_avals, solve_debug)
- _check_tree("solve", "initial_guess", solution_tree, in_tree, has_aux)
+ solve_jaxpr, solution_avals = pe.trace_to_jaxpr(
+ partial(solve, f), args_avals, solve_debug)
+ solve_jaxpr, solve_consts = pe.separate_consts(solve_jaxpr)
+ _check_tree("solve", "initial_guess", solution_avals.tree, guess_flat.tree, has_aux)
def linearize_and_solve(x, b):
unchecked_zeros, f_jvp = api.linearize(f, x)
@@ -114,18 +114,21 @@ def linearize_and_solve(x, b):
linearize_and_solve_dbg = api_util.debug_info("custom_root tangent_solve",
tangent_solve, (initial_guess, initial_guess), {})
- l_and_s_jaxpr, l_and_s_consts, out_tree = _initial_style_jaxpr(
- linearize_and_solve, treedef_tuple((in_tree,) * 2), guess_avals * 2,
- linearize_and_solve_dbg)
- _check_tree("tangent_solve", "x", out_tree, in_tree, False)
+
+
+ linearize_and_solve_avals = FlatTree.pack(((guess_avals, guess_avals), {}))
+ l_and_s_jaxpr, out_avals = pe.trace_to_jaxpr(
+ linearize_and_solve, linearize_and_solve_avals, linearize_and_solve_dbg)
+ l_and_s_jaxpr, l_and_s_consts = pe.separate_consts(l_and_s_jaxpr)
+ _check_tree("tangent_solve", "x", out_avals.tree, guess_flat.tree, False)
all_consts = [f_consts, solve_consts, l_and_s_consts]
const_lengths = _RootTuple(*_map(len, all_consts))
jaxprs = _RootTuple(f_jaxpr, solve_jaxpr, l_and_s_jaxpr)
solution_flat = _custom_root(
- const_lengths, jaxprs, *(_flatten(all_consts) + guess_flat))
- return tree_unflatten(solution_tree, solution_flat)
+ const_lengths, jaxprs, *_flatten(all_consts), *guess_flat)
+ return solution_avals.update_from_list(solution_flat).unflatten()
@partial(custom_derivatives.custom_jvp, nondiff_argnums=(0, 1))
@@ -195,8 +198,8 @@ def _flatten(args):
def _check_shapes(func_name, expected_name, actual, expected):
- actual_shapes = _map(np.shape, tree_leaves(actual))
- expected_shapes = _map(np.shape, tree_leaves(expected))
+ actual_shapes = _map(np.shape, actual)
+ expected_shapes = _map(np.shape, expected)
if actual_shapes != expected_shapes:
raise ValueError(
f"{func_name}() output shapes must match {expected_name}, "
@@ -247,20 +250,19 @@ def custom_linear_solve(
if transpose_solve is None and symmetric:
transpose_solve = solve
- b_flat, in_args_tree = tree_flatten((b,))
- b_avals = tuple(_map(core.get_aval, b_flat))
-
- tree, = treedef_children(in_args_tree)
+ b_flat = FlatTree.flatten(b)
+ b_avals = b_flat.map(core.get_aval)
+ tree = b_flat.tree
def _shape_checked(fun, name, has_aux):
def f(x):
y = fun(x)
- _check_shapes(name, "b", y, b_flat)
+ _check_shapes(name, "b", tree_leaves(y), b_flat)
return y
def f_aux(x):
y, aux = fun(x)
- _check_shapes(name, "b", y, b_flat)
+ _check_shapes(name, "b", tree_leaves(y), b_flat)
return y, aux
return f_aux if has_aux else f
@@ -268,18 +270,21 @@ def f_aux(x):
matvec_debug = api_util.debug_info("custom_linear_solve",
matvec, (b,), {})
# no auxiliary data assumed for matvec
- matvec_jaxpr, matvec_consts, out_tree = _initial_style_jaxpr(
- _shape_checked(matvec, "matvec", False), in_args_tree, b_avals,
+ args_avals = FlatTree.pack(((b_avals,),{}))
+ matvec_jaxpr, out_avals = pe.trace_to_jaxpr(
+ _shape_checked(matvec, "matvec", False), args_avals,
matvec_debug)
- _check_tree("matvec", "b", out_tree, tree, False)
+ matvec_jaxpr, matvec_consts = pe.separate_consts(matvec_jaxpr)
+ _check_tree("matvec", "b", out_avals.tree, tree, False)
solve_debug = api_util.debug_info("custom_linear_solve solve",
solve, (matvec, b), {},
static_argnums=(0,))
- solve_jaxpr, solve_consts, out_tree = _initial_style_jaxpr(
- _shape_checked(partial(solve, matvec), "solve", has_aux), in_args_tree, b_avals,
+ solve_jaxpr, out_avals = pe.trace_to_jaxpr(
+ _shape_checked(partial(solve, matvec), "solve", has_aux), args_avals,
solve_debug)
- _check_tree("solve", "b", out_tree, tree, has_aux)
+ solve_jaxpr, solve_consts = pe.separate_consts(solve_jaxpr)
+ _check_tree("solve", "b", out_avals.tree, tree, has_aux)
if transpose_solve is None:
vecmat_jaxpr = tr_solve_jaxpr = None
@@ -294,25 +299,27 @@ def f_aux(x):
vecmat_consts = matvec_consts
else:
vecmat = _transpose_one_output(matvec, b)
- vecmat_jaxpr, vecmat_consts, out_tree = _initial_style_jaxpr(
- vecmat, in_args_tree, b_avals, transpose_solve_debug)
- assert out_tree == tree
+ vecmat_jaxpr, out_avals = pe.trace_to_jaxpr(
+ vecmat, args_avals, transpose_solve_debug)
+ vecmat_jaxpr, vecmat_consts = pe.separate_consts(vecmat_jaxpr)
+ assert out_avals.tree == tree
- tr_solve_jaxpr, tr_solve_consts, out_tree = _initial_style_jaxpr(
+ tr_solve_jaxpr, out_avals = pe.trace_to_jaxpr(
_shape_checked(partial(transpose_solve, vecmat), "transpose_solve", has_aux),
- in_args_tree, b_avals, transpose_solve_debug)
- _check_tree("transpose_solve", "b", out_tree, tree, has_aux)
+ args_avals, transpose_solve_debug)
+ tr_solve_jaxpr, tr_solve_consts = pe.separate_consts(tr_solve_jaxpr)
+ _check_tree("transpose_solve", "b", out_avals.tree, tree, has_aux)
all_consts = [matvec_consts, vecmat_consts, solve_consts, tr_solve_consts]
const_lengths = _LinearSolveTuple(*_map(len, all_consts))
jaxprs = _LinearSolveTuple(
matvec_jaxpr, vecmat_jaxpr, solve_jaxpr, tr_solve_jaxpr)
- args = _flatten(all_consts) + b_flat
+ args = _flatten(all_consts) + list(b_flat)
args = core.standard_insert_pvary(*args)
out_flat = linear_solve_p.bind(*args, const_lengths=const_lengths, jaxprs=jaxprs)
- return tree_unflatten(out_tree, out_flat)
+ return out_avals.update_from_list(out_flat).unflatten()
def _linear_solve_abstract_eval(*args, const_lengths, jaxprs):
diff --git a/jax/_src/lax/convolution.py b/jax/_src/lax/convolution.py
index 2cd7dce73dee..8527e3678f2e 100644
--- a/jax/_src/lax/convolution.py
+++ b/jax/_src/lax/convolution.py
@@ -302,6 +302,13 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
This function directly calculates a fractionally strided conv rather than
indirectly calculating the gradient (transpose) of a forward convolution.
+ Notes:
+ TensorFlow/Keras Compatibility: By default, JAX does NOT reverse the
+ kernel's spatial dimensions. This differs from TensorFlow's "Conv2DTranspose"
+ and similar frameworks, which flip spatial axes and swap input/output channels.
+
+ To match TensorFlow/Keras behavior, set "transpose_kernel=True" .
+
Args:
lhs: a rank `n+2` dimensional input array.
rhs: a rank `n+2` dimensional array of kernel weights.
diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py
index 6cb687d8f2bd..609367132004 100644
--- a/jax/_src/lax/lax.py
+++ b/jax/_src/lax/lax.py
@@ -34,7 +34,6 @@
from jax._src import array
from jax._src import config
from jax._src import core
-from jax._src import deprecations
from jax._src import dispatch
from jax._src import dtypes
from jax._src import effects
@@ -56,13 +55,13 @@
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import pxla
-from jax._src.interpreters.batching import RaggedAxis
from jax._src.lax import slicing
from jax._src.lax import utils as lax_utils
from jax._src.mesh import get_abstract_mesh, get_concrete_mesh
from jax._src.lax.utils import (
input_dtype, dtype_to_string, standard_multi_result_abstract_eval,
standard_primitive)
+from jax._src.core import typeof
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import chlo
from jax._src.lib.mlir.dialects import hlo
@@ -103,11 +102,7 @@ def _check_static_shape(shape: Shape):
raise TypeError(msg)
assert shapes
- if config.dynamic_shapes.value:
- # pass dynamic shapes through unchecked
- return
- else:
- foreach(_check_static_shape, shapes)
+ foreach(_check_static_shape, shapes)
def _try_broadcast_shapes(*shapes: tuple[int, ...], name: str) -> tuple[int, ...]:
"""
@@ -245,39 +240,6 @@ def broadcast_shardings(*avals):
def _identity(x, **_): return x
-def _extract_tracers_dyn_shape(
- shape: Sequence[int | core.Tracer]
- ) -> tuple[list[core.Tracer], list[int | None]]:
- # Given a sequence representing a shape, pull out Tracers, replacing with None
- if config.dynamic_shapes.value:
- # We must gate this behavior under a flag because otherwise the errors
- # raised are different (and have worse source provenance information).
- dyn_shape = [d for d in shape if isinstance(d, core.Tracer)]
- static_shape = [None if isinstance(d, core.Tracer) else d for d in shape]
- return dyn_shape, static_shape
- else:
- return [], list(shape) # type: ignore
-
-def _merge_dyn_shape(
- static_shape: Sequence[int | None],
- dyn_shape: Sequence[Any],
- ) -> tuple[int | mlir.Value | core.Tracer, ...]:
- # Replace Nones in static_shape with elements of dyn_shape, in order
- dyn_shape_it = iter(dyn_shape)
- shape = tuple(next(dyn_shape_it) if d is None else d for d in static_shape)
- assert next(dyn_shape_it, None) is None
- return shape
-
-def _dyn_shape_staging_rule(trace, source_info, prim, out_aval, *args,
- **params):
- var = trace.frame.newvar(out_aval)
- eqn = pe.new_jaxpr_eqn([x.val for x in args],
- [var],
- prim, params, core.no_effects, source_info)
- out_tracer = pe.DynamicJaxprTracer(trace, out_aval, var, source_info)
- trace.frame.add_eqn(eqn)
- return out_tracer
-
### traceables
@@ -2448,6 +2410,7 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike,
preferred_element_type=preferred_element_type, out_sharding=out_sharding)
+# TODO(jakevdp): replace `*args`` with `*` in v0.10.0
def dot(lhs: ArrayLike, rhs: ArrayLike, *args,
dimension_numbers: DotDimensionNumbers | None = None,
precision: PrecisionLike = None,
@@ -2502,30 +2465,12 @@ def dot(lhs: ArrayLike, rhs: ArrayLike, *args,
.. _stablehlo.dot_general: https://openxla.org/stablehlo/spec#dot_general
.. _DotGeneral: https://www.openxla.org/xla/operation_semantics#dotgeneral
"""
- # TODO(jakevdp): keyword warning added for JAX v0.7.1; finalize this for v0.9.0.
if args:
- deprecations.warn(
- "jax-lax-dot-positional-args",
- (
- "jax.lax.dot: passing precision or preferred_element_type by position"
- " is deprecated; pass them by keyword instead."
- ),
- stacklevel=2
+ raise TypeError(
+ f"dot() takes 2 positional arguments but {2 + len(args)} were given."
+ " Passing precision or preferred_element_type by position is not allowed"
+ " as of JAX v0.9.0; pass them by keyword instead."
)
- # Prior to merging dot and dot_general, dot() had two additional positional args:
- # `precision` and `preferred_element_type`.
- if len(args) == 1:
- if precision is not None:
- raise TypeError("jax.lax.dot got multiple values for argument 'precision'")
- precision, = args
- elif len(args) == 2:
- if precision is not None:
- raise TypeError("jax.lax.dot got multiple values for argument 'precision'")
- if preferred_element_type is not None:
- raise TypeError("jax.lax.dot got multiple values for argument 'preferred_element_type'")
- precision, preferred_element_type = args
- else:
- raise TypeError("Too many positional arguments passed to jax.lax.dot.")
del args
lhs_shape = np.shape(lhs)
@@ -2734,7 +2679,7 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape,
operand: an array
shape: the shape of the target array
broadcast_dimensions: to which dimension in the target shape each dimension
- of the operand shape corresponds to. That is, dimension i of the operand
+ of the operand shape corresponds to. That is, dimension i of the operand
becomes dimension broadcast_dimensions[i] of the result.
Returns:
@@ -2743,21 +2688,18 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape,
See Also:
jax.lax.broadcast : simpler interface to add new leading dimensions.
"""
- # TODO(dfm): Re-write this as a "reshard" when only the sharding changes.
out_sharding = canonicalize_sharding(out_sharding, 'broadcast_in_dim')
if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and
isinstance(operand, Array) and out_sharding is None):
return operand
- if config.dynamic_shapes.value:
- # We must gate this behavior under a flag because otherwise the errors
- # raised are different (and have worse source provenance information).
- dyn_shape, static_shape = _extract_tracers_dyn_shape(shape)
- else:
- dyn_shape, static_shape = [], shape # type: ignore
+ operand_aval = typeof(operand)
+ if (operand_aval.shape == shape and
+ list(broadcast_dimensions) == list(range(operand_aval.ndim)) and
+ out_sharding is not None and operand_aval.sharding != out_sharding):
+ return pjit.reshard(operand, out_sharding)
return broadcast_in_dim_p.bind(
- operand, *dyn_shape, shape=tuple(static_shape),
- broadcast_dimensions=tuple(broadcast_dimensions),
- sharding=out_sharding)
+ operand, shape=tuple(shape),
+ broadcast_dimensions=tuple(broadcast_dimensions), sharding=out_sharding)
def broadcast_to_rank(x: ArrayLike, rank: int) -> Array:
"""Adds leading dimensions of ``1`` to give ``x`` rank ``rank``."""
@@ -2817,15 +2759,14 @@ def reshape(operand: ArrayLike, new_sizes: Shape,
same_dims = tuple(dims) == tuple(range(np.ndim(operand)))
out_sharding = canonicalize_sharding(out_sharding, 'reshape')
same_sharding = (out_sharding is None or
- core.typeof(operand).sharding == out_sharding)
+ typeof(operand).sharding == out_sharding)
if (np.shape(operand) and same_shape and same_dims and same_sharding and
isinstance(operand, Array)):
return operand
else:
- dyn_shape, static_new_sizes = _extract_tracers_dyn_shape(new_sizes)
return reshape_p.bind(
- operand, *dyn_shape, new_sizes=tuple(static_new_sizes),
+ operand, new_sizes=tuple(new_sizes),
dimensions=None if dims is None or same_dims else dims,
sharding=out_sharding)
@@ -3451,12 +3392,10 @@ def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int,
"""Convenience wrapper around ``iota``."""
dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "broadcasted_iota")
shape = canonicalize_shape(shape)
- dynamic_shape = [d for d in shape if isinstance(d, core.Tracer)]
- static_shape = [None if isinstance(d, core.Tracer) else d for d in shape]
dimension = core.concrete_or_error(
int, dimension, "dimension argument of lax.broadcasted_iota")
out_sharding = canonicalize_sharding(out_sharding, 'broadcasted_iota')
- return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape),
+ return iota_p.bind(dtype=dtype, shape=shape,
dimension=dimension, sharding=out_sharding)
def _eye(dtype: DTypeLike, shape: Shape, offset: DimSize = 0) -> Array:
@@ -3485,7 +3424,7 @@ def _delta(dtype: DTypeLike, shape: Shape, axes: Sequence[int]) -> Array:
def _tri(dtype: DTypeLike, shape: Shape, offset: DimSize) -> Array:
"""Like numpy.tri, create a 2D array with ones below a diagonal."""
offset = asarray(core.dimension_as_value(offset))
- if not dtypes.issubdtype(offset, np.integer):
+ if not dtypes.issubdtype(offset.dtype, np.integer):
raise TypeError(f"offset must be an integer, got {offset!r}")
shape_dtype = lax_utils.int_dtype_for_shape(shape, signed=True)
if (
@@ -3546,17 +3485,12 @@ def stop_gradient(x: T) -> T:
the applicability of ``stop_gradient``.
"""
def stop(x):
- # only bind primitive on inexact dtypes, to avoid some staging
if dtypes.issubdtype(core.get_aval(x).dtype, dtypes.extended):
return x
- elif (dtypes.issubdtype(_dtype(x), np.floating) or
- dtypes.issubdtype(_dtype(x), np.complexfloating)):
- # break abstractions to support legacy leaked tracer use cases
- if isinstance(x, ad.JVPTracer):
- return stop(x.primal)
- return ad_util.stop_gradient_p.bind(x)
+ elif isinstance(x, ad.JVPTracer):
+ return stop(x.primal)
else:
- return x
+ return ad_util.stop_gradient_p.bind(x)
return tree_util.tree_map(stop, x)
def reduce_precision(operand: float | ArrayLike,
@@ -3652,7 +3586,7 @@ def full_like(x: ArrayLike | DuckTypedArray,
# TODO(yashkatariya): Maybe use `shaped_abstractify` here instead of
# `typeof` because `x` can be anything that implements the
# `DuckTypedArray` protocol.
- val = core.pvary(val, tuple(core.typeof(x).vma))
+ val = core.pvary(val, tuple(typeof(x).vma))
return val
@@ -3919,7 +3853,6 @@ def _iter(tracer):
else:
return (slicing.index_in_dim(tracer, i, keepdims=False) for i in range(n))
ShapedArray._iter = staticmethod(_iter)
-core.DShapedArray._iter = staticmethod(_iter)
def zeros_like_array(x: ArrayLike) -> Array:
return full_like(x, 0)
@@ -3981,7 +3914,6 @@ def unop(result_dtype, accepted_dtypes, name, supports_narrow_ints=True):
vma_rule=_attrgetter('vma'),
reduced_rule=unop_reduced_rule)
batching.defvectorized(prim)
- pe.def_trivial_padding(prim)
return prim
standard_unop = partial(unop, _identity)
@@ -4106,7 +4038,6 @@ def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False,
vma_rule=partial(core.standard_vma_rule, name),
unreduced_rule=unreduced_rule, reduced_rule=nary_reduced_rule)
batching.defbroadcasting(prim)
- pe.def_trivial_padding(prim)
return prim
standard_naryop = partial(naryop, input_dtype)
@@ -4114,11 +4045,11 @@ def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False,
# Like autograd.numpy.numpy_vjps.unbroadcast, this utility handles transposition
# involving linear primitives with implicit broadcasting.
def _unbroadcast(aval, x):
- if not isinstance(aval, (core.DShapedArray, ShapedArray)):
+ if not isinstance(aval, ShapedArray):
raise TypeError("transpose with implicit broadcasting of unshaped values")
x_shape = np.shape(x)
if (core.definitely_equal_shape(aval.shape, x_shape) and
- aval.sharding == core.typeof(x).sharding):
+ aval.sharding == typeof(x).sharding):
return x
assert not aval.shape or len(x_shape) == len(aval.shape)
if not aval.shape:
@@ -4131,17 +4062,20 @@ def _unbroadcast(aval, x):
x = reduce_sum(x, dims) if dims else x
return reshape(x, aval.shape, out_sharding=aval.to_cotangent_aval().sharding)
-def _maybe_broadcast(target_shape, x):
+def _maybe_broadcast(target_shape, x, target_sharding):
x_shape = np.shape(x)
- if core.definitely_equal_shape(x_shape, target_shape):
+ x_sharding = typeof(x).sharding
+ if (core.definitely_equal_shape(x_shape, target_shape) and
+ x_sharding == target_sharding):
return x
elif not x_shape:
- return broadcast_in_dim(x, target_shape, ())
+ return broadcast_in_dim(x, target_shape, (), out_sharding=target_sharding)
else:
dims = [i for i, (a, b) in enumerate(zip(x_shape, target_shape))
if core.definitely_equal(a, b)]
squeeze_shape = [x_shape[i] for i in dims]
- return broadcast_in_dim(reshape(x, squeeze_shape), target_shape, dims)
+ return broadcast_in_dim(reshape(x, squeeze_shape), target_shape, dims,
+ out_sharding=target_sharding)
def broadcast_hlo(
aval_out: core.ShapedArray, avals: Sequence[core.ShapedArray],
@@ -4247,7 +4181,6 @@ def _round_lower(ctx, x, *, rounding_method):
exp_p = standard_unop(_float | _complex, 'exp')
ad.defjvp2(exp_p, lambda g, ans, x, **kwargs: mul(g, ans))
mlir.register_lowering(exp_p, partial(_nary_lower_hlo, hlo.exponential))
-batching.ragged_prop_rules[exp_p] = batching.ragged_mask_elementwise_rule
core.pp_eqn_rules[exp_p] = _unary_with_accuracy_pp_rule
exp2_p = standard_unop(_float | _complex, 'exp2')
@@ -4358,7 +4291,6 @@ def _sin_lin(nzs, x, accuracy):
ad.primitive_linearizations[sin_p] = _sin_lin
mlir.register_lowering(sin_p, _sin_lowering)
core.pp_eqn_rules[sin_p] = _unary_with_accuracy_pp_rule
-batching.ragged_prop_rules[sin_p] = batching.ragged_mask_elementwise_rule
def _cos_complex(x):
# cos(x) = complex(cos(real(x)) * cosh(imag(x)), -sin(real(x)) * sinh(imag(x)))
@@ -4563,8 +4495,9 @@ def _pow_jvp_lhs(g, ans, x, y):
if dtypes.issubdtype(y_dtype, np.integer):
if x.shape != y.shape:
shape = broadcast_shapes(x.shape, y.shape)
- x = _maybe_broadcast(shape, x)
- y = _maybe_broadcast(shape, y)
+ sharding = broadcast_shardings(typeof(x), typeof(y))
+ x = _maybe_broadcast(shape, x, sharding)
+ y = _maybe_broadcast(shape, y, sharding)
jac = select(eq(y, _const(y, 0)), _zeros(y),
mul(_replace_zero(y), pow(x, sub(y, _ones(y)))))
else:
@@ -4602,7 +4535,6 @@ def _integer_pow_jvp(g, x, *, y):
sharding_rule=_attrgetter('sharding'), vma_rule=_attrgetter('vma'))
batching.defvectorized(integer_pow_p)
ad.defjvp(integer_pow_p, _integer_pow_jvp)
-pe.def_trivial_padding(integer_pow_p)
def _integer_pow(x, *, y):
# This should be kept in sync with the jax2tf translation rule.
@@ -4672,9 +4604,11 @@ def _add_jvp(primals, tangents):
if type(xdot) is type(ydot) is ad_util.Zero:
return primal_out, ad_util.Zero.from_primal_value(primal_out)
if type(xdot) is ad_util.Zero:
- return primal_out, _maybe_broadcast(primal_out.shape, ydot)
+ return (primal_out, _maybe_broadcast(primal_out.shape, ydot,
+ typeof(primal_out).sharding))
elif type(ydot) is ad_util.Zero:
- return primal_out, _maybe_broadcast(primal_out.shape, xdot)
+ return (primal_out, _maybe_broadcast(primal_out.shape, xdot,
+ typeof(primal_out).sharding))
else:
return primal_out, add(xdot, ydot)
@@ -4717,7 +4651,6 @@ def _add_unreduced_rule(out_sharding, x, y):
ad.primitive_jvps[add_p] = _add_jvp
ad.primitive_transposes[add_p] = _add_transpose
mlir.register_lowering(add_p, partial(_nary_lower_hlo, hlo.add))
-batching.ragged_prop_rules[add_p] = batching.ragged_mask_elementwise_rule
def _sub_jvp(primals, tangents):
x, y = primals
@@ -4726,9 +4659,11 @@ def _sub_jvp(primals, tangents):
if type(xdot) is type(ydot) is ad_util.Zero:
return primal_out, ad_util.Zero.from_primal_value(primal_out)
if type(xdot) is ad_util.Zero:
- return primal_out, _maybe_broadcast(primal_out.shape, neg(ydot))
+ return (primal_out, _maybe_broadcast(primal_out.shape, neg(ydot),
+ typeof(primal_out).sharding))
elif type(ydot) is ad_util.Zero:
- return primal_out, _maybe_broadcast(primal_out.shape, xdot)
+ return (primal_out, _maybe_broadcast(primal_out.shape, xdot,
+ typeof(primal_out).sharding))
else:
return primal_out, sub(xdot, ydot)
@@ -4747,7 +4682,6 @@ def _sub_transpose(t, x, y):
ad.primitive_jvps[sub_p] = _sub_jvp
ad.primitive_transposes[sub_p] = _sub_transpose
mlir.register_lowering(sub_p, partial(_nary_lower_hlo, hlo.subtract))
-batching.ragged_prop_rules[sub_p] = batching.ragged_mask_elementwise_rule
def _mul_unreduced_rule(out_sharding, x, y):
x_ur, y_ur = x.sharding.spec.unreduced, y.sharding.spec.unreduced
@@ -4785,7 +4719,6 @@ def _mul_unreduced_rule(out_sharding, x, y):
ad.defbilinear(mul_p, lambda ct, x, y: _unbroadcast(x.aval, mul(ct, y)),
lambda ct, x, y: _unbroadcast(y.aval, mul(x, ct)))
mlir.register_lowering(mul_p, partial(_nary_lower_hlo, hlo.multiply))
-batching.ragged_prop_rules[mul_p] = batching.ragged_mask_elementwise_rule
def _div_transpose_rule(cotangent, x, y):
assert ad.is_undefined_primal(x) and not ad.is_undefined_primal(y)
@@ -4799,19 +4732,21 @@ def _div_transpose_rule(cotangent, x, y):
lambda g, x, y: mul(mul(neg(g), x), integer_pow(y, -2)))
ad.primitive_transposes[div_p] = _div_transpose_rule
mlir.register_lowering(div_p, partial(_nary_lower_hlo, hlo.divide))
-batching.ragged_prop_rules[div_p] = batching.ragged_mask_elementwise_rule
rem_p = standard_naryop([_int | _float, _int | _float], 'rem')
ad.defjvp(
rem_p,
- lambda g, x, y: _maybe_broadcast(broadcast_shapes(np.shape(x), np.shape(y)), g),
+ lambda g, x, y: _maybe_broadcast(
+ broadcast_shapes(np.shape(x), np.shape(y)), g,
+ broadcast_shardings(typeof(x), typeof(y))),
lambda g, x, y: mul(neg(g), mul(sign(div(x, y)), floor(abs(div(x, y))))))
mlir.register_lowering(rem_p, partial(_nary_lower_hlo, hlo.remainder))
def _minmax_complex_lowering(x, y, *, lax_cmp_pick_x):
result_shape = broadcast_shapes(np.shape(x), np.shape(y))
- x = _maybe_broadcast(result_shape, x)
- y = _maybe_broadcast(result_shape, y)
+ result_sharding = broadcast_shardings(typeof(x), typeof(y))
+ x = _maybe_broadcast(result_shape, x, result_sharding)
+ y = _maybe_broadcast(result_shape, y, result_sharding)
rx = real(x)
ry = real(y)
pick_x = select(eq(rx, ry), lax_cmp_pick_x(imag(x), imag(y)),
@@ -4823,14 +4758,12 @@ def _minmax_complex_lowering(x, y, *, lax_cmp_pick_x):
lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)),
lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x)))
mlir.register_lowering(max_p, partial(_nary_lower_hlo, mlir.max_hlo))
-batching.ragged_prop_rules[max_p] = batching.ragged_mask_elementwise_rule
min_p: core.Primitive = standard_naryop([_any, _any], 'min')
ad.defjvp2(min_p,
lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)),
lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x)))
mlir.register_lowering(min_p, partial(_nary_lower_hlo, mlir.min_hlo))
-batching.ragged_prop_rules[min_p] = batching.ragged_mask_elementwise_rule
shift_left_p = standard_naryop([_int, _int], 'shift_left')
ad.defjvp_zero(shift_left_p)
@@ -4897,7 +4830,6 @@ def _compare_lower_hlo(direction: str, total_order: bool, ctx, x, y):
eq_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq', allow_extended_dtype=True)
ad.defjvp_zero(eq_p)
mlir.register_lowering(eq_p, partial(_compare_lower_hlo, "EQ", False))
-batching.ragged_prop_rules[eq_p] = batching.ragged_mask_elementwise_rule
ne_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'ne', allow_extended_dtype=True)
ad.defjvp_zero(ne_p)
@@ -4918,7 +4850,6 @@ def _compare_lower_hlo(direction: str, total_order: bool, ctx, x, y):
lt_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'lt')
ad.defjvp_zero(lt_p)
mlir.register_lowering(lt_p, partial(_compare_lower_hlo, "LT", False))
-batching.ragged_prop_rules[lt_p] = batching.ragged_mask_elementwise_rule
eq_to_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq_to')
ad.defjvp_zero(eq_to_p)
@@ -5071,11 +5002,7 @@ def _convert_element_type_batching_rule(
pe.const_fold_rules[convert_element_type_p] = _convert_elt_type_folding_rule
pe.forwarding_rules[convert_element_type_p] = _convert_elt_type_fwd_rule
-pe.def_trivial_padding(convert_element_type_p)
core.pp_eqn_rules[convert_element_type_p] = _convert_elt_type_pp_rule
-batching.ragged_prop_rules[convert_element_type_p] = (
- batching.ragged_mask_elementwise_rule
-)
def _real_dtype(dtype): return np.finfo(dtype).dtype
@@ -5098,7 +5025,7 @@ def _to_edtype_abstract_eval(x, *, edtype):
not isinstance(x.dtype, dtypes.ExtendedDType))
# For backward compatibility, if the edtype rules have a `convert_to` method,
# use that rather than looking for an `allow_conversion: bool` attribute.
- if not isinstance(x, (ShapedArray, core.DShapedArray)):
+ if not isinstance(x, ShapedArray):
raise TypeError("can only convert to an extended dtype on an array type,"
f"but got {type(x)}")
if convert_to := getattr(edtype._rules, 'convert_to', None):
@@ -5143,8 +5070,6 @@ def _to_edtype_abstract_eval(x, *, edtype):
f"shape {rep_aval.shape}")
return x.update(shape=shape_prefix, dtype=edtype,
sharding=x.sharding.update(spec=spec_prefix))
- elif isinstance(x, core.DShapedArray):
- return x.update(shape=shape_prefix, dtype=edtype)
else:
assert False # unreachable, see isinstance check above
@@ -5163,7 +5088,7 @@ def _to_edtype_abstract_eval(x, *, edtype):
def _from_edtype_abstract_eval(x, *, dtype):
assert (isinstance(x.dtype, dtypes.ExtendedDType) and
not isinstance(dtype, dtypes.ExtendedDType))
- if not isinstance(x, (ShapedArray, core.DShapedArray)):
+ if not isinstance(x, ShapedArray):
raise TypeError("can only convert from an extended dtype on an array type,"
f"but got {type(x)}")
if convert_from := getattr(x.dtype._rules, 'convert_from', None):
@@ -5184,11 +5109,6 @@ def _from_edtype_abstract_eval(x, *, dtype):
f"{dtype_to_string(rep_aval.dtype)}.")
if isinstance(x, ShapedArray):
return x.update(shape=(*x.shape, *rep_aval.shape), dtype=dtype)
- elif isinstance(x, core.DShapedArray):
- if all(isinstance(d, int) for d in x.shape):
- return core.ShapedArray(shape=(*x.shape, *rep_aval.shape), dtype=dtype)
- else:
- raise NotImplementedError
else:
assert False # unreachable, see isinstance check above
@@ -5566,30 +5486,14 @@ def _dot_batch_rule(
lhs, rhs = unpack_args(batched_args)
lbd, rbd = unpack_dims(batch_dims)
- left_stack_dim = lbd.stacked_axis if type(lbd) is RaggedAxis else lbd
- right_stack_dim = rbd.stacked_axis if type(rbd) is RaggedAxis else rbd
new_dimension_numbers, result_stack_dim = _dot_general_batch_dim_nums(
- (np.ndim(lhs), np.ndim(rhs)), (left_stack_dim, right_stack_dim),
+ (np.ndim(lhs), np.ndim(rhs)), (lbd, rbd),
dimension_numbers)
- # TODO Should probably check that any ragged dimensions have corresponding
- # sizes, because otherwise the dot product is technically undefined.
- #
- # This masking is not strictly necessary for non-contraction dimensions;
- # we could micro-optimize here by avoiding computing that mask.
- if type(lbd) is RaggedAxis:
- lhs = batching.mask_ragged_axes(lhs, _get_sum_identity, lbd)
- lhs_shape = batching.bdim_as_shape(lbd, lhs.shape)
- else:
- lhs_shape = np.shape(lhs)
- if type(rbd) is RaggedAxis:
- rhs = batching.mask_ragged_axes(rhs, _get_sum_identity, rbd)
- rhs_shape = batching.bdim_as_shape(rbd, rhs.shape)
- else:
- rhs_shape = np.shape(rhs)
- result_batch_dim = batching.shape_as_bdim(
- result_stack_dim,
- _dot_general_shape_computation(lhs_shape, rhs_shape, new_dimension_numbers))
+ lhs_shape = np.shape(lhs)
+ rhs_shape = np.shape(rhs)
+ result_shape = _dot_general_shape_computation(lhs_shape, rhs_shape, new_dimension_numbers)
+ result_batch_dim = canonicalize_axis(result_stack_dim, len(result_shape))
if out_sharding is not None:
out_sharding = batching.get_sharding_for_vmap(
@@ -5677,15 +5581,6 @@ def bump_dims(dims, b):
)
return new_dimension_numbers, result_batch_dim
-def _dot_general_padding_rule(in_avals, out_avals, lhs, rhs, *,
- dimension_numbers, **params):
- lhs_aval, _ = in_avals
- (lhs_contract, _), _ = dimension_numbers
- padded_axes = [(i, lhs_aval.shape[i].val) for i in lhs_contract
- if isinstance(lhs_aval.shape[i], pe.BoundedAxisSize)]
- lhs_ = _replace_masked_values(lhs, 0, padded_axes)
- return [dot_general(lhs_, rhs, dimension_numbers=dimension_numbers, **params)]
-
def _dot_general_pp_rule(eqn, context, settings) -> pp.Doc:
# * suppress printing precision or preferred_element_type when None.
# * print dimension_numbers as list-of-lists to be shorter.
@@ -5696,59 +5591,6 @@ def _dot_general_pp_rule(eqn, context, settings) -> pp.Doc:
return core._pp_eqn(eqn.replace(params=printed_params), context, settings)
-def _dot_general_ragged_prop_rule(eqn_params, invar_raggedness, outvars):
- assert len(invar_raggedness) == 2
- assert len(outvars) == 1
- invar_raggedness_lhs = invar_raggedness[0]
- invar_raggedness_rhs = invar_raggedness[1]
-
- dimension_numbers = eqn_params['dimension_numbers']
- (lhs_contracting, rhs_contracting), (_, _) = dimension_numbers
-
- if not invar_raggedness_lhs and not invar_raggedness_rhs:
- # Both are dense - it is valid to reach here, because dense operations
- # are legal in code running under ragged prop.
- return invar_raggedness, [None]
-
- if not invar_raggedness_lhs or not invar_raggedness_rhs:
- # One ragged, one dense
- if not invar_raggedness_lhs:
- # left is dense, right is ragged
- _, ragged_axis_dim_rhs, _, _ = invar_raggedness_rhs
- if rhs_contracting != ragged_axis_dim_rhs:
- # Contraction is on a dense dimension, this is valid!
- return invar_raggedness, [None]
- if not invar_raggedness_rhs:
- # left is ragged, right is dense
- _, ragged_axis_dim_lhs, _, _ = invar_raggedness_lhs
- if lhs_contracting != ragged_axis_dim_lhs:
- # Contraction is on a dense dimension, this is valid!
- return invar_raggedness, [None]
-
- raise NotImplementedError('NYI - dense and ragged dim contraction')
-
- stacked_axis_lhs, ragged_axis_dim_lhs, _, _ = invar_raggedness_lhs
- stacked_axis_rhs, ragged_axis_dim_rhs, _, _ = invar_raggedness_rhs
-
- if stacked_axis_rhs != 0 or stacked_axis_lhs != 0:
- raise NotImplementedError(
- 'Dot general ragged prop for non 0 stacked axis, NYI'
- )
-
- # We only support ragged k atm, that is, lhs is (m, ragged_k) and rhs is
- # (ragged_k, n), meaning the output is dense.
- if ragged_axis_dim_lhs != 2 or ragged_axis_dim_rhs != 1:
- raise NotImplementedError(
- 'Dot general ragged prop for non contraction raggedness, NYI'
- )
-
- assert len(outvars) == 1
-
- # TODO(mvoz): A constant on batching.* ?
- # Dense (m, n) - no jumble only atm
- return invar_raggedness, [None]
-
-
dot_general_p = standard_primitive(
_dot_general_shape_rule,
_dot_general_dtype_rule,
@@ -5780,9 +5622,7 @@ def _dot_general_batch_unpack_dims(batch_dims):
)
batching.fancy_primitive_batchers[dot_general_p] = _dot_general_batch_rule
batching.skippable_batchers[dot_general_p] = lambda _: ()
-pe.padding_rules[dot_general_p] = _dot_general_padding_rule
core.pp_eqn_rules[dot_general_p] = _dot_general_pp_rule
-batching.ragged_prop_rules[dot_general_p] = _dot_general_ragged_prop_rule
def _full_precision(precision: Precision) -> tuple[Precision, Precision]:
@@ -6582,96 +6422,52 @@ def _broadcast_in_dim_sharding_rule(operand, *, shape, broadcast_dimensions,
orig_spec = iter(operand.sharding.spec)
new_spec = [next(orig_spec) if i in bds else None for i in range(len(shape))]
assert next(orig_spec, None) is None
+ mesh = (get_abstract_mesh() if operand.sharding.mesh.empty else
+ operand.sharding.mesh)
return operand.sharding.update(
- spec=operand.sharding.spec.update(partitions=new_spec))
+ mesh=mesh, spec=operand.sharding.spec.update(partitions=new_spec))
def _broadcast_in_dim_typecheck_rule(
- _, operand, *dyn_shape, shape, broadcast_dimensions, sharding):
- if not dyn_shape:
- out_aval, effects = broadcast_in_dim_p.abstract_eval(
- operand.aval, shape=shape, broadcast_dimensions=broadcast_dimensions,
- sharding=sharding)
- return [out_aval], effects
- else:
- # TODO(mattjj): perform more checks like _broadcast_in_dim_shape_rule
- out_shape = _merge_dyn_shape(shape, dyn_shape)
- out_shape = [x.val if type(x) is core.Literal else x for x in out_shape] # pytype: disable=attribute-error
- out_aval = core.DShapedArray(tuple(out_shape), operand.aval.dtype,
- operand.aval.weak_type)
- return [out_aval], core.no_effects
-
-def _broadcast_in_dim_transpose_rule(ct, operand, *dyn_shape,
+ _, operand, shape, broadcast_dimensions, sharding):
+ out_aval, effects = broadcast_in_dim_p.abstract_eval(
+ operand.aval, shape=shape, broadcast_dimensions=broadcast_dimensions,
+ sharding=sharding)
+ return [out_aval], effects
+
+def _broadcast_in_dim_transpose_rule(ct, operand,
shape, broadcast_dimensions, sharding):
if type(ct) is ad_util.Zero:
return [ad_util.Zero(operand.aval)]
if not isinstance(operand, ad.UndefinedPrimal):
- return [None] * (1 + len(dyn_shape)) # transpose wrt literal
+ return [None] # transpose wrt literal
unit_dims = [i for i, s in enumerate(operand.aval.shape)
if core.definitely_equal(s, 1)]
bdims = tuple(np.delete(broadcast_dimensions, unit_dims))
axes = tuple(np.delete(range(len(shape)), bdims))
- return ([expand_dims(reduce_sum(ct, axes), unit_dims)] +
- [None] * len(dyn_shape))
+ return [expand_dims(reduce_sum(ct, axes), unit_dims)]
def _broadcast_in_dim_batch_rule(axis_data, batched_args, batch_dims, shape,
broadcast_dimensions, sharding):
- # `dyn_shape` is the dynamic portion of the target shape. `shape`
- # is the target shape, with `None` for dynamic sections.
- # broadcast_dimensions gives indices where dimensions of the input
- # have to go: dimension i of the input becomes dimension
- # broadcast_dimensions[i] of the output.
- operand, *dyn_shape = batched_args
- operand_bdim, *dyn_shape_bdims = batch_dims
-
- stacked_size = None
- if operand_bdim is not None:
- if isinstance(operand_bdim, RaggedAxis):
- stacked_axis = operand_bdim.stacked_axis
- stacked_size = operand_bdim.size
- else:
- stacked_axis = operand_bdim
- stacked_size = operand.shape[stacked_axis]
- new_operand = batching.moveaxis(operand, stacked_axis, 0)
- new_broadcast_dimensions = (0,) + tuple(np.add(1, broadcast_dimensions))
- else:
- new_operand = operand
- new_broadcast_dimensions = tuple(np.add(1, broadcast_dimensions))
-
- # TODO(mattjj,axch) This section assumes that the shape of the operand is
- # broadcast-compatible with the requested shape. We should tweak vmap to run
- # the abstract_eval rule so this can be checked while the raggedness
- # information is available.
- dyn_limits = []
- out_ragged_sizes = []
- for sizes, bdim in zip(dyn_shape, dyn_shape_bdims):
- if bdim is None:
- # TODO(mattjj,axch) Is this what bdim == None means?
- assert isinstance(sizes, int)
- bound = sizes
- else:
- bound = sizes.dtype.bound
- out_ragged_sizes.append(sizes)
- if stacked_size is None:
- stacked_size = len(sizes)
- else:
- msg = "All segments lengths arrays must be the same length"
- assert len(sizes) == stacked_size, msg
- dyn_limits.append(bound)
- new_shape = (stacked_size,) + _merge_dyn_shape(shape, dyn_limits)
+ # `shape` is the target shape. broadcast_dimensions gives indices where
+ # dimensions of the input have to go: dimension i of the input becomes
+ # dimension broadcast_dimensions[i] of the output.
+ operand, = batched_args
+ operand_bdim, = batch_dims
+ assert operand_bdim is not None
+ new_operand = batching.moveaxis(operand, operand_bdim, 0)
+ new_broadcast_dimensions = (0,) + tuple(np.add(1, broadcast_dimensions))
+ new_shape = (operand.shape[operand_bdim],) + shape
if sharding is not None:
sharding = batching.get_sharding_for_vmap(axis_data, sharding, 0)
result = broadcast_in_dim(new_operand, new_shape, new_broadcast_dimensions,
out_sharding=sharding)
- out_ragged_axes = [idx+1 for idx, s in enumerate(shape) if s is None]
- out_bdim = batching.make_batch_axis(
- result.ndim, 0, zip(out_ragged_axes, out_ragged_sizes))
- return result, out_bdim
+ return result, 0
def _broadcast_in_dim_fwd_rule(eqn):
- v, *dyn = eqn.invars
- if (not dyn and core.definitely_equal_shape(eqn.params['shape'], v.aval.shape)
+ v, = eqn.invars
+ if (core.definitely_equal_shape(eqn.params['shape'], v.aval.shape)
and (eqn.params['sharding'] is None or
eqn.params['sharding'] == v.aval.sharding)):
return [0], None
@@ -6679,103 +6475,51 @@ def _broadcast_in_dim_fwd_rule(eqn):
return [None], eqn
def _broadcast_in_dim_staging_rule(
- trace, source_info, x, *dyn, shape, broadcast_dimensions, sharding):
+ trace, source_info, x, shape, broadcast_dimensions, sharding):
params = dict(shape=shape, broadcast_dimensions=broadcast_dimensions,
sharding=sharding)
- if not dyn:
- return trace.default_process_primitive(broadcast_in_dim_p, (x,), params,
- source_info=source_info)
- aval = core.DShapedArray(_merge_dyn_shape(shape, dyn), x.dtype, x.weak_type)
- return _dyn_shape_staging_rule(trace, source_info, broadcast_in_dim_p, aval,
- x, *dyn, **params)
-
-def _broadcast_in_dim_padding_rule(in_avals, out_avals, x, *dyn_shape,
- shape, broadcast_dimensions):
- del in_avals, dyn_shape
- out_aval, = out_avals
- new_shape = []
- new_dyn_shape = []
- for d in out_aval.shape:
- if type(d) is pe.BoundedAxisSize:
- new_shape.append(d.bound)
- elif type(d) is int:
- new_shape.append(d)
- else:
- assert isinstance(d, core.Tracer)
- new_shape.append(None)
- new_dyn_shape.append(d)
- return [broadcast_in_dim_p.bind(x, *new_dyn_shape, shape=tuple(new_shape),
- broadcast_dimensions=broadcast_dimensions)]
+ return trace.default_process_primitive(broadcast_in_dim_p, (x,), params,
+ source_info=source_info)
def _broadcast_in_dim_jvp_rule(primals, tangents, *, shape, broadcast_dimensions,
sharding):
- operand, *dyn_shape = primals
+ operand, = primals
operand_dot, *_ = tangents
- y = broadcast_in_dim_p.bind(operand, *dyn_shape, shape=shape,
+ y = broadcast_in_dim_p.bind(operand, shape=shape,
broadcast_dimensions=broadcast_dimensions,
sharding=sharding)
if type(operand_dot) is ad_util.Zero:
y_dot = ad_util.Zero.from_primal_value(y)
else:
- y_dot = broadcast_in_dim_p.bind(operand_dot, *dyn_shape, shape=shape,
+ y_dot = broadcast_in_dim_p.bind(operand_dot, shape=shape,
broadcast_dimensions=broadcast_dimensions,
sharding=sharding)
return y, y_dot
def _broadcast_in_dim_partial_eval(
- trace, operand, *dyn_shape, shape, broadcast_dimensions, sharding):
- if not dyn_shape:
- return trace.default_process_primitive(
- broadcast_in_dim_p, (operand, *dyn_shape),
- dict(shape=shape, broadcast_dimensions=broadcast_dimensions,
- sharding=sharding))
- assert all(t.pval.is_known() for t in dyn_shape)
- operand_tracer = trace.instantiate_const(operand)
- dyn_shape_tracers = map(trace.instantiate_const, dyn_shape)
- dyn_shape_tracers_ = iter(dyn_shape_tracers)
- shape_ = [next(dyn_shape_tracers_) if d is None else d for d in shape]
- out_aval = core.DShapedArray(tuple(shape_), operand.dtype, operand.weak_type)
- out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None)
- eqn = pe.new_eqn_recipe(
- trace, [operand_tracer, *dyn_shape_tracers], [out_tracer], broadcast_in_dim_p,
+ trace, operand, shape, broadcast_dimensions, sharding):
+ return trace.default_process_primitive(
+ broadcast_in_dim_p, (operand,),
dict(shape=shape, broadcast_dimensions=broadcast_dimensions,
- sharding=None),
- core.no_effects, source_info_util.current())
- out_tracer.recipe = eqn
- return out_tracer
+ sharding=sharding))
-def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions,
+def _broadcast_in_dim_lower(ctx, x, shape, broadcast_dimensions,
sharding) -> Sequence[ir.Value]:
aval_out, = ctx.avals_out
- if dyn_shape:
- aval_out = aval_out.update(shape=_merge_dyn_shape(shape, dyn_shape))
out = mlir.broadcast_in_dim(ctx, x, aval_out,
broadcast_dimensions=broadcast_dimensions)
return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)]
-def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions,
+def _broadcast_in_dim_abstract_eval(x, shape, broadcast_dimensions,
sharding):
- if (not dyn_shape and
- not any(isinstance(d, core.DArray) and
- type(core.get_aval(d).dtype) is core.bint for d in shape)):
- shape = _broadcast_in_dim_shape_rule( # error checking
- x, shape=shape, broadcast_dimensions=broadcast_dimensions, sharding=None)
- new_sharding = _broadcast_in_dim_sharding_rule(
- x, shape=shape, broadcast_dimensions=broadcast_dimensions,
- sharding=sharding)
- new_vma = core.standard_vma_rule('broadcast_in_dim', x)
- return core.ShapedArray(shape, x.dtype, x.weak_type, sharding=new_sharding,
- vma=new_vma, memory_space=x.memory_space)
- # If any BInts in shape, or Tracers in dyn_shape, produce a DShapedArray
- # (even if x is a ShapedArray)
- # TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code
- return core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), x.dtype, x.weak_type)
-
-
-def _broadcast_in_dim_ragged_prop_rule(eqn_params, invar_raggedness, outvars):
- assert len(invar_raggedness) == 1
- assert not isinstance(invar_raggedness[0], core.Var)
- return invar_raggedness, [None] * len(outvars)
+ shape = _broadcast_in_dim_shape_rule( # error checking
+ x, shape=shape, broadcast_dimensions=broadcast_dimensions, sharding=None)
+ new_sharding = _broadcast_in_dim_sharding_rule(
+ x, shape=shape, broadcast_dimensions=broadcast_dimensions,
+ sharding=sharding)
+ new_vma = core.standard_vma_rule('broadcast_in_dim', x)
+ return core.ShapedArray(shape, x.dtype, x.weak_type, sharding=new_sharding,
+ vma=new_vma, memory_space=x.memory_space)
broadcast_in_dim_p = core.Primitive('broadcast_in_dim')
@@ -6788,12 +6532,8 @@ def _broadcast_in_dim_ragged_prop_rule(eqn_params, invar_raggedness, outvars):
pe.forwarding_rules[broadcast_in_dim_p] = _broadcast_in_dim_fwd_rule
pe.custom_partial_eval_rules[broadcast_in_dim_p] = _broadcast_in_dim_partial_eval
pe.custom_staging_rules[broadcast_in_dim_p] = _broadcast_in_dim_staging_rule
-pe.padding_rules[broadcast_in_dim_p] = _broadcast_in_dim_padding_rule
core.custom_typechecks[broadcast_in_dim_p] = _broadcast_in_dim_typecheck_rule
mlir.register_lowering(broadcast_in_dim_p, _broadcast_in_dim_lower)
-batching.ragged_prop_rules[broadcast_in_dim_p] = (
- _broadcast_in_dim_ragged_prop_rule
-)
def _clamp_shape_rule(min, operand, max):
@@ -6862,7 +6602,6 @@ def _clamp_batch_rule(batched_args, batch_dims, **params):
select(lt(max, operand), g, _zeros(operand)))
batching.primitive_batchers[clamp_p] = _clamp_batch_rule
mlir.register_lowering(clamp_p, partial(_nary_lower_hlo, hlo.clamp))
-pe.def_trivial_padding(clamp_p)
def _concatenate_shape_rule(*operands, **kwargs):
dimension = kwargs.pop('dimension')
@@ -6953,7 +6692,6 @@ def _concatenate_pad_rule(in_avals, out_avals, *operands, dimension):
ad.deflinear2(concatenate_p, _concatenate_transpose_rule)
ad.primitive_transposes[concatenate_p] = _concatenate_transpose_rule
batching.primitive_batchers[concatenate_p] = _concatenate_batch_rule
-pe.padding_rules[concatenate_p] = _concatenate_pad_rule
def _concatenate_lower(ctx, *xs, dimension):
aval_out, = ctx.avals_out
@@ -6987,10 +6725,8 @@ def _split_transpose_rule(cotangents, operand, *, sizes, axis):
assert ad.is_undefined_primal(operand)
if all(type(t) is ad_util.Zero for t in cotangents):
return ad_util.Zero(operand.aval),
- cotangents = [
- _zeros(t.aval) if type(t) is ad_util.Zero else t
- for t in cotangents
- ]
+ cotangents = [t.instantiate() if type(t) is ad_util.Zero else t
+ for t in cotangents]
return concatenate(cotangents, dimension=axis),
def _split_batch_rule(batched_args, batch_dims, *, sizes, axis):
@@ -7173,12 +6909,11 @@ def _squeeze_transpose_rule(t, operand, *, dimensions):
def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions):
operand, = batched_args
bdim, = batch_dims
- operand, bdim = batching.move_stacked_axis(operand, bdim, 0)
+ operand = batching.moveaxis(operand, bdim, 0)
dimensions = tuple(np.add(1, dimensions))
- out_stack_dim = bdim.stacked_axis if isinstance(bdim, RaggedAxis) else bdim
- bdim_out = batching.shape_as_bdim(
- out_stack_dim,
- _compute_squeeze_shape(batching.bdim_as_shape(bdim, operand.shape), dimensions))
+
+ result_shape = _compute_squeeze_shape(operand.shape, dimensions)
+ bdim_out = canonicalize_axis(0, len(result_shape))
return squeeze(operand, dimensions=dimensions), bdim_out
squeeze_p = standard_primitive(
@@ -7188,8 +6923,6 @@ def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions):
reduced_rule=_squeeze_reduced_rule)
ad.deflinear2(squeeze_p, _squeeze_transpose_rule)
batching.primitive_batchers[squeeze_p] = _squeeze_batch_rule
-pe.def_trivial_padding(squeeze_p)
-batching.ragged_prop_rules[squeeze_p] = batching.ragged_mask_no_op_rule
def _squeeze_lower(ctx, operand, *, dimensions):
del dimensions # Implied by the output aval.
@@ -7221,12 +6954,6 @@ def _reshape_shape_rule(operand, *, new_sizes, dimensions, sharding):
# TODO(necula): re-enable this check
operand_size = math.prod(np.shape(operand))
new_size = math.prod(new_sizes)
- if (not config.dynamic_shapes.value and
- not operand_size == new_size):
- msg = (f"reshape total size must be unchanged, got new_sizes {new_sizes} "
- f"(of total size {new_size}) for shape {np.shape(operand)} "
- f"(of total size {operand_size}).")
- raise TypeError(msg)
if dimensions is not None:
if set(dimensions) != set(range(np.ndim(operand))):
msg = ('reshape dimensions must be a permutation of operand dimensions, '
@@ -7370,20 +7097,12 @@ def _merge_an_axis_sharding_rule(operand, operand_merge, new_sizes, dimensions):
return operand.sharding.update(spec=new_spec)
-def _reshape_typecheck_rule(_, operand, *dyn_shape, new_sizes, dimensions,
+def _reshape_typecheck_rule(_, operand, new_sizes, dimensions,
sharding):
- if not dyn_shape:
- out_aval, effects = reshape_p.abstract_eval(
- operand.aval, new_sizes=new_sizes, dimensions=dimensions,
- sharding=sharding)
- return [out_aval], effects
- else:
- # TODO(mattjj, necula): perform more checks like _reshape_shape_rule
- out_shape = _merge_dyn_shape(new_sizes, dyn_shape)
- out_shape = [x.val if type(x) is core.Literal else x for x in out_shape] # pytype: disable=attribute-error
- out_aval = core.DShapedArray(tuple(out_shape), operand.aval.dtype,
- operand.aval.weak_type)
- return [out_aval], core.no_effects
+ out_aval, effects = reshape_p.abstract_eval(
+ operand.aval, new_sizes=new_sizes, dimensions=dimensions,
+ sharding=sharding)
+ return [out_aval], effects
def _reshape_dtype_rule(operand, *, new_sizes, dimensions, sharding):
@@ -7417,24 +7136,18 @@ def _reshape_batch_rule(axis_data, batched_args, batch_dims, *, new_sizes,
return out, 0
-def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions, sharding):
+def _reshape_lower(ctx, x, new_sizes, dimensions, sharding):
aval_out, = ctx.avals_out
if dimensions is not None:
x = hlo.transpose(x, mlir.dense_int_array(dimensions))
- if dyn_shape:
- aval_out = aval_out.update(shape=_merge_dyn_shape(new_sizes, dyn_shape))
out = mlir.reshape(ctx, x, aval_out)
return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)]
def _reshape_staging_rule(
- trace, source_info, x, *dyn, new_sizes, dimensions, sharding):
+ trace, source_info, x, new_sizes, dimensions, sharding):
params = dict(new_sizes=new_sizes, dimensions=dimensions, sharding=sharding)
- if not dyn:
- return trace.default_process_primitive(reshape_p, (x,), params,
- source_info=source_info)
- av = core.DShapedArray(_merge_dyn_shape(new_sizes, dyn), x.dtype, x.weak_type)
- return _dyn_shape_staging_rule(trace, source_info, reshape_p, av, x, *dyn,
- **params)
+ return trace.default_process_primitive(reshape_p, (x,), params,
+ source_info=source_info)
reshape_p = standard_primitive(_reshape_shape_rule, _reshape_dtype_rule,
'reshape', sharding_rule=_reshape_sharding_rule,
@@ -7508,12 +7221,8 @@ def _transpose_reduced_rule(out_s, operand, *, permutation):
def _transpose_batch_rule(batched_args, batch_dims, *, permutation):
operand, = batched_args
bdim, = batch_dims
- stack_dim = bdim.stacked_axis if isinstance(bdim, RaggedAxis) else bdim
- perm = (stack_dim,) + tuple(i if i < stack_dim else i+1 for i in permutation)
- if isinstance(bdim, RaggedAxis):
- res_bdim = batching.transpose_ragged_axes(bdim.move_stacked_axis(0), perm)
- else:
- res_bdim = 0
+ perm = (bdim,) + tuple(i if i < bdim else i+1 for i in permutation)
+ res_bdim = 0
return transpose(operand, perm), res_bdim
def _transpose_lower(ctx, x, *, permutation):
@@ -7535,7 +7244,6 @@ def _transpose_lower(ctx, x, *, permutation):
lambda t, _, permutation: [transpose(t, np.argsort(permutation))])
batching.primitive_batchers[transpose_p] = _transpose_batch_rule
mlir.register_lowering(transpose_p, _transpose_lower)
-pe.def_trivial_padding(transpose_p)
def _select_shape_rule(which, *cases):
@@ -7611,7 +7319,7 @@ def _select_batch_rule(axis_data, batched_args, batch_dims, **unused_kwargs):
# vmapped function had a scalar which with nonscalar args
assert np.ndim(which) == 1
which = broadcast_in_dim(which, cases[0].shape, [which_bdim],
- out_sharding=core.typeof(cases[0]).sharding)
+ out_sharding=typeof(cases[0]).sharding)
return select_n(which, *cases), which_bdim
elif np.ndim(which) == 0 and all(bdim is not None for bdim in case_bdims):
if all(case_bdims[0] == bdim for bdim in case_bdims[1:]):
@@ -7633,7 +7341,7 @@ def _select_batch_rule(axis_data, batched_args, batch_dims, **unused_kwargs):
# vmapped function had a scalar which with nonscalar args
assert np.ndim(which) == 1
which = broadcast_in_dim(which, cases[0].shape, [0],
- out_sharding=core.typeof(cases[0]).sharding)
+ out_sharding=typeof(cases[0]).sharding)
if np.ndim(which) > np.ndim(cases[0]):
assert np.ndim(cases[0]) == 0
cases = [broadcast(c, which.shape) for c in cases]
@@ -7720,7 +7428,6 @@ def _select(offset, cases):
batching.fancy_primitive_batchers[select_n_p] = _select_batch_rule
batching.skippable_batchers[select_n_p] = lambda _: ()
mlir.register_lowering(select_n_p, _select_hlo_lowering)
-pe.def_trivial_padding(select_n_p)
def _reduce_shape_rule(*avals, computation, jaxpr, dimensions):
@@ -7934,9 +7641,6 @@ def _reduce_sum_reduced_rule(out_s, operand, *, axes, **kwargs):
reduced_rule=_reduce_sum_reduced_rule)
ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule)
batching.defreducer(reduce_sum_p, _get_sum_identity)
-pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, reduce_sum,
- _get_sum_identity)
-batching.ragged_prop_rules[reduce_sum_p] = batching.ragged_mask_elementwise_rule
def _reduce_prod_jvp_rule(primals, tangents, *, axes):
reducer = lambda x, y: [mul(x, y)]
@@ -7956,9 +7660,6 @@ def _reduce_op_sharding_rule(operand, *, axes):
vma_rule=partial(core.standard_vma_rule, 'reduce_prod'))
ad.primitive_jvps[reduce_prod_p] = _reduce_prod_jvp_rule
batching.defreducer(reduce_prod_p, _get_prod_identity)
-pe.padding_rules[reduce_prod_p] = partial(_reducer_padding, reduce_prod,
- _get_prod_identity)
-
def _reduce_chooser_jvp_rule(g, ans, operand, *, axes):
# TODO(mattjj): an alternative is to use variadic reduce to compute the chosen
@@ -7977,9 +7678,6 @@ def _reduce_chooser_jvp_rule(g, ans, operand, *, axes):
vma_rule=partial(core.standard_vma_rule, 'reduce_max'))
ad.defjvp2(reduce_max_p, _reduce_chooser_jvp_rule)
batching.defreducer(reduce_max_p, _get_max_identity)
-pe.padding_rules[reduce_max_p] = partial(_reducer_padding, reduce_max,
- _get_max_identity)
-batching.ragged_prop_rules[reduce_max_p] = batching.ragged_mask_elementwise_rule
reduce_min_p = standard_primitive(
@@ -7988,9 +7686,6 @@ def _reduce_chooser_jvp_rule(g, ans, operand, *, axes):
vma_rule=partial(core.standard_vma_rule, 'reduce_min'))
ad.defjvp2(reduce_min_p, _reduce_chooser_jvp_rule)
batching.defreducer(reduce_min_p, _get_min_identity)
-pe.padding_rules[reduce_min_p] = partial(_reducer_padding, reduce_min,
- _get_min_identity)
-
def _argminmax_shape_rule(operand, *, axes, index_dtype):
axis, = axes
@@ -8093,7 +7788,7 @@ def _reduce_logical_sharding_rule(operand, *, axes):
def _reduce_or_lin(nzs, x, *, axes):
nz, = nzs
y = reduce_or_p.bind(x, axes=axes)
- aval = core.typeof(y).to_tangent_aval()
+ aval = typeof(y).to_tangent_aval()
return y, False, (), lambda _, t: ad_util.Zero(aval)
reduce_or_p = standard_primitive(
@@ -8109,7 +7804,6 @@ def _reduce_or_lin(nzs, x, *, axes):
weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule,
vma_rule=partial(core.standard_vma_rule, 'reduce_and'))
batching.defreducer(reduce_and_p, _get_bitwise_and_identity)
-batching.ragged_prop_rules[reduce_and_p] = batching.ragged_mask_elementwise_rule
reduce_xor_p = standard_primitive(
@@ -8189,12 +7883,19 @@ def _reduce_precision_lower(ctx, operand, *, exponent_bits, mantissa_bits):
}
-def _sort_abstract_eval(*args, **kwargs):
- args = tuple(args)
- if any(arg.shape != args[0].shape for arg in args[1:]):
- shapes = " ".join(str(a.shape) for a in args)
+def _sort_abstract_eval(*avals, **kwargs):
+ avals = tuple(avals)
+ if any(arg.shape != avals[0].shape for arg in avals[1:]):
+ shapes = " ".join(str(a.shape) for a in avals)
raise TypeError(f"Arguments to sort must have equal shapes, got: {shapes}")
- return args
+ non_empty_s = [
+ a.sharding for a in avals
+ if not a.sharding.mesh.empty and a.sharding.mesh._any_axis_explicit]
+ if any(s != non_empty_s[0] for s in non_empty_s[1:]):
+ shardings = " ".join(str(s) for s in non_empty_s)
+ raise core.ShardingTypeError(
+ f'Arguments to sort must have equal shardings, got: {shardings}')
+ return avals
def _canonicalize_float_for_sort(x):
@@ -8292,7 +7993,9 @@ def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable, num_keys
for arg, bdim in zip(batched_args, batch_dims):
if bdim is None:
dims = np.delete(np.arange(prototype_arg.ndim), new_bdim)
- new_args.append(broadcast_in_dim(arg, prototype_arg.shape, dims))
+ new_args.append(broadcast_in_dim(
+ arg, prototype_arg.shape, dims,
+ out_sharding=typeof(prototype_arg).sharding))
else:
new_args.append(batching.moveaxis(arg, bdim, new_bdim))
new_dimension = dimension + (new_bdim <= dimension)
@@ -8441,7 +8144,6 @@ def _stop_gradient_batch_rule(batched_args, batch_dims):
ad.primitive_jvps[ad_util.stop_gradient_p] = _stop_gradient_jvp_rule
batching.primitive_batchers[ad_util.stop_gradient_p] = _stop_gradient_batch_rule
-pe.def_trivial_padding(ad_util.stop_gradient_p)
def create_token(_=None):
@@ -8694,7 +8396,6 @@ def _copy_impl(prim, *args, **kwargs):
copy_p.def_abstract_eval(lambda x: x)
mlir.register_lowering(copy_p, lambda ctx, x: [x])
ad.deflinear(copy_p, lambda t: [copy_p.bind(t)])
-pe.def_trivial_padding(copy_p)
batching.defvectorized(copy_p)
# The dce_sink_p primitive marks a value as "used" from the perspective of DCE
@@ -8714,7 +8415,6 @@ class NoDCEEffect(effects.Effect):
dce_sink_p.def_effectful_abstract_eval(lambda _: ([], {no_dce_effect}))
mlir.register_lowering(dce_sink_p, lambda ctx, _: [])
ad.deflinear(dce_sink_p, lambda _: [])
-pe.def_trivial_padding(dce_sink_p)
batching.primitive_batchers[dce_sink_p] = lambda x, bd: (x, bd)
def rng_bit_generator(key, shape, dtype=np.uint32,
@@ -8745,10 +8445,9 @@ def rng_bit_generator(key, shape, dtype=np.uint32,
out_sharding=out_sharding))
-def _iota_abstract_eval(*dyn_shape, dtype, shape, dimension, sharding):
- if not dyn_shape:
- # TODO(mattjj) Generalize shape_like checking to permit dynamic shapes
- _check_shapelike("iota", "shape", shape)
+def _iota_abstract_eval(dtype, shape, dimension, sharding):
+ # TODO(mattjj) Generalize shape_like checking to permit dynamic shapes
+ _check_shapelike("iota", "shape", shape)
if not any(dtypes.issubdtype(dtype, t) for t in _num):
msg = 'iota does not accept dtype {}. Accepted dtypes are subtypes of {}.'
typename = dtype_to_string(dtype)
@@ -8757,88 +8456,35 @@ def _iota_abstract_eval(*dyn_shape, dtype, shape, dimension, sharding):
if not 0 <= dimension < len(shape):
raise ValueError("iota dimension must be between 0 and len(shape), got "
f"{dimension=} for {shape=}")
- if (not dyn_shape and
- not any(isinstance(d, core.DArray) and
- type(core.get_aval(d).dtype) is core.bint for d in shape)):
- if sharding is None:
- sharding = core.get_cur_mesh_sharding(spec=core.P(*[None] * len(shape)))
- return ShapedArray(shape, dtype, sharding=sharding)
- # TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code
- return core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), dtype, False)
-
+ if sharding is None:
+ sharding = core.get_cur_mesh_sharding(spec=core.P(*[None] * len(shape)))
+ return ShapedArray(shape, dtype, sharding=sharding)
iota_p = Primitive('iota')
iota_p.def_impl(partial(dispatch.apply_primitive, iota_p))
iota_p.def_abstract_eval(_iota_abstract_eval)
-batching.ragged_prop_rules[iota_p] = batching.ragged_mask_no_op_rule
-def _iota_staging_rule(trace, source_info, *dyn_shape, dtype, shape, dimension,
+def _iota_staging_rule(trace, source_info, dtype, shape, dimension,
sharding):
params = dict(dtype=dtype, shape=shape, dimension=dimension,
sharding=sharding)
- if not dyn_shape:
- return trace.default_process_primitive(iota_p, (), params,
+ return trace.default_process_primitive(iota_p, (), params,
source_info=source_info)
- aval = core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), dtype, False)
- return _dyn_shape_staging_rule(trace, source_info, iota_p, aval, *dyn_shape,
- **params)
pe.custom_staging_rules[iota_p] = _iota_staging_rule
-def _iota_typecheck_rule(_, *dyn_shape, dtype, shape, dimension, sharding):
- if not dyn_shape:
- out_aval, effects = iota_p.abstract_eval(
- dtype=dtype, shape=shape, dimension=dimension, sharding=sharding)
- return [out_aval], effects
- else:
- out_shape = _merge_dyn_shape(shape, dyn_shape)
- out_shape = [x.val if type(x) is core.Literal else x for x in out_shape] # pytype: disable=attribute-error
- out_aval = core.DShapedArray(tuple(out_shape), dtype, False)
- return [out_aval], core.no_effects
+def _iota_typecheck_rule(_, dtype, shape, dimension, sharding):
+ out_aval, effects = iota_p.abstract_eval(
+ dtype=dtype, shape=shape, dimension=dimension, sharding=sharding)
+ return [out_aval], effects
core.custom_typechecks[iota_p] = _iota_typecheck_rule
-def _iota_lower(ctx, *dyn_shape, dtype, shape, dimension, sharding):
+def _iota_lower(ctx, dtype, shape, dimension, sharding):
del dtype
aval_out, = ctx.avals_out
- if dyn_shape:
- aval_out = aval_out.update(shape=_merge_dyn_shape(shape, dyn_shape))
out = mlir.iota(ctx, aval_out, dimension=dimension)
return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)]
mlir.register_lowering(iota_p, _iota_lower)
-def _iota_batching_rule(in_vals, in_dims, *, dtype, shape, dimension,
- sharding):
- (segment_lengths,), (ax,) = in_vals, in_dims
- assert ax == 0
- bound = segment_lengths.dtype.bound
- ragged_axis, = (i for i, dim in enumerate(shape) if dim is None)
- shape = (len(segment_lengths),) + _merge_dyn_shape(shape, (bound,))
- if sharding is not None:
- raise NotImplementedError('Please file an issue if you want this support')
- iota = broadcasted_iota(dtype, shape, dimension+1)
- return iota, batching.RaggedAxis(ax, ((ragged_axis+1, segment_lengths),))
-batching.primitive_batchers[iota_p] = _iota_batching_rule
-
-def _iota_padding_rule(in_avals, out_avals, *dyn_shape, dtype, shape, dimension,
- sharding):
- out_aval, = out_avals
- new_shape = []
- new_dyn_shape = []
- for d in out_aval.shape:
- if type(d) is pe.BoundedAxisSize:
- new_shape.append(d.bound)
- elif type(d) is int:
- new_shape.append(d)
- else:
- assert isinstance(d, core.Tracer)
- new_shape.append(None)
- new_dyn_shape.append(d)
- if sharding is not None:
- raise NotImplementedError('Please file an issue if you want this support')
- return [iota_p.bind(*new_dyn_shape, shape=tuple(new_shape),
- dtype=dtype, dimension=dimension, sharding=sharding)]
-pe.padding_rules[iota_p] = _iota_padding_rule
-
-
### util
_ndim = np.ndim
@@ -8981,9 +8627,6 @@ def _check_shapelike(fun_name, arg_name, obj, non_zero_shape=False):
# bool(obj) for an ndarray raises an error, so we check len
if not len(obj): # pylint: disable=g-explicit-length-test
return
- if (config.dynamic_shapes.value and isinstance(obj, (tuple, list)) and
- any(isinstance(d, (core.Tracer, core.DArray)) for d in obj)):
- return # TODO(mattjj): handle more checks in the dynamic shape case
obj_arr = np.array(obj)
if obj_arr.ndim != 1:
msg = "{} {} must be 1-dimensional, got {}."
@@ -9214,43 +8857,9 @@ def _empty2_lower(ctx, *, dtype, memory_space):
ad.primitive_jvps[tie_p] = \
lambda primals, tangents: (tie_p.bind(*primals), tangents[-1])
ad.primitive_transposes[tie_p] = lambda ct, x, _: [None, ct]
-pe.def_trivial_padding(tie_p)
batching.defvectorized(tie_p)
-class BIntRules:
- allow_conversion: bool = True
-
- @staticmethod
- def physical_element_aval(dtype) -> core.ShapedArray:
- return core.ShapedArray((), np.dtype('int32'))
-
- @staticmethod
- def result_handler(sticky_device, aval):
- def handler(_, buf):
- buf.aval = core.ShapedArray(buf.shape, buf.dtype)
- return core.DArray(aval, buf)
- return handler
-
- @staticmethod
- def global_sharded_result_handler(aval, out_sharding, committed):
- phys_aval = core.physical_aval(aval)
- phys_handler_maker = pxla.global_result_handlers[core.ShapedArray]
-
- if not dispatch.is_single_device_sharding(out_sharding):
- raise NotImplementedError # TODO(mattjj)
- else:
- phys_sharding = out_sharding
- phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed)
-
- def handler(bufs):
- return core.DArray(aval, phys_handler(bufs))
- return handler
-
-
-core.bint._rules = BIntRules
-
-
def optimization_barrier(operand, /):
"""Prevents the compiler from moving operations across the barrier.
diff --git a/jax/_src/lax/other.py b/jax/_src/lax/other.py
index f67f64a40133..b3d54064f9b4 100644
--- a/jax/_src/lax/other.py
+++ b/jax/_src/lax/other.py
@@ -284,8 +284,9 @@ def _logaddexp_jvp(primals, tangents):
x1, x2 = primals
t1, t2 = tangents
primal_out = logaddexp(x1, x2)
- tangent_out = lax.add(lax.mul(t1, lax.exp(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))),
- lax.mul(t2, lax.exp(lax.sub(_replace_inf(x2), _replace_inf(primal_out)))))
+ tangent_out = lax.add(
+ lax.mul(t1, lax.exp(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))),
+ lax.mul(t2, lax.exp(lax.sub(_replace_inf(x2), _replace_inf(primal_out)))))
return primal_out, tangent_out
diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py
index f1c916e39923..b06e543ae481 100644
--- a/jax/_src/lax/parallel.py
+++ b/jax/_src/lax/parallel.py
@@ -2830,6 +2830,8 @@ def _get_from(aval, axes: tuple[AxisName, ...], name) -> str:
_allowed_pcast_to = {'unreduced', 'reduced', 'varying'}
def pcast(x, axis_name, *, to: str):
+ if isinstance(axis_name, (set, frozenset)):
+ raise TypeError(f"{axis_name=} must be a tuple or a str. Got {axis_name}")
axes = (axis_name,) if not isinstance(axis_name, tuple) else axis_name
if not axis_name:
return x
diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py
index efcb930012bb..9ff745e922ce 100644
--- a/jax/_src/lax/slicing.py
+++ b/jax/_src/lax/slicing.py
@@ -26,7 +26,6 @@
from jax._src import ad_util
from jax._src import api
-from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
@@ -174,15 +173,9 @@ def dynamic_slice(
"""
start_indices = _dynamic_slice_indices(
operand, start_indices, allow_negative_indices)
- if config.dynamic_shapes.value:
- dynamic_sizes, static_sizes = lax._extract_tracers_dyn_shape(slice_sizes)
- else:
- dynamic_sizes = []
- static_sizes = core.canonicalize_shape(slice_sizes) # type: ignore
- operand, *start_indices = core.standard_insert_pvary(
- operand, *start_indices)
- return dynamic_slice_p.bind(operand, *start_indices, *dynamic_sizes,
- slice_sizes=tuple(static_sizes))
+ sizes = core.canonicalize_shape(slice_sizes) # type: ignore
+ operand, *start_indices = core.standard_insert_pvary(operand, *start_indices)
+ return dynamic_slice_p.bind(operand, *start_indices, slice_sizes=tuple(sizes))
def dynamic_update_slice(
@@ -1369,11 +1362,10 @@ def _slice_shape_rule(operand, *, start_indices, limit_indices, strides):
msg = ("slice start_indices must be greater than or equal to zero, "
"got start_indices of {}.")
raise TypeError(msg.format(start_indices))
- if not config.dynamic_shapes.value:
- if not all(map(operator.ge, limit_indices, start_indices)):
- msg = ("slice limit_indices must be greater than or equal to start_indices,"
- " got start_indices {} and limit_indices {}.")
- raise TypeError(msg.format(start_indices, limit_indices))
+ if not all(map(operator.ge, limit_indices, start_indices)):
+ msg = ("slice limit_indices must be greater than or equal to start_indices,"
+ " got start_indices {} and limit_indices {}.")
+ raise TypeError(msg.format(start_indices, limit_indices))
diff = tuple(map(operator.sub, limit_indices, start_indices))
if strides is None or tuple(strides) == (1,) * len(operand.shape):
return diff
@@ -1485,9 +1477,6 @@ def _slice_batching_rule(batched_args, batch_dims, *, start_indices,
ad.deflinear2(slice_p, _slice_transpose_rule)
ad.fancy_transposes[slice_p] = _slice_transpose_fancy
batching.primitive_batchers[slice_p] = _slice_batching_rule
-# TODO(mvoz): A better slice rule for ragged prop, enforcing boundaries
-# or supporting nested jumbles. NYI.
-batching.ragged_prop_rules[slice_p] = batching.ragged_mask_no_op_rule
# Override the standard impl to defer to dynamic_slice whenever possible.
# This lets us reuse the same program for many applications of slicing for as
@@ -1514,28 +1503,19 @@ def _slice_lower(ctx, x, *, start_indices, limit_indices, strides):
mlir.register_lowering(slice_p, _slice_lower)
-def _dynamic_slice_shape_rule(operand, *starts_and_dyn_sizes, slice_sizes):
- start_indices, dyn = util.split_list(starts_and_dyn_sizes, [operand.ndim])
- if operand.ndim != len(start_indices):
- msg = ("dynamic_slice start_indices must have length equal to the number "
- "of dimensions of the operand, got indices {} for operand shape {}.")
- raise TypeError(msg.format(start_indices, operand.shape))
- if len(start_indices) != len(slice_sizes):
- msg = ("dynamic_slice slice_sizes must have the same length as "
- "start_indices, got start_indices length {} and slice_sizes {}.")
- raise TypeError(msg.format(len(start_indices), slice_sizes))
- if not dyn and not all(map(operator.ge, operand.shape, slice_sizes)):
+def _dynamic_slice_shape_rule(operand, *start_indices, slice_sizes):
+ if not all(map(operator.ge, operand.shape, slice_sizes)):
msg = ("slice slice_sizes must be less than or equal to operand shape, "
"got slice_sizes {} for operand shape {}.")
raise TypeError(msg.format(slice_sizes, operand.shape))
- if not dyn and not all(ssz >= 0 for ssz in slice_sizes):
+ if not all(ssz >= 0 for ssz in slice_sizes):
msg = ("slice slice_sizes must be greater than or equal to zero, "
"got slice_sizes of {}.")
raise TypeError(msg.format(slice_sizes))
if any(idx.ndim != 0 for idx in start_indices):
raise TypeError("start_indices arguments to dynamic_slice must be scalars, "
f" got indices {start_indices}")
- return tuple(lax._merge_dyn_shape(slice_sizes, dyn))
+ return tuple(slice_sizes)
def _dynamic_slice_sharding_rule(operand, *starts_and_dyn_sizes, slice_sizes):
out_shape = _dynamic_slice_shape_rule(
@@ -1592,13 +1572,16 @@ def _batch_dynamic_slice_indices(indices, bdims):
empty_marker = object()
size = next((x.shape[i] for x, i in zip(indices, bdims) if i is not None),
empty_marker)
+ out = next(((core.typeof(x).sharding.mesh, core.typeof(x).sharding.spec[i])
+ for x, i in zip(indices, bdims) if i is not None), None)
if size is empty_marker:
return lax.concatenate([lax.broadcast(i, (1,)) for i in indices], 0), None
+ out_s = None if out is None else NamedSharding(out[0], P(out[1], None))
indices = lax.concatenate(
- [lax.broadcast_in_dim(x, (size, 1),
- broadcast_dimensions=((0,) if i is not None else ()))
- for x, i in zip(indices, bdims)],
- dimension=1)
+ [lax.broadcast_in_dim(
+ x, (size, 1), broadcast_dimensions=((0,) if i is not None else ()),
+ out_sharding=out_s)
+ for x, i in zip(indices, bdims)], dimension=1)
return indices, 0
def _dynamic_slice_batching_rule(batched_args, batch_dims, *, slice_sizes):
@@ -1626,42 +1609,16 @@ def _dynamic_slice_batching_rule(batched_args, batch_dims, *, slice_sizes):
slice_sizes=slice_sizes, unique_indices=True, indices_are_sorted=True,
mode=GatherScatterMode.PROMISE_IN_BOUNDS, fill_value=None)
-def _dynamic_slice_staging_rule(trace, source_info, x, *starts_and_dyn_sizes,
+def _dynamic_slice_staging_rule(trace, source_info, x, *start_indices,
slice_sizes):
- start_indices, dyn = util.split_list(starts_and_dyn_sizes, [x.ndim])
- if not dyn:
- return trace.default_process_primitive(
- dynamic_slice_p, (x, *start_indices), dict(slice_sizes=slice_sizes),
- source_info=source_info)
- shape = lax._merge_dyn_shape(slice_sizes, dyn)
- aval = core.DShapedArray(shape, x.dtype, False)
- return lax._dyn_shape_staging_rule(trace, source_info, dynamic_slice_p, aval,
- x, *starts_and_dyn_sizes,
- slice_sizes=slice_sizes)
-
-def _dynamic_slice_typecheck_rule(_, x, *starts_and_dyn_sizes, slice_sizes):
- start_indices, dyn = util.split_list(starts_and_dyn_sizes, [x.aval.ndim])
- if not dyn:
- out_aval, effects = dynamic_slice_p.abstract_eval(
- x.aval, *(d.aval for d in start_indices), slice_sizes=slice_sizes)
- return [out_aval], effects
- else:
- # TODO(mattjj): perform more checks
- out_shape = lax._merge_dyn_shape(slice_sizes, dyn)
- out_shape = [d.val if type(d) is core.Literal else d for d in out_shape]
- out_aval = core.DShapedArray(tuple(out_shape), x.aval.dtype,
- x.aval.weak_type)
- return [out_aval], core.no_effects
-
-def _dynamic_slice_padding_rule(in_avals, out_avals, x, *starts_and_dyn,
- slice_sizes):
- x_aval, start_indices_avals, dyn_avals = util.split_list(in_avals, [1, x.ndim])
- start_indices, dyn = util.split_list(starts_and_dyn, [x.ndim])
- dyn_ = [a.dtype.bound if type(a.dtype) is core.bint else d
- for a, d in zip(dyn_avals, dyn)]
- slice_sizes_ = lax._merge_dyn_shape(slice_sizes, dyn_)
- start_idx = [d.val if type(d) is core.DArray else d for d in start_indices]
- return [dynamic_slice(x, start_idx, slice_sizes_)]
+ return trace.default_process_primitive(
+ dynamic_slice_p, (x, *start_indices), dict(slice_sizes=slice_sizes),
+ source_info=source_info)
+
+def _dynamic_slice_typecheck_rule(_, x, *start_indices, slice_sizes):
+ out_aval, effects = dynamic_slice_p.abstract_eval(
+ x.aval, *(d.aval for d in start_indices), slice_sizes=slice_sizes)
+ return [out_aval], effects
dynamic_slice_p = standard_primitive(
_dynamic_slice_shape_rule, _dynamic_slice_dtype_rule, 'dynamic_slice',
@@ -1675,14 +1632,10 @@ def _dynamic_slice_padding_rule(in_avals, out_avals, x, *starts_and_dyn,
batching.primitive_batchers[dynamic_slice_p] = _dynamic_slice_batching_rule
pe.custom_staging_rules[dynamic_slice_p] = _dynamic_slice_staging_rule
core.custom_typechecks[dynamic_slice_p] = _dynamic_slice_typecheck_rule
-pe.padding_rules[dynamic_slice_p] = _dynamic_slice_padding_rule
-def _dynamic_slice_lower(ctx, x, *starts_and_dyn_sizes, slice_sizes):
+def _dynamic_slice_lower(ctx, x, *start_indices, slice_sizes):
x_aval, *_ = ctx.avals_in
- start_indices, dyn = util.split_list(starts_and_dyn_sizes, [x_aval.ndim])
aval_out, = ctx.avals_out
- if dyn:
- aval_out = aval_out.update(shape=lax._merge_dyn_shape(slice_sizes, dyn))
out = mlir.dynamic_slice(ctx, aval_out, x, start_indices=start_indices)
return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)]
@@ -2249,12 +2202,12 @@ def _gather_transpose_rule(t, operand, indices, *, dimension_numbers,
def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers,
slice_sizes, unique_indices, indices_are_sorted,
mode, fill_value):
- operand, indices, *dyn_slice_sizes = batched_args
- operand_bdim, indices_bdim, *dyn_slice_size_bds = batch_dims
- dyn_slice_size_bounds = [b.dtype.bound for b in dyn_slice_sizes]
+ operand, indices = batched_args
+ operand_bdim, indices_bdim = batch_dims
if operand_bdim is not None and indices_bdim is None:
- operand, operand_bdim = batching.move_stacked_axis(operand, operand_bdim, 0)
+ operand = batching.moveaxis(operand, operand_bdim, 0)
+ operand_bdim = 0
slice_sizes = (operand.shape[0],) + slice_sizes
offset_dims = (0,) + tuple(np.add(1, dimension_numbers.offset_dims))
collapsed_slice_dims = tuple(np.add(1, dimension_numbers.collapsed_slice_dims))
@@ -2269,29 +2222,10 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers,
operand_batching_dims=operand_batching_dims,
start_indices_batching_dims=dimension_numbers.start_indices_batching_dims,
)
- if isinstance(operand_bdim, batching.RaggedAxis):
- ragged_slice_sizes = batching.bdim_as_shape(operand_bdim, slice_sizes)
- for orig, fabricated in zip(
- lax._merge_dyn_shape(slice_sizes, dyn_slice_sizes),
- ragged_slice_sizes):
- if isinstance(fabricated, batching.IndexedAxisSize):
- if not core.same_referent(orig, fabricated.lengths):
- # Don't know what to do when slicing a ragged dimension with a
- # different size. To wit, if the client tries to index outside the
- # ragged size, the resulting element should be determined by the
- # out of bounds `mode`, but the underlying gather will only do that
- # if the client tries to index outside the _padded_ array. I guess
- # we should read the mode and apply a mask that writes the correct
- # fill element into all out-of-bounds locations?
- raise NotImplementedError
- bdim_out = batching.shape_as_bdim(
- operand_bdim.stacked_axis,
- _gather_shape_computation(indices, dnums, ragged_slice_sizes))
- else:
- bdim_out = operand_bdim
+ bdim_out = operand_bdim
return gather(
operand, indices, dimension_numbers=dnums,
- slice_sizes=lax._merge_dyn_shape(slice_sizes, dyn_slice_size_bounds),
+ slice_sizes=slice_sizes,
unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted, mode=mode,
fill_value=fill_value), bdim_out
@@ -2348,18 +2282,6 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers,
indices_are_sorted=indices_are_sorted, mode=mode,
fill_value=fill_value), 0
-def _gather_pad_rule(in_avals, out_avals, operand, indices, *,
- dimension_numbers, slice_sizes, unique_indices,
- indices_are_sorted, mode, fill_value):
- operand_aval, indices_aval = in_avals
- if any(isinstance(d, pe.BoundedAxisSize) for d in operand_aval.shape):
- raise NotImplementedError
- if mode != GatherScatterMode.PROMISE_IN_BOUNDS:
- # with fill, jnp.where on operand; with clip, jnp.where on indices
- raise NotImplementedError
- return [gather(operand, indices, dimension_numbers=dimension_numbers,
- slice_sizes=slice_sizes, mode=mode, fill_value=fill_value)]
-
gather_p = standard_primitive(
_gather_shape_rule, _gather_dtype_rule, 'gather',
weak_type_rule=_argnum_weak_type(0), sharding_rule=_gather_sharding_rule,
@@ -2367,7 +2289,6 @@ def _gather_pad_rule(in_avals, out_avals, operand, indices, *,
ad.defjvp(gather_p, _gather_jvp_rule, None)
ad.primitive_transposes[gather_p] = _gather_transpose_rule
batching.primitive_batchers[gather_p] = _gather_batching_rule
-pe.padding_rules[gather_p] = _gather_pad_rule
def _gather_lower_opaque(ctx, operand, indices, *,
diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py
index e24bea8fb6ff..669ffc510ae0 100644
--- a/jax/_src/lax/utils.py
+++ b/jax/_src/lax/utils.py
@@ -199,11 +199,6 @@ def standard_abstract_eval(
vma=out_vma, memory_space=out_mem_space)
core.check_avals_context_mesh([out_aval], prim.name)
return out_aval
- elif least_specialized is core.DShapedArray:
- shape = shape_rule(*avals, **kwargs)
- ty = (core.ShapedArray if all(type(d) is int for d in shape)
- else core.DShapedArray)
- return ty(shape, dtype_rule(*avals, **kwargs), weak_type)
else:
raise TypeError(avals, least_specialized)
diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py
index 7867286257e0..bf45604c2f5e 100644
--- a/jax/_src/linear_util.py
+++ b/jax/_src/linear_util.py
@@ -77,7 +77,7 @@ def trans1(static_arg, *dynamic_args, **kwargs):
from jax._src import core
from jax._src import traceback_util
from jax._src.tree_util import KeyPath, generate_key_paths, keystr
-from jax._src.util import HashableFunction, curry, fun_name, register_cache
+from jax._src.util import curry, fun_name, register_cache
traceback_util.register_exclusion(__file__)
@@ -405,13 +405,8 @@ def wrap_init(f: Callable, params=None, *, debug_info: DebugInfo) -> WrappedFun:
"""Wraps function `f` as a `WrappedFun`, suitable for transformation."""
params_dict = {} if params is None else params
params = () if params is None else tuple(sorted(params.items()))
+ debug_info = debug_info._replace(result_paths=None)
fun = WrappedFun(f, partial(f, **params_dict), (), (), params, None, debug_info)
- if debug_info.result_paths is initial_result_paths:
- fun, result_paths_thunk = _get_result_paths_thunk(fun)
- debug_info = debug_info._replace(
- result_paths=HashableFunction(result_paths_thunk, closure=()))
- fun = WrappedFun(fun.f, fun.f_transformed, fun.transforms, fun.stores,
- fun.params, fun.in_type, debug_info)
return fun
@@ -421,54 +416,18 @@ def _clean_keystr_arg_names(k: KeyPath) -> str:
res = keystr(k)
return _re_clean_keystr_arg_names.sub(r"\1", res)
-@transformation_with_aux2
-def _get_result_paths_thunk(_fun: Callable, _store: Store, *args, **kwargs):
- ans = _fun(*args, **kwargs)
- result_paths = tuple(f"result{_clean_keystr_arg_names(path)}" for path, _ in generate_key_paths(ans))
- if _store:
- # In some instances a lu.WrappedFun is called multiple times, e.g.,
- # the bwd function in a custom_vjp
- assert _store.val == result_paths, (_store, result_paths)
- else:
- _store.store(result_paths)
- return ans
-
def annotate(f: WrappedFun, in_type: core.InputType | None) -> WrappedFun:
assert f.in_type is None
if in_type is None:
return f
_check_input_type(in_type)
- return WrappedFun(f.f, f.f_transformed, f.transforms, f.stores, f.params, in_type, f.debug_info)
+ return WrappedFun(f.f, f.f_transformed, f.transforms, f.stores, f.params,
+ in_type, f.debug_info)
def _check_input_type(in_type: core.InputType) -> None:
# Check that in_type is syntactically well-formed
- assert type(in_type) is tuple and all(type(e) is tuple for e in in_type)
- assert all(isinstance(a, core.AbstractValue) and type(b) is bool
- for a, b in in_type)
-
- def valid_size(d) -> bool:
- if isinstance(d, core.DBIdx) and type(d.val) is int and d.val >= 0:
- return True
- return (isinstance(d, (int, core.DBIdx, core.DArray)) and
- (not isinstance(d, core.DArray) or type(d) is core.bint and not d.shape))
- assert all(valid_size(d) for a, _ in in_type if type(a) is core.DShapedArray
- for d in a.shape)
-
- # Check that all DBIdx point to positions to the left of the input on which
- # they appear.
- assert all(d.val < i for i, (aval, _) in enumerate(in_type)
- if isinstance(aval, core.DShapedArray) for d in aval.shape
- if isinstance(d, core.DBIdx))
-
- # Check that all implicit arguments have at least one DBIdx pointing to them.
- provided = [e for _, e in in_type]
- for aval, _ in in_type:
- if type(aval) is core.DShapedArray:
- for d in aval.shape:
- if isinstance(d, core.DBIdx):
- provided[d.val] = True
- assert all(provided)
-
+ assert type(in_type) is tuple
+ assert all(isinstance(a, core.AbstractValue) for a in in_type)
def cache(call: Callable, *,
explain: Callable[[WrappedFun, bool, dict, tuple, float], None] | None = None):
diff --git a/jax/_src/literals.py b/jax/_src/literals.py
index 237072f0e606..5aed0f3c3256 100644
--- a/jax/_src/literals.py
+++ b/jax/_src/literals.py
@@ -51,6 +51,9 @@ def __new__(cls, value: float, dtype: np.dtype):
def __repr__(self):
return f'TypedFloat({float(self)}, dtype={self.dtype.name})'
+ def __str__(self):
+ return str(float(self))
+
def __getnewargs__(self):
return (float(self), self.dtype)
diff --git a/jax/_src/named_sharding.py b/jax/_src/named_sharding.py
index a08e0b51b093..e3ffd1538322 100644
--- a/jax/_src/named_sharding.py
+++ b/jax/_src/named_sharding.py
@@ -515,19 +515,6 @@ def _check_unique_resources(pspec: PartitionSpec, arg_name: str, mesh=None
f' for {mesh_lib.show_axes(multiple_uses)}'),
mesh=mesh, pspec=pspec)
-def check_pspec_mix_axis_type(mesh, pspec):
- for spec in pspec:
- if isinstance(spec, tuple):
- if all(mesh._name_to_type[spec[0]] == mesh._name_to_type[p]
- for p in spec):
- continue
- if any(mesh._name_to_type[p] == AxisType.Manual for p in spec):
- raise ValueError(
- 'Tuple subset of `PartitionSpec` cannot contain `Manual` mixed'
- f' with `Auto` or `Explicit`. Got pspec {pspec} and subset'
- f' {spec} with axis types:'
- f' ({", ".join(str(mesh._name_to_type[p]) for p in spec)})')
-
def _check_mesh_resource_axis(mesh, pspec):
for p in pspec:
if p is PartitionSpec.UNCONSTRAINED or p is None:
@@ -538,7 +525,6 @@ def _check_mesh_resource_axis(mesh, pspec):
raise ValueError(
f"Resource axis: {r} of {pspec} "
f"is not found in mesh: {tuple(mesh.shape.keys())}.")
- check_pspec_mix_axis_type(mesh, pspec)
if (AxisType.Auto not in mesh.axis_types and
PartitionSpec.UNCONSTRAINED in pspec):
raise ValueError(
diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py
index 98c1b7e13c82..050ec7f28f3a 100644
--- a/jax/_src/nn/functions.py
+++ b/jax/_src/nn/functions.py
@@ -880,7 +880,7 @@ def _get_padding_mask_encoded(T, q_seqlen):
def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen,
local_window_size):
- if mask is None and not is_causal and q_seqlen is None and kv_seqlen is None:
+ if mask is None and not is_causal and q_seqlen is None and kv_seqlen is None and local_window_size is None:
return logits
combined_mask = jnp.ones_like(logits, dtype=bool)
diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py
index 320b95548fb4..e0496d157b67 100644
--- a/jax/_src/numpy/array_methods.py
+++ b/jax/_src/numpy/array_methods.py
@@ -557,9 +557,10 @@ def _view(self: Array, dtype: DTypeLike | None = None, type: None = None) -> Arr
if lax_numpy.issubdtype(self.dtype, np.complexfloating):
new_shape = (*self.shape[:-1], self.shape[-1] * 2)
new_dtype = lax_numpy.finfo(self.dtype).dtype
- self = (array_creation.zeros(new_shape, new_dtype)
- .at[..., 0::2].set(self.real)
- .at[..., 1::2].set(self.imag))
+ new_sharding = core.typeof(self).sharding
+ self = (array_creation.zeros(new_shape, new_dtype, out_sharding=new_sharding)
+ .at[..., 0::2].set(self.real)
+ .at[..., 1::2].set(self.imag))
return _view(self, dtype)
if dtype == bool:
@@ -1211,7 +1212,6 @@ def _set_array_abstract_methods(basearray):
def register_jax_array_methods():
"""Call this function once to register methods of JAX arrays"""
_set_shaped_array_attributes(core.ShapedArray)
- _set_shaped_array_attributes(core.DShapedArray)
_set_array_base_attributes(ArrayImpl, exclude={'__getitem__'})
_set_tracer_aval_forwarding(core.Tracer, exclude={*_impl_only_array_methods, "at"})
diff --git a/jax/_src/numpy/einsum.py b/jax/_src/numpy/einsum.py
index 761756f56780..b8f72081df62 100644
--- a/jax/_src/numpy/einsum.py
+++ b/jax/_src/numpy/einsum.py
@@ -20,7 +20,6 @@
import opt_einsum
from jax._src import api
-from jax._src import config
from jax._src import core
from jax._src import dtypes
from jax._src.export import shape_poly
@@ -537,7 +536,7 @@ def filter_singleton_dims(operand, names, other_shape, other_names):
# NOTE(mattjj): this can fail non-deterministically in python3, maybe
# due to opt_einsum
- assert config.dynamic_shapes.value or all(
+ assert all(
name in lhs_names and name in rhs_names and
lhs.shape[lhs_names.index(name)] == rhs.shape[rhs_names.index(name)]
for name in contracted_names), (
diff --git a/jax/_src/numpy/indexing.py b/jax/_src/numpy/indexing.py
index 039bed953c20..9f808540d85f 100644
--- a/jax/_src/numpy/indexing.py
+++ b/jax/_src/numpy/indexing.py
@@ -15,12 +15,14 @@
# pytype: skip-file
"""Indexing code for jax.numpy."""
+from __future__ import annotations
+
+import dataclasses
import enum
from functools import partial
import operator
import string
-from typing import Any, NamedTuple, cast
-from types import EllipsisType
+from typing import Any, NamedTuple
from collections.abc import Sequence
import numpy as np
@@ -36,6 +38,7 @@
from jax._src.lax import lax
from jax._src.lax import slicing
from jax._src.lax import utils as lax_utils
+from jax._src.numpy import array_constructors
from jax._src.numpy import einsum
from jax._src.numpy import error as jnp_error
from jax._src.numpy import lax_numpy
@@ -44,13 +47,384 @@
from jax._src.partition_spec import PartitionSpec
from jax._src.pjit import auto_axes
from jax._src.sharding_impls import canonicalize_sharding, NamedSharding
-from jax._src.tree_util import tree_flatten
-from jax._src.typing import Array, ArrayLike, Index, StaticIndex, StaticScalar
-from jax._src.util import canonicalize_axis, safe_zip, set_module, tuple_update
+from jax._src.tree_util import tree_flatten, tree_unflatten, register_pytree_node_class
+from jax._src.typing import Array, ArrayLike, Index, StaticScalar
+from jax._src.util import canonicalize_axis, safe_zip, set_module, tuple_update, unzip3
export = set_module('jax.numpy')
+# Internal utilities for parsing and validating NumPy-style indices.
+
+class IndexType(enum.Enum):
+ """Enum for tracking the type of an index."""
+ NONE = "none"
+ SLICE = "slice"
+ ELLIPSIS = "ellipsis"
+ INTEGER = "integer"
+ BOOLEAN = "boolean"
+ ARRAY = "array"
+
+ @classmethod
+ def from_index(cls, idx: Index) -> IndexType:
+ """Create an IndexType enum from a supported JAX array index."""
+ if idx is None:
+ return cls.NONE
+ elif idx is Ellipsis:
+ return cls.ELLIPSIS
+ elif isinstance(idx, slice):
+ return cls.SLICE
+ elif _is_integer_index(idx):
+ return cls.INTEGER
+ elif _is_boolean_index(idx):
+ return cls.BOOLEAN
+ elif isinstance(idx, (Array, np.ndarray, literals.TypedNdArray)):
+ if dtypes.issubdtype(idx.dtype, np.integer):
+ return cls.ARRAY
+ else:
+ raise TypeError(
+ f"Indexer must have integer or boolean type, got indexer with type {idx.dtype}")
+ elif isinstance(idx, str):
+ # TODO(jakevdp): this TypeError is for backward compatibility.
+ # We should switch to IndexError for consistency.
+ raise TypeError(f"JAX does not support string indexing; got {idx=}")
+ elif isinstance(idx, Sequence):
+ if not idx: # empty indices default to float, so special-case this.
+ return cls.ARRAY
+ idx_aval = api.eval_shape(array_constructors.asarray, idx)
+ if idx_aval.dtype == bool:
+ return cls.BOOLEAN
+ elif dtypes.issubdtype(idx_aval.dtype, np.integer):
+ return cls.ARRAY
+ else:
+ raise TypeError(
+ f"Indexer must have integer or boolean type, got indexer with type {idx_aval.dtype}")
+ elif isinstance(idx, (float, complex, np.generic)):
+ raise TypeError(
+ f"Indexer must have integer or boolean type, got indexer with type {np.dtype(type(idx))}")
+ else:
+ raise IndexError("only integers, slices (`:`), ellipsis (`...`), newaxis (`None`)"
+ f" and integer or boolean arrays are valid indices. Got {idx}")
+
+
+class ParsedIndex(NamedTuple):
+ """Structure for tracking an indexer parsed within the context of an array shape."""
+ index: Index # type: ignore[assignment] # seems to be a strange misfire by mypy.
+ typ: IndexType
+ consumed_axes: tuple[int, ...]
+
+
+def _parse_indices(
+ indices: tuple[Index, ...],
+ shape: tuple[int, ...],
+) -> list[ParsedIndex]:
+ """Parse indices in the context of an array shape.
+
+ Args:
+ indices: a tuple of user-supplied indices to be parsed.
+ shape: the shape of the array being indexed.
+
+ Returns:
+ The list of parsed indices stored in :class:`ParsedIndex` objects.
+ This list will have the same length as ``indices``.
+
+ Raises:
+ IndexError: if any unrecognized index types are present or if there
+ are too many indices, or too many ellipses.
+ """
+ # 1. go through indices to count the number of consumed dimensions.
+ # This is required to determine the effect of any ellipses.
+ dimensions_consumed: list[int] = []
+ ellipses_indices: list[int] = []
+ index_types: list[IndexType] = []
+ for i, idx in enumerate(indices):
+ typ = IndexType.from_index(idx)
+ index_types.append(typ)
+
+ if typ == IndexType.NONE:
+ dimensions_consumed.append(0)
+ elif typ == IndexType.ELLIPSIS:
+ # We don't yet know how many dimensions are consumed, so set to zero
+ # for now and update later.
+ dimensions_consumed.append(0)
+ ellipses_indices.append(i)
+ elif typ == IndexType.BOOLEAN:
+ dimensions_consumed.append(np.ndim(idx)) # type: ignore[arg-type]
+ elif typ in [IndexType.INTEGER, IndexType.ARRAY, IndexType.SLICE]:
+ dimensions_consumed.append(1)
+ else:
+ raise IndexError(f"Unrecognized index type: {typ}")
+
+ # 2. Validate the consumed dimensions and ellipses.
+ if len(ellipses_indices) > 1:
+ raise IndexError("an index can only have a single ellipsis ('...')")
+ total_consumed = sum(dimensions_consumed)
+ if total_consumed > len(shape):
+ raise IndexError(f"Too many indices: array is {len(shape)}-dimensional,"
+ f" but {total_consumed} were indexed")
+ if ellipses_indices:
+ dimensions_consumed[ellipses_indices[0]] = len(shape) - total_consumed
+
+ # 3. Generate the final sequence of parsed indices.
+ result: list[ParsedIndex] = []
+ current_dim = 0
+ for index, typ, n_consumed in safe_zip(indices, index_types, dimensions_consumed):
+ consumed_axes = tuple(range(current_dim, current_dim + n_consumed))
+ current_dim += len(consumed_axes)
+ result.append(ParsedIndex(index=index, typ=typ, consumed_axes=consumed_axes))
+ return result
+
+
+@register_pytree_node_class
+@dataclasses.dataclass(frozen=True, kw_only=True)
+class NDIndexer:
+ """Object that implements NumPy-style indexing operations on top of JAX.
+
+ Generally this will be constructed via the :meth:`NDIndexer.from_raw_indices`
+ method.
+
+ Attributes:
+ shape: the shape of the array being indexed.
+ indices: a list of :class:`ParsedIndex` objects.
+ """
+ shape: tuple[int, ...]
+ indices: list[ParsedIndex]
+
+ @classmethod
+ def from_raw_indices(cls, indices: Index | tuple[Index, ...], shape: tuple[int, ...]) -> NDIndexer:
+ """Create an NDIndexer object from raw user-supplied indices."""
+ indices = eliminate_deprecated_list_indexing(indices)
+ indices = _parse_indices(indices, shape)
+ return cls(shape=shape, indices=indices)
+
+ def validate_static_indices(self, normalize_indices: bool = True) -> None:
+ """Check that all static integer indices are in-bounds.
+
+ Raises an IndexError in case of out-of-bound indices
+ """
+ for position, idx in enumerate(self.indices):
+ if idx.typ == IndexType.INTEGER:
+ assert isinstance(idx.index, (int, np.integer))
+ i = operator.index(idx.index)
+ axis, = idx.consumed_axes
+ size = self.shape[axis]
+ normed_idx = i + size if normalize_indices and i < 0 else i
+ if not 0 <= normed_idx < size:
+ raise IndexError(f"index {i} out of bounds for axis {axis} with size {size}"
+ f" ({normalize_indices=})")
+
+ def validate_slices(self) -> None:
+ """Check that all slices have static start/stop/step values.
+
+ Raises an IndexError in case of non-static entries.
+ """
+ for position, idx in enumerate(self.indices):
+ if idx.typ == IndexType.SLICE:
+ assert isinstance(idx.index, slice)
+ if not all(_is_slice_element_none_or_constant_or_symbolic(val)
+ for val in [idx.index.start, idx.index.stop, idx.index.step]):
+ raise IndexError("Slice entries must be static integers."
+ f" Got {idx.index} at position {position}")
+
+ def expand_bool_indices(self) -> NDIndexer:
+ """Returns a new NDIndexer with boolean indices replaced by array indices.
+
+ The only exception are scalar boolean indices, which are left in-place.
+ """
+ expanded_indices: list[ParsedIndex] = []
+
+ for position, idx in enumerate(self.indices):
+ if idx.typ != IndexType.BOOLEAN:
+ expanded_indices.append(idx)
+ continue
+ if not core.is_concrete(idx.index):
+ # TODO(mattjj): improve this error by tracking _why_ the indices are not concrete
+ raise errors.NonConcreteBooleanIndexError(core.get_aval(idx.index))
+ assert isinstance(idx.index, (bool, np.ndarray, Array, literals.TypedNdArray, list))
+ if np.ndim(idx.index) == 0:
+ # Scalar booleans
+ assert idx.consumed_axes == ()
+ expanded_indices.append(ParsedIndex(index=bool(idx.index), typ=idx.typ, consumed_axes=()))
+ continue
+ idx_shape = np.shape(idx.index)
+ expected_shape = [self.shape[i] for i in idx.consumed_axes]
+ if not all(s1 in (0, s2) for s1, s2 in zip(idx_shape, expected_shape)):
+ raise IndexError("boolean index did not match shape of indexed array in index"
+ f" {position}: got {idx_shape}, expected {expected_shape}")
+ expanded_indices_raw = np.where(np.asarray(idx.index))
+ expanded_indices.extend(ParsedIndex(index=i, typ=IndexType.ARRAY, consumed_axes=(axis,))
+ for i, axis in safe_zip(expanded_indices_raw, idx.consumed_axes))
+ return NDIndexer(shape=self.shape, indices=expanded_indices)
+
+ def expand_scalar_bool_indices(self, sharding_spec: Any = None) -> tuple[NDIndexer, Any]:
+ new_shape = list(self.shape)
+ new_sharding_spec = list((None for _ in self.shape) if sharding_spec is None else sharding_spec)
+ new_indices = list(self.indices)
+ current_dim = 0
+ for i, idx in enumerate(self.indices):
+ if idx.typ == IndexType.BOOLEAN and np.ndim(idx.index) == 0: # type: ignore[arg-type]
+ new_shape.insert(i, 1)
+ new_sharding_spec.insert(i, None)
+ new_indices[i] = ParsedIndex(
+ np.arange(int(idx.index)), typ=IndexType.ARRAY, consumed_axes=(current_dim,)) # type: ignore[arg-type]
+ current_dim += 1
+ else:
+ n_consumed = len(idx.consumed_axes)
+ new_indices[i] = ParsedIndex(
+ index=idx.index,
+ typ=idx.typ,
+ consumed_axes = tuple(range(current_dim, current_dim + n_consumed))
+ )
+ current_dim += n_consumed
+ new_sharding_spec = None if sharding_spec is None else tuple(new_sharding_spec)
+ return NDIndexer(indices=new_indices, shape=tuple(new_shape)), new_sharding_spec
+
+ def convert_sequences_to_arrays(self) -> NDIndexer:
+ new_indices = [ParsedIndex(lax_numpy.asarray(idx.index), typ=idx.typ, consumed_axes=idx.consumed_axes)
+ if isinstance(idx.index, Sequence) else idx for idx in self.indices]
+ return NDIndexer(indices=new_indices, shape=self.shape)
+
+ def expand_ellipses(self) -> NDIndexer:
+ """
+ Returns a new indexer with ellipsis and implicit trailing slices
+ replaced by explicit empty slices.
+ """
+ expanded: list[ParsedIndex] = []
+ consumed = 0
+ for idx in self.indices:
+ consumed += len(idx.consumed_axes)
+ if idx.typ == IndexType.ELLIPSIS:
+ for axis in idx.consumed_axes:
+ expanded.append(ParsedIndex(index=slice(None), typ=IndexType.SLICE, consumed_axes=(axis,)))
+ else:
+ expanded.append(idx)
+ for axis in range(consumed, len(self.shape)):
+ expanded.append(ParsedIndex(index=slice(None), typ=IndexType.SLICE, consumed_axes=(axis,)))
+ return NDIndexer(shape=self.shape, indices=expanded)
+
+ def normalize_indices(self) -> NDIndexer:
+ new_indices: list[ParsedIndex] = []
+ for idx in self.indices:
+ if idx.typ == IndexType.INTEGER:
+ axis, = idx.consumed_axes
+ size: ArrayLike = self.shape[axis]
+ if isinstance(idx.index, np.unsignedinteger):
+ normed_index: Index = idx.index
+ else:
+ normed_index = idx.index + size if idx.index < 0 else idx.index # type: ignore[assignment,operator]
+ new_indices.append(ParsedIndex(normed_index, typ=idx.typ, consumed_axes=idx.consumed_axes))
+ elif idx.typ == IndexType.ARRAY:
+ assert isinstance(idx.index, (Array, np.ndarray, literals.TypedNdArray))
+ axis, = idx.consumed_axes
+ if dtypes.issubdtype(idx.index.dtype, np.unsignedinteger):
+ normed_index = idx.index
+ else:
+ size = self.shape[axis]
+ if core.is_constant_dim(size):
+ size = lax._const(idx.index, size)
+ else:
+ size = lax.convert_element_type(core.dimension_as_value(size),
+ idx.index.dtype)
+ normed_index = lax.select(idx.index < 0, lax.add(idx.index, size), idx.index)
+ new_indices.append(ParsedIndex(normed_index, typ=idx.typ, consumed_axes=idx.consumed_axes))
+ else:
+ new_indices.append(idx)
+ return NDIndexer(indices=new_indices, shape=self.shape)
+
+ def compute_via_static_slice(self, arr: Array) -> Array:
+ """Equivalent of arr[idx] implemented in terms of static :func:`lax.slice` operations.
+
+ This supports only INTEGER, ELLIPSIS, and SLICE indices, and will raise a TypeError
+ if other indices are present.
+ """
+ # Validation of the unmodified user indices.
+ self.validate_static_indices(normalize_indices=True)
+ self.validate_slices()
+
+ for position, pidx in enumerate(self.indices):
+ if pidx.typ in [IndexType.INTEGER, IndexType.ELLIPSIS, IndexType.SLICE]:
+ pass
+ elif pidx.typ == IndexType.NONE:
+ raise TypeError(f"static_slice: got {pidx.index} at position {position}")
+ elif pidx.typ in [IndexType.ARRAY, IndexType.BOOLEAN]:
+ raise TypeError("static_slice: indices must be static scalars or slices."
+ f" Got {pidx.index} at position {position}")
+ else:
+ raise TypeError(f"static_slice: unrecognized index {pidx.index} at position {position}.")
+
+ # Now re-iterate to generate static slices.
+ start_indices: list[int] = []
+ limit_indices: list[int] = []
+ strides: list[int] = []
+ rev_axes: list[int] = []
+ squeeze_axes: list[int] = []
+
+ expanded = self.expand_ellipses()
+ for pidx in expanded.indices:
+ if pidx.typ in [IndexType.ARRAY, IndexType.BOOLEAN, IndexType.NONE, IndexType.ELLIPSIS]:
+ raise RuntimeError(f"Internal: unexpected index encountered: {pidx}")
+ elif pidx.typ == IndexType.INTEGER:
+ assert isinstance(pidx.index, (int, np.integer))
+ axis, = pidx.consumed_axes
+ start_index = int(pidx.index + arr.shape[axis] if pidx.index < 0 else pidx.index)
+ start_indices.append(start_index)
+ limit_indices.append(start_index + 1)
+ strides.append(1)
+ squeeze_axes.append(axis)
+ elif pidx.typ == IndexType.SLICE:
+ assert isinstance(pidx.index, slice)
+ axis, = pidx.consumed_axes
+ size = arr.shape[axis]
+ start, stop, stride = pidx.index.indices(size)
+ if stride < 0:
+ new_start = stop + 1 + abs(start - stop - 1) % abs(stride)
+ start_indices.append(new_start)
+ limit_indices.append(max(new_start, start + 1))
+ strides.append(abs(stride))
+ rev_axes.append(axis)
+ else:
+ start_indices.append(start)
+ limit_indices.append(stop)
+ strides.append(stride)
+ else:
+ raise TypeError(f"static_slice: unrecognized index {pidx.index}")
+ result = arr
+ if start_indices:
+ result = slicing.slice(result, start_indices, limit_indices, strides)
+ if rev_axes:
+ result = lax.rev(result, rev_axes)
+ if squeeze_axes:
+ result = lax.squeeze(result, squeeze_axes)
+ return result
+
+ def is_advanced_int_indexer(self):
+ """Returns True if idx should trigger int array indexing, False otherwise."""
+ # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
+ return any(idx.typ in [IndexType.ARRAY, IndexType.BOOLEAN] and np.ndim(idx.index) > 0
+ for idx in self.indices)
+
+ def to_gather(self, x_sharding: NamedSharding | Any,
+ normalize_indices: bool = True) -> _GatherIndexer:
+ return _index_to_gather(self, x_sharding=x_sharding, normalize_indices=normalize_indices)
+
+ def tree_flatten(self):
+ # split dynamic and static indices
+ def is_dynamic(i: ParsedIndex):
+ return i.typ in [IndexType.INTEGER, IndexType.ARRAY, IndexType.BOOLEAN]
+ raw_dynamic_indices = [i.index if is_dynamic(i) else None for i in self.indices]
+ static_metadata = [
+ ParsedIndex(index=None, typ=i.typ, consumed_axes=i.consumed_axes) if is_dynamic(i) else i
+ for i in self.indices]
+ return raw_dynamic_indices, (self.shape, static_metadata)
+
+ @classmethod
+ def tree_unflatten(cls, aux_data, children):
+ shape, static_metadata = aux_data
+ indices = [idx if dyn_index is None else ParsedIndex(dyn_index, idx.typ, idx.consumed_axes)
+ for dyn_index, idx in safe_zip(children, static_metadata)]
+ return cls(indices=indices, shape=shape)
+
+
@export
def take(
a: ArrayLike,
@@ -532,12 +906,14 @@ def _is_contiguous_slice(idx):
(idx.step is None or (_is_integer_index(idx.step) and idx.step == 1)))
def _attempt_rewriting_take_via_slice(
- arr: Array,
- idx: Index | tuple[Index, ...], *,
+ arr: Array, indexer: NDIndexer, *,
mode: str | slicing.GatherScatterMode | None,
out_sharding: NamedSharding | PartitionSpec | None = None) -> Array | None:
# attempt to compute _rewriting_take via lax.slice(); return None if not possible.
- idx = idx if isinstance(idx, tuple) else (idx,)
+
+ # TODO(jakevdp): update implementation to use indexer directly, and to reuse code
+ # from compute_via_static_slice
+ idx = tuple(i.index for i in indexer.indices)
if not all(isinstance(i, int) for i in arr.shape):
return None
@@ -598,7 +974,7 @@ def _attempt_rewriting_take_via_slice(
allow_negative_indices.append(start < 0 or stop < 0)
else:
assert np.issubdtype(dtypes.dtype(ind), np.integer) # checked above
- assert np.shape(ind) == () # checked above
+ assert np.shape(ind) == () # type: ignore[arg-type] # checked above
start_indices.append(ind)
slice_sizes.append(1)
allow_negative_indices.append(
@@ -635,72 +1011,6 @@ def _attempt_rewriting_take_via_slice(
return arr
-def static_slice(arr: Array, idx: StaticIndex | tuple[StaticIndex, ...]):
- """Compute NumPy-style indexing for static slices only."""
- idx = idx if isinstance(idx, tuple) else (idx,)
-
- # First validate the types of entries before expanding ellipses: this allows
- # error messages to point to particular positions supplied by the user.
- # Valid index types here are integers, ellipses, and slices.
- for position, ind in enumerate(idx):
- if isinstance(ind, (int, np.integer, EllipsisType)):
- pass
- elif isinstance(ind, slice):
- if not all(val is None or isinstance(val, (int, np.integer))
- for val in [ind.start, ind.stop, ind.step]):
- raise ValueError("Slice entries must be static integers."
- f" Got {ind} at position {position}")
- elif ind is None:
- raise TypeError(f"static_slice: got {ind} at position {position}")
- elif isinstance(ind, (np.ndarray, Array, tuple, list, Sequence)):
- raise TypeError("static_slice: indices must be static scalars or slices."
- f" Got {ind} at position {position}")
- else:
- raise TypeError("static_slice: unrecognized index {ind} at position {position}.")
-
- # Now expand ellipses and validate the index values. This allows error messages
- # to point to relevant array dimensions.
- idx = _canonicalize_tuple_index(arr.ndim, idx)
- start_indices: list[int] = []
- limit_indices: list[int] = []
- strides: list[int] = []
- rev_axes: list[int] = []
- squeeze_axes: list[int] = []
-
- for axis, (ind, size) in enumerate(safe_zip(idx, arr.shape)):
- if isinstance(ind, (int, np.integer)):
- if not (-size <= ind < size):
- raise IndexError(f"index {ind} out of bounds for axis {axis} with size {size}")
- if ind < 0:
- ind += size
- start_indices.append(ind)
- limit_indices.append(ind + 1)
- strides.append(1)
- squeeze_axes.append(axis)
- elif isinstance(ind, slice):
- start, stop, stride = ind.indices(size)
- if stride < 0:
- new_start = stop + 1 + abs(start - stop - 1) % abs(stride)
- start_indices.append(new_start)
- limit_indices.append(max(new_start, start + 1))
- strides.append(abs(stride))
- rev_axes.append(axis)
- else:
- start_indices.append(start)
- limit_indices.append(stop)
- strides.append(stride)
- else:
- raise ValueError(f"Unexpected index: {ind} at axis {axis}")
-
- if start_indices:
- result = slicing.slice(arr, start_indices, limit_indices, strides)
- if rev_axes:
- result = lax.rev(result, rev_axes)
- if squeeze_axes:
- result = lax.squeeze(result, squeeze_axes)
- return result
-
-
class IndexingStrategy(enum.Enum):
AUTO = 'auto'
GATHER = 'gather'
@@ -722,40 +1032,31 @@ def rewriting_take(
# Computes arr[idx].
# All supported cases of indexing can be implemented as an XLA gather,
# followed by an optional reverse and broadcast_in_dim.
+ indexer = NDIndexer.from_raw_indices(idx, arr.shape)
if not isinstance(strategy, IndexingStrategy):
raise TypeError(f"Expected strategy to be IndexingStrategy; got {strategy}")
+ if config.check_static_indices.value and (mode is None or slicing.GatherScatterMode.from_any(mode) == slicing.GatherScatterMode.PROMISE_IN_BOUNDS):
+ indexer.validate_static_indices(normalize_indices=normalize_indices)
+
if strategy == IndexingStrategy.STATIC_SLICE:
if not normalize_indices:
raise ValueError("strategy=STATIC_SLICE is only supported when normalize_indices=True.")
- return static_slice(arr, cast(StaticIndex | tuple[StaticIndex, ...], idx))
+ return indexer.compute_via_static_slice(arr)
# For simplicity of generated primitives, we call lax.slice or lax.dynamic_slice
# in the simplest cases: i.e. non-dynamic arrays indexed with integers and slices.
# TODO(jakevdp): lower to slice even when normalize_indices is False
if strategy == IndexingStrategy.AUTO and normalize_indices:
- result = _attempt_rewriting_take_via_slice(arr, idx, mode=mode, out_sharding=out_sharding)
+ result = _attempt_rewriting_take_via_slice(arr, indexer, mode=mode, out_sharding=out_sharding)
if result is not None:
return result
- # otherwise, strategy is GATHER or SCATTER
-
- # TODO(mattjj,dougalm): expand dynamic shape indexing support
- if config.dynamic_shapes.value and arr.ndim > 0:
- try: aval = core.get_aval(idx)
- except: pass
- else:
- if (isinstance(aval, core.DShapedArray) and aval.shape == () and
- dtypes.issubdtype(aval.dtype, np.integer) and
- not dtypes.issubdtype(aval.dtype, dtypes.bool_) and
- isinstance(arr.shape[0], int)):
- assert isinstance(idx, (int, Array))
- return slicing.dynamic_index_in_dim(arr, idx, keepdims=False)
-
- treedef, static_idx, dynamic_idx = split_index_for_jit(idx, arr.shape)
+ indexer = indexer.expand_bool_indices()
+ dynamic_idx, treedef = tree_flatten(indexer)
internal_gather = partial(
- _gather, treedef=treedef, static_idx=static_idx,
+ _gather, treedef=treedef,
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
mode=mode, fill_value=fill_value, normalize_indices=normalize_indices)
if out_sharding is not None:
@@ -769,12 +1070,11 @@ def rewriting_take(
# TODO(phawkins): re-enable jit after fixing excessive recompilation for
# slice indexes (e.g., slice(0, 5, None), slice(10, 15, None), etc.).
# @api.jit(static_argnums=(1, 2))
-def _gather(arr, dynamic_idx, *, treedef, static_idx, indices_are_sorted,
+def _gather(arr, dynamic_idx, *, treedef, indices_are_sorted,
unique_indices, mode, fill_value, normalize_indices):
- idx = merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
- indexer = index_to_gather(
- np.shape(arr), idx, core.typeof(arr).sharding,
- normalize_indices=normalize_indices) # shared with _scatter_update
+ parsed_idx = tree_unflatten(treedef, dynamic_idx)
+ indexer = parsed_idx.to_gather(core.typeof(arr).sharding,
+ normalize_indices=normalize_indices)
jnp_error._check_precondition_oob_gather(arr.shape, indexer.gather_indices)
y = arr
@@ -809,7 +1109,7 @@ def _gather(arr, dynamic_idx, *, treedef, static_idx, indices_are_sorted,
return lax.expand_dims(y, indexer.newaxis_dims)
-class _Indexer(NamedTuple):
+class _GatherIndexer(NamedTuple):
# The expected shape of the slice output.
slice_shape: Sequence[int]
# The slice shape to pass to lax.gather().
@@ -841,123 +1141,43 @@ class _Indexer(NamedTuple):
slice_sharding: NamedSharding | None = None
-def split_index_for_jit(idx, shape):
- """Splits indices into necessarily-static and dynamic parts.
-
- Used to pass indices into `jit`-ted function.
- """
- # Convert list indices to tuples in cases (deprecated by NumPy.)
- idx = eliminate_deprecated_list_indexing(idx)
- if any(isinstance(i, str) for i in idx):
- raise TypeError(f"JAX does not support string indexing; got {idx=}")
-
- # Expand any (concrete) boolean indices. We can then use advanced integer
- # indexing logic to handle them.
- idx = _expand_bool_indices(idx, shape)
-
- leaves, treedef = tree_flatten(idx)
- dynamic = [None] * len(leaves)
- static = [None] * len(leaves)
- for i, x in enumerate(leaves):
- if x is Ellipsis:
- static[i] = x
- elif isinstance(x, slice):
- # slice objects aren't hashable.
- static[i] = (x.start, x.stop, x.step)
- else:
- dynamic[i] = x
- return treedef, tuple(static), dynamic
-
-def merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx):
- """Recombines indices that were split by split_index_for_jit."""
- idx = []
- for s, d in zip(static_idx, dynamic_idx):
- if d is not None:
- idx.append(d)
- elif isinstance(s, tuple):
- idx.append(slice(s[0], s[1], s[2]))
- else:
- idx.append(s)
- return treedef.unflatten(idx)
-
-def _int(aval):
- return not aval.shape and dtypes.issubdtype(aval.dtype, np.integer)
-
-def _aval_or_none(x):
- try:
- return core.get_aval(x)
- except:
- return None
+def _index_to_gather(indexer: NDIndexer, *, x_sharding: NamedSharding | Any,
+ normalize_indices: bool = True) -> _GatherIndexer:
+ indexer.validate_slices()
+ indexer = indexer.convert_sequences_to_arrays()
-def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
- x_sharding, normalize_indices: bool = True) -> _Indexer:
- # Convert sequences to arrays
- idx = tuple(lax_numpy.asarray(i, dtype=None if i else int)
- if isinstance(i, Sequence) else i for i in idx)
- abstract_idx = [_aval_or_none(i) for i in idx]
- float_indices = [(i, val, aval) for i, (val, aval) in enumerate(zip(idx, abstract_idx))
- if aval is not None and dtypes.issubdtype(aval, np.inexact)]
-
- # Check for float or complex indices:
- if float_indices:
- i, val, aval = float_indices[0]
- msg = ("Indexer must have integer or boolean type, got indexer "
- "with type {} at position {}, indexer value {}")
- raise TypeError(msg.format(aval.dtype.name, i, val))
-
- # Check whether advanced indices are contiguous. We must do this before
- # removing ellipses (https://github.com/jax-ml/jax/issues/25109)
- # If advanced idexing axes do not appear contiguously, NumPy semantics
- # move the advanced axes to the front.
- (is_advanced,) = np.nonzero([
- isinstance(e, (int, np.integer, Array, np.ndarray,
- literals.TypedNdArray))
- or lax_numpy.isscalar(e)
- for e in idx
- ])
+ is_advanced = np.nonzero([idx.typ in {IndexType.ARRAY, IndexType.INTEGER} for idx in indexer.indices])
advanced_axes_are_contiguous = np.all(np.diff(is_advanced) == 1)
- # Remove ellipses and add trailing slice(None)s.
- idx = _canonicalize_tuple_index(len(x_shape), idx)
+ indexer = indexer.expand_ellipses()
- x_spec = x_sharding.spec
+ scalar_bool_dims: Sequence[int] = [n for n, i in enumerate(indexer.indices) if i.typ == IndexType.BOOLEAN]
+ indexer, x_spec = indexer.expand_scalar_bool_indices(x_sharding.spec)
- # Check for scalar boolean indexing: this requires inserting extra dimensions
- # before performing the rest of the logic.
- scalar_bool_dims: Sequence[int] = [n for n, i in enumerate(idx) if isinstance(i, bool)]
- if scalar_bool_dims:
- idx = tuple(np.arange(int(i)) if isinstance(i, bool) else i for i in idx)
- x_shape = list(x_shape)
- x_spec = list(x_spec)
- for i in sorted(scalar_bool_dims):
- x_shape.insert(i, 1)
- x_spec.insert(i, None)
- x_shape = tuple(x_shape)
- x_spec = tuple(x_spec)
+ if normalize_indices:
+ indexer = indexer.normalize_indices()
# Check for advanced indexing:
# https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
- advanced_indexes: Sequence[Array | np.ndarray] | None = None
+ # The advanced indices.
+ advanced_indexes: Sequence[Array] = []
# The positions of the advanced indexing axes in `idx`.
idx_advanced_axes: Sequence[int] = []
# The positions of the advanced indexes in x's shape.
# collapsed, after None axes have been removed. See below.
- x_advanced_axes: Sequence[int] | None = None
+ x_advanced_axes: Sequence[int] = []
- if _is_advanced_int_indexer(idx):
- idx_no_nones = [(i, d) for i, d in enumerate(idx) if d is not None]
+ if indexer.is_advanced_int_indexer():
+ idx_without_none = [(i, d) for i, d in enumerate(indexer.indices) if d.typ != IndexType.NONE]
advanced_pairs = (
- (lax_numpy.asarray(e), i, j) for j, (i, e) in enumerate(idx_no_nones)
- if lax_numpy.isscalar(e)
- or isinstance(e, (Sequence, Array, np.ndarray,
- literals.TypedNdArray)))
- if normalize_indices:
- advanced_pairs = ((_normalize_index(e, x_shape[j]), i, j)
- for e, i, j in advanced_pairs)
- advanced_indexes, idx_advanced_axes, x_advanced_axes = zip(*advanced_pairs)
+ (lax_numpy.asarray(e.index), i, j)
+ for j, (i, e) in enumerate(idx_without_none)
+ if e.typ in [IndexType.ARRAY, IndexType.INTEGER]
+ )
+ advanced_indexes, idx_advanced_axes, x_advanced_axes = unzip3(advanced_pairs)
x_axis = 0 # Current axis in x.
y_axis = 0 # Current axis in y, before collapsing. See below.
@@ -968,7 +1188,7 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
collapsed_slice_dims: list[int] = []
start_index_map: list[int] = []
- index_dtype = lax_utils.int_dtype_for_shape(x_shape, signed=True)
+ index_dtype = lax_utils.int_dtype_for_shape(indexer.shape, signed=True)
# Gather indices.
# Pairs of (array, start_dim) values. These will be broadcast into
@@ -990,11 +1210,11 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
gather_slice_shape: list[int] = []
slice_spec = []
- for idx_pos, i in enumerate(idx):
+ for idx_pos, index in enumerate(indexer.indices):
# Handle the advanced indices here if:
# * the advanced indices were not contiguous and we are the start.
# * we are at the position of the first advanced index.
- if (advanced_indexes is not None and
+ if (advanced_indexes and
(advanced_axes_are_contiguous and idx_pos == idx_advanced_axes[0] or
not advanced_axes_are_contiguous and idx_pos == 0)):
advanced_index_arrs = util._broadcast_arrays(*advanced_indexes)
@@ -1023,46 +1243,35 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
gather_slice_shape.append(1)
continue
- # Handle basic int indexes.
- abstract_i = _aval_or_none(i)
- if isinstance(abstract_i, core.ShapedArray) and _int(abstract_i):
- if core.definitely_equal(x_shape[x_axis], 0):
+ if index.typ in [IndexType.INTEGER, IndexType.ARRAY] and np.ndim(index.index) == 0: # type: ignore[arg-type]
+ # Basic scalar int indices
+ if core.definitely_equal(indexer.shape[x_axis], 0):
# XLA gives error when indexing into an axis of size 0
raise IndexError(f"index is out of bounds for axis {x_axis} with size 0")
- i = _normalize_index(i, x_shape[x_axis]) if normalize_indices else i
- i_converted = lax.convert_element_type(i, index_dtype)
+ i_converted = lax.convert_element_type(index.index, index_dtype) # type: ignore[arg-type]
gather_indices.append((i_converted, len(gather_indices_shape)))
collapsed_slice_dims.append(x_axis)
gather_slice_shape.append(1)
start_index_map.append(x_axis)
x_axis += 1
- # Handle np.newaxis (None)
- elif i is None:
+
+ elif index.typ == IndexType.NONE:
+ # None indexing: add a dimension.
slice_shape.append(1)
slice_spec.append(None)
newaxis_dims.append(y_axis)
y_axis += 1
- elif isinstance(i, slice):
- # Handle slice index (only static, otherwise an error is raised)
- if not all(_is_slice_element_none_or_constant_or_symbolic(elt)
- for elt in (i.start, i.stop, i.step)):
- msg = ("Array slice indices must have static start/stop/step to be used "
- "with NumPy indexing syntax. "
- f"Found slice({i.start}, {i.stop}, {i.step}). "
- "To index a statically sized "
- "array at a dynamic position, try lax.dynamic_slice/"
- "dynamic_update_slice (JAX does not support dynamically sized "
- "arrays within JIT compiled functions).")
- raise IndexError(msg)
-
- start, step, slice_size = core.canonicalize_slice(i, x_shape[x_axis])
+ elif index.typ == IndexType.SLICE:
+ # Handle static slice index.
+ assert isinstance(index.index, slice)
+ start, step, slice_size = core.canonicalize_slice(index.index, indexer.shape[x_axis])
slice_shape.append(slice_size)
slice_spec.append(x_spec[x_axis])
if core.definitely_equal(step, 1):
- # Avoid generating trivial gather (an optimization)
- if not core.definitely_equal(slice_size, x_shape[x_axis]):
+ # Optimization: avoid generating trivial gather.
+ if not core.definitely_equal(slice_size, indexer.shape[x_axis]):
gather_indices.append((lax.convert_element_type(start, index_dtype),
len(gather_indices_shape)))
start_index_map.append(x_axis)
@@ -1085,14 +1294,7 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
y_axis += 1
x_axis += 1
else:
- if (abstract_i is not None and
- not (dtypes.issubdtype(abstract_i.dtype, np.integer) or dtypes.issubdtype(abstract_i.dtype, np.bool_))):
- msg = ("Indexer must have integer or boolean type, got indexer "
- "with type {} at position {}, indexer value {}")
- raise TypeError(msg.format(abstract_i.dtype.name, idx_pos, i))
-
- raise IndexError("Indexing mode not yet supported. Got unsupported indexer "
- f"at position {idx_pos}: {i!r}")
+ raise IndexError(f"Got unsupported indexer at position {idx_pos}: {index!r}")
if len(gather_indices) == 0:
gather_indices_array: ArrayLike = np.zeros((0,), dtype=index_dtype)
@@ -1113,15 +1315,15 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
start_index_map = tuple(start_index_map)
)
slice_sharding = x_sharding.update(spec=slice_spec)
- return _Indexer(
+ return _GatherIndexer(
slice_shape=slice_shape,
newaxis_dims=tuple(newaxis_dims),
gather_slice_shape=gather_slice_shape,
reversed_y_dims=reversed_y_dims,
dnums=dnums,
gather_indices=gather_indices_array,
- unique_indices=advanced_indexes is None,
- indices_are_sorted=advanced_indexes is None,
+ unique_indices=not advanced_indexes,
+ indices_are_sorted=not advanced_indexes,
scalar_bool_dims=scalar_bool_dims,
slice_sharding=slice_sharding)
@@ -1166,52 +1368,6 @@ def _is_boolean_index(i):
or isinstance(i, list) and i and all(_is_scalar(e)
and dtypes.issubdtype(dtypes.dtype(e), np.bool_) for e in i))
-def _expand_bool_indices(idx, shape):
- """Converts concrete bool indexes into advanced integer indexes."""
- out = []
- total_dims = len(shape)
- num_ellipsis = sum(e is Ellipsis for e in idx)
- if num_ellipsis > 1:
- raise IndexError("an index can only have a single ellipsis ('...')")
- elif num_ellipsis == 1:
- total_dims = sum(np.ndim(e) if _is_boolean_index(e) else 1 for e in idx
- if e is not None and e is not Ellipsis)
- ellipsis_offset = 0
- newaxis_offset = 0
- for dim_number, i in enumerate(idx):
- try:
- abstract_i = core.get_aval(i)
- except TypeError:
- abstract_i = None
- if _is_boolean_index(i):
- if isinstance(i, list):
- i = lax_numpy.array(i)
- abstract_i = core.get_aval(i)
-
- if not core.is_concrete(i):
- # TODO(mattjj): improve this error by tracking _why_ the indices are not concrete
- raise errors.NonConcreteBooleanIndexError(abstract_i)
- elif np.ndim(i) == 0:
- out.append(bool(i))
- else:
- i_shape = np.shape(i)
- start = len(out) + ellipsis_offset - newaxis_offset
- expected_shape = shape[start: start + np.ndim(i)]
- if len(i_shape) != len(expected_shape):
- raise IndexError(f"too many boolean indices at index {dim_number}: got mask of shape "
- f"{i_shape}, but only {len(expected_shape)} dimensions remain.")
- if not all(s1 in (0, s2) for s1, s2 in zip(i_shape, expected_shape)):
- raise IndexError("boolean index did not match shape of indexed array in index "
- f"{dim_number}: got {i_shape}, expected {expected_shape}")
- out.extend(np.where(i))
- else:
- out.append(i)
- if i is Ellipsis:
- ellipsis_offset = len(shape) - total_dims - 1
- if i is None:
- newaxis_offset += 1
- return tuple(out)
-
def _is_slice_element_none_or_constant_or_symbolic(elt):
"""Return True if elt is a constant or None."""
@@ -1222,23 +1378,6 @@ def _is_slice_element_none_or_constant_or_symbolic(elt):
except TypeError:
return False
-# TODO(mattjj): clean up this logic
-def _is_advanced_int_indexer(idx):
- """Returns True if idx should trigger int array indexing, False otherwise."""
- # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
- assert isinstance(idx, tuple)
- if all(e is None or e is Ellipsis or isinstance(e, slice)
- or _is_scalar(e) and dtypes.issubdtype(dtypes.dtype(e), np.integer) for e in idx):
- return False
- return all(e is None or e is Ellipsis or isinstance(e, slice)
- or _is_int_arraylike(e) for e in idx)
-
-def _is_int_arraylike(x):
- """Returns True if x is array-like with integer dtype, False otherwise."""
- return (isinstance(x, int) and not isinstance(x, bool)
- or dtypes.issubdtype(getattr(x, "dtype", None), np.integer)
- or isinstance(x, (list, tuple)) and all(_is_int_arraylike(e) for e in x))
-
def _is_scalar(x):
"""Checks if a Python or NumPy scalar."""
return np.isscalar(x) or (
diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py
index 87a3bb327f2f..96d100e5534b 100644
--- a/jax/_src/numpy/lax_numpy.py
+++ b/jax/_src/numpy/lax_numpy.py
@@ -36,7 +36,6 @@
import numpy as np
from jax._src import api
-from jax._src import config
from jax._src import core
from jax._src import deprecations
from jax._src import dtypes
@@ -143,8 +142,12 @@ def iscomplexobj(x: Any) -> bool:
>>> jnp.iscomplexobj(jnp.array([0, 1+2j]))
True
"""
- if x is None:
+ # Fast path for common types.
+ if isinstance(x, (complex, np.complexfloating)):
+ return True
+ if x is None or isinstance(x, (bool, int, float, str, np.generic)):
return False
+ # Fall back to dtype attribute lookup.
try:
typ = x.dtype.type
except AttributeError:
@@ -4554,7 +4557,8 @@ def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike],
except along the specified axis. If a single array is given it will be
treated equivalently to `arrays = unstack(arrays)`, but the implementation
will avoid explicit unstacking.
- axis: specify the axis along which to concatenate.
+ axis: specify the axis along which to concatenate. If None, the arrays are
+ flattened before concatenation.
dtype: optional dtype of the resulting array. If not specified, the dtype
will be determined via type promotion rules described in :ref:`type-promotion`.
@@ -5954,46 +5958,64 @@ def arange(start: ArrayLike | DimSize, stop: ArrayLike | DimSize | None = None,
def _arange(start: ArrayLike | DimSize, stop: ArrayLike | DimSize | None = None,
step: ArrayLike | None = None, dtype: DTypeLike | None = None,
out_sharding: NamedSharding | None = None) -> Array:
+ # Validate inputs
if dtype is not None:
dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "arange")
- if not config.dynamic_shapes.value:
- util.check_arraylike("arange", start)
- if stop is None and step is None:
- start = core.concrete_or_error(None, start, "It arose in the jnp.arange argument 'stop'")
- else:
- start = core.concrete_or_error(None, start, "It arose in the jnp.arange argument 'start'")
- util.check_arraylike_or_none("arange", None, stop, step)
+ util.check_arraylike_or_none("arange", start, stop, step)
+
+ # Ensure start/stop/step are concrete
+ start_name = "stop" if stop is None and step is None else "start"
+ start = core.concrete_or_error(None, start, f"It arose in the jnp.arange argument '{start_name}'")
stop = core.concrete_or_error(None, stop, "It arose in the jnp.arange argument 'stop'")
step = core.concrete_or_error(None, step, "It arose in the jnp.arange argument 'step'")
- start_name = "stop" if stop is None and step is None else "start"
+
+ # Ensure start/stop/step are scalars
for name, val in [(start_name, start), ("stop", stop), ("step", step)]:
if val is not None and np.ndim(val) != 0:
raise ValueError(f"jax.numpy.arange: arguments must be scalars; got {name}={val}")
+
+ # Handle symbolic dimensions
if any(core.is_symbolic_dim(v) for v in (start, stop, step)):
- # Some dynamic shapes
- if stop is None and step is None:
- stop = start
- start = 0
- step = 1
- elif stop is not None and step is None:
+ if stop is None:
+ start, stop = 0, start
+ if step is None:
step = 1
return _arange_dynamic(start, stop, step, dtype or dtypes.default_int_dtype())
+
if dtype is None:
- dtype = result_type(start, *(x for x in [stop, step] if x is not None))
+ dtype = dtypes.result_type(start, *(x for x in [stop, step] if x is not None))
dtype = dtypes.jax_dtype(dtype)
- if stop is None and step is None:
- start_dtype = _dtype(start)
- if (not dtypes.issubdtype(start_dtype, np.integer) and
- not dtypes.issubdtype(start_dtype, dtypes.extended)):
- ceil_ = ufuncs.ceil if isinstance(start, core.Tracer) else np.ceil
- start = ceil_(start).astype(int)
- return lax.broadcasted_iota(dtype, (start,), 0, out_sharding=out_sharding) # type: ignore[arg-type]
+
+ if iscomplexobj(start) or iscomplexobj(stop) or iscomplexobj(step):
+ deprecations.warn(
+ "jax-numpy-arange-complex",
+ (
+ "Passing complex start/stop/step to jnp.arange is deprecated;"
+ " in the future this will result in a ValueError."
+ ),
+ stacklevel=3
+ )
+ # Complex arange is poorly defined; fall back to NumPy here.
+ # TODO(jakevdp): deprecate the complex case.
+ return array(np.arange(start, stop, step, dtype=dtype), device=out_sharding)
+
+ if step is not None:
+ # arange(N, M, K): when step is specified, fall back to NumPy.
+ return array(np.arange(start, stop, step, dtype=dtype), device=out_sharding)
+
+ if stop is None:
+ start, stop = 0, start
+
+ if start == 0:
+ # arange(M) or arange(0, M)
+ size = max(0, int(np.ceil(stop)))
+ return lax.broadcasted_iota(dtype, (size,), 0, out_sharding=out_sharding)
+
else:
- if step is None and start == 0 and stop is not None:
- return lax.broadcasted_iota(dtype, (np.ceil(stop).astype(int),), 0,
- out_sharding=out_sharding)
- return array(np.arange(start, stop=stop, step=step, dtype=dtype),
- device=out_sharding)
+ # arange(N, M)
+ size = max(0, int(np.ceil(stop - start)))
+ return lax.add(lax.convert_element_type(start, dtype),
+ lax.broadcasted_iota(dtype, (size,), 0, out_sharding=out_sharding))
def _arange_dynamic(
diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py
index ec56e1c0506b..1d6869d2140d 100644
--- a/jax/_src/numpy/reductions.py
+++ b/jax/_src/numpy/reductions.py
@@ -34,7 +34,7 @@
from jax._src.lax import other as lax_other
from jax._src.lax import parallel as lax_parallel
from jax._src.lax import slicing as lax_slicing
-from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DeprecatedArg
+from jax._src.typing import Array, ArrayLike, DType, DTypeLike
from jax._src.util import canonicalize_axis, canonicalize_axis_tuple, maybe_named_axis, set_module
@@ -202,7 +202,7 @@ def _cast_to_numeric(operand: Array) -> Array:
return promote_dtypes_numeric(operand)[0]
def _require_integer(arr: Array) -> Array:
- if not dtypes.isdtype(arr, ("bool", "integral")):
+ if not dtypes.isdtype(arr.dtype, ("bool", "integral")):
raise ValueError(f"integer argument required; got dtype={arr.dtype}")
return arr
@@ -2371,12 +2371,11 @@ def cumulative_prod(
# Quantiles
-# TODO(jakevdp): interpolation argument deprecated 2024-05-16
@export
-@api.jit(static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method'))
+@api.jit(static_argnames=('axis', 'overwrite_input', 'keepdims', 'method'))
def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
- keepdims: bool = False, *, interpolation: DeprecatedArg = DeprecatedArg()) -> Array:
+ keepdims: bool = False) -> Array:
"""Compute the quantile of the data along the specified axis.
JAX implementation of :func:`numpy.quantile`.
@@ -2418,18 +2417,14 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No
if overwrite_input or out is not None:
raise ValueError("jax.numpy.quantile does not support overwrite_input=True "
"or out != None")
- # TODO(jakevdp): remove the interpolation argument in JAX v0.9.0
- if not isinstance(interpolation, DeprecatedArg):
- raise TypeError("quantile() argument interpolation was removed in JAX"
- " v0.8.0. Use method instead.")
return _quantile(lax.asarray(a), lax.asarray(q), axis, method, keepdims, False)
-# TODO(jakevdp): interpolation argument deprecated 2024-05-16
+
@export
-@api.jit(static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method'))
+@api.jit(static_argnames=('axis', 'overwrite_input', 'keepdims', 'method'))
def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
- keepdims: bool = False, *, interpolation: DeprecatedArg = DeprecatedArg()) -> Array:
+ keepdims: bool = False) -> Array:
"""Compute the quantile of the data along the specified axis, ignoring NaNs.
JAX implementation of :func:`numpy.nanquantile`.
@@ -2473,10 +2468,6 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None =
msg = ("jax.numpy.nanquantile does not support overwrite_input=True or "
"out != None")
raise ValueError(msg)
- # TODO(jakevdp): remove the interpolation argument in JAX v0.9.0
- if not isinstance(interpolation, DeprecatedArg):
- raise TypeError("nanquantile() argument interpolation was removed in JAX"
- " v0.8.0. Use method instead.")
return _quantile(lax.asarray(a), lax.asarray(q), axis, method, keepdims, True)
def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
@@ -2603,13 +2594,12 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
return lax.convert_element_type(result, a.dtype)
-# TODO(jakevdp): interpolation argument deprecated 2024-05-16
@export
-@api.jit(static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method'))
+@api.jit(static_argnames=('axis', 'overwrite_input', 'keepdims', 'method'))
def percentile(a: ArrayLike, q: ArrayLike,
axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
- keepdims: bool = False, *, interpolation: DeprecatedArg = DeprecatedArg()) -> Array:
+ keepdims: bool = False) -> Array:
"""Compute the percentile of the data along the specified axis.
JAX implementation of :func:`numpy.percentile`.
@@ -2649,21 +2639,16 @@ def percentile(a: ArrayLike, q: ArrayLike,
"""
a, q = ensure_arraylike("percentile", a, q)
q, = promote_dtypes_inexact(q)
- # TODO(jakevdp): remove the interpolation argument in JAX v0.9.0
- if not isinstance(interpolation, DeprecatedArg):
- raise TypeError("percentile() argument interpolation was removed in JAX"
- " v0.8.0. Use method instead.")
return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input,
method=method, keepdims=keepdims)
-# TODO(jakevdp): interpolation argument deprecated 2024-05-16
@export
-@api.jit(static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method'))
+@api.jit(static_argnames=('axis', 'overwrite_input', 'keepdims', 'method'))
def nanpercentile(a: ArrayLike, q: ArrayLike,
axis: int | tuple[int, ...] | None = None,
out: None = None, overwrite_input: bool = False, method: str = "linear",
- keepdims: bool = False, *, interpolation: DeprecatedArg = DeprecatedArg()) -> Array:
+ keepdims: bool = False) -> Array:
"""Compute the percentile of the data along the specified axis, ignoring NaN values.
JAX implementation of :func:`numpy.nanpercentile`.
@@ -2706,10 +2691,6 @@ def nanpercentile(a: ArrayLike, q: ArrayLike,
a, q = ensure_arraylike("nanpercentile", a, q)
q, = promote_dtypes_inexact(q)
q = q / 100
- # TODO(jakevdp): remove the interpolation argument in JAX v0.9.0
- if not isinstance(interpolation, DeprecatedArg):
- raise TypeError("nanpercentile() argument interpolation was removed in JAX"
- " v0.8.0. Use method instead.")
return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
method=method, keepdims=keepdims)
diff --git a/jax/_src/numpy/scalar_types.py b/jax/_src/numpy/scalar_types.py
index 360cb96ed1ed..4ebf75b020a7 100644
--- a/jax/_src/numpy/scalar_types.py
+++ b/jax/_src/numpy/scalar_types.py
@@ -102,7 +102,7 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta:
complex64 = csingle = _make_scalar_type(np.complex64)
complex128 = cdouble = _make_scalar_type(np.complex128)
-int_ = int32 if dtypes.int_ == np.int32 else int64
-uint = uint32 if dtypes.uint == np.uint32 else uint64
-float_: Any = float32 if dtypes.float_ == np.float32 else float64
-complex_ = complex64 if dtypes.complex_ == np.complex64 else complex128
+int_ = int64
+uint = uint64
+float_ = float64
+complex_ = complex128
diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py
index 4d7d992545bc..53a78cf6c8ac 100644
--- a/jax/_src/numpy/util.py
+++ b/jax/_src/numpy/util.py
@@ -49,23 +49,16 @@ def promote_shapes(fun_name: str, *args: ArrayLike) -> list[Array]:
return [lax.asarray(arg) for arg in args]
else:
shapes = [np.shape(arg) for arg in args]
- if config.dynamic_shapes.value:
- # With dynamic shapes we don't support singleton-dimension broadcasting;
- # we instead broadcast out to the full shape as a temporary workaround.
- # TODO(mattjj): revise this workaround
- res_shape = lax.broadcast_shapes(*shapes) # Can raise an error!
- return [_broadcast_to(arg, res_shape) for arg, shp in zip(args, shapes)]
+ if all(len(shapes[0]) == len(s) for s in shapes[1:]):
+ return [lax.asarray(arg) for arg in args] # no need for rank promotion, so rely on lax promotion
+ nonscalar_ranks = {len(shp) for shp in shapes if shp}
+ if len(nonscalar_ranks) < 2:
+ return [lax.asarray(arg) for arg in args] # rely on lax scalar promotion
else:
- if all(len(shapes[0]) == len(s) for s in shapes[1:]):
- return [lax.asarray(arg) for arg in args] # no need for rank promotion, so rely on lax promotion
- nonscalar_ranks = {len(shp) for shp in shapes if shp}
- if len(nonscalar_ranks) < 2:
- return [lax.asarray(arg) for arg in args] # rely on lax scalar promotion
- else:
- if config.numpy_rank_promotion.value != "allow":
- _rank_promotion_warning_or_error(fun_name, shapes)
- result_rank = len(lax.broadcast_shapes(*shapes))
- return [lax.broadcast_to_rank(arg, result_rank) for arg in args]
+ if config.numpy_rank_promotion.value != "allow":
+ _rank_promotion_warning_or_error(fun_name, shapes)
+ result_rank = len(lax.broadcast_shapes(*shapes))
+ return [lax.broadcast_to_rank(arg, result_rank) for arg in args]
def _rank_promotion_warning_or_error(fun_name: str, shapes: Sequence[Shape]):
diff --git a/jax/_src/ops/scatter.py b/jax/_src/ops/scatter.py
index 0cb35e310e25..f120d386b5fd 100644
--- a/jax/_src/ops/scatter.py
+++ b/jax/_src/ops/scatter.py
@@ -73,18 +73,19 @@ def _scatter_update(x: ArrayLike, idx: Index | tuple[Index, ...],
# XLA gathers and scatters are very similar in structure; the scatter logic
# is more or less a transpose of the gather equivalent.
- treedef, static_idx, dynamic_idx = indexing.split_index_for_jit(idx, x.shape)
+ indexer = indexing.NDIndexer.from_raw_indices(idx, x.shape).expand_bool_indices()
+ dynamic_idx, treedef = tree_util.tree_flatten(indexer)
internal_scatter = partial(
_scatter_impl, scatter_op=scatter_op, treedef=treedef,
- static_idx=static_idx, indices_are_sorted=indices_are_sorted,
+ indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode,
normalize_indices=normalize_indices)
if out_sharding is not None:
return auto_axes(internal_scatter, out_sharding=out_sharding,
axes=out_sharding.mesh.explicit_axes # type: ignore
)(x, y, dynamic_idx)
- return internal_scatter(x, y, dynamic_idx)
+ return internal_scatter(x, y, tuple(dynamic_idx))
# TODO(phawkins): re-enable jit after fixing excessive recompilation for
@@ -92,7 +93,7 @@ def _scatter_update(x: ArrayLike, idx: Index | tuple[Index, ...],
# @jit(static_argnums=(2, 3, 4))
def _scatter_impl(x: ArrayLike, y: ArrayLike, dynamic_idx: tuple[Any, ...], *,
scatter_op: Callable[..., Array],
- treedef: tree_util.PyTreeDef, static_idx: tuple[Any, ...],
+ treedef: tree_util.PyTreeDef,
indices_are_sorted: bool, unique_indices: bool,
mode: slicing.GatherScatterMode | str | None, normalize_indices: bool):
dtype = lax.dtype(x)
@@ -107,9 +108,8 @@ def _scatter_impl(x: ArrayLike, y: ArrayLike, dynamic_idx: tuple[Any, ...], *,
"In future JAX releases this will result in an error.",
FutureWarning)
- idx = indexing.merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
- indexer = indexing.index_to_gather(np.shape(x), idx, core.typeof(x).sharding,
- normalize_indices=normalize_indices)
+ general_indexer = tree_util.tree_unflatten(treedef, dynamic_idx)
+ indexer = general_indexer.to_gather(core.typeof(x).sharding, normalize_indices=normalize_indices)
# Avoid calling scatter if the slice shape is empty, both as a fast path and
# to handle cases like zeros(0)[array([], int32)].
diff --git a/jax/_src/pallas/BUILD b/jax/_src/pallas/BUILD
index 2207af4e9dc2..469b31d9b326 100644
--- a/jax/_src/pallas/BUILD
+++ b/jax/_src/pallas/BUILD
@@ -66,3 +66,14 @@ py_library(
"//jax/_src/lib",
] + py_deps("numpy"),
)
+
+py_library(
+ name = "pallas_test_util",
+ srcs = [
+ "pallas_test_util.py",
+ ],
+ deps = [
+ ":pallas",
+ "//jax/_src:test_util",
+ ],
+)
diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py
index dff563129ce0..6c995002976b 100644
--- a/jax/_src/pallas/core.py
+++ b/jax/_src/pallas/core.py
@@ -101,7 +101,7 @@ class semaphore_dtype(dtypes.extended):
"""Common dtype for all kinds of semaphore dtypes.
This is an abstract class that should never be instantiated, but rather
- exists for the sake of `jnp.issubdtype`.
+ exists for the sake of ``jnp.issubdtype``.
"""
class semaphore(semaphore_dtype):
@@ -359,7 +359,7 @@ def __str__(self):
class BoundedSlice:
"""Allows to specify a bounded slice of a dimension.
- Specifically, the index_map need to return a `pl.Slice/pl.ds` for this
+ Specifically, the index_map need to return a ``pl.Slice/pl.ds`` for this
dimension. The start and size may be dynamic, as long as the size <=
block_size.
"""
@@ -529,15 +529,7 @@ def to_block_mapping(
)
ref_block_shape = _get_ref_block_shape(block_shape)
- if isinstance(array_aval, jax_core.DShapedArray):
- # Get the "max" shape for the ragged array.
- block_array_aval = array_aval.update(shape=ref_block_shape)
- block_array_aval = jax_core.ShapedArray(
- block_array_aval.shape,
- block_array_aval.dtype,
- block_array_aval.weak_type,
- )
- elif isinstance(array_aval, ShapedArrayWithMemorySpace):
+ if isinstance(array_aval, ShapedArrayWithMemorySpace):
block_array_aval = jax_core.ShapedArray(
ref_block_shape, array_aval.dtype, array_aval.weak_type
)
@@ -618,10 +610,6 @@ def to_block_mapping(
f"{origin} must not capture constants: {consts}"
)
- if isinstance(array_aval, (jax_core.ShapedArray, jax_core.DShapedArray)):
- array_aval_shape = _max_shape_from_aval(array_aval)
- array_aval = array_aval.update(shape=array_aval_shape)
-
mapping = BlockMapping(
block_shape=block_shape,
transformed_block_aval=block_aval, # There are no transforms by default
@@ -1064,8 +1052,6 @@ def _max_shape_from_aval(array_aval: jax_core.ShapedArray):
for i, s in enumerate(array_aval.shape):
try:
aval = jax_core.get_aval(s)
- if isinstance(aval, jax_core.DShapedArray):
- array_aval_shape[i] = aval.dtype.bound
except OverflowError as e:
# Note - there are annoying cases where on 32 bit hardware,
# a flattened index space may overflow - for these cases,
@@ -1379,6 +1365,27 @@ def _get_sds(aval: jax_core.AbstractValue):
core_map_p = jax_core.Primitive("core_map")
core_map_p.multiple_results = True
+def _core_map_is_high(*avals, jaxpr, **params):
+ del avals, params
+ return jaxpr.is_high
+core_map_p.is_high = _core_map_is_high # type: ignore[method-assign]
+
+def _core_map_to_lojax(*consts, jaxpr, mesh, **params):
+ closed_hi_jaxpr = jax_core.ClosedJaxpr(jaxpr, consts)
+ with (
+ tracing_grid_env(tuple(mesh.shape.values()), mapped_dims=()),
+ jax_core.extend_axis_env_nd(mesh.shape.items()),
+ ):
+ closed_lo_jaxpr = pe.lower_jaxpr(closed_hi_jaxpr)
+ assert not closed_lo_jaxpr.is_high
+ return core_map_p.bind(
+ *closed_lo_jaxpr.consts,
+ jaxpr=closed_lo_jaxpr.jaxpr,
+ mesh=mesh,
+ **params,
+ )
+core_map_p.to_lojax = _core_map_to_lojax
+
def core_map(
mesh,
@@ -1549,8 +1556,13 @@ def default_mesh_discharge_rule(
scratch_shapes,
):
"""Discharges a ``core_map`` over a mesh to a ``pallas_call``."""
- del out_avals # Unused.
default_memory_space = memory_space
+ if not all(
+ isinstance(aval, state.AbstractRef) for aval in (in_avals + out_avals)
+ ):
+ raise ValueError(
+ "default_mesh_discharge_rule only supports Ref inputs/outputs."
+ )
def body(*args):
# Due to aliasing, ``args`` contains aliased inputs and outputs so we
@@ -1619,15 +1631,24 @@ def _core_map_discharge_rule(in_avals, out_avals, *args_flat, jaxpr, debug_info,
for var in jaxpr.constvars
if not isinstance(aval := var.aval, state.AbstractRef)
]
- if consts_avals:
+ is_scalar_const_aval = [
+ isinstance(aval, jax_core.ShapedArray) and not aval.shape
+ for aval in consts_avals
+ ]
+ if not all(is_scalar_const_aval):
ctx = jax_core.JaxprPpContext()
- pp_const_avals = ", ".join(
- jax_core.pp_aval(aval, ctx) for aval in consts_avals
+ non_scalar_const_avals = [
+ aval
+ for aval, is_scalar in zip(consts_avals, is_scalar_const_aval)
+ if not is_scalar
+ ]
+ non_scalar_const_pp_avals = ", ".join(
+ jax_core.pp_aval(aval, ctx) for aval in non_scalar_const_avals
)
raise ValueError(
"The kernel function in core_map"
- f" {debug_info.func_src_info} captures constants"
- f" [{pp_const_avals}]. You should pass them as inputs."
+ f" {debug_info.func_src_info} captures non-scalar constants"
+ f" [{non_scalar_const_pp_avals}]. You should pass them as inputs."
)
return _core_map_mesh_rules[type(mesh)](
in_avals, out_avals, *args_flat, jaxpr=jaxpr, mesh=mesh, **kwargs
diff --git a/jax/_src/pallas/cost_estimate.py b/jax/_src/pallas/cost_estimate.py
index ef2cf10beef3..83d35e2ae977 100644
--- a/jax/_src/pallas/cost_estimate.py
+++ b/jax/_src/pallas/cost_estimate.py
@@ -205,10 +205,12 @@ def dot_general_cost_rule(ctx: Context,
assert len(lhs_batch_dims) == len(rhs_batch_dims)
flops = 1
# Flops along a contracting dim is 2*dim (addition and multiplication)
+ contracting_flops = 1
for i in range(len(lhs_contracting_dims)):
lhs_dim, rhs_dim = lhs_contracting_dims[i], rhs_contracting_dims[i]
assert x_shape[lhs_dim] == y_shape[rhs_dim]
- flops *= 2 * x_shape[lhs_dim]
+ contracting_flops *= x_shape[lhs_dim]
+ flops *= 2 * contracting_flops
# Now we handle all other dimensions.
for i, lhs_dim in enumerate(x_shape):
if i in lhs_contracting_dims:
diff --git a/jax/_src/pallas/helpers.py b/jax/_src/pallas/helpers.py
index 9c8ae14ab2b7..026abbbe5731 100644
--- a/jax/_src/pallas/helpers.py
+++ b/jax/_src/pallas/helpers.py
@@ -36,6 +36,19 @@
@api.named_call
def empty_like(x: object):
+ """Create an empty PyTree of possibly uninitialized values.
+
+ Args:
+ x: A PyTree with leaves specifying the shape and dtype of
+ the uninitialized object.
+
+ Returns:
+ A PyTree with the same structure as ``x``, but with uninitialized
+ values.
+
+ See Also:
+ :func:`jax.lax.empty`
+ """
return tree_util.tree_map(lambda leaf: empty(leaf.shape, leaf.dtype), x)
diff --git a/jax/_src/pallas/hlo_interpreter.py b/jax/_src/pallas/hlo_interpreter.py
index ec100e1d796e..f3faddbc13dd 100644
--- a/jax/_src/pallas/hlo_interpreter.py
+++ b/jax/_src/pallas/hlo_interpreter.py
@@ -411,9 +411,6 @@ def pallas_call_hlo_interpret(
# to catch OOB accesses.
for carry_element in carry:
aval = carry_element.aval
- if isinstance(aval, jax_core.DShapedArray):
- aval = jax_core.ShapedArray(aval.shape, aval.dtype)
- carry_element.aval = aval
carry = map(_pad_to_block_dimension, carry, block_shapes)
carry.extend(scratch_values)
diff --git a/jax/_src/pallas/mosaic/BUILD b/jax/_src/pallas/mosaic/BUILD
index 3538428e97f3..afcdefffa81e 100644
--- a/jax/_src/pallas/mosaic/BUILD
+++ b/jax/_src/pallas/mosaic/BUILD
@@ -175,6 +175,7 @@ pytype_strict_library(
deps = [
":core",
":lowering",
+ ":sc_core",
":sc_lowering",
"//jax",
"//jax/_src:core",
diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py
index 58643515b239..f9fe91fe14b1 100644
--- a/jax/_src/pallas/mosaic/core.py
+++ b/jax/_src/pallas/mosaic/core.py
@@ -16,20 +16,23 @@
from __future__ import annotations
import collections
+from collections.abc import Mapping
from collections.abc import Sequence
import dataclasses
import enum
from typing import Any, ClassVar, Literal
-from collections.abc import Mapping
import jax
-import jax.numpy as jnp
-from jax.extend import backend as jex_backend
from jax._src import core as jax_core
+from jax._src import deprecations
+from jax._src import linear_util as lu
from jax._src import state
from jax._src import util
from jax._src.frozen_dict import FrozenDict
+from jax._src.interpreters import partial_eval as pe
from jax._src.pallas import core as pallas_core
+from jax.extend import backend as jex_backend
+import jax.numpy as jnp
import numpy as np
@@ -100,6 +103,8 @@ class CompilerParams(pallas_core.CompilerParams):
skip_device_barrier: Skip the default device barrier for the kernel.
allow_collective_id_without_custom_barrier: Allow the use of collective_id
without a custom barrier.
+ use_tc_tiling_on_sc: Use TensorCore tiling for SparseCore. This flag is
+ only used for ``SC_*_SUBCORE`` kernels.
"""
BACKEND: ClassVar[pallas_core.Backend] = "mosaic_tpu"
dimension_semantics: tuple[DimensionSemantics, ...] | None = None
@@ -115,6 +120,7 @@ class CompilerParams(pallas_core.CompilerParams):
skip_device_barrier: bool = False
allow_collective_id_without_custom_barrier: bool = False
shape_invariant_numerics: bool = True
+ use_tc_tiling_on_sc: bool | None = None
def __init__(
self,
@@ -131,6 +137,7 @@ def __init__(
skip_device_barrier: bool = False,
allow_collective_id_without_custom_barrier: bool = False,
shape_invariant_numerics: bool = True,
+ use_tc_tiling_on_sc: bool | None = None,
):
object.__setattr__(
self,
@@ -163,12 +170,13 @@ def __init__(
object.__setattr__(
self, "shape_invariant_numerics", shape_invariant_numerics
)
+ object.__setattr__(self, "use_tc_tiling_on_sc", use_tc_tiling_on_sc)
# Replace is a method, not a field.
replace = dataclasses.replace
+
class MemorySpace(enum.Enum):
- ANY = "any" # TODO(b/368401328): Remove this and just use pl.ANY.
VMEM = "vmem"
VMEM_SHARED = "vmem_shared"
SMEM = "smem"
@@ -187,6 +195,21 @@ def __call__(self, shape: Sequence[int], dtype: jnp.dtype):
# A convenience function for constructing MemoryRef types of ShapedArrays.
return self.from_type(jax_core.ShapedArray(tuple(shape), dtype))
+ def __getattr__(self, name):
+ if name == "ANY":
+ # Deprecated on Dec 10, 2025.
+ deprecations.warn(
+ "pltpu-memory-space-any",
+ "pltpu.MemorySpace.ANY is deprecated. Use pl.ANY instead.",
+ stacklevel=2,
+ )
+ return pallas_core.MemorySpace.ANY
+ return super().__getattr__(name) # type: ignore
+
+
+# TODO(slebedev): Remove this after
+MemorySpace.ANY = pallas_core.MemorySpace.ANY
+
class dma_semaphore(pallas_core.semaphore_dtype): pass
class DMASemaphore(pallas_core.AbstractSemaphoreTy):
@@ -336,6 +359,49 @@ def _tensorcore_mesh_discharge_rule(
"TensorCoreMesh does not support VMEM inputs/outputs when there are"
" >1 cores. Use HBM or ANY instead."
)
+ def allowed_aval(aval):
+ if isinstance(aval, state.AbstractRef):
+ return True
+ if isinstance(aval, jax_core.ShapedArray):
+ # Only scalars are allowed.
+ return not aval.shape
+ return False
+ assert all(allowed_aval(v.aval) for v in jaxpr.constvars + jaxpr.invars)
+
+ is_scalar_const = [
+ isinstance(v.aval, jax_core.ShapedArray) and not v.aval.shape
+ for v in jaxpr.constvars
+ ]
+ if any(is_scalar_const):
+ # Rewrite body jaxpr to take in scalar values as Refs.
+ def new_body(*args):
+ args = [
+ a[0] if is_scalar else a
+ for a, is_scalar in zip(args, is_scalar_const)
+ ]
+ return jax_core.eval_jaxpr(jaxpr, args)
+ # TODO(sharadmv): Remove this once Mosaic support passing scalars as values.
+ new_trace_avals = [
+ state.AbstractRef( # pylint: disable=g-long-ternary
+ jax_core.ShapedArray((1,), v.aval.dtype),
+ memory_space=MemorySpace.SMEM,
+ )
+ if is_scalar
+ else v.aval
+ for v, is_scalar in zip(jaxpr.constvars, is_scalar_const)
+ ]
+ new_jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
+ lu.wrap_init(
+ new_body, debug_info=jaxpr.debug_info.with_unknown_names()
+ ),
+ new_trace_avals,
+ )
+ jaxpr = new_jaxpr.replace(invars=[], constvars=new_jaxpr.invars)
+ args = tuple(
+ a[None] if is_scalar else a
+ for a, is_scalar in zip(args, is_scalar_const)
+ )
+ in_avals, out_avals = util.split_list(new_trace_avals, [len(in_avals)])
return pallas_core.default_mesh_discharge_rule(
in_avals,
out_avals,
diff --git a/jax/_src/pallas/mosaic/error_handling.py b/jax/_src/pallas/mosaic/error_handling.py
index 8286ab0b5e2a..3d8714945e90 100644
--- a/jax/_src/pallas/mosaic/error_handling.py
+++ b/jax/_src/pallas/mosaic/error_handling.py
@@ -36,7 +36,7 @@
r'( to (?P[0-9]+)?:(?P[0-9]+))?\)'
)
MLIR_ERR_PREFIX = (
- 'Pallas encountered an internal verification error.'
+ 'Pallas encountered an internal verification error. '
'Please file a bug at https://github.com/jax-ml/jax/issues. '
'Error details: '
)
diff --git a/jax/_src/pallas/mosaic/interpret/BUILD b/jax/_src/pallas/mosaic/interpret/BUILD
index 97300e4d0a3d..2a86f2258032 100644
--- a/jax/_src/pallas/mosaic/interpret/BUILD
+++ b/jax/_src/pallas/mosaic/interpret/BUILD
@@ -33,6 +33,8 @@ py_library(
deps = [
":race_detection_state",
":shared_memory",
+ ":thread_map",
+ ":utils",
":vector_clock",
"//jax",
"//jax/_src:api",
@@ -79,3 +81,23 @@ pytype_strict_library(
"//jax/_src:source_info_util",
],
)
+
+pytype_strict_library(
+ name = "thread_map",
+ srcs = ["thread_map.py"],
+ deps = [
+ "//jax",
+ "//jax/_src:callback",
+ ],
+)
+
+pytype_strict_library(
+ name = "utils",
+ srcs = ["utils.py"],
+ deps = [
+ "//jax",
+ "//jax/_src:core",
+ "//jax/_src:util",
+ "//jax/_src/pallas",
+ ] + py_deps("numpy"),
+)
diff --git a/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py b/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py
index e1094a43c19d..9b815860ac47 100644
--- a/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py
+++ b/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py
@@ -39,6 +39,8 @@
from jax._src.pallas.mosaic.interpret import shared_memory as memory
from jax._src.pallas.mosaic.interpret import vector_clock as vc
from jax._src.pallas.mosaic.interpret.race_detection_state import RaceDetectionState
+from jax._src.pallas.mosaic.interpret.thread_map import thread_map
+import jax._src.pallas.mosaic.interpret.utils as interpret_utils
from jax._src.state import discharge as state_discharge
from jax._src.state import indexing
from jax._src.state import primitives as state_primitives
@@ -58,7 +60,7 @@
@dataclasses.dataclass(frozen=True, kw_only=True)
-class InterpretParams:
+class InterpretParams(interpret_utils.InterpretParams):
"""Parameters for TPU interpret mode.
TPU interpret mode is a way run Pallas TPU kernels on CPU, while simulating
@@ -71,35 +73,16 @@ class InterpretParams:
:func:`jax.experimental.pallas.pallas_call` or
:func:`jax.experimental.pallas.core_map`.
+ NOTE: If an exception is raised while interpreting a kernel, you must call
+ :func:`reset_tpu_interpret_mode_state` before using TPU interpret mode
+ again in the same process.
+
Attributes:
dma_execution_mode: If "eager", DMAs are executed as soon as they are
issued. If "on_wait", DMA reads or writes are only executed when a device
is waiting on a DMA semaphore that will be signaled when the read or write
is complete.
Default: "on_wait".
- detect_races: If True, a dynamic, happens-before race detector will be used
- to detect data races during kernel interpretation. If any races are
- detected, a message will be printed and `races.races_found` will be set to
- True.
- Default: False.
- out_of_bounds_reads: If "raise", an exception will be raised on any
- out-of-bounds read of a buffer. If "uninitialized_value", any parts of
- the read that are out-of-bounds will return the value used to fill
- uninitialized memory, which can be configured via the
- "uninitialized_memory". NOTE: If an exception is raised while
- interpreting a kernel, you must call
- :func:`reset_tpu_interpret_mode_state` before using TPU interpret mode
- again in the same process.
- Default: "raise".
- skip_floating_point_ops: If True, operations that produce only floating
- point values will not be interpreted; instead, their results will be
- replaced with arrays all of `jnp.inf`. Additionally any floating point
- operands to any operation will be replaced with (arrays of) `jnp.inf`.
- Default: False.
- uninitialized_memory: If "nan", allocated buffers are initialized to contain
- all NaNs (or to their maximum possible value for integers). If "zero",
- allocated buffers are initialized to all zeros.
- Default: "nan".
random_seed: Seed for random number generator used during interpretation.
Currently random numbers are used to randomize the grid coordinates along
dimensions with 'parallel' semantics.
@@ -112,33 +95,25 @@ class InterpretParams:
along grid dimensions with 'parallel' semantics and - the mapping of grid
points to local (i.e. per-device) cores.
Default: None.
- num_cores_per_device: The number of cores per device.
- Default: 1.
allow_hbm_allocation_in_run_scoped: If `True`, allows the allocation of HBM
buffers (which are then shared across the cores in a device) in
`run_scoped`. While this behavior can be enabled in the interpreter,
allocating HBM buffers with `run_scoped` is not supported when executing
Pallas kernels on a real TPU.
Default: `False`.
- vector_clock_size: The number of entries in the vector clocks. This should
- be an integer bigger then the total number of cores, i.e. bigger than
- `number of devices * num_cores_per_device`. If `None`, the vector clock
- size that is used in the interpreter will default to twice the total
- number of cores.
- Default: None.
"""
+
dma_execution_mode: Literal["eager", "on_wait"] = "on_wait"
- detect_races: bool = False
- out_of_bounds_reads: Literal["raise", "uninitialized"] = "raise"
- skip_floating_point_ops: bool = False
- uninitialized_memory: Literal["nan", "zero"] = "nan"
random_seed: int | None = None
grid_point_recorder: (
Callable[[tuple[np.int32, ...], np.int32], None] | None
) = None
- num_cores_per_device: int = 1
allow_hbm_allocation_in_run_scoped: bool = False
- vector_clock_size: int | None = None
+
+ @property
+ def num_cores_per_device(self) -> int:
+ return self.num_cores_or_threads_per_device
+
@contextlib.contextmanager
def force_tpu_interpret_mode(params: InterpretParams = InterpretParams()):
@@ -167,26 +142,12 @@ def set_tpu_interpret_mode(params: InterpretParams = InterpretParams()):
config.pallas_tpu_interpret_mode_context_manager.set_global(params) # type: ignore[arg-type]
-class Counter:
- """A simple counter that is thread-safe."""
-
- def __init__(self, initial_value: int):
- self.value = initial_value
- self.lock = threading.Lock()
-
- def get_next(self):
- with self.lock:
- result = self.value
- self.value += 1
- return result
-
-
# TODO(jburnim): Do we want to support multiple instances of SharedMemory?
# Maybe for running multiple distinct interpreted computations in parallel?
_shared_memory: memory.SharedMemory | None = None
_shared_memory_init_lock = threading.Lock()
races: RaceDetectionState | None = None
-dma_id_counter: Counter | None = None
+dma_id_counter: interpret_utils.Counter | None = None
def reset_tpu_interpret_mode_state():
"""Resets all global, shared state used by TPU interpret mode.
@@ -218,23 +179,6 @@ def _clear_shared_memory():
_shared_memory = None
-def _get_vector_clock_size(
- num_devices, num_cores_per_device, *, interpret_params
-) -> int:
- """Returns the number of vector clocks to use.`"""
- num_cores = num_devices * num_cores_per_device
- if interpret_params.vector_clock_size is not None:
- if num_cores >= interpret_params.vector_clock_size:
- raise ValueError(
- f'Vector clock size ({interpret_params.vector_clock_size}) must be '
- f'greater than the total number of cores ({num_cores}).'
- )
- return interpret_params.vector_clock_size
- else:
- # Default the vector clock size to twice the total number of cores.
- return 2 * num_cores
-
-
def _initialize_shared_memory(
device_id, num_devices, num_cores_per_device, *, interpret_params
):
@@ -247,11 +191,9 @@ def _initialize_shared_memory(
with _shared_memory_init_lock:
if _shared_memory is None:
- vector_clock_size = _get_vector_clock_size(
- num_devices, num_cores_per_device, interpret_params=interpret_params
- )
+ vector_clock_size = interpret_params.get_vector_clock_size(num_devices)
races = RaceDetectionState(num_cores=num_cores)
- dma_id_counter = Counter(100)
+ dma_id_counter = interpret_utils.Counter(100)
_shared_memory = memory.SharedMemory(
num_devices=num_devices,
num_cores_per_device=num_cores_per_device,
@@ -273,31 +215,16 @@ def _initialize_shared_memory(
assert _shared_memory.num_cores == num_cores
-def _update_clocks(low_global_core_id, high_global_core_id):
- """Synchronizes the vector clocks for the cores with ids in the range between the two arguments."""
- shared_memory = _get_shared_memory()
- # Despite only updating the vector clocks for some cores, we still need to
- # hold the global lock to ensure that no other devices are concurrently
- # accessing the same vector clocks.
- with shared_memory.lock:
- for c in shared_memory.clocks[low_global_core_id + 1 : high_global_core_id]:
- vc.update_vector_clock(shared_memory.clocks[low_global_core_id], c)
- for c in shared_memory.clocks[low_global_core_id + 1 : high_global_core_id]:
- vc.update_vector_clock(c, shared_memory.clocks[low_global_core_id])
-
-
def _update_clocks_for_device_barrier(device_id):
"""Synchronizes the vector clocks for the cores on the given device."""
shared_memory = _get_shared_memory()
- low_core_id = device_id * shared_memory.num_cores_per_device
- high_core_id = (device_id + 1) * shared_memory.num_cores_per_device
- _update_clocks(low_core_id, high_core_id)
+ shared_memory.update_clocks_for_device_barrier(device_id)
def _update_clocks_for_global_barrier():
"""Synchronizes all vector clocks."""
shared_memory = _get_shared_memory()
- _update_clocks(0, shared_memory.num_cores)
+ shared_memory.update_clocks(0, shared_memory.num_cores)
def _barrier(device_id):
@@ -322,7 +249,8 @@ def _check_for_revisiting(device_id, local_core_id, loop_idx, output_blocks):
except:
raise ValueError('Advanced indexers are not supported on TPU')
output_ranges = [
- _to_range(b) if b is not None else None for b in output_blocks
+ interpret_utils.to_range(b) if b is not None else None
+ for b in output_blocks
]
shared_memory = _get_shared_memory()
@@ -527,13 +455,16 @@ def _allocate_semaphores(
TPU_MEMORY_SPACE_IDXS: dict[
mosaic_core.MemorySpace | pallas_core.MemorySpace | None, int
] = {v: i for i, v in enumerate(mosaic_core.MemorySpace)}
-TPU_MEMORY_SPACE_IDXS[pallas_core.MemorySpace.ANY] = TPU_MEMORY_SPACE_IDXS[
- mosaic_core.MemorySpace.ANY
-]
TPU_MEMORY_SPACE_NAMES = {
i: v.value for i, v in enumerate(mosaic_core.MemorySpace)
}
+# Inject ANY as the last memory space.
+TPU_MEMORY_SPACE_NAMES[len(TPU_MEMORY_SPACE_IDXS)] = (
+ pallas_core.MemorySpace.ANY.value
+)
+TPU_MEMORY_SPACE_IDXS[pallas_core.MemorySpace.ANY] = len(TPU_MEMORY_SPACE_IDXS)
+
# Default to VMEM when no memory space is specified.
TPU_MEMORY_SPACE_IDXS[None] = TPU_MEMORY_SPACE_IDXS[
mosaic_core.MemorySpace.VMEM
@@ -548,60 +479,6 @@ def get_barrier_semaphore(device_id, collective_id):
return np.int16(collective_id)
-def _transform_slice_or_index(slice_or_idx):
- if isinstance(slice_or_idx, int):
- return slice_or_idx
- else:
- start = int(slice_or_idx.start)
- size = int(slice_or_idx.size)
- stride = int(slice_or_idx.stride)
- return slice(start, start + size * stride, stride)
-
-
-def _compose_slice_or_index(slice_or_idx1, slice_or_idx2):
- ret = []
- i = 0
- j = 0
- while True:
- if i == len(slice_or_idx1):
- ret.extend(slice_or_idx2[j:])
- return tuple(ret)
- elif j == len(slice_or_idx2):
- ret.extend(slice_or_idx1[i:])
- return tuple(ret)
- elif isinstance(slice_or_idx1[i], int):
- ret.append(slice_or_idx1[i])
- i += 1
- elif isinstance(slice_or_idx2[j], int):
- ret.append(
- slice_or_idx1[i].start + slice_or_idx2[j] * slice_or_idx1[i].step
- )
- i += 1
- j += 1
- else:
- ret.append(
- slice(
- slice_or_idx1[i].start
- + slice_or_idx2[j].start * slice_or_idx1[i].step,
- slice_or_idx1[i].start
- + slice_or_idx2[j].stop * slice_or_idx1[i].step,
- slice_or_idx1[i].step * slice_or_idx2[j].step,
- )
- )
- i += 1
- j += 1
-
-
-def _to_range(transforms) -> tuple[slice | int, ...]:
- ret = ()
- for transform in transforms:
- # For now, assume only NDIndexer transforms.
- ret = _compose_slice_or_index(
- ret, tuple(_transform_slice_or_index(i) for i in transform.indices)
- )
- return ret
-
-
def _to_int(x: int | Array | None) -> int | None:
"""Converts a value to an integer, or returns None if the value is None."""
if x is None:
@@ -649,7 +526,7 @@ def get(
global_core_id = shared_memory.get_global_core_id(device_id, local_core_id)
key = (memory_space, buffer_id, device_id, local_core_id_for_buffer)
- read_range = _to_range(transforms)
+ read_range = interpret_utils.to_range(transforms)
ret, (shape, dtype), clock_ = shared_memory.get_buffer_content(
key, read_range, global_core_id
)
@@ -702,7 +579,9 @@ def get(
# out_of_bounds_reads == "uninitialized"
uninit_array = np.full(
full_read_shape,
- _uninitialized_value(dtype, shared_memory.uninitialized_memory),
+ interpret_utils.get_uninitialized_value(
+ dtype, shared_memory.uninitialized_memory
+ ),
dtype=dtype,
)
if ret is None:
@@ -771,7 +650,7 @@ def store(
global_core_id = shared_memory.get_global_core_id(device_id, local_core_id)
key = (memory_space, buffer_id, device_id, local_core_id_for_buffer)
- write_range = _to_range(transforms)
+ write_range = interpret_utils.to_range(transforms)
in_bounds, (shape, _), clock_ = shared_memory.store_buffer_content(
key, write_range, val, global_core_id
)
@@ -842,7 +721,7 @@ def swap(
global_core_id = shared_memory.get_global_core_id(device_id, local_core_id)
key = (memory_space, buffer_id, device_id, local_core_id_for_buffer)
- read_write_range = _to_range(transforms)
+ read_write_range = interpret_utils.to_range(transforms)
ret, (shape, _), clock = shared_memory.swap_buffer_content(
key, read_write_range, val, mask, global_core_id
)
@@ -1184,25 +1063,6 @@ def _compute_transformed_shape_and_dtype(shape, dtype, transforms):
dtype = transform.transform_dtype(dtype)
return shape, dtype
-def _device_coords_to_logical_id(device_coords, axis_sizes):
- if not isinstance(device_coords, tuple):
- device_coords = (device_coords,)
- assert len(device_coords) == len(axis_sizes)
- sizes = list(axis_sizes.values())
- ret = 0
- for i in range(len(device_coords)):
- ret += device_coords[i] * math.prod(sizes[i+1:])
- return ret
-
-def _device_id_to_logical(device_id, device_id_type, axis_sizes):
- if device_id is None:
- return None
- if device_id_type == primitives.DeviceIdType.MESH:
- return _device_coords_to_logical_id(device_id, axis_sizes)
- elif device_id_type == primitives.DeviceIdType.LOGICAL:
- return device_id
- else:
- raise ValueError(f'Unsupported device ID type: {device_id_type}')
@lu.cache
def _to_jaxpr(flat_fun, in_avals):
@@ -1211,24 +1071,15 @@ def _to_jaxpr(flat_fun, in_avals):
return new_jaxpr
def _is_any(memory_space):
- return ((memory_space == mosaic_core.MemorySpace.ANY) or
- (memory_space == pallas_core.MemorySpace.ANY))
+ return memory_space is pallas_core.MemorySpace.ANY
-def _is_float(dtype):
- return jnp.issubdtype(dtype, jnp.floating)
_SENTINEL = jnp.inf
-@dataclasses.dataclass(frozen=True)
-class Placeholder:
- """Placeholder for use in `_interpret_jaxpr` below instead of putting a concrete value into `env`."""
- shape: tuple[int, ...]
- dtype: jnp.dtype
-
def _get_memory_space_and_raise_if_hbm(aval, primitive_name, message=None):
memory_space = aval.memory_space
- if memory_space in [mosaic_core.MemorySpace.HBM, mosaic_core.MemorySpace.ANY]:
+ if memory_space in [mosaic_core.MemorySpace.HBM, pallas_core.MemorySpace.ANY]:
if message is None:
message = (
f'{primitive_name}: Buffers with a memory space of HBM or ANY cannot'
@@ -1250,23 +1101,14 @@ def _interpret_jaxpr(
compiler_params,
interpret_params
):
- env = {}
-
- def read(var):
- if isinstance(var, jax_core.Literal):
- result = var.val
- else:
- result = env[var]
- if isinstance(result, Placeholder):
- result = jax.lax.full(result.shape, _SENTINEL, result.dtype)
- return result
-
- def write(var, value):
- if interpret_params.skip_floating_point_ops and _is_float(value.dtype):
- value = Placeholder(value.shape, value.dtype)
- env[var] = value
-
- jax._src.util.safe_map(write, jaxpr.constvars + jaxpr.invars, args)
+ sentinel_for_floating_point_values = (
+ _SENTINEL if interpret_params.skip_floating_point_ops else None
+ )
+ env = interpret_utils.JaxprEnv(
+ vars=jaxpr.constvars + jaxpr.invars,
+ values=args,
+ sentinel_for_floating_point_values=sentinel_for_floating_point_values,
+ )
# TODO(jburnim): Clean up and finish this evaluation loop. For example:
# - Replace the big if-statement with a dictionary of rules.
@@ -1290,9 +1132,7 @@ def write(var, value):
# not need to do any reads if `interpret_params.skip_floating_point_ops`
# is True. If this is the case, we want to avoid materializing the read
# array into the jaxpr when this function is traced.
- deferred_invals = functools.partial(
- jax._src.util.safe_map, read, eqn.invars
- )
+ deferred_invals = functools.partial(env.read_many, eqn.invars)
if prim is primitives.load_p:
(ref, transforms, mask, _) = jax.tree.unflatten(
@@ -1438,8 +1278,8 @@ def f(*args, jaxpr):
device_id,
local_core_id,
TPU_MEMORY_SPACE_IDXS[memory_space],
- _uninitialized_array(
- v.aval.shape, v.aval.dtype, interpret_params
+ interpret_params.get_uninitialized_array(
+ v.aval.shape, v.aval.dtype
),
ordered=True,
)
@@ -1514,16 +1354,17 @@ def f(*args, jaxpr):
src_sem_transforms,
target_device_id,
) = jax.tree.unflatten(eqn.params['tree'], deferred_invals())
- target_device_id = _device_id_to_logical(
- target_device_id, eqn.params['device_id_type'], axis_sizes)
+ target_device_id = interpret_utils._device_id_to_logical(
+ target_device_id, eqn.params['device_id_type'], axis_sizes,
+ axis_indices)
(orig_src_ref, _, orig_dst_ref, *_
) = jax.tree.unflatten(eqn.params['tree'], eqn.invars)
src_memory_space = getattr(orig_src_ref.aval, 'memory_space', None)
if src_memory_space is None:
- src_memory_space = mosaic_core.MemorySpace.ANY
+ src_memory_space = pallas_core.MemorySpace.ANY
dst_memory_space = getattr(orig_dst_ref.aval, 'memory_space', None)
if dst_memory_space is None:
- dst_memory_space = mosaic_core.MemorySpace.ANY
+ dst_memory_space = pallas_core.MemorySpace.ANY
callback.io_callback(
functools.partial(dma_start, source_info=eqn.source_info),
(),
@@ -1579,8 +1420,9 @@ def f(*args, jaxpr):
elif prim is primitives.semaphore_signal_p:
sem, sem_transforms, inc, target_device_id, core_index = (
jax.tree.unflatten(eqn.params['args_tree'], deferred_invals()))
- target_device_id = _device_id_to_logical(
- target_device_id, eqn.params['device_id_type'], axis_sizes)
+ target_device_id = interpret_utils._device_id_to_logical(
+ target_device_id, eqn.params['device_id_type'], axis_sizes,
+ axis_indices)
callback.io_callback(
semaphore_signal,
(),
@@ -1618,7 +1460,7 @@ def f(*args, jaxpr):
else:
if interpret_params.skip_floating_point_ops and all(
- _is_float(ovar.aval.dtype) for ovar in eqn.outvars
+ interpret_utils.is_float(ovar.aval.dtype) for ovar in eqn.outvars
):
# Skip `prim.bind` since `prim` only produces floating-point values.
# It is safe to populate `out` with avals since mapping `write` over
@@ -1632,9 +1474,9 @@ def f(*args, jaxpr):
out = prim.bind(*subfuns, *deferred_invals(), **bind_params)
out = out if prim.multiple_results else [out]
- jax._src.util.safe_map(write, eqn.outvars, out)
+ env.write_many(eqn.outvars, out)
- return jax._src.util.safe_map(read, jaxpr.outvars)
+ return env.read_many(jaxpr.outvars)
def _compute_start_indices(
block_mapping, loop_idx, *args,
@@ -1796,7 +1638,6 @@ def _remove_memory_space_abstract_eval(x):
if (
x.memory_space is None
or x.memory_space is pallas_core.MemorySpace.ANY
- or x.memory_space is mosaic_core.MemorySpace.ANY
or x.memory_space is mosaic_core.MemorySpace.HBM
):
return jax_core.ShapedArray(x.shape, x.dtype)
@@ -1839,106 +1680,10 @@ def _get_grid_point(
grid_point.append(li if jnp.size(coords) == 0 else coords[li])
return jnp.array(grid_point, dtype=np.int32)
-def _uninitialized_value(dtype, uninitialized_memory: Literal['nan', 'zero']):
- if uninitialized_memory == 'nan':
- if jnp.issubdtype(dtype, jnp.floating):
- return np.nan
- elif jnp.issubdtype(dtype, jnp.integer):
- return jnp.iinfo(dtype).max
- elif jnp.issubdtype(dtype, jnp.bool):
- return True
- if uninitialized_memory == 'zero':
- return 0
- raise NotImplementedError(
- uninitialized_memory + ' + ' + str(dtype))
-
-def _uninitialized_array(shape, dtype, interpret_params):
- return jnp.full(
- shape,
- _uninitialized_value(dtype, interpret_params.uninitialized_memory),
- dtype,
- )
-
-def _pad_to_block_dimension(value, block_shape, interpret_params):
- """Pads values so the shape evenly divides into block dimensions.
-
- For example, if values has a shape of (33, 2, 5) with a block_shape of
- (32, 2, 4), this function will pad the value of shape to (64, 2, 8).
-
- Args:
- value: Array to be padded.
- block_shape: Block shapes to use for padding. If None, no padding will
- be performed.
-
- Returns:
- A padded array.
- """
- padded_shape = tuple(
- ((v - 1) // b + 1) * b for v, b in zip(value.shape, block_shape)
- )
- if padded_shape != value.shape:
- pad_width = tuple((0, a-b) for a, b in zip(padded_shape, value.shape))
- pad_value = _uninitialized_array((), value.dtype, interpret_params)
- value = jnp.pad(value, pad_width, constant_values=pad_value)
- return value
def get_interpret_effects():
return {callback._OrderedIOEffect}
-def _thread_map(f, num_threads):
- if num_threads == 1:
- f(jnp.int32(0))
- return
-
- def _f(core_index):
- f(core_index)
- return ()
- jaxpr = jax.make_jaxpr(_f)(jnp.int32(0))
-
- _call_threadmap_callback(jaxpr.jaxpr, num_threads, *jaxpr.consts)
-
-def _run_jaxpr(jaxpr, consts, *args):
- def _run(jaxpr, consts, *args):
- jax_core.eval_jaxpr(jaxpr, consts, *args)
- traced = jax.jit(_run, static_argnums=(0,)).trace(jaxpr, consts, *args)
- traced.lower().compile()(consts, *args)
- return
-
-import concurrent.futures
-
-def _thread_map_callback(jaxpr, num_threads, consts):
- num_threads = int(num_threads)
- threads = []
- with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
- for i in range(num_threads):
- threads.append(
- executor.submit(_run_jaxpr, jaxpr, consts, jnp.int32(i)))
- exceptions = []
- for i in range(num_threads):
- try:
- threads[i].result()
- except Exception as e:
- exceptions.append(e)
- if exceptions:
- # TODO(jburnim): Use ExceptionGroup once JAX requires Python 3.11.
- # raise ExceptionGroup('Exceptions raised during _thread_map', exceptions)
- raise exceptions[0]
-
-def _call_threadmap_callback(jaxpr, num_threads, *consts):
- # NOTE: At runtime, _thread_map_callback will lower and compile the
- # given jaxpr. (JAX's caches should ensure the jaxpr is only lowered and
- # compiled once.)
- #
- # TODO(jburnim): Would it be worth trying to lower/compile the jaxpr at
- # lowering/compilation time? E.g., by using a custom primitive here, could
- # we lower/compile jaxpr at lowering time, and then pass the compiled
- # function to the callback?
- return callback.io_callback(
- functools.partial(_thread_map_callback, jaxpr),
- (),
- num_threads,
- consts,
- ordered=True)
def interpret_pallas_call(
*args,
@@ -1963,7 +1708,8 @@ def interpret_pallas_call(
# that users don't have to specify it in the InterpretParams.
assert len(mesh.shape) == 1
interpret_params = dataclasses.replace(
- interpret_params, num_cores_per_device=mesh.devices.shape[0])
+ interpret_params, num_cores_or_threads_per_device=mesh.devices.shape[0]
+ )
args = [remove_memory_space_p.bind(a) for a in args]
# args contains: *dynamic_grid_sizes, *index, *inputs. (No consts?)
@@ -1983,8 +1729,9 @@ def interpret_pallas_call(
num_devices = functools.reduce(
jnp.multiply, axis_sizes.values(), jnp.int32(1))
axis_indices = {k: lax.axis_index(k) for k in axis_sizes.keys()}
- device_id = _device_coords_to_logical_id(
- tuple(axis_indices.values()), axis_sizes)
+ device_id = interpret_utils.device_coords_to_logical_id(
+ tuple(axis_indices.values()), axis_sizes, axis_indices
+ )
callback.io_callback(
functools.partial(
_initialize_shared_memory, interpret_params=interpret_params
@@ -2007,7 +1754,7 @@ def interpret_pallas_call(
]
num_inputs = grid_mapping.num_inputs
input_args = [
- _pad_to_block_dimension(a, bs, interpret_params)
+ interpret_params.pad_to_block_dimension(a, bs)
for a, bs in zip(input_args, block_shapes[:num_inputs])
]
@@ -2025,7 +1772,7 @@ def interpret_pallas_call(
jax.ShapeDtypeStruct((), jnp.int16),
device_id,
None, # local_core_id
- TPU_MEMORY_SPACE_IDXS[mosaic_core.MemorySpace.ANY],
+ TPU_MEMORY_SPACE_IDXS[pallas_core.MemorySpace.ANY],
input_args[i],
ordered=True,
)
@@ -2048,11 +1795,11 @@ def interpret_pallas_call(
output_buffer_shapes.append(input_args[oi_alias_map[i]].shape)
output_vals.append(input_args[oi_alias_map[i]])
else:
- out_val = _uninitialized_array(bm.array_aval.shape,
- bm.array_aval.dtype,
- interpret_params)
- padded_val = _pad_to_block_dimension(
- out_val, output_block_shapes[i], interpret_params
+ out_val = interpret_params.get_uninitialized_array(
+ bm.array_aval.shape, bm.array_aval.dtype
+ )
+ padded_val = interpret_params.pad_to_block_dimension(
+ out_val, output_block_shapes[i]
)
output_buffer_ids.append(
callback.io_callback(
@@ -2060,7 +1807,7 @@ def interpret_pallas_call(
jax.ShapeDtypeStruct((), jnp.int16),
device_id,
None, # local_core_id
- TPU_MEMORY_SPACE_IDXS[mosaic_core.MemorySpace.ANY],
+ TPU_MEMORY_SPACE_IDXS[pallas_core.MemorySpace.ANY],
padded_val,
ordered=True,
)
@@ -2121,8 +1868,8 @@ def interpret_pallas_call(
device_id,
None, # local_core_id,
TPU_MEMORY_SPACE_IDXS[var.aval.memory_space],
- _uninitialized_array(
- var.aval.shape, var.aval.dtype, interpret_params
+ interpret_params.get_uninitialized_array(
+ var.aval.shape, var.aval.dtype
),
ordered=True,
)
@@ -2300,7 +2047,7 @@ def _store_slice_to_kernel_input(index, input_var):
jax.ShapeDtypeStruct(input_var.aval.shape, input_var.aval.dtype),
device_id,
core_index,
- TPU_MEMORY_SPACE_IDXS[mosaic_core.MemorySpace.ANY],
+ TPU_MEMORY_SPACE_IDXS[pallas_core.MemorySpace.ANY],
input_buffer_ids[index],
(transform,),
cur_block_indices[index],
@@ -2369,7 +2116,7 @@ def _store_to_output_buffer(index, output_var, transform):
(),
device_id,
core_index,
- TPU_MEMORY_SPACE_IDXS[mosaic_core.MemorySpace.ANY],
+ TPU_MEMORY_SPACE_IDXS[pallas_core.MemorySpace.ANY],
output_buffer_ids[index],
(transform,),
kernel_output_val,
@@ -2471,7 +2218,7 @@ def _store_to_output_buffer(index, output_var, transform):
_update_clocks_for_device_barrier, (), device_id, ordered=True
)
- _thread_map(_execute_grid_for_core, interpret_params.num_cores_per_device)
+ thread_map(_execute_grid_for_core, interpret_params.num_cores_per_device)
# TODO(jburnim): Should we only create happens-before here from the other
# # cores to core 0?
@@ -2488,7 +2235,7 @@ def _store_to_output_buffer(index, output_var, transform):
val,
device_id,
0, # local_core_id
- TPU_MEMORY_SPACE_IDXS[mosaic_core.MemorySpace.ANY],
+ TPU_MEMORY_SPACE_IDXS[pallas_core.MemorySpace.ANY],
output_buffer_id,
(
indexing.NDIndexer.from_indices_shape(
diff --git a/jax/_src/pallas/mosaic/interpret/race_detection_state.py b/jax/_src/pallas/mosaic/interpret/race_detection_state.py
index 64f6568c3754..ff76778119d3 100644
--- a/jax/_src/pallas/mosaic/interpret/race_detection_state.py
+++ b/jax/_src/pallas/mosaic/interpret/race_detection_state.py
@@ -116,9 +116,9 @@ def check_read(
# between real device IDs vs. DMA IDs.
print(
f'RACE DETECTED\n read of {buffer_key}[{rnge}] from {device_id},'
- f' {local_core_id}, {user_frame}\n write of'
+ f' {local_core_id}, {user_frame}\n clock: {clock}\n write of'
f' {buffer_key}[{write_range}] from {write_device_id},'
- f' {write_local_core_id} {write_frame}'
+ f' {write_local_core_id} {write_frame}\n clock: {write_clock}\n'
)
with self.lock:
self.races_found = True
@@ -158,9 +158,9 @@ def check_write(
# between real device IDs vs. DMA IDs.
print(
f'RACE DETECTED\n write of {buffer_key}[{rnge}] from {device_id},'
- f' {local_core_id}, {user_frame}\n write of'
+ f' {local_core_id}, {user_frame}\n clock: {clock}\n write of'
f' {buffer_key}[{write_range}] from {write_device_id},'
- f' {write_local_core_id}, {write_frame}'
+ f' {write_local_core_id}, {write_frame}\n clock: {write_clock}\n'
)
with self.lock:
self.races_found = True
@@ -178,9 +178,9 @@ def check_write(
# between real device IDs vs. DMA IDs.
print(
f'RACE DETECTED\n write of {buffer_key}[{rnge}] from {device_id},'
- f' {local_core_id}, {user_frame}\n read of'
+ f' {local_core_id}, {user_frame}\n clock: {clock}\n read of'
f' {buffer_key}[{read_range}] from {read_device_id},'
- f' {read_local_core_id}, {read_frame}'
+ f' {read_local_core_id}, {read_frame}\n clock: {read_clock}\n'
)
with self.lock:
self.races_found = True
diff --git a/jax/_src/pallas/mosaic/interpret/shared_memory.py b/jax/_src/pallas/mosaic/interpret/shared_memory.py
index 74a078bd9fa2..21fd6600928a 100644
--- a/jax/_src/pallas/mosaic/interpret/shared_memory.py
+++ b/jax/_src/pallas/mosaic/interpret/shared_memory.py
@@ -572,3 +572,20 @@ def swap_buffer_content(
mask[in_bounds_idx], value[in_bounds_idx], raw_result
)
return result.copy(), shape_and_dtype, clock
+
+ def update_clocks(self, low_global_core_id, high_global_core_id):
+ """Synchronizes the vector clocks for the cores with ids in the range between the two arguments."""
+ # Despite only updating the vector clocks for some cores, we still need to
+ # hold the global lock to ensure that no other devices are concurrently
+ # accessing the same vector clocks.
+ with self.lock:
+ for c in self.clocks[low_global_core_id + 1 : high_global_core_id]:
+ vc.update_vector_clock(self.clocks[low_global_core_id], c)
+ for c in self.clocks[low_global_core_id + 1 : high_global_core_id]:
+ vc.update_vector_clock(c, self.clocks[low_global_core_id])
+
+ def update_clocks_for_device_barrier(self, device_id):
+ """Synchronizes the vector clocks for the cores on the given device."""
+ low_core_id = device_id * self.num_cores_per_device
+ high_core_id = (device_id + 1) * self.num_cores_per_device
+ self.update_clocks(low_core_id, high_core_id)
diff --git a/jax/_src/pallas/mosaic/interpret/thread_map.py b/jax/_src/pallas/mosaic/interpret/thread_map.py
new file mode 100644
index 000000000000..3b162a70daeb
--- /dev/null
+++ b/jax/_src/pallas/mosaic/interpret/thread_map.py
@@ -0,0 +1,80 @@
+# Copyright 2025 The JAX Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from concurrent import futures
+import functools
+
+import jax
+from jax._src import callback
+import jax.core as jax_core
+import jax.numpy as jnp
+
+
+def _run_jaxpr(jaxpr, consts, *args):
+ def _run(jaxpr, consts, *args):
+ jax_core.eval_jaxpr(jaxpr, consts, *args)
+
+ traced = jax.jit(_run, static_argnums=(0,)).trace(jaxpr, consts, *args)
+ traced.lower().compile()(consts, *args)
+ return
+
+
+def _thread_map_callback(jaxpr, num_threads, consts):
+ num_threads = int(num_threads)
+ threads = []
+ with futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
+ for i in range(num_threads):
+ threads.append(executor.submit(_run_jaxpr, jaxpr, consts, jnp.int32(i)))
+ exceptions = []
+ for i in range(num_threads):
+ try:
+ threads[i].result()
+ except Exception as e:
+ exceptions.append(e)
+ if exceptions:
+ # TODO(jburnim): Use ExceptionGroup once JAX requires Python 3.11.
+ # raise ExceptionGroup('Exceptions raised during _thread_map', exceptions)
+ raise exceptions[0]
+
+
+def _call_threadmap_callback(jaxpr, num_threads, *consts):
+ # NOTE: At runtime, _thread_map_callback will lower and compile the
+ # given jaxpr. (JAX's caches should ensure the jaxpr is only lowered and
+ # compiled once.)
+ #
+ # TODO(jburnim): Would it be worth trying to lower/compile the jaxpr at
+ # lowering/compilation time? E.g., by using a custom primitive here, could
+ # we lower/compile jaxpr at lowering time, and then pass the compiled
+ # function to the callback?
+ return callback.io_callback(
+ functools.partial(_thread_map_callback, jaxpr),
+ (),
+ num_threads,
+ consts,
+ ordered=True,
+ )
+
+
+def thread_map(f, num_threads):
+ if num_threads == 1:
+ f(jnp.int32(0))
+ return
+
+ def _f(core_index):
+ f(core_index)
+ return ()
+
+ jaxpr = jax.make_jaxpr(_f)(jnp.int32(0))
+
+ _call_threadmap_callback(jaxpr.jaxpr, num_threads, *jaxpr.consts)
diff --git a/jax/_src/pallas/mosaic/interpret/utils.py b/jax/_src/pallas/mosaic/interpret/utils.py
new file mode 100644
index 000000000000..460e5031e065
--- /dev/null
+++ b/jax/_src/pallas/mosaic/interpret/utils.py
@@ -0,0 +1,336 @@
+# Copyright 2025 The JAX Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections.abc import Sequence
+import dataclasses
+import math
+import threading
+from typing import Any, Literal
+
+from jax import lax
+from jax._src import core as jax_core
+from jax._src.pallas import primitives
+from jax._src.util import safe_map
+import jax.numpy as jnp
+import numpy as np
+
+
+def get_uninitialized_value(
+ dtype, uninitialized_memory: Literal["nan", "zero"]
+):
+ if uninitialized_memory == "nan":
+ if jnp.issubdtype(dtype, jnp.floating):
+ return np.nan
+ elif jnp.issubdtype(dtype, jnp.integer):
+ return jnp.iinfo(dtype).max
+ elif jnp.issubdtype(dtype, jnp.bool):
+ return True
+ if uninitialized_memory == "zero":
+ return 0
+ raise NotImplementedError(uninitialized_memory + " + " + str(dtype))
+
+
+@dataclasses.dataclass(frozen=True, kw_only=True)
+class InterpretParams:
+ """Parameters for kernel interpret mode.
+
+ Interpret mode is a way to run Pallas kernels on CPU, while simulating TPU/GPU
+ shared memory, communication, and synchronization operations.
+
+ Attributes:
+ detect_races: If True, a dynamic, happens-before race detector will be used
+ to detect data races during kernel interpretation. If any races are
+ detected, a message will be printed and `races.races_found` will be set to
+ True.
+ Default: False.
+ out_of_bounds_reads: If "raise", an exception will be raised on any
+ out-of-bounds read of a buffer. If "uninitialized_value", any parts of
+ the read that are out-of-bounds will return the value used to fill
+ uninitialized memory, which can be configured via the
+ "uninitialized_memory".
+ Default: "raise".
+ skip_floating_point_ops: If True, operations that produce only floating
+ point values will not be interpreted; instead, their results will be
+ replaced with arrays all of `jnp.inf`. Additionally any floating point
+ operands to any operation will be replaced with (arrays of) `jnp.inf`.
+ Default: False.
+ uninitialized_memory: If "nan", allocated buffers are initialized to contain
+ all NaNs (or to their maximum possible value for integers). If "zero",
+ allocated buffers are initialized to all zeros.
+ Default: "nan".
+ num_cores_or_threads_per_device: The number of cores (TPU) or threads (GPU)
+ per device.
+ Default: 1.
+ vector_clock_size: The number of entries in the vector clocks. This should
+ be an integer bigger then the total number of cores, i.e. bigger than
+ `number of devices * num_cores_per_device`. If `None`, the vector clock
+ size that is used in the interpreter will default to twice the total
+ number of cores.
+ Default: None.
+ """
+
+ detect_races: bool = False
+ out_of_bounds_reads: Literal["raise", "uninitialized"] = "raise"
+ skip_floating_point_ops: bool = False
+ uninitialized_memory: Literal["nan", "zero"] = "nan"
+ num_cores_or_threads_per_device: int = 1
+ vector_clock_size: int | None = None
+
+ def get_vector_clock_size(self, num_devices) -> int:
+ """Returns the number of vector clocks to use.`"""
+ num_cores_or_threads = num_devices * self.num_cores_or_threads_per_device
+ if self.vector_clock_size is not None:
+ if num_cores_or_threads >= self.vector_clock_size:
+ raise ValueError(
+ f"Vector clock size ({self.vector_clock_size}) must be greater than"
+ f" the total number of cores/threads ({num_cores_or_threads})."
+ )
+ return self.vector_clock_size
+ else:
+ # Default to twice the total number of cores/threads.
+ return 2 * num_cores_or_threads
+
+ def get_uninitialized_array(self, shape, dtype):
+ return jnp.full(
+ shape,
+ get_uninitialized_value(dtype, self.uninitialized_memory),
+ dtype,
+ )
+
+ def pad_to_block_dimension(self, value, block_shape):
+ """Pads values so the shape evenly divides into block dimensions.
+
+ For example, if values has a shape of (33, 2, 5) with a block_shape of
+ (32, 2, 4), this function will pad the value of shape to (64, 2, 8).
+
+ Args:
+ value: Array to be padded.
+ block_shape: Block shapes to use for padding. If None, no padding will be
+ performed.
+
+ Returns:
+ A padded array.
+ """
+ padded_shape = tuple(
+ ((v - 1) // b + 1) * b for v, b in zip(value.shape, block_shape)
+ )
+ if padded_shape != value.shape:
+ pad_width = tuple((0, a - b) for a, b in zip(padded_shape, value.shape))
+ pad_value = self.get_uninitialized_array((), value.dtype)
+ value = jnp.pad(value, pad_width, constant_values=pad_value)
+ return value
+
+
+@dataclasses.dataclass(frozen=True, kw_only=True)
+class InterpretGPUParams(InterpretParams):
+ ...
+
+
+class Counter:
+ """A simple counter that is thread-safe."""
+
+ def __init__(self, initial_value: int):
+ self.value = initial_value
+ self.lock = threading.Lock()
+
+ def get_next(self):
+ with self.lock:
+ result = self.value
+ self.value += 1
+ return result
+
+
+# TODO(sharadmv): De-dup this w/ the impl in primitives.py.
+def _device_id_dict_to_mesh(device_id_dict, axis_sizes, axis_indices):
+ physical_axis_dict = {}
+ axis_names = axis_sizes.keys()
+ for axis, idx in device_id_dict.items():
+ if isinstance(axis, tuple) and any(a in axis_names for a in axis):
+ if not all(a in axis_names for a in axis):
+ raise NotImplementedError(
+ f"{axis} mixes JAX mesh and Pallas mesh grid axes"
+ )
+ axes_dimensions = [axis_sizes[name] for name in axis]
+ for axis_index, axis_name in enumerate(axis):
+ axis_size = axis_sizes[axis_name]
+ inner_mesh_size = math.prod(axes_dimensions[axis_index + 1 :])
+ minor_divisor = inner_mesh_size
+
+ # Fast path for power of 2s
+ if inner_mesh_size & (inner_mesh_size - 1) == 0:
+ shift_len = (inner_mesh_size & -inner_mesh_size).bit_length() - 1
+ partial_device_idx = idx >> shift_len
+ else:
+ partial_device_idx = idx // minor_divisor
+
+ if axis_size & (axis_size - 1) == 0:
+ device_idx = partial_device_idx & (axis_size - 1)
+ else:
+ device_idx = partial_device_idx % axis_size
+ physical_axis_dict[axis_name] = device_idx
+ else:
+ physical_axis_dict[axis] = idx
+ device_id = []
+ for axis in axis_names:
+ if axis in physical_axis_dict:
+ device_id.append(physical_axis_dict[axis])
+ else:
+ device_id.append(axis_indices[axis])
+ non_mesh_axes = {
+ k: v for k, v in physical_axis_dict.items() if k not in axis_names
+ }
+ return tuple(device_id), non_mesh_axes
+
+
+def device_coords_to_logical_id(device_coords, axis_sizes, axis_indices):
+ if isinstance(device_coords, dict):
+ device_coords, non_mesh_axes = _device_id_dict_to_mesh(
+ device_coords, axis_sizes, axis_indices
+ )
+ if non_mesh_axes:
+ raise NotImplementedError(non_mesh_axes)
+ if not isinstance(device_coords, tuple):
+ device_coords = (device_coords,)
+ assert len(device_coords) == len(axis_sizes)
+ sizes = list(axis_sizes.values())
+ ret = 0
+ for i in range(len(device_coords)):
+ ret += device_coords[i] * math.prod(sizes[i + 1 :])
+ return ret
+
+
+def _device_id_to_logical(device_id, device_id_type, axis_sizes, axis_indices):
+ if device_id is None:
+ return None
+ if device_id_type == primitives.DeviceIdType.MESH:
+ return device_coords_to_logical_id(device_id, axis_sizes, axis_indices)
+ elif device_id_type == primitives.DeviceIdType.LOGICAL:
+ return device_id
+ else:
+ raise ValueError(f"Unsupported device ID type: {device_id_type}")
+
+
+def is_int(dtype):
+ return jnp.issubdtype(dtype, jnp.integer)
+
+
+def is_float(dtype):
+ return jnp.issubdtype(dtype, jnp.floating)
+
+
+@dataclasses.dataclass(frozen=True)
+class Placeholder:
+ """Placeholder for use in `JaxprEnv` below instead of storing a concrete value."""
+
+ shape: tuple[int, ...]
+ dtype: jnp.dtype
+
+
+class JaxprEnv:
+ """An environment for interpreting jaxprs, mapping variables to values."""
+
+ def __init__(
+ self,
+ *,
+ vars: Sequence[jax_core.Var] | None = None,
+ values: Sequence[Any] | None = None,
+ sentinel_for_floating_point_values: Any = None,
+ ):
+ self._sentinel_for_floating_point_values = (
+ sentinel_for_floating_point_values
+ )
+ self._env: dict[jax_core.Var, Any] = {}
+
+ if vars is None and values is None:
+ return
+
+ vars = vars or []
+ values = values or []
+ self.write_many(vars, values)
+
+ def read(self, var):
+ if isinstance(var, jax_core.Literal):
+ result = var.val
+ else:
+ result = self._env[var]
+ if isinstance(result, Placeholder):
+ result = lax.full(
+ result.shape, self._sentinel_for_floating_point_values, result.dtype
+ )
+ return result
+
+ def read_many(self, vars):
+ return safe_map(self.read, vars)
+
+ def write(self, var, value):
+ if self._sentinel_for_floating_point_values and is_float(value.dtype):
+ value = Placeholder(value.shape, value.dtype)
+ self._env[var] = value
+
+ def write_many(self, vars, values):
+ safe_map(self.write, vars, values)
+
+
+def _transform_slice_or_index(slice_or_idx):
+ if isinstance(slice_or_idx, int):
+ return slice_or_idx
+ else:
+ start = int(slice_or_idx.start)
+ size = int(slice_or_idx.size)
+ stride = int(slice_or_idx.stride)
+ return slice(start, start + size * stride, stride)
+
+
+def _compose_slice_or_index(slice_or_idx1, slice_or_idx2):
+ ret = []
+ i = 0
+ j = 0
+ while True:
+ if i == len(slice_or_idx1):
+ ret.extend(slice_or_idx2[j:])
+ return tuple(ret)
+ elif j == len(slice_or_idx2):
+ ret.extend(slice_or_idx1[i:])
+ return tuple(ret)
+ elif isinstance(slice_or_idx1[i], int):
+ ret.append(slice_or_idx1[i])
+ i += 1
+ elif isinstance(slice_or_idx2[j], int):
+ ret.append(
+ slice_or_idx1[i].start + slice_or_idx2[j] * slice_or_idx1[i].step
+ )
+ i += 1
+ j += 1
+ else:
+ ret.append(
+ slice(
+ slice_or_idx1[i].start
+ + slice_or_idx2[j].start * slice_or_idx1[i].step,
+ slice_or_idx1[i].start
+ + slice_or_idx2[j].stop * slice_or_idx1[i].step,
+ slice_or_idx1[i].step * slice_or_idx2[j].step,
+ )
+ )
+ i += 1
+ j += 1
+
+
+def to_range(transforms) -> tuple[slice | int, ...]:
+ ret = ()
+ for transform in transforms:
+ # For now, assume only NDIndexer transforms.
+ ret = _compose_slice_or_index(
+ ret, tuple(_transform_slice_or_index(i) for i in transform.indices)
+ )
+ return ret
diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py
index d6c30be8946f..045628dbd445 100644
--- a/jax/_src/pallas/mosaic/lowering.py
+++ b/jax/_src/pallas/mosaic/lowering.py
@@ -21,7 +21,7 @@
import functools
import operator
import string
-from typing import Any, Protocol, Self, TypeVar, cast
+from typing import Any, Literal, Protocol, Self, TypeVar, cast
import jax
from jax import api_util
@@ -34,8 +34,8 @@
from jax._src import custom_derivatives
from jax._src import debugging
from jax._src import dtypes
-from jax._src import literals
from jax._src import linear_util as lu
+from jax._src import literals
from jax._src import mesh as mesh_lib
from jax._src import pjit
from jax._src import prng
@@ -51,7 +51,6 @@
from jax._src.lax import control_flow
from jax._src.lax import lax as lax_internal
from jax._src.lax.control_flow import BranchesPlatforms
-
from jax._src.lib import xla_client
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith
@@ -89,6 +88,7 @@
AnyMemorySpace = pallas_core.MemorySpace | TPUMemorySpace
VMEM = TPUMemorySpace.VMEM
SMEM = TPUMemorySpace.SMEM
+ANY = pallas_core.MemorySpace.ANY
# Booleans are stored as the following type in memrefs.
BOOL_MEMREF_TYPE = np.dtype('int32')
@@ -249,10 +249,11 @@ def is_cloud_tpu_older_than(self, year: int, month: int, day: int):
return is_cloud_tpu_older_than(year, month, day, backend)
-def _memory_space_to_tpu_memory_space(memory_space: AnyMemorySpace | None
- ) -> TPUMemorySpace:
+def _memory_space_to_tpu_memory_space(
+ memory_space: AnyMemorySpace | None,
+) -> TPUMemorySpace | Literal[ANY]:
if memory_space == jax_core.MemorySpace.Device:
- return TPUMemorySpace.ANY
+ return ANY
match memory_space:
case None:
@@ -261,7 +262,7 @@ def _memory_space_to_tpu_memory_space(memory_space: AnyMemorySpace | None
return TPUMemorySpace.VMEM
case pallas_core.MemorySpace.ANY:
# Map the general ANY memory space to TPU ANY memory space
- return TPUMemorySpace.ANY
+ return ANY
case pallas_core.MemorySpace.HOST:
return TPUMemorySpace.HOST
case (
@@ -341,6 +342,8 @@ def aval_to_ir_type(
if isinstance(aval, state.AbstractRef):
if shape is None:
shape = aval.shape
+ if memory_space is None:
+ memory_space = aval.memory_space
memspace = _memory_space_to_mosaic_attribute(memory_space)
shape = dynamic_shape_replacement_fn(shape)
return ir.MemRefType.get(shape,
@@ -415,9 +418,6 @@ def _get_arg_type(
memory_space = None
if isinstance(aval, state.AbstractRef):
memory_space = _memory_space_to_tpu_memory_space(aval.memory_space)
- # We assume unannotated memory refs are in VMEM
- if memory_space is None:
- memory_space = TPUMemorySpace.VMEM
return aval_to_ir_type(
dynamic_shape_replacement_fn, aval, shape=shape, memory_space=memory_space
)
@@ -663,10 +663,8 @@ def err_details():
"rank >= 1. " + err_details())
if (
- (memory_space == tpu_core.MemorySpace.ANY
- or memory_space == tpu_core.MemorySpace.HBM)
- and not bm.has_trivial_window()
- ):
+ memory_space is ANY or memory_space == tpu_core.MemorySpace.HBM
+ ) and not bm.has_trivial_window():
raise ValueError(
"The Pallas TPU lowering currently supports in memory space ANY "
"only blocks having the same block shape as the array shape "
@@ -804,7 +802,7 @@ def dynamic_shape_replacement_fn(
tpu_memory_space = _memory_space_to_tpu_memory_space(
bm.block_aval.memory_space)
if (
- tpu_memory_space == tpu_core.MemorySpace.ANY
+ tpu_memory_space is ANY
or tpu_memory_space == tpu_core.MemorySpace.HBM
or tpu_memory_space == tpu_core.MemorySpace.SEMAPHORE
):
@@ -3393,7 +3391,8 @@ def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_):
@register_lowering_rule(pjit.reshard_p)
-def _reshard_lowering_rule(ctx: LoweringRuleContext, x, dst_sharding):
+def _reshard_lowering_rule(ctx: LoweringRuleContext, x, *, dst_sharding,
+ concrete_mesh):
return x
@@ -3594,7 +3593,7 @@ def _stochastic_round_lowering_rule(
return tpu.stochastic_convert(out_type, x, random_bits)
-def _check_elementwise_packing_dtypes(unpacked_dtype, packed_dtype):
+def _check_elementwise_unpack_dtypes(unpacked_dtype, packed_dtype):
if unpacked_dtype == jnp.float32 and packed_dtype == jnp.bfloat16:
return
if unpacked_dtype == jnp.int32 and packed_dtype in [
@@ -3612,11 +3611,9 @@ def _pack_elementwise_lowering_rule(
ctx: LoweringRuleContext, *xs, packed_dtype
):
in_aval = ctx.avals_in[0]
- _check_elementwise_packing_dtypes(in_aval.dtype, packed_dtype)
+ out_aval = ctx.avals_out[0]
packed_ir_type = _dtype_to_ir_type(packed_dtype)
- out_type = ir.VectorType.get(
- in_aval.shape, _dtype_to_ir_type(jnp.uint32)
- )
+ out_type = ir.VectorType.get(in_aval.shape, _dtype_to_ir_type(out_aval.dtype))
return tpu.pack_elementwise(out_type, xs, target_type=packed_ir_type)
@@ -3625,7 +3622,7 @@ def _unpack_elementwise_lowering_rule(
ctx: LoweringRuleContext, x, index, packed_dtype, unpacked_dtype
):
in_aval = ctx.avals_in[0]
- _check_elementwise_packing_dtypes(unpacked_dtype, packed_dtype)
+ _check_elementwise_unpack_dtypes(unpacked_dtype, packed_dtype)
out_type = ir.VectorType.get(
in_aval.shape, _dtype_to_ir_type(unpacked_dtype)
)
diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py
index 390781e05820..dd8bb00dcc78 100644
--- a/jax/_src/pallas/mosaic/pallas_call_registration.py
+++ b/jax/_src/pallas/mosaic/pallas_call_registration.py
@@ -18,6 +18,7 @@
from collections.abc import Sequence
import dataclasses
+import json
from typing import cast
import jax
@@ -46,8 +47,7 @@ def _maybe_cast_to_int(x: jax.Array | jax_core.AbstractValue):
after loading from a memref inside of the kernel.
"""
assert isinstance(
- x, (jax.Array, jax_core.ShapedArray, jax_core.DShapedArray,
- state_types.AbstractLinVal)
+ x, (jax.Array, jax_core.ShapedArray, state_types.AbstractLinVal)
), type(x)
if isinstance(x, jax.Array):
if dtypes.issubdtype(x.dtype, jax.numpy.bool_):
@@ -62,7 +62,7 @@ def _maybe_cast_to_int(x: jax.Array | jax_core.AbstractValue):
def _get_memory_space_from_aval(
- out_aval: jax_core.AbstractValue,
+ out_aval: jax_core.AbstractValue, kernel_type: tpu_core.KernelType
) -> tpu_custom_call.MemorySpace | None:
if not isinstance(out_aval, jax_core.ShapedArray):
raise ValueError("Memory spaces not defined for non-ShapedArrays")
@@ -73,10 +73,6 @@ def _get_memory_space_from_aval(
# If we are passed an aval with an explicit memory space tag, we use it
# to constrain the memory space.
match out_aval.memory_space:
- case None:
- return None
- case tpu_core.MemorySpace.ANY:
- return None
case tpu_core.MemorySpace.HBM:
return tpu_custom_call.MemorySpace.HBM
case tpu_core.MemorySpace.VMEM:
@@ -84,20 +80,29 @@ def _get_memory_space_from_aval(
case tpu_core.MemorySpace.SMEM:
return tpu_custom_call.MemorySpace.SMEM
case tpu_core.MemorySpace.SEMAPHORE:
- return tpu_custom_call.MemorySpace.SEMAPHORE_MEM
+ match kernel_type:
+ case tpu_core.KernelType.SC_SCALAR_SUBCORE:
+ return tpu_custom_call.MemorySpace.SC_SCALAR_SEMAPHORE_MEM
+ case tpu_core.KernelType.TC:
+ return tpu_custom_call.MemorySpace.SEMAPHORE_MEM
+ case _:
+ raise ValueError(f"Invalid kernel type for semaphore: {kernel_type}")
case tpu_core.MemorySpace.HOST:
return tpu_custom_call.MemorySpace.HOST
return None
def _get_memory_spaces_from_avals(
- avals: Sequence[jax_core.AbstractValue],
+ avals: Sequence[jax_core.AbstractValue], kernel_type: tpu_core.KernelType
) -> tuple[tpu_custom_call.MemorySpace | None, ...] | None:
memory_spaces = None
if any(
isinstance(aval, pallas_core.ShapedArrayWithMemorySpace) for aval in avals
):
- memory_spaces = tuple(map(_get_memory_space_from_aval, avals))
+ memory_spaces = tuple(
+ _get_memory_space_from_aval(aval, kernel_type=kernel_type)
+ for aval in avals
+ )
return memory_spaces
@@ -140,7 +145,7 @@ def pallas_call_tpu_lowering_rule(
mlir_ctx.load_all_available_dialects()
tpu.register_dialect(mlir_ctx)
- match mosaic_params.kernel_type:
+ match (kernel_type := mosaic_params.kernel_type):
case tpu_core.KernelType.TC:
lower_jaxpr_to_module = lowering.lower_jaxpr_to_module
case tpu_core.KernelType.SC_SCALAR_SUBCORE | tpu_core.KernelType.SC_VECTOR_SUBCORE:
@@ -156,7 +161,7 @@ def pallas_call_tpu_lowering_rule(
grid_mapping,
jaxpr,
dimension_semantics=mosaic_params.dimension_semantics,
- kernel_type=mosaic_params.kernel_type,
+ kernel_type=kernel_type,
mesh=jax_mesh,
dynamic_shape_replacement_enabled=pallas_core.dynamic_shapes_export_enabled(),
)
@@ -191,18 +196,17 @@ def _maybe_cast_inputs(*args):
# Dynamic grid bounds have to go at the front.
dynamic_grid_args, args = in_nodes[:num_dyn_bounds], in_nodes[num_dyn_bounds:]
kernel_ctx = ctx.replace(avals_in=kernel_in_avals, avals_out=kernel_out_avals)
- output_memory_spaces = _get_memory_spaces_from_avals(out_avals)
+ output_memory_spaces = _get_memory_spaces_from_avals(
+ out_avals, kernel_type=kernel_type
+ )
input_memory_spaces = None
if any(
isinstance(aval, pallas_core.ShapedArrayWithMemorySpace)
for aval in ctx.avals_in
):
- # TODO(sharadmv): Support dynamic grid bounds.
- if num_dyn_bounds != 0:
- raise NotImplementedError(
- "Dynamic grid bounds are not supported when specifying memory spaces for inputs."
- )
- input_memory_spaces = _get_memory_spaces_from_avals(ctx.avals_in)
+ input_memory_spaces = _get_memory_spaces_from_avals(
+ ctx.avals_in, kernel_type=kernel_type
+ )
if cost_estimate is not None:
mosaic_cost_estimate = cast(
tpu_custom_call.CostEstimate, dataclasses.asdict(cost_estimate)
@@ -246,6 +250,39 @@ def _maybe_cast_inputs(*args):
has_side_effects = tpu_custom_call.TpuSideEffectType.SIDE_EFFECTING
case _:
raise ValueError(f"Invalid side effect type: {mosaic_params.has_side_effects}")
+ tiling: tpu_custom_call.Tiling | None = None
+ if mosaic_params.use_tc_tiling_on_sc is not None:
+ if kernel_type not in (
+ tpu_core.KernelType.SC_SCALAR_SUBCORE,
+ tpu_core.KernelType.SC_VECTOR_SUBCORE,
+ ):
+ raise ValueError(
+ "use_tc_tiling_on_sc= is only supported for SC_*_SUBCORE kernels"
+ )
+
+ tiling = (
+ tpu_custom_call.Tiling.COMPACT
+ if mosaic_params.use_tc_tiling_on_sc
+ else tpu_custom_call.Tiling.SPARSE_CORE
+ )
+ dict_metadata = dict(metadata) if metadata is not None else {}
+ del metadata
+ if jax_mesh is not None:
+ mesh_axes = {
+ e.name
+ for e in jaxpr.effects
+ if isinstance(e, jax_core.NamedAxisEffect)
+ # Filter for only device mesh axis name effects
+ and e.name in jax_mesh.axis_names
+ }
+ # Only put mesh axes in metadata if there are any.
+ if mesh_axes:
+ if "mesh_axes" in dict_metadata:
+ raise ValueError("Metadata already contains mesh axes.")
+ mesh_axes_list = list(mesh_axes)
+ if all(isinstance(a, str) for a in mesh_axes):
+ mesh_axes_list = sorted(mesh_axes) # type: ignore
+ dict_metadata["mesh_axes"] = json.dumps(mesh_axes_list)
out_nodes = mosaic.lower_module_to_custom_call(
kernel_ctx,
*dynamic_grid_args,
@@ -265,10 +302,11 @@ def _maybe_cast_inputs(*args):
output_memory_spaces=output_memory_spaces,
disable_bounds_checks=mosaic_params.disable_bounds_checks,
input_memory_spaces=input_memory_spaces,
- metadata=dict(metadata) if metadata is not None else None,
+ metadata=dict_metadata,
skip_device_barrier=mosaic_params.skip_device_barrier,
allow_collective_id_without_custom_barrier=mosaic_params.allow_collective_id_without_custom_barrier,
shape_invariant_numerics=mosaic_params.shape_invariant_numerics,
+ tiling=tiling,
)
_maybe_cast_to_bool = (
lambda x, aval: x.astype(jax.numpy.bool_)
diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py
index 8d5ef57505ee..80adc8bac314 100644
--- a/jax/_src/pallas/mosaic/pipeline.py
+++ b/jax/_src/pallas/mosaic/pipeline.py
@@ -20,9 +20,10 @@
import dataclasses
import enum
import functools
-from typing import Any, Union
+from typing import Any, Literal, Union
import jax
+from jax import core as jax_core
from jax import lax
from jax import tree_util
from jax._src import util as jax_util
@@ -30,16 +31,17 @@
from jax._src.pallas import primitives as primitives
from jax._src.pallas.mosaic import core as tpu_core
from jax._src.pallas.mosaic import helpers as tpu_helpers
-from jax._src.pallas.mosaic import tpu_info
from jax._src.pallas.mosaic import primitives as tpu_primitives
+from jax._src.pallas.mosaic import tpu_info
+from jax._src.state import types as state_types
from jax.experimental import pallas as pl
import jax.numpy as jnp
-import numpy as np
SMEM = tpu_core.MemorySpace.SMEM
VMEM = tpu_core.MemorySpace.VMEM
-ANY = tpu_core.MemorySpace.ANY
+HBM = tpu_core.MemorySpace.HBM
+ANY = pallas_core.MemorySpace.ANY
REF = pallas_core.MemoryRef
GridDimensionSemantics = tpu_core.GridDimensionSemantics
PARALLEL = tpu_core.PARALLEL
@@ -79,17 +81,32 @@ def add_leaves(i, x):
def _get_tpu_generation() -> int:
return tpu_info.get_tpu_info().generation
-def _make_tiling(shape: tuple[int, ...], dtype: np.dtype) -> tuple[int, ...]:
- # For a n-dimensional shape, returns (8, 128) for the last 2 dimensions
- # and 1 for the leading n - 2. For example, (256, 256) -> (8, 128) and
- # (2, 3, 128, 128) -> (1, 1, 8, 128).
+
+def _make_tiling(
+ shape: tuple[int, ...], ty: jax_core.AbstractValue
+) -> tuple[int | None, ...]:
+ """Compute a tiling for the given shape and type.
+
+ For a n-dimensional shape, returns (8, 128) for the last 2 dimensions
+ and 1 for the leading n - 2. For example, (256, 256) -> (8, 128) and
+ (2, 3, 128, 128) -> (1, 1, 8, 128).
+
+ Types are not required to have a dtype, so for such types we return None for
+ all dimensions because their tiling is unknown.
+ """
+
if len(shape) < 2:
raise ValueError(f"Shape must have at least 2 dimensions: {shape=}")
+
+ if not hasattr(ty, 'dtype'):
+ return (None,) * len(shape)
+
leading_dims, final_dims = shape[:-2], shape[-2:]
# We want to find the minimum power of 2 that fits the second-minor dimension
# of shape, with maximum value 8.
second_minor, _ = final_dims
- packing = 4 // dtype.itemsize
+
+ packing = 4 // ty.dtype.itemsize
max_tiling = _TILING[0]
second_minor_tiling = (1 + int(_get_tpu_generation() < 4)) * packing
while second_minor_tiling < min(second_minor, max_tiling):
@@ -114,13 +131,18 @@ def _make_block_ds(
assert isinstance(out, pl.Slice)
return out
-def _create_blocked_slice(block_index: jax.Array | int,
- block_size: int,
- dim_size: int,
- tiling: int):
+
+def _create_blocked_slice(
+ block_index: jax.Array | int,
+ block_size: int,
+ dim_size: int,
+ tiling: int | None,
+):
block_start = block_size * block_index
if (dim_rem := dim_size % block_size) == 0:
return pl.ds(block_start, block_size)
+ if tiling is None:
+ raise ValueError("If tiling is None, block_size must divide dim_size.")
if block_size % tiling != 0:
raise ValueError(f"Block size must divide tiling: {block_size=}, {tiling=}")
num_blocks = pl.cdiv(dim_size, block_size)
@@ -137,12 +159,15 @@ def _create_bounded_slice(slice_start: jax.Array | int,
slice_size: jax.Array | int,
block_size: int,
dim_size: int,
- tiling: int):
- if block_size % tiling != 0:
+ tiling: int | None):
+ if tiling is not None and block_size % tiling != 0:
raise ValueError(f"Block size must divide tiling: {block_size=}, {tiling=}")
# We assume by construction that slice_size <= block_size. We also assume
# that the slice_start is already aligned to the tiling.
+ if tiling is None:
+ return pl.ds(slice_start, slice_size)
+
# If we are out of bound, we need to round the slice size down to the nearest
# multiple of the tiling.
is_oob = slice_start + slice_size > dim_size
@@ -157,7 +182,7 @@ def _create_bounded_slice(slice_start: jax.Array | int,
def _make_block_slice(
block_index: jax.Array, block_size: pl.BlockDim | int | None, size: int,
- tiling: int
+ tiling: int | None
) -> pl.Slice | slice | int | jax.Array:
# Computes a slice given a block index and block size. In the default case,
# we return slice(block_index * block_size, (block_index + 1) * block_size).
@@ -332,7 +357,7 @@ def block_shape(self) -> Sequence[pl.BlockDim | int | None] | None:
def compute_index(self):
return self.spec.index_map
- def get_dma_slice(self, src_shape, src_dtype, grid_indices):
+ def get_dma_slice(self, src_ty, grid_indices):
# We need to handle blocks that might go OOB in the src array. An in bounds
# block looks like this (for array shape (600, 600) and block shape
# (256, 256)):
@@ -379,10 +404,14 @@ def get_dma_slice(self, src_shape, src_dtype, grid_indices):
# Suppose A is now (601, 600), instead of picking a (88, 256)-sized block
# for the last iteration on that dimension, we will pick the next highest
# tile multiple, i.e. (96, 256).
+
+ if (src_shape := getattr(src_ty, "shape", None)) is None:
+ raise ValueError(f'Type {src_ty} does not have a type.')
+
if len(src_shape) < 2:
raise NotImplementedError("Must use >1D values.")
- tiling = _make_tiling(src_shape, src_dtype)
+ tiling = _make_tiling(src_shape, src_ty)
block_indices = self.compute_index(*grid_indices)
return tuple(
_make_block_slice(bi, bs, ss, t)
@@ -403,17 +432,24 @@ def with_spec(self, spec: pl.BlockSpec) -> BufferedRefBase:
"""Returns a new BufferedRefBase with the given block spec."""
raise NotImplementedError()
+def _ref_to_value_aval(ref):
+ """Return the inner of a ref, or a ShapedArray for TransformedRefs."""
+ return (
+ jax_core.ShapedArray(shape=ref.shape, dtype=ref.dtype)
+ if isinstance(ref, state_types.TransformedRef)
+ else jax.typeof(ref).inner_aval
+ )
+
# TODO(justinfu): Refactor and rename slot fields to reflect cumulative values
# instead of slot index.
-@tree_util.register_pytree_node_class
+@tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class BufferedRef(BufferedRefBase):
"""A helper class to automate VMEM double buffering in pallas pipelines.
Attributes:
spec: pallas blockspec.
- dtype: dtype for buffers.
buffer_type: enum indicating whether this is an input, output, or in/out
accumulator buffered reference.
window_ref: a multiple-buffer to hold the working and dirty buffers used
@@ -443,9 +479,8 @@ class BufferedRef(BufferedRefBase):
swap: Tracks whether the BufferedRef slots need to be swapped before next
copy.
"""
- _spec: pl.BlockSpec # static metadata
- dtype: Any # static metadata
- _buffer_type: BufferType # static metadata
+ _spec: pl.BlockSpec = dataclasses.field(metadata=dict(static=True))
+ _buffer_type: BufferType = dataclasses.field(metadata=dict(static=True))
window_ref: ArrayRef | None
accum_ref: ArrayRef | None
copy_in_slot: ArrayRef | None
@@ -502,47 +537,28 @@ def buffer_count(self) -> int:
raise ValueError("buffer count is undefined")
return self.window_ref.shape[0] # type: ignore[union-attr]
- def tree_flatten(self):
- return (
- (
- self.window_ref,
- self.accum_ref,
- self.copy_in_slot,
- self.wait_in_slot,
- self.copy_out_slot,
- self.wait_out_slot,
- self._copy_in_slot_reg,
- self._wait_in_slot_reg,
- self._copy_out_slot_reg,
- self._wait_out_slot_reg,
- self.next_fetch_smem,
- self.next_fetch_sreg,
- self.sem_recvs,
- self.sem_sends,
- self.swap,
- ),
- (self._spec, self.dtype, self._buffer_type),
- )
-
- @classmethod
- def tree_unflatten(cls, meta, data):
- return cls(*meta, *data)
-
@staticmethod
def buffer_types() -> type[BufferType]:
return BufferType
@classmethod
- def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
- needs_swap_ref=True,
- grid_rank=None,
- use_lookahead=False,
- source_memory_space: tpu_core.MemorySpace = ANY) -> BufferedRef:
+ def create(
+ cls,
+ spec: pl.BlockSpec,
+ dtype_or_type,
+ buffer_type,
+ buffer_count,
+ needs_swap_ref=True,
+ grid_rank=None,
+ use_lookahead=False,
+ source_memory_space: tpu_core.MemorySpace | Literal[ANY] = ANY, # type: ignore[valid-type]
+ ) -> BufferedRef:
"""Create a BufferedRef.
Args:
spec: pallas blockspec.
- dtype: dtype for buffers.
+ dtype_or_type: dtype or aval for buffers. If an aval, the shape is
+ ignored.
buffer_type: enum indicating whether this is an input, output, or in/out
accumulator buffered reference.
needs_swap_ref: whether a swap slots tracker needs to be allocated.
@@ -553,21 +569,29 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
Returns:
Initialized BufferedRef
"""
+
+ # (123, 456) is a dummy shape since we never use ty without
+ # calling .update(shape=...) first.
+ ty = (
+ dtype_or_type
+ if isinstance(dtype_or_type, jax_core.AbstractValue)
+ else jax_core.ShapedArray((123, 456), dtype_or_type)
+ )
+
block_shape = _get_block_shape(spec)
if buffer_type is BufferType.ACCUMULATOR:
- accum_ref = VMEM(block_shape, dtype)
+ accum_ref = VMEM.from_type(ty.update(shape=block_shape))
else:
accum_ref = None
- if source_memory_space == VMEM:
- # We don't need to do any double-buffering in the case that our pipeline
- # reference is already in VMEM, we just need allocate the accumulation
- # buffer and we will refer to the original reference slices directly.
- if spec.memory_space not in (VMEM, None):
- raise ValueError(
- f"Cannot hold a non-buffered ref in {spec.memory_space=}")
+ buffer_memory_space = (
+ VMEM if spec.memory_space is None else spec.memory_space)
+ if buffer_memory_space not in (SMEM, VMEM, HBM):
+ raise ValueError(
+ f"Unsupported buffer memory space: {buffer_memory_space}"
+ )
+ if source_memory_space is buffer_memory_space:
return cls(
_spec=spec,
- dtype=dtype,
_buffer_type=buffer_type,
window_ref=None, # to be bound to existing ref by the pipeline routine
accum_ref=accum_ref,
@@ -586,21 +610,16 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
swap=None,
)
else:
- buffer_memory_space = (
- VMEM if spec.memory_space is None else spec.memory_space)
- if buffer_memory_space not in (SMEM, VMEM):
- raise ValueError(
- f"Unsupported buffer memory space: {buffer_memory_space}"
- )
if use_lookahead and grid_rank is None:
raise ValueError(
"grid_rank must be specified when use_lookahead is True."
)
+
+ buffer_ty = ty.update(shape=(buffer_count, *block_shape))
return cls(
_spec=spec,
- dtype=dtype,
_buffer_type=buffer_type,
- window_ref=buffer_memory_space((buffer_count,) + block_shape, dtype),
+ window_ref=buffer_memory_space.from_type(buffer_ty),
accum_ref=accum_ref,
copy_in_slot=SMEM((1,), jnp.uint32) if buffer_type.is_input else None,
wait_in_slot=SMEM((1,), jnp.uint32) if buffer_type.is_input else None,
@@ -627,22 +646,28 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count,
)
@classmethod
- def input(cls, spec, dtype, buffer_count=2, **kwargs):
- return cls.create(spec, dtype, BufferType.INPUT, buffer_count, **kwargs)
+ def input(cls, spec, dtype_or_type, buffer_count=2, **kwargs):
+ return cls.create(
+ spec, dtype_or_type, BufferType.INPUT, buffer_count, **kwargs
+ )
@classmethod
- def output(cls, spec, dtype, buffer_count=2, **kwargs):
- return cls.create(spec, dtype, BufferType.OUTPUT, buffer_count, **kwargs)
+ def output(cls, spec, dtype_or_type, buffer_count=2, **kwargs):
+ return cls.create(
+ spec, dtype_or_type, BufferType.OUTPUT, buffer_count, **kwargs
+ )
@classmethod
- def accumulator(cls, spec, dtype, buffer_count=2, **kwargs):
- return cls.create(spec, dtype, BufferType.ACCUMULATOR, buffer_count,
- **kwargs)
+ def accumulator(cls, spec, dtype_or_type, buffer_count=2, **kwargs):
+ return cls.create(
+ spec, dtype_or_type, BufferType.ACCUMULATOR, buffer_count, **kwargs
+ )
@classmethod
- def input_output(cls, spec, dtype, buffer_count=2, **kwargs):
- return cls.create(spec, dtype, BufferType.INPUT_OUTPUT, buffer_count,
- **kwargs)
+ def input_output(cls, spec, dtype_or_type, buffer_count=2, **kwargs):
+ return cls.create(
+ spec, dtype_or_type, BufferType.INPUT_OUTPUT, buffer_count, **kwargs
+ )
@property
def block_shape(self):
@@ -949,7 +974,7 @@ def copy_in(self, src_ref, grid_indices):
if self.swap is not None:
self.swap[0] = True
slot = self.current_copy_in_slot
- src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices)
+ src_slice = self.get_dma_slice(_ref_to_value_aval(src_ref), grid_indices)
dst_slice = tuple(
pl.ds(0, s.size)
for s, bd in zip(src_slice, self.block_shape)
@@ -970,7 +995,7 @@ def copy_out(self, dst_ref, grid_indices):
if self.swap is not None:
self.swap[0] = True
slot = self.current_copy_out_slot
- dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices)
+ dst_slice = self.get_dma_slice(_ref_to_value_aval(dst_ref), grid_indices)
src_slice = tuple(
pl.ds(0, s.size)
for s, bd in zip(dst_slice, self.block_shape)
@@ -988,7 +1013,7 @@ def wait_in(self, src_ref, grid_indices):
if not self.is_buffered: return
assert not (self.window_ref is None or isinstance(self.window_ref, REF))
assert self.sem_recvs is not None
- src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices)
+ src_slice = self.get_dma_slice(_ref_to_value_aval(src_ref), grid_indices)
dst_slice = tuple(
pl.ds(0, s.size)
for s, bd in zip(src_slice, self.block_shape)
@@ -1010,7 +1035,7 @@ def wait_out(self, dst_ref, grid_indices):
assert not (self.window_ref is None or isinstance(self.window_ref, REF))
assert self.sem_sends is not None
wait_slot = self.current_wait_out_slot
- dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices)
+ dst_slice = self.get_dma_slice(_ref_to_value_aval(dst_ref), grid_indices)
src_slice = tuple(
pl.ds(0, s.size)
for s, bd in zip(dst_slice, self.block_shape)
@@ -1305,7 +1330,7 @@ def out_of_fetch(self, buffered_ref):
# Currently this is based on the iteration, but if we want to support
# lookahead this will depend on whether the lookahead reached the end.
if not buffered_ref.is_buffered:
- return False
+ return jnp.bool(False)
return self.step >= (self.num_steps - buffered_ref.buffer_count + 1)
def has_changed(self, buffered_ref):
@@ -1708,7 +1733,9 @@ def make_input_bref(in_spec, in_ref):
use_lookahead = in_spec.pipeline_mode.use_lookahead
if use_lookahead and grid is None:
raise ValueError("Grid must be specified when using lookahead.")
- return BufferedRef.input(in_spec, in_ref.dtype, buffer_count,
+
+ in_aval = _ref_to_value_aval(in_ref)
+ return BufferedRef.input(in_spec, in_aval, buffer_count,
needs_swap_ref=needs_swap_ref,
grid_rank=len(grid),
use_lookahead=use_lookahead,
@@ -1721,11 +1748,13 @@ def make_output_bref(out_spec, out_ref, accumulate):
if out_spec.pipeline_mode.use_lookahead:
raise ValueError("Output buffering does not support lookahead.")
+ out_aval = _ref_to_value_aval(out_ref)
+
if accumulate:
- return BufferedRef.accumulator(out_spec, out_ref.dtype, buffer_count,
+ return BufferedRef.accumulator(out_spec, out_aval, buffer_count,
needs_swap_ref=needs_swap_ref,
source_memory_space=out_ref.memory_space)
- return BufferedRef.output(out_spec, out_ref.dtype, buffer_count,
+ return BufferedRef.output(out_spec, out_aval, buffer_count,
needs_swap_ref=needs_swap_ref,
source_memory_space=out_ref.memory_space)
out_brefs = jax.tree.map(
@@ -1843,7 +1872,7 @@ def sync_copy(src: REF | BufferedRef, dst: REF | BufferedRef, indices):
bref = dst
hbm_ref = src
copy_in = True
- hbm_slice = bref.get_dma_slice(hbm_ref.shape, hbm_ref.dtype, indices)
+ hbm_slice = bref.get_dma_slice(_ref_to_value_aval(hbm_ref), indices)
bref_slice = tuple(
pl.ds(0, s.size)
for s, bd in zip(hbm_slice, bref.block_shape)
diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py
index be9146efc9c3..67aaab743d24 100644
--- a/jax/_src/pallas/mosaic/primitives.py
+++ b/jax/_src/pallas/mosaic/primitives.py
@@ -362,6 +362,8 @@ def _get_dma_effects(
dst_transforms_avals,
dst_sem_transforms_avals,
src_sem_aval,
+ device_id_aval,
+ device_id_type,
):
n_src_transforms = len(tree_util.tree_leaves(src_transforms_avals))
n_dst_transforms = len(tree_util.tree_leaves(dst_transforms_avals))
@@ -377,12 +379,67 @@ def _get_dma_effects(
1 + n_src_transforms + 1 + n_dst_transforms + 1 + n_dst_sem_transforms
)
effs.add(state.WriteEffect(src_sem_index))
+ if device_id_aval is not None:
+ if device_id_type is primitives.DeviceIdType.MESH and isinstance(
+ device_id_aval, dict
+ ):
+ for k in device_id_aval:
+ if not isinstance(k, tuple):
+ k = (k,)
+ for k_ in k:
+ effs.add(jax_core.NamedAxisEffect(k_))
return effs
dma_start_p = jax_core.Primitive('dma_start')
dma_start_p.multiple_results = True
+def _dma_is_high(*avals, **params):
+ return any(aval.is_high for aval in avals)
+
+dma_start_p.is_high = _dma_is_high # type: ignore[method-assign]
+
+def _dma_start_to_lojax(*args, tree, device_id_type, priority, add):
+ (
+ src_ref,
+ src_transforms,
+ dst_ref,
+ dst_transforms,
+ dst_sem,
+ dst_sem_transforms,
+ src_sem,
+ src_sem_transforms,
+ device_id,
+ ) = tree_util.tree_unflatten(tree, args)
+ src_ref_aval = jax_core.get_aval(src_ref)
+ dst_ref_aval = jax_core.get_aval(dst_ref)
+ if not (src_ref_aval.is_high and dst_ref_aval.is_high):
+ raise NotImplementedError("dma_start not implemented in LoJAX yet.")
+ dst_sem_aval = jax_core.get_aval(dst_sem)
+ if dst_sem_aval.is_high:
+ raise NotImplementedError("dma_start not implemented in LoJAX yet.")
+ if src_sem is not None:
+ if jax_core.get_aval(src_sem).is_high:
+ raise NotImplementedError("dma_start not implemented in LoJAX yet.")
+ src_transformed_ref = state.TransformedRef(src_ref, src_transforms)
+ dst_transformed_ref = state.TransformedRef(dst_ref, dst_transforms)
+ if src_sem is not None:
+ src_sem = state.TransformedRef(src_sem, src_sem_transforms)
+ dst_sem = state.TransformedRef(dst_sem, dst_sem_transforms)
+
+ src_ref_aval.inner_aval.dma_start(
+ src_transformed_ref,
+ dst_transformed_ref,
+ src_sem,
+ dst_sem,
+ device_id=device_id,
+ priority=priority,
+ device_id_type=device_id_type,
+ add=add
+ )
+ return []
+dma_start_p.to_lojax = _dma_start_to_lojax
+
@dma_start_p.def_effectful_abstract_eval
def _dma_start_abstract_eval(*args, tree, device_id_type, priority, add):
if priority < 0:
@@ -425,6 +482,8 @@ def _dma_start_abstract_eval(*args, tree, device_id_type, priority, add):
dst_transforms_avals,
dst_sem_transforms_avals,
src_sem_aval,
+ device_id_aval,
+ device_id_type,
)
def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn,
@@ -646,9 +705,48 @@ def do_discharge_src_sem(src_sem=src_sem):
dma_wait_p = jax_core.Primitive('dma_wait')
dma_wait_p.multiple_results = True
+dma_wait_p.is_high = _dma_is_high # type: ignore[method-assign]
+
+def _dma_wait_to_lojax(*args, tree, device_id_type):
+ (
+ src_ref,
+ src_transforms,
+ dst_ref,
+ dst_transforms,
+ dst_sem,
+ dst_sem_transforms,
+ src_sem,
+ src_sem_transforms,
+ device_id,
+ ) = tree_util.tree_unflatten(tree, args)
+ src_ref_aval = jax_core.get_aval(src_ref)
+ dst_ref_aval = jax_core.get_aval(dst_ref)
+ if not (src_ref_aval.is_high and dst_ref_aval.is_high):
+ raise NotImplementedError("dma_wait not implemented in LoJAX yet.")
+ dst_sem_aval = jax_core.get_aval(dst_sem)
+ if dst_sem_aval.is_high:
+ raise NotImplementedError("dma_wait not implemented in LoJAX yet.")
+ if src_sem is not None:
+ if jax_core.get_aval(src_sem).is_high:
+ raise NotImplementedError("dma_wait not implemented in LoJAX yet.")
+ src_transformed_ref = state.TransformedRef(src_ref, src_transforms)
+ dst_transformed_ref = state.TransformedRef(dst_ref, dst_transforms)
+ if src_sem is not None:
+ src_sem = state.TransformedRef(src_sem, src_sem_transforms)
+ dst_sem = state.TransformedRef(dst_sem, dst_sem_transforms)
+ src_ref_aval.inner_aval.dma_wait(
+ src_transformed_ref,
+ dst_transformed_ref,
+ src_sem,
+ dst_sem,
+ device_id=device_id,
+ device_id_type=device_id_type,
+ )
+ return []
+dma_wait_p.to_lojax = _dma_wait_to_lojax
+
@dma_wait_p.def_effectful_abstract_eval
def _dma_wait_abstract_eval(*args, tree, device_id_type):
- del device_id_type
(
src_ref_aval,
src_transforms_avals,
@@ -665,6 +763,8 @@ def _dma_wait_abstract_eval(*args, tree, device_id_type):
dst_transforms_avals,
dst_sem_transforms_avals,
src_sem_aval,
+ device_id_aval,
+ device_id_type,
)
def _dma_wait_pp_eqn(eqn: jax_core.JaxprEqn,
@@ -749,7 +849,16 @@ def _get_ref_and_transforms(ref):
def make_async_copy(src_ref, dst_ref, sem) -> AsyncCopyDescriptor:
- """Issues a DMA copying from src_ref to dst_ref."""
+ """Creates a description of an asynchronous copy operation.
+
+ Args:
+ src_ref: The source Reference.
+ dst_ref: The destination Reference.
+ sem: The semaphore used to track completion of the copy.
+
+ Returns:
+ An AsyncCopyDescriptor.
+ """
src_ref, src_transforms = _get_ref_and_transforms(src_ref)
dst_ref, dst_transforms = _get_ref_and_transforms(dst_ref)
sem, sem_transforms = _get_ref_and_transforms(sem)
@@ -835,6 +944,7 @@ def async_remote_copy(
device_id,
device_id_type: primitives.DeviceIdType = primitives.DeviceIdType.MESH,
) -> AsyncCopyDescriptor:
+ """Issues a remote DMA copying from src_ref to dst_ref."""
copy_descriptor = make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem,
device_id, device_id_type)
copy_descriptor.start()
@@ -969,6 +1079,7 @@ def _stochastic_round_abstract_eval(x, random_bits, *, target_dtype):
)
return jax_core.ShapedArray(x.shape, target_dtype)
+
def _get_elementwise_packing_factor(unpacked_dtype, packed_dtype):
unpacked_bitwidth = dtypes.itemsize_bits(unpacked_dtype)
packed_bitwidth = dtypes.itemsize_bits(packed_dtype)
@@ -995,13 +1106,22 @@ def _pack_elementwise_abstract_eval(*xs, packed_dtype):
raise ValueError("All sources must have the same shape")
if not all(x.dtype == first.dtype for x in xs):
raise ValueError("All sources must have the same dtype")
+ if not (first.dtype == jnp.float32 and packed_dtype == jnp.bfloat16) and not (
+ jnp.issubdtype(first.dtype, jnp.integer)
+ and jnp.issubdtype(packed_dtype, jnp.integer)
+ ):
+ raise ValueError(
+ "Only f32 -> bf16 and int -> int are supported. Got"
+ f" {first.dtype} and {packed_dtype}"
+ )
packing_factor = _get_elementwise_packing_factor(first.dtype, packed_dtype)
if len(xs) != packing_factor:
raise ValueError(
"The number of sources must match the packing factor "
f"({packing_factor}), got {len(xs)}"
)
- return jax_core.ShapedArray(first.shape, jnp.uint32)
+ out_dtype = jnp.dtype(f"uint{dtypes.itemsize_bits(first.dtype)}")
+ return jax_core.ShapedArray(first.shape, out_dtype)
unpack_elementwise_p = jax_core.Primitive("unpack_elementwise")
@@ -1029,7 +1149,7 @@ def with_memory_space_constraint(
) -> jax.Array:
"""Constrains the memory space of an array.
- This primitive does not change the value of `x`, but it constrains the
+ This primitive does not change the value of ``x``, but it constrains the
memory space where it should be allocated. This is useful to force
Pallas to allocate an array in a specific memory space.
@@ -1042,9 +1162,9 @@ def with_memory_space_constraint(
memory_space: The memory space to constrain to.
Returns:
- The array `x` with the memory space constraint.
+ The array ``x`` with the memory space constraint.
"""
- if memory_space in {tpu_core.MemorySpace.ANY, pl_core.MemorySpace.ANY}:
+ if memory_space is pl_core.MemorySpace.ANY:
return x
if memory_space not in {
tpu_core.MemorySpace.HBM,
diff --git a/jax/_src/pallas/mosaic/random.py b/jax/_src/pallas/mosaic/random.py
index b3725619caff..3751b5611655 100644
--- a/jax/_src/pallas/mosaic/random.py
+++ b/jax/_src/pallas/mosaic/random.py
@@ -171,13 +171,13 @@ def sample_block(sampler_fn: SampleFnType,
**kwargs) -> jax.Array:
"""Samples a block of random values with invariance guarantees.
- `sample_block` allows the sampling of identical blocks of random values
+ ``sample_block`` allows the sampling of identical blocks of random values
across kernels with different block shapes and iteration orders. Each call
to `sample_block` returns a `block_size`-shaped array of random samples
corresponding to the `block_index`.
- `tile_size` should be chosen such that it is a divisor to all block sizes
- one needs to be invariant to. The larger the `tile_size`, the more
+ ``tile_size`` should be chosen such that it is a divisor to all block sizes
+ one needs to be invariant to. The larger the ``tile_size``, the more
efficient the sampling process will be and therefore the best choice is
typically the greatest common divisor between all possible block sizes.
@@ -186,7 +186,7 @@ def sample_block(sampler_fn: SampleFnType,
random samples.
global_key: The global key to use for sampling.
block_size: The shape of an individual block.
- tile_size: The shape of a `tile`, which is the smallest unit at
+ tile_size: The shape of a ``tile``, which is the smallest unit at
which samples are generated. This should be selected to be a divisor
of all block sizes one needs to be invariant to.
total_size: The total size of the array to sample.
@@ -195,8 +195,8 @@ def sample_block(sampler_fn: SampleFnType,
**kwargs: Additional arguments to pass to the sampler_fn.
Returns:
- A `block_size` shaped array of samples for the current block corresponding
- to `block_index`.
+ A ``block_size`` shaped array of samples for the current block corresponding
+ to ``block_index``.
"""
if len(block_size) != len(tile_size):
raise ValueError(f"block_size ({len(block_size)}) and tile_size "
diff --git a/jax/_src/pallas/mosaic/sc_core.py b/jax/_src/pallas/mosaic/sc_core.py
index 2eaab2546e8a..8f3001f25730 100644
--- a/jax/_src/pallas/mosaic/sc_core.py
+++ b/jax/_src/pallas/mosaic/sc_core.py
@@ -152,7 +152,7 @@ class BlockMapping(pallas_core.BlockMapping):
def get_sparse_core_info() -> tpu_info.SparseCoreInfo:
"""Returns the SparseCore information for the current device."""
return tpu_info.get_tpu_info().sparse_core or tpu_info.SparseCoreInfo(
- num_cores=0, num_subcores=0, num_lanes=0
+ num_cores=0, num_subcores=0, num_lanes=0, dma_granule_size_bytes=0,
)
@@ -219,6 +219,11 @@ def _scalar_subcore_mesh_discharge_rule(
compiler_params = tpu_core.CompilerParams()
if compiler_params.dimension_semantics is not None:
raise ValueError("ScalarSubcoreMesh does not support dimension_semantics=")
+ sa_avals = [a for a in in_avals if isinstance(a, jax_core.ShapedArray)]
+ if sa_avals:
+ raise NotImplementedError(
+ f"Cannot close over values in core_map: {sa_avals}"
+ )
return pallas_core.default_mesh_discharge_rule(
in_avals,
out_avals,
diff --git a/jax/_src/pallas/mosaic/sc_lowering.py b/jax/_src/pallas/mosaic/sc_lowering.py
index b49c981cdc6a..c6da2956557b 100644
--- a/jax/_src/pallas/mosaic/sc_lowering.py
+++ b/jax/_src/pallas/mosaic/sc_lowering.py
@@ -330,7 +330,10 @@ def body_func(*args: ir.Value):
mosaic_grid_mapping.block_mappings,
):
d = {}
- if str(arg.type.memory_space) == "#tpu.memory_space":
+ if (
+ str(arg.type.memory_space) == "#tpu.memory_space"
+ or str(arg.type.memory_space) == "#tpu.memory_space"
+ ):
d["sc.persistent"] = ir.UnitAttr.get()
if isinstance(bm, sc_core.BlockMapping) and bm.indexed_by is not None:
d["sc.indexed_by"] = mlir.i32_attr(bm.indexed_by)
diff --git a/jax/_src/pallas/mosaic/sc_primitives.py b/jax/_src/pallas/mosaic/sc_primitives.py
index cd8a2cb303c5..7f06e37d8474 100644
--- a/jax/_src/pallas/mosaic/sc_primitives.py
+++ b/jax/_src/pallas/mosaic/sc_primitives.py
@@ -33,6 +33,7 @@
from jax._src.pallas import core as pallas_core
from jax._src.pallas.mosaic import core as tpu_core
from jax._src.pallas.mosaic import lowering as tc_lowering
+from jax._src.pallas.mosaic import sc_core
from jax._src.pallas.mosaic import sc_lowering
from jax._src.state import primitives as state_primitives
from jax._src.state import types as state_types
@@ -634,6 +635,72 @@ def _reduce_sum_lowering_rule(
_cumsum_lowering_rule(ctx, x, 0, reverse=False), [], [vec_dim - 1])
+masked_sort_p = jax_core.Primitive("masked_sort")
+masked_sort_p.multiple_results = True
+
+@masked_sort_p.def_abstract_eval
+def _masked_sort_abstract_eval(keys, values, *maybe_mask, descending):
+ del descending # Unused.
+ supported_shape = (sc_core.get_sparse_core_info().num_lanes,)
+ if keys.dtype not in (jnp.int32, jnp.float32):
+ raise NotImplementedError(
+ f"sort_key_val: keys dtype {keys.dtype} should be int32 or float32")
+ if keys.shape != supported_shape:
+ raise ValueError(f"keys shape {keys.shape} must be {supported_shape}")
+ if jnp.dtype(values.dtype).itemsize != 4:
+ raise NotImplementedError(
+ f"sort_key_val: values dtype {values.dtype} should be 32 bits")
+ if values.shape != supported_shape:
+ raise ValueError(f"values shape {values.shape} must be {supported_shape}")
+ if maybe_mask:
+ [mask] = maybe_mask
+ if not jnp.issubdtype(mask.dtype, jnp.bool):
+ raise TypeError(f"mask dtype {mask.dtype} is not boolean")
+ if mask.shape != supported_shape:
+ raise ValueError(f"mask shape {mask.shape} must be {supported_shape}")
+ return keys, values, *maybe_mask
+
+@sc_lowering.register_lowering_rule(masked_sort_p)
+def _masked_sort_lowering_rule(
+ ctx: sc_lowering.LoweringRuleContext, keys, values, *maybe_mask, descending):
+ del ctx # Unused.
+ if maybe_mask:
+ [mask] = maybe_mask
+ else:
+ mask_type = ir.VectorType.get(
+ [sc_core.get_sparse_core_info().num_lanes],
+ ir.IntegerType.get_signless(1))
+ mask = arith.constant(mask_type, ir.DenseElementsAttr.get_splat(
+ mask_type, ir.BoolAttr.get(True)))
+ out_mask, sorted_keys, sorted_values = tpu.sort(
+ mask.type, keys.type, values.type, keys, values, mask=mask,
+ descending=descending
+ )
+ if maybe_mask:
+ return sorted_keys, sorted_values, out_mask
+ return sorted_keys, sorted_values
+
+def sort_key_val(
+ keys: jax.Array, values: jax.Array, *,
+ mask: jax.Array | None = None, descending: bool = False
+) -> jax.Array:
+ """Sorts keys and values, pushing invalid elements to the last positions.
+
+ Args:
+ keys: An array of integers or floats.
+ values: An array of values corresponding to the keys.
+ mask: An optional array of booleans, which specifies which elements of
+ `keys` and `values` are valid. If `None`, all elements are valid.
+ descending: Whether to sort in descending order.
+
+ Returns:
+ sorted_keys, sorted_values, [output_mask]: The sorted keys and values, and,
+ if a mask was given, the corresponding mask for output keys and values.
+ """
+ maybe_mask = () if mask is None else (mask,)
+ return masked_sort_p.bind(keys, values, *maybe_mask, descending=descending)
+
+
parallel_loop_p = jax_core.Primitive("parallel_loop")
parallel_loop_p.is_effectful = lambda params: bool(params["jaxpr"].effects) # type: ignore
parallel_loop_p.multiple_results = True
diff --git a/jax/_src/pallas/mosaic/tpu_info.py b/jax/_src/pallas/mosaic/tpu_info.py
index 40006dc94880..3159fe0d1a98 100644
--- a/jax/_src/pallas/mosaic/tpu_info.py
+++ b/jax/_src/pallas/mosaic/tpu_info.py
@@ -20,8 +20,8 @@
from jax import numpy as jnp
from jax._src import dtypes
-from jax._src.pallas.mosaic import core
from jax._src import util as jax_util
+from jax._src.pallas.mosaic import core
class ChipVersionBase:
@@ -41,12 +41,15 @@ class ChipVersion(ChipVersionBase, enum.Enum):
def __str__(self) -> str:
return self.value
+
@dataclasses.dataclass(frozen=True, kw_only=True)
class SparseCoreInfo:
"""SparseCore-specific information."""
+
num_cores: int
num_subcores: int
num_lanes: int
+ dma_granule_size_bytes: int
@dataclasses.dataclass(frozen=True, kw_only=True)
@@ -122,10 +125,7 @@ def is_matmul_supported(
or (lhs_dt in {U4, S4} and rhs_dt in {U4, S4})
)
case 7:
- return (
- lhs_dt in {F32, BF16}
- and rhs_dt in {F32, BF16}
- ) or (
+ return (lhs_dt in {F32, BF16} and rhs_dt in {F32, BF16}) or (
lhs_dt in {F32, BF16, F8E5M2, F8E4M3FN}
and rhs_dt in {F8E5M2, F8E4M3FN}
)
@@ -154,6 +154,7 @@ def get_sublane_tiling(self, dtype: jnp.dtype) -> int:
def is_tpu_device() -> bool:
+ """Returns whether the current device is a TPU."""
return core.get_device_kind() in {
"TPU v2",
"TPU v3",
@@ -171,6 +172,7 @@ def is_tpu_device() -> bool:
registry: dict[str, Callable[[], TpuInfo]] = {}
+
@jax_util.cache(trace_context_in_key=True)
def get_tpu_info() -> TpuInfo:
"""Returns the TPU hardware information for the current device.
@@ -301,7 +303,12 @@ def get_tpu_info() -> TpuInfo:
int8_ops_per_second=int(9.18e14 // num_chip_cores),
fp8_ops_per_second=0, # Not Available
int4_ops_per_second=int(1.84e15 // num_chip_cores),
- sparse_core=SparseCoreInfo(num_cores=4, num_subcores=16, num_lanes=8),
+ sparse_core=SparseCoreInfo(
+ num_cores=4,
+ num_subcores=16,
+ num_lanes=8,
+ dma_granule_size_bytes=32,
+ ),
)
case "TPU v6 lite" | "TPU v6e": # 1 TensorCore per chip
return TpuInfo(
@@ -320,29 +327,39 @@ def get_tpu_info() -> TpuInfo:
int8_ops_per_second=int(1.84e15),
fp8_ops_per_second=int(9.20e14),
int4_ops_per_second=int(3.68e15),
- sparse_core=SparseCoreInfo(num_cores=2, num_subcores=16, num_lanes=8),
+ sparse_core=SparseCoreInfo(
+ num_cores=2,
+ num_subcores=16,
+ num_lanes=8,
+ dma_granule_size_bytes=32,
+ ),
)
case "TPU7x":
num_cores = core.get_num_device_cores()
num_chip_cores = 2
return TpuInfo(
- chip_version=ChipVersion.TPU_7X,
- generation=7,
- num_cores=num_cores,
- num_lanes=128,
- num_sublanes=8,
- mxu_column_size=256,
- vmem_capacity_bytes=64 * 1024 * 1024, # 64 MiB per core
- cmem_capacity_bytes=0,
- smem_capacity_bytes=1024 * 1024, # 1 MiB per core
- hbm_capacity_bytes=206_000_000_000 // num_chip_cores,
- mem_bw_bytes_per_second=int(7.40e12 // num_chip_cores),
- bf16_ops_per_second=int(2.31e15 // num_chip_cores),
- int8_ops_per_second=0, # Not Available
- fp8_ops_per_second=int(4.60e15 // num_chip_cores),
- int4_ops_per_second=0, # Not Available
- sparse_core=SparseCoreInfo(num_cores=4, num_subcores=16, num_lanes=16),
- )
+ chip_version=ChipVersion.TPU_7X,
+ generation=7,
+ num_cores=num_cores,
+ num_lanes=128,
+ num_sublanes=8,
+ mxu_column_size=256,
+ vmem_capacity_bytes=64 * 1024 * 1024, # 64 MiB per core
+ cmem_capacity_bytes=0,
+ smem_capacity_bytes=1024 * 1024, # 1 MiB per core
+ hbm_capacity_bytes=206_000_000_000 // num_chip_cores,
+ mem_bw_bytes_per_second=int(7.40e12 // num_chip_cores),
+ bf16_ops_per_second=int(2.31e15 // num_chip_cores),
+ int8_ops_per_second=0, # Not Available
+ fp8_ops_per_second=int(4.60e15 // num_chip_cores),
+ int4_ops_per_second=0, # Not Available
+ sparse_core=SparseCoreInfo(
+ num_cores=4,
+ num_subcores=16,
+ num_lanes=16,
+ dma_granule_size_bytes=64,
+ ),
+ )
case _ as d:
if d in registry:
return registry[d]()
diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py
index c8198d2cc655..0167c96d04ce 100644
--- a/jax/_src/pallas/mosaic_gpu/core.py
+++ b/jax/_src/pallas/mosaic_gpu/core.py
@@ -431,7 +431,8 @@ def flatten_ref_union(ref_union: AbstractRefUnion) -> tuple[_Ref, ...]:
union_bytes = 0
for ref_group in ref_union.refs:
byte_offset = 0
- for ref in jax.tree.leaves(ref_group):
+ def unflatten(ref):
+ nonlocal byte_offset
byte_offset = align_to(byte_offset, SMEM_ALIGNMENT)
assert isinstance(ref, state.AbstractRef) or isinstance(
ref, pallas_core.TransformedRef
@@ -439,10 +440,8 @@ def flatten_ref_union(ref_union: AbstractRefUnion) -> tuple[_Ref, ...]:
if not isinstance(ref, pallas_core.TransformedRef):
ref = pallas_core.TransformedRef(ref, transforms=())
transform = ExtractAliasedRef.from_transformed_ref(ref, byte_offset)
- flat_refs.append(
- pallas_core.TransformedRef(
- ref_union, transforms=(transform, *ref.transforms)
- )
+ result = pallas_core.TransformedRef(
+ ref_union, transforms=(transform, *ref.transforms)
)
if jnp.issubdtype(ref.dtype, jnp.integer):
nbits = jnp.iinfo(ref.dtype).bits
@@ -457,13 +456,16 @@ def flatten_ref_union(ref_union: AbstractRefUnion) -> tuple[_Ref, ...]:
f" {ref.dtype}{ref.shape}"
)
byte_offset += ref_bits // 8
+ return result
+ flat_refs.append(jax.tree.map(unflatten, ref_group))
union_bytes = max(union_bytes, byte_offset)
assert union_bytes == ref_union.shape[0]
elif ref_union.memory_space == TMEM:
union_cols = 0
for ref_group in ref_union.refs:
col_offset = 0
- for ref in jax.tree.leaves(ref_group):
+ def unflatten(ref):
+ nonlocal col_offset
col_offset = align_to(col_offset, TMEM_COL_ALIGNMENT)
if not isinstance(ref, pallas_core.TransformedRef):
ref = pallas_core.TransformedRef(ref, transforms=())
@@ -471,12 +473,12 @@ def flatten_ref_union(ref_union: AbstractRefUnion) -> tuple[_Ref, ...]:
dtypes.itemsize_bits(ref.dtype))
transform = ExtractAliasedRef.from_transformed_ref(
ref, col_offset, layout=ref.layout)
- flat_refs.append(
- pallas_core.TransformedRef(
- ref_union, transforms=(transform, *ref.transforms)
- )
+ result = pallas_core.TransformedRef(
+ ref_union, transforms=(transform, *ref.transforms)
)
col_offset += ncols
+ return result
+ flat_refs.append(jax.tree.map(unflatten, ref_group))
union_cols = max(union_cols, col_offset)
assert union_cols == ref_union.shape[1], (union_cols, ref_union.shape[1])
else:
@@ -670,7 +672,8 @@ def untransform_reshape(
self, dtype: jnp.dtype, shape: tuple[int, ...]
) -> tuple[tuple[int, ...], state_types.Transform]:
del dtype
- raise NotImplementedError("Reshapes don't commute with transposes.")
+ # TODO(slebedev): Support this.
+ raise NotImplementedError("Reshapes don't commute with tiling.")
def untransform_index(
self, dtype: jnp.dtype | ir.Type, idxs: tuple[Index, ...]
@@ -1367,6 +1370,11 @@ def _gpu_mesh_discharge_rule(
)
if not compiler_params:
compiler_params = CompilerParams()
+ sa_avals = [a for a in in_avals if isinstance(a, jax_core.ShapedArray)]
+ if sa_avals:
+ raise NotImplementedError(
+ f"Cannot close over values in core_map: {sa_avals}"
+ )
return pallas_core.default_mesh_discharge_rule(
in_avals,
out_avals,
diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py
index 6b9a4e5b2a1b..57834b6b7f63 100644
--- a/jax/_src/pallas/mosaic_gpu/lowering.py
+++ b/jax/_src/pallas/mosaic_gpu/lowering.py
@@ -1557,9 +1557,16 @@ def _get_lowering_rule(
dtype = ctx.avals_out[0].dtype
transforms = jax.tree.unflatten(tree, leaves)
+ transposed = ctx.out_layout_hint and ctx.out_layout_hint in (
+ mgpu.WGMMA_TRANSPOSED_LAYOUT,
+ mgpu.TCGEN05_TRANSPOSED_LAYOUT,
+ )
+ transposed = bool(transposed)
x_smem, transforms = _handle_transforms(
- ctx, x_ref, transforms, allow_peer_refs=True
+ ctx, x_ref, transforms, handle_transposes=not transposed,
+ allow_peer_refs=True
)
+ x_smem = cast(ir.Value, x_smem)
del x_ref # Don't use x_ref anymore. Use x_smem instead!
is_signed = mgpu_utils.is_signed(dtype)
@@ -1569,20 +1576,49 @@ def _get_lowering_rule(
return mgpu.FragmentedArray.splat(val, shape=(), is_signed=is_signed)
match transforms:
- case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)):
+ case (
+ gpu_core.UnswizzleRef(swizzle),
+ gpu_core.UntileRef(tiling),
+ *maybe_transpose,
+ ):
if len(tiling) != 2:
raise NotImplementedError(f"Only 2D tiling is supported, got: {tiling}")
- expected_minor_tiling = swizzle * 8 // dtypes.itemsize_bits(dtype)
+ bw = dtypes.itemsize_bits(ctx.avals_out[0].dtype)
+ expected_minor_tiling = swizzle * 8 // bw
if tiling[-1] != expected_minor_tiling:
raise NotImplementedError(
"Minor tiling dimension does not fit swizzle: "
f" expected {expected_minor_tiling}, got {tiling[-1]}"
)
- layout = ctx.out_layout_hint or mgpu.WGMMA_LAYOUT
+
+ if transposed != bool(maybe_transpose):
+ raise ValueError(
+ "Either both the ref and the value are transposed or neither is."
+ )
+
+ if maybe_transpose:
+ if maybe_transpose != [gpu_core.TransposeRef((1, 0))]:
+ raise NotImplementedError(
+ f"Unsupported transforms: {transforms} ({maybe_transpose})"
+ )
+
+ x_smem = mgpu.memref_transpose(x_smem, (1, 0, 3, 2))
return mgpu.FragmentedArray.load_tiled(
- x_smem, is_signed=is_signed, swizzle=swizzle, layout=layout, optimized=optimized
+ x_smem,
+ is_signed=is_signed,
+ swizzle=swizzle,
+ layout=ctx.out_layout_hint or mgpu.WGMMA_LAYOUT,
+ optimized=optimized,
)
- case ():
+ case (*maybe_transpose,):
+ if maybe_transpose:
+ if len(maybe_transpose) != 1 or not isinstance(
+ maybe_transpose[0], gpu_core.TransposeRef
+ ):
+ raise NotImplementedError(
+ f"Unsupported transforms: {transforms} ({maybe_transpose})"
+ )
+ x_smem = mgpu.memref_transpose(x_smem, maybe_transpose[0].permutation)
match ctx.out_layout_hint:
case mgpu.WGStridedFragLayout(shape=shape, vec_size=vec_size):
ref_ty = ir.MemRefType(x_smem.type)
@@ -1672,16 +1708,17 @@ def _swap_lowering_rule(
if ctx.module_ctx.auto_barriers:
barrier() # Make sure reads have completed before we write.
+
match transforms:
- case _ if not ctx.avals_out[0].shape: # Scalar case.
+ case _ if math.prod(ctx.avals_out[0].shape) == 1: # Scalar case.
+ zero_idx = _ir_constant(0, ir.IndexType.get())
+ indices = [zero_idx] * len(ctx.avals_out[0].shape)
old_value = mgpu.FragmentedArray.splat(
- memref_dialect.load(x_smem, []),
+ memref_dialect.load(x_smem, indices),
shape=(),
is_signed=mgpu_utils.is_signed(v_aval.dtype),
)
- memref_dialect.store(
- _ensure_ir_value(value, ctx.avals_out[0].dtype), x_smem, []
- )
+ value.store_untiled(x_smem)
case (
gpu_core.UnswizzleRef(swizzle),
gpu_core.UntileRef(tiling),
@@ -1717,9 +1754,16 @@ def _swap_lowering_rule(
layout=value.layout,
)
value.store_tiled(x_smem, swizzle=swizzle)
- case ():
+ case () | (gpu_core.TransposeRef(),):
+ transposed = bool(transforms)
match value.layout:
case mgpu.TiledLayout():
+ if transposed:
+ assert isinstance(
+ transforms[0], gpu_core.TransposeRef
+ ) # silence pytype
+ permutation = transforms[0].permutation
+ x_smem = mgpu.memref_transpose(x_smem, permutation)
old_value = mgpu.FragmentedArray.load_untiled(
x_smem,
layout=value.layout,
@@ -1728,6 +1772,8 @@ def _swap_lowering_rule(
)
value.store_untiled(x_smem, optimized=False)
case _:
+ if transposed:
+ raise NotImplementedError(f"Unsupported transforms: {transforms}")
old_value = mgpu.FragmentedArray.load_strided(
x_smem, is_signed=mgpu_utils.is_signed(v_aval.dtype)
)
@@ -1758,6 +1804,7 @@ def _swap_lowering_rule_wg(
"Transforms are not yet implemented for warpgroup semantics"
)
assert isinstance(x_smem, ir.Value)
+ value = _ensure_ir_value(value, ctx.avals_in[1].dtype)
if shape:
old_value = mgpu.dialect.vector_load(x_smem)
mgpu.dialect.vector_store(value, x_smem)
@@ -1859,6 +1906,14 @@ def _broadcast_in_dim_lowering_rule(
if (isinstance(x.layout, mgpu.WGSplatFragLayout) and
broadcast_dimensions == tuple(range(rank_diff, rank_diff + x_aval.ndim))):
return x.broadcast(shape)
+ if (
+ isinstance(x.layout, mgpu.WGStridedFragLayout)
+ and broadcast_dimensions == tuple(range(rank_diff, y_aval.ndim))
+ ):
+ new_layout = mgpu.WGStridedFragLayout(
+ shape=y_aval.shape, vec_size=x.layout.vec_size
+ )
+ return x.broadcast_in_dim(y_aval.shape, broadcast_dimensions, new_layout)
if not isinstance(layout := x.layout, mgpu.TiledLayout):
raise NotImplementedError(f"Unsupported layout: {x.layout}")
if any(d1 >= d2 for d1, d2 in zip(broadcast_dimensions[:-1], broadcast_dimensions[1:])):
@@ -2216,6 +2271,14 @@ def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y):
return res
+@register_lowering_rule(lax.clamp_p, mgpu.LoweringSemantics.Lane)
+@register_lowering_rule(lax.clamp_p, mgpu.LoweringSemantics.Warpgroup)
+def _clamp_lowering_rule(ctx: LoweringRuleContext, l, x, u):
+ return _lower_fun(
+ lambda l, x, u: lax.min(lax.max(x, l), u), multiple_results=False
+ )(ctx, l, x, u)
+
+
@register_lowering_rule(lax.square_p, mgpu.LoweringSemantics.Lane)
@register_lowering_rule(lax.square_p, mgpu.LoweringSemantics.Warpgroup)
def _square_lowering_rule(ctx: LoweringRuleContext, x):
@@ -2301,6 +2364,31 @@ def _exp2_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
)
return math_dialect.exp2(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath)
+@register_lowering_rule(lax.sin_p, mgpu.LoweringSemantics.Lane)
+@register_lowering_rule(lax.sin_p, mgpu.LoweringSemantics.Warpgroup)
+def _sin_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
+ if accuracy is not None:
+ raise NotImplementedError("Not implemented: accuracy")
+ [x_aval] = ctx.avals_in
+ if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane:
+ return _ensure_fa(x, x_aval.dtype).sin(approx=ctx.module_ctx.approx_math)
+ fastmath = (
+ arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None
+ )
+ return math_dialect.sin(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath)
+
+@register_lowering_rule(lax.cos_p, mgpu.LoweringSemantics.Lane)
+@register_lowering_rule(lax.cos_p, mgpu.LoweringSemantics.Warpgroup)
+def _cos_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
+ if accuracy is not None:
+ raise NotImplementedError("Not implemented: accuracy")
+ [x_aval] = ctx.avals_in
+ if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane:
+ return _ensure_fa(x, x_aval.dtype).cos(approx=ctx.module_ctx.approx_math)
+ fastmath = (
+ arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None
+ )
+ return math_dialect.cos(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath)
@register_lowering_rule(lax.log_p, mgpu.LoweringSemantics.Lane)
@register_lowering_rule(lax.log_p, mgpu.LoweringSemantics.Warpgroup)
@@ -2316,6 +2404,57 @@ def _log_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
return math_dialect.log(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath)
+@register_lowering_rule(lax.reshape_p, mgpu.LoweringSemantics.Lane)
+def _reshape_lowering_rule(
+ ctx: LoweringRuleContext, x, new_sizes, dimensions, sharding
+):
+ if dimensions is not None:
+ raise NotImplementedError("Not implemented: dimensions")
+ if sharding is not None:
+ raise NotImplementedError("Not implemented: sharding")
+ [x_aval] = ctx.avals_in
+ return _ensure_fa(x, x_aval.dtype).reshape(new_sizes)
+
+
+@register_lowering_rule(lax.reshape_p, mgpu.LoweringSemantics.Warpgroup)
+def _reshape_lowering_rule_wg(
+ ctx: LoweringRuleContext, x, new_sizes, dimensions, sharding
+):
+ if dimensions is not None:
+ raise NotImplementedError("Not implemented: dimensions")
+ if sharding is not None:
+ raise NotImplementedError("Not implemented: sharding")
+ [x_aval] = ctx.avals_in
+ x = _ensure_ir_value(x, x_aval.dtype)
+ if x_aval.ndim == 0: # scalar
+ res_ty = ir.VectorType.get(new_sizes, x.type)
+ return vector_dialect.broadcast(res_ty, x)
+ else:
+ res_ty = ir.VectorType.get(new_sizes, ir.VectorType(x.type).element_type)
+ return vector_dialect.shape_cast(res_ty, x)
+
+
+@register_lowering_rule(lax.squeeze_p, mgpu.LoweringSemantics.Lane)
+def _squeeze_lowering_rule(ctx: LoweringRuleContext, x, dimensions):
+ [x_aval] = ctx.avals_in
+ [y_aval] = ctx.avals_out
+ return _ensure_fa(x, x_aval.dtype).reshape(y_aval.shape)
+
+
+@register_lowering_rule(lax.squeeze_p, mgpu.LoweringSemantics.Warpgroup)
+def _squeeze_lowering_rule_wg(ctx: LoweringRuleContext, x, dimensions):
+ [x_aval] = ctx.avals_in
+ [y_aval] = ctx.avals_out
+ x = _ensure_ir_value(x, x_aval.dtype)
+ if y_aval.ndim == 0: # scalar
+ return vector_dialect.extract(
+ x, dynamic_position=[], static_position=[0] * x_aval.ndim
+ )
+ else:
+ res_ty = ir.VectorType.get(y_aval.shape, ir.VectorType(x.type).element_type)
+ return vector_dialect.shape_cast(res_ty, x)
+
+
def _reduce_lowering_rule(op, ctx: LoweringRuleContext, x, *, axes, **kwargs):
[x_aval] = ctx.avals_in
match x.layout:
diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py
index 61d75d5297b7..4c8b49cbce60 100644
--- a/jax/_src/pallas/mosaic_gpu/pipeline.py
+++ b/jax/_src/pallas/mosaic_gpu/pipeline.py
@@ -59,8 +59,15 @@ def _get_block_size(
raise NotImplementedError(f"Unsupported block size type: {type(bd)}")
def _get_block_shape(spec: pallas_core.BlockSpec):
- assert spec.block_shape is not None
- return tuple(_get_block_size(bd) for bd in spec.block_shape)
+ if spec.block_shape is None:
+ raise ValueError("Block shape must be specified.")
+
+ block_shape = tuple(
+ _get_block_size(bd)
+ for bd in spec.block_shape
+ if not (bd is None or isinstance(bd, pl.Squeezed))
+ )
+ return block_shape
map_brefs = functools.partial(
@@ -84,18 +91,27 @@ def get_ref_for_slot(
return self.gmem_ref
return self.smem_ref.at[slot]
- def compute_gmem_slice(self, grid_indices) -> tuple[pl.Slice, ...]:
+ def compute_gmem_slice(self, grid_indices) -> tuple[pl.Slice | jax.Array, ...]:
index_map = self.spec.index_map
assert index_map is not None
+ assert self.spec.block_shape is not None
# We don't allow Python scalars here, because they are interpreted
# differently depending on the x32/x64 mode.
assert all(i.dtype == jnp.dtype(jnp.int32) for i in grid_indices)
- sizes = _get_block_shape(self.spec)
+
+ def _make_block_slice(block_index: jax.Array, bd: pl.BlockDim | int | None):
+ match bd:
+ case int():
+ return pl.Slice(block_index * bd, bd)
+ case pl.Blocked(block_size):
+ return pl.Slice(block_index * block_size, block_size)
+ case None | pl.Squeezed():
+ return block_index
+ case _:
+ raise ValueError(f"Unsupported block dimension type: {bd}")
+
return tuple(
- pl.Slice(idx * size, size) # type: ignore[arg-type]
- for idx, size in zip(
- index_map(*grid_indices), sizes # type: ignore[arg-type]
- )
+ map(_make_block_slice, index_map(*grid_indices), self.spec.block_shape)
)
def copy_in(self, slot, grid_indices, barrier_ref, barrier_slot=None):
@@ -166,26 +182,6 @@ def _inc_grid_by_1(
def _in_smem(spec: pallas_core.BlockSpec) -> bool:
return spec.memory_space in (None, gpu_core.SMEM)
-
-# ``pl.Slice`` uses a different pytree encoding, depending on whether the
-# start/size are static or dynamic. This leads to pytree structure mismatch
-# in the pipeline body. So, we define a different ``Slice`` class below.
-
-
-@dataclasses.dataclass(frozen=True)
-class _Slice:
- start: int | jax.Array
- size: int | jax.Array
-
- def __eq__(self, other: _Slice) -> jax.Array: # type: ignore
- return lax.bitwise_and(self.start == other.start, self.size == other.size)
-
-
-jax.tree_util.register_dataclass(
- _Slice, data_fields=["start", "size"], meta_fields=[]
-)
-
-
def _downcast_spec(
spec: gpu_core.BlockSpec | pallas_core.BlockSpec,
) -> gpu_core.BlockSpec:
@@ -341,7 +337,7 @@ def prologue(step, fetch_indices):
# need to fetch more data anyway.
def loop_body(step, carry):
slot = lax.rem(step, max_concurrent_steps)
- indices, fetch_index_levels, last_store_slices, prev_body_carry = carry
+ indices, fetch_index_levels, last_store_indices, prev_body_carry = carry
if barrier_ref is not None:
# Wait for the current GMEM->SMEM copy to complete, if any.
@@ -365,19 +361,17 @@ def loop_body(step, carry):
gpu_primitives.commit_smem()
# Copy the output from SMEM to GMEM.
- new_store_slices = last_store_slices[:]
+ new_store_indices = last_store_indices[:]
for idx, bref in enumerate(out_brefs):
if bref.is_index_invariant:
- assert last_store_slices[idx] is None
+ assert last_store_indices[idx] is None
continue
- assert last_store_slices[idx] is not None
- new_store_slices[idx] = tuple(
- _Slice(s.start, s.size) for s in bref.compute_gmem_slice(indices)
- )
+ assert last_store_indices[idx] is not None
+ new_store_indices[idx] = bref.spec.index_map(*indices)
are_same_slices = map(
lambda old, new: old == new,
- last_store_slices[idx],
- new_store_slices[idx],
+ last_store_indices[idx],
+ new_store_indices[idx],
)
slices_changed = ~functools.reduce(lax.bitwise_and, are_same_slices)
is_last_step = step == num_steps - 1
@@ -419,7 +413,7 @@ def do_fetch():
return (
_inc_grid_by_1(indices, grid),
next_fetch_indices_levels,
- new_store_slices,
+ new_store_indices,
next_body_carry if init_carry is not None else None,
)
@@ -431,17 +425,17 @@ def do_fetch():
fetch_index_levels.append(fetch_indices)
# TODO(justinfu): Only store base pointer instead of all indices.
- last_store_slices = [
+ last_store_indices = [
None
if bref.is_index_invariant
- else (_Slice(-1, -1),) * len(bref.spec.block_shape)
+ else (jnp.array(-1),) * len(bref.spec.block_shape)
for bref in out_brefs
]
last_indices, _, _, final_carry = lax.fori_loop(
0,
num_steps,
loop_body,
- (indices, fetch_index_levels, last_store_slices, init_carry),
+ (indices, fetch_index_levels, last_store_indices, init_carry),
)
# Outputs invariant to the sequential axis are never written from inside the
@@ -690,7 +684,7 @@ def _get_scoped_allocs(*gmem_refs: AbstractRefPytree):
slots = max_concurrent_steps if has_seq_dim else 1
smem_allocs.append(
gpu_core.SMEM(
- (slots, *spec.block_shape), # type: ignore
+ (slots, *_get_block_shape(spec)), # type: ignore
gmem_ref.dtype,
transforms=getattr(spec, "transforms", ()),
)
@@ -826,7 +820,7 @@ def compute_block():
needs_epilogue = any(bref.is_index_invariant for bref in smem_out_brefs)
def compute_loop_body(step, carry):
- indices, last_store_slices, prev_body_carry = carry
+ indices, last_store_indices, prev_body_carry = carry
slot = lax.rem(step, max_concurrent_steps)
consumed_slot = lax.rem(step - delay_release, max_concurrent_steps)
# Wait for the current GMEM->SMEM copies to complete.
@@ -873,19 +867,17 @@ def compute_loop_body(step, carry):
if copies_out_in_loop:
gpu_primitives.commit_smem()
- new_store_slices = last_store_slices[:]
+ new_store_indices = last_store_indices[:]
for idx, bref in enumerate(flat_out_brefs):
if bref.is_index_invariant:
- assert last_store_slices[idx] is None
+ assert last_store_indices[idx] is None
continue
- assert last_store_slices[idx] is not None
- new_store_slices[idx] = tuple(
- _Slice(s.start, s.size) for s in bref.compute_gmem_slice(indices)
- )
+ assert last_store_indices[idx] is not None
+ new_store_indices[idx] = bref.spec.index_map(*indices)
are_same_slices = map(
lambda old, new: old == new,
- last_store_slices[idx],
- new_store_slices[idx],
+ last_store_indices[idx],
+ new_store_indices[idx],
)
slices_changed = ~functools.reduce(lax.bitwise_and, are_same_slices)
bref.copy_out(_get_slot(slot, not bref.is_index_invariant),
@@ -893,13 +885,14 @@ def compute_loop_body(step, carry):
predicate=slices_changed)
gpu_primitives.commit_smem_to_gmem_group()
next_indices = _inc_grid_by_1(indices, grid)
- return (next_indices, new_store_slices, next_body_carry)
+ return (next_indices, new_store_indices, next_body_carry)
init_indices = (jnp.asarray(0, dtype=jnp.int32),) * len(grid)
+
# TODO(justinfu): Only store base pointer instead of all indices.
- last_store_slices = [
+ last_store_indices = [
None
if bref.is_index_invariant
- else (_Slice(-1, -1),) * len(bref.spec.block_shape)
+ else (jnp.array(-1),) * len(bref.spec.block_shape)
for bref in flat_out_brefs
]
@@ -910,7 +903,7 @@ def pipeline_callback(user_init_carry):
if last_indices is not None:
raise ValueError(
"Cannot call pipeline more than once in `compute_context`")
- init_loop_carry = (init_indices, last_store_slices, user_init_carry)
+ init_loop_carry = (init_indices, last_store_indices, user_init_carry)
last_indices, _, final_body_carry = lax.fori_loop(0,
num_steps,
compute_loop_body,
@@ -923,7 +916,7 @@ def pipeline_callback(user_init_carry):
assert compute_context is None
last_indices, _, _ = lax.fori_loop(
0, num_steps, compute_loop_body,
- (init_indices, last_store_slices, None)
+ (init_indices, last_store_indices, None)
)
# Handle index_invariant outputs after the loop. They are not
diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py
index b0ed9f5d0848..ae8a6c06dba1 100644
--- a/jax/_src/pallas/mosaic_gpu/primitives.py
+++ b/jax/_src/pallas/mosaic_gpu/primitives.py
@@ -53,6 +53,7 @@
import numpy as np
+AxisName = jax_core.AxisName
WARP_SIZE = 32
WARPGROUP_SIZE = 128
@@ -248,14 +249,32 @@ def _copy_smem_to_gmem_lowering(
"GMEM refs with peer ids are not supported in warpgroup lowering."
)
assert not copy_params.get("gmem_transform")
- mgpu.dialect.async_store(
- src,
- dst,
- indices,
- slice_lengths,
- predicate=predicate,
- commit_group=commit_group, # type: ignore[call-arg]
- )
+ if reduction_op is not None:
+ # TODO(b/415721295): Call mgpu.dialect.async_store after the if, after
+ # the minimal jaxlib version is 0.8.2.
+ if not hasattr(mgpu.dialect, "TMAReduction"):
+ raise NotImplementedError("Reduction op is not supported yet.")
+ reduction_op_attr = getattr(
+ mgpu.dialect.TMAReduction, reduction_op.capitalize()
+ )
+ mgpu.dialect.async_store(
+ src,
+ dst,
+ indices,
+ slice_lengths,
+ predicate=predicate,
+ commit_group=commit_group, # type: ignore[call-arg]
+ reduction_op=reduction_op_attr,
+ )
+ else:
+ mgpu.dialect.async_store(
+ src,
+ dst,
+ indices,
+ slice_lengths,
+ predicate=predicate,
+ commit_group=commit_group, # type: ignore[call-arg]
+ )
return ()
@@ -3184,7 +3203,10 @@ def _async_store_tmem_lowering_rule_wg(
async_copy_scales_to_tmem_p = jax_core.Primitive("async_copy_scales_to_tmem")
async_copy_scales_to_tmem_p.multiple_results = True
-def async_copy_scales_to_tmem(smem_ref: _Ref, tmem_ref: _Ref):
+
+def async_copy_scales_to_tmem(
+ smem_ref: _Ref, tmem_ref: _Ref, collective_axis: AxisName | None = None,
+):
"""Copies the MMA scales from SMEM to TMEM.
The copy is performed asynchronously and can be awaited by calling
@@ -3208,12 +3230,17 @@ def async_copy_scales_to_tmem(smem_ref: _Ref, tmem_ref: _Ref):
async_copy_scales_to_tmem_p.bind(
smem_ref, tmem_ref, *flat_smem_transforms, *flat_tmem_transforms,
smem_tree=smem_transforms_treedef, tmem_tree=tmem_transforms_treedef,
+ collective_axis=collective_axis,
)
+
async_copy_sparse_metadata_to_tmem_p = jax_core.Primitive("async_copy_sparse_metadata_to_tmem")
async_copy_sparse_metadata_to_tmem_p.multiple_results = True
-def async_copy_sparse_metadata_to_tmem(smem_ref: _Ref, tmem_ref: _Ref):
+
+def async_copy_sparse_metadata_to_tmem(
+ smem_ref: _Ref, tmem_ref: _Ref, collective_axis: AxisName | None = None
+):
"""Copies the MMA sparse metadata from SMEM to TMEM.
The copy is performed asynchronously and can be awaited by calling
@@ -3237,11 +3264,13 @@ def async_copy_sparse_metadata_to_tmem(smem_ref: _Ref, tmem_ref: _Ref):
async_copy_sparse_metadata_to_tmem_p.bind(
smem_ref, tmem_ref, *flat_smem_transforms, *flat_tmem_transforms,
smem_tree=smem_transforms_treedef, tmem_tree=tmem_transforms_treedef,
+ collective_axis=collective_axis,
)
+
@async_copy_scales_to_tmem_p.def_effectful_abstract_eval
@async_copy_sparse_metadata_to_tmem_p.def_effectful_abstract_eval
-def _async_copy_to_tmem_abstract_eval(smem_ref, tmem_ref, *avals_flat, smem_tree, tmem_tree):
+def _async_copy_to_tmem_abstract_eval(smem_ref, tmem_ref, *_args, **_kwargs):
if smem_ref.memory_space != gpu_core.MemorySpace.SMEM:
raise ValueError("async_copy_scales_to_tmem source must be an SMEM ref")
if tmem_ref.memory_space != gpu_core.MemorySpace.TMEM:
@@ -3249,7 +3278,7 @@ def _async_copy_to_tmem_abstract_eval(smem_ref, tmem_ref, *avals_flat, smem_tree
return (), {gpu_core._memory_effect}
def _async_copy_to_tmem_lowering_rule(
- impl, ctx: lowering.LoweringRuleContext, smem_ref, tmem_ref, *leaves, smem_tree, tmem_tree
+ impl, ctx: lowering.LoweringRuleContext, smem_ref, tmem_ref, *leaves, smem_tree, tmem_tree, collective_axis
):
assert isinstance(tmem_ref, tcgen05.TMEMRef)
smem_leaves, tmem_leaves = util.split_list(leaves, [smem_tree.num_leaves])
@@ -3261,8 +3290,17 @@ def _async_copy_to_tmem_lowering_rule(
raise NotImplementedError(f"Unimplemented transforms for SMEM refs: {smem_transforms}")
if tmem_transforms:
raise NotImplementedError(f"Unimplemented transforms for TMEM refs: {tmem_transforms}")
- with mgpu.when(ctx.module_ctx.single_lane_predicate):
- impl(smem_ref, tmem_ref)
+
+ predicate = ctx.module_ctx.single_lane_predicate
+ if collective_axis is not None:
+ is_leader_block = _collective_mma_predicate(ctx, collective_axis)
+ predicate = arith_dialect.andi(predicate, is_leader_block)
+ collective = True
+ else:
+ collective = False
+
+ with mgpu.when(predicate):
+ impl(smem_ref, tmem_ref, collective=collective)
return ()
@lowering.register_lowering_rule(
diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py
index c2d8d0e97e6b..bf81467e4669 100644
--- a/jax/_src/pallas/pallas_call.py
+++ b/jax/_src/pallas/pallas_call.py
@@ -18,6 +18,7 @@
from collections.abc import Callable, Mapping, Sequence
import contextlib
import enum
+import math
from functools import partial, reduce
import types
from typing import Any
@@ -37,16 +38,17 @@
from jax._src.traceback_util import api_boundary
from jax._src import tree_util
from jax._src import typing as jax_typing
+from jax._src.mesh import get_abstract_mesh
from jax._src.frozen_dict import FrozenDict
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.pallas import core as pallas_core
-from jax._src.pallas import helpers as pallas_helpers
from jax._src.pallas import hlo_interpreter
from jax._src.pallas import primitives
from jax._src.state import discharge as state_discharge
+from jax._src.shard_map import shard_map, P, _as_manual_mesh
from jax._src.state import types as state_types
from jax._src.util import (
safe_map,
@@ -116,12 +118,26 @@ def _pallas_call_abstract_eval(
raise ValueError(f"input pinned buffers without input_output_aliases:"
f"{missing}")
outin_aliases = {out_idx: in_idx for in_idx, out_idx in inout_aliases.items()}
- out_avals = [jax_core.ShapedArray(a.shape, a.dtype, a.weak_type)
+ # Make sure we don't return ShapedArrayWithMemorySpace to the outside world.
+ out_avals = [jax_core.ShapedArray(a.shape, a.dtype, a.weak_type,
+ sharding=a.sharding)
if isinstance(a, pallas_core.ShapedArrayWithMemorySpace) else
avals[outin_aliases[out_idx]] if out_idx in outin_aliases
else a for out_idx, a in enumerate(out_avals)]
-
- # Make sure we don't return ShapedArrayWithMemorySpace to the outside world.
+ # TODO(mattjj,yashkatariya): if we hide vmapped away mesh axes, use this:
+ # if not (all(a.sharding.mesh.are_all_axes_manual for a in avals) and
+ # all(a.sharding.mesh.are_all_axes_manual for a in out_avals) and
+ # get_abstract_mesh().are_all_axes_manual):
+ # raise ValueError("pallas_call requires all mesh axes to be Manual, "
+ # f"got {get_abstract_mesh().axis_types}")
+
+ # NOTE(mattjj,yashkatariya): this doesn't catch auto-mode non-manual axes
+ if not (all(p is None for a in avals if isinstance(a, jax_core.ShapedArray)
+ for p in a.sharding.spec) and
+ all(p is None for a in out_avals if isinstance(a, jax_core.ShapedArray)
+ for p in a.sharding.spec)):
+ raise ValueError("pallas_call requires all mesh axes to be Manual, "
+ f"got {get_abstract_mesh().axis_types}")
return out_avals, effs
@@ -329,17 +345,12 @@ def _pallas_call_jvp_rule(
def _batch_block_mapping(
grid_mapping: GridMapping,
axis_size: int,
- for_ragged: bool,
aval: jax_core.ShapedArray,
dim: int | batching.NotMapped,
block_mapping: BlockMapping,
- ragged_axis_values,
) -> BlockMapping:
def _block_map_function(new_idx, *args):
- if for_ragged:
- drop_last_args = args[:-1]
- else:
- drop_last_args = args
+ drop_last_args = args
indices = jax_core.eval_jaxpr(
block_mapping.index_map_jaxpr.jaxpr,
@@ -352,27 +363,10 @@ def _block_map_function(new_idx, *args):
unflat_indices = (unflat_indices,)
unflat_indices = list(unflat_indices)
if dim is not batching.not_mapped:
- if isinstance(dim, batching.RaggedAxis):
- assert for_ragged, "Ragged axis not supported for non-ragged batching."
- stacked_axis = dim.stacked_axis
- unflat_indices.insert(stacked_axis, new_idx)
- else:
- unflat_indices.insert(dim, new_idx)
+ unflat_indices.insert(dim, new_idx)
return tuple(unflat_indices)
idx_avals = [pallas_core.index_map_grid_aval, *block_mapping.index_map_jaxpr.in_avals]
- if for_ragged:
- if isinstance(dim, batching.RaggedAxis):
- assert for_ragged, "Ragged axis not supported for non-ragged batching."
- _, _, _, lengths_aval = ragged_axis_values
- idx_avals = [*idx_avals, lengths_aval]
- else:
- i32_aval_memref = state.AbstractRef(
- jax_core.ShapedArray(([axis_size]), jnp.int32),
- pallas_core.MemorySpace.INDEX,
- )
- idx_avals = [*idx_avals, i32_aval_memref]
-
block_mapping_flat_fn, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(_block_map_function,
debug_info=block_mapping.index_map_jaxpr.jaxpr.debug_info.with_unknown_names()),
@@ -387,23 +381,10 @@ def _block_map_function(new_idx, *args):
new_block_shape = shape
new_array_aval = block_mapping.array_aval
else:
- if isinstance(dim, batching.RaggedAxis):
- assert for_ragged, "Ragged axis not supported for non-ragged batching."
- new_block_shape = shape
- stacked_axis = dim.stacked_axis
- new_block_shape = tuple_insert(
- new_block_shape, stacked_axis, pallas_core.squeezed
- )
- else:
- new_block_shape = tuple_insert(shape, dim, pallas_core.squeezed)
+ new_block_shape = tuple_insert(shape, dim, pallas_core.squeezed)
array_shape = block_mapping.array_aval.shape
- if isinstance(dim, batching.RaggedAxis):
- assert for_ragged, "Ragged axis not supported for non-ragged batching."
- stacked_axis = dim.stacked_axis
- array_shape = tuple_insert(array_shape, stacked_axis, axis_size)
- else:
- array_shape = tuple_insert(array_shape, dim, axis_size)
+ array_shape = tuple_insert(array_shape, dim, axis_size)
new_array_aval = jax_core.ShapedArray(
array_shape, block_mapping.array_aval.dtype
@@ -437,12 +418,6 @@ def _broadcast_input_output_aliases(
for input_index, _ in input_output_aliases:
dim = dims_[input_index]
dims_[input_index] = 0
- if isinstance(dim, batching.RaggedAxis):
- stacked_axis = dim.stacked_axis
- if stacked_axis != 0:
- raise NotImplementedError("Ragged aliasing on non 0 dim NYI")
- return tuple(args_), tuple(dims_)
-
if dim is batching.not_mapped:
args_[input_index] = batching.broadcast(
args_[input_index], axis_size, 0, None)
@@ -557,6 +532,7 @@ def body(batch_index: jax_typing.Array, state: list[jax_typing.Array]) -> list[j
def _pallas_call_batching_rule(
+ axis_data,
args,
dims,
*,
@@ -585,19 +561,21 @@ def _maybe_squeeze_out_bdim(
return x
return jnp.squeeze(x, axis=bdim)
- def get_size(i, x, d):
- if not isinstance(d, batching.RaggedAxis):
- return x.shape[d]
- return x.aval.shape[d.stacked_axis]
+ # this is the _global_ axis size if axis_data.explicit_mesh_axis is not None
+ # we want to convert it to the local axis size
+ axis_size = axis_data.size
+ ema = axis_data.explicit_mesh_axis
+ abs_mesh = get_abstract_mesh()
+ if ema:
+ mesh_size = math.prod(abs_mesh.shape[i] for i in ema)
+ axis_size, ragged = divmod(axis_size, mesh_size)
+ assert not ragged
- (axis_size,) = {
- get_size(i=i, x=x, d=d)
- for i, (x, d) in enumerate(zip(args, dims))
- if d is not batching.not_mapped
- }
if axis_size == 1:
# Why are we even vmapping?
args = map(_maybe_squeeze_out_bdim, args, dims)
+ if ema:
+ raise NotImplementedError()
out = pallas_call_p.bind(
*args,
jaxpr=jaxpr,
@@ -633,6 +611,8 @@ def get_size(i, x, d):
elif any(bdim is not batching.not_mapped for bdim in dynamic_grid_dims):
# TODO(amagni, sharadmv): Explore possibility of batching dynamic grid
# bounds.
+ if ema:
+ raise NotImplementedError()
return _batch_with_explicit_loop(
args=dynamic_grid_args + args,
dims=dynamic_grid_dims + dims,
@@ -670,6 +650,8 @@ def get_size(i, x, d):
else:
# TODO(amagni,sharadmv,apaszke): enable efficient batching over
# prefetched scalar args.
+ if ema:
+ raise NotImplementedError()
return _batch_with_explicit_loop(
args=scalar_args + args,
dims=scalar_bdims + bdims,
@@ -702,30 +684,7 @@ def get_size(i, x, d):
args, dims, input_output_aliases=input_output_aliases, axis_size=axis_size
)
- # Each dim either has data about its ragged axis, or None
- ragged_axis_values = []
- for d in dims:
- if isinstance(d, batching.RaggedAxis):
- stacked_axis, ragged_axis_dim, ragged_axis_length = (
- batching._ragged_axis_parts(d)
- )
- aval = jax_core.get_aval(ragged_axis_length).update(dtype=jnp.int32)
- if isinstance(aval, jax_core.DShapedArray):
- aval = jax_core.ShapedArray(aval.shape, aval.dtype, aval.weak_type)
- lengths_aval = state.AbstractRef(
- aval,
- pallas_core.MemorySpace.INDEX,
- )
- # TODO(mvoz): Give this its own type
- ragged_axis_values.append(
- (stacked_axis, ragged_axis_dim, ragged_axis_length, lengths_aval)
- )
- else:
- ragged_axis_values.append(None) # type: ignore[arg-type]
-
all_dims = list(dims) + [0] * grid_mapping.num_outputs
- ragged_axis_values = ragged_axis_values + [None] * grid_mapping.num_outputs
-
num_index_operands = grid_mapping.num_index_operands
num_scratch_operands = grid_mapping.num_scratch_operands
@@ -739,34 +698,16 @@ def get_size(i, x, d):
_batch_block_mapping,
grid_mapping,
axis_size,
- any(ragged_axis_values),
),
avals_to_batch,
all_dims[num_index_operands:],
block_mappings,
- ragged_axis_values[num_index_operands:],
)
index_map_tree_args, index_map_tree_kwargs = grid_mapping.index_map_tree.unflatten(
grid_mapping.index_map_avals)
assert not index_map_tree_kwargs
batched_index_map_args = (pallas_core.index_map_grid_aval,) + index_map_tree_args
-
- lengths_aval = None # type: ignore[assignment]
-
- # Check all the ragged axis values, ensure their raggedness pattern
- # is identical (consider moving this check up!)
- for rav in ragged_axis_values:
- if rav is not None:
- if lengths_aval is None:
- lengths_aval = rav[3]
- else:
- assert lengths_aval == rav[3], "NYI - different lengths in ragged batch"
-
- if lengths_aval:
- batched_index_map_args = batched_index_map_args + (lengths_aval,)
- num_index_operands += 1
-
batched_index_map_avals, batched_index_map_tree = tree_util.tree_flatten(
(batched_index_map_args, {}))
@@ -791,290 +732,47 @@ def get_size(i, x, d):
else:
batched_cost_estimate = None
- # Start the ragged handling code
- # Here, we:
- # - Rewrite the indexer to save memory (skip indices outside the ragged bounds)
- # - Rewrite the kernel to save compute (skip elements outside the ragged bounds)
- # - Update various internal structures/metadata to account for the new
- # block spec.
- # - Set the hacky flag of ragged_originating on the mapping, to signal to
- # the lowering code to treat mapped dimensions as part of the user grid.
- if lengths_aval:
- batched_grid_mapping = batched_grid_mapping.replace(
- get_grid_indices=lambda indices, maybe_include_mapped_dims: indices,
- local_grid_env=lambda loop_idx, grid: tuple(
- pallas_core.GridAxis(idx, b) for (idx, b) in zip(loop_idx, grid)
- ),
- )
-
- # Note - on zero filling counterfactuals
- # A debug util to produce a counterfactual version of the when
- # gating, where for all values that don't pass the @when check,
- # we write 0s. This is useful for debugging, as certain lowering paths
- # like mosaic will write the last data as passthrough, leading to
- # potentially confusing results.
- block_mapped_dim_idxs = []
- for block_mapping in batched_grid_mapping.block_mappings:
- mapped_dim_idxs = []
- for i, d in enumerate(block_mapping.block_shape):
- if isinstance(d, pallas_core.Squeezed):
- mapped_dim_idxs.append(i)
- else:
- mapped_dim_idxs.append(None) # type: ignore[arg-type]
- block_mapped_dim_idxs.append(mapped_dim_idxs)
-
- mapped_dim_idx = None
- for rav, mapped_dim_idxs in zip(ragged_axis_values, block_mapped_dim_idxs):
- if rav is not None:
- stacked_axis = rav[0]
- if mapped_dim_idx is None:
- mapped_dim_idx = mapped_dim_idxs[stacked_axis]
- if mapped_dim_idxs[stacked_axis] is None:
- raise ValueError(
- f"Expected mapped dim to be {stacked_axis}, but got"
- f" {mapped_dim_idxs[stacked_axis]}"
- )
- else:
- assert mapped_dim_idx == mapped_dim_idxs[stacked_axis], (
- f"Different mapped dims - expected {mapped_dim_idx}, but got"
- f" {mapped_dim_idxs[stacked_axis]}"
- )
-
- # This is the blockspec size of the dimension
- block_shapes = [b.block_shape for b in batched_grid_mapping.block_mappings]
-
- # Parse out the operations from the jaxpr to determine how to mask the output
- # NOTE! while this *could* be a default dict of None, and None is sound, as
- # it denotes that there is no raggedness for the given var, we explicitly
- # do not do this, so as to get better signal on implementation of rules
- # A misimplemented rule that does not account for new vars being introduced
- # will result in an error on the next op using the new var. The benefit of
- # of forcing implementers to account for all outputs and intermediaries is
- # a very nice one.
-
- var_to_raggedness = {}
- for invar, rav in zip(jaxpr.invars, ragged_axis_values):
- var_to_raggedness[invar] = rav
-
- for eqn in jaxpr.eqns:
- prim = eqn.primitive
- if prim not in batching.ragged_prop_rules:
- raise NotImplementedError(f"Not implemented - ragged prop for {prim}")
- rule = batching.ragged_prop_rules[prim]
-
- invar_raggedness = [
- (
- var_to_raggedness.get(invar, None)
- if isinstance(invar, jax_core.Var)
- else None
- )
- for invar in eqn.invars
- ]
- try:
- invar_raggedness, outvar_raggedness = rule(
- eqn.params, invar_raggedness, eqn.outvars # type: ignore[arg-type]
- )
- except Exception as e:
- raise RuntimeError(
- f"Failed to run rule for {prim}. invars: {eqn.invars}, outvars:"
- f" {eqn.outvars}. Underlying reason: {e}"
- ) from e
-
- for invar, rav in zip(eqn.invars, invar_raggedness): # type: ignore[assignment]
- if isinstance(invar, jax_core.Var):
- var_to_raggedness[invar] = rav
- for outvar, rav in zip(eqn.outvars, outvar_raggedness):
- if isinstance(outvar, jax_core.Var):
- var_to_raggedness[outvar] = rav
-
- for pos, invar in enumerate(jaxpr.invars):
- ragged_axis_values[pos] = var_to_raggedness[invar]
-
- per_input_ragged_axis_dim: list[int | None] = []
- for rav in ragged_axis_values:
- if rav is not None:
- per_input_ragged_axis_dim.append(rav[1])
- else:
- per_input_ragged_axis_dim.append(None)
-
- def when_wrapped_kernel(lengths_ref, *args, **kwargs):
- b_idx = primitives.program_id(mapped_dim_idx)
-
- b_len = lengths_ref[b_idx]
- run_kernel = jnp.array(True)
- for i, _ in enumerate(args):
- ragged_axis_dim = per_input_ragged_axis_dim[i]
- if ragged_axis_dim is None:
- continue
- arg_i_idx = (
- primitives.program_id(ragged_axis_dim)
- * pallas_core.get_block_dim_size(block_shapes[i][ragged_axis_dim])
- )
- run_kernel = jnp.logical_and(run_kernel, arg_i_idx < b_len)
-
- # TODO(mvoz): Unimplemented primitive in pallas
- # b_len_mod = jnp.equal(jnp.mod(b_len, val_at_ragged_dim), 0)
- # checkify.check(b_len_mod, "b_len % val_at_ragged_dim != 0")
-
- @pallas_helpers.when(run_kernel)
- def f():
- # Important! This allows us to trace the inner kernel with the correct
- # grid to preserve user program_id semantics. Ex: program_id(0) will
- # always be analogous to program_id(1) in the outer kernel.
- with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
- jax_core.eval_jaxpr(jaxpr, (), *args, **kwargs)
-
- kernel_avals = [lengths_aval] + [v.aval for v in jaxpr.invars]
- flat_kernel_avals, kernel_in_tree = tree_util.tree_flatten(
- list(kernel_avals)
- )
-
- def _rewrite_index_jaxpr(enumerate_batched_block_mapping):
- arg_pos, batched_block_mapping = enumerate_batched_block_mapping
- indexer_avals = [
- v.aval for v in batched_block_mapping.index_map_jaxpr.jaxpr.invars
- ]
- flat_indexer_avals, indexer_in_tree = tree_util.tree_flatten(
- list(indexer_avals)
- )
-
- def index_rewrite_kernel(*indexer_args):
- ragged_axis_dim = per_input_ragged_axis_dim[arg_pos]
-
- # the problem here seems to be that we are rnning this for all inputs, per input, because they each have an indexer - which means
- # that the indexer for output isn't getting written - before, it always was
-
- lengths_ref = indexer_args[-1]
- rest_indexer_args = indexer_args[:-1]
- # Lengths are always the last argument of the indexer.
- # lengths_ref = args[-1]
- # Invariant: Stacked axis is enforced to be the mapped axis above.
- b_idx = indexer_args[mapped_dim_idx]
-
- nargs = list(rest_indexer_args)
-
- if ragged_axis_dim is not None:
- val_at_ragged_dim = pallas_core.get_block_dim_size(
- batched_block_mapping.block_shape[ragged_axis_dim])
-
- # The current index into the ragged dimension.
- # Invariant: There is only one ragged dimension, enforced above.
- i_idx = indexer_args[ragged_axis_dim]
-
- # grid space -> element space
- i_len = i_idx * val_at_ragged_dim
-
- # The length of the current batch.
- b_len = lengths_ref[b_idx]
-
- # Have we reached the end of the current batch?
- not_done = i_len < b_len
-
- am_last_batch = b_idx == axis_size - 1
- last_good_block = lax.div(b_len, val_at_ragged_dim) - 1
-
- # The logic below can be thought of as:
- # if index_oob_ragged:
- # if not last_batch:
- # batch_idx += 1
- # ragged_idx = 0
- # else:
- # ragged_idx = last_good_block
- #
- # wherein we find the next good block by incrementing the batch index
- # and setting the ragged index to 0 if we are not in the last batch.
- # Otherwise, we set the ragged index to the last good block.
- b_next = jnp.where(
- not_done, b_idx, jnp.where(am_last_batch, b_idx, b_idx + 1)
- )
- i_next = jnp.where(
- not_done, i_idx, jnp.where(am_last_batch, last_good_block, 0)
- )
- nargs[ragged_axis_dim] = i_next
- nargs[mapped_dim_idx] = b_next
-
- nargs = nargs + [lengths_ref]
- return jax_core.eval_jaxpr(
- batched_block_mapping.index_map_jaxpr.jaxpr,
- batched_block_mapping.index_map_jaxpr.consts,
- *nargs,
- )
- index_jaxpr, _ = _trace_kernel_to_jaxpr(
- index_rewrite_kernel,
- batched_block_mapping.index_map_jaxpr.jaxpr.debug_info,
- batched_grid_mapping,
- tuple(flat_indexer_avals),
- indexer_in_tree,
- tuple(() for _ in flat_indexer_avals),
- indexer=True,
- )
-
- batched_block_mapping = batched_block_mapping.replace(
- index_map_jaxpr=pe.close_jaxpr(index_jaxpr)
- )
- return batched_block_mapping
-
- # Important! This allows us to trace the outer kernel with the correct grid
- # to enable accessing the batch program_id.
- with pallas_core.tracing_grid_env(batched_grid_mapping.grid, ()):
- batched_block_mappings = map(
- _rewrite_index_jaxpr, enumerate(batched_block_mappings)
- )
-
- batched_grid_mapping = batched_grid_mapping.replace(
- block_mappings=tuple(batched_block_mappings),
- )
-
- jaxpr, consts = _trace_kernel_to_jaxpr(
- when_wrapped_kernel,
- jaxpr.debug_info,
- batched_grid_mapping,
- tuple(flat_kernel_avals),
- kernel_in_tree,
- tuple(() for _ in flat_kernel_avals),
- )
- if consts:
- raise NotImplementedError("consts not supported in pallas_call")
-
- # We need to rewrite the input_output_aliases here, the initial call
- # to broadcast is done, and we have inserted a new input (lengths), so
- # there's an off-by-one here now.
- new_input_output_aliases = []
- for k, v in input_output_aliases:
- new_input_output_aliases.append((k + 1, v))
- input_output_aliases = tuple(new_input_output_aliases)
-
- # assert ragged_axis_length is not None
- args = (ragged_axis_length, *args)
assert all(isinstance(aval, jax_core.ShapedArray) for aval in out_avals)
batched_out_avals = []
for aval in out_avals:
- sharding = aval.sharding.update(spec=tuple_insert(aval.sharding.spec, 0, None))
+ manual_mesh = (_as_manual_mesh(aval.sharding.mesh, ema) if ema else
+ aval.sharding.mesh)
+ sharding = aval.sharding.update(
+ mesh=manual_mesh, spec=tuple_insert(aval.sharding.spec, 0, None))
shape = tuple_insert(aval.shape, 0, axis_size)
batched_out_avals.append(aval.update(shape=shape, sharding=sharding))
batched_out_avals = tuple(batched_out_avals)
- out = pallas_call_p.bind(
- *dynamic_grid_args,
- *args,
- jaxpr=jaxpr,
- grid_mapping=batched_grid_mapping,
- mesh=mesh,
- input_output_aliases=input_output_aliases,
- debug=debug,
- interpret=interpret,
- compiler_params=compiler_params,
- cost_estimate=batched_cost_estimate,
- out_avals=batched_out_avals,
- backend=backend,
- metadata=metadata,
- name=name,
- )
+ bind = partial(
+ pallas_call_p.bind, jaxpr=jaxpr, grid_mapping=batched_grid_mapping,
+ mesh=mesh, input_output_aliases=input_output_aliases, debug=debug,
+ interpret=interpret, compiler_params=compiler_params,
+ cost_estimate=batched_cost_estimate, out_avals=batched_out_avals,
+ backend=backend, metadata=metadata, name=name)
+
+ if ema:
+ # TODO all batching rules should probably be in outer mesh ctx
+ bind = remove_explicit(ema)(shard_map(
+ bind, out_specs=P(ema), axis_names=set(ema)))
+
+ out = bind(*dynamic_grid_args, *args)
return out, (0,) * len(out)
+batching.fancy_primitive_batchers[pallas_call_p] = _pallas_call_batching_rule
+batching.skippable_batchers[pallas_call_p] = lambda _: ()
+
-batching.primitive_batchers[pallas_call_p] = _pallas_call_batching_rule
+@contextlib.contextmanager
+def remove_explicit(ema):
+ prev = jax_core.trace_ctx.axis_env
+ # assert set(prev.explicit_mesh_axis_names) == set(ema)
+ new = jax_core.AxisEnv(prev.axis_sizes, prev.spmd_axis_names, set())
+ try:
+ jax_core.trace_ctx.set_axis_env(new)
+ yield
+ finally:
+ jax_core.trace_ctx.set_axis_env(prev)
def checkify_pallas_kernel_body_jaxpr(
@@ -1501,7 +1199,8 @@ def _convert_out_shape_to_aval(out_shape: Any) -> jax_core.AbstractValue:
return jax_core.ShapedArray(
shape=out_shape.shape, dtype=out_shape.dtype,
sharding=jax_core.get_cur_mesh_sharding(), vma=out_shape.vma)
- return jax_core.ShapedArray(shape=out_shape.shape, dtype=out_shape.dtype)
+ return jax_core.ShapedArray(shape=out_shape.shape, dtype=out_shape.dtype,
+ sharding=jax_core.get_cur_mesh_sharding())
case pallas_core.MemoryRef():
return out_shape.get_array_aval()
case hijax.HiType():
@@ -1668,7 +1367,10 @@ def pallas_call(
backend: Backend | None = None,
metadata: dict[str, str] | None = None,
) -> Callable[..., Any]:
- """Invokes a Pallas kernel on some inputs.
+ """Entry point for creating a Pallas kernel.
+
+ In contrast to :func:`jax.experimental.pallas.kernel`, this entry point
+ assumes that the kernel will be executed over a ``grid``.
See `Pallas Quickstart `_.
@@ -1747,6 +1449,9 @@ def pallas_call(
"If `grid_spec` is specified, then `scratch_shapes` must "
f"be `()`. It is {scratch_shapes}")
del grid, in_specs, out_specs
+ # We can infer a backend from compiler_params if it is not specified.
+ if backend is None and isinstance(compiler_params, pallas_core.CompilerParams):
+ backend = compiler_params.BACKEND
return _pallas_call(
kernel,
out_shape,
@@ -1884,18 +1589,6 @@ def wrapped(*args):
f"[0, {len(flat_out_avals)})")
in_aval = flat_in_avals[i_idx]
out_aval = flat_out_avals[o_idx]
- if isinstance(in_aval, jax_core.DShapedArray):
- new_shape = []
- for d in in_aval.shape:
- if isinstance(d, int):
- new_shape.append(d)
- else:
- new_shape.append(d.dtype.bound)
-
- in_aval = jax_core.ShapedArray(
- tuple(new_shape), in_aval.dtype, in_aval.weak_type
- )
-
if in_aval.shape != out_aval.shape or in_aval.dtype != out_aval.dtype:
raise ValueError(
f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' "
diff --git a/jax/_src/pallas/pallas_test_util.py b/jax/_src/pallas/pallas_test_util.py
new file mode 100644
index 000000000000..621ca70b72bd
--- /dev/null
+++ b/jax/_src/pallas/pallas_test_util.py
@@ -0,0 +1,55 @@
+# Copyright 2025 The JAX Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Pallas test utilities."""
+import sys
+
+from jax._src import test_util as jtu
+from jax._src.pallas import pallas_call
+from jax.experimental import pallas as pl
+
+use_mosaic_gpu = pallas_call._PALLAS_USE_MOSAIC_GPU.value
+
+
+@jtu.with_config(jax_traceback_filtering="off")
+class PallasTest(jtu.JaxTestCase):
+ INTERPRET: bool = False
+
+ def setUp(self):
+ if not jtu.test_device_matches(['cpu']) and self.INTERPRET:
+ self.skipTest('Only run interpret tests on CPU.')
+ if not self.INTERPRET:
+ # Running on accelerator
+ if jtu.test_device_matches(["cpu"]):
+ self.skipTest("On CPU the test works only in interpret mode")
+ if (jtu.test_device_matches(["cuda"]) and
+ not jtu.is_cuda_compute_capability_at_least("8.0")):
+ self.skipTest("Only works on GPU with capability >= sm80")
+ if (jtu.test_device_matches(["cuda"]) and use_mosaic_gpu and
+ not jtu.is_cuda_compute_capability_at_least("9.0")):
+ self.skipTest("Mosaic GPU requires capability >= sm90")
+ if sys.platform == "win32":
+ self.skipTest("Only works on non-Windows platforms")
+ super().setUp()
+
+ def pallas_call(self, *args, **kwargs):
+ return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET)
+
+
+class PallasTPUTest(PallasTest):
+ """A test case that only runs on TPUs or in interpret mode on CPU."""
+
+ def setUp(self):
+ if not jtu.test_device_matches(['tpu']) and not self.INTERPRET:
+ self.skipTest('Test requires TPUs')
+ super().setUp()
diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py
index a37d3d4338da..2706f2ecc9b2 100644
--- a/jax/_src/pallas/primitives.py
+++ b/jax/_src/pallas/primitives.py
@@ -39,7 +39,6 @@
from jax._src import state
from jax._src import util
from jax._src.interpreters import ad
-from jax._src.interpreters import batching
from jax._src.interpreters import partial_eval as pe
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith
@@ -59,16 +58,15 @@
zip, unsafe_zip = util.safe_zip, zip
program_id_p = jax_core.Primitive("program_id")
-batching.ragged_prop_rules[program_id_p] = batching.ragged_mask_no_op_rule
def program_id(axis: int) -> jax_typing.Array:
"""Returns the kernel execution position along the given axis of the grid.
- For example, with a 2D `grid` in the kernel execution corresponding to the
- grid coordinates `(1, 2)`,
- `program_id(axis=0)` returns `1` and `program_id(axis=1)` returns `2`.
+ For example, with a 2D ``grid`` in the kernel execution corresponding to the
+ grid coordinates ``(1, 2)``,
+ ``program_id(axis=0)`` returns ``1`` and ``program_id(axis=1)`` returns ``2``.
- The returned value is an array of shape `()` and dtype `int32`.
+ The returned value is an array of shape ``()`` and dtype ``int32``.
Args:
axis: the axis of the grid along which to count the program.
@@ -350,6 +348,8 @@ def _atomic_cas_discharge_rule(in_avals, out_avals, ref, cmp, val):
mlir.register_lowering(max_contiguous_p, lambda _, x, **__: [x])
def max_contiguous(x, values):
+ """A compiler hint that asserts the ``values`` first values of ``x`` are contiguous.
+ """
if not isinstance(values, (list, tuple)):
values = (values,)
return max_contiguous_p.bind(x, values=tuple(values))
@@ -364,6 +364,18 @@ def _max_contiguous_abstract_eval(aval, **_):
mlir.register_lowering(multiple_of_p, lambda _, x, **__: [x])
def multiple_of(x: jax_typing.Array, values: Sequence[int] | int) -> jax_typing.Array:
+ """A compiler hint that asserts a value is a static multiple of another.
+
+ Note that misusing this function, such as asserting ``x`` is a multiple of
+ ``N`` when it is not, can result in undefined behavior.
+
+ Args:
+ x: The input array.
+ values: A set of static divisors that ``x`` is a multiple of.
+
+ Returns:
+ A copy of ``x``.
+ """
values = (values,) if isinstance(values, int) else tuple(values)
return multiple_of_p.bind(x, values=values)
@@ -713,6 +725,24 @@ def _handle_small(dtype: jax_typing.DTypeLike):
def dot(a, b, trans_a: bool = False, trans_b: bool = False,
allow_tf32: bool | None = None, precision=None):
+ """Computes the dot product of two arrays.
+
+ The inputs can optionally be transposed before computing the
+ product. Depending on the hardware, this can be cheaper than
+ computing the transpose beforehand.
+
+ Args:
+ a: The left-hand size of the dot product, of shape ``(..., N)``.
+ b: The right-hand size of the dot product, of shape ``(...N, M)``.
+ trans_a: Whether to transpose ``a`` before the product.
+ trans_b: Whether to transpose ``b`` before the product.
+ allow_tf32: Whether to use tf32 precision.
+ Mutually exclusive with ``precision``.
+ precision: Specifies the precision of the dot product.
+
+ See Also:
+ :func:`jax.numpy.dot`
+ """
if (a.ndim != 2) or (b.ndim != 2):
raise ValueError("`a` and `b` must be 2D arrays.")
lhs_contract_dim = 0 if trans_a else 1
@@ -824,6 +854,17 @@ def wrap_with_transforms(f, transforms, *args):
run_scoped_p = jax_core.Primitive("run_scoped")
run_scoped_p.multiple_results = True
+def _run_scoped_is_high(*avals, jaxpr, **params):
+ del avals, params
+ return jaxpr.is_high
+run_scoped_p.is_high = _run_scoped_is_high # type: ignore[method-assign]
+
+def _run_scoped_to_lojax(*args, jaxpr, **params):
+ closed_hi_jaxpr = jax_core.ClosedJaxpr(jaxpr, args)
+ closed_lo_jaxpr = pe.lower_jaxpr(closed_hi_jaxpr)
+ consts = closed_lo_jaxpr.consts
+ return run_scoped_p.bind(*consts, jaxpr=closed_lo_jaxpr.jaxpr, **params)
+run_scoped_p.to_lojax = _run_scoped_to_lojax
def run_scoped(
f: Callable[..., Any],
@@ -837,9 +878,9 @@ def run_scoped(
to allocate for each argument. Each backend has its own set of reference
types in addition to :class:`jax.experimental.pallas.MemoryRef`.
- When `collective_axes` is specified, the same allocation will be returned for
+ When ``collective_axes`` is specified, the same allocation will be returned for
all programs that only differ in their program ids along the collective axes.
- It is an error not to call the same `run_scoped` in all programs along that
+ It is an error not to call the same ``run_scoped`` in all programs along that
axis.
"""
if not isinstance(collective_axes, tuple):
@@ -974,12 +1015,12 @@ def _lower_fun(*lower_fun_args):
def get_global(what: pallas_core.ScratchShape) -> jax_typing.Array:
"""Returns a global reference that persists across all kernel invocations.
- Each call to get_global returns a different and unique reference, but one that
+ Each call to ``get_global`` returns a different and unique reference, but one that
is stable across invocations of the kernel body.
Args:
what: The reference type to allocate. Each backend has its own set of
- reference types (e.g., `plgpu.SemaphoreType.REGULAR` for GPU).
+ reference types (e.g., :class:`jax.experimental.pallas.mosaic_gpu.SemaphoreType` for GPU).
Example::
@@ -1043,7 +1084,7 @@ def check_sem_avals(
):
raise ValueError(
f"Must {name} semaphores of the following types:"
- f" {allowed_semaphore_types}."
+ f" {allowed_semaphore_types}. Got {sem_dtype}."
)
@@ -1064,7 +1105,15 @@ def _transform_semaphore(ref_value, transforms, ref_aval):
semaphore_read_p.multiple_results = False
-def semaphore_read(sem_or_view):
+def semaphore_read(sem_or_view) -> jax_typing.Array:
+ """Reads the value of a semaphore.
+
+ Args:
+ sem_or_view: A Ref (or view) representing a semaphore.
+
+ Returns:
+ A scalar Array containing the value of the semaphore.
+ """
ref, transforms = _get_ref_and_transforms(sem_or_view)
args = [ref, transforms]
flat_args, args_tree = tree_util.tree_flatten(args)
@@ -1107,6 +1156,24 @@ def semaphore_signal(
device_id_type: DeviceIdType = DeviceIdType.MESH,
core_index: int | jax_typing.Array | None = None,
):
+ """Increments the value of a semaphore.
+
+ This operation can also be performed remotely if ``device_id`` is specified,
+ in which ``sem_or_view`` refers to a Ref located on another device.
+ Note that it is assumed that ``sem_or_view`` is already allocated
+ (e.g. through the proper use of barriers), or else this operation could
+ result in undefined behavior.
+
+ Args:
+ sem_or_view: A Ref (or view) representing a semaphore.
+ inc: The value to increment by.
+ device_id (optional): Specifies which device to signal.
+ If not specified, ``sem_or_view`` is assumed to be local.
+ device_id_type (optional): The format in which
+ ``device_id`` should be specified.
+ core_index (optional): If on a multi-core device,
+ specifies which core to signal.
+ """
ref, transforms = _get_ref_and_transforms(sem_or_view)
inc = jnp.asarray(inc, dtype=jnp.int32)
args = [ref, transforms, inc, device_id, core_index]
@@ -1124,26 +1191,32 @@ def _semaphore_signal_abstract_eval(
args_tree,
device_id_type: DeviceIdType,
):
- del device_id_type
(
sem_aval,
sem_transforms_avals,
value_aval,
- device_id_avals,
+ device_id_aval,
core_index_aval,
) = tree_util.tree_unflatten(args_tree, avals)
check_sem_avals(sem_aval, sem_transforms_avals, "signal")
if value_aval.dtype != jnp.dtype("int32"):
raise ValueError(f"Must signal an int32 value, but got {value_aval.dtype}")
effs : set[effects.Effect] = set()
- if device_id_avals is not None:
- device_id_flat_avals = tree_util.tree_leaves(device_id_avals)
+ if device_id_aval is not None:
+ device_id_flat_avals = tree_util.tree_leaves(device_id_aval)
for aval in device_id_flat_avals:
if aval.dtype != jnp.dtype("int32"):
raise ValueError(
f"`device_id`s must be an int32 value, but got {aval.dtype}"
)
- effs.add(pallas_core.comms_effect)
+ if device_id_type is DeviceIdType.MESH and isinstance(device_id_aval, dict):
+ for k in device_id_aval:
+ if not isinstance(k, tuple):
+ k = (k,)
+ for k_ in k:
+ effs.add(jax_core.NamedAxisEffect(k_))
+ else:
+ effs.add(pallas_core.comms_effect)
return [], effs
def _semaphore_signal_pp_eqn(eqn: jax_core.JaxprEqn,
@@ -1208,6 +1281,14 @@ def _semaphore_signal_discharge_rule(in_avals,
def semaphore_wait(
sem_or_view, value: int | jax_typing.Array = 1, *, decrement: bool = True
):
+ """Blocks execution of the current thread until a semaphore reaches a value.
+
+ Args:
+ sem_or_view: A Ref (or view) representing a semaphore.
+ value: The target value that the semaphore should reach before unblocking.
+ decrement: Whether to decrement the value of the semaphore after
+ a successful wait.
+ """
ref, transforms = _get_ref_and_transforms(sem_or_view)
value = jnp.asarray(value, dtype=jnp.int32)
args = [ref, transforms, value, decrement]
@@ -1270,38 +1351,59 @@ def _semaphore_wait_discharge_rule(in_avals,
)
-def _device_id_dict_to_mesh(mesh_context: pallas_utils.MeshInfo, device_id_dict, get_axis_index):
+def _device_id_dict_to_mesh(mesh_context: pallas_utils.MeshInfo | None, device_id_dict, get_axis_index):
i32 = ir.IntegerType.get_signless(32)
- assert mesh_context is not None
- mesh_axis_sizes = dict(zip(mesh_context.axis_names, mesh_context.mesh_shape))
+ if mesh_context is None:
+ mesh_axis_sizes = {}
+ else:
+ mesh_axis_sizes = dict(
+ zip(mesh_context.axis_names, mesh_context.mesh_shape)
+ )
physical_axis_dict = {}
# Handle joint axes (i.e., one logical axis over >1 physical axes)
- for axis, idx in device_id_dict.items():
- if isinstance(axis, tuple) and any(a in mesh_context.axis_names for a in axis):
- if not all(a in mesh_context.axis_names for a in axis):
+ for axis_name, idx in device_id_dict.items():
+ if isinstance(axis_name, tuple) and any(
+ a in mesh_axis_sizes for a in axis_name
+ ):
+ if not all(a in mesh_axis_sizes for a in axis_name):
raise NotImplementedError(
- f"{axis} mixes JAX mesh and Pallas mesh grid axes"
- )
- axes_dimensions = [mesh_axis_sizes[name] for name in axis]
- for axis_index, axis_name in enumerate(axis):
- axis_size = arith.constant(i32, mesh_axis_sizes[axis_name])
- minor_divisor = arith.constant(
- i32, math.prod(axes_dimensions[axis_index + 1 :])
+ f"{axis_name} mixes JAX mesh and Pallas mesh grid axes"
)
- device_idx = arith.remsi(arith.divsi(idx, minor_divisor), axis_size)
+ axes_dimensions = [mesh_axis_sizes[name] for name in axis_name]
+ for axis_index, axis_name in enumerate(axis_name):
+ axis_size = mesh_axis_sizes[axis_name]
+ inner_mesh_size = math.prod(axes_dimensions[axis_index + 1 :])
+ minor_divisor = arith.constant(i32, inner_mesh_size)
+
+ # Fast path for power of 2s
+ if inner_mesh_size & (inner_mesh_size - 1) == 0:
+ shift_len = (inner_mesh_size & -inner_mesh_size).bit_length() - 1
+ partial_device_idx = arith.shrui(idx, arith.constant(i32, shift_len))
+ else:
+ partial_device_idx = arith.divsi(idx, minor_divisor)
+
+ if axis_size & (axis_size - 1) == 0:
+ device_idx = arith.andi(
+ partial_device_idx,
+ arith.constant(i32, mesh_axis_sizes[axis_name] - 1),
+ )
+ else:
+ device_idx = arith.remsi(
+ partial_device_idx, arith.constant(i32, axis_size)
+ )
physical_axis_dict[axis_name] = device_idx
else:
- physical_axis_dict[axis] = idx
+ physical_axis_dict[axis_name] = idx
device_id = []
- for axis in mesh_context.axis_names:
- if axis in physical_axis_dict:
- device_id.append(physical_axis_dict[axis])
+ for axis_name in mesh_axis_sizes:
+ if axis_name in physical_axis_dict:
+ device_id.append(physical_axis_dict[axis_name])
else:
- device_id.append(get_axis_index(axis))
+ device_id.append(get_axis_index(axis_name))
non_mesh_axes = {
k: v
for k, v in physical_axis_dict.items()
- if k not in mesh_context.axis_names
+ if k not in mesh_axis_sizes
}
return tuple(device_id), non_mesh_axes
@@ -1323,13 +1425,15 @@ def device_id_to_logical(
"`device_id_type` must be MESH if `device_id` is a dict,"
f" got: {device_id_type = }."
)
- assert mesh_context is not None
device_id, non_mesh_axes = _device_id_dict_to_mesh(mesh_context, device_id, get_axis_index)
if device_id_type is DeviceIdType.MESH:
- assert mesh_context is not None
# Mesh means we are passed the mesh coordinates for the device
device_ids = tree_util.tree_leaves(device_id)
- mesh_strides = mesh_context.mesh_strides
+ mesh_strides: tuple[int, ...]
+ if mesh_context is None:
+ mesh_strides = ()
+ else:
+ mesh_strides = mesh_context.mesh_strides
if len(device_ids) != len(mesh_strides):
raise ValueError(
"Number of device ids must match the number of mesh axes, but got"
diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py
index 2caa9f860bcb..62411e7f8659 100644
--- a/jax/_src/pallas/triton/lowering.py
+++ b/jax/_src/pallas/triton/lowering.py
@@ -2537,7 +2537,7 @@ def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_):
@register_lowering(pjit.reshard_p)
-def _reshard_lowering_rule(ctx, x, dst_sharding):
+def _reshard_lowering_rule(ctx, x, *, dst_sharding, concrete_mesh):
return x
diff --git a/jax/_src/pallas/utils.py b/jax/_src/pallas/utils.py
index 77e157201107..90a61aeb4302 100644
--- a/jax/_src/pallas/utils.py
+++ b/jax/_src/pallas/utils.py
@@ -45,6 +45,14 @@ def cdiv(a: jax_typing.Array, b: jax_typing.Array) -> jax_typing.Array:
...
def cdiv(a: int | jax_typing.Array, b: int | jax_typing.Array) -> int | jax_typing.Array:
+ """Computes the ceiling division of a divided by b.
+
+ Examples:
+ >>> cdiv(8, 2)
+ 4
+ >>> cdiv(9, 2) # 9 / 2 = 4.5, which rounds up to 5
+ 5
+ """
if isinstance(a, int) and isinstance(b, int):
return (a + b - 1) // b
return lax.div(a + (b - 1), b)
diff --git a/jax/_src/partition_spec.py b/jax/_src/partition_spec.py
index 7720adc2d072..d1ee8d6a40fc 100644
--- a/jax/_src/partition_spec.py
+++ b/jax/_src/partition_spec.py
@@ -163,6 +163,15 @@ def _normalized_spec_for_aval(self, ndim: int) -> PartitionSpec:
out.extend([None] * (ndim - len(out)))
return self.update(partitions=out)
+ def _check_compatible_wrt_shape(self, shape):
+ if len(shape) < len(self._partitions):
+ extra_msg = (' For scalars the PartitionSpec should be P()'
+ if len(shape) == 0 else '')
+ raise ValueError(
+ f"PartitionSpec {self} is only valid for values of rank at least "
+ f"{len(self._partitions)}, but was applied to a value of rank "
+ f"{len(shape)}.{extra_msg}")
+
PartitionSpec.__module__ = 'jax.sharding'
P = PartitionSpec
diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py
index 54aedc578248..bcb4afe64c20 100644
--- a/jax/_src/pjit.py
+++ b/jax/_src/pjit.py
@@ -21,7 +21,7 @@
import inspect
import logging
import weakref
-from typing import NamedTuple, Any, Union, cast
+from typing import NamedTuple, Any, Union
import warnings
import numpy as np
@@ -46,10 +46,9 @@
from jax._src import xla_bridge as xb
from jax._src.core import typeof, cur_qdd
from jax._src.api_util import (
- argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
- donation_vector, check_callable, resolve_argnums,
- argnames_partial_except, debug_info, check_no_aliased_ref_args,
- _check_no_aliased_closed_over_refs)
+ argnums_partial_except, flatten_axes, flatten_fun3, donation_vector,
+ check_callable, resolve_argnums, argnames_partial_except, debug_info,
+ check_no_aliased_ref_args, _check_no_aliased_closed_over_refs)
from jax._src.interpreters import partial_eval as pe
from jax._src.partition_spec import PartitionSpec
from jax._src.interpreters import ad
@@ -72,8 +71,8 @@
from jax._src.traceback_util import api_boundary
from jax._src.tree_util import (
tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure,
- treedef_children, broadcast_prefix, all_leaves, prefix_errors, keystr,
- PyTreeDef, none_leaf_registry as none_lr, tree_map)
+ treedef_children, prefix_errors, keystr, PyTreeDef,
+ none_leaf_registry as none_lr, tree_map)
from jax._src.typing import ArrayLike
from jax._src.util import (
HashableFunction, safe_map, safe_zip, wraps, distributed_debug_log,
@@ -120,7 +119,6 @@ class PjitInfo(NamedTuple):
backend: str | None
keep_unused: bool
inline: bool
- abstracted_axes: Any | None
use_resource_env: bool # False for jit, True for pjit
compiler_options_kvs: tuple[tuple[str, Any], ...]
@@ -188,7 +186,7 @@ def _need_to_rebuild_with_fdo(pgle_profiler):
def _get_fastpath_data(
executable, out_tree, args_flat, out_flat, effects, consts_for_constvars,
- abstracted_axes, pgle_profiler, const_args: Sequence[ArrayLike]
+ pgle_profiler, const_args: Sequence[ArrayLike]
) -> pxla.MeshExecutableFastpathData | None:
if (
executable is None
@@ -197,7 +195,6 @@ def _get_fastpath_data(
# No effects in computation
or executable.unsafe_call.ordered_effects
or executable.unsafe_call.has_unordered_effects
- or abstracted_axes is not None
# no ref state effects
or any(isinstance(e, RefEffect) for e in effects)
# no prng reuse checking
@@ -266,7 +263,7 @@ def cache_miss(*args, **kwargs):
maybe_fastpath_data = _get_fastpath_data(
executable, out_tree, args_flat, out_flat, jaxpr.effects, jaxpr.consts,
- jit_info.abstracted_axes, pgle_profiler,
+ pgle_profiler,
const_args)
return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)
@@ -350,7 +347,6 @@ def _parse_jit_arguments(fun: Callable, *, in_shardings: Any,
donate_argnames: str | Iterable[str] | None,
keep_unused: bool, device: xc.Device | None,
backend: str | None, inline: bool,
- abstracted_axes: Any | None,
compiler_options: dict[str, Any] | None,
use_resource_env: bool) -> PjitInfo:
"""Parses the arguments to jit/pjit.
@@ -358,9 +354,6 @@ def _parse_jit_arguments(fun: Callable, *, in_shardings: Any,
Performs any preprocessing and validation of the arguments that we can do
ahead of time before the jit()-ed function is invoked.
"""
- if abstracted_axes and not config.dynamic_shapes.value:
- raise ValueError("abstracted_axes must be used with --jax_dynamic_shapes")
-
check_callable(fun)
if backend is not None or device is not None:
@@ -428,7 +421,6 @@ def _parse_jit_arguments(fun: Callable, *, in_shardings: Any,
static_argnames=static_argnames, donate_argnums=donate_argnums,
donate_argnames=donate_argnames, device=device, backend=backend,
keep_unused=keep_unused, inline=inline,
- abstracted_axes=abstracted_axes,
use_resource_env=use_resource_env,
compiler_options_kvs=compiler_options_kvs)
@@ -444,7 +436,6 @@ def make_jit(fun: Callable,
device: xc.Device | None,
backend: str | None,
inline: bool,
- abstracted_axes: Any | None,
compiler_options: dict[str, Any] | None,
use_resource_env: bool) -> Any:
"""jit() and pjit() are thin wrappers around this function."""
@@ -453,7 +444,7 @@ def make_jit(fun: Callable,
static_argnums=static_argnums, static_argnames=static_argnames,
donate_argnums=donate_argnums, donate_argnames=donate_argnames,
keep_unused=keep_unused, device=device, backend=backend, inline=inline,
- abstracted_axes=abstracted_axes, compiler_options=compiler_options,
+ compiler_options=compiler_options,
use_resource_env=use_resource_env)
return _cpp_pjit(fun, jit_info)
@@ -491,8 +482,6 @@ def _infer_params_impl(
"Mesh context manager should not be used with jit when backend or "
"device is also specified as an argument to jit.")
- axes_specs = _flat_axes_specs(ji.abstracted_axes, *args, **kwargs)
-
f = lu.wrap_init(fun, debug_info=dbg)
f, dyn_args = argnums_partial_except(f, ji.static_argnums, args, allow_invalid=True)
del args
@@ -500,7 +489,7 @@ def _infer_params_impl(
del kwargs
explicit_args, in_tree = tree_flatten((dyn_args, dyn_kwargs))
- flat_fun, out_tree = flatten_fun(f, in_tree)
+ flat_fun, out_tree_and_result_paths = flatten_fun3(f, in_tree)
if (ji.donate_argnums or ji.donate_argnames) and not config.debug_nans.value:
donated_invars = donation_vector(ji.donate_argnums, ji.donate_argnames, in_tree)
@@ -531,14 +520,9 @@ def _infer_params_impl(
assert None not in out_shardings_leaves
in_type: core.InputType | tuple[core.AbstractValue, ...]
- if config.dynamic_shapes.value:
- assert in_avals is None
- in_type = pe.infer_lambda_input_type(axes_specs, explicit_args)
- in_avals = tuple(a for a, e in in_type if e)
- else:
- in_type = in_avals # type: ignore
- in_type = tuple(core.AvalQDD(a, cur_qdd(x)) if a.has_qdd # type: ignore
- else a for a, x in zip(in_type, explicit_args))
+ in_type = in_avals # type: ignore
+ in_type = tuple(core.AvalQDD(a, cur_qdd(x)) if a.has_qdd # type: ignore
+ else a for a, x in zip(in_type, explicit_args))
assert in_avals is not None
in_shardings_flat, in_layouts_flat = _process_in_axis_resources(
@@ -550,6 +534,7 @@ def _infer_params_impl(
jaxpr, consts, out_avals = _create_pjit_jaxpr(
flat_fun, in_type, qdd_token, IgnoreKey(ji.inline))
+
if config.mutable_array_checks.value:
_check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), explicit_args)
_qdd_cache_update(flat_fun, in_type, qdd_token, consts,
@@ -557,25 +542,25 @@ def _infer_params_impl(
out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings(
out_shardings_treedef, out_shardings_leaves, ji.out_layouts_treedef,
- ji.out_layouts_leaves, HashableFunction(out_tree, closure=()),
+ ji.out_layouts_leaves, HashableFunction(lambda: out_tree_and_result_paths()[0], closure=()),
tuple(out_avals), jaxpr.jaxpr._debug_info, device_or_backend_set)
assert len(explicit_args) == len(in_shardings_flat) == len(in_layouts_flat)
- if config.dynamic_shapes.value:
- implicit_args = _extract_implicit_args(
- cast(core.InputType, in_type), explicit_args)
- else:
- implicit_args = []
- args_flat = [*implicit_args, *explicit_args]
+ args_flat = explicit_args
- num_extra_args = len(implicit_args) + len(consts)
+ num_extra_args = len(consts)
in_shardings_flat = (UNSPECIFIED,) * num_extra_args + in_shardings_flat
in_layouts_flat = (None,) * num_extra_args + in_layouts_flat
donated_invars = (False,) * num_extra_args + donated_invars
assert (len(in_shardings_flat) == len(in_layouts_flat) ==
len(donated_invars) == len(consts) + len(args_flat))
+ out_tree, result_paths = out_tree_and_result_paths()
+ result_paths = tuple(f"result{lu._clean_keystr_arg_names(path)}"
+ for path in result_paths)
+ jaxpr.jaxpr._debug_info = jaxpr.debug_info._replace(result_paths=result_paths)
+
params = dict(
jaxpr=jaxpr,
in_shardings=in_shardings_flat,
@@ -589,8 +574,9 @@ def _infer_params_impl(
inline=ji.inline,
compiler_options_kvs=ji.compiler_options_kvs,
)
+
return (PjitParams(consts, params, in_avals,
- in_tree, out_tree(), dbg.safe_arg_names(len(in_avals))),
+ in_tree, out_tree, dbg.safe_arg_names(len(in_avals))),
args_flat)
@@ -639,11 +625,6 @@ def _infer_params_internal(
static_argnames=ji.static_argnames, sourceinfo=ji.fun_sourceinfo,
signature=ji.fun_signature)
- if config.dynamic_shapes.value: # don't use the cache
- p, args_flat = _infer_params_impl(fun, ji, ctx_mesh, dbg_fn(),
- args, kwargs, in_avals=None)
- return p, p.consts + args_flat
-
signature, dynargs = jax_jit.parse_arguments(
args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums,
ji.static_argnames, tree_util.default_registry)
@@ -687,45 +668,6 @@ def _infer_input_type(fun: Callable, dbg_fn: Callable[[], core.DebugInfo],
check_no_aliased_ref_args(dbg_fn, avals, explicit_args)
return tuple(avals)
-def _extract_implicit_args(
- in_type: Sequence[tuple[core.AbstractValue, bool]],
- explicit_args: Sequence[Any]
-) -> Sequence[core.Tracer]:
- """
- Given an input type and explicitly-passed arguments (per the user-facing API
- calling convention), extract implicit axis size arguments from shapes of
- explicit arguments (for the trace-time / jaxpr-level calling convention).
- """
- # First, using `in_type` construct a list to represent the full argument list,
- # leaving the implicit arguments as None placeholders for now.
- explicit_args_ = iter(explicit_args)
- args = [next(explicit_args_) if expl else None for _, expl in in_type]
- assert next(explicit_args_, None) is None
- del explicit_args, explicit_args_
-
- # Next, populate the implicit arguments using the DBIdxs in `in_type`.
- for i, (aval, explicit) in enumerate(in_type):
- if not explicit or not isinstance(aval, core.DShapedArray):
- continue # can't populate an implicit argument
- arg = args[i]
- assert arg is not None
- for d1, d2 in zip(aval.shape, arg.aval.shape):
- if isinstance(d1, core.DBIdx):
- if args[d1.val] is None:
- args[d1.val] = d2
- assert core.same_referent(args[d1.val], d2)
- assert all(x is not None for x in args)
- return [x for x, (_, e) in zip(args, in_type) if not e] # type: ignore
-
-def _flat_axes_specs(abstracted_axes, *args, **kwargs
- ) -> list[pe.AbstractedAxesSpec] | None:
- if abstracted_axes is None: return None
- if kwargs: raise NotImplementedError
- def ax_leaf(l):
- return (isinstance(l, dict) and all_leaves(l.values()) or
- isinstance(l, tuple) and all_leaves(l, lambda x: x is None))
- return broadcast_prefix(abstracted_axes, args, ax_leaf)
-
class JitWrapped(stages.Wrapped):
@@ -752,7 +694,6 @@ def pjit(
device: xc.Device | None = None,
backend: str | None = None,
inline: bool = False,
- abstracted_axes: Any | None = None,
compiler_options: dict[str, Any] | None = None,
) -> JitWrapped:
"""`jax.experimental.pjit.pjit` has been deprecated. Please use `jax.jit`."""
@@ -761,8 +702,7 @@ def pjit(
static_argnums=static_argnums, static_argnames=static_argnames,
donate_argnums=donate_argnums, donate_argnames=donate_argnames,
keep_unused=keep_unused, device=device, backend=backend, inline=inline,
- abstracted_axes=abstracted_axes, compiler_options=compiler_options,
- use_resource_env=True)
+ compiler_options=compiler_options, use_resource_env=True)
def hashable_pytree(pytree):
@@ -884,13 +824,12 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves,
in_layouts_flat = flatten_axis_resources(
"pjit in_layouts", in_tree, in_layouts, tupled_args=True)
- if not config.dynamic_shapes.value:
- pjit_check_aval_sharding(in_shardings_flat, in_avals,
- debug_info.safe_arg_names(len(in_avals)),
- "pjit arguments", allow_uneven_sharding=False)
- check_aval_layout_compatibility(
- in_layouts_flat, in_avals,
- debug_info.safe_arg_names(len(in_avals)), "jit arguments") # type: ignore[arg-type]
+ pjit_check_aval_sharding(in_shardings_flat, in_avals,
+ debug_info.safe_arg_names(len(in_avals)),
+ "pjit arguments", allow_uneven_sharding=False)
+ check_aval_layout_compatibility(
+ in_layouts_flat, in_avals,
+ debug_info.safe_arg_names(len(in_avals)), "jit arguments") # type: ignore[arg-type]
return in_shardings_flat, in_layouts_flat
callsites_with_tracing_cache_miss: set[str] = set()
@@ -1051,7 +990,7 @@ def arg_type_to_str(at):
if t[0] != ot[0]:
unavailable(f"fun_transforms[{i}] transform", t, ot)
continue
- if t_name == "flatten_fun":
+ if t_name == "flatten_fun3":
explain_in_tree_diff(t[1][0], ot[1][0])
continue
if t_name == "_argnums_partial":
@@ -1176,12 +1115,7 @@ def _create_pjit_jaxpr(
with dispatch.log_elapsed_time(
"Finished tracing + transforming {fun_name} for pjit in {elapsed_time:.9f} sec",
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
- if config.dynamic_shapes.value:
- assert isinstance(in_type, core.InputType)
- jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2(
- lu.annotate(fun, in_type))
- else:
- jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, in_type)
+ jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, in_type)
if config.debug_key_reuse.value:
# Import here to avoid circular imports
@@ -1220,15 +1154,14 @@ def _check_and_canonicalize_out_shardings(
out_layouts_flat = flatten_axis_resources(
"pjit out_layouts", out_tree(), out_layouts, tupled_args=False)
- if not config.dynamic_shapes.value:
- pjit_check_aval_sharding(
- out_shardings_flat, out_avals,
- debug_info.safe_result_paths(len(out_avals)),
- "pjit outputs", allow_uneven_sharding=False)
- check_aval_layout_compatibility(
- out_layouts_flat, out_avals,
- debug_info.safe_result_paths(len(out_avals)),
- "jit outputs")
+ pjit_check_aval_sharding(
+ out_shardings_flat, out_avals,
+ debug_info.safe_result_paths(len(out_avals)),
+ "pjit outputs", allow_uneven_sharding=False)
+ check_aval_layout_compatibility(
+ out_layouts_flat, out_avals,
+ debug_info.safe_result_paths(len(out_avals)),
+ "jit outputs")
return out_shardings_flat, out_layouts_flat
_seen_qdds = weakref.WeakKeyDictionary() # type: ignore
@@ -1273,30 +1206,25 @@ def pjit_check_aval_sharding(
name_str = f' with pytree key path {name}' if name else ''
shape = aval.shape
try:
- # Sharding interfaces can implement `check_compatible_aval` as an optional
- # method to raise a more meaningful error.
- if hasattr(s, 'check_compatible_aval'):
- s.check_compatible_aval(shape)
- else:
- s._to_xla_hlo_sharding(len(shape))
+ s.check_compatible_aval(shape)
except ValueError as e:
raise ValueError(
f'One of {what_aval}{name_str} is incompatible with its sharding '
f'annotation {s}: {e}')
- # Use the `OpSharding` proto to find out how many ways each dimension of
- # the aval is sharded. This approach will work across all
- # Sharding.
- hlo_sharding = s._to_xla_hlo_sharding(len(shape))
- assert hlo_sharding is not None
- num_ways_dim_sharded, _ = op_shardings.get_num_ways_dim_sharded(
- hlo_sharding, allow_partial_manual)
- for i, size in enumerate(num_ways_dim_sharded):
- if not allow_uneven_sharding and shape[i] % size != 0:
- raise ValueError(f"One of {what_aval}{name_str} was given the sharding "
- f"of {s}, which implies that "
- f"the global size of its dimension {i} should be "
- f"divisible by {size}, but it is equal to {shape[i]} "
- f"(full shape: {shape})")
+
+ if not allow_uneven_sharding:
+ hlo_sharding = s._to_xla_hlo_sharding(len(shape))
+ assert hlo_sharding is not None
+ num_ways_dim_sharded, _ = op_shardings.get_num_ways_dim_sharded(
+ hlo_sharding, allow_partial_manual)
+ for i, size in enumerate(num_ways_dim_sharded):
+ if shape[i] % size != 0:
+ raise ValueError(
+ f'One of {what_aval}{name_str} was given the sharding '
+ f'of {s}, which implies that '
+ f'the global size of its dimension {i} should be '
+ f'divisible by {size}, but it is equal to {shape[i]} '
+ f'(full shape: {shape})')
def check_aval_layout_compatibility(
@@ -1672,7 +1600,7 @@ def call_impl_cache_miss(*args_, **kwargs_):
inline=inline, compiler_options_kvs=compiler_options_kvs)
fastpath_data = _get_fastpath_data(
compiled, tree_structure(out_flat), args, out_flat,
- jaxpr.effects, jaxpr.consts, None, pgle_profiler,
+ jaxpr.effects, jaxpr.consts, pgle_profiler,
const_args)
return out_flat, fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)
@@ -1738,36 +1666,12 @@ def pjit_staging_rule(trace, source_info, *args, **params):
all(i is None for i in params["in_layouts"]) and
all(o is None for o in params["out_layouts"])):
jaxpr = params["jaxpr"]
- if config.dynamic_shapes.value:
- # Inline jaxpr doesn't handle dynamic shapes when inlining. If dynamic
- # shapes are enabled, use eval_jaxpr, which uses the tracing machinery,
- # but redundantly performs abstract evaluation again.
- with core.set_current_trace(trace):
- out = core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args,
- propagate_source_info=False)
- else:
- out = pe.inline_jaxpr_into_trace(
- trace, source_info, jaxpr.jaxpr, jaxpr.consts, *args)
+ out = pe.inline_jaxpr_into_trace(
+ trace, source_info, jaxpr.jaxpr, jaxpr.consts, *args)
return [trace.to_jaxpr_tracer(x, source_info) for x in out]
jaxpr = params['jaxpr']
- if config.dynamic_shapes.value:
- jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding(
- jaxpr, params['out_shardings'], params['out_layouts'])
- params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings,
- out_layouts=out_layouts)
- outvars = map(trace.frame.newvar, _out_type(jaxpr))
- eqn = core.new_jaxpr_eqn(
- [arg.var for arg in args], outvars, jit_p, params,
- jaxpr.effects, source_info)
- trace.frame.add_eqn(eqn)
- out_tracers = [pe.DynamicJaxprTracer(trace, v.aval, v, source_info)
- for v in outvars]
- out_tracers_ = iter(out_tracers)
- out_tracers = [args[f] if type(f) is int else next(out_tracers_)
- for f in in_fwd]
- assert next(out_tracers_, None) is None
- elif any(isinstance(c, core.Ref) for c in jaxpr.consts):
+ if any(isinstance(c, core.Ref) for c in jaxpr.consts):
jaxpr, consts = pxla._move_mutable_consts(jaxpr)
consts = [trace.new_const(c, source_info) for c in consts]
in_shardings = (*params['in_shardings'],) + (UNSPECIFIED,) * len(consts)
@@ -1795,15 +1699,7 @@ def _pjit_forwarding(jaxpr, out_shardings, out_layouts):
return jaxpr, in_fwd, out_shardings, out_layouts
def pjit_forwarding_rule(eqn):
- if not config.dynamic_shapes.value:
- return [None] * len(eqn.outvars), eqn
- jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding(
- eqn.params['jaxpr'], eqn.params['out_shardings'], eqn.params['out_layouts'])
- new_outvars = [v for v, f in zip(eqn.outvars, in_fwd) if f is None]
- new_params = dict(eqn.params, jaxpr=jaxpr, out_shardings=out_shardings,
- out_layouts=out_layouts)
- new_eqn = eqn.replace(params=new_params, outvars=new_outvars)
- return in_fwd, new_eqn
+ return [None] * len(eqn.outvars), eqn
# TODO(mattjj): Remove pjit_forwarding_rule and also in staging rule.
pe.forwarding_rules[jit_p] = pjit_forwarding_rule
@@ -1817,11 +1713,6 @@ def _out_type(jaxpr: core.ClosedJaxpr) -> list[core.AbstractValue]:
if type(x) is core.Var}
for x in jaxpr.jaxpr.outvars:
aval = x.aval
- if type(aval) is core.DShapedArray:
- shape = [core.InDBIdx(in_idx[d]) if d in in_idx else
- core.OutDBIdx(out_idx[d]) if d in out_idx else
- d for d in x.aval.shape]
- aval = aval.update(shape=tuple(shape))
out.append(aval)
return out
@@ -1949,10 +1840,7 @@ def _pjit_batcher(axis_data, vals_in,
in_shardings, out_shardings, in_layouts, out_layouts,
donated_invars, ctx_mesh, name, keep_unused, inline,
compiler_options_kvs):
- segment_lens, dims_in = batching.indirectify_ragged_axes(dims_in)
new_jaxpr, axes_out = batching.batch_jaxpr2(jaxpr, axis_data, dims_in)
-
- # TODO(axch): prepend with Nones (?) to account for new segment_lens inputs
in_shardings = tuple(
_pjit_batcher_for_sharding(i, axis_in, axis_data.spmd_name, ctx_mesh,
aval.ndim)
@@ -1983,16 +1871,13 @@ def _pjit_batcher(axis_data, vals_in,
inline=inline,
compiler_options_kvs=compiler_options_kvs)
- resolved_axes_out = batching.resolve_ragged_axes_against_inputs_outputs(
- vals_in, vals_out, axes_out)
- return vals_out, resolved_axes_out
+ return vals_out, axes_out
batching.fancy_primitive_batchers[jit_p] = _pjit_batcher
-batching.ragged_prop_rules[jit_p] = batching.ragged_mask_no_op_rule
def _pjit_batcher_for_sharding(
- s, dim: int | batching.RaggedAxis, spmd_axis_name: tuple[str, ...] | None,
+ s, dim: int, spmd_axis_name: tuple[str, ...] | None,
mesh, ndim: int):
if isinstance(s, UnspecifiedValue):
return s
@@ -2338,81 +2223,6 @@ def _pjit_partial_eval_custom_params_updater(
_pjit_partial_eval_custom_params_updater)
-@lu.cache
-def _pjit_transpose_trace(fun: lu.WrappedFun,
- in_avals: Sequence[core.AbstractValue]):
- transpose_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
- transpose_jaxpr = core.ClosedJaxpr(transpose_jaxpr, consts)
- return transpose_jaxpr
-
-
-def _pjit_transpose(cts_in, *primals_in,
- jaxpr: core.ClosedJaxpr,
- in_shardings, out_shardings, in_layouts, out_layouts,
- donated_invars, ctx_mesh, name, keep_unused, inline,
- compiler_options_kvs):
- def prune_type(ty, xs, maybe_zeros):
- return tuple(x for x, mz in zip(xs, maybe_zeros) if type(mz) is not ty)
-
- dbg = jaxpr.jaxpr.debug_info.with_unknown_names()
- body = lu.wrap_init(ad.closed_backward_pass, debug_info=dbg)
- body = lu.hashable_partial(body, jaxpr, False)
- primals_and_nz_cts_in, in_treedef = tree_flatten((primals_in, cts_in))
- body, cts_out_treedef_thunk = flatten_fun_nokwargs(body, in_treedef)
-
- transpose_in_shardings = (
- *prune_type(ad.UndefinedPrimal, in_shardings, primals_in),
- *prune_type(ad.Zero, out_shardings, cts_in)
- )
- transpose_in_layouts = (
- *prune_type(ad.UndefinedPrimal, in_layouts, primals_in),
- *prune_type(ad.Zero, out_layouts, cts_in)
- )
- global_cts_in_avals = tuple(
- core.AvalQDD(a, cur_qdd(x)) if (a := typeof(x)).has_qdd else a
- for x in primals_and_nz_cts_in)
-
- transpose_jaxpr = _pjit_transpose_trace(body, global_cts_in_avals)
- cts_out_treedef = cts_out_treedef_thunk()
- transpose_out_shardings = prune_type(
- ad.Zero,
- in_shardings,
- tree_unflatten(cts_out_treedef, [object()] * cts_out_treedef.num_leaves))
- transpose_out_layouts = prune_type(
- ad.Zero,
- in_layouts,
- tree_unflatten(cts_out_treedef, [object()] * cts_out_treedef.num_leaves))
-
- try:
- nz_cts_out = jit_p.bind(
- *primals_and_nz_cts_in,
- jaxpr=transpose_jaxpr,
- in_shardings=transpose_in_shardings,
- out_shardings=transpose_out_shardings,
- in_layouts=transpose_in_layouts,
- out_layouts=transpose_out_layouts,
- donated_invars=(False,) * len(primals_and_nz_cts_in),
- ctx_mesh=ctx_mesh,
- name=name,
- keep_unused=keep_unused,
- inline=inline,
- compiler_options_kvs=compiler_options_kvs)
- except api_util.InternalFloatingPointError as e:
- print("Invalid nan value encountered in the backward pass of a jax.jit "
- "function. Calling the de-optimized backward pass.")
- try:
- _ = ad.closed_backward_pass(jaxpr, None, primals_in, cts_in)
- except (FloatingPointError, ZeroDivisionError) as e2:
- raise e2 from None # great
- else:
- # If control reaches this line, we got a NaN on the output of `compiled`
- # but not `fun.call_wrapped` on the same arguments. Let's tell the user.
- api_util._raise_no_nan_in_deoptimized(e)
-
- return tree_unflatten(cts_out_treedef, nz_cts_out)
-ad.primitive_transposes[jit_p] = _pjit_transpose
-
-
def _pjit_transpose_fancy(
cts_in, *args, jaxpr, in_shardings, out_shardings, in_layouts,
out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline,
@@ -2809,40 +2619,51 @@ def reshard(xs, out_shardings):
f'and have a nonempty mesh. Got sharding {s}.'
)
ds = ds.update(spec=ds.spec._normalized_spec_for_aval(x_aval.ndim)) # pytype: disable=attribute-error
- out_flat.append(reshard_p.bind(x, dst_sharding=ds))
+ cmesh = (s.mesh if (isinstance(s, NamedSharding) and
+ isinstance(s.mesh, mesh_lib.Mesh))
+ else None)
+ out_flat.append(reshard_p.bind(x, dst_sharding=ds, concrete_mesh=cmesh))
return tree_unflatten(treedef, out_flat)
reshard_p = core.Primitive('reshard')
reshard_p.skip_canonicalization = True
-def _reshard_abstract_eval(aval, dst_sharding):
+def _reshard_abstract_eval(aval, *, dst_sharding, concrete_mesh):
assert isinstance(aval, core.ShapedArray)
if aval.sharding == dst_sharding:
return aval
return aval.update(sharding=dst_sharding)
reshard_p.def_abstract_eval(_reshard_abstract_eval)
-def _reshard_impl(x, dst_sharding):
- return dispatch.apply_primitive(reshard_p, x, dst_sharding=dst_sharding)
+def _reshard_impl(x, *, dst_sharding, concrete_mesh):
+ thunk = lambda: dispatch.apply_primitive(
+ reshard_p, x, dst_sharding=dst_sharding, concrete_mesh=concrete_mesh)
+ if concrete_mesh is None:
+ return thunk()
+ else:
+ with sharding_impls.set_mesh(concrete_mesh):
+ return thunk()
reshard_p.def_impl(_reshard_impl)
-def _reshard_transpose_rule(ct, x, dst_sharding):
+def _reshard_transpose_rule(ct, x, *, dst_sharding, concrete_mesh):
assert ad.is_undefined_primal(x)
out_sharding = x.aval.to_cotangent_aval().sharding
with mesh_lib.use_abstract_mesh(out_sharding.mesh):
- x_bar = reshard_p.bind(ct, dst_sharding=out_sharding)
+ x_bar = reshard_p.bind(ct, dst_sharding=out_sharding,
+ concrete_mesh=concrete_mesh)
return [x_bar]
ad.deflinear2(reshard_p, _reshard_transpose_rule)
-def _reshard_transpose_fancy(ct, x, dst_sharding):
+def _reshard_transpose_fancy(ct, x, *, dst_sharding, concrete_mesh):
assert isinstance(x, ad.GradAccum)
out_sharding = x.aval.to_cotangent_aval().sharding
with mesh_lib.use_abstract_mesh(out_sharding.mesh):
- x_bar = reshard_p.bind(ct, dst_sharding=out_sharding)
+ x_bar = reshard_p.bind(ct, dst_sharding=out_sharding,
+ concrete_mesh=concrete_mesh)
x.accum(x_bar)
ad.fancy_transposes[reshard_p] = _reshard_transpose_fancy
-def _reshard_hlo_lowering(ctx, x_node, *, dst_sharding):
+def _reshard_hlo_lowering(ctx, x_node, *, dst_sharding, concrete_mesh):
aval_in, = ctx.avals_in
aval_out, = ctx.avals_out
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
@@ -2853,12 +2674,13 @@ def _reshard_hlo_lowering(ctx, x_node, *, dst_sharding):
return [mlir.lower_with_sharding_in_types(ctx, x_node, aval_out, proto)]
mlir.register_lowering(reshard_p, _reshard_hlo_lowering)
-def _reshard_batcher(axis_data, vals_in, dims_in, dst_sharding):
+def _reshard_batcher(axis_data, vals_in, dims_in, dst_sharding, concrete_mesh):
x, = vals_in
d, = dims_in
vmapped_dst_sharding = batching.get_sharding_for_vmap(
axis_data, dst_sharding, d)
- y = reshard_p.bind(x, dst_sharding=vmapped_dst_sharding)
+ y = reshard_p.bind(x, dst_sharding=vmapped_dst_sharding,
+ concrete_mesh=concrete_mesh)
return y, d
batching.fancy_primitive_batchers[reshard_p] = _reshard_batcher
batching.skippable_batchers[reshard_p] = lambda _: ()
diff --git a/jax/_src/prng.py b/jax/_src/prng.py
index 4cbaa1719ff7..7feb345ccb2e 100644
--- a/jax/_src/prng.py
+++ b/jax/_src/prng.py
@@ -42,6 +42,7 @@
from jax._src.lax import control_flow as lax_control_flow
from jax._src.lax import lax
from jax._src.lax import slicing as lax_slicing
+from jax._src.lib import jaxlib_extension_version
from jax._src.lib import gpu_prng
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
@@ -402,10 +403,16 @@ def local_sharded_result_handler(aval, sharding, indices):
phys_handler = phys_handler_maker(phys_aval, phys_sharding, phys_indices)
# set up a handler that calls the physical one and wraps back up
- def handler(bufs):
- return PRNGKeyArray(aval.dtype._impl, phys_handler(bufs))
+ if jaxlib_extension_version >= 390:
+ def handler(arr):
+ return PRNGKeyArray(aval.dtype._impl, arr)
- return handler
+ return phys_handler.wrap(handler)
+ else:
+ def handler(bufs):
+ return PRNGKeyArray(aval.dtype._impl, phys_handler(bufs))
+
+ return handler
@staticmethod
def global_sharded_result_handler(aval, out_sharding, committed):
@@ -414,9 +421,14 @@ def global_sharded_result_handler(aval, out_sharding, committed):
phys_sharding = physical_sharding(aval, out_sharding)
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed)
- def handler(bufs):
- return PRNGKeyArray(aval.dtype._impl, phys_handler(bufs))
- return handler
+ if jaxlib_extension_version >= 390:
+ def handler(bufs):
+ return PRNGKeyArray(aval.dtype._impl, bufs)
+ return phys_handler.wrap(handler)
+ else:
+ def handler(bufs):
+ return PRNGKeyArray(aval.dtype._impl, phys_handler(bufs))
+ return handler
@staticmethod
def make_sharded_array(aval, sharding, arrays, committed):
diff --git a/jax/_src/random.py b/jax/_src/random.py
index bdd55ea2f013..5da171c269af 100644
--- a/jax/_src/random.py
+++ b/jax/_src/random.py
@@ -565,21 +565,17 @@ def randint_via_uniform(key, shape, minval, maxval, dtype):
if not dtypes.issubdtype(dtype, np.integer):
raise TypeError(f"randint only accepts integer dtypes, got {dtype}")
- # TODO(jakevdp): migrate users to safer randint and remove the old version.
- if config.safer_randint.value:
- info = dtypes.iinfo(dtype)
- dtype_for_sampling = dtype
- if info.bits < 32:
- # Sample in 32 bits to avoid biased results.
- dtype_for_sampling = np.dtype('int32')
- minval = jnp.asarray(minval).astype('int32').clip(int(info.min), int(info.max))
- maxval = jnp.asarray(maxval).astype('int32').clip(int(info.min), int(info.max) + 1)
-
- return maybe_auto_axes(_randint, out_sharding, shape=shape, dtype=dtype_for_sampling)(
- key, minval, maxval).astype(dtype)
-
- return maybe_auto_axes(_randint, out_sharding, shape=shape, dtype=dtype)(
- key, minval, maxval)
+ info = dtypes.iinfo(dtype)
+ dtype_for_sampling = dtype
+ if info.bits < 32:
+ # Sample in 32 bits to avoid biased results.
+ dtype_for_sampling = np.dtype('int32')
+ minval = jnp.asarray(minval).astype('int32').clip(int(info.min), int(info.max))
+ maxval = jnp.asarray(maxval).astype('int32').clip(int(info.min), int(info.max) + 1)
+
+ return maybe_auto_axes(_randint, out_sharding, shape=shape, dtype=dtype_for_sampling)(
+ key, minval, maxval).astype(dtype)
+
@jit(static_argnums=(3, 4))
def _randint(key, minval, maxval, shape, dtype) -> Array:
diff --git a/jax/_src/scipy/optimize/_lbfgs.py b/jax/_src/scipy/optimize/_lbfgs.py
index 9c6df2737dae..3f4767f101a7 100644
--- a/jax/_src/scipy/optimize/_lbfgs.py
+++ b/jax/_src/scipy/optimize/_lbfgs.py
@@ -22,6 +22,7 @@
import numpy as np
from jax._src import api
+from jax._src import dtypes
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.numpy import linalg as jnp_linalg
@@ -112,7 +113,7 @@ def _minimize_lbfgs(
Optimization results.
"""
d = len(x0)
- dtype = np.dtype(x0)
+ dtype = dtypes.dtype(x0)
# ensure there is at least one termination condition
if (maxiter is None) and (maxfun is None) and (maxgrad is None):
diff --git a/jax/_src/scipy/optimize/line_search.py b/jax/_src/scipy/optimize/line_search.py
index 9bffd95328fb..6d16b67f1c66 100644
--- a/jax/_src/scipy/optimize/line_search.py
+++ b/jax/_src/scipy/optimize/line_search.py
@@ -191,6 +191,16 @@ def body(state):
),
),
)
+ state = state._replace(
+ **_binary_replace(
+ lo_to_j & ~hi_to_lo,
+ state._asdict(),
+ dict(
+ a_rec=state.a_lo,
+ phi_rec=state.phi_lo,
+ ),
+ ),
+ )
state = state._replace(
**_binary_replace(
lo_to_j,
@@ -199,8 +209,6 @@ def body(state):
a_lo=a_j,
phi_lo=phi_j,
dphi_lo=dphi_j,
- a_rec=state.a_lo,
- phi_rec=state.phi_lo,
),
),
)
diff --git a/jax/_src/scipy/stats/poisson.py b/jax/_src/scipy/stats/poisson.py
index bf314842d3ff..bb2f9399dfdc 100644
--- a/jax/_src/scipy/stats/poisson.py
+++ b/jax/_src/scipy/stats/poisson.py
@@ -17,8 +17,8 @@
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax.lax import _const as _lax_const
-from jax._src.numpy.util import promote_args_inexact
-from jax._src.scipy.special import xlogy, gammaln, gammaincc
+from jax._src.numpy.util import promote_args_inexact, promote_dtypes_inexact, ensure_arraylike
+from jax._src.scipy.special import xlogy, entr, gammaln, gammaincc
from jax._src.typing import Array, ArrayLike
@@ -114,3 +114,126 @@ def cdf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array:
x = lax.sub(k, loc)
p = gammaincc(jnp.floor(1 + x), mu)
return jnp.where(lax.lt(x, zero), zero, p)
+
+def entropy(mu: ArrayLike, loc: ArrayLike = 0) -> Array:
+ r"""Shannon entropy of the Poisson distribution.
+
+ JAX implementation of :obj:`scipy.stats.poisson` ``entropy``.
+
+ The entropy :math:`H(X)` of a Poisson random variable
+ :math:`X \sim \text{Poisson}(\mu)` is defined as:
+
+ .. math::
+
+ H(X) = -\sum_{k=0}^\infty p(k) \log p(k)
+
+ where :math:`p(k) = e^{-\mu} \mu^k / k!` for
+ :math:`k \geq \max(0, \lfloor \text{loc} \rfloor)`.
+
+ This implementation uses **regime switching** for numerical stability
+ and performance:
+
+ - **Small** :math:`\mu < 10`: Direct summation over PMF with adaptive
+ upper bound :math:`k \leq \mu + 20`
+ - **Medium** :math:`10 \leq \mu < 100`: Summation with bound
+ :math:`k \leq \mu + 10\sqrt{\mu} + 20`
+ - **Large** :math:`\mu \geq 100`: Asymptotic Stirling approximation:
+ :math:`H(\mu) \approx \frac{1}{2} \log(2\pi e \mu) - \frac{1}{12\mu}`
+
+ Matches SciPy to relative error :math:`< 10^{-5}` across all regimes.
+
+ Args:
+ mu: arraylike, mean parameter of the Poisson distribution.
+ Must be ``> 0``.
+ loc: arraylike, optional location parameter (default: 0).
+ Accepted for API compatibility with scipy but does not
+ affect the entropy
+
+ Returns:
+ Array of entropy values with shape broadcast from ``mu`` and ``loc``.
+ Returns ``NaN`` for ``mu <= 0``.
+
+ Examples:
+ >>> from jax.scipy.stats import poisson
+ >>> poisson.entropy(5.0)
+ Array(2.204394, dtype=float32)
+ >>> poisson.entropy(jax.numpy.array([1, 10, 100]))
+ Array([1.3048419, 2.5614073, 3.7206903], dtype=float32)
+
+ See Also:
+ - :func:`jax.scipy.stats.poisson.pmf`
+ - :func:`jax.scipy.stats.poisson.logpmf`
+ - :obj:`scipy.stats.poisson`
+ """
+ mu, loc = ensure_arraylike("poisson.entropy", mu, loc)
+ promoted_mu, promoted_loc = promote_dtypes_inexact(mu, loc)
+
+ #Note: loc does not affect the entropy - translation invariant
+ #it has only been taken to maintain compatibility with scipy api
+ result_shape = jnp.broadcast_shapes(
+ promoted_mu.shape,
+ promoted_loc.shape
+ )
+
+ mu_flat = jnp.ravel(promoted_mu)
+ zero_result = jnp.zeros_like(mu_flat)
+
+
+ # Choose the computation regime based on mu value
+ result = jnp.where(
+ mu_flat == 0,
+ zero_result,
+ jnp.where(
+ mu_flat < 10,
+ _entropy_small_mu(mu_flat),
+ jnp.where(
+ mu_flat < 100,
+ _entropy_medium_mu(mu_flat),
+ _entropy_large_mu(mu_flat)
+ )
+ )
+ )
+
+ result_mu_shape = jnp.reshape(result, promoted_mu.shape)
+
+ # Restore original shape
+ return jnp.broadcast_to(result_mu_shape, result_shape)
+
+def _entropy_small_mu(mu: Array) -> Array:
+ """Entropy via direct PMF summation for small μ (< 10).
+ Uses adaptive upper bound k ≤ μ + 20 to capture >99.999% of mass.
+ """
+ max_k = 35
+
+ k = jnp.arange(max_k, dtype=mu.dtype)[:, None]
+ probs = pmf(k, mu, 0)
+
+ # Mask: only compute up to mu + 20 for each value
+ upper_bounds = jnp.ceil(mu + 20).astype(k.dtype)
+ mask = k < upper_bounds[None, :]
+ probs_masked = jnp.where(mask, probs, 0.0)
+
+ return jnp.sum(entr(probs_masked), axis=0)
+
+def _entropy_medium_mu(mu: Array) -> Array:
+ """Entropy for medium mu (10-100): Adaptive bounds based on std dev.
+
+ Bounds: k ≤ μ + 10√μ + 20. Caps at k=250 for JIT compatibility.
+ """
+ max_k = 250 # Static bound for JIT. For mu<100, upper bound < 220
+
+ k = jnp.arange(max_k, dtype=mu.dtype)[:, None]
+ probs = pmf(k, mu, 0)
+
+ upper_bounds = jnp.ceil(mu + 10 * jnp.sqrt(mu) + 20).astype(k.dtype)
+ mask = k < upper_bounds[None, :]
+ probs_masked = jnp.where(mask, probs, 0.0)
+
+ return jnp.sum(entr(probs_masked), axis=0)
+
+def _entropy_large_mu(mu: Array) -> Array:
+ """Entropy for large mu (>= 100): Asymptotic approximation.
+
+ Formula: H(λ) ≈ 0.5*log(2πeλ) - 1/(12λ) + O(λ^-2)
+ """
+ return 0.5 * jnp.log(2 * np.pi * np.e * mu) - 1.0 / (12 * mu)
diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py
index dbcfadb3fc95..d77c1299fcf0 100644
--- a/jax/_src/shard_map.py
+++ b/jax/_src/shard_map.py
@@ -40,6 +40,7 @@
from jax._src.mesh import (AbstractMesh, Mesh, BaseMesh, AxisType,
use_abstract_mesh, get_abstract_mesh,
get_concrete_mesh)
+from jax._src.pjit import reshard
from jax._src.lax import lax, parallel as lax_parallel
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo, sdy
@@ -248,6 +249,15 @@ def wrapped(*args):
fun = _implicit_pvary_on_output(fun, out_specs_thunk)
fun = _implicit_unreduced_on_output(fun, out_specs_thunk)
+ # TODO(yashkatariya): Add support for partial manual
+ mesh_axis_names_wo_vmap = (
+ frozenset(mesh.axis_names) - core.get_axis_env().explicit_mesh_axis_names)
+ if (mesh_axis_names_wo_vmap == axis_names and
+ all(mesh._name_to_type[a] == AxisType.Explicit for a in axis_names)):
+ args_flat = [a if typeof(a).sharding.spec == s
+ else reshard(a, NamedSharding(mesh, s))
+ for a, s in zip(args_flat, in_specs_flat)]
+
try:
out_flat = shard_map_p.bind(
fun, *args_flat, mesh=mesh, in_specs=in_specs_flat,
@@ -1381,8 +1391,6 @@ def _shard_map_batch(
in_specs, out_specs_thunk, check_vma: bool, manual_axes: frozenset
) -> Sequence[batching.BatchTracer]:
in_vals, in_dims = unzip2(map(trace.to_batch_info, in_tracers))
- if any(isinstance(d, batching.RaggedAxis) for d in in_dims):
- raise NotImplementedError
spmd_axis_name = trace.axis_data.spmd_name
explicit_mesh_axis = trace.axis_data.explicit_mesh_axis
if spmd_axis_name is not None:
diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py
index 3f039d862672..48247bcce753 100644
--- a/jax/_src/sharding_impls.py
+++ b/jax/_src/sharding_impls.py
@@ -20,7 +20,6 @@
import dataclasses
import functools
import math
-import warnings
from typing import Any, NamedTuple, cast
from jax._src import config
@@ -33,7 +32,6 @@
from jax._src import source_info_util
from jax._src import xla_bridge as xb
from jax._src import mesh_utils
-from jax._src import deprecations
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir.dialects import sdy
from jax._src.named_sharding import ( # noqa: F401
@@ -181,6 +179,9 @@ def is_fully_replicated(self) -> bool:
def is_fully_addressable(self) -> bool:
return xb.process_index(self._device.client) == self._device.process_index
+ def check_compatible_aval(self, aval_shape: Shape) -> None:
+ return
+
SingleDeviceSharding.__module__ = 'jax.sharding'
@util.cache(max_size=4096, trace_context_in_key=False)
@@ -193,7 +194,6 @@ def pmap_sharding_devices_indices_map(
@use_cpp_class(xc.PmapSharding)
class PmapSharding(jsharding.Sharding):
- """Describes a sharding used by :func:`jax.pmap`."""
devices: np.ndarray
sharding_spec: sharding_specs.ShardingSpec
_internal_device_list: xc.DeviceList
@@ -324,6 +324,9 @@ def is_fully_replicated(self) -> bool:
def is_fully_addressable(self) -> bool:
return self._internal_device_list.is_fully_addressable
+ def check_compatible_aval(self, aval_shape: Shape) -> None:
+ return
+
def shard_shape(self, global_shape: Shape) -> Shape:
sharded_dim = None
sharded_dim_size = None
@@ -472,7 +475,6 @@ def get_replicated(cls, device_assignment, *, memory_kind: str | None = None):
def prepare_axis_resources(axis_resources, arg_name,
allow_unconstrained_dims=False):
- # PyTrees don't treat None values as leaves, so we use an is_leaf function.
entries, treedef = tree_util.tree_flatten(
axis_resources, is_leaf=lambda x: x is None)
what = f"{arg_name} leaf specifications"
@@ -485,6 +487,9 @@ def prepare_axis_resources(axis_resources, arg_name,
if isinstance(entry, PmapSharding):
raise ValueError(f'One of {what} got sharding {entry} which is not '
'allowed.')
+ if isinstance(entry, NamedSharding) and entry.mesh.empty:
+ raise ValueError(f'One of {what} got an empty NamedSharding: {entry} '
+ 'which is not allowed.')
if (not allow_unconstrained_dims and isinstance(entry, NamedSharding) and
PartitionSpec.UNCONSTRAINED in entry.spec):
raise ValueError(
@@ -1188,18 +1193,7 @@ def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str],
'`jax.make_mesh` does not support multi-slice topologies. Please use'
' jax.experimental.mesh_utils.create_hybrid_device_mesh')
if axis_types is None:
- if deprecations.is_accelerated('jax-make-mesh-default-explicit'):
- axis_types = (mesh_lib.AxisType.Explicit,) * len(mesh_devices.shape)
- else:
- warnings.warn(
- 'The default axis_types will change in JAX v0.9.0 to'
- ' jax.sharding.AxisType.Explicit. To maintain the old behavior, pass'
- ' `axis_types=(jax.sharding.AxisType.Auto,) * len(axis_names)`. To'
- ' opt-into the new behavior, pass'
- ' `axis_types=(jax.sharding.AxisType.Explicit,) * len(axis_names)',
- category=DeprecationWarning,
- stacklevel=2,
- )
+ axis_types = (mesh_lib.AxisType.Explicit,) * len(mesh_devices.shape)
return mesh_lib.Mesh(mesh_devices, axis_names, axis_types=axis_types)
class set_mesh:
diff --git a/jax/_src/stages.py b/jax/_src/stages.py
index 779744b4ca99..387fca3ee460 100644
--- a/jax/_src/stages.py
+++ b/jax/_src/stages.py
@@ -378,7 +378,7 @@ def _traced_args_info(self):
def _traced_out_info(self):
out_shardings = [None if isinstance(s, UnspecifiedValue) else s
- for s in self._params['out_shardings']]
+ for s in self._params['out_shardings']]
out_layouts = [None if isinstance(l, AutoLayout) else l
for l in self._params['out_layouts']]
out = []
@@ -402,8 +402,12 @@ class Traced(Stage):
A traced computation is ready for lowering. This class carries the
traced representation with the remaining information needed to later
lower, compile, and execute it.
+
+ Provides access to both the hijax (high-level) and lojax (low-level)
+ representations via `.jaxpr` and `.lojax` properties respectively.
"""
- __slots__ = ['_meta_tys_flat', '_params', '_in_tree', 'out_tree', '_consts']
+ __slots__ = ['_meta_tys_flat', '_params', '_in_tree', 'out_tree', '_consts',
+ '_lojax']
def __init__(self, meta_tys_flat, params, in_tree, out_tree, consts):
self._meta_tys_flat = meta_tys_flat
@@ -411,6 +415,7 @@ def __init__(self, meta_tys_flat, params, in_tree, out_tree, consts):
self._in_tree = in_tree
self.out_tree = out_tree
self._consts = consts
+ self._lojax = None
jaxpr = property(lambda self: self._params['jaxpr'])
fun_name = property(lambda self: self._params['name'])
@@ -422,12 +427,18 @@ def __init__(self, meta_tys_flat, params, in_tree, out_tree, consts):
def out_avals(self):
return tree_unflatten(self.out_tree, self.jaxpr.out_avals)
- def fall(self):
+ @property
+ def lojax(self) -> LoJax:
+ if self._lojax is not None:
+ return self._lojax
+
if not self.jaxpr.is_high:
- return Fallen(self._meta_tys_flat, self._params, self._in_tree,
- self.out_tree, (self._in_tree, self.jaxpr.in_avals),
- (self.out_tree, self.jaxpr.out_avals),
- self._consts)
+ self._lojax = LoJax(
+ self._meta_tys_flat, self._params, self._in_tree, self.out_tree,
+ (self._in_tree, self.jaxpr.in_avals),
+ (self.out_tree, self.jaxpr.out_avals),
+ self._consts)
+ return self._lojax
# TODO(mattjj): when pmap is deleted, merge with pjit.py BUILD rule
from jax._src.interpreters import partial_eval as pe # type:ignore
@@ -435,23 +446,45 @@ def fall(self):
_, closed_over_himutables = pe.convert_const_himutables(hi_jaxpr)
if closed_over_himutables: raise NotImplementedError # TODO(mattjj)
lo_jaxpr = pe.lower_jaxpr(hi_jaxpr)
- in_tree = lojax_pytree(hi_jaxpr.in_aval_qdds, self._in_tree)
- out_tree = lojax_pytree(hi_jaxpr.out_avals, self.out_tree)
+ if any(a.is_high for a in hi_jaxpr.final_aval_qdds):
+ in_tree = lojax_pytree(hi_jaxpr.in_aval_qdds, self._in_tree)
+ else:
+ in_tree = self._in_tree
+ if any(a.is_high for a in hi_jaxpr.out_avals):
+ out_tree = lojax_pytree(hi_jaxpr.out_avals, self.out_tree)
+ else:
+ out_tree = self.out_tree
params = dict(lojax_expand_params(hi_jaxpr, self._params), jaxpr=lo_jaxpr)
lo_meta_tys = [mty.replace(aval=lo_ty)
for mty, aq in zip(self._meta_tys_flat, hi_jaxpr.in_aval_qdds)
for lo_ty in (mty.aval.lo_ty_qdd(aq.qdd)
if mty.aval.has_qdd else mty.aval.lo_ty())]
- return Fallen(lo_meta_tys, params, in_tree, out_tree,
- (self._in_tree, hi_jaxpr.final_aval_qdds),
- (self.out_tree, hi_jaxpr.out_avals),
- self._consts)
+ self._lojax = LoJax(
+ lo_meta_tys, params, in_tree, out_tree,
+ (self._in_tree, hi_jaxpr.final_aval_qdds),
+ (self.out_tree, hi_jaxpr.out_avals),
+ self._consts)
+ return self._lojax
def lower(self, *, lowering_platforms: tuple[str, ...] | None = None,
_private_parameters: mlir.LoweringParameters | None = None):
"""Lower to compiler input, returning a ``Lowered`` instance."""
- return self.fall().lower(lowering_platforms=lowering_platforms,
- _private_parameters=_private_parameters)
+ lo = self.lojax
+ if _private_parameters is None:
+ _private_parameters = mlir.LoweringParameters()
+ try:
+ from jax._src.pjit import _resolve_and_lower # type: ignore
+ lowering = _resolve_and_lower(
+ lo._meta_tys_flat, **lo._params, lowering_platforms=lowering_platforms,
+ lowering_parameters=_private_parameters, pgle_profiler=None)
+ except DeviceAssignmentMismatchError as e:
+ fails, = e.args
+ msg = _device_assignment_mismatch_error(
+ lo._params['name'], fails, lo._meta_tys_flat, 'jit',
+ lo.jaxpr.debug_info.safe_arg_names(len(lo.jaxpr.in_avals)))
+ raise ValueError(msg) from None
+ return Lowered(lowering, lo.args_info, lo.out_tree,
+ in_types=lo._in_types, out_types=lo._out_types)
def lojax_expand_params(jaxpr, params):
@@ -468,8 +501,7 @@ def lojax_pytree(hi_avals, tree):
return tree_structure(tree_unflatten(tree, lo_avals))
-class Fallen(Stage):
- """True leader of the Decepticons."""
+class LoJax:
__slots__ = ['_meta_tys_flat', '_params', '_in_tree', 'out_tree',
'_consts', '_in_types', '_out_types']
@@ -489,28 +521,6 @@ def __init__(self, meta_tys_flat, params, in_tree, out_tree, in_types, out_types
out_info = property(_traced_out_info)
_num_consts = property(lambda self: len(self._consts))
- @property
- def out_avals(self):
- return tree_unflatten(self.out_tree, self.jaxpr.out_avals)
-
- def lower(self, *, lowering_platforms: tuple[str, ...] | None = None,
- _private_parameters: mlir.LoweringParameters | None = None):
- """Lower to compiler input, returning a ``Lowered`` instance."""
- if _private_parameters is None:
- _private_parameters = mlir.LoweringParameters()
- try:
- from jax._src.pjit import _resolve_and_lower # type: ignore
- lowering = _resolve_and_lower(
- self._meta_tys_flat, **self._params, lowering_platforms=lowering_platforms,
- lowering_parameters=_private_parameters, pgle_profiler=None)
- except DeviceAssignmentMismatchError as e:
- fails, = e.args
- msg = _device_assignment_mismatch_error(
- self._params['name'], fails, self._meta_tys_flat, 'jit',
- self.jaxpr.debug_info.safe_arg_names(len(self.jaxpr.in_avals)))
- raise ValueError(msg) from None
- return Lowered(lowering, self.args_info, self.out_tree,
- in_types=self._in_types, out_types=self._out_types)
class Lowered(Stage):
@@ -543,6 +553,20 @@ def __init__(self, lowering: Lowering, args_info,
self._in_types = in_types # type: ignore
self._out_types = out_types # type: ignore
+ @property
+ def in_avals(self):
+ in_avals_ = self._lowering.compile_args.get("global_in_avals", None)
+ if in_avals_ is None: # For old pmap code i.e. PmapComputation
+ return tree_util.tree_map(lambda x: x._aval, self.args_info)
+ kept_var_idx = self._lowering.compile_args["kept_var_idx"]
+ non_dce_avals = self._lowering.compile_args["all_args_info"].in_avals
+ if self.in_tree.num_leaves > len(in_avals_):
+ iter_in_avals = iter(in_avals_)
+ in_avals_ = [
+ next(iter_in_avals) if i in kept_var_idx
+ else a for i, a in zip(range(self.in_tree.num_leaves), non_dce_avals)]
+ return self.in_tree.unflatten(in_avals_)
+
@property
def out_info(self): # PyTree of OutInfo
out_avals = self._lowering.compile_args["global_out_avals"]
@@ -707,6 +731,17 @@ def memory_analysis(self) -> Any | None:
except NotImplementedError:
return None
+ @property
+ def in_avals(self):
+ in_avals_ = self._executable.in_avals
+ if self.in_tree.num_leaves > len(in_avals_):
+ iter_in_avals = iter(in_avals_)
+ non_dce_avals = self._executable._all_args_info.in_avals
+ in_avals_ = [
+ next(iter_in_avals) if i in self._executable._kept_var_idx
+ else a for i, a in zip(range(self.in_tree.num_leaves), non_dce_avals)]
+ return self.in_tree.unflatten(in_avals_)
+
@property
def out_info(self): # PyTree of jax.ShapeDtypeStruct
out_avals = self._executable.out_avals
@@ -783,8 +818,6 @@ def call(*args, **kwargs):
# which might conflict here.
params = args[0]
args = args[1:] # Not including const_args
- if config.dynamic_shapes.value:
- raise NotImplementedError
if params.no_kwargs and kwargs:
kws = ', '.join(kwargs.keys())
raise NotImplementedError(
@@ -1005,7 +1038,7 @@ def _device_assignment_mismatch_error(fun_name, fails, args_flat, api_name,
first, second = mismatched_args_msg # pytype: disable=bad-unpacking
extra_msg = f" Got {first} and {second}"
elif len(mismatched_args_msg) == 1:
- first, second = fails
+ first, second = fails
# Choose the failure left which is not already covered by ARG_SHARDING.
left = second if first.m_type == MismatchType.ARG_SHARDING else first
extra_msg = f" Got {mismatched_args_msg[0]} and{left._str(api_name)}"
diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py
index 56d2ae8e868f..69d26237db54 100644
--- a/jax/_src/state/primitives.py
+++ b/jax/_src/state/primitives.py
@@ -70,7 +70,6 @@
get_p = core.Primitive("get")
get_p.is_effectful = lambda params: True # type: ignore
get_p.def_impl(partial(dispatch.apply_primitive, get_p))
-batching.ragged_prop_rules[get_p] = batching.ragged_mask_transfer_identity
get_p.is_high = lambda ref_aval, *_, tree: ref_aval.is_high # type: ignore
def _get_to_lojax(ref, *idx, tree):
@@ -192,16 +191,6 @@ def _swap_to_lojax(ref, val, *idx, tree):
swap_p.to_lojax = _swap_to_lojax # type: ignore
-def swap_ragged_prop_rule(eqn_params, invar_raggedness, outvars):
- assert len(invar_raggedness) == 2
- invar_raggedness_lhs = invar_raggedness[0]
- invar_raggedness_rhs = invar_raggedness[1]
-
- return [invar_raggedness_rhs, invar_raggedness_lhs], [None]
-
-
-batching.ragged_prop_rules[swap_p] = swap_ragged_prop_rule
-
@partial(traceback_util.api_boundary, repro_api_name="jax.ref.swap")
def ref_swap(
ref: core.Ref | TransformedRef,
@@ -1079,6 +1068,18 @@ def _addupdate_vmap(axis_data, batched_args, batched_dims, *, tree):
broadcast_to_p = core.Primitive('broadcast_to')
def broadcast_to(a: Array, shape: tuple[int, ...]) -> Array:
+ """Broadcasts an array to a new shape.
+
+ Args:
+ a: The array to broadcast.
+ shape: The desired shape to broadcast to.
+
+ Returns:
+ An array of shape ``shape``.
+
+ See Also:
+ :func:`jax.numpy.broadcast_to`
+ """
import jax.numpy as jnp # pytype: disable=import-error
a = jnp.asarray(a)
if a.shape == shape:
@@ -1114,9 +1115,12 @@ def _ref_lin(nzs, x, *, memory_space, kind):
nz, = nzs
x_ref = core.ref_p.bind(x, memory_space=memory_space, kind=kind)
def mut_lin(_, x_dot):
+ if kind == 'anselm_ref':
+ aval = x_dot.aval if type(x_dot) is ad.Zero else core.typeof(x_dot)
+ return ad.Zero(AbstractRef(aval))
zero = ad_util.instantiate(x_dot)
return core.ref_p.bind(zero, memory_space=memory_space, kind=kind)
- return x_ref, True, None, mut_lin
+ return x_ref, kind != 'anselm_ref', None, mut_lin
ad.primitive_jvps[core.ref_p] = _ref_jvp
ad.primitive_linearizations[core.ref_p] = _ref_lin
diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py
index 2644f8392416..3c7c09684f99 100644
--- a/jax/_src/state/types.py
+++ b/jax/_src/state/types.py
@@ -409,7 +409,10 @@ def is_high(self):
return self.inner_aval.is_high
def lo_ty(self):
- return map(AbstractRef, self.inner_aval.lo_ty())
+ return [
+ AbstractRef(x, memory_space=self.memory_space)
+ for x in self.inner_aval.lo_ty()
+ ]
def lower_val(self, ref):
if not self.is_high:
@@ -552,7 +555,12 @@ def __repr__(self) -> str:
__str__ = __repr__
def to_tangent_aval(self):
- return AbstractRef(self.inner_aval.to_tangent_aval(), self.memory_space, kind=self.kind)
+ return AbstractRef(self.inner_aval.to_tangent_aval(), self.memory_space,
+ kind=self.kind)
+
+ def to_cotangent_aval(self):
+ return AbstractRef(self.inner_aval.to_cotangent_aval(), self.memory_space,
+ kind=self.kind)
def __eq__(self, other):
return (type(self) is type(other) and self.inner_aval == other.inner_aval
diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py
index f5217b5f6d38..b9397babb7b7 100644
--- a/jax/_src/test_util.py
+++ b/jax/_src/test_util.py
@@ -131,8 +131,7 @@ def to_default_dtype(arr: ArrayLike) -> np.ndarray:
"""Convert a value to an array with JAX's default dtype.
This is generally used for type conversions of values returned by numpy functions,
- to make their dtypes take into account the state of the ``jax_enable_x64`` and
- ``jax_default_dtype_bits`` flags.
+ to make their dtypes take into account the state of the ``jax_enable_x64`` flag.
"""
arr = np.asarray(arr)
dtype_fn = _dtypes.default_types.get(arr.dtype.kind)
@@ -143,8 +142,7 @@ def with_jax_dtype_defaults(func: Callable[..., Any], use_defaults: bool = True)
This is generally used to wrap numpy functions within tests, in order to make
their default output dtypes match those of corresponding JAX functions, taking
- into account the state of the ``jax_enable_x64`` and ``jax_default_dtype_bits``
- flags.
+ into account the state of the ``jax_enable_x64`` flag.
Args:
use_defaults : whether to convert any given output to the default dtype. May be
@@ -431,7 +429,7 @@ def is_sanitized():
# built at least `date``.
# TODO(b/327203806): after libtpu adds a XLA version and the oldest support
# libtpu contains the XLA version, remove using built time to skip tests.
-def if_cloud_tpu_at_least(year: int, month: int, day: int):
+def is_cloud_tpu_at_least(year: int, month: int, day: int):
date = datetime.date(year, month, day)
if not is_cloud_tpu():
return True
@@ -545,9 +543,11 @@ def test_method_wrapper(self, *args, **kwargs):
)
def get_cuda_nonportable_max_cluster_size():
- if device_kind_match("GB10$"):
- # 12 is the nonportable maximum cluster size on DGX Spark,
- # determined by querying cuOccupancyMaxPotentialClusterSize.
+ # Per-device nonportable maximum cluster sizes for Jetson Thor and DGX
+ # Spark (GB10) determined by querying cuOccupancyMaxPotentialClusterSize
+ if device_kind_match("Thor$"):
+ return 8
+ elif device_kind_match("GB10$"):
return 12
# 16 is the nonportable maximum cluster size on:
# - Hopper: https://docs.nvidia.com/cuda/hopper-tuning-guide/index.html#:~:text=cluster%20size%20of-,16,-by%20opting%20in
@@ -1246,8 +1246,6 @@ class JaxTestCase(parameterized.TestCase):
'jax_legacy_prng_key': 'error',
}
-
-
def setUp(self):
super().setUp()
self.enterContext(assert_global_configs_unchanged())
@@ -1339,14 +1337,16 @@ def assertAllClose(self, actual, desired, *, check_dtypes=True, atol=None, rtol=
rtol=rtol, canonicalize_dtypes=canonicalize_dtypes,
err_msg=err_msg)
elif is_sequence(actual) and not hasattr(actual, '__array__'):
- self.assertTrue(is_sequence(desired) and not hasattr(desired, '__array__'))
+ self.assertTrue(is_sequence(desired) and not hasattr(desired, '__array__'),
+ msg=f"Expected sequence, got {desired}")
self.assertEqual(len(actual), len(desired))
for actual_elt, desired_elt in zip(actual, desired):
self.assertAllClose(actual_elt, desired_elt, check_dtypes=check_dtypes, atol=atol,
rtol=rtol, canonicalize_dtypes=canonicalize_dtypes,
err_msg=err_msg)
elif hasattr(actual, '__array__') or np.isscalar(actual):
- self.assertTrue(hasattr(desired, '__array__') or np.isscalar(desired))
+ self.assertTrue(hasattr(desired, '__array__') or np.isscalar(desired),
+ msg=f"Expected array-like, got {desired}")
if check_dtypes:
self.assertDtypesMatch(actual, desired, canonicalize_dtypes=canonicalize_dtypes)
actual = np.asarray(actual)
@@ -1554,7 +1554,7 @@ def mesh_fn(*args, **kwargs):
def create_mesh(mesh_shape, axis_names, iota_order=False, axis_types=None):
size = math.prod(mesh_shape)
if len(xla_bridge.devices()) < size:
- raise unittest.SkipTest(f"Test requires {size} global devices.")
+ raise unittest.SkipTest(f"Test requires {size} global devices and found {len(xla_bridge.devices())}.")
if iota_order:
devices = sorted(xla_bridge.devices(), key=lambda d: d.id)
mesh_devices = np.array(devices[:size]).reshape(mesh_shape)
diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py
index f9c94f860ced..c7c2ceba1f0e 100644
--- a/jax/_src/tpu_custom_call.py
+++ b/jax/_src/tpu_custom_call.py
@@ -101,6 +101,7 @@ class MemorySpace(enum.Enum):
SEMAPHORE_MEM = enum.auto()
SMEM = enum.auto()
HOST = enum.auto()
+ SC_SCALAR_SEMAPHORE_MEM = enum.auto()
@property
def color(self) -> int:
@@ -110,6 +111,8 @@ def color(self) -> int:
return 1
elif self == MemorySpace.SEMAPHORE_MEM:
return 2
+ elif self == MemorySpace.SC_SCALAR_SEMAPHORE_MEM:
+ return 8
elif self == MemorySpace.SMEM:
return 4
elif self == MemorySpace.HOST:
@@ -142,6 +145,11 @@ class TpuSideEffectType(enum.Enum):
SIDE_EFFECTING = "side_effecting"
+class Tiling(enum.Enum):
+ COMPACT = "TILING_COMPACT"
+ SPARSE_CORE = "TILING_SPARSE_CORE"
+
+
@dataclasses.dataclass(frozen=True)
class CustomCallBackendConfig:
"""Represents an unserialized backend config for custom calls."""
@@ -163,6 +171,7 @@ class CustomCallBackendConfig:
input_memory_spaces: tuple[MemorySpace | None, ...] | None
skip_device_barrier: bool
shape_invariant_numerics: bool
+ tiling: Tiling | None = None # Only used for SparseCore.
def __post_init__(self):
if self.allow_input_fusion is not None:
@@ -192,9 +201,7 @@ def to_json(self) -> bytes:
config.write(str(self.collective_id).encode("ascii"))
if self.cost_estimate is not None:
config.write(b', "cost_estimate": ')
- config.write(
- json.dumps(dict(self.cost_estimate), sort_keys=True).encode("ascii")
- )
+ config.write(_compact_json_object(**self.cost_estimate))
if self.needs_hlo_passes:
config.write(b', "needs_hlo_passes": ')
config.write(str(self.needs_hlo_passes).lower().encode("ascii"))
@@ -211,7 +218,6 @@ def to_json(self) -> bytes:
config.write(b', "allow_input_fusion": [')
for i, value in enumerate(self.allow_input_fusion):
config.write(b"true" if value else b"false")
- # config.write(str(value).lower().encode("ascii"))
if i + 1 != len(self.allow_input_fusion):
config.write(b",")
config.write(b"]")
@@ -261,6 +267,9 @@ def to_json(self) -> bytes:
config.write(b', "skip_device_barrier": ')
config.write(str(self.skip_device_barrier).lower().encode("ascii"))
config.write(b"}") # End of custom_call_config.
+ if self.tiling is not None:
+ config.write(b', "sparse_core_config": ')
+ config.write(_compact_json_object(tiling=self.tiling.value))
if self.device_type is not None:
config.write(b', "device_type": ')
config.write(
@@ -300,6 +309,12 @@ def to_json(self) -> bytes:
return config.getvalue()
+def _compact_json_object(**kwargs: Any) -> bytes:
+ return json.dumps(
+ kwargs, sort_keys=True, indent=0, separators=(",", ":")
+ ).encode("ascii")
+
+
@tpu_custom_call_p.def_abstract_eval
def _tpu_custom_call_abstract_eval(*_, out_avals, **__):
return out_avals
@@ -322,8 +337,9 @@ def _tpu_custom_call_lowering(
result_types = [mlir.aval_to_ir_type(aval) for aval in out_avals]
axis_context = ctx.module_context.axis_context
if isinstance(axis_context, sharding_impls.SPMDAxisContext):
+ manual_axes = axis_context.manual_axes | set(axis_context.mesh.manual_axes)
if (axis_context.manual_axes and
- axis_context.manual_axes != frozenset(axis_context.mesh.axis_names)):
+ manual_axes != frozenset(axis_context.mesh.axis_names)):
raise NotImplementedError(
"Mosaic kernels cannot be automatically partitioned. Please wrap the"
" call in a shard_map."
@@ -365,7 +381,9 @@ def _tpu_custom_call_lowering(
)
metadata_dict = {}
if metadata is not None:
- metadata_dict["kernel_metadata"] = ir.StringAttr.get(json.dumps(metadata))
+ metadata_dict["kernel_metadata"] = ir.StringAttr.get(
+ _compact_json_object(**metadata)
+ )
assert isinstance(has_side_effects, TpuSideEffectType)
if has_side_effects == TpuSideEffectType.DATAFLOW_SIDE_EFFECTING:
metadata_dict["xla_allow_dce_side_effecting_op"] = ir.StringAttr.get("true")
@@ -541,6 +559,7 @@ def _lower_to_custom_call_config(
allow_collective_id_without_custom_barrier: bool = False,
shape_invariant_numerics: bool = False,
needs_layout_passes: bool | None = None,
+ tiling: Tiling | None = None,
) -> CustomCallBackendConfig:
device_type = _get_device_type(module)
needs_hlo_passes = _MOSAIC_ALLOW_HLO.value
@@ -575,6 +594,7 @@ def _lower_to_custom_call_config(
skip_device_barrier=skip_device_barrier,
allow_collective_id_without_custom_barrier=allow_collective_id_without_custom_barrier,
shape_invariant_numerics=shape_invariant_numerics,
+ tiling=tiling,
)
@@ -600,6 +620,7 @@ def _lowered_to_custom_call_config(
skip_device_barrier: bool = False,
allow_collective_id_without_custom_barrier: bool = False,
shape_invariant_numerics: bool = False,
+ tiling: Tiling | None = None,
):
if has_custom_barrier:
if collective_id is None:
@@ -616,7 +637,7 @@ def _lowered_to_custom_call_config(
"vmem_limit_bytes must be an int: provided with a"
f" {type(vmem_limit_bytes)}."
)
- config = CustomCallBackendConfig(
+ return CustomCallBackendConfig(
lowered_module_asm,
has_communication,
collective_id,
@@ -635,8 +656,8 @@ def _lowered_to_custom_call_config(
input_memory_spaces=input_memory_spaces,
skip_device_barrier=skip_device_barrier,
shape_invariant_numerics=shape_invariant_numerics,
+ tiling=tiling,
)
- return config
def lower_module_to_custom_call(
@@ -662,6 +683,7 @@ def lower_module_to_custom_call(
allow_collective_id_without_custom_barrier: bool = False,
shape_invariant_numerics: bool = False,
needs_layout_passes: bool | None = None,
+ tiling: Tiling | None = None,
) -> Sequence[ir.Value]:
if isinstance(has_side_effects, bool):
has_side_effects = (
@@ -686,6 +708,7 @@ def lower_module_to_custom_call(
allow_collective_id_without_custom_barrier=allow_collective_id_without_custom_barrier,
shape_invariant_numerics=shape_invariant_numerics,
needs_layout_passes=needs_layout_passes,
+ tiling=tiling,
)
return _tpu_custom_call_lowering(
ctx,
diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py
index 4f439a770e69..ad8ac40ff836 100644
--- a/jax/_src/tree_util.py
+++ b/jax/_src/tree_util.py
@@ -38,6 +38,7 @@
H = TypeVar("H", bound=Hashable)
Leaf = Any
+PyTree = Any
PyTreeDef = pytree.PyTreeDef
default_registry = pytree.default_registry()
@@ -92,6 +93,13 @@ def tree_leaves(tree: Any,
return default_registry.flatten(tree, is_leaf)[0]
+@export
+def tree_leaves_checked(treedef_expected: PyTreeDef, tree: Any) -> list[Leaf]:
+ flat_vals, treedef_actual = tree_flatten(tree)
+ assert treedef_actual == treedef_expected
+ return flat_vals
+
+
@export
def tree_structure(tree: Any,
is_leaf: None | (Callable[[Any],
@@ -543,7 +551,7 @@ class Partial(functools.partial):
>>> print_zero()
0
>>> call_func(print_zero) # doctest:+ELLIPSIS
- JitTracer<~int32[]>
+ JitTracer(~int32[])
"""
def __new__(klass, func, *args, **kw):
@@ -1336,3 +1344,119 @@ def _prefix_error(
f"{prefix_tree_keys} and {full_tree_keys}")
for k, t1, t2 in zip(prefix_tree_keys, prefix_tree_children, full_tree_children):
yield from _prefix_error((*key_path, k), t1, t2)
+
+# === flat tree ===
+
+class FlatTree:
+ """A FlatTree stores a treedef and a flat list of values. It's meant to be
+ isomorphic to the corresponding pytree but we can map over it more easily.
+ Compared to `tree_map`, FlatTree.map has these benefits:
+ 1. It doesn't touch user flatten/unflatten code (which shouldn't have side
+ effects but sometimes does in practice).
+ 2. It can be faster, because it skips the recursive traversal.
+ 3. It actually obeys the functor rules. For example,
+ `flat_tree.map(lambda x: (f(x), g(x))).unzip2()[0]` will give
+ the same result as `flat_tree.map(f)`, whereas in the `tree_map` version
+ the tuple-returning function would change the tree structure and `unzip`
+ wouldn't be able to recover it.
+ """
+ def __init__(self, vals:Sequence, treedef:PyTreeDef):
+ assert isinstance(treedef, pytree.PyTreeDef)
+ self.tree = treedef
+ self.vals = tuple(vals)
+
+ def map(self, f:Callable) -> FlatTree:
+ ans_vals = []
+ for x in self.vals:
+ ans_vals.append(f(x))
+ return FlatTree(ans_vals, self.tree)
+
+ def map2(self:FlatTree, f:Callable, t2:FlatTree) -> FlatTree:
+
+ n = len(self)
+ assert len(t2) == n
+ ans_vals = []
+ for x1, x2 in zip(self.vals, t2.vals):
+ ans_vals.append(f(x1, x2))
+ return FlatTree(ans_vals, self.tree)
+
+ def map3(
+ self:FlatTree, f:Callable, t2:FlatTree, t3:FlatTree) -> FlatTree:
+ n = len(self)
+ assert len(t2) == n and len(t3) == n
+ ans_vals = []
+ for x1, x2, x3 in zip(self.vals, t2.vals, t3.vals):
+ ans_vals.append(f(x1, x2, x3))
+ return FlatTree(ans_vals, self.tree)
+
+ def zip(self, t2:FlatTree) -> FlatTree:
+ assert False
+
+ def unzip2(self:FlatTree) -> tuple[FlatTree, FlatTree]:
+ ys = []
+ zs = []
+ for y, z in self.vals:
+ ys.append(y)
+ zs.append(z)
+ return FlatTree(ys, self.tree), FlatTree(zs, self.tree)
+
+ # TODO: add map3, zip3, unzip3 etc. as needed
+
+ @staticmethod
+ def pack(tree):
+ # We could generalize this to arbitrary pytrees of FlatTree but tuples/dicts
+ # are sufficient for now.
+ if isinstance(tree, FlatTree):
+ return tree
+ elif isinstance(tree, tuple):
+ vals = []
+ trees = []
+ for child_tree in tree:
+ child = FlatTree.pack(child_tree)
+ vals.extend(child.vals)
+ trees.append(child.tree)
+ return FlatTree(vals, treedef_tuple(trees))
+ elif isinstance(tree, dict):
+ # only empty case handled for now
+ if tree == {}:
+ return FlatTree.flatten({})
+ else:
+ assert False
+ else:
+ assert False
+
+ def unpack(self:FlatTree) -> tuple[FlatTree, ...]:
+ # TODO: this is O(N) not O(1) (with N as the number of leaves). If it
+ # becomes a problem we can fix it with a fancier data tree.
+ trees = treedef_children(self.tree)
+ children = []
+ offset = 0
+ for tree in trees:
+ new_offset = offset + tree.num_leaves
+ children.append(FlatTree(self.vals[offset:new_offset], tree))
+ offset = new_offset
+ return tuple(children)
+
+ @staticmethod
+ def flatten(tree: PyTree) -> FlatTree:
+ return FlatTree(*tree_flatten(tree))
+
+ def unflatten(self) -> PyTree:
+ return tree_unflatten(self.tree, self.vals)
+
+ def update_from_list(self, new_vals:list) -> FlatTree:
+ return FlatTree(new_vals, self.tree)
+
+ def __len__(self):
+ return self.tree.num_leaves
+
+ def __iter__(self):
+ return self.vals.__iter__()
+
+ def __eq__(self, other):
+ return (isinstance(other, FlatTree)
+ and self.vals == other.vals
+ and self.tree == other.tree)
+
+ def __hash__(self):
+ return hash((self.vals, self.tree))
diff --git a/jax/core.py b/jax/core.py
index d82cf3592482..ff8a29779596 100644
--- a/jax/core.py
+++ b/jax/core.py
@@ -15,35 +15,28 @@
# Note: import as is required for names to be exported.
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
+import jax._src.core as _src_core
from jax._src.core import (
- AbstractToken as AbstractToken,
AbstractValue as AbstractValue,
Atom as Atom,
CallPrimitive as CallPrimitive,
DebugInfo as DebugInfo,
- DShapedArray as DShapedArray,
DropVar as DropVar,
Effect as Effect,
Effects as Effects,
- get_opaque_trace_state as get_opaque_trace_state,
InconclusiveDimensionOperation as InconclusiveDimensionOperation,
JaxprPpContext as JaxprPpContext,
JaxprPpSettings as JaxprPpSettings,
JaxprTypeError as JaxprTypeError,
- nonempty_axis_env as nonempty_axis_env_DO_NOT_USE, # noqa: F401
OutputType as OutputType,
ParamDict as ParamDict,
ShapedArray as ShapedArray,
Trace as Trace,
Tracer as Tracer,
- unsafe_am_i_under_a_jit as unsafe_am_i_under_a_jit_DO_NOT_USE, # noqa: F401
- unsafe_am_i_under_a_vmap as unsafe_am_i_under_a_vmap_DO_NOT_USE, # noqa: F401
- unsafe_get_axis_names as unsafe_get_axis_names_DO_NOT_USE, # noqa: F401
Value as Value,
abstract_token as abstract_token,
aval_mapping_handlers as aval_mapping_handlers,
call as call,
- call_impl as call_impl,
check_jaxpr as check_jaxpr,
concrete_or_error as concrete_or_error,
concretization_function_error as concretization_function_error,
@@ -53,26 +46,86 @@
eval_jaxpr as eval_jaxpr,
find_top_trace as find_top_trace,
gensym as gensym,
- get_aval as get_aval,
+ get_opaque_trace_state as get_opaque_trace_state,
is_concrete as is_concrete,
is_constant_dim as is_constant_dim,
is_constant_shape as is_constant_shape,
jaxprs_in_params as jaxprs_in_params,
literalable_types as literalable_types,
- mapped_aval as mapped_aval,
max_dim as max_dim,
min_dim as min_dim,
new_jaxpr_eqn as new_jaxpr_eqn,
no_axis_name as no_axis_name,
no_effects as no_effects,
+ nonempty_axis_env as nonempty_axis_env_DO_NOT_USE, # noqa: F401
primal_dtype_to_tangent_dtype as primal_dtype_to_tangent_dtype,
pytype_aval_mappings as pytype_aval_mappings,
- set_current_trace as set_current_trace,
- subjaxprs as subjaxprs,
- take_current_trace as take_current_trace,
trace_ctx as trace_ctx,
- TraceTag as TraceTag,
- traverse_jaxpr_params as traverse_jaxpr_params,
- unmapped_aval as unmapped_aval,
+ unsafe_am_i_under_a_jit as unsafe_am_i_under_a_jit_DO_NOT_USE, # noqa: F401
+ unsafe_am_i_under_a_vmap as unsafe_am_i_under_a_vmap_DO_NOT_USE, # noqa: F401
+ unsafe_get_axis_names as unsafe_get_axis_names_DO_NOT_USE, # noqa: F401
valid_jaxtype as valid_jaxtype,
)
+
+_deprecations = {
+ # Added for v0.8.2
+ "call_impl": (
+ "jax.core.call_impl is deprecated.",
+ _src_core.call_impl,
+ ),
+ "get_aval": (
+ "jax.core.get_aval is deprecated; use jax.typeof instead.",
+ _src_core.get_aval,
+ ),
+ "mapped_aval": (
+ "jax.core.mapped_aval is deprecated.",
+ _src_core.mapped_aval,
+ ),
+ "set_current_trace": (
+ "jax.core.set_current_trace is deprecated.",
+ _src_core.set_current_trace,
+ ),
+ "subjaxprs": (
+ "jax.core.subjaxprs is deprecated.",
+ _src_core.subjaxprs,
+ ),
+ "take_current_trace": (
+ "jax.core.take_current_trace is deprecated.",
+ _src_core.take_current_trace,
+ ),
+ "traverse_jaxpr_params": (
+ "jax.core.traverse_jaxpr_params is deprecated.",
+ _src_core.traverse_jaxpr_params,
+ ),
+ "unmapped_aval": (
+ "jax.core.unmapped_aval is deprecated.",
+ _src_core.unmapped_aval,
+ ),
+ "AbstractToken": (
+ "jax.core.AbstractToken is deprecated.",
+ _src_core.AbstractToken,
+ ),
+ "TraceTag": (
+ "jax.core.TraceTag is deprecated.",
+ _src_core.TraceTag,
+ ),
+}
+
+import typing as _typing
+if _typing.TYPE_CHECKING:
+ call_impl = _src_core.call_impl
+ get_aval = _src_core.get_aval
+ mapped_aval = _src_core.mapped_aval
+ subjaxprs = _src_core.subjaxprs
+ set_current_trace = _src_core.set_current_trace
+ take_current_trace = _src_core.take_current_trace
+ traverse_jaxpr_params = _src_core.traverse_jaxpr_params
+ unmapped_aval = _src_core.unmapped_aval
+ AbstractToken = _src_core.AbstractToken
+ TraceTag = _src_core.TraceTag
+else:
+ from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
+ __getattr__ = _deprecation_getattr(__name__, _deprecations)
+ del _deprecation_getattr
+del _typing
+del _src_core
diff --git a/jax/custom_derivatives.py b/jax/custom_derivatives.py
index 2dde0d3cacbb..edefdae40c44 100644
--- a/jax/custom_derivatives.py
+++ b/jax/custom_derivatives.py
@@ -35,22 +35,3 @@
SymbolicZero as SymbolicZero,
zero_from_primal as zero_from_primal
)
-
-_deprecations = {
- # Finalized for v0.8.0; remove in v0.9.0
- "custom_jvp_call_jaxpr_p": (
- ("jax.custom_derivatives.custom_jvp_call_jaxpr_p was deprecated in v0.7.0"
- " and removed in v0.8.0. use jax.extend.core.primitives.custom_jvp_call_p"
- " instead, and please note that you must `import jax.extend` explicitly."),
- None,
- ),
-}
-
-import typing
-if typing.TYPE_CHECKING:
- pass
-else:
- from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
- __getattr__ = _deprecation_getattr(__name__, _deprecations)
- del _deprecation_getattr
-del typing
diff --git a/jax/dlpack.py b/jax/dlpack.py
index da04ed7119d7..6fa73748ee8b 100644
--- a/jax/dlpack.py
+++ b/jax/dlpack.py
@@ -16,19 +16,3 @@
from_dlpack as from_dlpack,
is_supported_dtype as is_supported_dtype,
)
-
-_deprecations = {
- # Deprecated in JAX v0.7.0
- "SUPPORTED_DTYPES": (
- (
- "jax.SUPPORTED_DTYPES is deprecated in JAX v0.7.0 and will be removed"
- " in JAX v0.8.0. Use jax.dlpack.is_supported_dtype() instead."
- ),
- None,
- ),
-}
-
-
-from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
-__getattr__ = _deprecation_getattr(__name__, _deprecations)
-del _deprecation_getattr
diff --git a/jax/errors.py b/jax/errors.py
index 928ab6c8a7f2..a4a6c5388db2 100644
--- a/jax/errors.py
+++ b/jax/errors.py
@@ -31,14 +31,3 @@
JaxRuntimeError = _jax.JaxRuntimeError
JaxRuntimeError.__module__ = "jax.errors"
del _jax
-
-_deprecations = {
- "SimplifiedTraceback": (
- "jax.errors.SimplifiedTraceback is deprecated and will be removed in JAX v0.8.",
- None,
- ),
-}
-
-from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
-__getattr__ = _deprecation_getattr(__name__, _deprecations)
-del _deprecation_getattr
diff --git a/jax/example_libraries/BUILD b/jax/example_libraries/BUILD
index 46805163d572..e740757c32a1 100644
--- a/jax/example_libraries/BUILD
+++ b/jax/example_libraries/BUILD
@@ -19,6 +19,14 @@ package(
default_visibility = ["//jax:internal"],
)
+pytype_strict_library(
+ name = "example_libraries",
+ srcs = [
+ "__init__.py",
+ ],
+ visibility = ["//jax:internal"],
+)
+
pytype_strict_library(
name = "stax",
srcs = [
@@ -39,11 +47,3 @@ pytype_strict_library(
"//jax/_src:util",
] + py_deps("numpy"),
)
-
-# TODO(dsuo): Remove this filegroup once experimental aliases from jax/BUILD are
-# removed.
-filegroup(
- name = "jax_example_libraries",
- srcs = glob(["*.py"]),
- visibility = ["//jax:internal"],
-)
diff --git a/jax/experimental/BUILD b/jax/experimental/BUILD
index f9c4de048eb5..bb18041e83dd 100644
--- a/jax/experimental/BUILD
+++ b/jax/experimental/BUILD
@@ -15,6 +15,7 @@
load(
"//jaxlib:jax.bzl",
"buffer_callback_internal_users",
+ "experimental_transfer_users",
"if_cuda_is_configured",
"jax_visibility",
"mosaic_gpu_internal_users",
@@ -41,6 +42,12 @@ package_group(
packages = buffer_callback_internal_users,
)
+package_group(
+ name = "experimental_transfer_users",
+ includes = ["//jax:internal"],
+ packages = experimental_transfer_users,
+)
+
package_group(
name = "mosaic_users",
includes = ["//jax:internal"],
@@ -117,11 +124,13 @@ pytype_strict_library(
"//jax",
"//jax/_src:api",
"//jax/_src:api_util",
+ "//jax/_src:config",
"//jax/_src:traceback_util",
"//jax/_src:tree_util",
"//jax/_src:util",
"//jax/_src:xla_bridge",
"//jax/_src/lib",
+ "//jax/extend:backend",
"//jax/extend:ifrt_programs",
] + py_deps("numpy") + py_deps("cloudpickle"),
)
@@ -582,7 +591,9 @@ pytype_strict_library(
# be used in new code. Use jax.shard_map instead.
name = "shard_map",
srcs = ["shard_map.py"],
- visibility = jax_visibility("experimental/shard_map"),
+ visibility = [
+ "//jax:internal",
+ ] + jax_visibility("experimental/shard_map"),
deps = [
"//jax",
"//jax/_src:mesh",
@@ -685,7 +696,10 @@ pytype_strict_library(
pytype_strict_library(
name = "transfer",
srcs = ["transfer.py"],
- visibility = ["//jax:internal"],
+ visibility = [
+ ":experimental_transfer_users",
+ "//jax:internal",
+ ],
deps = [
"//jax",
"//jax/_src:util",
@@ -738,18 +752,3 @@ filegroup(
],
visibility = ["//jax:internal"],
)
-
-filegroup(
- name = "jax_experimental",
- srcs = glob(
- [
- "*.py",
- ],
- exclude = [
- "buffer_callback.py",
- "mental/mosaic/gpu/*.py",
- "serialize_executable.py",
- ],
- ),
- visibility = ["//jax:internal"],
-)
diff --git a/jax/experimental/__init__.py b/jax/experimental/__init__.py
index 474f9ce5f675..606033c43ac7 100644
--- a/jax/experimental/__init__.py
+++ b/jax/experimental/__init__.py
@@ -21,10 +21,6 @@
# experimental features and as a result, more flexibility to manage their status
# and lifetimes.
-from jax._src.api import (
- saved_input_vjp as saved_input_vjp,
- si_vjp as si_vjp
-)
from jax._src.callback import (
io_callback as io_callback
)
diff --git a/jax/experimental/colocated_python/func.py b/jax/experimental/colocated_python/func.py
index 27f8f31f70cb..220f0cfdf540 100644
--- a/jax/experimental/colocated_python/func.py
+++ b/jax/experimental/colocated_python/func.py
@@ -15,12 +15,14 @@
from __future__ import annotations
+from collections.abc import Callable, Sequence
import dataclasses
import inspect
import random
import threading
from typing import Any
-from collections.abc import Callable, Sequence
+import uuid
+import weakref
import jax
from jax._src import api
@@ -31,7 +33,8 @@
from jax._src.traceback_util import api_boundary
from jax._src.util import wraps
from jax.experimental.colocated_python import func_backend
-from jax.experimental.colocated_python.serialization import _deserialize_specs, _make_specs_for_serialized_specs, _serialize, _serialize_specs
+from jax.experimental.colocated_python.serialization import _deserialize, _deserialize_specs, _make_specs_for_serialized_specs, _serialize, _serialize_specs
+from jax.extend.backend import register_backend_cache as jax_register_backend_cache
from jax.extend.ifrt_programs import ifrt_programs
ShapeDtypeStructTree = Any # PyTree[api.ShapeDtypeStruct]
@@ -186,12 +189,14 @@ def call(*args, **kwargs):
# TODO(hyeontaek): Implement colocated Python support in McJAX and remove
# this fallback path.
if "PjRtCompiler requires an HloProgram" in str(e):
- return fun
+ return _deserialize(pickled_function)[0]
raise
def _make_output_specs_and_push_result_fun(
- info: FunctionInfo, specialization: Specialization, uid: int
+ info: FunctionInfo,
+ specialization: Specialization,
+ uid: int,
) -> Callable[..., Any]:
"""Creates a function that computes output specs and pushes the result to the result store."""
assert specialization.in_specs_treedef is not None
@@ -226,7 +231,9 @@ def lowered_fun(*args, **kwargs) -> jax.Array:
def _make_pop_result_fun(
- info: FunctionInfo, specialization: Specialization, uid: int
+ info: FunctionInfo,
+ specialization: Specialization,
+ uid: int,
) -> Callable[..., Any]:
"""Makes a function that pops results from the result store."""
assert specialization.out_specs_treedef is not None
@@ -259,7 +266,8 @@ def lowered_fun():
def _make_async_execution_fun(
- info: FunctionInfo, specialization: Specialization
+ info: FunctionInfo,
+ specialization: Specialization,
) -> Callable[..., Any]:
"""Makes a function that asynchronously executes the function."""
assert specialization.in_specs_treedef is not None
@@ -280,9 +288,9 @@ def _make_async_execution_fun(
)
-@jax._src.util.cache(max_size=None)
-def _get_specialized_func(
- info: FunctionInfo, specialization: Specialization
+def _uncached_get_specialized_func(
+ info: FunctionInfo,
+ specialization: Specialization,
) -> Callable[..., Any]:
"""Returns a specialized function for the given specialization."""
util.test_event("colocated_python_func._get_specialized_func")
@@ -302,9 +310,14 @@ def specialized_func(*args, **kwargs):
if async_execution_func is None:
if specialization.out_specs_treedef is None:
if specialization.out_specs_fn is None:
- serialized_out_specs = _make_output_specs_and_push_result_fun(
- info, specialization, uid
- )(*args, **kwargs)
+ output_specs_and_push_result_fun = (
+ _make_output_specs_and_push_result_fun(
+ info, specialization, uid
+ )
+ )
+ serialized_out_specs = output_specs_and_push_result_fun(
+ *args, **kwargs
+ )
# Waits for the output_specs. This may block.
out_specs_treedef, out_specs_leaves = _deserialize_specs(
@@ -321,6 +334,13 @@ def specialized_func(*args, **kwargs):
info, specialization
)
+ # Hold the PyExecutable until async_execution_fun is called at
+ # least once, so the number of _OBJECT_STORE references at the
+ # backend does not drop to 0.
+ async_execution_func.output_specs_and_push_result_fun = (
+ output_specs_and_push_result_fun
+ )
+
return _make_pop_result_fun(info, specialization, uid)()
else:
# Compute out_specs using out_specs_fn and inputs.
@@ -348,122 +368,345 @@ def specialized_func(*args, **kwargs):
# Asynchronous execution runs outside of the mutex to allow concurrent
# execution for inline executors.
- return async_execution_func(*args, **kwargs)
+ result = async_execution_func(*args, **kwargs)
+ with mutex:
+ async_execution_func.output_specs_and_push_result_fun = None
+ return result
return specialized_func
-def make_callable(
- fun: Callable[..., Any],
- fun_sourceinfo: str | None,
- fun_signature: inspect.Signature | None,
-):
- """Makes a colocated Python callable."""
- return _make_callable(
- FunctionInfo(fun, fun_sourceinfo, fun_signature), Specialization()
- )
+class _SpecializedCollection:
+ """Collection of specialized functions for a single unspecialized function.
+
+ The `get()` method retrieves the specialized function for the provided input
+ spec, either by looking up a cache or by compiling the specialized function.
+
+ Looking up a cache with an input spec as a key can be slow, because
+ `Sharding`'s equivalence comparison is slow. Instead, we maintain two caches
+ for the same value: we use the ID of the sharding object (via `WeakSpec`) as
+ the key in one cache, and the corresponding strong references to the sharding
+ object (via `StrongSpec`) as the key in another cache. Looking up the
+ `WeakSpec`-keyed cache is fast. Note that the ID integer in the `WeakSpec`
+ cache will remain valid as long as a strong-ref exists in the `StrongSpec`
+ cache.
+
+ The `StrongSpec`-keyed cache is unbounded, while the `WeakSpec`-keyed cache
+ is LRU(1): if there is a miss in the `WeakSpec` cache but a hit in the
+ `StrongSpec` cache, the strong-ref is the `StrongSpec` cache and the ID
+ integer in the `WeakSpec` cache are both updated.
+ """
+
+ @dataclasses.dataclass(slots=True, unsafe_hash=True)
+ class WeakSpec:
+ """WeakSpec stores just the `id()` of the input spec sharding."""
+
+ dtypes: tuple[jax.numpy.dtype, ...]
+ shapes: tuple[tuple[int, ...], ...]
+ sharding_ids: tuple[int, ...]
+ treedef: tree_util.PyTreeDef
+
+ def __init__(
+ self, args_leaves: Sequence[jax.Array], treedef: tree_util.PyTreeDef
+ ):
+ self.dtypes = tuple(x.dtype for x in args_leaves)
+ self.shapes = tuple(x.shape for x in args_leaves)
+ self.sharding_ids = tuple(id(x.sharding) for x in args_leaves)
+ self.treedef = treedef
+
+ @dataclasses.dataclass(slots=True, unsafe_hash=True)
+ class StrongSpec:
+ """StrongSpec stores the full input spec sharding."""
+
+ in_specs_treedef: tree_util.PyTreeDef | None = None
+ in_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None
+
+ def __init__(
+ self, args_leaves: Sequence[jax.Array], pytreedef: tree_util.PyTreeDef
+ ):
+ self.in_specs_leaves = tuple(_get_spec(x) for x in args_leaves)
+ self.in_specs_treedef = pytreedef
+
+ def __init__(self):
+ CompiledId = int
+
+ self._weak_to_id: dict[_SpecializedCollection.WeakSpec, CompiledId] = {}
+ self._id_to_weak: dict[CompiledId, _SpecializedCollection.WeakSpec] = {}
+ self._strong_to_id: dict[_SpecializedCollection.StrongSpec, CompiledId] = {}
+ self._id_to_compiled: dict[CompiledId, Callable[..., Any]] = {}
+
+ self._counter = 0
+ self._mu = threading.Lock()
+
+ def get(
+ self,
+ args_leaves: Sequence[jax.Array],
+ pytreedef: tree_util.PyTreeDef,
+ func_info: FunctionInfo,
+ specialization: Specialization,
+ ) -> Callable[..., Any]:
+ # TODO(hyeontaek): Allow Python values in args_leaves, similar to the todo
+ # in _get_spec().
+
+ # Attempt fast-path cache hit.
+ weak_spec = _SpecializedCollection.WeakSpec(args_leaves, pytreedef)
+ compiled_id = self._weak_to_id.get(weak_spec)
+ if compiled_id is not None:
+ return self._id_to_compiled[compiled_id]
+
+ with self._mu:
+ # Attempt slow-path cache hit.
+ strong_spec = _SpecializedCollection.StrongSpec(args_leaves, pytreedef)
+ compiled_id = self._strong_to_id.pop(strong_spec, None)
+ if compiled_id is not None:
+ # Update the caches so that the fast-path cache stores the `id()` of the
+ # shardings presented by the current invocation.
+ old_weak = self._id_to_weak.pop(compiled_id)
+ del self._weak_to_id[old_weak]
+
+ self._strong_to_id[strong_spec] = compiled_id
+ self._weak_to_id[weak_spec] = compiled_id
+ self._id_to_weak[compiled_id] = weak_spec
+
+ return self._id_to_compiled[compiled_id]
+
+ # Cache-miss: compile.
+ if specialization.devices is None:
+ result = _uncached_get_specialized_func(
+ func_info,
+ specialization.update(
+ in_specs_treedef=strong_spec.in_specs_treedef,
+ in_specs_leaves=strong_spec.in_specs_leaves,
+ devices=_infer_devices_from_args(args_leaves),
+ ),
+ )
+ else:
+ result = _uncached_get_specialized_func(
+ func_info,
+ specialization.update(
+ in_specs_treedef=strong_spec.in_specs_treedef,
+ in_specs_leaves=strong_spec.in_specs_leaves,
+ ),
+ )
+ compiled_id = self._counter
+ self._counter += 1
-def _make_callable(info: FunctionInfo, specialization: Specialization):
- """Internal implementation of make_callable."""
+ self._weak_to_id[weak_spec] = compiled_id
+ self._strong_to_id[strong_spec] = compiled_id
+ self._id_to_weak[compiled_id] = weak_spec
+ self._id_to_compiled[compiled_id] = result
+ return result
- def specialize(
- in_specs: ShapeDtypeStructTree | None = None,
- out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None,
- devices: Sequence[jax.Device] | None = None,
+
+class _JaxSecondLevelCaches:
+ """Manages second-level caches registered as a single cache with JAX."""
+
+ def __init__(self, name: str):
+ self._lock = threading.Lock()
+ self._callbacks: dict[int, Callable[..., Any]] = {}
+ jax_register_backend_cache(self, name)
+
+ def cache_clear(self):
+ """Meant to be invoked by JAX internals."""
+ for callback in self._callbacks.values():
+ callback()
+ self._callbacks.clear()
+
+ def register_second_level(
+ self, uid: int, cache_clear_callback: Callable[..., Any]
):
- """Returns a colocated Python callable with extra specialization.
-
- Args:
- in_specs: Optionally specifies the expected input specs. Input specs are
- expressed as a `PyTree[ShapeDtypeStruct]` for `(args, kwargs)` of a
- function call.
- out_specs_fn: Optionally specifies a function that computes the output
- specs from input specs. If unspecified, colocated Python will compute
- the output specs during the very first execution, and this execution
- will be synchronous.
- devices: Optionally specifies the devices to execute the function on. Must
- be provided if `in_specs` has no leaves because devices cannot be
- inferred from input specs or arguments.
-
- Returns:
- A colocated Python callable with extra specialization.
- """
- # TODO(hyeontaek): Allow unspecified devices for zero-leaf `in_specs` if
- # `out_specs_fn(in_specs)` returns at least one leaf that we can use for
- # inferring `devices`.
- if in_specs is None:
- in_specs_leaves, in_specs_treedef = None, None
- else:
- in_specs_leaves_list, in_specs_treedef = tree_util.tree_flatten(in_specs)
- in_specs_leaves = tuple(in_specs_leaves_list)
- return _make_callable(
- info,
- specialization.update(
- in_specs_treedef=in_specs_treedef,
- in_specs_leaves=in_specs_leaves,
- out_specs_fn=out_specs_fn,
- devices=devices,
- ),
- )
+ self._callbacks[uid] = cache_clear_callback
+
+ def remove_second_level(self, uid: int):
+ try:
+ self._callbacks.pop(uid)
+ except KeyError:
+ pass
+
+
+class _CachedColocatedFunctionMaker:
+ """Function maker for colocated Python functions.
+
+ Generated functions are stored (cached) indefinitely so that they can be
+ reused, until the cache is dropped.
+ """
- @api_boundary
- def __call__(*args, **kwargs):
- """Executes the given Python function on the same devices as the arguments or as specialized.
-
- If the callable has not been specialized with output shapes and shardings
- (see `specialize` above), the very first call will run synchronously to
- discover output shapes and shardings, and will run asynchronously after. If
- specialized with output shapes and shardings, every execution of the
- callable will be asynchronous.
- """
- args_leaves, in_specs_treedef = tree_util.tree_flatten((args, kwargs))
-
- in_specs_leaves = tuple(_get_spec(x) for x in args_leaves)
- if specialization.in_specs_treedef is None:
- # Allow input polymorphism by applying input_specs specialization
- # temporarily for this call.
- return _make_callable(
+ JAX_CACHE = _JaxSecondLevelCaches("colocated_python_specialized_func_cache")
+
+ def __init__(self, held_by: int | None):
+ self.held_by = held_by if held_by is not None else uuid.uuid4().int
+ specialized_collections: list[_SpecializedCollection] = []
+ specialized_functions: list[Callable[..., Any]] = []
+
+ def clear_caches():
+ specialized_collections.clear()
+ specialized_functions.clear()
+
+ _CachedColocatedFunctionMaker.JAX_CACHE.register_second_level(
+ self.held_by,
+ clear_caches,
+ )
+ self.specialized_collections = specialized_collections
+ self.specialized_functions = specialized_functions
+
+ def __del__(self):
+ self.specialized_collections.clear()
+ self.specialized_functions.clear()
+ try:
+ _CachedColocatedFunctionMaker.JAX_CACHE.remove_second_level(self.held_by)
+ except AttributeError:
+ # Ignore error during python finalization.
+ pass
+
+ def _make_callable(
+ self,
+ info: FunctionInfo,
+ specialization: Specialization,
+ ):
+ """Internal implementation of make_callable."""
+
+ def specialize(
+ in_specs: ShapeDtypeStructTree | None = None,
+ out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None,
+ devices: Sequence[jax.Device] | None = None,
+ ):
+ """Returns a colocated Python callable with extra specialization.
+
+ Args:
+ in_specs: Optionally specifies the expected input specs. Input specs are
+ expressed as a `PyTree[ShapeDtypeStruct]` for `(args, kwargs)` of a
+ function call.
+ out_specs_fn: Optionally specifies a function that computes the output
+ specs from input specs. If unspecified, colocated Python will compute
+ the output specs during the very first execution, and this execution
+ will be synchronous.
+ devices: Optionally specifies the devices to execute the function on.
+ Must be provided if `in_specs` has no leaves because devices cannot be
+ inferred from input specs or arguments.
+
+ Returns:
+ A colocated Python callable with extra specialization.
+ """
+ # TODO(hyeontaek): Allow unspecified devices for zero-leaf `in_specs` if
+ # `out_specs_fn(in_specs)` returns at least one leaf that we can use for
+ # inferring `devices`.
+ if in_specs is None:
+ in_specs_leaves, in_specs_treedef = None, None
+ else:
+ in_specs_leaves_list, in_specs_treedef = tree_util.tree_flatten(
+ in_specs
+ )
+ in_specs_leaves = tuple(in_specs_leaves_list)
+ return self._make_callable(
info,
specialization.update(
in_specs_treedef=in_specs_treedef,
in_specs_leaves=in_specs_leaves,
+ out_specs_fn=out_specs_fn,
+ devices=devices,
),
- )(*args, **kwargs)
+ )
- if specialization.devices is None:
- devices = _infer_devices_from_args(args_leaves)
- if devices is None:
+ # Caches for a collection of specialized functions or a specialized function
+ # itself. The latter is used as a performance optimization when the input
+ # spec is explicitly specified and can skip a collection lookup. The caches
+ # use weakrefs so that we avoid creating cyclic references.
+ specialized_collections_wref = lambda: None
+ specialized_functions_wref = lambda: None
+ wref_mu = threading.Lock()
+
+ @api_boundary
+ def __call__(*args, **kwargs):
+ """Executes the given Python function on the same devices as the arguments or as specialized.
+
+ If the callable has not been specialized with output shapes and shardings
+ (see `specialize` above), the very first call will run synchronously to
+ discover output shapes and shardings, and will run asynchronously after.
+ If specialized with output shapes and shardings, every execution of the
+ callable will be asynchronous.
+ """
+ args_leaves, in_specs_treedef = tree_util.tree_flatten((args, kwargs))
+
+ no_input = len(args_leaves) == 0
+ if no_input and specialization.devices is None:
raise ValueError(
"No devices found. colocated_python function without input"
" arguments must be first specialized with devices."
)
- # Allow device polymorphism by applying devices specialization temporarily
- # for this call.
- return _make_callable(info, specialization.update(devices=devices))(
- *args, **kwargs
+
+ fully_specified_in_spec = (
+ specialization.in_specs_treedef is not None
+ and specialization.in_specs_leaves is not None
)
- # Assertion is added to silence mypy error: Unsupported operand types for !=
- # ("PyTreeDef" and "None") [operator]
- assert isinstance(specialization.in_specs_treedef, tree_util.PyTreeDef)
-
- # If input_specs is known, verify that it matches actual inputs.
- if (specialization.in_specs_treedef != in_specs_treedef
- or specialization.in_specs_leaves != in_specs_leaves):
- raise ValueError(
- "Input specs in specialization and input specs of arguments must have"
- " the same pytree structure, but they have the following structural"
- " differences:\n"
- + ("\n".join(
- f" - {tree_util.keystr(path)} is a {thing1} in value 1 and"
- f" a {thing2} in value 2, so {explanation}.\n"
- for path, thing1, thing2, explanation in tree_util.equality_errors_pytreedef(
- specialization.in_specs_treedef, in_specs_treedef
- ))))
-
- return _get_specialized_func(info, specialization)(*args, **kwargs)
-
- __call__ = wraps(info.fun)(__call__)
- __call__.specialize = specialize
- return __call__
+ if not fully_specified_in_spec and not no_input:
+ # We need to handle input polymorphism
+ nonlocal specialized_collections_wref
+ with wref_mu:
+ collection: _SpecializedCollection = specialized_collections_wref()
+ if collection is None:
+ collection = _SpecializedCollection()
+ self.specialized_collections.append(collection)
+ specialized_collections_wref = weakref.ref(collection)
+ result = collection.get(
+ args_leaves, in_specs_treedef, info, specialization
+ )(*args, **kwargs)
+ del collection
+ return result
+
+ # No input polymorphism -- exactly one compiled function is possible.
+ with wref_mu:
+ nonlocal specialized_functions_wref
+ func: Callable[..., Any] = specialized_functions_wref()
+ if func is None:
+ if fully_specified_in_spec and specialization.devices is not None:
+ func = _uncached_get_specialized_func(info, specialization)
+ elif fully_specified_in_spec:
+ func = _uncached_get_specialized_func(
+ info,
+ specialization.update(
+ devices=_infer_devices_from_args(args_leaves)
+ ),
+ )
+ elif no_input:
+ func = _uncached_get_specialized_func(
+ info,
+ specialization.update(
+ in_specs_leaves=tuple(),
+ in_specs_treedef=in_specs_treedef,
+ ),
+ )
+ self.specialized_functions.append(func)
+ specialized_functions_wref = weakref.ref(func)
+ result = func(*args, **kwargs)
+ del func
+ return result
+
+ __call__ = wraps(info.fun)(__call__)
+ __call__.specialize = specialize
+ return __call__
+
+ def make_callable(
+ self,
+ fun: Callable[..., Any],
+ fun_sourceinfo: str | None,
+ fun_signature: inspect.Signature | None,
+ ):
+ """Makes a colocated Python callable."""
+ return self._make_callable(
+ FunctionInfo(fun, fun_sourceinfo, fun_signature), Specialization()
+ )
+
+
+_DEFAULT_FUNCTION_MAKER = _CachedColocatedFunctionMaker(None)
+
+
+def make_callable(
+ fun: Callable[..., Any],
+ fun_sourceinfo: str | None,
+ fun_signature: inspect.Signature | None,
+):
+ return _DEFAULT_FUNCTION_MAKER.make_callable(
+ fun, fun_sourceinfo, fun_signature
+ )
diff --git a/jax/experimental/colocated_python/obj.py b/jax/experimental/colocated_python/obj.py
index 2351acd0d096..b804cd836a12 100644
--- a/jax/experimental/colocated_python/obj.py
+++ b/jax/experimental/colocated_python/obj.py
@@ -15,14 +15,15 @@
from __future__ import annotations
+from collections.abc import Callable
import inspect
import random
import threading
from typing import Any
-from collections.abc import Callable
import jax
from jax._src import api_util
+from jax._src import config
from jax._src import tree_util
from jax._src.traceback_util import api_boundary
from jax._src.util import wraps
@@ -30,6 +31,20 @@
from jax.experimental.colocated_python import obj_backend
+# TODO(madthanu): Remove the following config option and make its behavior the
+# default, once the behavior has been declared stable.
+_USE_WEAKREFS = config.bool_state(
+ 'jax_experimental_colocated_python_object_use_weakrefs_at_backend',
+ False,
+ help=(
+ 'Unstable in-development feature that switches the colocated-python'
+ ' implementation to internally use reference counting for destructing'
+ ' objects at the colocated backend, instead of invoking an explicit'
+ ' delete-object function from the frontend.'
+ ),
+)
+
+
class _InstanceRegistry:
"""Registry of object instances."""
@@ -78,19 +93,50 @@ def _make_method(
init_kwargs: dict[str, Any],
method_name: str,
original_method: Callable[..., Any],
+ func_maker: func._CachedColocatedFunctionMaker,
+ use_weakrefs: bool,
):
- # Initializer to use when the object is not present in the backend.
- def initializer() -> object:
- return cls(*init_args, **init_kwargs)
- # Method to call on the backend.
- def method(*args, **kwargs):
- obj = obj_backend.SINGLETON_OBJECT_STORE.get_or_create(uid, initializer)
- return getattr(obj, method_name)(*args, **kwargs)
+ class MethodCallerAtBackend:
+
+ def __init__(self):
+ self._lock = threading.Lock()
+
+ def __reduce__(self):
+ return type(self), ()
+
+ def _first_call(self):
+ # Temporarily hold a strong reference to a new object if it is created
+ # using initializer.
+ new_obj = None
+
+ def initializer():
+ nonlocal new_obj
+ new_obj = cls(*init_args, **init_kwargs)
+ if use_weakrefs:
+ import weakref
+
+ return weakref.ref(new_obj)
+ return new_obj
+
+ retrieved = obj_backend.SINGLETON_OBJECT_STORE.get_or_create(
+ uid, initializer
+ )
+
+ if use_weakrefs:
+ self.obj = retrieved()
+ else:
+ self.obj = retrieved
+
+ def __call__(self, *args, **kwargs):
+ with self._lock:
+ if not hasattr(self, 'obj'):
+ self._first_call()
+ return getattr(self.obj, method_name)(*args, **kwargs)
# Colocated Python callable for the controller.
- callable = func.make_callable(
- method,
+ callable = func_maker.make_callable(
+ MethodCallerAtBackend(),
cls_sourceinfo,
api_util.fun_signature(original_method),
)
@@ -143,6 +189,8 @@ def __init__(self, *init_args, **init_kwargs) -> None:
uid = self._colocated_python_uid = (
SINGLETON_INSTANCE_REGISTRY.new_instance()
)
+ self.func_maker = func._CachedColocatedFunctionMaker(uid)
+ self.use_weakrefs = _USE_WEAKREFS.value
for attr_name in dir(cls):
original_member = getattr(cls, attr_name)
if not inspect.isfunction(original_member):
@@ -162,12 +210,17 @@ def __init__(self, *init_args, **init_kwargs) -> None:
init_kwargs,
attr_name,
original_member,
+ self.func_maker,
+ self.use_weakrefs,
)
# TODO(hyeontaek): Support method specialization similar to function
# specialization.
setattr(self, attr_name, method)
- def __del__(self) -> None:
+ def __del__(self):
+ del self.func_maker
+ if self.use_weakrefs:
+ return
uid = self._colocated_python_uid
devices = SINGLETON_INSTANCE_REGISTRY.pop_instance(uid)
if devices:
@@ -175,9 +228,6 @@ def __del__(self) -> None:
def remove_object() -> None:
obj_backend.SINGLETON_OBJECT_STORE.remove(uid)
- # TODO(hyeontaek): Request "best-effort" non-SPMD execution that tries
- # to run this function on any healthy processes instead of failing when
- # any process of the execution is unhealthy.
destructor = func.make_callable(
remove_object,
cls_sourceinfo,
diff --git a/jax/experimental/compilation_cache/compilation_cache.py b/jax/experimental/compilation_cache/compilation_cache.py
index 8b993c1c142a..8c820c5434fe 100644
--- a/jax/experimental/compilation_cache/compilation_cache.py
+++ b/jax/experimental/compilation_cache/compilation_cache.py
@@ -16,30 +16,3 @@
set_cache_dir as set_cache_dir,
reset_cache as reset_cache,
)
-
-_deprecations = {
- # Finalized for v0.8.0; remove in v0.9.0
- "is_initialized": (
- (
- "compilation_cache.is_initialized was deprecated in JAX v0.4.24 and"
- " removed in JAX v0.8.0."
- ),
- None,
- ),
- "initialize_cache": (
- (
- "compilation_cache.initialize_cache was deprecated in JAX v0.4.24 and"
- " removed in JAX v0.8.0. use compilation_cache.set_cache_dir instead."
- ),
- None,
- ),
-}
-
-import typing as _typing
-if _typing.TYPE_CHECKING:
- pass
-else:
- from jax._src.deprecations import deprecation_getattr
- __getattr__ = deprecation_getattr(__name__, _deprecations)
- del deprecation_getattr
-del _typing
diff --git a/jax/experimental/hijax.py b/jax/experimental/hijax.py
index 5e5bb0512c79..087569ae9234 100644
--- a/jax/experimental/hijax.py
+++ b/jax/experimental/hijax.py
@@ -36,4 +36,5 @@
)
from jax._src.state import (
AbstractRef as AbstractRef,
+ TransformedRef as TransformedRef
)
diff --git a/jax/experimental/jax2tf/examples/saved_model_main_test.py b/jax/experimental/jax2tf/examples/saved_model_main_test.py
index 28bce3d014e6..b515ced90ca5 100644
--- a/jax/experimental/jax2tf/examples/saved_model_main_test.py
+++ b/jax/experimental/jax2tf/examples/saved_model_main_test.py
@@ -48,9 +48,11 @@ def setUp(self):
def test_train_and_save_full(self,
model="mnist_flax",
serving_batch_size=-1):
+ self.skipTest("no more dynamic shapes")
if (serving_batch_size == -1 and
- config.jax2tf_default_native_serialization.value and
- not config.dynamic_shapes.value):
+ config.jax2tf_default_native_serialization.value
+ # and not config.dynamic_shapes.value
+ ):
self.skipTest("shape polymorphism but --jax_dynamic_shapes is not set.")
FLAGS.model = model
FLAGS.model_classifier_layer = True
diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py
index fe1ee305eb3a..80e44f8938c1 100644
--- a/jax/experimental/jax2tf/tests/jax2tf_test.py
+++ b/jax/experimental/jax2tf/tests/jax2tf_test.py
@@ -19,6 +19,7 @@
import os
import re
import unittest
+import warnings
from absl import logging
from absl.testing import absltest, parameterized
@@ -40,7 +41,12 @@
import numpy as np
try:
- import tensorflow as tf
+ # TODO(b/470156950): Remove this once a proper fix is in place
+ with warnings.catch_warnings():
+ warnings.filterwarnings("ignore",
+ category=FutureWarning,
+ message=".*np.object.*")
+ import tensorflow as tf
from jax.experimental import jax2tf
from jax.experimental.jax2tf.tests import tf_test_util
JaxToTfTestCase = tf_test_util.JaxToTfTestCase
@@ -1229,8 +1235,9 @@ def assertAllOperationStartWith(self, g: "tf.Graph", scope_name: str):
self.fail(f"{op.name} does not start with {scope_name}.")
def test_name_scope_polymorphic(self):
- if not config.dynamic_shapes.value:
- self.skipTest("shape polymorphism but --jax_dynamic_shapes is not set.")
+ self.skipTest("no more dynamic shapes")
+ # if not config.dynamic_shapes.value:
+ # self.skipTest("shape polymorphism but --jax_dynamic_shapes is not set.")
def func_jax(x, y):
return jnp.sin(x) + jnp.cos(y)
diff --git a/jax/experimental/jax2tf/tests/multiprocess/jax2tf_multiprocess_test.py b/jax/experimental/jax2tf/tests/multiprocess/jax2tf_multiprocess_test.py
index 2576a277ab83..fa4861f55d92 100644
--- a/jax/experimental/jax2tf/tests/multiprocess/jax2tf_multiprocess_test.py
+++ b/jax/experimental/jax2tf/tests/multiprocess/jax2tf_multiprocess_test.py
@@ -22,9 +22,15 @@
from jax.experimental import multihost_utils
from jax.sharding import PartitionSpec as P
import unittest
+import warnings
try:
- import tensorflow as tf
+ # TODO(b/470156950): Remove this once a proper fix is in place
+ with warnings.catch_warnings():
+ warnings.filterwarnings("ignore",
+ category=FutureWarning,
+ message=".*np.object.*")
+ import tensorflow as tf
from jax.experimental import jax2tf
from jax.experimental.jax2tf.tests import tf_test_util
JaxToTfTestCase = tf_test_util.JaxToTfTestCase
diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py
index 0167c3c45ea8..f0bc0ffa78d5 100644
--- a/jax/experimental/jax2tf/tests/sharding_test.py
+++ b/jax/experimental/jax2tf/tests/sharding_test.py
@@ -24,6 +24,7 @@
import re
from typing import Any
import unittest
+import warnings
from absl import app
from absl.testing import absltest
@@ -46,7 +47,12 @@
import numpy as np
-import tensorflow as tf
+# TODO(b/470156950): Remove this once a proper fix is in place
+with warnings.catch_warnings():
+ warnings.filterwarnings("ignore",
+ category=FutureWarning,
+ message=".*np.object.*")
+ import tensorflow as tf
config.parse_flags_with_absl()
jtu.request_cpu_devices(8)
diff --git a/jax/experimental/mosaic/gpu/constraints.py b/jax/experimental/mosaic/gpu/constraints.py
index b1d1f3ec45a9..b6fda3d520fc 100644
--- a/jax/experimental/mosaic/gpu/constraints.py
+++ b/jax/experimental/mosaic/gpu/constraints.py
@@ -24,7 +24,7 @@
from collections.abc import Sequence
import dataclasses
import math
-from typing import Any, Callable, assert_never, final
+from typing import Any, assert_never, final
from . import fragmented_array as fa
from . import launch_context as lc
@@ -86,22 +86,6 @@ def __str__(self):
return f"C({self.value})"
-@dataclasses.dataclass(frozen=True)
-class LeastReplicated:
- expressions: tuple[Expression, ...]
-
- def __post_init__(self):
- assert len(self.expressions) >= 1
-
-
-@dataclasses.dataclass(frozen=True)
-class MostReplicated:
- expressions: tuple[Expression, ...]
-
- def __post_init__(self):
- assert len(self.expressions) >= 1
-
-
@dataclasses.dataclass(frozen=True)
class Reduce:
expression: Expression
@@ -136,8 +120,6 @@ def __str__(self):
Expression = (
Variable
| Constant
- | LeastReplicated
- | MostReplicated
| Reduce
| BroadcastInDim
| Reshape
@@ -145,62 +127,6 @@ def __str__(self):
)
-def reduce_replicated_expression(
- input_expr: LeastReplicated | MostReplicated,
- assignments: dict[Variable, Constant],
- reducer: Callable[[fa.FragmentedLayout, fa.FragmentedLayout], fa.FragmentedLayout | None]
-) -> Expression | Unsatisfiable:
- assert input_expr.expressions
-
- new_expressions: list[Expression] = []
- # Use a set to eliminate duplicates, but preserve the order.
- seen: set[Expression] = set()
- for expr in input_expr.expressions:
- reduced_expr = reduce_expression(expr, assignments)
- if isinstance(reduced_expr, Unsatisfiable):
- return Unsatisfiable()
- if reduced_expr in seen:
- continue
- new_expressions.append(reduced_expr)
- seen.add(reduced_expr)
-
- if len(new_expressions) == 1:
- return new_expressions[0]
-
- consts = []
- unknowns = []
- for e in new_expressions:
- if not isinstance(e, Constant):
- unknowns.append(e)
- continue
- if not isinstance(e, RegisterLayout):
- raise ValueError(
- f"Reduction of non-register layout constant is not supported: {e}"
- )
- consts.append(e)
-
- if consts:
- const_red, *consts = consts
- red = const_red
- for cst in consts:
- red_value = reducer(red.value, cst.value)
- if red_value is None:
- # The layouts are not compatible up to replication, this expression
- # cannot be simplified.
- return Unsatisfiable()
- red = RegisterLayout(red_value)
- else:
- red = None
-
- constructor = type(input_expr)
- if red is not None:
- if unknowns:
- return constructor((red, *unknowns))
- return red
-
- return constructor(tuple(unknowns))
-
-
def reduce_broadcast_expression(
broadcast: BroadcastInDim, assignments: dict[Variable, Constant]
) -> Expression | Unsatisfiable:
@@ -314,14 +240,6 @@ def reduce_expression(
return expr
case Variable():
return assignments.get(expr, expr)
- case MostReplicated():
- return reduce_replicated_expression(
- expr, assignments, layouts_lib.join_layouts
- )
- case LeastReplicated():
- return reduce_replicated_expression(
- expr, assignments, layouts_lib.meet_layouts
- )
case Reduce(expression=expr, axes=axes):
reduced_expr = reduce_expression(expr, assignments)
match reduced_expr:
@@ -340,6 +258,7 @@ def reduce_expression(
case _:
assert_never(expr)
+
@dataclasses.dataclass(frozen=True)
class Equals:
"""States that `lhs` and `rhs` are equal."""
@@ -471,25 +390,29 @@ def holds(self) -> bool | None:
Returns `None` if the constraint can't be checked.
"""
- source = self.source
- target = self.target
- if isinstance(source, TMEMLayout) and isinstance(target, RegisterLayout):
- return self._is_valid_tmem_transfer(source.value, target.value)
- if isinstance(target, TMEMLayout) and isinstance(source, RegisterLayout):
- return self._is_valid_tmem_transfer(target.value, source.value)
- if isinstance(source, TMEMLayout) and isinstance(target, TMEMLayout):
- return source == target
- if isinstance(source, SMEMTiling) and isinstance(target, RegisterLayout):
- return self._is_valid_smem_transfer(source.value, target.value)
- if isinstance(target, SMEMTiling) and isinstance(source, RegisterLayout):
- return self._is_valid_smem_transfer(target.value, source.value)
- if isinstance(target, Constant) and isinstance(source, Constant):
- source_type = type(source).__name__
- target_type = type(target).__name__
- raise NotImplementedError(f"Unsupported transfer: {source_type} -> {target_type}")
+ assert self.source != self.target, (
+ "IsTransferable constraints within the same memory space are not"
+ " supported."
+ )
- return None
+ match self.source, self.target:
+ case TMEMLayout(value=src), RegisterLayout(value=dst):
+ return self._is_valid_tmem_transfer(src, dst)
+ case RegisterLayout(value=src), TMEMLayout(value=dst):
+ return self._is_valid_tmem_transfer(dst, src)
+ case SMEMTiling(value=src), RegisterLayout(value=dst):
+ return self._is_valid_smem_transfer(src, dst)
+ case RegisterLayout(value=src), SMEMTiling(value=dst):
+ return self._is_valid_smem_transfer(dst, src)
+ case Constant(), Constant():
+ source_type = type(self.source).__name__
+ target_type = type(self.target).__name__
+ raise NotImplementedError(
+ f"Unsupported transfer: {source_type} -> {target_type}"
+ )
+ case _:
+ return None
def __str__(self):
return f"IsTransferable({self.source} ⟶ {self.target})"
@@ -568,10 +491,9 @@ def __str__(self):
def reduce_constraint(
constraint: Constraint, assignments: dict[Variable, Constant]
-) -> Constraint | Tautological | Unsatisfiable:
+) -> Constraint | Unsatisfiable:
"""Reduces a constraint."""
- new_constraint: Constraint
match constraint:
case Equals(lhs=lhs, rhs=rhs):
lhs_red = reduce_expression(lhs, assignments)
@@ -580,7 +502,7 @@ def reduce_constraint(
rhs_red = reduce_expression(rhs, assignments)
if isinstance(rhs_red, Unsatisfiable):
return Unsatisfiable()
- new_constraint = Equals(lhs_red, rhs_red)
+ return Equals(lhs_red, rhs_red)
case Relayout(source=source, target=target):
source_red = reduce_expression(source, assignments)
target_red = reduce_expression(target, assignments)
@@ -588,31 +510,26 @@ def reduce_constraint(
target_red, Unsatisfiable
):
return Unsatisfiable()
- new_constraint = Relayout(source_red, target_red)
+ return Relayout(source_red, target_red)
case NotOfType(expr=expr, type=type):
expr_red = reduce_expression(expr, assignments)
if isinstance(expr_red, Unsatisfiable):
return Unsatisfiable()
- new_constraint = NotOfType(expr_red, type)
+ return NotOfType(expr_red, type)
case IsTransferable(source=source, target=target, shape=shape):
source_red = reduce_expression(source, assignments)
target_red = reduce_expression(target, assignments)
if isinstance(source_red, Unsatisfiable) or isinstance(target_red, Unsatisfiable):
return Unsatisfiable()
- new_constraint = IsTransferable(source_red, target_red, shape)
+ return IsTransferable(source_red, target_red, shape)
case Divides(expr=expr, tiling_multiple=tiling_multiple):
expr_red = reduce_expression(expr, assignments)
if isinstance(expr_red, Unsatisfiable):
return Unsatisfiable()
- new_constraint = Divides(expr_red, tiling_multiple)
+ return Divides(expr_red, tiling_multiple)
case _ as never:
assert_never(never)
- constraint_holds = new_constraint.holds()
- if constraint_holds is None:
- return new_constraint
- return Tautological() if constraint_holds else Unsatisfiable()
-
@dataclasses.dataclass
class ConstraintSystem:
@@ -640,12 +557,6 @@ def extract_variables(expr: Expression) -> None:
free_variables.append(expr)
case Constant():
...
- case MostReplicated(expressions=expressions):
- for e in expressions:
- extract_variables(e)
- case LeastReplicated(expressions=expressions):
- for e in expressions:
- extract_variables(e)
case Reduce(expression=e):
extract_variables(e)
case BroadcastInDim(expression=e):
@@ -676,8 +587,10 @@ def extract_variables(expr: Expression) -> None:
return free_variables
def __and__(
- self, other: ConstraintSystem
+ self, other: ConstraintSystem | Unsatisfiable
) -> ConstraintSystem | Unsatisfiable:
+ if isinstance(other, Unsatisfiable):
+ return Unsatisfiable()
for variable, assignment in self.assignments.items():
if variable in other.assignments and assignment != other.assignments[variable]:
return Unsatisfiable()
@@ -704,10 +617,6 @@ def __and__(self, other: ConstraintSystem | Unsatisfiable) -> Unsatisfiable:
return self
-class Tautological:
- ...
-
-
def non_splat_variables(
constraints: Sequence[Constraint],
) -> set[Variable]:
@@ -916,11 +825,16 @@ def try_assign(var: Variable, cst: Constant) -> bool:
if not try_assign(var, cst):
return Unsatisfiable()
changed = True
- case Tautological():
- changed = True
case _ as new_constraint:
- changed |= new_constraint != constraint
- constraints.append(new_constraint)
+ assert isinstance(new_constraint, Constraint) # make pytype happy
+ match new_constraint.holds():
+ case None:
+ constraints.append(new_constraint)
+ changed |= new_constraint != constraint
+ case False:
+ return Unsatisfiable()
+ case True:
+ changed = True
new_constraints = merge_divides_constraints(constraints)
changed |= len(new_constraints) != len(constraints)
diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py
index 0492bf0a13dd..1d4f46c1cf9f 100644
--- a/jax/experimental/mosaic/gpu/dialect_lowering.py
+++ b/jax/experimental/mosaic/gpu/dialect_lowering.py
@@ -606,32 +606,16 @@ def _broadcasted_iota_op_lowering_rule(
return [fragmented_array_to_ir(a, result_type)]
-@_register_lowering(vector.BroadcastOp)
-def _vector_splat_op_lowering_rule(
- _: LoweringContext, vector_splat_op: vector.BroadcastOp
-) -> Sequence[ir.Value]:
-
- out_vec_ty = ir.VectorType(vector_splat_op.aggregate.type)
- fragmented_array = fa.FragmentedArray.splat(
- vector_splat_op.input,
- tuple(out_vec_ty.shape),
- layouts.from_layout_attr(vector_splat_op.attributes["out_layouts"][0]),
- is_signed=_default_is_signed(out_vec_ty.element_type),
- )
- return [fragmented_array_to_ir(fragmented_array, out_vec_ty)]
-
-
@_register_lowering(vector.BroadcastOp)
def _vector_broadcast_op_lowering_rule(
- _: LoweringContext, vector_broadcast_op: vector.BroadcastOp
+ _: LoweringContext, op: vector.BroadcastOp
) -> Sequence[ir.Value]:
-
- out_vec_ty = ir.VectorType(vector_broadcast_op.vector.type)
+ out_vec_ty = ir.VectorType(op.vector.type)
fragmented_array = fa.FragmentedArray.splat(
- vector_broadcast_op.source,
+ op.source,
tuple(out_vec_ty.shape),
layouts.from_layout_attr(
- vector_broadcast_op.attributes["out_layouts"][0]
+ op.attributes["out_layouts"][0]
),
is_signed=_default_is_signed(out_vec_ty.element_type),
)
@@ -646,7 +630,9 @@ def _vector_shape_cast_op_lowering_rule(
out_vec_ty = ir.VectorType(op.result.type)
assert out_vec_ty.has_static_shape
a = _fragmented_array_from_ir(op.source, layout)
- return [fragmented_array_to_ir(a.reshape(out_vec_ty.shape), out_vec_ty)]
+ return [
+ fragmented_array_to_ir(a.reshape(tuple(out_vec_ty.shape)), out_vec_ty)
+ ]
@_register_lowering(vector.ExtractStridedSliceOp)
@@ -673,6 +659,33 @@ def _vector_extract_strided_slice_op_lowering_rule(
return [fragmented_array_to_ir(result, out_vec_ty)]
+@_register_lowering(vector.ExtractOp)
+def _vector_extract_op_lowering_rule(
+ ctx: LoweringContext, op: vector.ExtractOp
+) -> Sequence[ir.Value]:
+ del ctx
+ if op.dynamic_position:
+ raise NotImplementedError("Only slicing with static indices allowed.")
+
+ [in_layout] = inference_utils.in_layouts(op)
+ a = _fragmented_array_from_ir(op.source, in_layout)
+
+ if not ir.VectorType.isinstance(op.result.type): # scalar result
+ result = a[tuple(op.static_position)]
+ assert isinstance(result.layout, fa.WGSplatFragLayout)
+ return [result.registers.item()]
+
+ [out_layout] = inference_utils.out_layouts(op)
+ assert in_layout == out_layout
+ a = _fragmented_array_from_ir(op.source, in_layout)
+ result_type = ir.VectorType(op.result.type)
+ slices = tuple(slice(i, i + 1) for i in op.static_position)
+ # TODO(allanrenucci): Add direct support for indexing to FragmentedArray.
+ result = a[slices].reshape(tuple(result_type.shape))
+ assert result.layout == layouts.from_layout_attr(out_layout)
+ return [fragmented_array_to_ir(result, result_type)]
+
+
@_register_lowering(vector.ReductionOp)
def _vector_reduction_op_lowering_rule(
ctx: LoweringContext, op: vector.ReductionOp
@@ -991,6 +1004,12 @@ def _mgpu_async_store_op_lowering_rule(
# flatten -> async_copy -> unflatted here, as long as flattened size is a
# multiple of 16.
+ # TODO(b/415721295):Simplify, after the minimal jaxlib version is 0.8.2.
+ if hasattr(mgpu, "TMAReduction") and store_op.reduction_op is not None:
+ reduction_op = mgpu.TMAReduction(store_op.reduction_op.value).name.lower()
+ else:
+ reduction_op = None
+
# TODO(dasenov): Add support for the remaining op properties.
ctx.launch_context.async_copy(
src_ref=unwrapped_source,
@@ -1000,6 +1019,7 @@ def _mgpu_async_store_op_lowering_rule(
gmem_transform=transforms,
predicate=ctx.single_thread_per_warpgroup_predicate,
arrive=store_op.commit_group,
+ reduction_op=reduction_op,
)
return []
@@ -1092,6 +1112,8 @@ def _unary_op_lowering_rule(
(mlir_math.RsqrtOp, fa.FragmentedArray.rsqrt, None),
(mlir_math.ExpOp, fa.FragmentedArray.exp, None),
(mlir_math.Exp2Op, fa.FragmentedArray.exp2, None),
+ (mlir_math.SinOp, fa.FragmentedArray.sin, None),
+ (mlir_math.CosOp, fa.FragmentedArray.cos, None),
(mlir_math.LogOp, fa.FragmentedArray.log, None),
(mlir_math.TanhOp, fa.FragmentedArray.tanh, None),
]:
@@ -1332,7 +1354,7 @@ def _mgpu_arrive_expect_tx_op_lowering_rule(
barrier = utils.DialectBarrierRef.from_barrier_memref(
arrive_expect_tx_op.barrier
)
- nvvm.mbarrier_arrive_expect_tx(barrier.get_ptr(), bytes)
+ utils.nvvm_mbarrier_arrive_expect_tx(barrier.get_ptr(), bytes)
return []
diff --git a/jax/experimental/mosaic/gpu/examples/BUILD b/jax/experimental/mosaic/gpu/examples/BUILD
index 7edfe1c74db0..42d6c9555018 100644
--- a/jax/experimental/mosaic/gpu/examples/BUILD
+++ b/jax/experimental/mosaic/gpu/examples/BUILD
@@ -35,7 +35,7 @@ py_library(
srcs = ["matmul.py"],
deps = [
"//jax",
- "//jax:mosaic_gpu",
+ "//jax/experimental:mosaic_gpu",
],
)
@@ -44,7 +44,7 @@ py_library(
srcs = ["matmul_blackwell.py"],
deps = [
"//jax",
- "//jax:mosaic_gpu",
+ "//jax/experimental:mosaic_gpu",
],
)
@@ -53,7 +53,7 @@ py_library(
srcs = ["flash_attention.py"],
deps = [
"//jax",
- "//jax:mosaic_gpu",
+ "//jax/experimental:mosaic_gpu",
],
)
@@ -68,6 +68,6 @@ jax_multiplatform_test(
"notap",
],
deps = [
- "//jax:mosaic_gpu",
+ "//jax/experimental:mosaic_gpu",
] + py_deps("numpy"),
)
diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py
index 223580bb9582..7a4d21dbc45f 100644
--- a/jax/experimental/mosaic/gpu/fragmented_array.py
+++ b/jax/experimental/mosaic/gpu/fragmented_array.py
@@ -616,10 +616,14 @@ class WGSplatFragLayout:
def can_broadcast_to(self, shape) -> bool:
"""Check that the shape can be broadcast.
- Only dimensions of size 1 can be broadcast. All other dimensions
- must be the same as the argument shape.
+ All source dimensions must match the target's trailing dimensions by
+ equality or being set to 1 (i.e. we can broadcast 1-sized dimensions or
+ create new leading dimensions).
"""
- return all(dim1 == dim2 or dim1 == 1 for dim1, dim2 in zip(self.shape[::-1], shape[::-1]))
+ return len(self.shape) <= len(shape) and all(
+ dim1 == dim2 or dim1 == 1
+ for dim1, dim2 in zip(self.shape[::-1], shape[::-1])
+ )
def registers_element_type(self, t: ir.Type) -> ir.Type:
return t
@@ -1301,9 +1305,8 @@ def to_layout(self, new_layout: FragmentedLayout) -> FragmentedArray:
raise NotImplementedError(
f"Cannot convert from {self.layout} to {new_layout}"
)
- [reg] = self.registers.flat
return type(self).splat(
- reg, self.shape, new_layout, is_signed=self.is_signed
+ self.registers.item(), self.shape, new_layout, is_signed=self.is_signed
)
def _pointwise(
@@ -1722,16 +1725,19 @@ def bitcast(
)
def __getitem__(self, idx) -> FragmentedArray:
+ base_idx, slice_shape, is_squeezed = utils.parse_indices(idx, self.shape)
+ if isinstance(self.layout, WGSplatFragLayout):
+ shape = tuple(d for d, s in zip(slice_shape, is_squeezed) if not s)
+ return self.splat(self.registers.item(), shape, is_signed=self.is_signed)
if not isinstance(self.layout, TiledLayout):
raise NotImplementedError("Only arrays with tiled layouts can be sliced")
- base_idx, slice_shape, is_squeezed = utils.parse_indices(idx, self.shape)
if any(isinstance(idx, ir.Value) for idx in base_idx):
raise ValueError("Only slicing with static indices allowed")
if any(is_squeezed):
raise NotImplementedError("Integer indexing not implemented (only slicing allowed)")
base_tile_shape = self.layout.base_tile_shape
- if len(base_tile_shape) != len(self.shape):
- raise NotImplementedError("Tiling has different rank than array")
+ if untiled_rank := len(self.shape) - len(base_tile_shape):
+ base_tile_shape = (1,) * untiled_rank + base_tile_shape
if any(b % t for b, t in zip(base_idx, base_tile_shape, strict=True)):
raise ValueError(
"Base indices of array slices must be aligned to the beginning of a"
@@ -2448,6 +2454,38 @@ def reduce(
)
def broadcast(self, shape) -> FragmentedArray:
+ if isinstance(self.layout, WGStridedFragLayout):
+ src_shape, dst_shape = self.layout.shape, shape
+ if len(src_shape) > len(dst_shape):
+ raise ValueError(
+ f"Shape length mismatch. Expected len({src_shape}) <= len({dst_shape})"
+ )
+ if not all(s == 1 or s == d for s, d in zip(src_shape[::-1], dst_shape[::-1])):
+ raise ValueError(
+ "Can broadcast if all source dimensions match trailing target"
+ " dimensions by being equal or set to 1. Broadcasting from"
+ f" {src_shape} to {dst_shape}"
+ )
+ rank_diff = len(dst_shape) - len(src_shape)
+ src_shape = tuple([1] * rank_diff + list(src_shape))
+
+ assert len(src_shape) == len(dst_shape), (src_shape, dst_shape)
+ len_suffix = next(
+ (i for i in range(len(src_shape)) if src_shape[~i] != dst_shape[~i]),
+ len(src_shape)
+ )
+ if len_suffix > 0 and all(x == 1 for x in src_shape[:-len_suffix]):
+ return FragmentedArray(
+ _registers=np.tile(self.registers, np.prod(dst_shape[:-len_suffix])),
+ _layout=WGStridedFragLayout(shape, self.layout.vec_size),
+ _is_signed=self.is_signed,
+ )
+
+ raise NotImplementedError(
+ "Only major-most broadcast for WGStridedFragLayout is implemented."
+ f" Broadcasting from: {src_shape}, to: {dst_shape}."
+ )
+
if not isinstance(self.layout, WGSplatFragLayout):
raise NotImplementedError(self.layout)
@@ -2463,7 +2501,7 @@ def broadcast(self, shape) -> FragmentedArray:
_is_signed=self.is_signed,
)
- def reshape(self, shape) -> FragmentedArray:
+ def reshape(self, shape: tuple[int, ...]) -> FragmentedArray:
if self.shape == shape:
return self
if math.prod(shape) != math.prod(self.shape):
@@ -2472,13 +2510,35 @@ def reshape(self, shape) -> FragmentedArray:
match self.layout:
case WGSplatFragLayout() | WGStridedFragLayout():
new_layout = dataclasses.replace(self.layout, shape=shape)
+ return FragmentedArray(
+ _registers=self.registers,
+ _layout=new_layout,
+ _is_signed=self.is_signed,
+ )
+ case TiledLayout():
+ base_tile_shape = self.layout.base_tile_shape
+ assert base_tile_shape
+ old_shape_suffix = self.shape[-len(base_tile_shape):]
+ new_shape_suffix = shape[-len(base_tile_shape):]
+ # We already know that old_shape_suffix[0] is divisible by
+ # base_tile_shape[0].
+ if (
+ old_shape_suffix[1:] != new_shape_suffix[1:]
+ or new_shape_suffix[0] % base_tile_shape[0]
+ ):
+ raise ValueError(
+ f"Can't reshape {self.shape} to {shape} with a tiled layout with"
+ f" base tile of {base_tile_shape}"
+ )
+ new_registers_shape = self.layout.registers_shape(shape)
+ return FragmentedArray(
+ _registers=self.registers.reshape(new_registers_shape),
+ _layout=self.layout,
+ _is_signed=self.is_signed,
+ )
case _:
raise NotImplementedError(self.layout)
- return FragmentedArray(
- _registers=self.registers, _layout=new_layout, _is_signed=self.is_signed
- )
-
def broadcast_minor(self, n) -> FragmentedArray:
if len(self.shape) != 1:
raise ValueError("Broadcast minor is only supported for 1D arrays")
@@ -2502,15 +2562,23 @@ def broadcast_in_dim(
f" {shape[target_dim]} in shape after broadcast"
)
if isinstance(self.layout, WGSplatFragLayout):
- if isinstance(layout, WGSplatFragLayout):
- if layout.shape != shape:
- raise ValueError(
- f"Layout shape {layout.shape} does not match broadcast shape {shape}"
- )
+ return type(self).splat(
+ self.registers.item(), shape, layout, is_signed=self.is_signed
+ )
+ if isinstance(self.layout, WGStridedFragLayout) and isinstance(layout, WGStridedFragLayout):
+ new_dims = set(range(len(shape))) - set(source_dimensions)
+ vec_match = self.layout.vec_size == layout.vec_size
+ broadcast_dim_match = new_dims == set(range(len(new_dims)))
+ assert layout.shape == shape, (layout.shape, shape)
+ if vec_match and broadcast_dim_match:
return FragmentedArray(
- _registers=self.registers, _layout=layout, _is_signed=self.is_signed,
+ _registers=np.tile(
+ self.registers,
+ np.prod(shape[:len(new_dims)]),
+ ),
+ _layout=layout,
+ _is_signed=self.is_signed,
)
- # TODO: Support splat to other layouts
if not isinstance(self.layout, TiledLayout) or not isinstance(layout, TiledLayout):
raise NotImplementedError(self.layout, layout)
if any(d1 >= d2 for d1, d2 in zip(source_dimensions, source_dimensions[1:])):
diff --git a/jax/experimental/mosaic/gpu/launch_context.py b/jax/experimental/mosaic/gpu/launch_context.py
index 0ab01a6de37f..ff8d94814c22 100644
--- a/jax/experimental/mosaic/gpu/launch_context.py
+++ b/jax/experimental/mosaic/gpu/launch_context.py
@@ -39,7 +39,24 @@
TMA_DESCRIPTOR_BYTES = 128
TMA_DESCRIPTOR_ALIGNMENT = 64
-TMAReductionOp = Literal["add", "min", "max", "inc", "dec", "and", "or", "xor"]
+TMAReductionOp = Literal[
+ "add",
+ "min",
+ "max",
+ "inc",
+ "dec",
+ "and",
+ "or",
+ "xor",
+ "umin",
+ "umax",
+ "smin",
+ "smax",
+]
+
+def _reduction_op_to_ptx(reduction_op: TMAReductionOp) -> str:
+ # convert [s|u]min|max to min|max
+ return reduction_op[-3:]
c = utils.c # This is too common to fully qualify.
@@ -426,6 +443,81 @@ def _find_kernel_argument_for_gmem_ref(
return gmem_ref
+def _is_tma_reduction_op_supported(
+ reduction_op: TMAReductionOp | None, dtype: ir.Type,
+) -> bool:
+ """Returns whether the given TMA reduction op supports the given dtype.
+
+ This function essentially implements the table at:
+ https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-reduce-async-bulk-tensor
+ with the following differences:
+ - For `add` reductions, we also support int64, treating it as uint64.
+ - For `and`, `or`, and `xor` reductions, we support signed integer types.
+ - For `inc` and `dec` reductions, we support both signed and unsigned i32
+ treating both as unsigned.
+ """
+ i32 = ir.IntegerType.get_signless(32)
+ i64 = ir.IntegerType.get_signless(64)
+ f16 = ir.F16Type.get()
+ f32 = ir.F32Type.get()
+ bf16 = ir.BF16Type.get()
+
+ match reduction_op:
+ case None:
+ return True
+ case "add":
+ return dtype in (f16, f32, bf16, i32, i64)
+ case "max" | "min":
+ return dtype in (f16, bf16)
+ case "umax" | "umin" | "smax" | "smin":
+ return dtype in (i32, i64)
+ case "inc" | "dec":
+ return dtype == i32
+ case "and" | "or" | "xor":
+ return dtype in (i32, i64)
+
+
+def _tma_dma_type(
+ element_type: ir.Type,
+ reduction_op: TMAReductionOp | None,
+) -> int:
+ """Returns the TMA DMA type for the given element type and signedness."""
+ if ir.IntegerType.isinstance(element_type):
+ bitwidth = utils.bitwidth_impl(element_type)
+ if bitwidth == 2:
+ tma_dtype = 8
+ elif bitwidth == 4:
+ tma_dtype = 0
+ elif bitwidth == 8:
+ tma_dtype = 1
+ elif bitwidth == 16:
+ tma_dtype = 2
+ elif bitwidth == 32:
+ tma_dtype = 9 if reduction_op in ("smin", "smax") else 3
+ elif bitwidth == 64:
+ tma_dtype = 10 if reduction_op in ("smin", "smax") else 4
+ else:
+ raise ValueError(f"Unsupported integer bitwidth: {bitwidth}")
+ elif ir.F16Type.isinstance(element_type):
+ tma_dtype = 5
+ elif ir.F32Type.isinstance(element_type):
+ tma_dtype = 6
+ elif ir.BF16Type.isinstance(element_type):
+ tma_dtype = 7
+ # We treat narrow floats as integers
+ elif ir.Float8E5M2Type.isinstance(element_type):
+ tma_dtype = 1
+ elif ir.Float8E4M3FNType.isinstance(element_type):
+ tma_dtype = 1
+ elif ir.Float8E8M0FNUType.isinstance(element_type):
+ tma_dtype = 1
+ elif ir.Float4E2M1FNType.isinstance(element_type):
+ tma_dtype = 0
+ else:
+ raise ValueError(f"unsupported TMA dtype {element_type}")
+ return tma_dtype
+
+
class AsyncCopyImplementation(enum.Enum):
TMA = enum.auto()
CP_ASYNC = enum.auto()
@@ -438,7 +530,7 @@ class LaunchContext:
cluster_size: tuple[int, int, int]
profiler: OnDeviceProfiler | None = None
tma_descriptors: dict[
- tuple[ir.Value, tuple[int, ...], int | None, tuple[MemRefTransform, ...], Any],
+ tuple[ir.Value, tuple[int, ...], int | None, tuple[MemRefTransform, ...], Any, int],
ir.Value,
] = dataclasses.field(default_factory=dict, init=False)
is_device_collective: bool = False
@@ -512,10 +604,11 @@ def _get_tma_desc(
reduction_op: TMAReductionOp | None,
):
gmem_ref = _find_kernel_argument_for_gmem_ref(gmem_ref)
+ tma_dtype = _tma_dma_type(ir.MemRefType(gmem_ref.type).element_type, reduction_op)
# Using ir.Values in cache keys is a little sketchy, but I think it should
# be fine. Having it in the key will keep it alive, and if comparison and
# hashing is by identity then it should work out.
- tma_desc_key = (gmem_ref, transformed_slice_shape, swizzle, gmem_transform, gmem_peer_id)
+ tma_desc_key = (gmem_ref, transformed_slice_shape, swizzle, gmem_transform, gmem_peer_id, tma_dtype)
if (tma_desc := self.tma_descriptors.get(tma_desc_key, None)) is None:
i32 = ir.IntegerType.get_signless(32)
i64 = ir.IntegerType.get_signless(64)
@@ -580,43 +673,6 @@ def init_tma_desc(host_ptr):
)
# TODO(apaszke): Better verification (e.g. slice is non-zero)
# TODO(apaszke): We always know strides statically.
- if isinstance(ref_ty.element_type, ir.IntegerType):
- if reduction_op is not None:
- raise ValueError(
- f"TMA with reduction_op={reduction_op} is not supported with Integers"
- )
- bitwidth = utils.bitwidth_impl(ref_ty.element_type)
- if bitwidth == 2:
- tma_dtype = 8
- elif bitwidth == 4:
- tma_dtype = 0
- elif bitwidth == 8:
- tma_dtype = 1
- elif bitwidth == 16:
- tma_dtype = 2
- elif bitwidth == 32:
- tma_dtype = 3
- elif bitwidth == 64:
- tma_dtype = 4
- else:
- raise ValueError(f"Unsupported integer bitwidth: {bitwidth}")
- elif ir.F16Type.isinstance(ref_ty.element_type):
- tma_dtype = 5
- elif ir.F32Type.isinstance(ref_ty.element_type):
- tma_dtype = 6
- elif ir.BF16Type.isinstance(ref_ty.element_type):
- tma_dtype = 7
- # We treat narrow floats as integers
- elif ir.Float8E5M2Type.isinstance(ref_ty.element_type):
- tma_dtype = 1
- elif ir.Float8E4M3FNType.isinstance(ref_ty.element_type):
- tma_dtype = 1
- elif ir.Float8E8M0FNUType.isinstance(ref_ty.element_type):
- tma_dtype = 1
- elif ir.Float4E2M1FNType.isinstance(ref_ty.element_type):
- tma_dtype = 0
- else:
- raise ValueError(f"unsupported TMA dtype {ref_ty.element_type}")
dtype_or_bitwidth = c(tma_dtype, i64)
args = [
host_ptr,
@@ -862,7 +918,7 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int):
if max(slice_shape) > 256:
raise ValueError(
"Async copies only support copying <=256 elements along each"
- " dimension"
+ f" dimension, got {tuple(slice_shape)}"
)
if (zeroth_bw := slice_shape[-1] * element_bitwidth) % 128 != 0:
raise ValueError(
@@ -953,16 +1009,10 @@ def async_copy(
if reduction_op is not None:
if implementation != AsyncCopyImplementation.TMA:
raise ValueError("Only the TMA implementation supports reductions")
- if not any(
- t.isinstance(element_type)
- for t in (ir.F32Type, ir.BF16Type, ir.F16Type)
- ):
- raise ValueError(
- "TMA with reduction is only supported with f32, f16 and bf16"
- )
- if reduction_op != "add":
+ if not _is_tma_reduction_op_supported(reduction_op, element_type):
raise ValueError(
- "TMA with reduction is only supported with add operation"
+ f"Reduction op {reduction_op} not supported by the TMA"
+ f" implementation for element type {element_type}"
)
if src_ref_ty.memory_space is None and utils.is_smem_ref(dst_ref_ty):
@@ -1019,7 +1069,7 @@ def async_copy(
raise ValueError(
"Expected the SMEM reference to have the same shape as the"
f" transformed slice: {tuple(smem_ref_ty.shape)} !="
- f" {slice_shape[len(squeezed_dims):]}"
+ f" {tuple(slice_shape[len(squeezed_dims):])}"
)
if implementation == AsyncCopyImplementation.CP_ASYNC:
@@ -1157,7 +1207,7 @@ def async_copy(
if arrive:
arrive_predicate = utils.single_thread_predicate(utils.ThreadSubset.WARPGROUP)
- nvvm.mbarrier_arrive_expect_tx(
+ utils.nvvm_mbarrier_arrive_expect_tx(
barrier_ptr,
transfer_bytes,
predicate=arrive_predicate,
@@ -1288,7 +1338,7 @@ def async_copy(
arith.CmpIPredicate.eq, self.cluster_idx(collective), c(0, index),
)
arrive_predicate = arith.andi(predicate, first_block)
- nvvm.mbarrier_arrive_expect_tx(
+ utils.nvvm_mbarrier_arrive_expect_tx(
barrier_ptr, transfer_bytes, predicate=arrive_predicate
)
rank = len(slice_shape)
@@ -1309,7 +1359,7 @@ def async_copy(
)
else:
if arrive:
- nvvm.mbarrier_arrive_expect_tx(
+ utils.nvvm_mbarrier_arrive_expect_tx(
barrier_ptr, transfer_bytes, predicate=predicate
)
if collective_size > 1:
@@ -1329,7 +1379,7 @@ def async_copy(
llvm.inline_asm(
ir.Type.parse("!llvm.void"),
[predicate,smem_ptr,tma_desc,*rev_dyn_base_indices],
- f"@$0 cp.reduce.async.bulk.tensor.{rank}d.global.shared::cta.{reduction_op}.tile.bulk_group [$2,{{{idx_operands}}}], [$1];",
+ f"@$0 cp.reduce.async.bulk.tensor.{rank}d.global.shared::cta.{_reduction_op_to_ptx(reduction_op)}.tile.bulk_group [$2,{{{idx_operands}}}], [$1];",
"b,r,l" + ",r" * rank,
has_side_effects=True,
)
diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py
index 5a2eaa97178b..736a6b3cfd9b 100644
--- a/jax/experimental/mosaic/gpu/layout_inference.py
+++ b/jax/experimental/mosaic/gpu/layout_inference.py
@@ -82,6 +82,7 @@ class MemorySpace(enum.Enum):
_op_name_regex = re.compile(r"^(%\d+ = )?\S+")
+
@dataclasses.dataclass(frozen=True)
class ValueSite:
"""A unique identifier for a variable.
@@ -114,6 +115,11 @@ def value(self) -> ir.Value:
else:
return self.operation.regions[self.region_index].blocks[0].arguments[self.index]
+ @property
+ def shape(self) -> tuple[int, ...]:
+ """Returns the shape of the underlying value."""
+ return tuple(self.value.type.shape) # pytype: disable=attribute-error
+
@property
def memory_space(self) -> MemorySpace:
"""Returns the memory space associated with this value."""
@@ -138,52 +144,13 @@ def __str__(self):
return f"{match.group(0)}:a-{self.index}"
-@dataclasses.dataclass(frozen=True)
-class Hint:
- """Hints are used to model propagation of layouts across operations.
-
- Since using `relayout`s is always an option in principle, propagation across
- ops can not rely only on a constraint system. Instead, we introduce hints as
- a form of "soft constraints", i.e., it suggests that `variable` should be
- equal to `expression`.
- """
- variable: cs.Variable
- expression: cs.Expression
-
- def __str__(self):
- return f"{self.variable} ?= {self.expression}"
-
-
-def extract_constant_from_replicated_expression_for_hint(
- expression: cs.LeastReplicated | cs.MostReplicated,
-) -> cs.Constant | None:
- assert len(expression.expressions) >= 1
- choices: list[cs.Constant] = []
- for e in expression.expressions:
- if (red := extract_constant_for_hint(e)) is not None:
- choices.append(red)
-
- if not choices:
- return None
-
- # We reduce the expression here in order to recover an unambiguous
- # replicated layout if it exists.
- maybe_choice = cs.reduce_expression(type(expression)(tuple(choices)), {})
-
- if isinstance(maybe_choice, cs.Unsatisfiable):
- # TODO(bchetioui): consider other choices.
- return choices[0]
-
- assert isinstance(maybe_choice, cs.Constant)
- return maybe_choice
-
-
-def extract_constant_from_broadcast_in_dim_expression_for_hint(
- e: cs.BroadcastInDim,
-) -> cs.RegisterLayout | None:
- if not isinstance(e.expression, cs.RegisterLayout):
- return None
-
+def extract_assignment_candidates_from_reduce_equation(
+ small: cs.RegisterLayout,
+ large: cs.Variable,
+ reduction_dims: tuple[int, ...]
+) -> Iterator[cs.RegisterLayout]:
+ """Yields layout candidates for the reduce equation `small = reduce(large, reduction_dims)."""
+ large_shape = large.key.value.type.shape # pytype: disable=attribute-error
candidates = [
fa.WGMMA_LAYOUT,
fa.WGMMA_TRANSPOSED_LAYOUT,
@@ -191,66 +158,14 @@ def extract_constant_from_broadcast_in_dim_expression_for_hint(
fa.TCGEN05_TRANSPOSED_LAYOUT,
tcgen05.TMEM_NATIVE_LAYOUT,
]
- if e.shape[-1] % 16 == 0:
- candidates.append(tcgen05.fa_m64_collective_layout(e.shape[-1]))
+ if large_shape[-1] % 16 == 0:
+ candidates.append(tcgen05.fa_m64_collective_layout(large_shape[-1]))
- # TODO(allanrenucci): Allow returning multiple valid candidates.
- reduction_dims = tuple(d for d in range(len(e.shape)) if d not in e.axes)
for candidate in candidates:
- if len(candidate.base_tile_shape) > len(e.shape):
+ if len(candidate.base_tile_shape) > len(large_shape):
continue
- if candidate.reduce(reduction_dims) == e.expression.value:
- return cs.RegisterLayout(candidate)
- return None
-
-
-def extract_constant_for_hint(e: cs.Expression) -> cs.Constant | None:
- """Attempts to extract a `ConstantExpression` from a `Hint`'s `Expression`.
-
- Returns `None` if no `ConstantExpression` could be reasonably extracted.
- """
- match e:
- case cs.Constant():
- return e
- case cs.LeastReplicated() | cs.MostReplicated():
- return extract_constant_from_replicated_expression_for_hint(e)
- case cs.BroadcastInDim():
- return extract_constant_from_broadcast_in_dim_expression_for_hint(e)
- case cs.Variable():
- return None
- case _:
- raise NotImplementedError(f"Unsupported expression type: {type(e)}")
-
-
-def extract_variable_assignment_from_hint(
- hint: Hint,
-) -> tuple[cs.Variable, cs.Constant] | None:
- """Attempts to extract a single variable assignment from a `Hint`."""
- # TODO(bchetioui): make this a generator. This will allow us to maybe extract
- # different assignments that satisfy a replication constraint in the case
- # where replicated expressions are incompatible and several extractions are
- # possible.
- red = extract_constant_for_hint(hint.expression)
- return (hint.variable, red) if red is not None else None
-
-
-def reduce_hints(
- hints: Sequence[Hint], assignments: dict[cs.Variable, cs.Constant]
-) -> list[Hint]:
- """Reduces a sequence of `Hint`s.
-
- We reduce the `Hint`s' expressions, drop `Unsatisfiable` hints, and drop
- `Hint`s pertaining to pre-existing assignments.
- """
- new_hints: list[Hint] = []
- for h in hints:
- if h.variable not in assignments:
- reduced_expression = cs.reduce_expression(h.expression, assignments)
- if isinstance(reduced_expression, cs.Unsatisfiable):
- continue
- new_hints.append(dataclasses.replace(h, expression=reduced_expression))
-
- return new_hints
+ if candidate.reduce(reduction_dims) == small.value:
+ yield cs.RegisterLayout(candidate)
def _strided_layout_for_variable(
@@ -267,6 +182,19 @@ def _strided_layout_for_variable(
return fa.WGStridedFragLayout.from_shaped_type(type)
+def _default_tmem_layout_for_variable(
+ variable: cs.Variable,
+) -> tcgen05.TMEMLayout | None:
+ """Returns a default TMEM layout for the given variable, if one is defined."""
+ value = variable.key.value
+ parent = value.owner
+ if isinstance(parent, mgpu.TmemAllocOp):
+ return tcgen05._infer_tmem_layout(
+ tuple(value.type.shape), parent.collective, packing=1
+ )
+ return None
+
+
def _extract_tiling_candidate(
divide_constraint: cs.Divides, num_tiled_dims: int
) -> Iterator[tuple[cs.Variable, cs.Constant]]:
@@ -357,58 +285,73 @@ def _extract_variable_assignments_from_constraints(
match c:
case cs.IsTransferable():
yield from _extract_layout_candidates_from_memory_space_transfer(c, dpv)
+ case cs.Equals(cs.Reduce(cs.Variable() as large, axes=axes), cs.RegisterLayout() as small):
+ for layout in extract_assignment_candidates_from_reduce_equation(small, large, axes):
+ yield large, layout
+ case cs.Equals(cs.RegisterLayout() as small, cs.Reduce(cs.Variable() as large, axes=axes)):
+ for layout in extract_assignment_candidates_from_reduce_equation(small, large, axes):
+ yield large, layout
+ case cs.Relayout(cs.Variable() as var, cs.RegisterLayout() as layout):
+ yield var, layout
+ case cs.Relayout(cs.RegisterLayout() as layout, cs.Variable() as var):
+ yield var, layout
def conjure_assignment(
unknowns: Sequence[cs.Variable],
constraint_system: cs.ConstraintSystem,
- hints: Sequence[Hint],
) -> Iterator[tuple[cs.Variable, cs.Constant]]:
"""Attempts to conjure an assignment for an unknown variable."""
# TODO(allanrenucci): We should be able to short-circuit the search here if
# the constraint is not satisfiable.
- yield from _extract_variable_assignments_from_constraints(
- constraint_system.constraints
- )
- def assignment_order(
- assignment: tuple[cs.Variable, cs.Constant],
- ) -> int:
- match assignment:
- # Try TiledLayout first, before other hints, because TiledLayout` are
- # usually more useful to propagate than `WGSplat`. Also this often
- # improves the performance of the layout inference.
- case (_, cs.RegisterLayout(fa.TiledLayout())):
- return 0
+ # As we extract assignment candidates from constraints, we prioritize
+ # candidates that are more "interesting"; e.g., in the case of registers,
+ # introducing splat layout candidate assignments often leads to a dead end in
+ # practice---as opposed to tiled layouts, which are more likely to yield
+ # solutions to the constraint system.
+ low_priority_assignments: list[tuple[cs.Variable, cs.Constant]] = []
+ for variable, constant in _extract_variable_assignments_from_constraints(
+ constraint_system.constraints
+ ):
+ match constant:
+ case cs.RegisterLayout(value=value) if not isinstance(value, fa.TiledLayout):
+ low_priority_assignments.append((variable, constant))
case _:
- return 1
+ yield variable, constant
- assignments = [extract_variable_assignment_from_hint(h) for h in hints]
- assignments = [a for a in assignments if a is not None]
- assignments = sorted(assignments, key=assignment_order)
- yield from assignments
+ # After all high-priority assignments have been attempted, switch to using
+ # low-priority assignments.
+ for variable, constant in low_priority_assignments:
+ yield variable, constant
# Here, we have not managed to find an assignment for all the unknown
- # variables, and our hints have not proven sufficient to unblock us. We now
- # try to introduce new arbitrary (valid) assignments into the system, and
- # hope that they turn out to be compatible with the constraint system.
+ # variables. We now try to introduce new arbitrary (valid) assignments into
+ # the system, and hope that they turn out to be compatible with the constraint
+ # system.
for variable in unknowns:
if variable in constraint_system.assignments:
continue
- # Try to instantiate a single variable to a strided layout and see if it
+ # Try to instantiate a single variable to a default layout and see if it
# reduces the system.
- if variable.key.memory_space == MemorySpace.REG:
- layout = _strided_layout_for_variable(variable)
- if layout is not None:
- yield variable, cs.RegisterLayout(layout)
- elif variable.key.memory_space == MemorySpace.SMEM:
- yield variable, cs.SMEMTiling(None)
+ match variable.key.memory_space:
+ case MemorySpace.REG:
+ layout = _strided_layout_for_variable(variable)
+ if layout is not None:
+ yield variable, cs.RegisterLayout(layout)
+ case MemorySpace.SMEM:
+ yield variable, cs.SMEMTiling(None)
+ case MemorySpace.TMEM:
+ layout = _default_tmem_layout_for_variable(variable)
+ if layout is not None:
+ yield variable, cs.TMEMLayout(layout)
+ case _:
+ raise ValueError(f"Unsupported memory space: {variable.key.memory_space}")
def find_assignments_for(
unknowns: Sequence[cs.Variable],
constraint_system: cs.ConstraintSystem,
- hints: Sequence[Hint],
*,
fuel: int,
) -> tuple[dict[cs.Variable, cs.Constant] | cs.Unsatisfiable, int]:
@@ -418,7 +361,6 @@ def find_assignments_for(
unknowns: the set of variables that are unknown. Represented as a sequence
of `Variable`s for determinism purposes.
constraint_system: the constraint system to satisfy.
- hints: a list of hints that may be used to introduce new assignments.
fuel: the fuel to use for the search. Once the fuel is exhausted, we raise
an error.
@@ -449,17 +391,12 @@ def find_assignments_for(
v: k for v, k in constraint_system.assignments.items() if v in unknowns
}, fuel
- # Reduce the expressions in the remaining hints based on the current
- # assignments, and eliminate hints that pertain to variables that already
- # have an assignment.
- hints = reduce_hints(hints, constraint_system.assignments)
-
# If unknowns remain and we have fully reduced the system, we may still
- # be able to make progress by extracting an assignment from a `Hint`. This
- # new assignment could make the system unsatisfiable, so we use a recursive
+ # be able to make progress by trying out potential assignments. These
+ # new assignments could make the system unsatisfiable, so we use a recursive
# call to be able to backtrack if necessary.
for assignment in conjure_assignment(
- remaining_unknowns, constraint_system, hints
+ remaining_unknowns, constraint_system
):
if fuel <= 0:
raise ValueError(
@@ -476,7 +413,7 @@ def find_assignments_for(
# This assignment is not compatible with the constraint system.
continue
solution, fuel = find_assignments_for(
- unknowns, new_constraint_system, hints, fuel=fuel
+ unknowns, new_constraint_system, fuel=fuel
)
if not isinstance(solution, cs.Unsatisfiable):
return solution, fuel
@@ -516,8 +453,8 @@ def producer_ref(self, operand: ValueSite) -> cs.Variable:
ValueSitesForVariable = dict[cs.Variable, list[ValueSite]]
# A constraint system derivation rule is a function that takes an MLIR operation
-# and returns a constraint system, a mapping from variables to value site
-# identifiers, and a list of hints.
+# and returns a constraint system, and a mapping from variables to value site
+# identifiers.
#
# The intended meaning of the mapping is that, for each identifier in the list
# keyed by a given variable, the MLIR operand/result/argument corresponding to
@@ -528,9 +465,12 @@ def producer_ref(self, operand: ValueSite) -> cs.Variable:
# and each identifier in the mapping must be keyed by exactly one variable.
# Lastly, the mapping must only refer to variables and
# operands/results/arguments that correspond to the given operation.
+ConstraintSystemDerivationRuleResult = cs.Unsatisfiable | tuple[
+ cs.ConstraintSystem, ValueSitesForVariable
+]
ConstraintSystemDerivationRule = Callable[
[DerivationContext, ir.OpView],
- tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]],
+ ConstraintSystemDerivationRuleResult,
]
_constraint_system_derivation_rules: dict[
str, ConstraintSystemDerivationRule
@@ -561,11 +501,11 @@ def _is_tmem_ref(v: ir.Value) -> bool:
def _pointwise_op_constraint_system(
ctx: DerivationContext,
op: ir.OpView,
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
del ctx
all_value_sites = vector_value_sites(op)
variable = cs.Variable(all_value_sites[-1])
- return cs.ConstraintSystem(), {variable: all_value_sites}, []
+ return cs.ConstraintSystem(), {variable: all_value_sites}
for op in [
@@ -604,6 +544,8 @@ def _pointwise_op_constraint_system(
arith.XOrIOp,
mlir_math.ExpOp,
mlir_math.Exp2Op,
+ mlir_math.SinOp,
+ mlir_math.CosOp,
mlir_math.LogOp,
mlir_math.RsqrtOp,
mlir_math.TanhOp,
@@ -615,7 +557,7 @@ def _pointwise_op_constraint_system(
def _vector_load_constraint_system(
ctx: DerivationContext,
op: mgpu.VectorLoadOp,
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
# TODO(b/447079781): Investigate whether we should check for contiguous
# strides here. An initial implementation of this failed the
# test_gmem_to_smem_with_multiple_smem_indexers_and_transforms test, but
@@ -636,14 +578,14 @@ def _vector_load_constraint_system(
constraints.append(cs.IsTransferable(source_var, dest_var, shape))
system = cs.ConstraintSystem(constraints=constraints)
- return system, value_sites_for_variable, []
+ return system, value_sites_for_variable
@_add_constraint_system_derivation_rule(mgpu.VectorStoreOp)
def _vector_store_constraint_system(
ctx: DerivationContext,
op: mgpu.VectorStoreOp,
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
# TODO(b/447079781): Investigate whether we should check for contiguous
# strides here. An initial implementaiton of this failed the
# test_gmem_to_smem_with_multiple_smem_indexers_and_transforms test, but
@@ -664,46 +606,46 @@ def _vector_store_constraint_system(
constraints.append(cs.IsTransferable(value_var, dest_var, shape))
system = cs.ConstraintSystem(constraints=constraints)
- return system, value_sites_for_variable, []
+ return system, value_sites_for_variable
@_add_constraint_system_derivation_rule(mgpu.DebugPrintOp)
def _debug_print_constraint_system(
ctx: DerivationContext,
op: mgpu.DebugPrintOp,
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
del ctx
value = ValueSite(op, VariableType.OPERAND, 0)
- return cs.ConstraintSystem(), {cs.Variable(value): [value]}, []
+ return cs.ConstraintSystem(), {cs.Variable(value): [value]}
@_add_constraint_system_derivation_rule(mgpu.PrintLayoutOp)
def _print_layout_constraint_system(
ctx: DerivationContext,
op: mgpu.PrintLayoutOp,
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
value = ValueSite(op, VariableType.OPERAND, 0)
var = cs.Variable(value) if is_vector(op.value) else ctx.producer_ref(value)
- return cs.ConstraintSystem(), {var: [value]}, []
+ return cs.ConstraintSystem(), {var: [value]}
@_add_constraint_system_derivation_rule(mgpu.BroadcastedIotaOp)
def _broadcasted_iota_constraint_system(
ctx: DerivationContext,
op: mgpu.BroadcastedIotaOp,
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
del ctx
value = ValueSite(op, VariableType.RESULT, 0)
var = cs.Variable(value)
constraints = [cs.NotOfType(var, fa.WGSplatFragLayout)]
- return cs.ConstraintSystem(constraints=constraints), {var: [value]}, []
+ return cs.ConstraintSystem(constraints=constraints), {var: [value]}
@_add_constraint_system_derivation_rule(mgpu.OptimizationBarrierOp)
def _optimization_barrier_constraint_system(
ctx: DerivationContext,
op: ir.OpView,
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
del ctx
value_sites_for_variable: ValueSitesForVariable = {}
@@ -716,14 +658,14 @@ def _optimization_barrier_constraint_system(
ValueSite(op, VariableType.RESULT, i)
]
- return cs.ConstraintSystem(), value_sites_for_variable, []
+ return cs.ConstraintSystem(), value_sites_for_variable
@_add_constraint_system_derivation_rule(vector.BroadcastOp)
def _vector_splat_constraint_system(
ctx: DerivationContext,
op: ir.OpView,
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
del ctx
result = ValueSite(op, VariableType.RESULT, 0)
variable = cs.Variable(result)
@@ -731,14 +673,14 @@ def _vector_splat_constraint_system(
system = cs.ConstraintSystem(
assignments={variable: cs.RegisterLayout(layout)}
)
- return system, {variable: [result]}, []
+ return system, {variable: [result]}
@_add_constraint_system_derivation_rule(arith.ConstantOp)
def _constant_constraint_system(
ctx: DerivationContext,
constant_op: arith.ConstantOp,
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
del ctx
value = constant_op.value
result = ValueSite(constant_op, VariableType.RESULT, 0)
@@ -756,7 +698,7 @@ def _constant_constraint_system(
constant_is_not_splat = cs.NotOfType(variable, fa.WGSplatFragLayout)
system = cs.ConstraintSystem(constraints=[constant_is_not_splat])
- return system, {variable: [result]}, []
+ return system, {variable: [result]}
def _terminator(
@@ -775,7 +717,7 @@ def _terminator(
def _for_constraint_system(
ctx: DerivationContext,
op: scf.ForOp,
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
[block] = op.region.blocks
yield_op = _terminator(block, scf.YieldOp)
value_sites_for_variable: ValueSitesForVariable = {}
@@ -797,7 +739,7 @@ def _for_constraint_system(
var = cs.Variable(operand) if is_vector(o) else ctx.producer_ref(operand)
value_sites_for_variable[var] = [operand, arg, result, yield_operand]
- return cs.ConstraintSystem(), value_sites_for_variable, []
+ return cs.ConstraintSystem(), value_sites_for_variable
def prime_decomposition(n: int) -> list[int]:
@@ -826,8 +768,8 @@ def dynamic_gcd(a: int, b: ir.Value) -> int:
raise ValueError("a must be strictly positive")
if not ir.IntegerType.isinstance(b.type) and not ir.IndexType.isinstance(b.type):
raise ValueError(f"Expected an integer dynamic value, got a {b.type}")
- if isinstance(b.owner, ir.Operation) and isinstance(b.owner.opview, arith.ConstantOp):
- return math.gcd(a, b.owner.opview.literal_value)
+ if isinstance(b.owner, arith.ConstantOp):
+ return math.gcd(a, b.owner.literal_value)
running_gcd = 1
for factor in prime_decomposition(a):
if utils.is_known_divisible(b, running_gcd * factor):
@@ -839,7 +781,7 @@ def dynamic_gcd(a: int, b: ir.Value) -> int:
def _while_constraint_system(
ctx: DerivationContext,
op: scf.WhileOp,
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
del ctx
[before_block] = op.before.blocks
[after_block] = op.after.blocks
@@ -871,14 +813,14 @@ def _while_constraint_system(
case _ as never:
assert_never(never) # pytype: disable=wrong-arg-types
- return cs.ConstraintSystem(), value_sites_for_variable, []
+ return cs.ConstraintSystem(), value_sites_for_variable
@_add_constraint_system_derivation_rule(scf.IndexSwitchOp)
def _index_switch_constraint_system(
ctx: DerivationContext,
op: scf.IndexSwitchOp,
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
del ctx
value_sites_for_variable: ValueSitesForVariable = {
cs.Variable(o): [o] for o in vector_value_sites(op)
@@ -893,23 +835,27 @@ def _index_switch_constraint_system(
)
value_sites_for_variable[value_site].append(yield_operand)
- return cs.ConstraintSystem(), value_sites_for_variable, []
+ return cs.ConstraintSystem(), value_sites_for_variable
@_add_constraint_system_derivation_rule(mgpu.LayoutCastOp)
def _layout_cast_constraint_system(
ctx: DerivationContext,
op: mgpu.LayoutCastOp,
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
del ctx
operand = ValueSite(op, VariableType.OPERAND, 0)
result = ValueSite(op, VariableType.RESULT, 0)
variable = cs.Variable(operand)
- out_layout = cs.RegisterLayout(layouts_lib.from_layout_attr(op.new_layout))
+ out_layout = layouts_lib.from_layout_attr(op.new_layout)
+ # TODO(bchetioui): think about raising a better error here.
+ if not is_valid_register_layout_assignment(operand.shape, out_layout):
+ return cs.Unsatisfiable()
return (
- cs.ConstraintSystem(assignments={variable: out_layout}),
+ cs.ConstraintSystem(
+ assignments={variable: cs.RegisterLayout(out_layout)}
+ ),
{variable: [operand, result]},
- [],
)
@@ -979,43 +925,55 @@ def _infer_wgmma_tiling(
def _wgmma_constraint_system(
ctx: DerivationContext,
op: mgpu.WGMMAOp,
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
assignments: dict[cs.Variable, cs.Constant] = {}
value_sites_for_variable: ValueSitesForVariable = {}
acc_out = ValueSite(op, VariableType.RESULT, 0)
acc_in = ValueSite(op, VariableType.OPERAND, 0)
acc_var = cs.Variable(acc_out)
- assignments[acc_var] = cs.RegisterLayout(fa.WGMMA_LAYOUT)
+ acc_layout = fa.WGMMA_LAYOUT
+ assignments[acc_var] = cs.RegisterLayout(acc_layout)
+ acc_is_valid = is_valid_register_layout_assignment(acc_out.shape, acc_layout)
value_sites_for_variable[acc_var] = [acc_in, acc_out]
a_tiling, b_tiling = _infer_wgmma_tiling(op.a.type, op.b.type)
b = ValueSite(op, VariableType.OPERAND, 2)
b_var = ctx.producer_ref(b)
- assignments[b_var] = cs.SMEMTiling(lc.TileTransform(b_tiling))
+ b_tile_transform = lc.TileTransform(b_tiling)
+ b_is_valid = is_valid_smem_layout_assignment(b.shape, b_tile_transform)
+ assignments[b_var] = cs.SMEMTiling(b_tile_transform)
value_sites_for_variable[b_var] = [b]
a = ValueSite(op, VariableType.OPERAND, 1)
if _is_smem_ref(op.a):
a_var = ctx.producer_ref(a)
- assignments[a_var] = cs.SMEMTiling(lc.TileTransform(a_tiling))
+ a_tile_transform = lc.TileTransform(a_tiling)
+ assignments[a_var] = cs.SMEMTiling(a_tile_transform)
+ a_is_valid = is_valid_smem_layout_assignment(a.shape, a_tile_transform)
else:
assert a_tiling is None
a_var = cs.Variable(a)
if ir.IntegerType.get_signless(8) == ir.VectorType(op.a.type).element_type:
- assignments[a_var] = cs.RegisterLayout(fa.WGMMA_LAYOUT_8BIT)
+ layout = fa.WGMMA_LAYOUT_8BIT
else:
- assignments[a_var] = cs.RegisterLayout(fa.WGMMA_LAYOUT)
+ layout = fa.WGMMA_LAYOUT
+ assignments[a_var] = cs.RegisterLayout(layout)
+ a_is_valid = is_valid_register_layout_assignment(a.shape, layout)
+
value_sites_for_variable[a_var] = [a]
- return cs.ConstraintSystem(assignments), value_sites_for_variable, []
+ # TODO(bchetioui): think about raising a better error here.
+ if not a_is_valid or not b_is_valid or not acc_is_valid:
+ return cs.Unsatisfiable()
+ return cs.ConstraintSystem(assignments), value_sites_for_variable
@_add_constraint_system_derivation_rule(vector.BroadcastOp)
def _vector_broadcast_constraint_system(
ctx: DerivationContext,
op: vector.BroadcastOp,
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
del ctx
# This is not expected to be necessary at the moment. We should be using
# mgpu.BroadcastInDimOp instead when dealing with broadcasting vectors.
@@ -1026,7 +984,6 @@ def _vector_broadcast_constraint_system(
return (
cs.ConstraintSystem(assignments={out_variable: layout}),
{out_variable: [out_variable.key]},
- [],
)
@@ -1034,34 +991,29 @@ def _vector_broadcast_constraint_system(
def _vector_reduction_constraint_system(
ctx: DerivationContext,
op: vector.ReductionOp,
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
del ctx
in_variable = cs.Variable(ValueSite(op, VariableType.OPERAND, 0))
- return cs.ConstraintSystem(), {in_variable: [in_variable.key]}, []
+ return cs.ConstraintSystem(), {in_variable: [in_variable.key]}
-def _reduction_constraint_and_hint(
+def _reduction_constraints(
larger: cs.Variable,
smaller: cs.Variable,
- larger_shape: tuple[int, ...],
reduction_dims: tuple[int, ...],
-) -> tuple[cs.Constraint, Hint]:
- reduce_expr = cs.Reduce(larger, reduction_dims)
- # There are always many options for broadcasting a layout, so we can only
- # derive a broadcast hint in the out_variable -> source_variable direction.
- broadcast_dims = tuple(
- i for i in range(len(larger_shape)) if i not in reduction_dims
- )
- broadcast_expr = cs.BroadcastInDim(smaller, broadcast_dims, larger_shape)
- broadcast_hint = Hint(variable=larger, expression=broadcast_expr)
- return cs.Equals(lhs=smaller, rhs=reduce_expr), broadcast_hint
+) -> list[cs.Constraint]:
+ return [
+ cs.Equals(lhs=smaller, rhs=cs.Reduce(larger, reduction_dims)),
+ # TODO(allanrenucci): Remove once we support reduction of strided layouts.
+ cs.NotOfType(larger, fa.WGStridedFragLayout),
+ ]
@_add_constraint_system_derivation_rule(vector.MultiDimReductionOp)
def _multi_dim_reduction_constraint_system(
ctx: DerivationContext,
op: vector.MultiDimReductionOp,
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
del ctx
source = ValueSite(op, VariableType.OPERAND, 0)
acc = ValueSite(op, VariableType.OPERAND, 1)
@@ -1069,19 +1021,17 @@ def _multi_dim_reduction_constraint_system(
source_variable = cs.Variable(source)
out_variable = cs.Variable(out)
- reduction_constraint, broadcast_hint = _reduction_constraint_and_hint(
+ reduction_constraints = _reduction_constraints(
source_variable,
out_variable,
- tuple(ir.ShapedType(op.source.type).shape),
tuple(op.reduction_dims),
)
# TODO(bchetioui): in the future, we may need to add rules that prevent
# strided layouts from being chosen---since trying to reduce a strided layout
# may cause us to raise an Exception at the moment.
return (
- cs.ConstraintSystem(constraints=[reduction_constraint]),
+ cs.ConstraintSystem(constraints=reduction_constraints),
{source_variable: [source], out_variable: [acc, out]},
- [broadcast_hint],
)
@@ -1089,7 +1039,7 @@ def _multi_dim_reduction_constraint_system(
def _broadcast_in_dim_constraint_system(
ctx: DerivationContext,
op: mgpu.BroadcastInDimOp,
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
del ctx
out_variable = cs.Variable(ValueSite(op, VariableType.RESULT, 0))
source_variable = cs.Variable(ValueSite(op, VariableType.OPERAND, 0))
@@ -1097,25 +1047,23 @@ def _broadcast_in_dim_constraint_system(
reduction_dims = tuple(
i for i in range(len(out_shape)) if i not in op.broadcast_dimensions
)
-
- reduction_constraint, broadcast_hint = _reduction_constraint_and_hint(
- out_variable, source_variable, out_shape, reduction_dims
+ reduction_constraints = _reduction_constraints(
+ out_variable, source_variable, reduction_dims
)
return (
- cs.ConstraintSystem(constraints=[reduction_constraint]),
+ cs.ConstraintSystem(constraints=reduction_constraints),
{
source_variable: [source_variable.key],
out_variable: [out_variable.key],
},
- [broadcast_hint],
)
@_add_constraint_system_derivation_rule(vector.ShapeCastOp)
def _shape_cast_constraint_system(
ctx: DerivationContext, op: vector.ShapeCastOp
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
del ctx
in_shape = tuple(cast(ir.ShapedType, op.source.type).shape)
out_shape = tuple(cast(ir.ShapedType, op.result.type).shape)
@@ -1152,14 +1100,13 @@ def _shape_cast_constraint_system(
],
),
{in_variable: [in_variable.key], out_variable: [out_variable.key]},
- [],
)
@_add_constraint_system_derivation_rule(vector.ExtractStridedSliceOp)
def _extract_strided_slice_constraint_system(
ctx: DerivationContext, op: vector.ExtractStridedSliceOp
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
del ctx
if any(ir.IntegerAttr(s).value != 1 for s in op.strides):
raise NotImplementedError("`strides` must contain only 1s.")
@@ -1179,7 +1126,37 @@ def _extract_strided_slice_constraint_system(
# We use a single variable because lowering does not support two different
# layouts for `source` and `result`.
{variable: [operand, result]},
- [],
+ )
+
+
+@_add_constraint_system_derivation_rule(vector.ExtractOp)
+def _vector_extract_constraint_system(
+ ctx: DerivationContext, op: vector.ExtractOp
+) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]:
+ del ctx
+ if not ir.VectorType.isinstance(op.result.type): # scalar result
+ operand = ValueSite(op, VariableType.OPERAND, 0)
+ variable = cs.Variable(operand)
+ layout = fa.WGSplatFragLayout(tuple(op.source.type.shape))
+ # We only support indexing for splat layout.
+ assignments = {variable: cs.RegisterLayout(layout)}
+ return cs.ConstraintSystem(assignments), {variable: [operand]}
+
+ if op.dynamic_position:
+ raise NotImplementedError("Only slicing with static indices allowed.")
+ operand = ValueSite(op, VariableType.OPERAND, 0)
+ result = ValueSite(op, VariableType.RESULT, 0)
+ variable = cs.Variable(operand)
+ constraints = [
+ cs.Divides(variable, tuple(op.result.type.shape)),
+ # TODO(allanrenucci): Remove once vectors with splat and strided layouts
+ # can be sliced.
+ cs.NotOfType(variable, fa.WGSplatFragLayout),
+ cs.NotOfType(variable, fa.WGStridedFragLayout),
+ ]
+ return (
+ cs.ConstraintSystem(constraints=constraints),
+ {variable: [operand, result]},
)
@@ -1187,7 +1164,7 @@ def _extract_strided_slice_constraint_system(
def _custom_primitive_constraint_system(
ctx: DerivationContext,
op: mgpu.CustomPrimitiveOp,
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
assignments: dict[cs.Variable, cs.Constant] = {}
constraints: list[cs.Constraint] = []
in_layouts = iter(op.in_layouts)
@@ -1231,7 +1208,6 @@ def _custom_primitive_constraint_system(
return (
cs.ConstraintSystem(assignments, constraints),
{v: [v.key] for v in variables},
- [],
)
@@ -1249,15 +1225,17 @@ def _tmem_layout_from_layout_attr(
def _tmem_layout_cast_constraint_system(
ctx: DerivationContext,
op: mgpu.TmemLayoutCastOp,
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
operand = ValueSite(op, VariableType.OPERAND, 0)
variable = ctx.producer_ref(operand)
result = ValueSite(op, VariableType.RESULT, 0)
- out_layout = cs.TMEMLayout(_tmem_layout_from_layout_attr(op.new_layout))
+ tmem_layout = _tmem_layout_from_layout_attr(op.new_layout)
+ if not is_valid_tmem_layout_assignment(operand.shape, tmem_layout):
+ return cs.Unsatisfiable()
+ out_layout = cs.TMEMLayout(tmem_layout)
return (
cs.ConstraintSystem(assignments={variable: out_layout}),
{variable: [operand, result]},
- [],
)
@@ -1265,43 +1243,34 @@ def _tmem_layout_cast_constraint_system(
def _tmem_alloc_constraint_system(
ctx: DerivationContext,
op: mgpu.TmemAllocOp,
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
del ctx
result = ValueSite(op, VariableType.RESULT, 0)
result_var = cs.Variable(result)
- layout = tcgen05._infer_tmem_layout(
- tuple(op.result.type.shape), op.collective, packing=1
- )
-
in_smem = ValueSite(op, VariableType.OPERAND, 0)
in_smem_var = cs.Variable(in_smem)
assignments: dict[cs.Variable, cs.Constant] = {
in_smem_var: cs.SMEMTiling(None)
}
operands_for_variable = {result_var: [result], in_smem_var: [in_smem]}
-
- # This is a hint, not a hard constraint. This will be the default layout if
- # none can be inferred.
- hint = Hint(result_var, cs.TMEMLayout(layout))
- system = cs.ConstraintSystem(assignments=assignments)
- return system, operands_for_variable, [hint]
+ return cs.ConstraintSystem(assignments=assignments), operands_for_variable
@_add_constraint_system_derivation_rule(mgpu.TmemDeallocOp)
def _tmem_dealloc_constraint_system(
ctx: DerivationContext,
op: mgpu.TmemDeallocOp,
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
operand = ValueSite(op, VariableType.OPERAND, 0)
variable = ctx.producer_ref(operand)
- return cs.ConstraintSystem(), {variable: [operand]}, []
+ return cs.ConstraintSystem(), {variable: [operand]}
@_add_constraint_system_derivation_rule(mgpu.TcGen05MMAOp)
def _tcgen05_mma_constraint_system(
ctx: DerivationContext,
op: mgpu.TcGen05MMAOp,
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
assignments: dict[cs.Variable, cs.Constant] = {}
operands_for_variable: ValueSitesForVariable = {}
@@ -1313,6 +1282,7 @@ def _tcgen05_mma_constraint_system(
tuple(acc_type.shape), op.collective, packing=1
)
assignments[acc_variable] = cs.TMEMLayout(acc_layout)
+ acc_is_valid = is_valid_tmem_layout_assignment(acc.shape, acc_layout)
operands_for_variable[acc_variable] = [acc]
if _is_tmem_ref(op.a):
@@ -1325,6 +1295,19 @@ def _tcgen05_mma_constraint_system(
)
assignments[a_var] = cs.TMEMLayout(a_layout)
operands_for_variable[a_var] = [a]
+ a_is_valid = is_valid_tmem_layout_assignment(a.shape, a_layout)
+ else:
+ assert _is_smem_ref(op.a)
+ a_tiling = _infer_tiling_for_mma_ref(
+ ir.MemRefType(op.a.type),
+ max_swizzle=mgpu.SwizzlingMode.k128ByteSwizzle,
+ )
+ a = ValueSite(op, VariableType.OPERAND, 1)
+ a_var = ctx.producer_ref(a)
+ a_tile_transform = lc.TileTransform(a_tiling)
+ assignments[a_var] = cs.SMEMTiling(a_tile_transform)
+ operands_for_variable[a_var] = [a]
+ a_is_valid = is_valid_smem_layout_assignment(a.shape, a_tile_transform)
# SMEM
M = op.accumulator.type.shape[0]
@@ -1344,27 +1327,23 @@ def _tcgen05_mma_constraint_system(
b_tiling = _infer_tiling_for_mma_ref(ir.MemRefType(op.b.type), max_b_swizzle)
b = ValueSite(op, VariableType.OPERAND, 2)
b_var = ctx.producer_ref(b)
- assignments[b_var] = cs.SMEMTiling(lc.TileTransform(b_tiling))
+ b_tile_transform = lc.TileTransform(b_tiling)
+ assignments[b_var] = cs.SMEMTiling(b_tile_transform)
operands_for_variable[b_var] = [b]
+ b_is_valid = is_valid_smem_layout_assignment(b.shape, b_tile_transform)
- if _is_smem_ref(op.a):
- a_tiling = _infer_tiling_for_mma_ref(
- ir.MemRefType(op.a.type),
- max_swizzle=mgpu.SwizzlingMode.k128ByteSwizzle,
- )
- a = ValueSite(op, VariableType.OPERAND, 1)
- a_var = ctx.producer_ref(a)
- assignments[a_var] = cs.SMEMTiling(lc.TileTransform(a_tiling))
- operands_for_variable[a_var] = [a]
+ # TODO(bchetioui): think about raising a better error here.
+ if not a_is_valid or not b_is_valid or not acc_is_valid:
+ return cs.Unsatisfiable()
- return cs.ConstraintSystem(assignments=assignments), operands_for_variable, []
+ return cs.ConstraintSystem(assignments=assignments), operands_for_variable
@_add_constraint_system_derivation_rule(mgpu.AsyncLoadTmemOp)
def _async_load_tmem_constraint_system(
ctx: DerivationContext,
op: mgpu.AsyncLoadTmemOp,
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
source = ValueSite(op, VariableType.OPERAND, 0)
source_variable = ctx.producer_ref(source)
destination = ValueSite(op, VariableType.RESULT, 0)
@@ -1377,7 +1356,6 @@ def _async_load_tmem_constraint_system(
return (
cs.ConstraintSystem(constraints=[constraint]),
{source_variable: [source], destination_variable: [destination]},
- [],
)
@@ -1385,7 +1363,7 @@ def _async_load_tmem_constraint_system(
def _slice_tmem_constraint_system(
ctx: DerivationContext,
op: mgpu.SliceTmemOp,
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
operand = ValueSite(op, VariableType.OPERAND, 0)
operand_variable = ctx.producer_ref(operand)
result = ValueSite(op, VariableType.RESULT, 0)
@@ -1393,7 +1371,6 @@ def _slice_tmem_constraint_system(
return (
cs.ConstraintSystem(),
{operand_variable: [operand], result_variable: [result]},
- [],
)
@@ -1401,7 +1378,7 @@ def _slice_tmem_constraint_system(
def _async_store_tmem_constraint_system(
ctx: DerivationContext,
op: mgpu.AsyncStoreTmemOp,
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
source = ValueSite(op, VariableType.OPERAND, 0)
source_variable = cs.Variable(source)
destination = ValueSite(op, VariableType.OPERAND, 1)
@@ -1414,7 +1391,6 @@ def _async_store_tmem_constraint_system(
return (
cs.ConstraintSystem(constraints=[constraint]),
{source_variable: [source], destination_variable: [destination]},
- [],
)
@@ -1422,18 +1398,18 @@ def _async_store_tmem_constraint_system(
def _slice_smem_constraint_system(
ctx: DerivationContext,
op: mgpu.SliceSMEMOp,
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
del ctx
res = ValueSite(op, VariableType.RESULT, 0)
res_var = cs.Variable(res)
- return (cs.ConstraintSystem(), {res_var: [res]}, [])
+ return cs.ConstraintSystem(), {res_var: [res]}
@_add_constraint_system_derivation_rule(memref.SubViewOp)
def _memref_subview_constraint_system(
ctx: DerivationContext,
op: memref.SubViewOp,
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
source = ValueSite(op, VariableType.OPERAND, 0)
dest = ValueSite(op, VariableType.RESULT, 0)
source_dest_var = ctx.producer_ref(source)
@@ -1471,25 +1447,25 @@ def _memref_subview_constraint_system(
constraints = [cs.Divides(source_dest_var, tuple(tiling_multiple))]
system = cs.ConstraintSystem(constraints=constraints)
- return system, {source_dest_var: [source, dest]}, []
+ return system, {source_dest_var: [source, dest]}
@_add_constraint_system_derivation_rule(memref.CastOp)
def _memref_cast_op_constraint_system(
ctx: DerivationContext,
op: memref.CastOp,
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
source = ValueSite(op, VariableType.OPERAND, 0)
var_source_dest = ctx.producer_ref(source)
dest = ValueSite(op, VariableType.RESULT, 0)
- return cs.ConstraintSystem(), {var_source_dest: [source, dest]}, []
+ return cs.ConstraintSystem(), {var_source_dest: [source, dest]}
@_add_constraint_system_derivation_rule(memref.TransposeOp)
def _memref_transpose_op_constraint_system(
ctx: DerivationContext,
op: memref.TransposeOp,
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
in_ty = ir.MemRefType(op.in_.type)
if len(in_ty.shape) != 2:
raise NotImplementedError(f"Only 2D memrefs are supported, got {in_ty}")
@@ -1502,7 +1478,7 @@ def _memref_transpose_op_constraint_system(
source_var = ctx.producer_ref(source)
if not transpose:
- return (cs.ConstraintSystem(), {source_var: [source, dest]}, [])
+ return cs.ConstraintSystem(), {source_var: [source, dest]}
dest_var = cs.Variable(dest)
constraints = [
@@ -1510,14 +1486,14 @@ def _memref_transpose_op_constraint_system(
cs.Equals(source_var, cs.Transpose(dest_var)),
]
system = cs.ConstraintSystem(constraints=constraints)
- return system, {source_var: [source], dest_var: [dest]}, []
+ return system, {source_var: [source], dest_var: [dest]}
@_add_constraint_system_derivation_rule(memref.ExpandShapeOp)
def _memref_expand_shape_op_equation_system(
ctx: DerivationContext,
op: memref.ExpandShapeOp,
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
if utils.is_memref_transposed(ir.MemRefType(op.src.type)):
raise NotImplementedError(
"Transposed memrefs are not supported in ExpandShapeOp."
@@ -1538,7 +1514,7 @@ def _memref_expand_shape_op_equation_system(
reverse_tiling_multiple.append(dim)
constraints = [cs.Divides(var, tuple(reversed(reverse_tiling_multiple)))]
- return cs.ConstraintSystem(constraints=constraints), {var: [source, dest]}, []
+ return cs.ConstraintSystem(constraints=constraints), {var: [source, dest]}
# `memref.load` and `memref.store` are used to load barrier phases which are
@@ -1548,7 +1524,7 @@ def _memref_expand_shape_op_equation_system(
def _memref_load_store_op_constraint_system(
ctx: DerivationContext,
op: memref.LoadOp | memref.StoreOp,
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
del ctx
ref_shape = ir.MemRefType(op.memref.type).shape
@@ -1561,7 +1537,7 @@ def _memref_load_store_op_constraint_system(
ref = ValueSite(op, VariableType.OPERAND, ref_op_index)
var = cs.Variable(ref)
assignments: dict[cs.Variable, cs.Constant] = {var: cs.SMEMTiling(None)}
- return cs.ConstraintSystem(assignments=assignments), {var: [ref]}, []
+ return cs.ConstraintSystem(assignments=assignments), {var: [ref]}
def _extract_smem_tiling_from_custom_transform_attrs(
@@ -1597,13 +1573,17 @@ def _extract_smem_tiling_from_custom_transform_attrs(
def _with_transforms_constraint_system(
ctx: DerivationContext,
op: mgpu.WithTransformsOp,
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
source = ValueSite(op, VariableType.OPERAND, 0)
dest = ValueSite(op, VariableType.RESULT, 0)
var = ctx.producer_ref(source)
tiling = _extract_smem_tiling_from_custom_transform_attrs(op.ref.type, op.transforms)
+ if tiling.value is not None:
+ # TODO(bchetioui): think about raising a better error here.
+ if not is_valid_smem_layout_assignment(source.shape, tiling.value):
+ return cs.Unsatisfiable()
assignments: dict[cs.Variable, cs.Constant] = {var: tiling}
- return cs.ConstraintSystem(assignments=assignments), {var: [source, dest]}, []
+ return cs.ConstraintSystem(assignments=assignments), {var: [source, dest]}
@_add_constraint_system_derivation_rule(mgpu.AsyncLoadOp)
@@ -1611,7 +1591,7 @@ def _with_transforms_constraint_system(
def _async_load_store_constraint_system(
ctx: DerivationContext,
op: mgpu.AsyncLoadOp | mgpu.AsyncStoreOp,
-) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]:
+) -> ConstraintSystemDerivationRuleResult:
tiling_multiple = []
for size, index in zip(op.slice_lengths, op.indices, strict=True):
if size == -1:
@@ -1623,7 +1603,7 @@ def _async_load_store_constraint_system(
operand = ValueSite(op, VariableType.OPERAND, operand_index)
var = ctx.producer_ref(operand)
constraints = [cs.Divides(expr=var, tiling_multiple=tuple(tiling_multiple))]
- return cs.ConstraintSystem(constraints=constraints), {var: [operand]}, []
+ return cs.ConstraintSystem(constraints=constraints), {var: [operand]}
def _ensure_all_layouts_are_set(op: ir.OpView) -> None:
@@ -1824,9 +1804,9 @@ def producer_result(operand: ValueSite) -> ValueSite:
assert operand.type == VariableType.OPERAND
value = operand.value
producer = value.owner
- if isinstance(producer, ir.Operation):
+ if isinstance(producer, ir.OpView):
index = list(producer.results).index(value)
- return ValueSite(producer.opview, VariableType.RESULT, index)
+ return ValueSite(producer, VariableType.RESULT, index)
if isinstance(producer, ir.Block):
index = list(producer.arguments).index(value)
@@ -1845,17 +1825,16 @@ def consumer_operands(result: ValueSite) -> Sequence[ValueSite]:
# The layout can also be chosen from the layout of the consumers of the
# results.
for use in result.value.uses:
- consumer = use.owner.opview # pytype: disable=attribute-error
+ consumer = use.owner
index = use.operand_number
consumer_operands.append(ValueSite(consumer, VariableType.OPERAND, index))
return consumer_operands
-def derive_hints_and_constraints(
+def derive_relayout_constraints(
value_sites_for_variable: ValueSitesForVariable,
-) -> tuple[list[Hint], list[cs.Relayout]]:
- """Derives propagation hints from the given variable mapping."""
- hints: list[Hint] = []
+) -> list[cs.Relayout]:
+ """Derives relayout constraints from the given variable mapping."""
constraints: list[cs.Relayout] = []
variable_for_value_site: dict[ValueSite, cs.Variable] = {}
for variable, value_sites in value_sites_for_variable.items():
@@ -1895,16 +1874,7 @@ def derive_hints_and_constraints(
# A variable must be relayout-able to its consumers.
constraints.append(cs.Relayout(variable, consumer_variable))
visited.add(variable)
-
- if producers:
- least_replicated_producer = cs.LeastReplicated(tuple(producers))
- hint_expr = cs.MostReplicated((least_replicated_producer, *consumers))
- hints.append(Hint(variable, hint_expr))
- elif consumers:
- hint_expr = cs.MostReplicated(tuple(consumers))
- hints.append(Hint(variable, hint_expr))
-
- return hints, constraints
+ return constraints
def is_terminator(op: ir.OpView) -> bool:
@@ -1930,6 +1900,77 @@ def traverse_op(
traverse_op(block_op, callback)
+def is_valid_register_layout_assignment(
+ shape: tuple[int, ...], layout: fa.FragmentedLayout
+) -> bool:
+ match layout:
+ case fa.WGStridedFragLayout() as strided_layout:
+ return strided_layout.shape == shape
+ case fa.WGSplatFragLayout() as splat_layout:
+ return splat_layout.shape == shape
+ case fa.TiledLayout(tiling=tiling):
+ try:
+ # `tiling.tile_shape` will raise if the shape is not tileable.
+ _ = tiling.tile_shape(shape)
+ except ValueError:
+ return False
+ return True
+ case _:
+ assert False, f"Unreachable {shape}, {layout}"
+
+
+def is_valid_smem_layout_assignment(
+ shape: tuple[int, ...], tiling: lc.TileTransform
+) -> bool:
+ try:
+ # `tiling.transform_shape` will raise if the shape is not tileable.
+ _ = tiling.transform_shape(shape)
+ except ValueError:
+ return False
+ return True
+
+
+def is_valid_tmem_layout_assignment(
+ shape: tuple[int, ...], layout: tcgen05.TMEMLayout
+) -> bool:
+ try:
+ # `layout.tiling.tile_shape` will raise if the shape is not tileable.
+ _ = layout.tiling.tile_shape(shape)
+ except ValueError:
+ return False
+ return True
+
+
+def check_layout_assignment(v: ValueSite, layout: cs.Constant) -> None:
+ """Raises if the given layout can not be assigned to the given `ValueSite`."""
+ match v.memory_space, layout:
+ case MemorySpace.REG, cs.RegisterLayout(value=reg_layout):
+ if not is_valid_register_layout_assignment(v.shape, reg_layout):
+ raise ValueError(
+ f"Layout {reg_layout} is not compatible with register variable "
+ f"{v.value}. This is a bug."
+ )
+ case MemorySpace.TMEM, cs.TMEMLayout(value=tmem_layout):
+ if not is_valid_tmem_layout_assignment(v.shape, tmem_layout):
+ raise ValueError(
+ f"Layout {tmem_layout} is not compatible with TMEM variable "
+ f"{v.value}. This is a bug."
+ )
+ case MemorySpace.SMEM, cs.SMEMTiling(value=tiling_or_none):
+ if tiling_or_none is None:
+ return
+ if not is_valid_smem_layout_assignment(v.shape, tiling_or_none):
+ raise ValueError(
+ f"Layout {tiling_or_none} is not compatible with SMEM variable "
+ f"{v.value}. This is a bug."
+ )
+ case _:
+ raise ValueError(
+ f"Variable {v.value} in memory space {v.memory_space} should not be "
+ f"assigned a layout of type {type(layout)}. This is a bug."
+ )
+
+
def infer_layout(
module: ir.Module, *, fuel: int = _DEFAULT_LAYOUT_INFERENCE_FUEL
):
@@ -1949,7 +1990,6 @@ def infer_layout(
"""
global_constraint_system: cs.ConstraintSystem | cs.Unsatisfiable
global_constraint_system = cs.ConstraintSystem()
- hints: list[Hint] = []
ctx = DerivationContext()
def gather_constraints(op: ir.Operation):
@@ -1969,14 +2009,21 @@ def gather_constraints(op: ir.Operation):
rule = _constraint_system_derivation_rules.get(op.OPERATION_NAME, None) # pytype: disable=attribute-error
if rule is None:
raise NotImplementedError(f"No layout inference rule defined for {op}")
- constraint_system, mapping, op_hints = rule(ctx, op)
- ctx.update(mapping)
+ rule_result = rule(ctx, op)
nonlocal global_constraint_system
+ if isinstance(rule_result, cs.Unsatisfiable):
+ global_constraint_system = cs.Unsatisfiable()
+ return
+ constraint_system, mapping = rule_result
global_constraint_system &= constraint_system
- hints.extend(op_hints)
+ ctx.update(mapping)
for op in module.body:
traverse_op(op, gather_constraints)
+ # Short-circuit if we have an unsatisfiable constraint system, we won't
+ # construct anything useful anymore.
+ if isinstance(global_constraint_system, cs.Unsatisfiable):
+ break
if isinstance(global_constraint_system, cs.Unsatisfiable):
raise ValueError(
@@ -1984,8 +2031,7 @@ def gather_constraints(op: ir.Operation):
"user-provided layout casts are unsatisfiable."
)
- propagation_hints, constraints = derive_hints_and_constraints(ctx.value_sites_for_variable)
- hints = reduce_hints(hints + propagation_hints, global_constraint_system.assignments) # pytype: disable=attribute-error
+ constraints = derive_relayout_constraints(ctx.value_sites_for_variable)
global_constraint_system &= cs.ConstraintSystem(constraints=constraints)
assert not isinstance(global_constraint_system, cs.Unsatisfiable)
@@ -2003,7 +2049,6 @@ def gather_constraints(op: ir.Operation):
solution, remaining_fuel = find_assignments_for(
list(ctx.value_sites_for_variable.keys()),
global_constraint_system,
- hints,
fuel=fuel,
)
@@ -2017,11 +2062,14 @@ def gather_constraints(op: ir.Operation):
"user-provided layout casts are unsatisfiable."
)
- layout_for_value_site = {
- k: solution[v]
- for v, ks in ctx.value_sites_for_variable.items()
- for k in ks
- }
+ layout_for_value_site: dict[ValueSite, cs.Constant] = {}
+ for variable, value_sites in ctx.value_sites_for_variable.items():
+ for value_site in value_sites:
+ layout = solution[variable]
+ # Ensure that the layout assignment is valid for the value site. This
+ # should only ever fail if our implementation is buggy.
+ check_layout_assignment(value_site, layout)
+ layout_for_value_site[value_site] = layout
# Assigns the layouts that we found to the ops.
assign_layouts(layout_for_value_site)
diff --git a/jax/experimental/mosaic/gpu/layouts.py b/jax/experimental/mosaic/gpu/layouts.py
index 82870696eb32..1b75e4d48b47 100644
--- a/jax/experimental/mosaic/gpu/layouts.py
+++ b/jax/experimental/mosaic/gpu/layouts.py
@@ -15,7 +15,6 @@
"""Layout utilities."""
import re
-from typing import assert_never
from jax._src.lib import mosaic_gpu_dialect as mgpu
from jax._src.lib.mlir import ir
@@ -224,118 +223,6 @@ def splat_is_compatible_with_tiled(
return all(d1 % d2 == 0 for d1, d2 in zip(s1, s2))
-def meet_layouts(
- layout1: fa.FragmentedLayout, layout2: fa.FragmentedLayout
-) -> fa.FragmentedLayout | None:
- """Returns the "meet" of two layouts that are compatible up to replication.
-
- The "meet" of the two layouts is the most replicated layout that is still
- less replicated than the arguments.
-
- This is the dual of `join_layouts`.
-
- Returns:
- The "meet" of the two layouts if both layouts are compatible up to
- replication.
-
- Raises:
- ValueError: if the two layouts are not compatible up to replication.
- """
- if layout1 == layout2:
- return layout1
-
- match (layout1, layout2):
- case (fa.WGSplatFragLayout(), _):
- if isinstance(layout2, fa.TiledLayout):
- if splat_is_compatible_with_tiled(layout1, layout2):
- return layout2
- elif layout1.shape == layout2.shape:
- return layout2
- case (_, fa.WGSplatFragLayout()):
- if isinstance(layout1, fa.TiledLayout):
- if splat_is_compatible_with_tiled(layout2, layout1):
- return layout1
- elif layout1.shape == layout2.shape:
- return layout1
- case (fa.TiledLayout(), fa.TiledLayout()):
- # TODO(bchetioui): handle `TiledLayout` replication.
- raise NotImplementedError("TiledLayout replication not supported yet")
-
- # Layouts are not compatible up to replication.
- return None
-
-# NOTE: We say that two layouts are compatible up to replication if the two
-# layouts satisfy at least one of the following conditions together:
-#
-# - The two layouts are equal;
-# - One of the layouts is a `WGSplatFragLayout`, and
-# * The other layout is a `WGStridedFragLayout` with the same shape;
-# * The other layout is a `TiledLayout` that can be used to tile the shape
-# embedded in the `WGSplatFragLayout`.
-#
-# If any of these conditions hold, then we are always able to substitute one
-# layout with the other without having to reorder any data in the underlying
-# array---i.e. a relayout is free.
-#
-# Note that there are other combinations of layouts for which relayout is free,
-# but we voluntarily narrowed down our definition to span a small, useful
-# subset.
-
-def join_layouts(
- layout1: fa.FragmentedLayout, layout2: fa.FragmentedLayout
-) -> fa.FragmentedLayout | None:
- """Returns the "join" of two layouts that are compatible up to replication.
-
- The "join" of the two layouts is the least replicated layout that is still
- more replicated than the arguments.
-
- This is the dual of `meet_layouts`.
-
- Returns:
- The "join" of the two layouts if both layouts are compatible up to
- replication.
-
- Raises:
- ValueError: if the two layouts are not compatible up to replication.
- """
- if layout1 == layout2:
- return layout1
-
- match (layout1, layout2):
- case (fa.WGSplatFragLayout(), _):
- if isinstance(layout2, fa.TiledLayout):
- if splat_is_compatible_with_tiled(layout1, layout2):
- return layout1
- elif layout1.shape == layout2.shape:
- return layout1
- case (_, fa.WGSplatFragLayout()):
- if isinstance(layout1, fa.TiledLayout):
- if splat_is_compatible_with_tiled(layout2, layout1):
- return layout2
- elif layout1.shape == layout2.shape:
- return layout2
- case (fa.TiledLayout(), fa.TiledLayout()):
- # TODO(bchetioui): handle `TiledLayout` replication.
- raise NotImplementedError("TiledLayout replication not supported yet")
-
- # Layouts are not compatible up to replication.
- return None
-
-
-def has_any_replication(layout: fa.FragmentedLayout) -> bool:
- match layout:
- case fa.WGSplatFragLayout():
- return True
- case fa.WGStridedFragLayout():
- return False
- case fa.TiledLayout():
- is_warp_replicated = any(isinstance(d, fa.Replicated) for d in layout.warp_dims)
- is_lane_replicated = any(isinstance(d, fa.Replicated) for d in layout.lane_dims)
- return is_warp_replicated or is_lane_replicated
- case _ as unreachable:
- return assert_never(unreachable) # pytype: disable=wrong-arg-types
-
-
_tile_transform_attr_pattern = re.compile(
r"^#mosaic_gpu.tile<[^>]+>$"
)
diff --git a/jax/experimental/mosaic/gpu/mma_utils.py b/jax/experimental/mosaic/gpu/mma_utils.py
index d4e04fdc67ec..ebd789ca348c 100644
--- a/jax/experimental/mosaic/gpu/mma_utils.py
+++ b/jax/experimental/mosaic/gpu/mma_utils.py
@@ -48,6 +48,7 @@ def create_descriptor(
# Soft deprecated. Use small tiling instead.
large_tile: tuple[int, int] | None = None,
mma_bytewidth_k: int = 32,
+ split_const: bool = False,
):
ref_ty = ir.MemRefType(ref.type)
element_bitwidth = utils.bitwidth(ref_ty.element_type)
@@ -183,6 +184,7 @@ def to_byte_stride(stride: int):
leading_byte_offset=leading_byte_offset,
stride_byte_offset=stride_byte_offset,
swizzle=swizzle,
+ split_const=split_const,
)
mn_tiles_per_group, rem = divmod(mn_group_size, mn_tiling)
@@ -221,7 +223,9 @@ def encode_descriptor(
stride_byte_offset: int,
swizzle: int | mgpu_dialect.SwizzlingMode | None,
const_init: int = 0,
+ split_const: bool = False,
):
+ i32 = ir.IntegerType.get_signless(32)
i64 = ir.IntegerType.get_signless(64)
if isinstance(ref_arg.type, ir.MemRefType):
ptr = utils.memref_ptr(ref_arg, 3)
@@ -246,7 +250,18 @@ def encode_descriptor(
const_init
| (encode_addr(leading_byte_offset) << 16)
| (encode_addr(stride_byte_offset) << 32)
+ | (swizzle_encoding << 62)
)
- desc = llvm.or_(arith.shli(c(swizzle_encoding), c(62)), c(desc_const))
- desc = llvm.or_(encoded_base_addr, desc)
- return desc
+ if split_const:
+ # The encoded base addr fits within a single 32-bit register.
+ return arith.trunci(i32, encoded_base_addr), desc_const
+ else:
+ # The desc_const frequently has the MSB set, leading to errors when trying
+ # to create ir.IntegerAttr through the MLIR python bindings... This should
+ # be easy enough for LLVM to constant fold away.
+ if desc_const >> 63:
+ desc_val = c(desc_const & 0xFFFFFFFF)
+ desc_val = llvm.or_(desc_val, arith.shli(c(desc_const >> 32), c(32)))
+ else:
+ desc_val = c(desc_const)
+ return llvm.or_(encoded_base_addr, desc_val)
diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py
index 0897021a28ae..c623a1b3349c 100644
--- a/jax/experimental/mosaic/gpu/tcgen05.py
+++ b/jax/experimental/mosaic/gpu/tcgen05.py
@@ -447,6 +447,7 @@ def mma(
group_size=(m_group_elems, k_group_elems // (1 + is_sparse)),
logical_k_major=False,
mma_bytewidth_k=32,
+ split_const=True,
)
else:
a_fastest = mma_utils.Dim.K
@@ -462,6 +463,7 @@ def mma(
group_size=(k_group_elems, n_group_elems),
logical_k_major=True,
mma_bytewidth_k=64 if is_sparse else 32,
+ split_const=True,
)
if is_scaled and utils.bitwidth(mma_element_type) == 4:
@@ -496,9 +498,9 @@ def mma(
a_mk = a.slice(slice(None), utils.ds(ki * a_k_group_elems, a_k_group_elems)).address
else:
a_offset = mi * a_m_group_stride + ki * a_k_group_stride
- a_mk = arith.addi(a_desc_base, utils.c(mma_utils.encode_addr(a_offset), i64))
+ a_mk = (a_desc_base[0], a_desc_base[1] + mma_utils.encode_addr(a_offset))
b_offset = ni * b_n_group_stride + ki * b_k_group_stride
- b_nk = arith.addi(b_desc_base, utils.c(mma_utils.encode_addr(b_offset), i64))
+ b_nk = (b_desc_base[0], b_desc_base[1] + mma_utils.encode_addr(b_offset))
if a_sparse_addr_base is not None:
if n_groups != 1 or m_groups != 1:
raise NotImplementedError("A sparse metadata address calculation for multiple tiles")
@@ -559,8 +561,8 @@ def mma(
def _do_mma(
d_addr: ir.Value,
- a_desc_or_addr: ir.Value, # TMEM address if a_k_stride is None
- b_desc: ir.Value,
+ a_desc_or_addr: tuple[ir.Value, int] | ir.Value, # TMEM address if a_k_stride is None
+ b_desc: tuple[ir.Value, int],
a_transpose: bool,
b_transpose: bool,
a_k_strides: tuple[tuple[int, ...], tuple[int, ...]] | None,
@@ -638,14 +640,12 @@ def create_scaled_instr_descriptor(*args): # type: ignore
num_cta = 2 if collective else 1
a_in_tmem = a_k_strides is None
- a_ptx = "[$1]" if a_in_tmem else "$1"
- a_ptx_constraint = "r" if a_in_tmem else "l"
+ a_ptx = "[a_desc]" if a_in_tmem else "a_desc"
sparse_mod = ".sp" if is_sparse else ""
sparse_meta_ptx = "[$5], " if is_sparse else ""
extra_constraints += ",r" if is_sparse else ""
sparse_addr: tuple[Any, ...] = ()
scales_addrs: tuple[Any, ...] = ()
- assert a_desc_or_addr.type == ir.IntegerType.get_signless(32 if a_in_tmem else 64)
def _get_offset(idx: int, idx_tiling: tuple[int, ...], strides: tuple[int, ...]):
assert len(idx_tiling) + 1 == len(strides)
idxs = []
@@ -654,7 +654,7 @@ def _get_offset(idx: int, idx_tiling: tuple[int, ...], strides: tuple[int, ...])
idx = idx % t
idxs.append(idx)
offset = sum(i * s for i, s in zip(idxs, strides, strict=True))
- return arith.constant(i64, offset >> 4)
+ return offset >> 4
for k_step in range(k // instr_k):
if is_scaled:
assert scale_steps is not None
@@ -666,7 +666,7 @@ def _get_offset(idx: int, idx_tiling: tuple[int, ...], strides: tuple[int, ...])
scale_id, scale_id, a_transpose, b_transpose
)
assert m == 128
- assert n % 128 == 0
+ assert (n * num_cta) % 128 == 0
# A scales are sharded, B scales are replicated across CTAs.
a_scale_addr_offset = arith.constant(i32, k_step // scale_steps * 4)
b_scale_addr_offset = arith.constant(i32, k_step // scale_steps * n // 32 * num_cta)
@@ -696,20 +696,32 @@ def _get_offset(idx: int, idx_tiling: tuple[int, ...], strides: tuple[int, ...])
)
if a_in_tmem:
cols_per_k_group = instr_k // packing // (1 + is_sparse)
- a_desc_or_addr_instr = arith.addi(
- a_desc_or_addr, arith.constant(i32, k_step * cols_per_k_group)
- )
+ a_offset = k_step * cols_per_k_group
+ assert isinstance(a_desc_or_addr, ir.Value)
+ assert a_desc_or_addr.type == ir.IntegerType.get_signless(32)
+ a_enc_addr_base = a_desc_or_addr
else:
assert a_k_idx_tiling is not None and a_k_strides is not None
- a_desc_or_addr_instr = arith.addi(
- a_desc_or_addr, _get_offset(k_step, a_k_idx_tiling, a_k_strides)
- )
- b_desc_instr = arith.addi(b_desc, _get_offset(k_step, b_k_idx_tiling, b_k_strides))
+ a_enc_addr_base, a_offset = a_desc_or_addr
+ a_offset += _get_offset(k_step, a_k_idx_tiling, a_k_strides)
+ b_enc_addr_base, b_offset = b_desc
+ b_offset += _get_offset(k_step, b_k_idx_tiling, b_k_strides)
+ a_offset_low, a_offset_high = a_offset & 0xFFFFFFFF, a_offset >> 32
+ b_offset_low, b_offset_high = b_offset & 0xFFFFFFFF, b_offset >> 32
llvm.inline_asm(
ir.Type.parse("!llvm.void"),
- [d_addr, a_desc_or_addr_instr, b_desc_instr, i_desc, accumulate, *scales_addrs, *sparse_addr],
- f"tcgen05.mma{sparse_mod}.cta_group::{num_cta}.kind::{kind} [$0], {a_ptx}, $2, {sparse_meta_ptx}$3, {extra_ptx}$4;",
- f"r,{a_ptx_constraint},l,r,b" + extra_constraints,
+ [d_addr, a_enc_addr_base, b_enc_addr_base, i_desc, accumulate, *scales_addrs, *sparse_addr],
+ f"""{{
+ .reg .b32 a_desc_low, a_desc_high, b_desc_low, b_desc_high;
+ .reg {".b32" if a_in_tmem else ".b64"} a_desc;
+ .reg .b64 b_desc;
+ add.s32 a_desc_low, $1, {a_offset_low};
+ add.s32 b_desc_low, $2, {b_offset_low};
+ mov.b64 b_desc, {{b_desc_low, {b_offset_high}}};
+ {"mov.b32 a_desc, a_desc_low;" if a_in_tmem else f"mov.b64 a_desc, {{a_desc_low, {a_offset_high}}};"}
+ tcgen05.mma{sparse_mod}.cta_group::{num_cta}.kind::{kind} [$0], {a_ptx}, b_desc, {sparse_meta_ptx}$3, {extra_ptx}$4;
+ }}""",
+ "r,r,r,r,b" + extra_constraints,
has_side_effects=True,
)
accumulate = arith.constant(i1, 1)
diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py
index 33965db4ef6f..94f23066a88a 100644
--- a/jax/experimental/mosaic/gpu/utils.py
+++ b/jax/experimental/mosaic/gpu/utils.py
@@ -722,6 +722,18 @@ def memref_reshape(
(), ref_ty.element_type, new_layout, ref_ty.memory_space
)
return memref.collapse_shape(result_ty, ref, [])
+ # For contiguous refs we can do arbitrary reshapes easily.
+ strides, _ = ref_ty.get_strides_and_offset()
+ if all(
+ d == 1 or s1 == s2
+ for d, s1, s2 in zip(
+ ref_ty.shape,
+ get_contiguous_strides(ref_ty.shape),
+ strides,
+ strict=True,
+ )
+ ):
+ return memref_unfold(memref_fold(ref, 0, ref_ty.rank), 0, shape)
return _reshape(ref, src_shape, dst_shape)
@@ -1071,7 +1083,9 @@ def arrive_expect_tx(
elif ir.IndexType.isinstance(bytes.type):
i32 = ir.IntegerType.get_signless(32)
bytes = arith.index_cast(i32, bytes)
- nvvm.mbarrier_arrive_expect_tx(self.get_ptr(), bytes, predicate=predicate)
+ nvvm_mbarrier_arrive_expect_tx(
+ self.get_ptr(), bytes, predicate=predicate
+ )
def get_ptr(self):
i64 = ir.IntegerType.get_signless(64)
@@ -1840,7 +1854,7 @@ def is_known_divisible(value, divisor, max_depth=10) -> bool:
"""Returns True if the value is statically known to be divisible by the divisor."""
if divisor == 1:
return True
- if max_depth < 0 or not isinstance(value.owner, ir.Operation):
+ if max_depth < 0 or not isinstance(value.owner, ir.OpView):
return False
new_depth = max_depth - 1
@@ -1986,3 +2000,10 @@ def nanosleep(nanos: ir.Value):
"r",
has_side_effects=True,
)
+
+
+def nvvm_mbarrier_arrive_expect_tx(barrier: ir.Value, expect_tx: ir.Value, predicate: ir.Value | None = None):
+ try:
+ return nvvm.mbarrier_arrive_expect_tx(None, barrier, expect_tx, predicate=predicate) # type: ignore
+ except TypeError:
+ return nvvm.mbarrier_arrive_expect_tx(barrier, expect_tx, predicate=predicate) # pytype: disable=missing-parameter
diff --git a/jax/experimental/mosaic/gpu/wgmma.py b/jax/experimental/mosaic/gpu/wgmma.py
index bdb1c5c200fe..9af9e8965ad4 100644
--- a/jax/experimental/mosaic/gpu/wgmma.py
+++ b/jax/experimental/mosaic/gpu/wgmma.py
@@ -65,7 +65,8 @@ def value(self) -> fa.FragmentedArray:
@classmethod
def zero(cls, m, n, dtype=None, *, is_signed: bool | None = None):
if m % 64 or n % 8:
- raise ValueError
+ raise ValueError("WGMMA requires m and n to be multiples of 64 and 8, "
+ f"got {m} and {n}")
if is_signed is False: # pylint: disable=g-bool-id-comparison
raise TypeError("PTX does not support unsigned WGMMA accumulators")
f32 = ir.F32Type.get()
diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py
index 1acb02e5b01f..f3026502abc6 100644
--- a/jax/experimental/multihost_utils.py
+++ b/jax/experimental/multihost_utils.py
@@ -123,8 +123,9 @@ def _handle_array_process_allgather(inp, tiled):
host_np_arr = np.expand_dims(host_np_arr, axis=0)
aval = core.ShapedArray(host_np_arr.shape, host_np_arr.dtype)
+ pspec = sharding_impls.prepare_axis_resources(pspec, "pspec to array_mapping")
global_aval = pxla.mesh_local_to_global(
- global_mesh, pxla.get_array_mapping(pspec), aval)
+ global_mesh, sharding_impls.get_array_mapping(pspec), aval)
bufs = [jax.device_put(host_np_arr, d) for d in jax.local_devices()]
global_arr = array.make_array_from_single_device_arrays(
@@ -225,7 +226,10 @@ def should_save(step_id: int) -> bool:
return False
sync_manager = distributed.global_state.preemption_sync_manager
if sync_manager is None:
- raise RuntimeError("Preemption sync manager has not been initialized.")
+ raise RuntimeError(
+ "Preemption sync manager has not been initialized. Make sure the"
+ " 'jax_enable_preemption_service' config is enabled."
+ )
return sync_manager.reached_sync_point(step_id)
@@ -236,13 +240,15 @@ def _flatten_pspecs(name, in_tree, pspecs_thunk):
@lru_cache
def _local_to_global_aval(local_aval, mesh, pspec):
- return pxla.mesh_local_to_global(mesh, pxla.get_array_mapping(pspec),
- local_aval)
+ pspec = sharding_impls.prepare_axis_resources(pspec, "pspec to array_mapping")
+ return pxla.mesh_local_to_global(
+ mesh, sharding_impls.get_array_mapping(pspec), local_aval)
@lru_cache
def _global_to_local_aval(global_aval, mesh, pspec):
- return pxla.mesh_global_to_local(mesh, pxla.get_array_mapping(pspec),
- global_aval)
+ pspec = sharding_impls.prepare_axis_resources(pspec, "pspec to array_mapping")
+ return pxla.mesh_global_to_local(
+ mesh, sharding_impls.get_array_mapping(pspec), global_aval)
def host_local_array_to_global_array_impl(
diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py
index 8adcd2da2521..eccc06881936 100644
--- a/jax/experimental/pallas/mosaic_gpu.py
+++ b/jax/experimental/pallas/mosaic_gpu.py
@@ -87,6 +87,7 @@
from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait as wgmma_wait
from jax._src.pallas.mosaic_gpu.torch import as_torch_kernel as as_torch_kernel
from jax.experimental.mosaic.gpu.core import LoweringSemantics as LoweringSemantics
+from jax.experimental.mosaic.gpu.fragmented_array import Replicated as Replicated
from jax.experimental.mosaic.gpu.fragmented_array import Tiling as Tiling
diff --git a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py
index 5e3dcb271a4b..b35d6d6dc8d5 100644
--- a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py
+++ b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py
@@ -547,10 +547,10 @@ def paged_attention(
if k_scales_pages is not None and v_scales_pages is not None:
in_specs = [
q_block_spec,
- pl.BlockSpec(memory_space=pltpu.ANY),
- pl.BlockSpec(memory_space=pltpu.ANY),
- pl.BlockSpec(memory_space=pltpu.ANY),
- pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=pl.ANY),
+ pl.BlockSpec(memory_space=pl.ANY),
+ pl.BlockSpec(memory_space=pl.ANY),
+ pl.BlockSpec(memory_space=pl.ANY),
]
scratch_shapes = (
pltpu.VMEM(
@@ -595,9 +595,9 @@ def paged_attention(
else:
in_specs = [
q_block_spec,
- pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=pl.ANY),
None, # type: ignore[list-item]
- pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=pl.ANY),
None, # type: ignore[list-item]
]
scratch_shapes = (
diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py
index eea2f60c26f1..5ddeab270657 100644
--- a/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py
+++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py
@@ -831,7 +831,7 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_):
)
in_specs = [
q_block_spec,
- pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=pl.ANY),
]
out_specs = q_block_spec
lm_scratch = pltpu.VMEM(
diff --git a/jax/experimental/pallas/ops/tpu/random/philox.py b/jax/experimental/pallas/ops/tpu/random/philox.py
index dcdbd94779db..cb108c319507 100644
--- a/jax/experimental/pallas/ops/tpu/random/philox.py
+++ b/jax/experimental/pallas/ops/tpu/random/philox.py
@@ -117,7 +117,7 @@ def kernel(offset_ref, key_ref, out_ref):
offset = prng_utils.compute_scalar_offset(
counts_idx, unpadded_shape, block_shape)
counts_lo = prng_utils.blocked_iota(block_size, unpadded_shape)
- counts_lo = counts_lo + offset + offset_ref[0]
+ counts_lo = counts_lo + offset.astype(jnp.uint32) + offset_ref[0]
counts_lo = counts_lo.astype(jnp.uint32)
# TODO(justinfu): Support hi bits on count.
_zeros = jnp.zeros_like(counts_lo)
diff --git a/jax/experimental/pallas/ops/tpu/random/threefry.py b/jax/experimental/pallas/ops/tpu/random/threefry.py
index 71a314e09b2d..06a82f4abac8 100644
--- a/jax/experimental/pallas/ops/tpu/random/threefry.py
+++ b/jax/experimental/pallas/ops/tpu/random/threefry.py
@@ -63,7 +63,7 @@ def kernel(key_ref, out_ref):
offset = prng_utils.compute_scalar_offset(
counts_idx, unpadded_shape, block_shape)
counts_lo = prng_utils.blocked_iota(block_size, unpadded_shape)
- counts_lo = counts_lo + offset
+ counts_lo = counts_lo + offset.astype(jnp.uint32)
counts_lo = counts_lo.astype(jnp.uint32)
# TODO(justinfu): Support hi bits on count.
counts_hi = jnp.zeros_like(counts_lo)
diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py
index 0fbece6f5e42..43be2d80840b 100644
--- a/jax/experimental/pallas/tpu.py
+++ b/jax/experimental/pallas/tpu.py
@@ -89,10 +89,6 @@
HBM = MemorySpace.HBM
HOST = MemorySpace.HOST
SEMAPHORE = MemorySpace.SEMAPHORE
-# Expose ANY for backward compatibility.
-ANY = GeneralMemorySpace.ANY
-del GeneralMemorySpace
-
_deprecations = {
# Added Oct 31, 2025
@@ -100,13 +96,20 @@
"pltpu.delay is deprecated, use pl.delay instead.",
pl_primitives.delay
),
+ # Added Dec 10, 2025
+ "ANY": (
+ "pltpu.ANY is deprecated, use pl.ANY instead.",
+ GeneralMemorySpace.ANY
+ ),
}
if typing.TYPE_CHECKING:
delay = pl_primitives.delay
+ ANY = GeneralMemorySpace.ANY
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing
del pl_primitives
+del GeneralMemorySpace
diff --git a/jax/experimental/pallas/tpu_sc.py b/jax/experimental/pallas/tpu_sc.py
index 68503bb15c33..e9a90ac9f7ac 100644
--- a/jax/experimental/pallas/tpu_sc.py
+++ b/jax/experimental/pallas/tpu_sc.py
@@ -32,6 +32,7 @@
from jax._src.pallas.mosaic.sc_primitives import PackFormat as PackFormat
from jax._src.pallas.mosaic.sc_primitives import parallel_loop as parallel_loop
from jax._src.pallas.mosaic.sc_primitives import scan_count as scan_count
+from jax._src.pallas.mosaic.sc_primitives import sort_key_val as sort_key_val
from jax._src.pallas.mosaic.sc_primitives import store_compressed as store_compressed
from jax._src.pallas.mosaic.sc_primitives import store_scatter as store_scatter
from jax._src.pallas.mosaic.sc_primitives import subcore_barrier as subcore_barrier
diff --git a/jax/experimental/sparse/__init__.py b/jax/experimental/sparse/__init__.py
index f388cd527cf9..dbd21e343bb7 100644
--- a/jax/experimental/sparse/__init__.py
+++ b/jax/experimental/sparse/__init__.py
@@ -15,10 +15,17 @@
"""
.. currentmodule:: jax.experimental.sparse
+.. note::
+
+ The methods in ``jax.experimental.sparse`` are experimental reference implementations,
+ and not recommended for use in performance-critical applications. The submodule is no
+ longer being actively developed, but the team will continue supporting existing features
+ as best we can.
+
The :mod:`jax.experimental.sparse` module includes experimental support for sparse matrix
-operations in JAX. It is under active development, and the API is subject to change. The
-primary interfaces made available are the :class:`BCOO` sparse array type, and the
-:func:`sparsify` transform.
+operations in JAX. The primary interfaces made available are the :class:`BCOO` sparse array
+type, and the :func:`sparsify` transform.
+
Batched-coordinate (BCOO) sparse matrices
-----------------------------------------
diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py
index 6af63b3e7310..0d64fd123cc1 100644
--- a/jax/interpreters/ad.py
+++ b/jax/interpreters/ad.py
@@ -41,30 +41,16 @@
primitive_jvps as primitive_jvps,
primitive_transposes as primitive_transposes,
reducing_transposes as reducing_transposes,
- vjp as vjp,
zeros_like_aval as zeros_like_aval,
)
-def _deprecated_backward_pass(jaxpr, reduce_axes, transform_stack,
- consts, primals_in, cotangents_in):
- if reduce_axes:
- raise NotImplementedError("reduce_axes on ad.backward_pass is deprecated")
- del reduce_axes
- return _src_ad.backward_pass(
- jaxpr, transform_stack, consts, primals_in, cotangents_in)
-
-
_deprecations = {
# Deprecated for JAX v0.7.1; finalize in JAX v0.9.0.
"zeros_like_p": (
"jax.interpreters.ad.zeros_like_p is deprecated in JAX v0.7.1. It has been unused since v0.4.24.",
_src_ad_util.zeros_like_p,
),
- "backward_pass": (
- "jax.interpreters.ad.backward_pass is deprecated.",
- _deprecated_backward_pass
- ),
"bilinear_transpose": (
"jax.interpreters.ad.bilinear_transpose is deprecated.",
_src_ad.bilinear_transpose,
@@ -81,10 +67,6 @@ def _deprecated_backward_pass(jaxpr, reduce_axes, transform_stack,
"jax.interpreters.ad.call_transpose_param_updaters is deprecated.",
_src_ad.call_transpose_param_updaters,
),
- "closed_backward_pass": (
- "jax.interpreters.ad.closed_backward_pass is deprecated.",
- _src_ad.closed_backward_pass,
- ),
"custom_lin_p": (
"jax.interpreters.ad.custom_lin_p is deprecated.",
_src_ad.custom_lin_p,
@@ -161,12 +143,10 @@ def _deprecated_backward_pass(jaxpr, reduce_axes, transform_stack,
import typing
if typing.TYPE_CHECKING:
- backward_pass = _deprecated_backward_pass
bilinear_transpose = _src_ad.bilinear_transpose
call_param_updaters = _src_ad.call_param_updaters
call_transpose = _src_ad.call_transpose
call_transpose_param_updaters = _src_ad.call_transpose_param_updaters
- closed_backward_pass = _src_ad.closed_backward_pass
custom_lin_p = _src_ad.custom_lin_p
defjvp_zero = _src_ad.defjvp_zero
f_jvp_traceable = _src_ad.f_jvp_traceable
diff --git a/jax/interpreters/batching.py b/jax/interpreters/batching.py
index 3eda5ab9bbbe..4db86da4d806 100644
--- a/jax/interpreters/batching.py
+++ b/jax/interpreters/batching.py
@@ -54,18 +54,6 @@
"jax.interpreters.batching.BatchingRule is deprecated.",
_src_batching.BatchingRule,
),
- "Jumble": (
- "jax.interpreters.batching.Jumble is deprecated.",
- _src_batching.Jumble,
- ),
- "JumbleAxis": (
- "jax.interpreters.batching.JumbleAxis is deprecated.",
- _src_batching.JumbleAxis,
- ),
- "JumbleTy": (
- "jax.interpreters.batching.JumbleTy is deprecated.",
- _src_batching.JumbleTy,
- ),
"Elt": (
"jax.interpreters.batching.Elt is deprecated.",
_src_batching.Elt,
@@ -78,10 +66,6 @@
"jax.interpreters.batching.GetIdx is deprecated.",
_src_batching.GetIdx,
),
- "IndexedAxisSize": (
- "jax.interpreters.batching.IndexedAxisSize is deprecated.",
- _src_batching.IndexedAxisSize,
- ),
"MakeIotaHandler": (
"jax.interpreters.batching.MakeIotaHandler is deprecated.",
_src_batching.MakeIotaHandler,
@@ -94,10 +78,6 @@
"jax.interpreters.batching.NotMapped is deprecated.",
_src_batching.NotMapped,
),
- "RaggedAxis": (
- "jax.interpreters.batching.RaggedAxis is deprecated.",
- _src_batching.RaggedAxis,
- ),
"ToEltHandler": (
"jax.interpreters.batching.ToEltHandler is deprecated.",
_src_batching.ToEltHandler,
@@ -130,10 +110,6 @@
"jax.interpreters.batching.batch_jaxpr is deprecated. It is an internal API.",
_src_batching.batch_jaxpr,
),
- "batch_jaxpr2": (
- "jax.interpreters.batching.batch_jaxpr2 is deprecated. It is an internal API.",
- _src_batching.batch_jaxpr2,
- ),
"batch_jaxpr_axes": (
"jax.interpreters.batching.batch_jaxpr_axes is deprecated. It is an internal API.",
_src_batching.batch_jaxpr_axes,
@@ -162,10 +138,6 @@
"jax.interpreters.batching.is_vmappable is deprecated. It is an internal API.",
_src_batching.is_vmappable,
),
- "jumble_axis": (
- "jax.interpreters.batching.jumble_axis is deprecated. It is an internal API.",
- _src_batching.jumble_axis,
- ),
"make_iota": (
"jax.interpreters.batching.make_iota is deprecated. It is an internal API.",
_src_batching.make_iota,
@@ -224,17 +196,12 @@
BatchTrace = _src_batching.BatchTrace
BatchTracer = _src_batching.BatchTracer
BatchingRule = _src_batching.BatchingRule
- Jumble = _src_batching.Jumble
- JumbleAxis = _src_batching.JumbleAxis
- JumbleTy = _src_batching.JumbleTy
Elt = _src_batching.Elt
FromEltHandler = _src_batching.FromEltHandler
GetIdx = _src_batching.GetIdx
- IndexedAxisSize = _src_batching.IndexedAxisSize
MakeIotaHandler = _src_batching.MakeIotaHandler
MapSpec = _src_batching.MapSpec
NotMapped = _src_batching.NotMapped
- RaggedAxis = _src_batching.RaggedAxis
ToEltHandler = _src_batching.ToEltHandler
Vmappable = _src_batching.Vmappable
Zero = _src_batching.Zero
@@ -243,7 +210,6 @@
batch_custom_jvp_subtrace = _src_batching.batch_custom_jvp_subtrace
batch_custom_vjp_bwd = _src_batching.batch_custom_vjp_bwd
batch_jaxpr = _src_batching.batch_jaxpr
- batch_jaxpr2 = _src_batching.batch_jaxpr2
batch_jaxpr_axes = _src_batching.batch_jaxpr_axes
batch_subtrace = _src_batching.batch_subtrace
broadcast_batcher = _src_batching.broadcast_batcher
@@ -251,7 +217,6 @@
from_elt = _src_batching.from_elt
from_elt_handlers = _src_batching.from_elt_handlers
is_vmappable = _src_batching.is_vmappable
- jumble_axis = _src_batching.jumble_axis
make_iota = _src_batching.make_iota
make_iota_handlers = _src_batching.make_iota_handlers
matchaxis = _src_batching.matchaxis
diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py
index 8b1022f73baa..67a246eba226 100644
--- a/jax/interpreters/partial_eval.py
+++ b/jax/interpreters/partial_eval.py
@@ -34,18 +34,6 @@
_deprecations = {
# Deprecated for JAX v0.7.1; finalize in JAX v0.9.0.
- "AbstractedAxesSpec": (
- "jax.interpreters.partial_eval.AbstractedAxesSpec is deprecated.",
- _pe_src.AbstractedAxesSpec,
- ),
- "AbstractedAxisName": (
- "jax.interpreters.partial_eval.AbstractedAxisName is deprecated.",
- _pe_src.AbstractedAxisName,
- ),
- "BoundedAxisSize": (
- "jax.interpreters.partial_eval.BoundedAxisSize is deprecated.",
- _pe_src.BoundedAxisSize,
- ),
"Const": (
"jax.interpreters.partial_eval.Const is deprecated.",
_pe_src.Const,
@@ -130,10 +118,6 @@
"jax.interpreters.partial_eval.abstract_eval_fun is deprecated.",
_pe_src.abstract_eval_fun,
),
- "call_padding_rule": (
- "jax.interpreters.partial_eval.call_padding_rule is deprecated.",
- _pe_src.call_padding_rule,
- ),
"call_param_updaters": (
"jax.interpreters.partial_eval.call_param_updaters is deprecated.",
_pe_src.call_param_updaters,
@@ -178,10 +162,6 @@
"jax.interpreters.partial_eval.custom_staging_rules is deprecated.",
_pe_src.custom_staging_rules,
),
- "def_trivial_padding": (
- "jax.interpreters.partial_eval.def_trivial_padding is deprecated.",
- _pe_src.def_trivial_padding,
- ),
"forwarding_rules": (
"jax.interpreters.partial_eval.forwarding_rules is deprecated.",
_pe_src.forwarding_rules,
@@ -190,10 +170,6 @@
"jax.interpreters.partial_eval.has_effects is deprecated.",
_pe_src.has_effects,
),
- "infer_lambda_input_type": (
- "jax.interpreters.partial_eval.infer_lambda_input_type is deprecated.",
- _pe_src.infer_lambda_input_type,
- ),
"instantiate_const_at": (
"jax.interpreters.partial_eval.instantiate_const_at is deprecated.",
_pe_src.instantiate_const_at,
@@ -214,14 +190,6 @@
"jax.interpreters.partial_eval.new_eqn_recipe is deprecated.",
_pe_src.new_eqn_recipe,
),
- "pad_jaxpr": (
- "jax.interpreters.partial_eval.pad_jaxpr is deprecated.",
- _pe_src.pad_jaxpr,
- ),
- "padding_rules": (
- "jax.interpreters.partial_eval.padding_rules is deprecated.",
- _pe_src.padding_rules,
- ),
"partial_eval_jaxpr_custom": (
"jax.interpreters.partial_eval.partial_eval_jaxpr_custom is deprecated.",
_pe_src.partial_eval_jaxpr_custom,
@@ -246,10 +214,6 @@
"jax.interpreters.partial_eval.recipe_to_eqn is deprecated.",
_pe_src.recipe_to_eqn,
),
- "trace_to_jaxpr_dynamic2": (
- "jax.interpreters.partial_eval.trace_to_jaxpr_dynamic2 is deprecated.",
- _pe_src.trace_to_jaxpr_dynamic2,
- ),
"trace_to_subjaxpr_nounits": (
"jax.interpreters.partial_eval.trace_to_subjaxpr_nounits is deprecated.",
_pe_src.trace_to_subjaxpr_nounits,
@@ -289,7 +253,6 @@
TracerAsName = _pe_src.TracerAsName
TracerId = _pe_src.TracerId
abstract_eval_fun = _pe_src.abstract_eval_fun
- call_padding_rule = _pe_src.call_padding_rule
call_param_updaters = _pe_src.call_param_updaters
call_partial_eval_custom_rule = _pe_src.call_partial_eval_custom_rule
call_partial_eval_rules = _pe_src.call_partial_eval_rules
@@ -301,7 +264,6 @@
convert_envvars_to_constvars = _pe_src.convert_envvars_to_constvars
convert_invars_to_constvars = _pe_src.convert_invars_to_constvars
custom_staging_rules = _pe_src.custom_staging_rules
- def_trivial_padding = _pe_src.def_trivial_padding
forwarding_rules = _pe_src.forwarding_rules
has_effects = _pe_src.has_effects
infer_lambda_input_type = _pe_src.infer_lambda_input_type
@@ -310,8 +272,6 @@
move_binders_to_back = _pe_src.move_binders_to_back
move_binders_to_front = _pe_src.move_binders_to_front
new_eqn_recipe = _pe_src.new_eqn_recipe
- pad_jaxpr = _pe_src.pad_jaxpr
- padding_rules = _pe_src.padding_rules
partial_eval_jaxpr_custom = _pe_src.partial_eval_jaxpr_custom
partial_eval_jaxpr_custom_rule_not_implemented = _pe_src.partial_eval_jaxpr_custom_rule_not_implemented
partial_eval_jaxpr_nounits = _pe_src.partial_eval_jaxpr_nounits
diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py
index 06c771169443..af179b1905a5 100644
--- a/jax/interpreters/pxla.py
+++ b/jax/interpreters/pxla.py
@@ -12,42 +12,153 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from jax._src.interpreters.pxla import (
- Index as Index,
- MapTracer as MapTracer,
- MeshAxisName as MeshAxisName,
- MeshComputation as MeshComputation,
- MeshExecutable as MeshExecutable,
- PmapExecutable as PmapExecutable,
- global_aval_to_result_handler as global_aval_to_result_handler,
- global_avals_to_results_handler as global_avals_to_results_handler,
- global_result_handlers as global_result_handlers,
- parallel_callable as parallel_callable,
- shard_args as shard_args,
- xla_pmap_p as xla_pmap_p,
-)
-from jax._src.mesh import (
- thread_resources as thread_resources,
-)
+# Note: import as is required for names to be exported.
+# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
-from jax._src.op_shardings import (
- are_hlo_shardings_equal as are_hlo_shardings_equal,
- is_hlo_sharding_replicated as is_hlo_sharding_replicated,
- op_sharding_to_indices as op_sharding_to_indices,
-)
+from jax._src.interpreters import pxla as _deprecated_pxla
+from jax._src import mesh as _deprecated_mesh
+from jax._src import op_shardings as _deprecated_op_shardings
+from jax._src import sharding_impls as _deprecated_sharding_impls
+from jax._src import sharding_specs as _deprecated_sharding_specs
-from jax._src.sharding_impls import (
- ArrayMapping as ArrayMapping,
- UNSPECIFIED as _UNSPECIFIED, # noqa: F401
- array_mapping_to_axis_resources as array_mapping_to_axis_resources,
-)
+_deprecations = {
+ # deprecated as of JAX v0.8.2 (Dec 2025)
+ "Index": (
+ "jax.interpreters.pxla.Index is deprecated as of JAX v0.8.2.",
+ _deprecated_pxla.Index,
+ ),
+ "MapTracer": (
+ "jax.interpreters.pxla.MapTracer is deprecated as of JAX v0.8.2.",
+ _deprecated_pxla.MapTracer,
+ ),
+ "MeshAxisName": (
+ "jax.interpreters.pxla.MeshAxisName is deprecated as of JAX v0.8.2. Use jax.sharding.Mesh axis names directly.",
+ _deprecated_pxla.MeshAxisName,
+ ),
+ "MeshComputation": (
+ "jax.interpreters.pxla.MeshComputation is deprecated as of JAX v0.8.2.",
+ _deprecated_pxla.MeshComputation,
+ ),
+ "MeshExecutable": (
+ "jax.interpreters.pxla.MeshExecutable is deprecated as of JAX v0.8.2.",
+ _deprecated_pxla.MeshExecutable,
+ ),
+ "PmapExecutable": (
+ "jax.interpreters.pxla.PmapExecutable is deprecated as of JAX v0.8.2.",
+ _deprecated_pxla.PmapExecutable,
+ ),
+ "global_aval_to_result_handler": (
+ "jax.interpreters.pxla.global_aval_to_result_handler is deprecated as of JAX v0.8.2.",
+ _deprecated_pxla.global_aval_to_result_handler,
+ ),
+ "global_avals_to_results_handler": (
+ "jax.interpreters.pxla.global_avals_to_results_handler is deprecated as of JAX v0.8.2.",
+ _deprecated_pxla.global_avals_to_results_handler,
+ ),
+ "global_result_handlers": (
+ "jax.interpreters.pxla.global_result_handlers is deprecated as of JAX v0.8.2.",
+ _deprecated_pxla.global_result_handlers,
+ ),
+ "parallel_callable": (
+ "jax.interpreters.pxla.parallel_callable is deprecated as of JAX v0.8.2.",
+ _deprecated_pxla.parallel_callable,
+ ),
+ "shard_args": (
+ "jax.interpreters.pxla.shard_args is deprecated as of JAX v0.8.2.",
+ _deprecated_pxla.shard_args,
+ ),
+ "xla_pmap_p": (
+ "jax.interpreters.pxla.xla_pmap_p is deprecated as of JAX v0.8.2.",
+ _deprecated_pxla.xla_pmap_p,
+ ),
+ "thread_resources": (
+ "jax.interpreters.pxla.thread_resources is deprecated as of JAX v0.8.2.",
+ _deprecated_mesh.thread_resources,
+ ),
+ "are_hlo_shardings_equal": (
+ "jax.interpreters.pxla.are_hlo_shardings_equal is deprecated as of JAX v0.8.2.",
+ _deprecated_op_shardings.are_hlo_shardings_equal,
+ ),
+ "is_hlo_sharding_replicated": (
+ "jax.interpreters.pxla.is_hlo_sharding_replicated is deprecated as of JAX v0.8.2.",
+ _deprecated_op_shardings.is_hlo_sharding_replicated,
+ ),
+ "op_sharding_to_indices": (
+ "jax.interpreters.pxla.op_sharding_to_indices is deprecated as of JAX v0.8.2.",
+ _deprecated_op_shardings.op_sharding_to_indices,
+ ),
+ "ArrayMapping": (
+ "jax.interpreters.pxla.ArrayMapping is deprecated as of JAX v0.8.2.",
+ _deprecated_sharding_impls.ArrayMapping,
+ ),
+ "_UNSPECIFIED": (
+ "jax.interpreters.pxla._UNSPECIFIED is deprecated as of JAX v0.8.2.",
+ _deprecated_sharding_impls.UNSPECIFIED,
+ ),
+ "array_mapping_to_axis_resources": (
+ "jax.interpreters.pxla.array_mapping_to_axis_resources is deprecated as of JAX v0.8.2.",
+ _deprecated_sharding_impls.array_mapping_to_axis_resources,
+ ),
+ "Chunked": (
+ "jax.interpreters.pxla.Chunked is deprecated as of JAX v0.8.2.",
+ _deprecated_sharding_specs.Chunked,
+ ),
+ "NoSharding": (
+ "jax.interpreters.pxla.NoSharding is deprecated as of JAX v0.8.2.",
+ _deprecated_sharding_specs.NoSharding,
+ ),
+ "Replicated": (
+ "jax.interpreters.pxla.Replicated is deprecated as of JAX v0.8.2.",
+ _deprecated_sharding_specs.Replicated,
+ ),
+ "ShardedAxis": (
+ "jax.interpreters.pxla.ShardedAxis is deprecated as of JAX v0.8.2.",
+ _deprecated_sharding_specs.ShardedAxis,
+ ),
+ "ShardingSpec": (
+ "jax.interpreters.pxla.ShardingSpec is deprecated as of JAX v0.8.2.",
+ _deprecated_sharding_specs.ShardingSpec,
+ ),
+ "Unstacked": (
+ "jax.interpreters.pxla.Unstacked is deprecated as of JAX v0.8.2.",
+ _deprecated_sharding_specs.Unstacked,
+ ),
+ "spec_to_indices": (
+ "jax.interpreters.pxla.spec_to_indices is deprecated as of JAX v0.8.2.",
+ _deprecated_sharding_specs.spec_to_indices,
+ ),
+}
-from jax._src.sharding_specs import (
- Chunked as Chunked,
- NoSharding as NoSharding,
- Replicated as Replicated,
- ShardedAxis as ShardedAxis,
- ShardingSpec as ShardingSpec,
- Unstacked as Unstacked,
- spec_to_indices as spec_to_indices,
-)
+import typing as _typing
+if _typing.TYPE_CHECKING:
+ Index = _deprecated_pxla.Index
+ MapTracer = _deprecated_pxla.MapTracer
+ MeshAxisName = _deprecated_pxla.MeshAxisName
+ MeshComputation = _deprecated_pxla.MeshComputation
+ MeshExecutable = _deprecated_pxla.MeshExecutable
+ PmapExecutable = _deprecated_pxla.PmapExecutable
+ global_aval_to_result_handler = _deprecated_pxla.global_aval_to_result_handler
+ global_avals_to_results_handler = _deprecated_pxla.global_avals_to_results_handler
+ global_result_handlers = _deprecated_pxla.global_result_handlers
+ parallel_callable = _deprecated_pxla.parallel_callable
+ shard_args = _deprecated_pxla.shard_args
+ xla_pmap_p = _deprecated_pxla.xla_pmap_p
+ thread_resources = _deprecated_mesh.thread_resources
+ are_hlo_shardings_equal = _deprecated_op_shardings.are_hlo_shardings_equal
+ is_hlo_sharding_replicated = _deprecated_op_shardings.is_hlo_sharding_replicated
+ op_sharding_to_indices = _deprecated_op_shardings.op_sharding_to_indices
+ ArrayMapping = _deprecated_sharding_impls.ArrayMapping
+ _UNSPECIFIED = _deprecated_sharding_impls.UNSPECIFIED
+ array_mapping_to_axis_resources = _deprecated_sharding_impls.array_mapping_to_axis_resources
+ Chunked = _deprecated_sharding_specs.Chunked
+ NoSharding = _deprecated_sharding_specs.NoSharding
+ Replicated = _deprecated_sharding_specs.Replicated
+ ShardedAxis = _deprecated_sharding_specs.ShardedAxis
+ ShardingSpec = _deprecated_sharding_specs.ShardingSpec
+ Unstacked = _deprecated_sharding_specs.Unstacked
+ spec_to_indices = _deprecated_sharding_specs.spec_to_indices
+else:
+ from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
+ __getattr__ = _deprecation_getattr(__name__, _deprecations)
+ del _deprecation_getattr
+del _typing
diff --git a/jax/lib/__init__.py b/jax/lib/__init__.py
index 989bcc944067..46b3668e0fdf 100644
--- a/jax/lib/__init__.py
+++ b/jax/lib/__init__.py
@@ -16,11 +16,3 @@
from jax._src.lib import (
version_str as __version__,
)
-
-# Dynamically load submodules because they warn on import.
-# TODO(jakevdp): remove this in JAX v0.9.0.
-def __getattr__(attr):
- if attr in {'xla_bridge', 'xla_client', 'xla_extension'}:
- import importlib
- return importlib.import_module(f'jax.lib.{attr}')
- raise AttributeError(f"module 'jax.lib' has no attribute {attr!r}")
diff --git a/jax/lib/xla_bridge.py b/jax/lib/xla_bridge.py
deleted file mode 100644
index 39bda685f4ec..000000000000
--- a/jax/lib/xla_bridge.py
+++ /dev/null
@@ -1,48 +0,0 @@
-# Copyright 2018 The JAX Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from jax._src import deprecations as _deps
-
-_deps.warn(
- 'jax-lib-module',
- (
- 'jax.lib.xla_bridge module will be removed in JAX v0.9.0;'
- ' all its APIs were deprecated and removed by JAX v0.8.0.'
- ),
- stacklevel=4
-)
-
-_deprecations = {
- # Finalized in JAX v0.8.0; remove these messages in v0.9.0.
- "get_backend": (
- (
- "jax.lib.xla_bridge.get_backend is deprecated and will be removed"
- " in JAX v0.8.0; use jax.extend.backend.get_backend, and please"
- " note that you must `import jax.extend` explicitly."
- ),
- None,
- ),
- "get_compile_options": (
- (
- "jax.lib.xla_bridge.get_compile_options is deprecated in JAX v0.7.0"
- " and will be removed in JAX v0.8.0. Use"
- " jax.extend.backend.get_compile_options, and please note that you"
- " must `import jax.extend` explicitly."
- ),
- None,
- ),
-}
-
-__getattr__ = _deps.deprecation_getattr(__name__, _deprecations)
-del _deps
diff --git a/jax/lib/xla_client.py b/jax/lib/xla_client.py
deleted file mode 100644
index ecebb1a7b9a6..000000000000
--- a/jax/lib/xla_client.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# Copyright 2024 The JAX Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from jax._src import deprecations as _deps
-
-_deps.warn(
- 'jax-lib-module',
- (
- 'jax.lib.xla_client module will be removed in JAX v0.9.0;'
- ' all its APIs were deprecated and removed by JAX v0.8.0.'
- ),
- stacklevel=4
-)
-
-_deprecations = {
- # Finalized in JAX v0.8.0; remove these messages in v0.9.0.
- "Client": (
- (
- "jax.lib.xla_client.Client was deprecated in JAX v0.6.0 and will be"
- " removed in JAX v0.8.0"
- ),
- None,
- ),
- "CompileOptions": (
- (
- "jax.lib.xla_client.CompileOptions was deprecated in JAX v0.6.0 and"
- " will be removed in JAX v0.8.0"
- ),
- None,
- ),
- "Frame": (
- (
- "jax.lib.xla_client.Frame was deprecated in JAX v0.6.0 and will be"
- " removed in JAX v0.8.0"
- ),
- None,
- ),
- "HloSharding": (
- (
- "jax.lib.xla_client.HloSharding was deprecated in JAX v0.6.0 and"
- " will be removed in JAX v0.8.0"
- ),
- None,
- ),
- "OpSharding": (
- (
- "jax.lib.xla_client.OpSharding was deprecated in JAX v0.6.0 and"
- " will be removed in JAX v0.8.0"
- ),
- None,
- ),
- "Traceback": (
- (
- "jax.lib.xla_client.Traceback was deprecated in JAX v0.6.0 and will"
- " be removed in JAX v0.8.0"
- ),
- None,
- ),
-}
-
-__getattr__ = _deps.deprecation_getattr(__name__, _deprecations)
-del _deps
diff --git a/jax/lib/xla_extension.py b/jax/lib/xla_extension.py
deleted file mode 100644
index c02710c081ad..000000000000
--- a/jax/lib/xla_extension.py
+++ /dev/null
@@ -1,60 +0,0 @@
-# Copyright 2024 The JAX Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from jax._src import deprecations as _deps
-
-_deps.warn(
- 'jax-lib-module',
- (
- 'jax.lib.xla_extension module will be removed in JAX v0.9.0;'
- ' all its APIs were deprecated and removed by JAX v0.8.0.'
- ),
- stacklevel=4
-)
-
-_deprecations = {
- # Finalized in JAX v0.8.0; remove these messages in v0.9.0.
- "ifrt_proxy": (
- "jax.lib.xla_extension.ifrt_proxy is deprecated.",
- None,
- ),
- "mlir": ("jax.lib.xla_extension.mlir is deprecated.", None),
- "profiler": (
- "jax.lib.xla_extension.profiler is deprecated.",
- None,
- ),
- "hlo_module_cost_analysis": (
- "jax.lib.xla_extension.hlo_module_cost_analysis is deprecated.",
- None,
- ),
- "hlo_module_to_dot_graph": (
- "jax.lib.xla_extension.hlo_module_to_dot_graph is deprecated.",
- None,
- ),
- "HloPrintOptions": (
- "jax.lib.xla_extension.HloPrintOptions is deprecated.",
- None,
- ),
- "PjitFunction": (
- "jax.lib.xla_extension.PjitFunction is deprecated.",
- None,
- ),
- "PmapFunction": (
- "jax.lib.xla_extension.PmapFunction is deprecated.",
- None,
- ),
-}
-
-__getattr__ = _deps.deprecation_getattr(__name__, _deprecations)
-del _deps
diff --git a/jax/scipy/stats/poisson.py b/jax/scipy/stats/poisson.py
index 5fcde905f89b..ac7cfa141063 100644
--- a/jax/scipy/stats/poisson.py
+++ b/jax/scipy/stats/poisson.py
@@ -19,4 +19,5 @@
logpmf as logpmf,
pmf as pmf,
cdf as cdf,
+ entropy as entropy
)
diff --git a/jax/version.py b/jax/version.py
index 24fccad06e8e..408a38510c28 100644
--- a/jax/version.py
+++ b/jax/version.py
@@ -21,7 +21,7 @@
import pathlib
import subprocess
-_version = "0.8.2"
+_version = "0.8.3"
# The following line is overwritten by build scripts in distributions &
# releases. Do not modify this manually, or jax/jaxlib build will fail.
_release_version: str | None = None
@@ -152,7 +152,7 @@ def make_release_tree(self, base_dir, files):
__version__ = _get_version_string()
-_minimum_jaxlib_version = '0.8.1'
+_minimum_jaxlib_version = '0.8.2'
def _version_as_tuple(version_str):
return tuple(int(i) for i in version_str.split(".") if i.isdigit())
diff --git a/jaxlib/BUILD b/jaxlib/BUILD
index e9cf8b41fc48..ebcd42f05304 100644
--- a/jaxlib/BUILD
+++ b/jaxlib/BUILD
@@ -397,6 +397,7 @@ nanobind_pywrap_extension(
"@xla//xla/pjrt/distributed:key_value_store_interface",
"@xla//xla/pjrt/distributed:protocol_proto_cc",
"@xla//xla/pjrt/distributed:service",
+ "@xla//xla/pjrt/distributed/preemption:preemption_sync_manager",
"@xla//xla/pjrt/plugin/xla_cpu:cpu_client_options",
"@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client",
"@xla//xla/python:logging",
@@ -406,6 +407,7 @@ nanobind_pywrap_extension(
"@xla//xla/python:types",
"@xla//xla/python:version",
"@xla//xla/python/ifrt",
+ "@xla//xla/python/ifrt:attribute_map",
"@xla//xla/python/ifrt:plugin_program",
"@xla//xla/python/ifrt:plugin_program_serdes",
"@xla//xla/python/pjrt_ifrt",
@@ -514,6 +516,7 @@ cc_library(
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/base:no_destructor",
"@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@nanobind",
diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi
index 831f44ea2e8e..405c485e1f76 100644
--- a/jaxlib/_jax/__init__.pyi
+++ b/jaxlib/_jax/__init__.pyi
@@ -1029,6 +1029,8 @@ def array_result_handler(
class ResultHandler:
def __call__(self, arg: Array | Sequence[Array], /) -> Array: ...
+ def wrap(self, arg: Callable, /) -> ResultHandler: ...
+ def pre_wrap(self, arg: Callable, /) -> ResultHandler: ...
class DeviceList:
def __init__(self, arg: tuple[Device, ...], /) -> None: ...
@@ -1199,6 +1201,7 @@ class LoadedExecutable:
def client(self) -> Client: ...
def local_devices(self) -> list[Device]: ...
def get_hlo_text(self) -> str: ...
+ def serialize(self) -> bytes: ...
def size_of_generated_code_in_bytes(self) -> int: ...
def get_compiled_memory_stats(self) -> CompiledMemoryStats: ...
def execute_sharded(
@@ -1226,7 +1229,7 @@ class ExecuteResults:
self, arg: int, /
) -> list[list[Array]]: ...
def consume_with_handlers(
- self, arg: Sequence[ResultHandler | object], /
+ self, out_handlers: Sequence[ResultHandler | object], strict: bool = ...
) -> list[object]: ...
def consume_token(self) -> ShardedToken: ...
diff --git a/jaxlib/call_location.cc b/jaxlib/call_location.cc
index df335af1c46e..96855b114fb2 100644
--- a/jaxlib/call_location.cc
+++ b/jaxlib/call_location.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "absl/base/no_destructor.h"
#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_map.h"
+#include "absl/log/check.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/synchronization/mutex.h"
@@ -31,10 +32,10 @@ limitations under the License.
#include "nanobind/stl/optional.h" // IWYU pragma: keep
#include "nanobind/stl/string.h" // IWYU pragma: keep
#include "nanobind/stl/vector.h" // IWYU pragma: keep
-#include "jaxlib/traceback.h"
#include "jaxlib/py_user_context.h"
-#include "xla/python/ifrt/executable.h"
+#include "jaxlib/traceback.h"
#include "xla/python/ifrt/attribute_map.h"
+#include "xla/python/ifrt/executable.h"
#include "xla/python/ifrt/user_context.h"
#include "xla/python/pjrt_ifrt/pjrt_executable.h"
@@ -124,14 +125,12 @@ void PopulateCallLocation(xla::ifrt::ExecuteOptions& options,
}
if (!call_location_str.empty()) {
- xla::ifrt::AttributeMap::Map attrs_map;
- if (options.custom_options.has_value()) {
- attrs_map = options.custom_options->map();
+ if (!options.custom_options.has_value()) {
+ options.custom_options.emplace(xla::ifrt::AttributeMap({}));
}
- attrs_map.insert(
- {std::string(xla::ifrt::PjRtCompatibleLoadedExecutable::kCallLocation),
- xla::ifrt::AttributeMap::StringValue(std::move(call_location_str))});
- options.custom_options.emplace(std::move(attrs_map));
+ CHECK_OK(options.custom_options->Set(
+ std::string(xla::ifrt::PjRtCompatibleLoadedExecutable::kCallLocation),
+ std::move(call_location_str)));
}
}
diff --git a/jaxlib/callback.cc b/jaxlib/callback.cc
index 83d46ed45c07..81263d149c64 100644
--- a/jaxlib/callback.cc
+++ b/jaxlib/callback.cc
@@ -104,7 +104,7 @@ absl::Status CpuCallback::PrepareAndCall(void** result, void** arg_ptrs) {
xla::primitive_util::ByteWidth(results_[i].type);
options.dims = dims;
options.permutation = results_[i].reversed_layout;
- options.input_layout = xla::TransposePlan::Striding{strides};
+ options.input_striding = xla::TransposePlan::Striding{strides};
absl::StatusOr> plan =
transpose_cache_.GetOrCreate(options);
if (!plan.ok()) {
diff --git a/jaxlib/dlpack.cc b/jaxlib/dlpack.cc
index 03388de494b4..14ccb51a85c1 100644
--- a/jaxlib/dlpack.cc
+++ b/jaxlib/dlpack.cc
@@ -178,11 +178,14 @@ absl::StatusOr> GetByteStrides(const DLTensor& dl_tensor) {
return strides;
}
-absl::StatusOr> MakePjrtBuffer(
- xla::PjRtDevice& device, ::DLManagedTensor* dlmt, const xla::Shape& shape,
- xla::PrimitiveType element_type, absl::Span dimensions,
- std::optional copy = std::nullopt,
- std::optional stream = std::nullopt) {
+// Makes a PjRtBuffer from a DLPack tensor. Returns a pair where the second
+// element is true if a copy actually happened.
+absl::StatusOr, bool>>
+MakePjrtBuffer(xla::PjRtDevice& device, ::DLManagedTensor* dlmt,
+ const xla::Shape& shape, xla::PrimitiveType element_type,
+ absl::Span dimensions,
+ std::optional copy = std::nullopt,
+ std::optional stream = std::nullopt) {
std::function on_delete_callback;
if (dlmt->deleter) {
on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); };
@@ -204,7 +207,8 @@ absl::StatusOr> MakePjrtBuffer(
stream);
if (!(result.status().code() == absl::StatusCode::kInvalidArgument &&
fallback_to_copy)) {
- return result;
+ TF_RETURN_IF_ERROR(result.status());
+ return std::make_pair(*std::move(result), false);
}
}
@@ -217,10 +221,13 @@ absl::StatusOr> MakePjrtBuffer(
TF_ASSIGN_OR_RETURN(auto* memory_space, device.default_memory_space());
// Create a copy.
- return device.client()->BufferFromHostBuffer(
- data, element_type, dimensions, byte_strides,
- xla::PjRtClient::HostBufferSemantics::kMutableZeroCopy,
- on_delete_callback, memory_space, /*device_layout=*/nullptr);
+ TF_ASSIGN_OR_RETURN(
+ auto buffer,
+ device.client()->BufferFromHostBuffer(
+ data, element_type, dimensions, byte_strides,
+ xla::PjRtClient::HostBufferSemantics::kMutableZeroCopy,
+ on_delete_callback, memory_space, /*device_layout=*/nullptr));
+ return std::make_pair(std::move(buffer), true);
}
} // namespace
@@ -365,9 +372,13 @@ absl::StatusOr DLPackManagedTensorToBuffer(
xla::Shape shape = xla::ShapeUtil::MakeShapeWithDenseLayout(
element_type, dimensions, minor_to_major);
- TF_ASSIGN_OR_RETURN(auto pjrt_buffer,
+ TF_ASSIGN_OR_RETURN(auto pjrt_buffer_and_copied,
MakePjrtBuffer(*device->pjrt_device(), dlmt, shape,
element_type, dimensions, copy, stream));
+ if (pjrt_buffer_and_copied.second) {
+ // A PjRtBuffer uses a default layout if it has been created using copy.
+ has_custom_layout = false;
+ }
// We have taken ownership of the array inside the capsule; make sure the
// capsule it cannot be used again.
@@ -383,7 +394,8 @@ absl::StatusOr DLPackManagedTensorToBuffer(
PyUserContextScope user_context_scope;
TF_ASSIGN_OR_RETURN(
auto ifrt_array,
- ifrt_client->CreatePjRtArray(std::move(pjrt_buffer), has_custom_layout));
+ ifrt_client->CreatePjRtArray(std::move(pjrt_buffer_and_copied.first),
+ has_custom_layout));
return PyArray::MakeFromSingleDeviceArray(std::move(client),
std::move(ifrt_array), false, true);
}
diff --git a/jaxlib/ffi.cc b/jaxlib/ffi.cc
index ff1dd96958f5..e5f31079b332 100644
--- a/jaxlib/ffi.cc
+++ b/jaxlib/ffi.cc
@@ -21,6 +21,9 @@ limitations under the License.
#include
#include
#include
+#include
+#include
+#include
#include
#include "absl/base/casts.h"
diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD
index 9d6111fcadaa..6169c882aa05 100644
--- a/jaxlib/gpu/BUILD
+++ b/jaxlib/gpu/BUILD
@@ -82,6 +82,10 @@ proto_library(
cc_proto_library(
name = "triton_cc_proto",
compatible_with = None,
+ visibility = [
+ "//jax:internal",
+ "//third_party/py/enzyme_ad:__subpackages__",
+ ],
deps = [":triton_proto"],
)
diff --git a/jaxlib/gpu/gpu_plugin_extension.cc b/jaxlib/gpu/gpu_plugin_extension.cc
index 862430257f27..d3f411cc87b6 100644
--- a/jaxlib/gpu/gpu_plugin_extension.cc
+++ b/jaxlib/gpu/gpu_plugin_extension.cc
@@ -48,9 +48,6 @@ namespace {
struct TritonCompilationResult {
std::string asm_text;
int64_t smem_bytes;
- int cluster_dim_x;
- int cluster_dim_y;
- int cluster_dim_z;
};
absl::StatusOr CompileTritonToASM(
@@ -77,9 +74,6 @@ absl::StatusOr CompileTritonToASM(
return TritonCompilationResult{
.asm_text = asm_text,
.smem_bytes = args.out_smem_bytes,
- .cluster_dim_x = args.out_cluster_dim_x,
- .cluster_dim_y = args.out_cluster_dim_y,
- .cluster_dim_z = args.out_cluster_dim_z,
};
}
@@ -240,10 +234,7 @@ void BuildGpuPluginExtension(nanobind::module_& m) {
nb::class_(m, "TritonCompilationResult")
.def_ro("asm", &TritonCompilationResult::asm_text)
- .def_ro("smem_bytes", &TritonCompilationResult::smem_bytes)
- .def_ro("cluster_dim_x", &TritonCompilationResult::cluster_dim_x)
- .def_ro("cluster_dim_y", &TritonCompilationResult::cluster_dim_y)
- .def_ro("cluster_dim_z", &TritonCompilationResult::cluster_dim_z);
+ .def_ro("smem_bytes", &TritonCompilationResult::smem_bytes);
m.def("compile_triton_to_asm",
[](nb::capsule c_api, nb::bytes module, std::string_view arch_name,
diff --git a/jaxlib/gpu/py_client_gpu.cc b/jaxlib/gpu/py_client_gpu.cc
index b3c091403c0c..f6678e023c6e 100644
--- a/jaxlib/gpu/py_client_gpu.cc
+++ b/jaxlib/gpu/py_client_gpu.cc
@@ -217,7 +217,7 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream,
absl::c_reverse_copy(expected_shape.layout().minor_to_major(),
reversed_layout.begin());
options.permutation = reversed_layout;
- options.input_layout = xla::TransposePlan::Striding{strides};
+ options.input_striding = xla::TransposePlan::Striding{strides};
auto maybe_plan = transpose_cache->cache.GetOrCreate(options);
if (!maybe_plan.ok()) {
return xla::ffi::Error::Internal(maybe_plan.status().ToString());
diff --git a/jaxlib/gpu/triton.cc b/jaxlib/gpu/triton.cc
index 42c58eb613a2..a1bb10ed510f 100644
--- a/jaxlib/gpu/triton.cc
+++ b/jaxlib/gpu/triton.cc
@@ -45,8 +45,8 @@ namespace jax::JAX_GPU_NAMESPACE {
NB_MODULE(_triton, m) {
nb::class_(m, "TritonKernel")
- .def(nb::init());
+ .def(nb::init());
nb::class_(m, "TritonParameter");
diff --git a/jaxlib/gpu/triton.proto b/jaxlib/gpu/triton.proto
index 786b07afbdbe..553f95dd5f17 100644
--- a/jaxlib/gpu/triton.proto
+++ b/jaxlib/gpu/triton.proto
@@ -5,6 +5,7 @@ package jax_triton;
message TritonKernel {
string kernel_name = 1; // Kernel function name within module.
uint32 num_warps = 2;
+ optional uint32 num_ctas = 10;
uint32 shared_mem_bytes = 3;
string ptx = 4;
string ttir = 5;
diff --git a/jaxlib/gpu/triton_kernels.cc b/jaxlib/gpu/triton_kernels.cc
index 1961ada1bf76..0ad86f522d9d 100644
--- a/jaxlib/gpu/triton_kernels.cc
+++ b/jaxlib/gpu/triton_kernels.cc
@@ -315,17 +315,16 @@ class ModuleImage {
ABSL_GUARDED_BY(mutex_);
};
-Kernel::Kernel(std::string kernel_name, uint32_t num_warps,
+Kernel::Kernel(std::string kernel_name, uint32_t num_warps, uint32_t num_ctas,
uint32_t shared_mem_bytes, std::string ptx, std::string ttir,
- int compute_capability, uint32_t cluster_dim_0,
- uint32_t cluster_dim_1, uint32_t cluster_dim_2)
+ int compute_capability)
: kernel_name_(std::move(kernel_name)),
block_dim_x_(num_warps * kNumThreadsPerWarp),
+ num_ctas_(num_ctas),
shared_mem_bytes_(shared_mem_bytes),
ptx_(std::move(ptx)),
ttir_(std::move(ttir)),
- compute_capability_(compute_capability),
- cluster_dims_{cluster_dim_0, cluster_dim_1, cluster_dim_2} {}
+ compute_capability_(compute_capability) {}
absl::Status Kernel::Launch(gpuStream_t stream, uint32_t grid[3],
void** params) {
@@ -362,9 +361,7 @@ absl::Status Kernel::Launch(gpuStream_t stream, uint32_t grid[3],
JAX_ASSIGN_OR_RETURN(gpuFunction_t kernel,
module_image_->GetFunctionForContext(context));
- const uint32_t cluster_size =
- cluster_dims_[0] * cluster_dims_[1] * cluster_dims_[2];
- if (cluster_size <= 1) {
+ if (num_ctas_ == 1) {
return JAX_AS_STATUS(gpuLaunchKernel(
kernel, grid[0], grid[1], grid[2], block_dim_x_,
/*blockDimY=*/1, /*blockDimZ=*/1, shared_mem_bytes_, stream, params,
@@ -372,16 +369,16 @@ absl::Status Kernel::Launch(gpuStream_t stream, uint32_t grid[3],
}
CUlaunchAttribute launch_attrs[2];
launch_attrs[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
- launch_attrs[0].value.clusterDim.x = cluster_dims_[0];
- launch_attrs[0].value.clusterDim.y = cluster_dims_[1];
- launch_attrs[0].value.clusterDim.z = cluster_dims_[2];
+ launch_attrs[0].value.clusterDim.x = num_ctas_;
+ launch_attrs[0].value.clusterDim.y = 1;
+ launch_attrs[0].value.clusterDim.z = 1;
launch_attrs[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
launch_attrs[1].value.clusterSchedulingPolicyPreference =
CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
CUlaunchConfig launch_config = {
- /*gridDimX=*/grid[0] * cluster_dims_[0],
- /*gridDimY=*/grid[1] * cluster_dims_[1],
- /*gridDimZ=*/grid[2] * cluster_dims_[2],
+ /*gridDimX=*/grid[0] * num_ctas_,
+ /*gridDimY=*/grid[1],
+ /*gridDimZ=*/grid[2],
/*blockDimX=*/block_dim_x_,
/*blockDimY=*/1,
/*blockDimZ=*/1,
@@ -396,23 +393,23 @@ absl::Status Kernel::Launch(gpuStream_t stream, uint32_t grid[3],
}
/*static*/ Kernel Kernel::FromProto(const jax_triton::TritonKernel& proto) {
- return Kernel(proto.kernel_name(), proto.num_warps(),
+ // Use 1 as default value if not specified in already serialized kernels.
+ int num_ctas = proto.has_num_ctas() ? proto.num_ctas() : 1;
+
+ return Kernel(proto.kernel_name(), proto.num_warps(), num_ctas,
proto.shared_mem_bytes(), proto.ptx(), proto.ttir(),
- proto.compute_capability(), proto.cluster_dim_0(),
- proto.cluster_dim_1(), proto.cluster_dim_2());
+ proto.compute_capability());
}
jax_triton::TritonKernel Kernel::ToProto() const {
jax_triton::TritonKernel proto;
proto.set_kernel_name(kernel_name_);
proto.set_num_warps(block_dim_x_ / kNumThreadsPerWarp);
+ proto.set_num_ctas(num_ctas_);
proto.set_shared_mem_bytes(shared_mem_bytes_);
proto.set_ptx(ptx_);
proto.set_ttir(ttir_);
proto.set_compute_capability(compute_capability_);
- proto.set_cluster_dim_0(cluster_dims_[0]);
- proto.set_cluster_dim_1(cluster_dims_[1]);
- proto.set_cluster_dim_2(cluster_dims_[2]);
return proto;
}
diff --git a/jaxlib/gpu/triton_kernels.h b/jaxlib/gpu/triton_kernels.h
index 3ab3e9143fb8..08320a104183 100644
--- a/jaxlib/gpu/triton_kernels.h
+++ b/jaxlib/gpu/triton_kernels.h
@@ -38,10 +38,9 @@ class ModuleImage;
class Kernel {
public:
- Kernel(std::string kernel_name, uint32_t num_warps, uint32_t shared_mem_bytes,
- std::string ptx, std::string ttir, int compute_capability,
- uint32_t cluster_dim_0, uint32_t cluster_dim_1,
- uint32_t cluster_dim_2);
+ Kernel(std::string kernel_name, uint32_t num_warps, uint32_t num_ctas,
+ uint32_t shared_mem_bytes, std::string ptx, std::string ttir,
+ int compute_capability);
absl::Status Launch(gpuStream_t stream, uint32_t grid[3], void** params);
@@ -54,11 +53,11 @@ class Kernel {
private:
std::string kernel_name_;
uint32_t block_dim_x_;
+ uint32_t num_ctas_;
uint32_t shared_mem_bytes_;
std::string ptx_;
std::string ttir_;
int compute_capability_;
- uint32_t cluster_dims_[3];
ModuleImage* module_image_ = nullptr;
};
@@ -107,8 +106,7 @@ class AutotunedKernelCall {
AutotunedKernelCall(
std::string name, std::vector configs,
- std::vector> input_output_aliases);
+ std::vector> input_output_aliases);
static absl::StatusOr Autotune(AutotunedKernelCall kernel_call,
gpuStream_t stream,
diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl
index 49e11679da88..abb6658cf640 100644
--- a/jaxlib/jax.bzl
+++ b/jaxlib/jax.bzl
@@ -45,6 +45,7 @@ tf_cuda_tests_tags = _tf_cuda_tests_tags
jax_internal_packages = []
jax_extend_internal_users = []
+experimental_transfer_users = []
mosaic_gpu_internal_users = []
mosaic_internal_users = []
pallas_gpu_internal_users = []
@@ -137,11 +138,13 @@ jax2tf_deps = []
def pytype_library(name, pytype_srcs = None, **kwargs):
_ = pytype_srcs # @unused
+ kwargs.pop("lazy_imports", None)
py_library(name = name, **kwargs)
def pytype_strict_library(name, pytype_srcs = [], **kwargs):
data = pytype_srcs + (kwargs["data"] if "data" in kwargs else [])
new_kwargs = {k: v for k, v in kwargs.items() if k != "data"}
+ new_kwargs.pop("lazy_imports", None)
py_library(name = name, data = data, **new_kwargs)
py_strict_library = py_library
@@ -150,6 +153,7 @@ py_strict_test = py_test
def py_library_providing_imports_info(*, name, lib_rule = py_library, pytype_srcs = [], **kwargs):
data = pytype_srcs + (kwargs["data"] if "data" in kwargs else [])
new_kwargs = {k: v for k, v in kwargs.items() if k != "data"}
+ new_kwargs.pop("lazy_imports", None)
lib_rule(name = name, data = data, **new_kwargs)
def py_extension(name, srcs, copts, deps, linkopts = []):
diff --git a/jaxlib/jax.cc b/jaxlib/jax.cc
index 1571ddb29bb9..30959db5f1de 100644
--- a/jaxlib/jax.cc
+++ b/jaxlib/jax.cc
@@ -65,6 +65,7 @@ limitations under the License.
#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h"
#include "xla/pjrt/status_casters.h"
#include "xla/python/ifrt/array.h"
+#include "xla/python/ifrt/attribute_map.h"
#include "xla/python/ifrt/device.h"
#include "xla/python/ifrt/device_list.h"
#include "xla/python/ifrt/executable.h"
@@ -116,6 +117,7 @@ limitations under the License.
#include "xla/hlo/builder/lib/approx_topk_shape.h"
#include "xla/pjrt/c_api_client/pjrt_c_api_client.h"
#include "xla/pjrt/distributed/key_value_store_interface.h"
+#include "xla/pjrt/distributed/preemption/preemption_sync_manager.h"
#include "xla/pjrt/exceptions.h"
#include "xla/pjrt/pjrt_api.h"
#include "xla/pjrt/pjrt_client.h"
@@ -580,6 +582,30 @@ NB_MODULE(_jax, m) {
aux::RegisterTransferServerTypes(m);
#endif // defined(__linux__)
+#if JAX_IFRT_VERSION_NUMBER >= 39
+ nb::class_ preemption_sync_manager(
+ m, "PreemptionSyncManager");
+ preemption_sync_manager
+ .def(
+ "initialize",
+ [](xla::PreemptionSyncManager& manager,
+ xla::DistributedRuntimeClient* client) {
+ xla::CoordinationServiceAgent* agent =
+ xla::ValueOrThrow(client->GetCoordinationServiceAgent());
+ xla::ThrowIfError(manager.Initialize(agent));
+ },
+ nb::arg("distributed_client"))
+ .def("reached_sync_point",
+ [](xla::PreemptionSyncManager& manager, int step_counter) {
+ return manager.ReachedSyncPoint(step_counter);
+ })
+ .def("shutdown", [](xla::PreemptionSyncManager& manager) {
+ nb::gil_scoped_release gil_release;
+ manager.Shutdown();
+ });
+ m.def("create_preemption_sync_manager",
+ []() { return xla::CreatePreemptionSyncManager(); });
+#else
nb::class_ preemption_sync_manager(
m, "PreemptionSyncManager");
preemption_sync_manager
@@ -602,6 +628,7 @@ NB_MODULE(_jax, m) {
});
m.def("create_preemption_sync_manager",
[]() { return tsl::CreatePreemptionSyncManager(); });
+#endif
nb::class_ distributed_runtime_service(
m, "DistributedRuntimeService");
@@ -885,11 +912,12 @@ NB_MODULE(_jax, m) {
.def("__getattr__",
[](xla::ifrt::Topology& topology,
std::string_view name) -> nb::object {
- const auto& attrs = topology.Attributes().map();
- auto it = attrs.find(name);
- if (it != attrs.end()) {
+ auto value =
+ topology.Attributes().Get(
+ std::string(name));
+ if (value.ok()) {
return std::visit([](auto&& v) { return nb::cast(v.value); },
- it->second);
+ *value);
}
throw nb::attribute_error(
absl::StrCat("Unknown attribute ", name).c_str());
@@ -898,7 +926,6 @@ NB_MODULE(_jax, m) {
nb::class_(
m, "TransferServerInterfaceFactory");
-
m.def("is_asan", IsAsan);
m.def("is_msan", IsMsan);
m.def("is_tsan", IsTsan);
diff --git a/jaxlib/mosaic/BUILD b/jaxlib/mosaic/BUILD
index 5c1c8b58f1da..cced087e215b 100644
--- a/jaxlib/mosaic/BUILD
+++ b/jaxlib/mosaic/BUILD
@@ -148,16 +148,50 @@ gentbl_cc_library(
name = "tpu_inc_gen",
# compatible with libtpu
tbl_outs = {
- "dialect/tpu/tpu_ops.h.inc": ["-gen-op-decls"],
- "dialect/tpu/tpu_ops.cc.inc": ["-gen-op-defs"],
- "dialect/tpu/tpu_dialect.h.inc": ["-gen-dialect-decls"],
- "dialect/tpu/tpu_dialect.cc.inc": ["-gen-dialect-defs"],
- "dialect/tpu/tpu_enums.h.inc": ["-gen-enum-decls"],
- "dialect/tpu/tpu_enums.cc.inc": ["-gen-enum-defs"],
- "dialect/tpu/tpu_attr_defs.h.inc": ["-gen-attrdef-decls"],
- "dialect/tpu/tpu_attr_defs.cc.inc": ["-gen-attrdef-defs"],
- "dialect/tpu/tpu_type_defs.h.inc": ["-gen-typedef-decls"],
- "dialect/tpu/tpu_type_defs.cc.inc": ["-gen-typedef-defs"],
+ "dialect/tpu/tpu_ops.h.inc": [
+ "-gen-op-decls",
+ "-dialect=tpu",
+ ],
+ "dialect/tpu/tpu_ops.cc.inc": [
+ "-gen-op-defs",
+ "-dialect=tpu",
+ ],
+ "dialect/tpu/tpu_dialect.h.inc": [
+ "-gen-dialect-decls",
+ "-dialect=tpu",
+ ],
+ "dialect/tpu/tpu_dialect.cc.inc": [
+ "-gen-dialect-defs",
+ "-dialect=tpu",
+ ],
+ "dialect/tpu/tpu_enums.h.inc": [
+ "-gen-enum-decls",
+ "-dialect=tpu",
+ ],
+ "dialect/tpu/tpu_enums.cc.inc": [
+ "-gen-enum-defs",
+ "-dialect=tpu",
+ ],
+ "dialect/tpu/tpu_attr_defs.h.inc": [
+ "-gen-attrdef-decls",
+ "-dialect=tpu",
+ "--attrdefs-dialect=tpu",
+ ],
+ "dialect/tpu/tpu_attr_defs.cc.inc": [
+ "-gen-attrdef-defs",
+ "-dialect=tpu",
+ "--attrdefs-dialect=tpu",
+ ],
+ "dialect/tpu/tpu_type_defs.h.inc": [
+ "-gen-typedef-decls",
+ "-dialect=tpu",
+ "--typedefs-dialect=tpu",
+ ],
+ "dialect/tpu/tpu_type_defs.cc.inc": [
+ "-gen-typedef-defs",
+ "-dialect=tpu",
+ "--typedefs-dialect=tpu",
+ ],
"dialect/tpu/tpu_passes.h.inc": [
"-gen-pass-decls",
"-name=TPU",
@@ -172,8 +206,8 @@ gentbl_cc_library(
],
},
tblgen = "@llvm-project//mlir:mlir-tblgen",
- td_file = "dialect/tpu/tpu.td",
- deps = [":tpu_td_files"],
+ td_file = "dialect/tpu/tpu_ops.td",
+ deps = [":tpu_ops_td_files"],
)
td_library(
@@ -184,6 +218,18 @@ td_library(
# compatible with libtpu
deps = [
"@llvm-project//mlir:BuiltinDialectTdFiles",
+ ],
+)
+
+td_library(
+ name = "tpu_ops_td_files",
+ srcs = [
+ "dialect/tpu/tpu_ops.td",
+ ],
+ # compatible with libtpu
+ deps = [
+ ":tpu_td_files",
+ "@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:ControlFlowInterfacesTdFiles",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td
index 9cc2968ca8b0..cc733d28fbc8 100644
--- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td
+++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td
@@ -195,6 +195,25 @@ def MosaicGPU_SwizzlingMode : I32EnumAttr<"SwizzlingMode",
let cppNamespace = "::mosaic_gpu";
}
+def MosaicGPU_TMAReduction : I32EnumAttr<"TMAReduction",
+ "Reduction operation for TMA.",
+ [
+ I32EnumAttrCase<"Add", 0, "add">,
+ I32EnumAttrCase<"Min", 1, "min">,
+ I32EnumAttrCase<"Max", 2, "max">,
+ I32EnumAttrCase<"Inc", 3, "inc">,
+ I32EnumAttrCase<"Dec", 4, "dec">,
+ I32EnumAttrCase<"And", 5, "and">,
+ I32EnumAttrCase<"Or", 6, "or">,
+ I32EnumAttrCase<"Xor", 7, "xor">,
+ I32EnumAttrCase<"Umin", 8, "umin">,
+ I32EnumAttrCase<"Umax", 9, "umax">,
+ I32EnumAttrCase<"Smin", 10, "smin">,
+ I32EnumAttrCase<"Smax", 11, "smax">
+ ]>{
+ let cppNamespace = "::mosaic_gpu";
+}
+
def TileTransformAttr : MosaicGPU_Attr<"TileTransform", "tile"> {
let parameters = (ins ArrayRefParameter<"int32_t", "tiling">:$tiling);
let summary = "Specifies a transform that tiles suffix dimensions of a memref in SMEM.";
@@ -345,6 +364,10 @@ def MosaicGPU_AsyncStoreOp : Op:$commit_group
+ DefaultValuedOptionalAttr:$commit_group,
+ OptionalAttr:$reduction_op
);
let assemblyFormat = [{
@@ -470,7 +494,7 @@ def MosaicGPU_BroadcastInDimOp : Op {
}
-def MosaicGPU_SliceSMEMOp : Op {
+def MosaicGPU_SliceSMEMOp : Op {
let summary = "Constructs an SMEM MemRef with the requested type that begins at the specified SMEM offset address.";
let arguments = (ins I32:$offset);
diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td
index 5f66e7caf9ff..b19614bdb3bd 100644
--- a/jaxlib/mosaic/dialect/tpu/tpu.td
+++ b/jaxlib/mosaic/dialect/tpu/tpu.td
@@ -1,4 +1,4 @@
-/* Copyright 2023 The JAX Authors.
+/* Copyright 2025 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -13,18 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TPU_ATTRS
-#define TPU_ATTRS
+#ifndef TPU_BASE
+#define TPU_BASE
-include "mlir/IR/OpBase.td"
-include "mlir/IR/AttrTypeBase.td"
-include "mlir/IR/BuiltinAttributeInterfaces.td"
+include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/BuiltinTypeInterfaces.td"
-include "mlir/IR/EnumAttr.td"
-include "mlir/Pass/PassBase.td"
-include "mlir/Interfaces/ControlFlowInterfaces.td"
-include "mlir/Interfaces/SideEffectInterfaces.td"
-include "mlir/Interfaces/InferTypeOpInterface.td"
def TPU_Dialect : Dialect {
let name = "tpu";
@@ -45,840 +38,6 @@ class TPU_Attr traits = []>
let mnemonic = mnemonic_;
}
-// TODO(b/369418606): Find out the way to verify vreg size.
-def TPU_Vreg : Type;
-
-class TPU_Type traits = [],
- string baseCppType = "::mlir::Type">
- : TypeDef {
- let mnemonic = mnemonic_;
-}
-
-def TPU_CoreType : I32EnumAttr<"CoreType", "Core type", [
- I32EnumAttrCase<"kTc", 0, "tc">,
- I32EnumAttrCase<"kScScalarSubcore", 1, "sc_scalar_subcore">,
- I32EnumAttrCase<"kScVectorSubcore", 2, "sc_vector_subcore">
-]> {
- let genSpecializedAttr = 0;
- let cppNamespace = "::mlir::tpu";
-}
-
-def TPU_CoreTypeEnum : EnumAttr {
- let assemblyFormat = "`<` $value `>`";
-}
-
-def TPU_PipelineMode : I32EnumAttr<"PipelineMode", "Pipeline mode", [
- I32EnumAttrCase<"kSynchronous", 1, "synchronous">,
- I32EnumAttrCase<"kDoubleBuffered", 2, "double_buffered">
- ]> {
- let genSpecializedAttr = 0;
- let cppNamespace = "::mlir::tpu";
-}
-
-def TPU_PipelineModeEnum : EnumAttr {
- let assemblyFormat = "`<` $value `>`";
-}
-
-def TPU_SemaphoreType : TPU_Type<"Semaphore", "semaphore", [MemRefElementTypeInterface]>;
-def TPU_DMASemaphoreType : TPU_Type<"DMASemaphore", "dma_semaphore", [MemRefElementTypeInterface]>;
-def TPU_SomeSemaphoreType : AnyTypeOf<[TPU_SemaphoreType, TPU_DMASemaphoreType]>;
-
-def TPU_Float8EXMYType : TPU_Type<"Float8EXMY", "float8_exmy",
- [DeclareTypeInterfaceMethods]> {
- let summary = "EXMY type in a 8 bit container";
- let description = [{
- EXMY type in a 8 bit container. Meaningful bits are aligned to LSB, and
- bits higher than the underlying exmy type in the container are considered
- as ignored. See https://arxiv.org/abs/2405.13938 for more details.
- }];
-
- let parameters = (ins
- TypeParameter<"::mlir::FloatType", "Underlying EXMY type">:$underlying_type
- );
-
- let assemblyFormat = [{
- `<` $underlying_type `>`
- }];
-}
-
-def TPU_DimensionSemantics : I32EnumAttr<"DimensionSemantics", "Dimension semantics", [
- I32EnumAttrCase<"parallel", 0>,
- I32EnumAttrCase<"arbitrary", 1>,
- I32EnumAttrCase<"core_parallel", 2>,
- I32EnumAttrCase<"subcore_parallel", 3>
-]> {
- let genSpecializedAttr = 0;
- let cppNamespace = "::mlir::tpu";
-}
-
-def TPU_DimensionSemanticsEnum
- : EnumAttr {
- let assemblyFormat = "`<` $value `>`";
-}
-
-// All indices/sizes are in element-space.
-// Note that the implementation will require statically provable tile alignment.
-def TPU_ElementWindowAttr : TPU_Attr<"ElementWindow", "element_window"> {
- // Including low padding, to avoid backwards-incompatible changes once we add it.
- let parameters = (ins
- ArrayRefParameter<"int64_t", "">:$pad_low,
- ArrayRefParameter<"int64_t", "">:$pad_high
- );
- let assemblyFormat = "`<` `[` $pad_low `]` `,` `[` $pad_high `]` `>`";
-}
-
-def TPU_ContractPrecision : I32EnumAttr<"ContractPrecision", "Contraction precision", [
- I32EnumAttrCase<"kBF16", 0, "bf16">,
- I32EnumAttrCase<"kFP32", 1, "fp32">
-]> {
- let genSpecializedAttr = 0;
- let cppNamespace = "::mlir::tpu";
-}
-
-def TPU_ContractPrecisionEnum
- : EnumAttr {
- let assemblyFormat = "`<` $value `>`";
-}
-
-def TPU_PackFormat : I32EnumAttr<"PackFormat", "Pack format", [
- I32EnumAttrCase<"kCompressed", 0, "compressed">,
- I32EnumAttrCase<"kInterleaved", 1, "interleaved">
-]> {
- let genSpecializedAttr = 0;
- let cppNamespace = "::mlir::tpu";
-}
-
-def TPU_PackFormatEnum : EnumAttr {
- let assemblyFormat = "`<` $value `>`";
-}
-
-def TPU_TiledCase : I32EnumAttrCase<"tiled", 0>;
-def TPU_LaneCase : I32EnumAttrCase<"lanes", 1>;
-def TPU_SublaneCase : I32EnumAttrCase<"sublanes", 2>;
-def TPU_VectorLayoutDim : I32EnumAttr<
- "VectorLayoutDim", "", [TPU_TiledCase, TPU_LaneCase, TPU_SublaneCase]>;
-
-def TPU_VectorLayoutAttr : TPU_Attr<"VectorLayout", "vpad"> {
- let description = [{TODO}];
-
- let parameters = (ins "Layout":$layout);
- let hasCustomAssemblyFormat = 1;
-}
-
-def TPU_TiledLayoutAttr
- : TPU_Attr<"TiledLayout", "tiled",
- [DeclareAttrInterfaceMethods]> {
- let description = [{TODO}];
- let parameters = (ins
- ArrayRefParameter<"::xla::Tile", "">:$tiles,
- ArrayRefParameter<"int64_t", "">:$tile_strides
- );
-
- let hasCustomAssemblyFormat = 1;
-}
-
-def TPU_MemorySpace : I32EnumAttr<"MemorySpace", "Memory space", [
- I32EnumAttrCase<"kAny", 4294967295, "any">,
- I32EnumAttrCase<"kVmem", 0, "vmem">,
- I32EnumAttrCase<"kSmem", 1, "smem">,
- I32EnumAttrCase<"kHbm", 2, "hbm">,
- I32EnumAttrCase<"kCmem", 3, "cmem">,
- I32EnumAttrCase<"kSemaphoreMem", 4, "semaphore_mem">,
- I32EnumAttrCase<"kVmemShared", 5, "vmem_shared">,
- I32EnumAttrCase<"kHost", 6, "host">
-]> {
- let genSpecializedAttr = 0;
- let cppNamespace = "::mlir::tpu";
-}
-
-def TPU_MemorySpaceEnum
- : EnumAttr {
- let assemblyFormat = "`<` $value `>`";
-}
-
-class TPU_Op traits = []> :
- Op {
-}
-
-def DefaultMemWrite : MemoryEffects<[MemWrite]>;
-def DefaultMemRead : MemoryEffects<[MemRead]>;
-
-def TPU_ReductionKind : I32EnumAttr<"ReductionKind", "Reduction kind", [
- I32EnumAttrCase<"kSum", 0, "sum">,
- I32EnumAttrCase<"kMax", 1, "max">,
- I32EnumAttrCase<"kMin", 2, "min">,
- I32EnumAttrCase<"kArgMax", 3, "arg_max">,
- I32EnumAttrCase<"kArgMin", 4, "arg_min">,
- I32EnumAttrCase<"kFindFirstSet", 5, "find_first_set">
-]> {
- let genSpecializedAttr = 0;
- let cppNamespace = "::mlir::tpu";
-}
-
-def TPU_ReductionKindAttr
- : EnumAttr {
- let assemblyFormat = "`<` $value `>`";
-}
-
-def TPU_AllReduceOp : TPU_Op<"all_reduce", [Pure]> {
- let arguments = (ins AnyVectorOfNonZeroRank:$input, I64Attr:$dim, TPU_ReductionKindAttr:$kind);
- let results = (outs AnyVectorOfNonZeroRank:$output);
- let assemblyFormat = [{
- $input attr-dict `:` type($input) `->` type($output)
- }];
- let hasVerifier = 1;
-}
-
-def TPU_ReduceIndexOp : TPU_Op<"reduce_index", [Pure]> {
- let arguments = (ins
- AnyVectorOfNonZeroRank:$input,
- I32Attr:$axis,
- TPU_ReductionKindAttr:$kind
- );
- let results = (outs VectorOfNonZeroRankOf<[I32]>:$output);
- let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }];
- let hasVerifier = 1;
-}
-
-// tpu.scan performs a scan across a vector.
-//
-// If a mask is provided, all output elements before the first unmasked input
-// element is undefined. Subsequent masked elements will hold the result
-// of the last unmasked element.
-//
-// For example, a "kSum" reduction over a input vector [1, 2, 3, 4]
-// with mask [0, 1, 0, 1] will produce the output vector [X, 2, 2, 6].
-// where X is some undefined value.
-//
-// output : Result vector. Must have the same shape as source.
-// input : Vector to scan.
-// kind : Reduction operator. Must be one of "kSum", "kMax", or "kMin".
-// Must be "kSum" if input is an I1 vector.
-// mask : Elementwise vector mask. The scan operation starts from the
-// lowest-indexed non-masked vector element (all previous elements
-// have undefined values). Not taken for I1 input vectors.
-def TPU_ScanOp : TPU_Op<"scan"> {
- let arguments = (ins
- VectorOfNonZeroRankOf<[I1, I16, I32, BF16, F32]>:$input,
- TPU_ReductionKindAttr:$kind,
- Optional>:$mask
- );
- let results = (outs VectorOfNonZeroRankOf<[I16, I32, BF16, F32]>:$output);
- let assemblyFormat = [{
- $kind `,` $input (`masked` $mask^)? attr-dict `:` type($input) `,` type($mask) `->` type($output)
- }];
- let hasVerifier = 1;
-}
-
-def TPU_SortOp : TPU_Op<"sort", [Pure]> {
- let summary = "Sorts key/value pairs based on keys.";
- let description = [{
- tpu.sort performs a stable sort of key/value pairs in ascending or
- descending order based on keys. Masked-out keys and values are placed at the
- end of the output vectors. An output mask indicates which outputs
- correspond to the valid inputs.
- }];
- let arguments = (ins
- VectorOfNonZeroRankOf<[I32,F32]>:$keys,
- VectorOfNonZeroRankOf<[I32,F32]>:$values,
- Optional>:$mask,
- DefaultValuedAttr:$descending
- );
- let results = (outs
- VectorOfNonZeroRankOf<[I1]>:$output_mask,
- VectorOfNonZeroRankOf<[I32,F32]>:$sorted_keys,
- VectorOfNonZeroRankOf<[I32,F32]>:$sorted_values
- );
- let assemblyFormat = [{
- $keys `,` $values (`masked` $mask^)? attr-dict `:` functional-type(operands, results)
- }];
- let hasVerifier = 1;
-}
-
-def TPU_StoreOp : TPU_Op<"store", [DefaultMemWrite, AttrSizedOperandSegments]> {
- let arguments = (ins
- TPU_Vreg:$valueToStore,
- AnyType:$base,
- Variadic:$indices,
- DenseBoolArrayAttr:$sublane_mask,
- Optional:$mask,
- OptionalAttr:$sublane_stride // In sublane-sized units
- );
- let results = (outs);
- let assemblyFormat = [{
- $base `[` $indices `]` `,` $valueToStore (`masked` $mask^)? `sublanes` $sublane_mask (`sublane_stride` $sublane_stride^)? attr-dict `:` type($base) `,` type($valueToStore) `,` type($mask)
- }];
-}
-
-def TPU_LoadOp : TPU_Op<"load", [DefaultMemRead]> {
- let arguments = (ins
- AnyType:$base,
- Variadic:$indices,
- DenseBoolArrayAttr:$sublane_mask,
- OptionalAttr:$sublane_stride // In sublane-sized units
- );
- let results = (outs TPU_Vreg:$result);
- let assemblyFormat = [{
- $base `[` $indices `]` `sublanes` $sublane_mask (`sublane_stride` $sublane_stride^)? attr-dict `:` type($base) `,` type($result)
- }];
- let description = [{
- Similar to `vector::LoadOp` but with `sublane_mask` and `sublane_stride`.
- When `indices` are negative, it means loading from negative offset
- of `base` address.
- }];
-}
-
-// TODO(jevinjiang): migrate tpu.strided_store to general vector store op.
-def TPU_VectorStoreOp :TPU_Op<"vector_store", [DefaultMemWrite, AttrSizedOperandSegments]> {
- let arguments = (ins
- AnyVectorOfNonZeroRank:$valueToStore,
- AnyMemRef:$base,
- Variadic:$indices,
- DenseI32ArrayAttr:$strides,
- Optional:$mask, // Elementwise mask.
- DefaultValuedAttr:$add
- );
- let results = (outs);
- let assemblyFormat = [{
- $base `[` $indices `]` `,` $valueToStore (`masked` $mask^)? attr-dict `:` type($base) `,` type($valueToStore) `,` type($mask)
- }];
- let hasVerifier = 1;
- let hasCanonicalizeMethod = 1;
-}
-
-// tpu.vector_load loads a vector from memory into a register.
-//
-// base : Memref to load from.
-// indices: Scalar indices into base. indices must be of the same rank as the
-// base memref shape.
-// strides: The stride to use for calculating the address of subsequent
-// elements. If left unspecified, the stride is implicitly 1 along
-// each dimension. Otherwise the stride must match the rank of the
-// memref shape.
-// mask : Elementwise vector mask. Must be broadcastable to the shape of the
-// result vector. Depending on the core type, this may be a dynamic
-// (lane) mask consumed from a register or a static (sublane) mask
-// that must be the result of arith.constant.
-def TPU_VectorLoadOp :TPU_Op<"vector_load", [DefaultMemRead, AttrSizedOperandSegments]> {
- let arguments = (ins
- AnyMemRef:$base,
- Variadic:$indices,
- DenseI32ArrayAttr:$strides,
- Optional:$mask // Elementwise mask.
- );
- let results = (outs AnyVectorOfNonZeroRank:$result);
- let assemblyFormat = [{
- $base `[` $indices `]` (`masked` $mask^)? attr-dict `:` type($base) `,` type($result) `,` type($mask)
- }];
- let hasVerifier = 1;
- let hasCanonicalizeMethod = 1;
-}
-
-def TPU_StridedLoadOp : TPU_Op<"strided_load", [DefaultMemRead]> {
- let arguments = (ins
- AnyMemRef:$base,
- Variadic:$indices,
- DenseI32ArrayAttr:$strides
- );
- let results = (outs AnyVectorOfNonZeroRank:$result);
- let assemblyFormat = [{
- $base `[` $indices `]` attr-dict `:` type($base) `,` type($result)
- }];
- let hasVerifier = 1;
-}
-
-def TPU_StridedStoreOp : TPU_Op<"strided_store", [DefaultMemWrite]> {
- let arguments = (ins
- AnyVectorOfNonZeroRank:$valueToStore,
- AnyMemRef:$base,
- Variadic:$indices,
- DenseI32ArrayAttr:$strides
- );
- let results = (outs);
- let assemblyFormat = [{
- $base `[` $indices `]` `,` $valueToStore attr-dict `:` type($base) `,` type($valueToStore)
- }];
- let hasVerifier = 1;
-}
-
-// TODO: b/435258666 - Merge with tpu.vector_load_idx.
-def TPU_ShuffledLoadOp : TPU_Op<"shuffled_load", [DefaultMemRead]> {
- let arguments = (ins
- AnyMemRef:$base,
- Variadic:$indices,
- DenseBoolArrayAttr:$sublane_mask,
- DenseI32ArrayAttr:$sublane_offsets
- );
- let results = (outs TPU_Vreg:$result);
- let assemblyFormat = [{
- $base `[` $indices `]` attr-dict `:` type($base) `,` type($result)
- }];
- let hasVerifier = 1;
- let hasCanonicalizeMethod = 1;
-}
-
-// TODO: b/435258666 - Merge with tpu.vector_store_idx.
-def TPU_ShuffledStoreOp : TPU_Op<"shuffled_store", [DefaultMemWrite]> {
- let arguments = (ins
- TPU_Vreg:$valueToStore,
- AnyMemRef:$base,
- Variadic:$indices,
- DenseBoolArrayAttr:$sublane_mask,
- DenseI32ArrayAttr:$sublane_offsets
- );
- let results = (outs);
- let assemblyFormat = [{
- $base `[` $indices `]` `,` $valueToStore attr-dict `:` type($base) `,` type($valueToStore)
- }];
- let hasVerifier = 1;
- let hasCanonicalizeMethod = 1;
-}
-
-// tpu.vector_load_idx loads values from arbitrary locations in memory.
-//
-// Each element in the output vector is loaded from an index in the base memref
-// specified by the corresponding elements in the 'indices' vectors. The shape
-// of each index vector must match the shape of the output vector. The number
-// of index vectors must equal the rank of the base memref.
-//
-// For example, for a vector of length n with rank 2, the indices will look like:
-// indices = [[idx0, idx1, ...], [idxn, idxn+1, ...]]
-// where [idx0, idxn] is the offset of the first vector element.
-//
-// base : Memref specifying the base address.
-// indices : Vectors of indices for each dimension of the base memref.
-// mask : Optional elementwise vector mask.
-def TPU_VectorLoadIdxOp :TPU_Op<"vector_load_idx", [DefaultMemRead, AttrSizedOperandSegments]> {
- let arguments = (ins
- MemRefOf<[I32, F32]>:$base,
- Variadic>:$indices,
- Optional>:$mask
- );
- let results = (outs VectorOfNonZeroRankOf<[I32, F32]>:$value);
- let assemblyFormat = [{
- $base `[` $indices `]` (`masked` $mask^)? attr-dict `:` type($base) `[` type($indices) `]` `,` type($value) `,` type($mask)
- }];
- let hasVerifier = 1;
- let hasCanonicalizeMethod = 1;
-}
-
-// tpu.vector_store_idx stores values to arbitrary locations in memory.
-//
-// Each element in the input vector is stored to an index in the base memref
-// specified by the corresponding elements in the 'indices' vectors. The shape
-// of each index vector must match the shape of the input vector. The number
-// of index vectors must equal the rank of the base memref.
-//
-// For example, for a vector of length n with rank 2, the indices will look like:
-// indices = [[idx0, idx1, ...], [idxn, idxn+1, ...]]
-// where [idx0, idxn] is the offset of the first vector element.
-//
-// When multiple vector elements have the same index to store to, the data from
-// the highest lane will be the one stored. If add is true, then the data will
-// be added from the lowest lane to the highest lane.
-//
-// valueToStore: Vector to be stored.
-// base : Memref specifying the base address.
-// indices : Vectors of indices for each dimension of the base memref.
-// mask : Optional elementwise vector mask.
-// add : If true, add source values to target values. Otherwise, overwrite.
-def TPU_VectorStoreIdxOp :TPU_Op<"vector_store_idx", [DefaultMemWrite, AttrSizedOperandSegments]> {
- let arguments = (ins
- VectorOfNonZeroRankOf<[I32, F32]>:$valueToStore,
- MemRefOf<[I32, F32]>:$base,
- Variadic>:$indices,
- Optional>:$mask,
- DefaultValuedAttr:$add
- );
- let results = (outs);
- let assemblyFormat = [{
- $base `[` $indices `]` `,` $valueToStore (`masked` $mask^)? attr-dict `:` type($base) `[` type($indices) `]` `,` type($valueToStore) `,` type($mask)
- }];
- let hasVerifier = 1;
- let hasCanonicalizeMethod = 1;
-}
-
-// TODO(jevinjiang): deprecate to use dynamic_rotate.
-def TPU_RotateOp : TPU_Op<"rotate", [Pure, SameOperandsAndResultType]> {
- let description = [{
- Rotates the given vector by the given amount in the given dimension, i.e.,
- for a 2D vector of shape (m, n), rotating dim 0 by `amount` will shift a row
- at index `i` to index `(i + amount) % m`
- }];
- let arguments = (ins
- AnyVectorOfNonZeroRank:$value,
- SI32Attr:$amount,
- SI32Attr:$dimension,
- // When the stride is specified, the rotation amount for each index on the
- // stride dimension will be (amount + stride * index).
- OptionalAttr:$stride,
- OptionalAttr:$stride_dimension
- );
- let results = (outs AnyVectorOfNonZeroRank:$result);
- let assemblyFormat = [{
- $value `by` $amount `dim` $dimension (`stride` $stride `stride_dim` $stride_dimension^)? attr-dict `:` type($value)
- }];
- let hasVerifier = 1;
-}
-
-def TPU_DynamicRotateOp : TPU_Op<"dynamic_rotate", [Pure]> {
- let arguments = (ins
- AnyVectorOfNonZeroRank:$value,
- I32:$amount,
- SI32Attr:$dimension,
- // When the stride is specified, the rotation amount for each index on the
- // stride dimension will be (amount + stride * index).
- OptionalAttr:$stride,
- OptionalAttr:$stride_dimension
- );
- let results = (outs AnyVectorOfNonZeroRank:$result);
- let assemblyFormat = [{
- $value `by` $amount `dim` $dimension attr-dict `:` type($value) `,` type($amount) `->` type($result)
- }];
- let hasVerifier = 1;
-}
-
-def TPU_ScanCountOp : TPU_Op<"scan_count", [Pure, InferTypeOpAdaptor, SameOperandsAndResultShape]> {
-let summary = [{
- ScanCountOp calculates the running duplicate occurrence count of the elements
- in the input vector. Elements eligible for counting are specified by the
- input mask vector. The output mask vector indicates one unique occurrence
- per duplicate that was counted.
- }];
-
- let description = [{
- ScanCountOp calculates the running duplicate occurrence count of the elements
- in the input vector, %values. The output vector, %counts, contains the running
- duplicate occurrence count for the corresponding element in
- the input vector, where the count is performed in ascending order of element
- indices. For example, if the elements of %values at indices 0, 5, and 7 had
- duplicate values, then the elements of %counts at indices 0, 5, and 7 would
- be 1, 2, and 3, respectively.
-
- A mask vector, %in_mask, specifies which of the elements in the input vector
- are eligible for counting. An element in %values that has its mask set to 0
- will always have a count of 1 in %counts, regardless of the position in the
- vector, or whether there were duplicates or not.
- }];
-
- let arguments = (ins
- VectorOfNonZeroRankOf<[I1]>:$in_mask,
- AnyVectorOfNonZeroRank:$values
- );
- let results = (outs
- VectorOfNonZeroRankOf<[I1]>:$out_mask,
- VectorOfNonZeroRankOf<[I32]>:$counts
- );
-
- let assemblyFormat = [{
- `mask` `(` $in_mask `:` type($in_mask) `)`
- `value` `(` $values `:` type($values) `)`
- attr-dict `:` type(results)
- }];
-
-}
-
-def TPU_IotaOp : TPU_Op<"iota", [Pure]> {
- let description = [{
- Creates a vector that with values that start at 0 and increase along a
- dimension resulting from collapsing the given `dimensions` together in
- row-major order.
-
- Example:
- ```
- tpu.iota {dimensions = array} : vector<4x3x2xi16>
- ```
- This produces a vector with the following values:
- ```
- [[[0, 4], [0, 4], [0, 4]]
- [[1, 5], [1, 5], [1, 5]]
- [[2, 6], [2, 6], [2, 6]]
- [[3, 7], [3, 7], [3, 7]]]
- ```
- }];
- let arguments = (ins DenseI32ArrayAttr:$dimensions);
- let results = (outs VectorOfNonZeroRankOf<[AnyInteger, Index]>:$output);
- let assemblyFormat = [{ attr-dict `:` type($output) }];
- let hasVerifier = 1;
-}
-
-def TPU_ReshapeOp : TPU_Op<"reshape", [Pure]> {
- let arguments = (ins AnyVectorOfNonZeroRank:$source);
- let results = (outs AnyVectorOfNonZeroRank:$result);
- let assemblyFormat = [{ $source attr-dict `:` type($source) `->` type($result) }];
- let hasVerifier = 1;
- let hasFolder = 1;
-}
-
-// TODO(mvoz): deprecated - use concat. Canonicalization will do so automatically.
-// b/376295711
-def TPU_RepeatOp : TPU_Op<"repeat", [Pure]> {
- let arguments = (ins
- AnyVectorOfNonZeroRank:$source,
- I32Attr:$dimension,
- I32Attr:$times
- );
- let results = (outs AnyVectorOfNonZeroRank:$output);
- let assemblyFormat = [{ $source `,` $dimension `x` $times attr-dict `:` type($source) `->` type($output) }];
-}
-
-def TPU_BroadcastInSublanesOp : TPU_Op<"broadcast_in_sublanes", [Pure]> {
- let description = [{
- For each sublane `i`, broadcasts the value in lane `lane + i` along the entire
- sublane. If `lane + i` is not in [0, lane_count), then the value in sublane `i`
- is not defined (can be anything).
- }];
- let arguments = (ins
- TPU_Vreg:$source, // All sublanes should be equal.
- I32Attr:$lane // Coordinates of the first element to take.
- );
- // Output shape should be the same, except for position dim which contains
- // the newly inserted dimension.
- let results = (outs AnyVectorOfNonZeroRank:$output);
- let assemblyFormat = [{
- $source `,` $lane attr-dict `:` type($source) `->` type($output)
- }];
-}
-
-// Integer unpacks are always signed at the moment.
-//
-// When unpacking integers to integers, setting `sign_extended` to false will
-// leave bits higher than source bitwidth as undefined.
-//
-// Take int4 to int16 interleaved unpacking and `index = 1` as an example:
-//
-// Source:
-//
-// Bits 28 24 20 16 12 8 4 0
-// --------abcd------------efgh----
-//
-// where "a" and "e" are the sign bits of the values to be unpacked, and "-" are
-// bits to be ignored.
-//
-// Unpacked, sign_extend = true:
-//
-// Bits 28 24 20 16 12 8 4 0
-// aaaaaaaaaaaaabcdeeeeeeeeeeeeefgh
-//
-// Unpacked, sign_extend = false:
-//
-// Bits 28 24 20 16 12 8 4 0
-// ------------abcd------------efgh
-def TPU_UnpackSubelementsOp : TPU_Op<"unpack_subelements", [Pure]> {
- let arguments = (ins
- AnyVectorOfNonZeroRank:$source,
- I32Attr:$index,
- TPU_PackFormatEnum:$pack_format,
- DefaultValuedAttr:$sign_extended
- );
- let results = (outs AnyVectorOfNonZeroRank:$output);
- let assemblyFormat = [{ $source `,` $index attr-dict `:` type($source) `->` type($output) }];
- let hasVerifier = 1;
- let hasCanonicalizeMethod = 1;
-}
-
-// Integer packs are always signed at the moment.
-// Float to integer packing rounds to nearest even.
-// WARNING: pack(pack(a, b), pack(c, d)) == pack(a, b, c, d) only holds for
-// compressed packing!
-// Below, we use [ ... ] to denote the bounds of the vreg and use regular parens
-// ( ... ) to denote packing of multiple subelements into a single 32-bit word.
-//
-// Interleaved packing
-//
-// Interleaved packing downcasts to a narrower dtype, and packs multiple elements
-// into the same word coordinate from which they originated. If a and b are packed
-// values, then interleaved packing first iterates over the operand list and only
-// then over the subelements within each word.
-// Take 16-bit vregs A, B, C and D:
-///
-// [ (A000 A001) (A010 A011) ... ]
-// [ (A100 A101) (A110 A111) ... ]
-// ...
-//
-// An interleaved pack(a, b) from 16-bit values produces:
-//
-// [ (A000 B000 A001 B001) (A010 B010 A011 B011) ...]
-// ...
-//
-// While an interleaved pack(a, b, c, d) produces the following subelements in
-// each vreg word:
-//
-// [ (A000 B000 C000 D000 A001 B001 C001 D001) ... ]
-// ...
-//
-// Compressed packing
-//
-// Compressed packing downcasts each value and then packs multiple rows together.
-// A compressed pack(a, b) from 16-bit values produces:
-//
-// [ (A000 A001 A100 A101) (A010 A011 A110 A111) ... ]
-// [ (A200 A201 A300 A301) (A210 A211 A310 A311) ... ]
-// ... # 2 more sublanes
-// [ (B000 B001 B100 B101) (B010 B011 B110 B111) ... ]
-// [ (B200 B201 B300 B301) (B210 B211 B310 B311) ... ]
-// ...
-def TPU_PackSubelementsOp : TPU_Op<"pack_subelements", [Pure, SameTypeOperands]> {
- let arguments = (ins
- Variadic:$sources,
- DenseI32ArrayAttr:$positions,
- TPU_PackFormatEnum:$pack_format
- );
- let results = (outs TPU_Vreg:$output);
- let assemblyFormat = [{ $sources attr-dict `:` type($sources) `->` type($output) }];
- let builders = [
- OpBuilder<(ins "::mlir::VectorType":$output_type, "::mlir::ArrayRef<::mlir::Value>":$padded_sources, "::mlir::tpu::PackFormat":$pack_format)>,
- ];
- let extraClassDeclaration = [{
- static ::mlir::SmallVector<::mlir::Value> getPaddedSources(::mlir::ValueRange sources, ::mlir::ArrayRef positions, int packing_factor);
- }];
- let hasVerifier = 1;
-}
-
-def TPU_PackElementwiseOp : TPU_Op<"pack_elementwise", [Pure, SameTypeOperands, ElementwiseMappable]> {
- let description = [{
- Packs multiple `sources` elementwise into a single vector of a narrower `target_type`.
-
- The number of `sources` must equal the packing factor, which is the ratio of
- the element bitwidth of the `sources` to the element bitwidth of the
- `target_type`. Elements from the `sources` are interleaved and packed into
- each word of the `output`, ordered from lowest to highest bits,
- corresponding to their order in the `sources`.
- }];
- let arguments = (ins
- Variadic>:$sources,
- TypeAttr:$target_type
- );
- let results = (outs VectorOfNonZeroRankOf<[I32]>:$output);
- let assemblyFormat = [{ $sources attr-dict `:` type($sources) `->` type($output) }];
- let hasVerifier = 1;
-}
-
-def TPU_UnpackElementwiseOp : TPU_Op<"unpack_elementwise", [Pure, ElementwiseMappable]> {
- let description = [{
- Unpacks a single vector from `source`, which contains multiple `source_type`
- vectors packed elementwise.
-
- The `index` selects which packed value to extract from each word of `source`.
- An `index` of 0 corresponds to the lowest bits. The extracted values are
- cast to the output element type.
- }];
- let arguments = (ins
- VectorOfNonZeroRankOf<[I32]>:$source,
- TypeAttr:$source_type,
- I32Attr:$index
- );
- let results = (outs VectorOfNonZeroRankOf<[F32, I32]>:$output);
- let assemblyFormat = [{ $source `,` $index attr-dict `:` type($source) `->` type($output) }];
- let hasVerifier = 1;
-}
-
-def TPU_RelayoutOp : TPU_Op<"relayout", [Pure, SameOperandsAndResultType]> {
- let arguments = (ins AnyVectorOfAnyRank:$input);
- let results = (outs AnyVectorOfAnyRank:$output);
- let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }];
- let hasVerifier = 1;
-}
-
-def TPU_PackMaskOp : TPU_Op<"pack_vmsk", [Pure, SameTypeOperands]> {
- let arguments = (ins
- VectorOfNonZeroRankOf<[I1]>: $low,
- VectorOfNonZeroRankOf<[I1]>: $high
- );
- let results = (outs VectorOfNonZeroRankOf<[I1]>:$output);
- let assemblyFormat = [{ $low `,` $high `,` attr-dict `:` type($low) `,` type($high) `->` type($output) }];
-}
-
-def TPU_GatherOp : TPU_Op<"gather", [Pure]> {
- let arguments = (ins
- AnyVectorOfNonZeroRank:$source,
- DenseI32ArrayAttr:$indices,
- I32Attr:$dimension
- );
- let results = (outs AnyVectorOfNonZeroRank:$output);
- let assemblyFormat = [{
- $source `[` $indices `]` `in` $dimension attr-dict
- `:` type($source) `->` type($output)
- }];
-}
-
-def TPU_DynamicGatherOp : TPU_Op<"dynamic_gather", [Pure, DeclareOpInterfaceMethods, AllShapesMatch<["indices", "output"]>, AllElementTypesMatch<["source", "output"]>]> {
- let description = [{
- Gathers elements from `source` using `indices`.
-
- The specified `dimensions` of `source` are collapsed together and indexed by
- `indices`.
-
- Given a shape `N0 x N1 x ...`, the `output[i0, i1, ...]` is given by
- `collapsed_source[j0, j1, ..., indices[i0, i1, ...] mod M]` where
- - `collapsed_source` is the result of collapsing `dimensions` of `source`
- into a new trailing dimension of size `M`.
- - `jk` is the subsequence of `in` for `n` not in `dimensions`.
-
- When a single dimension is specified, this is similar to
- `np.take_along_axis`.
- }];
- let arguments = (ins
- AnyVectorOfNonZeroRank:$source,
- VectorOfNonZeroRankOf<[AnyInteger]>:$indices,
- DenseI32ArrayAttr:$dimensions
- );
- let results = (outs AnyVectorOfNonZeroRank:$output);
- let assemblyFormat = [{
- $source `[` $indices `]` `in` $dimensions attr-dict
- `:` type($source) `,` type($indices) `->` type($output)
- }];
- let hasVerifier = 1;
-}
-
-def TPU_RoundingMode : I32EnumAttr<"RoundingMode", "Rounding mode", [
- I32EnumAttrCase<"kTowardsZero", 0, "towards_zero">,
- I32EnumAttrCase<"kToNearestEven", 1, "to_nearest_even">,
-]> {
- let genSpecializedAttr = 0;
- let cppNamespace = "::mlir::tpu";
-}
-
-def TPU_RoundingModeEnum : EnumAttr {
- let assemblyFormat = "`<` $value `>`";
-}
-
-// Internal operation. All arith.fptosi operations that change the bitwidth
-// must be canonicalized to this operation.
-def TPU_FPToSIOp : TPU_Op<"fptosi", [Pure, ElementwiseMappable]> {
- let arguments = (ins AnyVectorOfAnyRank:$input, TPU_RoundingModeEnum:$rounding_mode);
- let results = (outs AnyVectorOfAnyRank:$output);
- let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }];
- let hasCanonicalizeMethod = 1;
-}
-
-// Internal operation. All arith.sitofp operations that change the bitwidth
-// must be canonicalized to this operation.
-def TPU_SIToFPOp : TPU_Op<"sitofp", [Pure, ElementwiseMappable]> {
- let arguments = (ins AnyType:$in, TPU_RoundingModeEnum:$rounding_mode);
- let results = (outs AnyType:$output);
- let assemblyFormat = [{ $in attr-dict `:` type($in) `->` type($output) }];
-}
-
-// Internal operation.
-def TPU_ExtFOp : TPU_Op<"extf", [Pure, ElementwiseMappable]> {
- let arguments = (ins AnyType:$in);
- let results = (outs AnyType:$out);
- let assemblyFormat = [{ $in attr-dict `:` type($in) `->` type($out) }];
- let hasFolder = 1;
-}
-
-// Internal operation.
-def TPU_TruncFOp : TPU_Op<"truncf", [Pure, ElementwiseMappable]> {
- let arguments = (
- ins AnyType:$in,
- TPU_RoundingModeEnum:$rounding_mode
- );
- let results = (outs AnyType:$out);
- let assemblyFormat = [{ $in attr-dict `:` type($in) `->` type($out) }];
- let hasFolder = 1;
-}
-
def TPU_DotDimensionNumbersAttr : TPU_Attr<"DotDimensionNumbers", "dot_dimension_numbers"> {
let parameters = (ins
ArrayRefParameter<"int64_t", "">:$lhs_contracting_dims,
@@ -898,628 +57,7 @@ def TPU_DotDimensionNumbersAttr : TPU_Attr<"DotDimensionNumbers", "dot_dimension
"`[` $output_dim_order `]` `,` "
"`[` (`]`):($lhs_batch_dims^ `]`)? `,` "
"`[` (`]`):($rhs_batch_dims^ `]`)? `>`";
+ let constBuilderCall = "::mlir::tpu::DotDimensionNumbersAttr::get($_builder.getContext(), $0)";
}
-// TODO(apaszke): Think hard about precision
-def TPU_MatmulOp : TPU_Op<"matmul", [Pure]> {
- let arguments = (ins
- AnyVectorOfNonZeroRank:$lhs,
- AnyVectorOfNonZeroRank:$rhs,
- AnyVectorOfNonZeroRank:$acc,
- // These flags are deprecated - if dimension_numbers are defined,
- // these flags are ignored. They will always be false after canonicalize.
- DefaultValuedAttr:$transpose_lhs,
- DefaultValuedAttr:$transpose_rhs,
- OptionalAttr:$precision,
- // NOTE: User-level optional, once canonicalized, always present.
- OptionalAttr:$dimension_numbers
- );
- let results = (outs AnyVectorOfNonZeroRank:$result);
- let assemblyFormat = [{
- $lhs `,` $rhs `,` $acc attr-dict `:` type($lhs) `,` type($rhs) `,` type($acc) `->` type($result)
- }];
- let hasCanonicalizer = 1;
- let hasVerifier = 1;
-}
-
-def TPU_ConcatenateOp : TPU_Op<"concatenate", [Pure, DeclareOpInterfaceMethods]> {
- let arguments = (ins
- Variadic:$sources,
- I32Attr:$dimension
- );
- let results = (outs AnyVectorOfNonZeroRank:$output);
- let assemblyFormat = [{
- $sources `in` $dimension attr-dict `:` type($sources) `->` type($output)
- }];
- let hasVerifier = 1;
-}
-
-def TPU_BitcastOp : TPU_Op<"bitcast", [Pure]> {
- let arguments = (ins AnyVectorOfNonZeroRank:$input);
- let results = (outs AnyVectorOfNonZeroRank:$output);
- let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }];
- let hasVerifier = 1;
-}
-
-def TPU_BitcastVregOp : TPU_Op<"bitcast_vreg", [Pure]> {
- let arguments = (ins TPU_Vreg:$input);
- let results = (outs TPU_Vreg:$output);
- let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }];
- let hasFolder = 1;
-}
-
-def TPU_WeirdOp : TPU_Op<"weird", [Pure, ElementwiseMappable]> {
- let arguments = (ins AnyType:$input); // F32 vector or scalar
- let results = (outs AnyType:$output); // I1 vector or scalar
- let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }];
- let hasVerifier = 1;
-}
-
-def TPU_ReciprocalOp : TPU_Op<"reciprocal", [Pure, SameOperandsAndResultType, ElementwiseMappable]> {
- let arguments = (ins
- AnyVectorOfNonZeroRank:$input,
- DefaultValuedAttr:$approx
- );
- let results = (outs AnyVectorOfNonZeroRank:$output);
- let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }];
- let hasVerifier = 1;
-}
-
-def TPU_StochasticConvertOp : TPU_Op<"stochastic_convert", [Pure, SameOperandsAndResultShape]> {
- let arguments = (ins
- VectorOfNonZeroRankOf<[F32]>:$input,
- VectorOfNonZeroRankOf<[I32]>:$random
- );
- let results = (outs AnyVectorOfNonZeroRank:$output);
- let assemblyFormat = [{ $input `,` $random attr-dict `:` type($input) `,` type($random) `->` type($output) }];
-}
-
-def TPU_StochasticConvertElementwiseOp : TPU_Op<"stochastic_convert_elementwise", [Pure, ElementwiseMappable]> {
- // Stochastically converts the input to the target dtype based on the mode.
- // When the target dtype is less than 32 bits, the result occupies the lowest {bitwidth} bits in the I32 output.
- let arguments = (ins
- VectorOfNonZeroRankOf<[F32]>:$input,
- VectorOfNonZeroRankOf<[I32]>:$random,
- TypeAttr:$dst_type
- );
- let results = (outs VectorOfNonZeroRankOf<[I32]>:$output);
- let assemblyFormat = [{ $input `,` $random attr-dict `:` type($input) `,` type($random) `->` type($output) }];
- let hasVerifier = 1;
-}
-
-def TPU_RollVectorsOp : TPU_Op<"roll_vectors", [Pure]> {
- let arguments = (ins Variadic:$input);
- let results = (outs AnyVectorOfAnyRank:$output);
- let assemblyFormat = [{
- $input attr-dict `:` type($input) `->` type($output)
- }];
-}
-
-def TPU_UnrollVectorsOp : TPU_Op<"unroll_vectors", [Pure]> {
- let arguments = (ins AnyVectorOfAnyRank:$input);
- let results = (outs Variadic:$output);
- let hasCanonicalizeMethod = 1;
- let assemblyFormat = [{
- $input attr-dict `:` type($input) `->` type($output)
- }];
-}
-
-def TPU_CreateMaskOp : TPU_Op<"create_mask", [Pure, SameVariadicOperandSize]> {
- // high is exclusive
- let arguments = (ins Variadic:$low, Variadic:$high);
- let results = (outs AnyType:$output);
- let assemblyFormat = [{
- `[` $low `]``[` $high `]` attr-dict `:` type($output)
- }];
-}
-
-def TPU_CreateSubelementMaskOp : TPU_Op<"create_subelement_mask", [Pure]> {
- let summary = "Create a mask masking contiguous rows of subelements.";
- let description = [{
- The "half-sublanes", "quarter-sublanes", etc. (unit is determined by
- the type of `output`) of the mask are masked in the range specified by
- `from` and `to`.
-
- - If `from <= to`, the range `[from, to)` is set and the rest is unset.
- - If `to <= from`, the range `[to, from)` is unset and the rest is set.
-
- All lanes are set identically.
-
- Example:
-
- ```mlir
- %msk = tpu.create_subelement_mask 3, 9 : vector<8x128x2xi1>
- ```
-
- This creates a mask `%msk` where, for all `lane`s, `%msk[*][lane][*]` is:
-
- ```
- [[0, 0], [0, 1], [1, 1], [1, 1], [1, 0], [0, 0], [0, 0], [0, 0]]
- ```
-
- It is currently only supported:
- - In TPU v4, for `num_subelems` of 1 and 2.
- - In TPU v5, for `num_subelems` of 1, 2, and 4.
- }];
- let arguments = (ins
- I32Attr:$from, // inclusive
- I32Attr:$to // exclusive
- );
- let results = (outs AnyType:$output); // Verify this is a vmsk with num_subelems
- let assemblyFormat = [{
- $from `,` $to attr-dict `:` type($output)
- }];
-}
-
-def TPU_AssumeMultipleOp : TPU_Op<"assume_multiple", [Pure, SameOperandsAndResultType]> {
- let summary = "Assumes that a value is a multiple of a given integer.";
- let description = [{
- This operation is a hint to the compiler that the input `value` is guaranteed
- to be a multiple of `multiple`. This can be used to satisfy divisibility checks
- in some compiler passes.
-
- The result is the same as the input `value`.
-
- Example:
-
- ```mlir
- %val = tpu.assume_multiple %arg0, 16 : index
- ```
- }];
- let arguments = (ins
- AnyTypeOf<[Index, AnyInteger]>:$value,
- I32Attr:$multiple
- );
- let results = (outs AnyTypeOf<[Index, AnyInteger]>:$result);
- let assemblyFormat = [{$value `,` $multiple attr-dict `:` type($result)}];
- let hasVerifier = 1;
-}
-
-def TPU_MemRefSliceOp : TPU_Op<"memref_slice", [Pure, AttrSizedOperandSegments]> {
- let arguments = (ins
- AnyMemRef:$mem_ref,
- Variadic:$base_idx,
- Variadic:$dynamic_sizes
- );
- let results = (outs AnyMemRef:$result);
- let assemblyFormat = [{
- $mem_ref `[` $base_idx `]` (`<` $dynamic_sizes^ `>`)?
- attr-dict `:` type($mem_ref) `->` type($result)
- }];
- let hasVerifier = 1;
- let hasCanonicalizer = 1;
-}
-
-def TPU_MemRefSqueezeOp : TPU_Op<"memref_squeeze", [Pure]> {
- let arguments = (ins AnyMemRef:$input);
- let results = (outs AnyMemRef:$result);
- let assemblyFormat = [{
- $input attr-dict `:` type($input) `->` type($result)
- }];
- let hasVerifier = 1;
- let hasCanonicalizeMethod = 1;
-}
-
-def TPU_MemRefReshapeOp : TPU_Op<"memref_reshape", [Pure]> {
- let arguments = (ins AnyMemRef:$input);
- let results = (outs AnyMemRef:$result);
- let assemblyFormat = [{
- $input attr-dict `:` type($input) `->` type($result)
- }];
- let hasVerifier = 1;
- let hasCanonicalizeMethod = 1;
-}
-
-def TPU_MemRefBitcastOp : TPU_Op<"memref_bitcast", [Pure]> {
- let arguments = (ins AnyMemRef:$input);
- let results = (outs AnyMemRef:$result);
- let assemblyFormat = [{
- $input attr-dict `:` type($input) `->` type($result)
- }];
- let hasVerifier = 1;
- let hasCanonicalizeMethod = 1;
-}
-
-def TPU_ReinterpretCastOp : TPU_Op<"reinterpret_cast", [Pure]> {
- let arguments = (ins AnyMemRef:$input);
- let results = (outs AnyMemRef:$result);
- let assemblyFormat = [{
- $input attr-dict `:` type($input) `->` type($result)
- }];
- let hasVerifier = 1;
- let hasCanonicalizeMethod = 1;
-}
-
-def TPU_AssumeLayoutOp : TPU_Op<"assume_layout", [Pure]> {
- let arguments = (ins AnyType:$input);
- let results = (outs AnyType:$result);
- let assemblyFormat = [{
- $input attr-dict `:` type($input) `->` type($result)
- }];
-}
-
-// Erases the layout attribute from the memref.
-//
-// The resulting memref is identical to the input, except that it has an
-// identity layout.
-def TPU_EraseLayoutOp : TPU_Op<"erase_memref_layout", [Pure, InferTypeOpAdaptor]> {
- let arguments = (ins AnyMemRef:$operand);
- let results = (outs AnyMemRef:$result);
- let assemblyFormat = [{
- $operand attr-dict `:` type($operand) `->` type($result)
- }];
- let hasFolder = 1;
-}
-
-// Returns the ID of the current device.
-//
-// On the input to the compiler the return value is a logical ID in the XLA
-// device assignment. It changes to a physical ID after the
-// logical-to-physical-device-id pass.
-def TPU_DeviceIdOp : TPU_Op<"device_id", [Pure]> {
- let arguments = (ins);
- let results = (outs I32:$result);
- let assemblyFormat = [{ attr-dict `:` type($result) }];
-}
-
-def TPU_SemaphoreReadOp : TPU_Op<"sem_read"> {
- let arguments = (ins MemRefOf<[TPU_SemaphoreType, TPU_DMASemaphoreType]>:$semaphore);
- let results = (outs I32:$result);
- let assemblyFormat = [{ $semaphore attr-dict `:` type($semaphore) `->` type($result)}];
-}
-
-def TPU_SemaphoreWaitOp : TPU_Op<"sem_wait"> {
- let arguments = (ins
- MemRefOf<[TPU_SemaphoreType]>:$semaphore,
- I32:$amount
- );
- let results = (outs);
- let assemblyFormat = [{ $semaphore `,` $amount attr-dict `:` type($semaphore)}];
- let hasVerifier = 1;
-}
-
-def TPU_AllocaSemaphoreOp : TPU_Op<"sem_alloc"> {
- let arguments = (ins);
- let results = (outs MemRefOf<[TPU_SomeSemaphoreType]>:$result);
- let assemblyFormat = [{ attr-dict `:` type($result) }];
-}
-
-def TPU_GetBarrierSemaphoreOp : TPU_Op<"sem_barrier"> {
- let arguments = (ins);
- let results = (outs MemRefOf<[TPU_SemaphoreType]>:$semaphore);
- let assemblyFormat = [{ attr-dict `:` type($semaphore) }];
- let hasVerifier = 1;
-}
-
-def TPU_SemaphoreSignalOp : TPU_Op<"sem_signal", [AttrSizedOperandSegments]> {
- let arguments = (ins
- MemRefOf<[TPU_SemaphoreType]>:$semaphore,
- I32:$amount,
- Optional:$device_id, // For remote DMAs
- Optional:$core_id, // For megacore
- OptionalAttr:$core_type
- );
-let assemblyFormat = [{
- $semaphore `,` $amount (`device_id` $device_id^)? (`core_id` $core_id^)? (`core_type` $core_type^)? attr-dict `:` type($semaphore)
- }];
- let hasVerifier = 1;
- let builders = [
- // A backward-compatible builder that sets `core_type` to nullptr.
- OpBuilder<(ins "Value":$semaphore, "Value":$amount,
- "Value":$device_id, "Value":$core_id)>,
- ];
-}
-
-def TPU_BarrierOp : TPU_Op<"barrier"> {
- let summary = [{Barrier synchronization across SC vector subcores.}];
- let description = [{
- Performs barrier synchronization across all SC vector subcores at the
- specified barrier id.
- }];
- let arguments = (ins Index:$barrier_id);
- let results = (outs);
- let assemblyFormat = [{ `barrier_id` `(` $barrier_id `)` attr-dict }];
-}
-
-// tpu.enqueue_dma enqueues a DMA operation.
-//
-// source : Memref to copy from.
-// source_semaphore : Semaphore to signal after the DMA completes.
-// target : Memref to copy to.
-// target_semaphore : Semaphore to wait on before the DMA completes.
-// device_id : The id of the device to copy to for remote DMAs.
-// core_id : The id of the core to copy to for remote and cross-core
-// DMAs.
-// priority : The priority of the DMA.
-// strict_ordering : True if the DMA requires strict ordering. If false, the
-// ordering is either strict or relaxed depending on the
-// source and destination.
-def TPU_EnqueueDMAOp : TPU_Op<"enqueue_dma", [AttrSizedOperandSegments]> {
- let arguments = (ins
- AnyMemRef:$source,
- Optional>:$source_semaphore, // For remote DMAs
- AnyMemRef:$target,
- MemRefOf<[TPU_DMASemaphoreType]>:$target_semaphore,
- Optional:$device_id, // For remote DMAs
- Optional:$core_id, // For megacore
- // Smaller number means higher priority. 0 is the highest and the default.
- DefaultValuedAttr:$priority,
- DefaultValuedAttr:$strict_ordering
- );
- let assemblyFormat = [{
- `source` `(` $source `:` type($source) `)`
- `target` `(` $target `:` type($target) `)`
- (`source_semaphore` `(` $source_semaphore^ `:` type($source_semaphore) `)`)?
- `target_semaphore` `(` $target_semaphore `:` type($target_semaphore) `)`
- (`device_id` `(` $device_id^ `)`)?
- (`core_id` `(` $core_id^ `)`)?
- attr-dict
- }];
- let hasVerifier = 1;
- let hasCanonicalizeMethod = 1;
-}
-
-// A base class for all ops that need to differentiate between gather and
-// scatter.
-class IndirectDMAOp {
- code extraBaseClassDeclaration = [{
- // Return true if this op performs a gather. Returns false if it performs a
- // scatter.
- FailureOr isGather();
- }];
-}
-
-// tpu.enqueue_indirect_dma copies data between HBM and VMEM, or between
-// VMEM_SHARED and VMEM using indirect HBM offsets.
-//
-// If the source is in HBM or VMEM_SHARED and the target is in VMEM, performs a
-// gather from the source (operand) at the offsets to the target (gather
-// result).
-// If the source is in VMEM and the target is in HBM or VMEM_SHARED, performs a
-// scatter of the source (updates) to the target (operand) at the offsets.
-//
-// source : Memref to copy from.
-// target : Memref to copy to.
-// offsets : Gather or scatter offsets.
-// semaphore : Semaphore to wait on; receive semaphore for scatter, send semaphore for gather.
-// add : If true, add source values to target values. Otherwise, overwrite.
-// offset_filter : If set, don't write values at offsets whose value is equal to
-// the filter value.
-def TPU_EnqueueIndirectDMAOp : TPU_Op<"enqueue_indirect_dma">, IndirectDMAOp {
- let arguments = (ins
- AnyMemRef:$source,
- AnyMemRef:$target,
- AnyTypeOf<[MemRefOf<[I32]>, VectorOfRankAndType<[1], [I32]>]>:$offsets,
- MemRefOf<[TPU_DMASemaphoreType]>:$semaphore,
- Optional:$offset_filter,
- DefaultValuedAttr:$add
- );
- let assemblyFormat = [{
- `source` `(` $source `:` type($source) `)`
- `target` `(` $target `:` type($target) `)`
- `offsets` `(` $offsets `:` type($offsets) `)`
- (`offset_filter` `(` $offset_filter^ `)`)?
- `semaphore` `(` $semaphore `:` type($semaphore) `)`
- attr-dict
- }];
- let hasVerifier = 1;
- let extraClassDeclaration = extraBaseClassDeclaration # [{
- LogicalResult verifyGather(MemRefType operand_ty,
- ArrayRef offsets_shape,
- MemRefType result_ty);
- LogicalResult verifyScatter(MemRefType updates_ty,
- ArrayRef offsets_shape,
- MemRefType operand_ty);
- }];
- let hasCanonicalizeMethod = 1;
-}
-
-// tpu.wait_dma2 waits for a DMA to complete.
-//
-// The number of bytes to wait for is determined based on the size of the
-// destination memref.
-def TPU_WaitDMA2Op : TPU_Op<"wait_dma2", [AttrSizedOperandSegments]> {
- let arguments = (ins
- MemRefOf<[TPU_DMASemaphoreType]>:$semaphore,
- AnyMemRef:$src,
- AnyMemRef:$dst,
- Optional:$device_id, // For remote DMAs
- Optional:$core_id, // For megacore
- DefaultValuedAttr:$strict_ordering
- );
- let assemblyFormat = [{
- `semaphore` `(` $semaphore `:` type($semaphore) `)`
- `src` `(` $src `:` type($src) `)`
- `dst` `(` $dst `:` type($dst) `)`
- (`device_id` `(` $device_id^ `)`)?
- (`core_id` `(` $core_id^ `)`)?
- attr-dict
- }];
- let hasVerifier = 1;
- // A backward-compatible builder that sets `device_id` and `core_id` to nullptr.
- let builders = [
- OpBuilder<(ins "Value":$semaphore, "Value":$src, "Value":$dst)>
- ];
- let hasCanonicalizeMethod = 1;
-}
-
-// TODO(b/395630795): Remove after 2025-08-10.
-def TPU_WaitDMAOp : TPU_Op<"wait_dma"> {
- let arguments = (ins
- MemRefOf<[TPU_DMASemaphoreType]>:$semaphore,
- AnyMemRef:$ref
- );
- let hasVerifier = 1;
-}
-
-// Like tpu.wait_dma2, but for indirect DMAs.
-//
-// The number of bytes to wait for is determined based on the size of the
-// destination memref in a gather, and the size of the source memref in a
-// scatter. The op differentiates between gather and scatter based on the memory
-// spaces of the source and destination memrefs.
-def TPU_WaitIndirectDMAOp : TPU_Op<"wait_indirect_dma">, IndirectDMAOp {
- let arguments = (ins
- MemRefOf<[TPU_DMASemaphoreType]>:$semaphore,
- AnyMemRef:$src,
- AnyMemRef:$dst
- );
- let assemblyFormat = [{
- `semaphore` `(` $semaphore `:` type($semaphore) `)`
- `src` `(` $src `:` type($src) `)`
- `dst` `(` $dst `:` type($dst) `)`
- attr-dict
- }];
- let hasVerifier = 1;
- let hasCanonicalizeMethod = 1;
- let extraClassDeclaration = extraBaseClassDeclaration;
-}
-
-def TPU_RegionOp : TPU_Op<"region", [RecursiveMemoryEffects, SingleBlockImplicitTerminator<"tpu::YieldOp">]> {
- let arguments = (ins);
- let results = (outs Variadic:$results);
- let regions = (region AnyRegion:$region);
- let hasVerifier = 1;
-}
-
-def TPU_TraceOp : TPU_Op<"trace", [RecursiveMemoryEffects, SingleBlockImplicitTerminator<"tpu::YieldOp">]> {
- let arguments = (ins StrAttr:$message, I32Attr:$level);
- let results = (outs Variadic:$results);
- let regions = (region AnyRegion:$region);
-}
-
-def TPU_TraceStartOp : TPU_Op<"trace_start", []> {
- let arguments = (ins StrAttr:$message, I32Attr:$level);
- let results = (outs);
-}
-
-def TPU_TraceStopOp : TPU_Op<"trace_stop", []> {
- let arguments = (ins);
- let results = (outs);
-}
-
-def TPU_YieldOp : TPU_Op<"yield", [Pure, ReturnLike, Terminator]> {
- let arguments = (ins Variadic:$results);
- let assemblyFormat = [{ attr-dict ($results^ `:` type($results))? }];
-}
-
-def TPU_DelayOp : TPU_Op<"delay"> {
- let arguments = (ins I32:$nanos);
- let results = (outs);
-}
-
-// Expands the granularity of mask to subelements.
-def TPU_MaskCastOp : TPU_Op<"mask_cast", [Pure]> {
- let description = [{
- Cast a mask register into a different packing.
-
- If casting to a type with smaller packing, then values being packed together
- must be identical. For example, for 8x128x4xi1 -> 8x128x2xi1,
- input[i, j, 0] == input[i, j, 1] and input[i, j, 2] == input[i, j, 3] must
- hold for all i, j. Otherwise, the result is undefined.
- }];
- let arguments = (ins VectorOfNonZeroRankOf<[I1]>:$input);
- let results = (outs VectorOfNonZeroRankOf<[I1]>:$result);
- let assemblyFormat = [{
- $input attr-dict `:` type($input) `->` type($result)
- }];
- let hasVerifier = 1;
-}
-
-def TPU_GetIterationBoundOp : TPU_Op<"iteration_bound"> {
- let arguments = (ins I32Attr:$dim);
- let results = (outs I32:$result);
- let assemblyFormat = [{ $dim attr-dict `:` type($result) }];
-}
-
-def TPU_GetInternalScratchOp : TPU_Op<"internal_scratch"> {
- let arguments = (ins);
- let results = (outs AnyMemRef:$result);
- let assemblyFormat = [{ attr-dict `:` type($result) }];
-}
-
-def TPU_PRNGSeed32Op : TPU_Op<"prng_set_seed_32"> {
- let arguments = (ins Variadic:$seeds);
- let results = (outs);
-}
-
-def TPU_PRNGRandomBitsOp : TPU_Op<"prng_random_bits"> {
- let arguments = (ins);
- let results = (outs AnyVectorOfNonZeroRank:$output);
-}
-
-def TPU_SublaneShuffleOp : TPU_Op<"sublane_shuffle", [SameOperandsAndResultType]> {
- // This op takes 2 physical vregs and a pattern, applies the pattern,
- // and returns the result as 1 vreg.
- //
- // The pattern is a list of integers, where the integer value is the
- // index of the sublane in the *combined input* [lhs, rhs], and the
- // position of the integer in the list is the index of the sublane
- // in the *output* vreg.
- //
- // The pattern size must match the operand/result sublane count.
- //
- // Example:
- // %0 = tpu.single_output_sublane_shuffle %a, %b,
- // [0, 1, 2, 3, 4, 5, 6, 7] // Result is %a
- // %1 = tpu.single_output_sublane_shuffle %a, %b,
- // [8, 9, 10, 11, 12, 13, 14, 15] // Result is %b
- // %2 = tpu.single_output_sublane_shuffle %a, %b,
- // [7, 6, 5, 4, 11, 10, 9, 8] // Result uses high half of a
- // // and low half of b, reversed.
- let arguments = (ins
- TPU_Vreg:$lhs,
- TPU_Vreg:$rhs,
- DenseI32ArrayAttr:$pattern
- );
- let results = (outs TPU_Vreg:$result);
- let assemblyFormat = [{
- $lhs `,` $rhs `,` $pattern attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
- }];
-
- let hasVerifier = 1;
-}
-
-def TPU_TransposeOp : TPU_Op<"transpose", [Pure]> {
- let summary = "tpu transpose operation";
- let arguments = (ins AnyVectorOfAnyRank:$vector,
- DenseI64ArrayAttr:$permutation);
- let results = (outs AnyVectorOfAnyRank:$result);
-
- let assemblyFormat = [{
- $vector `,` $permutation attr-dict `:` type($vector) `->` type($result)
- }];
- let extraClassDeclaration = [{
- VectorType getSourceVectorType() {
- return ::llvm::cast(getVector().getType());
- }
- VectorType getResultVectorType() {
- return ::llvm::cast(getResult().getType());
- }
- }];
- let hasVerifier = 1;
-}
-
-def TPU_LogOp : TPU_Op<"log"> {
- let arguments = (ins
- Variadic:$inputs,
- StrAttr:$tag,
- DefaultValuedAttr:$formatted
- );
- let results = (outs);
- let assemblyFormat = [{ $tag attr-dict (`:` `[` $inputs^ `]` `:` type($inputs))? }];
- let hasVerifier = 1;
-}
-
-def TPU_LogBufferOp : TPU_Op<"log_buffer"> {
- let arguments = (ins
- AnyMemRef:$input,
- DenseI64ArrayAttr:$shape,
- StrAttr:$tag
- );
- let results = (outs);
- let assemblyFormat = [{ $tag attr-dict `:` $input `:` type($input) }];
- let hasVerifier = 1;
-}
-
-#endif // TPU_ATTRS
+#endif // TPU_BASE
diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc
index d49380877549..1fbc59746921 100644
--- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc
+++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc
@@ -15,22 +15,26 @@ limitations under the License.
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
+#include
+#include
#include
#include
#include
#include "absl/hash/hash.h"
-#include "absl/log/log.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/Hashing.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep.
+#include "llvm/Support/MathExtras.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep.
@@ -42,6 +46,7 @@ limitations under the License.
#include "jaxlib/mosaic/dialect/tpu/layout.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.cc.inc"
#include "jaxlib/mosaic/dialect/tpu/tpu_enums.cc.inc"
+#include "jaxlib/mosaic/dialect/tpu/util.h"
#include "xla/layout.h"
// This is a bit unclean, but we need to squat the xla namespace to make sure
@@ -118,9 +123,50 @@ struct MemRefCastEraseLayout : public OpRewritePattern {
}
};
+// Rewrites memref.dim(tpu.memref_squeeze(x)) to memref.dim(x) with the
+// dimension index adjusted to account for squeezed dimensions.
+struct MemRefDimOfSqueeze : public OpRewritePattern {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::DimOp dim_op,
+ PatternRewriter& rewriter) const override {
+ auto squeeze_op = dim_op.getSource().getDefiningOp();
+ if (!squeeze_op) {
+ return failure();
+ }
+ const std::optional maybe_dim =
+ getConstantIntValue(dim_op.getDimension());
+ if (!maybe_dim) {
+ return failure();
+ }
+ const int64_t dim = *maybe_dim;
+ MemRefType result_type = squeeze_op.getType();
+ if (dim < 0 || result_type.getRank() <= dim) {
+ return dim_op.emitWarning("Dimension index is out of bounds");
+ }
+ if (result_type.getDimSize(dim) != ShapedType::kDynamic) {
+ return failure();
+ }
+ MemRefType source_type = getMemRefType(squeeze_op.getInput());
+ FAILUREOR_ASSIGN_OR_RETURN(
+ SmallVector squeezed,
+ computeSqueezedDimsChecked(squeeze_op, source_type.getShape(),
+ result_type.getShape()));
+ int64_t source_dim = dim;
+ for (int squeezed_dim : squeezed) {
+ if (squeezed_dim <= source_dim) {
+ ++source_dim;
+ }
+ }
+ rewriter.replaceOpWithNewOp(dim_op, squeeze_op.getInput(),
+ source_dim);
+ return success();
+ }
+};
+
void TPUDialect::getCanonicalizationPatterns(RewritePatternSet& results) const
/*override*/ {
- results.add(getContext());
+ results.add(getContext());
}
FailureOr GetCoreTypeOfParentFunc(Operation &op) {
@@ -215,32 +261,210 @@ Attribute TiledLayoutAttr::parse(AsmParser &parser, Type type) {
}
AffineMap TiledLayoutAttr::getAffineMap() const {
- AffineMap map =
- AffineMap::getMultiDimIdentityMap(getTileStrides().size(), getContext());
SmallVector exprs;
- for (const xla::Tile &tile : getTiles()) {
- exprs.clear();
+ for (int64_t i = 0; i < getRank(); ++i) {
+ exprs.push_back(getAffineDimExpr(i, getContext()));
+ }
+ for (const xla::Tile& tile : getTiles()) {
+ SmallVector new_exprs;
auto dimensions = tile.dimensions();
- int64_t untiled_dims = map.getNumResults() - dimensions.size();
- if (untiled_dims < 0) {
- LOG(FATAL) << "Invalid TiledLayoutAttr: Number of dims must be larger "
- "or equal to the rank of the tile";
+ int64_t untiled_rank = exprs.size() - dimensions.size();
+ assert(untiled_rank >= 0);
+ for (int64_t i = 0; i < untiled_rank; ++i) {
+ new_exprs.push_back(exprs[i]);
+ }
+ for (int64_t i = 0; i < dimensions.size(); ++i) {
+ new_exprs.push_back(exprs[untiled_rank + i].floorDiv(dimensions[i]));
+ }
+ for (int64_t i = 0; i < dimensions.size(); ++i) {
+ new_exprs.push_back(exprs[untiled_rank + i] % dimensions[i]);
+ }
+ exprs = std::move(new_exprs);
+ }
+ int64_t num_symbols = 0;
+ AffineExpr result = getAffineConstantExpr(0, getContext());
+ SmallVector strides = getExpandedStrides();
+ assert(strides.size() == exprs.size());
+ for (int64_t i = 0; i < exprs.size(); ++i) {
+ AffineExpr stride_expr =
+ ShapedType::isDynamic(strides[i])
+ ? getAffineSymbolExpr(num_symbols++, getContext())
+ : getAffineConstantExpr(strides[i], getContext());
+ result = result + exprs[i] * stride_expr;
+ }
+ return AffineMap::get(getRank(), num_symbols, result);
+}
+
+namespace {
+int64_t getUntiledRank(ArrayRef tiles, const int64_t rank) {
+ // Note: This implementation does not assume there is no nested tiling across
+ // the first level of tiling, though this is enforced by the verifier.
+ int64_t untiled_rank = rank;
+ int64_t tiled_rank = rank;
+ for (const xla::Tile& tile : tiles) {
+ const int64_t tile_ndims = tile.dimensions().size();
+ untiled_rank = std::min(untiled_rank, tiled_rank - tile_ndims);
+ tiled_rank += tile_ndims;
+ }
+ return untiled_rank;
+}
+} // namespace
+
+int64_t TiledLayoutAttr::getUntiledRank() const {
+ return mlir::tpu::getUntiledRank(getTiles(), getRank());
+}
+
+namespace {
+FailureOr> getExpandedShape(
+ const ArrayRef untiled_shape, const ArrayRef tiles,
+ const bool require_alignment) {
+ SmallVector shape(untiled_shape);
+ for (const xla::Tile& tile : tiles) {
+ const int64_t tile_ndims = tile.dimensions().size();
+ const llvm::ArrayRef tiled_shape =
+ llvm::ArrayRef(shape).take_back(tile_ndims);
+ llvm::SmallVector new_tiled_shape(2 * tile_ndims);
+ for (int64_t i = 0; i < tile_ndims; ++i) {
+ if (require_alignment && (ShapedType::isDynamic(tiled_shape[i]) ||
+ tiled_shape[i] % tile.dimension(i) != 0)) {
+ return failure();
+ }
+ if (ShapedType::isDynamic(tiled_shape[i])) {
+ new_tiled_shape[i] = ShapedType::kDynamic;
+ } else {
+ new_tiled_shape[i] =
+ llvm::divideCeil(tiled_shape[i], tile.dimension(i));
+ }
+ new_tiled_shape[tile_ndims + i] = tile.dimension(i);
+ }
+ shape.pop_back_n(tile_ndims);
+ shape.append(new_tiled_shape);
+ }
+ return shape;
+}
+} // namespace
+
+SmallVector TiledLayoutAttr::getDefaultTileStrides(
+ const ArrayRef tiles, const ArrayRef shape) {
+ SmallVector strides(shape.size());
+ int64_t stride = 1;
+ const xla::Tile* const first_tile = tiles.empty() ? nullptr : &tiles.front();
+ const int64_t first_tile_rank =
+ first_tile == nullptr ? 0 : first_tile->dimensions().size();
+ for (int64_t d = shape.size() - 1; d >= 0; --d) {
+ assert(!ShapedType::isDynamic(shape[d]));
+ strides[d] = stride;
+ if (d >= shape.size() - first_tile_rank) {
+ assert(first_tile != nullptr);
+ const int64_t tile_d = d - (shape.size() - first_tile_rank);
+ stride *= llvm::divideCeil(shape[d], first_tile->dimension(tile_d));
+ } else {
+ stride *= shape[d];
+ }
+ }
+ return strides;
+}
+
+bool TiledLayoutAttr::tilesAreKnownContiguous(
+ const ArrayRef shape) const {
+ const ArrayRef tiles = getTiles();
+ const ArrayRef tile_strides = getTileStrides();
+ int64_t stride = 1;
+ const xla::Tile* const first_tile = tiles.empty() ? nullptr : &tiles.front();
+ const int64_t first_tile_rank =
+ first_tile == nullptr ? 0 : first_tile->dimensions().size();
+ for (int64_t d = shape.size() - 1; d >= 0; --d) {
+ int64_t size_tiles;
+ if (d >= shape.size() - first_tile_rank &&
+ shape[d] != ShapedType::kDynamic) {
+ assert(first_tile != nullptr);
+ const int64_t tile_d = d - (shape.size() - first_tile_rank);
+ size_tiles = llvm::divideCeil(shape[d], first_tile->dimension(tile_d));
+ } else {
+ size_tiles = shape[d];
}
- for (int64_t i = 0; i < untiled_dims; ++i) {
- exprs.push_back(getAffineDimExpr(i, getContext()));
+ // Dimensions with only one element/tile can have any stride.
+ if (stride != tile_strides[d] && size_tiles != 1) {
+ return false;
}
- for (int i = 0; i < dimensions.size(); ++i) {
- exprs.push_back(getAffineDimExpr(untiled_dims + i, getContext())
- .floorDiv(dimensions[i]));
+ if (d == 0) {
+ break;
}
- for (int i = 0; i < dimensions.size(); ++i) {
- exprs.push_back(getAffineDimExpr(untiled_dims + i, getContext()) %
- dimensions[i]);
+ // When any dimension other than the leading one has a dynamic size, we
+ // cannot guarantee that there are no gaps.
+ if (size_tiles == ShapedType::kDynamic) {
+ return false;
}
- auto tile_map = AffineMap::get(map.getNumResults(), 0, exprs, getContext());
- map = tile_map.compose(map);
+ stride *= size_tiles;
}
- return map;
+ return true;
+}
+
+SmallVector TiledLayoutAttr::getExpandedShape(
+ ArrayRef untiled_shape) const {
+ // getExpandedShape should never fail without require_alignment
+ return *mlir::tpu::getExpandedShape(untiled_shape, getTiles(),
+ /*require_alignment=*/false);
+}
+
+SmallVector TiledLayoutAttr::getExpandedStrides() const {
+ if (getTiles().empty()) {
+ return SmallVector(getTileStrides());
+ }
+ SmallVector strides(getTileStrides());
+ // Expand front tile
+ const xla::Tile& first_tile = getTiles().front();
+ const FailureOr> failure_or_expanded_tile =
+ mlir::tpu::getExpandedShape(first_tile.dimensions(),
+ getTiles().drop_front(),
+ /*require_alignment=*/true);
+ // Verification should ensure this:
+ assert(succeeded(failure_or_expanded_tile));
+ const SmallVector& expanded_tile = *failure_or_expanded_tile;
+ strides.resize_for_overwrite(getRank() + expanded_tile.size());
+ int64_t first_tile_size = llvm::product_of(first_tile.dimensions());
+ int64_t tile_size = 1;
+ for (int64_t d = strides.size() - 1; d >= 0; --d) {
+ if (d >= getRank()) {
+ const int64_t new_stride = tile_size;
+ tile_size *= expanded_tile[d - getRank()];
+ strides[d] = new_stride;
+ } else {
+ strides[d] *= first_tile_size;
+ }
+ }
+ return strides;
+}
+
+LogicalResult TiledLayoutAttr::verify(
+ function_ref emitError,
+ const llvm::ArrayRef tiles,
+ const llvm::ArrayRef tile_strides) {
+ if (llvm::any_of(tile_strides, ShapedType::isDynamic)) {
+ return emitError() << "Not implemented: Dynamic tile strides";
+ }
+ if (tiles.empty()) {
+ return success();
+ }
+ const int64_t rank = tile_strides.size();
+ const xla::Tile& first_tile = tiles.front();
+ const int64_t first_tile_rank = first_tile.dimensions().size();
+ // The interpretation of tile strides is unclear if there is nested tiling
+ // across first tiles (e.g. T(8, 128)(2, 4, 64)), and this has no applications
+ // anyway.
+ if (mlir::tpu::getUntiledRank(tiles, rank) != rank - first_tile_rank) {
+ return emitError() << "Not implemented: Nested tiling across first tiles";
+ }
+ // Check that nested tiles evenly divide previous tiles (so they don't add any
+ // padding or change the tile size)
+ if (failed(mlir::tpu::getExpandedShape(first_tile.dimensions(),
+ tiles.drop_front(),
+ /*require_alignment=*/true))) {
+ return emitError() << "Not implemented: Nested tiles must evenly divide "
+ << "the first tile " << first_tile.ToString()
+ << " but they do not (would add padding)";
+ }
+ return success();
}
MemRefType getMemRefType(Value value) {
@@ -250,6 +474,15 @@ MemRefType getMemRefType(Value value) {
return cast(value.getType());
}
+template
+bool checkBothOperandsDivisible(Value value, int64_t divisor, int64_t fuel) {
+ if (auto op = value.getDefiningOp()) {
+ return isGuaranteedDivisible(op.getLhs(), divisor, fuel / 2) &&
+ isGuaranteedDivisible(op.getRhs(), divisor, (fuel + 1) / 2);
+ }
+ return false;
+}
+
bool isGuaranteedDivisible(Value value, int64_t divisor, int64_t fuel) {
if (fuel <= 0) {
return false;
@@ -272,9 +505,16 @@ bool isGuaranteedDivisible(Value value, int64_t divisor, int64_t fuel) {
if (auto cast_op = value.getDefiningOp()) {
return isGuaranteedDivisible(cast_op.getOperand(), divisor, fuel - 1);
}
- if (auto add_op = value.getDefiningOp()) {
- return isGuaranteedDivisible(add_op.getRhs(), divisor, fuel / 2) &&
- isGuaranteedDivisible(add_op.getLhs(), divisor, (fuel + 1) / 2);
+ if (checkBothOperandsDivisible(value, divisor, fuel) ||
+ checkBothOperandsDivisible(value, divisor, fuel) ||
+ checkBothOperandsDivisible(value, divisor, fuel) ||
+ checkBothOperandsDivisible(value, divisor, fuel)) {
+ return true;
+ }
+ if (auto select_op = value.getDefiningOp()) {
+ return isGuaranteedDivisible(select_op.getTrueValue(), divisor, fuel / 2) &&
+ isGuaranteedDivisible(select_op.getFalseValue(), divisor,
+ (fuel + 1) / 2);
}
return false;
}
diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h
index 16f710836f0b..9d9fcf624a40 100644
--- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h
+++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h
@@ -77,7 +77,7 @@ LogicalResult specializeMemorySpace(TypedValue value,
// vector ops. This functions inverts the layout erasure applied to the value.
MemRefType getMemRefType(Value value);
-bool isGuaranteedDivisible(Value value, int64_t divisor, int64_t fuel = 8);
+bool isGuaranteedDivisible(Value value, int64_t divisor, int64_t fuel = 128);
DotDimensionNumbersAttr defaultDimensionNumbers(Builder &builder,
bool transpose_lhs,
diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc
index 855625337e64..ec56b3561591 100644
--- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc
+++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc
@@ -2042,9 +2042,19 @@ LogicalResult PackElementwiseOp::verify() {
return emitOpError("At least one source is required");
}
const auto src_vty = cast(getSources().front().getType());
- if (failed(verifyElementwisePacking(*this, src_vty.getElementType(),
- getTargetType()))) {
- return failure();
+ if (getElementTypeBitwidth(src_vty) != getElementTypeBitwidth(getType())) {
+ return emitOpError("All sources must have the same bitwidth as the result");
+ }
+ if (!getType().getElementType().isSignlessInteger()) {
+ return emitOpError("Output type must be a signless integer type");
+ }
+
+ auto src_elem_ty = src_vty.getElementType();
+ auto tgt_elem_ty = getTargetType();
+ if (!(src_elem_ty.isF32() && tgt_elem_ty.isBF16()) &&
+ !(src_elem_ty.isSignlessInteger() && tgt_elem_ty.isSignlessInteger())) {
+ return emitOpError(
+ "Only packing f32 -> bf16 and integer -> integer is supported");
}
const int packing_factor =
getElementTypeBitwidth(src_vty) / getTypeBitwidth(getTargetType());
diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.td b/jaxlib/mosaic/dialect/tpu/tpu_ops.td
new file mode 100644
index 000000000000..58f36f78d499
--- /dev/null
+++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.td
@@ -0,0 +1,1522 @@
+/* Copyright 2023 The JAX Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TPU_OPS
+#define TPU_OPS
+
+include "mlir/IR/OpBase.td"
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinAttributeInterfaces.td"
+include "mlir/IR/BuiltinTypeInterfaces.td"
+include "mlir/IR/EnumAttr.td"
+include "mlir/Pass/PassBase.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
+include "jaxlib/mosaic/dialect/tpu/tpu.td"
+
+// TODO(b/369418606): Find out the way to verify vreg size.
+def TPU_Vreg : Type;
+
+class TPU_Type traits = [],
+ string baseCppType = "::mlir::Type">
+ : TypeDef {
+ let mnemonic = mnemonic_;
+}
+
+def TPU_CoreType : I32EnumAttr<"CoreType", "Core type", [
+ I32EnumAttrCase<"kTc", 0, "tc">,
+ I32EnumAttrCase<"kScScalarSubcore", 1, "sc_scalar_subcore">,
+ I32EnumAttrCase<"kScVectorSubcore", 2, "sc_vector_subcore">
+]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::tpu";
+}
+
+def TPU_CoreTypeEnum : EnumAttr {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def TPU_PipelineMode : I32EnumAttr<"PipelineMode", "Pipeline mode", [
+ I32EnumAttrCase<"kSynchronous", 1, "synchronous">,
+ I32EnumAttrCase<"kDoubleBuffered", 2, "double_buffered">
+ ]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::tpu";
+}
+
+def TPU_PipelineModeEnum : EnumAttr {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def TPU_SemaphoreType : TPU_Type<"Semaphore", "semaphore", [MemRefElementTypeInterface]>;
+def TPU_DMASemaphoreType : TPU_Type<"DMASemaphore", "dma_semaphore", [MemRefElementTypeInterface]>;
+def TPU_SomeSemaphoreType : AnyTypeOf<[TPU_SemaphoreType, TPU_DMASemaphoreType]>;
+
+def TPU_Float8EXMYType : TPU_Type<"Float8EXMY", "float8_exmy",
+ [DeclareTypeInterfaceMethods]> {
+ let summary = "EXMY type in a 8 bit container";
+ let description = [{
+ EXMY type in a 8 bit container. Meaningful bits are aligned to LSB, and
+ bits higher than the underlying exmy type in the container are considered
+ as ignored. See https://arxiv.org/abs/2405.13938 for more details.
+ }];
+
+ let parameters = (ins
+ TypeParameter<"::mlir::FloatType", "Underlying EXMY type">:$underlying_type
+ );
+
+ let assemblyFormat = [{
+ `<` $underlying_type `>`
+ }];
+}
+
+def TPU_DimensionSemantics : I32EnumAttr<"DimensionSemantics", "Dimension semantics", [
+ I32EnumAttrCase<"parallel", 0>,
+ I32EnumAttrCase<"arbitrary", 1>,
+ I32EnumAttrCase<"core_parallel", 2>,
+ I32EnumAttrCase<"subcore_parallel", 3>
+]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::tpu";
+}
+
+def TPU_DimensionSemanticsEnum
+ : EnumAttr {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+// All indices/sizes are in element-space.
+// Note that the implementation will require statically provable tile alignment.
+def TPU_ElementWindowAttr : TPU_Attr<"ElementWindow", "element_window"> {
+ // Including low padding, to avoid backwards-incompatible changes once we add it.
+ let parameters = (ins
+ ArrayRefParameter<"int64_t", "">:$pad_low,
+ ArrayRefParameter<"int64_t", "">:$pad_high
+ );
+ let assemblyFormat = "`<` `[` $pad_low `]` `,` `[` $pad_high `]` `>`";
+}
+
+def TPU_ContractPrecision : I32EnumAttr<"ContractPrecision", "Contraction precision", [
+ I32EnumAttrCase<"kBF16", 0, "bf16">,
+ I32EnumAttrCase<"kFP32", 1, "fp32">
+]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::tpu";
+}
+
+def TPU_ContractPrecisionEnum
+ : EnumAttr {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def TPU_PackFormat : I32EnumAttr<"PackFormat", "Pack format", [
+ I32EnumAttrCase<"kCompressed", 0, "compressed">,
+ I32EnumAttrCase<"kInterleaved", 1, "interleaved">
+]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::tpu";
+}
+
+def TPU_PackFormatEnum : EnumAttr {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def TPU_TiledCase : I32EnumAttrCase<"tiled", 0>;
+def TPU_LaneCase : I32EnumAttrCase<"lanes", 1>;
+def TPU_SublaneCase : I32EnumAttrCase<"sublanes", 2>;
+def TPU_VectorLayoutDim : I32EnumAttr<
+ "VectorLayoutDim", "", [TPU_TiledCase, TPU_LaneCase, TPU_SublaneCase]>;
+
+def TPU_VectorLayoutAttr : TPU_Attr<"VectorLayout", "vpad"> {
+ let description = [{TODO}];
+
+ let parameters = (ins "Layout":$layout);
+ let hasCustomAssemblyFormat = 1;
+}
+
+def TPU_TiledLayoutAttr
+ : TPU_Attr<"TiledLayout", "tiled",
+ [DeclareAttrInterfaceMethods]> {
+ let description = [{
+ This attribute represents tiled layouts in memrefs.
+
+ Multiple levels of tiling are supported with the following restriction:
+ - Additional levels of tiling may not add any padding.
+ - Additional levels of tiling may not tile previously untiled dimensions,
+ that is, they cannot tile across first-level tiles.
+
+ Tile strides encode the stride when moving along a given dimension. They
+ must have the same rank as the shape and must be decreasing with increasing
+ dimension number. For tiled dimensions, the stride applies only when moving
+ across first-level tiles. The strides are in units of the size of the first
+ tile, or 1 if there are no tiles.
+ }];
+ let parameters = (ins
+ ArrayRefParameter<"::xla::Tile", "">:$tiles,
+ ArrayRefParameter<"int64_t", "">:$tile_strides
+ );
+ let extraClassDeclaration = [{
+ static ::llvm::SmallVector getDefaultTileStrides(::llvm::ArrayRef<::xla::Tile> tiles, ::llvm::ArrayRef shape);
+ bool tilesAreKnownContiguous(::llvm::ArrayRef shape) const;
+
+ int64_t getRank() const {
+ return getTileStrides().size();
+ }
+ int64_t getUntiledRank() const;
+
+ ::llvm::SmallVector getExpandedShape(::llvm::ArrayRef shape) const;
+ ::llvm::SmallVector getExpandedStrides() const;
+ }];
+
+ let hasCustomAssemblyFormat = 1;
+ let genVerifyDecl = 1;
+}
+
+def TPU_MemorySpace : I32EnumAttr<"MemorySpace", "Memory space", [
+ I32EnumAttrCase<"kAny", 4294967295, "any">,
+ I32EnumAttrCase<"kVmem", 0, "vmem">,
+ I32EnumAttrCase<"kSmem", 1, "smem">,
+ I32EnumAttrCase<"kHbm", 2, "hbm">,
+ I32EnumAttrCase<"kCmem", 3, "cmem">,
+ I32EnumAttrCase<"kSemaphoreMem", 4, "semaphore_mem">,
+ I32EnumAttrCase<"kVmemShared", 5, "vmem_shared">,
+ I32EnumAttrCase<"kHost", 6, "host">
+]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::tpu";
+}
+
+def TPU_MemorySpaceEnum
+ : EnumAttr {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+class TPU_Op traits = []> :
+ Op {
+}
+
+def DefaultMemWrite : MemoryEffects<[MemWrite]>;
+def DefaultMemRead : MemoryEffects<[MemRead]>;
+
+def TPU_ReductionKind : I32EnumAttr<"ReductionKind", "Reduction kind", [
+ I32EnumAttrCase<"kSum", 0, "sum">,
+ I32EnumAttrCase<"kMax", 1, "max">,
+ I32EnumAttrCase<"kMin", 2, "min">,
+ I32EnumAttrCase<"kArgMax", 3, "arg_max">,
+ I32EnumAttrCase<"kArgMin", 4, "arg_min">,
+ I32EnumAttrCase<"kFindFirstSet", 5, "find_first_set">
+]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::tpu";
+}
+
+def TPU_ReductionKindAttr
+ : EnumAttr {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def TPU_AllReduceOp : TPU_Op<"all_reduce", [Pure]> {
+ let arguments = (ins AnyVectorOfNonZeroRank:$input, I64Attr:$dim, TPU_ReductionKindAttr:$kind);
+ let results = (outs AnyVectorOfNonZeroRank:$output);
+ let assemblyFormat = [{
+ $input attr-dict `:` type($input) `->` type($output)
+ }];
+ let hasVerifier = 1;
+}
+
+def TPU_ReduceIndexOp : TPU_Op<"reduce_index", [Pure]> {
+ let arguments = (ins
+ AnyVectorOfNonZeroRank:$input,
+ I32Attr:$axis,
+ TPU_ReductionKindAttr:$kind
+ );
+ let results = (outs VectorOfNonZeroRankOf<[I32]>:$output);
+ let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }];
+ let hasVerifier = 1;
+}
+
+// tpu.scan performs a scan across a vector.
+//
+// If a mask is provided, all output elements before the first unmasked input
+// element is undefined. Subsequent masked elements will hold the result
+// of the last unmasked element.
+//
+// For example, a "kSum" reduction over a input vector [1, 2, 3, 4]
+// with mask [0, 1, 0, 1] will produce the output vector [X, 2, 2, 6].
+// where X is some undefined value.
+//
+// output : Result vector. Must have the same shape as source.
+// input : Vector to scan.
+// kind : Reduction operator. Must be one of "kSum", "kMax", or "kMin".
+// Must be "kSum" if input is an I1 vector.
+// mask : Elementwise vector mask. The scan operation starts from the
+// lowest-indexed non-masked vector element (all previous elements
+// have undefined values). Not taken for I1 input vectors.
+def TPU_ScanOp : TPU_Op<"scan"> {
+ let arguments = (ins
+ VectorOfNonZeroRankOf<[I1, I16, I32, BF16, F32]>:$input,
+ TPU_ReductionKindAttr:$kind,
+ Optional>:$mask
+ );
+ let results = (outs VectorOfNonZeroRankOf<[I16, I32, BF16, F32]>:$output);
+ let assemblyFormat = [{
+ $kind `,` $input (`masked` $mask^)? attr-dict `:` type($input) `,` type($mask) `->` type($output)
+ }];
+ let hasVerifier = 1;
+}
+
+def TPU_SortOp : TPU_Op<"sort", [Pure]> {
+ let summary = "Sorts key/value pairs based on keys.";
+ let description = [{
+ tpu.sort performs a stable sort of key/value pairs in ascending or
+ descending order based on keys. Masked-out keys and values are placed at the
+ end of the output vectors. An output mask indicates which outputs
+ correspond to the valid inputs.
+ }];
+ let arguments = (ins
+ VectorOfNonZeroRankOf<[I32,F32]>:$keys,
+ VectorOfNonZeroRankOf<[I32,F32]>:$values,
+ Optional>:$mask,
+ DefaultValuedAttr:$descending
+ );
+ let results = (outs
+ VectorOfNonZeroRankOf<[I1]>:$output_mask,
+ VectorOfNonZeroRankOf<[I32,F32]>:$sorted_keys,
+ VectorOfNonZeroRankOf<[I32,F32]>:$sorted_values
+ );
+ let assemblyFormat = [{
+ $keys `,` $values (`masked` $mask^)? attr-dict `:` functional-type(operands, results)
+ }];
+ let hasVerifier = 1;
+}
+
+def TPU_StoreOp : TPU_Op<"store", [DefaultMemWrite, AttrSizedOperandSegments]> {
+ let arguments = (ins
+ TPU_Vreg:$valueToStore,
+ AnyType:$base,
+ Variadic:$indices,
+ DenseBoolArrayAttr:$sublane_mask,
+ Optional:$mask,
+ OptionalAttr:$sublane_stride // In sublane-sized units
+ );
+ let results = (outs);
+ let assemblyFormat = [{
+ $base `[` $indices `]` `,` $valueToStore (`masked` $mask^)? `sublanes` $sublane_mask (`sublane_stride` $sublane_stride^)? attr-dict `:` type($base) `,` type($valueToStore) `,` type($mask)
+ }];
+}
+
+def TPU_LoadOp : TPU_Op<"load", [DefaultMemRead]> {
+ let arguments = (ins
+ AnyType:$base,
+ Variadic:$indices,
+ DenseBoolArrayAttr:$sublane_mask,
+ OptionalAttr:$sublane_stride // In sublane-sized units
+ );
+ let results = (outs TPU_Vreg:$result);
+ let assemblyFormat = [{
+ $base `[` $indices `]` `sublanes` $sublane_mask (`sublane_stride` $sublane_stride^)? attr-dict `:` type($base) `,` type($result)
+ }];
+ let description = [{
+ Similar to `vector::LoadOp` but with `sublane_mask` and `sublane_stride`.
+ When `indices` are negative, it means loading from negative offset
+ of `base` address.
+ }];
+}
+
+// TODO(jevinjiang): migrate tpu.strided_store to general vector store op.
+def TPU_VectorStoreOp :TPU_Op<"vector_store", [DefaultMemWrite, AttrSizedOperandSegments]> {
+ let arguments = (ins
+ AnyVectorOfNonZeroRank:$valueToStore,
+ AnyMemRef:$base,
+ Variadic:$indices,
+ DenseI32ArrayAttr:$strides,
+ Optional:$mask, // Elementwise mask.
+ DefaultValuedAttr:$add
+ );
+ let results = (outs);
+ let assemblyFormat = [{
+ $base `[` $indices `]` `,` $valueToStore (`masked` $mask^)? attr-dict `:` type($base) `,` type($valueToStore) `,` type($mask)
+ }];
+ let hasVerifier = 1;
+ let hasCanonicalizeMethod = 1;
+}
+
+// tpu.vector_load loads a vector from memory into a register.
+//
+// base : Memref to load from.
+// indices: Scalar indices into base. indices must be of the same rank as the
+// base memref shape.
+// strides: The stride to use for calculating the address of subsequent
+// elements. If left unspecified, the stride is implicitly 1 along
+// each dimension. Otherwise the stride must match the rank of the
+// memref shape.
+// mask : Elementwise vector mask. Must be broadcastable to the shape of the
+// result vector. Depending on the core type, this may be a dynamic
+// (lane) mask consumed from a register or a static (sublane) mask
+// that must be the result of arith.constant.
+def TPU_VectorLoadOp :TPU_Op<"vector_load", [DefaultMemRead, AttrSizedOperandSegments]> {
+ let arguments = (ins
+ AnyMemRef:$base,
+ Variadic