From 099f0a170820c182495da50705017b11ff46d989 Mon Sep 17 00:00:00 2001 From: Akshay Babbar <19975437+akshay-babbar@users.noreply.github.com> Date: Sun, 5 Oct 2025 20:41:46 +0530 Subject: [PATCH 001/315] Fix: local_window_size ignored when mask is None in dot_product_attention Add check for local_window_size in _apply_masks early return condition. Previously, the function would skip masking when no explicit mask was provided, causing local_window_size to be ignored. --- jax/_src/nn/functions.py | 2 +- tests/nn_test.py | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index b3b3150593f6..a9b433e96359 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -893,7 +893,7 @@ def _get_padding_mask_encoded(T, q_seqlen): def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen, local_window_size): - if mask is None and not is_causal and q_seqlen is None and kv_seqlen is None: + if mask is None and not is_causal and q_seqlen is None and kv_seqlen is None and local_window_size is None: return logits combined_mask = jnp.ones_like(logits, dtype=bool) diff --git a/tests/nn_test.py b/tests/nn_test.py index 843af79b8655..e3366830a78c 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -763,6 +763,27 @@ def testLog1mExpGrad(self): atol=1e-3, ) + def testDotProductAttention_localWindowSizeWithoutMask(self): + dtype = jnp.float32 + B, S, T, N, H = 2, 128, 128, 4, 32 + keys = random.split(random.PRNGKey(0), 3) + Q = random.normal(keys[0], (B, T, N, H), dtype) + K = random.normal(keys[1], (B, S, N, H), dtype) + V = random.normal(keys[2], (B, S, N, H), dtype) + + output_large_window = nn.dot_product_attention( + Q, K, V, mask=None, local_window_size=(32, 32) + ) + + output_small_window = nn.dot_product_attention( + Q, K, V, mask=None, local_window_size=(1, 1) + ) + + self.assertFalse( + jnp.allclose(output_large_window, output_small_window), + "Attention output should differ with different local_window_size, even without a mask.", + ) + InitializerRecord = collections.namedtuple( "InitializerRecord", From 86d3c6d85568697b1773e1271ecbf26f38d5f9eb Mon Sep 17 00:00:00 2001 From: Ming Huang Date: Mon, 24 Nov 2025 08:17:21 -0800 Subject: [PATCH 002/315] Support H-max to 256 in JAX/SPDA for Blackwell. --- jax/_src/cudnn/fused_attention_stablehlo.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index 9583b0fce0f9..d49cf00c8cf1 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -380,7 +380,8 @@ def check_is_flash_attention( # bf16/fp16 attention conditions # Check the head dim. is_on_hopper = is_cuda_compute_capability_equal("9.0") - H_max = 256 if is_on_hopper else 128 + is_on_blackwell = is_cuda_compute_capability_equal("10.0") + H_max = 256 if (is_on_hopper or is_on_blackwell) else 128 # check if multi-head latent attention is needed is_mla = qH != vH if not (qH <= H_max and qH % 8 == 0): From 3feb4748f1a8edb156c880d7223c7b746a9180a6 Mon Sep 17 00:00:00 2001 From: Ming Huang Date: Mon, 24 Nov 2025 08:23:57 -0800 Subject: [PATCH 003/315] Apply check_compute_capability --- jax/_src/cudnn/fused_attention_stablehlo.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index d49cf00c8cf1..43599c264cbc 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -379,9 +379,8 @@ def check_is_flash_attention( else: # bf16/fp16 attention conditions # Check the head dim. - is_on_hopper = is_cuda_compute_capability_equal("9.0") - is_on_blackwell = is_cuda_compute_capability_equal("10.0") - H_max = 256 if (is_on_hopper or is_on_blackwell) else 128 + is_hopper_or_later = check_compute_capability("9.0") + H_max = 256 if is_hopper_or_later else 128 # check if multi-head latent attention is needed is_mla = qH != vH if not (qH <= H_max and qH % 8 == 0): From 7a57a4cd52220dce825e8138b1cd12bf182b90ab Mon Sep 17 00:00:00 2001 From: Prakhar Prasun Date: Wed, 19 Nov 2025 08:46:15 -0800 Subject: [PATCH 004/315] docs: improve conv_transpose clarity by documenting the diff spatial axes bw tensorflow and lax.conv_transpose --- jax/_src/lax/convolution.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/jax/_src/lax/convolution.py b/jax/_src/lax/convolution.py index d355d13af6a8..c687e87c2bdb 100644 --- a/jax/_src/lax/convolution.py +++ b/jax/_src/lax/convolution.py @@ -294,6 +294,13 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int], This function directly calculates a fractionally strided conv rather than indirectly calculating the gradient (transpose) of a forward convolution. + Notes: + TensorFlow/Keras Compatibility: By default, JAX does NOT reverse the + kernel's spatial dimensions. This differs from TensorFlow's "Conv2DTranspose" + and similar frameworks, which flip spatial axes and swap input/output channels. + + To match TensorFlow/Keras behavior, set "transpose_kernel=True" . + Args: lhs: a rank `n+2` dimensional input array. rhs: a rank `n+2` dimensional array of kernel weights. From ec5e3d64e9ff27652496694be3205f8e0f1e5d24 Mon Sep 17 00:00:00 2001 From: vfdev Date: Tue, 25 Nov 2025 18:36:35 +0100 Subject: [PATCH 005/315] Fixed docstring of make_array_from_process_local_data --- jax/_src/array.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index 13ca89fb25d7..1c6fd22e6b40 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -903,14 +903,14 @@ def make_array_from_process_local_data( >>> assert output_global_array.addressable_data(0).shape == per_device_shape >>> assert output_global_array.shape == global_shape - NB: While most shardings are uniform, It is possible to design am exotic + NB: While most shardings are uniform, It is possible to design an exotic sharding mesh where each process's devices will be arranged in a non-grid like pattern in some dimensions, or for indices to overlap non-trivially. Such sharding is called "non-uniform" in those dimensions. In that case, the global shape along those directions must match local shape as there is no meaningful way to represent all needed per-process data in non-overlapping fashion. For example for global_shape 4x4 - if sharding looks like this: + if sharding looks like this:: 0123 2103 @@ -918,7 +918,7 @@ def make_array_from_process_local_data( 4567 with 4 processes, containing devices (0,1), (2, 3), (4, 5), (6, 7) respectively. - Then the data for each host look like + Then the data for each host look like:: xx.. ..xx .... .... .xx. x..x .... .... @@ -932,7 +932,7 @@ def make_array_from_process_local_data( In this case user must provide global_shape explicitly and for local_shape=(2, 4), potentially valid global shapes are (2, 4) and (4, 4). - On the other hand for sharding: + On the other hand for sharding:: 0213 x.x. .x.x. .... .... 0213 x.x. .x.x. .... .... From 921b0a6a54f240be72fe562b94302a079e538a90 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Mon, 1 Dec 2025 13:52:49 -0800 Subject: [PATCH 006/315] [Pallas] Add missing docstrings and clean up mis-renderings. --- docs/jax.experimental.pallas.mosaic_gpu.rst | 1 + docs/jax.experimental.pallas.rst | 1 + jax/_src/pallas/core.py | 4 +- jax/_src/pallas/helpers.py | 13 ++++ jax/_src/pallas/mosaic/error_handling.py | 2 +- jax/_src/pallas/mosaic/primitives.py | 16 +++- jax/_src/pallas/mosaic/random.py | 12 +-- jax/_src/pallas/mosaic/tpu_info.py | 1 + jax/_src/pallas/pallas_call.py | 5 +- jax/_src/pallas/primitives.py | 84 ++++++++++++++++++--- jax/_src/pallas/utils.py | 8 ++ jax/_src/state/primitives.py | 12 +++ 12 files changed, 137 insertions(+), 22 deletions(-) diff --git a/docs/jax.experimental.pallas.mosaic_gpu.rst b/docs/jax.experimental.pallas.mosaic_gpu.rst index 0a639f1e9d07..bb3340f248ee 100644 --- a/docs/jax.experimental.pallas.mosaic_gpu.rst +++ b/docs/jax.experimental.pallas.mosaic_gpu.rst @@ -14,6 +14,7 @@ Classes CompilerParams MemorySpace Layout + SemaphoreType SwizzleTransform TilingTransform TransposeTransform diff --git a/docs/jax.experimental.pallas.rst b/docs/jax.experimental.pallas.rst index 12d5129a9d34..17242e1551f5 100644 --- a/docs/jax.experimental.pallas.rst +++ b/docs/jax.experimental.pallas.rst @@ -47,6 +47,7 @@ Functions debug_check debug_print dot + get_global loop max_contiguous multiple_of diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 84748eae4d55..9ff32c19e1ca 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -101,7 +101,7 @@ class semaphore_dtype(dtypes.extended): """Common dtype for all kinds of semaphore dtypes. This is an abstract class that should never be instantiated, but rather - exists for the sake of `jnp.issubdtype`. + exists for the sake of ``jnp.issubdtype``. """ class semaphore(semaphore_dtype): @@ -355,7 +355,7 @@ def __str__(self): class BoundedSlice: """Allows to specify a bounded slice of a dimension. - Specifically, the index_map need to return a `pl.Slice/pl.ds` for this + Specifically, the index_map need to return a ``pl.Slice/pl.ds`` for this dimension. The start and size may be dynamic, as long as the size <= block_size. """ diff --git a/jax/_src/pallas/helpers.py b/jax/_src/pallas/helpers.py index 9c8ae14ab2b7..026abbbe5731 100644 --- a/jax/_src/pallas/helpers.py +++ b/jax/_src/pallas/helpers.py @@ -36,6 +36,19 @@ @api.named_call def empty_like(x: object): + """Create an empty PyTree of possibly uninitialized values. + + Args: + x: A PyTree with leaves specifying the shape and dtype of + the uninitialized object. + + Returns: + A PyTree with the same structure as ``x``, but with uninitialized + values. + + See Also: + :func:`jax.lax.empty` + """ return tree_util.tree_map(lambda leaf: empty(leaf.shape, leaf.dtype), x) diff --git a/jax/_src/pallas/mosaic/error_handling.py b/jax/_src/pallas/mosaic/error_handling.py index 8286ab0b5e2a..3d8714945e90 100644 --- a/jax/_src/pallas/mosaic/error_handling.py +++ b/jax/_src/pallas/mosaic/error_handling.py @@ -36,7 +36,7 @@ r'( to (?P[0-9]+)?:(?P[0-9]+))?\)' ) MLIR_ERR_PREFIX = ( - 'Pallas encountered an internal verification error.' + 'Pallas encountered an internal verification error. ' 'Please file a bug at https://github.com/jax-ml/jax/issues. ' 'Error details: ' ) diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index be9146efc9c3..e472253fb231 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -749,7 +749,16 @@ def _get_ref_and_transforms(ref): def make_async_copy(src_ref, dst_ref, sem) -> AsyncCopyDescriptor: - """Issues a DMA copying from src_ref to dst_ref.""" + """Creates a description of an asynchronous copy operation. + + Args: + src_ref: The source Reference. + dst_ref: The destination Reference. + sem: The semaphore used to track completion of the copy. + + Returns: + An AsyncCopyDescriptor. + """ src_ref, src_transforms = _get_ref_and_transforms(src_ref) dst_ref, dst_transforms = _get_ref_and_transforms(dst_ref) sem, sem_transforms = _get_ref_and_transforms(sem) @@ -835,6 +844,7 @@ def async_remote_copy( device_id, device_id_type: primitives.DeviceIdType = primitives.DeviceIdType.MESH, ) -> AsyncCopyDescriptor: + """Issues a remote DMA copying from src_ref to dst_ref.""" copy_descriptor = make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, device_id_type) copy_descriptor.start() @@ -1029,7 +1039,7 @@ def with_memory_space_constraint( ) -> jax.Array: """Constrains the memory space of an array. - This primitive does not change the value of `x`, but it constrains the + This primitive does not change the value of ``x``, but it constrains the memory space where it should be allocated. This is useful to force Pallas to allocate an array in a specific memory space. @@ -1042,7 +1052,7 @@ def with_memory_space_constraint( memory_space: The memory space to constrain to. Returns: - The array `x` with the memory space constraint. + The array ``x`` with the memory space constraint. """ if memory_space in {tpu_core.MemorySpace.ANY, pl_core.MemorySpace.ANY}: return x diff --git a/jax/_src/pallas/mosaic/random.py b/jax/_src/pallas/mosaic/random.py index b3725619caff..3751b5611655 100644 --- a/jax/_src/pallas/mosaic/random.py +++ b/jax/_src/pallas/mosaic/random.py @@ -171,13 +171,13 @@ def sample_block(sampler_fn: SampleFnType, **kwargs) -> jax.Array: """Samples a block of random values with invariance guarantees. - `sample_block` allows the sampling of identical blocks of random values + ``sample_block`` allows the sampling of identical blocks of random values across kernels with different block shapes and iteration orders. Each call to `sample_block` returns a `block_size`-shaped array of random samples corresponding to the `block_index`. - `tile_size` should be chosen such that it is a divisor to all block sizes - one needs to be invariant to. The larger the `tile_size`, the more + ``tile_size`` should be chosen such that it is a divisor to all block sizes + one needs to be invariant to. The larger the ``tile_size``, the more efficient the sampling process will be and therefore the best choice is typically the greatest common divisor between all possible block sizes. @@ -186,7 +186,7 @@ def sample_block(sampler_fn: SampleFnType, random samples. global_key: The global key to use for sampling. block_size: The shape of an individual block. - tile_size: The shape of a `tile`, which is the smallest unit at + tile_size: The shape of a ``tile``, which is the smallest unit at which samples are generated. This should be selected to be a divisor of all block sizes one needs to be invariant to. total_size: The total size of the array to sample. @@ -195,8 +195,8 @@ def sample_block(sampler_fn: SampleFnType, **kwargs: Additional arguments to pass to the sampler_fn. Returns: - A `block_size` shaped array of samples for the current block corresponding - to `block_index`. + A ``block_size`` shaped array of samples for the current block corresponding + to ``block_index``. """ if len(block_size) != len(tile_size): raise ValueError(f"block_size ({len(block_size)}) and tile_size " diff --git a/jax/_src/pallas/mosaic/tpu_info.py b/jax/_src/pallas/mosaic/tpu_info.py index 40006dc94880..767b8d1c7fb8 100644 --- a/jax/_src/pallas/mosaic/tpu_info.py +++ b/jax/_src/pallas/mosaic/tpu_info.py @@ -154,6 +154,7 @@ def get_sublane_tiling(self, dtype: jnp.dtype) -> int: def is_tpu_device() -> bool: + """Returns whether the current device is a TPU.""" return core.get_device_kind() in { "TPU v2", "TPU v3", diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index c2d8d0e97e6b..1621ba411369 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1668,7 +1668,10 @@ def pallas_call( backend: Backend | None = None, metadata: dict[str, str] | None = None, ) -> Callable[..., Any]: - """Invokes a Pallas kernel on some inputs. + """Entry point for creating a Pallas kernel. + + In contrast to :func:`jax.experimental.pallas.kernel`, this entry point + assumes that the kernel will be executed over a ``grid``. See `Pallas Quickstart `_. diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index a37d3d4338da..54e5cae849eb 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -64,11 +64,11 @@ def program_id(axis: int) -> jax_typing.Array: """Returns the kernel execution position along the given axis of the grid. - For example, with a 2D `grid` in the kernel execution corresponding to the - grid coordinates `(1, 2)`, - `program_id(axis=0)` returns `1` and `program_id(axis=1)` returns `2`. + For example, with a 2D ``grid`` in the kernel execution corresponding to the + grid coordinates ``(1, 2)``, + ``program_id(axis=0)`` returns ``1`` and ``program_id(axis=1)`` returns ``2``. - The returned value is an array of shape `()` and dtype `int32`. + The returned value is an array of shape ``()`` and dtype ``int32``. Args: axis: the axis of the grid along which to count the program. @@ -350,6 +350,8 @@ def _atomic_cas_discharge_rule(in_avals, out_avals, ref, cmp, val): mlir.register_lowering(max_contiguous_p, lambda _, x, **__: [x]) def max_contiguous(x, values): + """A compiler hint that asserts the ``values`` first values of ``x`` are contiguous. + """ if not isinstance(values, (list, tuple)): values = (values,) return max_contiguous_p.bind(x, values=tuple(values)) @@ -364,6 +366,18 @@ def _max_contiguous_abstract_eval(aval, **_): mlir.register_lowering(multiple_of_p, lambda _, x, **__: [x]) def multiple_of(x: jax_typing.Array, values: Sequence[int] | int) -> jax_typing.Array: + """A compiler hint that asserts a value is a static multiple of another. + + Note that misusing this function, such as asserting ``x`` is a multiple of + ``N`` when it is not, can result in undefined behavior. + + Args: + x: The input array. + values: A set of static divisors that ``x`` is a multiple of. + + Returns: + A copy of ``x``. + """ values = (values,) if isinstance(values, int) else tuple(values) return multiple_of_p.bind(x, values=values) @@ -713,6 +727,24 @@ def _handle_small(dtype: jax_typing.DTypeLike): def dot(a, b, trans_a: bool = False, trans_b: bool = False, allow_tf32: bool | None = None, precision=None): + """Computes the dot product of two arrays. + + The inputs can optionally be transposed before computing the + product. Depending on the hardware, this can be cheaper than + computing the transpose beforehand. + + Args: + a: The left-hand size of the dot product, of shape ``(..., N)``. + b: The right-hand size of the dot product, of shape ``(...N, M)``. + trans_a: Whether to transpose ``a`` before the product. + trans_b: Whether to transpose ``b`` before the product. + allow_tf32: Whether to use tf32 precision. + Mutually exclusive with ``precision``. + precision: Specifies the precision of the dot product. + + See Also: + :func:`jax.numpy.dot` + """ if (a.ndim != 2) or (b.ndim != 2): raise ValueError("`a` and `b` must be 2D arrays.") lhs_contract_dim = 0 if trans_a else 1 @@ -837,9 +869,9 @@ def run_scoped( to allocate for each argument. Each backend has its own set of reference types in addition to :class:`jax.experimental.pallas.MemoryRef`. - When `collective_axes` is specified, the same allocation will be returned for + When ``collective_axes`` is specified, the same allocation will be returned for all programs that only differ in their program ids along the collective axes. - It is an error not to call the same `run_scoped` in all programs along that + It is an error not to call the same ``run_scoped`` in all programs along that axis. """ if not isinstance(collective_axes, tuple): @@ -974,12 +1006,12 @@ def _lower_fun(*lower_fun_args): def get_global(what: pallas_core.ScratchShape) -> jax_typing.Array: """Returns a global reference that persists across all kernel invocations. - Each call to get_global returns a different and unique reference, but one that + Each call to ``get_global`` returns a different and unique reference, but one that is stable across invocations of the kernel body. Args: what: The reference type to allocate. Each backend has its own set of - reference types (e.g., `plgpu.SemaphoreType.REGULAR` for GPU). + reference types (e.g., :class:`jax.experimental.pallas.mosaic_gpu.SemaphoreType` for GPU). Example:: @@ -1064,7 +1096,15 @@ def _transform_semaphore(ref_value, transforms, ref_aval): semaphore_read_p.multiple_results = False -def semaphore_read(sem_or_view): +def semaphore_read(sem_or_view) -> jax_typing.Array: + """Reads the value of a semaphore. + + Args: + sem_or_view: A Ref (or view) representing a semaphore. + + Returns: + A scalar Array containing the value of the semaphore. + """ ref, transforms = _get_ref_and_transforms(sem_or_view) args = [ref, transforms] flat_args, args_tree = tree_util.tree_flatten(args) @@ -1107,6 +1147,24 @@ def semaphore_signal( device_id_type: DeviceIdType = DeviceIdType.MESH, core_index: int | jax_typing.Array | None = None, ): + """Increments the value of a semaphore. + + This operation can also be performed remotely if ``device_id`` is specified, + in which ``sem_or_view`` refers to a Ref located on another device. + Note that it is assumed that ``sem_or_view`` is already allocated + (e.g. through the proper use of barriers), or else this operation could + result in undefined behavior. + + Args: + sem_or_view: A Ref (or view) representing a semaphore. + inc: The value to increment by. + device_id (optional): Specifies which device to signal. + If not specified, ``sem_or_view`` is assumed to be local. + device_id_type (optional): The format in which + ``device_id`` should be specified. + core_index (optional): If on a multi-core device, + specifies which core to signal. + """ ref, transforms = _get_ref_and_transforms(sem_or_view) inc = jnp.asarray(inc, dtype=jnp.int32) args = [ref, transforms, inc, device_id, core_index] @@ -1208,6 +1266,14 @@ def _semaphore_signal_discharge_rule(in_avals, def semaphore_wait( sem_or_view, value: int | jax_typing.Array = 1, *, decrement: bool = True ): + """Blocks execution of the current thread until a semaphore reaches a value. + + Args: + sem_or_view: A Ref (or view) representing a semaphore. + value: The target value that the semaphore should reach before unblocking. + decrement: Whether to decrement the value of the semaphore after + a successful wait. + """ ref, transforms = _get_ref_and_transforms(sem_or_view) value = jnp.asarray(value, dtype=jnp.int32) args = [ref, transforms, value, decrement] diff --git a/jax/_src/pallas/utils.py b/jax/_src/pallas/utils.py index 77e157201107..90a61aeb4302 100644 --- a/jax/_src/pallas/utils.py +++ b/jax/_src/pallas/utils.py @@ -45,6 +45,14 @@ def cdiv(a: jax_typing.Array, b: jax_typing.Array) -> jax_typing.Array: ... def cdiv(a: int | jax_typing.Array, b: int | jax_typing.Array) -> int | jax_typing.Array: + """Computes the ceiling division of a divided by b. + + Examples: + >>> cdiv(8, 2) + 4 + >>> cdiv(9, 2) # 9 / 2 = 4.5, which rounds up to 5 + 5 + """ if isinstance(a, int) and isinstance(b, int): return (a + b - 1) // b return lax.div(a + (b - 1), b) diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 56d2ae8e868f..3b94cd5e705e 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -1079,6 +1079,18 @@ def _addupdate_vmap(axis_data, batched_args, batched_dims, *, tree): broadcast_to_p = core.Primitive('broadcast_to') def broadcast_to(a: Array, shape: tuple[int, ...]) -> Array: + """Broadcasts an array to a new shape. + + Args: + a: The array to broadcast. + shape: The desired shape to broadcast to. + + Returns: + An array of shape ``shape``. + + See Also: + :func:`jax.numpy.broadcast_to` + """ import jax.numpy as jnp # pytype: disable=import-error a = jnp.asarray(a) if a.shape == shape: From 4d9ff5bd390d2695d3e0a6ae3c57f0f0e826efd1 Mon Sep 17 00:00:00 2001 From: Jake Harmon Date: Tue, 2 Dec 2025 18:16:11 -0800 Subject: [PATCH 007/315] Disable some JAX tests on TPU in open source. PiperOrigin-RevId: 839505833 --- tests/multiprocess/pjit_test.py | 3 +++ tests/pallas/ops_test.py | 3 +++ tests/pjit_test.py | 3 +++ tests/python_callback_test.py | 22 ++++++++++++++++++++++ 4 files changed, 31 insertions(+) diff --git a/tests/multiprocess/pjit_test.py b/tests/multiprocess/pjit_test.py index 79c0721ab66b..c25e2998ecb3 100644 --- a/tests/multiprocess/pjit_test.py +++ b/tests/multiprocess/pjit_test.py @@ -527,6 +527,9 @@ def f(x): self.assertEqual(output(), "") def test_print_in_multihost_shard_map(self): + if jtu.is_cloud_tpu(): + self.skipTest("TODO: b/465504705") + devices = jax.devices() mesh = jax.sharding.Mesh(devices, ("i",)) num_devices = jax.local_device_count() diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 12ff66b0b77f..c6b3d9e7db6e 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -2764,6 +2764,9 @@ class OpsInterpretTest(OpsTest): INTERPRET = True def test_debug_print(self): + if jtu.is_cloud_tpu(): + self.skipTest("TODO: b/465504705") + @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), diff --git a/tests/pjit_test.py b/tests/pjit_test.py index fa5f11f1c5de..aa1a6d692200 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4240,6 +4240,9 @@ def test_in_out_shardings_unconstrained_error(self): in_shardings=NamedSharding(mesh, P(P.UNCONSTRAINED, 'x'))) def test_empty_io_callback_under_shard_map(self): + if jtu.is_cloud_tpu(): + self.skipTest("TODO: b/465504705") + mesh = jtu.create_mesh((4,), 'i') def empty_callback(x): diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 3a70b08ea912..128632a992d3 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -53,6 +53,7 @@ ) +@unittest.skipIf(jtu.is_cloud_tpu(), "TODO: b/465504705") class PythonCallbackTest(jtu.JaxTestCase): def setUp(self): @@ -669,6 +670,7 @@ def f(x): np.testing.assert_array_equal(x, result) +@unittest.skipIf(jtu.is_cloud_tpu(), "TODO: b/465504705") class PureCallbackTest(jtu.JaxTestCase): def setUp(self): @@ -1150,6 +1152,9 @@ def tearDown(self): dispatch.runtime_tokens.clear() def test_io_callback_can_mutate_state(self): + if jtu.is_cloud_tpu(): + self.skipTest("TODO: b/465504705") + x = 0 def cb(): nonlocal x @@ -1166,6 +1171,9 @@ def f(): self.assertEqual(x, 2) def test_io_callback_can_be_batched_if_unordered(self): + if jtu.is_cloud_tpu(): + self.skipTest("TODO: b/465504705") + _mut = 0 def cb(x): nonlocal _mut @@ -1274,6 +1282,9 @@ def f(x, y): def test_can_use_io_callback_in_pjit( self, *, ordered: bool, with_sharding: bool ): + if jtu.is_cloud_tpu(): + self.skipTest("TODO: b/465504705") + devices = jax.devices() mesh = jax.sharding.Mesh(np.array(devices), ['dev']) @@ -1334,6 +1345,9 @@ def f(x): @jtu.ignore_warning(message='.*Please use `jax.jit` instead.*', category=DeprecationWarning) def test_sequence_pjit_io_callback_ordered(self): + if jtu.is_cloud_tpu(): + self.skipTest("TODO: b/465504705") + if jtu.is_device_tpu(7, 'x'): self.skipTest('TODO(b/453664256): Failing on TPU 7x.') @@ -1395,6 +1409,8 @@ def f_base(i, x): single_device=True) ) def test_can_shard_io_callback_manually(self, single_device: bool): + if jtu.is_cloud_tpu(): + self.skipTest("TODO: b/465504705") devices = jax.devices() if single_device: @@ -1429,6 +1445,9 @@ def f(shard_ids, x): def test_batching_with_side_effects(self): # https://github.com/jax-ml/jax/issues/20628#issuecomment-2050800195 + if jtu.is_cloud_tpu(): + self.skipTest("TODO: b/465504705") + x_lst = [] def append_x(x): nonlocal x_lst @@ -1445,6 +1464,9 @@ def f(x): def test_batching_with_side_effects_while_loop(self): # https://github.com/jax-ml/jax/issues/20628#issuecomment-2050921219 + if jtu.is_cloud_tpu(): + self.skipTest("TODO: b/465504705") + x_lst = [] def append_x(x): nonlocal x_lst From 4592cfb33b875cce88d4fc1881e29bc928cc2e86 Mon Sep 17 00:00:00 2001 From: Zac Cranko Date: Tue, 2 Dec 2025 18:36:15 -0800 Subject: [PATCH 008/315] Fix footnote typo, move `jax.sharding.Mesh` to `jax.make_mesh` PiperOrigin-RevId: 839512055 --- docs/the-training-cookbook.rst | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/the-training-cookbook.rst b/docs/the-training-cookbook.rst index 3f8018f41490..4059ec5f8ed2 100644 --- a/docs/the-training-cookbook.rst +++ b/docs/the-training-cookbook.rst @@ -109,7 +109,7 @@ Examining the call signature of the function ``adam_apply`` gives us a hint: .. tagged-block:: the-training-cookbook.py adam-apply -Because ``train_state.params`` is the first argument, :func:`jax.tree.map` uses its tree structure to guide the mapping process.[#prefix_tree]_ This means that ``train_state.opt`` is traversed only as deep as the leaves of ``train_state.params``. The optimizer state for each parameter is therefore passed in as a complete subtree, which allows us to easily access all relevant states (like ``mu`` and ``nu``) for a given ``param`` inside ``adam_apply``. +Because ``train_state.params`` is the first argument, :func:`jax.tree.map` uses its tree structure to guide the mapping process. [#prefix_tree]_ This means that ``train_state.opt`` is traversed only as deep as the leaves of ``train_state.params``. The optimizer state for each parameter is therefore passed in as a complete subtree, which allows us to easily access all relevant states (like ``mu`` and ``nu``) for a given ``param`` inside ``adam_apply``. .. tip:: @@ -286,7 +286,7 @@ The drawback of data-parallel sharding is that we have to keep multiple, full, r .. code-block:: python - mesh = jax.sharding.Mesh(jax.devices(), ('fsdp',)) + mesh = jax.make_mesh((128*4,), ("fsdp",)) *Parameter Shardings:* @@ -323,8 +323,8 @@ If our model is large enough and structured appropriately, it becomes beneficial *Mesh:* .. code-block:: python - - mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(128, 4), ("fsdp", "tensor")) + + mesh = jax.make_mesh((128,4), ("fsdp", "tensor")) *Parameter Shardings:* From de608ee48af40a9e23653c140ecbea898e58c7e1 Mon Sep 17 00:00:00 2001 From: Yulia Baturina Date: Tue, 2 Dec 2025 18:55:48 -0800 Subject: [PATCH 009/315] Update `rules_ml_toolchain` to disable `--@cuda_driver//:include_cuda_umd_libs` by default. PiperOrigin-RevId: 839517818 --- .bazelrc | 4 ++-- .github/workflows/bazel_cuda_h100_b200.yml | 2 ++ WORKSPACE | 6 +++--- ci/run_bazel_test_cuda_non_rbe.sh | 2 ++ ci/run_bazel_test_cuda_rbe.sh | 1 + 5 files changed, 10 insertions(+), 5 deletions(-) diff --git a/.bazelrc b/.bazelrc index c5ef46f4e406..7a6ea6c44fa1 100644 --- a/.bazelrc +++ b/.bazelrc @@ -191,7 +191,6 @@ common:cuda_common --repo_env TF_NEED_CUDA=1 common:cuda_common --repo_env TF_NCCL_USE_STUB=1 common:cuda_common --@local_config_cuda//:enable_cuda common:cuda_common --@local_config_cuda//cuda:include_cuda_libs=true -common:cuda_common --@cuda_driver//:include_cuda_umd_libs=true # Force the linker to set RPATH, not RUNPATH. When resolving dynamic libraries, # ld.so prefers in order: RPATH, LD_LIBRARY_PATH, RUNPATH. JAX sets RPATH to @@ -218,7 +217,8 @@ common:cuda --config=cuda12 # This config is used for building targets with CUDA/NVSHMEM libraries from stubs. common:cuda_libraries_from_stubs --@local_config_cuda//cuda:include_cuda_libs=false -common:cuda_libraries_from_stubs --@cuda_driver//:include_cuda_umd_libs=false + +common:hermetic_cuda_umd --@cuda_driver//:include_cuda_umd_libs=true # common CUDA and other C++ targets with Clang common:build_cuda_with_clang --@local_config_cuda//:cuda_compiler=clang diff --git a/.github/workflows/bazel_cuda_h100_b200.yml b/.github/workflows/bazel_cuda_h100_b200.yml index 5ab758a5beae..25829ac8cf7e 100644 --- a/.github/workflows/bazel_cuda_h100_b200.yml +++ b/.github/workflows/bazel_cuda_h100_b200.yml @@ -76,6 +76,7 @@ jobs: bazel test \ --config=ci_linux_x86_64_cuda \ --config=ci_rbe_cache \ + --config=hermetic_cuda_umd \ --repo_env=HERMETIC_PYTHON_VERSION="3.14" \ --repo_env=HERMETIC_CUDNN_VERSION="9.11.0" \ --repo_env=HERMETIC_CUDA_UMD_VERSION="13.0.0" \ @@ -120,6 +121,7 @@ jobs: bazel test \ --config=ci_linux_x86_64_cuda \ --config=ci_rbe_cache \ + --config=hermetic_cuda_umd \ --repo_env=HERMETIC_PYTHON_VERSION="3.14" \ --repo_env=HERMETIC_CUDNN_VERSION="9.11.0" \ --repo_env=HERMETIC_CUDA_UMD_VERSION="13.0.0" \ diff --git a/WORKSPACE b/WORKSPACE index d3e5043da2ba..8d4ba31a0539 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -18,9 +18,9 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") # Details: https://github.com/google-ml-infra/rules_ml_toolchain tf_http_archive( name = "rules_ml_toolchain", - sha256 = "8123d826b0a4c5ceda767abc8092419fcc980c3ce45fb0f438b101fb886c014c", - strip_prefix = "rules_ml_toolchain-552b53a04a86fd5cdb4d5091e7420411d8b2a045", - urls = tf_mirror_urls("https://github.com/google-ml-infra/rules_ml_toolchain/archive/552b53a04a86fd5cdb4d5091e7420411d8b2a045.tar.gz"), + sha256 = "b1e5e306d8b1103e73b9b778dfc3a9e069d20664437a03246a235724962b5c94", + strip_prefix = "rules_ml_toolchain-484235be45e6843db962c45d08fe4b2b65a6a24c", + urls = tf_mirror_urls("https://github.com/google-ml-infra/rules_ml_toolchain/archive/484235be45e6843db962c45d08fe4b2b65a6a24c.tar.gz"), ) load( diff --git a/ci/run_bazel_test_cuda_non_rbe.sh b/ci/run_bazel_test_cuda_non_rbe.sh index 25e3f97468c7..98d5955d60cb 100755 --- a/ci/run_bazel_test_cuda_non_rbe.sh +++ b/ci/run_bazel_test_cuda_non_rbe.sh @@ -122,6 +122,7 @@ bazel test --config=$TEST_CONFIG \ --action_env=NCCL_DEBUG=WARN \ --color=yes \ --config=cuda_libraries_from_stubs \ + --config=hermetic_cuda_umd \ //tests:gpu_tests //tests:backend_independent_tests \ //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests @@ -147,6 +148,7 @@ bazel test --config=$TEST_CONFIG \ --action_env=NCCL_DEBUG=WARN \ --color=yes \ --config=cuda_libraries_from_stubs \ + --config=hermetic_cuda_umd \ //tests:gpu_tests //tests/pallas:gpu_tests \ //tests/multiprocess:gpu_tests diff --git a/ci/run_bazel_test_cuda_rbe.sh b/ci/run_bazel_test_cuda_rbe.sh index 8aaee0505c0f..572cb132bcf0 100755 --- a/ci/run_bazel_test_cuda_rbe.sh +++ b/ci/run_bazel_test_cuda_rbe.sh @@ -67,6 +67,7 @@ bazel test --config=rbe_linux_x86_64_cuda${JAXCI_CUDA_VERSION} \ --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ --color=yes \ $cuda_libs_flag \ + --config=hermetic_cuda_umd \ --//jax:build_jaxlib=$JAXCI_BUILD_JAXLIB \ --//jax:build_jax=$JAXCI_BUILD_JAX \ //tests:gpu_tests //tests:backend_independent_tests \ From cafc0325f5ebcaad4f5a6cfc533e7e0e19c19b28 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 2 Dec 2025 19:26:13 -0800 Subject: [PATCH 010/315] Fix pjit_test on tpu v5_x4 PiperOrigin-RevId: 839528083 --- tests/pjit_test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index aa1a6d692200..e5f6f53fc658 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -6450,11 +6450,12 @@ def g(x): out2 = core.jaxpr_as_fun(jaxpr)(arr) self.assertEqual(out2[0].sharding, NamedSharding(mesh, P('x', 'y'))) - @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + @jtu.with_explicit_mesh((2, 2), ('x', 'y'), iota_order=True) def test_device_put_different_dst_mesh(self, mesh): np1 = np.arange(16).reshape(8, 2) x = jax.device_put(np1, P('x', 'y')) - mesh2 = jax.make_mesh((4,), ('a',), axis_types=(AxisType.Explicit,)) + mesh2 = jtu.create_mesh((4,), ('a',), axis_types=(AxisType.Explicit,), + iota_order=True) y = jax.device_put(x, NamedSharding(mesh2, P('a', None))) self.assertEqual(y.sharding, NamedSharding(mesh2, P('a', None))) self.assertArraysEqual(y, np1) From 7b342dc925f2b8b2162f37e44b7e834e6d750c39 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 2 Dec 2025 19:53:00 -0800 Subject: [PATCH 011/315] Export `reshard` as `jax.reshard` PiperOrigin-RevId: 839537115 --- jax/__init__.py | 1 + tests/documentation_coverage_test.py | 2 +- tests/pjit_test.py | 3 ++- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/jax/__init__.py b/jax/__init__.py index 945fb9f46374..936b7f914377 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -136,6 +136,7 @@ from jax._src.sharding_impls import make_mesh as make_mesh from jax._src.sharding_impls import set_mesh as set_mesh from jax._src.partition_spec import P as P +from jax._src.pjit import reshard as reshard from jax._src.shard_map import shard_map as shard_map from jax._src.shard_map import smap as smap diff --git a/tests/documentation_coverage_test.py b/tests/documentation_coverage_test.py index 21822a7f4218..83ae55a7423c 100644 --- a/tests/documentation_coverage_test.py +++ b/tests/documentation_coverage_test.py @@ -51,7 +51,7 @@ def jax_docs_dir() -> str: UNDOCUMENTED_APIS = { - 'jax': ['NamedSharding', 'P', 'Ref', 'Shard', 'ad_checkpoint', 'api_util', 'checkpoint_policies', 'core', 'custom_derivatives', 'custom_transpose', 'debug_key_reuse', 'device_put_replicated', 'device_put_sharded', 'effects_barrier', 'example_libraries', 'explain_cache_misses', 'experimental', 'extend', 'float0', 'freeze', 'fwd_and_bwd', 'host_count', 'host_id', 'host_ids', 'interpreters', 'jax', 'jax2tf_associative_scan_reductions', 'legacy_prng_key', 'lib', 'make_user_context', 'new_ref', 'no_execution', 'numpy_dtype_promotion', 'remat', 'remove_size_one_mesh_axis_from_type', 'softmax_custom_jvp', 'threefry_partitionable', 'tools', 'transfer_guard_device_to_device', 'transfer_guard_device_to_host', 'transfer_guard_host_to_device', 'version'], + 'jax': ['NamedSharding', 'P', 'Ref', 'Shard', 'reshard', 'ad_checkpoint', 'api_util', 'checkpoint_policies', 'core', 'custom_derivatives', 'custom_transpose', 'debug_key_reuse', 'device_put_replicated', 'device_put_sharded', 'effects_barrier', 'example_libraries', 'explain_cache_misses', 'experimental', 'extend', 'float0', 'freeze', 'fwd_and_bwd', 'host_count', 'host_id', 'host_ids', 'interpreters', 'jax', 'jax2tf_associative_scan_reductions', 'legacy_prng_key', 'lib', 'make_user_context', 'new_ref', 'no_execution', 'numpy_dtype_promotion', 'remat', 'remove_size_one_mesh_axis_from_type', 'softmax_custom_jvp', 'threefry_partitionable', 'tools', 'transfer_guard_device_to_device', 'transfer_guard_device_to_host', 'transfer_guard_host_to_device', 'version'], 'jax.ad_checkpoint': ['checkpoint', 'checkpoint_policies', 'print_saved_residuals', 'remat', 'Offloadable', 'Recompute', 'Saveable'], 'jax.custom_batching': ['custom_vmap', 'sequential_vmap'], 'jax.custom_derivatives': ['CustomVJPPrimal', 'SymbolicZero', 'closure_convert', 'custom_gradient', 'custom_jvp', 'custom_jvp_call_p', 'custom_vjp', 'custom_vjp_call_p', 'custom_vjp_primal_tree_values', 'linear_call', 'remat_opt_p', 'zero_from_primal'], diff --git a/tests/pjit_test.py b/tests/pjit_test.py index e5f6f53fc658..94daf7c9cc44 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -29,6 +29,7 @@ import jax import jax.numpy as jnp +from jax import reshard from jax._src import core from jax._src import config from jax._src import dispatch @@ -42,7 +43,7 @@ from jax.lax import with_sharding_constraint from jax._src import prng from jax.sharding import (PartitionSpec as P, Mesh, auto_axes, explicit_axes, - reshard, AbstractDevice) + AbstractDevice) from jax.experimental import multihost_utils from jax._src.shard_map import shard_map from jax._src.compilation_cache import is_persistent_cache_enabled From d23905a7f8332553a59ec1ecd8a8972edca46d43 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 2 Dec 2025 21:43:26 -0800 Subject: [PATCH 012/315] Allow colocated python objects to use reference-counting to understand when the object can be destructed at the backend. Reverts bcc9102cecf823f85b5a30902fcb411876588ef6 PiperOrigin-RevId: 839574881 --- jax/experimental/BUILD | 2 + jax/experimental/colocated_python/func.py | 333 ++++++++++++++-------- jax/experimental/colocated_python/obj.py | 78 ++++- 3 files changed, 286 insertions(+), 127 deletions(-) diff --git a/jax/experimental/BUILD b/jax/experimental/BUILD index f9c4de048eb5..1038238db600 100644 --- a/jax/experimental/BUILD +++ b/jax/experimental/BUILD @@ -117,11 +117,13 @@ pytype_strict_library( "//jax", "//jax/_src:api", "//jax/_src:api_util", + "//jax/_src:config", "//jax/_src:traceback_util", "//jax/_src:tree_util", "//jax/_src:util", "//jax/_src:xla_bridge", "//jax/_src/lib", + "//jax/extend:backend", "//jax/extend:ifrt_programs", ] + py_deps("numpy") + py_deps("cloudpickle"), ) diff --git a/jax/experimental/colocated_python/func.py b/jax/experimental/colocated_python/func.py index 27f8f31f70cb..a7fe5a52ba39 100644 --- a/jax/experimental/colocated_python/func.py +++ b/jax/experimental/colocated_python/func.py @@ -15,12 +15,13 @@ from __future__ import annotations +from collections.abc import Callable, Sequence import dataclasses +import functools import inspect import random import threading from typing import Any -from collections.abc import Callable, Sequence import jax from jax._src import api @@ -31,7 +32,8 @@ from jax._src.traceback_util import api_boundary from jax._src.util import wraps from jax.experimental.colocated_python import func_backend -from jax.experimental.colocated_python.serialization import _deserialize_specs, _make_specs_for_serialized_specs, _serialize, _serialize_specs +from jax.experimental.colocated_python.serialization import _deserialize, _deserialize_specs, _make_specs_for_serialized_specs, _serialize, _serialize_specs +from jax.extend.backend import register_backend_cache as jax_register_backend_cache from jax.extend.ifrt_programs import ifrt_programs ShapeDtypeStructTree = Any # PyTree[api.ShapeDtypeStruct] @@ -186,12 +188,14 @@ def call(*args, **kwargs): # TODO(hyeontaek): Implement colocated Python support in McJAX and remove # this fallback path. if "PjRtCompiler requires an HloProgram" in str(e): - return fun + return _deserialize(pickled_function)[0] raise def _make_output_specs_and_push_result_fun( - info: FunctionInfo, specialization: Specialization, uid: int + info: FunctionInfo, + specialization: Specialization, + uid: int, ) -> Callable[..., Any]: """Creates a function that computes output specs and pushes the result to the result store.""" assert specialization.in_specs_treedef is not None @@ -226,7 +230,9 @@ def lowered_fun(*args, **kwargs) -> jax.Array: def _make_pop_result_fun( - info: FunctionInfo, specialization: Specialization, uid: int + info: FunctionInfo, + specialization: Specialization, + uid: int, ) -> Callable[..., Any]: """Makes a function that pops results from the result store.""" assert specialization.out_specs_treedef is not None @@ -259,7 +265,8 @@ def lowered_fun(): def _make_async_execution_fun( - info: FunctionInfo, specialization: Specialization + info: FunctionInfo, + specialization: Specialization, ) -> Callable[..., Any]: """Makes a function that asynchronously executes the function.""" assert specialization.in_specs_treedef is not None @@ -280,9 +287,9 @@ def _make_async_execution_fun( ) -@jax._src.util.cache(max_size=None) -def _get_specialized_func( - info: FunctionInfo, specialization: Specialization +def _uncached_get_specialized_func( + info: FunctionInfo, + specialization: Specialization, ) -> Callable[..., Any]: """Returns a specialized function for the given specialization.""" util.test_event("colocated_python_func._get_specialized_func") @@ -302,9 +309,14 @@ def specialized_func(*args, **kwargs): if async_execution_func is None: if specialization.out_specs_treedef is None: if specialization.out_specs_fn is None: - serialized_out_specs = _make_output_specs_and_push_result_fun( - info, specialization, uid - )(*args, **kwargs) + output_specs_and_push_result_fun = ( + _make_output_specs_and_push_result_fun( + info, specialization, uid + ) + ) + serialized_out_specs = output_specs_and_push_result_fun( + *args, **kwargs + ) # Waits for the output_specs. This may block. out_specs_treedef, out_specs_leaves = _deserialize_specs( @@ -321,6 +333,13 @@ def specialized_func(*args, **kwargs): info, specialization ) + # Hold the PyExecutable until async_execution_fun is called at + # least once, so the number of _OBJECT_STORE references at the + # backend does not drop to 0. + async_execution_func.output_specs_and_push_result_fun = ( + output_specs_and_push_result_fun + ) + return _make_pop_result_fun(info, specialization, uid)() else: # Compute out_specs using out_specs_fn and inputs. @@ -348,122 +367,210 @@ def specialized_func(*args, **kwargs): # Asynchronous execution runs outside of the mutex to allow concurrent # execution for inline executors. - return async_execution_func(*args, **kwargs) + result = async_execution_func(*args, **kwargs) + with mutex: + async_execution_func.output_specs_and_push_result_fun = None + return result return specialized_func -def make_callable( - fun: Callable[..., Any], - fun_sourceinfo: str | None, - fun_signature: inspect.Signature | None, -): - """Makes a colocated Python callable.""" - return _make_callable( - FunctionInfo(fun, fun_sourceinfo, fun_signature), Specialization() - ) +class _CachedGetSpecializedFunction: + """Manages cached versions of `_uncached_get_specialized_func`. + This class holds a collection of caches, each identified by a unique ID, and + presents itself as a single cache to JAX's `register_backend_cache`. One can + clear individual caches identified by the UID, using the `cache_remove(uid)` + method. JAX's `clear_backend_cache()` will clear all caches. + """ -def _make_callable(info: FunctionInfo, specialization: Specialization): - """Internal implementation of make_callable.""" + def __init__(self): + self._lock = threading.Lock() + self._caches: dict[int, Any] = {} + jax_register_backend_cache(self, "colocated_python_specialized_func_cache") - def specialize( - in_specs: ShapeDtypeStructTree | None = None, - out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None, - devices: Sequence[jax.Device] | None = None, - ): - """Returns a colocated Python callable with extra specialization. - - Args: - in_specs: Optionally specifies the expected input specs. Input specs are - expressed as a `PyTree[ShapeDtypeStruct]` for `(args, kwargs)` of a - function call. - out_specs_fn: Optionally specifies a function that computes the output - specs from input specs. If unspecified, colocated Python will compute - the output specs during the very first execution, and this execution - will be synchronous. - devices: Optionally specifies the devices to execute the function on. Must - be provided if `in_specs` has no leaves because devices cannot be - inferred from input specs or arguments. - - Returns: - A colocated Python callable with extra specialization. - """ - # TODO(hyeontaek): Allow unspecified devices for zero-leaf `in_specs` if - # `out_specs_fn(in_specs)` returns at least one leaf that we can use for - # inferring `devices`. - if in_specs is None: - in_specs_leaves, in_specs_treedef = None, None + def cache_clear(self): + self._caches.clear() + + def cache_remove(self, held_by: int): + try: + self._caches.pop(held_by) + except KeyError: + pass + + def get(self, held_by: int) -> Callable[..., Any]: + with self._lock: + try: + return self._caches[held_by] + except KeyError: + cache = functools.cache(_uncached_get_specialized_func) + self._caches[held_by] = cache + return cache + + +_SINGLETON_CACHED_GET_SPECIALIZED_FUNCTION = _CachedGetSpecializedFunction() + + +class _CachedColocatedFunctionMaker: + """Function maker for colocated Python functions. + + Generated functions are stored (cached) indefinitely so that they can be + reused, until the cache is dropped. + """ + + def __init__(self, held_by: int | None): + self.held_by = held_by + if held_by is None: + self._get_specialized_func = jax._src.util.cache( + max_size=None, trace_context_in_key=False + )(_uncached_get_specialized_func) else: - in_specs_leaves_list, in_specs_treedef = tree_util.tree_flatten(in_specs) - in_specs_leaves = tuple(in_specs_leaves_list) - return _make_callable( - info, - specialization.update( - in_specs_treedef=in_specs_treedef, - in_specs_leaves=in_specs_leaves, - out_specs_fn=out_specs_fn, - devices=devices, - ), - ) + self._get_specialized_func = ( + _SINGLETON_CACHED_GET_SPECIALIZED_FUNCTION.get(held_by) + ) + + def __del__(self): + if self.held_by is not None: + _SINGLETON_CACHED_GET_SPECIALIZED_FUNCTION.cache_remove(self.held_by) - @api_boundary - def __call__(*args, **kwargs): - """Executes the given Python function on the same devices as the arguments or as specialized. - - If the callable has not been specialized with output shapes and shardings - (see `specialize` above), the very first call will run synchronously to - discover output shapes and shardings, and will run asynchronously after. If - specialized with output shapes and shardings, every execution of the - callable will be asynchronous. - """ - args_leaves, in_specs_treedef = tree_util.tree_flatten((args, kwargs)) - - in_specs_leaves = tuple(_get_spec(x) for x in args_leaves) - if specialization.in_specs_treedef is None: - # Allow input polymorphism by applying input_specs specialization - # temporarily for this call. - return _make_callable( + def _make_callable( + self, + info: FunctionInfo, + specialization: Specialization, + ): + """Internal implementation of make_callable.""" + + def specialize( + in_specs: ShapeDtypeStructTree | None = None, + out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None, + devices: Sequence[jax.Device] | None = None, + ): + """Returns a colocated Python callable with extra specialization. + + Args: + in_specs: Optionally specifies the expected input specs. Input specs are + expressed as a `PyTree[ShapeDtypeStruct]` for `(args, kwargs)` of a + function call. + out_specs_fn: Optionally specifies a function that computes the output + specs from input specs. If unspecified, colocated Python will compute + the output specs during the very first execution, and this execution + will be synchronous. + devices: Optionally specifies the devices to execute the function on. + Must be provided if `in_specs` has no leaves because devices cannot be + inferred from input specs or arguments. + + Returns: + A colocated Python callable with extra specialization. + """ + # TODO(hyeontaek): Allow unspecified devices for zero-leaf `in_specs` if + # `out_specs_fn(in_specs)` returns at least one leaf that we can use for + # inferring `devices`. + if in_specs is None: + in_specs_leaves, in_specs_treedef = None, None + else: + in_specs_leaves_list, in_specs_treedef = tree_util.tree_flatten( + in_specs + ) + in_specs_leaves = tuple(in_specs_leaves_list) + return self._make_callable( info, specialization.update( in_specs_treedef=in_specs_treedef, in_specs_leaves=in_specs_leaves, + out_specs_fn=out_specs_fn, + devices=devices, ), - )(*args, **kwargs) + ) - if specialization.devices is None: - devices = _infer_devices_from_args(args_leaves) - if devices is None: + @api_boundary + def __call__(*args, **kwargs): + """Executes the given Python function on the same devices as the arguments or as specialized. + + If the callable has not been specialized with output shapes and shardings + (see `specialize` above), the very first call will run synchronously to + discover output shapes and shardings, and will run asynchronously after. + If + specialized with output shapes and shardings, every execution of the + callable will be asynchronous. + """ + args_leaves, in_specs_treedef = tree_util.tree_flatten((args, kwargs)) + + in_specs_leaves = tuple(_get_spec(x) for x in args_leaves) + if specialization.in_specs_treedef is None: + # Allow input polymorphism by applying input_specs specialization + # temporarily for this call. + return self._make_callable( + info, + specialization.update( + in_specs_treedef=in_specs_treedef, + in_specs_leaves=in_specs_leaves, + ), + )(*args, **kwargs) + + if specialization.devices is None: + devices = _infer_devices_from_args(args_leaves) + if devices is None: + raise ValueError( + "No devices found. colocated_python function without input" + " arguments must be first specialized with devices." + ) + # Allow device polymorphism by applying devices specialization temporarily + # for this call. + return self._make_callable( + info, + specialization.update(devices=devices), + )(*args, **kwargs) + + # Assertion is added to silence mypy error: Unsupported operand types for != + # ("PyTreeDef" and "None") [operator] + assert isinstance(specialization.in_specs_treedef, tree_util.PyTreeDef) + + # If input_specs is known, verify that it matches actual inputs. + if ( + specialization.in_specs_treedef != in_specs_treedef + or specialization.in_specs_leaves != in_specs_leaves + ): raise ValueError( - "No devices found. colocated_python function without input" - " arguments must be first specialized with devices." + "Input specs in specialization and input specs of arguments must" + " have the same pytree structure, but they have the following" + " structural differences:\n" + + ( + "\n".join( + f" - {tree_util.keystr(path)} is a {thing1} in value 1" + f" and a {thing2} in value 2, so {explanation}.\n" + for path, thing1, thing2, explanation in tree_util.equality_errors_pytreedef( + specialization.in_specs_treedef, in_specs_treedef + ) + ) + ) ) - # Allow device polymorphism by applying devices specialization temporarily - # for this call. - return _make_callable(info, specialization.update(devices=devices))( - *args, **kwargs - ) - # Assertion is added to silence mypy error: Unsupported operand types for != - # ("PyTreeDef" and "None") [operator] - assert isinstance(specialization.in_specs_treedef, tree_util.PyTreeDef) - - # If input_specs is known, verify that it matches actual inputs. - if (specialization.in_specs_treedef != in_specs_treedef - or specialization.in_specs_leaves != in_specs_leaves): - raise ValueError( - "Input specs in specialization and input specs of arguments must have" - " the same pytree structure, but they have the following structural" - " differences:\n" - + ("\n".join( - f" - {tree_util.keystr(path)} is a {thing1} in value 1 and" - f" a {thing2} in value 2, so {explanation}.\n" - for path, thing1, thing2, explanation in tree_util.equality_errors_pytreedef( - specialization.in_specs_treedef, in_specs_treedef - )))) - - return _get_specialized_func(info, specialization)(*args, **kwargs) - - __call__ = wraps(info.fun)(__call__) - __call__.specialize = specialize - return __call__ + return self._get_specialized_func(info, specialization)(*args, **kwargs) + + __call__ = wraps(info.fun)(__call__) + __call__.specialize = specialize + return __call__ + + def make_callable( + self, + fun: Callable[..., Any], + fun_sourceinfo: str | None, + fun_signature: inspect.Signature | None, + ): + """Makes a colocated Python callable.""" + return self._make_callable( + FunctionInfo(fun, fun_sourceinfo, fun_signature), Specialization() + ) + + +_DEFAULT_FUNCTION_MAKER = _CachedColocatedFunctionMaker(None) + + +def make_callable( + fun: Callable[..., Any], + fun_sourceinfo: str | None, + fun_signature: inspect.Signature | None, +): + return _DEFAULT_FUNCTION_MAKER.make_callable( + fun, fun_sourceinfo, fun_signature + ) diff --git a/jax/experimental/colocated_python/obj.py b/jax/experimental/colocated_python/obj.py index 2351acd0d096..b804cd836a12 100644 --- a/jax/experimental/colocated_python/obj.py +++ b/jax/experimental/colocated_python/obj.py @@ -15,14 +15,15 @@ from __future__ import annotations +from collections.abc import Callable import inspect import random import threading from typing import Any -from collections.abc import Callable import jax from jax._src import api_util +from jax._src import config from jax._src import tree_util from jax._src.traceback_util import api_boundary from jax._src.util import wraps @@ -30,6 +31,20 @@ from jax.experimental.colocated_python import obj_backend +# TODO(madthanu): Remove the following config option and make its behavior the +# default, once the behavior has been declared stable. +_USE_WEAKREFS = config.bool_state( + 'jax_experimental_colocated_python_object_use_weakrefs_at_backend', + False, + help=( + 'Unstable in-development feature that switches the colocated-python' + ' implementation to internally use reference counting for destructing' + ' objects at the colocated backend, instead of invoking an explicit' + ' delete-object function from the frontend.' + ), +) + + class _InstanceRegistry: """Registry of object instances.""" @@ -78,19 +93,50 @@ def _make_method( init_kwargs: dict[str, Any], method_name: str, original_method: Callable[..., Any], + func_maker: func._CachedColocatedFunctionMaker, + use_weakrefs: bool, ): - # Initializer to use when the object is not present in the backend. - def initializer() -> object: - return cls(*init_args, **init_kwargs) - # Method to call on the backend. - def method(*args, **kwargs): - obj = obj_backend.SINGLETON_OBJECT_STORE.get_or_create(uid, initializer) - return getattr(obj, method_name)(*args, **kwargs) + class MethodCallerAtBackend: + + def __init__(self): + self._lock = threading.Lock() + + def __reduce__(self): + return type(self), () + + def _first_call(self): + # Temporarily hold a strong reference to a new object if it is created + # using initializer. + new_obj = None + + def initializer(): + nonlocal new_obj + new_obj = cls(*init_args, **init_kwargs) + if use_weakrefs: + import weakref + + return weakref.ref(new_obj) + return new_obj + + retrieved = obj_backend.SINGLETON_OBJECT_STORE.get_or_create( + uid, initializer + ) + + if use_weakrefs: + self.obj = retrieved() + else: + self.obj = retrieved + + def __call__(self, *args, **kwargs): + with self._lock: + if not hasattr(self, 'obj'): + self._first_call() + return getattr(self.obj, method_name)(*args, **kwargs) # Colocated Python callable for the controller. - callable = func.make_callable( - method, + callable = func_maker.make_callable( + MethodCallerAtBackend(), cls_sourceinfo, api_util.fun_signature(original_method), ) @@ -143,6 +189,8 @@ def __init__(self, *init_args, **init_kwargs) -> None: uid = self._colocated_python_uid = ( SINGLETON_INSTANCE_REGISTRY.new_instance() ) + self.func_maker = func._CachedColocatedFunctionMaker(uid) + self.use_weakrefs = _USE_WEAKREFS.value for attr_name in dir(cls): original_member = getattr(cls, attr_name) if not inspect.isfunction(original_member): @@ -162,12 +210,17 @@ def __init__(self, *init_args, **init_kwargs) -> None: init_kwargs, attr_name, original_member, + self.func_maker, + self.use_weakrefs, ) # TODO(hyeontaek): Support method specialization similar to function # specialization. setattr(self, attr_name, method) - def __del__(self) -> None: + def __del__(self): + del self.func_maker + if self.use_weakrefs: + return uid = self._colocated_python_uid devices = SINGLETON_INSTANCE_REGISTRY.pop_instance(uid) if devices: @@ -175,9 +228,6 @@ def __del__(self) -> None: def remove_object() -> None: obj_backend.SINGLETON_OBJECT_STORE.remove(uid) - # TODO(hyeontaek): Request "best-effort" non-SPMD execution that tries - # to run this function on any healthy processes instead of failing when - # any process of the execution is unhealthy. destructor = func.make_callable( remove_object, cls_sourceinfo, From 88d98bdb3cb08a9728c6c85789dee782e15afd0c Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 3 Dec 2025 00:17:57 -0800 Subject: [PATCH 013/315] Fix the jax line search, the zoom may only store two points in lo, hi and rec. PiperOrigin-RevId: 839622009 --- jax/_src/scipy/optimize/line_search.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/jax/_src/scipy/optimize/line_search.py b/jax/_src/scipy/optimize/line_search.py index 9bffd95328fb..6d16b67f1c66 100644 --- a/jax/_src/scipy/optimize/line_search.py +++ b/jax/_src/scipy/optimize/line_search.py @@ -191,6 +191,16 @@ def body(state): ), ), ) + state = state._replace( + **_binary_replace( + lo_to_j & ~hi_to_lo, + state._asdict(), + dict( + a_rec=state.a_lo, + phi_rec=state.phi_lo, + ), + ), + ) state = state._replace( **_binary_replace( lo_to_j, @@ -199,8 +209,6 @@ def body(state): a_lo=a_j, phi_lo=phi_j, dphi_lo=dphi_j, - a_rec=state.a_lo, - phi_rec=state.phi_lo, ), ), ) From 403522ded1be5cba5714cbbad0cce965fd65a2b0 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Wed, 3 Dec 2025 02:04:56 -0800 Subject: [PATCH 014/315] [Mosaic GPU] Add support for reductions when writing to GMEM via TMA in WG semantics. The new logic supports all reductions, though the underlying `async_copy` only supports `add` for now, so I only added a test for that. PiperOrigin-RevId: 839655895 --- jax/_src/pallas/mosaic_gpu/primitives.py | 8 +++ .../mosaic/gpu/dialect_lowering.py | 6 +++ jaxlib/mosaic/dialect/gpu/mosaic_gpu.td | 22 +++++++- tests/mosaic/gpu_test.py | 53 +++++++++++++++++++ tests/pallas/mosaic_gpu_test.py | 1 - 5 files changed, 88 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index b0ed9f5d0848..78e9302a40e2 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -248,6 +248,13 @@ def _copy_smem_to_gmem_lowering( "GMEM refs with peer ids are not supported in warpgroup lowering." ) assert not copy_params.get("gmem_transform") + if reduction_op is not None: + reduction_op_attr = getattr( + mgpu.dialect.TMAReduction, reduction_op.capitalize() + ) + else: + reduction_op_attr = None + mgpu.dialect.async_store( src, dst, @@ -255,6 +262,7 @@ def _copy_smem_to_gmem_lowering( slice_lengths, predicate=predicate, commit_group=commit_group, # type: ignore[call-arg] + reduction_op=reduction_op_attr, ) return () diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 0492bf0a13dd..0baf0d90ef6c 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -991,6 +991,11 @@ def _mgpu_async_store_op_lowering_rule( # flatten -> async_copy -> unflatted here, as long as flattened size is a # multiple of 16. + if store_op.reduction_op is not None: + reduction_op = mgpu.TMAReduction(store_op.reduction_op.value).name.lower() + else: + reduction_op = None + # TODO(dasenov): Add support for the remaining op properties. ctx.launch_context.async_copy( src_ref=unwrapped_source, @@ -1000,6 +1005,7 @@ def _mgpu_async_store_op_lowering_rule( gmem_transform=transforms, predicate=ctx.single_thread_per_warpgroup_predicate, arrive=store_op.commit_group, + reduction_op=reduction_op, ) return [] diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index 9cc2968ca8b0..ee262ff05400 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -195,6 +195,21 @@ def MosaicGPU_SwizzlingMode : I32EnumAttr<"SwizzlingMode", let cppNamespace = "::mosaic_gpu"; } +def MosaicGPU_TMAReduction : I32EnumAttr<"TMAReduction", + "Reduction operation for TMA.", + [ + I32EnumAttrCase<"Add", 0, "add">, + I32EnumAttrCase<"Min", 1, "min">, + I32EnumAttrCase<"Max", 2, "max">, + I32EnumAttrCase<"Inc", 3, "inc">, + I32EnumAttrCase<"Dec", 4, "dec">, + I32EnumAttrCase<"And", 5, "and">, + I32EnumAttrCase<"Or", 6, "or">, + I32EnumAttrCase<"Xor", 7, "xor"> + ]>{ + let cppNamespace = "::mosaic_gpu"; +} + def TileTransformAttr : MosaicGPU_Attr<"TileTransform", "tile"> { let parameters = (ins ArrayRefParameter<"int32_t", "tiling">:$tiling); let summary = "Specifies a transform that tiles suffix dimensions of a memref in SMEM."; @@ -345,6 +360,10 @@ def MosaicGPU_AsyncStoreOp : Op:$commit_group + DefaultValuedOptionalAttr:$commit_group, + OptionalAttr:$reduction_op ); let assemblyFormat = [{ diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 7b3cf8b18be6..dea06535ee1d 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -5346,6 +5346,59 @@ def body( x = self.prng.uniform(0, 10, input_shape).astype(el_type) self.assertArraysEqual(kernel(x), x.reshape(output_shape)) + @parameterized.parameters(jnp.float32, jnp.bfloat16, jnp.float16) + def test_async_store_add_reduction(self, dtype): + shape = (8, 128) + + def body(ctx, src, dst, smem): + del ctx + smem_ref, tma_barrier = smem + i32 = ir.IntegerType.get_signless(32) + zero = arith.constant(i32, 0) + indices = [zero, zero] + slice_lengths = smem_ref.type.shape + + tma_barrier.arrive_expect_tx( + utils.bitwidth(smem_ref.type.element_type) * math.prod(shape) // 8 + ) + + mgpu_dialect.async_load( + source=src, + destination=smem_ref, + barrier=tma_barrier.as_barrier_memref(), + indices=indices, + slice_lengths=slice_lengths, + collective=ir.ArrayAttr.get([]), + ) + + tma_barrier.wait() + + mgpu_dialect.async_store( + source=smem_ref, + destination=dst, + indices=indices, + slice_lengths=slice_lengths, + reduction_op=mgpu_dialect.TMAReduction.Add, + ) + nvvm.cp_async_bulk_wait_group(0) + + src = jnp.ones(shape, dtype=dtype) + dst = jnp.ones(shape, dtype=dtype) + + jax_shape = jax.ShapeDtypeStruct(shape, dtype) + kernel = mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(jax_shape,), + out_shape=(), + inout_shape=(jax_shape,), + smem_scratch_shape=[jax_shape, core.TMABarrier(1)], + thread_semantics=mgpu.LoweringSemantics.Warpgroup, + ) + + np.testing.assert_array_equal(kernel(src, dst)[0], src + dst) + class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase): diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 241db8287310..35edac84c16b 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -654,7 +654,6 @@ def kernel(x_ref, o_ref_gmem, scratch_ref): @parameterized.parameters(jnp.bfloat16, jnp.float16, jnp.float32) def test_copy_smem_to_gmem_reduction(self, dtype): - self.skip_if_wg_semantics() # Reduction not implemented. @functools.partial( self.pallas_call, grid=(200,), From 746e13ef02771ee9393886dee907e44dabf7e59d Mon Sep 17 00:00:00 2001 From: George Necula Date: Wed, 3 Dec 2025 13:47:35 +0100 Subject: [PATCH 015/315] [export] Fix forward compatibility for exporting nr_devices In #33492 I have added a new 32-bit field to store the nr_devices, and renamed the old 16-bit field to nr_devices_short. However, that change mistakenly stops populating the 16-bit field, which means that when the export is read by old code, it will have a nr_devices == 0. --- jax/_src/export/serialization.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax/_src/export/serialization.py b/jax/_src/export/serialization.py index 4dd4e5755ee4..f479a6222d7e 100644 --- a/jax/_src/export/serialization.py +++ b/jax/_src/export/serialization.py @@ -132,6 +132,7 @@ def _serialize_exported( ser_flatbuf.ExportedAddOutTree(builder, out_tree) ser_flatbuf.ExportedAddOutAvals(builder, out_avals) ser_flatbuf.ExportedAddNrDevices(builder, exp.nr_devices) + ser_flatbuf.ExportedAddNrDevicesShort(builder, exp.nr_devices) # For forward compatibility, can remove after January 2026 ser_flatbuf.ExportedAddInShardings(builder, in_shardings) ser_flatbuf.ExportedAddOutShardings(builder, out_shardings) ser_flatbuf.ExportedAddPlatforms(builder, platforms) From b8dbb940596715bfaf6579e12d47aa230e97a90d Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Wed, 3 Dec 2025 06:02:34 -0800 Subject: [PATCH 016/315] [Pallas:MGPU] Fix Pallas `swap` lowering rule for scalar constants. PiperOrigin-RevId: 839724087 --- jax/_src/pallas/mosaic_gpu/lowering.py | 1 + tests/pallas/mosaic_gpu_test.py | 11 +++++++++++ 2 files changed, 12 insertions(+) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 6b9a4e5b2a1b..e23416a566e5 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1758,6 +1758,7 @@ def _swap_lowering_rule_wg( "Transforms are not yet implemented for warpgroup semantics" ) assert isinstance(x_smem, ir.Value) + value = _ensure_ir_value(value, ctx.avals_in[1].dtype) if shape: old_value = mgpu.dialect.vector_load(x_smem) mgpu.dialect.vector_store(value, x_smem) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 35edac84c16b..5d25674c94a9 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1322,6 +1322,17 @@ def kernel(x_ref, o_ref, scratch_ref, barrier_ref): x = jnp.arange(math.prod(shape), dtype=jnp.int32).reshape(shape) np.testing.assert_array_equal(kernel(x), x * 2) + def test_swap_scalar_constant(self): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((), jnp.int32), + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + ) + def kernel(o_ref): + o_ref[...] = jnp.array(42) + + np.testing.assert_array_equal(kernel(), jnp.array(42, jnp.int32)) + def test_check(self): self.skip_if_wg_semantics() From 307bca0d03bd0b0d92f456d639f3fb5fb7f4ec8d Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Wed, 3 Dec 2025 07:24:24 -0800 Subject: [PATCH 017/315] [Mosaic] Support packed BroadcastInSublanesOp. PiperOrigin-RevId: 839747685 --- jaxlib/mosaic/dialect/tpu/tpu.td | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 5f66e7caf9ff..892032666299 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -624,9 +624,13 @@ def TPU_RepeatOp : TPU_Op<"repeat", [Pure]> { def TPU_BroadcastInSublanesOp : TPU_Op<"broadcast_in_sublanes", [Pure]> { let description = [{ - For each sublane `i`, broadcasts the value in lane `lane + i` along the entire - sublane. If `lane + i` is not in [0, lane_count), then the value in sublane `i` - is not defined (can be anything). + For each sublane `i`, broadcasts the value in lane `lane + i` along the + entire sublane. For packed type, imagine the data is compressed unpacked + along sublane dimension, and the sublane count is multiplied by the packing + factor. + For example, for i16 with sublane count 8, `i` above is in [0, 8 * 2). + If `lane + i` is not in [0, lane_count), then the value in sublane `i` is + not defined (can be anything). }]; let arguments = (ins TPU_Vreg:$source, // All sublanes should be equal. From dd01b37c224d779c6e8e20425f2faff1fc92f6da Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 3 Dec 2025 08:52:19 -0800 Subject: [PATCH 018/315] Fix deprecated use of arrays in place of dtypes --- jax/_src/lax/lax.py | 2 +- jax/_src/numpy/reductions.py | 2 +- tests/api_test.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 6cb687d8f2bd..e06a842b733c 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -3485,7 +3485,7 @@ def _delta(dtype: DTypeLike, shape: Shape, axes: Sequence[int]) -> Array: def _tri(dtype: DTypeLike, shape: Shape, offset: DimSize) -> Array: """Like numpy.tri, create a 2D array with ones below a diagonal.""" offset = asarray(core.dimension_as_value(offset)) - if not dtypes.issubdtype(offset, np.integer): + if not dtypes.issubdtype(offset.dtype, np.integer): raise TypeError(f"offset must be an integer, got {offset!r}") shape_dtype = lax_utils.int_dtype_for_shape(shape, signed=True) if ( diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index ec56e1c0506b..05cc64bf1568 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -202,7 +202,7 @@ def _cast_to_numeric(operand: Array) -> Array: return promote_dtypes_numeric(operand)[0] def _require_integer(arr: Array) -> Array: - if not dtypes.isdtype(arr, ("bool", "integral")): + if not dtypes.isdtype(arr.dtype, ("bool", "integral")): raise ValueError(f"integer argument required; got dtype={arr.dtype}") return arr diff --git a/tests/api_test.py b/tests/api_test.py index 8bfb5bc3d146..98c4ffed849a 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4174,7 +4174,7 @@ def __jax_array__(self): x = jnp.array(1) a = AlexArray(x) - for f in [jnp.isscalar, jnp.size, jnp.shape, jnp.dtype]: + for f in [jnp.isscalar, jnp.size, jnp.shape, jnp.result_type]: self.assertEqual(f(x), f(a)) x = AlexArray(jnp.array(1)) From 8f48ac88e64a60fea06dfda95ed51d541e07a108 Mon Sep 17 00:00:00 2001 From: Yulia Baturina Date: Wed, 3 Dec 2025 09:27:30 -0800 Subject: [PATCH 019/315] Build JAX artifacts only in CPU pools for RBE configurations. PiperOrigin-RevId: 839788470 --- .bazelrc | 5 ++++- ci/build_artifacts.sh | 6 ++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/.bazelrc b/.bazelrc index 7a6ea6c44fa1..344c03a2278c 100644 --- a/.bazelrc +++ b/.bazelrc @@ -396,6 +396,9 @@ common:rbe --spawn_strategy=remote,worker,standalone,local common:rbe --remote_download_toplevel test:rbe --test_env=USER=anon +common:rbe_cpu_pool --repo_env=REMOTE_GPU_TESTING=0 +common:rbe_gpu_pool --repo_env=REMOTE_GPU_TESTING=1 + # RBE configs for Linux x86 # Set the remote worker pool common:rbe_linux_x86_64_base --remote_instance_name=projects/tensorflow-testing/instances/default_instance @@ -416,7 +419,7 @@ common:rbe_linux_x86_64 --config=rbe_linux_x86_64_base common:rbe_linux_x86_64 --config=ci_linux_x86_64 common:rbe_linux_x86_64_cuda_common --config=rbe_linux_x86_64_base -common:rbe_linux_x86_64_cuda_common --repo_env=REMOTE_GPU_TESTING=1 +common:rbe_linux_x86_64_cuda_common --config=rbe_gpu_pool common:rbe_linux_x86_64_cuda12 --config=rbe_linux_x86_64_cuda_common common:rbe_linux_x86_64_cuda12 --config=ci_linux_x86_64_cuda12 diff --git a/ci/build_artifacts.sh b/ci/build_artifacts.sh index 4b72cda1c54f..3a9d3dcf8c6a 100755 --- a/ci/build_artifacts.sh +++ b/ci/build_artifacts.sh @@ -109,10 +109,11 @@ if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then # Build the artifact. python build/build.py build --wheels="$artifact" \ --bazel_options=--config="$bazelrc_config" $bazel_remote_cache \ + --bazel_options=--config=rbe_cpu_pool \ --bazel_startup_options="$bazel_startup_options" \ --python_version=$JAXCI_HERMETIC_PYTHON_VERSION \ $cuda_version_flag \ - --verbose --detailed_timestamped_log --use_new_wheel_build_rule \ + --verbose --detailed_timestamped_log \ --output_path="$JAXCI_OUTPUT_DIR" \ $artifact_tag_flags @@ -121,10 +122,11 @@ if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then if [[ "$JAXCI_ARTIFACT_TYPE" == "release" ]]; then python build/build.py build --wheels="$artifact" \ --bazel_options=--config="$bazelrc_config" $bazel_remote_cache \ + --bazel_options=--config=rbe_cpu_pool \ --bazel_startup_options="$bazel_startup_options" \ --python_version=$JAXCI_HERMETIC_PYTHON_VERSION \ $cuda_version_flag \ - --verbose --detailed_timestamped_log --use_new_wheel_build_rule \ + --verbose --detailed_timestamped_log \ --output_path="$JAXCI_OUTPUT_DIR" \ $artifact_tag_flags --bazel_options=--repo_env=ML_WHEEL_VERSION_SUFFIX="$JAXCI_WHEEL_RC_VERSION" fi From 23cd412cebf6ac17b27b3d0699aced5c9002ae89 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Wed, 3 Dec 2025 11:52:28 -0800 Subject: [PATCH 020/315] PR #33677: [no-thunks] Reduce stack frames in `cond` and friends Imported from GitHub PR https://github.com/jax-ml/jax/pull/33677 Copybara import of the project: -- 8e772e71f3a4402c94387635f489caf2a79e32e4 by Dougal : [no-thunks] Reduce stack frames in `cond` and friends Also deprecate the very old form of `cond`: `cond(predicate, true_arg, true_fun, false_arg, false_fun)`. Merging this change closes #33677 COPYBARA_INTEGRATE_REVIEW=https://github.com/jax-ml/jax/pull/33677 from jax-ml:fewer-cond-stack-frames 8e772e71f3a4402c94387635f489caf2a79e32e4 PiperOrigin-RevId: 839851241 --- jax/_src/interpreters/partial_eval.py | 4 +- jax/_src/lax/control_flow/__init__.py | 3 - jax/_src/lax/control_flow/common.py | 78 ++++++++--------------- jax/_src/lax/control_flow/conditionals.py | 64 ++++--------------- jax/_src/lax/control_flow/loops.py | 9 ++- jax/_src/lax/control_flow/solves.py | 16 ++--- tests/api_test.py | 10 +-- tests/core_test.py | 4 -- tests/lax_control_flow_test.py | 58 +++++++---------- tests/metadata_test.py | 2 +- 10 files changed, 85 insertions(+), 163 deletions(-) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index a3dbfc55714f..757718104bc2 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -2400,7 +2400,7 @@ def trace_to_jaxpr( in_tree: PyTreeDef, in_avals_flat: Sequence[AbstractValue | core.AvalQDD], debug_info: core.DebugInfo -) -> tuple[Jaxpr, PyTreeDef, list[Any]]: +) -> tuple[ClosedJaxpr, PyTreeDef, list[Any]]: config.enable_checks.value and debug_info.assert_arg_names(len(in_avals_flat)) parent_trace = core.trace_ctx.trace trace = DynamicJaxprTrace(debug_info, parent_trace=parent_trace) @@ -2424,6 +2424,8 @@ def trace_to_jaxpr( del trace, fun, in_tracers_flat, in_tracers, out_tracers, ans, ans_flat config.enable_checks.value and core.check_jaxpr(jaxpr) + # TODO(dougalm): remove this once we merge Jaxpr and ClosedJaxpr + jaxpr = close_jaxpr(convert_constvars_jaxpr(jaxpr)) return jaxpr, out_tree, consts diff --git a/jax/_src/lax/control_flow/__init__.py b/jax/_src/lax/control_flow/__init__.py index 44ee94e14ca2..5cbe5a39d381 100644 --- a/jax/_src/lax/control_flow/__init__.py +++ b/jax/_src/lax/control_flow/__init__.py @@ -50,9 +50,6 @@ # Private utilities used elsewhere in JAX # TODO(sharadmv): lift them into a more common place from jax._src.lax.control_flow.common import ( - _initial_style_open_jaxpr as _initial_style_open_jaxpr, - _initial_style_jaxpr as _initial_style_jaxpr, - _initial_style_jaxprs_with_common_consts as _initial_style_jaxprs_with_common_consts, _check_tree_and_avals as _check_tree_and_avals, ) diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py index d29746560b46..3d78d193ddc3 100644 --- a/jax/_src/lax/control_flow/common.py +++ b/jax/_src/lax/control_flow/common.py @@ -43,78 +43,54 @@ def _typecheck_param(prim, param, name, msg_required, pred): msg = sep.join([msg, param_str]) raise core.JaxprTypeError(msg) -# TODO(dougalm): this is a silly wrapper now. Delete it. -@weakref_lru_cache -def _initial_style_open_jaxpr(fun: Callable, - in_tree: PyTreeDef, - in_avals: Sequence[core.AbstractValue | core.AvalQDD], - debug_info: core.DebugInfo): - jaxpr, out_tree, consts = pe.trace_to_jaxpr(fun, in_tree, in_avals, debug_info) - return jaxpr, consts, out_tree - -# TODO(dougalm): Delete. Make `trace_to_jaxpr` do the jaxpr-closing thing instead. -@weakref_lru_cache -def _initial_style_jaxpr(fun: Callable, - in_tree: PyTreeDef, - in_avals: Sequence[core.AbstractValue], - debug_info: core.DebugInfo) -> tuple[core.ClosedJaxpr, Sequence[Any], PyTreeDef]: - jaxpr, consts, out_tree = _initial_style_open_jaxpr( - fun, in_tree, in_avals, debug_info) - closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr)) - return closed_jaxpr, consts, out_tree - -def _initial_style_jaxprs_with_common_consts( - funs: Sequence[Callable], - in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue | core.AvalQDD], - debug_infos: Sequence[core.DebugInfo]): - jaxpr_data = [_initial_style_open_jaxpr(fn, in_tree, in_avals, debug_info) - for fn, debug_info in zip(funs, debug_infos)] - if not jaxpr_data: return [], [], [] - jaxprs, all_consts, all_out_trees = zip(*jaxpr_data) - +# TODO(dougalm): this seems way too complicated. Why not allow different consts for each +# branch of a switch? +def _merge_common_consts( + jaxprs: Sequence[core.Jaxpr], + all_consts: Sequence[Sequence[Any]] + ) -> tuple[Sequence[core.ClosedJaxpr], Sequence[Any]]: # Jaxprs must share consts, so we concat consts and pad the jaxprs' constvars. lens = map(len, all_consts) consts = [c for cs in all_consts for c in cs] avalqdds = tuple(map(core.cur_aval_qdd, consts)) - jaxprs = [_pad_constvars(jaxpr, avalqdds[:sum(lens[:i])], avalqdds[sum(lens[:i+1]):]) - for i, jaxpr in enumerate(jaxprs)] + num_constss = [len(cs) for cs in all_consts] + jaxprs = [_pad_constvars(jaxpr, num_consts, avalqdds[:sum(lens[:i])], avalqdds[sum(lens[:i+1]):]) + for i, (jaxpr, num_consts) in enumerate(zip(jaxprs, num_constss))] # De-duplicate shared constants. const_ids = tuple(id(c) for c in consts) seen = set() - consts = [c for c in consts if id(c) not in seen and not seen.add(id(c))] # type: ignore - jaxprs = [_dedup_consts(jaxpr, const_ids) for jaxpr in jaxprs] - - closed_jaxprs = [pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr)) - for jaxpr in jaxprs] - return closed_jaxprs, consts, all_out_trees + dd_consts = [c for c in consts if id(c) not in seen and not seen.add(id(c))] # type: ignore + jaxprs = [_dedup_consts(jaxpr, len(consts), const_ids) for jaxpr in jaxprs] + return jaxprs, dd_consts @weakref_lru_cache -def _pad_constvars(jaxpr: core.Jaxpr, left: tuple[core.AvalQDD, ...], - right: tuple[core.AbstractValue, ...]) -> core.Jaxpr: +def _pad_constvars(jaxpr: core.ClosedJaxpr, num_consts: int, + left: tuple[core.AvalQDD, ...], + right: tuple[core.AbstractValue, ...]) -> core.ClosedJaxpr: def make_var(aq): return core.Var(aq.aval, initial_qdd=aq.qdd, final_qdd=aq.qdd) - constvars = [*map(make_var, left), *jaxpr.constvars, *map(make_var, right)] - effs = pe._renumber_effects([*constvars, *jaxpr.invars], - [*jaxpr.constvars, *jaxpr.invars], jaxpr.effects) - jaxpr = jaxpr.replace(constvars=constvars, effects=effs) + invars = [*map(make_var, left), *jaxpr.invars[:num_consts], + *map(make_var, right), *jaxpr.invars[num_consts:]] + effs = pe._renumber_effects(invars, jaxpr.invars, jaxpr.effects) + jaxpr = jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(invars=invars, effects=effs)) config.enable_checks.value and core.check_jaxpr(jaxpr) return jaxpr @weakref_lru_cache -def _dedup_consts(jaxpr, const_ids): +def _dedup_consts(jaxpr, num_consts, const_ids): newvars = {} canonicalize = {v: newvars.setdefault(constid, v) - for constid, v in zip(const_ids, jaxpr.constvars)} + for constid, v in zip(const_ids, jaxpr.invars[:num_consts])} eqns = [e.replace(invars=[canonicalize.get(x, x) if isinstance(x, core.Var) else x for x in e.invars]) for e in jaxpr.eqns] outvars = [canonicalize.get(x, x) if isinstance(x, core.Var) else x for x in jaxpr.outvars] - constvars = list(newvars.values()) - effs = pe._renumber_effects( - [*constvars, *jaxpr.invars], - [*map(canonicalize.get, jaxpr.constvars), *jaxpr.invars], jaxpr.effects) - jaxpr = jaxpr.replace(constvars=constvars, eqns=eqns, outvars=outvars, - effects=effs) + invars = [*list(newvars.values()), *jaxpr.invars[num_consts:]] + effs = pe._renumber_effects(invars, + [*map(canonicalize.get, jaxpr.invars[:num_consts]), *jaxpr.invars[num_consts:]], + jaxpr.effects) + jaxpr = jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(invars=invars, eqns=eqns, outvars=outvars, + effects=effs)) config.enable_checks.value and core.check_jaxpr(jaxpr) return jaxpr diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 668b880ab4dd..eab8b92b742d 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -53,7 +53,7 @@ import numpy as np from jax._src.lax.control_flow.common import ( - _avals_short, _typecheck_param, _initial_style_jaxprs_with_common_consts, + _avals_short, _typecheck_param, _merge_common_consts, _make_closed_jaxpr, _prune_zeros) map, unsafe_map = safe_map, map @@ -149,8 +149,10 @@ def _switch_internal( if config.mutable_array_checks.value: api_util.check_no_aliased_ref_args(lambda: dbgs[0], ops_avals, ops) - jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts( - branches, ops_tree, ops_avals, dbgs) + jaxprs_, out_trees, all_consts = zip(*[pe.trace_to_jaxpr( + branch, ops_tree, ops_avals, dbg) for branch, dbg in zip(branches, dbgs)]) + jaxprs, consts = _merge_common_consts(jaxprs_, all_consts) + if config.mutable_array_checks.value: api_util._check_no_aliased_closed_over_refs(dbgs[0], (*jaxprs[0].consts, *consts), ops) for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])): @@ -184,7 +186,7 @@ def _switch_internal( return tree_unflatten(out_trees[0], out) @partial(api_boundary, repro_api_name="jax_cond") -def _cond(pred, true_fun: Callable, false_fun: Callable, *operands, +def cond(pred, true_fun: Callable, false_fun: Callable, *operands, operand=_no_operand_sentinel): """Conditionally apply ``true_fun`` or ``false_fun``. @@ -270,14 +272,16 @@ def cond(pred, true_fun, false_fun, *operands): if config.mutable_array_checks.value: api_util.check_no_aliased_ref_args(lambda: dbg_true_fun, ops_avals, ops) dbg_false_fun = api_util.debug_info("cond", false_fun, operands, {}) - jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts( - (true_fun, false_fun), ops_tree, ops_avals, - [dbg_true_fun, dbg_false_fun]) - true_jaxpr, false_jaxpr = jaxprs + + true_jaxpr_, out_tree, true_consts = pe.trace_to_jaxpr( + true_fun, ops_tree, ops_avals, dbg_true_fun) + false_jaxpr_, false_out_tree, false_consts = pe.trace_to_jaxpr( + false_fun, ops_tree, ops_avals, dbg_false_fun) + (true_jaxpr, false_jaxpr), consts = _merge_common_consts( + (true_jaxpr_, false_jaxpr_), (true_consts, false_consts)) if config.mutable_array_checks.value: api_util._check_no_aliased_closed_over_refs(dbg_true_fun, (*true_jaxpr.consts, *consts), ops) - out_tree, false_out_tree = out_trees if any(isinstance(out_aval, AbstractRef) for out_aval in true_jaxpr.out_avals + false_jaxpr.out_avals): raise ValueError("Cannot return `Ref`s from `cond`.") @@ -399,48 +403,6 @@ def _capitalize(s): # s.capitalize() converts s[1:] to lowercase which we don't want. return s[0].capitalize() + s[1:] -@api_boundary -@functools.wraps(_cond) -def cond(*args, **kwargs): - # detect an attempt to call the former, deprecated cond - try: - ba = inspect.signature(_cond_with_per_branch_args).bind(*args, **kwargs) - except TypeError: - pass - else: - assert not ba.kwargs # no catch-all **kwargs in _cond_with_per_branch - _, true_operand, true_fun, false_operand, false_fun = ba.args - if callable(true_operand) and callable(true_fun): - # treat this as modern cond (with two operands) - return _cond(*args, **kwargs) - if callable(true_fun) and callable(false_fun): - return _cond_with_per_branch_args(*ba.args) - - return _cond(*args, **kwargs) - -@partial(api_boundary, repro_api_name="jax_cond_with_per_branch_args") -def _cond_with_per_branch_args(pred, - true_operand, true_fun: Callable, - false_operand, false_fun: Callable): - """Conditionally apply ``true_fun`` or ``false_fun``. - - Has equivalent semantics to this Python implementation:: - - def cond(pred, true_operand, true_fun, false_operand, false_fun): - if pred: - return true_fun(true_operand) - else: - return false_fun(false_operand) - - Pred has to be a scalar type, collection types (list, tuple) are not supported - """ - if not (callable(true_fun) and callable(false_fun)): - raise TypeError("lax.cond: true_fun and false_fun arguments should be callable.") - return _cond(pred, - lambda op: true_fun(op[0]), - lambda op: false_fun(op[1]), - (true_operand, false_operand)) - def _join_cond_effects(branches: Sequence[core.ClosedJaxpr]) -> effects.Effects: joined_effects = set() for b in branches: diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index d5e31b0ca0a2..dab448cde147 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -50,7 +50,7 @@ from jax._src.lax import slicing from jax._src.lax import windowed_reductions from jax._src.lax.control_flow.common import ( - _avals_short, _initial_style_jaxpr, _prune_zeros, _typecheck_param, + _avals_short, _prune_zeros, _typecheck_param, _make_closed_jaxpr) from jax._src.lax.other import logaddexp from jax._src.pjit import auto_axes, PartitionSpec as P, reshard @@ -281,9 +281,8 @@ def _create_jaxpr(init): init_flat, init_tree = tree_flatten(init) in_flat, in_tree = tree_flatten((init, xs)) carry_avals = tuple(_map(core.get_aval, init_flat)) - open_jaxpr, out_tree, consts = pe.trace_to_jaxpr( + jaxpr, out_tree, consts = pe.trace_to_jaxpr( f, in_tree, (*carry_avals, *x_avals), debug_info=dbg_body) - jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(open_jaxpr)) if config.mutable_array_checks.value: _check_no_aliased_closed_over_refs(dbg_body, (*jaxpr.consts, *consts), in_flat) out_tree_children = out_tree.children() @@ -1712,10 +1711,10 @@ def _create_jaxpr(init_val): init_vals, in_tree = tree_flatten((init_val,)) init_avals = tuple(_map(core.get_aval, init_vals)) cond_dbg = api_util.debug_info("while_cond", cond_fun, (init_val,), {}) - cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr( + cond_jaxpr, cond_tree, cond_consts = pe.trace_to_jaxpr( cond_fun, in_tree, init_avals, cond_dbg) body_dbg = api_util.debug_info("while_body", body_fun, (init_val,), {}) - body_jaxpr, body_consts, body_tree = _initial_style_jaxpr( + body_jaxpr, body_tree, body_consts = pe.trace_to_jaxpr( body_fun, in_tree, init_avals, body_dbg) if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1: msg = "cond_fun must return a boolean scalar, but got pytree {}." diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py index 17ed44b69991..e65f0cda1480 100644 --- a/jax/_src/lax/control_flow/solves.py +++ b/jax/_src/lax/control_flow/solves.py @@ -27,6 +27,7 @@ from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir +from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import pxla from jax._src.traceback_util import api_boundary from jax._src.tree_util import (tree_flatten, treedef_children, tree_leaves, @@ -36,7 +37,6 @@ from jax._src.lax.control_flow.common import ( _check_tree, - _initial_style_jaxpr, ) _map = safe_map @@ -95,7 +95,7 @@ def custom_root(f: Callable, guess_flat, in_args_tree = tree_flatten((initial_guess,)) guess_avals = tuple(_map(core.get_aval, guess_flat)) f_debug = api_util.debug_info("custom_root", f, (initial_guess,), {}) - f_jaxpr, f_consts, out_tree = _initial_style_jaxpr( + f_jaxpr, out_tree, f_consts = pe.trace_to_jaxpr( f, in_args_tree, guess_avals, f_debug) in_tree, = treedef_children(in_args_tree) @@ -104,7 +104,7 @@ def custom_root(f: Callable, solve_debug = api_util.debug_info("custom_root solve", solve, (f, initial_guess), {}, static_argnums=(0,)) - solve_jaxpr, solve_consts, solution_tree = _initial_style_jaxpr( + solve_jaxpr, solution_tree, solve_consts = pe.trace_to_jaxpr( partial(solve, f), in_args_tree, guess_avals, solve_debug) _check_tree("solve", "initial_guess", solution_tree, in_tree, has_aux) @@ -114,7 +114,7 @@ def linearize_and_solve(x, b): linearize_and_solve_dbg = api_util.debug_info("custom_root tangent_solve", tangent_solve, (initial_guess, initial_guess), {}) - l_and_s_jaxpr, l_and_s_consts, out_tree = _initial_style_jaxpr( + l_and_s_jaxpr, out_tree, l_and_s_consts = pe.trace_to_jaxpr( linearize_and_solve, treedef_tuple((in_tree,) * 2), guess_avals * 2, linearize_and_solve_dbg) _check_tree("tangent_solve", "x", out_tree, in_tree, False) @@ -268,7 +268,7 @@ def f_aux(x): matvec_debug = api_util.debug_info("custom_linear_solve", matvec, (b,), {}) # no auxiliary data assumed for matvec - matvec_jaxpr, matvec_consts, out_tree = _initial_style_jaxpr( + matvec_jaxpr, out_tree, matvec_consts = pe.trace_to_jaxpr( _shape_checked(matvec, "matvec", False), in_args_tree, b_avals, matvec_debug) _check_tree("matvec", "b", out_tree, tree, False) @@ -276,7 +276,7 @@ def f_aux(x): solve_debug = api_util.debug_info("custom_linear_solve solve", solve, (matvec, b), {}, static_argnums=(0,)) - solve_jaxpr, solve_consts, out_tree = _initial_style_jaxpr( + solve_jaxpr, out_tree, solve_consts = pe.trace_to_jaxpr( _shape_checked(partial(solve, matvec), "solve", has_aux), in_args_tree, b_avals, solve_debug) _check_tree("solve", "b", out_tree, tree, has_aux) @@ -294,11 +294,11 @@ def f_aux(x): vecmat_consts = matvec_consts else: vecmat = _transpose_one_output(matvec, b) - vecmat_jaxpr, vecmat_consts, out_tree = _initial_style_jaxpr( + vecmat_jaxpr, out_tree, vecmat_consts = pe.trace_to_jaxpr( vecmat, in_args_tree, b_avals, transpose_solve_debug) assert out_tree == tree - tr_solve_jaxpr, tr_solve_consts, out_tree = _initial_style_jaxpr( + tr_solve_jaxpr, out_tree, tr_solve_consts = pe.trace_to_jaxpr( _shape_checked(partial(transpose_solve, vecmat), "transpose_solve", has_aux), in_args_tree, b_avals, transpose_solve_debug) _check_tree("transpose_solve", "b", out_tree, tree, has_aux) diff --git a/tests/api_test.py b/tests/api_test.py index 98c4ffed849a..f479739a1042 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -7207,10 +7207,10 @@ def fun(x): def test_cond(self): def f(x): return lax.cond(x >= 0., + lambda xt, _: xt + x, + lambda _, xf: xf - x, x + 1., - lambda xt: xt + x, - x + 2., - lambda xf: xf - x) + x + 2.) expected = """{ lambda ; a:f32[]. let b:bool[] = ge a 0.0:f32[] c:f32[] = add a 1.0:f32[] @@ -7941,10 +7941,10 @@ def f(c, x): jax.lax.scan(f, 0, jnp.arange(4)) def test_cond_traceback(self): - if sys.version_info < (3, 14): + if sys.version_info < (3, 13): # Fails because 3.11 adds an extra stack frame due to a list comprehension self.skipTest("Expected failure.") - expected_depth = 8 + expected_depth = 4 init_depth = self.cur_depth() def f(): diff --git a/tests/core_test.py b/tests/core_test.py index c7cf4918f8e5..a2b3da15fd2c 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -448,15 +448,11 @@ class JaxprTypeChecks(jtu.JaxTestCase): def setUp(self): super().setUp() - lax_control_flow._initial_style_open_jaxpr.cache_clear() - lax_control_flow._initial_style_jaxpr.cache_clear() lax_control_flow.common._dedup_consts.cache_clear() lax_control_flow.common._pad_constvars.cache_clear() def tearDown(self): super().tearDown() - lax_control_flow._initial_style_open_jaxpr.cache_clear() - lax_control_flow._initial_style_jaxpr.cache_clear() lax_control_flow.common._dedup_consts.cache_clear() lax_control_flow.common._pad_constvars.cache_clear() diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 01c8898749cc..7a3a45a485d9 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -53,25 +53,14 @@ # provides a lax.cond-compatible interface to a two-branch lax.switch. Several # tests in this file are parameterized such that they either call into lax.cond # or into this function. -def cond_via_switch(pred, true_fun, false_fun, op, *args): - if len(args) > 0: - assert len(args) == 1 - true_op, _true_fun, false_op, _false_fun = true_fun, false_fun, op, args[0] - op = (false_op, true_op) - false_fun = lambda op: _false_fun(op[0]) - true_fun = lambda op: _true_fun(op[1]) +def cond_via_switch(pred, true_fun, false_fun, *args): index = lax.convert_element_type(pred, np.int32) - return lax.switch(index, [false_fun, true_fun], op) - -def cond_with_new_checkpoint(pred, true_fun, false_fun, op, *args): - if args: - true_op, _true_fun, false_op, _false_fun = true_fun, false_fun, op, args[0] - op = (false_op, true_op) - false_fun = lambda op: _false_fun(op[0]) - true_fun = lambda op: _true_fun(op[1]) + return lax.switch(index, [false_fun, true_fun], *args) + +def cond_with_new_checkpoint(pred, true_fun, false_fun, *args): index = lax.convert_element_type(pred, np.int32) - fn = lambda index, op: lax.switch(index, [false_fun, true_fun], op) - return jax.checkpoint(fn)(index, op) + fn = lambda index, *args: lax.switch(index, [false_fun, true_fun], *args) + return jax.checkpoint(fn)(index, *args) COND_IMPLS = [ (lax.cond, 'cond'), @@ -171,8 +160,6 @@ class LaxControlFlowTest(jtu.JaxTestCase): def setUp(self): super().setUp() - lax_control_flow._initial_style_open_jaxpr.cache_clear() - lax_control_flow._initial_style_jaxpr.cache_clear() lax_control_flow.common._dedup_consts.cache_clear() lax_control_flow.common._pad_constvars.cache_clear() @@ -1000,8 +987,8 @@ def cfun(x): lax.lt(x, 2), lambda x: lax.mul(2, x), lambda x: cond(lax.lt(x, 5), - x, lambda x: lax.mul(3, x), - 4, lambda y: lax.mul(y, x)), + lambda x, _: lax.mul(3, x), + lambda _, y: lax.mul(y, x), x, 4), x) self.assertEqual(cfun(1), 2) @@ -1121,9 +1108,9 @@ def cfun(x): def testCondBatched(self): def fun(x, y, z): pred = lax.lt(x, 3) - true_fun = lambda y: y - false_fun = lambda z: lax.neg(z) - return lax.cond(pred, y, true_fun, z, false_fun) + true_fun = lambda y, _: y + false_fun = lambda _, z: lax.neg(z) + return lax.cond(pred, true_fun, false_fun, y, z) # these cases stay as cond x = jnp.array(2) @@ -1287,7 +1274,7 @@ def fun_ref(x): return 2. * x def fun(x): - return cond(x < 3, None, lambda _: 2., x, lambda x: 2. * x) + return cond(x < 3, lambda _: 2., lambda x: 2. * x, x) x = 3.14 ans = jax.jvp(fun, (x,), (x,)) @@ -1445,7 +1432,7 @@ def fun_ref(x): return 2. * x def fun(x): - return cond(x < 3, None, lambda _: 2., x, lambda x: 2. * x) + return cond(x < 3, lambda _: 2., lambda x: 2. * x, x) x = 3.14 ans = jax.grad(fun)(x) @@ -1475,8 +1462,9 @@ def fun_ref(x, y): def fun(x, y): return cond( x < 3, - None, lambda _: 2. * jnp.sin(y), - x, lambda x: 2. * x) + lambda _: 2. * jnp.sin(y), + lambda x: 2. * x, + x) y = 5.8 x = 3.14 @@ -1665,7 +1653,7 @@ def g(x): return jnp.where(x > 0, f_1(x), f_2(x)) def testIssue1263(self): def f(rng, x): cond = random.bernoulli(rng) - return lax.cond(cond, x, lambda x: x, jnp.abs(x) - 1., lambda x: x) + return lax.cond(cond, lambda x, _: x, lambda _, x: x, x, jnp.abs(x) - 1.) def body_fn(i, state): rng, x = state @@ -1680,8 +1668,9 @@ def g(rng, x): def testIssue514(self): # just check this doesn't crash lax.cond(True, - (0, 0), lambda x: (x[0], 0), - (1, 1), lambda x: x) + lambda x, _: (x[0], 0), + lambda _, x: x, + (0, 0), (1, 1)) def testIssue649(self): from jax import lax @@ -2388,8 +2377,9 @@ def testWhileGradError(self, loop: str = "fori_inside_scan"): elif loop == "fori_inside_cond": func = lambda x: lax.cond( True, - x, lambda x: lax.fori_loop(x, x + 2., lambda i, c: c * 2., x), - 1., lambda x: x) + lambda x, _: lax.fori_loop(x, x + 2., lambda i, c: c * 2., x), + lambda _, x: x, + x, 1.) elif loop == "fori_inside_scan": func = lambda x: lax.scan( lambda c, x: (lax.fori_loop(x, x + 2., lambda i, c1: c1 * c, x), None), @@ -2561,7 +2551,7 @@ def f(h, _): def test_disable_jit_cond_with_vmap(self): # https://github.com/jax-ml/jax/issues/3093 def fn(t): - return lax.cond(t > 0, 0, lambda x: 0, 0, lambda x: 1) + return lax.cond(t > 0, lambda x, _: 0, lambda _, x: 1, 0, 0) fn = jax.vmap(fn) with jax.disable_jit(): diff --git a/tests/metadata_test.py b/tests/metadata_test.py index 917cf7bf5133..524768aaed87 100644 --- a/tests/metadata_test.py +++ b/tests/metadata_test.py @@ -79,7 +79,7 @@ def true_fun(x): def false_fun(x): return jnp.cos(x) def f(which, x): - return jax.lax.cond(which, x, true_fun, x, false_fun) + return jax.lax.cond(which, true_fun, false_fun, x) hlo = module_to_string(jax.jit(f).lower(True, 1.).compiler_ir()) self.assertRegex(hlo, r'loc\(".*cond/branch_0_fun/cos"') self.assertRegex(hlo, r'loc\(".*cond/branch_1_fun/sin"') From eaf6efc1beebaf922ab2db7db6cdf8398db12e88 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Wed, 3 Dec 2025 12:12:26 -0800 Subject: [PATCH 021/315] [Pallas] Fix cost estimate einsum bug. PiperOrigin-RevId: 839859433 --- ci/run_bazel_test_tpu.sh | 1 - jax/_src/pallas/cost_estimate.py | 4 +++- tests/pallas/BUILD | 8 +++----- tests/pallas/pallas_cost_estimate_test.py | 17 +++++++++++++++++ 4 files changed, 23 insertions(+), 7 deletions(-) diff --git a/ci/run_bazel_test_tpu.sh b/ci/run_bazel_test_tpu.sh index e17ba028bb79..19acc331c716 100755 --- a/ci/run_bazel_test_tpu.sh +++ b/ci/run_bazel_test_tpu.sh @@ -183,7 +183,6 @@ else //tests/pallas:tpu_pallas_test_tpu \ //tests/pallas:tpu_pallas_call_print_test_tpu \ //tests/pallas:indexing_test_tpu \ - //tests/pallas:pallas_cost_estimate_test_tpu \ //tests/pallas:pallas_error_handling_test_tpu \ //tests/pallas:pallas_jumble_test_tpu \ //tests/pallas:pallas_shape_poly_test_tpu \ diff --git a/jax/_src/pallas/cost_estimate.py b/jax/_src/pallas/cost_estimate.py index ef2cf10beef3..83d35e2ae977 100644 --- a/jax/_src/pallas/cost_estimate.py +++ b/jax/_src/pallas/cost_estimate.py @@ -205,10 +205,12 @@ def dot_general_cost_rule(ctx: Context, assert len(lhs_batch_dims) == len(rhs_batch_dims) flops = 1 # Flops along a contracting dim is 2*dim (addition and multiplication) + contracting_flops = 1 for i in range(len(lhs_contracting_dims)): lhs_dim, rhs_dim = lhs_contracting_dims[i], rhs_contracting_dims[i] assert x_shape[lhs_dim] == y_shape[rhs_dim] - flops *= 2 * x_shape[lhs_dim] + contracting_flops *= x_shape[lhs_dim] + flops *= 2 * contracting_flops # Now we handle all other dimensions. for i, lhs_dim in enumerate(x_shape): if i in lhs_contracting_dims: diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index b6e3fd4919e2..2d934e329e36 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -67,17 +67,15 @@ jax_multiplatform_test( ]), ) -jax_multiplatform_test( +jax_py_test( name = "pallas_cost_estimate_test", srcs = [ "pallas_cost_estimate_test.py", ], + args = ["--jax_test_dut=cpu"], deps = [ "//jax:pallas", - "//jax:pallas_gpu", - "//jax:pallas_gpu_ops", - "//jax:pallas_tpu", - "//jax:pallas_tpu_ops", + "//jax/_src:test_util", ] + py_deps([ "absl/testing", "numpy", diff --git a/tests/pallas/pallas_cost_estimate_test.py b/tests/pallas/pallas_cost_estimate_test.py index d9eb18e6f540..c9a68844a185 100644 --- a/tests/pallas/pallas_cost_estimate_test.py +++ b/tests/pallas/pallas_cost_estimate_test.py @@ -61,6 +61,23 @@ def matmul(a, b): self.assertEqual(cost.transcendentals, 0) self.assertEqual(cost.bytes_accessed, 4*(b*m*k + b*n*k + b*m*n)) + @parameterized.parameters( + ((10, 11, 12), (11, 12), "abc,bc->a"), + ((10, 11, 12), (13, 11, 12), "abc,dbc->ad"), + ((10, 11, 12), (9, 10, 11, 12), "abc,dabc->d"), + ) + def test_einsum(self, a_shape, b_shape, pattern): + a = jnp.ones(a_shape, dtype=jnp.float32) + b = jnp.ones(b_shape, dtype=jnp.float32) + def matmul(a, b): + return jnp.einsum(pattern, a, b) + cost = cost_estimate.estimate_cost( + matmul, + jax.ShapeDtypeStruct(a_shape, jnp.float32), + jax.ShapeDtypeStruct(b_shape, jnp.float32)) + xla_flops = jax.jit(matmul).lower(a, b).compile().cost_analysis()['flops'] + self.assertEqual(cost.flops, int(xla_flops)) + def test_attention(self): qk_dim = 16 v_dim = 4 From bcfea9822ceeb8748053d2c5bcfd95b43dfc4932 Mon Sep 17 00:00:00 2001 From: Yulia Baturina Date: Wed, 3 Dec 2025 13:13:17 -0800 Subject: [PATCH 022/315] Update hermetic CUDA UMD version to match CUDA driver version installed on RBE machines. PiperOrigin-RevId: 839885172 --- .bazelrc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.bazelrc b/.bazelrc index 344c03a2278c..371f4b78a376 100644 --- a/.bazelrc +++ b/.bazelrc @@ -420,6 +420,8 @@ common:rbe_linux_x86_64 --config=ci_linux_x86_64 common:rbe_linux_x86_64_cuda_common --config=rbe_linux_x86_64_base common:rbe_linux_x86_64_cuda_common --config=rbe_gpu_pool +# Update UMD version when RBE CUDA driver is updated. +common:rbe_linux_x86_64_cuda_common --repo_env=HERMETIC_CUDA_UMD_VERSION="13.0.1" common:rbe_linux_x86_64_cuda12 --config=rbe_linux_x86_64_cuda_common common:rbe_linux_x86_64_cuda12 --config=ci_linux_x86_64_cuda12 From c37f3904c391829c15a3d5ce4f02c6c5c250dfa4 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 3 Dec 2025 13:19:50 -0800 Subject: [PATCH 023/315] Fix deprecated use of arrays in place of dtypes --- jax/_src/dtypes.py | 6 +++++- tests/nn_test.py | 4 ++-- tests/random_test.py | 10 ++++++++-- tests/stax_test.py | 2 +- 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index fbf4c2fd86e1..7373f0ad3815 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -995,7 +995,11 @@ def dtype(x: Any) -> DType: dt = x.dtype else: try: - dt = np.result_type(x) + with warnings.catch_warnings(): + # Ignore warning associated with __numpy_dtype__ change in NumPy 2.4. + # TODO(jakevdp): remove this warning context after change is finalized. + warnings.simplefilter("ignore", DeprecationWarning) + dt = np.result_type(x) except TypeError as err: raise TypeError(f"Cannot determine dtype of {x}") from err if dt not in _jax_dtype_set and not issubdtype(dt, extended): diff --git a/tests/nn_test.py b/tests/nn_test.py index 88bbb4928530..9bae00527eb7 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -856,7 +856,7 @@ def testInitializer(self, initializer, shape, dtype): val = initializer(rng, shape, dtype) self.assertEqual(shape, jnp.shape(val)) - self.assertEqual(jax.dtypes.canonicalize_dtype(dtype), jnp.dtype(val)) + self.assertEqual(jax.dtypes.canonicalize_dtype(dtype), val.dtype) @parameterized.parameters(itertools.chain.from_iterable( jtu.sample_product_testcases( @@ -872,7 +872,7 @@ def testInitializerProvider(self, initializer_provider, shape, dtype): val = initializer(rng, shape) self.assertEqual(shape, jnp.shape(val)) - self.assertEqual(jax.dtypes.canonicalize_dtype(dtype), jnp.dtype(val)) + self.assertEqual(jax.dtypes.canonicalize_dtype(dtype), val.dtype) def testVarianceScalingMultiAxis(self): rng = random.PRNGKey(0) diff --git a/tests/random_test.py b/tests/random_test.py index bf98c50997aa..ce3b2f6a9956 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -654,8 +654,14 @@ def test_issubdtype(self): self.assertFalse(jnp.issubdtype(key.dtype, np.integer)) self.assertFalse(jnp.issubdtype(key.dtype, np.number)) - with self.assertRaisesRegex(TypeError, "Cannot interpret"): - jnp.issubdtype(key, dtypes.prng_key) + if jtu.numpy_version() < (2, 4, 0): + with self.assertRaisesRegex(TypeError, "Cannot interpret"): + jnp.issubdtype(key, dtypes.prng_key) + else: + with jtu.ignore_warning(category=DeprecationWarning, + message="Implicit conversion of an array to a dtype"): + with self.assertRaisesRegex(ValueError, "Could not convert Array"): + jnp.issubdtype(key, dtypes.prng_key) @skipIf(not config.enable_custom_prng.value, 'relies on typed key upgrade flag') def test_construction_upgrade_flag(self): diff --git a/tests/stax_test.py b/tests/stax_test.py index d0ee095fa163..fe86543cee26 100644 --- a/tests/stax_test.py +++ b/tests/stax_test.py @@ -40,7 +40,7 @@ def _CheckShapeAgreement(test_case, init_fun, apply_fun, input_shape): result_shape, params = init_fun(init_key, input_shape) inputs = random_inputs(test_case.rng(), input_shape) if params: - inputs = inputs.astype(np.dtype(params[0])) + inputs = inputs.astype(dtypes.dtype(params[0])) result = apply_fun(params, inputs, rng=rng_key) test_case.assertEqual(result.shape, result_shape) From cdb2b692d8e637fe9e62c6fb1f859c5f128c9668 Mon Sep 17 00:00:00 2001 From: IvyZX Date: Wed, 3 Dec 2025 13:30:38 -0800 Subject: [PATCH 024/315] Add Pallas core_map guide --- docs/conf.py | 2 + docs/pallas/tpu/core_map.ipynb | 628 +++++++++++++++++++++++++++++++++ docs/pallas/tpu/core_map.md | 363 +++++++++++++++++++ docs/pallas/tpu/index.rst | 1 + 4 files changed, 994 insertions(+) create mode 100644 docs/pallas/tpu/core_map.ipynb create mode 100644 docs/pallas/tpu/core_map.md diff --git a/docs/conf.py b/docs/conf.py index 9c3845800bac..ee7cafdf2aaa 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -143,6 +143,7 @@ def _do_not_evaluate_in_jax( 'pallas/tpu/distributed.md', 'pallas/tpu/sparse.md', 'pallas/tpu/matmul.md', + 'pallas/tpu/core_map.md', 'jep/9407-type-promotion.md', 'autodidax.md', 'autodidax2_part1.md', @@ -245,6 +246,7 @@ def _do_not_evaluate_in_jax( 'pallas/tpu/distributed.*', 'pallas/tpu/sparse.*', 'pallas/tpu/matmul.*', + 'pallas/tpu/core_map.*', 'distributed_data_loading.*', 'notebooks/host-offloading.*', ] diff --git a/docs/pallas/tpu/core_map.ipynb b/docs/pallas/tpu/core_map.ipynb new file mode 100644 index 000000000000..38be63d61cb4 --- /dev/null +++ b/docs/pallas/tpu/core_map.ipynb @@ -0,0 +1,628 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "last_runtime": { + "build_target": "//third_party/py/jax_triton/google/pallas_tpu:notebook", + "kind": "private" + } + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Pallas Core-specific Programming" + ], + "metadata": { + "id": "YIt0Za36LYg9" + } + }, + { + "cell_type": "markdown", + "source": [ + "In this guide, we explore using `pl.core_map` to write Pallas kernels. Compared with `pallas_call`, `core_map` offers a few key characteristics:\n", + "\n", + "* **Per-core level programming**: You write code for an TPU/GPU core, not for a JAX device. This gives you full control over what runs on every core, or how cores communicate and distribute work among one another.\n", + "\n", + "* **Collectives**: `core_map` explicitly models physical cores, so inter-core communication can be expressed safely.\n", + "\n", + "* **Platform generic**: `core_map` programming model works for TPU (TensorCore and SparseCore) and GPU with minimal boilerplate changes." + ], + "metadata": { + "id": "khDWSc7aOVts" + } + }, + { + "cell_type": "markdown", + "source": [ + "This guide focuses on TPU. For how to use `core_map` on GPU to achieve higher thread flexibility, check out our [Pallas GPU `core_map` tutorial](https://docs.jax.dev/en/latest/pallas/gpu/reference.html#using-core-map)." + ], + "metadata": { + "id": "i8pl0CLqTVvL" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Environment setup\n", + "\n", + "Modern accelerators often have multiple cores under a device. For recent TPU chips (v4, v5p), every JAX device may contains 2 TensorCores (aka. a [Megacore](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#chips)). Some TPUs (v5p, v6e, 7x) also contain [SparseCores](https://openxla.org/xla/sparsecore#specifications_at_a_glance), each of which consists of many subcores.\n", + "\n", + "This guide was written on a v5p chip, which contains 4 devices (2 TensorCores each) and 4 SparseCores, each with 16 subcores." + ], + "metadata": { + "id": "bsOPXdJkzC-x" + } + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "14PNaMVsLUur", + "executionInfo": { + "status": "ok", + "timestamp": 1764795546418, + "user_tz": 480, + "elapsed": 2087, + "user": { + "displayName": "Ivy Zheng", + "userId": "15297372265856137303" + } + }, + "outputId": "01976bb1-2f2f-40e9-ca23-f0e480a82ab3" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Running on 4 TPU v5p devices.\n" + ] + } + ], + "source": [ + "from functools import partial\n", + "\n", + "import jax\n", + "from jax.sharding import NamedSharding\n", + "from jax.experimental import pallas as pl\n", + "from jax.experimental.pallas import tpu as pltpu\n", + "from jax.experimental.pallas import tpu_sc as plsc\n", + "import jax.numpy as jnp\n", + "import numpy as np\n", + "\n", + "\n", + "num_devices = jax.local_device_count()\n", + "assert num_devices > 1, \"Please run this notebook with more than one device.\"\n", + "\n", + "tpu_info = pltpu.get_tpu_info() # This notebook only runs on TPU.\n", + "print(f\"Running on {num_devices} TPU {tpu_info.chip_version} devices.\")" + ] + }, + { + "cell_type": "markdown", + "source": [ + "In addition to the typical TPU device mesh, you need to make a mesh of cores. Consider this as an addition dimension called `core`, with length 2, in addition to the 4-device mesh you work with. That is 8 cores in total." + ], + "metadata": { + "id": "3f0XEhaYnGyk" + } + }, + { + "cell_type": "code", + "source": [ + "# Mesh of devices\n", + "mesh = jax.make_mesh((jax.device_count(),), ('device',))\n", + "print(mesh)\n", + "\n", + "# Mesh of cores, within a JAX device\n", + "tc_mesh = pltpu.create_tensorcore_mesh('core')\n", + "print(tc_mesh)\n", + "\n", + "num_devices = mesh.size\n", + "num_cores = len(tc_mesh.devices)\n", + "print(f\"There are {num_devices} devices, and {num_cores} cores each.\")" + ], + "metadata": { + "id": "jr5MARD-mIlC", + "executionInfo": { + "status": "ok", + "timestamp": 1764795546665, + "user_tz": 480, + "elapsed": 57, + "user": { + "displayName": "Ivy Zheng", + "userId": "15297372265856137303" + } + }, + "outputId": "1ea63c2f-3aec-4cdd-9674-d0e2df32460c" + }, + "execution_count": 2, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Mesh('device': 4, axis_types=(Explicit,))\n", + "TensorCoreMesh(devices=array([TensorCore(id=0), TensorCore(id=1)], dtype=object), axis_names=('core',))\n", + "There are 4 devices, and 2 cores each.\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## A simple per-core kernel\n", + "\n", + "`pl.core_map` allows you to write per-core local code, just as `jax.shard_map` allows you to write per-device code.\n", + "\n", + "In the example kernel below, each core has its own VMEM and semaphore allocations. As with normal kernel, you can initiate copies between HBM and VMEM refs using `pltpu.async_copy`.\n", + "\n", + "**Communication between cores**\n", + "\n", + "Before communicating between cores, it is good practice to perform a barrier (using `pltpu.semaphore_signal`) to ensure resources have been allocated and both cores are at the same point during the program.\n", + "\n", + "Once the cores are synchronized, use `pltpu.make_async_remote_copy` to send data between them. The `device_id` keyword argument generically allows sending to any core on any device, but if you just pass in `{'core': other_core_id}`, it will perform a intra-device inter-core copy (the other axis names are held constant).\n" + ], + "metadata": { + "id": "CYxwiULfndlh" + } + }, + { + "cell_type": "code", + "source": [ + "# This runs on every core\n", + "def swap_cores_kernel(in_hbm, out_hbm,\n", + " in_vmem, scratch_vmem, out_vmem,\n", + " sem, send_sem, recv_sem):\n", + " core_index = jax.lax.axis_index('core')\n", + " num_cores = jax.lax.axis_size('core')\n", + " slc_size = in_hbm.shape[-1] // num_cores\n", + " slc = pl.ds(core_index * slc_size, slc_size)\n", + "\n", + " # Copy in a core-dependent slice of the input\n", + " pltpu.async_copy(in_hbm.at[:, slc], in_vmem, sem).wait()\n", + "\n", + " # A barrier to make sure all cores have entered run_scoped.\n", + " # You won't need this if not doing inter-core communications.\n", + " dst_core = (core_index + 1) % num_cores\n", + " sem0 = pltpu.get_barrier_semaphore()\n", + " pltpu.semaphore_signal(sem0, 1, device_id={'core': dst_core})\n", + " pltpu.semaphore_wait(sem0, 1)\n", + "\n", + " # Swap data between core 0 and core 1\n", + " the_copy = pltpu.make_async_remote_copy(\n", + " in_vmem, scratch_vmem, send_sem, recv_sem, device_id={'core': dst_core},\n", + " )\n", + " the_copy.start()\n", + " the_copy.wait()\n", + "\n", + " # Core-local compute\n", + " out_vmem[...] = scratch_vmem[...] * 2\n", + "\n", + " # Copy out the output\n", + " pltpu.async_copy(out_vmem, out_hbm.at[:, slc], sem).wait()\n" + ], + "metadata": { + "id": "GkGRT2HRJOUU", + "executionInfo": { + "status": "ok", + "timestamp": 1764795546946, + "user_tz": 480, + "elapsed": 53, + "user": { + "displayName": "Ivy Zheng", + "userId": "15297372265856137303" + } + } + }, + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Once you have the local kernel:\n", + "\n", + " * Start your top-level JAX code with HBM refs, and allocate output refs if needed.\n", + "\n", + " * Use `pl.core_map`, which takes the TensorCore mesh, to start per-core programming.\n", + "\n", + " * You will need `collective_id` for the barrier semaphore.\n", + "\n", + " * Inside `pl.core_map`, invoke `pl.run_scoped` to allocate per-core scratch spaces (VMEM and semaphores) and run the local kernel." + ], + "metadata": { + "id": "2T0tSkFmoFLI" + } + }, + { + "cell_type": "code", + "source": [ + "input_shape = (32, 256)\n", + "local_vmem_shape = (32 // num_devices, 256 // num_cores)\n", + "in_spec = jax.P('device', None)\n", + "sharding = NamedSharding(mesh, in_spec)\n", + "\n", + "@jax.jit\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=in_spec, out_specs=in_spec,\n", + " check_vma=False)\n", + "def swap_cores(x):\n", + " # Get buffers out of the input and output\n", + " x_hbm_ref = jax.new_ref(x)\n", + " o_hbm_ref = jax.new_ref(jax.lax.empty(x.shape, x.dtype))\n", + "\n", + " @pl.core_map(tc_mesh, compiler_params=pltpu.CompilerParams(collective_id=0))\n", + " def _():\n", + " pl.run_scoped(\n", + " partial(swap_cores_kernel, x_hbm_ref, o_hbm_ref),\n", + " *([pltpu.VMEM(local_vmem_shape, x.dtype)] * 3), # VMEM allocations\n", + " *([pltpu.SemaphoreType.DMA] * 3), # semaphores\n", + " )\n", + " return o_hbm_ref[...]\n", + "\n", + "\n", + "x = jax.random.normal(jax.random.key(0), input_shape, jnp.float32)\n", + "x = jax.device_put(x, sharding)\n", + "y = swap_cores(x)\n", + "\n", + "np.testing.assert_array_equal(y[:, 128:], x[:, :128] * 2)\n", + "np.testing.assert_array_equal(y[:, :128], x[:, 128:] * 2)" + ], + "metadata": { + "id": "KT6zkEKi1Sbc", + "executionInfo": { + "status": "ok", + "timestamp": 1764795548996, + "user_tz": 480, + "elapsed": 1800, + "user": { + "displayName": "Ivy Zheng", + "userId": "15297372265856137303" + } + } + }, + "execution_count": 4, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Save the boilerplate\n", + "\n", + "You can use the `pl.kernel` decorator to wrap boilerplate such as `core_map`, `run_scoped`, and output buffer allocation.\n", + "\n", + "Note that this should run inside any `jax.shard_map` you may have at the top level." + ], + "metadata": { + "id": "dLV8sKa4HuSX" + } + }, + { + "cell_type": "code", + "source": [ + "@jax.jit\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=in_spec, out_specs=in_spec, check_vma=False)\n", + "def swap_cores(x):\n", + " scratch_shapes = [pltpu.VMEM(local_vmem_shape, x.dtype)] * 3 + [pltpu.SemaphoreType.DMA] * 3\n", + " return pl.kernel(swap_cores_kernel, out_shape=x, mesh=tc_mesh,\n", + " scratch_shapes=scratch_shapes,\n", + " compiler_params=pltpu.CompilerParams(collective_id=0))(x)\n", + "\n", + "y = swap_cores(x)\n", + "np.testing.assert_array_equal(y[:, 128:], x[:, :128] * 2)\n", + "np.testing.assert_array_equal(y[:, :128], x[:, 128:] * 2)" + ], + "metadata": { + "id": "7cHnsRHPHyfH", + "executionInfo": { + "status": "ok", + "timestamp": 1764795549347, + "user_tz": 480, + "elapsed": 106, + "user": { + "displayName": "Ivy Zheng", + "userId": "15297372265856137303" + } + } + }, + "execution_count": 5, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Pipelining with `core_map`\n", + "\n", + "Note that the kernel above only does simple copies and compute, without automatic pipelining via Pallas `grid` and `BlockSpec`. To do pipelining inside `core_map`, use `pltpu.emit_pipeline` inside the core-local kernel.\n", + "\n", + "**Automatically parallelize work amongst cores**\n", + "\n", + "The simple way is to annotate a block axis as `pltpu.PARALLEL`, and Pallas will automatically parallelize work along this axis. Both `pl.pallas_call` and `pltpu.emit_pipeline` supports this, via arguments `core_axis` and `dimension_semantics`. The `pallas_call` example is [in another guide](https://docs.jax.dev/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration), and the `emit_pipeline` case is shown below.\n", + "\n", + "When the `PARALLEL` annotation is provided, the corresponding grid dimension will be logically split and executed on separate cores. (The exact semantics of which grid dimensions are executed on which core is guaranteed).\n", + "\n", + "**Scratch shapes allocation**\n", + "\n", + "Note that in the example below, the top level `pl.run_scoped` (wrapped inside `kernel`) did not allocate any VMEM scratch buffers. Instead, `pltpu.emit_pipeline` allocates its own scratch buffers in VMEM and use them for its multiple buffering.\n" + ], + "metadata": { + "id": "4-G--Wnysdjs" + } + }, + { + "cell_type": "code", + "source": [ + "def add_one_body(in_vmem, out_vmem):\n", + " out_vmem[...] = in_vmem[...] + 1\n", + "\n", + "input_shape = (1024, 1024)\n", + "in_spec = jax.P('device', None)\n", + "\n", + "def add_one_kernel(x_hbm_ref, o_hbm_ref):\n", + " in_shape = x_hbm_ref.shape\n", + " pltpu.emit_pipeline(\n", + " add_one_body,\n", + " grid=(in_shape[0] // 8, in_shape[1] // 128),\n", + " in_specs=[pl.BlockSpec(\n", + " block_shape=(8, 128), index_map=lambda i, j: (i, j),\n", + " )],\n", + " out_specs=[pl.BlockSpec(\n", + " block_shape=(8, 128), index_map=lambda i, j: (i, j),\n", + " )],\n", + " core_axis_name='core',\n", + " dimension_semantics=(pltpu.PARALLEL, pltpu.ARBITRARY),\n", + " )(x_hbm_ref, o_hbm_ref)\n", + "\n", + "\n", + "@jax.jit\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=in_spec, out_specs=in_spec, check_vma=False)\n", + "def add_one(x):\n", + " return pl.kernel(add_one_kernel, out_shape=x, mesh=tc_mesh, scratch_shapes=[])(x)\n", + "\n", + "\n", + "x = jax.random.normal(jax.random.key(0), input_shape, jnp.float32)\n", + "x = jax.device_put(x, NamedSharding(mesh, in_spec))\n", + "y = add_one(x)\n", + "\n", + "np.testing.assert_array_equal(y, x + 1)" + ], + "metadata": { + "id": "xUMRPLxb1rEH", + "executionInfo": { + "status": "ok", + "timestamp": 1764795550106, + "user_tz": 480, + "elapsed": 518, + "user": { + "displayName": "Ivy Zheng", + "userId": "15297372265856137303" + } + } + }, + "execution_count": 6, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Scalar prefetch\n", + "\n", + "The code below extends the kernel above but uses [scalar prefetch and dynamic block indexing](https://docs.jax.dev/en/latest/pallas/tpu/sparse.html) to select a specific sub-slice of the input.\n", + "\n", + "This involves pre-allocating an SMEM buffer (via the `pl.run_scoped` call inside `kernel`) and populating the buffer using a `sync_copy` before the pipeline starts. Close over the dynamic index value inside the `index_map` to use it.\n", + "\n", + "**Manually delegate work amongst cores**\n", + "\n", + "The code example below also shows how `core_map` allows you to customize exactly how the work is split between cores, without relying on the automatic API shown above.\n", + "\n", + "To achieve that, customize your `index_map` to use the core index to work on different slices on different cores.\n" + ], + "metadata": { + "id": "Cq5rYyvL2Tte" + } + }, + { + "cell_type": "code", + "source": [ + "input_shape = (1024, 1024)\n", + "in_spec = jax.P('device', None)\n", + "output_shape = (1024, 512)\n", + "\n", + "def indexed_add_one_kernel(in_refs, out_refs, i_smem_ref):\n", + " (x_hbm_ref, i_hbm_ref), o_hbm_ref = in_refs, out_refs\n", + " in_shape = x_hbm_ref.shape\n", + " pltpu.sync_copy(i_hbm_ref, i_smem_ref)\n", + "\n", + " core_idx = jax.lax.axis_index('core')\n", + " core_slc_size = in_shape[0] // num_cores\n", + " i_map = lambda i: core_idx * core_slc_size // 8 + i # split work among cores\n", + " j_map = lambda j: i_smem_ref[0] // 128 + j # use the prefetched offset\n", + "\n", + " pltpu.emit_pipeline(\n", + " add_one_body,\n", + " grid=(core_slc_size // 8, output_shape[1] // 128),\n", + " in_specs=[pl.BlockSpec(\n", + " block_shape=(8, 128), index_map=lambda i, j: (i_map(i), j_map(j)),\n", + " )],\n", + " out_specs=[pl.BlockSpec(\n", + " block_shape=(8, 128), index_map=lambda i, j: (i_map(i), j),\n", + " )]\n", + " )(x_hbm_ref, o_hbm_ref)\n", + "\n", + "\n", + "@jax.jit\n", + "@partial(jax.shard_map, mesh=mesh,\n", + " in_specs=(in_spec, jax.P()), out_specs=in_spec, check_vma=False)\n", + "def indexed_add_one(x, index):\n", + " out_shape = jax.ShapeDtypeStruct((x.shape[0], x.shape[1] // 2), x.dtype)\n", + " return pl.kernel(indexed_add_one_kernel,\n", + " out_shape=out_shape, mesh=tc_mesh,\n", + " scratch_shapes=[pltpu.SMEM((1,), jnp.int32)])((x, index))\n", + "\n", + "\n", + "xs = jax.random.normal(jax.random.key(0), input_shape, jnp.float32)\n", + "xs = jax.device_put(xs, NamedSharding(mesh, in_spec))\n", + "idx = 256\n", + "y = indexed_add_one(xs, jnp.array([idx]))\n", + "\n", + "np.testing.assert_array_equal(y, xs[:, idx:(idx+512)] + 1)" + ], + "metadata": { + "id": "SE8pTStHeSWB", + "executionInfo": { + "status": "ok", + "timestamp": 1764795550778, + "user_tz": 480, + "elapsed": 378, + "user": { + "displayName": "Ivy Zheng", + "userId": "15297372265856137303" + } + } + }, + "execution_count": 7, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Mapping over SparseCores\n", + "\n", + "TPU v5p contains 4 [SparseCores](https://openxla.org/xla/sparsecore), which are specialized for sparse memory access and operations. This guide will not dive into the full capabilities of SparseCore, but rather show how to run a program on SparseCore with the same semantics and minimal changes from the TensorCore code.\n", + "\n", + "Start with knowing the basic SparseCore specs of your chip, and create a `VectorSubcoreMesh` for vector operations. Note that each SparseCore has 16 (or other number) subcores on TPU v5p, and `core_map` will run your code SPMD on each of them." + ], + "metadata": { + "id": "B8qeo-4A2KRm" + } + }, + { + "cell_type": "code", + "source": [ + "sc_info = pltpu.get_tpu_info().sparse_core\n", + "assert sc_info is not None\n", + "print(sc_info)\n", + "\n", + "sc_mesh = plsc.VectorSubcoreMesh(\n", + " core_axis_name=\"core\", subcore_axis_name=\"subcore\",\n", + " num_cores=sc_info.num_cores\n", + ")\n", + "sc_num_cores = sc_info.num_cores\n", + "sc_num_subcores = sc_info.num_subcores" + ], + "metadata": { + "id": "AHurx-yyYVvs", + "executionInfo": { + "status": "ok", + "timestamp": 1764795551102, + "user_tz": 480, + "elapsed": 55, + "user": { + "displayName": "Ivy Zheng", + "userId": "15297372265856137303" + } + }, + "outputId": "aa4a45da-dd9a-4f57-de1a-bc9b5872b2df" + }, + "execution_count": 8, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "SparseCoreInfo(num_cores=4, num_subcores=16, num_lanes=8)\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "The code below is very similar to the `add_one_kernel` we wrote earlier, except for a few differences:\n", + "\n", + "1. You need to split the work amongst all subcores, so a few lines to compute the specific slice for each subcore.\n", + "\n", + "1. SparseCore register computation allows smaller slices (`4x16` max for int32), so you need nested loops to iterate the slice during computation phase." + ], + "metadata": { + "id": "n2_dfsUWFgwU" + } + }, + { + "cell_type": "code", + "source": [ + "input_shape = (4096, 128)\n", + "SC_REG_OP_SHAPE = (4, 16)\n", + "\n", + "def sc_add_one_body(in_vmem, out_vmem):\n", + " @pl.loop(0, in_vmem.shape[0], step=SC_REG_OP_SHAPE[0])\n", + " def _reg_loop_0(c0):\n", + " @pl.loop(0, in_vmem.shape[1], step=SC_REG_OP_SHAPE[1])\n", + " def _reg_loop_1(c1):\n", + " slc = (pl.ds(c0, SC_REG_OP_SHAPE[0]), pl.ds(c1, SC_REG_OP_SHAPE[1]))\n", + " out_vmem[slc] = in_vmem[slc] + 1\n", + "\n", + "\n", + "def sc_add_one_kernel(x_hbm_ref, o_hbm_ref):\n", + " in_shape = x_hbm_ref.shape\n", + " core_idx = jax.lax.axis_index('core')\n", + " subcore_idx = jax.lax.axis_index(\"subcore\")\n", + " cm_idx = core_idx * sc_num_subcores + subcore_idx # index on the core_map\n", + " slc_size = in_shape[0] // (sc_num_subcores * sc_num_cores)\n", + " index_map = lambda i, j: (\n", + " pl.ds(pl.multiple_of(cm_idx * slc_size + i * 8, 8), 8), j)\n", + "\n", + " pltpu.emit_pipeline(\n", + " sc_add_one_body,\n", + " grid=(slc_size // 8, in_shape[1] // 128),\n", + " in_specs=[pl.BlockSpec(\n", + " block_shape=(pl.BoundedSlice(8), 128), index_map=index_map,\n", + " )],\n", + " out_specs=[pl.BlockSpec(\n", + " block_shape=(pl.BoundedSlice(8), 128), index_map=index_map,\n", + " )]\n", + " )(x_hbm_ref, o_hbm_ref)\n", + "\n", + "\n", + "@jax.jit\n", + "@partial(jax.shard_map, mesh=mesh, in_specs=in_spec, out_specs=in_spec, check_vma=False)\n", + "def sc_add_one(x):\n", + " return pl.kernel(sc_add_one_kernel, out_shape=x, mesh=sc_mesh, scratch_shapes=[])(x)\n", + "\n", + "\n", + "x = jax.random.randint(jax.random.key(0), input_shape, 0, 64, jnp.int32)\n", + "x = jax.device_put(x, NamedSharding(mesh, in_spec))\n", + "y = sc_add_one(x)\n", + "\n", + "np.testing.assert_array_equal(y, x + 1)" + ], + "metadata": { + "id": "6fNShx6k2kxi", + "executionInfo": { + "status": "ok", + "timestamp": 1764795552411, + "user_tz": 480, + "elapsed": 1117, + "user": { + "displayName": "Ivy Zheng", + "userId": "15297372265856137303" + } + } + }, + "execution_count": 9, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/docs/pallas/tpu/core_map.md b/docs/pallas/tpu/core_map.md new file mode 100644 index 000000000000..4e00399b6fe8 --- /dev/null +++ b/docs/pallas/tpu/core_map.md @@ -0,0 +1,363 @@ +# Pallas Core-specific Programming + +In this guide, we explore using `pl.core_map` to write Pallas kernels. Compared with `pallas_call`, `core_map` offers a few key characteristics: + +* **Per-core level programming**: You write code for an TPU/GPU core, not for a JAX device. This gives you full control over what runs on every core, or how cores communicate and distribute work among one another. + +* **Collectives**: `core_map` explicitly models physical cores, so inter-core communication can be expressed safely. + +* **Platform generic**: `core_map` programming model works for TPU (TensorCore and SparseCore) and GPU with minimal boilerplate changes. + +This guide focuses on TPU. For how to use `core_map` on GPU to achieve higher thread flexibility, check out our [Pallas GPU `core_map` tutorial](https://docs.jax.dev/en/latest/pallas/gpu/reference.html#using-core-map). + +## Environment setup + +Modern accelerators often have multiple cores under a device. For recent TPU chips (v4, v5p), every JAX device may contains 2 TensorCores (aka. a [Megacore](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#chips)). Some TPUs (v5p, v6e, 7x) also contain [SparseCores](https://openxla.org/xla/sparsecore#specifications_at_a_glance), each of which consists of many subcores. + +This guide was written on a v5p chip, which contains 4 devices (2 TensorCores each) and 4 SparseCores, each with 16 subcores. + + +```python +from functools import partial + +import jax +from jax.sharding import NamedSharding +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +from jax.experimental.pallas import tpu_sc as plsc +import jax.numpy as jnp +import numpy as np + + +num_devices = jax.local_device_count() +assert num_devices > 1, "Please run this notebook with more than one device." + +tpu_info = pltpu.get_tpu_info() # This notebook only runs on TPU. +print(f"Running on {num_devices} TPU {tpu_info.chip_version} devices.") +``` + + Running on 4 TPU v5p devices. + + +In addition to the typical TPU device mesh, you need to make a mesh of cores. Consider this as an addition dimension called `core`, with length 2, in addition to the 4-device mesh you work with. That is 8 cores in total. + + +```python +# Mesh of devices +mesh = jax.make_mesh((jax.device_count(),), ('device',)) +print(mesh) + +# Mesh of cores, within a JAX device +tc_mesh = pltpu.create_tensorcore_mesh('core') +print(tc_mesh) + +num_devices = mesh.size +num_cores = len(tc_mesh.devices) +print(f"There are {num_devices} devices, and {num_cores} cores each.") +``` + + Mesh('device': 4, axis_types=(Explicit,)) + TensorCoreMesh(devices=array([TensorCore(id=0), TensorCore(id=1)], dtype=object), axis_names=('core',)) + There are 4 devices, and 2 cores each. + + +## A simple per-core kernel + +`pl.core_map` allows you to write per-core local code, just as `jax.shard_map` allows you to write per-device code. + +In the example kernel below, each core has its own VMEM and semaphore allocations. As with normal kernel, you can initiate copies between HBM and VMEM refs using `pltpu.async_copy`. + +**Communication between cores** + +Before communicating between cores, it is good practice to perform a barrier (using `pltpu.semaphore_signal`) to ensure resources have been allocated and both cores are at the same point during the program. + +Once the cores are synchronized, use `pltpu.make_async_remote_copy` to send data between them. The `device_id` keyword argument generically allows sending to any core on any device, but if you just pass in `{'core': other_core_id}`, it will perform a intra-device inter-core copy (the other axis names are held constant). + + + +```python +# This runs on every core +def swap_cores_kernel(in_hbm, out_hbm, + in_vmem, scratch_vmem, out_vmem, + sem, send_sem, recv_sem): + core_index = jax.lax.axis_index('core') + num_cores = jax.lax.axis_size('core') + slc_size = in_hbm.shape[-1] // num_cores + slc = pl.ds(core_index * slc_size, slc_size) + + # Copy in a core-dependent slice of the input + pltpu.async_copy(in_hbm.at[:, slc], in_vmem, sem).wait() + + # A barrier to make sure all cores have entered run_scoped. + # You won't need this if not doing inter-core communications. + dst_core = (core_index + 1) % num_cores + sem0 = pltpu.get_barrier_semaphore() + pltpu.semaphore_signal(sem0, 1, device_id={'core': dst_core}) + pltpu.semaphore_wait(sem0, 1) + + # Swap data between core 0 and core 1 + the_copy = pltpu.make_async_remote_copy( + in_vmem, scratch_vmem, send_sem, recv_sem, device_id={'core': dst_core}, + ) + the_copy.start() + the_copy.wait() + + # Core-local compute + out_vmem[...] = scratch_vmem[...] * 2 + + # Copy out the output + pltpu.async_copy(out_vmem, out_hbm.at[:, slc], sem).wait() + +``` + +Once you have the local kernel: + + * Start your top-level JAX code with HBM refs, and allocate output refs if needed. + + * Use `pl.core_map`, which takes the TensorCore mesh, to start per-core programming. + + * You will need `collective_id` for the barrier semaphore. + + * Inside `pl.core_map`, invoke `pl.run_scoped` to allocate per-core scratch spaces (VMEM and semaphores) and run the local kernel. + + +```python +input_shape = (32, 256) +local_vmem_shape = (32 // num_devices, 256 // num_cores) +in_spec = jax.P('device', None) +sharding = NamedSharding(mesh, in_spec) + +@jax.jit +@partial(jax.shard_map, mesh=mesh, in_specs=in_spec, out_specs=in_spec, + check_vma=False) +def swap_cores(x): + # Get buffers out of the input and output + x_hbm_ref = jax.new_ref(x) + o_hbm_ref = jax.new_ref(jax.lax.empty(x.shape, x.dtype)) + + @pl.core_map(tc_mesh, compiler_params=pltpu.CompilerParams(collective_id=0)) + def _(): + pl.run_scoped( + partial(swap_cores_kernel, x_hbm_ref, o_hbm_ref), + *([pltpu.VMEM(local_vmem_shape, x.dtype)] * 3), # VMEM allocations + *([pltpu.SemaphoreType.DMA] * 3), # semaphores + ) + return o_hbm_ref[...] + + +x = jax.random.normal(jax.random.key(0), input_shape, jnp.float32) +x = jax.device_put(x, sharding) +y = swap_cores(x) + +np.testing.assert_array_equal(y[:, 128:], x[:, :128] * 2) +np.testing.assert_array_equal(y[:, :128], x[:, 128:] * 2) +``` + +### Save the boilerplate + +You can use the `pl.kernel` decorator to wrap boilerplate such as `core_map`, `run_scoped`, and output buffer allocation. + +Note that this should run inside any `jax.shard_map` you may have at the top level. + + +```python +@jax.jit +@partial(jax.shard_map, mesh=mesh, in_specs=in_spec, out_specs=in_spec, check_vma=False) +def swap_cores(x): + scratch_shapes = [pltpu.VMEM(local_vmem_shape, x.dtype)] * 3 + [pltpu.SemaphoreType.DMA] * 3 + return pl.kernel(swap_cores_kernel, out_shape=x, mesh=tc_mesh, + scratch_shapes=scratch_shapes, + compiler_params=pltpu.CompilerParams(collective_id=0))(x) + +y = swap_cores(x) +np.testing.assert_array_equal(y[:, 128:], x[:, :128] * 2) +np.testing.assert_array_equal(y[:, :128], x[:, 128:] * 2) +``` + +## Pipelining with `core_map` + +Note that the kernel above only does simple copies and compute, without automatic pipelining via Pallas `grid` and `BlockSpec`. To do pipelining inside `core_map`, use `pltpu.emit_pipeline` inside the core-local kernel. + +**Automatically parallelize work amongst cores** + +The simple way is to annotate a block axis as `pltpu.PARALLEL`, and Pallas will automatically parallelize work along this axis. Both `pl.pallas_call` and `pltpu.emit_pipeline` supports this, via arguments `core_axis` and `dimension_semantics`. The `pallas_call` example is [in another guide](https://docs.jax.dev/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration), and the `emit_pipeline` case is shown below. + +When the `PARALLEL` annotation is provided, the corresponding grid dimension will be logically split and executed on separate cores. (The exact semantics of which grid dimensions are executed on which core is guaranteed). + +**Scratch shapes allocation** + +Note that in the example below, the top level `pl.run_scoped` (wrapped inside `kernel`) did not allocate any VMEM scratch buffers. Instead, `pltpu.emit_pipeline` allocates its own scratch buffers in VMEM and use them for its multiple buffering. + + + +```python +def add_one_body(in_vmem, out_vmem): + out_vmem[...] = in_vmem[...] + 1 + +input_shape = (1024, 1024) +in_spec = jax.P('device', None) + +def add_one_kernel(x_hbm_ref, o_hbm_ref): + in_shape = x_hbm_ref.shape + pltpu.emit_pipeline( + add_one_body, + grid=(in_shape[0] // 8, in_shape[1] // 128), + in_specs=[pl.BlockSpec( + block_shape=(8, 128), index_map=lambda i, j: (i, j), + )], + out_specs=[pl.BlockSpec( + block_shape=(8, 128), index_map=lambda i, j: (i, j), + )], + core_axis_name='core', + dimension_semantics=(pltpu.PARALLEL, pltpu.ARBITRARY), + )(x_hbm_ref, o_hbm_ref) + + +@jax.jit +@partial(jax.shard_map, mesh=mesh, in_specs=in_spec, out_specs=in_spec, check_vma=False) +def add_one(x): + return pl.kernel(add_one_kernel, out_shape=x, mesh=tc_mesh, scratch_shapes=[])(x) + + +x = jax.random.normal(jax.random.key(0), input_shape, jnp.float32) +x = jax.device_put(x, NamedSharding(mesh, in_spec)) +y = add_one(x) + +np.testing.assert_array_equal(y, x + 1) +``` + +## Scalar prefetch + +The code below extends the kernel above but uses [scalar prefetch and dynamic block indexing](https://docs.jax.dev/en/latest/pallas/tpu/sparse.html) to select a specific sub-slice of the input. + +This involves pre-allocating an SMEM buffer (via the `pl.run_scoped` call inside `kernel`) and populating the buffer using a `sync_copy` before the pipeline starts. Close over the dynamic index value inside the `index_map` to use it. + +**Manually delegate work amongst cores** + +The code example below also shows how `core_map` allows you to customize exactly how the work is split between cores, without relying on the automatic API shown above. + +To achieve that, customize your `index_map` to use the core index to work on different slices on different cores. + + + +```python +input_shape = (1024, 1024) +in_spec = jax.P('device', None) +output_shape = (1024, 512) + +def indexed_add_one_kernel(in_refs, out_refs, i_smem_ref): + (x_hbm_ref, i_hbm_ref), o_hbm_ref = in_refs, out_refs + in_shape = x_hbm_ref.shape + pltpu.sync_copy(i_hbm_ref, i_smem_ref) + + core_idx = jax.lax.axis_index('core') + core_slc_size = in_shape[0] // num_cores + i_map = lambda i: core_idx * core_slc_size // 8 + i # split work among cores + j_map = lambda j: i_smem_ref[0] // 128 + j # use the prefetched offset + + pltpu.emit_pipeline( + add_one_body, + grid=(core_slc_size // 8, output_shape[1] // 128), + in_specs=[pl.BlockSpec( + block_shape=(8, 128), index_map=lambda i, j: (i_map(i), j_map(j)), + )], + out_specs=[pl.BlockSpec( + block_shape=(8, 128), index_map=lambda i, j: (i_map(i), j), + )] + )(x_hbm_ref, o_hbm_ref) + + +@jax.jit +@partial(jax.shard_map, mesh=mesh, + in_specs=(in_spec, jax.P()), out_specs=in_spec, check_vma=False) +def indexed_add_one(x, index): + out_shape = jax.ShapeDtypeStruct((x.shape[0], x.shape[1] // 2), x.dtype) + return pl.kernel(indexed_add_one_kernel, + out_shape=out_shape, mesh=tc_mesh, + scratch_shapes=[pltpu.SMEM((1,), jnp.int32)])((x, index)) + + +xs = jax.random.normal(jax.random.key(0), input_shape, jnp.float32) +xs = jax.device_put(xs, NamedSharding(mesh, in_spec)) +idx = 256 +y = indexed_add_one(xs, jnp.array([idx])) + +np.testing.assert_array_equal(y, xs[:, idx:(idx+512)] + 1) +``` + +## Mapping over SparseCores + +TPU v5p contains 4 [SparseCores](https://openxla.org/xla/sparsecore), which are specialized for sparse memory access and operations. This guide will not dive into the full capabilities of SparseCore, but rather show how to run a program on SparseCore with the same semantics and minimal changes from the TensorCore code. + +Start with knowing the basic SparseCore specs of your chip, and create a `VectorSubcoreMesh` for vector operations. Note that each SparseCore has 16 (or other number) subcores on TPU v5p, and `core_map` will run your code SPMD on each of them. + + +```python +sc_info = pltpu.get_tpu_info().sparse_core +assert sc_info is not None +print(sc_info) + +sc_mesh = plsc.VectorSubcoreMesh( + core_axis_name="core", subcore_axis_name="subcore", + num_cores=sc_info.num_cores +) +sc_num_cores = sc_info.num_cores +sc_num_subcores = sc_info.num_subcores +``` + + SparseCoreInfo(num_cores=4, num_subcores=16, num_lanes=8) + + +The code below is very similar to the `add_one_kernel` we wrote earlier, except for a few differences: + +1. You need to split the work amongst all subcores, so a few lines to compute the specific slice for each subcore. + +1. SparseCore register computation allows smaller slices (`4x16` max for int32), so you need nested loops to iterate the slice during computation phase. + + +```python +input_shape = (4096, 128) +SC_REG_OP_SHAPE = (4, 16) + +def sc_add_one_body(in_vmem, out_vmem): + @pl.loop(0, in_vmem.shape[0], step=SC_REG_OP_SHAPE[0]) + def _reg_loop_0(c0): + @pl.loop(0, in_vmem.shape[1], step=SC_REG_OP_SHAPE[1]) + def _reg_loop_1(c1): + slc = (pl.ds(c0, SC_REG_OP_SHAPE[0]), pl.ds(c1, SC_REG_OP_SHAPE[1])) + out_vmem[slc] = in_vmem[slc] + 1 + + +def sc_add_one_kernel(x_hbm_ref, o_hbm_ref): + in_shape = x_hbm_ref.shape + core_idx = jax.lax.axis_index('core') + subcore_idx = jax.lax.axis_index("subcore") + cm_idx = core_idx * sc_num_subcores + subcore_idx # index on the core_map + slc_size = in_shape[0] // (sc_num_subcores * sc_num_cores) + index_map = lambda i, j: ( + pl.ds(pl.multiple_of(cm_idx * slc_size + i * 8, 8), 8), j) + + pltpu.emit_pipeline( + sc_add_one_body, + grid=(slc_size // 8, in_shape[1] // 128), + in_specs=[pl.BlockSpec( + block_shape=(pl.BoundedSlice(8), 128), index_map=index_map, + )], + out_specs=[pl.BlockSpec( + block_shape=(pl.BoundedSlice(8), 128), index_map=index_map, + )] + )(x_hbm_ref, o_hbm_ref) + + +@jax.jit +@partial(jax.shard_map, mesh=mesh, in_specs=in_spec, out_specs=in_spec, check_vma=False) +def sc_add_one(x): + return pl.kernel(sc_add_one_kernel, out_shape=x, mesh=sc_mesh, scratch_shapes=[])(x) + + +x = jax.random.randint(jax.random.key(0), input_shape, 0, 64, jnp.int32) +x = jax.device_put(x, NamedSharding(mesh, in_spec)) +y = sc_add_one(x) + +np.testing.assert_array_equal(y, x + 1) +``` diff --git a/docs/pallas/tpu/index.rst b/docs/pallas/tpu/index.rst index 1aca99f0dda9..784e037bca06 100644 --- a/docs/pallas/tpu/index.rst +++ b/docs/pallas/tpu/index.rst @@ -11,5 +11,6 @@ TPU specific documentation. matmul sparse distributed + core_map prng From c328f0625ef71aac2819e8aba2f0c7cad2122115 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Wed, 3 Dec 2025 21:34:33 +0000 Subject: [PATCH 025/315] improve error message and hasattr performance for Tracer.sharding Co-authored-by: Yash Katariya --- jax/_src/core.py | 4 ++-- tests/api_test.py | 8 ++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 7950f281f33a..6da26eb1e2a5 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1064,8 +1064,8 @@ def __getattr__(self, name): if name == 'sharding': raise AttributeError( - f"The 'sharding' attribute is not available on {self._error_repr()}." - f"{self._origin_msg()}") + f"The 'sharding' attribute is not available on {self._error_repr()}. " + "To query sharding information on tracers, use `jax.typeof(x)`.") try: attr = getattr(self.aval, name) diff --git a/tests/api_test.py b/tests/api_test.py index f479739a1042..8dde8b47d6c2 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -5374,6 +5374,14 @@ def f(x): jax.vmap(f)(jnp.arange(3.)) # don't crash + def test_sharding_attr_on_tracer_error(self): + @jax.jit + def f(x): + with self.assertRaisesRegex(AttributeError, 'typeof'): + x.sharding + + f(jnp.arange(2.)) + class RematTest(jtu.JaxTestCase): From 92c283e916dd4da19c25c3d3d3317fbb2921ebf8 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 3 Dec 2025 14:09:29 -0800 Subject: [PATCH 026/315] BUILD fixes. PiperOrigin-RevId: 839908432 --- tests/BUILD | 2 +- tests/pallas/BUILD | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/BUILD b/tests/BUILD index 11966beb37b9..73d18edcbe77 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -271,7 +271,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "extend_test", srcs = ["extend_test.py"], - deps = ["//jax:extend"] + py_deps("absl/testing"), + deps = ["//jax/extend"] + py_deps("absl/testing"), ) jax_multiplatform_test( diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 2d934e329e36..a32ea7dace43 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -425,10 +425,10 @@ jax_multiplatform_test( ], shard_count = 8, deps = [ - "//jax:extend", "//jax:mesh_utils", "//jax:pallas_tpu", "//jax:pallas_tpu_ops", + "//jax/extend", ] + py_deps([ "absl/testing", "numpy", @@ -453,9 +453,9 @@ jax_multiplatform_test( ), shard_count = 8, deps = [ - "//jax:extend", "//jax:pallas_tpu", "//jax:pallas_tpu_ops", + "//jax/extend", ] + py_deps([ "absl/testing", "numpy", @@ -483,10 +483,10 @@ jax_multiplatform_test( "multiaccelerator", ], deps = [ - "//jax:extend", "//jax:pallas_mosaic_gpu", "//jax/_src:test_multiprocess", "//jax/experimental:pallas_experimental_gpu_ops", + "//jax/extend", ] + py_deps([ "portpicker", "absl/testing", @@ -522,10 +522,10 @@ jax_multiplatform_test( enable_backends = ["tpu"], tags = ["multiaccelerator"], deps = [ - "//jax:extend", "//jax:mesh_utils", "//jax:pallas_tpu", "//jax:pallas_tpu_ops", + "//jax/extend", ] + py_deps([ "absl/testing", "numpy", @@ -556,10 +556,10 @@ jax_multiplatform_test( "notsan", # Times out. ], deps = [ - "//jax:extend", "//jax:mesh_utils", "//jax:pallas_tpu", "//jax:pallas_tpu_ops", + "//jax/extend", ] + py_deps([ "absl/testing", "numpy", @@ -623,8 +623,8 @@ jax_multiplatform_test( "notsan", ], deps = [ - "//jax:extend", "//jax:pallas_tpu", + "//jax/extend", ] + py_deps([ "absl/testing", "numpy", @@ -773,9 +773,9 @@ jax_multiplatform_test( ], shard_count = 10, deps = [ - "//jax:extend", "//jax:pallas_tpu", "//jax:pallas_tpu_ops", + "//jax/extend", ] + py_deps([ "absl/testing", "numpy", @@ -1260,10 +1260,10 @@ jax_multiplatform_test( "multiaccelerator", ], deps = [ - "//jax:extend", "//jax:pallas_mosaic_gpu", "//jax/_src:test_multiprocess", "//jax/experimental:pallas_experimental_gpu_ops", + "//jax/extend", ] + py_deps([ "portpicker", "absl/testing", From 19ab33dd20d2da782190d885f375ce6ce89767ad Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Wed, 3 Dec 2025 14:19:56 -0800 Subject: [PATCH 027/315] Reverts 23cd412cebf6ac17b27b3d0699aced5c9002ae89 PiperOrigin-RevId: 839912705 --- jax/_src/interpreters/partial_eval.py | 4 +- jax/_src/lax/control_flow/__init__.py | 3 + jax/_src/lax/control_flow/common.py | 78 +++++++++++++++-------- jax/_src/lax/control_flow/conditionals.py | 64 +++++++++++++++---- jax/_src/lax/control_flow/loops.py | 9 +-- jax/_src/lax/control_flow/solves.py | 16 ++--- tests/api_test.py | 10 +-- tests/core_test.py | 4 ++ tests/lax_control_flow_test.py | 58 ++++++++++------- tests/metadata_test.py | 2 +- 10 files changed, 163 insertions(+), 85 deletions(-) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 757718104bc2..a3dbfc55714f 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -2400,7 +2400,7 @@ def trace_to_jaxpr( in_tree: PyTreeDef, in_avals_flat: Sequence[AbstractValue | core.AvalQDD], debug_info: core.DebugInfo -) -> tuple[ClosedJaxpr, PyTreeDef, list[Any]]: +) -> tuple[Jaxpr, PyTreeDef, list[Any]]: config.enable_checks.value and debug_info.assert_arg_names(len(in_avals_flat)) parent_trace = core.trace_ctx.trace trace = DynamicJaxprTrace(debug_info, parent_trace=parent_trace) @@ -2424,8 +2424,6 @@ def trace_to_jaxpr( del trace, fun, in_tracers_flat, in_tracers, out_tracers, ans, ans_flat config.enable_checks.value and core.check_jaxpr(jaxpr) - # TODO(dougalm): remove this once we merge Jaxpr and ClosedJaxpr - jaxpr = close_jaxpr(convert_constvars_jaxpr(jaxpr)) return jaxpr, out_tree, consts diff --git a/jax/_src/lax/control_flow/__init__.py b/jax/_src/lax/control_flow/__init__.py index 5cbe5a39d381..44ee94e14ca2 100644 --- a/jax/_src/lax/control_flow/__init__.py +++ b/jax/_src/lax/control_flow/__init__.py @@ -50,6 +50,9 @@ # Private utilities used elsewhere in JAX # TODO(sharadmv): lift them into a more common place from jax._src.lax.control_flow.common import ( + _initial_style_open_jaxpr as _initial_style_open_jaxpr, + _initial_style_jaxpr as _initial_style_jaxpr, + _initial_style_jaxprs_with_common_consts as _initial_style_jaxprs_with_common_consts, _check_tree_and_avals as _check_tree_and_avals, ) diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py index 3d78d193ddc3..d29746560b46 100644 --- a/jax/_src/lax/control_flow/common.py +++ b/jax/_src/lax/control_flow/common.py @@ -43,54 +43,78 @@ def _typecheck_param(prim, param, name, msg_required, pred): msg = sep.join([msg, param_str]) raise core.JaxprTypeError(msg) -# TODO(dougalm): this seems way too complicated. Why not allow different consts for each -# branch of a switch? -def _merge_common_consts( - jaxprs: Sequence[core.Jaxpr], - all_consts: Sequence[Sequence[Any]] - ) -> tuple[Sequence[core.ClosedJaxpr], Sequence[Any]]: +# TODO(dougalm): this is a silly wrapper now. Delete it. +@weakref_lru_cache +def _initial_style_open_jaxpr(fun: Callable, + in_tree: PyTreeDef, + in_avals: Sequence[core.AbstractValue | core.AvalQDD], + debug_info: core.DebugInfo): + jaxpr, out_tree, consts = pe.trace_to_jaxpr(fun, in_tree, in_avals, debug_info) + return jaxpr, consts, out_tree + +# TODO(dougalm): Delete. Make `trace_to_jaxpr` do the jaxpr-closing thing instead. +@weakref_lru_cache +def _initial_style_jaxpr(fun: Callable, + in_tree: PyTreeDef, + in_avals: Sequence[core.AbstractValue], + debug_info: core.DebugInfo) -> tuple[core.ClosedJaxpr, Sequence[Any], PyTreeDef]: + jaxpr, consts, out_tree = _initial_style_open_jaxpr( + fun, in_tree, in_avals, debug_info) + closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr)) + return closed_jaxpr, consts, out_tree + +def _initial_style_jaxprs_with_common_consts( + funs: Sequence[Callable], + in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue | core.AvalQDD], + debug_infos: Sequence[core.DebugInfo]): + jaxpr_data = [_initial_style_open_jaxpr(fn, in_tree, in_avals, debug_info) + for fn, debug_info in zip(funs, debug_infos)] + if not jaxpr_data: return [], [], [] + jaxprs, all_consts, all_out_trees = zip(*jaxpr_data) + # Jaxprs must share consts, so we concat consts and pad the jaxprs' constvars. lens = map(len, all_consts) consts = [c for cs in all_consts for c in cs] avalqdds = tuple(map(core.cur_aval_qdd, consts)) - num_constss = [len(cs) for cs in all_consts] - jaxprs = [_pad_constvars(jaxpr, num_consts, avalqdds[:sum(lens[:i])], avalqdds[sum(lens[:i+1]):]) - for i, (jaxpr, num_consts) in enumerate(zip(jaxprs, num_constss))] + jaxprs = [_pad_constvars(jaxpr, avalqdds[:sum(lens[:i])], avalqdds[sum(lens[:i+1]):]) + for i, jaxpr in enumerate(jaxprs)] # De-duplicate shared constants. const_ids = tuple(id(c) for c in consts) seen = set() - dd_consts = [c for c in consts if id(c) not in seen and not seen.add(id(c))] # type: ignore - jaxprs = [_dedup_consts(jaxpr, len(consts), const_ids) for jaxpr in jaxprs] - return jaxprs, dd_consts + consts = [c for c in consts if id(c) not in seen and not seen.add(id(c))] # type: ignore + jaxprs = [_dedup_consts(jaxpr, const_ids) for jaxpr in jaxprs] + + closed_jaxprs = [pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr)) + for jaxpr in jaxprs] + return closed_jaxprs, consts, all_out_trees @weakref_lru_cache -def _pad_constvars(jaxpr: core.ClosedJaxpr, num_consts: int, - left: tuple[core.AvalQDD, ...], - right: tuple[core.AbstractValue, ...]) -> core.ClosedJaxpr: +def _pad_constvars(jaxpr: core.Jaxpr, left: tuple[core.AvalQDD, ...], + right: tuple[core.AbstractValue, ...]) -> core.Jaxpr: def make_var(aq): return core.Var(aq.aval, initial_qdd=aq.qdd, final_qdd=aq.qdd) - invars = [*map(make_var, left), *jaxpr.invars[:num_consts], - *map(make_var, right), *jaxpr.invars[num_consts:]] - effs = pe._renumber_effects(invars, jaxpr.invars, jaxpr.effects) - jaxpr = jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(invars=invars, effects=effs)) + constvars = [*map(make_var, left), *jaxpr.constvars, *map(make_var, right)] + effs = pe._renumber_effects([*constvars, *jaxpr.invars], + [*jaxpr.constvars, *jaxpr.invars], jaxpr.effects) + jaxpr = jaxpr.replace(constvars=constvars, effects=effs) config.enable_checks.value and core.check_jaxpr(jaxpr) return jaxpr @weakref_lru_cache -def _dedup_consts(jaxpr, num_consts, const_ids): +def _dedup_consts(jaxpr, const_ids): newvars = {} canonicalize = {v: newvars.setdefault(constid, v) - for constid, v in zip(const_ids, jaxpr.invars[:num_consts])} + for constid, v in zip(const_ids, jaxpr.constvars)} eqns = [e.replace(invars=[canonicalize.get(x, x) if isinstance(x, core.Var) else x for x in e.invars]) for e in jaxpr.eqns] outvars = [canonicalize.get(x, x) if isinstance(x, core.Var) else x for x in jaxpr.outvars] - invars = [*list(newvars.values()), *jaxpr.invars[num_consts:]] - effs = pe._renumber_effects(invars, - [*map(canonicalize.get, jaxpr.invars[:num_consts]), *jaxpr.invars[num_consts:]], - jaxpr.effects) - jaxpr = jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(invars=invars, eqns=eqns, outvars=outvars, - effects=effs)) + constvars = list(newvars.values()) + effs = pe._renumber_effects( + [*constvars, *jaxpr.invars], + [*map(canonicalize.get, jaxpr.constvars), *jaxpr.invars], jaxpr.effects) + jaxpr = jaxpr.replace(constvars=constvars, eqns=eqns, outvars=outvars, + effects=effs) config.enable_checks.value and core.check_jaxpr(jaxpr) return jaxpr diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index eab8b92b742d..668b880ab4dd 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -53,7 +53,7 @@ import numpy as np from jax._src.lax.control_flow.common import ( - _avals_short, _typecheck_param, _merge_common_consts, + _avals_short, _typecheck_param, _initial_style_jaxprs_with_common_consts, _make_closed_jaxpr, _prune_zeros) map, unsafe_map = safe_map, map @@ -149,10 +149,8 @@ def _switch_internal( if config.mutable_array_checks.value: api_util.check_no_aliased_ref_args(lambda: dbgs[0], ops_avals, ops) - jaxprs_, out_trees, all_consts = zip(*[pe.trace_to_jaxpr( - branch, ops_tree, ops_avals, dbg) for branch, dbg in zip(branches, dbgs)]) - jaxprs, consts = _merge_common_consts(jaxprs_, all_consts) - + jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts( + branches, ops_tree, ops_avals, dbgs) if config.mutable_array_checks.value: api_util._check_no_aliased_closed_over_refs(dbgs[0], (*jaxprs[0].consts, *consts), ops) for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])): @@ -186,7 +184,7 @@ def _switch_internal( return tree_unflatten(out_trees[0], out) @partial(api_boundary, repro_api_name="jax_cond") -def cond(pred, true_fun: Callable, false_fun: Callable, *operands, +def _cond(pred, true_fun: Callable, false_fun: Callable, *operands, operand=_no_operand_sentinel): """Conditionally apply ``true_fun`` or ``false_fun``. @@ -272,16 +270,14 @@ def cond(pred, true_fun, false_fun, *operands): if config.mutable_array_checks.value: api_util.check_no_aliased_ref_args(lambda: dbg_true_fun, ops_avals, ops) dbg_false_fun = api_util.debug_info("cond", false_fun, operands, {}) - - true_jaxpr_, out_tree, true_consts = pe.trace_to_jaxpr( - true_fun, ops_tree, ops_avals, dbg_true_fun) - false_jaxpr_, false_out_tree, false_consts = pe.trace_to_jaxpr( - false_fun, ops_tree, ops_avals, dbg_false_fun) - (true_jaxpr, false_jaxpr), consts = _merge_common_consts( - (true_jaxpr_, false_jaxpr_), (true_consts, false_consts)) + jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts( + (true_fun, false_fun), ops_tree, ops_avals, + [dbg_true_fun, dbg_false_fun]) + true_jaxpr, false_jaxpr = jaxprs if config.mutable_array_checks.value: api_util._check_no_aliased_closed_over_refs(dbg_true_fun, (*true_jaxpr.consts, *consts), ops) + out_tree, false_out_tree = out_trees if any(isinstance(out_aval, AbstractRef) for out_aval in true_jaxpr.out_avals + false_jaxpr.out_avals): raise ValueError("Cannot return `Ref`s from `cond`.") @@ -403,6 +399,48 @@ def _capitalize(s): # s.capitalize() converts s[1:] to lowercase which we don't want. return s[0].capitalize() + s[1:] +@api_boundary +@functools.wraps(_cond) +def cond(*args, **kwargs): + # detect an attempt to call the former, deprecated cond + try: + ba = inspect.signature(_cond_with_per_branch_args).bind(*args, **kwargs) + except TypeError: + pass + else: + assert not ba.kwargs # no catch-all **kwargs in _cond_with_per_branch + _, true_operand, true_fun, false_operand, false_fun = ba.args + if callable(true_operand) and callable(true_fun): + # treat this as modern cond (with two operands) + return _cond(*args, **kwargs) + if callable(true_fun) and callable(false_fun): + return _cond_with_per_branch_args(*ba.args) + + return _cond(*args, **kwargs) + +@partial(api_boundary, repro_api_name="jax_cond_with_per_branch_args") +def _cond_with_per_branch_args(pred, + true_operand, true_fun: Callable, + false_operand, false_fun: Callable): + """Conditionally apply ``true_fun`` or ``false_fun``. + + Has equivalent semantics to this Python implementation:: + + def cond(pred, true_operand, true_fun, false_operand, false_fun): + if pred: + return true_fun(true_operand) + else: + return false_fun(false_operand) + + Pred has to be a scalar type, collection types (list, tuple) are not supported + """ + if not (callable(true_fun) and callable(false_fun)): + raise TypeError("lax.cond: true_fun and false_fun arguments should be callable.") + return _cond(pred, + lambda op: true_fun(op[0]), + lambda op: false_fun(op[1]), + (true_operand, false_operand)) + def _join_cond_effects(branches: Sequence[core.ClosedJaxpr]) -> effects.Effects: joined_effects = set() for b in branches: diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index dab448cde147..d5e31b0ca0a2 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -50,7 +50,7 @@ from jax._src.lax import slicing from jax._src.lax import windowed_reductions from jax._src.lax.control_flow.common import ( - _avals_short, _prune_zeros, _typecheck_param, + _avals_short, _initial_style_jaxpr, _prune_zeros, _typecheck_param, _make_closed_jaxpr) from jax._src.lax.other import logaddexp from jax._src.pjit import auto_axes, PartitionSpec as P, reshard @@ -281,8 +281,9 @@ def _create_jaxpr(init): init_flat, init_tree = tree_flatten(init) in_flat, in_tree = tree_flatten((init, xs)) carry_avals = tuple(_map(core.get_aval, init_flat)) - jaxpr, out_tree, consts = pe.trace_to_jaxpr( + open_jaxpr, out_tree, consts = pe.trace_to_jaxpr( f, in_tree, (*carry_avals, *x_avals), debug_info=dbg_body) + jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(open_jaxpr)) if config.mutable_array_checks.value: _check_no_aliased_closed_over_refs(dbg_body, (*jaxpr.consts, *consts), in_flat) out_tree_children = out_tree.children() @@ -1711,10 +1712,10 @@ def _create_jaxpr(init_val): init_vals, in_tree = tree_flatten((init_val,)) init_avals = tuple(_map(core.get_aval, init_vals)) cond_dbg = api_util.debug_info("while_cond", cond_fun, (init_val,), {}) - cond_jaxpr, cond_tree, cond_consts = pe.trace_to_jaxpr( + cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr( cond_fun, in_tree, init_avals, cond_dbg) body_dbg = api_util.debug_info("while_body", body_fun, (init_val,), {}) - body_jaxpr, body_tree, body_consts = pe.trace_to_jaxpr( + body_jaxpr, body_consts, body_tree = _initial_style_jaxpr( body_fun, in_tree, init_avals, body_dbg) if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1: msg = "cond_fun must return a boolean scalar, but got pytree {}." diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py index e65f0cda1480..17ed44b69991 100644 --- a/jax/_src/lax/control_flow/solves.py +++ b/jax/_src/lax/control_flow/solves.py @@ -27,7 +27,6 @@ from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir -from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import pxla from jax._src.traceback_util import api_boundary from jax._src.tree_util import (tree_flatten, treedef_children, tree_leaves, @@ -37,6 +36,7 @@ from jax._src.lax.control_flow.common import ( _check_tree, + _initial_style_jaxpr, ) _map = safe_map @@ -95,7 +95,7 @@ def custom_root(f: Callable, guess_flat, in_args_tree = tree_flatten((initial_guess,)) guess_avals = tuple(_map(core.get_aval, guess_flat)) f_debug = api_util.debug_info("custom_root", f, (initial_guess,), {}) - f_jaxpr, out_tree, f_consts = pe.trace_to_jaxpr( + f_jaxpr, f_consts, out_tree = _initial_style_jaxpr( f, in_args_tree, guess_avals, f_debug) in_tree, = treedef_children(in_args_tree) @@ -104,7 +104,7 @@ def custom_root(f: Callable, solve_debug = api_util.debug_info("custom_root solve", solve, (f, initial_guess), {}, static_argnums=(0,)) - solve_jaxpr, solution_tree, solve_consts = pe.trace_to_jaxpr( + solve_jaxpr, solve_consts, solution_tree = _initial_style_jaxpr( partial(solve, f), in_args_tree, guess_avals, solve_debug) _check_tree("solve", "initial_guess", solution_tree, in_tree, has_aux) @@ -114,7 +114,7 @@ def linearize_and_solve(x, b): linearize_and_solve_dbg = api_util.debug_info("custom_root tangent_solve", tangent_solve, (initial_guess, initial_guess), {}) - l_and_s_jaxpr, out_tree, l_and_s_consts = pe.trace_to_jaxpr( + l_and_s_jaxpr, l_and_s_consts, out_tree = _initial_style_jaxpr( linearize_and_solve, treedef_tuple((in_tree,) * 2), guess_avals * 2, linearize_and_solve_dbg) _check_tree("tangent_solve", "x", out_tree, in_tree, False) @@ -268,7 +268,7 @@ def f_aux(x): matvec_debug = api_util.debug_info("custom_linear_solve", matvec, (b,), {}) # no auxiliary data assumed for matvec - matvec_jaxpr, out_tree, matvec_consts = pe.trace_to_jaxpr( + matvec_jaxpr, matvec_consts, out_tree = _initial_style_jaxpr( _shape_checked(matvec, "matvec", False), in_args_tree, b_avals, matvec_debug) _check_tree("matvec", "b", out_tree, tree, False) @@ -276,7 +276,7 @@ def f_aux(x): solve_debug = api_util.debug_info("custom_linear_solve solve", solve, (matvec, b), {}, static_argnums=(0,)) - solve_jaxpr, out_tree, solve_consts = pe.trace_to_jaxpr( + solve_jaxpr, solve_consts, out_tree = _initial_style_jaxpr( _shape_checked(partial(solve, matvec), "solve", has_aux), in_args_tree, b_avals, solve_debug) _check_tree("solve", "b", out_tree, tree, has_aux) @@ -294,11 +294,11 @@ def f_aux(x): vecmat_consts = matvec_consts else: vecmat = _transpose_one_output(matvec, b) - vecmat_jaxpr, out_tree, vecmat_consts = pe.trace_to_jaxpr( + vecmat_jaxpr, vecmat_consts, out_tree = _initial_style_jaxpr( vecmat, in_args_tree, b_avals, transpose_solve_debug) assert out_tree == tree - tr_solve_jaxpr, out_tree, tr_solve_consts = pe.trace_to_jaxpr( + tr_solve_jaxpr, tr_solve_consts, out_tree = _initial_style_jaxpr( _shape_checked(partial(transpose_solve, vecmat), "transpose_solve", has_aux), in_args_tree, b_avals, transpose_solve_debug) _check_tree("transpose_solve", "b", out_tree, tree, has_aux) diff --git a/tests/api_test.py b/tests/api_test.py index 8dde8b47d6c2..c7c4dda3a6be 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -7215,10 +7215,10 @@ def fun(x): def test_cond(self): def f(x): return lax.cond(x >= 0., - lambda xt, _: xt + x, - lambda _, xf: xf - x, x + 1., - x + 2.) + lambda xt: xt + x, + x + 2., + lambda xf: xf - x) expected = """{ lambda ; a:f32[]. let b:bool[] = ge a 0.0:f32[] c:f32[] = add a 1.0:f32[] @@ -7949,10 +7949,10 @@ def f(c, x): jax.lax.scan(f, 0, jnp.arange(4)) def test_cond_traceback(self): - if sys.version_info < (3, 13): + if sys.version_info < (3, 14): # Fails because 3.11 adds an extra stack frame due to a list comprehension self.skipTest("Expected failure.") - expected_depth = 4 + expected_depth = 8 init_depth = self.cur_depth() def f(): diff --git a/tests/core_test.py b/tests/core_test.py index a2b3da15fd2c..c7cf4918f8e5 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -448,11 +448,15 @@ class JaxprTypeChecks(jtu.JaxTestCase): def setUp(self): super().setUp() + lax_control_flow._initial_style_open_jaxpr.cache_clear() + lax_control_flow._initial_style_jaxpr.cache_clear() lax_control_flow.common._dedup_consts.cache_clear() lax_control_flow.common._pad_constvars.cache_clear() def tearDown(self): super().tearDown() + lax_control_flow._initial_style_open_jaxpr.cache_clear() + lax_control_flow._initial_style_jaxpr.cache_clear() lax_control_flow.common._dedup_consts.cache_clear() lax_control_flow.common._pad_constvars.cache_clear() diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 7a3a45a485d9..01c8898749cc 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -53,14 +53,25 @@ # provides a lax.cond-compatible interface to a two-branch lax.switch. Several # tests in this file are parameterized such that they either call into lax.cond # or into this function. -def cond_via_switch(pred, true_fun, false_fun, *args): +def cond_via_switch(pred, true_fun, false_fun, op, *args): + if len(args) > 0: + assert len(args) == 1 + true_op, _true_fun, false_op, _false_fun = true_fun, false_fun, op, args[0] + op = (false_op, true_op) + false_fun = lambda op: _false_fun(op[0]) + true_fun = lambda op: _true_fun(op[1]) index = lax.convert_element_type(pred, np.int32) - return lax.switch(index, [false_fun, true_fun], *args) - -def cond_with_new_checkpoint(pred, true_fun, false_fun, *args): + return lax.switch(index, [false_fun, true_fun], op) + +def cond_with_new_checkpoint(pred, true_fun, false_fun, op, *args): + if args: + true_op, _true_fun, false_op, _false_fun = true_fun, false_fun, op, args[0] + op = (false_op, true_op) + false_fun = lambda op: _false_fun(op[0]) + true_fun = lambda op: _true_fun(op[1]) index = lax.convert_element_type(pred, np.int32) - fn = lambda index, *args: lax.switch(index, [false_fun, true_fun], *args) - return jax.checkpoint(fn)(index, *args) + fn = lambda index, op: lax.switch(index, [false_fun, true_fun], op) + return jax.checkpoint(fn)(index, op) COND_IMPLS = [ (lax.cond, 'cond'), @@ -160,6 +171,8 @@ class LaxControlFlowTest(jtu.JaxTestCase): def setUp(self): super().setUp() + lax_control_flow._initial_style_open_jaxpr.cache_clear() + lax_control_flow._initial_style_jaxpr.cache_clear() lax_control_flow.common._dedup_consts.cache_clear() lax_control_flow.common._pad_constvars.cache_clear() @@ -987,8 +1000,8 @@ def cfun(x): lax.lt(x, 2), lambda x: lax.mul(2, x), lambda x: cond(lax.lt(x, 5), - lambda x, _: lax.mul(3, x), - lambda _, y: lax.mul(y, x), x, 4), + x, lambda x: lax.mul(3, x), + 4, lambda y: lax.mul(y, x)), x) self.assertEqual(cfun(1), 2) @@ -1108,9 +1121,9 @@ def cfun(x): def testCondBatched(self): def fun(x, y, z): pred = lax.lt(x, 3) - true_fun = lambda y, _: y - false_fun = lambda _, z: lax.neg(z) - return lax.cond(pred, true_fun, false_fun, y, z) + true_fun = lambda y: y + false_fun = lambda z: lax.neg(z) + return lax.cond(pred, y, true_fun, z, false_fun) # these cases stay as cond x = jnp.array(2) @@ -1274,7 +1287,7 @@ def fun_ref(x): return 2. * x def fun(x): - return cond(x < 3, lambda _: 2., lambda x: 2. * x, x) + return cond(x < 3, None, lambda _: 2., x, lambda x: 2. * x) x = 3.14 ans = jax.jvp(fun, (x,), (x,)) @@ -1432,7 +1445,7 @@ def fun_ref(x): return 2. * x def fun(x): - return cond(x < 3, lambda _: 2., lambda x: 2. * x, x) + return cond(x < 3, None, lambda _: 2., x, lambda x: 2. * x) x = 3.14 ans = jax.grad(fun)(x) @@ -1462,9 +1475,8 @@ def fun_ref(x, y): def fun(x, y): return cond( x < 3, - lambda _: 2. * jnp.sin(y), - lambda x: 2. * x, - x) + None, lambda _: 2. * jnp.sin(y), + x, lambda x: 2. * x) y = 5.8 x = 3.14 @@ -1653,7 +1665,7 @@ def g(x): return jnp.where(x > 0, f_1(x), f_2(x)) def testIssue1263(self): def f(rng, x): cond = random.bernoulli(rng) - return lax.cond(cond, lambda x, _: x, lambda _, x: x, x, jnp.abs(x) - 1.) + return lax.cond(cond, x, lambda x: x, jnp.abs(x) - 1., lambda x: x) def body_fn(i, state): rng, x = state @@ -1668,9 +1680,8 @@ def g(rng, x): def testIssue514(self): # just check this doesn't crash lax.cond(True, - lambda x, _: (x[0], 0), - lambda _, x: x, - (0, 0), (1, 1)) + (0, 0), lambda x: (x[0], 0), + (1, 1), lambda x: x) def testIssue649(self): from jax import lax @@ -2377,9 +2388,8 @@ def testWhileGradError(self, loop: str = "fori_inside_scan"): elif loop == "fori_inside_cond": func = lambda x: lax.cond( True, - lambda x, _: lax.fori_loop(x, x + 2., lambda i, c: c * 2., x), - lambda _, x: x, - x, 1.) + x, lambda x: lax.fori_loop(x, x + 2., lambda i, c: c * 2., x), + 1., lambda x: x) elif loop == "fori_inside_scan": func = lambda x: lax.scan( lambda c, x: (lax.fori_loop(x, x + 2., lambda i, c1: c1 * c, x), None), @@ -2551,7 +2561,7 @@ def f(h, _): def test_disable_jit_cond_with_vmap(self): # https://github.com/jax-ml/jax/issues/3093 def fn(t): - return lax.cond(t > 0, lambda x, _: 0, lambda _, x: 1, 0, 0) + return lax.cond(t > 0, 0, lambda x: 0, 0, lambda x: 1) fn = jax.vmap(fn) with jax.disable_jit(): diff --git a/tests/metadata_test.py b/tests/metadata_test.py index 524768aaed87..917cf7bf5133 100644 --- a/tests/metadata_test.py +++ b/tests/metadata_test.py @@ -79,7 +79,7 @@ def true_fun(x): def false_fun(x): return jnp.cos(x) def f(which, x): - return jax.lax.cond(which, true_fun, false_fun, x) + return jax.lax.cond(which, x, true_fun, x, false_fun) hlo = module_to_string(jax.jit(f).lower(True, 1.).compiler_ir()) self.assertRegex(hlo, r'loc\(".*cond/branch_0_fun/cos"') self.assertRegex(hlo, r'loc\(".*cond/branch_1_fun/sin"') From 245b2c54eb633b61043a9d9a8e62d200706bc892 Mon Sep 17 00:00:00 2001 From: Yulia Baturina Date: Wed, 3 Dec 2025 14:44:18 -0800 Subject: [PATCH 028/315] Remove CUDA12 tests with jaxlib built from HEAD from the mandatory presubmit. PiperOrigin-RevId: 839923210 --- .github/workflows/bazel_cuda_presubmit.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/bazel_cuda_presubmit.yml b/.github/workflows/bazel_cuda_presubmit.yml index 26124c37bd58..48eb4495f2a9 100644 --- a/.github/workflows/bazel_cuda_presubmit.yml +++ b/.github/workflows/bazel_cuda_presubmit.yml @@ -73,6 +73,9 @@ jobs: - python: "3.14" enable-x64: 1 jaxlib-version: "pypi_latest" + # Exclude CUDA 12 on jaxlib head because it's too slow. + - cuda-version: "12" + jaxlib-version: "head" name: "Bazel single accelerator ${{ format('{0}', 'CUDA tests') }}" # End Presubmit Naming Check github-cuda-presubmits with: From ec1359573aed686e9c5d5f0c6b21dc36e159ed65 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Wed, 3 Dec 2025 22:05:24 +0000 Subject: [PATCH 029/315] fix printing of unreduced Arrays Co-authored-by: Yash Katariya --- jax/_src/array.py | 27 +++++++++++++++++---------- tests/array_test.py | 7 +++++++ 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index 13ca89fb25d7..9c7dc1738343 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -44,7 +44,7 @@ from jax._src.sharding import Sharding from jax._src.tree_util import broadcast_prefix, tree_flatten, tree_unflatten from jax._src.sharding_impls import ( - PmapSharding, SingleDeviceSharding, + PmapSharding, SingleDeviceSharding, NamedSharding, device_replica_id_map, hashed_index, num_addressable_indices, local_to_global_shape, _internal_use_concrete_mesh) # pyformat: disable from jax._src.typing import ArrayLike, DLDeviceType, DTypeLike, ExtendedDType @@ -284,9 +284,6 @@ def weak_type(self): def committed(self) -> bool: return self._committed - def __str__(self): - return str(self._value) - def __len__(self): try: return self.shape[0] @@ -394,11 +391,13 @@ def is_fully_replicated(self) -> bool: def __repr__(self): prefix = 'Array(' if self.aval is not None and self.aval.weak_type: - dtype_str = f'dtype={self.dtype.name}, weak_type=True)' + dtype_str = f'dtype={self.dtype.name}, weak_type=True' else: - dtype_str = f'dtype={self.dtype.name})' + dtype_str = f'dtype={self.dtype.name}' - if self.is_fully_addressable or self.is_fully_replicated: + if isinstance(self.sharding, NamedSharding) and self.sharding.spec.unreduced: + return f"Array(shape={self.shape}, {dtype_str}, sharding={self.sharding})" + elif self.is_fully_addressable or self.is_fully_replicated: line_width = np.get_printoptions()["linewidth"] if self.size == 0: s = f"[], shape={self.shape}" @@ -409,11 +408,19 @@ def __repr__(self): separator=', ', max_line_width=line_width) last_line_len = len(s) - s.rfind('\n') + 1 sep = ' ' - if last_line_len + len(dtype_str) + 1 > line_width: + if last_line_len + len(dtype_str) + 2 > line_width: sep = ' ' * len(prefix) - return f"{prefix}{s},{sep}{dtype_str}" + return f"{prefix}{s},{sep}{dtype_str})" + else: + return f"{prefix}shape={self.shape}, {dtype_str})" + + def __str__(self): + if isinstance(self.sharding, NamedSharding) and self.sharding.spec.unreduced: + return repr(self) + elif self.is_fully_addressable or self.is_fully_replicated: + return str(self._value) # doesn't print Array(...) else: - return f"{prefix}shape={self.shape}, {dtype_str}" + return repr(self) @property def is_fully_addressable(self) -> bool: diff --git a/tests/array_test.py b/tests/array_test.py index 025ce31d6c54..50447cf7a3d4 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -882,6 +882,13 @@ def test_make_array_from_single_device_arrays_bad_dtype_error(self): jax.make_array_from_single_device_arrays( shape, s, [arr], dtype=jnp.float32) + @jtu.with_explicit_mesh((2,), ('x',)) + def test_unreduced_printing(self, mesh): + x = jax.device_put(jnp.arange(8., dtype='float32'), P('x')) + x = jax.lax.reduce_sum(x, [0], out_sharding=P(unreduced={'x'})) + self.assertIn('nreduced', str(x.sharding)) + self.assertIn('Array(shape=(), dtype=float32, sharding=', str(x)) + class ShardingTest(jtu.JaxTestCase): From 7bda7a0530d6735da5980d1f0870addafb0d57ce Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Wed, 3 Dec 2025 23:09:55 +0000 Subject: [PATCH 030/315] add ensure_compile_time_eval support to custom_jvp/vjp fixes #30787 --- jax/_src/interpreters/partial_eval.py | 6 ++++++ tests/custom_api_test.py | 27 +++++++++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index a3dbfc55714f..85e5ce37bc5a 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -2245,6 +2245,9 @@ def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): def process_custom_jvp_call(self, prim, fun: lu.WrappedFun, jvp: lu.WrappedFun, tracers, symbolic_zeros: bool): + if config.eager_constant_folding.value and not any(isinstance(x, Tracer) for x in tracers): + return prim.bind_with_trace(core.eval_trace, (fun, jvp, *tracers), + dict(symbolic_zeros=symbolic_zeros)) source_info = source_info_util.current() to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) tracers = map(to_jaxpr_tracer, tracers) @@ -2279,6 +2282,9 @@ def process_custom_vjp_call(self, prim: core.Primitive, fwd: lu.WrappedFun, bwd: lu.WrappedFun, tracers, out_trees: Callable[[], tuple[PyTreeDef, PyTreeDef, list[int | None]]], symbolic_zeros: bool): + if config.eager_constant_folding.value and not any(isinstance(x, Tracer) for x in tracers): + return prim.bind_with_trace(core.eval_trace, (fun, fwd, bwd, *tracers), + dict(out_trees=out_trees, symbolic_zeros=symbolic_zeros)) source_info = source_info_util.current() to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) tracers = map(to_jaxpr_tracer, tracers) diff --git a/tests/custom_api_test.py b/tests/custom_api_test.py index 904cfbeeda39..786409de5a6e 100644 --- a/tests/custom_api_test.py +++ b/tests/custom_api_test.py @@ -1476,6 +1476,33 @@ def sin_jvp(primals, tangents): ans, = f_vjp(1.) self.assertAllClose(ans, 1./2, check_dtypes=False) + def test_ensure_compile_time_eval(self): + @jax.custom_jvp + def f(x): + assert x == 0. # concrete! + return x + @f.defjvp + def f_jvp(primals, tangents): + (x,), (x_dot,) = primals, tangents + assert x == 0. # concrete! + + @jax.jit + def g(): + with jax.ensure_compile_time_eval(): + return f(0.) + + g() # don't crash + + # TODO(mattjj): do we want to support autodiff here too? + # def h(x): + # @jax.jit + # def hh(): + # with jax.ensure_compile_time_eval(): + # return f(x) + # return hh() + + # jax.grad(h)(0.) # don't crash + class CustomVJPTest(jtu.JaxTestCase): From 4cbcbe7eb0d190443195c23838d9cb4035473530 Mon Sep 17 00:00:00 2001 From: Dougal Date: Tue, 2 Dec 2025 18:20:41 -0500 Subject: [PATCH 031/315] [no-thunks] Deprecate old form of `cond` and remove some tracing layers. This reduces the stack depth in `cond` and exposes fewer internals. --- jax/_src/interpreters/partial_eval.py | 10 ++- jax/_src/lax/control_flow/__init__.py | 3 - jax/_src/lax/control_flow/common.py | 84 ++++++++--------------- jax/_src/lax/control_flow/conditionals.py | 68 +++++------------- jax/_src/lax/control_flow/loops.py | 14 ++-- jax/_src/lax/control_flow/solves.py | 23 ++++--- tests/api_test.py | 10 +-- tests/core_test.py | 4 -- tests/lax_control_flow_test.py | 58 +++++++--------- tests/metadata_test.py | 2 +- 10 files changed, 105 insertions(+), 171 deletions(-) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index a3dbfc55714f..37dbd8dbadfa 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -884,6 +884,11 @@ def move_envvars(jaxpr: Jaxpr, which: tuple[bool, ...]) -> Jaxpr: constvars, envvars = partition_list(which, jaxpr.constvars) return jaxpr.replace(constvars=constvars, invars=[*envvars, *jaxpr.invars]) +@weakref_lru_cache +def separate_consts(jaxpr: ClosedJaxpr) -> tuple[ClosedJaxpr, list[Any]]: + """Moves the constvars to the start of invars and returns the consts explicitly.""" + return ClosedJaxpr(convert_constvars_jaxpr(jaxpr.jaxpr), []), jaxpr.consts + @weakref_lru_cache def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr: """Moves the constvars to the start of invars.""" @@ -2400,7 +2405,7 @@ def trace_to_jaxpr( in_tree: PyTreeDef, in_avals_flat: Sequence[AbstractValue | core.AvalQDD], debug_info: core.DebugInfo -) -> tuple[Jaxpr, PyTreeDef, list[Any]]: +) -> tuple[ClosedJaxpr, PyTreeDef]: config.enable_checks.value and debug_info.assert_arg_names(len(in_avals_flat)) parent_trace = core.trace_ctx.trace trace = DynamicJaxprTrace(debug_info, parent_trace=parent_trace) @@ -2424,8 +2429,7 @@ def trace_to_jaxpr( del trace, fun, in_tracers_flat, in_tracers, out_tracers, ans, ans_flat config.enable_checks.value and core.check_jaxpr(jaxpr) - return jaxpr, out_tree, consts - + return ClosedJaxpr(jaxpr, consts), out_tree # TODO(dougalm): remove in favor of `trace_to_jaxpr` @profiler.annotate_function diff --git a/jax/_src/lax/control_flow/__init__.py b/jax/_src/lax/control_flow/__init__.py index 44ee94e14ca2..5cbe5a39d381 100644 --- a/jax/_src/lax/control_flow/__init__.py +++ b/jax/_src/lax/control_flow/__init__.py @@ -50,9 +50,6 @@ # Private utilities used elsewhere in JAX # TODO(sharadmv): lift them into a more common place from jax._src.lax.control_flow.common import ( - _initial_style_open_jaxpr as _initial_style_open_jaxpr, - _initial_style_jaxpr as _initial_style_jaxpr, - _initial_style_jaxprs_with_common_consts as _initial_style_jaxprs_with_common_consts, _check_tree_and_avals as _check_tree_and_avals, ) diff --git a/jax/_src/lax/control_flow/common.py b/jax/_src/lax/control_flow/common.py index d29746560b46..9518b4484bd9 100644 --- a/jax/_src/lax/control_flow/common.py +++ b/jax/_src/lax/control_flow/common.py @@ -15,7 +15,7 @@ from __future__ import annotations -from collections.abc import Callable, Sequence +from collections.abc import Sequence import os from functools import partial from typing import Any @@ -27,7 +27,7 @@ from jax._src.util import weakref_lru_cache, safe_map from jax._src.interpreters import partial_eval as pe from jax._src.tree_util import (equality_errors_pytreedef, tree_map, - tree_unflatten, keystr, PyTreeDef) + tree_unflatten, keystr) map, unsafe_map = safe_map, map @@ -43,78 +43,54 @@ def _typecheck_param(prim, param, name, msg_required, pred): msg = sep.join([msg, param_str]) raise core.JaxprTypeError(msg) -# TODO(dougalm): this is a silly wrapper now. Delete it. -@weakref_lru_cache -def _initial_style_open_jaxpr(fun: Callable, - in_tree: PyTreeDef, - in_avals: Sequence[core.AbstractValue | core.AvalQDD], - debug_info: core.DebugInfo): - jaxpr, out_tree, consts = pe.trace_to_jaxpr(fun, in_tree, in_avals, debug_info) - return jaxpr, consts, out_tree - -# TODO(dougalm): Delete. Make `trace_to_jaxpr` do the jaxpr-closing thing instead. -@weakref_lru_cache -def _initial_style_jaxpr(fun: Callable, - in_tree: PyTreeDef, - in_avals: Sequence[core.AbstractValue], - debug_info: core.DebugInfo) -> tuple[core.ClosedJaxpr, Sequence[Any], PyTreeDef]: - jaxpr, consts, out_tree = _initial_style_open_jaxpr( - fun, in_tree, in_avals, debug_info) - closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr)) - return closed_jaxpr, consts, out_tree - -def _initial_style_jaxprs_with_common_consts( - funs: Sequence[Callable], - in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue | core.AvalQDD], - debug_infos: Sequence[core.DebugInfo]): - jaxpr_data = [_initial_style_open_jaxpr(fn, in_tree, in_avals, debug_info) - for fn, debug_info in zip(funs, debug_infos)] - if not jaxpr_data: return [], [], [] - jaxprs, all_consts, all_out_trees = zip(*jaxpr_data) - +# TODO(dougalm): this seems way too complicated. Why not allow different consts for each +# branch of a switch? +def _merge_common_consts( + jaxprs: Sequence[core.ClosedJaxpr], + all_consts: Sequence[Sequence[Any]] + ) -> tuple[Sequence[core.ClosedJaxpr], Sequence[Any]]: # Jaxprs must share consts, so we concat consts and pad the jaxprs' constvars. lens = map(len, all_consts) consts = [c for cs in all_consts for c in cs] avalqdds = tuple(map(core.cur_aval_qdd, consts)) - jaxprs = [_pad_constvars(jaxpr, avalqdds[:sum(lens[:i])], avalqdds[sum(lens[:i+1]):]) - for i, jaxpr in enumerate(jaxprs)] + num_constss = [len(cs) for cs in all_consts] + jaxprs = [_pad_constvars(jaxpr, num_consts, avalqdds[:sum(lens[:i])], avalqdds[sum(lens[:i+1]):]) + for i, (jaxpr, num_consts) in enumerate(zip(jaxprs, num_constss))] # De-duplicate shared constants. const_ids = tuple(id(c) for c in consts) seen = set() - consts = [c for c in consts if id(c) not in seen and not seen.add(id(c))] # type: ignore - jaxprs = [_dedup_consts(jaxpr, const_ids) for jaxpr in jaxprs] - - closed_jaxprs = [pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr)) - for jaxpr in jaxprs] - return closed_jaxprs, consts, all_out_trees + dd_consts = [c for c in consts if id(c) not in seen and not seen.add(id(c))] # type: ignore + jaxprs = [_dedup_consts(jaxpr, len(consts), const_ids) for jaxpr in jaxprs] + return jaxprs, dd_consts @weakref_lru_cache -def _pad_constvars(jaxpr: core.Jaxpr, left: tuple[core.AvalQDD, ...], - right: tuple[core.AbstractValue, ...]) -> core.Jaxpr: +def _pad_constvars(jaxpr: core.ClosedJaxpr, num_consts: int, + left: tuple[core.AvalQDD, ...], + right: tuple[core.AbstractValue, ...]) -> core.ClosedJaxpr: def make_var(aq): return core.Var(aq.aval, initial_qdd=aq.qdd, final_qdd=aq.qdd) - constvars = [*map(make_var, left), *jaxpr.constvars, *map(make_var, right)] - effs = pe._renumber_effects([*constvars, *jaxpr.invars], - [*jaxpr.constvars, *jaxpr.invars], jaxpr.effects) - jaxpr = jaxpr.replace(constvars=constvars, effects=effs) - config.enable_checks.value and core.check_jaxpr(jaxpr) + invars = [*map(make_var, left), *jaxpr.invars[:num_consts], + *map(make_var, right), *jaxpr.invars[num_consts:]] + effs = pe._renumber_effects(invars, jaxpr.invars, jaxpr.effects) + jaxpr = jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(invars=invars, effects=effs)) + config.enable_checks.value and core.check_jaxpr(jaxpr.jaxpr) return jaxpr @weakref_lru_cache -def _dedup_consts(jaxpr, const_ids): +def _dedup_consts(jaxpr, num_consts, const_ids): newvars = {} canonicalize = {v: newvars.setdefault(constid, v) - for constid, v in zip(const_ids, jaxpr.constvars)} + for constid, v in zip(const_ids, jaxpr.invars[:num_consts])} eqns = [e.replace(invars=[canonicalize.get(x, x) if isinstance(x, core.Var) else x for x in e.invars]) for e in jaxpr.eqns] outvars = [canonicalize.get(x, x) if isinstance(x, core.Var) else x for x in jaxpr.outvars] - constvars = list(newvars.values()) - effs = pe._renumber_effects( - [*constvars, *jaxpr.invars], - [*map(canonicalize.get, jaxpr.constvars), *jaxpr.invars], jaxpr.effects) - jaxpr = jaxpr.replace(constvars=constvars, eqns=eqns, outvars=outvars, - effects=effs) + invars = [*list(newvars.values()), *jaxpr.invars[num_consts:]] + effs = pe._renumber_effects(invars, + [*map(canonicalize.get, jaxpr.invars[:num_consts]), *jaxpr.invars[num_consts:]], + jaxpr.effects) + jaxpr = jaxpr.replace(jaxpr=jaxpr.jaxpr.replace(invars=invars, eqns=eqns, outvars=outvars, + effects=effs)) config.enable_checks.value and core.check_jaxpr(jaxpr) return jaxpr diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 668b880ab4dd..dc53fc16dc30 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -18,7 +18,6 @@ from collections.abc import Callable, Sequence import functools from functools import partial -import inspect import itertools import operator from typing import Any, TypeVar @@ -53,7 +52,7 @@ import numpy as np from jax._src.lax.control_flow.common import ( - _avals_short, _typecheck_param, _initial_style_jaxprs_with_common_consts, + _avals_short, _typecheck_param, _merge_common_consts, _make_closed_jaxpr, _prune_zeros) map, unsafe_map = safe_map, map @@ -149,8 +148,11 @@ def _switch_internal( if config.mutable_array_checks.value: api_util.check_no_aliased_ref_args(lambda: dbgs[0], ops_avals, ops) - jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts( - branches, ops_tree, ops_avals, dbgs) + jaxprs_, out_trees = zip(*[pe.trace_to_jaxpr( + branch, ops_tree, ops_avals, dbg) for branch, dbg in zip(branches, dbgs)]) + jaxprs_, all_consts = zip(*[pe.separate_consts(j) for j in jaxprs_]) + jaxprs, consts = _merge_common_consts(jaxprs_, all_consts) + if config.mutable_array_checks.value: api_util._check_no_aliased_closed_over_refs(dbgs[0], (*jaxprs[0].consts, *consts), ops) for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])): @@ -184,7 +186,7 @@ def _switch_internal( return tree_unflatten(out_trees[0], out) @partial(api_boundary, repro_api_name="jax_cond") -def _cond(pred, true_fun: Callable, false_fun: Callable, *operands, +def cond(pred, true_fun: Callable, false_fun: Callable, *operands, operand=_no_operand_sentinel): """Conditionally apply ``true_fun`` or ``false_fun``. @@ -270,14 +272,18 @@ def cond(pred, true_fun, false_fun, *operands): if config.mutable_array_checks.value: api_util.check_no_aliased_ref_args(lambda: dbg_true_fun, ops_avals, ops) dbg_false_fun = api_util.debug_info("cond", false_fun, operands, {}) - jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts( - (true_fun, false_fun), ops_tree, ops_avals, - [dbg_true_fun, dbg_false_fun]) - true_jaxpr, false_jaxpr = jaxprs + + true_jaxpr_, out_tree = pe.trace_to_jaxpr( + true_fun, ops_tree, ops_avals, dbg_true_fun) + true_jaxpr_, true_consts = pe.separate_consts(true_jaxpr_) + false_jaxpr_, false_out_tree = pe.trace_to_jaxpr( + false_fun, ops_tree, ops_avals, dbg_false_fun) + false_jaxpr_, false_consts = pe.separate_consts(false_jaxpr_) + (true_jaxpr, false_jaxpr), consts = _merge_common_consts( + (true_jaxpr_, false_jaxpr_), (true_consts, false_consts)) if config.mutable_array_checks.value: api_util._check_no_aliased_closed_over_refs(dbg_true_fun, (*true_jaxpr.consts, *consts), ops) - out_tree, false_out_tree = out_trees if any(isinstance(out_aval, AbstractRef) for out_aval in true_jaxpr.out_avals + false_jaxpr.out_avals): raise ValueError("Cannot return `Ref`s from `cond`.") @@ -399,48 +405,6 @@ def _capitalize(s): # s.capitalize() converts s[1:] to lowercase which we don't want. return s[0].capitalize() + s[1:] -@api_boundary -@functools.wraps(_cond) -def cond(*args, **kwargs): - # detect an attempt to call the former, deprecated cond - try: - ba = inspect.signature(_cond_with_per_branch_args).bind(*args, **kwargs) - except TypeError: - pass - else: - assert not ba.kwargs # no catch-all **kwargs in _cond_with_per_branch - _, true_operand, true_fun, false_operand, false_fun = ba.args - if callable(true_operand) and callable(true_fun): - # treat this as modern cond (with two operands) - return _cond(*args, **kwargs) - if callable(true_fun) and callable(false_fun): - return _cond_with_per_branch_args(*ba.args) - - return _cond(*args, **kwargs) - -@partial(api_boundary, repro_api_name="jax_cond_with_per_branch_args") -def _cond_with_per_branch_args(pred, - true_operand, true_fun: Callable, - false_operand, false_fun: Callable): - """Conditionally apply ``true_fun`` or ``false_fun``. - - Has equivalent semantics to this Python implementation:: - - def cond(pred, true_operand, true_fun, false_operand, false_fun): - if pred: - return true_fun(true_operand) - else: - return false_fun(false_operand) - - Pred has to be a scalar type, collection types (list, tuple) are not supported - """ - if not (callable(true_fun) and callable(false_fun)): - raise TypeError("lax.cond: true_fun and false_fun arguments should be callable.") - return _cond(pred, - lambda op: true_fun(op[0]), - lambda op: false_fun(op[1]), - (true_operand, false_operand)) - def _join_cond_effects(branches: Sequence[core.ClosedJaxpr]) -> effects.Effects: joined_effects = set() for b in branches: diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index d5e31b0ca0a2..0ca5cf773743 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -50,7 +50,7 @@ from jax._src.lax import slicing from jax._src.lax import windowed_reductions from jax._src.lax.control_flow.common import ( - _avals_short, _initial_style_jaxpr, _prune_zeros, _typecheck_param, + _avals_short, _prune_zeros, _typecheck_param, _make_closed_jaxpr) from jax._src.lax.other import logaddexp from jax._src.pjit import auto_axes, PartitionSpec as P, reshard @@ -281,9 +281,9 @@ def _create_jaxpr(init): init_flat, init_tree = tree_flatten(init) in_flat, in_tree = tree_flatten((init, xs)) carry_avals = tuple(_map(core.get_aval, init_flat)) - open_jaxpr, out_tree, consts = pe.trace_to_jaxpr( + jaxpr, out_tree = pe.trace_to_jaxpr( f, in_tree, (*carry_avals, *x_avals), debug_info=dbg_body) - jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(open_jaxpr)) + jaxpr, consts = pe.separate_consts(jaxpr) if config.mutable_array_checks.value: _check_no_aliased_closed_over_refs(dbg_body, (*jaxpr.consts, *consts), in_flat) out_tree_children = out_tree.children() @@ -1712,11 +1712,11 @@ def _create_jaxpr(init_val): init_vals, in_tree = tree_flatten((init_val,)) init_avals = tuple(_map(core.get_aval, init_vals)) cond_dbg = api_util.debug_info("while_cond", cond_fun, (init_val,), {}) - cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr( - cond_fun, in_tree, init_avals, cond_dbg) + cond_jaxpr, cond_tree = pe.trace_to_jaxpr(cond_fun, in_tree, init_avals, cond_dbg) + cond_jaxpr, cond_consts = pe.separate_consts(cond_jaxpr) body_dbg = api_util.debug_info("while_body", body_fun, (init_val,), {}) - body_jaxpr, body_consts, body_tree = _initial_style_jaxpr( - body_fun, in_tree, init_avals, body_dbg) + body_jaxpr, body_tree = pe.trace_to_jaxpr(body_fun, in_tree, init_avals, body_dbg) + body_jaxpr, body_consts = pe.separate_consts(body_jaxpr) if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1: msg = "cond_fun must return a boolean scalar, but got pytree {}." raise TypeError(msg.format(cond_tree)) diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py index 17ed44b69991..5a8600f8dcc1 100644 --- a/jax/_src/lax/control_flow/solves.py +++ b/jax/_src/lax/control_flow/solves.py @@ -27,6 +27,7 @@ from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir +from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import pxla from jax._src.traceback_util import api_boundary from jax._src.tree_util import (tree_flatten, treedef_children, tree_leaves, @@ -36,7 +37,6 @@ from jax._src.lax.control_flow.common import ( _check_tree, - _initial_style_jaxpr, ) _map = safe_map @@ -95,8 +95,9 @@ def custom_root(f: Callable, guess_flat, in_args_tree = tree_flatten((initial_guess,)) guess_avals = tuple(_map(core.get_aval, guess_flat)) f_debug = api_util.debug_info("custom_root", f, (initial_guess,), {}) - f_jaxpr, f_consts, out_tree = _initial_style_jaxpr( + f_jaxpr, out_tree = pe.trace_to_jaxpr( f, in_args_tree, guess_avals, f_debug) + f_jaxpr, f_consts = pe.separate_consts(f_jaxpr) in_tree, = treedef_children(in_args_tree) _check_tree("f", "initial_guess", out_tree, in_tree, False) @@ -104,8 +105,9 @@ def custom_root(f: Callable, solve_debug = api_util.debug_info("custom_root solve", solve, (f, initial_guess), {}, static_argnums=(0,)) - solve_jaxpr, solve_consts, solution_tree = _initial_style_jaxpr( + solve_jaxpr, solution_tree = pe.trace_to_jaxpr( partial(solve, f), in_args_tree, guess_avals, solve_debug) + solve_jaxpr, solve_consts = pe.separate_consts(solve_jaxpr) _check_tree("solve", "initial_guess", solution_tree, in_tree, has_aux) def linearize_and_solve(x, b): @@ -114,9 +116,10 @@ def linearize_and_solve(x, b): linearize_and_solve_dbg = api_util.debug_info("custom_root tangent_solve", tangent_solve, (initial_guess, initial_guess), {}) - l_and_s_jaxpr, l_and_s_consts, out_tree = _initial_style_jaxpr( + l_and_s_jaxpr, out_tree = pe.trace_to_jaxpr( linearize_and_solve, treedef_tuple((in_tree,) * 2), guess_avals * 2, linearize_and_solve_dbg) + l_and_s_jaxpr, l_and_s_consts = pe.separate_consts(l_and_s_jaxpr) _check_tree("tangent_solve", "x", out_tree, in_tree, False) all_consts = [f_consts, solve_consts, l_and_s_consts] @@ -268,17 +271,19 @@ def f_aux(x): matvec_debug = api_util.debug_info("custom_linear_solve", matvec, (b,), {}) # no auxiliary data assumed for matvec - matvec_jaxpr, matvec_consts, out_tree = _initial_style_jaxpr( + matvec_jaxpr, out_tree = pe.trace_to_jaxpr( _shape_checked(matvec, "matvec", False), in_args_tree, b_avals, matvec_debug) + matvec_jaxpr, matvec_consts = pe.separate_consts(matvec_jaxpr) _check_tree("matvec", "b", out_tree, tree, False) solve_debug = api_util.debug_info("custom_linear_solve solve", solve, (matvec, b), {}, static_argnums=(0,)) - solve_jaxpr, solve_consts, out_tree = _initial_style_jaxpr( + solve_jaxpr, out_tree = pe.trace_to_jaxpr( _shape_checked(partial(solve, matvec), "solve", has_aux), in_args_tree, b_avals, solve_debug) + solve_jaxpr, solve_consts = pe.separate_consts(solve_jaxpr) _check_tree("solve", "b", out_tree, tree, has_aux) if transpose_solve is None: @@ -294,13 +299,15 @@ def f_aux(x): vecmat_consts = matvec_consts else: vecmat = _transpose_one_output(matvec, b) - vecmat_jaxpr, vecmat_consts, out_tree = _initial_style_jaxpr( + vecmat_jaxpr, out_tree = pe.trace_to_jaxpr( vecmat, in_args_tree, b_avals, transpose_solve_debug) + vecmat_jaxpr, vecmat_consts = pe.separate_consts(vecmat_jaxpr) assert out_tree == tree - tr_solve_jaxpr, tr_solve_consts, out_tree = _initial_style_jaxpr( + tr_solve_jaxpr, out_tree = pe.trace_to_jaxpr( _shape_checked(partial(transpose_solve, vecmat), "transpose_solve", has_aux), in_args_tree, b_avals, transpose_solve_debug) + tr_solve_jaxpr, tr_solve_consts = pe.separate_consts(tr_solve_jaxpr) _check_tree("transpose_solve", "b", out_tree, tree, has_aux) all_consts = [matvec_consts, vecmat_consts, solve_consts, tr_solve_consts] diff --git a/tests/api_test.py b/tests/api_test.py index c7c4dda3a6be..8dde8b47d6c2 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -7215,10 +7215,10 @@ def fun(x): def test_cond(self): def f(x): return lax.cond(x >= 0., + lambda xt, _: xt + x, + lambda _, xf: xf - x, x + 1., - lambda xt: xt + x, - x + 2., - lambda xf: xf - x) + x + 2.) expected = """{ lambda ; a:f32[]. let b:bool[] = ge a 0.0:f32[] c:f32[] = add a 1.0:f32[] @@ -7949,10 +7949,10 @@ def f(c, x): jax.lax.scan(f, 0, jnp.arange(4)) def test_cond_traceback(self): - if sys.version_info < (3, 14): + if sys.version_info < (3, 13): # Fails because 3.11 adds an extra stack frame due to a list comprehension self.skipTest("Expected failure.") - expected_depth = 8 + expected_depth = 4 init_depth = self.cur_depth() def f(): diff --git a/tests/core_test.py b/tests/core_test.py index c7cf4918f8e5..a2b3da15fd2c 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -448,15 +448,11 @@ class JaxprTypeChecks(jtu.JaxTestCase): def setUp(self): super().setUp() - lax_control_flow._initial_style_open_jaxpr.cache_clear() - lax_control_flow._initial_style_jaxpr.cache_clear() lax_control_flow.common._dedup_consts.cache_clear() lax_control_flow.common._pad_constvars.cache_clear() def tearDown(self): super().tearDown() - lax_control_flow._initial_style_open_jaxpr.cache_clear() - lax_control_flow._initial_style_jaxpr.cache_clear() lax_control_flow.common._dedup_consts.cache_clear() lax_control_flow.common._pad_constvars.cache_clear() diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 01c8898749cc..7a3a45a485d9 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -53,25 +53,14 @@ # provides a lax.cond-compatible interface to a two-branch lax.switch. Several # tests in this file are parameterized such that they either call into lax.cond # or into this function. -def cond_via_switch(pred, true_fun, false_fun, op, *args): - if len(args) > 0: - assert len(args) == 1 - true_op, _true_fun, false_op, _false_fun = true_fun, false_fun, op, args[0] - op = (false_op, true_op) - false_fun = lambda op: _false_fun(op[0]) - true_fun = lambda op: _true_fun(op[1]) +def cond_via_switch(pred, true_fun, false_fun, *args): index = lax.convert_element_type(pred, np.int32) - return lax.switch(index, [false_fun, true_fun], op) - -def cond_with_new_checkpoint(pred, true_fun, false_fun, op, *args): - if args: - true_op, _true_fun, false_op, _false_fun = true_fun, false_fun, op, args[0] - op = (false_op, true_op) - false_fun = lambda op: _false_fun(op[0]) - true_fun = lambda op: _true_fun(op[1]) + return lax.switch(index, [false_fun, true_fun], *args) + +def cond_with_new_checkpoint(pred, true_fun, false_fun, *args): index = lax.convert_element_type(pred, np.int32) - fn = lambda index, op: lax.switch(index, [false_fun, true_fun], op) - return jax.checkpoint(fn)(index, op) + fn = lambda index, *args: lax.switch(index, [false_fun, true_fun], *args) + return jax.checkpoint(fn)(index, *args) COND_IMPLS = [ (lax.cond, 'cond'), @@ -171,8 +160,6 @@ class LaxControlFlowTest(jtu.JaxTestCase): def setUp(self): super().setUp() - lax_control_flow._initial_style_open_jaxpr.cache_clear() - lax_control_flow._initial_style_jaxpr.cache_clear() lax_control_flow.common._dedup_consts.cache_clear() lax_control_flow.common._pad_constvars.cache_clear() @@ -1000,8 +987,8 @@ def cfun(x): lax.lt(x, 2), lambda x: lax.mul(2, x), lambda x: cond(lax.lt(x, 5), - x, lambda x: lax.mul(3, x), - 4, lambda y: lax.mul(y, x)), + lambda x, _: lax.mul(3, x), + lambda _, y: lax.mul(y, x), x, 4), x) self.assertEqual(cfun(1), 2) @@ -1121,9 +1108,9 @@ def cfun(x): def testCondBatched(self): def fun(x, y, z): pred = lax.lt(x, 3) - true_fun = lambda y: y - false_fun = lambda z: lax.neg(z) - return lax.cond(pred, y, true_fun, z, false_fun) + true_fun = lambda y, _: y + false_fun = lambda _, z: lax.neg(z) + return lax.cond(pred, true_fun, false_fun, y, z) # these cases stay as cond x = jnp.array(2) @@ -1287,7 +1274,7 @@ def fun_ref(x): return 2. * x def fun(x): - return cond(x < 3, None, lambda _: 2., x, lambda x: 2. * x) + return cond(x < 3, lambda _: 2., lambda x: 2. * x, x) x = 3.14 ans = jax.jvp(fun, (x,), (x,)) @@ -1445,7 +1432,7 @@ def fun_ref(x): return 2. * x def fun(x): - return cond(x < 3, None, lambda _: 2., x, lambda x: 2. * x) + return cond(x < 3, lambda _: 2., lambda x: 2. * x, x) x = 3.14 ans = jax.grad(fun)(x) @@ -1475,8 +1462,9 @@ def fun_ref(x, y): def fun(x, y): return cond( x < 3, - None, lambda _: 2. * jnp.sin(y), - x, lambda x: 2. * x) + lambda _: 2. * jnp.sin(y), + lambda x: 2. * x, + x) y = 5.8 x = 3.14 @@ -1665,7 +1653,7 @@ def g(x): return jnp.where(x > 0, f_1(x), f_2(x)) def testIssue1263(self): def f(rng, x): cond = random.bernoulli(rng) - return lax.cond(cond, x, lambda x: x, jnp.abs(x) - 1., lambda x: x) + return lax.cond(cond, lambda x, _: x, lambda _, x: x, x, jnp.abs(x) - 1.) def body_fn(i, state): rng, x = state @@ -1680,8 +1668,9 @@ def g(rng, x): def testIssue514(self): # just check this doesn't crash lax.cond(True, - (0, 0), lambda x: (x[0], 0), - (1, 1), lambda x: x) + lambda x, _: (x[0], 0), + lambda _, x: x, + (0, 0), (1, 1)) def testIssue649(self): from jax import lax @@ -2388,8 +2377,9 @@ def testWhileGradError(self, loop: str = "fori_inside_scan"): elif loop == "fori_inside_cond": func = lambda x: lax.cond( True, - x, lambda x: lax.fori_loop(x, x + 2., lambda i, c: c * 2., x), - 1., lambda x: x) + lambda x, _: lax.fori_loop(x, x + 2., lambda i, c: c * 2., x), + lambda _, x: x, + x, 1.) elif loop == "fori_inside_scan": func = lambda x: lax.scan( lambda c, x: (lax.fori_loop(x, x + 2., lambda i, c1: c1 * c, x), None), @@ -2561,7 +2551,7 @@ def f(h, _): def test_disable_jit_cond_with_vmap(self): # https://github.com/jax-ml/jax/issues/3093 def fn(t): - return lax.cond(t > 0, 0, lambda x: 0, 0, lambda x: 1) + return lax.cond(t > 0, lambda x, _: 0, lambda _, x: 1, 0, 0) fn = jax.vmap(fn) with jax.disable_jit(): diff --git a/tests/metadata_test.py b/tests/metadata_test.py index 917cf7bf5133..524768aaed87 100644 --- a/tests/metadata_test.py +++ b/tests/metadata_test.py @@ -79,7 +79,7 @@ def true_fun(x): def false_fun(x): return jnp.cos(x) def f(which, x): - return jax.lax.cond(which, x, true_fun, x, false_fun) + return jax.lax.cond(which, true_fun, false_fun, x) hlo = module_to_string(jax.jit(f).lower(True, 1.).compiler_ir()) self.assertRegex(hlo, r'loc\(".*cond/branch_0_fun/cos"') self.assertRegex(hlo, r'loc\(".*cond/branch_1_fun/sin"') From aa886ea9de9be65cf1e7f283cc4bb9f54bb8b8e6 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 3 Dec 2025 14:04:48 -0800 Subject: [PATCH 032/315] Remove lax_metal_test and associated CI job --- .github/workflows/metal_plugin_ci.yml | 54 - tests/BUILD | 14 - tests/lax_metal_test.py | 5732 ------------------------- 3 files changed, 5800 deletions(-) delete mode 100644 .github/workflows/metal_plugin_ci.yml delete mode 100644 tests/lax_metal_test.py diff --git a/.github/workflows/metal_plugin_ci.yml b/.github/workflows/metal_plugin_ci.yml deleted file mode 100644 index 5ff1a76952db..000000000000 --- a/.github/workflows/metal_plugin_ci.yml +++ /dev/null @@ -1,54 +0,0 @@ -# JAX-Metal plugin CI - -name: Jax-Metal CI -on: - schedule: - - cron: "0 12 * * *" # Daily at 12:00 UTC - workflow_dispatch: # allows triggering the workflow run manually - pull_request: # Automatically trigger on pull requests affecting this file - branches: - - main - paths: - - '**workflows/metal_plugin_ci.yml' - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} - cancel-in-progress: true -permissions: {} -jobs: - jax-metal-plugin-test: - - strategy: - fail-fast: false # don't cancel all jobs on failure - matrix: - jaxlib-version: ["pypi_latest", "nightly"] - name: "Jax-Metal plugin test (jaxlib=${{ matrix.jaxlib-version }})" - runs-on: [self-hosted, macOS, ARM64] - - steps: - - name: Get repo - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 - with: - path: jax - persist-credentials: false - - name: Setup build and test enviroment - run: | - rm -rf ${GITHUB_WORKSPACE}/jax-metal-venv - python3 -m venv ${GITHUB_WORKSPACE}/jax-metal-venv - source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate - pip install uv~=0.5.30 - uv pip install -U pip numpy wheel absl-py pytest - if [[ "${{ matrix.jaxlib-version }}" == "nightly" ]]; then - uv pip install --pre jaxlib \ - -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ - fi; - cd jax - uv pip install . jax-metal - - name: Run test - run: | - source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate - export ENABLE_PJRT_COMPATIBILITY=1 - cd jax - pytest tests/lax_metal_test.py - - diff --git a/tests/BUILD b/tests/BUILD index 73d18edcbe77..9b0c54777d60 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -879,20 +879,6 @@ jax_multiplatform_test( ]), ) -jax_multiplatform_test( - name = "lax_metal_test", - srcs = ["lax_metal_test.py"], - enable_backends = ["metal"], - tags = ["notap"], - deps = [ - "//jax/_src:internal_test_util", - "//jax/_src:lax_reference", - ] + py_deps([ - "absl/testing", - "numpy", - ]), -) - jax_multiplatform_test( name = "lax_autodiff_test", srcs = ["lax_autodiff_test.py"], diff --git a/tests/lax_metal_test.py b/tests/lax_metal_test.py deleted file mode 100644 index d42a8f2762a3..000000000000 --- a/tests/lax_metal_test.py +++ /dev/null @@ -1,5732 +0,0 @@ -# Copyright 2018 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from array import array as make_python_array -import collections -import copy -from functools import partial -import io -import itertools -import math -import platform -from typing import Union, cast -import unittest -from unittest import SkipTest - -from absl.testing import absltest -from absl.testing import parameterized - -import numpy as np -try: - import numpy_dispatch -except ImportError: - numpy_dispatch = None - -import jax -from jax import lax -from jax import numpy as jnp -from jax.sharding import SingleDeviceSharding - -from jax._src import array -from jax._src import config -from jax._src import core -from jax._src import dtypes -from jax._src import test_util as jtu -from jax._src.lax import lax as lax_internal - -from jax._src.util import safe_zip - -try: - from jax_plugins import metal_plugin -except ImportError: - metal_plugin = None - -config.parse_flags_with_absl() - -nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)] -nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes -one_dim_array_shapes = [(1,), (6,), (12,)] -empty_array_shapes = [(0,), (0, 4), (3, 0),] -broadcast_compatible_shapes = [(), (1,), (3,), (1, 3), (4, 1), (4, 3)] - -scalar_shapes = [jtu.NUMPY_SCALAR_SHAPE, jtu.PYTHON_SCALAR_SHAPE] -array_shapes = nonempty_array_shapes + empty_array_shapes -nonzerodim_shapes = nonempty_nonscalar_array_shapes + empty_array_shapes -nonempty_shapes = scalar_shapes + nonempty_array_shapes -all_shapes = scalar_shapes + array_shapes - -float_dtypes = jtu.dtypes.all_floating -complex_dtypes = jtu.dtypes.complex -int_dtypes = jtu.dtypes.all_integer -unsigned_dtypes = jtu.dtypes.all_unsigned -bool_dtypes = jtu.dtypes.boolean -default_dtypes = float_dtypes + int_dtypes -inexact_dtypes = float_dtypes + complex_dtypes -number_dtypes = float_dtypes + complex_dtypes + int_dtypes + unsigned_dtypes -all_dtypes = number_dtypes + bool_dtypes - -NO_VALUE = object() - -python_scalar_dtypes = [jnp.bool_, jnp.int_, jnp.float_, jnp.complex_] - -# uint64 is problematic because with any uint type it promotes to float: -int_dtypes_no_uint64 = [d for d in int_dtypes + unsigned_dtypes if d != np.uint64] - -def _indexer_with_default_outputs(indexer, use_defaults=True): - """Like jtu.with_jax_dtype_defaults, but for __getitem__ APIs""" - class Indexer: - @partial(jtu.with_jax_dtype_defaults, use_defaults=use_defaults) - def __getitem__(self, *args): - return indexer.__getitem__(*args) - return Indexer() - -def _valid_dtypes_for_shape(shape, dtypes): - # Not all (shape, dtype) pairs are valid. In particular, Python scalars only - # have one type in each category (float, bool, etc.) - if shape is jtu.PYTHON_SCALAR_SHAPE: - return [t for t in dtypes if t in python_scalar_dtypes] - return dtypes - -def _shape_and_dtypes(shapes, dtypes): - for shape in shapes: - for dtype in _valid_dtypes_for_shape(shape, dtypes): - yield (shape, dtype) - -def _compatible_shapes(shape): - if np.ndim(shape) == 0 or shape in scalar_shapes: - return [shape] - return (shape[n:] for n in range(len(shape) + 1)) - -OpRecord = collections.namedtuple( - "OpRecord", - ["name", "nargs", "dtypes", "shapes", "rng_factory", "diff_modes", - "test_name", "check_dtypes", "tolerance", "inexact", "kwargs"]) - -def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes, - test_name=None, check_dtypes=True, - tolerance=None, inexact=False, kwargs=None): - test_name = test_name or name - return OpRecord(name, nargs, dtypes, shapes, rng_factory, diff_modes, - test_name, check_dtypes, tolerance, inexact, kwargs) - - -JAX_ARGMINMAX_RECORDS = [ - op_record("argmin", 1, default_dtypes, nonempty_shapes, jtu.rand_some_equal, []), - op_record("argmax", 1, default_dtypes, nonempty_shapes, jtu.rand_some_equal, []), - op_record("nanargmin", 1, default_dtypes, nonempty_shapes, jtu.rand_some_nan, []), - op_record("nanargmax", 1, default_dtypes, nonempty_shapes, jtu.rand_some_nan, []), -] - -def _shapes_are_broadcast_compatible(shapes): - try: - lax.broadcast_shapes(*(() if s in scalar_shapes else s for s in shapes)) - except ValueError: - return False - else: - return True - -def _shapes_are_equal_length(shapes): - return all(len(shape) == len(shapes[0]) for shape in shapes[1:]) - -@unittest.skipIf(metal_plugin == None, "Tests require jax-metal plugin.") -class LaxBackedNumpyTests(jtu.JaxTestCase): - """Tests for LAX-backed Numpy implementation.""" - - def _GetArgsMaker(self, rng, shapes, dtypes, np_arrays=True): - def f(): - out = [rng(shape, dtype or jnp.float_) - for shape, dtype in zip(shapes, dtypes)] - if np_arrays: - return out - return [jnp.asarray(a) if isinstance(a, (np.ndarray, np.generic)) else a - for a in out] - return f - - @parameterized.parameters( - [dtype for dtype in [jnp.bool, jnp.uint8, jnp.uint16, jnp.uint32, - jnp.int8, jnp.int16, jnp.int32, jnp.int64, - jnp.float16, jnp.float32] - if dtype == dtypes.canonicalize_dtype(dtype)]) - def testDtypeWrappers(self, dtype): - arr = dtype(0) - self.assertIsInstance(arr, jax.Array) - self.assertEqual(arr.dtype, np.dtype(dtype)) - self.assertArraysEqual(arr, 0, check_dtypes=False) - - # No copy primitive is generated - jaxpr = jax.make_jaxpr(dtype)(0) - prims = [eqn.primitive for eqn in jaxpr.eqns] - self.assertEqual(prims, [lax.convert_element_type_p]) # No copy generated. - - def testBoolDtypeAlias(self): - self.assertIs(jnp.bool, jnp.bool_) - - @jtu.sample_product( - dtype=float_dtypes + [object], - allow_pickle=[True, False], - ) - def testLoad(self, dtype, allow_pickle): - if dtype == object and not allow_pickle: - self.skipTest("dtype=object requires allow_pickle=True") - rng = jtu.rand_default(self.rng()) - arr = rng((10), dtype) - with io.BytesIO() as f: - jnp.save(f, arr) - f.seek(0) - arr_out = jnp.load(f, allow_pickle=allow_pickle) - self.assertArraysEqual(arr, arr_out, allow_object_dtype=True) - - @unittest.skip("Jax-metal fail.") - def testArrayEqualExamples(self): - # examples from the array_equal() docstring. - self.assertTrue(jnp.array_equal([1, 2], [1, 2])) - self.assertTrue(jnp.array_equal(np.array([1, 2]), np.array([1, 2]))) - self.assertFalse(jnp.array_equal([1, 2], [1, 2, 3])) - self.assertFalse(jnp.array_equal([1, 2], [1, 4])) - - a = np.array([1, np.nan]) - self.assertFalse(jnp.array_equal(a, a)) - self.assertTrue(jnp.array_equal(a, a, equal_nan=True)) - - a = np.array([1 + 1j]) - b = a.copy() - a.real = np.nan - b.imag = np.nan - self.assertTrue(jnp.array_equal(a, b, equal_nan=True)) - - def testArrayEquivExamples(self): - # examples from the array_equiv() docstring. - self.assertTrue(jnp.array_equiv([1, 2], [1, 2])) - self.assertFalse(jnp.array_equiv([1, 2], [1, 3])) - with jax.numpy_rank_promotion('allow'): - self.assertTrue(jnp.array_equiv([1, 2], [[1, 2], [1, 2]])) - self.assertFalse(jnp.array_equiv([1, 2], [[1, 2, 1, 2], [1, 2, 1, 2]])) - self.assertFalse(jnp.array_equiv([1, 2], [[1, 2], [1, 3]])) - - def testArrayModule(self): - if numpy_dispatch is None: - raise SkipTest('requires https://github.com/seberg/numpy-dispatch') - - jnp_array = jnp.array(1.0) - np_array = np.array(1.0) - - module = numpy_dispatch.get_array_module(jnp_array) - self.assertIs(module, jnp) - - module = numpy_dispatch.get_array_module(jnp_array, np_array) - self.assertIs(module, jnp) - - def f(x): - module = numpy_dispatch.get_array_module(x) - self.assertIs(module, jnp) - return x - jax.jit(f)(jnp_array) - jax.grad(f)(jnp_array) - - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in all_shapes - for axis in list(range(-len(shape), len(shape)))], - discont=[None, "pi", 2], - period=["2pi", "pi"], - dtype=default_dtypes, - ) - def testUnwrap(self, shape, dtype, axis, discont, period): - special_vals = {"pi": np.pi, "2pi": 2 * np.pi} - period = special_vals.get(period, period) - discont = special_vals.get(discont, discont) - - rng = jtu.rand_default(self.rng()) - - def np_fun(x): - dtype = None - if x.dtype == dtypes.bfloat16: - dtype = x.dtype - x = x.astype(np.float32) - out = np.unwrap(x, axis=axis, discont=discont, period=period) - return out if dtype is None else out.astype(dtype) - - jnp_fun = partial(jnp.unwrap, axis=axis, discont=discont, period=period) - if not dtypes.issubdtype(dtype, np.inexact): - # This case requires implicit dtype promotion - jnp_fun = jax.numpy_dtype_promotion('standard')(jnp_fun) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, - atol={dtypes.bfloat16: 1e-1, np.float16: 1e-2}) - self._CompileAndCheck(jnp_fun, args_maker, atol={dtypes.bfloat16: 1e-1}) - - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in all_shapes - for axis in list(range(-len(shape), len(shape))) + [None]], - dtype=all_dtypes, - ) - def testCountNonzero(self, shape, dtype, axis): - rng = jtu.rand_some_zero(self.rng()) - np_fun = lambda x: np.count_nonzero(x, axis) - jnp_fun = lambda x: jnp.count_nonzero(x, axis) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes) - def testNonzero(self, shape, dtype): - rng = jtu.rand_some_zero(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np.nonzero, jnp.nonzero, args_maker, check_dtypes=False) - - @jtu.sample_product( - [dict(shape=shape, fill_value=fill_value) - for shape in nonempty_array_shapes - for fill_value in [None, -1, shape or (1,)] - ], - dtype=all_dtypes, - size=[1, 5, 10], - ) - def testNonzeroSize(self, shape, dtype, size, fill_value): - rng = jtu.rand_some_zero(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - def np_fun(x): - result = np.nonzero(x) - if size <= len(result[0]): - return tuple(arg[:size] for arg in result) - else: - fillvals = fill_value if np.ndim(fill_value) else len(result) * [fill_value or 0] - return tuple(np.concatenate([arg, np.full(size - len(arg), fval, arg.dtype)]) - for fval, arg in safe_zip(fillvals, result)) - jnp_fun = lambda x: jnp.nonzero(x, size=size, fill_value=fill_value) - with jtu.ignore_warning(category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*"): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product(shape=all_shapes, dtype=all_dtypes) - def testFlatNonzero(self, shape, dtype): - rng = jtu.rand_some_zero(self.rng()) - np_fun = jtu.ignore_warning( - category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*")(np.flatnonzero) - jnp_fun = jnp.flatnonzero - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - - # JIT compilation requires specifying the size statically: - jnp_fun = lambda x: jnp.flatnonzero(x, size=np.size(x) // 2) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - shape=nonempty_array_shapes, - dtype=all_dtypes, - fill_value=[None, -1, 10, (-1,), (10,)], - size=[1, 5, 10], - ) - def testFlatNonzeroSize(self, shape, dtype, size, fill_value): - rng = jtu.rand_some_zero(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - @jtu.ignore_warning(category=DeprecationWarning, message="Calling nonzero on 0d arrays.*") - def np_fun(x): - result = np.flatnonzero(x) - if size <= len(result): - return result[:size] - else: - fill_val = fill_value or 0 - return np.concatenate([result, np.full(size - len(result), fill_val, result.dtype)]) - jnp_fun = lambda x: jnp.flatnonzero(x, size=size, fill_value=fill_value) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes) - def testArgWhere(self, shape, dtype): - rng = jtu.rand_some_zero(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np.argwhere, jnp.argwhere, args_maker, check_dtypes=False) - - # JIT compilation requires specifying a size statically. Full test of this - # behavior is in testNonzeroSize(). - jnp_fun = lambda x: jnp.argwhere(x, size=np.size(x) // 2) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, fill_value=fill_value) - for shape in nonempty_array_shapes - for fill_value in [None, -1, shape or (1,)] - ], - dtype=all_dtypes, - size=[1, 5, 10], - ) - def testArgWhereSize(self, shape, dtype, size, fill_value): - rng = jtu.rand_some_zero(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - def np_fun(x): - result = np.argwhere(x) - if size <= len(result): - return result[:size] - else: - fillvals = fill_value if np.ndim(fill_value) else result.shape[-1] * [fill_value or 0] - return np.empty((size, 0), dtype=int) if np.ndim(x) == 0 else np.stack([np.concatenate([arg, np.full(size - len(arg), fval, arg.dtype)]) - for fval, arg in safe_zip(fillvals, result.T)]).T - jnp_fun = lambda x: jnp.argwhere(x, size=size, fill_value=fill_value) - - with jtu.ignore_warning(category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*"): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(np_op=getattr(np, rec.name), jnp_op=getattr(jnp, rec.name), - shape=shape, dtype=dtype, axis=axis, rng_factory=rec.rng_factory) - for rec in JAX_ARGMINMAX_RECORDS - for shape, dtype in _shape_and_dtypes(rec.shapes, rec.dtypes) - for axis in range(-len(shape), len(shape))], - keepdims=[False, True], - ) - def testArgMinMax(self, np_op, jnp_op, rng_factory, shape, dtype, axis, keepdims): - rng = rng_factory(self.rng()) - if dtype == np.complex128 and jtu.test_device_matches(["gpu"]): - raise unittest.SkipTest("complex128 reductions not supported on GPU") - if "nan" in np_op.__name__ and dtype == jnp.bfloat16: - raise unittest.SkipTest("NumPy doesn't correctly handle bfloat16 arrays") - kwds = {"keepdims": True} if keepdims else {} - - np_fun = jtu.with_jax_dtype_defaults(partial(np_op, axis=axis, **kwds)) - jnp_fun = partial(jnp_op, axis=axis, **kwds) - - args_maker = lambda: [rng(shape, dtype)] - try: - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - except ValueError as e: - if str(e) == "All-NaN slice encountered": - self.skipTest("JAX doesn't support checking for all-NaN slices") - else: - raise - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(name=rec.name, np_op=getattr(np, rec.name), - jnp_op=getattr(jnp, rec.name)) - for rec in JAX_ARGMINMAX_RECORDS], - ) - def testArgMinMaxEmpty(self, name, np_op, jnp_op): - name = name[3:] if name.startswith("nan") else name - msg = f"attempt to get {name} of an empty sequence" - with self.assertRaisesRegex(ValueError, msg): - jnp_op(np.array([])) - with self.assertRaisesRegex(ValueError, msg): - jnp_op(np.zeros((2, 0)), axis=1) - np_fun = jtu.with_jax_dtype_defaults(partial(np_op, axis=0)) - jnp_fun = partial(jnp_op, axis=0) - args_maker = lambda: [np.zeros((2, 0))] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, axes=axes) - for lhs_shape, rhs_shape, axes in [ - [(2,), (2,), (-1, -1, -1, None)], # scalar output - [(2, 4), (2, 4), (-1, -1, -1, 0)], # 2D vectors - [(3, 4), (3, 4), (-1, -1, -1, 0)], # 3D vectors - [(3, 4), (3, 6, 5, 4), (-1, -1, -1, 0)], # broadcasting - [(4, 3), (3, 6, 5, 4), (1, 0, -1, None)], # different axes - [(6, 1, 3), (5, 3), (-1, -1, -1, None)], # more broadcasting - [(6, 1, 2), (5, 3), (-1, -1, -1, None)], # mixed 2D and 3D vectors - [(10, 5, 2, 8), (1, 5, 1, 3), (-2, -1, -3, None)], # axes/broadcasting - [(4, 5, 2), (4, 5, 2), (-1, -1, 0, None)], # axisc should do nothing - [(4, 5, 2), (4, 5, 2), (-1, -1, -1, None)] # same as before - ]], - lhs_dtype=number_dtypes, - rhs_dtype=number_dtypes, - ) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testCross(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] - axisa, axisb, axisc, axis = axes - jnp_fun = lambda a, b: jnp.cross(a, b, axisa, axisb, axisc, axis) - # Note: 2D inputs to jnp.cross are deprecated in numpy 2.0. - @jtu.ignore_warning(category=DeprecationWarning, - message="Arrays of 2-dimensional vectors are deprecated.") - def np_fun(a, b): - a = a.astype(np.float32) if lhs_dtype == jnp.bfloat16 else a - b = b.astype(np.float32) if rhs_dtype == jnp.bfloat16 else b - out = np.cross(a, b, axisa, axisb, axisc, axis) - return out.astype(jnp.promote_types(lhs_dtype, rhs_dtype)) - tol_spec = {dtypes.bfloat16: 3e-1, np.float16: 0.15} - tol = max(jtu.tolerance(lhs_dtype, tol_spec), - jtu.tolerance(rhs_dtype, tol_spec)) - with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol) - - @jtu.sample_product( - [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape) - for lhs_shape, rhs_shape in [ - ((3, 3), ()), - ((), (3, 3)), - ((4, 5), (5,)), - ((6,), (6, 4)), - ((3, 4), (4, 5)), - ((4, 3, 2), (2,)), - ((2,), (3, 2, 4)), - ((4, 3, 2), (2, 5)), - ((5, 2), (3, 2, 4)), - ((2, 3, 4), (5, 4, 1))]], - lhs_dtype=float_dtypes,#number_dtypes, - rhs_dtype=float_dtypes,#number_dtypes, - ) - @jax.default_matmul_precision("float32") - def testDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] - tol = {np.float16: 1e-2, np.float32: 2e-5, np.float64: 1e-14, - np.complex128: 1e-14} - if (lhs_dtype in [np.float16, jnp.bfloat16] and - rhs_dtype in [np.float16, jnp.bfloat16]): - tol = 1e-2 - def np_dot(x, y): - x = x.astype(np.float32) if lhs_dtype == jnp.bfloat16 else x - y = y.astype(np.float32) if rhs_dtype == jnp.bfloat16 else y - return np.dot(x, y).astype(jnp.promote_types(lhs_dtype, rhs_dtype)) - with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): - self._CheckAgainstNumpy(np_dot, jnp.dot, args_maker, tol=tol) - self._CompileAndCheck(jnp.dot, args_maker, atol=tol, rtol=tol) - - @jtu.sample_product( - lhs_dtype=number_dtypes, - rhs_dtype=number_dtypes, - ) - @jax.numpy_dtype_promotion('standard') - def testMixedPrecisionDot(self, lhs_dtype, rhs_dtype): - # This test confirms that jnp.dot lowers to a single dot_general call, - # avoiding explicit type casting of inputs and outputs. - lhs = jax.ShapeDtypeStruct((5,), lhs_dtype) - rhs = jax.ShapeDtypeStruct((5,), rhs_dtype) - jaxpr = jax.make_jaxpr(jnp.dot)(lhs, rhs) - prims = [eqn.primitive for eqn in jaxpr.eqns] - self.assertIn(prims, [ - [lax.dot_general_p], - [lax.dot_general_p, lax.convert_element_type_p] - ]) - - @jtu.sample_product( - [dict(name=name, lhs_shape=lhs_shape, rhs_shape=rhs_shape) - for name, lhs_shape, rhs_shape in [ - ("vector-vector", (3,), (3,)), - ("matrix-vector", (3, 3), (3,)), - ("vector-matrix", (3,), (3, 3)), - ("matrix-matrix", (3, 3), (3, 3)), - ("vector-tensor", (3,), (5, 3, 2)), - ("tensor-vector", (5, 3, 2), (2,)), - ("matrix-tensor", (5, 2), (3, 2, 4)), - ("tensor-matrix", (5, 2, 3), (3, 2)), - ("tensor-tensor", (5, 3, 4), (5, 4, 1)), - ("tensor-tensor-broadcast", (3, 1, 3, 4), (5, 4, 1))]], - lhs_dtype=float_dtypes, #number_dtypes, - rhs_dtype=float_dtypes, #number_dtypes, - ) - @jax.default_matmul_precision("float32") - def testMatmul(self, name, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype): - rng = jtu.rand_default(self.rng()) - def np_fun(x, y): - dtype = jnp.promote_types(lhs_dtype, rhs_dtype) - return np.matmul(x, y).astype(dtype) - args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] - tol = {np.float16: 1e-2, np.float32: 2e-2, np.float64: 1e-12, - np.complex128: 1e-12} - - with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): - self._CheckAgainstNumpy(np_fun, jnp.matmul, args_maker, tol=tol) - self._CompileAndCheck(jnp.matmul, args_maker, atol=tol, rtol=tol) - - @jtu.sample_product( - lhs_batch=broadcast_compatible_shapes, - rhs_batch=broadcast_compatible_shapes, - axis_size=[2, 4], - axis=range(-2, 2), - dtype=float_dtypes,#number_dtypes, - ) - @jax.default_matmul_precision("float32") - @jax.numpy_rank_promotion('allow') # adopt PR#22316 - def testVecdot(self, lhs_batch, rhs_batch, axis_size, axis, dtype): - # Construct vecdot-compatible shapes. - size = min(len(lhs_batch), len(rhs_batch)) - axis = int(np.clip(axis, -size - 1, size)) - if axis >= 0: - lhs_shape = (*lhs_batch[:axis], axis_size, *lhs_batch[axis:]) - rhs_shape = (*rhs_batch[:axis], axis_size, *rhs_batch[axis:]) - else: - laxis = axis + len(lhs_batch) + 1 - lhs_shape = (*lhs_batch[:laxis], axis_size, *lhs_batch[laxis:]) - raxis = axis + len(rhs_batch) + 1 - rhs_shape = (*rhs_batch[:raxis], axis_size, *rhs_batch[raxis:]) - - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] - @jtu.promote_like_jnp - def np_fn(x, y, axis=axis): - return np.vecdot(x, y, axis=axis).astype(x.dtype) - jnp_fn = partial(jnp.vecdot, axis=axis) - tol = {np.float16: 1e-2, np.float32: 1E-3, np.float64: 1e-12, - np.complex64: 1E-3, np.complex128: 1e-12} - self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol) - self._CompileAndCheck(jnp_fn, args_maker, tol=tol) - - @jtu.sample_product( - [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, axes=axes) - for lhs_shape, rhs_shape, axes in [ - [(3,), (), 0], - [(2, 3, 4), (5, 6, 7), 0], # from issue #740 - [(2, 3, 4), (3, 4, 5, 6), 2], - [(2, 3, 4), (5, 4, 3, 6), [1, 2]], - [(2, 3, 4), (5, 4, 3, 6), [[1, 2], [2, 1]]], - [(1, 2, 3, 4), (4, 5, 3, 6), [[2, 3], [2, 0]]], - ]], - lhs_dtype=float_dtypes,#number_dtypes, - rhs_dtype=float_dtypes,#number_dtypes, - ) - @jax.default_matmul_precision("float32") - def testTensordot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] - jnp_fun = lambda a, b: jnp.tensordot(a, b, axes) - def np_fun(a, b): - a = a if lhs_dtype != jnp.bfloat16 else a.astype(np.float32) - b = b if rhs_dtype != jnp.bfloat16 else b.astype(np.float32) - dtype = jnp.promote_types(lhs_dtype, rhs_dtype) - return np.tensordot(a, b, axes).astype(dtype) - tol = {np.float16: 1e-1, np.float32: 1e-3, np.float64: 1e-12, - np.complex64: 1e-3, np.complex128: 1e-12} - - with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, tol=tol) - - def testTensordotErrors(self): - a = self.rng().random((3, 2, 2)) - b = self.rng().random((2,)) - self.assertRaisesRegex( - TypeError, "Number of tensordot axes.*exceeds input ranks.*", - lambda: jnp.tensordot(a, b, axes=2)) - - self.assertRaisesRegex( - TypeError, "tensordot requires axes lists to have equal length.*", - lambda: jnp.tensordot(a, b, axes=([0], [0, 1]))) - - self.assertRaisesRegex( - TypeError, "tensordot requires both axes lists to be either ints, tuples or lists.*", - lambda: jnp.tensordot(a, b, axes=('bad', 'axes'))) - - self.assertRaisesRegex( - TypeError, "tensordot axes argument must be an int, a pair of ints, or a pair of lists.*", - lambda: jnp.tensordot(a, b, axes='badaxes')) - - @jtu.sample_product( - element_shape=all_shapes, - test_shape=all_shapes, - dtype=default_dtypes, - invert=[False, True], - ) - def testIsin(self, element_shape, test_shape, dtype, invert): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(element_shape, dtype), rng(test_shape, dtype)] - jnp_fun = lambda e, t: jnp.isin(e, t, invert=invert) - np_fun = lambda e, t: np.isin(e, t, invert=invert) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - dtype1=[s for s in default_dtypes if s != jnp.bfloat16], - dtype2=[s for s in default_dtypes if s != jnp.bfloat16], - shape1=all_shapes, - shape2=all_shapes, - ) - def testSetdiff1d(self, shape1, shape2, dtype1, dtype2): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] - - with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): - self._CheckAgainstNumpy(np.setdiff1d, jnp.setdiff1d, args_maker) - - @unittest.skip("JAx-metal fail.") - @jtu.sample_product( - dtype1=[s for s in default_dtypes if s != jnp.bfloat16], - dtype2=[s for s in default_dtypes if s != jnp.bfloat16], - shape1=all_shapes, - shape2=all_shapes, - size=[1, 5, 10], - fill_value=[None, -1], - ) - def testSetdiff1dSize(self, shape1, shape2, dtype1, dtype2, size, fill_value): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] - def np_fun(arg1, arg2): - result = np.setdiff1d(arg1, arg2) - if size <= len(result): - return result[:size] - else: - return np.pad(result, (0, size-len(result)), constant_values=fill_value or 0) - def jnp_fun(arg1, arg2): - return jnp.setdiff1d(arg1, arg2, size=size, fill_value=fill_value) - with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - dtype1=[s for s in default_dtypes if s != jnp.bfloat16], - dtype2=[s for s in default_dtypes if s != jnp.bfloat16], - shape1=nonempty_nonscalar_array_shapes, - shape2=nonempty_nonscalar_array_shapes, - ) - def testUnion1d(self, shape1, shape2, dtype1, dtype2): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] - def np_fun(arg1, arg2): - dtype = jnp.promote_types(arg1.dtype, arg2.dtype) - return np.union1d(arg1, arg2).astype(dtype) - with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): - self._CheckAgainstNumpy(np_fun, jnp.union1d, args_maker) - - @unittest.skip("Jax-metal fail.") - @jtu.sample_product( - dtype1=[s for s in default_dtypes if s != jnp.bfloat16], - dtype2=[s for s in default_dtypes if s != jnp.bfloat16], - shape1=nonempty_nonscalar_array_shapes, - shape2=nonempty_nonscalar_array_shapes, - size=[1, 5, 10], - fill_value=[None, -1], - ) - def testUnion1dSize(self, shape1, shape2, dtype1, dtype2, size, fill_value): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] - def np_fun(arg1, arg2): - dtype = jnp.promote_types(arg1.dtype, arg2.dtype) - result = np.union1d(arg1, arg2).astype(dtype) - fv = result.min() if fill_value is None else fill_value - if size <= len(result): - return result[:size] - else: - return np.concatenate([result, np.full(size - len(result), fv, result.dtype)]) - def jnp_fun(arg1, arg2): - return jnp.union1d(arg1, arg2, size=size, fill_value=fill_value) - with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - dtype1=[s for s in default_dtypes if s != jnp.bfloat16], - dtype2=[s for s in default_dtypes if s != jnp.bfloat16], - shape1=all_shapes, - shape2=all_shapes, - assume_unique=[False, True], - ) - def testSetxor1d(self, shape1, dtype1, shape2, dtype2, assume_unique): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] - jnp_fun = lambda ar1, ar2: jnp.setxor1d(ar1, ar2, assume_unique=assume_unique) - def np_fun(ar1, ar2): - if assume_unique: - # pre-flatten the arrays to match with jax implementation - ar1 = np.ravel(ar1) - ar2 = np.ravel(ar2) - return np.setxor1d(ar1, ar2, assume_unique) - with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - - @jtu.sample_product( - dtype1=[s for s in default_dtypes if s != jnp.bfloat16], - dtype2=[s for s in default_dtypes if s != jnp.bfloat16], - shape1=all_shapes, - shape2=all_shapes, - assume_unique=[False, True], - return_indices=[False, True], - ) - def testIntersect1d(self, shape1, dtype1, shape2, dtype2, assume_unique, - return_indices): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] - jnp_fun = lambda ar1, ar2: jnp.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) - np_fun = lambda ar1, ar2: np.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) - with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - - @jtu.sample_product( - [dict(lhs_shape=lhs_shape, lhs_dtype=lhs_dtype, - rhs_shape=rhs_shape, rhs_dtype=rhs_dtype) - # TODO(phawkins): support integer dtypes too. - for lhs_shape, lhs_dtype in _shape_and_dtypes(all_shapes, inexact_dtypes) - for rhs_shape, rhs_dtype in _shape_and_dtypes(all_shapes, inexact_dtypes) - if len(jtu._dims_of_shape(lhs_shape)) == 0 - or len(jtu._dims_of_shape(rhs_shape)) == 0 - or lhs_shape[-1] == rhs_shape[-1]], - ) - @jax.default_matmul_precision("float32") - def testInner(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] - def np_fun(lhs, rhs): - lhs = lhs if lhs_dtype != jnp.bfloat16 else lhs.astype(np.float32) - rhs = rhs if rhs_dtype != jnp.bfloat16 else rhs.astype(np.float32) - dtype = jnp.promote_types(lhs_dtype, rhs_dtype) - return np.inner(lhs, rhs).astype(dtype) - jnp_fun = lambda lhs, rhs: jnp.inner(lhs, rhs) - tol_spec = {np.float16: 1e-2, np.float32: 1e-5, np.float64: 1e-13, - np.complex64: 1e-5} - tol = max(jtu.tolerance(lhs_dtype, tol_spec), - jtu.tolerance(rhs_dtype, tol_spec)) - # TODO(phawkins): there are float32/float64 disagreements for some inputs. - with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=False, atol=tol, rtol=tol) - - @unittest.skip("MLIR translation rule for primitive 'eigh' not found for platform METAL.") - @jtu.sample_product( - dtype=[dt for dt in float_dtypes if dt not in [jnp.float16, jnp.bfloat16]], - shape=[shape for shape in one_dim_array_shapes if shape != (1,)], - deg=[1, 2, 3], - rcond=[None, -1, 10e-3, 10e-5, 10e-10], - full=[False, True], - w=[False, True], - cov=[False, True, "unscaled"], - ) - @jax.default_matmul_precision("float32") - def testPolyfit(self, shape, dtype, deg, rcond, full, w, cov): - rng = jtu.rand_default(self.rng()) - tol_spec = {np.float32: 1e-3, np.float64: 1e-13, np.complex64: 1e-5} - tol = jtu.tolerance(dtype, tol_spec) - _w = lambda a: abs(a) if w else None - args_maker = lambda: [rng(shape, dtype), rng(shape, dtype), rng(shape, dtype)] - jnp_fun = lambda x, y, a: jnp.polyfit(x, y, deg=deg, rcond=rcond, full=full, w=_w(a), cov=cov) - np_fun = jtu.ignore_warning( - message="Polyfit may be poorly conditioned*")(lambda x, y, a: np.polyfit(x, y, deg=deg, rcond=rcond, full=full, w=_w(a), cov=cov)) - - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=False, atol=tol, rtol=tol) - - args = args_maker() - if not full: - args = args_maker() - try: - np_out = np_fun(*args) - except ValueError: - return # https://github.com/numpy/numpy/issues/22380 - jnp_out = jnp_fun(*args) - self.assertAllClose(np_out, jnp_out, atol=tol, rtol=tol, - check_dtypes=False) - else: - # Don't compare the residuals because jnp.linalg.lstsq acts slightly - # differently to remain `jit`-compatible. - np_p, _, nrank, nsingular_values, nrcond = np_fun(*args) - jp_p, _, jrank, jsingular_values, jrcond = jnp_fun(*args) - self.assertAllClose( - (np_p, nrank, nsingular_values, nrcond), - (jp_p, jrank, jsingular_values, jrcond), - atol=tol, rtol=tol, check_dtypes=False) - - @jtu.sample_product( - [dict(a_min=a_min, a_max=a_max) - for a_min, a_max in [(-1, None), (None, 1), (-0.9, 1), - (-np.ones(1), None), - (None, np.ones(1)), - (np.full(1, -0.9), np.ones(1))] - ], - shape=all_shapes, - dtype=number_dtypes, - ) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion - def testClipStaticBounds(self, shape, dtype, a_min, a_max): - if np.issubdtype(dtype, np.unsignedinteger): - a_min = None if a_min is None else abs(a_min) - a_max = None if a_max is None else abs(a_max) - rng = jtu.rand_default(self.rng()) - np_fun = lambda x: np.clip(x, a_min=a_min, a_max=a_max) - jnp_fun = lambda x: jnp.clip(x, min=a_min, max=a_max) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, dtype=dtype) - for shape, dtype in _shape_and_dtypes(all_shapes, number_dtypes)], - decimals=[0, 1, -2], - ) - def testRoundStaticDecimals(self, shape, dtype, decimals): - rng = jtu.rand_default(self.rng()) - if jnp.issubdtype(dtype, np.integer) and decimals < 0: - self.skipTest("Integer rounding with decimals < 0 not implemented") - np_fun = lambda x: np.round(x, decimals=decimals) - jnp_fun = lambda x: jnp.round(x, decimals=decimals) - args_maker = lambda: [rng(shape, dtype)] - tol = {jnp.bfloat16: 5e-2, np.float16: 1e-2} - check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, - check_dtypes=check_dtypes, tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=check_dtypes, - atol=tol, rtol=tol) - - @jtu.sample_product(jit=[False, True]) - def testOperatorRound(self, jit): - jround = jax.jit(round, static_argnums=1) if jit else round - self.assertAllClose(round(np.float32(7.532), 1), - jround(jnp.float32(7.5), 1)) - self.assertAllClose(round(np.float32(1.234), 2), - jround(jnp.float32(1.234), 2)) - self.assertAllClose(round(np.float32(1.234)), - jround(jnp.float32(1.234)), check_dtypes=False) - self.assertAllClose(round(np.float32(7.532), 1), - jround(jnp.array(7.5, jnp.float32), 1)) - self.assertAllClose(round(np.float32(1.234), 2), - jround(jnp.array(1.234, jnp.float32), 2)) - self.assertAllClose(round(np.float32(1.234)), - jround(jnp.array(1.234, jnp.float32)), - check_dtypes=False) - - def testRoundMethod(self): - # https://github.com/jax-ml/jax/issues/15190 - (jnp.arange(3.) / 5.).round() # doesn't crash - - @jtu.sample_product(shape=[(5,), (5, 2)]) - def testOperatorReversed(self, shape): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, 'float32')] - np_fun = lambda x: np.array(list(reversed(x))) - jnp_fun = lambda x: jnp.array(list(reversed(x))) - - self._CompileAndCheck(jnp_fun, args_maker) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - - @jtu.sample_product( - [dict(mode=mode, shape=shape, dtype=dtype, - pad_width=pad_width, constant_values=constant_values) - for mode, shapes in [ - ('constant', all_shapes), - ('wrap', nonempty_shapes), - ('edge', nonempty_shapes), - ] - for shape, dtype in _shape_and_dtypes(shapes, all_dtypes) - for constant_values in [ - # None is used for modes other than 'constant' - None, - # constant - 0, 1, - # (constant,) - (0,), (2.718,), - # ((before_const, after_const),) - ((0, 2),), ((-1, 3.14),), - # ((before_1, after_1), ..., (before_N, after_N)) - tuple((i / 2, -3.14 * i) for i in range(len(shape))), - ] - for pad_width in [ - # ((before_1, after_1), ..., (before_N, after_N)) - tuple((i % 3, (i + 1) % 3) for i in range(len(shape))), - # ((before, after),) - ((1, 2),), ((2, 0),), - # (before, after) (not in the docstring but works in numpy) - (2, 0), (0, 0), - # (pad,) - (1,), (2,), - # pad - 0, 1, - ] - if (pad_width != () and constant_values != () and - ((mode == 'constant' and constant_values is not None) or - (mode != 'constant' and constant_values is None)))], - ) - def testPad(self, shape, dtype, mode, pad_width, constant_values): - if np.issubdtype(dtype, np.unsignedinteger): - constant_values = jax.tree.map(abs, constant_values) - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - if constant_values is None: - np_fun = partial(np.pad, pad_width=pad_width, mode=mode) - jnp_fun = partial(jnp.pad, pad_width=pad_width, mode=mode) - else: - np_fun = partial(np.pad, pad_width=pad_width, mode=mode, - constant_values=constant_values) - jnp_fun = partial(jnp.pad, pad_width=pad_width, mode=mode, - constant_values=constant_values) - - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, - check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(mode=mode, shape=shape, dtype=dtype, - pad_width=pad_width, stat_length=stat_length) - for mode in ['maximum', 'minimum', 'mean', 'median'] - for shape, dtype in _shape_and_dtypes(nonempty_shapes, all_dtypes) - for pad_width in [ - # ((before_1, after_1), ..., (before_N, after_N)) - tuple((i % 3, (i + 1) % 3) for i in range(len(shape))), - # ((before, after),) - ((1, 2),), ((2, 0),), - # (before, after) (not in the docstring but works in numpy) - (2, 0), (0, 0), - # (pad,) - (1,), (2,), - # pad - 0, 1, - ] - for stat_length in [ - None, - # ((before_1, after_1), ..., (before_N, after_N)) - tuple(((i % 3 + 1), ((i + 1) % 3) + 1) for i in range(len(shape))), - # ((before, after),) - ((1, 2),), ((2, 2),), - # (before, after) (not in the docstring but works in numpy) - (1, 1), (3, 4), - # (pad,) - (1,), (2,), - # pad - 1, 2 - ] - if (pad_width != () and stat_length != () and - not (dtype in bool_dtypes and mode == 'mean'))], - ) - def testPadStatValues(self, shape, dtype, mode, pad_width, stat_length): - if mode == 'median' and np.issubdtype(dtype, np.complexfloating): - self.skipTest("median statistic is not supported for dtype=complex.") - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - - np_fun = partial(np.pad, pad_width=pad_width, mode=mode, stat_length=stat_length) - jnp_fun = partial(jnp.pad, pad_width=pad_width, mode=mode, stat_length=stat_length) - - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, - check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, dtype=dtype, - pad_width=pad_width, reflect_type=reflect_type) - for shape, dtype in _shape_and_dtypes(nonempty_shapes, all_dtypes) - for pad_width in [ - # ((before_1, after_1), ..., (before_N, after_N)) - tuple((i % 3, (i + 1) % 3) for i in range(len(shape))), - # ((before, after),) - ((1, 2),), ((2, 3),), - # (before, after) (not in the docstring but works in numpy) - (2, 1), (1, 2), - # (pad,) - (1,), (2,), (3,), - # pad - 0, 5, 7, 10 - ] - for reflect_type in ['even', 'odd'] - if (pad_width != () and - # following types lack precision when calculating odd values - (reflect_type != 'odd' or dtype not in [np.bool_, np.float16, jnp.bfloat16]))], - mode=['symmetric', 'reflect'] - ) - def testPadSymmetricAndReflect(self, shape, dtype, mode, pad_width, reflect_type): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - - np_fun = partial(np.pad, pad_width=pad_width, mode=mode, reflect_type=reflect_type) - jnp_fun = partial(jnp.pad, pad_width=pad_width, mode=mode, reflect_type=reflect_type) - - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, - check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE, - tol={np.float32: 1e-3, np.complex64: 1e-3}) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, dtype=dtype, pad_width=pad_width, end_values=end_values) - for shape, dtype in _shape_and_dtypes(nonempty_shapes, default_dtypes + complex_dtypes) - for pad_width in [ - # ((before_1, after_1), ..., (before_N, after_N)) - tuple((i % 3, (i + 1) % 3) for i in range(len(shape))), - # ((before, after),) - ((1, 2),), ((2, 0),), - # (before, after) (not in the docstring but works in numpy) - (2, 0), (0, 0), - # (pad,) - (1,), (2,), - # pad - 0, 1, - ] - for end_values in [ - # ((before_1, after_1), ..., (before_N, after_N)) - tuple((i % 3, (i + 1) % 3) for i in range(len(shape))), - # ((before, after),) - ((1, 2),), ((2.0, 3.14),), - # (before, after) (not in the docstring but works in numpy) - (0, 0), (-8.0, 2.0), - # (end_values,) - (1,), (2,), - # end_values - 0, 1, 100, 10.0, 3.5, 4.2, -5, -3 - ] - if (pad_width != () and end_values != () and - # following types lack precision - dtype not in [np.int8, np.int16, np.float16, jnp.bfloat16])], - ) - def testPadLinearRamp(self, shape, dtype, pad_width, end_values): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - - np_fun = partial(np.pad, pad_width=pad_width, mode="linear_ramp", - end_values=end_values) - jnp_fun = partial(jnp.pad, pad_width=pad_width, mode="linear_ramp", - end_values=end_values) - - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, - check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE) - self._CompileAndCheck(jnp_fun, args_maker) - - def testPadEmpty(self): - arr = np.arange(6).reshape(2, 3) - - pad_width = ((2, 3), (3, 1)) - np_res = np.pad(arr, pad_width=pad_width, mode="empty") - jnp_res = jnp.pad(arr, pad_width=pad_width, mode="empty") - - np.testing.assert_equal(np_res.shape, jnp_res.shape) - np.testing.assert_equal(arr, np_res[2:-3, 3:-1]) - np.testing.assert_equal(arr, jnp_res[2:-3, 3:-1]) - np.testing.assert_equal(np_res[2:-3, 3:-1], jnp_res[2:-3, 3:-1]) - - def testPadKwargs(self): - modes = { - 'constant': {'constant_values': 0}, - 'edge': {}, - 'linear_ramp': {'end_values': 0}, - 'maximum': {'stat_length': None}, - 'mean': {'stat_length': None}, - 'median': {'stat_length': None}, - 'minimum': {'stat_length': None}, - 'reflect': {'reflect_type': 'even'}, - 'symmetric': {'reflect_type': 'even'}, - 'wrap': {}, - 'empty': {} - } - arr = jnp.array([1, 2, 3]) - pad_width = 1 - - for mode in modes.keys(): - allowed = modes[mode] - not_allowed = {} - for kwargs in modes.values(): - if kwargs != allowed: - not_allowed.update(kwargs) - - # Test if allowed keyword arguments pass - jnp.pad(arr, pad_width, mode, **allowed) - # Test if prohibited keyword arguments of other modes raise an error - match = f"unsupported keyword arguments for mode '{mode}'" - for key, value in not_allowed.items(): - with self.assertRaisesRegex(ValueError, match): - jnp.pad(arr, pad_width, mode, **{key: value}) - - # Test if unsupported mode raise error. - unsupported_modes = [1, None, "foo"] - for mode in unsupported_modes: - match = f"Unimplemented padding mode '{mode}' for np.pad." - with self.assertRaisesRegex(NotImplementedError, match): - jnp.pad(arr, pad_width, mode) - - def testPadFunction(self): - def np_pad_with(vector, pad_width, iaxis, kwargs): - pad_value = kwargs.get('padder', 10) - vector[:pad_width[0]] = pad_value - vector[-pad_width[1]:] = pad_value - - def jnp_pad_with(vector, pad_width, iaxis, kwargs): - pad_value = kwargs.get('padder', 10) - vector = vector.at[:pad_width[0]].set(pad_value) - vector = vector.at[-pad_width[1]:].set(pad_value) - return vector - - arr = np.arange(6).reshape(2, 3) - np_res = np.pad(arr, 2, np_pad_with) - jnp_res = jnp.pad(arr, 2, jnp_pad_with) - np.testing.assert_equal(np_res, jnp_res) - - arr = np.arange(24).reshape(2, 3, 4) - np_res = np.pad(arr, 1, np_pad_with, padder=100) - jnp_res = jnp.pad(arr, 1, jnp_pad_with, padder=100) - np.testing.assert_equal(np_res, jnp_res) - - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(arr.shape, arr.dtype)] - jnp_fun = partial(jnp.pad, pad_width=1, mode=jnp_pad_with) - self._CompileAndCheck(jnp_fun, args_maker) - - def testPadWithNumpyPadWidth(self): - a = jnp.array([1, 2, 3, 4, 5]) - f = jax.jit( - partial( - jnp.pad, - pad_width=np.asarray((2, 3)), - mode="constant", - constant_values=(4, 6))) - - np.testing.assert_array_equal( - f(a), - np.pad( - a, - pad_width=np.asarray((2, 3)), - mode="constant", - constant_values=(4, 6))) - - def testPadWeakType(self): - x = jnp.array(1.0)[None] - for mode in ['constant', 'edge', 'linear_ramp', 'maximum', 'mean', 'median', - 'minimum', 'reflect', 'symmetric', 'wrap', 'empty']: - y = jnp.pad(x, 0, mode=mode) - self.assertTrue(dtypes.is_weakly_typed(y)) - - @jtu.sample_product( - [dict(shape=shape, dtype=dtype) - for shape, dtype in _shape_and_dtypes(all_shapes, default_dtypes)], - reps=[(), (2,), (3, 4), (2, 3, 4), (1, 0, 2)], - ) - def testTile(self, shape, dtype, reps): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arg: np.tile(arg, reps) - jnp_fun = lambda arg: jnp.tile(arg, reps) - - args_maker = lambda: [rng(shape, dtype)] - - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, - check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product(shape=all_shapes, dtype=all_dtypes) - def testExtract(self, shape, dtype): - rng = jtu.rand_some_zero(self.rng()) - args_maker = lambda: [rng(shape, jnp.float32), rng(shape, dtype)] - self._CheckAgainstNumpy(np.extract, jnp.extract, args_maker) - - @jtu.sample_product( - [dict(ncond=ncond, nfunc=nfunc) - for ncond in [1, 2, 3] - for nfunc in [ncond, ncond + 1] - ], - shape=all_shapes, - dtype=all_dtypes) - def testPiecewise(self, shape, dtype, ncond, nfunc): - rng = jtu.rand_default(self.rng()) - rng_bool = jtu.rand_int(self.rng(), 0, 2) - funclist = [lambda x: x - 1, 1, lambda x: x, 0][:nfunc] - args_maker = lambda: (rng(shape, dtype), [rng_bool(shape, bool) for i in range(ncond)]) - np_fun = partial(np.piecewise, funclist=funclist) - jnp_fun = partial(jnp.piecewise, funclist=funclist) - - if dtype == np.bool_: - # The `x - 1` above uses type promotion. - jnp_fun = jax.numpy_dtype_promotion('standard')(jnp_fun) - - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - # This is a higher-order function, so the cache miss check will fail. - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, check_cache_misses=False) - - def testPiecewiseRecompile(self): - def g(x): - g.num_traces += 1 - return x - g.num_traces = 0 - x = jnp.arange(10.0) - for i in range(5): - jnp.piecewise(x, [x < 0], [g, 0.]) - self.assertEqual(g.num_traces, 1) - - @jtu.sample_product( - [dict(shape=shape, perm=perm) - for shape in array_shapes - for perm in [ - None, - tuple(np.random.RandomState(0).permutation(np.zeros(shape).ndim)), - tuple(np.random.RandomState(0).permutation( - np.zeros(shape).ndim) - np.zeros(shape).ndim) - ] - ], - dtype=default_dtypes, - arg_type=["splat", "value"], - ) - def testTransposeTuple(self, shape, dtype, perm, arg_type): - rng = jtu.rand_some_zero(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - if arg_type == "value": - np_fun = lambda x: x.transpose(perm) - jnp_fun = lambda x: jnp.array(x).transpose(perm) - else: - np_fun = lambda x: x.transpose(*(perm or ())) - jnp_fun = lambda x: jnp.array(x).transpose(*(perm or ())) - - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) - - @jtu.sample_product( - shape=array_shapes, - dtype=default_dtypes, - ) - def testPermuteDims(self, shape, dtype): - rng = jtu.rand_some_zero(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - axes = self.rng().permutation(len(shape)) - np_fun = partial(getattr(np, "permute_dims", np.transpose), axes=axes) - jnp_fun = partial(jnp.permute_dims, axes=axes) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) - - @jtu.sample_product( - shape=[s for s in array_shapes if len(s) >= 2], - dtype=default_dtypes, - use_property=[True, False] - ) - def testMatrixTranspose(self, shape, dtype, use_property): - if use_property: - jnp_fun = lambda x: jnp.asarray(x).mT - else: - jnp_fun = jnp.matrix_transpose - np_fun = np.matrix_transpose - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - dtype=default_dtypes, - a_shape=one_dim_array_shapes, - trim=["f", "b", "fb"], - ) - def testTrimZeros(self, a_shape, dtype, trim): - rng = jtu.rand_some_zero(self.rng()) - args_maker = lambda: [rng(a_shape, dtype)] - np_fun = lambda arg1: np.trim_zeros(arg1, trim) - jnp_fun = lambda arg1: jnp.trim_zeros(arg1, trim) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - - @unittest.skip("Jax-metal don't support map op.") - @jtu.sample_product( - rank=(1, 2), - dtype=default_dtypes, - a_shape=one_dim_array_shapes, - ) - @jax.default_matmul_precision("float32") - def testPoly(self, a_shape, dtype, rank): - if dtype in (np.float16, jnp.bfloat16, np.int16): - self.skipTest(f"{dtype} gets promoted to {np.float16}, which is not supported.") - elif rank == 2 and not jtu.test_device_matches(["cpu"]): - self.skipTest("Nonsymmetric eigendecomposition is only implemented on the CPU backend.") - rng = jtu.rand_default(self.rng()) - tol = { np.int8: 2e-3, np.int32: 1e-3, np.float32: 1e-3, np.float64: 1e-6 } - if jtu.test_device_matches(["tpu"]): - tol[np.int32] = tol[np.float32] = 1e-1 - tol = jtu.tolerance(dtype, tol) - args_maker = lambda: [rng(a_shape * rank, dtype)] - self._CheckAgainstNumpy(np.poly, jnp.poly, args_maker, check_dtypes=False, tol=tol) - self._CompileAndCheck(jnp.poly, args_maker, check_dtypes=True, rtol=tol, atol=tol) - - @unittest.skip("Jax-metal don't support map op.") - @jtu.sample_product( - dtype=default_dtypes, - a_shape=one_dim_array_shapes, - b_shape=one_dim_array_shapes, - ) - def testPolyAdd(self, a_shape, b_shape, dtype): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arg1, arg2: np.polyadd(arg1, arg2) - jnp_fun = lambda arg1, arg2: jnp.polyadd(arg1, arg2) - args_maker = lambda: [rng(a_shape, dtype), rng(b_shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) - - @unittest.skip("Jax-metal don't support map op.") - @jtu.sample_product( - dtype=default_dtypes, - a_shape=one_dim_array_shapes, - b_shape=one_dim_array_shapes, - ) - def testPolySub(self, a_shape, b_shape, dtype): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arg1, arg2: np.polysub(arg1, arg2) - jnp_fun = lambda arg1, arg2: jnp.polysub(arg1, arg2) - args_maker = lambda: [rng(a_shape, dtype), rng(b_shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) - - @unittest.skip("Jax-metal don't support map op.") - @jtu.sample_product( - [dict(order=order, k=k, dtype=dtype) - for dtype in default_dtypes - for order in range(5) - for k in [np.arange(order, dtype=dtype), np.ones(1, dtype), None]], - a_shape=one_dim_array_shapes, - ) - def testPolyInt(self, a_shape, order, k, dtype): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arg1: np.polyint(arg1, m=order, k=k) - jnp_fun = lambda arg1: jnp.polyint(arg1, m=order, k=k) - args_maker = lambda: [rng(a_shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) - - @unittest.skip("Jax-metal don't support map op.") - @jtu.sample_product( - dtype=default_dtypes, - a_shape=one_dim_array_shapes, - order=list(range(5)), - ) - def testPolyDer(self, a_shape, order, dtype): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arg1: np.polyder(arg1, m=order) - jnp_fun = lambda arg1: jnp.polyder(arg1, m=order) - args_maker = lambda: [rng(a_shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) - - @parameterized.parameters(['int', 'np.int', 'jnp.int']) - def testIntegerPower(self, ptype): - p = {'int': 2, 'np.int': np.int32(2), 'jnp.int': jnp.int32(2)}[ptype] - jaxpr = jax.make_jaxpr(lambda x1: jnp.power(x1, p))(1) - eqns = jaxpr.jaxpr.eqns - self.assertLen(eqns, 1) - self.assertEqual(eqns[0].primitive, lax.integer_pow_p) - - @jtu.sample_product( - x=[-1, 0, 1], - y=[0, 32, 64, 128], - ) - def testIntegerPowerOverflow(self, x, y): - # Regression test for https://github.com/jax-ml/jax/issues/5987 - args_maker = lambda: [x, y] - self._CheckAgainstNumpy(np.power, jnp.power, args_maker) - self._CompileAndCheck(jnp.power, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in all_shapes - for axis in [None] + list(range(len(shape))) - ], - dtype=all_dtypes, - ) - def testCompress(self, shape, dtype, axis): - rng = jtu.rand_some_zero(self.rng()) - if shape in scalar_shapes or len(shape) == 0: - cond_shape = (0,) - elif axis is None: - cond_shape = (math.prod(shape),) - else: - cond_shape = (shape[axis],) - - args_maker = lambda: [rng(cond_shape, jnp.float32), rng(shape, dtype)] - - np_fun = partial(np.compress, axis=axis) - jnp_fun = partial(jnp.compress, axis=axis) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - - @jtu.sample_product( - shape=[(2, 3)], - dtype=int_dtypes, - # condition entries beyond axis size must be zero. - condition=[[1], [1, 0, 0, 0, 0, 0, 0]], - axis=[None, 0, 1], - ) - def testCompressMismatchedShapes(self, shape, dtype, condition, axis): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [np.array(condition), rng(shape, dtype)] - np_fun = partial(np.compress, axis=axis) - jnp_fun = partial(jnp.compress, axis=axis) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in array_shapes - for axis in [None] + list(range(len(shape))) - ], - dtype=all_dtypes, - ) - def testCompressMethod(self, shape, dtype, axis): - rng = jtu.rand_some_zero(self.rng()) - if shape in scalar_shapes or len(shape) == 0: - cond_shape = (0,) - elif axis is None: - cond_shape = (math.prod(shape),) - else: - cond_shape = (shape[axis],) - - args_maker = lambda: [rng(cond_shape, jnp.float32), rng(shape, dtype)] - - np_fun = lambda condition, x: np.compress(condition, x, axis=axis) - jnp_fun = lambda condition, x: x.compress(condition, axis=axis) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - - @jtu.sample_product( - [dict(base_shape=base_shape, axis=axis) - for base_shape in [(4,), (3, 4), (2, 3, 4)] - for axis in (None, *range(-len(base_shape)+1, len(base_shape))) - ], - arg_dtypes=[ - arg_dtypes - for num_arrs in [3] - for arg_dtypes in itertools.combinations_with_replacement(default_dtypes, num_arrs) - ], - dtype=[None] + default_dtypes, - ) - def testConcatenate(self, axis, dtype, base_shape, arg_dtypes): - rng = jtu.rand_default(self.rng()) - wrapped_axis = 0 if axis is None else axis % len(base_shape) - shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:] - for size, _ in zip(itertools.cycle([3, 1, 4]), arg_dtypes)] - @jtu.promote_like_jnp - def np_fun(*args, dtype=dtype): - dtype = dtype or args[0].dtype - args = [x if x.dtype != jnp.bfloat16 else x.astype(np.float32) - for x in args] - return np.concatenate(args, axis=axis, dtype=dtype, casting='unsafe') - jnp_fun = lambda *args: jnp.concatenate(args, axis=axis, dtype=dtype) - - def args_maker(): - return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)] - - with jtu.strict_promotion_if_dtypes_match(arg_dtypes): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in [(4, 1), (4, 3), (4, 5, 6)] - for axis in [None] + list(range(1 - len(shape), len(shape) - 1)) - ], - dtype=all_dtypes, - ) - def testConcatenateArray(self, shape, dtype, axis): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - np_fun = lambda x: np.concatenate(x, axis=axis) - jnp_fun = lambda x: jnp.concatenate(x, axis=axis) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - def testConcatenateAxisNone(self): - # https://github.com/jax-ml/jax/issues/3419 - a = jnp.array([[1, 2], [3, 4]]) - b = jnp.array([[5]]) - jnp.concatenate((a, b), axis=None) - - def testConcatenateScalarAxisNone(self): - arrays = [np.int32(0), np.int32(1)] - self.assertArraysEqual(jnp.concatenate(arrays, axis=None), - np.concatenate(arrays, axis=None)) - - @jtu.sample_product( - [dict(base_shape=base_shape, axis=axis) - for base_shape in [(), (4,), (3, 4), (2, 3, 4)] - for axis in (None, *range(-len(base_shape)+1, len(base_shape))) - ], - dtype=default_dtypes, - ) - def testConcat(self, axis, base_shape, dtype): - rng = jtu.rand_default(self.rng()) - wrapped_axis = 0 if axis is None else axis % len(base_shape) - shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:] - for size in [3, 1, 4]] - @jtu.promote_like_jnp - def np_fun(*args): - return np.concat(args, axis=axis) - jnp_fun = lambda *args: jnp.concat(args, axis=axis) - args_maker = lambda: [rng(shape, dtype) for shape in shapes] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(base_shape=base_shape, axis=axis) - for base_shape in [(4,), (3, 4), (2, 3, 4)] - for axis in range(-len(base_shape)+1, len(base_shape))], - arg_dtypes=itertools.combinations_with_replacement(default_dtypes, 2) - ) - def testAppend(self, axis, base_shape, arg_dtypes): - rng = jtu.rand_default(self.rng()) - wrapped_axis = axis % len(base_shape) - shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:] - for size, _ in zip(itertools.cycle([3, 1, 4]), arg_dtypes)] - def np_fun(arr, values): - arr = arr.astype(np.float32) if arr.dtype == jnp.bfloat16 else arr - values = (values.astype(np.float32) if values.dtype == jnp.bfloat16 - else values) - out = np.append(arr, values, axis=axis) - return out.astype(jnp.promote_types(*arg_dtypes)) - jnp_fun = lambda arr, values: jnp.append(arr, values, axis=axis) - - def args_maker(): - return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)] - - with jtu.strict_promotion_if_dtypes_match(arg_dtypes): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axis=axis, idx=idx) - for shape in nonempty_nonscalar_array_shapes - for axis in [None] + list(range(-len(shape), len(shape))) - for idx in (range(-math.prod(shape), math.prod(shape)) - if axis is None else - range(-shape[axis], shape[axis]))], - dtype=all_dtypes, - ) - def testDeleteInteger(self, shape, dtype, idx, axis): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - np_fun = lambda arg: np.delete(arg, idx, axis=axis) - jnp_fun = lambda arg: jnp.delete(arg, idx, axis=axis) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in nonempty_nonscalar_array_shapes - for axis in [None] + list(range(-len(shape), len(shape))) - ], - dtype=all_dtypes, - slc=[slice(None), slice(1, 3), slice(1, 5, 2)], - ) - def testDeleteSlice(self, shape, dtype, axis, slc): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - np_fun = lambda arg: np.delete(arg, slc, axis=axis) - jnp_fun = lambda arg: jnp.delete(arg, slc, axis=axis) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in nonempty_nonscalar_array_shapes - for axis in [None] + list(range(-len(shape), len(shape))) - ], - dtype=all_dtypes, - idx_shape=all_shapes, - ) - def testDeleteIndexArray(self, shape, dtype, axis, idx_shape): - rng = jtu.rand_default(self.rng()) - max_idx = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis] - idx = jtu.rand_int(self.rng(), low=-max_idx, high=max_idx)(idx_shape, int) - args_maker = lambda: [rng(shape, dtype)] - np_fun = lambda arg: np.delete(arg, idx, axis=axis) - jnp_fun = lambda arg: jnp.delete(arg, idx, axis=axis) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in nonempty_nonscalar_array_shapes - for axis in [None] + list(range(-len(shape), len(shape))) - ], - dtype=all_dtypes, - idx_shape=all_shapes, - ) - def testDeleteUniqueIndices(self, shape, dtype, axis, idx_shape): - rng = jtu.rand_default(self.rng()) - max_idx = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis] - idx_size = np.zeros(idx_shape).size - if idx_size > max_idx: - self.skipTest("Too many indices to be unique") - def args_maker(): - x = rng(shape, dtype) - idx = self.rng().choice(max_idx, idx_shape, replace=False) - return x, idx - np_fun = partial(np.delete, axis=axis) - jnp_fun = partial(jnp.delete, axis=axis, assume_unique_indices=True) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in nonempty_nonscalar_array_shapes - for axis in [None] + list(range(-len(shape), len(shape))) - ], - dtype=all_dtypes, - ) - def testDeleteMaskArray(self, shape, dtype, axis): - rng = jtu.rand_default(self.rng()) - mask_size = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis] - mask = jtu.rand_int(self.rng(), low=0, high=2)(mask_size, bool) - args_maker = lambda: [rng(shape, dtype)] - np_fun = lambda arg: np.delete(arg, mask, axis=axis) - jnp_fun = lambda arg: jnp.delete(arg, mask, axis=axis) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @unittest.skip("JAX-metal fail.") - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in nonempty_nonscalar_array_shapes - for axis in [None] + list(range(-len(shape), len(shape))) - ], - dtype=all_dtypes, - ) - def testInsertInteger(self, shape, dtype, axis): - x = jnp.empty(shape) - max_ind = x.size if axis is None else x.shape[axis] - rng = jtu.rand_default(self.rng()) - i_rng = jtu.rand_int(self.rng(), -max_ind, max_ind) - args_maker = lambda: [rng(shape, dtype), i_rng((), np.int32), rng((), dtype)] - np_fun = lambda *args: np.insert(*args, axis=axis) - jnp_fun = lambda *args: jnp.insert(*args, axis=axis) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @unittest.skip("Jax-metal fail.") - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in nonempty_nonscalar_array_shapes - for axis in [None] + list(range(-len(shape), len(shape))) - ], - dtype=all_dtypes, - ) - def testInsertSlice(self, shape, dtype, axis): - x = jnp.empty(shape) - max_ind = x.size if axis is None else x.shape[axis] - rng = jtu.rand_default(self.rng()) - i_rng = jtu.rand_int(self.rng(), -max_ind, max_ind) - slc = slice(i_rng((), jnp.int32).item(), i_rng((), jnp.int32).item()) - args_maker = lambda: [rng(shape, dtype), rng((), dtype)] - np_fun = lambda x, val: np.insert(x, slc, val, axis=axis) - jnp_fun = lambda x, val: jnp.insert(x, slc, val, axis=axis) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @parameterized.parameters([ - [[[1, 1], [2, 2], [3, 3]], 1, 5, None], - [[[1, 1], [2, 2], [3, 3]], 1, 5, 1], - [[[1, 1], [2, 2], [3, 3]], 1, [1, 2, 3], 1], - [[[1, 1], [2, 2], [3, 3]], [1], [[1],[2],[3]], 1], - [[1, 1, 2, 2, 3, 3], [2, 2], [5, 6], None], - [[1, 1, 2, 2, 3, 3], slice(2, 4), [5, 6], None], - [[1, 1, 2, 2, 3, 3], [2, 2], [7.13, False], None], - [[[0, 1, 2, 3], [4, 5, 6, 7]], (1, 3), 999, 1] - ]) - def testInsertExamples(self, arr, index, values, axis): - # Test examples from the np.insert docstring - args_maker = lambda: ( - np.asarray(arr), index if isinstance(index, slice) else np.array(index), - np.asarray(values), axis) - self._CheckAgainstNumpy(np.insert, jnp.insert, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in nonempty_array_shapes - for axis in range(-len(shape), len(shape)) - ], - dtype=default_dtypes, - out_dims=[0, 1, 2], - ) - def testApplyAlongAxis(self, shape, dtype, axis, out_dims): - def func(x, out_dims): - if out_dims == 0: - return x.sum(dtype=x.dtype) - elif out_dims == 1: - return x * x[0] - elif out_dims == 2: - return x[:, None] + x[None, :] - else: - raise NotImplementedError(f"{out_dims=}") - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - np_fun = lambda arr: np.apply_along_axis(func, axis, arr, out_dims=out_dims) - jnp_fun = lambda arr: jnp.apply_along_axis(func, axis, arr, out_dims=out_dims) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, - atol={dtypes.bfloat16: 2e-2}) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axes=axes) - for shape in nonempty_shapes - for axes in itertools.combinations(range(len(shape)), 2) - ], - func=["sum"], - keepdims=[True, False], - # Avoid low-precision types in sum() - dtype=[dtype for dtype in default_dtypes - if dtype not in [np.float16, jnp.bfloat16]], - ) - def testApplyOverAxes(self, shape, dtype, func, keepdims, axes): - f = lambda x, axis: getattr(x, func)(axis=axis, keepdims=keepdims, dtype=dtype) - rng = jtu.rand_default(self.rng()) - args_maker = lambda: (rng(shape, dtype),) - np_fun = lambda a: np.apply_over_axes(f, a, axes) - jnp_fun = lambda a: jnp.apply_over_axes(f, a, axes) - self._CompileAndCheck(jnp_fun, args_maker) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, dtype=dtype, axis=axis) - for shape, dtype in _shape_and_dtypes(all_shapes, default_dtypes) - for axis in [None] + list(range(-len(shape), max(1, len(shape)))) - ], - repeats=[0, 1, 2], - fixed_size=[False, True], - ) - def testRepeat(self, axis, shape, dtype, repeats, fixed_size): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arg: np.repeat(arg, repeats=repeats, axis=axis) - np_fun = jtu.promote_like_jnp(np_fun) - if fixed_size: - total_repeat_length = np.repeat(np.zeros(shape), repeats, axis).shape[axis or 0] - jnp_fun = lambda arg, rep: jnp.repeat(arg, repeats=rep, axis=axis, - total_repeat_length=total_repeat_length) - jnp_args_maker = lambda: [rng(shape, dtype), repeats] - clo_fun = lambda arg: jnp.repeat(arg, repeats=repeats, axis=axis, - total_repeat_length=total_repeat_length) - clo_fun_args_maker = lambda: [rng(shape, dtype)] - self._CompileAndCheck(jnp_fun, jnp_args_maker) - self._CheckAgainstNumpy(np_fun, clo_fun, clo_fun_args_maker) - else: - # Now repeats is in a closure, so a constant. - jnp_fun = lambda arg: jnp.repeat(arg, repeats=repeats, axis=axis) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - def testRepeatScalarFastPath(self): - a = jnp.array([1,2,3,4]) - f = lambda a: jnp.repeat(a, repeats=2) - jaxpr = jax.make_jaxpr(f)(a) - self.assertLessEqual(len(jaxpr.jaxpr.eqns), 6) - - @unittest.skip("jax-metal fail to convert sort op.") - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in all_shapes - for axis in [None] + list(range(len(shape)))], - dtype=number_dtypes, - return_index=[False, True], - return_inverse=[False, True], - return_counts=[False, True], - ) - def testUnique(self, shape, dtype, axis, return_index, return_inverse, return_counts): - rng = jtu.rand_some_equal(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - extra_args = (return_index, return_inverse, return_counts) - use_defaults = (False, *(True for arg in extra_args if arg)) if any(extra_args) else False - np_fun = jtu.with_jax_dtype_defaults(lambda x: np.unique(x, *extra_args, axis=axis), use_defaults) - jnp_fun = lambda x: jnp.unique(x, *extra_args, axis=axis) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - - @jtu.sample_product(shape=all_shapes, dtype=number_dtypes) - def testUniqueAll(self, shape, dtype): - rng = jtu.rand_some_equal(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - np_fun = np.unique_all - self._CheckAgainstNumpy(jnp.unique_all, np_fun, args_maker) - - @jtu.sample_product(shape=all_shapes, dtype=number_dtypes) - def testUniqueCounts(self, shape, dtype): - rng = jtu.rand_some_equal(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - np_fun = np.unique_counts - self._CheckAgainstNumpy(jnp.unique_counts, np_fun, args_maker) - - @jtu.sample_product(shape=all_shapes, dtype=number_dtypes) - def testUniqueInverse(self, shape, dtype): - rng = jtu.rand_some_equal(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - np_fun = np.unique_inverse - self._CheckAgainstNumpy(jnp.unique_inverse, np_fun, args_maker) - - @jtu.sample_product(shape=all_shapes, dtype=number_dtypes) - def testUniqueValues(self, shape, dtype): - rng = jtu.rand_some_equal(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - np_fun = np.unique_values - self._CheckAgainstNumpy(jnp.unique_values, np_fun, args_maker) - - @unittest.skip("jax-metal fail.") - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in nonempty_array_shapes - for axis in [None] + list(range(len(shape)))], - dtype=number_dtypes, - size=[1, 5, 10], - fill_value=[None, 0, "slice"], - ) - def testUniqueSize(self, shape, dtype, axis, size, fill_value): - rng = jtu.rand_some_equal(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - kwds = dict(axis=axis, return_index=True, return_inverse=True, return_counts=True) - - if fill_value == "slice": - if axis is None: - fill_value = rng((), dtype) - else: - fill_value = rng(shape[:axis] + shape[axis + 1:], dtype) - elif fill_value is not None: - fill_value = np.array(fill_value).astype(dtype) - - @partial(jtu.with_jax_dtype_defaults, use_defaults=(False, True, True, True)) - def np_fun(x, fill_value=fill_value): - u, ind, inv, counts = np.unique(x, **kwds) - axis = kwds['axis'] - if axis is None: - x = x.ravel() - axis = 0 - - n_unique = u.shape[axis] - if size <= u.shape[axis]: - slc = (slice(None),) * axis + (slice(size),) - u, ind, counts = u[slc], ind[:size], counts[:size] - else: - extra = (0, size - n_unique) - pads = [(0, 0)] * u.ndim - pads[axis] = extra - u = np.pad(u, pads, constant_values=0) - slices = [slice(None)] * u.ndim - slices[axis] = slice(1) - if fill_value is None: - fill_value = u[tuple(slices)] - elif np.ndim(fill_value): - fill_value = lax.expand_dims(fill_value, (axis,)) - slices[axis] = slice(n_unique, None) - u[tuple(slices)] = fill_value - ind = np.pad(ind, extra, constant_values=ind[0]) - counts = np.pad(counts, extra, constant_values=0) - return u, ind, inv, counts - - jnp_fun = lambda x: jnp.unique(x, size=size, fill_value=fill_value, **kwds) - - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @unittest.skip("jax-metal fail.") - @jtu.sample_product(dtype=inexact_dtypes) - def testUniqueNans(self, dtype): - def args_maker(): - x = [-0.0, 0.0, 1.0, 1.0, np.nan, -np.nan] - if np.issubdtype(dtype, np.complexfloating): - x = [complex(i, j) for i, j in itertools.product(x, repeat=2)] - return [np.array(x, dtype=dtype)] - - kwds = dict(return_index=True, return_inverse=True, return_counts=True) - jnp_fun = partial(jnp.unique, **kwds) - def np_fun(x): - dtype = x.dtype - # numpy unique fails for bfloat16 NaNs, so we cast to float64 - if x.dtype == jnp.bfloat16: - x = x.astype('float64') - u, *rest = np.unique(x, **kwds) - return (u.astype(dtype), *rest) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - - @unittest.skip("jax-metal fail.") - @jtu.sample_product(dtype=inexact_dtypes, equal_nan=[True, False]) - def testUniqueEqualNan(self, dtype, equal_nan): - shape = (20,) - rng = jtu.rand_some_nan(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - def np_fun(x): - dtype = x.dtype - # numpy unique fails for bfloat16 NaNs, so we cast to float64 - if x.dtype == jnp.bfloat16: - x = x.astype('float64') - return np.unique(x, equal_nan=equal_nan).astype(dtype) - jnp_fun = partial(jnp.unique, equal_nan=equal_nan) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - - @jtu.sample_product(fixed_size=[False, True]) - def testNonScalarRepeats(self, fixed_size): - ''' - Following numpy test suite from `test_repeat` at - https://github.com/numpy/numpy/blob/main/numpy/core/tests/test_multiarray.py - ''' - tol = 1e-5 - - def test_single(m, args_maker, repeats, axis): - lax_ans = jnp.repeat(m, repeats, axis) - numpy_ans = np.repeat(m, repeats, axis) - - self.assertAllClose(lax_ans, numpy_ans, rtol=tol, atol=tol) - if fixed_size: - - # Calculate expected size of the repeated axis. - rep_length = np.repeat(np.zeros_like(m), repeats, axis).shape[axis or 0] - jnp_fun = lambda arg, rep: jnp.repeat( - arg, repeats=rep, axis=axis, total_repeat_length=rep_length) - else: - jnp_fun = lambda arg: jnp.repeat(arg, repeats = repeats, axis=axis) - self._CompileAndCheck(jnp_fun, args_maker) - - m = jnp.array([1,2,3,4,5,6]) - if fixed_size: - args_maker = lambda: [m, repeats] - else: - args_maker = lambda: [m] - - for repeats in [2, jnp.array([1,3,0,1,1,2]), jnp.array([1,3,2,1,1,2]), jnp.array([2])]: - test_single(m, args_maker, repeats, axis=None) - test_single(m, args_maker, repeats, axis=0) - - m_rect = m.reshape((2,3)) - if fixed_size: - args_maker = lambda: [m_rect, repeats] - else: - args_maker = lambda: [m_rect] - - for repeats in [2, jnp.array([2,1]), jnp.array([2])]: - test_single(m_rect, args_maker, repeats, axis=0) - - for repeats in [2, jnp.array([1,3,2]), jnp.array([2])]: - test_single(m_rect, args_maker, repeats, axis=1) - - def testIssue2330(self): - ''' - Make sure return value of jnp.concatenate is a jax.ndarray and is side-effect save - ''' - def attempt_sideeffect(x): - x = [x] - x = jnp.concatenate(x) - x -= 1. - return x - - np_input = np.ones(1) - jnp_input = jnp.ones(1) - expected_np_input_after_call = np.ones(1) - expected_jnp_input_after_call = jnp.ones(1) - - out = jnp.concatenate([np_input]) - self.assertIs(type(out), array.ArrayImpl) - - attempt_sideeffect(np_input) - attempt_sideeffect(jnp_input) - - self.assertAllClose(np_input, expected_np_input_after_call) - self.assertAllClose(jnp_input, expected_jnp_input_after_call) - - @jtu.sample_product( - mode=['full', 'same', 'valid'], - op=['convolve', 'correlate'], - dtype= float_dtypes, #number_dtypes, - xshape=one_dim_array_shapes, - yshape=one_dim_array_shapes, - ) - def testConvolutions(self, xshape, yshape, dtype, mode, op): - jnp_op = getattr(jnp, op) - np_op = getattr(np, op) - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)] - precision = lax.Precision.HIGHEST if jtu.test_device_matches(["tpu"]) else None - jnp_fun = partial(jnp_op, mode=mode, precision=precision) - def np_fun(x, y): - return np_op(x, y, mode=mode).astype(dtypes.to_inexact_dtype(dtype)) - tol = {np.float16: 2e-1, np.float32: 1e-2, np.float64: 1e-14, - np.complex128: 1e-14} - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True, tol=tol) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - mode=['full', 'same', 'valid'], - op=['convolve', 'correlate'], - dtype=float_dtypes, #number_dtypes, - xshape=one_dim_array_shapes, - yshape=one_dim_array_shapes, - ) - @jtu.skip_on_devices("cuda", "rocm") # backends don't support all dtypes. - def testConvolutionsPreferredElementType(self, xshape, yshape, dtype, mode, op): - jnp_op = getattr(jnp, op) - np_op = getattr(np, op) - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)] - precision = lax.Precision.HIGHEST if jtu.test_device_matches(["tpu"]) else None - jnp_fun = partial(jnp_op, mode=mode, precision=precision, - preferred_element_type=dtype) - def np_fun(x, y): - return np_op(x, y, mode=mode).astype(dtype) - tol = {np.float16: 2e-1, np.float32: 1e-2, np.float64: 1e-14, - np.complex128: 1e-14} - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True, tol=tol) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in all_shapes - for axis in [None] + list(range(-len(shape), len(shape)))], - op=["cumsum", "cumprod"], - dtype=all_dtypes, - out_dtype=[dtype for dtype in default_dtypes if dtype != np.float16], - ) - def testCumSumProd(self, axis, shape, dtype, out_dtype, op): - jnp_op = getattr(jnp, op) - np_op = getattr(np, op) - rng = jtu.rand_default(self.rng()) - np_fun = lambda arg: np_op(arg, axis=axis, dtype=out_dtype) - np_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(np_fun) - np_fun = jtu.ignore_warning(category=RuntimeWarning, - message="overflow encountered.*")(np_fun) - jnp_fun = lambda arg: jnp_op(arg, axis=axis, dtype=out_dtype) - jnp_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(jnp_fun) - - args_maker = lambda: [rng(shape, dtype)] - - tol_thresholds = {dtypes.bfloat16: 4e-2} - tol = max(jtu.tolerance(dtype, tol_thresholds), - jtu.tolerance(out_dtype, tol_thresholds)) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, - tol=tol) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in all_shapes - for axis in [None] + list(range(-len(shape), len(shape)))], - op=["nancumsum", "nancumprod"], - dtype=all_dtypes, - out_dtype=default_dtypes, - ) - def testNanCumSumProd(self, axis, shape, dtype, out_dtype, op): - jnp_op = getattr(jnp, op) - np_op = getattr(np, op) - rng = jtu.rand_some_nan(self.rng()) - np_fun = partial(np_op, axis=axis, dtype=out_dtype) - np_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(np_fun) - np_fun = jtu.ignore_warning(category=RuntimeWarning, - message="overflow encountered.*")(np_fun) - jnp_fun = partial(jnp_op, axis=axis, dtype=out_dtype) - jnp_fun = jtu.ignore_warning(category=np.exceptions.ComplexWarning)(jnp_fun) - - args_maker = lambda: [rng(shape, dtype)] - - tol_thresholds = {dtypes.bfloat16: 4e-2, np.float16: 3e-3} - tol = max(jtu.tolerance(dtype, tol_thresholds), - jtu.tolerance(out_dtype, tol_thresholds)) - if dtype != jnp.bfloat16: - # numpy functions do not properly handle bfloat16 - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True, - tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) - - @unittest.skip("Jax-metal fail on testEye2") - @jtu.sample_product( - dtype=default_dtypes, - n=[0, 4], - m=[None, 0, 1, 3, 4], - k=[*range(-4, 4), -2**100, 2**100], - ) - def testEye(self, n, m, k, dtype): - np_fun = lambda: np.eye(n, M=m, k=k, dtype=dtype) - jnp_fun = lambda: jnp.eye(n, M=m, k=k, dtype=dtype) - args_maker = lambda: [] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - dtype=default_dtypes, - n=[0, 4], - m=[None, 0, 1, 3, 4], - k=range(-4, 4), - ) - def testTri(self, m, n, k, dtype): - np_fun = lambda: np.tri(n, M=m, k=k, dtype=dtype) - jnp_fun = lambda: jnp.tri(n, M=m, k=k, dtype=dtype) - args_maker = lambda: [] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - dtype=default_dtypes, - shape=[shape for shape in all_shapes if len(shape) >= 2], - op=["tril", "triu"], - k=list(range(-3, 3)), - ) - def testTriLU(self, dtype, shape, op, k): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arg: getattr(np, op)(arg, k=k) - jnp_fun = lambda arg: getattr(jnp, op)(arg, k=k) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - n=range(5), - k=range(-3, 3), - m=[None, *range(5)], - ) - def testTrilIndices(self, n, k, m): - np_fun = lambda n, k, m: np.tril_indices(n, k=k, m=m) - jnp_fun = lambda n, k, m: jnp.tril_indices(n, k=k, m=m) - args_maker = lambda: [n, k, m] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - - @jtu.sample_product( - n=range(5), - k=range(-3, 3), - m=[None, *range(5)], - ) - def testTriuIndices(self, n, k, m): - np_fun = lambda n, k, m: np.triu_indices(n, k=k, m=m) - jnp_fun = lambda n, k, m: jnp.triu_indices(n, k=k, m=m) - args_maker = lambda: [n, k, m] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - - @jtu.sample_product( - dtype=default_dtypes, - shape=[(1,1), (1,2), (2,2), (2,3), (3,2), (3,3), (4,4)], - k=[-1, 0, 1], - ) - def testTriuIndicesFrom(self, shape, dtype, k): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arr, k: np.triu_indices_from(arr, k=k) - jnp_fun = lambda arr, k: jnp.triu_indices_from(arr, k=k) - args_maker = lambda: [rng(shape, dtype), k] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - - @jtu.sample_product( - dtype=default_dtypes, - shape=[(1,1), (1,2), (2,2), (2,3), (3,2), (3,3), (4,4)], - k=[-1, 0, 1], - ) - def testTrilIndicesFrom(self, shape, dtype, k): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arr, k: np.tril_indices_from(arr, k=k) - jnp_fun = lambda arr, k: jnp.tril_indices_from(arr, k=k) - args_maker = lambda: [rng(shape, dtype), k] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - - @jtu.sample_product( - dtype=default_dtypes, - a_shape=[(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1), (2, 2), (1, 2), (0, 2), (2, 3), (2, 2, 2), (2, 2, 2, 2)], - val_shape=[(), (1,), (2,), (1, 2), (3, 2)], - ) - def testFillDiagonal(self, dtype, a_shape, val_shape): - rng = jtu.rand_default(self.rng()) - - def np_fun(a, val): - a_copy = a.copy() - np.fill_diagonal(a_copy, val) - return a_copy - - jnp_fun = partial(jnp.fill_diagonal, inplace=False) - args_maker = lambda : [rng(a_shape, dtype), rng(val_shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - ndim=[0, 1, 4], - n=[0, 1, 7], - ) - def testDiagIndices(self, ndim, n): - np.testing.assert_equal(jtu.with_jax_dtype_defaults(np.diag_indices)(n, ndim), - jnp.diag_indices(n, ndim)) - - @jtu.sample_product( - dtype=default_dtypes, - shape=[(1,1), (2,2), (3,3), (4,4), (5,5)], - ) - def testDiagIndicesFrom(self, dtype, shape): - rng = jtu.rand_default(self.rng()) - np_fun = jtu.with_jax_dtype_defaults(np.diag_indices_from) - jnp_fun = jnp.diag_indices_from - args_maker = lambda : [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - dtype=default_dtypes, - shape=[shape for shape in all_shapes if len(shape) in (1, 2)], - k=list(range(-4, 4)), - ) - def testDiag(self, shape, dtype, k): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arg: np.diag(arg, k) - jnp_fun = lambda arg: jnp.diag(arg, k) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - dtype=default_dtypes, - shape=all_shapes, - k=list(range(-4, 4)), - ) - def testDiagFlat(self, shape, dtype, k): - rng = jtu.rand_default(self.rng()) - # numpy has inconsistencies for scalar values - # https://github.com/numpy/numpy/issues/16477 - # jax differs in that it treats scalars values as length-1 arrays - np_fun = lambda arg: np.diagflat(np.atleast_1d(arg), k) - jnp_fun = lambda arg: jnp.diagflat(arg, k) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True) - - @unittest.skip("jax-metal fail.") - @jtu.sample_product( - dtype=default_dtypes, - a1_shape=one_dim_array_shapes, - a2_shape=one_dim_array_shapes, - ) - def testPolyMul(self, a1_shape, a2_shape, dtype): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arg1, arg2: np.polymul(arg1, arg2) - jnp_fun_np = lambda arg1, arg2: jnp.polymul(arg1, arg2, trim_leading_zeros=True) - jnp_fun_co = lambda arg1, arg2: jnp.polymul(arg1, arg2) - args_maker = lambda: [rng(a1_shape, dtype), rng(a2_shape, dtype)] - tol = {np.float16: 2e-1, np.float32: 5e-2, np.float64: 1e-13} - self._CheckAgainstNumpy(np_fun, jnp_fun_np, args_maker, check_dtypes=False, tol=tol) - self._CompileAndCheck(jnp_fun_co, args_maker, check_dtypes=False) - - @unittest.skip("jax-metal fail.") - @jtu.sample_product( - dtype=[dtype for dtype in default_dtypes - if dtype not in (np.float16, jnp.bfloat16)], - a_shape=one_dim_array_shapes, - b_shape=one_dim_array_shapes, - ) - def testPolyDiv(self, a_shape, b_shape, dtype): - rng = jtu.rand_default(self.rng()) - - @jtu.ignore_warning(category=RuntimeWarning, message="divide by zero.*") - @jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*") - @jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered.*") - def np_fun(arg1, arg2): - q, r = np.polydiv(arg1, arg2) - while r.size < max(arg1.size, arg2.size): # Pad residual to same size - r = np.pad(r, (1, 0), 'constant') - return q, r - - def jnp_fun(arg1, arg2): - q, r = jnp.polydiv(arg1, arg2, trim_leading_zeros=True) - while r.size < max(arg1.size, arg2.size): # Pad residual to same size - r = jnp.pad(r, (1, 0), 'constant') - return q, r - - args_maker = lambda: [rng(a_shape, dtype), rng(b_shape, dtype)] - tol = { - dtypes.bfloat16: 2e-1, - np.float16: 2e-1, - np.float32: 5e-2, - np.float64: 5e-7 - } - - jnp_compile = jnp.polydiv # Without trim_leading_zeros (trim_zeros make it unable to be compiled by XLA) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol) - self._CompileAndCheck(jnp_compile, args_maker, check_dtypes=True, atol=tol, rtol=tol) - - @jtu.sample_product( - [dict(shape=shape, axis1=axis1, axis2=axis2) - for shape in [shape for shape in all_shapes if len(shape) >= 2] - for axis1 in range(-len(shape), len(shape)) - for axis2 in [a for a in range(-len(shape), len(shape)) - if a % len(shape) != axis1 % len(shape)] - ], - dtype=default_dtypes, - offset=list(range(-4, 4)), - ) - def testDiagonal(self, shape, dtype, offset, axis1, axis2): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arg: np.diagonal(arg, offset, axis1, axis2) - jnp_fun = lambda arg: jnp.diagonal(arg, offset, axis1, axis2) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - dtype=default_dtypes, - n=list(range(4)), - ) - def testIdentity(self, n, dtype): - np_fun = lambda: np.identity(n, dtype) - jnp_fun = lambda: jnp.identity(n, dtype) - args_maker = lambda: [] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @unittest.skip("jax-metal crash.") - @jtu.sample_product( - shape=nonempty_shapes, - period=[None, 0.59], - left=[None, 0], - right=[None, 1], - # Note: skip 8-bit and 16-bit types due to insufficient precision. - dtype=jtu.dtypes.integer + jtu.dtypes.floating, - target_dtype=jtu.dtypes.inexact, - ) - def testInterp(self, shape, dtype, period, left, right, target_dtype): - rng = jtu.rand_default(self.rng(), scale=10) - kwds = dict(period=period, left=left, right=right) - np_fun = partial(np.interp, **kwds) - jnp_fun = partial(jnp.interp, **kwds) - - args_maker = lambda: [rng(shape, dtype), np.unique(rng((100,), dtype))[:20], - rng((20,), target_dtype)] - - with jtu.strict_promotion_if_dtypes_match([dtype, target_dtype]): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, - rtol=3e-3, atol=1e-3) - self._CompileAndCheck(jnp_fun, args_maker) - - @unittest.skip("jax-metal crash.") - @jtu.sample_product([ - dict(x=0.5, left='extrapolate', expected=5), - dict(x=1.5, left='extrapolate', expected=15), - dict(x=3.5, left='extrapolate', expected=30), - dict(x=3.9, right='extrapolate', expected=39), - ]) - def testInterpExtrapoate(self, x, expected, **kwargs): - xp = jnp.array([1.0, 2.0, 3.0]) - fp = jnp.array([10.0, 20.0, 30.0]) - actual = jnp.interp(x, xp, fp, **kwargs) - self.assertAlmostEqual(actual, expected) - - def testInterpErrors(self): - with self.assertRaisesWithLiteralMatch( - ValueError, - 'xp and fp must be one-dimensional arrays of equal size' - ): - jnp.interp(0.0, jnp.arange(2.0), jnp.arange(3.0)) - with self.assertRaisesWithLiteralMatch( - ValueError, - "the only valid string value of `left` is 'extrapolate', but got: 'interpolate'" - ): - jnp.interp(0.0, jnp.arange(3.0), jnp.arange(3.0), left='interpolate') - with self.assertRaisesWithLiteralMatch( - ValueError, - "the only valid string value of `right` is 'extrapolate', but got: 'interpolate'" - ): - jnp.interp(0.0, jnp.arange(3.0), jnp.arange(3.0), right='interpolate') - with self.assertRaisesWithLiteralMatch( - ValueError, - "jnp.interp: complex x values not supported." - ): - jnp.interp(1j, 1j * np.arange(3.0), np.arange(3.0)) - with self.assertRaisesRegex( - ValueError, - "period must be a scalar; got" - ): - jnp.interp(0.0, jnp.arange(3.0), jnp.arange(3.0), period=np.array([1.0])) - - @jtu.sample_product( - period=[None, 0.59], - left=[None, 0], - right=[None, 1], - dtype=jtu.dtypes.floating, - ) - def testInterpGradNan(self, dtype, period, left, right): - kwds = dict(period=period, left=left, right=right) - jnp_fun = partial(jnp.interp, **kwds) - # Probe values of x and xp that are close to zero and close together. - x = dtype(np.exp(np.linspace(-90, -20, 1000))) - g = jax.grad(lambda z: jnp.sum(jnp_fun(z, z, jnp.ones_like(z))))(x) - np.testing.assert_equal(np.all(np.isfinite(g)), True) - - @jtu.sample_product( - [dict(x1_shape=x1_shape, x2_shape=x2_shape) - for x1_shape, x2_shape in filter(_shapes_are_broadcast_compatible, - itertools.combinations_with_replacement(array_shapes, 2)) - ], - x1_rng_factory=[jtu.rand_some_inf_and_nan, jtu.rand_some_zero], - x2_rng_factory=[partial(jtu.rand_int, low=-1075, high=1024)], - x1_dtype=default_dtypes, - ) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testLdexp(self, x1_shape, x1_dtype, x2_shape, x1_rng_factory, x2_rng_factory): - x1_rng = x1_rng_factory(self.rng()) - x2_rng = x2_rng_factory(self.rng()) - - @jtu.ignore_warning(category=RuntimeWarning, message="overflow.*") - def np_fun(x1, x2): - out_dtype = dtypes.to_inexact_dtype(x1.dtype) - return np.ldexp(x1.astype(out_dtype), x2) - - jnp_fun = jnp.ldexp - args_maker = lambda: [x1_rng(x1_shape, x1_dtype), - x2_rng(x2_shape, np.int32)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - rng_factory=[ - jtu.rand_some_inf_and_nan, - jtu.rand_some_zero, - partial(jtu.rand_not_small, offset=1e8), - ], - shape=all_shapes, - dtype=default_dtypes, - ) - def testFrexp(self, shape, dtype, rng_factory): - # integer types are converted to float64 in numpy's implementation - if (dtype not in [jnp.bfloat16, np.float16, np.float32] - and not config.enable_x64.value): - self.skipTest("Only run float64 testcase when float64 is enabled.") - rng = rng_factory(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - def np_frexp(x): - mantissa, exponent = np.frexp(x) - # NumPy is inconsistent between Windows and Linux/Mac on what the - # value of exponent is if the input is infinite. Normalize to the Linux - # behavior. - exponent = np.where(np.isinf(mantissa), np.zeros_like(exponent), exponent) - return mantissa, exponent - self._CheckAgainstNumpy(np_frexp, jnp.frexp, args_maker, - check_dtypes=np.issubdtype(dtype, np.inexact)) - self._CompileAndCheck(jnp.frexp, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axis1=axis1, axis2=axis2) - for shape in [shape for shape in all_shapes if len(shape) >= 2] - for axis1 in range(-len(shape), len(shape)) - for axis2 in range(-len(shape), len(shape)) - if (axis1 % len(shape)) != (axis2 % len(shape)) - ], - dtype=default_dtypes, - out_dtype=[None] + number_dtypes, - offset=list(range(-4, 4)), - ) - def testTrace(self, shape, dtype, out_dtype, offset, axis1, axis2): - rng = jtu.rand_default(self.rng()) - def np_fun(arg): - if out_dtype == jnp.bfloat16: - return np.trace(arg, offset, axis1, axis2, np.float32).astype(jnp.bfloat16) - else: - return np.trace(arg, offset, axis1, axis2, out_dtype) - jnp_fun = lambda arg: jnp.trace(arg, offset, axis1, axis2, out_dtype) - args_maker = lambda: [rng(shape, dtype)] - # TODO: Fails with uint8/uint16 output dtypes (integer overflow?) - if out_dtype not in (np.uint8, np.uint16, np.uint32): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - #unittest.skip("jax-metal fail with empty vshape.") - @jtu.sample_product( - ashape=[(15,), (16,), (17,)], - vshape= [(5,), (5, 5)],#[(), (5,), (5, 5)], - side=['left', 'right'], - dtype= number_dtypes, - method=['sort', 'scan', 'scan_unrolled', 'compare_all'], - ) - def testSearchsorted(self, ashape, vshape, side, dtype, method): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [np.sort(rng(ashape, dtype)), rng(vshape, dtype)] - def np_fun(a, v): - return np.searchsorted(a, v, side=side).astype('int32') - jnp_fun = lambda a, v: jnp.searchsorted(a, v, side=side, method=method) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @unittest.skipIf( - platform.system() == "Windows", - "Under Windows, NumPy throws if 2**32 is converted to an int32" - ) - def testSearchsortedDtype(self): - # Test that for large arrays, int64 indices are used. We test this - # via abstract evaluation to avoid allocating a large array in tests. - a_int32 = core.ShapedArray((np.iinfo(np.int32).max,), np.float32) - a_int64 = core.ShapedArray((np.iinfo(np.int32).max + 1,), np.float32) - v = core.ShapedArray((), np.float32) - - out_int32 = jax.eval_shape(jnp.searchsorted, a_int32, v) - self.assertEqual(out_int32.dtype, np.int32) - - if config.enable_x64.value: - out_int64 = jax.eval_shape(jnp.searchsorted, a_int64, v) - self.assertEqual(out_int64.dtype, np.int64) - else: - with self.assertWarnsRegex(UserWarning, "Explicitly requested dtype int64"): - with self.assertRaisesRegex(OverflowError, "Python integer 2147483648 out of bounds.*"): - out_int64 = jax.eval_shape(jnp.searchsorted, a_int64, v) - - @unittest.skip("Jax-metal fail.") - @jtu.sample_product( - dtype=inexact_dtypes, - side=['left', 'right'], - method=['sort', 'scan', 'compare_all'], - ) - def testSearchsortedNans(self, dtype, side, method): - if np.issubdtype(dtype, np.complexfloating): - raise SkipTest("Known failure for complex inputs; see #9107") - x = np.array([-np.inf, -1.0, 0.0, -0.0, 1.0, np.inf, np.nan, -np.nan], dtype=dtype) - # The sign bit should not matter for 0.0 or NaN, so argsorting the above should be - # equivalent to argsorting the following: - x_equiv = np.array([0, 1, 2, 2, 3, 4, 5, 5]) - - if jnp.issubdtype(dtype, jnp.complexfloating): - x = np.array([complex(r, c) for r, c in itertools.product(x, repeat=2)]) - x_equiv = np.array([complex(r, c) for r, c in itertools.product(x_equiv, repeat=2)]) - - fun = partial(jnp.searchsorted, side=side, method=method) - self.assertArraysEqual(fun(x, x), fun(x_equiv, x_equiv)) - self.assertArraysEqual(jax.jit(fun)(x, x), fun(x_equiv, x_equiv)) - - @jtu.sample_product( - xshape=[(20,), (5, 4)], - binshape=[(1,), (5,)], - right=[True, False], - reverse=[True, False], - dtype=default_dtypes, - ) - def testDigitize(self, xshape, binshape, right, reverse, dtype): - order = jnp.index_exp[::-1] if reverse else jnp.index_exp[:] - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(xshape, dtype), jnp.sort(rng(binshape, dtype))[order]] - np_fun = lambda x, bins: np.digitize(x, bins, right=right).astype('int32') - jnp_fun = lambda x, bins: jnp.digitize(x, bins, right=right) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - dtypes=[ - [np.float32], - [np.float32, np.float32], - [np.float32, np.int32, np.float32], - [np.float32, np.int64, np.float32], - [np.float32, np.int32, np.float64], - ], - shape=[(), (2,), (3, 4), (1, 5)], - array_input=[True, False], - ) - def testColumnStack(self, shape, dtypes, array_input): - rng = jtu.rand_default(self.rng()) - if array_input: - args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])] - else: - args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] - np_fun = jtu.promote_like_jnp(np.column_stack) - jnp_fun = jnp.column_stack - with jtu.strict_promotion_if_dtypes_match(dtypes): - self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in [(), (2,), (3, 4), (1, 100)] - for axis in range(-len(shape), len(shape) + 1) - ], - dtypes=[ - [np.float32], - [np.float32, np.float32], - [np.float32, np.int32, np.float32], - [np.float32, np.int64, np.float32], - [np.float32, np.int32, np.float64], - ], - array_input=[True, False], - out_dtype=[np.float32, np.int32], - ) - def testStack(self, shape, axis, dtypes, array_input, out_dtype): - rng = jtu.rand_default(self.rng()) - if array_input: - args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])] - else: - args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] - - np_fun = jtu.promote_like_jnp(partial(np.stack, axis=axis, dtype=out_dtype, casting='unsafe')) - - jnp_fun = partial(jnp.stack, axis=axis, dtype=out_dtype) - with jtu.strict_promotion_if_dtypes_match(dtypes): - self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - op=["hstack", "vstack", "dstack"], - dtypes=[ - [np.float32], - [np.float32, np.float32], - [np.float32, np.int32, np.float32], - [np.float32, np.int64, np.float32], - [np.float32, np.int32, np.float64], - ], - shape=[(), (2,), (3, 4), (1, 100), (2, 3, 4)], - array_input=[True, False], - out_dtype=[np.float32, np.int32], - ) - def testHVDStack(self, shape, op, dtypes, array_input, out_dtype): - rng = jtu.rand_default(self.rng()) - if array_input: - args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])] - else: - args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] - - if op == "dstack": - np_fun = jtu.promote_like_jnp(lambda *args: getattr(np, op)(*args).astype(out_dtype)) - else: - np_fun = partial(jtu.promote_like_jnp(getattr(np, op)), dtype=out_dtype, - casting='unsafe') - - jnp_fun = partial(getattr(jnp, op), dtype=out_dtype) - with jtu.strict_promotion_if_dtypes_match(dtypes): - self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(name=name, **kwds) - for name in ['blackman', 'bartlett', 'hamming', 'hanning', 'kaiser'] - for kwds in ([dict(beta=1), dict(beta=0.5)] if name == 'kaiser' else [{}]) - ], - size = [0, 1, 5, 10], - ) - def testWindowFunction(self, name, size, **kwds): - jnp_fun = partial(getattr(jnp, name), size, **kwds) - np_fun = jtu.with_jax_dtype_defaults(partial(getattr(np, name), size, **kwds)) - args_maker = lambda: [] - tol = ( - 5e-6 if jtu.test_device_matches(['tpu']) and name == 'kaiser' else None - ) - self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker, atol=tol, rtol=tol) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, fill_value_shape=fill_value_shape) - for shape in array_shapes + [3, np.array(7, dtype=np.int32)] - for fill_value_shape in _compatible_shapes(shape)], - fill_value_dtype=default_dtypes, - out_dtype=[None] + default_dtypes, - ) - def testFull(self, shape, fill_value_dtype, fill_value_shape, out_dtype): - rng = jtu.rand_default(self.rng()) - np_fun = lambda fill_value: np.full(shape, fill_value, dtype=out_dtype) - jnp_fun = lambda fill_value: jnp.full(shape, fill_value, dtype=out_dtype) - args_maker = lambda: [rng(fill_value_shape, fill_value_dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, dtype=dtype, axis=axis) - for shape, dtype in _shape_and_dtypes(nonempty_nonscalar_array_shapes, default_dtypes) - for axis in list(range(-len(shape), max(1, len(shape)))) - ], - prepend=[None, 1, 0], - append=[None, 1, 0], - n=[0, 1, 2], - ) - def testDiff(self, shape, dtype, n, axis, prepend, append): - prepend = np.zeros(shape, dtype=dtype) if prepend == 0 else prepend - append = np.zeros(shape, dtype=dtype) if append == 0 else append - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - - def np_fun(x, n=n, axis=axis, prepend=prepend, append=append): - if prepend is None: - prepend = np._NoValue - elif not np.isscalar(prepend) and prepend.dtype == jnp.bfloat16: - prepend = prepend.astype(np.float32) - - if append is None: - append = np._NoValue - elif not np.isscalar(append) and append.dtype == jnp.bfloat16: - append = append.astype(np.float32) - - if x.dtype == jnp.bfloat16: - return np.diff(x.astype(np.float32), n=n, axis=axis, prepend=prepend, append=append).astype(jnp.bfloat16) - else: - return np.diff(x, n=n, axis=axis, prepend=prepend, append=append) - - jnp_fun = lambda x: jnp.diff(x, n=n, axis=axis, prepend=prepend, append=append) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - self._CompileAndCheck(jnp_fun, args_maker) - - def testDiffPrepoendScalar(self): - # Regression test for https://github.com/jax-ml/jax/issues/19362 - x = jnp.arange(10) - result_jax = jnp.diff(x, prepend=x[0], append=x[-1]) - - x = np.array(x) - result_numpy = np.diff(x, prepend=x[0], append=x[-1]) - - self.assertArraysEqual(result_jax, result_numpy) - - @jtu.sample_product( - op=["zeros", "ones"], - shape=[2, (), (2,), (3, 0), np.array((4, 5, 6), dtype=np.int32), - np.array(4, dtype=np.int32)], - dtype=all_dtypes, - ) - def testZerosOnes(self, op, shape, dtype): - np_op = getattr(np, op) - jnp_op = getattr(jnp, op) - args_maker = lambda: [] - np_op = partial(np_op, shape, dtype) - jnp_op = partial(jnp_op, shape, dtype) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - def testOnesWithInvalidShape(self): - with self.assertRaises(TypeError): - jnp.ones((-1, 1)) - - def test_full_like_commited(self): - x = jnp.array((1, 2, 3), dtype=np.int32) - self.assertFalse(x._committed) - self.assertFalse(lax.full_like(x, 1.1)._committed) - x = jax.device_put(x, jax.devices()[-1]) - self.assertTrue(x._committed) - y = lax.full_like(x, 1.1) - self.assertTrue(y._committed) - self.assertEqual(x.sharding, y.sharding) - - def test_zeros_like_with_explicit_device_and_jitted(self): - x = jnp.array((1, 2, 3), dtype=np.int32) - x = jax.device_put(x, jax.devices()[0]) - zeros_like_with_device = partial(jnp.zeros_like, device=jax.devices()[0]) - y = jax.jit(zeros_like_with_device)(x) - self.assertEqual(x.shape, y.shape) - self.assertEqual(y.sharding, SingleDeviceSharding(jax.devices()[0])) - - @jtu.sample_product( - [dict(shape=shape, out_shape=out_shape, fill_value_shape=fill_value_shape) - for shape in array_shapes - for out_shape in [None] + array_shapes - for fill_value_shape in _compatible_shapes(shape if out_shape is None else out_shape) - ], - in_dtype=default_dtypes, - fill_value_dtype=default_dtypes, - out_dtype=default_dtypes, - ) - def testFullLike(self, shape, in_dtype, fill_value_dtype, fill_value_shape, out_dtype, out_shape): - rng = jtu.rand_default(self.rng()) - np_fun = lambda x, fill_value: np.full_like( - x, fill_value, dtype=out_dtype, shape=out_shape) - jnp_fun = lambda x, fill_value: jnp.full_like( - x, fill_value, dtype=out_dtype, shape=out_shape) - args_maker = lambda: [rng(shape, in_dtype), rng(fill_value_shape, fill_value_dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - shape=array_shapes, - out_shape=[None] + array_shapes, - in_dtype=default_dtypes, - func=["ones_like", "zeros_like"], - out_dtype=default_dtypes, - ) - def testZerosOnesLike(self, func, shape, in_dtype, out_shape, out_dtype): - rng = jtu.rand_default(self.rng()) - np_fun = lambda x: getattr(np, func)(x, dtype=out_dtype, shape=out_shape) - jnp_fun = lambda x: getattr(jnp, func)(x, dtype=out_dtype, shape=out_shape) - args_maker = lambda: [rng(shape, in_dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - func=[jnp.empty, jnp.zeros, jnp.ones, jnp.full], - shape=array_shapes, - dtype=default_dtypes, - ) - def testArrayCreationWithDevice(self, func, shape, dtype): - device = jax.devices()[-1] - kwds = {'fill_value': 1} if func is jnp.full else {} - out = func(**kwds, shape=shape, dtype=dtype, device=device) - self.assertEqual(out.devices(), {device}) - - @jtu.sample_product( - func=[jnp.empty, jnp.zeros, jnp.ones, jnp.full], - shape=array_shapes, - dtype=default_dtypes, - ) - def testArrayCreationWithSharding(self, func, shape, dtype): - sharding = SingleDeviceSharding(jax.devices()[-1]) - kwds = {'fill_value': 1} if func is jnp.full else {} - out = func(**kwds, shape=shape, dtype=dtype, device=sharding) - self.assertEqual(out.sharding, sharding) - - @jtu.sample_product( - func=[jnp.empty_like, jnp.zeros_like, jnp.ones_like, jnp.full_like], - shape=array_shapes, - dtype=default_dtypes, - ) - def testFullLikeWithDevice(self, func, shape, dtype): - device = jax.devices()[-1] - rng = jtu.rand_default(self.rng()) - x = rng(shape, dtype) - kwds = {'fill_value': 1} if func is jnp.full_like else {} - - with self.subTest('device from keyword'): - out = func(x, **kwds, device=device) - self.assertEqual(out.devices(), {device}) - - with self.subTest('device from input array'): - out2 = func(out, **kwds) - self.assertEqual(out2.devices(), out.devices()) - - @jtu.sample_product( - func=[jnp.empty_like, jnp.zeros_like, jnp.ones_like, jnp.full_like], - shape=array_shapes, - dtype=default_dtypes, - ) - def testFullLikeWithSharding(self, func, shape, dtype): - sharding = SingleDeviceSharding(jax.devices()[-1]) - rng = jtu.rand_default(self.rng()) - x = rng(shape, dtype) - kwds = {'fill_value': 1} if func is jnp.full_like else {} - - with self.subTest('device from keyword'): - out = func(x, **kwds, device=sharding) - self.assertEqual(out.sharding, sharding) - - with self.subTest('device from input array'): - out2 = func(out, **kwds) - self.assertEqual(out2.devices(), out.devices()) - - def testDuckTypedLike(self): - x = jax.ShapeDtypeStruct((1, 2, 3), np.dtype("int32")) - self.assertArraysEqual(jnp.zeros_like(x), jnp.zeros(x.shape, x.dtype)) - self.assertArraysEqual(jnp.ones_like(x), jnp.ones(x.shape, x.dtype)) - self.assertArraysEqual(jnp.empty_like(x), jnp.empty(x.shape, x.dtype)) - self.assertArraysEqual(jnp.full_like(x, 2), jnp.full(x.shape, 2, x.dtype)) - - @jtu.sample_product( - [dict(func=func, args=args) - for func, args in [("full_like", (-100,)), ("ones_like", ()), ("zeros_like", ())] - ], - shape=array_shapes, - #in_dtype=[np.int32, np.float32, np.complex64], - in_dtype=[np.int32, np.float32], - weak_type=[True, False], - out_shape=[None, (), (10,)], - out_dtype=[None, float], - ) - def testZerosOnesFullLikeWeakType(self, func, args, shape, in_dtype, weak_type, out_shape, out_dtype): - rng = jtu.rand_default(self.rng()) - x = lax_internal._convert_element_type(rng(shape, in_dtype), - weak_type=weak_type) - fun = lambda x: getattr(jnp, func)(x, *args, dtype=out_dtype, shape=out_shape) - expected_weak_type = weak_type and (out_dtype is None) - self.assertEqual(dtypes.is_weakly_typed(fun(x)), expected_weak_type) - self.assertEqual(dtypes.is_weakly_typed(jax.jit(fun)(x)), expected_weak_type) - - @jtu.sample_product( - funcname=["array", "asarray"], - dtype=[int, float, None], - val=[0, 1], - input_type=[int, float, np.int32, np.float32], - ) - def testArrayWeakType(self, funcname, input_type, val, dtype): - func = lambda x: getattr(jnp, funcname)(x, dtype=dtype) - fjit = jax.jit(func) - val = input_type(val) - expected_weak_type = dtype is None and input_type in set(dtypes._weak_types) - self.assertEqual(dtypes.is_weakly_typed(func(val)), expected_weak_type) - self.assertEqual(dtypes.is_weakly_typed(fjit(val)), expected_weak_type) - - @jtu.sample_product( - shape=nonempty_nonscalar_array_shapes, - #dtype=[int, float, complex], - dtype=[int, float], - weak_type=[True, False], - slc=[slice(None), slice(0), slice(3), 0, ...], - ) - def testSliceWeakTypes(self, shape, dtype, weak_type, slc): - rng = jtu.rand_default(self.rng()) - x = lax_internal._convert_element_type(rng(shape, dtype), - weak_type=weak_type) - op = lambda x: x[slc] - self.assertEqual(op(x).aval.weak_type, weak_type) - self.assertEqual(jax.jit(op)(x).aval.weak_type, weak_type) - - @jtu.sample_product( - [dict(shape=shape, axis=axis, num_sections=num_sections) - for shape, axis, num_sections in [ - ((3,), 0, 3), ((12,), 0, 3), ((12, 4), 0, 4), ((12, 4), 1, 2), - ((2, 3, 4), -1, 2), ((2, 3, 4), -2, 3)] - ], - dtype=default_dtypes, - ) - def testSplitStaticInt(self, shape, num_sections, axis, dtype): - rng = jtu.rand_default(self.rng()) - np_fun = lambda x: np.split(x, num_sections, axis=axis) - jnp_fun = lambda x: jnp.split(x, num_sections, axis=axis) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axis=axis, num_sections=num_sections) - # All testcases split the specified axis unequally - for shape, axis, num_sections in [ - ((3,), 0, 2), ((12,), 0, 5), ((12, 4), 0, 7), ((12, 4), 1, 3), - ((2, 3, 5), -1, 2), ((2, 4, 4), -2, 3), ((7, 2, 2), 0, 3)] - ], - dtype=default_dtypes, - ) - def testArraySplitStaticInt(self, shape, num_sections, axis, dtype): - rng = jtu.rand_default(self.rng()) - np_fun = lambda x: np.array_split(x, num_sections, axis=axis) - jnp_fun = lambda x: jnp.array_split(x, num_sections, axis=axis) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - def testSplitTypeError(self): - # If we pass an ndarray for indices_or_sections -> no error - self.assertEqual(3, len(jnp.split(jnp.zeros(3), jnp.array([1, 2])))) - - CONCRETIZATION_MSG = "Abstract tracer value encountered where concrete value is expected." - with self.assertRaisesRegex(TypeError, CONCRETIZATION_MSG): - # An abstract tracer for idx - jax.jit(lambda idx: jnp.split(jnp.zeros((12, 2)), idx))(2.) - with self.assertRaisesRegex(TypeError, CONCRETIZATION_MSG): - # A list including an abstract tracer - jax.jit(lambda idx: jnp.split(jnp.zeros((12, 2)), [2, idx]))(2.) - - # A concrete tracer -> no error - jax.jvp(lambda idx: jnp.split(jnp.zeros((12, 2)), idx), - (2.,), (1.,)) - # A tuple including a concrete tracer -> no error - jax.jvp(lambda idx: jnp.split(jnp.zeros((12, 2)), (1, idx.astype(np.int32))), - (2.,), (1.,)) - - @jtu.sample_product( - shape=[(5,), (5, 5)], - dtype=number_dtypes, - bins=[10, np.arange(-5, 6), np.array([-5, 0, 3])], - range=[None, (0, 0), (0, 10)], - weights=[True, False], - ) - def testHistogramBinEdges(self, shape, dtype, bins, range, weights): - rng = jtu.rand_default(self.rng()) - _weights = lambda w: abs(w) if weights else None - np_fun = lambda a, w, r: np.histogram_bin_edges(a, bins=bins, range=r, - weights=_weights(w)) - jnp_fun = lambda a, w, r: jnp.histogram_bin_edges(a, bins=bins, range=r, - weights=_weights(w)) - args_maker = lambda: [rng(shape, dtype), rng(shape, dtype), range] - tol = {jnp.bfloat16: 2E-2, np.float16: 1E-2} - # linspace() compares poorly to numpy when using bfloat16 - if dtype != jnp.bfloat16: - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol) - self._CompileAndCheck(jnp_fun, args_maker, - atol=tol, rtol=tol) - - @jtu.sample_product( - shape=[(5,), (4, 5)], - dtype=default_dtypes, - # We only test explicit integer-valued bin edges because in other cases - # rounding errors lead to flaky tests. - bins=[np.arange(-5, 6), np.array([-5, 0, 3])], - density=[True, False], - weights=[True, False], - ) - def testHistogram(self, shape, dtype, bins, density, weights): - rng = jtu.rand_default(self.rng()) - _weights = lambda w: abs(w) if weights else None - def np_fun(a, w): - # Numpy can't handle bfloat16 - a = a.astype('float32') if a.dtype == jnp.bfloat16 else a - w = w.astype('float32') if w.dtype == jnp.bfloat16 else w - return np.histogram(a, bins=bins, density=density, weights=_weights(w)) - jnp_fun = lambda a, w: jnp.histogram(a, bins=bins, density=density, - weights=_weights(w)) - args_maker = lambda: [rng(shape, dtype), rng(shape, dtype)] - tol = {jnp.bfloat16: 2E-2, np.float16: 1E-1} - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, - tol=tol) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - shape=[(5,), (12,)], - dtype=int_dtypes, - bins=[2, [2, 2], [np.array([0, 1, 3, 5]), np.array([0, 2, 3, 4, 6])]], - weights=[False, True], - density=[False, True], - range=[None, [(-1, 1), None], [(-1, 1), (-2, 2)]], - ) - def testHistogram2d(self, shape, dtype, bins, weights, density, range): - rng = jtu.rand_default(self.rng()) - _weights = lambda w: abs(w) if weights else None - np_fun = jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*")( - lambda a, b, w: np.histogram2d(a, b, bins=bins, weights=_weights(w), density=density, range=range)) - jnp_fun = lambda a, b, w: jnp.histogram2d(a, b, bins=bins, weights=_weights(w), density=density, range=range) - args_maker = lambda: [rng(shape, dtype), rng(shape, dtype), rng(shape, dtype)] - tol = {jnp.bfloat16: 2E-2, np.float16: 1E-1} - # np.searchsorted errors on bfloat16 with - # "TypeError: invalid type promotion with custom data type" - with np.errstate(divide='ignore', invalid='ignore'): - if dtype != jnp.bfloat16: - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, - tol=tol) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - shape=[(5, 3), (10, 3)], - dtype=int_dtypes, - bins=[(2, 2, 2), [np.array([-5, 0, 4]), np.array([-4, -1, 2]), np.array([-6, -1, 4])]], - weights=[False, True], - density=[False, True], - range=[None, [(-1, 1), None, None], [(-1, 1), (-2, 2), (-3, 3)]], - ) - def testHistogramdd(self, shape, dtype, bins, weights, density, range): - rng = jtu.rand_default(self.rng()) - _weights = lambda w: abs(w) if weights else None - np_fun = jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*")( - lambda a, w: np.histogramdd(a, bins=bins, weights=_weights(w), density=density, range=range)) - jnp_fun = lambda a, w: jnp.histogramdd(a, bins=bins, weights=_weights(w), density=density, range=range) - args_maker = lambda: [rng(shape, dtype), rng((shape[0],), dtype)] - tol = {jnp.bfloat16: 2E-2, np.float16: 1E-1} - # np.searchsorted errors on bfloat16 with - # "TypeError: invalid type promotion with custom data type" - if dtype != jnp.bfloat16: - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, - tol=tol) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axis=axis, num_sections=num_sections) - for shape, axis, num_sections in [ - ((12, 4), 0, 4), ((12,), 1, 2), - ((2, 3, 4), 2, 2), ((4, 3, 4), 0, 2)]], - dtype=default_dtypes, - ) - def testHVDSplit(self, shape, num_sections, axis, dtype): - rng = jtu.rand_default(self.rng()) - def fn(module, axis): - if axis == 0: - return module.vsplit - elif axis == 1: - return module.hsplit - else: - assert axis == 2 - return module.dsplit - - np_fun = lambda x: fn(np, axis)(x, num_sections) - jnp_fun = lambda x: fn(jnp, axis)(x, num_sections) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(arg_shape=arg_shape, out_shape=out_shape) - for arg_shape, out_shape in [ - (jtu.NUMPY_SCALAR_SHAPE, (1, 1, 1)), - ((), (1, 1, 1)), - ((7, 0), (0, 42, 101)), - ((3, 4), 12), - ((3, 4), (12,)), - ((3, 4), -1), - ((2, 1, 4), (-1,)), - ((2, 2, 4), (2, 8)) - ] - ], - dtype=default_dtypes, - order=["C", "F"], - ) - def testReshape(self, arg_shape, out_shape, dtype, order): - rng = jtu.rand_default(self.rng()) - np_fun = lambda x: np.reshape(x, out_shape, order=order) - jnp_fun = lambda x: jnp.reshape(x, out_shape, order=order) - args_maker = lambda: [rng(arg_shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(arg_shape=arg_shape, out_shape=out_shape) - for arg_shape, out_shape in [ - ((7, 0), (0, 42, 101)), - ((2, 1, 4), (-1,)), - ((2, 2, 4), (2, 8)) - ] - ], - dtype=default_dtypes, - ) - def testReshapeMethod(self, arg_shape, out_shape, dtype): - rng = jtu.rand_default(self.rng()) - np_fun = lambda x: np.reshape(x, out_shape) - jnp_fun = lambda x: x.reshape(*out_shape) - args_maker = lambda: [rng(arg_shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(arg_shape=arg_shape, out_shape=out_shape) - for arg_shape, out_shape in itertools.product(all_shapes, array_shapes)], - dtype=default_dtypes, - ) - def testResize(self, arg_shape, out_shape, dtype): - rng = jtu.rand_default(self.rng()) - np_fun = lambda x: np.resize(x, out_shape) - jnp_fun = lambda x: jnp.resize(x, out_shape) - args_maker = lambda: [rng(arg_shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(arg_shape=arg_shape, dim=dim) - for arg_shape in [(), (3,), (3, 4)] - for dim in (list(range(-len(arg_shape)+1, len(arg_shape))) - + [np.array(0), np.array(-1), (0,), [np.array(0)], - (len(arg_shape), len(arg_shape) + 1)]) - ], - dtype=default_dtypes, - ) - def testExpandDimsStaticDim(self, arg_shape, dtype, dim): - rng = jtu.rand_default(self.rng()) - np_fun = lambda x: np.expand_dims(x, dim) - jnp_fun = lambda x: jnp.expand_dims(x, dim) - args_maker = lambda: [rng(arg_shape, dtype)] - self._CompileAndCheck(jnp_fun, args_maker) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - - def testExpandDimsRepeatedAxisError(self): - x = jnp.ones((2, 3)) - self.assertRaisesRegex( - ValueError, 'repeated axis.*', - lambda: jnp.expand_dims(x, [1, 1])) - self.assertRaisesRegex( - ValueError, 'repeated axis.*', - lambda: jnp.expand_dims(x, [3, -1])) - - # ensure this is numpy's behavior too, so that we remain consistent - x = np.ones((2, 3)) - self.assertRaisesRegex( - ValueError, 'repeated axis.*', - lambda: np.expand_dims(x, [1, 1])) - self.assertRaisesRegex( - ValueError, 'repeated axis.*', - lambda: np.expand_dims(x, [3, -1])) - - @jtu.sample_product( - [dict(arg_shape=arg_shape, ax1=ax1, ax2=ax2) - for arg_shape, ax1, ax2 in [ - ((3, 4), 0, 1), ((3, 4), 1, 0), ((3, 4, 5), 1, 2), - ((3, 4, 5), -1, -2), ((3, 4, 5), 0, 1)] - ], - dtype=default_dtypes, - ) - def testSwapAxesStaticAxes(self, arg_shape, dtype, ax1, ax2): - rng = jtu.rand_default(self.rng()) - np_fun = lambda x: np.swapaxes(x, ax1, ax2) - jnp_fun = lambda x: jnp.swapaxes(x, ax1, ax2) - args_maker = lambda: [rng(arg_shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(arg_shape=arg_shape, ax=ax) - for arg_shape, ax in [ - ((3, 1), None), - ((3, 1), 1), - ((3, 1), -1), - ((3, 1), np.array(1)), - ((1, 3, 1), (0, 2)), - ((1, 3, 1), (0,)), - ((1, 4, 1), (np.array(0),))] - ], - dtype=default_dtypes, - ) - def testSqueeze(self, arg_shape, dtype, ax): - rng = jtu.rand_default(self.rng()) - np_fun = lambda x: np.squeeze(x, ax) - jnp_fun = lambda x: jnp.squeeze(x, ax) - args_maker = lambda: [rng(arg_shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - def testArrayFromMasked(self): - args_maker = lambda: [np.ma.array([1, 2], mask=[True, False])] - # Like np.array, jnp.array strips the mask from masked array inputs. - self._CheckAgainstNumpy(np.array, jnp.array, args_maker) - # Under JIT, masked arrays are flagged as invalid. - with self.assertRaisesRegex(ValueError, "numpy masked arrays are not supported"): - jax.jit(jnp.asarray)(*args_maker()) - - @jtu.sample_product( - [dict(arg=arg, dtype=dtype, ndmin=ndmin) - for arg, dtypes in [ - ([True, False, True], all_dtypes), - (3., all_dtypes), - ([1, 2, 3], all_dtypes), - (np.array([1, 2, 3], dtype=np.int64), all_dtypes), - ([1., 2., 3.], all_dtypes), - ([[1, 2], [3, 4], [5, 6]], all_dtypes), - ([[1, 2.], [3, 4], [5, 6]], all_dtypes), - ([[1., 2j], [3., 4.], [5., 6.]], complex_dtypes), - ([[3, np.array(2, dtype=jnp.float_), 1], - np.arange(3., dtype=jnp.float_)], all_dtypes), - ] - for dtype in [None] + dtypes - for ndmin in [None, np.ndim(arg), np.ndim(arg) + 1, np.ndim(arg) + 2] - ], - ) - def testArray(self, arg, ndmin, dtype): - args_maker = lambda: [arg] - canonical_dtype = dtypes.canonicalize_dtype(dtype or np.array(arg).dtype) - if ndmin is not None: - np_fun = partial(np.array, ndmin=ndmin, dtype=canonical_dtype) - jnp_fun = partial(jnp.array, ndmin=ndmin, dtype=dtype) - else: - np_fun = partial(np.array, dtype=canonical_dtype) - jnp_fun = partial(jnp.array, dtype=dtype) - - # We are testing correct canonicalization behavior here, so we turn off the - # permissive canonicalization logic in the test harness. - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, - canonicalize_dtypes=False) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product(copy=[None, True, False]) - def testAsarrayCopy(self, copy): - x_jax = jnp.arange(4) - x_np = np.arange(4) - x_list = [0, 1, 2, 3] - x_buf = make_python_array('l', x_list) - - func = partial(jnp.asarray, copy=copy) - self.assertArraysEqual(x_jax, func(x_jax)) - self.assertArraysEqual(x_jax, func(x_list), check_dtypes=False) - - if copy is False and jax.default_backend() != 'cpu': - # copy=False is strict: it must raise if the input supports the buffer protocol - # but a copy is still required. - self.assertRaises(ValueError, func, x_np) - self.assertRaises(ValueError, func, x_buf) - else: - self.assertArraysEqual(x_jax, func(x_np), check_dtypes=False) - self.assertArraysEqual(x_jax, func(x_buf), check_dtypes=False) - - @unittest.skip("Jax-metal don't support all dtypes.") - @jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*") - def testArrayDtypeInference(self): - def _check(obj, out_dtype, weak_type): - dtype_reference = np.array(obj, dtype=out_dtype) - - out = jnp.array(obj) - self.assertDtypesMatch(out, dtype_reference) - self.assertEqual(dtypes.is_weakly_typed(out), weak_type) - - out_jit = jax.jit(jnp.array)(obj) - self.assertDtypesMatch(out_jit, dtype_reference) - self.assertEqual(dtypes.is_weakly_typed(out_jit), weak_type) - - # Python scalars become 64-bit weak types. - _check(1, np.int64, True) - _check(1.0, np.float64, True) - _check(1.0j, np.complex128, True) - - # Lists become strongly-typed defaults. - _check([1], jnp.int64, False) - _check([1.0], jnp.float64, False) - _check([1.0j], jnp.complex128, False) - - # Lists of weakly-typed objects become strongly-typed defaults. - _check([jnp.array(1)], jnp.int64, False) - _check([jnp.array(1.0)], jnp.float64, False) - _check([jnp.array(1.0j)], jnp.complex128, False) - - # Lists of strongly-typed objects maintain their strong type. - _check([jnp.int64(1)], np.int64, False) - _check([jnp.float64(1)], np.float64, False) - _check([jnp.complex128(1)], np.complex128, False) - - # Mixed inputs use JAX-style promotion. - # (regression test for https://github.com/jax-ml/jax/issues/8945) - _check([0, np.int16(1)], np.int16, False) - _check([0.0, np.float16(1)], np.float16, False) - - @jtu.sample_product( - dtype=all_dtypes, - func=["array", "copy", "copy.copy", "copy.deepcopy"], - ) - def testArrayCopy(self, dtype, func): - x = jnp.ones(10, dtype=dtype) - if func == "copy.deepcopy": - copy_func = copy.deepcopy - elif func == "copy.copy": - copy_func = copy.copy - else: - copy_func = getattr(jnp, func) - - x_view = jnp.asarray(x) - x_view_jit = jax.jit(jnp.asarray)(x) - x_copy = copy_func(x) - x_copy_jit = jax.jit(copy_func)(x) - - _ptr = lambda x: x.unsafe_buffer_pointer() - - self.assertEqual(_ptr(x), _ptr(x_view)) - self.assertNotEqual(_ptr(x), _ptr(x_view_jit)) - self.assertNotEqual(_ptr(x), _ptr(x_copy)) - self.assertNotEqual(_ptr(x), _ptr(x_copy_jit)) - - x.delete() - - self.assertTrue(x_view.is_deleted()) - self.assertFalse(x_view_jit.is_deleted()) - - self.assertFalse(x_copy.is_deleted()) - self.assertFalse(x_copy_jit.is_deleted()) - - def testArrayCopyAutodiff(self): - f = lambda x: jnp.array(x, copy=True) - - x = jnp.ones(10) - xdot = jnp.ones(10) - y, ydot = jax.jvp(f, (x,), (xdot,)) - self.assertIsNot(x, y) - self.assertIsNot(xdot, ydot) - - ybar = jnp.ones(10) - y, f_vjp = jax.vjp(f, x) - xbar, = f_vjp(ybar) - self.assertIsNot(x, y) - self.assertIsNot(xbar, ybar) - - def testArrayCopyVmap(self): - f = lambda x: jnp.array(x, copy=True) - x = jnp.ones(10) - y = jax.vmap(f)(x) - self.assertIsNot(x, y) - - def testArrayUnsupportedDtypeError(self): - with self.assertRaisesRegex(TypeError, - "JAX only supports number and bool dtypes.*"): - jnp.array(3, [('a',' 0.: - return x * 2 - else: - return x + 2 - - self.assertRaises(jax.errors.ConcretizationTypeError, lambda: g(3.)) - - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in [(3,), (2, 3)] - for axis in list(range(-len(shape), len(shape))) + [None] + [tuple(range(len(shape)))] # Test negative axes and tuples - ], - dtype=default_dtypes, - ) - def testFlip(self, shape, dtype, axis): - rng = jtu.rand_default(self.rng()) - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - jnp_op = lambda x: jnp.flip(x, axis) - np_op = lambda x: np.flip(x, axis) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - @jtu.sample_product( - shape=[(3,), (2, 3), (3, 2, 4)], - dtype=default_dtypes, - ) - def testFlipud(self, shape, dtype): - rng = jtu.rand_default(self.rng()) - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - jnp_op = lambda x: jnp.flipud(x) - np_op = lambda x: np.flipud(x) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - @jtu.sample_product( - shape=[(3, 2), (2, 3), (3, 2, 4)], - dtype=default_dtypes, - ) - def testFliplr(self, shape, dtype): - rng = jtu.rand_default(self.rng()) - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - jnp_op = lambda x: jnp.fliplr(x) - np_op = lambda x: np.fliplr(x) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - @jtu.sample_product( - [dict(shape=shape, axes=axes) - for shape, axes in [ - [(2, 3), (0, 1)], - [(2, 3), (1, 0)], - [(4, 3, 2), (0, 2)], - [(4, 3, 2), (2, 1)], - ] - ], - k=range(-3, 4), - dtype=default_dtypes, - ) - def testRot90(self, shape, dtype, k, axes): - rng = jtu.rand_default(self.rng()) - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - jnp_op = lambda x: jnp.rot90(x, k, axes) - np_op = lambda x: np.rot90(x, k, axes) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - # TODO(mattjj): test infix operator overrides - - def testRavel(self): - rng = self.rng() - args_maker = lambda: [rng.randn(3, 4).astype("float32")] - self._CompileAndCheck(lambda x: x.ravel(), args_maker) - - @jtu.sample_product( - shape=nonempty_nonscalar_array_shapes, - order=['C', 'F'], - mode=['wrap', 'clip', 'raise'], - ) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testRavelMultiIndex(self, shape, order, mode): - # generate indices in each dimension with a few out of bounds. - rngs = [jtu.rand_int(self.rng(), low=-1, high=dim + 1) - for dim in shape] - # generate multi_indices of different dimensions that broadcast. - args_maker = lambda: [tuple(rng(ndim * (3,), jnp.int_) - for ndim, rng in enumerate(rngs))] - def np_fun(x): - try: - return np.ravel_multi_index(x, shape, order=order, mode=mode) - except ValueError as err: - if str(err).startswith('invalid entry'): - # sentinel indicating expected error. - return -999 - else: - raise - def jnp_fun(x): - try: - return jnp.ravel_multi_index(x, shape, order=order, mode=mode) - except ValueError as err: - if str(err).startswith('invalid entry'): - # sentinel indicating expected error. - return -999 - else: - raise - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - if mode == 'raise': - msg = ("The error occurred because ravel_multi_index was jit-compiled " - "with mode='raise'. Use mode='wrap' or mode='clip' instead.") - with self.assertRaisesRegex(core.ConcretizationTypeError, msg): - jax.jit(jnp_fun)(*args_maker()) - else: - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - ashape=((), (4,), (3, 4)), - cshapes=[ - [(), (4,)], - [(3, 4), (4,), (3, 1)] - ], - adtype=int_dtypes, - cdtype=default_dtypes, - mode=['wrap', 'clip', 'raise'], - ) - def testChoose(self, ashape, adtype, cshapes, cdtype, mode): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(ashape, adtype), [rng(s, cdtype) for s in cshapes]] - def np_fun(a, c): - try: - return np.choose(a, c, mode=mode) - except ValueError as err: - if mode == 'raise' and str(err).startswith('invalid entry'): - return -999 # sentinel indicating expected error. - else: - raise - def jnp_fun(a, c): - try: - return jnp.choose(a, c, mode=mode) - except ValueError as err: - if mode == 'raise' and str(err).startswith('invalid entry'): - return -999 # sentinel indicating expected error. - else: - raise - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - if mode == 'raise': - msg = ("The error occurred because jnp.choose was jit-compiled" - " with mode='raise'. Use mode='wrap' or mode='clip' instead.") - with self.assertRaisesRegex(core.ConcretizationTypeError, msg): - jax.jit(jnp_fun)(*args_maker()) - else: - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - shape=nonempty_nonscalar_array_shapes, - dtype=int_dtypes, - idx_shape=all_shapes, - ) - def testUnravelIndex(self, shape, idx_shape, dtype): - size = math.prod(shape) - rng = jtu.rand_int(self.rng(), low=-((2 * size) // 3), high=(2 * size) // 3) - - def np_fun(index, shape): - # JAX's version outputs the same dtype as the input in the typical case - # where shape is weakly-typed. - out_dtype = index.dtype - # Adjust out-of-bounds behavior to match jax's documented behavior. - index = np.clip(index, -size, size - 1) - index = np.where(index < 0, index + size, index) - return [i.astype(out_dtype) for i in np.unravel_index(index, shape)] - - jnp_fun = jnp.unravel_index - args_maker = lambda: [rng(idx_shape, dtype), shape] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - from_dtype=['int32', 'float32'], - to_dtype=['int32', 'float32', None], - use_method=[True, False], - ) - def testAstype(self, from_dtype, to_dtype, use_method): - rng = self.rng() - args_maker = lambda: [rng.randn(3, 4).astype(from_dtype)] - if not use_method: - np_op = lambda x: np.astype(x, to_dtype) - else: - np_op = lambda x: np.asarray(x).astype(to_dtype) - if use_method: - jnp_op = lambda x: jnp.asarray(x).astype(to_dtype) - else: - jnp_op = lambda x: jnp.astype(x, to_dtype) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - @unittest.skip("Jax-metal don't support all dtypes") - def testAstypeInt4(self): - # Test converting from int4 to int8 - x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int4) - args_maker = lambda: [x] - np_op = lambda x: np.asarray(x).astype(jnp.int8) - jnp_op = lambda x: jnp.asarray(x).astype(jnp.int8) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - # Test converting from int8 to int4 - x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int8) - args_maker = lambda: [x] - np_op = lambda x: np.asarray(x).astype(jnp.int4) - jnp_op = lambda x: jnp.asarray(x).astype(jnp.int4) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - @jtu.sample_product( - shape=array_shapes, - dtype=all_dtypes, - ) - def testNbytes(self, shape, dtype): - rng = jtu.rand_default(self.rng()) - np_op = lambda x: np.asarray(x).nbytes - jnp_op = lambda x: jnp.asarray(x).nbytes - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - @jtu.sample_product( - shape=array_shapes, - dtype=all_dtypes, - ) - def testItemsize(self, shape, dtype): - rng = jtu.rand_default(self.rng()) - np_op = lambda x: np.asarray(x).itemsize - jnp_op = lambda x: jnp.asarray(x).itemsize - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - @jtu.sample_product( - shape=nonempty_array_shapes, - dtype=all_dtypes, - num_args=[0, 1, "all"], - use_tuple=[True, False] - ) - def testItem(self, shape, dtype, num_args, use_tuple): - rng = jtu.rand_default(self.rng()) - size = math.prod(shape) - - if num_args == 0: - args = () - elif num_args == 1: - args = (self.rng().randint(0, size),) - else: - args = tuple(self.rng().randint(0, s) for s in shape) - args = (args,) if use_tuple else args - - np_op = lambda x: np.asarray(x).item(*args) - jnp_op = lambda x: jnp.asarray(x).item(*args) - args_maker = lambda: [rng(shape, dtype)] - - if size != 1 and num_args == 0: - with self.assertRaises(ValueError): - jnp_op(*args_maker()) - else: - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - - @jtu.sample_product( - # Final dimension must be a multiple of 16 to ensure compatibility of all dtype pairs. - shape=[(0,), (32,), (2, 16)], - a_dtype=all_dtypes, - dtype=(*all_dtypes, None) if config.enable_x64.value else all_dtypes, - ) - def testView(self, shape, a_dtype, dtype): - if jtu.test_device_matches(["tpu"]): - if jnp.dtype(a_dtype).itemsize in [1, 2] or jnp.dtype(dtype).itemsize in [1, 2]: - self.skipTest("arr.view() not supported on TPU for 8- or 16-bit types.") - # It is possible to fill bool arrays with arbitrary bits (not just 0/1 - # bytes), but the behavior is implementation-defined. We therefore only test - # the well-defined case. - rng = (jtu.rand_bool if a_dtype == np.bool_ else jtu.rand_fullrange)( - self.rng() - ) - args_maker = lambda: [rng(shape, a_dtype)] - np_op = lambda x: np.asarray(x).view(dtype) - jnp_op = lambda x: jnp.asarray(x).view(dtype) - # Above may produce signaling nans; ignore warnings from invalid values. - with np.errstate(invalid='ignore'): - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - @jtu.sample_product([ - {'a_dtype': a_dtype, 'dtype': dtype} - for a_dtype in all_dtypes - for dtype in all_dtypes - if np.dtype(a_dtype).itemsize == np.dtype(dtype).itemsize - ]) - def testViewScalar(self, a_dtype, dtype): - if jtu.test_device_matches(["tpu"]): - if jnp.dtype(a_dtype).itemsize in [1, 2] or jnp.dtype(dtype).itemsize in [1, 2]: - self.skipTest("arr.view() not supported on TPU for 8- or 16-bit types.") - rng = jtu.rand_fullrange(self.rng()) - args_maker = lambda: [jnp.array(rng((), a_dtype))] - np_op = lambda x: np.asarray(x).view(dtype) - jnp_op = lambda x: jnp.asarray(x).view(dtype) - # Above may produce signaling nans; ignore warnings from invalid values. - with np.errstate(invalid='ignore'): - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - def testPathologicalFloats(self): - args_maker = lambda: [np.array([ - 0b_0111_1111_1000_0000_0000_0000_0000_0000, # inf - 0b_1111_1111_1000_0000_0000_0000_0000_0000, # -inf - 0b_0111_1111_1100_0000_0000_0000_0000_0000, # qnan - 0b_1111_1111_1100_0000_0000_0000_0000_0000, # -qnan - 0b_0111_1111_1000_0000_0000_0000_0000_0001, # snan - 0b_1111_1111_1000_0000_0000_0000_0000_0001, # -snan - 0b_0111_1111_1000_0000_0000_1100_0000_0000, # nonstandard nan - 0b_1111_1111_1000_0000_0000_1100_0000_0000, # -nonstandard nan - 0b_0000_0000_0000_0000_0000_0000_0000_0000, # zero - 0b_1000_0000_0000_0000_0000_0000_0000_0000, # -zero - ], dtype='uint32')] - - np_op = lambda x: np.asarray(x).view('float32').view('uint32') - jnp_op = lambda x: jnp.asarray(x).view('float32').view('uint32') - - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - # TODO(mattjj): test other ndarray-like method overrides - - def testNpMean(self): - # from https://github.com/jax-ml/jax/issues/125 - x = jnp.eye(3, dtype=float) + 0. - ans = np.mean(x) - self.assertAllClose(ans, np.array(1./3), check_dtypes=False) - - def testArangeOnFloats(self): - np_arange = jtu.with_jax_dtype_defaults(np.arange) - # from https://github.com/jax-ml/jax/issues/145 - self.assertAllClose(np_arange(0.0, 1.0, 0.1), - jnp.arange(0.0, 1.0, 0.1)) - # from https://github.com/jax-ml/jax/issues/3450 - self.assertAllClose(np_arange(2.5), - jnp.arange(2.5)) - self.assertAllClose(np_arange(0., 2.5), - jnp.arange(0., 2.5)) - - def testArangeTypes(self): - # Test that arange() output type is equal to the default types. - int_ = dtypes.default_int_dtype() - float_ = dtypes.default_float_dtype() - - self.assertEqual(jnp.arange(10).dtype, int_) - self.assertEqual(jnp.arange(10.).dtype, float_) - self.assertEqual(jnp.arange(10, dtype='uint16').dtype, np.uint16) - #self.assertEqual(jnp.arange(10, dtype='bfloat16').dtype, jnp.bfloat16) - - self.assertEqual(jnp.arange(0, 10, 1).dtype, int_) - with jax.numpy_dtype_promotion('standard'): - self.assertEqual(jnp.arange(0, 10, 1.).dtype, float_) - self.assertEqual(jnp.arange(0., 10, 1).dtype, float_) - - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in nonzerodim_shapes - for axis in (NO_VALUE, None, *range(-len(shape), len(shape))) - ], - stable=[True, False], - dtype=all_dtypes, - ) - def testSort(self, dtype, shape, axis, stable): - rng = jtu.rand_some_equal(self.rng()) if stable else jtu.rand_some_inf_and_nan(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - kwds = {} if axis is NO_VALUE else {'axis': axis} - - def np_fun(arr): - # Note: numpy sort fails on NaN and Inf values with bfloat16 - dtype = arr.dtype - if arr.dtype == jnp.bfloat16: - arr = arr.astype('float32') - # TODO(jakevdp): switch to stable=stable when supported by numpy. - result = np.sort(arr, kind='stable' if stable else None, **kwds) - with jtu.ignore_warning(category=RuntimeWarning, message='invalid value'): - return result.astype(dtype) - jnp_fun = partial(jnp.sort, stable=stable, **kwds) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - def testSortStableDescending(self): - # TODO(jakevdp): test directly against np.sort when descending is supported. - x = jnp.array([0, 1, jnp.nan, 0, 2, jnp.nan, -jnp.inf, jnp.inf]) - x_sorted = jnp.array([-jnp.inf, 0, 0, 1, 2, jnp.inf, jnp.nan, jnp.nan]) - argsorted_stable = jnp.array([6, 0, 3, 1, 4, 7, 2, 5]) - argsorted_rev_stable = jnp.array([2, 5, 7, 4, 1, 0, 3, 6]) - - self.assertArraysEqual(jnp.sort(x), x_sorted) - self.assertArraysEqual(jnp.sort(x, descending=True), lax.rev(x_sorted, [0])) - self.assertArraysEqual(jnp.argsort(x), argsorted_stable) - self.assertArraysEqual(jnp.argsort(x, descending=True), argsorted_rev_stable) - - @unittest.skip("Jax-metal don't support complex.") - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in one_dim_array_shapes - for axis in [None] - ], - dtype=all_dtypes, - ) - def testSortComplex(self, dtype, shape, axis): - rng = jtu.rand_some_equal(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np.sort_complex, jnp.sort_complex, args_maker, - check_dtypes=False) - self._CompileAndCheck(jnp.sort_complex, args_maker) - - @unittest.skip("Jax-metal fail to convert sort op.") - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in nonempty_nonscalar_array_shapes - for axis in (-1, *range(len(shape) - 1)) - ], - dtype=all_dtypes, - input_type=[np.array, tuple], - ) - def testLexsort(self, dtype, shape, input_type, axis): - rng = jtu.rand_some_equal(self.rng()) - args_maker = lambda: [input_type(rng(shape, dtype))] - jnp_op = lambda x: jnp.lexsort(x, axis=axis) - np_op = jtu.with_jax_dtype_defaults(lambda x: np.lexsort(x, axis=axis)) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - @unittest.skip("JAX-metal crash.") - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in nonzerodim_shapes - for axis in (NO_VALUE, None, *range(-len(shape), len(shape))) - ], - dtype=all_dtypes, - ) - def testArgsort(self, dtype, shape, axis): - rng = jtu.rand_some_equal(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - kwds = {} if axis is NO_VALUE else {'axis': axis} - - @jtu.with_jax_dtype_defaults - def np_fun(arr): - # Note: numpy sort fails on NaN and Inf values with bfloat16 - if arr.dtype == jnp.bfloat16: - arr = arr.astype('float32') - # TODO(jakevdp): switch to stable=True when supported by numpy. - return np.argsort(arr, kind='stable', **kwds) - jnp_fun = partial(jnp.argsort, stable=True, **kwds) - - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @unittest.skip("JAX-metal crash.") - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in nonempty_nonscalar_array_shapes - for axis in (NO_VALUE, None, *range(-len(shape), len(shape))) - ], - descending=[True, False], - dtype=all_dtypes, - ) - def testArgsortUnstable(self, dtype, shape, axis, descending): - # We cannot directly compare unstable argsorts, so instead check that indexed values match. - rng = jtu.rand_some_equal(self.rng()) - x = rng(shape, dtype) - kwds = {} if axis is NO_VALUE else {'axis': axis} - expected = jnp.sort(x, descending=descending, stable=False, **kwds) - indices = jnp.argsort(x, descending=descending, stable=False, **kwds) - if axis is None: - actual = jnp.ravel(x)[indices] - else: - actual = jnp.take_along_axis(x, indices, axis=-1 if axis is NO_VALUE else axis) - self.assertArraysEqual(actual, expected) - - @jtu.sample_product( - [{'shape': shape, 'axis': axis, 'kth': kth} - for shape in nonzerodim_shapes - for axis in range(-len(shape), len(shape)) - for kth in range(-shape[axis], shape[axis])], - dtype=default_dtypes, - ) - def testPartition(self, shape, dtype, axis, kth): - rng = jtu.rand_default(self.rng()) - arg = rng(shape, dtype) - jnp_output = jnp.partition(arg, axis=axis, kth=kth) - np_output = np.partition(arg, axis=axis, kth=kth) - - # Assert that pivot point is equal: - self.assertArraysEqual( - lax.index_in_dim(jnp_output, axis=axis, index=kth), - lax.index_in_dim(np_output, axis=axis, index=kth)) - - # Assert remaining values are correctly partitioned: - self.assertArraysEqual( - lax.sort(lax.slice_in_dim(jnp_output, start_index=0, limit_index=kth, axis=axis), dimension=axis), - lax.sort(lax.slice_in_dim(np_output, start_index=0, limit_index=kth, axis=axis), dimension=axis)) - self.assertArraysEqual( - lax.sort(lax.slice_in_dim(jnp_output, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis), - lax.sort(lax.slice_in_dim(np_output, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis)) - - #@unittest.skipIf(jtu.device_under_test=="METAL", "Jax-metal fail on empty dim shape.") - @jtu.sample_product( - [{'shape': shape, 'axis': axis, 'kth': kth} - for shape in nonempty_shapes# nonzerodim_shapes - for axis in range(-len(shape), len(shape)) - for kth in range(-shape[axis], shape[axis])], - dtype=default_dtypes, - ) - def testArgpartition(self, shape, dtype, axis, kth): - rng = jtu.rand_default(self.rng()) - arg = rng(shape, dtype) - - jnp_output = jnp.argpartition(arg, axis=axis, kth=kth) - np_output = np.argpartition(arg, axis=axis, kth=kth) - - # Assert that all indices are present - self.assertArraysEqual(jnp.sort(jnp_output, axis), np.sort(np_output, axis), check_dtypes=False) - - # Because JAX & numpy may treat duplicates differently, we must compare values - # rather than indices. - getvals = lambda x, ind: x[ind] - for ax in range(arg.ndim): - if ax != range(arg.ndim)[axis]: - getvals = jax.vmap(getvals, in_axes=ax, out_axes=ax) - jnp_values = getvals(arg, jnp_output) - np_values = getvals(arg, np_output) - - # Assert that pivot point is equal: - self.assertArraysEqual( - lax.index_in_dim(jnp_values, axis=axis, index=kth), - lax.index_in_dim(np_values, axis=axis, index=kth)) - - # Assert remaining values are correctly partitioned: - self.assertArraysEqual( - lax.sort(lax.slice_in_dim(jnp_values, start_index=0, limit_index=kth, axis=axis), dimension=axis), - lax.sort(lax.slice_in_dim(np_values, start_index=0, limit_index=kth, axis=axis), dimension=axis)) - self.assertArraysEqual( - lax.sort(lax.slice_in_dim(jnp_values, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis), - lax.sort(lax.slice_in_dim(np_values, start_index=kth + 1, limit_index=shape[axis], axis=axis), dimension=axis)) - - @jtu.sample_product( - [dict(shifts=shifts, axis=axis) - for shifts, axis in [ - (3, None), - (1, 1), - ((3,), (0,)), - ((-2,), (-2,)), - ((1, 2), (0, -1)), - ((4, 2, 5, 5, 2, 4), None), - (100, None), - ] - ], - dtype=all_dtypes, - shape=[(3, 4), (3, 4, 5), (7, 4, 0)], - ) - def testRoll(self, shape, dtype, shifts, axis): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype), np.array(shifts)] - jnp_op = partial(jnp.roll, axis=axis) - np_op = partial(np.roll, axis=axis) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - @jtu.sample_product( - dtype=all_dtypes, - shape=[(1, 2, 3, 4)], - axis=[-3, 0, 2, 3], - start=[-4, -1, 2, 4], - ) - def testRollaxis(self, shape, dtype, start, axis): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - jnp_op = partial(jnp.rollaxis, axis=axis, start=start) - np_op = partial(np.rollaxis, axis=axis, start=start) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - @unittest.skip("jax-metal generates a different result from cpu.") - @jtu.sample_product( - dtype=[np.uint8, np.bool_], - bitorder=['big', 'little'], - shape=[(1, 2, 3, 4)], - axis=[None, 0, 1, -2, -1], - ) - def testPackbits(self, shape, dtype, axis, bitorder): - rng = jtu.rand_some_zero(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - jnp_op = partial(jnp.packbits, axis=axis, bitorder=bitorder) - np_op = partial(np.packbits, axis=axis, bitorder=bitorder) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - @jtu.sample_product( - dtype=[np.uint8], - bitorder=['big', 'little'], - shape=[(1, 2, 3, 4)], - axis=[None, 0, 1, -2, -1], - count=[None, 20], - ) - def testUnpackbits(self, shape, dtype, axis, bitorder, count): - rng = jtu.rand_int(self.rng(), 0, 256) - args_maker = lambda: [rng(shape, dtype)] - jnp_op = partial(jnp.unpackbits, axis=axis, bitorder=bitorder, count=count) - np_op = partial(np.unpackbits, axis=axis, bitorder=bitorder, count=count) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - #@unittest.skip("jax-metal generates a different result from cpu.") - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in [(3,), (3, 4), (3, 4, 5)] - for axis in itertools.chain(range(-len(shape), len(shape)), - [cast(Union[int, None], None)]) - ], - index_shape=scalar_shapes + [(3,), (2, 1, 3)], - dtype=all_dtypes, - index_dtype=int_dtypes, - #mode=[None, 'wrap', 'clip'], - mode=[None, 'wrap'], - ) - def testTake(self, shape, dtype, index_shape, index_dtype, axis, mode): - def args_maker(): - x = rng(shape, dtype) - i = rng_indices(index_shape, index_dtype) - return x, i - - rng = jtu.rand_default(self.rng()) - if mode is None: - rng_indices = jtu.rand_int(self.rng(), -shape[axis or 0], shape[axis or 0]) - else: - rng_indices = jtu.rand_int(self.rng(), -5, 5) - jnp_op = lambda x, i: jnp.take(x, i, axis=axis, mode=mode) - np_op = lambda x, i: np.take(x, i, axis=axis, mode=mode) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - def testTakeEmpty(self): - np.testing.assert_array_equal( - jnp.array([], dtype=jnp.float32), - jnp.take(jnp.array([], jnp.float32), jnp.array([], jnp.int32))) - - np.testing.assert_array_equal( - jnp.ones((2, 0, 4), dtype=jnp.float32), - jnp.take(jnp.ones((2, 0, 4), dtype=jnp.float32), jnp.array([], jnp.int32), - axis=1)) - - with self.assertRaisesRegex(IndexError, "non-empty jnp.take"): - jnp.take(jnp.ones((2, 0, 4), dtype=jnp.float32), - jnp.array([0], jnp.int32), axis=1) - - def testTakeOptionalArgs(self): - x = jnp.arange(5.0) - ind = jnp.array([0, 2, 4, 6]) - expected = jnp.array([0.0, 2.0, 4.0, 10.0], dtype=x.dtype) - actual = jnp.take(x, ind, unique_indices=True, - indices_are_sorted=True, fill_value=10.0) - self.assertArraysEqual(expected, actual) - - @jtu.sample_product( - [dict(x_shape=x_shape, i_shape=i_shape, axis=axis) - for x_shape, i_shape in filter( - _shapes_are_equal_length, - filter(_shapes_are_broadcast_compatible, - itertools.combinations_with_replacement(nonempty_nonscalar_array_shapes, 2))) - for axis in itertools.chain(range(len(x_shape)), [-1], - [cast(Union[int, None], None)]) - ], - dtype=default_dtypes, - index_dtype=int_dtypes, - ) - def testTakeAlongAxis(self, x_shape, i_shape, dtype, index_dtype, axis): - rng = jtu.rand_default(self.rng()) - - i_shape = list(i_shape) - if axis is None: - i_shape = [math.prod(i_shape)] - else: - # Test the case where the size of the axis doesn't necessarily broadcast. - i_shape[axis] *= 3 - def args_maker(): - x = rng(x_shape, dtype) - n = math.prod(x_shape) if axis is None else x_shape[axis] - if np.issubdtype(index_dtype, np.unsignedinteger): - index_rng = jtu.rand_int(self.rng(), 0, n) - else: - index_rng = jtu.rand_int(self.rng(), -n, n) - i = index_rng(i_shape, index_dtype) - return x, i - - jnp_op = lambda x, i: jnp.take_along_axis(x, i, axis=axis) - np_op = lambda x, i: np.take_along_axis(x, i, axis=axis) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - def testTakeAlongAxisWithUint8IndicesDoesNotOverflow(self): - # https://github.com/jax-ml/jax/issues/5088 - h = jtu.rand_default(self.rng())((256, 256, 100), np.float32) - g = jtu.rand_int(self.rng(), 0, 100)((256, 256, 1), np.uint8) - q0 = jnp.take_along_axis(h, g, axis=-1) - q1 = np.take_along_axis( h, g, axis=-1) - np.testing.assert_equal(q0, q1) - - @unittest.skip("Jax-metal fail.") - def testTakeAlongAxisOutOfBounds(self): - x = jnp.arange(10, dtype=jnp.float32) - idx = jnp.array([-11, -10, -9, -5, -1, 0, 1, 5, 9, 10, 11]) - out = jnp.take_along_axis(x, idx, axis=0) - expected_fill = np.array([jnp.nan, 0, 1, 5, 9, 0, 1, 5, 9, jnp.nan, - jnp.nan], np.float32) - np.testing.assert_array_equal(expected_fill, out) - out = jnp.take_along_axis(x, idx, axis=0, mode="fill") - np.testing.assert_array_equal(expected_fill, out) - - expected_clip = np.array([0, 0, 1, 5, 9, 0, 1, 5, 9, 9, 9], np.float32) - out = jnp.take_along_axis(x, idx, axis=0, mode="clip") - np.testing.assert_array_equal(expected_clip, out) - - def testTakeAlongAxisRequiresIntIndices(self): - x = jnp.arange(5) - idx = jnp.array([3.], jnp.float32) - with self.assertRaisesRegex( - TypeError, - "take_along_axis indices must be of integer type, got float32"): - jnp.take_along_axis(x, idx, axis=0) - - def testTakeAlongAxisWithEmptyArgs(self): - # take_along_axis should allow us to gather an empty list of indices from - # an empty input axis without raising a shape error. - x = jnp.ones((4, 0, 3), dtype=jnp.int32) - np.testing.assert_array_equal(x, jnp.take_along_axis(x, x, axis=1)) - - @jtu.sample_product( - dtype=inexact_dtypes, - shape=[0, 5], - n=[2, 4], - increasing=[False, True], - ) - def testVander(self, shape, dtype, n, increasing): - rng = jtu.rand_default(self.rng()) - def np_fun(arg): - arg = arg.astype(np.float32) if dtype == jnp.bfloat16 else arg - return np.vander(arg, N=n, increasing=increasing) - jnp_fun = lambda arg: jnp.vander(arg, N=n, increasing=increasing) - args_maker = lambda: [rng([shape], dtype)] - # np.vander seems to return float64 for all floating types. We could obey - # those semantics, but they seem like a bug. - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, - tol={np.float32: 1e-3, np.complex64: 1e-3}) - self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=False) - - @jtu.sample_product( - shape=array_shapes, - dtype=all_dtypes, - ) - def testNanToNum(self, shape, dtype): - rng = jtu.rand_some_inf_and_nan(self.rng()) - dtype = np.dtype(dtypes.canonicalize_dtype(dtype)).type - def np_fun(x): - if dtype == jnp.bfloat16: - x = np.where(np.isnan(x), dtype(0), x) - x = np.where(np.isposinf(x), jnp.finfo(dtype).max, x) - x = np.where(np.isneginf(x), jnp.finfo(dtype).min, x) - return x - else: - return np.nan_to_num(x).astype(dtype) - - args_maker = lambda: [rng(shape, dtype)] - check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE - self._CheckAgainstNumpy(np_fun, jnp.nan_to_num, args_maker, - check_dtypes=check_dtypes) - self._CompileAndCheck(jnp.nan_to_num, args_maker, - check_dtypes=check_dtypes) - - @jtu.sample_product( - [dict(shapes=shapes, dtypes=dtypes) - for shapes, dtypes in ( - ((), ()), - (((7,),), (np.int32,)), - (((3,), (4,)), (np.int32, np.int32)), - (((3,), (1,), (4,)), (np.int32, np.int32, np.int32)), - ) - ], - ) - def testIx_(self, shapes, dtypes): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype) - for shape, dtype in zip(shapes, dtypes)] - self._CheckAgainstNumpy(np.ix_, jnp.ix_, args_maker) - self._CompileAndCheck(jnp.ix_, args_maker) - - @jtu.sample_product( - dimensions=[(), (2,), (3, 0), (4, 5, 6)], - dtype=number_dtypes, - sparse=[True, False], - ) - def testIndices(self, dimensions, dtype, sparse): - def args_maker(): return [] - np_fun = partial(np.indices, dimensions=dimensions, - dtype=dtype, sparse=sparse) - jnp_fun = partial(jnp.indices, dimensions=dimensions, - dtype=dtype, sparse=sparse) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - shape=nonzerodim_shapes, dtype=all_dtypes, - ) - def testWhereOneArgument(self, shape, dtype): - rng = jtu.rand_some_zero(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np.where, jnp.where, args_maker, check_dtypes=False) - - # JIT compilation requires specifying a size statically. Full test of - # this behavior is in testNonzeroSize(). - jnp_fun = lambda x: jnp.where(x, size=np.size(x) // 2) - - with jtu.ignore_warning(category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*"): - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - shapes=filter(_shapes_are_broadcast_compatible, - itertools.combinations_with_replacement(all_shapes, 3)), - dtypes=itertools.combinations_with_replacement(all_dtypes, 3), - ) - def testWhereThreeArgument(self, shapes, dtypes): - rng = jtu.rand_default(self.rng()) - args_maker = self._GetArgsMaker(rng, shapes, dtypes) - def np_fun(cond, x, y): - return jtu.promote_like_jnp(partial(np.where, cond))(x, y) - with jtu.strict_promotion_if_dtypes_match(dtypes): - self._CheckAgainstNumpy(np_fun, jnp.where, args_maker) - self._CompileAndCheck(jnp.where, args_maker) - - def testWhereExtraCode(self): - def f(x): - return jnp.where(x > 0, x, -x) - - # Test no comparison literal True/False in jaxpr, and hence no comparison to - # literals - jaxpr = jax.make_jaxpr(jax.grad(f))(3.) - self.assertNotIn('False', str(jaxpr)) - self.assertNotIn('True', str(jaxpr)) - - def testWhereScalarPromotion(self): - x = jnp.where(jnp.array([True, False]), 3, - jnp.ones((2,), dtype=jnp.float32)) - self.assertEqual(x.dtype, np.dtype(np.float32)) - - @jtu.sample_product( - [dict(n=n, shapes=shapes) - for n in range(1, 3) - for shapes in filter( - _shapes_are_broadcast_compatible, - itertools.combinations_with_replacement(all_shapes, 2 * n + 1)) - ], - # To avoid forming the full product of shapes and dtypes we always sample - # maximal set of dtypes. - dtypes=itertools.combinations_with_replacement(all_dtypes, 3), - ) - def testSelect(self, n, shapes, dtypes): - dtypes = dtypes[:n+1] - rng = jtu.rand_default(self.rng()) - n = len(dtypes) - 1 - def args_maker(): - condlist = [rng(shape, np.bool_) for shape in shapes[:n]] - choicelist = [rng(shape, dtype) - for shape, dtype in zip(shapes[n:-1], dtypes[:n])] - default = rng(shapes[-1], dtypes[-1]) - return condlist, choicelist, default - # TODO(phawkins): float32/float64 type mismatches - @jax.numpy_dtype_promotion('standard') - def np_fun(condlist, choicelist, default): - choicelist = [x if jnp.result_type(x) != jnp.bfloat16 - else x.astype(np.float32) for x in choicelist] - dtype = jnp.result_type(default, *choicelist) - return np.select(condlist, - [np.asarray(x, dtype=dtype) for x in choicelist], - np.asarray(default, dtype=dtype)) - with jtu.strict_promotion_if_dtypes_match(dtypes): - self._CheckAgainstNumpy(np_fun, jnp.select, args_maker, - check_dtypes=False) - self._CompileAndCheck(jnp.select, args_maker, - rtol={np.float64: 1e-7, np.complex128: 1e-7}) - - def testIssue330(self): - x = jnp.full((1, 1), jnp.array([1])[0]) # doesn't crash - self.assertEqual(x[0, 0], 1) - - def testScalarDtypePromotion(self): - orig_numpy_result = (1 + np.eye(1, dtype=np.float32)).dtype - jax_numpy_result = (1 + jnp.eye(1, dtype=jnp.float32)).dtype - self.assertEqual(orig_numpy_result, jax_numpy_result) - - def testSymmetrizeDtypePromotion(self): - x = np.eye(3, dtype=np.float32) - orig_numpy_result = ((x + x.T) / 2).dtype - - x = jnp.eye(3, dtype=jnp.float32) - jax_numpy_result = ((x + x.T) / 2).dtype - self.assertEqual(orig_numpy_result, jax_numpy_result) - - # NOTE(mattjj): I disabled this test when removing lax._safe_mul because - # introducing the convention 0 * inf = 0 leads to silently wrong results in - # some cases. See this comment for details: - # https://github.com/jax-ml/jax/issues/1052#issuecomment-514083352 - # def testIssue347(self): - # # https://github.com/jax-ml/jax/issues/347 - # def test_fail(x): - # x = jnp.sqrt(jnp.sum(x ** 2, axis=1)) - # ones = jnp.ones_like(x) - # x = jnp.where(x > 0.5, x, ones) - # return jnp.sum(x) - # x = jnp.array([[1, 2], [3, 4], [0, 0]], dtype=jnp.float64) - # result = jax.grad(test_fail)(x) - # assert not np.any(np.isnan(result)) - - def testIssue453(self): - # https://github.com/jax-ml/jax/issues/453 - a = np.arange(6) + 1 - ans = jnp.reshape(a, (3, 2), order='F') - expected = np.reshape(a, (3, 2), order='F') - self.assertAllClose(ans, expected) - - @jtu.sample_product( - #dtype=[int, float, bool, complex], - dtype=[int, float, bool], - op=["atleast_1d", "atleast_2d", "atleast_3d"], - ) - def testAtLeastNdLiterals(self, dtype, op): - # Fixes: https://github.com/jax-ml/jax/issues/634 - np_fun = lambda arg: getattr(np, op)(arg).astype(dtypes.python_scalar_dtypes[dtype]) - jnp_fun = lambda arg: getattr(jnp, op)(arg) - args_maker = lambda: [dtype(2)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - shape=[(0,), (5,), (10,)], - dtype=int_dtypes, - weights=[True, False], - minlength=[0, 20], - length=[None, 8], - ) - def testBincount(self, shape, dtype, weights, minlength, length): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: (rng(shape, dtype), (rng(shape, 'float32') if weights else None)) - - def np_fun(x, *args): - x = np.clip(x, 0, None) # jnp.bincount clips negative values to zero. - out = np.bincount(x, *args, minlength=minlength) - if length and length > out.size: - return np.pad(out, (0, length - out.size)) - return out[:length] - jnp_fun = partial(jnp.bincount, minlength=minlength, length=length) - - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - if length is not None: - self._CompileAndCheck(jnp_fun, args_maker) - - def testBincountNegative(self): - # Test that jnp.bincount ignores negative values. - x_rng = jtu.rand_int(self.rng(), -100, 100) - w_rng = jtu.rand_uniform(self.rng()) - shape = (1000,) - x = x_rng(shape, 'int32') - w = w_rng(shape, 'float32') - - xn = np.array(x) - xn[xn < 0] = 0 - wn = np.array(w) - np_result = np.bincount(xn[xn >= 0], wn[xn >= 0]) - jnp_result = jnp.bincount(x, w) - self.assertAllClose(np_result, jnp_result, check_dtypes=False) - - @jtu.sample_product( - input=[ - 3, - [3], - [np.array(3)], - [np.array([3])], - [[np.array(3)]], - [[np.array([3])]], - [3, 4, 5], - [ - [np.eye(2, dtype=np.int32) * 2, np.zeros((2, 3), dtype=np.int32)], - [np.ones((3, 2), dtype=np.int32), np.eye(3, dtype=np.int32) * 3], - ], - [np.array([1, 2, 3]), np.array([2, 3, 4]), 10], - [np.ones((2, 2), dtype=np.int32), np.zeros((2, 2), dtype=np.int32)], - [[np.array([1, 2, 3])], [np.array([2, 3, 4])]], - ], - ) - def testBlock(self, input): - args_maker = lambda: [input] - self._CheckAgainstNumpy(np.block, jnp.block, args_maker) - self._CompileAndCheck(jnp.block, args_maker) - - def testLongLong(self): - self.assertAllClose(np.int64(7), jax.jit(lambda x: x)(np.longlong(7))) - - @jtu.ignore_warning(category=UserWarning, - message="Explicitly requested dtype.*") - @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion - def testArange(self): - # test cases inspired by dask tests at - # https://github.com/dask/dask/blob/main/dask/array/tests/test_creation.py#L92 - np_arange = jtu.with_jax_dtype_defaults(np.arange) - self.assertAllClose(jnp.arange(77), - np_arange(77)) - self.assertAllClose(jnp.arange(2, 13), - np_arange(2, 13)) - self.assertAllClose(jnp.arange(4, 21, 9), - np_arange(4, 21, 9)) - self.assertAllClose(jnp.arange(53, 5, -3), - np_arange(53, 5, -3)) - self.assertAllClose(jnp.arange(77, dtype=float), - np_arange(77, dtype=float)) - self.assertAllClose(jnp.arange(2, 13, dtype=int), - np_arange(2, 13, dtype=int)) - self.assertAllClose(jnp.arange(0, 1, -0.5), - np_arange(0, 1, -0.5)) - - self.assertRaises(TypeError, lambda: jnp.arange()) - - # test that jnp.arange(N) doesn't instantiate an ndarray - self.assertNotEqual(type(jnp.arange(77)), type(np.arange(77))) - self.assertEqual(type(jnp.arange(77)), type(lax.iota(np.int32, 77))) - - # test that jnp.arange(N, dtype=int32) doesn't instantiate an ndarray - self.assertNotEqual(type(jnp.arange(77, dtype=jnp.int32)), - type(np.arange(77, dtype=np.int32))) - self.assertEqual(type(jnp.arange(77, dtype=jnp.int32)), - type(lax.iota(np.int32, 77))) - - def testArangeJit(self): - ans = jax.jit(lambda: jnp.arange(5))() - expected = jtu.with_jax_dtype_defaults(np.arange)(5) - self.assertAllClose(ans, expected) - - @jtu.sample_product(args=[(5,), (0, 5)]) - def testArangeJaxpr(self, args): - jaxpr = jax.make_jaxpr(lambda: jnp.arange(*args))() - self.assertEqual(len(jaxpr.jaxpr.eqns), 1) - self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.iota_p) - - @unittest.skip("Jax-metal don't support complex.") - def testIssue830(self): - a = jnp.arange(4, dtype=jnp.complex64) - self.assertEqual(a.dtype, jnp.complex64) - - def testIssue728(self): - np_eye = jtu.with_jax_dtype_defaults(np.eye) - self.assertAllClose(jnp.eye(5000), np_eye(5000)) - self.assertEqual(0, np.sum(jnp.eye(1050) - np_eye(1050))) - - def testIssue746(self): - jnp.arange(12).reshape(3, 4) # doesn't crash - - def testIssue764(self): - x = jnp.linspace(190, 200, 4) - f = jax.grad(lambda x: jnp.sum(jnp.tanh(x))) - # Expected values computed with autograd in float64 precision. - expected = np.array([3.71669453e-165, 4.72999108e-168, 6.01954653e-171, - 7.66067839e-174], np.float64) - self.assertAllClose(f(x), expected, check_dtypes=False) - - # Test removed because tie_in is deprecated. - # def testIssue776(self): - # """Tests that the scatter-add transpose rule instantiates symbolic zeros.""" - # def f(u): - # y = jnp.ones_like(u, shape=10).at[np.array([2, 4, 5])].add(u) - # # The transpose rule for lax.tie_in returns a symbolic zero for its first - # # argument. - # return lax.tie_in(y, 7.) - - # self.assertAllClose(np.zeros(3,), jax.grad(f)(np.ones(3,))) - - # NOTE(mattjj): I disabled this test when removing lax._safe_mul because this - # is a numerical stability issue that should be solved with a custom jvp rule - # of the sigmoid function being differentiated here, not by safe_mul. - # def testIssue777(self): - # x = jnp.linspace(-200, 0, 4, dtype=np.float32) - # f = jax.grad(lambda x: jnp.sum(1 / (1 + jnp.exp(-x)))) - # self.assertAllClose(f(x), np.array([0., 0., 0., 0.25], dtype=np.float32)) - - #unittest.skip("Jax-metal fail on tanh with np.nan") - @jtu.sample_product( - dtype=float_dtypes, - op=("sqrt", "arccos", "arcsin", "arctan", "sin", "cos", "tan", - "sinh", "cosh", "tanh", "arccosh", "arcsinh", "arctanh", "exp", - "log", "expm1", "log1p"), - ) - def testMathSpecialFloatValues(self, op, dtype): - np_op = getattr(np, op) - np_op = jtu.ignore_warning(category=RuntimeWarning, - message="invalid value.*")(np_op) - np_op = jtu.ignore_warning(category=RuntimeWarning, - message="divide by zero.*")(np_op) - np_op = jtu.ignore_warning(category=RuntimeWarning, - message="overflow.*")(np_op) - - jnp_op = getattr(jnp, op) - dtype = np.dtype(dtypes.canonicalize_dtype(dtype)).type - for x in (-np.inf, -100., -2., -1., 0., 1., 2., 100., np.inf, - jnp.finfo(dtype).max, np.sqrt(jnp.finfo(dtype).max), - np.sqrt(jnp.finfo(dtype).max) * 2.): #np.nan - x = dtype(x) - expected = np_op(x) - actual = jnp_op(x) - tol = jtu.tolerance(dtype, {np.float32: 1e-3, np.float64: 1e-7}) - self.assertAllClose(expected, actual, atol=tol, - rtol=tol) - - def testIssue956(self): - self.assertRaises(TypeError, lambda: jnp.ndarray((1, 1))) - - def testIssue967(self): - self.assertRaises(TypeError, lambda: jnp.zeros(1.5)) - - @jtu.sample_product( - shape=[(5,), (10, 5), (4, 10)], - dtype=number_dtypes, - rowvar=[True, False], - ) - @jax.default_matmul_precision("float32") - def testCorrCoef(self, shape, dtype, rowvar): - rng = jtu.rand_default(self.rng()) - def args_maker(): - ok = False - while not ok: - x = rng(shape, dtype) - ok = not np.any(np.isclose(np.std(x), 0.0)) - return (x,) - np_fun = partial(np.corrcoef, rowvar=rowvar) - np_fun = jtu.ignore_warning( - category=RuntimeWarning, message="invalid value encountered.*")(np_fun) - jnp_fun = partial(jnp.corrcoef, rowvar=rowvar) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - [dict(dtype=dtype, end_dtype=end_dtype, begin_dtype=begin_dtype, - shape=shape, begin_shape=begin_shape, end_shape=end_shape) - for dtype in number_dtypes - for end_dtype in [None] + [dtype] - for begin_dtype in [None] + [dtype] - for shape in [s for s in all_shapes if s != jtu.PYTHON_SCALAR_SHAPE] - for begin_shape in ( - [None] if begin_dtype is None - else [s for s in all_shapes if s != jtu.PYTHON_SCALAR_SHAPE]) - for end_shape in ( - [None] if end_dtype is None - else [s for s in all_shapes if s != jtu.PYTHON_SCALAR_SHAPE]) - ], - ) - def testEDiff1d(self, shape, dtype, end_shape, end_dtype, begin_shape, - begin_dtype): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype), - (None if end_dtype is None else rng(end_shape, end_dtype)), - (None if begin_dtype is None else rng(begin_shape, begin_dtype))] - np_fun = lambda x, to_end, to_begin: np.ediff1d(x, to_end, to_begin) - jnp_fun = lambda x, to_end, to_begin: jnp.ediff1d(x, to_end, to_begin) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - def testEDiff1dWithDtypeCast(self): - rng = jtu.rand_default(self.rng()) - shape = jtu.NUMPY_SCALAR_SHAPE - dtype = jnp.float32 - end_dtype = jnp.int32 - args_maker = lambda: [rng(shape, dtype), rng(shape, end_dtype), rng(shape, dtype)] - np_fun = lambda x, to_end, to_begin: np.ediff1d(x, to_end, to_begin) - jnp_fun = lambda x, to_end, to_begin: jnp.ediff1d(x, to_end, to_begin) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - shapes=[(), (5,), (5, 3)], - dtype=number_dtypes, - indexing=['xy', 'ij'], - sparse=[True, False], - ) - def testMeshGrid(self, shapes, dtype, indexing, sparse): - rng = jtu.rand_default(self.rng()) - args_maker = self._GetArgsMaker(rng, [(x,) for x in shapes], - [dtype] * len(shapes)) - np_fun = partial(np.meshgrid, indexing=indexing, sparse=sparse) - jnp_fun = partial(jnp.meshgrid, indexing=indexing, sparse=sparse) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - def testMgrid(self): - # wrap indexer for appropriate dtype defaults. - np_mgrid = _indexer_with_default_outputs(np.mgrid) - assertAllEqual = partial(self.assertAllClose, atol=0, rtol=0) - assertAllEqual(np_mgrid[()], jnp.mgrid[()]) - assertAllEqual(np_mgrid[:4], jnp.mgrid[:4]) - assertAllEqual(np_mgrid[:4,], jnp.mgrid[:4,]) - assertAllEqual(np_mgrid[:4], jax.jit(lambda: jnp.mgrid[:4])()) - assertAllEqual(np_mgrid[:5, :5], jnp.mgrid[:5, :5]) - assertAllEqual(np_mgrid[:3, :2], jnp.mgrid[:3, :2]) - assertAllEqual(np_mgrid[1:4:2], jnp.mgrid[1:4:2]) - assertAllEqual(np_mgrid[1:5:3, :5], jnp.mgrid[1:5:3, :5]) - assertAllEqual(np_mgrid[:3, :2, :5], jnp.mgrid[:3, :2, :5]) - assertAllEqual(np_mgrid[:3:2, :2, :5], jnp.mgrid[:3:2, :2, :5]) - # Corner cases - assertAllEqual(np_mgrid[:], jnp.mgrid[:]) - # When the step length is a complex number, because of float calculation, - # the values between jnp and np might slightly different. - atol = 1e-6 - rtol = 1e-6 - self.assertAllClose(np_mgrid[-1:1:5j], - jnp.mgrid[-1:1:5j], - atol=atol, - rtol=rtol) - self.assertAllClose(np_mgrid[3:4:7j], - jnp.mgrid[3:4:7j], - atol=atol, - rtol=rtol) - self.assertAllClose(np_mgrid[1:6:8j, 2:4], - jnp.mgrid[1:6:8j, 2:4], - atol=atol, - rtol=rtol) - # Non-integer steps - self.assertAllClose(np_mgrid[0:3.5:0.5], - jnp.mgrid[0:3.5:0.5], - atol=atol, - rtol=rtol) - self.assertAllClose(np_mgrid[1.3:4.2:0.3], - jnp.mgrid[1.3:4.2:0.3], - atol=atol, - rtol=rtol) - # abstract tracer value for jnp.mgrid slice - with self.assertRaisesRegex(core.ConcretizationTypeError, - "slice start of jnp.mgrid"): - jax.jit(lambda a, b: jnp.mgrid[a:b])(0, 2) - - def testOgrid(self): - # wrap indexer for appropriate dtype defaults. - np_ogrid = _indexer_with_default_outputs(np.ogrid) - def assertSequenceOfArraysEqual(xs, ys): - self.assertIsInstance(xs, (list, tuple)) - self.assertIsInstance(ys, (list, tuple)) - self.assertEqual(len(xs), len(ys)) - for x, y in zip(xs, ys): - self.assertArraysEqual(x, y) - - self.assertArraysEqual(np_ogrid[:5], jnp.ogrid[:5]) - self.assertArraysEqual(np_ogrid[:5], jax.jit(lambda: jnp.ogrid[:5])()) - self.assertArraysEqual(np_ogrid[1:7:2], jnp.ogrid[1:7:2]) - # List of arrays - assertSequenceOfArraysEqual(np_ogrid[:5,], jnp.ogrid[:5,]) - assertSequenceOfArraysEqual(np_ogrid[0:5, 1:3], jnp.ogrid[0:5, 1:3]) - assertSequenceOfArraysEqual(np_ogrid[1:3:2, 2:9:3], jnp.ogrid[1:3:2, 2:9:3]) - assertSequenceOfArraysEqual(np_ogrid[:5, :9, :11], jnp.ogrid[:5, :9, :11]) - # Corner cases - self.assertArraysEqual(np_ogrid[:], jnp.ogrid[:]) - # Complex number steps - atol = 1e-6 - rtol = 1e-6 - self.assertAllClose(np_ogrid[-1:1:5j], - jnp.ogrid[-1:1:5j], - atol=atol, - rtol=rtol) - # Non-integer steps - self.assertAllClose(np_ogrid[0:3.5:0.3], - jnp.ogrid[0:3.5:0.3], - atol=atol, - rtol=rtol) - self.assertAllClose(np_ogrid[1.2:4.8:0.24], - jnp.ogrid[1.2:4.8:0.24], - atol=atol, - rtol=rtol) - # abstract tracer value for ogrid slice - with self.assertRaisesRegex(core.ConcretizationTypeError, - "slice start of jnp.ogrid"): - jax.jit(lambda a, b: jnp.ogrid[a:b])(0, 2) - - def testR_(self): - a = np.arange(6).reshape((2,3)) - self.assertArraysEqual(np.r_[np.array([1,2,3]), 0, 0, np.array([4,5,6])], - jnp.r_[np.array([1,2,3]), 0, 0, np.array([4,5,6])]) - self.assertArraysEqual(np.r_['-1', a, a], jnp.r_['-1', a, a]) - - self.assertArraysEqual(np.r_['0,2', [1,2,3], [4,5,6]], jnp.r_['0,2', [1,2,3], [4,5,6]]) - self.assertArraysEqual(np.r_['0,2,0', [1,2,3], [4,5,6]], jnp.r_['0,2,0', [1,2,3], [4,5,6]]) - self.assertArraysEqual(np.r_['1,2,0', [1,2,3], [4,5,6]], jnp.r_['1,2,0', [1,2,3], [4,5,6]]) - # negative 1d axis start - self.assertArraysEqual(np.r_['0,4,-1', [1,2,3], [4,5,6]], jnp.r_['0,4,-1', [1,2,3], [4,5,6]]) - self.assertArraysEqual(np.r_['0,4,-2', [1,2,3], [4,5,6]], jnp.r_['0,4,-2', [1,2,3], [4,5,6]]) - - # matrix directives - with jtu.ignore_warning(category=PendingDeprecationWarning): - self.assertArraysEqual(np.r_['r',[1,2,3], [4,5,6]], jnp.r_['r',[1,2,3], [4,5,6]]) - self.assertArraysEqual(np.r_['c', [1, 2, 3], [4, 5, 6]], jnp.r_['c', [1, 2, 3], [4, 5, 6]]) - - # bad directive - with self.assertRaisesRegex(ValueError, "could not understand directive.*"): - jnp.r_["asdfgh",[1,2,3]] - # abstract tracer value for r_ slice - with self.assertRaisesRegex(core.ConcretizationTypeError, - "slice start of jnp.r_"): - jax.jit(lambda a, b: jnp.r_[a:b])(0, 2) - - # wrap indexer for appropriate dtype defaults. - np_r_ = _indexer_with_default_outputs(np.r_) - - # Complex number steps - atol = 1e-6 - rtol = 1e-6 - self.assertAllClose(np_r_[-1:1:6j], - jnp.r_[-1:1:6j], - atol=atol, - rtol=rtol) - with jax.numpy_dtype_promotion('standard'): # Requires dtype promotion. - self.assertAllClose(np_r_[-1:1:6j, [0]*3, 5, 6], - jnp.r_[-1:1:6j, [0]*3, 5, 6], - atol=atol, - rtol=rtol) - # Non-integer steps - self.assertAllClose(np_r_[1.2:4.8:0.24], - jnp.r_[1.2:4.8:0.24], - atol=atol, - rtol=rtol) - - def testC_(self): - a = np.arange(6).reshape((2, 3)) - self.assertArraysEqual(np.c_[np.array([1,2,3]), np.array([4,5,6])], - jnp.c_[np.array([1,2,3]), np.array([4,5,6])]) - self.assertArraysEqual(np.c_[np.array([[1,2,3]]), 0, 0, np.array([[4,5,6]])], - jnp.c_[np.array([[1,2,3]]), 0, 0, np.array([[4,5,6]])]) - self.assertArraysEqual(np.c_['-1', a, a], jnp.c_['-1', a, a]) - - self.assertArraysEqual(np.c_['0,2', [1,2,3], [4,5,6]], jnp.c_['0,2', [1,2,3], [4,5,6]]) - self.assertArraysEqual(np.c_['0,2,0', [1,2,3], [4,5,6]], jnp.c_['0,2,0', [1,2,3], [4,5,6]]) - self.assertArraysEqual(np.c_['1,2,0', [1,2,3], [4,5,6]], jnp.c_['1,2,0', [1,2,3], [4,5,6]]) - # negative 1d axis start - self.assertArraysEqual(np.c_['0,4,-1', [1,2,3], [4,5,6]], jnp.c_['0,4,-1', [1,2,3], [4,5,6]]) - self.assertArraysEqual(np.c_['0,4,-2', [1,2,3], [4,5,6]], jnp.c_['0,4,-2', [1,2,3], [4,5,6]]) - # matrix directives, avoid numpy deprecation warning - with jtu.ignore_warning(category=PendingDeprecationWarning): - self.assertArraysEqual(np.c_['r',[1,2,3], [4,5,6]], jnp.c_['r',[1,2,3], [4,5,6]]) - self.assertArraysEqual(np.c_['c', [1, 2, 3], [4, 5, 6]], jnp.c_['c', [1, 2, 3], [4, 5, 6]]) - - # bad directive - with self.assertRaisesRegex(ValueError, "could not understand directive.*"): - jnp.c_["asdfgh",[1,2,3]] - # abstract tracer value for c_ slice - with self.assertRaisesRegex(core.ConcretizationTypeError, - "slice start of jnp.c_"): - jax.jit(lambda a, b: jnp.c_[a:b])(0, 2) - - # wrap indexer for appropriate dtype defaults. - np_c_ = _indexer_with_default_outputs(np.c_) - - # Complex number steps - atol = 1e-6 - rtol = 1e-6 - self.assertAllClose(np_c_[-1:1:6j], - jnp.c_[-1:1:6j], - atol=atol, - rtol=rtol) - - # Non-integer steps - self.assertAllClose(np_c_[1.2:4.8:0.24], - jnp.c_[1.2:4.8:0.24], - atol=atol, - rtol=rtol) - - def testS_(self): - self.assertEqual(np.s_[1:2:20],jnp.s_[1:2:20]) - - def testIndex_exp(self): - self.assertEqual(np.index_exp[5:3:2j],jnp.index_exp[5:3:2j]) - - @jtu.sample_product( - start_shape=[(), (2,), (2, 2)], - stop_shape=[(), (2,), (2, 2)], - num=[0, 1, 2, 5, 20], - endpoint=[True, False], - retstep=[True, False], - # floating-point compute between jitted platforms and non-jit + rounding - # cause unavoidable variation in integer truncation for some inputs, so - # we currently only test inexact 'dtype' arguments. - dtype=inexact_dtypes + [None,], - ) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testLinspace(self, start_shape, stop_shape, num, endpoint, retstep, dtype): - rng = jtu.rand_default(self.rng()) - # relax default tolerances slightly - tol = jtu.tolerance(dtype if dtype else np.float32) * 10 - args_maker = self._GetArgsMaker(rng, - [start_shape, stop_shape], - [dtype, dtype]) - start, stop = args_maker() - ndim = len(np.shape(start + stop)) - for axis in range(-ndim, ndim): - jnp_op = lambda start, stop: jnp.linspace( - start, stop, num, - endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis) - np_op = lambda start, stop: np.linspace( - start, stop, num, - endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis) - - self._CheckAgainstNumpy(np_op, jnp_op, args_maker, - check_dtypes=False, tol=tol) - self._CompileAndCheck(jnp_op, args_maker, - check_dtypes=False, atol=tol, rtol=tol) - - @jtu.sample_product(dtype=number_dtypes) - def testLinspaceEndpoints(self, dtype): - """Regression test for Issue #3014.""" - rng = jtu.rand_default(self.rng()) - endpoints = rng((2,), dtype) - out = jnp.linspace(*endpoints, 10, dtype=dtype) - self.assertAllClose(out[np.array([0, -1])], endpoints, rtol=0, atol=0) - - @jtu.sample_product( - start_shape=[(), (2,), (2, 2)], - stop_shape=[(), (2,), (2, 2)], - num=[0, 1, 2, 5, 20], - endpoint=[True, False], - base=[10.0, 2, np.e], - # skip 16-bit floats due to insufficient precision for the test. - dtype=jtu.dtypes.inexact + [None,], - ) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testLogspace(self, start_shape, stop_shape, num, - endpoint, base, dtype): - if (dtype in int_dtypes and - jtu.test_device_matches(["gpu", "tpu"]) and - not config.enable_x64.value): - raise unittest.SkipTest("GPUx32 truncated exponentiation" - " doesn't exactly match other platforms.") - rng = jtu.rand_default(self.rng()) - # relax default tolerances slightly - tol = {np.float32: 1e-2, np.float64: 1e-6, np.complex64: 1e-3, np.complex128: 1e-6} - args_maker = self._GetArgsMaker(rng, - [start_shape, stop_shape], - [dtype, dtype]) - start, stop = args_maker() - ndim = len(np.shape(start + stop)) - for axis in range(-ndim, ndim): - jnp_op = lambda start, stop: jnp.logspace( - start, stop, num, endpoint=endpoint, base=base, dtype=dtype, axis=axis) - @jtu.ignore_warning(category=RuntimeWarning, - message="overflow encountered in power") - def np_op(start, stop): - return np.logspace(start, stop, num, endpoint=endpoint, - base=base, dtype=dtype, axis=axis) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker, - check_dtypes=False, tol=tol) - if dtype in (inexact_dtypes + [None,]): - # Why do compiled and op-by-op float16 np.power numbers differ - # slightly more than expected? - atol = {np.float16: 1e-2} - self._CompileAndCheck(jnp_op, args_maker, - check_dtypes=False, atol=atol, rtol=tol) - - @jtu.sample_product( - [dict(start_shape=start_shape, stop_shape=stop_shape, axis=axis) - for start_shape in [(), (2,), (2, 2)] - for stop_shape in [(), (2,), (2, 2)] - for axis in range(-max(len(start_shape), len(stop_shape)), - max(len(start_shape), len(stop_shape))) - ], - num=[0, 1, 2, 5, 20], - endpoint=[True, False], - # NB: numpy's geomspace gives nonsense results on integer types - dtype=inexact_dtypes + [None,], - ) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testGeomspace(self, start_shape, stop_shape, num, - endpoint, dtype, axis): - rng = jtu.rand_default(self.rng()) - # relax default tolerances slightly - tol = {dtypes.bfloat16: 2e-2, np.float16: 4e-3, np.float32: 2e-3, - np.float64: 1e-14, np.complex64: 2e-3, np.complex128: 1e-14} - def args_maker(): - """Test the set of inputs np.geomspace is well-defined on.""" - start, stop = self._GetArgsMaker(rng, - [start_shape, stop_shape], - [dtype, dtype])() - # np.geomspace can't handle differently ranked tensors - # w. negative numbers! - start, stop = jnp.broadcast_arrays(start, stop) - if dtype in complex_dtypes: - return start, stop - # to avoid NaNs, non-complex start and stop cannot - # differ in sign, elementwise - start = start * jnp.sign(start) * jnp.sign(stop) - return start, stop - start, stop = args_maker() - def jnp_op(start, stop): - return jnp.geomspace(start, stop, num, endpoint=endpoint, dtype=dtype, - axis=axis) - def np_op(start, stop): - start = start.astype(np.float32) if dtype == jnp.bfloat16 else start - stop = stop.astype(np.float32) if dtype == jnp.bfloat16 else stop - return np.geomspace( - start, stop, num, endpoint=endpoint, - dtype=dtype if dtype != jnp.bfloat16 else np.float32, - axis=axis).astype(dtype) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker, - check_dtypes=False, tol=tol) - if dtype in (inexact_dtypes + [None,]): - self._CompileAndCheck(jnp_op, args_maker, - check_dtypes=False, atol=tol, rtol=tol) - - def testDisableNumpyRankPromotionBroadcasting(self): - with jax.numpy_rank_promotion('allow'): - jnp.ones(2) + jnp.ones((1, 2)) # works just fine - - with jax.numpy_rank_promotion('raise'): - self.assertRaises(ValueError, lambda: jnp.ones(2) + jnp.ones((1, 2))) - jnp.ones(2) + 3 # don't want to raise for scalars - - with jax.numpy_rank_promotion('warn'): - self.assertWarnsRegex(UserWarning, "Following NumPy automatic rank promotion for add on " - r"shapes \(2,\) \(1, 2\).*", lambda: jnp.ones(2) + jnp.ones((1, 2))) - jnp.ones(2) + 3 # don't want to warn for scalars - - @unittest.skip("Test fails on CI, perhaps due to JIT caching") - def testDisableNumpyRankPromotionBroadcastingDecorator(self): - with jax.numpy_rank_promotion("allow"): - jnp.ones(2) + jnp.ones((1, 2)) # works just fine - - with jax.numpy_rank_promotion("raise"): - self.assertRaises(ValueError, lambda: jnp.ones(2) + jnp.ones((1, 2))) - jnp.ones(2) + 3 # don't want to raise for scalars - - with jax.numpy_rank_promotion("warn"): - self.assertWarnsRegex(UserWarning, "Following NumPy automatic rank promotion for add on " - r"shapes \(2,\) \(1, 2\).*", lambda: jnp.ones(2) + jnp.ones((1, 2))) - jnp.ones(2) + 3 # don't want to warn for scalars - - def testStackArrayArgument(self): - # tests https://github.com/jax-ml/jax/issues/1271 - @jax.jit - def foo(x): - return jnp.stack(x) - foo(np.zeros(2)) # doesn't crash - - @jax.jit - def foo(x): - return jnp.concatenate(x) - foo(np.zeros((2, 2))) # doesn't crash - - def testReluGradientConstants(self): - # This is a regression test that verifies that constants associated with the - # gradient of np.maximum (from lax._balanced_eq) aren't hoisted into the - # outermost jaxpr. This was producing some large materialized constants for - # every relu activation in a model. - def body(i, xy): - x, y = xy - y = y + jax.grad(lambda z: jnp.sum(jnp.maximum(z, 0.)))(x) - return x, y - - f = lambda y: lax.fori_loop(0, 5, body, (y, y)) - jaxpr = jax.make_jaxpr(f)(np.zeros((3, 4), np.float32)) - self.assertFalse( - any(np.array_equal(x, np.full((3, 4), 2., dtype=np.float32)) - for x in jaxpr.consts)) - - @jtu.sample_product( - [dict(from_shape=from_shape, to_shape=to_shape) - for from_shape, to_shape in [ - [(1, 3), (4, 3)], - [(3,), (2, 1, 3)], - [(3,), (3, 3)], - [(1,), (3,)], - [(1,), 3], - ] - ], - ) - def testBroadcastTo(self, from_shape, to_shape): - rng = jtu.rand_default(self.rng()) - args_maker = self._GetArgsMaker(rng, [from_shape], [np.float32]) - np_op = lambda x: np.broadcast_to(x, to_shape) - jnp_op = lambda x: jnp.broadcast_to(x, to_shape) - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - @jtu.sample_product( - [dict(shapes=shapes, broadcasted_shape=broadcasted_shape) - for shapes, broadcasted_shape in [ - [[], ()], - [[()], ()], - [[(1, 3), (4, 3)], (4, 3)], - [[(3,), (2, 1, 3)], (2, 1, 3)], - [[(3,), (3, 3)], (3, 3)], - [[(1,), (3,)], (3,)], - [[(1,), 3], (3,)], - [[(6, 7), (5, 6, 1), (7,), (5, 1, 7)], (5, 6, 7)], - [[[1], [0, 1]], (0, 1)], - [[(1,), np.array([0, 1])], (0, 1)], - ] - ], - ) - def testBroadcastShapes(self, shapes, broadcasted_shape): - # Test against np.broadcast_shapes once numpy 1.20 is minimum required version - np.testing.assert_equal(jnp.broadcast_shapes(*shapes), broadcasted_shape) - - def testBroadcastToIssue1522(self): - self.assertRaisesRegex( - ValueError, "Incompatible shapes for broadcasting: .*", - lambda: jnp.broadcast_to(np.ones((2, 3)), (1, 3))) - - def testBroadcastToIntIssue1548(self): - self.assertAllClose(jnp.broadcast_to(1, (3, 2)), np.ones((3, 2)), - check_dtypes=False) - - def testBroadcastToOnScalar(self): - self.assertIsInstance(jnp.broadcast_to(10.0, ()), jax.Array) - self.assertIsInstance(np.broadcast_to(10.0, ()), np.ndarray) - - def testPrecision(self): - - ones_1d = np.ones((2,)) - ones_2d = np.ones((2, 2)) - ones_3d = np.ones((2, 2, 2)) - HIGHEST = lax.Precision.HIGHEST - - jtu.assert_dot_precision(None, jnp.dot, ones_1d, ones_1d) - jtu.assert_dot_precision( - HIGHEST, - partial(jnp.dot, precision=HIGHEST), - ones_1d, ones_1d) - jtu.assert_dot_precision( - HIGHEST, - partial(jnp.dot, precision=HIGHEST), - ones_3d, ones_3d) - jtu.assert_dot_precision( - HIGHEST, - partial(jnp.matmul, precision=HIGHEST), - ones_2d, ones_2d) - jtu.assert_dot_precision( - HIGHEST, - partial(jnp.vdot, precision=HIGHEST), - ones_1d, ones_1d) - jtu.assert_dot_precision( - HIGHEST, - partial(jnp.vecdot, precision=HIGHEST), - ones_1d, ones_1d) - jtu.assert_dot_precision( - HIGHEST, - partial(jnp.tensordot, axes=2, precision=HIGHEST), - ones_2d, ones_2d) - jtu.assert_dot_precision( - HIGHEST, - partial(jnp.tensordot, axes=(0, 0), precision=HIGHEST), - ones_1d, ones_1d) - jtu.assert_dot_precision( - HIGHEST, - partial(jnp.tensordot, axes=((0,), (0,)), precision=HIGHEST), - ones_1d, ones_1d) - jtu.assert_dot_precision( - HIGHEST, - partial(jnp.einsum, 'i,i', precision=HIGHEST), - ones_1d, ones_1d) - jtu.assert_dot_precision( - HIGHEST, - partial(jnp.einsum, 'ij,ij', precision=HIGHEST), - ones_2d, ones_2d) - jtu.assert_dot_precision( - HIGHEST, - partial(jnp.inner, precision=HIGHEST), - ones_1d, ones_1d) - - @jtu.sample_product( - funcname=['inner', 'matmul', 'dot', 'vdot', 'tensordot', 'vecdot'] - ) - def testPreferredElementType(self, funcname): - func = getattr(jnp, funcname) - kwargs = dict(axes=0) if funcname == 'tensordot' else {} - - ones_i32 = np.ones(2, dtype='int32') - ones_f32 = np.ones(2, dtype='float32') - - with jax.numpy_dtype_promotion('strict'): - jtu.assert_dot_preferred_element_type('int32', func, ones_i32, ones_i32, **kwargs) - jtu.assert_dot_preferred_element_type('float32', func, ones_f32, ones_f32, **kwargs) - jtu.assert_dot_preferred_element_type('bfloat16', func, ones_f32, ones_f32, **kwargs, - preferred_element_type='bfloat16') - with jax.numpy_dtype_promotion('standard'): - jtu.assert_dot_preferred_element_type('float32', func, ones_i32, ones_f32, **kwargs) - - @jtu.sample_product( - [dict(shape=shape, varargs=varargs, axis=axis) - for shape in [(10,), (10, 15), (10, 15, 20)] - for _num_axes in range(len(shape)) - for varargs in itertools.combinations(range(1, len(shape) + 1), _num_axes) - for axis in itertools.combinations(range(len(shape)), _num_axes) - ], - dtype=inexact_dtypes, - ) - def testGradient(self, shape, varargs, axis, dtype): - rng = jtu.rand_default(self.rng()) - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - jnp_fun = lambda y: jnp.gradient(y, *varargs, axis=axis) - np_fun = lambda y: np.gradient(y, *varargs, axis=axis) - self._CheckAgainstNumpy( - np_fun, jnp_fun, args_maker, check_dtypes=False) - self._CompileAndCheck(jnp_fun, args_maker) - - def testZerosShapeErrors(self): - # see https://github.com/jax-ml/jax/issues/1822 - self.assertRaisesRegex( - TypeError, - "Shapes must be 1D sequences of concrete values of integer type.*", - lambda: jnp.zeros(1.)) - self.assertRaisesRegex( - TypeError, - r"Shapes must be 1D sequences of concrete values of integer type.*\n" - "If using `jit`, try using `static_argnums` or applying `jit` to " - "smaller subfunctions.", - lambda: jax.jit(jnp.zeros)(2)) - - def testTraceMethod(self): - x = self.rng().randn(3, 4).astype(jnp.float_) - self.assertAllClose(x.trace(), jnp.array(x).trace()) - self.assertAllClose(x.trace(), jax.jit(lambda y: y.trace())(x)) - - def testIntegerPowersArePrecise(self): - # See https://github.com/jax-ml/jax/pull/3036 - # Checks if the squares of float32 integers have no numerical errors. - # It should be satisfied with all integers less than sqrt(2**24). - x = jnp.arange(-2**12, 2**12, dtype=jnp.int32) - np.testing.assert_array_equal(jnp.square(x.astype(jnp.float32)), x * x) - np.testing.assert_array_equal(x.astype(jnp.float32) ** 2, x * x) - - # Similarly for cubes. - x = jnp.arange(-2**8, 2**8, dtype=jnp.int32) - np.testing.assert_array_equal(x.astype(jnp.float32) ** 3, x * x * x) - - x = np.arange(10, dtype=np.float32) - for i in range(10): - self.assertAllClose(x.astype(jnp.float32) ** i, x ** i, - check_dtypes=False) - - def testToBytes(self): - v = np.arange(12, dtype=np.int32).reshape(3, 4) - for order in ['C', 'F']: - self.assertEqual(jnp.asarray(v).tobytes(order), v.tobytes(order)) - - def testToBytesJitError(self): - v = np.arange(12, dtype=np.int32).reshape(3, 4) - f = jax.jit(lambda x: x.tobytes()) - msg = r".*The tobytes\(\) method was called on traced array" - with self.assertRaisesRegex(core.ConcretizationTypeError, msg): - f(v) - - def testToList(self): - v = np.arange(12, dtype=np.int32).reshape(3, 4) - self.assertEqual(jnp.asarray(v).tolist(), v.tolist()) - - def testToListJitError(self): - v = np.arange(12, dtype=np.int32).reshape(3, 4) - f = jax.jit(lambda x: x.tolist()) - msg = r".*The tolist\(\) method was called on traced array" - with self.assertRaisesRegex(core.ConcretizationTypeError, msg): - f(v) - - def testArangeConcretizationError(self): - msg = r"It arose in the jnp.arange argument '{}'".format - with self.assertRaisesRegex(core.ConcretizationTypeError, msg('stop')): - jax.jit(jnp.arange)(3) - - with self.assertRaisesRegex(core.ConcretizationTypeError, msg('start')): - jax.jit(lambda start: jnp.arange(start, 3))(0) - - with self.assertRaisesRegex(core.ConcretizationTypeError, msg('stop')): - jax.jit(lambda stop: jnp.arange(0, stop))(3) - - @jtu.sample_product(dtype=[None] + float_dtypes) - def testArange64Bit(self, dtype): - # Test that jnp.arange uses 64-bit arithmetic to define its range, even if the - # output has another dtype. The issue here is that if python scalar inputs to - # jnp.arange are cast to float32 before the range is computed, it changes the - # number of elements output by the range. It's unclear whether this was deliberate - # behavior in the initial implementation, but it's behavior that downstream users - # have come to rely on. - args = (1.2, 4.8, 0.24) - - # Ensure that this test case leads to differing lengths if cast to float32. - self.assertLen(np.arange(*args), 15) - self.assertLen(np.arange(*map(np.float32, args)), 16) - - jnp_fun = lambda: jnp.arange(*args, dtype=dtype) - np_fun = jtu.with_jax_dtype_defaults(lambda: np.arange(*args, dtype=dtype), dtype is None) - args_maker = lambda: [] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - def testIssue2347(self): - # https://github.com/jax-ml/jax/issues/2347 - object_list = list[tuple[jnp.array, float, float, jnp.array, bool]] - self.assertRaises(TypeError, jnp.array, object_list) - - np_object_list = np.array(object_list) - self.assertRaises(TypeError, jnp.array, np_object_list) - - @unittest.skip("JAX-metal don't support complex type yet.") - @jtu.sample_product( - [dict(shapes=shapes, dtypes=dtypes) - for shapes in filter( - _shapes_are_broadcast_compatible, - itertools.combinations_with_replacement(all_shapes, 2)) - for dtypes in itertools.product( - *(_valid_dtypes_for_shape(s, complex_dtypes + [None]) for s in shapes)) - ], - ) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testLogaddexpComplex(self, shapes, dtypes): - @jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*") - def np_op(x1, x2): - return np.log(np.exp(x1) + np.exp(x2)) - - rng = jtu.rand_some_nan(self.rng()) - args_maker = lambda: tuple(rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)) - if jtu.test_device_matches(["tpu"]): - tol = {np.complex64: 1e-3, np.complex128: 1e-10} - else: - tol = {np.complex64: 1e-5, np.complex128: 1e-14} - - with jtu.strict_promotion_if_dtypes_match(dtypes): - self._CheckAgainstNumpy(jtu.promote_like_jnp(np_op), jnp.logaddexp, args_maker, tol=tol) - self._CompileAndCheck(jnp.logaddexp, args_maker, rtol=tol, atol=tol) - - @unittest.skip("JAX-metal don't support complex type yet.") - @jtu.sample_product( - [dict(shapes=shapes, dtypes=dtypes) - for shapes in filter( - _shapes_are_broadcast_compatible, - itertools.combinations_with_replacement(all_shapes, 2)) - for dtypes in itertools.product( - *(_valid_dtypes_for_shape(s, complex_dtypes + [None]) for s in shapes)) - ], - ) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testLogaddexp2Complex(self, shapes, dtypes): - @jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*") - def np_op(x1, x2): - return np.log2(np.exp2(x1) + np.exp2(x2)) - - rng = jtu.rand_some_nan(self.rng()) - args_maker = lambda: tuple(rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)) - if jtu.test_device_matches(["tpu"]): - tol = {np.complex64: 1e-3, np.complex128: 1e-10} - else: - tol = {np.complex64: 1e-5, np.complex128: 1e-14} - - with jtu.strict_promotion_if_dtypes_match(dtypes): - self._CheckAgainstNumpy(jtu.promote_like_jnp(np_op), jnp.logaddexp2, args_maker, tol=tol) - self._CompileAndCheck(jnp.logaddexp2, args_maker, rtol=tol, atol=tol) - - def testDefaultDtypes(self): - precision = config.default_dtype_bits.value - assert precision in ['32', '64'] - self.assertEqual(jnp.bool_, np.bool_) - self.assertEqual(jnp.int_, np.int32 if precision == '32' else np.int64) - self.assertEqual(jnp.uint, np.uint32 if precision == '32' else np.uint64) - self.assertEqual(jnp.float_, np.float32 if precision == '32' else np.float64) - self.assertEqual(jnp.complex_, np.complex64 if precision == '32' else np.complex128) - - def testFromBuffer(self): - buf = b'\x01\x02\x03' - expected = np.frombuffer(buf, dtype='uint8') - actual = jnp.frombuffer(buf, dtype='uint8') - self.assertArraysEqual(expected, actual) - - def testFromFunction(self): - def f(x, y, z): - return x + 2 * y + 3 * z - shape = (3, 4, 5) - expected = np.fromfunction(f, shape=shape) - actual = jnp.fromfunction(f, shape=shape) - self.assertArraysEqual(expected, actual, check_dtypes=False) - - def testFromString(self): - s = "1,2,3" - expected = np.fromstring(s, sep=',', dtype=int) - actual = jnp.fromstring(s, sep=',', dtype=int) - self.assertArraysEqual(expected, actual) - - @jtu.sample_product( - a_shape=nonempty_nonscalar_array_shapes, - v_shape=nonempty_shapes, - dtype=jtu.dtypes.all, - ) - def testPlace(self, a_shape, v_shape, dtype): - rng = jtu.rand_default(self.rng()) - mask_rng = jtu.rand_bool(self.rng()) - - def args_maker(): - a = rng(a_shape, dtype) - m = mask_rng(a_shape, bool) - v = rng(v_shape, dtype) - return a, m, v - - def np_fun(a, m, v): - a_copy = a.copy() - np.place(a_copy, m, v) - return a_copy - - jnp_fun = partial(jnp.place, inplace=False) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @jtu.sample_product( - a_shape=nonempty_nonscalar_array_shapes, - i_shape=all_shapes, - v_shape=all_shapes, - dtype=jtu.dtypes.all, - mode=[None, 'wrap', 'clip'], - ) - def testPut(self, mode, a_shape, i_shape, v_shape, dtype): - size = math.prod(a_shape) - if math.prod(i_shape) > size: - self.skipTest("too many indices") - rng = jtu.rand_default(self.rng()) - # Must test unique integers, because overlapping updates in - # JAX have implementation-defined order - idx_rng = jtu.rand_unique_int(self.rng(), size) - - def args_maker(): - a = rng(a_shape, dtype) - i = idx_rng(i_shape, np.int32) - v = rng(v_shape, dtype) - # put some indices out of range without duplicating indices - if mode == "clip" and i.size: - np.put(i, np.argmax(i), size + 2) - np.put(i, np.argmin(i), -2) - if mode == "wrap" and i.size: - np.put(i, 0, np.take(i, 0) + size) - return a, i, v - - def np_fun(a, i, v): - a_copy = a.copy() - np.put(a_copy, i, v, mode=mode) - return a_copy - - jnp_fun = partial(jnp.put, mode=mode, inplace=False) - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - def test_rot90_error(self): - with self.assertRaisesRegex( - ValueError, - "rot90 requires its first argument to have ndim at least two, " - "but got first argument of"): - jnp.rot90(jnp.ones(2)) - - @parameterized.named_parameters( - ('ones', jnp.ones), - ('zeros', jnp.zeros), - ('empty', jnp.empty)) - def test_error_hint(self, fn): - with self.assertRaisesRegex( - TypeError, - r"Did you accidentally write `jax\.numpy\..*?\(2, 3\)` " - r"when you meant `jax\.numpy\..*?\(\(2, 3\)\)`"): - fn(2, 3) - - @jtu.sample_product( - dtype=jtu.dtypes.all, - kind=['bool', 'signed integer', 'unsigned integer', 'integral', - 'real floating', 'complex floating', 'numeric'] - ) - def test_isdtype(self, dtype, kind): - # Full tests also in dtypes_test.py; here we just compare against numpy - jax_result = jnp.isdtype(dtype, kind) - if dtype == dtypes.bfloat16: - # just a smoke test - self.assertIsInstance(jax_result, bool) - else: - numpy_result = np.isdtype(dtype, kind) - self.assertEqual(jax_result, numpy_result) - - -@unittest.skipIf(metal_plugin == None, "Tests require jax-metal plugin.") -class ReportedIssuesTests(jtu.JaxTestCase): - def dispatchOn(self, args, func, device=jax.devices('cpu')[0]): - deviceArgs = [] - for arg in args: - deviceArgs.append(jax.device_put(arg, device)) - return func(*deviceArgs) - - @staticmethod - def compile_and_exec(module, args, run_on_cpu=False): - from jax.extend.backend import get_backend - backend = get_backend('METAL') - if run_on_cpu: - backend = get_backend('cpu') - executable = backend.compile(module) - def put(arg): - return backend.buffer_from_pyval(arg, device=executable.local_devices()[0]) - arguments = [put(arg) for arg in args] - outputs = executable.execute(arguments) - return [np.asarray(x) for x in outputs] - - @staticmethod - def jax_metal_supported(target_ver): - if metal_plugin is None or not hasattr(metal_plugin, 'version'): - return False - curr_ver = metal_plugin.version() - if hasattr(jtu, 'parse_version'): - return jtu.parse_version(curr_ver) >= jtu.parse_version(target_ver) - return False - - - #https://github.com/jax-ml/jax/issues/16420 - def test_broadcast_dim(self): - x = jnp.arange(2) - f = lambda x : jax.lax.broadcast_in_dim(x, (2, 2), (0,)) - res = f(x) - print(res) - res_cpu = self.dispatchOn([x],f) - jtu.check_eq(res, res_cpu) - f = lambda x : jax.lax.broadcast_in_dim(x, (2, 2), (1,)) - res = f(x) - print(res) - res_cpu = self.dispatchOn([x],f) - jtu.check_eq(res, res_cpu) - - def test_identity(self): - x = jnp.identity(4) - jtu.check_eq(x, np.identity(4)) - - def test_triu(self): - x = np.ones((4,4)) - res = jnp.triu(x) - jtu.check_eq(res, np.triu(x)) - - #https://github.com/jax-ml/jax/issues/16471 - def test_matmul_1d(self): - x = np.array(np.random.rand(3, 3)) - y = np.array(np.random.rand(3)) - z = np.array(np.random.rand(3)) - res = jnp.dot(y, z) - self.assertArraysAllClose(res, np.dot(y,z)) - res = jnp.dot(x, y) - self.assertArraysAllClose(res, np.dot(x,y)) - - #https://github.com/jax-ml/jax/issues/17175 - def test_indexing(self): - x = jnp.array([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], dtype=jnp.float32) - @jax.vmap - def f(i): - return x[i] - f = jax.jit(f) - idx = jnp.array([1,1,2,2,0]) - res = f(idx) - jtu.check_eq(res, np.array([[4., 5., 6.], [4., 5., 6.], [7., 8., 9.], [7., 8., 9.], [1., 2., 3.]])) - - #https://github.com/jax-ml/jax/issues/17344 - def test_take_along_axis(self): - @jax.jit - def f(): - idx = jnp.array([[0],[0],[0]]) - x = jnp.array([[0.3756883, 0.05820537, 0.7399422, 0.45242703], - [0.5848844, 0.18772626, 0.47942543, 0.20703673], - [0.1071583, 0.26139486, 0.25664794, 0.8109596]]) - return jnp.take_along_axis(x, idx, axis=1) - jtu.check_eq(f(), self.dispatchOn([], f)) - - #https://github.com/jax-ml/jax/issues/17590 - def test_in1d(self): - a = np.array([123,2,4]) - b = np.array([123,1]) - res = jnp.isin(a,b) - jtu.check_eq(res, np.isin(a, b)) - - def test_indexing_update(self): - x = jnp.array([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], dtype=jnp.float32) - @jax.vmap - def f(x): - return x.at[0].set(1.0) - f = jax.jit(f) - res = f(x) - jtu.check_eq(res, np.array([[1., 2., 3.], [1., 5., 6.,], [1., 8., 9.], [1., 11., 12.]])) - - #https://github.com/jax-ml/jax/issues/16326 - def test_indexing_update2(self): - @jax.jit - def f(x, r): - x = x.at[:, 0].set(x[:, 0] / r) - return x - x = jnp.array([[1.0, 2.0], [3.0, 4.0]]) - fx = f(x, jnp.array([10.0])) - jtu.check_eq(fx, np.array([[0.1, 2.0], [0.3, 4.]])) - - def test_gather_ir(self): - ir = ''' -#loc = loc(unknown) -module @jit_gather attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<3x2x3xf32> {mhlo.sharding = "{replicated}"} loc(unknown), %arg1: tensor<3x2xi32> {mhlo.sharding = "{replicated}"} loc(unknown)) -> tensor<3x2xf32> { - %0 = "stablehlo.gather"(%arg0, %arg1) {dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array} : (tensor<3x2x3xf32>, tensor<3x2xi32>) -> tensor<3x2xf32> loc(#loc2) - return %0 : tensor<3x2xf32> loc(#loc) - } loc(#loc) -} loc(#loc) -#loc1 = loc("/Users/shuhan/Code/jax-metal/tests/lax_numpy_indexing_test.py":1156:0) -#loc2 = loc("jit(gather)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(1,), collapsed_slice_dims=(0, 2), start_index_map=(0, 2)) slice_sizes=(1, 2, 1) unique_indices=False indices_are_sorted=False mode=GatherScatterMode.CLIP fill_value=None]"(#loc1)) - ''' - data = np.array([[[0.6369617, 0.26978672, 0.04097353], - [0.01652764, 0.8132702, 0.91275555]], - [[0.60663575, 0.72949654, 0.543625 ], - [0.9350724, 0.81585354, 0.0027385 ]], - [[0.8574043, 0.03358557, 0.72965544], - [0.17565562, 0.8631789, 0.5414612 ]]], dtype=np.float32) - index = np.array([[1, 0],[2, 1],[0, 2]], dtype=np.int32) - res = ReportedIssuesTests.compile_and_exec(ir, [data, index]) - res_ref = ReportedIssuesTests.compile_and_exec(ir, [data, index], run_on_cpu = True) - print(res) - jtu.check_eq(res, res_ref) - - #https://github.com/jax-ml/jax/issues/16366 - def test_pad_interior_1(self): - if not ReportedIssuesTests.jax_metal_supported('0.0.6'): - raise unittest.SkipTest("jax-metal version doesn't support it.") - ir = ''' - module @jit_gather attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<128x7x7x64xf32> {mhlo.sharding = "{replicated}"} loc(unknown), %arg1: tensor {mhlo.sharding = "{replicated}"} loc(unknown)) -> tensor<128x15x15x64xf32> { - %206 = "mhlo.pad"(%arg0, %arg1) {edge_padding_high = dense<[0, 1, 1, 0]> : tensor<4xi64>, edge_padding_low = dense<[0, 1, 1, 0]> : tensor<4xi64>, interior_padding = dense<[0, 1, 1, 0]> : tensor<4xi64>} : (tensor<128x7x7x64xf32>, tensor) -> tensor<128x15x15x64xf32> - return %206 : tensor<128x15x15x64xf32> - } - } - ''' - data = np.random.rand(128,7,7,64).astype(np.float32) - padding = np.array(0.5, dtype=np.float32) - res = ReportedIssuesTests.compile_and_exec(ir, [data, padding]) - res_ref = ReportedIssuesTests.compile_and_exec(ir, [data, padding], run_on_cpu = True) - jtu.check_eq(res, res_ref) - - def test_pad_interior_2(self): - if not ReportedIssuesTests.jax_metal_supported('0.0.6'): - raise unittest.SkipTest("jax-metal version doesn't support it.") - batch = 2 - seq_len = 8 - num_decode = 32 - - seq = np.random.randint(size=(batch, seq_len, num_decode), low=0, high=256, dtype=np.uint8) - res = jnp.cumsum(seq, axis=-1) - res_ref = np.cumsum(seq, axis=-1, dtype=np.uint8) - jtu.check_eq(res, res_ref) - - @unittest.expectedFailure - def test_issue_pad(self): - ir = ''' - module @jit_dummy attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor<2x2xf32> {mhlo.sharding = "{replicated}"} loc(unknown), %arg1: tensor<3x4xf32> {mhlo.sharding = "{replicated}"} loc(unknown)) -> tensor<4x4xf32> { - %12 = stablehlo.slice %arg0 [0:1, 1:2] : (tensor<2x2xf32>) -> tensor<1x1xf32> - %13 = stablehlo.reshape %12 : (tensor<1x1xf32>) -> tensor - %14 = stablehlo.pad %arg1, %13, low = [0, 0], high = [1, 0], interior = [0, 0] : (tensor<3x4xf32>, tensor) -> tensor<4x4xf32> - return %14 : tensor<4x4xf32> - } - } - ''' - data = np.array([[1, 3], [1, 3]], dtype=np.float32) - input = np.random.rand(3,4).astype(np.float32) - res = ReportedIssuesTests.compile_and_exec(ir, [data, input]) - res_ref = ReportedIssuesTests.compile_and_exec(ir, [data, input], run_on_cpu = True) - jtu.check_eq(res, res_ref) - -if __name__ == "__main__": - absltest.main(testLoader=jtu.JaxTestLoader()) From 10a43df9662420de15aecacac938d956fe8fbfdb Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Wed, 3 Dec 2025 16:33:16 -0800 Subject: [PATCH 033/315] [Pallas] Make rng tests inherit from JaxTestCase PiperOrigin-RevId: 839963758 --- jax/experimental/pallas/ops/tpu/random/philox.py | 2 +- jax/experimental/pallas/ops/tpu/random/threefry.py | 2 +- tests/pallas/tpu_pallas_random_test.py | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/jax/experimental/pallas/ops/tpu/random/philox.py b/jax/experimental/pallas/ops/tpu/random/philox.py index dcdbd94779db..cb108c319507 100644 --- a/jax/experimental/pallas/ops/tpu/random/philox.py +++ b/jax/experimental/pallas/ops/tpu/random/philox.py @@ -117,7 +117,7 @@ def kernel(offset_ref, key_ref, out_ref): offset = prng_utils.compute_scalar_offset( counts_idx, unpadded_shape, block_shape) counts_lo = prng_utils.blocked_iota(block_size, unpadded_shape) - counts_lo = counts_lo + offset + offset_ref[0] + counts_lo = counts_lo + offset.astype(jnp.uint32) + offset_ref[0] counts_lo = counts_lo.astype(jnp.uint32) # TODO(justinfu): Support hi bits on count. _zeros = jnp.zeros_like(counts_lo) diff --git a/jax/experimental/pallas/ops/tpu/random/threefry.py b/jax/experimental/pallas/ops/tpu/random/threefry.py index 71a314e09b2d..06a82f4abac8 100644 --- a/jax/experimental/pallas/ops/tpu/random/threefry.py +++ b/jax/experimental/pallas/ops/tpu/random/threefry.py @@ -63,7 +63,7 @@ def kernel(key_ref, out_ref): offset = prng_utils.compute_scalar_offset( counts_idx, unpadded_shape, block_shape) counts_lo = prng_utils.blocked_iota(block_size, unpadded_shape) - counts_lo = counts_lo + offset + counts_lo = counts_lo + offset.astype(jnp.uint32) counts_lo = counts_lo.astype(jnp.uint32) # TODO(justinfu): Support hi bits on count. counts_hi = jnp.zeros_like(counts_lo) diff --git a/tests/pallas/tpu_pallas_random_test.py b/tests/pallas/tpu_pallas_random_test.py index 91f638a0efb4..6a675bb0c7ae 100644 --- a/tests/pallas/tpu_pallas_random_test.py +++ b/tests/pallas/tpu_pallas_random_test.py @@ -197,7 +197,7 @@ def body(key_ref, o_ref): key, shape=o_ref[0, ...].shape, minval=0.0, maxval=1.0 ) - key = jax_random.fold_in(key, 2) + key = jax_random.fold_in(key, jnp.uint32(2)) o_ref[1, ...] = jax_random.uniform( key, shape=o_ref[1, ...].shape, minval=0.0, maxval=1.0 ) @@ -243,7 +243,7 @@ def f(rng_key): self.assertGreaterEqual(jnp.max(y), jnp.min(y)) -class BlockInvarianceTest(parameterized.TestCase): +class BlockInvarianceTest(jtu.JaxTestCase): def setUp(self): if not jtu.test_device_matches(["tpu"]): @@ -290,7 +290,7 @@ def body(key_ref, o_ref): np.testing.assert_array_equal(result_16x128, result_32x256) -class ThreefryTest(parameterized.TestCase): +class ThreefryTest(jtu.JaxTestCase): def setUp(self): if not jtu.test_device_matches(["tpu"]): @@ -373,7 +373,7 @@ def test_threefry_kernel_matches_jax_threefry_sharded(self, shape): np.testing.assert_array_equal(jax_gen, pl_gen) -class PhiloxTest(parameterized.TestCase): +class PhiloxTest(jtu.JaxTestCase): def setUp(self): if not jtu.test_device_matches(["tpu"]): From daedbcc05be8d29d067c27d0ff8b53663b853be7 Mon Sep 17 00:00:00 2001 From: partev Date: Wed, 3 Dec 2025 19:48:08 -0500 Subject: [PATCH 034/315] Update PRNG documentation with corrected links update URLs --- docs/jep/263-prng.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/jep/263-prng.md b/docs/jep/263-prng.md index 7ef10ae0e9c4..ff58d1b7b94e 100644 --- a/docs/jep/263-prng.md +++ b/docs/jep/263-prng.md @@ -12,7 +12,7 @@ We want a PRNG design that As a corollary of these we believe the design should be functional. Another corollary is that, at least given current hardware constraints, we’re going to do the PRNG in software. > TLDR -> **JAX PRNG = [Threefry counter PRNG](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf) + a functional array-oriented [splitting model](https://dl.acm.org/citation.cfm?id=2503784)** +> **JAX PRNG = [Threefry counter PRNG](https://thesalmons.org/john/random123/papers/random123sc11.pdf) + a functional array-oriented [splitting model](https://dl.acm.org/doi/10.1145/2503778.2503784)** ## Contents * [Three programming models and toy example programs](#three-programming-models-and-toy-example-programs) @@ -79,7 +79,7 @@ Explicit threading is inconvenient for the programmer. But worse, it hasn’t ac In short, making the code functional by explicitly threading state isn’t enough to achieve our expressiveness (#1) and performance (#5, #6) goals. -The key problem in both the previous models is that there’s too much sequencing. To reduce the amount of sequential dependence we use **functional [splittable](https://dl.acm.org/citation.cfm?id=2503784) PRNGs**. Splitting is a mechanism to ‘fork’ a new PRNG state into two PRNG states while maintaining the usual desirable PRNG properties (the two new streams are computationally parallelizable and produce independent random values, i.e. they behave like [multistreams](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf)). +The key problem in both the previous models is that there’s too much sequencing. To reduce the amount of sequential dependence we use **functional [splittable](https://dl.acm.org/doi/10.1145/2503778.2503784) PRNGs**. Splitting is a mechanism to ‘fork’ a new PRNG state into two PRNG states while maintaining the usual desirable PRNG properties (the two new streams are computationally parallelizable and produce independent random values, i.e. they behave like [multistreams](https://thesalmons.org/john/random123/papers/random123sc11.pdf)). ```python def foo(rng_1): @@ -105,7 +105,7 @@ The example doesn’t show it, but as a consequence of the choice (2) the only w ## Design -We can use the *counter-based PRNG* design, and in particular the Threefry hash function, as described in [Parallel random numbers: as easy as 1, 2, 3](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf). We use the counter to achieve efficient vectorization: for a given key we can generate an array of values in a vectorized fashion by mapping the hash function over a range of integers [k + 1, …, k + sample_size]. We use the key together with the hash function to implement [splittable PRNGs](https://dl.acm.org/citation.cfm?id=2503784): that is, splitting is a way to generate two new keys from an existing one. +We can use the *counter-based PRNG* design, and in particular the Threefry hash function, as described in [Parallel random numbers: as easy as 1, 2, 3](https://thesalmons.org/john/random123/papers/random123sc11.pdf). We use the counter to achieve efficient vectorization: for a given key we can generate an array of values in a vectorized fashion by mapping the hash function over a range of integers [k + 1, …, k + sample_size]. We use the key together with the hash function to implement [splittable PRNGs](https://dl.acm.org/doi/10.1145/2503778.2503784): that is, splitting is a way to generate two new keys from an existing one. ```haskell type Sample = Int256 From 80967f3d8991cddc35c4c7db521d4b7abbf807f8 Mon Sep 17 00:00:00 2001 From: Rachel Han Date: Wed, 3 Dec 2025 23:04:42 -0800 Subject: [PATCH 035/315] Change int2 shape in pallas test PiperOrigin-RevId: 840084719 --- tests/pallas/ops_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index c6b3d9e7db6e..45fb75e11fac 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -666,8 +666,10 @@ def test_cast_from_32bit(self, from_dtype, to_dtype, data): shape = (8, 128) if to_dtype in {"int2", "uint2"}: # Make sure #rows is a least the packing factor of int2. - # TODO(b/343490729): XLA convert(f32[16, 128]) fails on v5p. - shape = (32, 128) + # TODO: b/343490729 - XLA convert(f32[16, 128]) fails on v5p. + # TODO: b/459440496 - Support more shapes for int2. The number of rows is + # required to be an even multiple of 128. + shape = (128, 128) x = data.draw(hnp.arrays(from_dtype, shape, elements=elements)) x = jnp.asarray(x) def kernel(x_ref, y_ref): From c5f92284833f7768d4c4efd1f043e3c33232755f Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Wed, 3 Dec 2025 23:34:48 -0800 Subject: [PATCH 036/315] Fix Python path for executing "hi" JAX code. PiperOrigin-RevId: 840093989 --- jax/_src/stages.py | 10 ++++++++-- tests/hijax_test.py | 1 + 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 779744b4ca99..55475a0f34dd 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -435,8 +435,14 @@ def fall(self): _, closed_over_himutables = pe.convert_const_himutables(hi_jaxpr) if closed_over_himutables: raise NotImplementedError # TODO(mattjj) lo_jaxpr = pe.lower_jaxpr(hi_jaxpr) - in_tree = lojax_pytree(hi_jaxpr.in_aval_qdds, self._in_tree) - out_tree = lojax_pytree(hi_jaxpr.out_avals, self.out_tree) + if any(a.is_high for a in hi_jaxpr.final_aval_qdds): + in_tree = lojax_pytree(hi_jaxpr.in_aval_qdds, self._in_tree) + else: + in_tree = self._in_tree + if any(a.is_high for a in hi_jaxpr.out_avals): + out_tree = lojax_pytree(hi_jaxpr.out_avals, self.out_tree) + else: + out_tree = self.out_tree params = dict(lojax_expand_params(hi_jaxpr, self._params), jaxpr=lo_jaxpr) lo_meta_tys = [mty.replace(aval=lo_ty) for mty, aq in zip(self._meta_tys_flat, hi_jaxpr.in_aval_qdds) diff --git a/tests/hijax_test.py b/tests/hijax_test.py index b4b3e11c3300..2f665ce76aed 100644 --- a/tests/hijax_test.py +++ b/tests/hijax_test.py @@ -538,6 +538,7 @@ def f(x): if jit: f = jax.jit(f) + self.assertEqual(f.lower(2.0).compile()(2.0), 8.0) self.assertEqual(f(2.0), 8.0) xs = jnp.arange(3.0) From 7f343072d70df146473d122bed3ee097ddad5689 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 4 Dec 2025 00:05:03 -0800 Subject: [PATCH 037/315] Update XLA dependency to use revision http://github.com/openxla/xla/commit/c6315cd85539fac2c08ed33dbe25e006f41ce72b PiperOrigin-RevId: 840102440 --- third_party/xla/revision.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/revision.bzl b/third_party/xla/revision.bzl index 0a5d05fd103d..1804835d507a 100644 --- a/third_party/xla/revision.bzl +++ b/third_party/xla/revision.bzl @@ -21,5 +21,5 @@ # and update XLA_SHA256 with the result. # buildifier: disable=module-docstring -XLA_COMMIT = "601dfb22e9dac2583652e0b44a0f603923e2aaa8" -XLA_SHA256 = "7e36a4a20898d27ed81517857b6670fe1a2adfb92ff0237aa801ebe5e3d7488e" +XLA_COMMIT = "c6315cd85539fac2c08ed33dbe25e006f41ce72b" +XLA_SHA256 = "d1875b989ed12f511deec609b2592698dfb3af4f49cb1a470db42c6658ac83f6" From ca87f344c6d3a5f26869be0ebaf898eff07decf4 Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Thu, 4 Dec 2025 02:45:53 -0800 Subject: [PATCH 038/315] [Mosaic GPU][NFC] Clarify comments on custom call variants. PiperOrigin-RevId: 840153402 --- jaxlib/mosaic/gpu/custom_call.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 3fab5664f49c..06dc1f0491c5 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -641,9 +641,10 @@ absl::StatusOr CachedCompileAndInit(CacheKey key, return &cache.kernels.at(key); } +// TODO(b/464203195): Backward-compatible version using the legacy FFI +// API. Remove once backward compatibility window has passed. void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque, size_t opaque_len, XlaCustomCallStatus* status) { - // Forward-compatible version using the legacy FFI API if (reinterpret_cast(opaque) % alignof(KernelHash)) { fprintf(stderr, "Misaligned opaque pointer\n"); abort(); @@ -680,8 +681,6 @@ absl::Status MosaicGpuExecute(cudaStream_t stream, ffi::RemainingArgs inputs, std::string_view kernel_hash, std::string_view module, bool use_custom_barrier) { - // Updated version using the new FFI API supporting custom barrier - // for distributed kernels if (use_custom_barrier) { return absl::UnimplementedError("Custom barrier is not supported on GPUs."); } From fdfffdc3308f8506d3074c2bb3d953e4255b4ea7 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Thu, 4 Dec 2025 03:51:04 -0800 Subject: [PATCH 039/315] Fix paths triggering H100/B200 presubmits. PiperOrigin-RevId: 840170495 --- .github/workflows/bazel_cuda_h100_b200.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/bazel_cuda_h100_b200.yml b/.github/workflows/bazel_cuda_h100_b200.yml index 25829ac8cf7e..60cbb01ed203 100644 --- a/.github/workflows/bazel_cuda_h100_b200.yml +++ b/.github/workflows/bazel_cuda_h100_b200.yml @@ -41,10 +41,10 @@ jobs: uses: tj-actions/changed-files@ed68ef82c095e0d48ec87eccea555d944a631a4c # v46 with: files: | - jax/jax/_src/pallas/mosaic_gpu/** - jax/jax/experimental/mosaic/gpu/** - jax/jaxlib/mosaic/dialect/gpu/** - jax/jaxlib/mosaic/gpu/** + jax/_src/pallas/mosaic_gpu/** + jax/experimental/mosaic/gpu/** + jaxlib/mosaic/dialect/gpu/** + jaxlib/mosaic/gpu/** .github/workflows/bazel_cuda_h100_b200.yml - name: List all changed files env: From d0fd8f089885d2ce4e760dbb69545e5fa9ee31d1 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Thu, 4 Dec 2025 04:24:43 -0800 Subject: [PATCH 040/315] [Mosaic GPU] Fix broken tests when jax is coupled with an old jax lib version. PiperOrigin-RevId: 840181806 --- jax/_src/pallas/mosaic_gpu/primitives.py | 32 ++++++++++++------- .../mosaic/gpu/dialect_lowering.py | 3 +- tests/mosaic/gpu_test.py | 4 +++ tests/pallas/mosaic_gpu_test.py | 4 +++ 4 files changed, 31 insertions(+), 12 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 78e9302a40e2..b9800026590d 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -249,21 +249,31 @@ def _copy_smem_to_gmem_lowering( ) assert not copy_params.get("gmem_transform") if reduction_op is not None: + # TODO(b/415721295): Call mgpu.dialect.async_store after the if, after + # the minimal jaxlib version is 0.8.2. + if not hasattr(mgpu.dialect, "TMAReduction"): + raise NotImplementedError("Reduction op is not supported yet.") reduction_op_attr = getattr( mgpu.dialect.TMAReduction, reduction_op.capitalize() ) + mgpu.dialect.async_store( + src, + dst, + indices, + slice_lengths, + predicate=predicate, + commit_group=commit_group, # type: ignore[call-arg] + reduction_op=reduction_op_attr, + ) else: - reduction_op_attr = None - - mgpu.dialect.async_store( - src, - dst, - indices, - slice_lengths, - predicate=predicate, - commit_group=commit_group, # type: ignore[call-arg] - reduction_op=reduction_op_attr, - ) + mgpu.dialect.async_store( + src, + dst, + indices, + slice_lengths, + predicate=predicate, + commit_group=commit_group, # type: ignore[call-arg] + ) return () diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 0baf0d90ef6c..24ea841e7cfb 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -991,7 +991,8 @@ def _mgpu_async_store_op_lowering_rule( # flatten -> async_copy -> unflatted here, as long as flattened size is a # multiple of 16. - if store_op.reduction_op is not None: + # TODO(b/415721295):Simplify, after the minimal jaxlib version is 0.8.2. + if hasattr(mgpu, "TMAReduction") and store_op.reduction_op is not None: reduction_op = mgpu.TMAReduction(store_op.reduction_op.value).name.lower() else: reduction_op = None diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index dea06535ee1d..98c6e6a76191 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -5348,6 +5348,10 @@ def body( @parameterized.parameters(jnp.float32, jnp.bfloat16, jnp.float16) def test_async_store_add_reduction(self, dtype): + # TODO(b/415721295):Remove after the minimal jaxlib version is 0.8.2. + if not hasattr(mgpu_dialect, "TMAReduction"): + self.skipTest("TMAReduction op is required.") + shape = (8, 128) def body(ctx, src, dst, smem): diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 5d25674c94a9..06c6adc023bd 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -654,6 +654,10 @@ def kernel(x_ref, o_ref_gmem, scratch_ref): @parameterized.parameters(jnp.bfloat16, jnp.float16, jnp.float32) def test_copy_smem_to_gmem_reduction(self, dtype): + # TODO(b/415721295):Remove after the minimal jaxlib version is 0.8.2. + if not hasattr(mgpu.dialect, "TMAReduction"): + self.skip_if_wg_semantics() + @functools.partial( self.pallas_call, grid=(200,), From 01513b2786713b64102e0c5b619a6f6f4e4cd945 Mon Sep 17 00:00:00 2001 From: Olli Lupton Date: Wed, 3 Dec 2025 20:18:31 +0000 Subject: [PATCH 041/315] Fixes for Jetson Thor devices - The non-portable cluster size maximum is 8 (not 16). - Improve the error message when a kernel launch fails due to too-large cluster dimensions. - The default device-side printf buffer size is smaller; rather than resizing it, make one test case emit fewer device-size printf to fit. --- jax/_src/test_util.py | 8 +-- jaxlib/mosaic/gpu/runtime.cc | 94 +++++++++++++++------------------ tests/pallas/mosaic_gpu_test.py | 4 +- 3 files changed, 51 insertions(+), 55 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index f5217b5f6d38..5ebfe0f3821d 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -545,9 +545,11 @@ def test_method_wrapper(self, *args, **kwargs): ) def get_cuda_nonportable_max_cluster_size(): - if device_kind_match("GB10$"): - # 12 is the nonportable maximum cluster size on DGX Spark, - # determined by querying cuOccupancyMaxPotentialClusterSize. + # Per-device nonportable maximum cluster sizes for Jetson Thor and DGX + # Spark (GB10) determined by querying cuOccupancyMaxPotentialClusterSize + if device_kind_match("Thor$"): + return 8 + elif device_kind_match("GB10$"): return 12 # 16 is the nonportable maximum cluster size on: # - Hopper: https://docs.nvidia.com/cuda/hopper-tuning-guide/index.html#:~:text=cluster%20size%20of-,16,-by%20opting%20in diff --git a/jaxlib/mosaic/gpu/runtime.cc b/jaxlib/mosaic/gpu/runtime.cc index fed2b2fad804..7c12c8e2748e 100644 --- a/jaxlib/mosaic/gpu/runtime.cc +++ b/jaxlib/mosaic/gpu/runtime.cc @@ -20,6 +20,18 @@ limitations under the License. #include "third_party/gpus/cuda/include/cuda.h" #include "jaxlib/mosaic/gpu/nvshmem.h" +namespace { +template +void abort_on_error(CUresult result, const char* fmt, Args&&... args) { + if (result != CUDA_SUCCESS) { + const char *ptr = nullptr; + cuGetErrorString(result, &ptr); + fprintf(stderr, fmt, std::forward(args)..., ptr); + abort(); + } +} +} + extern "C" { void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr, @@ -159,27 +171,18 @@ void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr, fprintf(stderr, "Unsupported swizzle: %ld\n", swizzle_bytes); abort(); } - CUresult result = cuTensorMapEncodeTiled( + abort_on_error( + cuTensorMapEncodeTiled( tma_desc, data_type, rank, base_addr, tma_sizes, tma_strides, tma_window_shape, element_strides, CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle, - CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); - if (result != CUDA_SUCCESS) { - const char *ptr = nullptr; - cuGetErrorString(result, &ptr); - fprintf(stderr, "cuTensorMapEncodeTiled failed: %s\n", ptr); - abort(); - } + CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE), + "cuTensorMapEncodeTiled failed: %s\n"); } void* mosaic_gpu_module_load(void *data) { CUmodule module = nullptr; - if (auto result = cuModuleLoadData(&module, data); result != CUDA_SUCCESS) { - const char *ptr = nullptr; - cuGetErrorString(result, &ptr); - fprintf(stderr, "cuModuleLoadData failed: %s\n", ptr); - abort(); - } - + abort_on_error(cuModuleLoadData(&module, data), + "cuModuleLoadData failed: %s\n"); { // Set the NVSHMEM state if it's used by the module. CUdeviceptr ptr = 0; size_t size = 0; @@ -200,41 +203,23 @@ void* mosaic_gpu_module_load(void *data) { void *mosaic_gpu_get_function(CUmodule module, const char *name, int32_t smem_bytes, int32_t cluster_size) { CUfunction function = nullptr; - CUresult result = cuModuleGetFunction(&function, module, name); - if (result != CUDA_SUCCESS) { - const char *ptr = nullptr; - cuGetErrorString(result, &ptr); - fprintf(stderr, - "Failed to retrieve function pointer to kernel \"%s\", " - "cuModuleGetFunction failed: %s\n", - name, ptr); - abort(); - } + abort_on_error( + cuModuleGetFunction(&function, module, name), + "Failed to retrieve function pointer to kernel \"%s\", " + "cuModuleGetFunction failed: %s\n", name); if (smem_bytes) { - result = cuFuncSetAttribute( - function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_bytes); - if (result != CUDA_SUCCESS) { - const char *ptr = nullptr; - cuGetErrorString(result, &ptr); - fprintf(stderr, - "Failed to set maximum dynamic shared memory size for kernel " - "\"%s\" to %d bytes, cuFuncSetAttribute failed: %s\n", - name, smem_bytes, ptr); - abort(); - } + abort_on_error( + cuFuncSetAttribute( + function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_bytes), + "Failed to set maximum dynamic shared memory size for kernel \"%s\" " + "to %d bytes, cuFuncSetAttribute failed: %s\n", name, smem_bytes); } if (cluster_size > 8) { - result = cuFuncSetAttribute( - function, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1); - if (result != CUDA_SUCCESS) { - const char *ptr = nullptr; - cuGetErrorString(result, &ptr); - fprintf(stderr, - "Failed to set allowed cluster size for kernel \"%s\" to %d, " - "cuFuncSetAttribute failed: %s\n", - name, cluster_size, ptr); - abort(); - } + abort_on_error( + cuFuncSetAttribute( + function, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1), + "Failed to set allowed cluster size for kernel \"%s\" to %d, " + "cuFuncSetAttribute failed: %s\n", name, cluster_size); } return function; } @@ -270,11 +255,18 @@ void mosaic_gpu_launch_kernel(CUfunction function, uint32_t grid_x, config.numAttrs = 1; } CUresult result = cuLaunchKernelEx(&config, function, params, nullptr); - if (result != CUDA_SUCCESS) { - const char *ptr = nullptr; - cuGetErrorString(result, &ptr); - fprintf(stderr, "cuLaunchKernel failed: %s\n", ptr); + if (result == CUDA_ERROR_INVALID_CLUSTER_SIZE) { + int max_cluster_size; + abort_on_error(cuOccupancyMaxPotentialClusterSize(&max_cluster_size, + function, &config), + "cuOccupancyMaxPotentialClusterSize failed: %s\n"); + fprintf(stderr, + "cuLaunchKernel failed with invalid cluster size (%d, %d, %d)" + ": maximum is %d\n", cluster_x, cluster_y, cluster_z, + max_cluster_size); abort(); + } else { + abort_on_error(result, "cuLaunchKernelEx: %s\n"); } } } diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 06c6adc023bd..d2f6f56baab2 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1204,7 +1204,9 @@ def kernel(x_ref, o_ref): self.assertEqual(output(), "It works!\n") def test_print_wgmma_tiled_layout(self): - shape = (128, 64) + # The default printf buffer on some smaller GPUs (e.g. Thor) only has space for + # 4096 threads to printf (short) messages. Keep this shape below that. + shape = (128, 32) size = math.prod(shape) @functools.partial( From 03132bdcd66bf40de1d3b34be4d09480cd5632b5 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 4 Dec 2025 06:31:26 -0800 Subject: [PATCH 042/315] Remove //jax:extend bazel target (use //jax/extend/... instead). PiperOrigin-RevId: 840217879 --- jax/BUILD | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index ed8f35c60529..029f569fef86 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -254,19 +254,6 @@ pytype_strict_library( ) # Public JAX libraries below this point. -# TODO(phawkins): remove this target in favor of the finer-grained targets in jax/extend/... -pytype_strict_library( - name = "extend", - visibility = [":jax_extend_users"], - deps = [ - "//jax/extend", - "//jax/extend:backend", - "//jax/extend:core", - "//jax/extend:linear_util", - "//jax/extend:random", - "//jax/extend:source_info_util", - ], -) # Aliases of experimental targets. # TODO(dsuo): remove these aliases/targets. From 1588d16ea9f81d359d428695404f567e134be939 Mon Sep 17 00:00:00 2001 From: Olli Lupton Date: Thu, 26 Jun 2025 09:32:52 +0000 Subject: [PATCH 043/315] Test dumping HLO when reading compilation cache --- tests/compilation_cache_test.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index f00db7fe1e19..fa7ec7d81c14 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -15,6 +15,7 @@ from __future__ import annotations from collections import Counter +import glob import logging import math import os @@ -648,6 +649,35 @@ def test_persistent_cache_enable_xla_caches(self): self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_per_fusion_autotune_cache_dir, f"jax-cache{s}xla_gpu_per_fusion_autotune_cache_dir") self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_experimental_autotune_cache_mode, xc.AutotuneCacheMode.UPDATE) + @jtu.skip_on_devices("tpu") # TPU backend does not dump on deserialize + def test_dump_on_cache_hit(self): + previous_counts = Counter(_counts) + with ( + config.persistent_cache_min_compile_time_secs(0), + config.persistent_cache_min_entry_size_bytes(0), + tempfile.TemporaryDirectory() as dump_dir1, + tempfile.TemporaryDirectory() as dump_dir2 + ): + jit(lambda x: x + 1, compiler_options={"xla_dump_to": dump_dir1})(1) + self.assertEqual( + _counts["/jax/compilation_cache/cache_hits"], + previous_counts["/jax/compilation_cache/cache_hits"], + ) + jit(lambda x: x + 1, compiler_options={"xla_dump_to": dump_dir2, "xla_dump_hlo_as_proto": True, "xla_dump_hlo_as_text": True})(1) + self.assertEqual( + _counts["/jax/compilation_cache/cache_hits"], + previous_counts["/jax/compilation_cache/cache_hits"] + 1, + 1) + dump1_files = glob.glob(os.path.join(dump_dir1, "*after_optimizations.txt")) + dump2_files = glob.glob(os.path.join(dump_dir2, "*after_optimizations.txt")) + self.assertEqual(len(dump1_files), 1) + self.assertEqual(len(dump2_files), 1) + with (open(dump1_files[0]) as file1, open(dump2_files[0]) as file2): + self.assertEqual(file1.read(), file2.read()) + dump2_pbs = glob.glob(os.path.join(dump_dir2, "*after_optimizations.hlo.pb")) + self.assertEqual(len(dump2_pbs), 1) + + @jtu.with_config( jax_enable_compilation_cache=False, jax_persistent_cache_min_compile_time_secs=0, From 9b31b8ef99ff89e2e8402926273fe918b4120c4f Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 4 Dec 2025 07:48:00 -0800 Subject: [PATCH 044/315] [pallas:mosaic] Use `tree_util.register_dataclass` for `BufferedRef` This way we don't need explicit `tree_{flatten,unflatten}` implementations. PiperOrigin-RevId: 840242085 --- jax/_src/pallas/mosaic/pipeline.py | 34 ++++-------------------------- 1 file changed, 4 insertions(+), 30 deletions(-) diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 8d5ef57505ee..9bd15e18f17d 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -406,7 +406,7 @@ def with_spec(self, spec: pl.BlockSpec) -> BufferedRefBase: # TODO(justinfu): Refactor and rename slot fields to reflect cumulative values # instead of slot index. -@tree_util.register_pytree_node_class +@tree_util.register_dataclass @dataclasses.dataclass(frozen=True) class BufferedRef(BufferedRefBase): """A helper class to automate VMEM double buffering in pallas pipelines. @@ -443,9 +443,9 @@ class BufferedRef(BufferedRefBase): swap: Tracks whether the BufferedRef slots need to be swapped before next copy. """ - _spec: pl.BlockSpec # static metadata - dtype: Any # static metadata - _buffer_type: BufferType # static metadata + _spec: pl.BlockSpec = dataclasses.field(metadata=dict(static=True)) + dtype: Any = dataclasses.field(metadata=dict(static=True)) + _buffer_type: BufferType = dataclasses.field(metadata=dict(static=True)) window_ref: ArrayRef | None accum_ref: ArrayRef | None copy_in_slot: ArrayRef | None @@ -502,32 +502,6 @@ def buffer_count(self) -> int: raise ValueError("buffer count is undefined") return self.window_ref.shape[0] # type: ignore[union-attr] - def tree_flatten(self): - return ( - ( - self.window_ref, - self.accum_ref, - self.copy_in_slot, - self.wait_in_slot, - self.copy_out_slot, - self.wait_out_slot, - self._copy_in_slot_reg, - self._wait_in_slot_reg, - self._copy_out_slot_reg, - self._wait_out_slot_reg, - self.next_fetch_smem, - self.next_fetch_sreg, - self.sem_recvs, - self.sem_sends, - self.swap, - ), - (self._spec, self.dtype, self._buffer_type), - ) - - @classmethod - def tree_unflatten(cls, meta, data): - return cls(*meta, *data) - @staticmethod def buffer_types() -> type[BufferType]: return BufferType From 6832479f0702c16ae5fbfeb143eb555526b1bf4c Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 4 Dec 2025 08:13:56 -0800 Subject: [PATCH 045/315] Add jax_check_static_indices configuration --- jax/_src/config.py | 9 +++++ jax/_src/numpy/indexing.py | 28 ++++++++++++++- tests/lax_numpy_indexing_test.py | 62 ++++++++++++++++++++++++++++---- 3 files changed, 91 insertions(+), 8 deletions(-) diff --git a/jax/_src/config.py b/jax/_src/config.py index 212902268d78..04c0d30de5a5 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1157,6 +1157,15 @@ def _validate_jax_pjrt_client_create_options(new_val): 'to disable any debuggers while leak checking is enabled.')) checking_leaks = functools.partial(check_tracer_leaks, True) +check_static_indices = bool_state( + name='jax_check_static_indices', + default=False, + help=('Turn on bounds checks for static indices during array indexing operations.' + ' These will only be checked when indexing mode is PROMISE_IN_BOUNDS, which' + ' is the default for gather-type operations.'), + include_in_jit_key=True, + include_in_trace_context=True, +) captured_constants_warn_bytes = int_state( name='jax_captured_constants_warn_bytes', diff --git a/jax/_src/numpy/indexing.py b/jax/_src/numpy/indexing.py index 039bed953c20..1e65d9c37e16 100644 --- a/jax/_src/numpy/indexing.py +++ b/jax/_src/numpy/indexing.py @@ -670,7 +670,7 @@ def static_slice(arr: Array, idx: StaticIndex | tuple[StaticIndex, ...]): for axis, (ind, size) in enumerate(safe_zip(idx, arr.shape)): if isinstance(ind, (int, np.integer)): if not (-size <= ind < size): - raise IndexError(f"index {ind} out of bounds for axis {axis} with size {size}") + raise IndexError(f"index {ind} out of bounds for axis {axis} with size {size}") if ind < 0: ind += size start_indices.append(ind) @@ -701,6 +701,29 @@ def static_slice(arr: Array, idx: StaticIndex | tuple[StaticIndex, ...]): return result +def validate_static_indices( + arr: Array, + idx: Index | tuple[Index, ...], *, + normalize_indices: bool) -> None: + """Perform bounds-checks for static indices. + + Raises an IndexError if any static indices are out-of-bounds. + """ + # TODO(jakevdp): expand_bool_indices is expensive; do this more efficiently. + idx = idx if isinstance(idx, tuple) else (idx,) + idx = _expand_bool_indices(idx, arr.shape) + idx_tup = tuple(i for i in _canonicalize_tuple_index(arr.ndim, idx) + if i is not None and not isinstance(i, bool)) + def norm_index(i, size): + return i + size if normalize_indices and i < 0 else i + if len(idx_tup) != arr.ndim: + raise RuntimeError(f"Error for {idx=} and {arr.shape=}: processed {idx_tup=}") + for axis, (i, size) in enumerate(safe_zip(idx_tup, arr.shape)): + if isinstance(i, (int, np.integer)) and (norm_index(i, size) < 0 or i >= size): + raise IndexError(f"index {i} out of bounds for axis {axis} with size {size}" + f" ({normalize_indices=})") + + class IndexingStrategy(enum.Enum): AUTO = 'auto' GATHER = 'gather' @@ -726,6 +749,9 @@ def rewriting_take( if not isinstance(strategy, IndexingStrategy): raise TypeError(f"Expected strategy to be IndexingStrategy; got {strategy}") + if config.check_static_indices.value and (mode is None or slicing.GatherScatterMode.from_any(mode) == slicing.GatherScatterMode.PROMISE_IN_BOUNDS): + validate_static_indices(arr, idx, normalize_indices=normalize_indices) + if strategy == IndexingStrategy.STATIC_SLICE: if not normalize_indices: raise ValueError("strategy=STATIC_SLICE is only supported when normalize_indices=True.") diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index b1ec454d6ca2..5a1ba1ac870f 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -460,11 +460,11 @@ def test_simple_indexing(self, name, shape, dtype, indexer, strategy): self._CompileAndCheck(jnp_fun, args_maker) @parameterized.parameters( - ((2,), -4, IndexError, "index -4 out of bounds for axis 0 with size 2"), - ((2,), 4, IndexError, "index 4 out of bounds for axis 0 with size 2"), - ((2, 3), np.index_exp[:, 4], IndexError, "index 4 out of bounds for axis 1 with size 3"), - ((2, 3), np.index_exp[..., -4], IndexError, "index -4 out of bounds for axis 1 with size 3"), - ((2, 3, 5), np.index_exp[3, :, 0], IndexError, "index 3 out of bounds for axis 0 with size 2"), + ((2,), -4, IndexError, "index -4 out of bounds for axis 0 with size 2"), + ((2,), 4, IndexError, "index 4 out of bounds for axis 0 with size 2"), + ((2, 3), np.index_exp[:, 4], IndexError, "index 4 out of bounds for axis 1 with size 3"), + ((2, 3), np.index_exp[..., -4], IndexError, "index -4 out of bounds for axis 1 with size 3"), + ((2, 3, 5), np.index_exp[3, :, 0], IndexError, "index 3 out of bounds for axis 0 with size 2"), ((2, 3), ([1, 2], 0), TypeError, "static_slice: indices must be static scalars or slices."), ((2, 3), (np.arange(2), 0), TypeError, "static_slice: indices must be static scalars or slices."), ((2, 3), (None, 0), TypeError, "static_slice: got None at position 0"), @@ -1033,10 +1033,10 @@ def testIndexingEmptyDimension(self): "index .* is out of bounds for axis .* with size 0"): _ = np.ones((2, 0))[0, 0] # The numpy error with self.assertRaisesRegex(IndexError, - "index is out of bounds for axis .* with size 0"): + "index .* out of bounds for axis .* with size 0"): _ = x[0, 0] # JAX indexing with self.assertRaisesRegex(IndexError, - "index is out of bounds for axis .* with size 0"): + "index .* out of bounds for axis .* with size 0"): jax.jit(lambda i: x[0, i])(0) # JAX indexing under jit def testBooleanIndexingWithEmptyResult(self): @@ -1849,5 +1849,53 @@ def f(x, i): self.assertArraysEqual(jax.jacrev(f)(x, i), expected) self.assertArraysEqual(jax.jacrev(jax.vmap(f, (None, 0)))(x, i), expected) +@jtu.with_config(jax_check_static_indices=True) +class ValidateIndicesTest(jtu.JaxTestCase): + @parameterized.parameters( + ((2,), -4, IndexError, "index -4 out of bounds for axis 0 with size 2"), + ((2,), 4, IndexError, "index 4 out of bounds for axis 0 with size 2"), + ((2, 3), np.index_exp[:, 4], IndexError, "index 4 out of bounds for axis 1 with size 3"), + ((2, 3), np.index_exp[..., -4], IndexError, "index -4 out of bounds for axis 1 with size 3"), + ((2, 3, 5), np.index_exp[3, :, 0], IndexError, "index 3 out of bounds for axis 0 with size 2"), + ((2, 3, 5), np.index_exp[:5, :, 6], IndexError, "index 6 out of bounds for axis 2 with size 5"), + ((2, 3, 5), np.index_exp[np.arange(3), 6, None], IndexError, "index 6 out of bounds for axis 1 with size 3"), + ((2, 3), (1, 2, 3), IndexError, "Too many indices: 2-dimensional array indexed with 3 regular indices"), + ) + def test_out_of_bound_indices(self, shape, idx, err, msg): + """Test that out-of-bound indexing """ + arr = jnp.zeros(shape) + + with self.subTest("eager"): + with self.assertRaisesRegex(err, msg): + arr[idx] + + with self.subTest("jit"): + with self.assertRaisesRegex(err, msg): + jax.jit(lambda x: x[idx])(arr) + + with self.subTest("arr.at[idx].get()"): + with self.assertRaisesRegex(err, msg): + arr.at[idx].get() + + @jtu.sample_product( + [dict(name=name, shape=shape, indexer=indexer) + for name, index_specs in STATIC_INDEXING_TESTS + for shape, indexer, _ in index_specs], + dtype=all_dtypes + ) + def test_simple_indexing(self, name, shape, dtype, indexer): + """Test that in-bound indexing works correctly.""" + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + np_fun = lambda x: np.asarray(x)[indexer] + jnp_fun = lambda x: jnp.asarray(x)[indexer] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + # Tests x.at[...].get(...) as well. + jnp_fun = lambda x: jnp.asarray(x).at[indexer].get() + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From b82aa387a92d0484600b6b87ca25fdcbffe9d876 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 4 Dec 2025 08:18:24 -0800 Subject: [PATCH 046/315] Fix deprecated use of np.dtype --- jax/_src/scipy/optimize/_lbfgs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jax/_src/scipy/optimize/_lbfgs.py b/jax/_src/scipy/optimize/_lbfgs.py index 9c6df2737dae..3f4767f101a7 100644 --- a/jax/_src/scipy/optimize/_lbfgs.py +++ b/jax/_src/scipy/optimize/_lbfgs.py @@ -22,6 +22,7 @@ import numpy as np from jax._src import api +from jax._src import dtypes from jax._src import lax from jax._src import numpy as jnp from jax._src.numpy import linalg as jnp_linalg @@ -112,7 +113,7 @@ def _minimize_lbfgs( Optimization results. """ d = len(x0) - dtype = np.dtype(x0) + dtype = dtypes.dtype(x0) # ensure there is at least one termination condition if (maxiter is None) and (maxfun is None) and (maxgrad is None): From 53a2a5c88f2566c2419b5fd4db2500fd501bfd05 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 4 Dec 2025 08:21:22 -0800 Subject: [PATCH 047/315] Change //jax:experimental to be a target that reexports other targets. Change in preparation for removing //jax:experimental. PiperOrigin-RevId: 840253885 --- jax/BUILD | 33 +++++++++++++++++++++++++-------- jax/example_libraries/BUILD | 8 -------- jax/experimental/BUILD | 19 +++---------------- 3 files changed, 28 insertions(+), 32 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 029f569fef86..92a07f72f5e6 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -257,18 +257,35 @@ pytype_strict_library( # Aliases of experimental targets. # TODO(dsuo): remove these aliases/targets. -py_library_providing_imports_info( +pytype_strict_library( name = "experimental", - srcs = [ - "//jax/example_libraries:jax_example_libraries", - "//jax/experimental:jax_experimental", - ], visibility = ["//visibility:public"], - # NOTE: Exclude mosaic_gpu, serialize_executable, and buffer_callback. deps = [ ":jax", - "//jax/_src:buffer_callback", - ] + py_deps("absl/logging") + py_deps("numpy"), + "//jax/example_libraries:optimizers", + "//jax/example_libraries:stax", + "//jax/experimental", + "//jax/experimental:checkify", + "//jax/experimental:compute_on", + "//jax/experimental:custom_dce", + "//jax/experimental:custom_partitioning", + "//jax/experimental:fused", + "//jax/experimental:hijax", + "//jax/experimental:jet", + "//jax/experimental:layout", + "//jax/experimental:mesh_utils", + "//jax/experimental:multihost_utils", + "//jax/experimental:ode", + "//jax/experimental:pjit", + "//jax/experimental:profiler", + "//jax/experimental:rnn", + "//jax/experimental:scheduling_groups", + "//jax/experimental:shard_alike", + "//jax/experimental:shard_map", + "//jax/experimental:topologies", + "//jax/experimental:transfer", + "//jax/experimental:xla_metadata", + ], ) alias( diff --git a/jax/example_libraries/BUILD b/jax/example_libraries/BUILD index 46805163d572..133b066cb053 100644 --- a/jax/example_libraries/BUILD +++ b/jax/example_libraries/BUILD @@ -39,11 +39,3 @@ pytype_strict_library( "//jax/_src:util", ] + py_deps("numpy"), ) - -# TODO(dsuo): Remove this filegroup once experimental aliases from jax/BUILD are -# removed. -filegroup( - name = "jax_example_libraries", - srcs = glob(["*.py"]), - visibility = ["//jax:internal"], -) diff --git a/jax/experimental/BUILD b/jax/experimental/BUILD index 1038238db600..c6ee0e9b05b4 100644 --- a/jax/experimental/BUILD +++ b/jax/experimental/BUILD @@ -584,7 +584,9 @@ pytype_strict_library( # be used in new code. Use jax.shard_map instead. name = "shard_map", srcs = ["shard_map.py"], - visibility = jax_visibility("experimental/shard_map"), + visibility = [ + "//jax:internal", + ] + jax_visibility("experimental/shard_map"), deps = [ "//jax", "//jax/_src:mesh", @@ -740,18 +742,3 @@ filegroup( ], visibility = ["//jax:internal"], ) - -filegroup( - name = "jax_experimental", - srcs = glob( - [ - "*.py", - ], - exclude = [ - "buffer_callback.py", - "mental/mosaic/gpu/*.py", - "serialize_executable.py", - ], - ), - visibility = ["//jax:internal"], -) From 4fe443e367fa4eab83eeb0379dc51b223a8c1d13 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Thu, 4 Dec 2025 08:48:11 -0800 Subject: [PATCH 048/315] [Mosaic GPU] Reduce test size to make it run on GPUs with less SMEM (e.g. RTX Pro 6000 Blackwell) PiperOrigin-RevId: 840262900 --- tests/mosaic/gpu_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 98c6e6a76191..b3d6ce892c93 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -5267,10 +5267,10 @@ def body(ctx, out, scratch): self.assertArraysEqual(kernel(), expected) @parameterized.parameters( - ((4, 64, 128), [[0], [1], [2]], (4, 64, 128), False), - ((4, 64, 128), [[0], [1, 2], [3]], (4, 4, 16, 128), False), - ((4, 8, 16, 128), [[0], [1], [2, 3], [4]], (4, 8, 2, 8, 128), False), - ((4, 64, 128), [[0, 1], [2], [3]], (2, 2, 64, 128), True), + ((4, 64, 64), [[0], [1], [2]], (4, 64, 64), False), + ((4, 64, 64), [[0], [1, 2], [3]], (4, 4, 16, 64), False), + ((4, 8, 16, 64), [[0], [1], [2, 3], [4]], (4, 8, 2, 8, 64), False), + ((4, 64, 64), [[0, 1], [2], [3]], (2, 2, 64, 64), True), ) def test_memref_expand_shape( self, input_shape, reassociation, output_shape, has_transforms From eae3b49bda2a548fa3f8294b7467d5dbe77b10f3 Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Thu, 4 Dec 2025 08:55:03 -0800 Subject: [PATCH 049/315] [Mosaic GPU][NFC] Add a test for legacy Mosaic GPU custom calls. PiperOrigin-RevId: 840265546 --- jaxlib/mosaic/gpu/BUILD | 4 +++- jaxlib/mosaic/gpu/custom_call_test.cc | 33 ++++++++++++++++++++++++++- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index 02338d0b1a08..b7359495f184 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -349,8 +349,10 @@ cc_test( srcs = ["custom_call_test.cc"], tags = ["requires-gpu-sm90"], deps = [ - ":custom_call", + ":mosaic_gpu_support", "//testing/base/public:gunit_main", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@xla//xla/hlo/builder:xla_computation", diff --git a/jaxlib/mosaic/gpu/custom_call_test.cc b/jaxlib/mosaic/gpu/custom_call_test.cc index 13b38194889a..e4756a394325 100644 --- a/jaxlib/mosaic/gpu/custom_call_test.cc +++ b/jaxlib/mosaic/gpu/custom_call_test.cc @@ -15,9 +15,12 @@ limitations under the License. #include #include +#include #include #include +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "xla/hlo/builder/xla_computation.h" @@ -32,6 +35,15 @@ limitations under the License. namespace { +using ::absl_testing::IsOk; + +absl::Status ExecuteSync(xla::PjRtLoadedExecutable* executable) { + std::vector no_buffers; + TF_ASSIGN_OR_RETURN(auto result, + executable->Execute({no_buffers}, /*options=*/{})); + return result[0][0]->GetReadyFuture().Await(); +} + TEST(CustomCallTest, MosaicGpuUsesCommandBuffers) { constexpr absl::string_view kHloModule = R"( HloModule mosaic_gpu_uses_command_buffers @@ -70,7 +82,7 @@ ENTRY main { // Ignore return value. Execution will fail because the custom calls don't // wrap any valid Mosaic code, but we only care that the chosen execution // plan uses a command buffer thunk. - (void)executable->Execute(/*argument_handles=*/{}, /*options=*/{}); + ExecuteSync(executable.get()).IgnoreError(); // Matching the name exactly is vulnerable to renaming changes, and is not // ideal. With that said, this seems like the most reasonable thing to do, and @@ -103,4 +115,23 @@ ENTRY main { EXPECT_THAT(after_contents, testing::StartsWith("000: kCommandBuffer")); } +TEST(CustomCallTest, LegacyCustomCall) { + absl::string_view hlo_string = R"hlo( + HloModule test + + ENTRY main { + ROOT result = s32[] custom-call(), custom_call_target="mosaic_gpu", api_version=API_VERSION_STATUS_RETURNING, backend_config="\220\307\037$\222=c\235\344\250\025\261Y\233.\002\264\260\013\026\305Ol\324\355\315dA-\311\3277\"builtin.module\"() <{sym_name = \"kernel\"}> ({\n \"stable_mosaic_gpu.func.func\"() ({\n }) {function_type = (!llvm.ptr, !llvm.ptr, i64, i64, !llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> (), sym_name = \"mosaic_gpu_init_tma_desc\", sym_visibility = \"private\"} : () -> ()\n \"stable_mosaic_gpu.llvm.mlir.global\"() ({\n }) {addr_space = 4 : i32, global_type = !llvm.array<0 x i8>, linkage = #llvm.linkage, sym_name = \"global_scratch\", unnamed_addr = 0 : i64, visibility_ = 0 : i64} : () -> ()\n \"stable_mosaic_gpu.func.func\"() ({\n ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):\n %0 = \"stable_mosaic_gpu.arith.constant\"() {value = 42 : i32} : () -> i32\n %1 = \"stable_mosaic_gpu.arith.constant\"() {value = 0 : i32} : () -> i32\n %2 = \"stable_mosaic_gpu.arith.constant\"() {value = 128 : index} : () -> index\n %3 = \"stable_mosaic_gpu.arith.constant\"() {value = 1 : index} : () -> index\n %4 = \"stable_mosaic_gpu.llvm.mlir.constant\"() {value = 0 : i64} : () -> i64\n %5 = \"stable_mosaic_gpu.llvm.mlir.undef\"() : () -> !llvm.struct<(ptr, ptr, i64)>\n %6 = \"stable_mosaic_gpu.builtin.unrealized_conversion_cast\"(%arg0) : (!llvm.ptr) -> !gpu.async.token\n %7 = \"stable_mosaic_gpu.llvm.load\"(%arg1) {ordering = 0 : i64} : (!llvm.ptr) -> !llvm.ptr\n %8 = \"stable_mosaic_gpu.llvm.insertvalue\"(%5, %7) {position = array} : (!llvm.struct<(ptr, ptr, i64)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64)>\n %9 = \"stable_mosaic_gpu.llvm.insertvalue\"(%8, %7) {position = array} : (!llvm.struct<(ptr, ptr, i64)>, !llvm.ptr) -> !llvm.struct<(ptr, ptr, i64)>\n %10 = \"stable_mosaic_gpu.llvm.insertvalue\"(%9, %4) {position = array} : (!llvm.struct<(ptr, ptr, i64)>, i64) -> !llvm.struct<(ptr, ptr, i64)>\n %11 = \"stable_mosaic_gpu.builtin.unrealized_conversion_cast\"(%10) : (!llvm.struct<(ptr, ptr, i64)>) -> memref\n %12 = \"stable_mosaic_gpu.gpu.launch\"(%6, %3, %3, %3, %2, %3, %3, %1) ({\n ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index, %arg8: index, %arg9: index, %arg10: index, %arg11: index, %arg12: index, %arg13: index):\n %13 = \"stable_mosaic_gpu.nvvm.elect.sync\"() : () -> i1\n %14 = \"stable_mosaic_gpu.gpu.thread_id\"() {dimension = #gpu} : () -> index\n %15 = \"stable_mosaic_gpu.arith.index_cast\"(%14) : (index) -> i32\n %16 = \"stable_mosaic_gpu.gpu.block_dim\"() {dimension = #gpu} : () -> index\n %17 = \"stable_mosaic_gpu.arith.index_cast\"(%16) : (index) -> i32\n %18 = \"stable_mosaic_gpu.gpu.thread_id\"() {dimension = #gpu} : () -> index\n %19 = \"stable_mosaic_gpu.arith.index_cast\"(%18) : (index) -> i32\n %20 = \"stable_mosaic_gpu.arith.muli\"(%19, %17) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %21 = \"stable_mosaic_gpu.arith.addi\"(%15, %20) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %22 = \"stable_mosaic_gpu.gpu.block_dim\"() {dimension = #gpu} : () -> index\n %23 = \"stable_mosaic_gpu.arith.index_cast\"(%22) : (index) -> i32\n %24 = \"stable_mosaic_gpu.arith.muli\"(%17, %23) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %25 = \"stable_mosaic_gpu.gpu.thread_id\"() {dimension = #gpu} : () -> index\n %26 = \"stable_mosaic_gpu.arith.index_cast\"(%25) : (index) -> i32\n %27 = \"stable_mosaic_gpu.arith.muli\"(%26, %24) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %28 = \"stable_mosaic_gpu.arith.addi\"(%21, %27) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %29 = \"stable_mosaic_gpu.gpu.block_dim\"() {dimension = #gpu} : () -> index\n %30 = \"stable_mosaic_gpu.arith.index_cast\"(%29) : (index) -> i32\n %31 = \"stable_mosaic_gpu.arith.muli\"(%24, %30) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %32 = \"stable_mosaic_gpu.arith.constant\"() {value = 5 : i32} : () -> i32\n %33 = \"stable_mosaic_gpu.arith.shrui\"(%28, %32) : (i32, i32) -> i32\n %34 = \"stable_mosaic_gpu.arith.constant\"() {value = -1 : i32} : () -> i32\n %35 = \"stable_mosaic_gpu.arith.constant\"() {value = 0 : i32} : () -> i32\n %36 = \"stable_mosaic_gpu.arith.constant\"() {value = 31 : i32} : () -> i32\n %37 = \"stable_mosaic_gpu.nvvm.shfl.sync\"(%34, %33, %35, %36) {kind = #nvvm} : (i32, i32, i32, i32) -> i32\n %38 = \"stable_mosaic_gpu.arith.constant\"() {value = 0 : i32} : () -> i32\n %39 = \"stable_mosaic_gpu.arith.cmpi\"(%37, %38) {predicate = 0 : i64} : (i32, i32) -> i1\n %40 = \"stable_mosaic_gpu.arith.andi\"(%39, %13) : (i1, i1) -> i1\n %41 = \"stable_mosaic_gpu.nvvm.elect.sync\"() : () -> i1\n %42 = \"stable_mosaic_gpu.gpu.thread_id\"() {dimension = #gpu} : () -> index\n %43 = \"stable_mosaic_gpu.arith.index_cast\"(%42) : (index) -> i32\n %44 = \"stable_mosaic_gpu.gpu.block_dim\"() {dimension = #gpu} : () -> index\n %45 = \"stable_mosaic_gpu.arith.index_cast\"(%44) : (index) -> i32\n %46 = \"stable_mosaic_gpu.gpu.thread_id\"() {dimension = #gpu} : () -> index\n %47 = \"stable_mosaic_gpu.arith.index_cast\"(%46) : (index) -> i32\n %48 = \"stable_mosaic_gpu.arith.muli\"(%47, %45) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %49 = \"stable_mosaic_gpu.arith.addi\"(%43, %48) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %50 = \"stable_mosaic_gpu.gpu.block_dim\"() {dimension = #gpu} : () -> index\n %51 = \"stable_mosaic_gpu.arith.index_cast\"(%50) : (index) -> i32\n %52 = \"stable_mosaic_gpu.arith.muli\"(%45, %51) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %53 = \"stable_mosaic_gpu.gpu.thread_id\"() {dimension = #gpu} : () -> index\n %54 = \"stable_mosaic_gpu.arith.index_cast\"(%53) : (index) -> i32\n %55 = \"stable_mosaic_gpu.arith.muli\"(%54, %52) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %56 = \"stable_mosaic_gpu.arith.addi\"(%49, %55) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %57 = \"stable_mosaic_gpu.gpu.block_dim\"() {dimension = #gpu} : () -> index\n %58 = \"stable_mosaic_gpu.arith.index_cast\"(%57) : (index) -> i32\n %59 = \"stable_mosaic_gpu.arith.muli\"(%52, %58) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %60 = \"stable_mosaic_gpu.arith.constant\"() {value = 5 : i32} : () -> i32\n %61 = \"stable_mosaic_gpu.arith.shrui\"(%56, %60) : (i32, i32) -> i32\n %62 = \"stable_mosaic_gpu.arith.constant\"() {value = -1 : i32} : () -> i32\n %63 = \"stable_mosaic_gpu.arith.constant\"() {value = 0 : i32} : () -> i32\n %64 = \"stable_mosaic_gpu.arith.constant\"() {value = 31 : i32} : () -> i32\n %65 = \"stable_mosaic_gpu.nvvm.shfl.sync\"(%62, %61, %63, %64) {kind = #nvvm} : (i32, i32, i32, i32) -> i32\n %66 = \"stable_mosaic_gpu.arith.constant\"() {value = 4 : i32} : () -> i32\n %67 = \"stable_mosaic_gpu.arith.remui\"(%65, %66) : (i32, i32) -> i32\n %68 = \"stable_mosaic_gpu.arith.constant\"() {value = 0 : i32} : () -> i32\n %69 = \"stable_mosaic_gpu.arith.cmpi\"(%67, %68) {predicate = 0 : i64} : (i32, i32) -> i1\n %70 = \"stable_mosaic_gpu.arith.andi\"(%69, %41) : (i1, i1) -> i1\n %71 = \"stable_mosaic_gpu.gpu.thread_id\"() {dimension = #gpu} : () -> index\n %72 = \"stable_mosaic_gpu.arith.index_cast\"(%71) : (index) -> i32\n %73 = \"stable_mosaic_gpu.gpu.block_dim\"() {dimension = #gpu} : () -> index\n %74 = \"stable_mosaic_gpu.arith.index_cast\"(%73) : (index) -> i32\n %75 = \"stable_mosaic_gpu.gpu.thread_id\"() {dimension = #gpu} : () -> index\n %76 = \"stable_mosaic_gpu.arith.index_cast\"(%75) : (index) -> i32\n %77 = \"stable_mosaic_gpu.arith.muli\"(%76, %74) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %78 = \"stable_mosaic_gpu.arith.addi\"(%72, %77) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %79 = \"stable_mosaic_gpu.gpu.block_dim\"() {dimension = #gpu} : () -> index\n %80 = \"stable_mosaic_gpu.arith.index_cast\"(%79) : (index) -> i32\n %81 = \"stable_mosaic_gpu.arith.muli\"(%74, %80) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %82 = \"stable_mosaic_gpu.gpu.thread_id\"() {dimension = #gpu} : () -> index\n %83 = \"stable_mosaic_gpu.arith.index_cast\"(%82) : (index) -> i32\n %84 = \"stable_mosaic_gpu.arith.muli\"(%83, %81) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %85 = \"stable_mosaic_gpu.arith.addi\"(%78, %84) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %86 = \"stable_mosaic_gpu.gpu.block_dim\"() {dimension = #gpu} : () -> index\n %87 = \"stable_mosaic_gpu.arith.index_cast\"(%86) : (index) -> i32\n %88 = \"stable_mosaic_gpu.arith.muli\"(%81, %87) {overflowFlags = #arith.overflow} : (i32, i32) -> i32\n %89 = \"stable_mosaic_gpu.arith.constant\"() {value = 5 : i32} : () -> i32\n %90 = \"stable_mosaic_gpu.arith.shrui\"(%85, %89) : (i32, i32) -> i32\n %91 = \"stable_mosaic_gpu.arith.constant\"() {value = 0 : i32} : () -> i32\n %92 = \"stable_mosaic_gpu.arith.cmpi\"(%90, %91) {predicate = 0 : i64} : (i32, i32) -> i1\n %93 = \"stable_mosaic_gpu.gpu.dynamic_shared_memory\"() : () -> memref>\n %94 = \"stable_mosaic_gpu.arith.index_cast\"(%1) : (i32) -> index\n %95 = \"stable_mosaic_gpu.memref.view\"(%93, %94) : (memref>, index) -> memref<0xi8, #gpu.address_space>\n %96 = \"stable_mosaic_gpu.builtin.unrealized_conversion_cast\"(%95) {transforms = []} : (memref<0xi8, #gpu.address_space>) -> memref<0xi8, #gpu.address_space>\n \"stable_mosaic_gpu.nvvm.fence.mbarrier.init\"() : () -> ()\n \"stable_mosaic_gpu.gpu.barrier\"() : () -> ()\n \"stable_mosaic_gpu.memref.store\"(%0, %11) : (i32, memref) -> ()\n \"stable_mosaic_gpu.gpu.terminator\"() : () -> ()\n }) {operandSegmentSizes = array, workgroup_attributions = 0 : i64} : (!gpu.async.token, index, index, index, index, index, index, i32) -> !gpu.async.token\n \"stable_mosaic_gpu.func.return\"() : () -> ()\n }) {function_type = (!llvm.ptr, !llvm.ptr) -> (), llvm.emit_c_interface, sym_name = \"kernel_mosaic_gpu\"} : () -> ()\n}) {stable_mosaic_gpu.version = 6 : i64} : () -> ()\n" + } + )hlo"; + ASSERT_OK_AND_ASSIGN(auto module, + xla::ParseAndReturnUnverifiedModule(hlo_string)); + ASSERT_OK_AND_ASSIGN(std::unique_ptr client, + xla::GetXlaPjrtGpuClient(/*options=*/{})); + ASSERT_OK_AND_ASSIGN( + std::unique_ptr executable, + client->CompileAndLoad(xla::XlaComputation(module->ToProto()), + /*options=*/{})); + EXPECT_THAT(ExecuteSync(executable.get()), IsOk()); +} + } // namespace From ca4d7884bb2de13df8aae622c568c213005ac7c4 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 4 Dec 2025 09:47:11 -0800 Subject: [PATCH 050/315] Make in_avals and out_avals match as well as we can with in_shardings and out_shardings respectively during lowering. PiperOrigin-RevId: 840285373 --- jax/_src/array.py | 1 - jax/_src/core.py | 13 +++--- jax/_src/interpreters/pxla.py | 12 ++++-- jax/_src/stages.py | 27 +++++++++++- tests/pjit_test.py | 81 +++++++++++++++++++++++++++++++---- 5 files changed, 113 insertions(+), 21 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index 9c7dc1738343..5a1c2834feda 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -1284,7 +1284,6 @@ def _array_shard_arg(xs, shardings, layouts, copy_semantics): def _array_global_result_handler(global_aval, out_sharding, committed): - global_aval = core.update_aval_with_sharding(global_aval, out_sharding) if global_aval.dtype == dtypes.float0: return lambda _: np.zeros(global_aval.shape, dtypes.float0) if dtypes.issubdtype(global_aval.dtype, dtypes.extended): diff --git a/jax/_src/core.py b/jax/_src/core.py index 6da26eb1e2a5..9aa32fe6e7b5 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1759,15 +1759,12 @@ def mem_space_to_kind(mem_space: MemorySpace) -> str: @cache(max_size=4096, trace_context_in_key=lambda: config.remove_size_one_mesh_axis_from_type.value) def update_aval_with_sharding(aval, sharding, vma=None): - if vma is None: - vma = aval.vma if isinstance(sharding, NamedSharding): - return aval.update( - sharding=NamedSharding( - sharding.mesh.abstract_mesh, - sharding.spec._normalized_spec_for_aval(aval.ndim)), - vma=vma, memory_space=mem_kind_to_space(sharding.memory_kind)) - return aval.update(vma=vma) + s = NamedSharding(sharding.mesh.abstract_mesh, + sharding.spec._normalized_spec_for_aval(aval.ndim)) + return aval.update(sharding=s, vma=aval.vma if vma is None else vma, + memory_space=mem_kind_to_space(sharding.memory_kind)) + return aval if vma is None else aval.update(vma=vma) # We have three flavors of abstractification APIs here which each used to have diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index f12a4b518d0b..228229ac40d2 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1886,7 +1886,7 @@ def _discharge_internal_refs(jaxpr: core.ClosedJaxpr) -> core.ClosedJaxpr: class SemanticallyEqualShardings: def __init__(self, shardings: tuple[GSPMDSharding | UnspecifiedValue, ...], - avals: tuple[core.AbstractValue]): + avals: Sequence[core.AbstractValue]): gspmd_shardings = [ s if (isinstance(s, (UnspecifiedValue, AUTO)) or (isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh))) @@ -1894,7 +1894,6 @@ def __init__(self, shardings: tuple[GSPMDSharding | UnspecifiedValue, ...], for s, a in zip(shardings, avals)] self._gspmd_shardings = gspmd_shardings self.shardings = shardings - self.avals = avals def __hash__(self): return hash(tuple( @@ -2374,7 +2373,14 @@ def lower_sharding_computation( out_shardings, global_out_avals, device_assignment, propagated_out_mem_kinds) - # 2. Build up the HLO + global_in_avals = [core.update_aval_with_sharding(a, sh) + if isinstance(a, core.ShapedArray) else a + for a, sh in zip(global_in_avals, in_shardings)] + global_out_avals = [core.update_aval_with_sharding(a, sh) + if isinstance(a, core.ShapedArray) else a + for a, sh in zip(global_out_avals, out_shardings)] + + ############################ Build up the stableHLO ###################### abstract_mesh = None if prim_requires_devices: diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 55475a0f34dd..013d35d5ad7c 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -378,7 +378,7 @@ def _traced_args_info(self): def _traced_out_info(self): out_shardings = [None if isinstance(s, UnspecifiedValue) else s - for s in self._params['out_shardings']] + for s in self._params['out_shardings']] out_layouts = [None if isinstance(l, AutoLayout) else l for l in self._params['out_layouts']] out = [] @@ -549,6 +549,20 @@ def __init__(self, lowering: Lowering, args_info, self._in_types = in_types # type: ignore self._out_types = out_types # type: ignore + @property + def in_avals(self): + in_avals_ = self._lowering.compile_args.get("global_in_avals", None) + if in_avals_ is None: # For old pmap code i.e. PmapComputation + return tree_util.tree_map(lambda x: x._aval, self.args_info) + kept_var_idx = self._lowering.compile_args["kept_var_idx"] + non_dce_avals = self._lowering.compile_args["all_args_info"].in_avals + if self.in_tree.num_leaves > len(in_avals_): + iter_in_avals = iter(in_avals_) + in_avals_ = [ + next(iter_in_avals) if i in kept_var_idx + else a for i, a in zip(range(self.in_tree.num_leaves), non_dce_avals)] + return self.in_tree.unflatten(in_avals_) + @property def out_info(self): # PyTree of OutInfo out_avals = self._lowering.compile_args["global_out_avals"] @@ -713,6 +727,17 @@ def memory_analysis(self) -> Any | None: except NotImplementedError: return None + @property + def in_avals(self): + in_avals_ = self._executable.in_avals + if self.in_tree.num_leaves > len(in_avals_): + iter_in_avals = iter(in_avals_) + non_dce_avals = self._executable._all_args_info.in_avals + in_avals_ = [ + next(iter_in_avals) if i in self._executable._kept_var_idx + else a for i, a in zip(range(self.in_tree.num_leaves), non_dce_avals)] + return self.in_tree.unflatten(in_avals_) + @property def out_info(self): # PyTree of jax.ShapeDtypeStruct out_avals = self._executable.out_avals diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 94daf7c9cc44..8fd2253d47ab 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -840,14 +840,13 @@ def f(x): @jtu.with_mesh([('x', 2), ('y', 2)]) def testLowerCompile(self): - @partial(pjit, - in_shardings=P(('x', 'y'),), - out_shardings=P(('x', 'y'),)) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + x = jnp.arange(64).reshape(8, 8) + + @partial(pjit, in_shardings=P(('x', 'y')), out_shardings=P(('x', 'y'))) def f(x, y): return x @ y - shape = (8, 8) - x = jnp.arange(math.prod(shape)).reshape(shape) expected = x @ (x + 1) lowered = f.lower(x, x + 1) @@ -855,9 +854,11 @@ def f(x, y): actual = compiled(x, x + 1) self.assertEqual(lowered.in_avals, compiled.in_avals) - self.assertEqual( - lowered.in_avals, - ((core.ShapedArray(x.shape, x.dtype, weak_type=False),) * 2, {})) + + abs_mesh = mesh.abstract_mesh + exp_aval = core.ShapedArray(x.shape, x.dtype, + sharding=NamedSharding(abs_mesh, P())) + self.assertEqual(lowered.in_avals, ((exp_aval,) * 2, {})) splits = np.split(expected, 4) self.assertAllClose(np.asarray(actual.addressable_shards[0].data), splits[0], @@ -9744,6 +9745,70 @@ def g(params, inputs): self.assertEqual(out.sharding, NamedSharding(mesh, P(None, None, unreduced={'x'}))) + @jtu.with_explicit_mesh((2,), 'x') + def test_out_aval_matches_out_sharding(self, mesh): + arr = jnp.arange(8) + + @jax.jit(out_shardings=P('x')) + def f(x): + return x * 2 + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + self.assertEqual(out.aval.sharding, NamedSharding(mesh.abstract_mesh, P('x'))) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_out_aval_matches_out_sharding_override(self, mesh): + arr = jax.device_put(jnp.arange(8).reshape(4, 2), P('x', None)) + + @jax.jit(out_shardings=P('x', 'y')) + def f(x): + return x * 2 + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + self.assertEqual(out.aval.sharding, + NamedSharding(mesh.abstract_mesh, P('x', 'y'))) + + @jax.jit + def g(x): + return x * 2 + + out = g(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + self.assertEqual(out.aval.sharding, + NamedSharding(mesh.abstract_mesh, P('x', None))) + + @jtu.with_explicit_mesh((2,), 'x', axis_types=(AxisType.Auto,)) + def test_out_aval_auto_mode(self, mesh): + arr = jax.device_put(jnp.arange(8).reshape(4, 2), P('x')) + + @jax.jit + def f(x): + return x * 2 + + out = f(arr) + self.assertEqual(out.aval.sharding.mesh, mesh.abstract_mesh) + + @jtu.with_explicit_mesh((2,), 'x') + def test_in_aval_matches_in_sharding(self, mesh): + arr = np.arange(8) + + @jax.jit(in_shardings=P('x'), out_shardings=P('x')) + def f(x): + return x * 2 + + lowered = f.lower(arr) + l_in_aval = lowered.in_avals[0][0] + self.assertEqual(l_in_aval.sharding, + NamedSharding(mesh.abstract_mesh, P('x'))) + + compiled = lowered.compile() + c_in_aval = compiled.in_avals[0][0] + self.assertEqual(c_in_aval.sharding, + NamedSharding(mesh.abstract_mesh, P('x'))) + compiled(arr) # doesn't crash + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From 7f31fa3177dfa56cbc31e7e534aa829366405b6f Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 4 Dec 2025 10:26:19 -0800 Subject: [PATCH 051/315] PR #33714: always apply stop_gradient_p, even to exact dtypes Imported from GitHub PR https://github.com/jax-ml/jax/pull/33714 fixes #33689 Copybara import of the project: -- 25ffbc808b5e05da80e6916e80b4efcd33137874 by Matthew Johnson : always apply stop_gradient_p, even to exact dtypes fixes #33689 Merging this change closes #33714 COPYBARA_INTEGRATE_REVIEW=https://github.com/jax-ml/jax/pull/33714 from mattjj:issue33689 25ffbc808b5e05da80e6916e80b4efcd33137874 PiperOrigin-RevId: 840301807 --- jax/_src/lax/lax.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index e06a842b733c..ba7f7998dbd3 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -3546,17 +3546,12 @@ def stop_gradient(x: T) -> T: the applicability of ``stop_gradient``. """ def stop(x): - # only bind primitive on inexact dtypes, to avoid some staging if dtypes.issubdtype(core.get_aval(x).dtype, dtypes.extended): return x - elif (dtypes.issubdtype(_dtype(x), np.floating) or - dtypes.issubdtype(_dtype(x), np.complexfloating)): - # break abstractions to support legacy leaked tracer use cases - if isinstance(x, ad.JVPTracer): - return stop(x.primal) - return ad_util.stop_gradient_p.bind(x) + elif isinstance(x, ad.JVPTracer): + return stop(x.primal) else: - return x + return ad_util.stop_gradient_p.bind(x) return tree_util.tree_map(stop, x) def reduce_precision(operand: float | ArrayLike, From a8bd216378b77370f4472e63d4a3fc84c8514dab Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 4 Dec 2025 10:28:29 -0800 Subject: [PATCH 052/315] [Pallas] Hardcode einsum flops estimate. PiperOrigin-RevId: 840302687 --- tests/pallas/pallas_cost_estimate_test.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tests/pallas/pallas_cost_estimate_test.py b/tests/pallas/pallas_cost_estimate_test.py index c9a68844a185..6b0b39ed15dc 100644 --- a/tests/pallas/pallas_cost_estimate_test.py +++ b/tests/pallas/pallas_cost_estimate_test.py @@ -62,21 +62,18 @@ def matmul(a, b): self.assertEqual(cost.bytes_accessed, 4*(b*m*k + b*n*k + b*m*n)) @parameterized.parameters( - ((10, 11, 12), (11, 12), "abc,bc->a"), - ((10, 11, 12), (13, 11, 12), "abc,dbc->ad"), - ((10, 11, 12), (9, 10, 11, 12), "abc,dabc->d"), + ((10, 11, 12), (11, 12), "abc,bc->a", 2640), + ((10, 11, 12), (13, 11, 12), "abc,dbc->ad", 34320), + ((10, 11, 12), (9, 10, 11, 12), "abc,dabc->d", 23760), ) - def test_einsum(self, a_shape, b_shape, pattern): - a = jnp.ones(a_shape, dtype=jnp.float32) - b = jnp.ones(b_shape, dtype=jnp.float32) + def test_einsum(self, a_shape, b_shape, pattern, expected_flops): def matmul(a, b): return jnp.einsum(pattern, a, b) cost = cost_estimate.estimate_cost( matmul, jax.ShapeDtypeStruct(a_shape, jnp.float32), jax.ShapeDtypeStruct(b_shape, jnp.float32)) - xla_flops = jax.jit(matmul).lower(a, b).compile().cost_analysis()['flops'] - self.assertEqual(cost.flops, int(xla_flops)) + self.assertEqual(cost.flops, expected_flops) def test_attention(self): qk_dim = 16 From 42eafcb7e41b5ee2e367858f97209e6b2861a930 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 4 Dec 2025 01:42:47 +0000 Subject: [PATCH 053/315] add test for #33714 --- tests/lax_test.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/lax_test.py b/tests/lax_test.py index 8e30931145cf..2ddb296df13e 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -3766,6 +3766,22 @@ def g(x): cts = vjp_fn(jnp.ones((8,), dtype=jnp.float8_e4m3fn)) # Don't crash self.assertEqual(cts[0].dtype, jnp.bfloat16) + def test_stop_gradient_on_ints(self): + # https://github.com/jax-ml/jax/issues/33689 + @jax.custom_gradient + def f(x): + def fbwd(g): + return jnp.ones_like(x) + return (x, jnp.round(x).astype(jnp.int32)), fbwd + + def loss(x): + y, i = f(x) + y_nograd, i_nograd = jax.lax.stop_gradient((y, i)) + self.assertEqual(type(y_nograd), type(i_nograd)) + return jnp.sum(f(y)[0]) + + jax.grad(loss)(jnp.ones((3,))) + class LazyConstantTest(jtu.JaxTestCase): def _Check(self, make_const, expected): From 3ba9f282c7523e78e8a9eaabcf625f8266161687 Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 4 Dec 2025 10:47:14 -0800 Subject: [PATCH 054/315] [Pallas] Skip interpret tests on accelerators PiperOrigin-RevId: 840310278 --- jax/_src/pallas/BUILD | 12 ++++ jax/_src/pallas/pallas_test_util.py | 55 +++++++++++++++++++ tests/pallas/BUILD | 18 ++++-- tests/pallas/ops_test.py | 20 ++----- tests/pallas/pallas_test.py | 46 +++++----------- tests/pallas/tpu_ops_test.py | 18 +----- tests/pallas/tpu_pallas_call_print_test.py | 15 +---- tests/pallas/tpu_pallas_test.py | 47 ++++++---------- ...pu_splash_attention_kernel_sharded_test.py | 17 +++--- .../tpu_splash_attention_kernel_test.py | 12 ++-- 10 files changed, 135 insertions(+), 125 deletions(-) create mode 100644 jax/_src/pallas/pallas_test_util.py diff --git a/jax/_src/pallas/BUILD b/jax/_src/pallas/BUILD index 2207af4e9dc2..074002f3c935 100644 --- a/jax/_src/pallas/BUILD +++ b/jax/_src/pallas/BUILD @@ -66,3 +66,15 @@ py_library( "//jax/_src/lib", ] + py_deps("numpy"), ) + +py_library( + name = "pallas_test_util", + testonly = True, + srcs = [ + "pallas_test_util.py", + ], + deps = [ + ":pallas", + "//jax/_src:test_util", + ], +) diff --git a/jax/_src/pallas/pallas_test_util.py b/jax/_src/pallas/pallas_test_util.py new file mode 100644 index 000000000000..621ca70b72bd --- /dev/null +++ b/jax/_src/pallas/pallas_test_util.py @@ -0,0 +1,55 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pallas test utilities.""" +import sys + +from jax._src import test_util as jtu +from jax._src.pallas import pallas_call +from jax.experimental import pallas as pl + +use_mosaic_gpu = pallas_call._PALLAS_USE_MOSAIC_GPU.value + + +@jtu.with_config(jax_traceback_filtering="off") +class PallasTest(jtu.JaxTestCase): + INTERPRET: bool = False + + def setUp(self): + if not jtu.test_device_matches(['cpu']) and self.INTERPRET: + self.skipTest('Only run interpret tests on CPU.') + if not self.INTERPRET: + # Running on accelerator + if jtu.test_device_matches(["cpu"]): + self.skipTest("On CPU the test works only in interpret mode") + if (jtu.test_device_matches(["cuda"]) and + not jtu.is_cuda_compute_capability_at_least("8.0")): + self.skipTest("Only works on GPU with capability >= sm80") + if (jtu.test_device_matches(["cuda"]) and use_mosaic_gpu and + not jtu.is_cuda_compute_capability_at_least("9.0")): + self.skipTest("Mosaic GPU requires capability >= sm90") + if sys.platform == "win32": + self.skipTest("Only works on non-Windows platforms") + super().setUp() + + def pallas_call(self, *args, **kwargs): + return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET) + + +class PallasTPUTest(PallasTest): + """A test case that only runs on TPUs or in interpret mode on CPU.""" + + def setUp(self): + if not jtu.test_device_matches(['tpu']) and not self.INTERPRET: + self.skipTest('Test requires TPUs') + super().setUp() diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index a32ea7dace43..d058ea22746c 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -61,6 +61,7 @@ jax_multiplatform_test( "//jax:pallas_gpu_ops", "//jax:pallas_tpu", "//jax:pallas_tpu_ops", + "//jax/_src/pallas:pallas_test_util", ] + py_deps([ "absl/testing", "numpy", @@ -143,6 +144,7 @@ jax_multiplatform_test( "//jax:pallas_gpu", # build_cleaner: keep "//jax:pallas_tpu", "//jax:pallas_tpu_ops", + "//jax/_src/pallas:pallas_test_util", ] + py_deps([ "absl/testing:flagsaver", "absl/testing", @@ -187,6 +189,7 @@ jax_multiplatform_test( "//jax:pallas_gpu", # build_cleaner: keep "//jax:pallas_mosaic_gpu", # build_cleaner: keep "//jax:pallas_tpu", + "//jax/_src/pallas:pallas_test_util", ] + py_deps([ "absl/testing:flagsaver", "absl/testing", @@ -418,7 +421,10 @@ jax_multiplatform_test( jax_multiplatform_test( name = "tpu_pallas_test", srcs = ["tpu_pallas_test.py"], - enable_backends = ["tpu"], + enable_backends = [ + "tpu", + "cpu", + ], enable_configs = [ "tpu_v5e", "tpu_v5p", @@ -428,6 +434,7 @@ jax_multiplatform_test( "//jax:mesh_utils", "//jax:pallas_tpu", "//jax:pallas_tpu_ops", + "//jax/_src/pallas:pallas_test_util", "//jax/extend", ] + py_deps([ "absl/testing", @@ -455,6 +462,7 @@ jax_multiplatform_test( deps = [ "//jax:pallas_tpu", "//jax:pallas_tpu_ops", + "//jax/_src/pallas:pallas_test_util", "//jax/extend", ] + py_deps([ "absl/testing", @@ -499,16 +507,14 @@ jax_multiplatform_test( srcs = [ "tpu_ops_test.py", ], - enable_backends = [ - "cpu", - "tpu", - ], + enable_backends = ["tpu"], shard_count = 8, deps = [ "//jax:pallas", "//jax:pallas_gpu", # build_cleaner: keep "//jax:pallas_tpu", "//jax:pallas_tpu_ops", + "//jax/_src/pallas:pallas_test_util", ] + py_deps([ "absl/testing", "hypothesis", @@ -757,6 +763,7 @@ jax_multiplatform_test( ], deps = [ "//jax:pallas_tpu_ops", + "//jax/_src/pallas:pallas_test_util", ] + py_deps([ "absl/testing", "numpy", @@ -775,6 +782,7 @@ jax_multiplatform_test( deps = [ "//jax:pallas_tpu", "//jax:pallas_tpu_ops", + "//jax/_src/pallas:pallas_test_util", "//jax/extend", ] + py_deps([ "absl/testing", diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 45fb75e11fac..d157c0b6c388 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -33,6 +33,7 @@ from jax._src import test_util as jtu from jax._src.pallas import pallas_call from jax._src.pallas import primitives as pallas_primitives +from jax._src.pallas import pallas_test_util as ptu from jax.experimental import pallas as pl from jax.interpreters import partial_eval as pe import jax.numpy as jnp @@ -274,21 +275,7 @@ def select_n_strategy( ] -class PallasBaseTest(jtu.JaxTestCase): - INTERPRET = False - - def setUp(self): - if not self.INTERPRET: - if jtu.device_under_test() == "cpu": - self.skipTest("Only interpret mode supported on CPU") - if (jtu.test_device_matches(["cuda"]) and - not jtu.is_cuda_compute_capability_at_least("8.0")): - self.skipTest("Only works on GPUs with capability >= sm80") - if (jtu.test_device_matches(["cuda"]) and use_mosaic_gpu and - not jtu.is_cuda_compute_capability_at_least("9.0")): - self.skipTest("Mosaic GPU requires capability >= sm90") - - super().setUp() +class PallasBaseTest(ptu.PallasTest): @classmethod def pallas_call(cls, *args, **kwargs): @@ -715,6 +702,9 @@ def test_cast_from_sub_32bit(self, from_dtype, to_dtype, randomize): or from_dtype in {"int2", "uint2"} ): self.skipTest("sub-byte casts are buggy on GPU") # b/391292861 + if self.INTERPRET and (to_dtype in {"int2", "uint2"} or + from_dtype in {"int2", "uint2"}): + self.skipTest("Test fails on CPU.") if from_dtype == "float16" or to_dtype == "float16" and not sut_is_mosaic_gpu: self.skipTest("float16 is only supported with Mosaic GPU") if sut_is_mosaic_gpu: diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 8b4de6a8f42e..689800eb1f81 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -33,6 +33,7 @@ from jax._src import dtypes from jax._src import hijax from jax._src import test_util as jtu +from jax._src.pallas import pallas_test_util as ptu from jax.experimental import pallas as pl import jax.export import jax.numpy as jnp @@ -92,26 +93,7 @@ def body(i, acc): return matmul_kernel(x, y) -@jtu.with_config(jax_traceback_filtering="off") -class PallasBaseTest(jtu.JaxTestCase): - INTERPRET = False - - def setUp(self): - if jtu.test_device_matches(["cpu"]) and not self.INTERPRET: - self.skipTest("On CPU the test works only in interpret mode") - if (jtu.test_device_matches(["cuda"]) and - not jtu.is_cuda_compute_capability_at_least("8.0")): - self.skipTest("Only works on GPU with capability >= sm80") - if sys.platform == "win32" and not self.INTERPRET: - self.skipTest("Only works on non-Windows platforms") - - super().setUp() - - def pallas_call(self, *args, **kwargs): - return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET) - - -class PallasCallTest(PallasBaseTest): +class PallasCallTest(ptu.PallasTest): def test_add_one(self): if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: @@ -757,7 +739,7 @@ class PallasCallInterpretTest(PallasCallTest): INTERPRET = True -class PallasCallElementIndexingTest(PallasBaseTest): +class PallasCallElementIndexingTest(ptu.PallasTest): def test_block_spec_element(self): def show_program_ids( @@ -893,7 +875,7 @@ class PallasCallElementIndexingInterpretTest(PallasCallElementIndexingTest): INTERPRET = True -class PallasCallBoundedSliceIndexingTest(PallasBaseTest): +class PallasCallBoundedSliceIndexingTest(ptu.PallasTest): def setUp(self): super().setUp() @@ -922,7 +904,7 @@ def kernel(x_ref, o_ref): ), )(x) -class ApiErrorTest(PallasBaseTest): +class ApiErrorTest(ptu.PallasTest): def test_pallas_call_kernel_args_mismatch(self): a = np.arange(256, dtype=np.int32) f = self.pallas_call(lambda x_ref: None, # Missing o_ref @@ -1174,7 +1156,7 @@ class ApiErrorInterpretTest(ApiErrorTest): INTERPRET = True -class PallasCallInputOutputAliasingTest(PallasBaseTest): +class PallasCallInputOutputAliasingTest(ptu.PallasTest): def test_vector_input_output_aliasing(self): # Input needs to be big so it doesn't fit in VMEM @@ -1285,11 +1267,11 @@ def f(x_scalar_in, x_vector_in): print(x_vector) -class PallasCallInputOutputAliasingInterpretTest(PallasBaseTest): +class PallasCallInputOutputAliasingInterpretTest(ptu.PallasTest): INTERPRET = True -class PallasControlFlowTest(PallasBaseTest): +class PallasControlFlowTest(ptu.PallasTest): def setUp(self): super().setUp() @@ -2078,7 +2060,7 @@ class PallasControlFlowInterpretTest(PallasControlFlowTest): ] -class PallasCallAutodifferentiationTest(PallasBaseTest): +class PallasCallAutodifferentiationTest(ptu.PallasTest): def setUp(self): super().setUp() @@ -2193,7 +2175,7 @@ class PallasCallAutodifferentiationInterpretTest(PallasCallAutodifferentiationTe INTERPRET = True -class PallasOutOfBoundsInterpretTest(PallasBaseTest): +class PallasOutOfBoundsInterpretTest(ptu.PallasTest): INTERPRET = True def test_interpret_mode_out_of_bounds_access(self): @@ -2273,7 +2255,7 @@ def _(): np.testing.assert_allclose(out, expected, atol=atol, rtol=rtol) -class PallasCheckifyTest(PallasBaseTest): +class PallasCheckifyTest(ptu.PallasTest): INTERPRET = False def test_basic_runtime_assert(self): @@ -2453,7 +2435,7 @@ class PallasCheckifyInterpretTest(PallasCheckifyTest): INTERPRET = True -class PallasCallNamedGridTest(PallasBaseTest): +class PallasCallNamedGridTest(ptu.PallasTest): def test_named_grid(self): def kernel(x_ref, y_ref): @@ -2553,7 +2535,7 @@ def kernel(x_ref, y_ref): ) -class SymbolicPallasTest(PallasBaseTest): +class SymbolicPallasTest(ptu.PallasTest): def test_simple_symbolic_matmul_export(self): if jtu.test_device_matches(["gpu"]): @@ -2747,7 +2729,7 @@ def index_to_lojax(xt: jax.Ref) -> jax.Array: index_p.to_lojax = index_to_lojax -class PallasHiJaxTest(PallasBaseTest): +class PallasHiJaxTest(ptu.PallasTest): def test_pass_weird_tuple_into_pallas_call(self): diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index aa61e0f65a23..7ed8c5c261a6 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -22,6 +22,7 @@ from jax import lax from jax._src import dtypes from jax._src import test_util as jtu +from jax._src.pallas import pallas_test_util as ptu from jax.experimental import pallas as pl import jax.numpy as jnp import numpy as np @@ -76,22 +77,8 @@ def rand( raise NotImplementedError(f"Unsupported random data generation for {dtype=}") -class PallasBaseTest(jtu.JaxTestCase): - INTERPRET = False - - def setUp(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Test only supported on TPU.") - - super().setUp() - - @classmethod - def pallas_call(cls, *args, **kwargs): - return pl.pallas_call(*args, interpret=cls.INTERPRET, **kwargs) - - @jtu.thread_unsafe_test_class(condition=not jtu.hypothesis_is_thread_safe()) -class OpsTest(PallasBaseTest): +class OpsTest(ptu.PallasTPUTest): @parameterized.product( from_dtype=_JAX_DTYPES, @@ -883,5 +870,6 @@ def kernel(x_ref, o_ref): np.testing.assert_array_equal(result, expected) + if __name__ == "__main__": absltest.main() diff --git a/tests/pallas/tpu_pallas_call_print_test.py b/tests/pallas/tpu_pallas_call_print_test.py index eb9a410da811..6dba1178b536 100644 --- a/tests/pallas/tpu_pallas_call_print_test.py +++ b/tests/pallas/tpu_pallas_call_print_test.py @@ -20,6 +20,7 @@ from absl.testing import parameterized import jax from jax._src import test_util as jtu +from jax._src.pallas import pallas_test_util as ptu from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp @@ -32,20 +33,8 @@ partial = functools.partial -class PallasBaseTest(jtu.JaxTestCase): - INTERPRET: bool = False - - def setUp(self): - if not jtu.test_device_matches(['tpu']) and not self.INTERPRET: - self.skipTest('Test requires TPUs, or interpret mode') - super().setUp() - - def pallas_call(self, *args, **kwargs): - return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET) - - @jtu.thread_unsafe_test_class() # debug print test is not thread safe -class PallasCallPrintTest(PallasBaseTest): +class PallasCallPrintTest(ptu.PallasTPUTest): def test_debug_print(self): @functools.partial( diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index e3b9369213e2..67a71ab452d4 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -34,6 +34,7 @@ from jax._src import state from jax._src import test_util as jtu from jax._src.interpreters import partial_eval as pe +from jax._src.pallas import pallas_test_util as ptu from jax._src.pallas.mosaic import error_handling from jax._src.state import discharge as state_discharge from jax._src.state import utils as state_utils @@ -90,19 +91,7 @@ def wrap_init(f: Callable, nr_args: int): debug_info=api_util.debug_info("state_test", f, (0,) * nr_args, {})) -class PallasBaseTest(jtu.JaxTestCase): - INTERPRET: bool = False - - def setUp(self): - if not jtu.test_device_matches(['tpu']) and not self.INTERPRET: - self.skipTest('Test requires TPUs, or interpret mode') - super().setUp() - - def pallas_call(self, *args, **kwargs): - return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET) - - -class TPUPipelineModeTest(PallasBaseTest): +class TPUPipelineModeTest(ptu.PallasTPUTest): @parameterized.parameters( (pl.Buffered(2), pl.Buffered(2)), @@ -154,7 +143,7 @@ def vadd(x, y): np.testing.assert_allclose(z, x + y) -class PallasCallScalarPrefetchTest(PallasBaseTest): +class PallasCallScalarPrefetchTest(ptu.PallasTPUTest): def test_trivial_scalar_prefetch(self): def body(_, x_ref, o_ref): o_ref[...] = x_ref[...] @@ -585,7 +574,7 @@ class PallasCallScalarPrefetchInterpretTest(PallasCallScalarPrefetchTest): INTERPRET: bool = True -class PallasCallDynamicGridTest(PallasBaseTest): +class PallasCallDynamicGridTest(ptu.PallasTPUTest): def test_can_query_grid_statically_via_num_programs(self): @@ -835,7 +824,7 @@ class PallasCallDynamicGridInterpretTest(PallasCallDynamicGridTest): INTERPRET = True -class PallasCallDMATest(PallasBaseTest): +class PallasCallDMATest(ptu.PallasTPUTest): def setUp(self): super().setUp() @@ -1876,7 +1865,7 @@ def test_kernel(o_ref, np.testing.assert_array_equal(results, expected) -class PallasCallTest(PallasBaseTest): +class PallasCallTest(ptu.PallasTPUTest): @parameterized.parameters([ dict(shape=shape, dty=dty) @@ -2702,7 +2691,7 @@ def kernel(condlist, choicelist, out_ref): np.testing.assert_array_equal(z, jnp.where(condlist, choicelist, 0)) -class PallasScalarIOpsTest(PallasBaseTest): +class PallasScalarIOpsTest(ptu.PallasTPUTest): @staticmethod def parameterized_integer_types(func): @@ -2823,7 +2812,7 @@ def kernel(x_ref, y_ref): self._integer_ops_canonicalization_helper(kernel, 1 ^ 2, dtype) -class PallasUXTest(PallasBaseTest): +class PallasUXTest(ptu.PallasTPUTest): def test_mlir_location(self): # Make sure that MLIR locations are correctly propagated to primitives. @@ -2841,7 +2830,7 @@ def capture_as_tpu_kernel(module, *args, **kwargs): mosaic.as_tpu_kernel = as_tpu_kernel -class PallasMegacoreTest(PallasBaseTest): +class PallasMegacoreTest(ptu.PallasTPUTest): def test_megacore_splitting(self): # We want to make sure a 3-sized dimension is split across megacore @@ -2877,7 +2866,7 @@ def _(): ) -class PallasCallVmapTest(PallasBaseTest): +class PallasCallVmapTest(ptu.PallasTPUTest): def test_scratch_input_vmap(self): """Test that vmapp-ing a kernel with scratch inputs works correctly.""" @@ -2914,7 +2903,7 @@ def add_one_with_scratch(x_ref, o_ref, scratch_ref): np.testing.assert_array_equal(out, out_ref, strict=True) -class PallasCallDynamicDMATest(PallasBaseTest): +class PallasCallDynamicDMATest(ptu.PallasTPUTest): def setUp(self): super().setUp() @@ -2978,7 +2967,7 @@ def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem): np.testing.assert_array_equal(out, expected) -class PallasCallRefTransformTest(PallasBaseTest): +class PallasCallRefTransformTest(ptu.PallasTPUTest): @parameterized.product(slice_first=[True, False]) def test_dma_bitcasted_ref(self, slice_first): @@ -3151,7 +3140,7 @@ def kernel(x_ref, y_ref): np.testing.assert_array_equal(y, x[8:16, :128]) -class PallasCallTraceTest(PallasBaseTest): +class PallasCallTraceTest(ptu.PallasTPUTest): @jtu.thread_unsafe_test() # stdout redirection is not thread safe def test_trace_start_stop_match(self): @@ -3201,7 +3190,7 @@ def scope2(): self.assertEqual(num_stop, 2) -class PallasCallTPUBooleanTest(PallasBaseTest): +class PallasCallTPUBooleanTest(ptu.PallasTPUTest): """Tests for loading/storing from bool memrefs on TPUs. We specifically test bools because they have special handling. @@ -3330,7 +3319,7 @@ class PallasCallTPUBooleanInterpretTest(PallasCallTPUBooleanTest): INTERPRET: bool = True -class PallasCallTPUCheckifyTest(PallasBaseTest): +class PallasCallTPUCheckifyTest(ptu.PallasTPUTest): @parameterized.parameters((2,), (5,), (6,), (7,)) def test_checkify_with_scalar_prefetch(self, threshold): def body(scalar_ref, x_ref, o_ref): @@ -3434,7 +3423,7 @@ class PallasCallTPUCheckifyInterpretTest(PallasCallTPUCheckifyTest): INTERPRET: bool = True -class PrettyPrintingTest(PallasBaseTest): +class PrettyPrintingTest(ptu.PallasTPUTest): @parameterized.parameters( ( @@ -3480,7 +3469,7 @@ def inner(x_ref, sem): self.assertIn(expected, jaxpr.pretty_print(use_color=False)) -class MiscellaneousTest(PallasBaseTest): +class MiscellaneousTest(ptu.PallasTPUTest): """Tests for reported bugs. Only pass in interpret mode unless fixed.""" def test_casting_bool_to_i8(self): @@ -4162,7 +4151,7 @@ def _(y): np.testing.assert_array_equal(result, np.ones((1,), dtype=jnp.float32)) -class PallasKernelMetadataTest(PallasBaseTest): +class PallasKernelMetadataTest(ptu.PallasTPUTest): @parameterized.parameters( (dict(foo='bar'),), diff --git a/tests/pallas/tpu_splash_attention_kernel_sharded_test.py b/tests/pallas/tpu_splash_attention_kernel_sharded_test.py index 9598e46653f3..4b5b9ee19a48 100644 --- a/tests/pallas/tpu_splash_attention_kernel_sharded_test.py +++ b/tests/pallas/tpu_splash_attention_kernel_sharded_test.py @@ -20,11 +20,16 @@ import jax from jax import random from jax._src import test_util as jtu +from jax._src.pallas import pallas_test_util as ptu +from jax._src.shard_map import shard_map +from jax.experimental.pallas.ops.tpu.splash_attention import ( + CausalMask, + MultiHeadMask, + SegmentIds, + make_splash_mha, +) from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel as splash from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib -from jax.experimental.pallas.ops.tpu.splash_attention import ( - CausalMask, MultiHeadMask, SegmentIds, make_splash_mha) -from jax._src.shard_map import shard_map import jax.numpy as jnp from jax.sharding import PartitionSpec import numpy as np @@ -35,14 +40,10 @@ @jtu.with_config(jax_traceback_filtering="off") -class PallasBaseTest(jtu.JaxTestCase): - INTERPRET = False +class PallasBaseTest(ptu.PallasTPUTest): def setUp(self): super().setUp() - if not jtu.is_device_tpu(): - self.skipTest("Test requires TPU.") - if len(jax.devices()) < 4: self.skipTest("This test requires at least 4 devices.") diff --git a/tests/pallas/tpu_splash_attention_kernel_test.py b/tests/pallas/tpu_splash_attention_kernel_test.py index 178042a87ba9..0abf7f8c9c81 100644 --- a/tests/pallas/tpu_splash_attention_kernel_test.py +++ b/tests/pallas/tpu_splash_attention_kernel_test.py @@ -22,9 +22,12 @@ from absl.testing import absltest from absl.testing import parameterized +import hypothesis as hp +import hypothesis.strategies as hps import jax from jax import random from jax._src import test_util as jtu +from jax._src.pallas import pallas_test_util as ptu from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel as splash from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask as mask_lib from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask_info import process_mask @@ -32,10 +35,6 @@ import numpy as np -import hypothesis as hp -import hypothesis.strategies as hps - - jax.config.parse_flags_with_absl() jtu.setup_hypothesis(max_examples=5) @@ -304,13 +303,10 @@ def attn_logits_soft_cap_strategy() -> hps.SearchStrategy[float | None]: @jtu.with_config(jax_traceback_filtering="off") -class PallasBaseTest(jtu.JaxTestCase): - INTERPRET = False +class PallasBaseTest(ptu.PallasTPUTest): def setUp(self): if not self.INTERPRET: - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Only interpret mode supported on non-TPU") # TODO(b/327487669): selectively re-enable tests that works on TPU v3. if not jtu.is_device_tpu_at_least(4): self.skipTest("Not supported on TPU generations <= 3") From 766f189919ecaa086b1899185c19e60452df7c07 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 4 Dec 2025 11:11:16 -0800 Subject: [PATCH 055/315] [test] make signature error test robust to NumPy version --- tests/api_test.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/api_test.py b/tests/api_test.py index 8dde8b47d6c2..1d9bd6339d2d 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -416,12 +416,20 @@ def f(args_list): # Jit and Donate arguments def test_donate_argnames_signature_fail(self): + class NoSignature: + @property + def __signature__(self): + raise TypeError("no signature") + def __call__(self, *args, **kwargs): + return None + fun = NoSignature() + inp = np.arange(4) with self.assertRaisesRegex( ValueError, "Getting the signature of function.*failed. Pass donate_argnums " "instead of donate_argnames."): - jax.jit(np.dot, donate_argnames='a')(inp, inp) + jax.jit(fun, donate_argnames='a')(inp, inp) @parameterized.named_parameters( ("argnums", "donate_argnums", (0, 1)), From f9e9e1da31708d1523b7ee069f7c3ece058ac7bb Mon Sep 17 00:00:00 2001 From: Yulia Baturina Date: Thu, 4 Dec 2025 12:06:32 -0800 Subject: [PATCH 056/315] Add `example_libraries/__init__.py` to jax wheel sources as a follow-up after the cleanup in https://github.com/jax-ml/jax/pull/33723. PiperOrigin-RevId: 840344608 --- BUILD.bazel | 1 + jax/example_libraries/BUILD | 8 ++++++++ 2 files changed, 9 insertions(+) diff --git a/BUILD.bazel b/BUILD.bazel index ce175425f435..284f54495095 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -32,6 +32,7 @@ wheel_sources( "//jax", "//jax:compilation_cache", "//jax:experimental", + "//jax/example_libraries:example_libraries", "//jax:experimental_colocated_python", "//jax:experimental_sparse", "//jax:experimental_buffer_callback", diff --git a/jax/example_libraries/BUILD b/jax/example_libraries/BUILD index 133b066cb053..e740757c32a1 100644 --- a/jax/example_libraries/BUILD +++ b/jax/example_libraries/BUILD @@ -19,6 +19,14 @@ package( default_visibility = ["//jax:internal"], ) +pytype_strict_library( + name = "example_libraries", + srcs = [ + "__init__.py", + ], + visibility = ["//jax:internal"], +) + pytype_strict_library( name = "stax", srcs = [ From 334ef7ef91d0a105cd9570f26d040464460460a5 Mon Sep 17 00:00:00 2001 From: Ivy Zheng Date: Thu, 4 Dec 2025 13:05:54 -0800 Subject: [PATCH 057/315] Add test case to debug_print scalar in Pallas. PiperOrigin-RevId: 840367743 --- tests/pallas/tpu_pallas_call_print_test.py | 23 +++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/tests/pallas/tpu_pallas_call_print_test.py b/tests/pallas/tpu_pallas_call_print_test.py index 6dba1178b536..5978a886576f 100644 --- a/tests/pallas/tpu_pallas_call_print_test.py +++ b/tests/pallas/tpu_pallas_call_print_test.py @@ -78,17 +78,23 @@ def kernel(x_ref, o_ref): jax.block_until_ready(compiled_kernel(x)) self.assertIn('It works!', get_output()) - def test_debug_print_with_values(self): + @parameterized.product(dtype=[jnp.int32, jnp.float32]) + def test_debug_print_with_values(self, dtype): @functools.partial( self.pallas_call, in_specs=(pl.BlockSpec(memory_space=pltpu.SMEM),), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), ) def kernel(x_ref, o_ref): - pl.debug_print('BEGIN1 x[0] == {}', x_ref[0]) - pl.debug_print('BEGIN2 x[0] == {} ; x[1] == {} ; END', x_ref[0], x_ref[1]) - - x = jnp.array([42, 24]).astype(jnp.int32) + if dtype == jnp.int32: + pl.debug_print('BEGIN1 x[0] == {}', x_ref[0]) + pl.debug_print( + 'BEGIN2 x[0] == {} ; x[1] == {} ; END', x_ref[0], x_ref[1] + ) + else: + pl.debug_print('BEGIN1 x[0] == ', x_ref[0]) + + x = jnp.array([42, 24], dtype=dtype) compiled_kernel = ( jax.jit(kernel) .lower(x) @@ -97,8 +103,11 @@ def kernel(x_ref, o_ref): with jtu.capture_stderr() as get_output: jax.block_until_ready(compiled_kernel(x)) output = get_output() - self.assertIn('BEGIN1 x[0] == 42', output) - self.assertIn('BEGIN2 x[0] == 42 ; x[1] == 24 ; END', output) + if dtype == jnp.int32: + self.assertIn('BEGIN1 x[0] == 42', output) + self.assertIn('BEGIN2 x[0] == 42 ; x[1] == 24 ; END', output) + else: + self.assertIn('BEGIN1 x[0] == f32[] 42', output) @parameterized.named_parameters( (f"{'_'.join(map(str, shape))}_{dtype.__name__}", shape, dtype) From 2f7c9073f4978373e04289c91b7fdccf0f8c06c7 Mon Sep 17 00:00:00 2001 From: Michael Whittaker Date: Thu, 4 Dec 2025 14:47:08 -0800 Subject: [PATCH 058/315] Change PjRt to use new copy of coordination service. PiperOrigin-RevId: 840411324 --- jaxlib/BUILD | 1 + jaxlib/jax.cc | 27 ++++++++++++++++++++++++++- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index e9cf8b41fc48..bfcff630cc89 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -397,6 +397,7 @@ nanobind_pywrap_extension( "@xla//xla/pjrt/distributed:key_value_store_interface", "@xla//xla/pjrt/distributed:protocol_proto_cc", "@xla//xla/pjrt/distributed:service", + "@xla//xla/pjrt/distributed/preemption:preemption_sync_manager", "@xla//xla/pjrt/plugin/xla_cpu:cpu_client_options", "@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", "@xla//xla/python:logging", diff --git a/jaxlib/jax.cc b/jaxlib/jax.cc index 1571ddb29bb9..c75cd8dde6b4 100644 --- a/jaxlib/jax.cc +++ b/jaxlib/jax.cc @@ -116,6 +116,7 @@ limitations under the License. #include "xla/hlo/builder/lib/approx_topk_shape.h" #include "xla/pjrt/c_api_client/pjrt_c_api_client.h" #include "xla/pjrt/distributed/key_value_store_interface.h" +#include "xla/pjrt/distributed/preemption/preemption_sync_manager.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/pjrt_api.h" #include "xla/pjrt/pjrt_client.h" @@ -580,6 +581,30 @@ NB_MODULE(_jax, m) { aux::RegisterTransferServerTypes(m); #endif // defined(__linux__) +#if JAX_IFRT_VERSION_NUMBER >= 38 + nb::class_ preemption_sync_manager( + m, "PreemptionSyncManager"); + preemption_sync_manager + .def( + "initialize", + [](xla::PreemptionSyncManager& manager, + xla::DistributedRuntimeClient* client) { + xla::CoordinationServiceAgent* agent = + xla::ValueOrThrow(client->GetCoordinationServiceAgent()); + xla::ThrowIfError(manager.Initialize(agent)); + }, + nb::arg("distributed_client")) + .def("reached_sync_point", + [](xla::PreemptionSyncManager& manager, int step_counter) { + return manager.ReachedSyncPoint(step_counter); + }) + .def("shutdown", [](xla::PreemptionSyncManager& manager) { + nb::gil_scoped_release gil_release; + manager.Shutdown(); + }); + m.def("create_preemption_sync_manager", + []() { return xla::CreatePreemptionSyncManager(); }); +#else nb::class_ preemption_sync_manager( m, "PreemptionSyncManager"); preemption_sync_manager @@ -602,6 +627,7 @@ NB_MODULE(_jax, m) { }); m.def("create_preemption_sync_manager", []() { return tsl::CreatePreemptionSyncManager(); }); +#endif nb::class_ distributed_runtime_service( m, "DistributedRuntimeService"); @@ -898,7 +924,6 @@ NB_MODULE(_jax, m) { nb::class_( m, "TransferServerInterfaceFactory"); - m.def("is_asan", IsAsan); m.def("is_msan", IsMsan); m.def("is_tsan", IsTsan); From 2e548e5bd72b824c0b5027e99c02a2a31fbc07b6 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Thu, 4 Dec 2025 15:43:02 -0800 Subject: [PATCH 059/315] AOT: remove explicit lojax stage (`Fallen`) and extend `Traced` instead We can simply hang lojax information off of a `Traced` instance (as `Traced.lojax`), without making it a `Stage` of its own. From an external point of view, it seems OK for our various jaxpr levels to remain bundled under the tracing stage of AOT. This change preserves some degree of explicit public control over when the work of producing lojax is carried out, by computing it lazily. Note that we had no tests exercising `fall()` or `Fallen` anyway, suggesting that we have been treating these as internal symbols rather than a tested public API. PiperOrigin-RevId: 840433294 --- jax/_src/stages.py | 76 ++++++++++++++++++++++++---------------------- 1 file changed, 40 insertions(+), 36 deletions(-) diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 013d35d5ad7c..810f76cd8de8 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -402,8 +402,12 @@ class Traced(Stage): A traced computation is ready for lowering. This class carries the traced representation with the remaining information needed to later lower, compile, and execute it. + + Provides access to both the hijax (high-level) and lojax (low-level) + representations via `.jaxpr` and `.lojax` properties respectively. """ - __slots__ = ['_meta_tys_flat', '_params', '_in_tree', 'out_tree', '_consts'] + __slots__ = ['_meta_tys_flat', '_params', '_in_tree', 'out_tree', '_consts', + '_lojax'] def __init__(self, meta_tys_flat, params, in_tree, out_tree, consts): self._meta_tys_flat = meta_tys_flat @@ -411,6 +415,7 @@ def __init__(self, meta_tys_flat, params, in_tree, out_tree, consts): self._in_tree = in_tree self.out_tree = out_tree self._consts = consts + self._lojax = None jaxpr = property(lambda self: self._params['jaxpr']) fun_name = property(lambda self: self._params['name']) @@ -422,12 +427,18 @@ def __init__(self, meta_tys_flat, params, in_tree, out_tree, consts): def out_avals(self): return tree_unflatten(self.out_tree, self.jaxpr.out_avals) - def fall(self): + @property + def lojax(self) -> LoJax: + if self._lojax is not None: + return self._lojax + if not self.jaxpr.is_high: - return Fallen(self._meta_tys_flat, self._params, self._in_tree, - self.out_tree, (self._in_tree, self.jaxpr.in_avals), - (self.out_tree, self.jaxpr.out_avals), - self._consts) + self._lojax = LoJax( + self._meta_tys_flat, self._params, self._in_tree, self.out_tree, + (self._in_tree, self.jaxpr.in_avals), + (self.out_tree, self.jaxpr.out_avals), + self._consts) + return self._lojax # TODO(mattjj): when pmap is deleted, merge with pjit.py BUILD rule from jax._src.interpreters import partial_eval as pe # type:ignore @@ -448,16 +459,32 @@ def fall(self): for mty, aq in zip(self._meta_tys_flat, hi_jaxpr.in_aval_qdds) for lo_ty in (mty.aval.lo_ty_qdd(aq.qdd) if mty.aval.has_qdd else mty.aval.lo_ty())] - return Fallen(lo_meta_tys, params, in_tree, out_tree, - (self._in_tree, hi_jaxpr.final_aval_qdds), - (self.out_tree, hi_jaxpr.out_avals), - self._consts) + self._lojax = LoJax( + lo_meta_tys, params, in_tree, out_tree, + (self._in_tree, hi_jaxpr.final_aval_qdds), + (self.out_tree, hi_jaxpr.out_avals), + self._consts) + return self._lojax def lower(self, *, lowering_platforms: tuple[str, ...] | None = None, _private_parameters: mlir.LoweringParameters | None = None): """Lower to compiler input, returning a ``Lowered`` instance.""" - return self.fall().lower(lowering_platforms=lowering_platforms, - _private_parameters=_private_parameters) + lo = self.lojax + if _private_parameters is None: + _private_parameters = mlir.LoweringParameters() + try: + from jax._src.pjit import _resolve_and_lower # type: ignore + lowering = _resolve_and_lower( + lo._meta_tys_flat, **lo._params, lowering_platforms=lowering_platforms, + lowering_parameters=_private_parameters, pgle_profiler=None) + except DeviceAssignmentMismatchError as e: + fails, = e.args + msg = _device_assignment_mismatch_error( + lo._params['name'], fails, lo._meta_tys_flat, 'jit', + lo.jaxpr.debug_info.safe_arg_names(len(lo.jaxpr.in_avals))) + raise ValueError(msg) from None + return Lowered(lowering, lo.args_info, lo.out_tree, + in_types=lo._in_types, out_types=lo._out_types) def lojax_expand_params(jaxpr, params): @@ -474,8 +501,7 @@ def lojax_pytree(hi_avals, tree): return tree_structure(tree_unflatten(tree, lo_avals)) -class Fallen(Stage): - """True leader of the Decepticons.""" +class LoJax: __slots__ = ['_meta_tys_flat', '_params', '_in_tree', 'out_tree', '_consts', '_in_types', '_out_types'] @@ -495,28 +521,6 @@ def __init__(self, meta_tys_flat, params, in_tree, out_tree, in_types, out_types out_info = property(_traced_out_info) _num_consts = property(lambda self: len(self._consts)) - @property - def out_avals(self): - return tree_unflatten(self.out_tree, self.jaxpr.out_avals) - - def lower(self, *, lowering_platforms: tuple[str, ...] | None = None, - _private_parameters: mlir.LoweringParameters | None = None): - """Lower to compiler input, returning a ``Lowered`` instance.""" - if _private_parameters is None: - _private_parameters = mlir.LoweringParameters() - try: - from jax._src.pjit import _resolve_and_lower # type: ignore - lowering = _resolve_and_lower( - self._meta_tys_flat, **self._params, lowering_platforms=lowering_platforms, - lowering_parameters=_private_parameters, pgle_profiler=None) - except DeviceAssignmentMismatchError as e: - fails, = e.args - msg = _device_assignment_mismatch_error( - self._params['name'], fails, self._meta_tys_flat, 'jit', - self.jaxpr.debug_info.safe_arg_names(len(self.jaxpr.in_avals))) - raise ValueError(msg) from None - return Lowered(lowering, self.args_info, self.out_tree, - in_types=self._in_types, out_types=self._out_types) class Lowered(Stage): From bf7661017b322b668de7855f7f63494c0f9ab140 Mon Sep 17 00:00:00 2001 From: Yulia Baturina Date: Thu, 4 Dec 2025 15:52:34 -0800 Subject: [PATCH 060/315] Add `pallas_test_util.py` to the test runfiles as a follow-up after https://github.com/jax-ml/jax/pull/33704. PiperOrigin-RevId: 840436497 --- BUILD.bazel | 1 + 1 file changed, 1 insertion(+) diff --git a/BUILD.bazel b/BUILD.bazel index 284f54495095..c9bcd69227da 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -124,6 +124,7 @@ genrule( "//jax/experimental/mosaic/gpu/examples:flash_attention.py", "//jax/experimental/mosaic/gpu/examples:matmul.py", "//jax/_src:test_multiprocess", + "//jax/_src/pallas:pallas_test_util", ], outs = ["wheel_additives.zip"], cmd = "$(location @bazel_tools//tools/zip:zipper) c $@ $(SRCS)", From 36cd74a1938b02a0084d7fb4175dabd6427e7651 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Thu, 4 Dec 2025 16:25:19 -0800 Subject: [PATCH 061/315] Rename if_cloud_tpu_at_least to is_cloud_tpu_at_least PiperOrigin-RevId: 840447696 --- jax/_src/test_util.py | 2 +- tests/lax_test.py | 2 +- tests/pallas/indexing_test.py | 3 +- tests/pallas/ops_test.py | 6 +- tests/pallas/pallas_test.py | 4 +- tests/pallas/tpu_ops_test.py | 16 +++--- tests/pallas/tpu_pallas_test.py | 64 +++++++++++----------- tests/pallas/tpu_side_effects_test.py | 2 +- tests/pallas/tpu_sparsecore_pallas_test.py | 18 +++--- tests/pjit_test.py | 2 +- 10 files changed, 59 insertions(+), 60 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 5ebfe0f3821d..366bec41611b 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -431,7 +431,7 @@ def is_sanitized(): # built at least `date``. # TODO(b/327203806): after libtpu adds a XLA version and the oldest support # libtpu contains the XLA version, remove using built time to skip tests. -def if_cloud_tpu_at_least(year: int, month: int, day: int): +def is_cloud_tpu_at_least(year: int, month: int, day: int): date = datetime.date(year, month, day) if not is_cloud_tpu(): return True diff --git a/tests/lax_test.py b/tests/lax_test.py index 2ddb296df13e..8f1aa9304718 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -5025,7 +5025,7 @@ def test_ragged_dot_use_ragged_dot_instruction(self, use_instruction): {"m": 10, "k": 9, "n": 8, "num_groups": 2}, ) def test_ragged_dot_small_m(self, m, k, n, num_groups): - if not jtu.if_cloud_tpu_at_least(2025, 10, 14): + if not jtu.is_cloud_tpu_at_least(2025, 10, 14): self.skipTest("Requires libtpu built after 2025-10-14") lhs_shape = (m, k) rhs_shape = (num_groups, k, n) diff --git a/tests/pallas/indexing_test.py b/tests/pallas/indexing_test.py index 3739f6003a30..59aaa0dc4dbc 100644 --- a/tests/pallas/indexing_test.py +++ b/tests/pallas/indexing_test.py @@ -543,7 +543,7 @@ def body(x_ref, y_ref1, y_ref2): @hp.given(hps.data()) def test_load_and_broadcast_with_stride_0(self, data): - if not jtu.if_cloud_tpu_at_least(2025, 11, 25): + if not jtu.is_cloud_tpu_at_least(2025, 11, 25): self.skipTest("Requires libtpu built after 2025-11-25") if self.INTERPRET: self.skipTest("TODO: fails in interpret mode.") @@ -580,7 +580,6 @@ def body(x_ref, y_ref): expected = jnp.broadcast_to(x[slices], shape) self.assertAllClose(y, expected) - def test_load_with_dynamic_2nd_minor_index(self): if pltpu is None: self.skipTest("No TPU module available.") diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index d157c0b6c388..93ea67c563e3 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1462,7 +1462,7 @@ def test_dot_general_multiple_non_contracting_dims( if jtu.test_device_matches(["gpu"]): self.skipTest("TPU only test") - if jtu.test_device_matches(["tpu"]) and not jtu.if_cloud_tpu_at_least( + if jtu.test_device_matches(["tpu"]) and not jtu.is_cloud_tpu_at_least( 2025, 10, 5 ): self.skipTest("Requires libtpu built after 2025-10-05") @@ -1522,7 +1522,7 @@ def test_dot_general_non_front_batch_dims(self, shapes_and_dims_numbers): if jtu.test_device_matches(["gpu"]): self.skipTest("TPU only test") - if jtu.test_device_matches(["tpu"]) and not jtu.if_cloud_tpu_at_least( + if jtu.test_device_matches(["tpu"]) and not jtu.is_cloud_tpu_at_least( 2025, 11, 30 ): self.skipTest("Requires libtpu built after 2025-11-30") @@ -1585,7 +1585,7 @@ def test_dot_general_multiple_non_contracting_dims_with_transposes( if jtu.test_device_matches(["gpu"]): self.skipTest("TPU only test") - if jtu.test_device_matches(["tpu"]) and not jtu.if_cloud_tpu_at_least( + if jtu.test_device_matches(["tpu"]) and not jtu.is_cloud_tpu_at_least( 2025, 10, 5 ): self.skipTest("Requires libtpu built after 2025-10-05") diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 689800eb1f81..4400a5b7ac14 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -1188,7 +1188,7 @@ def f(x): self.assertEqual(mem_analysis.temp_size_in_bytes, 0) def test_scalar_input_output_aliasing(self): - if jtu.test_device_matches(["tpu"]) and not jtu.if_cloud_tpu_at_least( + if jtu.test_device_matches(["tpu"]) and not jtu.is_cloud_tpu_at_least( 2025, 10, 7 ): self.skipTest("Requires libtpu built after 2025-10-07") @@ -1221,7 +1221,7 @@ def f(x_in): print(x) def test_mixed_scalar_vector_input_output_aliasing(self): - if jtu.test_device_matches(["tpu"]) and not jtu.if_cloud_tpu_at_least( + if jtu.test_device_matches(["tpu"]) and not jtu.is_cloud_tpu_at_least( 2025, 10, 7 ): self.skipTest("Requires libtpu built after 2025-10-07") diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index 7ed8c5c261a6..f4cc095f921f 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -194,7 +194,7 @@ def body(x_ref, o_ref): ) def test_sum_of_two_matmuls(self): - if not jtu.if_cloud_tpu_at_least(2025, 11, 15): + if not jtu.is_cloud_tpu_at_least(2025, 11, 15): self.skipTest("Test requires libtpu from 2025/11/15 or later") if not jtu.is_device_tpu_at_least(version=5): self.skipTest("Test requires TPUv5+") @@ -349,7 +349,7 @@ def kernel(x, out): keepdims=[False, True], ) def test_reduce_index(self, axis, in_shape, reduce_func, keepdims): - if not keepdims and not jtu.if_cloud_tpu_at_least(2025, 11, 24): + if not keepdims and not jtu.is_cloud_tpu_at_least(2025, 11, 24): self.skipTest("Requires libtpu built after 2025-11-24") dtype = jnp.float32 rank = len(in_shape) @@ -389,7 +389,7 @@ def kernel(x, out): dtype=[jnp.float32, jnp.bfloat16], ) def test_i1_relayout_bw(self, shape, msk_dtype, dtype): - if shape[0] < 8 and not jtu.if_cloud_tpu_at_least(2025, 11, 9): + if shape[0] < 8 and not jtu.is_cloud_tpu_at_least(2025, 11, 9): self.skipTest("Requires libtpu built after 2025-11-09") msk_bitwidth = dtypes.itemsize_bits(msk_dtype) bitwidth = dtypes.itemsize_bits(dtype) @@ -424,7 +424,7 @@ def kernel(x_ref, mask_ref, o_ref): ) def test_i1_relayout_bw_tiling(self, msk_dtype, dtype): self.skipTest("TODO: jevinjiang - Enable once presubmits pass.") - if not jtu.if_cloud_tpu_at_least(2025, 10, 7): + if not jtu.is_cloud_tpu_at_least(2025, 10, 7): self.skipTest("Requires libtpu built after 2025-10-07") shape = (256, 256) bitwidth = dtypes.itemsize_bits(dtype) @@ -708,7 +708,7 @@ def else_0(): self.assertEqual(output, 0) def test_retiling_with_replicated_lane(self): - if not jtu.if_cloud_tpu_at_least(2025, 11, 5): + if not jtu.is_cloud_tpu_at_least(2025, 11, 5): self.skipTest("Test requires libtpu from 2025/11/5 or later") shape = (32, 1) broadcast_shape = (32, 256) @@ -733,7 +733,7 @@ def kernel(x_ref, o_ref): def test_stochastic_round(self, target_dtype): if not jtu.is_device_tpu_at_least(version=5): self.skipTest("Requires TPU v5+") - if not jtu.if_cloud_tpu_at_least(2025, 10, 29): + if not jtu.is_cloud_tpu_at_least(2025, 10, 29): self.skipTest("Test requires libtpu from 2025/10/29 or later") def kernel(x_ref, b_ref, o_ref): @@ -807,7 +807,7 @@ def test_pack_elementwise(self, config, shape): unpacked_dtype, packed_dtype = config if not jtu.is_device_tpu_at_least(version=5): self.skipTest("Requires TPU v5+") - if not jtu.if_cloud_tpu_at_least(2025, 11, 7): + if not jtu.is_cloud_tpu_at_least(2025, 11, 7): self.skipTest("Test requires libtpu from 2025/11/7 or later") bitwidth = dtypes.itemsize_bits(packed_dtype) @@ -842,7 +842,7 @@ def test_unpack_elementwise(self, config, index, shape): unpacked_dtype, packed_dtype = config if not jtu.is_device_tpu_at_least(version=5): self.skipTest("Requires TPU v5+") - if not jtu.if_cloud_tpu_at_least(2025, 11, 7): + if not jtu.is_cloud_tpu_at_least(2025, 11, 7): self.skipTest("Test requires libtpu from 2025/11/7 or later") bitwidth = dtypes.itemsize_bits(packed_dtype) diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 67a71ab452d4..7923bfa27f82 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -1901,7 +1901,7 @@ def reduce(): reduce_value = jnp.sum(jnp.full(shape, x), dtype=dty) np.testing.assert_allclose(z, reduce_value) - if not jtu.if_cloud_tpu_at_least(2025, 10, 12): + if not jtu.is_cloud_tpu_at_least(2025, 10, 12): self.skipTest( 'New CompilerParams shape_invariant_numerics was added on Oct 12,' ' 2025' @@ -2012,7 +2012,7 @@ def reduce(x): expected = reduce_func(dilated_x, axis=reduced_dims).reshape(red_shape) np.testing.assert_allclose(y, expected) - if not jtu.if_cloud_tpu_at_least(2025, 10, 12): + if not jtu.is_cloud_tpu_at_least(2025, 10, 12): self.skipTest( 'New CompilerParams shape_invariant_numerics was added on Oct 12,' ' 2025' @@ -2142,7 +2142,7 @@ def kernel(x_ref, y_ref): pl.Buffered(2), ]) def test_vmem_oom_error_message_basics(self, pmode: pl.Buffered): - if not jtu.if_cloud_tpu_at_least(2025, 11, 12): + if not jtu.is_cloud_tpu_at_least(2025, 11, 12): self.skipTest('Support added on Nov 12, 2025') if jtu.is_device_tpu(version=5, variant='e') or jtu.is_device_tpu( @@ -2192,7 +2192,7 @@ def index_map(i, j): f' full shape is f32[{shape[0]},{shape[1]}].', error_message, ) - if jtu.if_cloud_tpu_at_least(2025, 11, 5): + if jtu.is_cloud_tpu_at_least(2025, 11, 5): self.assertIn( 'This allocation is single buffered.' if pmode.buffer_count == 1 @@ -2205,7 +2205,7 @@ def test_vmem_oom_error_message_dynamic_grid_scalar_prefetch_and_vmem_scratch( ): if jax.device_count() > 1: self.skipTest("Test only works with a single device.") - if not jtu.if_cloud_tpu_at_least(2025, 10, 14): + if not jtu.is_cloud_tpu_at_least(2025, 10, 14): self.skipTest('Support added on Oct 14, 2025') def body(s_ref, x_hbm_ref, o_hbm_ref, vmem_scratch_ref): @@ -2256,7 +2256,7 @@ def run(num_grid, s, x): def test_automatic_single_buffering(self,): if self.INTERPRET: self.skipTest('OOM tests need us to compile the kernels') - if not jtu.if_cloud_tpu_at_least(2025, 11, 12): + if not jtu.is_cloud_tpu_at_least(2025, 11, 12): self.skipTest('Support added on Oct 14, 2025') def body(*_): @@ -2550,7 +2550,7 @@ def test_scalar_integer_addition(self, dtype): def kernel(x_ref, y_ref): y_ref[0] = x_ref[0] + x_ref[0] - if not jtu.if_cloud_tpu_at_least(2025, 9, 13): + if not jtu.is_cloud_tpu_at_least(2025, 9, 13): self.skipTest('Scalar integer addition support was added on Sep 13, 2025') x = jnp.asarray([3], dtype=dtype) @@ -2589,7 +2589,7 @@ def test_vector_integer_addition(self, dtype): def kernel(x_ref, y_ref): y_ref[...] = x_ref[...] + x_ref[...] - if not jtu.if_cloud_tpu_at_least(2025, 9, 15): + if not jtu.is_cloud_tpu_at_least(2025, 9, 15): self.skipTest('Descriptive message was added on Sep 15, 2025') x = jnp.full((128, 16), 7, dtype=dtype) @@ -2623,7 +2623,7 @@ def test_max_operation(self, dtype): def kernel(x_ref, y_ref): y_ref[0] = jnp.maximum(x_ref[0], x_ref[1]) - if not jtu.if_cloud_tpu_at_least(2025, 9, 20): + if not jtu.is_cloud_tpu_at_least(2025, 9, 20): self.skipTest('Support added on Sep 20, 2025') x = jnp.asarray([242, 87], dtype=dtype) @@ -2645,7 +2645,7 @@ def test_min_operation(self, dtype): def kernel(x_ref, y_ref): y_ref[0] = jnp.minimum(x_ref[0], x_ref[1]) - if not jtu.if_cloud_tpu_at_least(2025, 9, 20): + if not jtu.is_cloud_tpu_at_least(2025, 9, 20): self.skipTest('Support added on Sep 20, 2025') x = jnp.asarray([242, 87], dtype=dtype) @@ -2674,7 +2674,7 @@ def test_bool_select_operation(self, dtype): def kernel(condlist, choicelist, out_ref): out_ref[...] = jnp.where(condlist[...], choicelist[...], 0) - if not jtu.if_cloud_tpu_at_least(2025, 10, 15): + if not jtu.is_cloud_tpu_at_least(2025, 10, 15): self.skipTest('Support added on Oct 15, 2025') if dtype in [jnp.int4, jnp.uint4] and not jtu.is_device_tpu_at_least(4): @@ -2714,7 +2714,7 @@ def wrapper(*args, **kwargs): def _integer_ops_canonicalization_helper(self, kernel, result, dtype): """For integer scalar ops, only i1 and i32 are supported.""" - if not jtu.if_cloud_tpu_at_least(2025, 9, 27): + if not jtu.is_cloud_tpu_at_least(2025, 9, 27): self.skipTest('Error message was changed on Sep 27, 2025') x = jnp.arange(3, dtype=dtype) @@ -3475,7 +3475,7 @@ class MiscellaneousTest(ptu.PallasTPUTest): def test_casting_bool_to_i8(self): if not jtu.is_device_tpu_at_least(5): self.skipTest("Operation not supported on this TPU version.") - if not jtu.if_cloud_tpu_at_least(2025, 9, 12): + if not jtu.is_cloud_tpu_at_least(2025, 9, 12): self.skipTest("Needs a newer libtpu") def greater_than(x: jax.Array, y: jax.Array): @@ -3518,7 +3518,7 @@ def kernel(x_ref, y_ref, out_ref): np.testing.assert_array_equal(out, np.stack([x, y], axis=1)) def test_lane_to_chunk_reshape_bf16(self): - if not jtu.if_cloud_tpu_at_least(2025, 7, 12): + if not jtu.is_cloud_tpu_at_least(2025, 7, 12): self.skipTest('Needs a newer libTPU') if not jtu.is_device_tpu_at_least(4): self.skipTest('Operation not supported on this TPU version.') @@ -3591,7 +3591,7 @@ def test_roll_partial_with_static_shift( self, shape: tuple[int, int], shift: int, axis: int ): if ( - not jtu.if_cloud_tpu_at_least(2025, 7, 19) + not jtu.is_cloud_tpu_at_least(2025, 7, 19) and shape[0] % 8 and axis == 0 ): @@ -3627,7 +3627,7 @@ def kernel(x_ref, out_ref): )(x) def test_retiling1(self): - if not jtu.if_cloud_tpu_at_least(2025, 7, 2): + if not jtu.is_cloud_tpu_at_least(2025, 7, 2): self.skipTest('Needs a newer libtpu') x = np.arange(1024, dtype=jnp.bfloat16).reshape(1024) @@ -3657,7 +3657,7 @@ def kernel(x_ref, out_ref): np.testing.assert_array_equal(out, np.reshape(x[:, 7, :], (1, 8, 128))) def test_sublane_adding_shape_cast_f32(self): - if not jtu.if_cloud_tpu_at_least(2025, 7, 12): + if not jtu.is_cloud_tpu_at_least(2025, 7, 12): self.skipTest('Needs a newer libTPU') x = np.arange(8 * 128, dtype=jnp.float32).reshape(8, 128) @@ -3671,7 +3671,7 @@ def kernel(x_ref, out_ref): np.testing.assert_array_equal(out, np.reshape(x, (8, 1, 128))) def test_sublane_adding_shape_cast_bf16(self): - if not jtu.if_cloud_tpu_at_least(2025, 7, 12): + if not jtu.is_cloud_tpu_at_least(2025, 7, 12): self.skipTest('Needs a newer libTPU') if not jtu.is_device_tpu_at_least(4): self.skipTest('Operation not supported on this TPU version.') @@ -3715,7 +3715,7 @@ def kernel(x_ref, out_ref): np.testing.assert_array_equal(out, np.zeros((8, 2, 128), dtype=jnp.float32)) def test_transpose(self): - if not jtu.if_cloud_tpu_at_least(2025, 9, 19): + if not jtu.is_cloud_tpu_at_least(2025, 9, 19): self.skipTest('Needs a newer libTPU') x = np.zeros((8, 2, 8, 128), dtype=jnp.float32) @@ -3746,7 +3746,7 @@ def kernel(x_ref, out_ref): ) ) def test_reshape_two_minor_dims_to_R2(self, q, m, n, dtype): - if not jtu.if_cloud_tpu_at_least(2025, 7, 12): + if not jtu.is_cloud_tpu_at_least(2025, 7, 12): self.skipTest('Needs a newer libTPU') if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) @@ -3782,7 +3782,7 @@ def kernel(x_ref, y_ref): ) ) def test_reshape_two_minor_dims_to_R3(self, q, m, n, k, dtype): - if not jtu.if_cloud_tpu_at_least(2025, 7, 12): + if not jtu.is_cloud_tpu_at_least(2025, 7, 12): self.skipTest('Needs a newer libTPU') if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) @@ -3815,7 +3815,7 @@ def kernel(x_ref, y_ref): ) ) def test_reshape_four_minor_dims_to_R2(self, p, q, m, n, k, dtype): - if not jtu.if_cloud_tpu_at_least(2025, 7, 12): + if not jtu.is_cloud_tpu_at_least(2025, 7, 12): self.skipTest('Needs a newer libTPU') if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) @@ -3845,7 +3845,7 @@ def kernel(x_ref, y_ref): ) ) def test_reshape_two_minor_dims_preserve_rank(self, q, m, n, k, dtype): - if not jtu.if_cloud_tpu_at_least(2025, 7, 12): + if not jtu.is_cloud_tpu_at_least(2025, 7, 12): self.skipTest('Needs a newer libTPU') if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) @@ -3887,7 +3887,7 @@ def kernel(x_ref, y_ref): def test_reshape_fold_two_leading_dims_and_two_minor_dims_R4_to_R2( self, q, m, n, k, dtype ): - if not jtu.if_cloud_tpu_at_least(2025, 7, 12): + if not jtu.is_cloud_tpu_at_least(2025, 7, 12): self.skipTest('Needs a newer libTPU') if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) @@ -3920,7 +3920,7 @@ def kernel(x_ref, y_ref): def test_reshape_unfold_leading_dim_and_fold_two_minor_dims_R3_to_R3( self, q, m, n, k, dtype ): - if not jtu.if_cloud_tpu_at_least(2025, 7, 12): + if not jtu.is_cloud_tpu_at_least(2025, 7, 12): self.skipTest('Needs a newer libTPU') if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) @@ -3955,7 +3955,7 @@ def kernel(x_ref, y_ref): def test_reshape_unfold_leading_and_minor_dims_R2_to_R4( self, q, m, n, k, dtype ): - if not jtu.if_cloud_tpu_at_least(2025, 7, 12): + if not jtu.is_cloud_tpu_at_least(2025, 7, 12): self.skipTest('Needs a newer libTPU') if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) @@ -3986,7 +3986,7 @@ def kernel(x_ref, y_ref): def test_reshape_fold_leading_dims_and_unfold_minor_dim( self, q, m, n, k, dtype ): - if not jtu.if_cloud_tpu_at_least(2025, 7, 12): + if not jtu.is_cloud_tpu_at_least(2025, 7, 12): self.skipTest('Needs a newer libTPU') if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) @@ -4015,7 +4015,7 @@ def kernel(x_ref, y_ref): ) ) def test_reshape_fold_middle_dims(self, q, m, n, k, dtype): - if not jtu.if_cloud_tpu_at_least(2025, 7, 12): + if not jtu.is_cloud_tpu_at_least(2025, 7, 12): self.skipTest('Needs a newer libTPU') if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) @@ -4044,7 +4044,7 @@ def kernel(x_ref, y_ref): ) ) def test_reshape_unfold_middle_dims(self, q, m, n, k, dtype): - if not jtu.if_cloud_tpu_at_least(2025, 7, 12): + if not jtu.is_cloud_tpu_at_least(2025, 7, 12): self.skipTest('Needs a newer libTPU') if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) @@ -4062,7 +4062,7 @@ def kernel(x_ref, y_ref): @parameterized.parameters([jnp.int8, jnp.bfloat16, jnp.float32]) def test_reshape_shift_factor_from_minor_to_major(self, dtype): - if not jtu.if_cloud_tpu_at_least(2025, 7, 12): + if not jtu.is_cloud_tpu_at_least(2025, 7, 12): self.skipTest('Needs a newer libTPU') if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) @@ -4084,7 +4084,7 @@ def kernel(x_ref, y_ref): dtype=[jnp.float32, jnp.bfloat16, jnp.float8_e4m3fn], ) def test_reshape_fold_minormost_dim(self, dtype): - if not jtu.if_cloud_tpu_at_least(2025, 10, 22): + if not jtu.is_cloud_tpu_at_least(2025, 10, 22): self.skipTest('Needs a newer libTPU') packing = 32 // (8 * np.dtype(dtype).itemsize) @@ -4105,7 +4105,7 @@ def kernel(x_ref, y_ref): def test_dynamic_grid_with_smem_output(self): if self.INTERPRET: self.skipTest('Fail on interpreter.') - if not jtu.if_cloud_tpu_at_least(2025, 11, 3): + if not jtu.is_cloud_tpu_at_least(2025, 11, 3): self.skipTest('Needs a newer libTPU') def body(_, o_ref): diff --git a/tests/pallas/tpu_side_effects_test.py b/tests/pallas/tpu_side_effects_test.py index c5ad3ac6b43f..c5109c66605f 100644 --- a/tests/pallas/tpu_side_effects_test.py +++ b/tests/pallas/tpu_side_effects_test.py @@ -30,7 +30,7 @@ def setUp(self): super().setUp() if not jtu.is_device_tpu(): self.skipTest("TPU required") - if not jtu.if_cloud_tpu_at_least(2025, 11, 11): + if not jtu.is_cloud_tpu_at_least(2025, 11, 11): self.skipTest("Newer libtpu required") @parameterized.named_parameters( diff --git a/tests/pallas/tpu_sparsecore_pallas_test.py b/tests/pallas/tpu_sparsecore_pallas_test.py index f68a92c53c51..c8cc4666e7c2 100644 --- a/tests/pallas/tpu_sparsecore_pallas_test.py +++ b/tests/pallas/tpu_sparsecore_pallas_test.py @@ -549,7 +549,7 @@ def kernel(x_hbm_ref, indices_ref, o_ref): def test_gather_1d_with_dynamically_sized_2d_ref(self): self.skip_if_tc_tiling() - if not jtu.if_cloud_tpu_at_least(2025, 10, 22): + if not jtu.is_cloud_tpu_at_least(2025, 10, 22): self.skipTest("Needs a newer libtpu") x = jnp.arange(16) @@ -770,7 +770,7 @@ def kernel(x_ref, indices_ref, o_ref): ) def test_store_scatter_2d(self): - if not jtu.if_cloud_tpu_at_least(2025, 10, 31): + if not jtu.is_cloud_tpu_at_least(2025, 10, 31): self.skipTest("Needs a newer libtpu") num_steps = 4 @@ -1093,7 +1093,7 @@ def scoped_kernel(scratch_ref): @parameterized.product(sizes=[[1, 1], [2, 2], [1, 1, 1, 1]]) def test_split_concatenate(self, sizes): - if not jtu.if_cloud_tpu_at_least(2025, 10, 26): + if not jtu.is_cloud_tpu_at_least(2025, 10, 26): self.skipTest("Test requires a newer libtpu") shape = (sum(sizes), 8) @@ -1384,7 +1384,7 @@ def kernel(x_ref, o_ref): kernel(x) def test_multiple_of(self): - if not jtu.if_cloud_tpu_at_least(2025, 10, 16): + if not jtu.is_cloud_tpu_at_least(2025, 10, 16): self.skipTest("Test requires a newer libtpu") x = jnp.arange(16) @@ -1426,7 +1426,7 @@ def _(i): np.testing.assert_array_equal(kernel(), expected) def test_barrier_via_pallas_call(self): - if not jtu.if_cloud_tpu_at_least(2025, 11, 22): + if not jtu.is_cloud_tpu_at_least(2025, 11, 22): self.skipTest("Test requires a newer libtpu") mesh = plsc.VectorSubcoreMesh( @@ -1620,7 +1620,7 @@ def kernel(in_ref, o_ref, scratch_ref): @parameterized.named_parameters( ("exp", jnp.exp), ("neg", lambda x: -x), ("abs", jnp.abs)) def test_unary_ops(self, op): - if not jtu.if_cloud_tpu_at_least(2025, 11, 30): + if not jtu.is_cloud_tpu_at_least(2025, 11, 30): self.skipTest("Test requires a newer libtpu") x = jnp.arange(8, dtype=jnp.float32) @@ -1639,7 +1639,7 @@ def sc_exp_kernel(x_hbm_ref, out_ref): @parameterized.product(dtype=[np.int32, np.float32]) def test_vector_gather(self, dtype): - if not jtu.if_cloud_tpu_at_least(2025, 12, 2): + if not jtu.is_cloud_tpu_at_least(2025, 12, 2): self.skipTest("Test requires a newer libtpu") vec_dim = self.sc_info.num_lanes @@ -1656,7 +1656,7 @@ def kernel(x_ref, indices_ref, out_ref): @parameterized.product(dtype=[np.int32, np.float32]) def test_rev_and_sort_desc(self, dtype): - if not jtu.if_cloud_tpu_at_least(2025, 12, 2): + if not jtu.is_cloud_tpu_at_least(2025, 12, 2): self.skipTest("Test requires a newer libtpu") vec_dim = self.sc_info.num_lanes @@ -1678,7 +1678,7 @@ def kernel(x_ref, o1_ref, o2_ref): values_dtypes=[(), (np.int32,), (np.float32, np.int32)], ) def test_sort(self, keys_dtype, values_dtypes): - if not jtu.if_cloud_tpu_at_least(2025, 11, 30): + if not jtu.is_cloud_tpu_at_least(2025, 11, 30): self.skipTest("Test requires a newer libtpu") vec_dim = self.sc_info.num_lanes diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 8fd2253d47ab..b1b2a9b7b81f 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -9580,7 +9580,7 @@ def body(c, _): @jtu.with_explicit_mesh((2,), ('x',)) def test_reduced_sin_fwd_mul_bwd(self, mesh): - if not jtu.if_cloud_tpu_at_least(2025, 11, 7): + if not jtu.is_cloud_tpu_at_least(2025, 11, 7): self.skipTest('Requires libtpu built after 2025-11-6') np_inp1 = np.arange(8.).reshape(4, 2) From 7af0e6439f449b2d105745a0ff819c47e0e5a56c Mon Sep 17 00:00:00 2001 From: Michael Whittaker Date: Thu, 4 Dec 2025 16:43:53 -0800 Subject: [PATCH 062/315] Reverts 2f7c9073f4978373e04289c91b7fdccf0f8c06c7 PiperOrigin-RevId: 840455062 --- jaxlib/BUILD | 1 - jaxlib/jax.cc | 27 +-------------------------- 2 files changed, 1 insertion(+), 27 deletions(-) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index bfcff630cc89..e9cf8b41fc48 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -397,7 +397,6 @@ nanobind_pywrap_extension( "@xla//xla/pjrt/distributed:key_value_store_interface", "@xla//xla/pjrt/distributed:protocol_proto_cc", "@xla//xla/pjrt/distributed:service", - "@xla//xla/pjrt/distributed/preemption:preemption_sync_manager", "@xla//xla/pjrt/plugin/xla_cpu:cpu_client_options", "@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", "@xla//xla/python:logging", diff --git a/jaxlib/jax.cc b/jaxlib/jax.cc index c75cd8dde6b4..1571ddb29bb9 100644 --- a/jaxlib/jax.cc +++ b/jaxlib/jax.cc @@ -116,7 +116,6 @@ limitations under the License. #include "xla/hlo/builder/lib/approx_topk_shape.h" #include "xla/pjrt/c_api_client/pjrt_c_api_client.h" #include "xla/pjrt/distributed/key_value_store_interface.h" -#include "xla/pjrt/distributed/preemption/preemption_sync_manager.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/pjrt_api.h" #include "xla/pjrt/pjrt_client.h" @@ -581,30 +580,6 @@ NB_MODULE(_jax, m) { aux::RegisterTransferServerTypes(m); #endif // defined(__linux__) -#if JAX_IFRT_VERSION_NUMBER >= 38 - nb::class_ preemption_sync_manager( - m, "PreemptionSyncManager"); - preemption_sync_manager - .def( - "initialize", - [](xla::PreemptionSyncManager& manager, - xla::DistributedRuntimeClient* client) { - xla::CoordinationServiceAgent* agent = - xla::ValueOrThrow(client->GetCoordinationServiceAgent()); - xla::ThrowIfError(manager.Initialize(agent)); - }, - nb::arg("distributed_client")) - .def("reached_sync_point", - [](xla::PreemptionSyncManager& manager, int step_counter) { - return manager.ReachedSyncPoint(step_counter); - }) - .def("shutdown", [](xla::PreemptionSyncManager& manager) { - nb::gil_scoped_release gil_release; - manager.Shutdown(); - }); - m.def("create_preemption_sync_manager", - []() { return xla::CreatePreemptionSyncManager(); }); -#else nb::class_ preemption_sync_manager( m, "PreemptionSyncManager"); preemption_sync_manager @@ -627,7 +602,6 @@ NB_MODULE(_jax, m) { }); m.def("create_preemption_sync_manager", []() { return tsl::CreatePreemptionSyncManager(); }); -#endif nb::class_ distributed_runtime_service( m, "DistributedRuntimeService"); @@ -924,6 +898,7 @@ NB_MODULE(_jax, m) { nb::class_( m, "TransferServerInterfaceFactory"); + m.def("is_asan", IsAsan); m.def("is_msan", IsMsan); m.def("is_tsan", IsTsan); From 71d9ec1448309fbbca05d489424e11392c1f1744 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 4 Dec 2025 17:01:02 -0800 Subject: [PATCH 063/315] Update references deprecated JAX build targets PiperOrigin-RevId: 840460409 --- benchmarks/mosaic/BUILD | 2 +- jax/experimental/mosaic/gpu/examples/BUILD | 8 +- tests/BUILD | 50 ++-- tests/mosaic/BUILD | 22 +- tests/multiprocess/BUILD | 2 +- tests/pallas/BUILD | 276 ++++++++++----------- 6 files changed, 180 insertions(+), 180 deletions(-) diff --git a/benchmarks/mosaic/BUILD b/benchmarks/mosaic/BUILD index 39c7aa5f3395..0b14c147d571 100644 --- a/benchmarks/mosaic/BUILD +++ b/benchmarks/mosaic/BUILD @@ -35,7 +35,7 @@ jax_multiplatform_test( enable_configs = ["gpu_h100"], tags = ["notap"], deps = [ - "//jax:mosaic_gpu", + "//jax/experimental:mosaic_gpu", "//jax/experimental/mosaic/gpu/examples:matmul", "//third_party/py/google_benchmark", ] + py_deps("absl/testing") + py_deps("numpy"), diff --git a/jax/experimental/mosaic/gpu/examples/BUILD b/jax/experimental/mosaic/gpu/examples/BUILD index 7edfe1c74db0..42d6c9555018 100644 --- a/jax/experimental/mosaic/gpu/examples/BUILD +++ b/jax/experimental/mosaic/gpu/examples/BUILD @@ -35,7 +35,7 @@ py_library( srcs = ["matmul.py"], deps = [ "//jax", - "//jax:mosaic_gpu", + "//jax/experimental:mosaic_gpu", ], ) @@ -44,7 +44,7 @@ py_library( srcs = ["matmul_blackwell.py"], deps = [ "//jax", - "//jax:mosaic_gpu", + "//jax/experimental:mosaic_gpu", ], ) @@ -53,7 +53,7 @@ py_library( srcs = ["flash_attention.py"], deps = [ "//jax", - "//jax:mosaic_gpu", + "//jax/experimental:mosaic_gpu", ], ) @@ -68,6 +68,6 @@ jax_multiplatform_test( "notap", ], deps = [ - "//jax:mosaic_gpu", + "//jax/experimental:mosaic_gpu", ] + py_deps("numpy"), ) diff --git a/tests/BUILD b/tests/BUILD index 9b0c54777d60..ad8a01ca70da 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -62,13 +62,13 @@ jax_multiplatform_test( enable_configs = ["tpu_v3_x4"], deps = [ "//jax:experimental", - "//jax:pallas", - "//jax:pallas_gpu", - "//jax:pallas_gpu_ops", - "//jax:pallas_tpu", - "//jax:pallas_tpu_ops", "//jax/_src:custom_transpose", "//jax/_src:shard_map", + "//jax/experimental:pallas", + "//jax/experimental:pallas_gpu", + "//jax/experimental:pallas_gpu_ops", + "//jax/experimental:pallas_tpu", + "//jax/experimental:pallas_tpu_ops", ] + py_deps([ "numpy", "absl/testing", @@ -173,7 +173,7 @@ jax_multiplatform_test( "gpu", ], deps = [ - "//jax:experimental_buffer_callback", + "//jax/experimental:buffer_callback", ] + py_deps([ "absl/testing", "numpy", @@ -347,7 +347,7 @@ jax_multiplatform_test( # using matplotlib plots # env = {"LOBPCG_EMIT_DEBUG_PLOTS": "1"}, deps = [ - "//jax:experimental_sparse", + "//jax/experimental:sparse", ] + py_deps([ "matplotlib", "absl/testing", @@ -485,8 +485,8 @@ jax_multiplatform_test( "multiaccelerator", ], deps = [ - "//jax:experimental_profiler", - "//jax:experimental_serialize_executable", + "//jax/experimental:profiler", + "//jax/experimental:serialize_executable", ] + py_deps([ "absl/testing", "numpy", @@ -547,9 +547,9 @@ jax_multiplatform_test( srcs = ["aot_test.py"], tags = ["multiaccelerator"], deps = [ - "//jax:experimental_pjit", - "//jax:experimental_serialize_executable", - "//jax:experimental_topologies", + "//jax/experimental:pjit", + "//jax/experimental:serialize_executable", + "//jax/experimental:topologies", ] + py_deps([ "numpy", "absl/testing", @@ -614,8 +614,8 @@ jax_multiplatform_test( "gpu": 2, }, deps = [ - "//jax:jet", - "//jax:stax", + "//jax/example_libraries:stax", + "//jax/experimental:jet", "//jax/extend:core", ] + py_deps([ "absl/testing", @@ -1085,7 +1085,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "optimizers_test", srcs = ["optimizers_test.py"], - deps = ["//jax:optimizers"] + py_deps([ + deps = ["//jax/example_libraries:optimizers"] + py_deps([ "absl/testing", "numpy", ]), @@ -1457,8 +1457,8 @@ jax_multiplatform_test( "notsan", ], # Test times out under asan/msan/tsan. deps = [ - "//jax:experimental_sparse", - "//jax:sparse_test_util", + "//jax/experimental:sparse", + "//jax/experimental:sparse_test_util", ] + py_deps([ "scipy", "absl/testing", @@ -1505,8 +1505,8 @@ jax_multiplatform_test( "notsan", ], # Test times out under asan/msan/tsan. deps = [ - "//jax:experimental_sparse", - "//jax:sparse_test_util", + "//jax/experimental:sparse", + "//jax/experimental:sparse_test_util", ] + py_deps([ "absl/testing", "numpy", @@ -1533,8 +1533,8 @@ jax_multiplatform_test( "tpu": 10, }, deps = [ - "//jax:experimental_sparse", - "//jax:sparse_test_util", + "//jax/experimental:sparse", + "//jax/experimental:sparse_test_util", ] + py_deps([ "absl/testing", "numpy", @@ -1592,7 +1592,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "stax_test", srcs = ["stax_test.py"], - deps = ["//jax:stax"] + py_deps([ + deps = ["//jax/example_libraries:stax"] + py_deps([ "absl/testing", "numpy", ]), @@ -1720,7 +1720,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "ode_test", srcs = ["ode_test.py"], - deps = ["//jax:ode"] + py_deps([ + deps = ["//jax/experimental:ode"] + py_deps([ "absl/testing", "numpy", "scipy", @@ -1948,7 +1948,7 @@ jax_multiplatform_test( name = "colocated_python_test", srcs = ["colocated_python_test.py"], deps = [ - "//jax:experimental_colocated_python", + "//jax/experimental:colocated_python", "//jax/extend:ifrt_programs", ] + py_deps([ "absl/testing", @@ -1968,7 +1968,7 @@ jax_multiplatform_test( enable_backends = ["gpu"], shard_count = 15, deps = [ - "//jax:rnn", + "//jax/experimental:rnn", ] + py_deps([ "absl/testing", "numpy", diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD index f18c01f91adc..af9aab89da16 100644 --- a/tests/mosaic/BUILD +++ b/tests/mosaic/BUILD @@ -52,7 +52,7 @@ jax_multiplatform_test( "noasan", # Times out. ], deps = [ - "//jax:mosaic_gpu", + "//jax/experimental:mosaic_gpu", "//jax/experimental:mosaic_gpu_test_util", ] + py_deps([ "absl/testing", @@ -74,7 +74,7 @@ jax_multiplatform_test( "noasan", # ASAN is unsupported. ], deps = [ - "//jax:mosaic_gpu", + "//jax/experimental:mosaic_gpu", ] + py_deps([ "absl/testing", "numpy", @@ -93,7 +93,7 @@ jax_multiplatform_test( env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}, tags = ["multiaccelerator"], deps = [ - "//jax:mosaic_gpu", + "//jax/experimental:mosaic_gpu", ] + py_deps([ "absl/testing", "numpy", @@ -116,8 +116,8 @@ jax_multiplatform_test( tags = ["multiaccelerator"], deps = [ "//jax:experimental", - "//jax:mosaic_gpu", "//jax/_src:test_multiprocess", + "//jax/experimental:mosaic_gpu", ] + py_deps([ "absl/testing", "numpy", @@ -130,8 +130,8 @@ jax_py_test( tags = ["gpu_mlir_deviceless_test"], deps = [ "//jax", - "//jax:mosaic_gpu", "//jax/_src:test_util", + "//jax/experimental:mosaic_gpu", ] + py_deps("absl/testing"), ) @@ -140,8 +140,8 @@ jax_py_test( srcs = ["gpu_constraints_test.py"], deps = [ "//jax", - "//jax:mosaic_gpu", "//jax/_src:test_util", + "//jax/experimental:mosaic_gpu", ] + py_deps("absl/testing"), ) @@ -152,8 +152,8 @@ jax_py_test( tags = ["gpu_mlir_deviceless_test"], deps = [ "//jax", - "//jax:mosaic_gpu", "//jax/_src:test_util", + "//jax/experimental:mosaic_gpu", "//jax/experimental:mosaic_gpu_test_util", ] + py_deps("absl/testing"), ) @@ -171,7 +171,7 @@ jax_multiplatform_test( "noasan", ], deps = [ - "//jax:mosaic_gpu", + "//jax/experimental:mosaic_gpu", "//jax/experimental/mosaic/gpu/examples:matmul", "//jax/experimental/mosaic/gpu/examples:matmul_blackwell", ] + py_deps([ @@ -196,7 +196,7 @@ jax_multiplatform_test( "notap", ], deps = [ - "//jax:mosaic_gpu", + "//jax/experimental:mosaic_gpu", ] + py_deps([ "numpy", "absl/testing", @@ -216,7 +216,7 @@ jax_multiplatform_test( "noasan", # Remove the tag once the CUPTI issue is fixed. ], deps = [ - "//jax:mosaic_gpu", + "//jax/experimental:mosaic_gpu", "//jax/experimental/mosaic/gpu/examples:flash_attention", ] + py_deps("absl/testing"), ) @@ -234,6 +234,6 @@ jax_multiplatform_test( "nomsan", ], deps = [ - "//jax:mosaic_gpu", + "//jax/experimental:mosaic_gpu", ] + py_deps("absl/testing"), ) diff --git a/tests/multiprocess/BUILD b/tests/multiprocess/BUILD index 1b451bd505ab..5d5d90a5f321 100644 --- a/tests/multiprocess/BUILD +++ b/tests/multiprocess/BUILD @@ -88,8 +88,8 @@ jax_multiprocess_test( ], main = "colocated_python_test.py", deps = [ - "//jax:experimental_colocated_python", "//jax/_src:test_multiprocess", + "//jax/experimental:colocated_python", ], ) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index d058ea22746c..7f1182236615 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -56,12 +56,12 @@ jax_multiplatform_test( "tpu": 4, }, deps = [ - "//jax:pallas", - "//jax:pallas_gpu", - "//jax:pallas_gpu_ops", - "//jax:pallas_tpu", - "//jax:pallas_tpu_ops", "//jax/_src/pallas:pallas_test_util", + "//jax/experimental:pallas", + "//jax/experimental:pallas_gpu", + "//jax/experimental:pallas_gpu_ops", + "//jax/experimental:pallas_tpu", + "//jax/experimental:pallas_tpu_ops", ] + py_deps([ "absl/testing", "numpy", @@ -75,8 +75,8 @@ jax_py_test( ], args = ["--jax_test_dut=cpu"], deps = [ - "//jax:pallas", "//jax/_src:test_util", + "//jax/experimental:pallas", ] + py_deps([ "absl/testing", "numpy", @@ -98,9 +98,9 @@ jax_multiplatform_test( "gpu_b200", ], deps = [ - "//jax:pallas", - "//jax:pallas_tpu", - "//jax:pallas_tpu_ops", + "//jax/experimental:pallas", + "//jax/experimental:pallas_tpu", + "//jax/experimental:pallas_tpu_ops", ] + py_deps([ "absl/testing", "numpy", @@ -140,11 +140,11 @@ jax_multiplatform_test( "notsan", # Times out. ], deps = [ - "//jax:pallas", - "//jax:pallas_gpu", # build_cleaner: keep - "//jax:pallas_tpu", - "//jax:pallas_tpu_ops", "//jax/_src/pallas:pallas_test_util", + "//jax/experimental:pallas", + "//jax/experimental:pallas_gpu", # build_cleaner: keep + "//jax/experimental:pallas_tpu", + "//jax/experimental:pallas_tpu_ops", ] + py_deps([ "absl/testing:flagsaver", "absl/testing", @@ -185,11 +185,11 @@ jax_multiplatform_test( "notsan", # Times out. ], deps = [ - "//jax:pallas", - "//jax:pallas_gpu", # build_cleaner: keep - "//jax:pallas_mosaic_gpu", # build_cleaner: keep - "//jax:pallas_tpu", "//jax/_src/pallas:pallas_test_util", + "//jax/experimental:pallas", + "//jax/experimental:pallas_gpu", # build_cleaner: keep + "//jax/experimental:pallas_mosaic_gpu", # build_cleaner: keep + "//jax/experimental:pallas_tpu", ] + py_deps([ "absl/testing:flagsaver", "absl/testing", @@ -213,8 +213,8 @@ jax_multiplatform_test( "notsan", # Times out. ], deps = [ - "//jax:pallas", - "//jax:pallas_tpu", + "//jax/experimental:pallas", + "//jax/experimental:pallas_tpu", ] + py_deps([ "absl/testing", "hypothesis", @@ -235,11 +235,11 @@ jax_multiplatform_test( ], shard_count = 4, deps = [ - "//jax:pallas", - "//jax:pallas_gpu", - "//jax:pallas_gpu_ops", - "//jax:pallas_tpu", - "//jax:pallas_tpu_ops", + "//jax/experimental:pallas", + "//jax/experimental:pallas_gpu", + "//jax/experimental:pallas_gpu_ops", + "//jax/experimental:pallas_tpu", + "//jax/experimental:pallas_tpu_ops", ] + py_deps([ "absl/testing", "numpy", @@ -268,8 +268,8 @@ jax_multiplatform_test( "notsan", # Times out. ], deps = [ - "//jax:pallas", - "//jax:pallas_mosaic_gpu", # build_cleaner: keep + "//jax/experimental:pallas", + "//jax/experimental:pallas_mosaic_gpu", # build_cleaner: keep ] + py_deps([ "absl/testing", "numpy", @@ -287,12 +287,12 @@ jax_multiplatform_test( ], tags = [], deps = [ - "//jax:pallas", - "//jax:pallas_gpu", # build_cleaner: keep - "//jax:pallas_mosaic_gpu", # build_cleaner: keep - "//jax:pallas_tpu_ops", # build_cleaner: keep "//jax/_src:internal_export_back_compat_test_data", "//jax/_src:internal_export_back_compat_test_util", + "//jax/experimental:pallas", + "//jax/experimental:pallas_gpu", # build_cleaner: keep + "//jax/experimental:pallas_mosaic_gpu", # build_cleaner: keep + "//jax/experimental:pallas_tpu_ops", # build_cleaner: keep ] + py_deps("absl/testing"), ) @@ -302,11 +302,11 @@ jax_py_test( args = ["--jax_test_dut=cpu"], main = "export_pallas_test.py", deps = [ - "//jax:pallas", - "//jax:pallas_gpu", # build_cleaner: keep - "//jax:pallas_mosaic_gpu", # build_cleaner: keep - "//jax:pallas_tpu", # build_cleaner: keep "//jax/_src:test_util", + "//jax/experimental:pallas", + "//jax/experimental:pallas_gpu", # build_cleaner: keep + "//jax/experimental:pallas_mosaic_gpu", # build_cleaner: keep + "//jax/experimental:pallas_tpu", # build_cleaner: keep ] + jax_gpu_support_deps + py_deps([ "absl/testing", "numpy", @@ -326,10 +326,10 @@ jax_multiplatform_test( ], tags = [], deps = [ - "//jax:pallas", - "//jax:pallas_gpu", # build_cleaner: keep - "//jax:pallas_mosaic_gpu", # build_cleaner: keep - "//jax:pallas_tpu", # build_cleaner: keep + "//jax/experimental:pallas", + "//jax/experimental:pallas_gpu", # build_cleaner: keep + "//jax/experimental:pallas_mosaic_gpu", # build_cleaner: keep + "//jax/experimental:pallas_tpu", # build_cleaner: keep ] + py_deps([ "absl/testing", "numpy", @@ -350,9 +350,9 @@ jax_multiplatform_test( ], tags = [], deps = [ - "//jax:pallas", - "//jax:pallas_gpu", # build_cleaner: keep - "//jax:pallas_tpu", # build_cleaner: keep + "//jax/experimental:pallas", + "//jax/experimental:pallas_gpu", # build_cleaner: keep + "//jax/experimental:pallas_tpu", # build_cleaner: keep ] + py_deps([ "absl/testing", "numpy", @@ -366,9 +366,9 @@ jax_multiplatform_test( ], enable_backends = ["tpu"], deps = [ - "//jax:pallas", - "//jax:pallas_tpu", "//jax/_src/pallas/mosaic:random", + "//jax/experimental:pallas", + "//jax/experimental:pallas_tpu", ] + py_deps([ "absl/testing", "numpy", @@ -385,8 +385,8 @@ jax_multiplatform_test( "tpu_v5e_x8", ], deps = [ - "//jax:mesh_utils", - "//jax:pallas_tpu_ops", + "//jax/experimental:mesh_utils", + "//jax/experimental:pallas_tpu_ops", ] + py_deps([ "absl/testing", "numpy", @@ -409,7 +409,7 @@ jax_multiplatform_test( "notsan", # Times out. ], deps = [ - "//jax:pallas_tpu_ops", + "//jax/experimental:pallas_tpu_ops", ] + py_deps([ "absl/testing", "absl/flags", @@ -431,10 +431,10 @@ jax_multiplatform_test( ], shard_count = 8, deps = [ - "//jax:mesh_utils", - "//jax:pallas_tpu", - "//jax:pallas_tpu_ops", "//jax/_src/pallas:pallas_test_util", + "//jax/experimental:mesh_utils", + "//jax/experimental:pallas_tpu", + "//jax/experimental:pallas_tpu_ops", "//jax/extend", ] + py_deps([ "absl/testing", @@ -460,9 +460,9 @@ jax_multiplatform_test( ), shard_count = 8, deps = [ - "//jax:pallas_tpu", - "//jax:pallas_tpu_ops", "//jax/_src/pallas:pallas_test_util", + "//jax/experimental:pallas_tpu", + "//jax/experimental:pallas_tpu_ops", "//jax/extend", ] + py_deps([ "absl/testing", @@ -491,9 +491,9 @@ jax_multiplatform_test( "multiaccelerator", ], deps = [ - "//jax:pallas_mosaic_gpu", "//jax/_src:test_multiprocess", "//jax/experimental:pallas_experimental_gpu_ops", + "//jax/experimental:pallas_mosaic_gpu", "//jax/extend", ] + py_deps([ "portpicker", @@ -510,11 +510,11 @@ jax_multiplatform_test( enable_backends = ["tpu"], shard_count = 8, deps = [ - "//jax:pallas", - "//jax:pallas_gpu", # build_cleaner: keep - "//jax:pallas_tpu", - "//jax:pallas_tpu_ops", "//jax/_src/pallas:pallas_test_util", + "//jax/experimental:pallas", + "//jax/experimental:pallas_gpu", # build_cleaner: keep + "//jax/experimental:pallas_tpu", + "//jax/experimental:pallas_tpu_ops", ] + py_deps([ "absl/testing", "hypothesis", @@ -528,9 +528,9 @@ jax_multiplatform_test( enable_backends = ["tpu"], tags = ["multiaccelerator"], deps = [ - "//jax:mesh_utils", - "//jax:pallas_tpu", - "//jax:pallas_tpu_ops", + "//jax/experimental:mesh_utils", + "//jax/experimental:pallas_tpu", + "//jax/experimental:pallas_tpu_ops", "//jax/extend", ] + py_deps([ "absl/testing", @@ -562,9 +562,9 @@ jax_multiplatform_test( "notsan", # Times out. ], deps = [ - "//jax:mesh_utils", - "//jax:pallas_tpu", - "//jax:pallas_tpu_ops", + "//jax/experimental:mesh_utils", + "//jax/experimental:pallas_tpu", + "//jax/experimental:pallas_tpu_ops", "//jax/extend", ] + py_deps([ "absl/testing", @@ -583,7 +583,7 @@ jax_multiplatform_test( "tpu_v5p_x4", ], deps = [ - "//jax:pallas_tpu", + "//jax/experimental:pallas_tpu", ] + py_deps([ "absl/testing", "numpy", @@ -599,7 +599,7 @@ jax_multiplatform_test( ], deps = [ "//jax:experimental", - "//jax:pallas_tpu", + "//jax/experimental:pallas_tpu", ] + py_deps([ "absl/testing", "numpy", @@ -629,7 +629,7 @@ jax_multiplatform_test( "notsan", ], deps = [ - "//jax:pallas_tpu", + "//jax/experimental:pallas_tpu", "//jax/extend", ] + py_deps([ "absl/testing", @@ -647,10 +647,10 @@ jax_multiplatform_test( "tpu_v5p_x4", ], deps = [ - "//jax:pallas", - "//jax:pallas_tpu", - "//jax:pallas_tpu_ops", "//jax/_src/pallas/mosaic:random", + "//jax/experimental:pallas", + "//jax/experimental:pallas_tpu", + "//jax/experimental:pallas_tpu_ops", ] + py_deps([ "absl/testing", "numpy", @@ -665,8 +665,8 @@ jax_multiplatform_test( enable_backends = ["cpu"], deps = [ "//jax:experimental", - "//jax:pallas", - "//jax:pallas_tpu", + "//jax/experimental:pallas", + "//jax/experimental:pallas_tpu", ] + py_deps([ "absl/testing", "numpy", @@ -681,8 +681,8 @@ jax_multiplatform_test( enable_backends = ["cpu"], deps = [ "//jax:experimental", - "//jax:pallas", - "//jax:pallas_tpu", + "//jax/experimental:pallas", + "//jax/experimental:pallas_tpu", ] + py_deps([ "absl/testing", "numpy", @@ -696,8 +696,8 @@ jax_multiplatform_test( ], enable_backends = ["cpu"], deps = [ - "//jax:pallas", - "//jax:pallas_tpu", + "//jax/experimental:pallas", + "//jax/experimental:pallas_tpu", ] + py_deps([ "absl/testing", "numpy", @@ -719,7 +719,7 @@ jax_multiplatform_test( "notsan", # Times out. ], deps = [ - "//jax:pallas_tpu_ops", + "//jax/experimental:pallas_tpu_ops", ] + py_deps([ "absl/testing", "numpy", @@ -741,7 +741,7 @@ jax_multiplatform_test( "notsan", # Times out. ], deps = [ - "//jax:pallas_tpu_ops", + "//jax/experimental:pallas_tpu_ops", ] + py_deps([ "absl/testing", "numpy", @@ -762,8 +762,8 @@ jax_multiplatform_test( "notsan", # Times out. ], deps = [ - "//jax:pallas_tpu_ops", "//jax/_src/pallas:pallas_test_util", + "//jax/experimental:pallas_tpu_ops", ] + py_deps([ "absl/testing", "numpy", @@ -780,9 +780,9 @@ jax_multiplatform_test( ], shard_count = 10, deps = [ - "//jax:pallas_tpu", - "//jax:pallas_tpu_ops", "//jax/_src/pallas:pallas_test_util", + "//jax/experimental:pallas_tpu", + "//jax/experimental:pallas_tpu_ops", "//jax/extend", ] + py_deps([ "absl/testing", @@ -798,8 +798,8 @@ jax_py_test( ], deps = [ "//jax", - "//jax:pallas_tpu_ops", "//jax/_src:test_util", + "//jax/experimental:pallas_tpu_ops", ] + py_deps([ "absl/testing", "numpy", @@ -823,9 +823,9 @@ jax_multiplatform_test( "noasan", # Times out. ], deps = [ - "//jax:pallas", - "//jax:pallas_gpu", # build_cleaner: keep - "//jax:pallas_gpu_ops", + "//jax/experimental:pallas", + "//jax/experimental:pallas_gpu", # build_cleaner: keep + "//jax/experimental:pallas_gpu_ops", ] + py_deps([ "absl/testing", "numpy", @@ -850,9 +850,9 @@ jax_multiplatform_test( "notsan", # Times out. ], deps = [ - "//jax:pallas", - "//jax:pallas_gpu", - "//jax:pallas_gpu_ops", + "//jax/experimental:pallas", + "//jax/experimental:pallas_gpu", + "//jax/experimental:pallas_gpu_ops", ] + py_deps([ "absl/testing", "numpy", @@ -876,9 +876,9 @@ jax_multiplatform_test( "notsan", # Times out. ], deps = [ - "//jax:pallas", - "//jax:pallas_gpu", - "//jax:pallas_gpu_ops", + "//jax/experimental:pallas", + "//jax/experimental:pallas_gpu", + "//jax/experimental:pallas_gpu_ops", ] + py_deps([ "absl/testing", "numpy", @@ -897,8 +897,8 @@ jax_multiplatform_test( ], shard_count = 1, deps = [ - "//jax:pallas", - "//jax:pallas_gpu", + "//jax/experimental:pallas", + "//jax/experimental:pallas_gpu", ] + py_deps([ "absl/testing", ]), @@ -918,8 +918,8 @@ jax_multiplatform_test( "notap", ], deps = [ - "//jax:pallas", - "//jax:pallas_mosaic_gpu", + "//jax/experimental:pallas", + "//jax/experimental:pallas_mosaic_gpu", ] + py_deps([ "absl/testing", "numpy", @@ -941,9 +941,9 @@ jax_multiplatform_test( "mosaic_gpu_test", ], deps = [ - "//jax:pallas", - "//jax:pallas_experimental_gpu_ops", - "//jax:pallas_mosaic_gpu", + "//jax/experimental:pallas", + "//jax/experimental:pallas_experimental_gpu_ops", + "//jax/experimental:pallas_mosaic_gpu", ] + py_deps([ "absl/testing", "numpy", @@ -963,9 +963,9 @@ jax_multiplatform_test( "notap", ], deps = [ - "//jax:pallas", - "//jax:pallas_experimental_gpu_ops", - "//jax:pallas_mosaic_gpu", + "//jax/experimental:pallas", + "//jax/experimental:pallas_experimental_gpu_ops", + "//jax/experimental:pallas_mosaic_gpu", ] + py_deps([ "absl/testing", "numpy", @@ -984,9 +984,9 @@ jax_multiplatform_test( shard_count = 8, tags = ["mosaic_gpu_test"], deps = [ - "//jax:pallas", - "//jax:pallas_experimental_gpu_ops", - "//jax:pallas_mosaic_gpu", + "//jax/experimental:pallas", + "//jax/experimental:pallas_experimental_gpu_ops", + "//jax/experimental:pallas_mosaic_gpu", ] + py_deps([ "absl/testing", "numpy", @@ -1004,8 +1004,8 @@ jax_multiplatform_test( "notap", ], deps = [ - "//jax:pallas", - "//jax:pallas_mosaic_gpu", + "//jax/experimental:pallas", + "//jax/experimental:pallas_mosaic_gpu", ] + py_deps([ "absl/testing", "numpy", @@ -1023,8 +1023,8 @@ jax_multiplatform_test( "notap", ], deps = [ - "//jax:pallas", - "//jax:pallas_mosaic_gpu", + "//jax/experimental:pallas", + "//jax/experimental:pallas_mosaic_gpu", ] + py_deps([ "absl/testing", "numpy", @@ -1041,8 +1041,8 @@ jax_multiplatform_test( "notap", ], deps = [ - "//jax:pallas", - "//jax:pallas_mosaic_gpu", + "//jax/experimental:pallas", + "//jax/experimental:pallas_mosaic_gpu", ] + py_deps([ "absl/testing", "numpy", @@ -1065,8 +1065,8 @@ jax_multiplatform_test( "notap", ], deps = [ - "//jax:pallas", - "//jax:pallas_mosaic_gpu", + "//jax/experimental:pallas", + "//jax/experimental:pallas_mosaic_gpu", ] + py_deps([ "absl/testing", "numpy", @@ -1088,8 +1088,8 @@ jax_multiplatform_test( "notap", ], deps = [ - "//jax:pallas", - "//jax:pallas_mosaic_gpu", + "//jax/experimental:pallas", + "//jax/experimental:pallas_mosaic_gpu", ] + py_deps("absl/testing") + py_deps("numpy"), ) @@ -1109,8 +1109,8 @@ jax_multiplatform_test( "notap", ], deps = [ - "//jax:pallas", - "//jax:pallas_mosaic_gpu", + "//jax/experimental:pallas", + "//jax/experimental:pallas_mosaic_gpu", ] + py_deps("absl/testing") + py_deps("numpy"), ) @@ -1130,9 +1130,9 @@ jax_multiplatform_test( "noasan", # Times out. ], deps = [ - "//jax:pallas", - "//jax:pallas_experimental_gpu_ops", - "//jax:pallas_mosaic_gpu", + "//jax/experimental:pallas", + "//jax/experimental:pallas_experimental_gpu_ops", + "//jax/experimental:pallas_mosaic_gpu", ] + py_deps([ "absl/testing", "numpy", @@ -1164,10 +1164,10 @@ jax_multiplatform_test( ], deps = [ "//jax:experimental", - "//jax:pallas", - "//jax:pallas_experimental_gpu_ops", - "//jax:pallas_mosaic_gpu", "//jax/_src:test_multiprocess", + "//jax/experimental:pallas", + "//jax/experimental:pallas_experimental_gpu_ops", + "//jax/experimental:pallas_mosaic_gpu", ] + py_deps("absl/testing") + py_deps("numpy"), ) @@ -1183,9 +1183,9 @@ jax_multiplatform_test( "noasan", ], deps = [ - "//jax:pallas", - "//jax:pallas_experimental_gpu_ops", - "//jax:pallas_mosaic_gpu", + "//jax/experimental:pallas", + "//jax/experimental:pallas_experimental_gpu_ops", + "//jax/experimental:pallas_mosaic_gpu", ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("torch"), ) @@ -1211,10 +1211,10 @@ jax_multiplatform_test( ], deps = [ "//jax:experimental", - "//jax:pallas", - "//jax:pallas_experimental_gpu_ops", - "//jax:pallas_mosaic_gpu", "//jax/_src:test_multiprocess", + "//jax/experimental:pallas", + "//jax/experimental:pallas_experimental_gpu_ops", + "//jax/experimental:pallas_mosaic_gpu", ], ) @@ -1240,10 +1240,10 @@ jax_multiplatform_test( ], deps = [ "//jax:experimental", - "//jax:pallas", - "//jax:pallas_experimental_gpu_ops", - "//jax:pallas_mosaic_gpu", "//jax/_src:test_multiprocess", + "//jax/experimental:pallas", + "//jax/experimental:pallas_experimental_gpu_ops", + "//jax/experimental:pallas_mosaic_gpu", ], ) @@ -1268,9 +1268,9 @@ jax_multiplatform_test( "multiaccelerator", ], deps = [ - "//jax:pallas_mosaic_gpu", "//jax/_src:test_multiprocess", "//jax/experimental:pallas_experimental_gpu_ops", + "//jax/experimental:pallas_mosaic_gpu", "//jax/extend", ] + py_deps([ "portpicker", @@ -1294,8 +1294,8 @@ jax_multiplatform_test( "notsan", ], deps = [ - "//jax:pallas", "//jax/_src/pallas/fuser", + "//jax/experimental:pallas", ] + py_deps([ "absl/testing", "numpy", @@ -1317,8 +1317,8 @@ jax_multiplatform_test( "notsan", ], deps = [ - "//jax:pallas", - "//jax:pallas_fuser", + "//jax/experimental:pallas", + "//jax/experimental:pallas_fuser", ] + py_deps([ "absl/testing", "numpy", @@ -1339,9 +1339,9 @@ jax_multiplatform_test( "notsan", ], deps = [ - "//jax:pallas_tpu", - "//jax:pallas_tpu_ops", "//jax/experimental:pallas_fuser", + "//jax/experimental:pallas_tpu", + "//jax/experimental:pallas_tpu_ops", ] + py_deps([ "absl/testing", "numpy", @@ -1387,8 +1387,8 @@ jax_multiplatform_test( ], shard_count = 3, deps = [ - "//jax:pallas", - "//jax:pallas_tpu", + "//jax/experimental:pallas", + "//jax/experimental:pallas_tpu", "//jax/experimental:pallas_tpu_sc", ] + py_deps([ "numpy", @@ -1411,7 +1411,7 @@ jax_multiplatform_test( "tpu_v6e_x8", ], deps = [ - "//jax:mesh_utils", + "//jax/experimental:mesh_utils", "//jax/experimental:pallas_tpu", "//jax/experimental:pallas_tpu_sc", ] + py_deps([ From 6168f5cd8b4c66ae646b249a1c9edf943982272c Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Thu, 4 Dec 2025 18:52:27 -0800 Subject: [PATCH 064/315] [PjRt-IFRT] Tighten `PjRtExecutable::Create()` to take an MLIR module `xla::ifrt::PjRtExecutable::Create()` was taking already constructed `xla::PjRtExecutable`. This prevents `xla::ifrt::PjRtExecutable` from extracting useful information such as whether a default layout is used for outputs that is not available in the already constructed `xla::PjRtExecutable`. By taking a module and performing compilation inside, the `Create()` has access to the original MLIR module information. This is a prerequisite for improving the serialization of `xla::ifrt::PjRtLoadedExecutable` and `xla::ifrt::PjRtExecutable` to use `xla::ifrt::SerializedXlaExecutableMetadata` header. Ultimately, the richer executable metadata information will be used for unifying semantics for default layout handling across IFRT runtimes. PiperOrigin-RevId: 840491451 --- jaxlib/py_client.cc | 8 ++++++++ jaxlib/py_compile_only_client.cc | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/jaxlib/py_client.cc b/jaxlib/py_client.cc index c5b0bfe91779..71b89d626cf2 100644 --- a/jaxlib/py_client.cc +++ b/jaxlib/py_client.cc @@ -93,6 +93,7 @@ limitations under the License. #include "xla/python/pjrt_ifrt/pjrt_executable.h" #include "xla/python/pjrt_ifrt/xla_compiler.h" #include "xla/python/types.h" +#include "xla/python/version.h" #include "xla/service/platform_util.h" // IWYU pragma: keep #include "xla/service/spmd/shardy/utils.h" // IWYU pragma: keep #include "xla/shape.h" @@ -504,11 +505,18 @@ PyClient::CompileAndLoadIfrtProgram( client->ifrt_client()->GetTopologyForDevices(executable_devices)); auto xla_options = std::make_unique( options, std::move(executable_devices)); +#if JAX_IFRT_VERSION_NUMBER >= 38 + TF_ASSIGN_OR_RETURN( + executable_ref, + ifrt::PjRtExecutable::Create(std::move(module), std::move(options), + *topology->description())); +#else TF_ASSIGN_OR_RETURN( auto pjrt_executable, PjRtCompile(std::move(options), module, *topology->description())); TF_ASSIGN_OR_RETURN(executable_ref, ifrt::PjRtExecutable::Create( std::move(pjrt_executable))); +#endif } return make_nb_class(executable_ref); } diff --git a/jaxlib/py_compile_only_client.cc b/jaxlib/py_compile_only_client.cc index 349d62f23271..d38f19e35c83 100644 --- a/jaxlib/py_compile_only_client.cc +++ b/jaxlib/py_compile_only_client.cc @@ -43,6 +43,7 @@ limitations under the License. #include "xla/python/pjrt_ifrt/pjrt_executable.h" #include "xla/python/pjrt_ifrt/pjrt_topology.h" #include "xla/python/pjrt_ifrt/xla_compiler.h" +#include "xla/python/version.h" #include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/statusor.h" #include "xla/tsl/python/lib/core/numpy.h" @@ -78,11 +79,18 @@ absl::StatusOr> CompileOnlyPyClient::CompileUnloaded( auto xla_options = std::make_unique( options, std::move(executable_devices)); +#if JAX_IFRT_VERSION_NUMBER >= 38 + TF_ASSIGN_OR_RETURN( + ifrt_executable, + ifrt::PjRtExecutable::Create(std::move(module), std::move(options), + *ifrt_client->topology().description())); +#else TF_ASSIGN_OR_RETURN(auto executable, PjRtCompile(std::move(options), module, *ifrt_client->topology().description())); TF_ASSIGN_OR_RETURN(ifrt_executable, ifrt::PjRtExecutable::Create(std::move(executable))); +#endif } return make_nb_class(ifrt_executable); } From fa27af7b4c5263bae50abfff9103e312e473333d Mon Sep 17 00:00:00 2001 From: Yulia Baturina Date: Thu, 4 Dec 2025 20:29:27 -0800 Subject: [PATCH 065/315] Remove `testonly` attribute from `pallas_test_util` to include the target in the additional wheel dependencies. PiperOrigin-RevId: 840521740 --- jax/_src/pallas/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/jax/_src/pallas/BUILD b/jax/_src/pallas/BUILD index 074002f3c935..469b31d9b326 100644 --- a/jax/_src/pallas/BUILD +++ b/jax/_src/pallas/BUILD @@ -69,7 +69,6 @@ py_library( py_library( name = "pallas_test_util", - testonly = True, srcs = [ "pallas_test_util.py", ], From 167dc0df62a7a767f636cee3bbfff9ed4a82c01d Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 5 Dec 2025 00:06:20 -0800 Subject: [PATCH 066/315] Update XLA dependency to use revision http://github.com/openxla/xla/commit/0158b0d5b1911f7e2a8de06d3c1f855d95ca5ec6 PiperOrigin-RevId: 840582739 --- third_party/xla/revision.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/revision.bzl b/third_party/xla/revision.bzl index 1804835d507a..0733d88d6e4e 100644 --- a/third_party/xla/revision.bzl +++ b/third_party/xla/revision.bzl @@ -21,5 +21,5 @@ # and update XLA_SHA256 with the result. # buildifier: disable=module-docstring -XLA_COMMIT = "c6315cd85539fac2c08ed33dbe25e006f41ce72b" -XLA_SHA256 = "d1875b989ed12f511deec609b2592698dfb3af4f49cb1a470db42c6658ac83f6" +XLA_COMMIT = "0158b0d5b1911f7e2a8de06d3c1f855d95ca5ec6" +XLA_SHA256 = "18188dd12346c55f043c9617089aae329408ddeb611b035c6852714b679e5d7a" From a83b4d6859268f7e67ed071d48ea03f7e0afb391 Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Fri, 5 Dec 2025 02:00:02 -0800 Subject: [PATCH 067/315] [Mosaic GPU][NFC] Refactor: Use a macro for CUDA error checking. This change introduces a `CUDA_RETURN_IF_ERROR` macro to simplify error handling for CUDA API calls, replacing repetitive checks with a more concise form. PiperOrigin-RevId: 840628538 --- jaxlib/mosaic/gpu/custom_call.cc | 36 +++++++++++++++++--------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 06dc1f0491c5..29931d56e139 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -360,24 +360,31 @@ GetAssemblyToBinaryCompilationProvider() { return (*compilation_provider)->get(); } +std::string CUDAErrorString(CUresult result) { + const char* error; + cuGetErrorString(result, &error); + return error; +} +// Returns if the CUDA expression returns an error. +#define CUDA_RETURN_IF_ERROR(stmt) \ + do { \ + if (CUresult result = stmt; result != CUDA_SUCCESS) { \ + return absl::InternalError(CUDAErrorString(result)); \ + } \ + } while (0) + absl::StatusOr GetCudaComputeCapability() { // Assumes driver has been initialized and a context exists. XLA already has // some utilities to query this, but we try to stay runtime-agnostic, so we // build our own here. CUdevice device; - if (cuCtxGetDevice(&device) != CUDA_SUCCESS) { - return absl::InternalError("Failed to get device for current context"); - } + CUDA_RETURN_IF_ERROR(cuCtxGetDevice(&device)); int major = 0; - if (cuDeviceGetAttribute(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, - device) != CUDA_SUCCESS) { - return absl::InternalError("Failed to get major compute capability"); - } + CUDA_RETURN_IF_ERROR(cuDeviceGetAttribute( + &major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device)); int minor = 0; - if (cuDeviceGetAttribute(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, - device) != CUDA_SUCCESS) { - return absl::InternalError("Failed to get minor compute capability"); - } + CUDA_RETURN_IF_ERROR(cuDeviceGetAttribute( + &minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device)); TF_ASSIGN_OR_RETURN(std::string sm, mosaic::gpu::GetSmVersion(major, minor)); bool has_accelerated_features = absl::EndsWith(sm, "a"); @@ -692,12 +699,7 @@ absl::Status MosaicGpuExecute(cudaStream_t stream, ffi::RemainingArgs inputs, KernelHash hash; std::memcpy(hash.data(), kernel_hash.data(), sizeof(KernelHash)); CUcontext ctx; - if (auto result = cuCtxGetCurrent(&ctx); result != CUDA_SUCCESS) { - const char* error; - cuGetErrorString(result, &error); - return absl::InternalError( - absl::StrFormat("Failed to get current CUDA context: %s", error)); - } + CUDA_RETURN_IF_ERROR(cuCtxGetCurrent(&ctx)); CacheKey key(hash, reinterpret_cast(ctx)); TF_ASSIGN_OR_RETURN(auto compiled_kernel, CachedCompileAndInit(key, module)); From 92621c77d9a498afafc0b38ab3980db4b0a553b1 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 5 Dec 2025 02:24:34 -0800 Subject: [PATCH 068/315] [Mosaic GPU] Optimize the computation of tcgen05.mma matrix descriptors Previously we used a simple approach of computing the descriptors entirely using LLVM ops. This was convenient, but it turns out that there are two problems with it: 1. LLVM doesn't always fully constant fold properly and sometimes emits PTX that causes ptxas to generate lots of non-uniform operations. 2. LLVM is quite aggressive to hoist descriptor computation outside of loops, which blows up the register pressure. The alternative implemented here is to compute the descriptors in inline ptx, with manual constant folding, and right before the MMA operations. This seems to generate code that has extremely low register pressure and only very few uniform operations on 32-bit quantities. PiperOrigin-RevId: 840636187 --- jax/experimental/mosaic/gpu/mma_utils.py | 21 +++++++++-- jax/experimental/mosaic/gpu/tcgen05.py | 48 +++++++++++++++--------- 2 files changed, 48 insertions(+), 21 deletions(-) diff --git a/jax/experimental/mosaic/gpu/mma_utils.py b/jax/experimental/mosaic/gpu/mma_utils.py index d4e04fdc67ec..ebd789ca348c 100644 --- a/jax/experimental/mosaic/gpu/mma_utils.py +++ b/jax/experimental/mosaic/gpu/mma_utils.py @@ -48,6 +48,7 @@ def create_descriptor( # Soft deprecated. Use small tiling instead. large_tile: tuple[int, int] | None = None, mma_bytewidth_k: int = 32, + split_const: bool = False, ): ref_ty = ir.MemRefType(ref.type) element_bitwidth = utils.bitwidth(ref_ty.element_type) @@ -183,6 +184,7 @@ def to_byte_stride(stride: int): leading_byte_offset=leading_byte_offset, stride_byte_offset=stride_byte_offset, swizzle=swizzle, + split_const=split_const, ) mn_tiles_per_group, rem = divmod(mn_group_size, mn_tiling) @@ -221,7 +223,9 @@ def encode_descriptor( stride_byte_offset: int, swizzle: int | mgpu_dialect.SwizzlingMode | None, const_init: int = 0, + split_const: bool = False, ): + i32 = ir.IntegerType.get_signless(32) i64 = ir.IntegerType.get_signless(64) if isinstance(ref_arg.type, ir.MemRefType): ptr = utils.memref_ptr(ref_arg, 3) @@ -246,7 +250,18 @@ def encode_descriptor( const_init | (encode_addr(leading_byte_offset) << 16) | (encode_addr(stride_byte_offset) << 32) + | (swizzle_encoding << 62) ) - desc = llvm.or_(arith.shli(c(swizzle_encoding), c(62)), c(desc_const)) - desc = llvm.or_(encoded_base_addr, desc) - return desc + if split_const: + # The encoded base addr fits within a single 32-bit register. + return arith.trunci(i32, encoded_base_addr), desc_const + else: + # The desc_const frequently has the MSB set, leading to errors when trying + # to create ir.IntegerAttr through the MLIR python bindings... This should + # be easy enough for LLVM to constant fold away. + if desc_const >> 63: + desc_val = c(desc_const & 0xFFFFFFFF) + desc_val = llvm.or_(desc_val, arith.shli(c(desc_const >> 32), c(32))) + else: + desc_val = c(desc_const) + return llvm.or_(encoded_base_addr, desc_val) diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index 0897021a28ae..79fffad7f511 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -447,6 +447,7 @@ def mma( group_size=(m_group_elems, k_group_elems // (1 + is_sparse)), logical_k_major=False, mma_bytewidth_k=32, + split_const=True, ) else: a_fastest = mma_utils.Dim.K @@ -462,6 +463,7 @@ def mma( group_size=(k_group_elems, n_group_elems), logical_k_major=True, mma_bytewidth_k=64 if is_sparse else 32, + split_const=True, ) if is_scaled and utils.bitwidth(mma_element_type) == 4: @@ -496,9 +498,9 @@ def mma( a_mk = a.slice(slice(None), utils.ds(ki * a_k_group_elems, a_k_group_elems)).address else: a_offset = mi * a_m_group_stride + ki * a_k_group_stride - a_mk = arith.addi(a_desc_base, utils.c(mma_utils.encode_addr(a_offset), i64)) + a_mk = (a_desc_base[0], a_desc_base[1] + mma_utils.encode_addr(a_offset)) b_offset = ni * b_n_group_stride + ki * b_k_group_stride - b_nk = arith.addi(b_desc_base, utils.c(mma_utils.encode_addr(b_offset), i64)) + b_nk = (b_desc_base[0], b_desc_base[1] + mma_utils.encode_addr(b_offset)) if a_sparse_addr_base is not None: if n_groups != 1 or m_groups != 1: raise NotImplementedError("A sparse metadata address calculation for multiple tiles") @@ -559,8 +561,8 @@ def mma( def _do_mma( d_addr: ir.Value, - a_desc_or_addr: ir.Value, # TMEM address if a_k_stride is None - b_desc: ir.Value, + a_desc_or_addr: tuple[ir.Value, int] | ir.Value, # TMEM address if a_k_stride is None + b_desc: tuple[ir.Value, int], a_transpose: bool, b_transpose: bool, a_k_strides: tuple[tuple[int, ...], tuple[int, ...]] | None, @@ -638,14 +640,12 @@ def create_scaled_instr_descriptor(*args): # type: ignore num_cta = 2 if collective else 1 a_in_tmem = a_k_strides is None - a_ptx = "[$1]" if a_in_tmem else "$1" - a_ptx_constraint = "r" if a_in_tmem else "l" + a_ptx = "[a_desc]" if a_in_tmem else "a_desc" sparse_mod = ".sp" if is_sparse else "" sparse_meta_ptx = "[$5], " if is_sparse else "" extra_constraints += ",r" if is_sparse else "" sparse_addr: tuple[Any, ...] = () scales_addrs: tuple[Any, ...] = () - assert a_desc_or_addr.type == ir.IntegerType.get_signless(32 if a_in_tmem else 64) def _get_offset(idx: int, idx_tiling: tuple[int, ...], strides: tuple[int, ...]): assert len(idx_tiling) + 1 == len(strides) idxs = [] @@ -654,7 +654,7 @@ def _get_offset(idx: int, idx_tiling: tuple[int, ...], strides: tuple[int, ...]) idx = idx % t idxs.append(idx) offset = sum(i * s for i, s in zip(idxs, strides, strict=True)) - return arith.constant(i64, offset >> 4) + return offset >> 4 for k_step in range(k // instr_k): if is_scaled: assert scale_steps is not None @@ -696,20 +696,32 @@ def _get_offset(idx: int, idx_tiling: tuple[int, ...], strides: tuple[int, ...]) ) if a_in_tmem: cols_per_k_group = instr_k // packing // (1 + is_sparse) - a_desc_or_addr_instr = arith.addi( - a_desc_or_addr, arith.constant(i32, k_step * cols_per_k_group) - ) + a_offset = k_step * cols_per_k_group + assert isinstance(a_desc_or_addr, ir.Value) + assert a_desc_or_addr.type == ir.IntegerType.get_signless(32) + a_enc_addr_base = a_desc_or_addr else: assert a_k_idx_tiling is not None and a_k_strides is not None - a_desc_or_addr_instr = arith.addi( - a_desc_or_addr, _get_offset(k_step, a_k_idx_tiling, a_k_strides) - ) - b_desc_instr = arith.addi(b_desc, _get_offset(k_step, b_k_idx_tiling, b_k_strides)) + a_enc_addr_base, a_offset = a_desc_or_addr + a_offset += _get_offset(k_step, a_k_idx_tiling, a_k_strides) + b_enc_addr_base, b_offset = b_desc + b_offset += _get_offset(k_step, b_k_idx_tiling, b_k_strides) + a_offset_low, a_offset_high = a_offset & 0xFFFFFFFF, a_offset >> 32 + b_offset_low, b_offset_high = b_offset & 0xFFFFFFFF, b_offset >> 32 llvm.inline_asm( ir.Type.parse("!llvm.void"), - [d_addr, a_desc_or_addr_instr, b_desc_instr, i_desc, accumulate, *scales_addrs, *sparse_addr], - f"tcgen05.mma{sparse_mod}.cta_group::{num_cta}.kind::{kind} [$0], {a_ptx}, $2, {sparse_meta_ptx}$3, {extra_ptx}$4;", - f"r,{a_ptx_constraint},l,r,b" + extra_constraints, + [d_addr, a_enc_addr_base, b_enc_addr_base, i_desc, accumulate, *scales_addrs, *sparse_addr], + f"""{{ + .reg .b32 a_desc_low, a_desc_high, b_desc_low, b_desc_high; + .reg {".b32" if a_in_tmem else ".b64"} a_desc; + .reg .b64 b_desc; + add.s32 a_desc_low, $1, {a_offset_low}; + add.s32 b_desc_low, $2, {b_offset_low}; + mov.b64 b_desc, {{b_desc_low, {b_offset_high}}}; + {"mov.b32 a_desc, a_desc_low;" if a_in_tmem else f"mov.b64 a_desc, {{a_desc_low, {a_offset_high}}};"} + tcgen05.mma{sparse_mod}.cta_group::{num_cta}.kind::{kind} [$0], {a_ptx}, b_desc, {sparse_meta_ptx}$3, {extra_ptx}$4; + }}""", + "r,r,r,r,b" + extra_constraints, has_side_effects=True, ) accumulate = arith.constant(i1, 1) From dd12a48ea049c6e37ee698b63b78bb2d93797e11 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 5 Dec 2025 03:15:07 -0800 Subject: [PATCH 069/315] [Mosaic GPU] Add support for arbitrary reshapes of contiguous refs PiperOrigin-RevId: 840650264 --- jax/experimental/mosaic/gpu/utils.py | 12 ++++++++++++ tests/mosaic/gpu_test.py | 1 + 2 files changed, 13 insertions(+) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 33965db4ef6f..a027ac441e39 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -722,6 +722,18 @@ def memref_reshape( (), ref_ty.element_type, new_layout, ref_ty.memory_space ) return memref.collapse_shape(result_ty, ref, []) + # For contiguous refs we can do arbitrary reshapes easily. + strides, _ = ref_ty.get_strides_and_offset() + if all( + d == 1 or s1 == s2 + for d, s1, s2 in zip( + ref_ty.shape, + get_contiguous_strides(ref_ty.shape), + strides, + strict=True, + ) + ): + return memref_unfold(memref_fold(ref, 0, ref_ty.rank), 0, shape) return _reshape(ref, src_shape, dst_shape) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index b3d6ce892c93..ebd4e68b5324 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -394,6 +394,7 @@ def kernel(ctx, inp, out, _): ("un", (1, 10, 1), (1, 5, 2, 1,)), ("to_scalar", (1, 1, 1), ()), ("from_scalar", (), (1, 1, 1)), + ("arbitrary", (2 * 5, 7 * 3), (2, 7, 5, 3)), ) def test_reshape(self, inp_shape, out_shape): def kernel(ctx, inp, out, _): From 283e4c77cc2c16eaea23006d1a177bcf81fa103c Mon Sep 17 00:00:00 2001 From: scxfjiang Date: Fri, 5 Dec 2025 06:25:46 -0600 Subject: [PATCH 070/315] register rocm to block_scaled_dot lowering path --- jax/_src/cudnn/scaled_matmul_stablehlo.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jax/_src/cudnn/scaled_matmul_stablehlo.py b/jax/_src/cudnn/scaled_matmul_stablehlo.py index 48c5d27c2678..cf339d6e5ed5 100644 --- a/jax/_src/cudnn/scaled_matmul_stablehlo.py +++ b/jax/_src/cudnn/scaled_matmul_stablehlo.py @@ -66,7 +66,7 @@ def _scaled_matmul_impl(a, b, a_scale, b_scale, preferred_element_type): ) -def _scaled_matmul_cuda_lowering( +def _scaled_matmul_gpu_lowering( ctx, a, b, a_scales, b_scales, preferred_element_type ): lhs_type = ir.RankedTensorType(a.type) @@ -119,8 +119,8 @@ def _scaled_matmul_abstract(a, b, a_scale, b_scale, *, preferred_element_type): mlir.register_lowering( _scaled_matmul_p, - _scaled_matmul_cuda_lowering, - platform="cuda", + _scaled_matmul_gpu_lowering, + platform="gpu", ) _scaled_matmul_p_wrapper = core.Primitive("scaled_matmul_wrapper") From 529a0f4dd8f588c641a002a893c5f294d0811a0c Mon Sep 17 00:00:00 2001 From: Mohammed Anany Date: Fri, 5 Dec 2025 05:24:52 -0800 Subject: [PATCH 071/315] Remove cluster dims from Triton compilation results. These are removed as of the current integration. Submission needs to be separate as OSS looks at the old code. PiperOrigin-RevId: 840684139 --- jaxlib/gpu/gpu_plugin_extension.cc | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/jaxlib/gpu/gpu_plugin_extension.cc b/jaxlib/gpu/gpu_plugin_extension.cc index 862430257f27..d3f411cc87b6 100644 --- a/jaxlib/gpu/gpu_plugin_extension.cc +++ b/jaxlib/gpu/gpu_plugin_extension.cc @@ -48,9 +48,6 @@ namespace { struct TritonCompilationResult { std::string asm_text; int64_t smem_bytes; - int cluster_dim_x; - int cluster_dim_y; - int cluster_dim_z; }; absl::StatusOr CompileTritonToASM( @@ -77,9 +74,6 @@ absl::StatusOr CompileTritonToASM( return TritonCompilationResult{ .asm_text = asm_text, .smem_bytes = args.out_smem_bytes, - .cluster_dim_x = args.out_cluster_dim_x, - .cluster_dim_y = args.out_cluster_dim_y, - .cluster_dim_z = args.out_cluster_dim_z, }; } @@ -240,10 +234,7 @@ void BuildGpuPluginExtension(nanobind::module_& m) { nb::class_(m, "TritonCompilationResult") .def_ro("asm", &TritonCompilationResult::asm_text) - .def_ro("smem_bytes", &TritonCompilationResult::smem_bytes) - .def_ro("cluster_dim_x", &TritonCompilationResult::cluster_dim_x) - .def_ro("cluster_dim_y", &TritonCompilationResult::cluster_dim_y) - .def_ro("cluster_dim_z", &TritonCompilationResult::cluster_dim_z); + .def_ro("smem_bytes", &TritonCompilationResult::smem_bytes); m.def("compile_triton_to_asm", [](nb::capsule c_api, nb::bytes module, std::string_view arch_name, From 6a56337d51527804bdb5144973c500721d2fce5b Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Fri, 5 Dec 2025 06:18:17 -0800 Subject: [PATCH 072/315] [Mosaic GPU] Make `mosaic_gpu.SliceSMEM` pure. This op does not perform allocation itself and needn't exist if it doesn't have useful side-effectful consumers. In a future change, this will allow it to be DCE'd. PiperOrigin-RevId: 840698679 --- jaxlib/mosaic/dialect/gpu/mosaic_gpu.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index ee262ff05400..9fc855486ad4 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -490,7 +490,7 @@ def MosaicGPU_BroadcastInDimOp : Op { } -def MosaicGPU_SliceSMEMOp : Op { +def MosaicGPU_SliceSMEMOp : Op { let summary = "Constructs an SMEM MemRef with the requested type that begins at the specified SMEM offset address."; let arguments = (ins I32:$offset); From a4c5abc5a60086357ffae4101382d30bee79cc98 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 5 Dec 2025 07:32:11 -0800 Subject: [PATCH 073/315] Deprecate jax.core.get_aval; jax.typeof is a drop-in replacement. --- jax/core.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/jax/core.py b/jax/core.py index d82cf3592482..60d031083bb9 100644 --- a/jax/core.py +++ b/jax/core.py @@ -53,7 +53,7 @@ eval_jaxpr as eval_jaxpr, find_top_trace as find_top_trace, gensym as gensym, - get_aval as get_aval, + get_aval as _deprecated_get_aval, is_concrete as is_concrete, is_constant_dim as is_constant_dim, is_constant_shape as is_constant_shape, @@ -76,3 +76,21 @@ unmapped_aval as unmapped_aval, valid_jaxtype as valid_jaxtype, ) + +_deprecations = { + # Added for v0.8.2 + "get_aval": ( + "jax.core.get_aval is deprecated; use jax.typeof instead.", + _deprecated_get_aval + ), +} + +import typing as _typing +if _typing.TYPE_CHECKING: + get_aval = _deprecated_get_aval +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del _typing +del _deprecated_get_aval From e6b7b4f11cf1ad1f08406a9627329b4a6142f92b Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 5 Dec 2025 07:34:05 -0800 Subject: [PATCH 074/315] Fix another dtype deprecation --- tests/typing_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/typing_test.py b/tests/typing_test.py index 4408aa3d0425..9368d11c8af8 100644 --- a/tests/typing_test.py +++ b/tests/typing_test.py @@ -46,7 +46,7 @@ def dtypelike_to_dtype(x: typing.DTypeLike) -> typing.DType: # inputs to jax primitive functions; use convert_element_type here # for simplicity. def arraylike_to_array(x: typing.ArrayLike) -> typing.Array: - return lax.convert_element_type(x, dtypes.dtype(np.result_type(x))) + return lax.convert_element_type(x, dtypes.dtype(x)) class HasDType: From 814e75eceba566f753df81757d195152a84a8166 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 5 Dec 2025 07:46:07 -0800 Subject: [PATCH 075/315] [test] fix signatures test for NumPy 2.4 --- tests/lax_numpy_test.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index d56186f1b98a..d8ab12ea9609 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6266,24 +6266,32 @@ def testWrappedSignaturesMatch(self): # TODO(jakevdp): fix some of the following signatures. Some are due to wrong argument names. unsupported_params = { + 'arange': ['start_or_stop', 'like'], + 'array': ['ndmax', 'like', 'subok'], 'argpartition': ['kind', 'order'], 'asarray': ['like'], 'broadcast_to': ['subok'], 'clip': ['kwargs', 'out'], + 'concat': ['out', 'dtype', 'casting'], + 'concatenate': ['out', 'casting'], 'copy': ['subok'], 'corrcoef': ['ddof', 'bias'], 'cumulative_prod': ['out'], 'cumulative_sum': ['out'], + 'dot': ['out'], 'empty_like': ['subok', 'order'], 'einsum': ['kwargs'], 'einsum_path': ['einsum_call'], + 'empty': ['order', 'like'], 'eye': ['order', 'like'], 'hstack': ['casting'], 'identity': ['like'], 'isin': ['kind'], 'full': ['order', 'like'], 'full_like': ['subok', 'order'], + 'frombuffer': ['like'], 'fromfunction': ['like'], + 'frompyfunc': ['kwargs'], 'load': ['mmap_mode', 'allow_pickle', 'fix_imports', 'encoding', 'max_header_size'], 'nanpercentile': ['weights'], 'nanquantile': ['weights'], @@ -6293,15 +6301,19 @@ def testWrappedSignaturesMatch(self): 'ones_like': ['subok', 'order'], 'partition': ['kind', 'order'], 'percentile': ['weights'], + 'promote_types': ['type1', 'type2'], 'quantile': ['weights'], 'row_stack': ['casting'], 'stack': ['casting'], 'tri': ['like'], + 'unravel_index': ['order'], 'vstack': ['casting'], + 'zeros': ['order', 'like'], 'zeros_like': ['subok', 'order'] } extra_params = { + 'arange': ['start'], 'compress': ['size', 'fill_value'], 'einsum': ['subscripts', 'precision'], 'einsum_path': ['subscripts'], From 683a9e23749861b5894f24c5ae3ed3e6b0fa5de6 Mon Sep 17 00:00:00 2001 From: Yurii Topin Date: Fri, 5 Dec 2025 08:10:41 -0800 Subject: [PATCH 076/315] Update the rules_ml_toolchain dependency to its latest version. This is necessary to ensure proper compatibility with the newest XLA libraries when building on Linux AArch64. PiperOrigin-RevId: 840733660 --- WORKSPACE | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index 8d4ba31a0539..87d1e6830a9f 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -18,9 +18,9 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") # Details: https://github.com/google-ml-infra/rules_ml_toolchain tf_http_archive( name = "rules_ml_toolchain", - sha256 = "b1e5e306d8b1103e73b9b778dfc3a9e069d20664437a03246a235724962b5c94", - strip_prefix = "rules_ml_toolchain-484235be45e6843db962c45d08fe4b2b65a6a24c", - urls = tf_mirror_urls("https://github.com/google-ml-infra/rules_ml_toolchain/archive/484235be45e6843db962c45d08fe4b2b65a6a24c.tar.gz"), + sha256 = "7f00b3e94bbca1a4737ded6b9ed5358f6d1c86430c2ec97c90081343c0482f18", + strip_prefix = "rules_ml_toolchain-29d54c875da37e74b8548924ed30e78cb28126b9", + urls = tf_mirror_urls("https://github.com/google-ml-infra/rules_ml_toolchain/archive/29d54c875da37e74b8548924ed30e78cb28126b9.tar.gz"), ) load( From ac2c547ec89479638defb4fb1f4c6395c4aa0972 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 5 Dec 2025 08:49:00 -0800 Subject: [PATCH 077/315] Integrate LLVM at llvm/llvm-project@ac66ae45cd22 Updates LLVM usage to match [ac66ae45cd22](https://github.com/llvm/llvm-project/commit/ac66ae45cd22) PiperOrigin-RevId: 840746128 --- jax/experimental/mosaic/gpu/dialect_lowering.py | 2 +- jax/experimental/mosaic/gpu/launch_context.py | 5 +++-- jax/experimental/mosaic/gpu/utils.py | 4 +++- jaxlib/mosaic/gpu/serde.cc | 13 ++++++++----- 4 files changed, 15 insertions(+), 9 deletions(-) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 24ea841e7cfb..6a54b8fac3de 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -1339,7 +1339,7 @@ def _mgpu_arrive_expect_tx_op_lowering_rule( barrier = utils.DialectBarrierRef.from_barrier_memref( arrive_expect_tx_op.barrier ) - nvvm.mbarrier_arrive_expect_tx(barrier.get_ptr(), bytes) + nvvm.mbarrier_arrive_expect_tx(None, barrier.get_ptr(), bytes) return [] diff --git a/jax/experimental/mosaic/gpu/launch_context.py b/jax/experimental/mosaic/gpu/launch_context.py index 0ab01a6de37f..9f03af257fca 100644 --- a/jax/experimental/mosaic/gpu/launch_context.py +++ b/jax/experimental/mosaic/gpu/launch_context.py @@ -1158,6 +1158,7 @@ def async_copy( if arrive: arrive_predicate = utils.single_thread_predicate(utils.ThreadSubset.WARPGROUP) nvvm.mbarrier_arrive_expect_tx( + None, barrier_ptr, transfer_bytes, predicate=arrive_predicate, @@ -1289,7 +1290,7 @@ def async_copy( ) arrive_predicate = arith.andi(predicate, first_block) nvvm.mbarrier_arrive_expect_tx( - barrier_ptr, transfer_bytes, predicate=arrive_predicate + None, barrier_ptr, transfer_bytes, predicate=arrive_predicate ) rank = len(slice_shape) idx_operands = ",".join(f"${i}" for i in range(4, 4 + rank)) @@ -1310,7 +1311,7 @@ def async_copy( else: if arrive: nvvm.mbarrier_arrive_expect_tx( - barrier_ptr, transfer_bytes, predicate=predicate + None, barrier_ptr, transfer_bytes, predicate=predicate ) if collective_size > 1: multicast_mask = arith.trunci( diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index a027ac441e39..5b0fe7a00291 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -1083,7 +1083,9 @@ def arrive_expect_tx( elif ir.IndexType.isinstance(bytes.type): i32 = ir.IntegerType.get_signless(32) bytes = arith.index_cast(i32, bytes) - nvvm.mbarrier_arrive_expect_tx(self.get_ptr(), bytes, predicate=predicate) + nvvm.mbarrier_arrive_expect_tx( + None, self.get_ptr(), bytes, predicate=predicate + ) def get_ptr(self): i64 = ir.IntegerType.get_signless(64) diff --git a/jaxlib/mosaic/gpu/serde.cc b/jaxlib/mosaic/gpu/serde.cc index aae46b2aeee2..3330199a986f 100644 --- a/jaxlib/mosaic/gpu/serde.cc +++ b/jaxlib/mosaic/gpu/serde.cc @@ -195,13 +195,16 @@ LogicalResult nvvm_mbarrier_try_wait_parity_shared_upgrade(Operation* op, LogicalResult nvvm_mbarrier_arrive_expect_tx_shared_upgrade(Operation* op, int version, bool& erased) { - // https://github.com/llvm/llvm-project/commit/7eeae8e41d7827d84de12df7b5ecfab3058900cb + // https://github.com/llvm/llvm-project/commit/fddf7b0510e5df7a08c512a177ea9c1ec4307718 if (version < 6) { - mlir::OpBuilder b(op->getParentRegion()); + mlir::ImplicitLocOpBuilder b(op->getLoc(), op->getParentRegion()); b.setInsertionPointAfter(op); - mlir::NVVM::MBarrierArriveExpectTxOp::create( - b, op->getLoc(), op->getOperand(0), op->getOperand(1), - op->getNumOperands() < 3 ? Value{} : op->getOperand(2)); + auto new_op = mlir::NVVM::MBarrierArriveExpectTxOp::create( + b, op->getResultTypes(), op->getOperand(0), op->getOperand(1), + mlir::NVVM::MemScopeKind::CTA, + /*relaxed=*/false, + op->getNumOperands() < 3 ? mlir::Value{} : op->getOperand(2)); + op->replaceAllUsesWith(new_op); op->erase(); erased = true; } From 8d6d5716d3da35ffdfd748db6d5367908a0f4ff4 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 5 Dec 2025 09:04:59 -0800 Subject: [PATCH 078/315] Raise an error if an empty mesh is passed to jit in/out_shardings PiperOrigin-RevId: 840752040 --- jax/_src/sharding_impls.py | 4 +++- tests/pjit_test.py | 5 +++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 3f039d862672..ded173ffb653 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -472,7 +472,6 @@ def get_replicated(cls, device_assignment, *, memory_kind: str | None = None): def prepare_axis_resources(axis_resources, arg_name, allow_unconstrained_dims=False): - # PyTrees don't treat None values as leaves, so we use an is_leaf function. entries, treedef = tree_util.tree_flatten( axis_resources, is_leaf=lambda x: x is None) what = f"{arg_name} leaf specifications" @@ -485,6 +484,9 @@ def prepare_axis_resources(axis_resources, arg_name, if isinstance(entry, PmapSharding): raise ValueError(f'One of {what} got sharding {entry} which is not ' 'allowed.') + if isinstance(entry, NamedSharding) and entry.mesh.empty: + raise ValueError(f'One of {what} got an empty NamedSharding: {entry} ' + 'which is not allowed.') if (not allow_unconstrained_dims and isinstance(entry, NamedSharding) and PartitionSpec.UNCONSTRAINED in entry.spec): raise ValueError( diff --git a/tests/pjit_test.py b/tests/pjit_test.py index b1b2a9b7b81f..4b5acdf0aabe 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1544,6 +1544,11 @@ def test_pjit_array_single_output_with_mesh_context_manager( self.assertArraysEqual(s.data, expected_matrix_mul[s.index]) self.assertArraysEqual(out._value, expected_matrix_mul) + def test_empty_mesh_to_out_sharding(self): + sharding = jax.NamedSharding(mesh_lib.empty_concrete_mesh, P()) + with self.assertRaisesRegex(ValueError, "got an empty NamedSharding"): + jax.jit(lambda x: x, out_shardings=sharding)(jnp.ones((32,))) + def test_numpy_array_input_assume_fully_replicated(self): input_shape = (8, 2) global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) From 1bab3be4db019e66b6a329c93398b9d6073b1914 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 5 Dec 2025 09:21:13 -0800 Subject: [PATCH 079/315] Clean up get_array_mapping duplication PiperOrigin-RevId: 840758280 --- jax/_src/interpreters/pxla.py | 10 ++-------- jax/experimental/multihost_utils.py | 13 ++++++++----- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 228229ac40d2..bc4407798945 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -67,10 +67,8 @@ from jax._src.mesh import (AbstractMesh, Mesh, get_abstract_mesh, get_concrete_mesh) from jax._src.sharding_impls import ( - ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UnspecifiedValue, - get_array_mapping as _get_array_mapping, array_mapping_to_axis_resources, - SingleDeviceSharding, GSPMDSharding, NamedSharding, - PartitionSpec as P) + ArrayMapping, AUTO, UnspecifiedValue, array_mapping_to_axis_resources, + SingleDeviceSharding, GSPMDSharding, NamedSharding, PartitionSpec as P) from jax._src.util import (safe_map, safe_zip, partition_list, wrap_name, tuple_update, tuple_delete, distributed_debug_log, unzip2, HashableFunction, weakref_lru_cache, @@ -3420,7 +3418,3 @@ def batch_spec(spec, dim, val): spec += (None,) * too_short new_partitions = tuple_insert(spec, dim, val) # type: ignore return PartitionSpec(*new_partitions) - -def get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified: - pspec = sharding_impls.prepare_axis_resources(pspec, "pspec to array_mapping") - return _get_array_mapping(pspec) diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py index 1acb02e5b01f..ee7e9509ea3d 100644 --- a/jax/experimental/multihost_utils.py +++ b/jax/experimental/multihost_utils.py @@ -123,8 +123,9 @@ def _handle_array_process_allgather(inp, tiled): host_np_arr = np.expand_dims(host_np_arr, axis=0) aval = core.ShapedArray(host_np_arr.shape, host_np_arr.dtype) + pspec = sharding_impls.prepare_axis_resources(pspec, "pspec to array_mapping") global_aval = pxla.mesh_local_to_global( - global_mesh, pxla.get_array_mapping(pspec), aval) + global_mesh, sharding_impls.get_array_mapping(pspec), aval) bufs = [jax.device_put(host_np_arr, d) for d in jax.local_devices()] global_arr = array.make_array_from_single_device_arrays( @@ -236,13 +237,15 @@ def _flatten_pspecs(name, in_tree, pspecs_thunk): @lru_cache def _local_to_global_aval(local_aval, mesh, pspec): - return pxla.mesh_local_to_global(mesh, pxla.get_array_mapping(pspec), - local_aval) + pspec = sharding_impls.prepare_axis_resources(pspec, "pspec to array_mapping") + return pxla.mesh_local_to_global( + mesh, sharding_impls.get_array_mapping(pspec), local_aval) @lru_cache def _global_to_local_aval(global_aval, mesh, pspec): - return pxla.mesh_global_to_local(mesh, pxla.get_array_mapping(pspec), - global_aval) + pspec = sharding_impls.prepare_axis_resources(pspec, "pspec to array_mapping") + return pxla.mesh_global_to_local( + mesh, sharding_impls.get_array_mapping(pspec), global_aval) def host_local_array_to_global_array_impl( From a6fa073337a160653fac3a92ca3580e660e38da2 Mon Sep 17 00:00:00 2001 From: Dougal Date: Fri, 5 Dec 2025 16:33:08 +0000 Subject: [PATCH 080/315] [no-thunks] Avoid thunking/linear_util in checkify --- jax/_src/checkify.py | 61 ++++++++++++++------------- jax/_src/interpreters/partial_eval.py | 11 ++--- tests/debug_info_test.py | 1 + 3 files changed, 38 insertions(+), 35 deletions(-) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 1fc01e00cbe6..3b76ea8724de 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -753,20 +753,27 @@ def scatter_error_check(prim, error, enabled_errors, operand, indices, updates, # HOP error check rules +@jtu.register_static +class ErrorEffects: + def __init__(self, val): + self.val = val + @weakref_lru_cache def jaxpr_to_checkify_jaxpr( jaxpr: core.ClosedJaxpr, enabled_errors, err_tree: PyTreeDef, *flat_err_and_in_vals) -> tuple[core.ClosedJaxpr, PyTreeDef, set[ErrorEffect]]: - checkify_jaxpr_partial = functools.partial(checkify_jaxpr_flat, jaxpr.jaxpr, - jaxpr.consts, enabled_errors, - err_tree) - fun = lu.wrap_init(checkify_jaxpr_partial, - debug_info=jaxpr.jaxpr.debug_info.with_unknown_names()) - fun, metadata = _flatten_and_get_error_metadata_thunk(fun) - - new_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, flat_err_and_in_vals) - checked_jaxpr = core.ClosedJaxpr(new_jaxpr, consts) - out_tree, error_effects = metadata() + + def fun_wrapped(*invals): + error, out = checkify_jaxpr_flat( + jaxpr.jaxpr, jaxpr.consts, enabled_errors, err_tree, *invals) + error_effects = ErrorEffects(set(error._pred.keys())) + return (error, out), error_effects + + debug_info = jaxpr.jaxpr.debug_info.with_unknown_names() + checked_jaxpr, full_out_tree = pe.trace_to_jaxpr( + fun_wrapped, None, flat_err_and_in_vals, debug_info) + out_tree, error_effects_treedef = full_out_tree.children() + error_effects = error_effects_treedef.unflatten(()).val return checked_jaxpr, out_tree, error_effects def cond_error_check(error: Error, enabled_errors, index, *ops, @@ -848,18 +855,16 @@ def new_body_f(*c_consts_and_vals): # This checks if the next cond application will error lax.dce_sink(cond_f(*c_consts, *out)) return out - new_body_f_ = lu.wrap_init( - new_body_f, - debug_info=body_jaxpr.jaxpr.debug_info.with_unknown_names()) c_consts_avals = cond_jaxpr.in_avals[:c_consts_num] - jaxpr, _, () = pe.trace_to_jaxpr_dynamic( - new_body_f_, [*c_consts_avals, *body_jaxpr.in_avals]) - closed_jaxpr = pe.close_jaxpr(jaxpr) + jaxpr, _ = pe.trace_to_jaxpr( + new_body_f, None, + (*c_consts_avals, *body_jaxpr.in_avals), + debug_info=body_jaxpr.jaxpr.debug_info.with_unknown_names()) err_vals, err_tree = jtu.tree_flatten(error) err_vals = map(core.get_aval, err_vals) flat_err_and_in_vals = [*err_vals, *c_consts_avals, *body_jaxpr.in_avals] jaxpr, out_tree, error_effects = jaxpr_to_checkify_jaxpr( - closed_jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals) + jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals) return jaxpr, out_tree, error_effects @@ -1004,12 +1009,10 @@ def expand_errors_leading_dim(*xs): return *errs, *outs with core.extend_axis_env_nd(mesh.shape.items()), config._check_vma(check_vma): - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(expand_errors_leading_dim, - debug_info=checked_jaxpr.jaxpr.debug_info), - checked_jaxpr.in_avals - ) - checked_jaxpr = core.ClosedJaxpr(jaxpr, consts) + checked_jaxpr, _ = pe.trace_to_jaxpr( + expand_errors_leading_dim, None, + tuple(checked_jaxpr.in_avals), + debug_info=checked_jaxpr.jaxpr.debug_info) # Update shard_map params to account for extra error values. # Use fully sharded partitioning for out errors. @@ -1235,17 +1238,15 @@ def checkify(f: Callable[..., Out], @traceback_util.api_boundary def checked_fun(*args, **kwargs): # close over all arguments so they're not turned into abstract values. - in_tree = jtu.tree_structure(((), {})) + in_tree = jtu.tree_structure(()) closed_f = lambda: f(*args, **kwargs) # stage: - debug = api_util.debug_info("checkify", f, args, kwargs) - fun_, out_tree = api_util.flatten_fun( - lu.wrap_init(closed_f, debug_info=debug.with_unknown_names()), in_tree) - jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(fun_, ()) - jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_)) + debug_info = api_util.debug_info("checkify", f, args, kwargs).with_unknown_names() + jaxpr_, out_tree = pe.trace_to_jaxpr(closed_f, in_tree, (), debug_info) + jaxpr, consts = pe.separate_consts(jaxpr_) # checkify: error, out_flat = checkify_jaxpr(jaxpr, errors, init_error, *consts) - return error, jtu.tree_unflatten(out_tree(), out_flat) + return error, jtu.tree_unflatten(out_tree, out_flat) return checked_fun def check(pred: Bool, msg: str, diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 37dbd8dbadfa..cbb759c43bf4 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -2402,8 +2402,8 @@ def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals): @weakref_lru_cache def trace_to_jaxpr( fun: Callable, - in_tree: PyTreeDef, - in_avals_flat: Sequence[AbstractValue | core.AvalQDD], + in_tree: PyTreeDef | None, + in_avals_flat: tuple[AbstractValue | core.AvalQDD, ...], debug_info: core.DebugInfo ) -> tuple[ClosedJaxpr, PyTreeDef]: config.enable_checks.value and debug_info.assert_arg_names(len(in_avals_flat)) @@ -2413,10 +2413,11 @@ def trace_to_jaxpr( # rooted at the enclosing jaxpr. with core.ensure_no_leaks(trace), source_info_util.reset_name_stack(): source_info = source_info_util.current() - in_tracers_flat = map(partial(trace.new_arg, source_info=source_info), + in_tracers = map(partial(trace.new_arg, source_info=source_info), in_avals_flat) with core.set_current_trace(trace): - in_tracers = tree_unflatten(in_tree, in_tracers_flat) + if in_tree is not None: + in_tracers = tree_unflatten(in_tree, in_tracers) ans = fun(*in_tracers) debug_info = debug_info.set_result_paths(ans) ans_flat, out_tree = tree_flatten(ans) @@ -2426,7 +2427,7 @@ def trace_to_jaxpr( _check_no_returned_refs(debug_info, out_tracers) jaxpr, consts = trace.frame.to_jaxpr(trace, out_tracers, debug_info, source_info) - del trace, fun, in_tracers_flat, in_tracers, out_tracers, ans, ans_flat + del trace, fun, in_tracers, out_tracers, ans, ans_flat config.enable_checks.value and core.check_jaxpr(jaxpr) return ClosedJaxpr(jaxpr, consts), out_tree diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index 719d4a75b115..2936bb9ae743 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -1869,6 +1869,7 @@ def my_f(x, y): self.assertEqual(res[0][1], "from the argument x") self.assertRegex(res[1][1], r"named 'foo' from .*debug_info_test.py:.*my_f") + @unittest.skip("Test fails during no-thunks rewrite") def test_checkify_pmap_basic(self): if len(jax.devices()) < 2: self.skipTest("requires at least 2 devices") From d5810d061ea197041629793b0fa2cb92d784c00b Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 5 Dec 2025 10:04:20 -0800 Subject: [PATCH 081/315] Support multiple MLIR Python binding versions The nvvm.mbarrier_arrive_expect_tx signature has recently changed. PiperOrigin-RevId: 840773612 --- jax/experimental/mosaic/gpu/dialect_lowering.py | 2 +- jax/experimental/mosaic/gpu/launch_context.py | 11 +++++------ jax/experimental/mosaic/gpu/utils.py | 11 +++++++++-- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 6a54b8fac3de..33e8b5fdd5e2 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -1339,7 +1339,7 @@ def _mgpu_arrive_expect_tx_op_lowering_rule( barrier = utils.DialectBarrierRef.from_barrier_memref( arrive_expect_tx_op.barrier ) - nvvm.mbarrier_arrive_expect_tx(None, barrier.get_ptr(), bytes) + utils.nvvm_mbarrier_arrive_expect_tx(barrier.get_ptr(), bytes) return [] diff --git a/jax/experimental/mosaic/gpu/launch_context.py b/jax/experimental/mosaic/gpu/launch_context.py index 9f03af257fca..d03410bbd3a2 100644 --- a/jax/experimental/mosaic/gpu/launch_context.py +++ b/jax/experimental/mosaic/gpu/launch_context.py @@ -1157,8 +1157,7 @@ def async_copy( if arrive: arrive_predicate = utils.single_thread_predicate(utils.ThreadSubset.WARPGROUP) - nvvm.mbarrier_arrive_expect_tx( - None, + utils.nvvm_mbarrier_arrive_expect_tx( barrier_ptr, transfer_bytes, predicate=arrive_predicate, @@ -1289,8 +1288,8 @@ def async_copy( arith.CmpIPredicate.eq, self.cluster_idx(collective), c(0, index), ) arrive_predicate = arith.andi(predicate, first_block) - nvvm.mbarrier_arrive_expect_tx( - None, barrier_ptr, transfer_bytes, predicate=arrive_predicate + utils.nvvm_mbarrier_arrive_expect_tx( + barrier_ptr, transfer_bytes, predicate=arrive_predicate ) rank = len(slice_shape) idx_operands = ",".join(f"${i}" for i in range(4, 4 + rank)) @@ -1310,8 +1309,8 @@ def async_copy( ) else: if arrive: - nvvm.mbarrier_arrive_expect_tx( - None, barrier_ptr, transfer_bytes, predicate=predicate + utils.nvvm_mbarrier_arrive_expect_tx( + barrier_ptr, transfer_bytes, predicate=predicate ) if collective_size > 1: multicast_mask = arith.trunci( diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 5b0fe7a00291..c06555242740 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -1083,8 +1083,8 @@ def arrive_expect_tx( elif ir.IndexType.isinstance(bytes.type): i32 = ir.IntegerType.get_signless(32) bytes = arith.index_cast(i32, bytes) - nvvm.mbarrier_arrive_expect_tx( - None, self.get_ptr(), bytes, predicate=predicate + nvvm_mbarrier_arrive_expect_tx( + self.get_ptr(), bytes, predicate=predicate ) def get_ptr(self): @@ -2000,3 +2000,10 @@ def nanosleep(nanos: ir.Value): "r", has_side_effects=True, ) + + +def nvvm_mbarrier_arrive_expect_tx(barrier: ir.Value, expect_tx: ir.Value, predicate: ir.Value | None = None): + try: + return nvvm.mbarrier_arrive_expect_tx(None, barrier, expect_tx, predicate=predicate) # type: ignore + except TypeError: + return nvvm.mbarrier_arrive_expect_tx(barrier, expect_tx, predicate=predicate) # pytype: disable=missing-parameter From c7d36f4b0473664dc7a589ef0d63de3e76037452 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 4 Dec 2025 22:12:20 +0000 Subject: [PATCH 082/315] [debug-info] de-thunk DebugInfo.result_paths It breaks the test mechanism more than the actual results. Will fix in follow-up once we finish de-thunkifying. Co-authored-by: Dougal Maclaurin Co-authored-by: Yash Katariya --- jax/_src/api_util.py | 19 +++++++++++++++---- jax/_src/linear_util.py | 24 ++++-------------------- jax/_src/pjit.py | 22 ++++++++++++++-------- tests/api_test.py | 12 ++++++------ tests/debug_info_test.py | 1 + tests/mutable_array_test.py | 4 ++++ tests/pjit_test.py | 1 + 7 files changed, 45 insertions(+), 38 deletions(-) diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 77e9e1c11d72..4f8a378a14f2 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -27,16 +27,17 @@ from jax._src.state.types import AbstractRef from jax._src.tree_util import ( PyTreeDef, tree_flatten, tree_unflatten, treedef_children, - generate_key_paths, broadcast_prefix, prefix_errors, none_leaf_registry, - broadcast_flattened_prefix_with_treedef) + tree_flatten_with_path, generate_key_paths, broadcast_prefix, prefix_errors, + none_leaf_registry, broadcast_flattened_prefix_with_treedef) from jax._src import linear_util as lu from jax._src.util import (safe_map, WrapKwArgs, Hashable, HashableFunction, - Unhashable, safe_zip as zip) + Unhashable, safe_zip, unzip2) from jax._src import traceback_util traceback_util.register_exclusion(__file__) -map = safe_map +map, unsafe_map = safe_map, map +zip, unsafe_zip = safe_zip, zip def _ensure_index(x: Any) -> int | tuple[int, ...]: """Ensure x is either an index or a tuple of indices.""" @@ -75,6 +76,16 @@ def flatten_fun(f: Callable, store: lu.Store, store.store(out_tree) return ans +@lu.transformation_with_aux2 +def flatten_fun3(f: Callable, store: lu.Store, + in_tree: PyTreeDef, *args_flat): + py_args, py_kwargs = tree_unflatten(in_tree, args_flat) + ans = f(*py_args, **py_kwargs) + paths_and_ans, out_tree = tree_flatten_with_path(ans) + paths, ans = unzip2(paths_and_ans) + store.store((out_tree, paths)) + return ans + def apply_flat_fun(fun, io_tree, *py_args): in_tree_expected, out_tree = io_tree args, in_tree = tree_flatten((py_args, {})) diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index 7867286257e0..5433e9769698 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -77,7 +77,7 @@ def trans1(static_arg, *dynamic_args, **kwargs): from jax._src import core from jax._src import traceback_util from jax._src.tree_util import KeyPath, generate_key_paths, keystr -from jax._src.util import HashableFunction, curry, fun_name, register_cache +from jax._src.util import curry, fun_name, register_cache traceback_util.register_exclusion(__file__) @@ -405,13 +405,8 @@ def wrap_init(f: Callable, params=None, *, debug_info: DebugInfo) -> WrappedFun: """Wraps function `f` as a `WrappedFun`, suitable for transformation.""" params_dict = {} if params is None else params params = () if params is None else tuple(sorted(params.items())) + debug_info = debug_info._replace(result_paths=None) fun = WrappedFun(f, partial(f, **params_dict), (), (), params, None, debug_info) - if debug_info.result_paths is initial_result_paths: - fun, result_paths_thunk = _get_result_paths_thunk(fun) - debug_info = debug_info._replace( - result_paths=HashableFunction(result_paths_thunk, closure=())) - fun = WrappedFun(fun.f, fun.f_transformed, fun.transforms, fun.stores, - fun.params, fun.in_type, debug_info) return fun @@ -421,24 +416,13 @@ def _clean_keystr_arg_names(k: KeyPath) -> str: res = keystr(k) return _re_clean_keystr_arg_names.sub(r"\1", res) -@transformation_with_aux2 -def _get_result_paths_thunk(_fun: Callable, _store: Store, *args, **kwargs): - ans = _fun(*args, **kwargs) - result_paths = tuple(f"result{_clean_keystr_arg_names(path)}" for path, _ in generate_key_paths(ans)) - if _store: - # In some instances a lu.WrappedFun is called multiple times, e.g., - # the bwd function in a custom_vjp - assert _store.val == result_paths, (_store, result_paths) - else: - _store.store(result_paths) - return ans - def annotate(f: WrappedFun, in_type: core.InputType | None) -> WrappedFun: assert f.in_type is None if in_type is None: return f _check_input_type(in_type) - return WrappedFun(f.f, f.f_transformed, f.transforms, f.stores, f.params, in_type, f.debug_info) + return WrappedFun(f.f, f.f_transformed, f.transforms, f.stores, f.params, + in_type, f.debug_info) def _check_input_type(in_type: core.InputType) -> None: # Check that in_type is syntactically well-formed diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 54aedc578248..1e9f390702d4 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -46,10 +46,9 @@ from jax._src import xla_bridge as xb from jax._src.core import typeof, cur_qdd from jax._src.api_util import ( - argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs, - donation_vector, check_callable, resolve_argnums, - argnames_partial_except, debug_info, check_no_aliased_ref_args, - _check_no_aliased_closed_over_refs) + argnums_partial_except, flatten_axes, flatten_fun3, flatten_fun_nokwargs, + donation_vector, check_callable, resolve_argnums, argnames_partial_except, + debug_info, check_no_aliased_ref_args, _check_no_aliased_closed_over_refs) from jax._src.interpreters import partial_eval as pe from jax._src.partition_spec import PartitionSpec from jax._src.interpreters import ad @@ -500,7 +499,7 @@ def _infer_params_impl( del kwargs explicit_args, in_tree = tree_flatten((dyn_args, dyn_kwargs)) - flat_fun, out_tree = flatten_fun(f, in_tree) + flat_fun, out_tree_and_result_paths = flatten_fun3(f, in_tree) if (ji.donate_argnums or ji.donate_argnames) and not config.debug_nans.value: donated_invars = donation_vector(ji.donate_argnums, ji.donate_argnames, in_tree) @@ -550,6 +549,7 @@ def _infer_params_impl( jaxpr, consts, out_avals = _create_pjit_jaxpr( flat_fun, in_type, qdd_token, IgnoreKey(ji.inline)) + if config.mutable_array_checks.value: _check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), explicit_args) _qdd_cache_update(flat_fun, in_type, qdd_token, consts, @@ -557,7 +557,7 @@ def _infer_params_impl( out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings( out_shardings_treedef, out_shardings_leaves, ji.out_layouts_treedef, - ji.out_layouts_leaves, HashableFunction(out_tree, closure=()), + ji.out_layouts_leaves, HashableFunction(lambda: out_tree_and_result_paths()[0], closure=()), tuple(out_avals), jaxpr.jaxpr._debug_info, device_or_backend_set) assert len(explicit_args) == len(in_shardings_flat) == len(in_layouts_flat) @@ -576,6 +576,11 @@ def _infer_params_impl( assert (len(in_shardings_flat) == len(in_layouts_flat) == len(donated_invars) == len(consts) + len(args_flat)) + out_tree, result_paths = out_tree_and_result_paths() + result_paths = tuple(f"result{lu._clean_keystr_arg_names(path)}" + for path in result_paths) + jaxpr.jaxpr._debug_info = jaxpr.debug_info._replace(result_paths=result_paths) + params = dict( jaxpr=jaxpr, in_shardings=in_shardings_flat, @@ -589,8 +594,9 @@ def _infer_params_impl( inline=ji.inline, compiler_options_kvs=ji.compiler_options_kvs, ) + return (PjitParams(consts, params, in_avals, - in_tree, out_tree(), dbg.safe_arg_names(len(in_avals))), + in_tree, out_tree, dbg.safe_arg_names(len(in_avals))), args_flat) @@ -1051,7 +1057,7 @@ def arg_type_to_str(at): if t[0] != ot[0]: unavailable(f"fun_transforms[{i}] transform", t, ot) continue - if t_name == "flatten_fun": + if t_name == "flatten_fun3": explain_in_tree_diff(t[1][0], ot[1][0]) continue if t_name == "_argnums_partial": diff --git a/tests/api_test.py b/tests/api_test.py index 1d9bd6339d2d..b55f844f033f 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -7970,7 +7970,7 @@ def f(): def test_jit_traceback(self): # TODO(dougalm): improve this! jit can (and should) be nested a lot. - expected_depth = 14 + expected_depth = 13 init_depth = self.cur_depth() @jit def foo(x): @@ -7980,7 +7980,7 @@ def foo(x): def test_grad_traceback(self): # TODO(dougalm): improve this - expected_depth = 13 + expected_depth = 12 init_depth = self.cur_depth() def foo(x): @@ -7991,7 +7991,7 @@ def foo(x): def test_vmap_traceback(self): # TODO(dougalm): improve this - expected_depth = 8 + expected_depth = 7 init_depth = self.cur_depth() def foo(x): @@ -8002,9 +8002,9 @@ def foo(x): def test_custom_vjp_traceback(self): # TODO(dougalm): improve this - expected_depth_f = 11 - expected_depth_f_fwd = 22 - expected_depth_f_rev = 13 + expected_depth_f = 10 + expected_depth_f_fwd = 20 + expected_depth_f_rev = 12 init_depth = self.cur_depth() @jax.custom_vjp def f(x): diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index 719d4a75b115..b151e24bf47f 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -115,6 +115,7 @@ def append(self, t: Any) -> None: @jtu.with_config(jax_mutable_array_checks=True) +@unittest.skip("WIP") class DebugInfoTest(jtu.JaxTestCase): def _check_tracers_and_jaxprs(self, traceable: Any, diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index fbd9a0cc3eb7..2d03c1866bbf 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -14,6 +14,8 @@ from __future__ import annotations +import unittest + from absl.testing import absltest from absl.testing import parameterized from functools import partial @@ -1045,12 +1047,14 @@ def test_return_from_jit_arg(self): r".*was passed in as the argument x_ref"): jax.jit(lambda x_ref: x_ref)(core.new_ref(jnp.arange(3))) + @unittest.skip("regressed") # TODO(mattjj): fix def test_return_from_jit_pytree(self): with self.assertRaisesRegex( ValueError, r"tree path result\['hi'\]"): jax.jit(lambda x_ref: {'hi': x_ref})(core.new_ref(jnp.arange(3))) + @unittest.skip("regressed") # TODO(mattjj): fix def test_return_from_jit_closure(self): with self.assertRaisesRegex( ValueError, diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 8fd2253d47ab..ca13d0081f4c 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -9826,6 +9826,7 @@ def testNonDivisibleArgs(self, mesh, resources): with self.assertRaisesRegex(ValueError, error): pjit(lambda x: x, in_shardings=spec, out_shardings=None)(x) + @unittest.skip("regressed") # TODO(mattjj): fix test @check_1d_2d_mesh(set_mesh=True) def testNonDivisibleOuts(self, mesh, resources): x = jnp.ones((3, 2)) From bd63099996e61fc097ed9e6a044cd56b895bcb61 Mon Sep 17 00:00:00 2001 From: Yue Sheng Date: Fri, 5 Dec 2025 12:58:51 -0800 Subject: [PATCH 083/315] [Mosaic TPU] Support 1d tiling for packed dtypes when transposing major/minor dims. PiperOrigin-RevId: 840842151 --- tests/pallas/tpu_pallas_test.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 7923bfa27f82..551ee97db9d5 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -3731,6 +3731,32 @@ def kernel(x_ref, out_ref): out, np.zeros((8, 8, 2, 128), dtype=jnp.float32) ) + @parameterized.parameters( + (3, 1, 2048, jnp.bfloat16), + (5, 1, 4096, jnp.int8), + ) + def test_1d_tiling_major_minor_transpose(self, q, m, n, dtype): + if not jtu.is_cloud_tpu_at_least(2025, 12, 10): + self.skipTest('Needs a newer libTPU') + + in_shape = (q, n) + mid_shape = (q, m, n) + out_shape = (m, q, n) + x = np.arange(np.prod(in_shape), dtype=dtype).reshape(in_shape) + + def kernel(x_ref, o_ref): + x = x_ref[...] + x = jnp.reshape(x, mid_shape) + o_ref[...] = jnp.transpose(x, axes=(1, 0, 2)) + + result = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(out_shape, dtype), + )(x) + np.testing.assert_array_equal( + result, np.transpose(x.reshape(mid_shape), axes=(1, 0, 2)) + ) + # (q, m, n) -> (q, m * n) where n % 128 == 0 @parameterized.parameters( (q, m, n, dtype) From 3a83fa48995650b659b9200a325b7e551640958b Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Fri, 5 Dec 2025 13:07:29 -0800 Subject: [PATCH 084/315] Re-enable TPU tests now that TPU thread stack sizes have been increased. Reverts 4d9ff5bd390d2695d3e0a6ae3c57f0f0e826efd1 PiperOrigin-RevId: 840845684 --- tests/multiprocess/pjit_test.py | 3 --- tests/pallas/ops_test.py | 3 --- tests/pjit_test.py | 3 --- tests/python_callback_test.py | 22 ---------------------- 4 files changed, 31 deletions(-) diff --git a/tests/multiprocess/pjit_test.py b/tests/multiprocess/pjit_test.py index c25e2998ecb3..79c0721ab66b 100644 --- a/tests/multiprocess/pjit_test.py +++ b/tests/multiprocess/pjit_test.py @@ -527,9 +527,6 @@ def f(x): self.assertEqual(output(), "") def test_print_in_multihost_shard_map(self): - if jtu.is_cloud_tpu(): - self.skipTest("TODO: b/465504705") - devices = jax.devices() mesh = jax.sharding.Mesh(devices, ("i",)) num_devices = jax.local_device_count() diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 93ea67c563e3..c345a47f9876 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -2756,9 +2756,6 @@ class OpsInterpretTest(OpsTest): INTERPRET = True def test_debug_print(self): - if jtu.is_cloud_tpu(): - self.skipTest("TODO: b/465504705") - @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 7937cf75089d..0de679d43025 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4247,9 +4247,6 @@ def test_in_out_shardings_unconstrained_error(self): in_shardings=NamedSharding(mesh, P(P.UNCONSTRAINED, 'x'))) def test_empty_io_callback_under_shard_map(self): - if jtu.is_cloud_tpu(): - self.skipTest("TODO: b/465504705") - mesh = jtu.create_mesh((4,), 'i') def empty_callback(x): diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 128632a992d3..3a70b08ea912 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -53,7 +53,6 @@ ) -@unittest.skipIf(jtu.is_cloud_tpu(), "TODO: b/465504705") class PythonCallbackTest(jtu.JaxTestCase): def setUp(self): @@ -670,7 +669,6 @@ def f(x): np.testing.assert_array_equal(x, result) -@unittest.skipIf(jtu.is_cloud_tpu(), "TODO: b/465504705") class PureCallbackTest(jtu.JaxTestCase): def setUp(self): @@ -1152,9 +1150,6 @@ def tearDown(self): dispatch.runtime_tokens.clear() def test_io_callback_can_mutate_state(self): - if jtu.is_cloud_tpu(): - self.skipTest("TODO: b/465504705") - x = 0 def cb(): nonlocal x @@ -1171,9 +1166,6 @@ def f(): self.assertEqual(x, 2) def test_io_callback_can_be_batched_if_unordered(self): - if jtu.is_cloud_tpu(): - self.skipTest("TODO: b/465504705") - _mut = 0 def cb(x): nonlocal _mut @@ -1282,9 +1274,6 @@ def f(x, y): def test_can_use_io_callback_in_pjit( self, *, ordered: bool, with_sharding: bool ): - if jtu.is_cloud_tpu(): - self.skipTest("TODO: b/465504705") - devices = jax.devices() mesh = jax.sharding.Mesh(np.array(devices), ['dev']) @@ -1345,9 +1334,6 @@ def f(x): @jtu.ignore_warning(message='.*Please use `jax.jit` instead.*', category=DeprecationWarning) def test_sequence_pjit_io_callback_ordered(self): - if jtu.is_cloud_tpu(): - self.skipTest("TODO: b/465504705") - if jtu.is_device_tpu(7, 'x'): self.skipTest('TODO(b/453664256): Failing on TPU 7x.') @@ -1409,8 +1395,6 @@ def f_base(i, x): single_device=True) ) def test_can_shard_io_callback_manually(self, single_device: bool): - if jtu.is_cloud_tpu(): - self.skipTest("TODO: b/465504705") devices = jax.devices() if single_device: @@ -1445,9 +1429,6 @@ def f(shard_ids, x): def test_batching_with_side_effects(self): # https://github.com/jax-ml/jax/issues/20628#issuecomment-2050800195 - if jtu.is_cloud_tpu(): - self.skipTest("TODO: b/465504705") - x_lst = [] def append_x(x): nonlocal x_lst @@ -1464,9 +1445,6 @@ def f(x): def test_batching_with_side_effects_while_loop(self): # https://github.com/jax-ml/jax/issues/20628#issuecomment-2050921219 - if jtu.is_cloud_tpu(): - self.skipTest("TODO: b/465504705") - x_lst = [] def append_x(x): nonlocal x_lst From c71327c925a7a1873ff68109bee463fe1df729df Mon Sep 17 00:00:00 2001 From: Yashwant Bezawada Date: Wed, 3 Dec 2025 16:04:27 -0600 Subject: [PATCH 085/315] Fix constant folding bug in jnp.arange for non-zero start When jnp.arange is called with a non-zero start value (e.g., jnp.arange(1, n)), the array was being folded into a constant instead of using the iota primitive. This caused compilation time to explode for large arrays. Extended the optimization to handle any start value by computing the array size as ceil(stop - start) and using iota with an offset added when start is non-zero. Also adds a check to skip the optimization when start or stop is complex, since ceil doesn't support complex numbers. Test coverage includes: - testArangeJaxprNonZeroStart: Verifies non-zero start uses iota + add - testArangeRandomValues: Parameterized randomized testing for int32/float32 dtypes - testArangeComplex: Ensures complex arguments match NumPy behavior Fixes #32542 --- jax/_src/numpy/lax_numpy.py | 20 +++++++++++++++++--- tests/lax_numpy_test.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 3 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 87a3bb327f2f..338e5b6ab164 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -5989,9 +5989,23 @@ def _arange(start: ArrayLike | DimSize, stop: ArrayLike | DimSize | None = None, start = ceil_(start).astype(int) return lax.broadcasted_iota(dtype, (start,), 0, out_sharding=out_sharding) # type: ignore[arg-type] else: - if step is None and start == 0 and stop is not None: - return lax.broadcasted_iota(dtype, (np.ceil(stop).astype(int),), 0, - out_sharding=out_sharding) + if step is None and stop is not None: + # Skip optimization if start or stop is complex (ceil doesn't support complex) + start_dtype = _dtype(start) + stop_dtype = _dtype(stop) + if (dtypes.issubdtype(start_dtype, np.complexfloating) or + dtypes.issubdtype(stop_dtype, np.complexfloating)): + return array(np.arange(start, stop=stop, step=step, dtype=dtype), + device=out_sharding) + # Use iota + offset instead of creating a constant array + size = int(np.ceil(stop - start)) + if size <= 0: + return array([], dtype=dtype, device=out_sharding) + result = lax.broadcasted_iota(dtype, (size,), 0, out_sharding=out_sharding) + if start != 0: + # Add offset if start is non-zero + result = lax.add(result, lax.convert_element_type(start, dtype)) + return result return array(np.arange(start, stop=stop, step=step, dtype=dtype), device=out_sharding) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 2ba0e3e7ec75..3a813440f554 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -4869,6 +4869,42 @@ def testArangeJaxpr(self, args, specify_device): self.assertEqual(len(jaxpr.jaxpr.eqns), num_eqs) self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.iota_p) + @jtu.sample_product(specify_device=[True, False]) + def testArangeJaxprNonZeroStart(self, specify_device): + device = jax.devices()[-1] if specify_device else None + jaxpr = jax.make_jaxpr(lambda: jnp.arange(1, 5, device=device))() + # Non-zero start should produce iota + add (+ device_put if device specified) + num_eqs = 3 if device is not None else 2 + self.assertEqual(len(jaxpr.jaxpr.eqns), num_eqs) + self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.iota_p) + self.assertEqual(jaxpr.jaxpr.eqns[1].primitive, lax.add_p) + + @jtu.sample_product( + dtype=[np.int32, np.float32], + iteration=range(10) + ) + def testArangeRandomValues(self, dtype, iteration): + del iteration # not needed: each test case gets its own random seed. + rng = jtu.rand_default(self.rng()) + start = rng((), dtype) + stop = rng((), dtype) + jax_result = jnp.arange(start, stop, dtype=dtype) + np_result = np.arange(start, stop, dtype=dtype) + self.assertAllClose(jax_result, np_result) + + def testArangeComplex(self): + test_cases = [ + (1+2j, 5+3j), + (0+0j, 5+0j), + (1.0+0j, 5.0+0j), + (0, 5, 1+1j), + ] + for args in test_cases: + with self.subTest(args=args): + jax_result = jnp.arange(*args) + np_result = np.arange(*args) + self.assertArraysEqual(jax_result, np_result) + def testIssue830(self): a = jnp.arange(4, dtype=jnp.complex64) self.assertEqual(a.dtype, jnp.complex64) From 1b53ea1be11bc4251c342986c76621cd23982df3 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 5 Dec 2025 20:11:41 -0800 Subject: [PATCH 086/315] Fix c64 -> f32 .view where a new zeros array is created without the correct sharding. Fixes https://github.com/jax-ml/jax/issues/33787 PiperOrigin-RevId: 840976307 --- jax/_src/numpy/array_methods.py | 7 ++++--- tests/pjit_test.py | 10 ++++++++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 320b95548fb4..e713ee8c53ed 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -557,9 +557,10 @@ def _view(self: Array, dtype: DTypeLike | None = None, type: None = None) -> Arr if lax_numpy.issubdtype(self.dtype, np.complexfloating): new_shape = (*self.shape[:-1], self.shape[-1] * 2) new_dtype = lax_numpy.finfo(self.dtype).dtype - self = (array_creation.zeros(new_shape, new_dtype) - .at[..., 0::2].set(self.real) - .at[..., 1::2].set(self.imag)) + new_sharding = core.typeof(self).sharding + self = (array_creation.zeros(new_shape, new_dtype, out_sharding=new_sharding) + .at[..., 0::2].set(self.real) + .at[..., 1::2].set(self.imag)) return _view(self, dtype) if dtype == bool: diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 0de679d43025..b55beb335dd3 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -9811,6 +9811,16 @@ def f(x): NamedSharding(mesh.abstract_mesh, P('x'))) compiled(arr) # doesn't crash + @jtu.with_explicit_mesh((2,), 'x') + def test_c64_to_f32_view_rountrip(self, mesh): + x = jnp.zeros((128, 64), dtype=jnp.complex64, out_sharding=P(('x'))) + y = jax.jit(lambda _x: _x.view(jnp.float32))(x) + self.assertEqual(y.sharding, NamedSharding(mesh, P('x', None))) + + x = jnp.zeros((128, 64), dtype=jnp.float32, out_sharding=P(('x'))) + y = jax.jit(lambda _x: _x.view(jnp.complex64))(x) + self.assertEqual(y.sharding, NamedSharding(mesh, P('x', None))) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From 1709d4639a3dd73d0f8748090ab9af4af9a84758 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 6 Dec 2025 00:05:14 -0800 Subject: [PATCH 087/315] Update XLA dependency to use revision http://github.com/openxla/xla/commit/e03ddca21a32a5edd8cf93a0c2f55a052cd1f66b PiperOrigin-RevId: 841029872 --- third_party/xla/revision.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/revision.bzl b/third_party/xla/revision.bzl index 0733d88d6e4e..84eeef03d8fd 100644 --- a/third_party/xla/revision.bzl +++ b/third_party/xla/revision.bzl @@ -21,5 +21,5 @@ # and update XLA_SHA256 with the result. # buildifier: disable=module-docstring -XLA_COMMIT = "0158b0d5b1911f7e2a8de06d3c1f855d95ca5ec6" -XLA_SHA256 = "18188dd12346c55f043c9617089aae329408ddeb611b035c6852714b679e5d7a" +XLA_COMMIT = "e03ddca21a32a5edd8cf93a0c2f55a052cd1f66b" +XLA_SHA256 = "c66e405a25e8b48bb9cde5bc53cae7fefd46ee48ea149f3f168a9108d8efd75c" From c5e3b96ba4ada994653b2f88ba6c76daa5d4bb95 Mon Sep 17 00:00:00 2001 From: scxfjiang Date: Sat, 6 Dec 2025 15:01:57 -0600 Subject: [PATCH 088/315] update --- jax/_src/cudnn/scaled_matmul_stablehlo.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/jax/_src/cudnn/scaled_matmul_stablehlo.py b/jax/_src/cudnn/scaled_matmul_stablehlo.py index cf339d6e5ed5..b6b80748fb24 100644 --- a/jax/_src/cudnn/scaled_matmul_stablehlo.py +++ b/jax/_src/cudnn/scaled_matmul_stablehlo.py @@ -120,7 +120,12 @@ def _scaled_matmul_abstract(a, b, a_scale, b_scale, *, preferred_element_type): mlir.register_lowering( _scaled_matmul_p, _scaled_matmul_gpu_lowering, - platform="gpu", + platform="cuda", +) +mlir.register_lowering( + _scaled_matmul_p, + _scaled_matmul_gpu_lowering, + platform="rocm", ) _scaled_matmul_p_wrapper = core.Primitive("scaled_matmul_wrapper") From 3aa8a6b0d4de5e554f45db638b0f3056e4c520f1 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 7 Dec 2025 00:06:27 -0800 Subject: [PATCH 089/315] Update XLA dependency to use revision http://github.com/openxla/xla/commit/8ef2c8582a08a2b6aa5c74773421501e45e6be1c PiperOrigin-RevId: 841317625 --- third_party/xla/revision.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/revision.bzl b/third_party/xla/revision.bzl index 84eeef03d8fd..a597ef4a7667 100644 --- a/third_party/xla/revision.bzl +++ b/third_party/xla/revision.bzl @@ -21,5 +21,5 @@ # and update XLA_SHA256 with the result. # buildifier: disable=module-docstring -XLA_COMMIT = "e03ddca21a32a5edd8cf93a0c2f55a052cd1f66b" -XLA_SHA256 = "c66e405a25e8b48bb9cde5bc53cae7fefd46ee48ea149f3f168a9108d8efd75c" +XLA_COMMIT = "8ef2c8582a08a2b6aa5c74773421501e45e6be1c" +XLA_SHA256 = "acf7a420f2e0d5389a5279d6523fb973a78a7d606633ef265d1bbe9234513bf1" From 052ce57b7bf4ab6a4754a72140cc66bba3c424ec Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 8 Dec 2025 00:05:15 -0800 Subject: [PATCH 090/315] Update XLA dependency to use revision http://github.com/openxla/xla/commit/f5d6d1aae38dfa3de44c057064ec7609b9a390af PiperOrigin-RevId: 841624372 --- third_party/xla/revision.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/revision.bzl b/third_party/xla/revision.bzl index a597ef4a7667..a1890149fb2b 100644 --- a/third_party/xla/revision.bzl +++ b/third_party/xla/revision.bzl @@ -21,5 +21,5 @@ # and update XLA_SHA256 with the result. # buildifier: disable=module-docstring -XLA_COMMIT = "8ef2c8582a08a2b6aa5c74773421501e45e6be1c" -XLA_SHA256 = "acf7a420f2e0d5389a5279d6523fb973a78a7d606633ef265d1bbe9234513bf1" +XLA_COMMIT = "f5d6d1aae38dfa3de44c057064ec7609b9a390af" +XLA_SHA256 = "e867f1329105c55f34667c589ee718d10d6de378b026d8d363f13c20a83beb5d" From 128b4e99ccd1b24f965eee8b23238307d65a2f4b Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 8 Dec 2025 06:02:51 -0800 Subject: [PATCH 091/315] [Pallas:MGPU] Add support for jnp.sin PiperOrigin-RevId: 841730048 --- jax/_src/pallas/mosaic_gpu/lowering.py | 13 +++++++++++++ jax/experimental/mosaic/gpu/dialect_lowering.py | 1 + jax/experimental/mosaic/gpu/layout_inference.py | 1 + tests/pallas/ops_test.py | 3 ++- 4 files changed, 17 insertions(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index e23416a566e5..27e9d719e1b2 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -2302,6 +2302,19 @@ def _exp2_lowering_rule(ctx: LoweringRuleContext, x, accuracy): ) return math_dialect.exp2(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) +@register_lowering_rule(lax.sin_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.sin_p, mgpu.LoweringSemantics.Warpgroup) +def _sin_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") + [x_aval] = ctx.avals_in + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: + return _ensure_fa(x, x_aval.dtype).sin(approx=ctx.module_ctx.approx_math) + fastmath = ( + arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None + ) + return math_dialect.sin(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) + @register_lowering_rule(lax.log_p, mgpu.LoweringSemantics.Lane) @register_lowering_rule(lax.log_p, mgpu.LoweringSemantics.Warpgroup) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 33e8b5fdd5e2..ab8ed5431f8f 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -1099,6 +1099,7 @@ def _unary_op_lowering_rule( (mlir_math.RsqrtOp, fa.FragmentedArray.rsqrt, None), (mlir_math.ExpOp, fa.FragmentedArray.exp, None), (mlir_math.Exp2Op, fa.FragmentedArray.exp2, None), + (mlir_math.SinOp, fa.FragmentedArray.sin, None), (mlir_math.LogOp, fa.FragmentedArray.log, None), (mlir_math.TanhOp, fa.FragmentedArray.tanh, None), ]: diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index 5a2eaa97178b..fdbd34d2806b 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -604,6 +604,7 @@ def _pointwise_op_constraint_system( arith.XOrIOp, mlir_math.ExpOp, mlir_math.Exp2Op, + mlir_math.SinOp, mlir_math.LogOp, mlir_math.RsqrtOp, mlir_math.TanhOp, diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index c345a47f9876..7c4b88e0bd71 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1083,7 +1083,8 @@ def kernel(x_ref, o_ref): for fn, dtype in itertools.product(*args) ) def test_elementwise(self, fn, dtype): - self.skip_if_mosaic_gpu() + if fn is not jnp.sin or dtype == "float64": + self.skip_if_mosaic_gpu() if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8: self.skipTest("64-bit types require x64_enabled") From 25e182fe2a1caf5cc598a1f9386c6eebf09153bc Mon Sep 17 00:00:00 2001 From: Brian Patton Date: Mon, 8 Dec 2025 07:12:03 -0800 Subject: [PATCH 092/315] [Pallas:SC] Add plsc.sort_key_val to give access to the `mask` and `descending` args. PiperOrigin-RevId: 841750021 --- jax/_src/pallas/mosaic/BUILD | 1 + jax/_src/pallas/mosaic/sc_primitives.py | 67 ++++++++++++++++++++++ jax/experimental/pallas/tpu_sc.py | 1 + tests/pallas/tpu_sparsecore_pallas_test.py | 50 ++++++++++++++++ 4 files changed, 119 insertions(+) diff --git a/jax/_src/pallas/mosaic/BUILD b/jax/_src/pallas/mosaic/BUILD index 3538428e97f3..afcdefffa81e 100644 --- a/jax/_src/pallas/mosaic/BUILD +++ b/jax/_src/pallas/mosaic/BUILD @@ -175,6 +175,7 @@ pytype_strict_library( deps = [ ":core", ":lowering", + ":sc_core", ":sc_lowering", "//jax", "//jax/_src:core", diff --git a/jax/_src/pallas/mosaic/sc_primitives.py b/jax/_src/pallas/mosaic/sc_primitives.py index cd8a2cb303c5..7f06e37d8474 100644 --- a/jax/_src/pallas/mosaic/sc_primitives.py +++ b/jax/_src/pallas/mosaic/sc_primitives.py @@ -33,6 +33,7 @@ from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic import core as tpu_core from jax._src.pallas.mosaic import lowering as tc_lowering +from jax._src.pallas.mosaic import sc_core from jax._src.pallas.mosaic import sc_lowering from jax._src.state import primitives as state_primitives from jax._src.state import types as state_types @@ -634,6 +635,72 @@ def _reduce_sum_lowering_rule( _cumsum_lowering_rule(ctx, x, 0, reverse=False), [], [vec_dim - 1]) +masked_sort_p = jax_core.Primitive("masked_sort") +masked_sort_p.multiple_results = True + +@masked_sort_p.def_abstract_eval +def _masked_sort_abstract_eval(keys, values, *maybe_mask, descending): + del descending # Unused. + supported_shape = (sc_core.get_sparse_core_info().num_lanes,) + if keys.dtype not in (jnp.int32, jnp.float32): + raise NotImplementedError( + f"sort_key_val: keys dtype {keys.dtype} should be int32 or float32") + if keys.shape != supported_shape: + raise ValueError(f"keys shape {keys.shape} must be {supported_shape}") + if jnp.dtype(values.dtype).itemsize != 4: + raise NotImplementedError( + f"sort_key_val: values dtype {values.dtype} should be 32 bits") + if values.shape != supported_shape: + raise ValueError(f"values shape {values.shape} must be {supported_shape}") + if maybe_mask: + [mask] = maybe_mask + if not jnp.issubdtype(mask.dtype, jnp.bool): + raise TypeError(f"mask dtype {mask.dtype} is not boolean") + if mask.shape != supported_shape: + raise ValueError(f"mask shape {mask.shape} must be {supported_shape}") + return keys, values, *maybe_mask + +@sc_lowering.register_lowering_rule(masked_sort_p) +def _masked_sort_lowering_rule( + ctx: sc_lowering.LoweringRuleContext, keys, values, *maybe_mask, descending): + del ctx # Unused. + if maybe_mask: + [mask] = maybe_mask + else: + mask_type = ir.VectorType.get( + [sc_core.get_sparse_core_info().num_lanes], + ir.IntegerType.get_signless(1)) + mask = arith.constant(mask_type, ir.DenseElementsAttr.get_splat( + mask_type, ir.BoolAttr.get(True))) + out_mask, sorted_keys, sorted_values = tpu.sort( + mask.type, keys.type, values.type, keys, values, mask=mask, + descending=descending + ) + if maybe_mask: + return sorted_keys, sorted_values, out_mask + return sorted_keys, sorted_values + +def sort_key_val( + keys: jax.Array, values: jax.Array, *, + mask: jax.Array | None = None, descending: bool = False +) -> jax.Array: + """Sorts keys and values, pushing invalid elements to the last positions. + + Args: + keys: An array of integers or floats. + values: An array of values corresponding to the keys. + mask: An optional array of booleans, which specifies which elements of + `keys` and `values` are valid. If `None`, all elements are valid. + descending: Whether to sort in descending order. + + Returns: + sorted_keys, sorted_values, [output_mask]: The sorted keys and values, and, + if a mask was given, the corresponding mask for output keys and values. + """ + maybe_mask = () if mask is None else (mask,) + return masked_sort_p.bind(keys, values, *maybe_mask, descending=descending) + + parallel_loop_p = jax_core.Primitive("parallel_loop") parallel_loop_p.is_effectful = lambda params: bool(params["jaxpr"].effects) # type: ignore parallel_loop_p.multiple_results = True diff --git a/jax/experimental/pallas/tpu_sc.py b/jax/experimental/pallas/tpu_sc.py index 68503bb15c33..e9a90ac9f7ac 100644 --- a/jax/experimental/pallas/tpu_sc.py +++ b/jax/experimental/pallas/tpu_sc.py @@ -32,6 +32,7 @@ from jax._src.pallas.mosaic.sc_primitives import PackFormat as PackFormat from jax._src.pallas.mosaic.sc_primitives import parallel_loop as parallel_loop from jax._src.pallas.mosaic.sc_primitives import scan_count as scan_count +from jax._src.pallas.mosaic.sc_primitives import sort_key_val as sort_key_val from jax._src.pallas.mosaic.sc_primitives import store_compressed as store_compressed from jax._src.pallas.mosaic.sc_primitives import store_scatter as store_scatter from jax._src.pallas.mosaic.sc_primitives import subcore_barrier as subcore_barrier diff --git a/tests/pallas/tpu_sparsecore_pallas_test.py b/tests/pallas/tpu_sparsecore_pallas_test.py index c8cc4666e7c2..18ce12312436 100644 --- a/tests/pallas/tpu_sparsecore_pallas_test.py +++ b/tests/pallas/tpu_sparsecore_pallas_test.py @@ -1654,6 +1654,56 @@ def kernel(x_ref, indices_ref, out_ref): np.testing.assert_array_equal(kernel(x, indices), x[indices]) + @parameterized.product( + keys_dtype=[np.int32, np.float32], + values_dtype=[np.int32, np.float32], + use_mask=[False, True], + descending=[False, True], + ) + def test_sort_key_val(self, keys_dtype, values_dtype, use_mask, descending): + if not jtu.is_cloud_tpu_at_least(2025, 12, 2): + self.skipTest("Test requires a newer libtpu") + + vec_dim = self.sc_info.num_lanes + keys = np.arange(vec_dim, dtype=keys_dtype) + np.random.shuffle(keys) + keys[3] = keys[1] # Verify sort stability. + values = np.arange(vec_dim, dtype=values_dtype) + np.random.shuffle(values) + mask = np.random.choice([True, False], size=vec_dim) if use_mask else None + maybe_mask_arg = (mask.astype(jnp.int32),) if use_mask else () + + @self.vector_subcore_kernel(out_shape=(keys, values, *maybe_mask_arg)) + def kernel(*args): + if use_mask: + mask_ref, *args, o_mask_ref = args + mask = mask_ref[...].astype(jnp.bool) + else: + mask, o_mask_ref = None, None + keys_ref, values_ref, o_keys_ref, o_vals_ref = args + o_keys_ref[...], o_vals_ref[...], *maybe_out_mask = plsc.sort_key_val( + keys_ref[...], values_ref[...], mask=mask, descending=descending) + if use_mask: + [out_mask] = maybe_out_mask + o_mask_ref[...] = out_mask.astype(jnp.int32) + + out_keys, out_values, *maybe_out_mask = kernel( + *maybe_mask_arg, keys, values) + + keys_arg = keys + if descending: + keys_arg = -keys_arg + if use_mask: + keys_arg = jnp.where(mask, keys_arg, 100) + _, gt_keys = jax.lax.sort_key_val(keys_arg, keys) + _, gt_values = jax.lax.sort_key_val(keys_arg, values) + if use_mask: + [out_mask] = maybe_out_mask + gt_out_mask = jnp.arange(vec_dim) < mask.sum() + np.testing.assert_array_equal(out_mask, gt_out_mask.astype(jnp.int32)) + np.testing.assert_array_equal(out_keys, gt_keys) + np.testing.assert_array_equal(out_values, gt_values) + @parameterized.product(dtype=[np.int32, np.float32]) def test_rev_and_sort_desc(self, dtype): if not jtu.is_cloud_tpu_at_least(2025, 12, 2): From fac65507464a63a878f48947a2281a90a77e3c3f Mon Sep 17 00:00:00 2001 From: Yulia Baturina Date: Mon, 8 Dec 2025 08:22:49 -0800 Subject: [PATCH 093/315] Add new mandatory presubmit jobs: 1) Build CPU test targets on Windows platform. 2) Build jax, jaxlib artifacts on Linux x86 and arm6, build CUDA plugins on Linux x86. PiperOrigin-RevId: 841773766 --- .github/workflows/build_artifacts.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml index 8278ea197078..b1f3b6c8f545 100644 --- a/.github/workflows/build_artifacts.yml +++ b/.github/workflows/build_artifacts.yml @@ -115,7 +115,7 @@ jobs: name: "${{ inputs.artifact }}, ${{ (contains(inputs.runner, 'linux-x86') && 'linux x86') || (contains(inputs.runner, 'linux-arm64') && 'linux arm64') || - (contains(inputs.runner, 'windows-x86') && 'windows x86') }}, py ${{ inputs.python }} ${{ (contains(inputs.artifact, 'cuda') && format(', cuda {0}', inputs.cuda-version)) || '' }}, clone main XLA=${{ inputs.clone_main_xla }}" + (contains(inputs.runner, 'windows-x86') && 'windows x86') }}, py ${{ inputs.python }}${{ (contains(inputs.artifact, 'cuda') && format(', cuda {0}', inputs.cuda-version)) || '' }}, clone main XLA=${{ inputs.clone_main_xla }}" # Map the job outputs to step outputs outputs: From 32d830f9cf213f37a50a058cabda2ab8b230ee7f Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 8 Dec 2025 11:47:03 -0800 Subject: [PATCH 094/315] Make sure all args passed to `lax.sort` have the same sharding just like shapes. PiperOrigin-RevId: 841855619 --- jax/_src/lax/lax.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index ba7f7998dbd3..5e25498e1cbe 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -8184,12 +8184,17 @@ def _reduce_precision_lower(ctx, operand, *, exponent_bits, mantissa_bits): } -def _sort_abstract_eval(*args, **kwargs): - args = tuple(args) - if any(arg.shape != args[0].shape for arg in args[1:]): - shapes = " ".join(str(a.shape) for a in args) +def _sort_abstract_eval(*avals, **kwargs): + avals = tuple(avals) + if any(arg.shape != avals[0].shape for arg in avals[1:]): + shapes = " ".join(str(a.shape) for a in avals) raise TypeError(f"Arguments to sort must have equal shapes, got: {shapes}") - return args + non_empty_s = [a.sharding for a in avals if not a.sharding.mesh.empty] + if any(s != non_empty_s[0] for s in non_empty_s[1:]): + shardings = " ".join(str(s) for s in non_empty_s) + raise core.ShardingTypeError( + f'Arguments to sort must have equal shardings, got: {shardings}') + return avals def _canonicalize_float_for_sort(x): @@ -8287,7 +8292,9 @@ def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable, num_keys for arg, bdim in zip(batched_args, batch_dims): if bdim is None: dims = np.delete(np.arange(prototype_arg.ndim), new_bdim) - new_args.append(broadcast_in_dim(arg, prototype_arg.shape, dims)) + new_args.append(broadcast_in_dim( + arg, prototype_arg.shape, dims, + out_sharding=core.typeof(prototype_arg).sharding)) else: new_args.append(batching.moveaxis(arg, bdim, new_bdim)) new_dimension = dimension + (new_bdim <= dimension) From c5c8af279a9fb020740c4213704319328b773ab1 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 8 Dec 2025 12:08:20 -0800 Subject: [PATCH 095/315] [Pallas:MGPU] Add support for jnp.cos PiperOrigin-RevId: 841864222 --- jax/_src/pallas/mosaic_gpu/lowering.py | 12 ++++++++++++ jax/experimental/mosaic/gpu/dialect_lowering.py | 1 + jax/experimental/mosaic/gpu/layout_inference.py | 1 + tests/pallas/ops_test.py | 2 +- 4 files changed, 15 insertions(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 27e9d719e1b2..4cfd3924c8b2 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -2315,6 +2315,18 @@ def _sin_lowering_rule(ctx: LoweringRuleContext, x, accuracy): ) return math_dialect.sin(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) +@register_lowering_rule(lax.cos_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.cos_p, mgpu.LoweringSemantics.Warpgroup) +def _cos_lowering_rule(ctx: LoweringRuleContext, x, accuracy): + if accuracy is not None: + raise NotImplementedError("Not implemented: accuracy") + [x_aval] = ctx.avals_in + if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: + return _ensure_fa(x, x_aval.dtype).cos(approx=ctx.module_ctx.approx_math) + fastmath = ( + arith_dialect.FastMathFlags.afn if ctx.module_ctx.approx_math else None + ) + return math_dialect.cos(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) @register_lowering_rule(lax.log_p, mgpu.LoweringSemantics.Lane) @register_lowering_rule(lax.log_p, mgpu.LoweringSemantics.Warpgroup) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index ab8ed5431f8f..d7a9141fa318 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -1100,6 +1100,7 @@ def _unary_op_lowering_rule( (mlir_math.ExpOp, fa.FragmentedArray.exp, None), (mlir_math.Exp2Op, fa.FragmentedArray.exp2, None), (mlir_math.SinOp, fa.FragmentedArray.sin, None), + (mlir_math.CosOp, fa.FragmentedArray.cos, None), (mlir_math.LogOp, fa.FragmentedArray.log, None), (mlir_math.TanhOp, fa.FragmentedArray.tanh, None), ]: diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index fdbd34d2806b..569494e6d8c2 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -605,6 +605,7 @@ def _pointwise_op_constraint_system( mlir_math.ExpOp, mlir_math.Exp2Op, mlir_math.SinOp, + mlir_math.CosOp, mlir_math.LogOp, mlir_math.RsqrtOp, mlir_math.TanhOp, diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 7c4b88e0bd71..11d295ae9bdd 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1083,7 +1083,7 @@ def kernel(x_ref, o_ref): for fn, dtype in itertools.product(*args) ) def test_elementwise(self, fn, dtype): - if fn is not jnp.sin or dtype == "float64": + if fn not in (jnp.sin, jnp.cos) or dtype == "float64": self.skip_if_mosaic_gpu() if not jax.config.x64_enabled and jnp.dtype(dtype).itemsize == 8: From 31e61ae110110af6c8abb31b8b69481b6e551c25 Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Mon, 8 Dec 2025 12:33:46 -0800 Subject: [PATCH 096/315] Disable failing multiprocess array_test on TPU. PiperOrigin-RevId: 841872712 --- ci/run_bazel_test_tpu.sh | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/ci/run_bazel_test_tpu.sh b/ci/run_bazel_test_tpu.sh index 19acc331c716..5c8d53e4c23b 100755 --- a/ci/run_bazel_test_tpu.sh +++ b/ci/run_bazel_test_tpu.sh @@ -73,6 +73,9 @@ echo "Running Bazel TPU tests..." # commands below. set +e +# TODO(emilyaf): Debug and re-enable this test. +IGNORE_TESTS_MULTIACCELERATOR="-//tests/multiprocess:array_test_tpu" + if [[ "$JAXCI_RUN_FULL_TPU_TEST_SUITE" == "1" ]]; then # We're deselecting all Pallas TPU tests in the oldest libtpu build. Mosaic # TPU does not guarantee anything about forward compatibility (unless @@ -142,7 +145,8 @@ if [[ "$JAXCI_RUN_FULL_TPU_TEST_SUITE" == "1" ]]; then //tests:tpu_tests \ //tests/pallas:tpu_tests \ //tests/pallas:tpu_pallas_test_tpu \ - //tests/multiprocess:tpu_tests + //tests/multiprocess:tpu_tests \ + $IGNORE_TESTS_MULTIACCELERATOR # Store the return value of the second bazel command. second_bazel_cmd_retval=$? @@ -224,7 +228,8 @@ else //tests:pjit_test_tpu \ //tests:python_callback_test_tpu \ //tests:ragged_collective_test_tpu \ - //tests/multiprocess:tpu_tests + //tests/multiprocess:tpu_tests \ + $IGNORE_TESTS_MULTIACCELERATOR # Store the return value of the second bazel command. second_bazel_cmd_retval=$? From d5bf94e2c6c07fee9758bb923a968e06e445392b Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 8 Dec 2025 12:51:26 -0800 Subject: [PATCH 097/315] [Pallas:MGPU] Add support for collective scale/sparse metadata copies to TMEM PiperOrigin-RevId: 841879848 --- jax/_src/pallas/mosaic_gpu/primitives.py | 32 ++++-- tests/mosaic/gpu_test.py | 3 +- tests/pallas/mosaic_gpu_test.py | 127 +++++++++++++++++++++++ 3 files changed, 154 insertions(+), 8 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index b9800026590d..ae8a6c06dba1 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -53,6 +53,7 @@ import numpy as np +AxisName = jax_core.AxisName WARP_SIZE = 32 WARPGROUP_SIZE = 128 @@ -3202,7 +3203,10 @@ def _async_store_tmem_lowering_rule_wg( async_copy_scales_to_tmem_p = jax_core.Primitive("async_copy_scales_to_tmem") async_copy_scales_to_tmem_p.multiple_results = True -def async_copy_scales_to_tmem(smem_ref: _Ref, tmem_ref: _Ref): + +def async_copy_scales_to_tmem( + smem_ref: _Ref, tmem_ref: _Ref, collective_axis: AxisName | None = None, +): """Copies the MMA scales from SMEM to TMEM. The copy is performed asynchronously and can be awaited by calling @@ -3226,12 +3230,17 @@ def async_copy_scales_to_tmem(smem_ref: _Ref, tmem_ref: _Ref): async_copy_scales_to_tmem_p.bind( smem_ref, tmem_ref, *flat_smem_transforms, *flat_tmem_transforms, smem_tree=smem_transforms_treedef, tmem_tree=tmem_transforms_treedef, + collective_axis=collective_axis, ) + async_copy_sparse_metadata_to_tmem_p = jax_core.Primitive("async_copy_sparse_metadata_to_tmem") async_copy_sparse_metadata_to_tmem_p.multiple_results = True -def async_copy_sparse_metadata_to_tmem(smem_ref: _Ref, tmem_ref: _Ref): + +def async_copy_sparse_metadata_to_tmem( + smem_ref: _Ref, tmem_ref: _Ref, collective_axis: AxisName | None = None +): """Copies the MMA sparse metadata from SMEM to TMEM. The copy is performed asynchronously and can be awaited by calling @@ -3255,11 +3264,13 @@ def async_copy_sparse_metadata_to_tmem(smem_ref: _Ref, tmem_ref: _Ref): async_copy_sparse_metadata_to_tmem_p.bind( smem_ref, tmem_ref, *flat_smem_transforms, *flat_tmem_transforms, smem_tree=smem_transforms_treedef, tmem_tree=tmem_transforms_treedef, + collective_axis=collective_axis, ) + @async_copy_scales_to_tmem_p.def_effectful_abstract_eval @async_copy_sparse_metadata_to_tmem_p.def_effectful_abstract_eval -def _async_copy_to_tmem_abstract_eval(smem_ref, tmem_ref, *avals_flat, smem_tree, tmem_tree): +def _async_copy_to_tmem_abstract_eval(smem_ref, tmem_ref, *_args, **_kwargs): if smem_ref.memory_space != gpu_core.MemorySpace.SMEM: raise ValueError("async_copy_scales_to_tmem source must be an SMEM ref") if tmem_ref.memory_space != gpu_core.MemorySpace.TMEM: @@ -3267,7 +3278,7 @@ def _async_copy_to_tmem_abstract_eval(smem_ref, tmem_ref, *avals_flat, smem_tree return (), {gpu_core._memory_effect} def _async_copy_to_tmem_lowering_rule( - impl, ctx: lowering.LoweringRuleContext, smem_ref, tmem_ref, *leaves, smem_tree, tmem_tree + impl, ctx: lowering.LoweringRuleContext, smem_ref, tmem_ref, *leaves, smem_tree, tmem_tree, collective_axis ): assert isinstance(tmem_ref, tcgen05.TMEMRef) smem_leaves, tmem_leaves = util.split_list(leaves, [smem_tree.num_leaves]) @@ -3279,8 +3290,17 @@ def _async_copy_to_tmem_lowering_rule( raise NotImplementedError(f"Unimplemented transforms for SMEM refs: {smem_transforms}") if tmem_transforms: raise NotImplementedError(f"Unimplemented transforms for TMEM refs: {tmem_transforms}") - with mgpu.when(ctx.module_ctx.single_lane_predicate): - impl(smem_ref, tmem_ref) + + predicate = ctx.module_ctx.single_lane_predicate + if collective_axis is not None: + is_leader_block = _collective_mma_predicate(ctx, collective_axis) + predicate = arith_dialect.andi(predicate, is_leader_block) + collective = True + else: + collective = False + + with mgpu.when(predicate): + impl(smem_ref, tmem_ref, collective=collective) return () @lowering.register_lowering_rule( diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index ebd4e68b5324..fce498539f87 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1760,11 +1760,10 @@ def format_scales(scales): @parameterized.product( m=(256,), - n=(64, 128, 256), + n=(256,), scale_jax_dtype=(jnp.float8_e8m0fnu, jnp.float8_e4m3fn), ) def test_mma_block_scaled_collective(self, m, n, scale_jax_dtype): - m, n = 256, 256 in_jax_dtype = jnp.float4_e2m1fn out_jax_dtype = jnp.float32 scale_block = 32 if scale_jax_dtype == jnp.float8_e8m0fnu else 16 diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index d2f6f56baab2..7f7e18f2fd8b 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -3768,6 +3768,133 @@ def format_scales(scales): ) np.testing.assert_allclose(result, expected, rtol=1e-3) + @parameterized.product( + m=[256], + n=[256], + scale_jax_dtype=[jnp.float8_e8m0fnu, jnp.float8_e4m3fn], + ) + def test_collective_scaled_matmul(self, m, n, scale_jax_dtype): + self.skip_if_wg_semantics() + + in_jax_dtype = jnp.float4_e2m1fn + out_jax_dtype = jnp.float32 + scale_block = 32 if scale_jax_dtype == jnp.float8_e8m0fnu else 16 + swizzle = 128 + k_steps = 2 + swizzle_elems = 8 * swizzle // dtypes.itemsize_bits(in_jax_dtype) + k = swizzle_elems * k_steps + tiling = (8, swizzle_elems) + transforms = ( + plgpu.TilingTransform(tiling), plgpu.SwizzleTransform(swizzle) + ) + out_transforms = self.default_transforms(dtype=out_jax_dtype) + + m_block = m // 2 + n_block = n // 2 + + def kernel(lhs_gmem, rhs_gmem, lhs_scales_gmem, rhs_scales_gmem, out_gmem, + lhs_smem, rhs_smem, lhs_scales_smem, rhs_scales_smem, out_smem, + tma_barrier, mma_barrier, + acc_tmem, lhs_scales_tmem, rhs_scales_tmem): + plgpu.copy_gmem_to_smem(lhs_gmem, lhs_smem, tma_barrier, + collective_axes="x", partitioned_axis=0) + plgpu.copy_gmem_to_smem(rhs_gmem, rhs_smem, tma_barrier, + collective_axes="x", partitioned_axis=0) + plgpu.copy_gmem_to_smem(lhs_scales_gmem, lhs_scales_smem, tma_barrier, + collective_axes="x", partitioned_axis=0) + # RHS scales are replicated (multicast) + plgpu.copy_gmem_to_smem(rhs_scales_gmem, rhs_scales_smem, tma_barrier, + collective_axes="x", partitioned_axis=None) + cluster_idx = lax.axis_index("x") + + @pl.when(cluster_idx == 0) + def _leader_block(): + plgpu.barrier_wait(tma_barrier) + plgpu.async_copy_scales_to_tmem(lhs_scales_smem, lhs_scales_tmem, collective_axis="x") + plgpu.async_copy_scales_to_tmem(rhs_scales_smem, rhs_scales_tmem, collective_axis="x") + plgpu.tcgen05_mma( + acc_tmem, + lhs_smem, + plgpu.transpose_ref(rhs_smem, (1, 0)), + mma_barrier, + a_scale=lhs_scales_tmem, + b_scale=rhs_scales_tmem, + accumulate=False, + collective_axis="x" + ) + plgpu.barrier_wait(mma_barrier) + + out_smem[...] = plgpu.async_load_tmem(acc_tmem) + plgpu.commit_smem() + slice_out = pl.ds(cluster_idx * m_block, m_block) + plgpu.copy_smem_to_gmem(out_smem, out_gmem.at[slice_out, :]) + plgpu.wait_smem_to_gmem(0) + + scratch_shapes = [ + plgpu.SMEM((m_block, k), in_jax_dtype, transforms=transforms), + plgpu.SMEM((n_block, k), in_jax_dtype, transforms=transforms), + plgpu.SMEM((m_block // 128, k // (scale_block * 4), 32, 16), scale_jax_dtype), + plgpu.SMEM((n // 128, k // (scale_block * 4), 32, 16), scale_jax_dtype), + plgpu.SMEM((m_block, n), out_jax_dtype, transforms=out_transforms), + plgpu.Barrier(num_arrivals=4), + plgpu.Barrier(orders_tensor_core=True), + plgpu.TMEM((m_block, n), out_jax_dtype, collective=True), + plgpu.TMEM((m_block, k // scale_block), scale_jax_dtype, + layout=plgpu.TMEMLayout.SCALES_LAYOUT, collective=True), + plgpu.TMEM((n, k // scale_block), scale_jax_dtype, + layout=plgpu.TMEMLayout.SCALES_LAYOUT, collective=True), + ] + + f = self.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct((m, n), out_jax_dtype), + grid=(1,), + grid_names=("_",), + cluster=(2,), + cluster_names=("x",), + scratch_shapes=scratch_shapes, + ) + + x = jax.random.uniform(jax.random.key(1), shape=(m, k), dtype=jnp.float32).astype(in_jax_dtype) + y = jax.random.uniform(jax.random.key(2), shape=(n, k), dtype=jnp.float32).astype(in_jax_dtype) + + ka, kb = jax.random.split(jax.random.key(1234), 2) + if scale_jax_dtype == jnp.float8_e8m0fnu: + x_scale = jax.lax.bitcast_convert_type( + jax.random.randint(ka, (m, k // scale_block), 122, 132, dtype=jnp.uint8), + scale_jax_dtype + ) + y_scale = jax.lax.bitcast_convert_type( + jax.random.randint(kb, (n, k // scale_block), 122, 132, dtype=jnp.uint8), + scale_jax_dtype + ) + else: + x_scale = jnp.abs( + jax.random.normal(ka, (m, k // scale_block), dtype=jnp.float32).astype(scale_jax_dtype) + ) + y_scale = jnp.abs( + jax.random.normal(kb, (n, k // scale_block), dtype=jnp.float32).astype(scale_jax_dtype) + ) + + def format_scales(scales): + mn, k = scales.shape + assert mn % 128 == 0 and k % 4 == 0 + return ( + scales.reshape(mn // 128, 4, 32, k // 4, 4) + .transpose(0, 3, 2, 1, 4) + .reshape(mn // 128, k // 4, 32, 16) + ) + + result = f(x, y, format_scales(x_scale), format_scales(y_scale)) + + x_logical_scale = jnp.repeat(x_scale, scale_block, axis=1).astype(jnp.float32) + y_logical_scale = jnp.repeat(y_scale, scale_block, axis=1).astype(jnp.float32) + expected = jnp.dot( + x.astype(jnp.float32) * x_logical_scale, + (y.astype(jnp.float32) * y_logical_scale).T, + ) + np.testing.assert_allclose(result, expected, rtol=1e-3) + @parameterized.product( m=[128], n=[128, 256], From 6cb78e549691acfa880bb99c6ee022731b8962ab Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Mon, 8 Dec 2025 12:52:59 -0800 Subject: [PATCH 098/315] Add LoadedExecutable.serialize() to match Executable.serialize(). PiperOrigin-RevId: 841880441 --- jax/_src/compilation_cache.py | 5 ++++- jaxlib/_jax/__init__.pyi | 1 + jaxlib/py_executable.cc | 6 ++++++ jaxlib/xla_client.py | 2 +- 4 files changed, 12 insertions(+), 2 deletions(-) diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py index 59d11bc95b06..bb416fd1c633 100644 --- a/jax/_src/compilation_cache.py +++ b/jax/_src/compilation_cache.py @@ -261,7 +261,10 @@ def put_executable_and_time( " since cache is disabled/not initialized", cache_key) return - serialized_executable = backend.serialize_executable(executable) + if hasattr(executable, "serialize") or xla_client._version >= 389: + serialized_executable = executable.serialize() + else: + serialized_executable = backend.serialize_executable(executable) executable_and_time = combine_executable_and_time( serialized_executable, compile_time) executable_and_time = compress_executable(executable_and_time) diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index 831f44ea2e8e..59cfcd1caf66 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -1199,6 +1199,7 @@ class LoadedExecutable: def client(self) -> Client: ... def local_devices(self) -> list[Device]: ... def get_hlo_text(self) -> str: ... + def serialize(self) -> bytes: ... def size_of_generated_code_in_bytes(self) -> int: ... def get_compiled_memory_stats(self) -> CompiledMemoryStats: ... def execute_sharded( diff --git a/jaxlib/py_executable.cc b/jaxlib/py_executable.cc index 5353883d0c33..d6ba17fd8a09 100644 --- a/jaxlib/py_executable.cc +++ b/jaxlib/py_executable.cc @@ -545,6 +545,12 @@ void PyLoadedExecutable::Register(nb::module_& m) { .def("get_hlo_text", xla::ValueOrThrowWrapper( &PyLoadedExecutable::GetHumanReadableProgramText)) + .def("serialize", + [](const PyLoadedExecutable& exec) -> nb::bytes { + std::string serialized = + xla::ValueOrThrow(exec.ifrt_loaded_executable()->Serialize()); + return nb::bytes(serialized.data(), serialized.size()); + }) .def("size_of_generated_code_in_bytes", &PyLoadedExecutable::SizeOfGeneratedCodeInBytes) .def( diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index 55f0401e9a69..d0f9bceed0f8 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -47,7 +47,7 @@ # Please suffix the version number with a brief description of your change # in a comment. The goal here is to force a merge conflict if two changes # attempt to grab the same version number. -_version = 388 # Add ArrayMeta +_version = 389 # LoadedExecutable.serialize # An internal increasing version number for protecting jaxlib code against # ifrt changes. From 43bb91f753febcd825b505317146ad3b7c98ad8c Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Sat, 6 Dec 2025 00:28:49 +0000 Subject: [PATCH 099/315] improve tracer reprs to show a lil concrete info Co-authored-by: Yash Katariya --- jax/_src/api.py | 2 +- jax/_src/interpreters/ad.py | 9 ++++++++- jax/_src/interpreters/batching.py | 3 ++- jax/_src/interpreters/partial_eval.py | 2 +- jax/_src/literals.py | 3 +++ jax/_src/tree_util.py | 2 +- tests/core_test.py | 14 +++++++------- 7 files changed, 23 insertions(+), 12 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 15577426e2d7..e5c57b3d99e4 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -393,7 +393,7 @@ def disable_jit(disable: bool = True): ... return y + 3 ... >>> print(f(jax.numpy.array([1, 2, 3]))) - Value of y is JitTracer + Value of y is JitTracer(int32[3]) [5 7 9] Here ``y`` has been abstracted by :py:func:`jit` to a :py:class:`ShapedArray`, diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 7e5aacba3d80..3ebe87212ae2 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -828,7 +828,9 @@ def __init__(self, trace, primal, tangent): self.tangent = tangent def _short_repr(self): - return f"GradTracer<{self.aval}>" + pp = lambda x: x._short_repr() if isinstance(x, Tracer) else str(x) + primal, tangent = pp(self.primal), pp(self.tangent) + return f'JVPTracer({primal=!s}, {tangent=!s})' @property def aval(self): @@ -1173,6 +1175,11 @@ def __init__(self, trace, primal, tangent): self.primal = primal self.tangent = tangent + def _short_repr(self): + pp = lambda x: x._short_repr() if isinstance(x, Tracer) else str(x) + primal, tangent = pp(self.primal), typeof(self.tangent).str_short(True) + return f"GradTracer({primal=!s}, typeof(tangent)={tangent!s})" + @property def aval(self): return get_aval(self.primal) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 260a43988018..9eafef3a6396 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -23,6 +23,7 @@ from jax._src import config from jax._src import core +from jax._src.core import typeof from jax._src import source_info_util from jax._src import linear_util as lu from jax._src.partition_spec import PartitionSpec as P @@ -406,7 +407,7 @@ def __init__(self, trace, val, batch_dim: NotMapped | int | RaggedAxis, self.source_info = source_info def _short_repr(self): - return f"VmapTracer<{self.aval}>" + return f"VmapTracer(aval={self.aval}, batched={typeof(self.val)})" @property def aval(self): diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 413ac19f20e4..fc4ffda4396d 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1687,7 +1687,7 @@ def __init__(self, trace: DynamicJaxprTrace, self.parent = parent def _short_repr(self): - return f"JitTracer<{self.aval}>" + return f"JitTracer({self.aval})" def cur_qdd(self): return self.mutable_qdd.cur_val diff --git a/jax/_src/literals.py b/jax/_src/literals.py index 237072f0e606..5aed0f3c3256 100644 --- a/jax/_src/literals.py +++ b/jax/_src/literals.py @@ -51,6 +51,9 @@ def __new__(cls, value: float, dtype: np.dtype): def __repr__(self): return f'TypedFloat({float(self)}, dtype={self.dtype.name})' + def __str__(self): + return str(float(self)) + def __getnewargs__(self): return (float(self), self.dtype) diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 4f439a770e69..7ddbb3cf55ef 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -543,7 +543,7 @@ class Partial(functools.partial): >>> print_zero() 0 >>> call_func(print_zero) # doctest:+ELLIPSIS - JitTracer<~int32[]> + JitTracer(~int32[]) """ def __new__(klass, func, *args, **kw): diff --git a/tests/core_test.py b/tests/core_test.py index a2b3da15fd2c..e0c6b8436e2e 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -411,19 +411,19 @@ def f(x): x_repr = "" jax.jit(f)(jnp.arange(10.0, dtype='float32')) - self.assertEqual(x_repr, "JitTracer") + self.assertEqual(x_repr, "JitTracer(float32[10])") jax.vmap(f)(jnp.arange(20, dtype='int32')) - self.assertEqual(x_repr, "VmapTracer") + self.assertEqual(x_repr, "VmapTracer(aval=int32[], batched=int32[20])") jax.grad(f)(jnp.float16(1.0)) - self.assertRegex(x_repr, r"(Grad)|(Linearize)Tracer") + self.assertEqual(x_repr, "GradTracer(primal=1.0, typeof(tangent)=f16[])") - jax.jacrev(f)(jnp.arange(12, dtype='float32')) - self.assertRegex(x_repr, r"(Grad)|(Linearize)Tracer") + jax.jacrev(f)(jnp.arange(4, dtype='float32')) + self.assertEqual(x_repr, "GradTracer(primal=[0. 1. 2. 3.], typeof(tangent)=f32[4])") - jax.jacfwd(f)(jnp.arange(14, dtype='float32')) - self.assertRegex(x_repr, r"(Grad)|(Linearize)Tracer") + jax.jacfwd(f)(jnp.arange(3, dtype='float32')) + self.assertEqual(x_repr, "JVPTracer(primal=[0. 1. 2.], tangent=VmapTracer(aval=float32[3], batched=float32[3,3]))") def test_verbose_tracer_reprs(self): # Verbose reprs, avaiable via tracer._pretty_print() From ab3b9b236f2cd90b29fefa2effb07350f12318d8 Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Mon, 8 Dec 2025 13:08:37 -0800 Subject: [PATCH 100/315] [Mosaic GPU] Exclude strided layouts in reduction rules. We currently don't support reducing strided layouts. PiperOrigin-RevId: 841886121 --- .../mosaic/gpu/layout_inference.py | 19 +++++++++++------- tests/mosaic/gpu_layout_inference_test.py | 20 +++++++++++++++++++ 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index 569494e6d8c2..4a005bf67745 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -1042,12 +1042,12 @@ def _vector_reduction_constraint_system( return cs.ConstraintSystem(), {in_variable: [in_variable.key]}, [] -def _reduction_constraint_and_hint( +def _reduction_constraints_and_hint( larger: cs.Variable, smaller: cs.Variable, larger_shape: tuple[int, ...], reduction_dims: tuple[int, ...], -) -> tuple[cs.Constraint, Hint]: +) -> tuple[list[cs.Constraint], Hint]: reduce_expr = cs.Reduce(larger, reduction_dims) # There are always many options for broadcasting a layout, so we can only # derive a broadcast hint in the out_variable -> source_variable direction. @@ -1056,7 +1056,12 @@ def _reduction_constraint_and_hint( ) broadcast_expr = cs.BroadcastInDim(smaller, broadcast_dims, larger_shape) broadcast_hint = Hint(variable=larger, expression=broadcast_expr) - return cs.Equals(lhs=smaller, rhs=reduce_expr), broadcast_hint + constraints = [ + cs.Equals(lhs=smaller, rhs=reduce_expr), + # TODO(allanrenucci): Remove once we support reduction of strided layouts. + cs.NotOfType(larger, fa.WGStridedFragLayout), + ] + return constraints, broadcast_hint @_add_constraint_system_derivation_rule(vector.MultiDimReductionOp) @@ -1071,7 +1076,7 @@ def _multi_dim_reduction_constraint_system( source_variable = cs.Variable(source) out_variable = cs.Variable(out) - reduction_constraint, broadcast_hint = _reduction_constraint_and_hint( + reduction_constraints, broadcast_hint = _reduction_constraints_and_hint( source_variable, out_variable, tuple(ir.ShapedType(op.source.type).shape), @@ -1081,7 +1086,7 @@ def _multi_dim_reduction_constraint_system( # strided layouts from being chosen---since trying to reduce a strided layout # may cause us to raise an Exception at the moment. return ( - cs.ConstraintSystem(constraints=[reduction_constraint]), + cs.ConstraintSystem(constraints=reduction_constraints), {source_variable: [source], out_variable: [acc, out]}, [broadcast_hint], ) @@ -1100,12 +1105,12 @@ def _broadcast_in_dim_constraint_system( i for i in range(len(out_shape)) if i not in op.broadcast_dimensions ) - reduction_constraint, broadcast_hint = _reduction_constraint_and_hint( + reduction_constraints, broadcast_hint = _reduction_constraints_and_hint( out_variable, source_variable, out_shape, reduction_dims ) return ( - cs.ConstraintSystem(constraints=[reduction_constraint]), + cs.ConstraintSystem(constraints=reduction_constraints), { source_variable: [source_variable.key], out_variable: [out_variable.key], diff --git a/tests/mosaic/gpu_layout_inference_test.py b/tests/mosaic/gpu_layout_inference_test.py index 786c0a1f6aee..cf7147ed551c 100644 --- a/tests/mosaic/gpu_layout_inference_test.py +++ b/tests/mosaic/gpu_layout_inference_test.py @@ -327,6 +327,26 @@ def test_infer_broadcast_in_dim_layout(self, layout, axis, hint_on_input): self.checkInLayouts(bcast, [in_layout]) self.checkOutLayouts(bcast, [out_layout]) + # TODO(allanrenucci): Turn into a positive test. This is currently not + # implemented. The test checks we fail gracefully. + @parameterized.parameters(True, False) + def test_cant_infer_reduced_strided_layout(self, hint_on_input): + with ir.InsertionPoint(self.module.body): + [x] = undefs(ir.VectorType.get((128,), ir.F32Type.get())) + if hint_on_input: + layout = mgpu.WGStridedFragLayout.from_shaped_type(x.type) + x = layout_cast(x, layout) + out_type = ir.VectorType.get((128, 128), ir.F32Type.get()) + out = mgpu.dialect.broadcast_in_dim(out_type, x, [0]) + if not hint_on_input: + layout = mgpu.WGStridedFragLayout.from_shaped_type(out.type) + layout_cast(out, layout) + + with self.assertRaisesRegex( + ValueError, "Failed to infer a possible set of layouts" + ): + mgpu.infer_layout(self.module) + @parameterized.parameters( (1, mgpu.WGMMA_LAYOUT, None, None), (0, mgpu.WGMMA_LAYOUT, None, None), From a1c8fadb66e372bacf16c93aa6e90fa0aa6ac3af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Mon, 8 Dec 2025 14:18:27 -0800 Subject: [PATCH 101/315] [Mosaic] Add hasVectorOperandsOrResults utility function PiperOrigin-RevId: 841913251 --- jaxlib/mosaic/dialect/tpu/util.cc | 9 +++++++++ jaxlib/mosaic/dialect/tpu/util.h | 2 ++ 2 files changed, 11 insertions(+) diff --git a/jaxlib/mosaic/dialect/tpu/util.cc b/jaxlib/mosaic/dialect/tpu/util.cc index 30d64ad54a32..86f51a8ae170 100644 --- a/jaxlib/mosaic/dialect/tpu/util.cc +++ b/jaxlib/mosaic/dialect/tpu/util.cc @@ -371,4 +371,13 @@ SmallVector getNontrivialTransitiveUsers(Value v) { return users; } +bool hasVectorOperandsOrResults(Operation& op) { + for (Value value : llvm::concat(op.getOperands(), op.getResults())) { + if (isa(value.getType())) { + return true; + } + } + return false; +} + } // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/util.h b/jaxlib/mosaic/dialect/tpu/util.h index 5cfe574f3bdc..0dd447784a95 100644 --- a/jaxlib/mosaic/dialect/tpu/util.h +++ b/jaxlib/mosaic/dialect/tpu/util.h @@ -301,6 +301,8 @@ std::optional getIntConst(Value v); // results. SmallVector getNontrivialTransitiveUsers(Value v); +bool hasVectorOperandsOrResults(Operation& op); + // Return a mod b for a, b > 0, but adjusted to return b when a mod b == 0 such // that the result is strictly positive. template From 59edc3da54f7ce6754bb0e87bb538c9aeca4fc90 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 8 Dec 2025 14:50:23 -0800 Subject: [PATCH 102/315] Only check sharding in sort abstract eval if any mesh axis is Explicit PiperOrigin-RevId: 841924827 --- jax/_src/lax/lax.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 5e25498e1cbe..e89d5060a92c 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -8189,7 +8189,9 @@ def _sort_abstract_eval(*avals, **kwargs): if any(arg.shape != avals[0].shape for arg in avals[1:]): shapes = " ".join(str(a.shape) for a in avals) raise TypeError(f"Arguments to sort must have equal shapes, got: {shapes}") - non_empty_s = [a.sharding for a in avals if not a.sharding.mesh.empty] + non_empty_s = [ + a.sharding for a in avals + if not a.sharding.mesh.empty and a.sharding.mesh._any_axis_explicit] if any(s != non_empty_s[0] for s in non_empty_s[1:]): shardings = " ".join(str(s) for s in non_empty_s) raise core.ShardingTypeError( From a6aa5ddc68f8117cf36a18ead1100c6c15b2ac76 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Mon, 8 Dec 2025 15:15:09 -0800 Subject: [PATCH 103/315] [Mosaic GPU] Get rid of hints in layout inference. All the hints we derive can now be constructed on the fly based on `Relayout`s or `Equals(Reduce())` constraints. This greatly simplifies the necessary bookkeeping in the layout inference pass, as well as the APIs. A follow-up change will take care of deleting support for `{Least,Most}ReplicatedExpression`, which no longer need to exist. PiperOrigin-RevId: 841933775 --- .../mosaic/gpu/layout_inference.py | 417 +++++++----------- tests/mosaic/gpu_layout_inference_test.py | 127 +----- 2 files changed, 167 insertions(+), 377 deletions(-) diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index 4a005bf67745..542811ee509f 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -138,52 +138,13 @@ def __str__(self): return f"{match.group(0)}:a-{self.index}" -@dataclasses.dataclass(frozen=True) -class Hint: - """Hints are used to model propagation of layouts across operations. - - Since using `relayout`s is always an option in principle, propagation across - ops can not rely only on a constraint system. Instead, we introduce hints as - a form of "soft constraints", i.e., it suggests that `variable` should be - equal to `expression`. - """ - variable: cs.Variable - expression: cs.Expression - - def __str__(self): - return f"{self.variable} ?= {self.expression}" - - -def extract_constant_from_replicated_expression_for_hint( - expression: cs.LeastReplicated | cs.MostReplicated, -) -> cs.Constant | None: - assert len(expression.expressions) >= 1 - choices: list[cs.Constant] = [] - for e in expression.expressions: - if (red := extract_constant_for_hint(e)) is not None: - choices.append(red) - - if not choices: - return None - - # We reduce the expression here in order to recover an unambiguous - # replicated layout if it exists. - maybe_choice = cs.reduce_expression(type(expression)(tuple(choices)), {}) - - if isinstance(maybe_choice, cs.Unsatisfiable): - # TODO(bchetioui): consider other choices. - return choices[0] - - assert isinstance(maybe_choice, cs.Constant) - return maybe_choice - - -def extract_constant_from_broadcast_in_dim_expression_for_hint( - e: cs.BroadcastInDim, -) -> cs.RegisterLayout | None: - if not isinstance(e.expression, cs.RegisterLayout): - return None - +def extract_assignment_candidates_from_reduce_equation( + small: cs.RegisterLayout, + large: cs.Variable, + reduction_dims: tuple[int, ...] +) -> Iterator[cs.RegisterLayout]: + """Yields layout candidates for the reduce equation `small = reduce(large, reduction_dims).""" + large_shape = large.key.value.type.shape # pytype: disable=attribute-error candidates = [ fa.WGMMA_LAYOUT, fa.WGMMA_TRANSPOSED_LAYOUT, @@ -191,66 +152,14 @@ def extract_constant_from_broadcast_in_dim_expression_for_hint( fa.TCGEN05_TRANSPOSED_LAYOUT, tcgen05.TMEM_NATIVE_LAYOUT, ] - if e.shape[-1] % 16 == 0: - candidates.append(tcgen05.fa_m64_collective_layout(e.shape[-1])) + if large_shape[-1] % 16 == 0: + candidates.append(tcgen05.fa_m64_collective_layout(large_shape[-1])) - # TODO(allanrenucci): Allow returning multiple valid candidates. - reduction_dims = tuple(d for d in range(len(e.shape)) if d not in e.axes) for candidate in candidates: - if len(candidate.base_tile_shape) > len(e.shape): + if len(candidate.base_tile_shape) > len(large_shape): continue - if candidate.reduce(reduction_dims) == e.expression.value: - return cs.RegisterLayout(candidate) - return None - - -def extract_constant_for_hint(e: cs.Expression) -> cs.Constant | None: - """Attempts to extract a `ConstantExpression` from a `Hint`'s `Expression`. - - Returns `None` if no `ConstantExpression` could be reasonably extracted. - """ - match e: - case cs.Constant(): - return e - case cs.LeastReplicated() | cs.MostReplicated(): - return extract_constant_from_replicated_expression_for_hint(e) - case cs.BroadcastInDim(): - return extract_constant_from_broadcast_in_dim_expression_for_hint(e) - case cs.Variable(): - return None - case _: - raise NotImplementedError(f"Unsupported expression type: {type(e)}") - - -def extract_variable_assignment_from_hint( - hint: Hint, -) -> tuple[cs.Variable, cs.Constant] | None: - """Attempts to extract a single variable assignment from a `Hint`.""" - # TODO(bchetioui): make this a generator. This will allow us to maybe extract - # different assignments that satisfy a replication constraint in the case - # where replicated expressions are incompatible and several extractions are - # possible. - red = extract_constant_for_hint(hint.expression) - return (hint.variable, red) if red is not None else None - - -def reduce_hints( - hints: Sequence[Hint], assignments: dict[cs.Variable, cs.Constant] -) -> list[Hint]: - """Reduces a sequence of `Hint`s. - - We reduce the `Hint`s' expressions, drop `Unsatisfiable` hints, and drop - `Hint`s pertaining to pre-existing assignments. - """ - new_hints: list[Hint] = [] - for h in hints: - if h.variable not in assignments: - reduced_expression = cs.reduce_expression(h.expression, assignments) - if isinstance(reduced_expression, cs.Unsatisfiable): - continue - new_hints.append(dataclasses.replace(h, expression=reduced_expression)) - - return new_hints + if candidate.reduce(reduction_dims) == small.value: + yield cs.RegisterLayout(candidate) def _strided_layout_for_variable( @@ -267,6 +176,19 @@ def _strided_layout_for_variable( return fa.WGStridedFragLayout.from_shaped_type(type) +def _default_tmem_layout_for_variable( + variable: cs.Variable, +) -> tcgen05.TMEMLayout | None: + """Returns a default TMEM layout for the given variable, if one is defined.""" + value = variable.key.value + parent = value.owner.opview + if isinstance(parent, mgpu.TmemAllocOp): + return tcgen05._infer_tmem_layout( + tuple(value.type.shape), parent.collective, packing=1 + ) + return None + + def _extract_tiling_candidate( divide_constraint: cs.Divides, num_tiled_dims: int ) -> Iterator[tuple[cs.Variable, cs.Constant]]: @@ -357,58 +279,73 @@ def _extract_variable_assignments_from_constraints( match c: case cs.IsTransferable(): yield from _extract_layout_candidates_from_memory_space_transfer(c, dpv) + case cs.Equals(cs.Reduce(cs.Variable() as large, axes=axes), cs.RegisterLayout() as small): + for layout in extract_assignment_candidates_from_reduce_equation(small, large, axes): + yield large, layout + case cs.Equals(cs.RegisterLayout() as small, cs.Reduce(cs.Variable() as large, axes=axes)): + for layout in extract_assignment_candidates_from_reduce_equation(small, large, axes): + yield large, layout + case cs.Relayout(cs.Variable() as var, cs.RegisterLayout() as layout): + yield var, layout + case cs.Relayout(cs.RegisterLayout() as layout, cs.Variable() as var): + yield var, layout def conjure_assignment( unknowns: Sequence[cs.Variable], constraint_system: cs.ConstraintSystem, - hints: Sequence[Hint], ) -> Iterator[tuple[cs.Variable, cs.Constant]]: """Attempts to conjure an assignment for an unknown variable.""" # TODO(allanrenucci): We should be able to short-circuit the search here if # the constraint is not satisfiable. - yield from _extract_variable_assignments_from_constraints( - constraint_system.constraints - ) - def assignment_order( - assignment: tuple[cs.Variable, cs.Constant], - ) -> int: - match assignment: - # Try TiledLayout first, before other hints, because TiledLayout` are - # usually more useful to propagate than `WGSplat`. Also this often - # improves the performance of the layout inference. - case (_, cs.RegisterLayout(fa.TiledLayout())): - return 0 + # As we extract assignment candidates from constraints, we prioritize + # candidates that are more "interesting"; e.g., in the case of registers, + # introducing splat layout candidate assignments often leads to a dead end in + # practice---as opposed to tiled layouts, which are more likely to yield + # solutions to the constraint system. + low_priority_assignments: list[tuple[cs.Variable, cs.Constant]] = [] + for variable, constant in _extract_variable_assignments_from_constraints( + constraint_system.constraints + ): + match constant: + case cs.RegisterLayout(value=value) if not isinstance(value, fa.TiledLayout): + low_priority_assignments.append((variable, constant)) case _: - return 1 + yield variable, constant - assignments = [extract_variable_assignment_from_hint(h) for h in hints] - assignments = [a for a in assignments if a is not None] - assignments = sorted(assignments, key=assignment_order) - yield from assignments + # After all high-priority assignments have been attempted, switch to using + # low-priority assignments. + for variable, constant in low_priority_assignments: + yield variable, constant # Here, we have not managed to find an assignment for all the unknown - # variables, and our hints have not proven sufficient to unblock us. We now - # try to introduce new arbitrary (valid) assignments into the system, and - # hope that they turn out to be compatible with the constraint system. + # variables. We now try to introduce new arbitrary (valid) assignments into + # the system, and hope that they turn out to be compatible with the constraint + # system. for variable in unknowns: if variable in constraint_system.assignments: continue - # Try to instantiate a single variable to a strided layout and see if it + # Try to instantiate a single variable to a default layout and see if it # reduces the system. - if variable.key.memory_space == MemorySpace.REG: - layout = _strided_layout_for_variable(variable) - if layout is not None: - yield variable, cs.RegisterLayout(layout) - elif variable.key.memory_space == MemorySpace.SMEM: - yield variable, cs.SMEMTiling(None) + match variable.key.memory_space: + case MemorySpace.REG: + layout = _strided_layout_for_variable(variable) + if layout is not None: + yield variable, cs.RegisterLayout(layout) + case MemorySpace.SMEM: + yield variable, cs.SMEMTiling(None) + case MemorySpace.TMEM: + layout = _default_tmem_layout_for_variable(variable) + if layout is not None: + yield variable, cs.TMEMLayout(layout) + case _: + raise ValueError(f"Unsupported memory space: {variable.key.memory_space}") def find_assignments_for( unknowns: Sequence[cs.Variable], constraint_system: cs.ConstraintSystem, - hints: Sequence[Hint], *, fuel: int, ) -> tuple[dict[cs.Variable, cs.Constant] | cs.Unsatisfiable, int]: @@ -418,7 +355,6 @@ def find_assignments_for( unknowns: the set of variables that are unknown. Represented as a sequence of `Variable`s for determinism purposes. constraint_system: the constraint system to satisfy. - hints: a list of hints that may be used to introduce new assignments. fuel: the fuel to use for the search. Once the fuel is exhausted, we raise an error. @@ -449,17 +385,12 @@ def find_assignments_for( v: k for v, k in constraint_system.assignments.items() if v in unknowns }, fuel - # Reduce the expressions in the remaining hints based on the current - # assignments, and eliminate hints that pertain to variables that already - # have an assignment. - hints = reduce_hints(hints, constraint_system.assignments) - # If unknowns remain and we have fully reduced the system, we may still - # be able to make progress by extracting an assignment from a `Hint`. This - # new assignment could make the system unsatisfiable, so we use a recursive + # be able to make progress by trying out potential assignments. These + # new assignments could make the system unsatisfiable, so we use a recursive # call to be able to backtrack if necessary. for assignment in conjure_assignment( - remaining_unknowns, constraint_system, hints + remaining_unknowns, constraint_system ): if fuel <= 0: raise ValueError( @@ -476,7 +407,7 @@ def find_assignments_for( # This assignment is not compatible with the constraint system. continue solution, fuel = find_assignments_for( - unknowns, new_constraint_system, hints, fuel=fuel + unknowns, new_constraint_system, fuel=fuel ) if not isinstance(solution, cs.Unsatisfiable): return solution, fuel @@ -516,8 +447,8 @@ def producer_ref(self, operand: ValueSite) -> cs.Variable: ValueSitesForVariable = dict[cs.Variable, list[ValueSite]] # A constraint system derivation rule is a function that takes an MLIR operation -# and returns a constraint system, a mapping from variables to value site -# identifiers, and a list of hints. +# and returns a constraint system, and a mapping from variables to value site +# identifiers. # # The intended meaning of the mapping is that, for each identifier in the list # keyed by a given variable, the MLIR operand/result/argument corresponding to @@ -530,7 +461,7 @@ def producer_ref(self, operand: ValueSite) -> cs.Variable: # operands/results/arguments that correspond to the given operation. ConstraintSystemDerivationRule = Callable[ [DerivationContext, ir.OpView], - tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]], + tuple[cs.ConstraintSystem, ValueSitesForVariable], ] _constraint_system_derivation_rules: dict[ str, ConstraintSystemDerivationRule @@ -561,11 +492,11 @@ def _is_tmem_ref(v: ir.Value) -> bool: def _pointwise_op_constraint_system( ctx: DerivationContext, op: ir.OpView, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: del ctx all_value_sites = vector_value_sites(op) variable = cs.Variable(all_value_sites[-1]) - return cs.ConstraintSystem(), {variable: all_value_sites}, [] + return cs.ConstraintSystem(), {variable: all_value_sites} for op in [ @@ -617,7 +548,7 @@ def _pointwise_op_constraint_system( def _vector_load_constraint_system( ctx: DerivationContext, op: mgpu.VectorLoadOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: # TODO(b/447079781): Investigate whether we should check for contiguous # strides here. An initial implementation of this failed the # test_gmem_to_smem_with_multiple_smem_indexers_and_transforms test, but @@ -638,14 +569,14 @@ def _vector_load_constraint_system( constraints.append(cs.IsTransferable(source_var, dest_var, shape)) system = cs.ConstraintSystem(constraints=constraints) - return system, value_sites_for_variable, [] + return system, value_sites_for_variable @_add_constraint_system_derivation_rule(mgpu.VectorStoreOp) def _vector_store_constraint_system( ctx: DerivationContext, op: mgpu.VectorStoreOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: # TODO(b/447079781): Investigate whether we should check for contiguous # strides here. An initial implementaiton of this failed the # test_gmem_to_smem_with_multiple_smem_indexers_and_transforms test, but @@ -666,46 +597,46 @@ def _vector_store_constraint_system( constraints.append(cs.IsTransferable(value_var, dest_var, shape)) system = cs.ConstraintSystem(constraints=constraints) - return system, value_sites_for_variable, [] + return system, value_sites_for_variable @_add_constraint_system_derivation_rule(mgpu.DebugPrintOp) def _debug_print_constraint_system( ctx: DerivationContext, op: mgpu.DebugPrintOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: del ctx value = ValueSite(op, VariableType.OPERAND, 0) - return cs.ConstraintSystem(), {cs.Variable(value): [value]}, [] + return cs.ConstraintSystem(), {cs.Variable(value): [value]} @_add_constraint_system_derivation_rule(mgpu.PrintLayoutOp) def _print_layout_constraint_system( ctx: DerivationContext, op: mgpu.PrintLayoutOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: value = ValueSite(op, VariableType.OPERAND, 0) var = cs.Variable(value) if is_vector(op.value) else ctx.producer_ref(value) - return cs.ConstraintSystem(), {var: [value]}, [] + return cs.ConstraintSystem(), {var: [value]} @_add_constraint_system_derivation_rule(mgpu.BroadcastedIotaOp) def _broadcasted_iota_constraint_system( ctx: DerivationContext, op: mgpu.BroadcastedIotaOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: del ctx value = ValueSite(op, VariableType.RESULT, 0) var = cs.Variable(value) constraints = [cs.NotOfType(var, fa.WGSplatFragLayout)] - return cs.ConstraintSystem(constraints=constraints), {var: [value]}, [] + return cs.ConstraintSystem(constraints=constraints), {var: [value]} @_add_constraint_system_derivation_rule(mgpu.OptimizationBarrierOp) def _optimization_barrier_constraint_system( ctx: DerivationContext, op: ir.OpView, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: del ctx value_sites_for_variable: ValueSitesForVariable = {} @@ -718,14 +649,14 @@ def _optimization_barrier_constraint_system( ValueSite(op, VariableType.RESULT, i) ] - return cs.ConstraintSystem(), value_sites_for_variable, [] + return cs.ConstraintSystem(), value_sites_for_variable @_add_constraint_system_derivation_rule(vector.BroadcastOp) def _vector_splat_constraint_system( ctx: DerivationContext, op: ir.OpView, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: del ctx result = ValueSite(op, VariableType.RESULT, 0) variable = cs.Variable(result) @@ -733,14 +664,14 @@ def _vector_splat_constraint_system( system = cs.ConstraintSystem( assignments={variable: cs.RegisterLayout(layout)} ) - return system, {variable: [result]}, [] + return system, {variable: [result]} @_add_constraint_system_derivation_rule(arith.ConstantOp) def _constant_constraint_system( ctx: DerivationContext, constant_op: arith.ConstantOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: del ctx value = constant_op.value result = ValueSite(constant_op, VariableType.RESULT, 0) @@ -758,7 +689,7 @@ def _constant_constraint_system( constant_is_not_splat = cs.NotOfType(variable, fa.WGSplatFragLayout) system = cs.ConstraintSystem(constraints=[constant_is_not_splat]) - return system, {variable: [result]}, [] + return system, {variable: [result]} def _terminator( @@ -777,7 +708,7 @@ def _terminator( def _for_constraint_system( ctx: DerivationContext, op: scf.ForOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: [block] = op.region.blocks yield_op = _terminator(block, scf.YieldOp) value_sites_for_variable: ValueSitesForVariable = {} @@ -799,7 +730,7 @@ def _for_constraint_system( var = cs.Variable(operand) if is_vector(o) else ctx.producer_ref(operand) value_sites_for_variable[var] = [operand, arg, result, yield_operand] - return cs.ConstraintSystem(), value_sites_for_variable, [] + return cs.ConstraintSystem(), value_sites_for_variable def prime_decomposition(n: int) -> list[int]: @@ -841,7 +772,7 @@ def dynamic_gcd(a: int, b: ir.Value) -> int: def _while_constraint_system( ctx: DerivationContext, op: scf.WhileOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: del ctx [before_block] = op.before.blocks [after_block] = op.after.blocks @@ -873,14 +804,14 @@ def _while_constraint_system( case _ as never: assert_never(never) # pytype: disable=wrong-arg-types - return cs.ConstraintSystem(), value_sites_for_variable, [] + return cs.ConstraintSystem(), value_sites_for_variable @_add_constraint_system_derivation_rule(scf.IndexSwitchOp) def _index_switch_constraint_system( ctx: DerivationContext, op: scf.IndexSwitchOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: del ctx value_sites_for_variable: ValueSitesForVariable = { cs.Variable(o): [o] for o in vector_value_sites(op) @@ -895,14 +826,14 @@ def _index_switch_constraint_system( ) value_sites_for_variable[value_site].append(yield_operand) - return cs.ConstraintSystem(), value_sites_for_variable, [] + return cs.ConstraintSystem(), value_sites_for_variable @_add_constraint_system_derivation_rule(mgpu.LayoutCastOp) def _layout_cast_constraint_system( ctx: DerivationContext, op: mgpu.LayoutCastOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: del ctx operand = ValueSite(op, VariableType.OPERAND, 0) result = ValueSite(op, VariableType.RESULT, 0) @@ -911,7 +842,6 @@ def _layout_cast_constraint_system( return ( cs.ConstraintSystem(assignments={variable: out_layout}), {variable: [operand, result]}, - [], ) @@ -981,7 +911,7 @@ def _infer_wgmma_tiling( def _wgmma_constraint_system( ctx: DerivationContext, op: mgpu.WGMMAOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: assignments: dict[cs.Variable, cs.Constant] = {} value_sites_for_variable: ValueSitesForVariable = {} @@ -1010,14 +940,14 @@ def _wgmma_constraint_system( assignments[a_var] = cs.RegisterLayout(fa.WGMMA_LAYOUT) value_sites_for_variable[a_var] = [a] - return cs.ConstraintSystem(assignments), value_sites_for_variable, [] + return cs.ConstraintSystem(assignments), value_sites_for_variable @_add_constraint_system_derivation_rule(vector.BroadcastOp) def _vector_broadcast_constraint_system( ctx: DerivationContext, op: vector.BroadcastOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: del ctx # This is not expected to be necessary at the moment. We should be using # mgpu.BroadcastInDimOp instead when dealing with broadcasting vectors. @@ -1028,7 +958,6 @@ def _vector_broadcast_constraint_system( return ( cs.ConstraintSystem(assignments={out_variable: layout}), {out_variable: [out_variable.key]}, - [], ) @@ -1036,39 +965,29 @@ def _vector_broadcast_constraint_system( def _vector_reduction_constraint_system( ctx: DerivationContext, op: vector.ReductionOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: del ctx in_variable = cs.Variable(ValueSite(op, VariableType.OPERAND, 0)) - return cs.ConstraintSystem(), {in_variable: [in_variable.key]}, [] + return cs.ConstraintSystem(), {in_variable: [in_variable.key]} -def _reduction_constraints_and_hint( +def _reduction_constraints( larger: cs.Variable, smaller: cs.Variable, - larger_shape: tuple[int, ...], reduction_dims: tuple[int, ...], -) -> tuple[list[cs.Constraint], Hint]: - reduce_expr = cs.Reduce(larger, reduction_dims) - # There are always many options for broadcasting a layout, so we can only - # derive a broadcast hint in the out_variable -> source_variable direction. - broadcast_dims = tuple( - i for i in range(len(larger_shape)) if i not in reduction_dims - ) - broadcast_expr = cs.BroadcastInDim(smaller, broadcast_dims, larger_shape) - broadcast_hint = Hint(variable=larger, expression=broadcast_expr) - constraints = [ - cs.Equals(lhs=smaller, rhs=reduce_expr), +) -> list[cs.Constraint]: + return [ + cs.Equals(lhs=smaller, rhs=cs.Reduce(larger, reduction_dims)), # TODO(allanrenucci): Remove once we support reduction of strided layouts. cs.NotOfType(larger, fa.WGStridedFragLayout), ] - return constraints, broadcast_hint @_add_constraint_system_derivation_rule(vector.MultiDimReductionOp) def _multi_dim_reduction_constraint_system( ctx: DerivationContext, op: vector.MultiDimReductionOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: del ctx source = ValueSite(op, VariableType.OPERAND, 0) acc = ValueSite(op, VariableType.OPERAND, 1) @@ -1076,10 +995,9 @@ def _multi_dim_reduction_constraint_system( source_variable = cs.Variable(source) out_variable = cs.Variable(out) - reduction_constraints, broadcast_hint = _reduction_constraints_and_hint( + reduction_constraints = _reduction_constraints( source_variable, out_variable, - tuple(ir.ShapedType(op.source.type).shape), tuple(op.reduction_dims), ) # TODO(bchetioui): in the future, we may need to add rules that prevent @@ -1088,7 +1006,6 @@ def _multi_dim_reduction_constraint_system( return ( cs.ConstraintSystem(constraints=reduction_constraints), {source_variable: [source], out_variable: [acc, out]}, - [broadcast_hint], ) @@ -1096,7 +1013,7 @@ def _multi_dim_reduction_constraint_system( def _broadcast_in_dim_constraint_system( ctx: DerivationContext, op: mgpu.BroadcastInDimOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: del ctx out_variable = cs.Variable(ValueSite(op, VariableType.RESULT, 0)) source_variable = cs.Variable(ValueSite(op, VariableType.OPERAND, 0)) @@ -1104,9 +1021,8 @@ def _broadcast_in_dim_constraint_system( reduction_dims = tuple( i for i in range(len(out_shape)) if i not in op.broadcast_dimensions ) - - reduction_constraints, broadcast_hint = _reduction_constraints_and_hint( - out_variable, source_variable, out_shape, reduction_dims + reduction_constraints = _reduction_constraints( + out_variable, source_variable, reduction_dims ) return ( @@ -1115,14 +1031,13 @@ def _broadcast_in_dim_constraint_system( source_variable: [source_variable.key], out_variable: [out_variable.key], }, - [broadcast_hint], ) @_add_constraint_system_derivation_rule(vector.ShapeCastOp) def _shape_cast_constraint_system( ctx: DerivationContext, op: vector.ShapeCastOp -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: del ctx in_shape = tuple(cast(ir.ShapedType, op.source.type).shape) out_shape = tuple(cast(ir.ShapedType, op.result.type).shape) @@ -1159,14 +1074,13 @@ def _shape_cast_constraint_system( ], ), {in_variable: [in_variable.key], out_variable: [out_variable.key]}, - [], ) @_add_constraint_system_derivation_rule(vector.ExtractStridedSliceOp) def _extract_strided_slice_constraint_system( ctx: DerivationContext, op: vector.ExtractStridedSliceOp -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: del ctx if any(ir.IntegerAttr(s).value != 1 for s in op.strides): raise NotImplementedError("`strides` must contain only 1s.") @@ -1186,7 +1100,6 @@ def _extract_strided_slice_constraint_system( # We use a single variable because lowering does not support two different # layouts for `source` and `result`. {variable: [operand, result]}, - [], ) @@ -1194,7 +1107,7 @@ def _extract_strided_slice_constraint_system( def _custom_primitive_constraint_system( ctx: DerivationContext, op: mgpu.CustomPrimitiveOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: assignments: dict[cs.Variable, cs.Constant] = {} constraints: list[cs.Constraint] = [] in_layouts = iter(op.in_layouts) @@ -1238,7 +1151,6 @@ def _custom_primitive_constraint_system( return ( cs.ConstraintSystem(assignments, constraints), {v: [v.key] for v in variables}, - [], ) @@ -1256,7 +1168,7 @@ def _tmem_layout_from_layout_attr( def _tmem_layout_cast_constraint_system( ctx: DerivationContext, op: mgpu.TmemLayoutCastOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: operand = ValueSite(op, VariableType.OPERAND, 0) variable = ctx.producer_ref(operand) result = ValueSite(op, VariableType.RESULT, 0) @@ -1264,7 +1176,6 @@ def _tmem_layout_cast_constraint_system( return ( cs.ConstraintSystem(assignments={variable: out_layout}), {variable: [operand, result]}, - [], ) @@ -1272,43 +1183,34 @@ def _tmem_layout_cast_constraint_system( def _tmem_alloc_constraint_system( ctx: DerivationContext, op: mgpu.TmemAllocOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: del ctx result = ValueSite(op, VariableType.RESULT, 0) result_var = cs.Variable(result) - layout = tcgen05._infer_tmem_layout( - tuple(op.result.type.shape), op.collective, packing=1 - ) - in_smem = ValueSite(op, VariableType.OPERAND, 0) in_smem_var = cs.Variable(in_smem) assignments: dict[cs.Variable, cs.Constant] = { in_smem_var: cs.SMEMTiling(None) } operands_for_variable = {result_var: [result], in_smem_var: [in_smem]} - - # This is a hint, not a hard constraint. This will be the default layout if - # none can be inferred. - hint = Hint(result_var, cs.TMEMLayout(layout)) - system = cs.ConstraintSystem(assignments=assignments) - return system, operands_for_variable, [hint] + return cs.ConstraintSystem(assignments=assignments), operands_for_variable @_add_constraint_system_derivation_rule(mgpu.TmemDeallocOp) def _tmem_dealloc_constraint_system( ctx: DerivationContext, op: mgpu.TmemDeallocOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: operand = ValueSite(op, VariableType.OPERAND, 0) variable = ctx.producer_ref(operand) - return cs.ConstraintSystem(), {variable: [operand]}, [] + return cs.ConstraintSystem(), {variable: [operand]} @_add_constraint_system_derivation_rule(mgpu.TcGen05MMAOp) def _tcgen05_mma_constraint_system( ctx: DerivationContext, op: mgpu.TcGen05MMAOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: assignments: dict[cs.Variable, cs.Constant] = {} operands_for_variable: ValueSitesForVariable = {} @@ -1364,14 +1266,14 @@ def _tcgen05_mma_constraint_system( assignments[a_var] = cs.SMEMTiling(lc.TileTransform(a_tiling)) operands_for_variable[a_var] = [a] - return cs.ConstraintSystem(assignments=assignments), operands_for_variable, [] + return cs.ConstraintSystem(assignments=assignments), operands_for_variable @_add_constraint_system_derivation_rule(mgpu.AsyncLoadTmemOp) def _async_load_tmem_constraint_system( ctx: DerivationContext, op: mgpu.AsyncLoadTmemOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: source = ValueSite(op, VariableType.OPERAND, 0) source_variable = ctx.producer_ref(source) destination = ValueSite(op, VariableType.RESULT, 0) @@ -1384,7 +1286,6 @@ def _async_load_tmem_constraint_system( return ( cs.ConstraintSystem(constraints=[constraint]), {source_variable: [source], destination_variable: [destination]}, - [], ) @@ -1392,7 +1293,7 @@ def _async_load_tmem_constraint_system( def _slice_tmem_constraint_system( ctx: DerivationContext, op: mgpu.SliceTmemOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: operand = ValueSite(op, VariableType.OPERAND, 0) operand_variable = ctx.producer_ref(operand) result = ValueSite(op, VariableType.RESULT, 0) @@ -1400,7 +1301,6 @@ def _slice_tmem_constraint_system( return ( cs.ConstraintSystem(), {operand_variable: [operand], result_variable: [result]}, - [], ) @@ -1408,7 +1308,7 @@ def _slice_tmem_constraint_system( def _async_store_tmem_constraint_system( ctx: DerivationContext, op: mgpu.AsyncStoreTmemOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: source = ValueSite(op, VariableType.OPERAND, 0) source_variable = cs.Variable(source) destination = ValueSite(op, VariableType.OPERAND, 1) @@ -1421,7 +1321,6 @@ def _async_store_tmem_constraint_system( return ( cs.ConstraintSystem(constraints=[constraint]), {source_variable: [source], destination_variable: [destination]}, - [], ) @@ -1429,18 +1328,18 @@ def _async_store_tmem_constraint_system( def _slice_smem_constraint_system( ctx: DerivationContext, op: mgpu.SliceSMEMOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: del ctx res = ValueSite(op, VariableType.RESULT, 0) res_var = cs.Variable(res) - return (cs.ConstraintSystem(), {res_var: [res]}, []) + return cs.ConstraintSystem(), {res_var: [res]} @_add_constraint_system_derivation_rule(memref.SubViewOp) def _memref_subview_constraint_system( ctx: DerivationContext, op: memref.SubViewOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: source = ValueSite(op, VariableType.OPERAND, 0) dest = ValueSite(op, VariableType.RESULT, 0) source_dest_var = ctx.producer_ref(source) @@ -1478,25 +1377,25 @@ def _memref_subview_constraint_system( constraints = [cs.Divides(source_dest_var, tuple(tiling_multiple))] system = cs.ConstraintSystem(constraints=constraints) - return system, {source_dest_var: [source, dest]}, [] + return system, {source_dest_var: [source, dest]} @_add_constraint_system_derivation_rule(memref.CastOp) def _memref_cast_op_constraint_system( ctx: DerivationContext, op: memref.CastOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: source = ValueSite(op, VariableType.OPERAND, 0) var_source_dest = ctx.producer_ref(source) dest = ValueSite(op, VariableType.RESULT, 0) - return cs.ConstraintSystem(), {var_source_dest: [source, dest]}, [] + return cs.ConstraintSystem(), {var_source_dest: [source, dest]} @_add_constraint_system_derivation_rule(memref.TransposeOp) def _memref_transpose_op_constraint_system( ctx: DerivationContext, op: memref.TransposeOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: in_ty = ir.MemRefType(op.in_.type) if len(in_ty.shape) != 2: raise NotImplementedError(f"Only 2D memrefs are supported, got {in_ty}") @@ -1509,7 +1408,7 @@ def _memref_transpose_op_constraint_system( source_var = ctx.producer_ref(source) if not transpose: - return (cs.ConstraintSystem(), {source_var: [source, dest]}, []) + return cs.ConstraintSystem(), {source_var: [source, dest]} dest_var = cs.Variable(dest) constraints = [ @@ -1517,14 +1416,14 @@ def _memref_transpose_op_constraint_system( cs.Equals(source_var, cs.Transpose(dest_var)), ] system = cs.ConstraintSystem(constraints=constraints) - return system, {source_var: [source], dest_var: [dest]}, [] + return system, {source_var: [source], dest_var: [dest]} @_add_constraint_system_derivation_rule(memref.ExpandShapeOp) def _memref_expand_shape_op_equation_system( ctx: DerivationContext, op: memref.ExpandShapeOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: if utils.is_memref_transposed(ir.MemRefType(op.src.type)): raise NotImplementedError( "Transposed memrefs are not supported in ExpandShapeOp." @@ -1545,7 +1444,7 @@ def _memref_expand_shape_op_equation_system( reverse_tiling_multiple.append(dim) constraints = [cs.Divides(var, tuple(reversed(reverse_tiling_multiple)))] - return cs.ConstraintSystem(constraints=constraints), {var: [source, dest]}, [] + return cs.ConstraintSystem(constraints=constraints), {var: [source, dest]} # `memref.load` and `memref.store` are used to load barrier phases which are @@ -1555,7 +1454,7 @@ def _memref_expand_shape_op_equation_system( def _memref_load_store_op_constraint_system( ctx: DerivationContext, op: memref.LoadOp | memref.StoreOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: del ctx ref_shape = ir.MemRefType(op.memref.type).shape @@ -1568,7 +1467,7 @@ def _memref_load_store_op_constraint_system( ref = ValueSite(op, VariableType.OPERAND, ref_op_index) var = cs.Variable(ref) assignments: dict[cs.Variable, cs.Constant] = {var: cs.SMEMTiling(None)} - return cs.ConstraintSystem(assignments=assignments), {var: [ref]}, [] + return cs.ConstraintSystem(assignments=assignments), {var: [ref]} def _extract_smem_tiling_from_custom_transform_attrs( @@ -1604,13 +1503,13 @@ def _extract_smem_tiling_from_custom_transform_attrs( def _with_transforms_constraint_system( ctx: DerivationContext, op: mgpu.WithTransformsOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: source = ValueSite(op, VariableType.OPERAND, 0) dest = ValueSite(op, VariableType.RESULT, 0) var = ctx.producer_ref(source) tiling = _extract_smem_tiling_from_custom_transform_attrs(op.ref.type, op.transforms) assignments: dict[cs.Variable, cs.Constant] = {var: tiling} - return cs.ConstraintSystem(assignments=assignments), {var: [source, dest]}, [] + return cs.ConstraintSystem(assignments=assignments), {var: [source, dest]} @_add_constraint_system_derivation_rule(mgpu.AsyncLoadOp) @@ -1618,7 +1517,7 @@ def _with_transforms_constraint_system( def _async_load_store_constraint_system( ctx: DerivationContext, op: mgpu.AsyncLoadOp | mgpu.AsyncStoreOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable, list[Hint]]: +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: tiling_multiple = [] for size, index in zip(op.slice_lengths, op.indices, strict=True): if size == -1: @@ -1630,7 +1529,7 @@ def _async_load_store_constraint_system( operand = ValueSite(op, VariableType.OPERAND, operand_index) var = ctx.producer_ref(operand) constraints = [cs.Divides(expr=var, tiling_multiple=tuple(tiling_multiple))] - return cs.ConstraintSystem(constraints=constraints), {var: [operand]}, [] + return cs.ConstraintSystem(constraints=constraints), {var: [operand]} def _ensure_all_layouts_are_set(op: ir.OpView) -> None: @@ -1858,11 +1757,10 @@ def consumer_operands(result: ValueSite) -> Sequence[ValueSite]: return consumer_operands -def derive_hints_and_constraints( +def derive_relayout_constraints( value_sites_for_variable: ValueSitesForVariable, -) -> tuple[list[Hint], list[cs.Relayout]]: - """Derives propagation hints from the given variable mapping.""" - hints: list[Hint] = [] +) -> list[cs.Relayout]: + """Derives relayout constraints from the given variable mapping.""" constraints: list[cs.Relayout] = [] variable_for_value_site: dict[ValueSite, cs.Variable] = {} for variable, value_sites in value_sites_for_variable.items(): @@ -1902,16 +1800,7 @@ def derive_hints_and_constraints( # A variable must be relayout-able to its consumers. constraints.append(cs.Relayout(variable, consumer_variable)) visited.add(variable) - - if producers: - least_replicated_producer = cs.LeastReplicated(tuple(producers)) - hint_expr = cs.MostReplicated((least_replicated_producer, *consumers)) - hints.append(Hint(variable, hint_expr)) - elif consumers: - hint_expr = cs.MostReplicated(tuple(consumers)) - hints.append(Hint(variable, hint_expr)) - - return hints, constraints + return constraints def is_terminator(op: ir.OpView) -> bool: @@ -1956,7 +1845,6 @@ def infer_layout( """ global_constraint_system: cs.ConstraintSystem | cs.Unsatisfiable global_constraint_system = cs.ConstraintSystem() - hints: list[Hint] = [] ctx = DerivationContext() def gather_constraints(op: ir.Operation): @@ -1976,11 +1864,10 @@ def gather_constraints(op: ir.Operation): rule = _constraint_system_derivation_rules.get(op.OPERATION_NAME, None) # pytype: disable=attribute-error if rule is None: raise NotImplementedError(f"No layout inference rule defined for {op}") - constraint_system, mapping, op_hints = rule(ctx, op) + constraint_system, mapping = rule(ctx, op) ctx.update(mapping) nonlocal global_constraint_system global_constraint_system &= constraint_system - hints.extend(op_hints) for op in module.body: traverse_op(op, gather_constraints) @@ -1991,8 +1878,7 @@ def gather_constraints(op: ir.Operation): "user-provided layout casts are unsatisfiable." ) - propagation_hints, constraints = derive_hints_and_constraints(ctx.value_sites_for_variable) - hints = reduce_hints(hints + propagation_hints, global_constraint_system.assignments) # pytype: disable=attribute-error + constraints = derive_relayout_constraints(ctx.value_sites_for_variable) global_constraint_system &= cs.ConstraintSystem(constraints=constraints) assert not isinstance(global_constraint_system, cs.Unsatisfiable) @@ -2010,7 +1896,6 @@ def gather_constraints(op: ir.Operation): solution, remaining_fuel = find_assignments_for( list(ctx.value_sites_for_variable.keys()), global_constraint_system, - hints, fuel=fuel, ) diff --git a/tests/mosaic/gpu_layout_inference_test.py b/tests/mosaic/gpu_layout_inference_test.py index cf7147ed551c..c52a0bdc1d46 100644 --- a/tests/mosaic/gpu_layout_inference_test.py +++ b/tests/mosaic/gpu_layout_inference_test.py @@ -65,7 +65,6 @@ def undefs(*tys: ir.Type) -> list[ir.Value]: V = cs.Variable -H = layout_inference.Hint E = cs.Equals RL = cs.RegisterLayout @@ -76,7 +75,6 @@ def _undef_constraint_system( ) -> tuple[ cs.ConstraintSystem, layout_inference.ValueSitesForVariable, - list[layout_inference.Hint], ]: del ctx # This rule is only called if the single output of the undef op is a vector or @@ -84,7 +82,7 @@ def _undef_constraint_system( result = layout_inference.ValueSite( op, layout_inference.VariableType.RESULT, 0 ) - return cs.ConstraintSystem(), {cs.Variable(result): [result]}, [] + return cs.ConstraintSystem(), {cs.Variable(result): [result]} class LayoutInferenceTest(parameterized.TestCase): @@ -289,7 +287,8 @@ def test_infer_layout_cast_layout(self): cast = mgpu.dialect.LayoutCastOp(add.result, wgmma_layout) mgpu.infer_layout(self.module) - self.checkOutLayouts(add, [splat_layout]) + # The layout of `add` may be either WGMMA or SPLAT. + self.checkOutLayouts(add, [wgmma_layout]) self.checkInLayouts(cast, [wgmma_layout]) self.checkOutLayouts(cast, [wgmma_layout]) @@ -668,27 +667,22 @@ def test_custom_primitive_op_retains_layouts(self): self.checkInLayouts(op, [wgmma_layout]) self.checkOutLayouts(op, [wgmma_row_layout]) - def test_hint_and_constraint_extraction_works_correctly(self): + def test_constraint_extraction_works_correctly(self): layout = mgpu.WGMMA_ROW_LAYOUT with ir.InsertionPoint(self.module.body): x = llvm.UndefOp(ir.VectorType.get((64,), ir.BF16Type.get())) lc = layout_cast(x.result, layouts.to_layout_attr(layout)).owner.opview ctx = layout_inference.DerivationContext() - x_system, x_mapping, _ = _undef_constraint_system(ctx, x) - lc_system, lc_mapping, _ = layout_inference._layout_cast_constraint_system( + _, x_mapping = _undef_constraint_system(ctx, x) + _, lc_mapping = layout_inference._layout_cast_constraint_system( ctx, lc ) - assignments = x_system.assignments | lc_system.assignments - hints, [constraint] = layout_inference.derive_hints_and_constraints( + [constraint] = layout_inference.derive_relayout_constraints( x_mapping | lc_mapping ) - [hint_cst] = layout_inference.reduce_hints(hints, assignments) - [x_variable] = x_mapping.keys() [lc_variable] = lc_mapping.keys() - self.assertEqual(hint_cst.variable, x_variable) - self.assertEqual(hint_cst.expression, RL(layout)) self.assertEqual(constraint, cs.Relayout(x_variable, lc_variable)) @parameterized.parameters(*layout_inference.MemorySpace) @@ -718,35 +712,15 @@ def test_relayout_only_derived_for_registers(self, memory_space): ) o_var = cs.Variable(o) - hints, relayouts = layout_inference.derive_hints_and_constraints( + relayouts = layout_inference.derive_relayout_constraints( layout_inference.ValueSitesForVariable({r_var: [r], o_var: [o]}) ) if memory_space == layout_inference.MemorySpace.REG: - hint0 = layout_inference.Hint(r_var, cs.MostReplicated((o_var,))) - hint1 = layout_inference.Hint( - o_var, cs.MostReplicated((cs.LeastReplicated((r_var,)),)) - ) - - self.assertEqual(hints, [hint0, hint1]) self.assertEqual(relayouts, [cs.Relayout(r_var, o_var)]) else: - self.assertEmpty(hints) self.assertEmpty(relayouts) - def test_unambiguous_hints_are_used_to_assign_variables_correctly(self): - v0 = V(0) - assignments, _ = layout_inference.find_assignments_for( - {v0}, - cs.ConstraintSystem(), - # Voluntarily use conflicting hints to check that we use one of them - # deterministically. This may require updating if we decide to change - # the traversal order in the future. - [H(v0, RL(mgpu.WGMMA_ROW_LAYOUT)), H(v0, RL(mgpu.WGMMA_COL_LAYOUT))], - fuel=1000, - ) - self.assertEqual(assignments, {v0: RL(mgpu.WGMMA_ROW_LAYOUT)}) - def test_find_assignments_for_is_transferable_constraints_is_deterministic( self, ): @@ -758,7 +732,6 @@ def test_find_assignments_for_is_transferable_constraints_is_deterministic( assignments, _ = layout_inference.find_assignments_for( {v0}, cs.ConstraintSystem(constraints=[constraint]), - [], fuel=1000, ) # Another valid layout is TMEM_NATIVE_LAYOUT but TCGEN05_LAYOUT is tried @@ -780,77 +753,10 @@ def test_cannot_find_assignments_for_unsatisfiable_constraint_system(self): E(variable, RL(mgpu.WGMMA_COL_LAYOUT)), ] ), - hints=[], fuel=1000, ) self.assertIsInstance(assignments, cs.Unsatisfiable) - def test_hint_that_would_make_system_unsatisfiable_is_not_used_in_solution(self): - with ir.InsertionPoint(self.module.body): - ty = ir.VectorType.get((32, 4), ir.BF16Type.get()) - op0, op1 = [llvm.mlir_undef(ty).owner.opview for _ in range(2)] - [kv0] = layout_inference.vector_value_sites(op0) - [kv1] = layout_inference.vector_value_sites(op1) - v0, v1 = cs.Variable(kv0), cs.Variable(kv1) - splat_layout = RL(mgpu.WGSplatFragLayout((3, 128))) - assignments, _ = layout_inference.find_assignments_for( - {v0}, - cs.ConstraintSystem( - constraints=[ - E( - v0, - cs.MostReplicated( - [v1, RL(mgpu.WGStridedFragLayout((3, 128), vec_size=1))] - ), - ) - ] - ), - # The first hint would make the system unsatisfiable, but the second - # hint should be used to find a solution. - hints=[H(v1, RL(mgpu.WGMMA_LAYOUT)), H(v1, splat_layout)], - fuel=1000, - ) - self.assertEqual(assignments, {v0: splat_layout}) - - def test_hint_can_be_chosen_when_constant_exists_in_least_replicated_expression(self): - v0, v1 = V(0), V(1) - layout = RL(mgpu.WGMMA_LAYOUT) - assignment = layout_inference.extract_variable_assignment_from_hint( - H(v0, cs.LeastReplicated([layout, v1])), - ) - self.assertEqual(assignment, (v0, layout)) - - def test_hint_cannot_be_chosen_when_constant_exists_in_most_replicated_expression(self): - v0, v1 = V(0), V(1) - layout = RL(mgpu.WGSplatFragLayout((1, 128))) - assignment = layout_inference.extract_variable_assignment_from_hint( - H(v0, cs.MostReplicated([layout, v1])), - ) - self.assertEqual(assignment, (v0, layout)) - - def test_hint_is_still_extracted_when_underlying_expression_is_unsatisfiable(self): - v0, v1 = V(0), V(1) - layout0 = RL(mgpu.WGSplatFragLayout((1, 128))) - layout1 = RL(mgpu.WGStridedFragLayout((1, 256), vec_size=2)) - hint_expr = cs.LeastReplicated([layout0, cs.MostReplicated([layout1, v1])]) - self.assertIsInstance( - cs.reduce_expression(hint_expr, {v1: layout1}), cs.Unsatisfiable - ) - _, expr = layout_inference.extract_variable_assignment_from_hint( - H(v0, hint_expr)) - self.assertIsNotNone(expr) - - def test_least_replicated_hint_is_still_resolved_when_all_known_choices_are_replicated( - self, - ): - v0, v1 = V(0), V(1) - layout0 = RL(mgpu.WGSplatFragLayout((1, 128))) - layout1 = RL(mgpu.WGSplatFragLayout((1, 129))) - assignment = layout_inference.extract_variable_assignment_from_hint( - H(v0, cs.LeastReplicated([v1, layout0, layout1])), - ) - self.assertIsNotNone(assignment) - def test_vector_broadcast_from_scalar_infers_splat_layout(self): shape = (128,) f32 = ir.F32Type.get() @@ -1196,7 +1102,7 @@ def test_conjure_smem_assignment_from_is_transferrable(self, transposed): def conjure(constraints) -> list[tuple[cs.Variable, cs.Constant]]: system = cs.ConstraintSystem(constraints=constraints) - return list(layout_inference.conjure_assignment({var}, system, [])) + return list(layout_inference.conjure_assignment({var}, system)) # Yield only empty tiling with no constraints. with self.subTest("no_constraints_yield_empty_tiling"): @@ -1246,8 +1152,7 @@ def conjure(constraints) -> list[tuple[cs.Variable, cs.Constant]]: ], ) - def test_conjure_orders_hints_correctly(self): - # Create a var to use in the constraint system. + def test_conjure_tries_high_priority_assignments_first(self): shape = (128, 128) f32 = ir.F32Type.get() [val] = undefs(ir.VectorType.get(shape, f32)) @@ -1258,23 +1163,23 @@ def test_conjure_orders_hints_correctly(self): ) var = cs.Variable(value_site) - hints = [ - layout_inference.Hint( + constraints = [ + cs.Relayout( var, cs.RegisterLayout(fa.WGSplatFragLayout((128, 128))), ), - layout_inference.Hint( + cs.Relayout( var, cs.RegisterLayout(fa.WGMMA_LAYOUT), ), - layout_inference.Hint( + cs.Relayout( var, cs.RegisterLayout(fa.WGStridedFragLayout(shape, vec_size=4)), ), ] - system = cs.ConstraintSystem() - ordered = list(layout_inference.conjure_assignment({var}, system, hints)) + system = cs.ConstraintSystem(constraints=constraints) + ordered = list(layout_inference.conjure_assignment({var}, system)) expected = [ (var, cs.RegisterLayout(fa.WGMMA_LAYOUT)), (var, cs.RegisterLayout(fa.WGSplatFragLayout((128, 128)))), From 0339ad16c1ffcdb3eff6818272d52c827828993f Mon Sep 17 00:00:00 2001 From: Michael Whittaker Date: Mon, 8 Dec 2025 15:45:42 -0800 Subject: [PATCH 104/315] Change PjRt to use new copy of coordination service. PiperOrigin-RevId: 841944431 --- jaxlib/BUILD | 1 + jaxlib/jax.cc | 27 ++++++++++++++++++++++++++- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index e9cf8b41fc48..bfcff630cc89 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -397,6 +397,7 @@ nanobind_pywrap_extension( "@xla//xla/pjrt/distributed:key_value_store_interface", "@xla//xla/pjrt/distributed:protocol_proto_cc", "@xla//xla/pjrt/distributed:service", + "@xla//xla/pjrt/distributed/preemption:preemption_sync_manager", "@xla//xla/pjrt/plugin/xla_cpu:cpu_client_options", "@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", "@xla//xla/python:logging", diff --git a/jaxlib/jax.cc b/jaxlib/jax.cc index 1571ddb29bb9..86be4995b823 100644 --- a/jaxlib/jax.cc +++ b/jaxlib/jax.cc @@ -116,6 +116,7 @@ limitations under the License. #include "xla/hlo/builder/lib/approx_topk_shape.h" #include "xla/pjrt/c_api_client/pjrt_c_api_client.h" #include "xla/pjrt/distributed/key_value_store_interface.h" +#include "xla/pjrt/distributed/preemption/preemption_sync_manager.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/pjrt_api.h" #include "xla/pjrt/pjrt_client.h" @@ -580,6 +581,30 @@ NB_MODULE(_jax, m) { aux::RegisterTransferServerTypes(m); #endif // defined(__linux__) +#if JAX_IFRT_VERSION_NUMBER >= 39 + nb::class_ preemption_sync_manager( + m, "PreemptionSyncManager"); + preemption_sync_manager + .def( + "initialize", + [](xla::PreemptionSyncManager& manager, + xla::DistributedRuntimeClient* client) { + xla::CoordinationServiceAgent* agent = + xla::ValueOrThrow(client->GetCoordinationServiceAgent()); + xla::ThrowIfError(manager.Initialize(agent)); + }, + nb::arg("distributed_client")) + .def("reached_sync_point", + [](xla::PreemptionSyncManager& manager, int step_counter) { + return manager.ReachedSyncPoint(step_counter); + }) + .def("shutdown", [](xla::PreemptionSyncManager& manager) { + nb::gil_scoped_release gil_release; + manager.Shutdown(); + }); + m.def("create_preemption_sync_manager", + []() { return xla::CreatePreemptionSyncManager(); }); +#else nb::class_ preemption_sync_manager( m, "PreemptionSyncManager"); preemption_sync_manager @@ -602,6 +627,7 @@ NB_MODULE(_jax, m) { }); m.def("create_preemption_sync_manager", []() { return tsl::CreatePreemptionSyncManager(); }); +#endif nb::class_ distributed_runtime_service( m, "DistributedRuntimeService"); @@ -898,7 +924,6 @@ NB_MODULE(_jax, m) { nb::class_( m, "TransferServerInterfaceFactory"); - m.def("is_asan", IsAsan); m.def("is_msan", IsMsan); m.def("is_tsan", IsTsan); From 2a6de35abf0159dc1c5f21dccba727deeaa91733 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Mon, 8 Dec 2025 17:09:21 -0800 Subject: [PATCH 105/315] [Pallas SC] Allow semaphores to be returned by SCS kernels PiperOrigin-RevId: 841972235 --- .../pallas/mosaic/pallas_call_registration.py | 27 +++++--- jax/_src/pallas/mosaic/sc_lowering.py | 5 +- jax/_src/tpu_custom_call.py | 3 + tests/pallas/tpu_sparsecore_pallas_test.py | 62 +++++++++++++++++++ 4 files changed, 89 insertions(+), 8 deletions(-) diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 390781e05820..4811aa9f12d6 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -62,7 +62,7 @@ def _maybe_cast_to_int(x: jax.Array | jax_core.AbstractValue): def _get_memory_space_from_aval( - out_aval: jax_core.AbstractValue, + out_aval: jax_core.AbstractValue, kernel_type: tpu_core.KernelType ) -> tpu_custom_call.MemorySpace | None: if not isinstance(out_aval, jax_core.ShapedArray): raise ValueError("Memory spaces not defined for non-ShapedArrays") @@ -84,20 +84,29 @@ def _get_memory_space_from_aval( case tpu_core.MemorySpace.SMEM: return tpu_custom_call.MemorySpace.SMEM case tpu_core.MemorySpace.SEMAPHORE: - return tpu_custom_call.MemorySpace.SEMAPHORE_MEM + match kernel_type: + case tpu_core.KernelType.SC_SCALAR_SUBCORE: + return tpu_custom_call.MemorySpace.SC_SCALAR_SEMAPHORE_MEM + case tpu_core.KernelType.TC: + return tpu_custom_call.MemorySpace.SEMAPHORE_MEM + case _: + raise ValueError(f"Invalid kernel type for semaphore: {kernel_type}") case tpu_core.MemorySpace.HOST: return tpu_custom_call.MemorySpace.HOST return None def _get_memory_spaces_from_avals( - avals: Sequence[jax_core.AbstractValue], + avals: Sequence[jax_core.AbstractValue], kernel_type: tpu_core.KernelType ) -> tuple[tpu_custom_call.MemorySpace | None, ...] | None: memory_spaces = None if any( isinstance(aval, pallas_core.ShapedArrayWithMemorySpace) for aval in avals ): - memory_spaces = tuple(map(_get_memory_space_from_aval, avals)) + memory_spaces = tuple( + _get_memory_space_from_aval(aval, kernel_type=kernel_type) + for aval in avals + ) return memory_spaces @@ -140,7 +149,7 @@ def pallas_call_tpu_lowering_rule( mlir_ctx.load_all_available_dialects() tpu.register_dialect(mlir_ctx) - match mosaic_params.kernel_type: + match (kernel_type := mosaic_params.kernel_type): case tpu_core.KernelType.TC: lower_jaxpr_to_module = lowering.lower_jaxpr_to_module case tpu_core.KernelType.SC_SCALAR_SUBCORE | tpu_core.KernelType.SC_VECTOR_SUBCORE: @@ -191,7 +200,9 @@ def _maybe_cast_inputs(*args): # Dynamic grid bounds have to go at the front. dynamic_grid_args, args = in_nodes[:num_dyn_bounds], in_nodes[num_dyn_bounds:] kernel_ctx = ctx.replace(avals_in=kernel_in_avals, avals_out=kernel_out_avals) - output_memory_spaces = _get_memory_spaces_from_avals(out_avals) + output_memory_spaces = _get_memory_spaces_from_avals( + out_avals, kernel_type=kernel_type + ) input_memory_spaces = None if any( isinstance(aval, pallas_core.ShapedArrayWithMemorySpace) @@ -202,7 +213,9 @@ def _maybe_cast_inputs(*args): raise NotImplementedError( "Dynamic grid bounds are not supported when specifying memory spaces for inputs." ) - input_memory_spaces = _get_memory_spaces_from_avals(ctx.avals_in) + input_memory_spaces = _get_memory_spaces_from_avals( + ctx.avals_in, kernel_type=kernel_type + ) if cost_estimate is not None: mosaic_cost_estimate = cast( tpu_custom_call.CostEstimate, dataclasses.asdict(cost_estimate) diff --git a/jax/_src/pallas/mosaic/sc_lowering.py b/jax/_src/pallas/mosaic/sc_lowering.py index b49c981cdc6a..c6da2956557b 100644 --- a/jax/_src/pallas/mosaic/sc_lowering.py +++ b/jax/_src/pallas/mosaic/sc_lowering.py @@ -330,7 +330,10 @@ def body_func(*args: ir.Value): mosaic_grid_mapping.block_mappings, ): d = {} - if str(arg.type.memory_space) == "#tpu.memory_space": + if ( + str(arg.type.memory_space) == "#tpu.memory_space" + or str(arg.type.memory_space) == "#tpu.memory_space" + ): d["sc.persistent"] = ir.UnitAttr.get() if isinstance(bm, sc_core.BlockMapping) and bm.indexed_by is not None: d["sc.indexed_by"] = mlir.i32_attr(bm.indexed_by) diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index f9c94f860ced..81c63f94cce4 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -101,6 +101,7 @@ class MemorySpace(enum.Enum): SEMAPHORE_MEM = enum.auto() SMEM = enum.auto() HOST = enum.auto() + SC_SCALAR_SEMAPHORE_MEM = enum.auto() @property def color(self) -> int: @@ -110,6 +111,8 @@ def color(self) -> int: return 1 elif self == MemorySpace.SEMAPHORE_MEM: return 2 + elif self == MemorySpace.SC_SCALAR_SEMAPHORE_MEM: + return 8 elif self == MemorySpace.SMEM: return 4 elif self == MemorySpace.HOST: diff --git a/tests/pallas/tpu_sparsecore_pallas_test.py b/tests/pallas/tpu_sparsecore_pallas_test.py index 18ce12312436..17d1e3acf098 100644 --- a/tests/pallas/tpu_sparsecore_pallas_test.py +++ b/tests/pallas/tpu_sparsecore_pallas_test.py @@ -1923,5 +1923,67 @@ class PipelineTestWithTCTiling(TCTilingMixin, PipelineTest): pass +class PallasSparsecoreAsyncTest(PallasSCTest): + + @parameterized.product( + shape=[ + (8, 128), + (8, 256), + (8, 512), + (8, 1024), + (16, 128), + (16, 256), + (16, 512), + (16, 1024), + # TODO(sharadmv): These shapes fail right now. + # (64, 8), + ], + dtype=[jnp.int32, jnp.float32, jnp.bfloat16], + ) + def test_basic_async_kernel(self, shape, dtype): + if not jtu.is_cloud_tpu_at_least(2025, 12, 8): + self.skipTest("Need newer libtpu") + x = jnp.arange(shape[0] * shape[1], dtype=dtype).reshape(shape) + + @jax.jit + def foo(x): + sc_mesh = plsc.ScalarSubcoreMesh(axis_name="core", num_cores=1) + + sem = pl.pallas_call( + lambda _: None, + out_shape=pltpu.SemaphoreType.DMA(()), + out_specs=pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + compiler_params=pltpu.CompilerParams( + dimension_semantics=["core_parallel"], + kernel_type=pltpu.KernelType.SC_SCALAR_SUBCORE, + ), + )() + + sem_ref = jax.new_ref(sem, memory_space=pltpu.SEMAPHORE) + y_ref = pl.empty_ref_like(pltpu.HBM(x.shape, x.dtype)) + x_ref = jax.new_ref(x) + + run_kernel = pl.core_map(mesh=sc_mesh) + + @run_kernel + def _(): + pltpu.make_async_copy(x_ref, y_ref, sem_ref).start() + + @run_kernel + def _(): + pltpu.make_async_copy(x_ref, y_ref, sem_ref).wait() + + return y_ref[...] + + o = jax.block_until_ready(foo(x)) + np.testing.assert_array_equal(o, x) + + +class PallasSparsecoreAsyncTestWithTCTiling( + TCTilingMixin, PallasSparsecoreAsyncTest +): + pass + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 890ccd23c3728a152ad551e187cea3777af9e435 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 8 Dec 2025 17:27:51 -0800 Subject: [PATCH 106/315] Deprecate `with mesh:` context manager. Use `with jax.set_mesh(mesh):` instead PiperOrigin-RevId: 841977281 --- CHANGELOG.md | 2 ++ jax/_src/mesh.py | 5 +++++ jax/_src/pjit.py | 2 +- jax/experimental/jax2tf/tests/tf_test_util.py | 2 ++ tests/cache_key_test.py | 2 ++ tests/fused_attention_stablehlo_test.py | 2 ++ tests/multiprocess/multihost_utils_test.py | 2 ++ tests/multiprocess/pjit_test.py | 2 ++ tests/pjit_test.py | 7 ++++++- tests/python_callback_test.py | 2 ++ 10 files changed, 26 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2cf1fd16e09b..882ea245ed5a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * Deprecations * `jax.lax.pvary` has been deprecated. Please use `jax.lax.pcast(..., to='varying')` as the replacement. + * `with mesh:` context manager has been deprecated. + Please use `with jax.set_mesh(mesh):` instead. * Changes: * jax's `Tracer` no longer inherits from `jax.Array` at runtime. However, diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index f0210d5d9b82..6ab9f7703324 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -24,6 +24,7 @@ import math import threading from typing import Any, NamedTuple +import warnings import numpy as np @@ -322,6 +323,10 @@ def __setattr__(self, name, value): def __enter__(self): if jax_config.disallow_mesh_context_manager.value: raise RuntimeError("Mesh context manager is disabled.") + warnings.warn( + "`with mesh:` context manager has been deprecated. Please use `with" + " jax.set_mesh(mesh):` instead.", + category=DeprecationWarning, stacklevel=2) new_env = thread_resources.stack[-1].with_mesh(self) thread_resources.stack.append(new_env) thread_resources.env = new_env diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 1e9f390702d4..d9d3aa16d35f 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -367,7 +367,7 @@ def _parse_jit_arguments(fun: Callable, *, in_shardings: Any, 'backend and device argument on jit is deprecated. You can use' ' `jax.device_put(..., jax.local_devices(backend="cpu")[0])` on the' ' inputs to the jitted function to get the same behavior.', - DeprecationWarning, + category=DeprecationWarning, stacklevel=2 ) if device is not None and backend is not None: raise ValueError("can't specify both a device and a backend for jit, " diff --git a/jax/experimental/jax2tf/tests/tf_test_util.py b/jax/experimental/jax2tf/tests/tf_test_util.py index df7e59a0d8ce..0514ae0530c9 100644 --- a/jax/experimental/jax2tf/tests/tf_test_util.py +++ b/jax/experimental/jax2tf/tests/tf_test_util.py @@ -188,6 +188,8 @@ def setUp(self): self.assertGreaterEqual(version, export.minimum_supported_calling_convention_version) self.enter_context(config.jax_export_calling_convention_version(version)) + self.enter_context(jtu.ignore_warning( + category=DeprecationWarning, message='`with mesh:` context manager')) logging.info( "Using JAX serialization version %s (export.max_version %s, tf.XlaCallModule max version %s)", version, diff --git a/tests/cache_key_test.py b/tests/cache_key_test.py index 35ac03011a97..96faa47be7e2 100644 --- a/tests/cache_key_test.py +++ b/tests/cache_key_test.py @@ -165,6 +165,8 @@ def test_different_computations(self): # TODO(phawkins): this test flakes if test concurrency is enabled. @jtu.thread_unsafe_test() + @jtu.ignore_warning(category=DeprecationWarning, + message='`with mesh:` context manager') def test_custom_partitioning_ptr_removal(self): def _partition(mesh, arg_shapes, result_shape): arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) diff --git a/tests/fused_attention_stablehlo_test.py b/tests/fused_attention_stablehlo_test.py index 8df402e6e4ff..e968107e6f6e 100644 --- a/tests/fused_attention_stablehlo_test.py +++ b/tests/fused_attention_stablehlo_test.py @@ -753,6 +753,8 @@ def generate_segment_mask(segment_ids, dtype): self.assertArraysAllClose(value_grad_ref, value_grad, rtol=1e-2, atol=1e-2) @jtu.run_on_devices("cuda") + @jtu.ignore_warning(category=DeprecationWarning, + message='`with mesh:` context manager') def test_sdpa_residual(self): k1, k2, k3, k4, k5 = jax.random.split(jax.random.key(0), 5) query = jax.random.normal( diff --git a/tests/multiprocess/multihost_utils_test.py b/tests/multiprocess/multihost_utils_test.py index d3ce2d5d6393..fe7aec5c630d 100644 --- a/tests/multiprocess/multihost_utils_test.py +++ b/tests/multiprocess/multihost_utils_test.py @@ -171,6 +171,8 @@ def test_sync_global_devices_error(self): else: multihost_utils.sync_global_devices('test message2') + @jtu.ignore_warning(category=DeprecationWarning, + message='`with mesh:` context manager') def test_sync_global_devices_mesh_context_manager(self): global_mesh = jtu.create_mesh((2, 2), ('x', 'y'), iota_order=True) with global_mesh: diff --git a/tests/multiprocess/pjit_test.py b/tests/multiprocess/pjit_test.py index 79c0721ab66b..4ef9e1c4c2cf 100644 --- a/tests/multiprocess/pjit_test.py +++ b/tests/multiprocess/pjit_test.py @@ -381,6 +381,8 @@ def _lower_compile(inp): for out in list(result): np.testing.assert_array_equal(out(x), expected_out) + @jtu.ignore_warning(category=DeprecationWarning, + message='`with mesh:` context manager') def test_fully_sharded_on_all_devices(self): if jax.local_device_count() > 1: self.skipTest("This test only works with 1 process per device.") diff --git a/tests/pjit_test.py b/tests/pjit_test.py index b55beb335dd3..2b8ff9198420 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -95,8 +95,9 @@ def check_1d_2d_mesh(f, set_mesh): ))(jtu.with_mesh_from_kwargs(f) if set_mesh else f) -# TODO(skye): make the buffer donation utils part of JaxTestCase @jtu.pytest_mark_if_available('multiaccelerator') +@jtu.ignore_warning(category=DeprecationWarning, + message='`with mesh:` context manager') class PJitTest(jtu.BufferDonationTestCase): @jtu.with_mesh([('x', 1)]) @@ -1491,6 +1492,8 @@ def test_pjit_array_error(self): @jtu.pytest_mark_if_available('multiaccelerator') +@jtu.ignore_warning(category=DeprecationWarning, + message='`with mesh:` context manager') class ArrayPjitTest(jtu.JaxTestCase): @parameterized.named_parameters( @@ -9823,6 +9826,8 @@ def test_c64_to_f32_view_rountrip(self, mesh): @jtu.pytest_mark_if_available('multiaccelerator') +@jtu.ignore_warning(category=DeprecationWarning, + message='`with mesh:` context manager') class PJitErrorTest(jtu.JaxTestCase): @check_1d_2d_mesh(set_mesh=True) diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 3a70b08ea912..1423a8bcf268 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -1138,6 +1138,8 @@ def f(x): self.assertEqual(count(), 1) +@jtu.ignore_warning(category=DeprecationWarning, + message='`with mesh:` context manager') class IOCallbackTest(jtu.JaxTestCase): def setUp(self): From c74b2a5cc871342ecfa74006982225d746ff5d65 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 8 Dec 2025 17:48:52 -0800 Subject: [PATCH 107/315] Add a jax network transfer benchmark script PiperOrigin-RevId: 841983057 --- jax/BUILD | 2 +- jax/experimental/BUILD | 7 +++++++ jaxlib/jax.bzl | 1 + 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/jax/BUILD b/jax/BUILD index 92a07f72f5e6..245d15d9de16 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -459,7 +459,7 @@ alias( alias( name = "experimental_transfer", actual = "//jax/experimental:transfer", - visibility = [":internal"], + visibility = ["//jax/experimental:experimental_transfer_users"], ) # Aliases of example_library targets. diff --git a/jax/experimental/BUILD b/jax/experimental/BUILD index c6ee0e9b05b4..7215cecc8651 100644 --- a/jax/experimental/BUILD +++ b/jax/experimental/BUILD @@ -15,6 +15,7 @@ load( "//jaxlib:jax.bzl", "buffer_callback_internal_users", + "experimental_transfer_users", "if_cuda_is_configured", "jax_visibility", "mosaic_gpu_internal_users", @@ -41,6 +42,12 @@ package_group( packages = buffer_callback_internal_users, ) +package_group( + name = "experimental_transfer_users", + includes = ["//jax:internal"], + packages = experimental_transfer_users, +) + package_group( name = "mosaic_users", includes = ["//jax:internal"], diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 49e11679da88..0fd5faf39398 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -45,6 +45,7 @@ tf_cuda_tests_tags = _tf_cuda_tests_tags jax_internal_packages = [] jax_extend_internal_users = [] +experimental_transfer_users = [] mosaic_gpu_internal_users = [] mosaic_internal_users = [] pallas_gpu_internal_users = [] From 249b7a7347e62527cf5263e8bc59e9b4321a7020 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Mon, 8 Dec 2025 19:25:49 -0800 Subject: [PATCH 108/315] Add ResultHandler.wrap which allows us to avoid exploding PRNG keys. PiperOrigin-RevId: 842010889 --- jax/_src/array.py | 33 ++++++++++++++++++++++------- jax/_src/interpreters/pxla.py | 2 +- jax/_src/lax/lax.py | 13 ++++++++---- jax/_src/prng.py | 24 +++++++++++++++------ jaxlib/_jax/__init__.pyi | 1 + jaxlib/py_array.cc | 39 ++++++++++++++++++++++++----------- jaxlib/py_array.h | 20 +++++++++++++----- jaxlib/xla_client.py | 2 +- tests/dtypes_test.py | 3 +++ tests/lax_test.py | 16 +++++++------- 10 files changed, 109 insertions(+), 44 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index f3f522441f12..50185186306f 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -39,6 +39,7 @@ from jax._src.interpreters import pxla from jax._src.layout import AutoLayout, Format, Layout from jax._src.lib import _jax +from jax._src.lib import jaxlib_extension_version from jax._src.lib import xla_client as xc from jax._src.mesh import empty_concrete_mesh from jax._src.sharding import Sharding @@ -1285,7 +1286,14 @@ def _array_shard_arg(xs, shardings, layouts, copy_semantics): def _array_global_result_handler(global_aval, out_sharding, committed): if global_aval.dtype == dtypes.float0: - return lambda _: np.zeros(global_aval.shape, dtypes.float0) + def handler(xs): + return np.zeros(global_aval.shape, dtypes.float0) + if jaxlib_extension_version >= 390: + phys_aval = core.physical_aval(global_aval) + return xc.array_result_handler(phys_aval, out_sharding, committed=committed, + _skip_checks=True).wrap(handler) + else: + return handler if dtypes.issubdtype(global_aval.dtype, dtypes.extended): return global_aval.dtype._rules.global_sharded_result_handler( global_aval, out_sharding, committed) @@ -1297,7 +1305,14 @@ def _array_global_result_handler(global_aval, out_sharding, committed): # Only used for Arrays that come out of pmap. def _array_local_result_handler(aval, sharding, indices): if aval.dtype == dtypes.float0: - return lambda _: np.zeros(aval.shape, dtypes.float0) + def handler(xs): + return np.zeros(aval.shape, dtypes.float0) + if jaxlib_extension_version >= 390: + phys_aval = core.physical_aval(aval) + return xc.array_result_handler(phys_aval, sharding, committed=True, + _skip_checks=True).wrap(handler) + else: + return handler if dtypes.issubdtype(aval.dtype, dtypes.extended): return aval.dtype._rules.local_sharded_result_handler( aval, sharding, indices) @@ -1326,9 +1341,13 @@ def _token_shard_arg(xs, shardings, layouts, copy_semantics): def _token_global_result_handler(global_aval, out_sharding, committed): array_handler = _array_global_result_handler( core.get_token_aval(), out_sharding, committed) - - def wrapper(*args, **kwargs): - out_buf = array_handler(*args, **kwargs) - return core.Token(out_buf) - return wrapper + if jaxlib_extension_version >= 390: + def wrapper(array): + return core.Token(array) + return array_handler.wrap(wrapper) # type: ignore + else: + def old_wrapper(*args, **kwargs): + out_buf = array_handler(*args, **kwargs) + return core.Token(out_buf) + return old_wrapper pxla.global_result_handlers[core.AbstractToken] = _token_global_result_handler diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index bc4407798945..0570b083a77e 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -302,7 +302,7 @@ def local_aval_to_result_handler( raise TypeError( f"No pxla_result_handler for type: {type(aval)}") from err -PxlaResultHandler = Callable[..., Callable[[Any], Any]] +PxlaResultHandler = Callable[..., xc._xla.ResultHandler] local_result_handlers: dict[type[core.AbstractValue], PxlaResultHandler] = {} diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index e89d5060a92c..689b76531cda 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -63,6 +63,7 @@ from jax._src.lax.utils import ( input_dtype, dtype_to_string, standard_multi_result_abstract_eval, standard_primitive) +from jax._src.lib import jaxlib_extension_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo @@ -9246,10 +9247,14 @@ def global_sharded_result_handler(aval, out_sharding, committed): else: phys_sharding = out_sharding phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed) - - def handler(bufs): - return core.DArray(aval, phys_handler(bufs)) - return handler + if jaxlib_extension_version >= 390: + def handler(arr): + return core.DArray(aval, arr) + return phys_handler.wrap(handler) + else: + def handler(bufs): + return core.DArray(aval, phys_handler(bufs)) + return handler core.bint._rules = BIntRules diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 4cbaa1719ff7..7feb345ccb2e 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -42,6 +42,7 @@ from jax._src.lax import control_flow as lax_control_flow from jax._src.lax import lax from jax._src.lax import slicing as lax_slicing +from jax._src.lib import jaxlib_extension_version from jax._src.lib import gpu_prng from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir @@ -402,10 +403,16 @@ def local_sharded_result_handler(aval, sharding, indices): phys_handler = phys_handler_maker(phys_aval, phys_sharding, phys_indices) # set up a handler that calls the physical one and wraps back up - def handler(bufs): - return PRNGKeyArray(aval.dtype._impl, phys_handler(bufs)) + if jaxlib_extension_version >= 390: + def handler(arr): + return PRNGKeyArray(aval.dtype._impl, arr) - return handler + return phys_handler.wrap(handler) + else: + def handler(bufs): + return PRNGKeyArray(aval.dtype._impl, phys_handler(bufs)) + + return handler @staticmethod def global_sharded_result_handler(aval, out_sharding, committed): @@ -414,9 +421,14 @@ def global_sharded_result_handler(aval, out_sharding, committed): phys_sharding = physical_sharding(aval, out_sharding) phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed) - def handler(bufs): - return PRNGKeyArray(aval.dtype._impl, phys_handler(bufs)) - return handler + if jaxlib_extension_version >= 390: + def handler(bufs): + return PRNGKeyArray(aval.dtype._impl, bufs) + return phys_handler.wrap(handler) + else: + def handler(bufs): + return PRNGKeyArray(aval.dtype._impl, phys_handler(bufs)) + return handler @staticmethod def make_sharded_array(aval, sharding, arrays, committed): diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index 59cfcd1caf66..4caad8876f06 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -1029,6 +1029,7 @@ def array_result_handler( class ResultHandler: def __call__(self, arg: Array | Sequence[Array], /) -> Array: ... + def wrap(self, wrapper: Callable) -> Any: ... class DeviceList: def __init__(self, arg: tuple[Device, ...], /) -> None: ... diff --git a/jaxlib/py_array.cc b/jaxlib/py_array.cc index b91115529768..37baa58c6cc9 100644 --- a/jaxlib/py_array.cc +++ b/jaxlib/py_array.cc @@ -584,18 +584,21 @@ PyArray PyArray::MakeFromIfrtArrayAndSharding(nb_class_ptr py_client, std::move(ifrt_array), committed, skip_checks); } -PyArrayResultHandler::PyArrayResultHandler(nb::object aval, nb::object sharding, - bool committed, bool skip_checks) +PyArrayResultHandler::PyArrayResultHandler( + nb::object aval, nb::object sharding, bool committed, bool skip_checks, + std::vector wrappers) : aval_(std::move(aval)), sharding_(std::move(sharding)), committed_(committed), - skip_checks_(skip_checks) { + skip_checks_(skip_checks), + wrappers_(std::move(wrappers)) { weak_type_ = nb::cast(aval_.attr("weak_type")); dtype_ = nb::cast(aval_.attr("dtype")); shape_ = nb::cast>(aval_.attr("shape")); } -PyArray PyArrayResultHandler::Call(absl::Span py_arrays) const { +nanobind::object PyArrayResultHandler::Call( + absl::Span py_arrays) const { auto py_device_list = GetPyDeviceList(sharding_); if (!py_device_list.ok()) { throw nb::value_error( @@ -610,15 +613,20 @@ PyArray PyArrayResultHandler::Call(absl::Span py_arrays) const { xla::Future<>()); } -PyArray PyArrayResultHandler::Call(nb_class_ptr py_client, - ifrt::ArrayRef ifrt_array, - xla::Future<> result_status) const { - return PyArray(aval_, weak_type_, dtype_, shape_, sharding_, - std::move(py_client), std::move(ifrt_array), committed_, - skip_checks_, std::move(result_status)); +nanobind::object PyArrayResultHandler::Call(nb_class_ptr py_client, + ifrt::ArrayRef ifrt_array, + xla::Future<> result_status) const { + nanobind::object result = + PyArray(aval_, weak_type_, dtype_, shape_, sharding_, + std::move(py_client), std::move(ifrt_array), committed_, + skip_checks_, std::move(result_status)); + for (auto& cb : wrappers_) { + result = cb(std::move(result)); + } + return result; } -PyArray PyArrayResultHandler::Call(PyArray py_array) const { +nanobind::object PyArrayResultHandler::Call(PyArray py_array) const { return Call(py_array.py_client(), tsl::FormRef(py_array.ifrt_array()), xla::Future<>()); } @@ -2364,7 +2372,14 @@ absl::Status PyArray::Register(nb::module_& m) { .c_str()); }, nb::sig( - "def __call__(self, arg: Array | Sequence[Array], /) -> Array")); + "def __call__(self, arg: Array | Sequence[Array], /) -> Array")) + .def("wrap", [](const PyArrayResultHandler& self, nb::callable wrapper) { + auto wrappers = self.wrappers(); + wrappers.push_back(std::move(wrapper)); + return make_nb_class( + self.aval(), self.sharding(), self.committed(), self.skip_checks(), + std::move(wrappers)); + }); return absl::OkStatus(); } diff --git a/jaxlib/py_array.h b/jaxlib/py_array.h index 3a323f53809f..f4d1d59b99e9 100644 --- a/jaxlib/py_array.h +++ b/jaxlib/py_array.h @@ -359,13 +359,22 @@ class PyArray : public nanobind::object { class PyArrayResultHandler { public: PyArrayResultHandler(nanobind::object aval, nanobind::object sharding, - bool committed, bool skip_checks); + bool committed, bool skip_checks, + std::vector wrappers = {}); - PyArray Call(absl::Span py_arrays) const; - PyArray Call(PyArray py_array) const; + nanobind::object Call(absl::Span py_arrays) const; + nanobind::object Call(PyArray py_array) const; - PyArray Call(nb_class_ptr py_client, xla::ifrt::ArrayRef ifrt_array, - xla::Future<> result_status = xla::Future<>()) const; + nanobind::object Call(nb_class_ptr py_client, + xla::ifrt::ArrayRef ifrt_array, + xla::Future<> result_status = xla::Future<>()) const; + + const std::vector& wrappers() const { return wrappers_; } + + nanobind::object aval() const { return aval_; } + nanobind::object sharding() const { return sharding_; } + bool committed() const { return committed_; } + bool skip_checks() const { return skip_checks_; } private: nanobind::object aval_; @@ -376,6 +385,7 @@ class PyArrayResultHandler { xla::nb_dtype dtype_; std::vector shape_; + std::vector wrappers_; }; absl::StatusOr CudaArrayInterfaceToBuffer( diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index d0f9bceed0f8..8ae8102af30a 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -47,7 +47,7 @@ # Please suffix the version number with a brief description of your change # in a comment. The goal here is to force a merge conflict if two changes # attempt to grab the same version number. -_version = 389 # LoadedExecutable.serialize +_version = 390 # ResultHandler.wrap # An internal increasing version number for protecting jaxlib code against # ifrt changes. diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index cc58a7011025..bbb1f0ea7afc 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -34,6 +34,7 @@ from jax._src import literals from jax._src import test_util as jtu from jax._src.lax import lax as lax_internal +from jax._src.lib import jaxlib_extension_version config.parse_flags_with_absl() @@ -872,6 +873,8 @@ def global_sharded_result_handler(aval, out_sharding, committed): phys_aval = core.physical_aval(aval) phys_handler_maker = pxla.global_result_handlers[core.ShapedArray] phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed) + if jaxlib_extension_version >= 390: + return phys_handler.wrap(lambda arr: earray.EArray(aval, arr)) return lambda bufs: earray.EArray(aval, phys_handler(bufs)) @dataclasses.dataclass(frozen=True) diff --git a/tests/lax_test.py b/tests/lax_test.py index 8f1aa9304718..626982b1c88d 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -48,6 +48,7 @@ from jax._src.internal_test_util import lax_test_util from jax._src.lax import lax as lax_internal from jax._src.lax import utils as lax_utils +from jax._src.lib import jaxlib_extension_version from jax._src.util import safe_zip from jax._src.tree_util import tree_map @@ -3992,14 +3993,13 @@ def handler(_, buf): @staticmethod def global_sharded_result_handler(aval, out_sharding, committed): - def handler(arr): - from jax._src.array import ArrayImpl - if isinstance(arr, ArrayImpl): - buf, = arr._arrays - else: - buf, = arr - return FooArray(aval.shape, buf) - return handler + phys_sharding = out_sharding # unlike KeyTyRules, assume same shape + phys_aval = core.physical_aval(aval) + phys_handler_maker = pxla.global_result_handlers[core.ShapedArray] + phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed) + if jaxlib_extension_version >= 390: + return phys_handler.wrap(lambda arr: FooArray(aval.shape, arr)) + return lambda bufs: FooArray(aval.shape, phys_handler(bufs)) class FooTy(dtypes.ExtendedDType): From eff53a5c7a33bb7e795144b30e1cb038d8da66ad Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 8 Dec 2025 21:00:43 -0800 Subject: [PATCH 109/315] [Pallas] Device Id dict to mesh fastpath for power of twos PiperOrigin-RevId: 842041847 --- jax/_src/pallas/primitives.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 54e5cae849eb..2fdfe6b31513 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -1350,11 +1350,26 @@ def _device_id_dict_to_mesh(mesh_context: pallas_utils.MeshInfo, device_id_dict, ) axes_dimensions = [mesh_axis_sizes[name] for name in axis] for axis_index, axis_name in enumerate(axis): - axis_size = arith.constant(i32, mesh_axis_sizes[axis_name]) - minor_divisor = arith.constant( - i32, math.prod(axes_dimensions[axis_index + 1 :]) - ) - device_idx = arith.remsi(arith.divsi(idx, minor_divisor), axis_size) + axis_size = mesh_axis_sizes[axis_name] + inner_mesh_size = math.prod(axes_dimensions[axis_index + 1 :]) + minor_divisor = arith.constant(i32, inner_mesh_size) + + # Fast path for power of 2s + if inner_mesh_size & (inner_mesh_size - 1) == 0: + shift_len = (inner_mesh_size & -inner_mesh_size).bit_length() - 1 + partial_device_idx = arith.shrui(idx, arith.constant(i32, shift_len)) + else: + partial_device_idx = arith.divsi(idx, minor_divisor) + + if axis_size & (axis_size - 1) == 0: + device_idx = arith.andi( + partial_device_idx, + arith.constant(i32, mesh_axis_sizes[axis_name] - 1), + ) + else: + device_idx = arith.remsi( + partial_device_idx, arith.constant(i32, axis_size) + ) physical_axis_dict[axis_name] = device_idx else: physical_axis_dict[axis] = idx From 9be13f9c5bccfe4103972f23cad1abbe112a2b4f Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 9 Dec 2025 00:05:47 -0800 Subject: [PATCH 110/315] Update XLA dependency to use revision http://github.com/openxla/xla/commit/91b3f740b75d1d932a12fb0886338f84f856a453 PiperOrigin-RevId: 842096427 --- third_party/xla/revision.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/revision.bzl b/third_party/xla/revision.bzl index a1890149fb2b..8810fa765f03 100644 --- a/third_party/xla/revision.bzl +++ b/third_party/xla/revision.bzl @@ -21,5 +21,5 @@ # and update XLA_SHA256 with the result. # buildifier: disable=module-docstring -XLA_COMMIT = "f5d6d1aae38dfa3de44c057064ec7609b9a390af" -XLA_SHA256 = "e867f1329105c55f34667c589ee718d10d6de378b026d8d363f13c20a83beb5d" +XLA_COMMIT = "91b3f740b75d1d932a12fb0886338f84f856a453" +XLA_SHA256 = "68d6d2f66b10e826512fa6e262143425606dda30d6ab95daa83e3dfb5cf298a0" From 3bec82db421b7d3a0d5a23f57ab469cd2be1de49 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Tue, 9 Dec 2025 00:59:44 -0800 Subject: [PATCH 111/315] [Mosaic GPU] Remove `{Least,Most}ReplicatedExpression` constructors. After getting rid of hints, these constructors are no longer necessary. PiperOrigin-RevId: 842112040 --- jax/experimental/mosaic/gpu/constraints.py | 90 +--------------- jax/experimental/mosaic/gpu/layouts.py | 113 --------------------- tests/mosaic/gpu_constraints_test.py | 110 +------------------- tests/pallas/mosaic_gpu_test.py | 4 +- 4 files changed, 3 insertions(+), 314 deletions(-) diff --git a/jax/experimental/mosaic/gpu/constraints.py b/jax/experimental/mosaic/gpu/constraints.py index b1d1f3ec45a9..9a66cde1432b 100644 --- a/jax/experimental/mosaic/gpu/constraints.py +++ b/jax/experimental/mosaic/gpu/constraints.py @@ -24,7 +24,7 @@ from collections.abc import Sequence import dataclasses import math -from typing import Any, Callable, assert_never, final +from typing import Any, assert_never, final from . import fragmented_array as fa from . import launch_context as lc @@ -86,22 +86,6 @@ def __str__(self): return f"C({self.value})" -@dataclasses.dataclass(frozen=True) -class LeastReplicated: - expressions: tuple[Expression, ...] - - def __post_init__(self): - assert len(self.expressions) >= 1 - - -@dataclasses.dataclass(frozen=True) -class MostReplicated: - expressions: tuple[Expression, ...] - - def __post_init__(self): - assert len(self.expressions) >= 1 - - @dataclasses.dataclass(frozen=True) class Reduce: expression: Expression @@ -136,8 +120,6 @@ def __str__(self): Expression = ( Variable | Constant - | LeastReplicated - | MostReplicated | Reduce | BroadcastInDim | Reshape @@ -145,62 +127,6 @@ def __str__(self): ) -def reduce_replicated_expression( - input_expr: LeastReplicated | MostReplicated, - assignments: dict[Variable, Constant], - reducer: Callable[[fa.FragmentedLayout, fa.FragmentedLayout], fa.FragmentedLayout | None] -) -> Expression | Unsatisfiable: - assert input_expr.expressions - - new_expressions: list[Expression] = [] - # Use a set to eliminate duplicates, but preserve the order. - seen: set[Expression] = set() - for expr in input_expr.expressions: - reduced_expr = reduce_expression(expr, assignments) - if isinstance(reduced_expr, Unsatisfiable): - return Unsatisfiable() - if reduced_expr in seen: - continue - new_expressions.append(reduced_expr) - seen.add(reduced_expr) - - if len(new_expressions) == 1: - return new_expressions[0] - - consts = [] - unknowns = [] - for e in new_expressions: - if not isinstance(e, Constant): - unknowns.append(e) - continue - if not isinstance(e, RegisterLayout): - raise ValueError( - f"Reduction of non-register layout constant is not supported: {e}" - ) - consts.append(e) - - if consts: - const_red, *consts = consts - red = const_red - for cst in consts: - red_value = reducer(red.value, cst.value) - if red_value is None: - # The layouts are not compatible up to replication, this expression - # cannot be simplified. - return Unsatisfiable() - red = RegisterLayout(red_value) - else: - red = None - - constructor = type(input_expr) - if red is not None: - if unknowns: - return constructor((red, *unknowns)) - return red - - return constructor(tuple(unknowns)) - - def reduce_broadcast_expression( broadcast: BroadcastInDim, assignments: dict[Variable, Constant] ) -> Expression | Unsatisfiable: @@ -314,14 +240,6 @@ def reduce_expression( return expr case Variable(): return assignments.get(expr, expr) - case MostReplicated(): - return reduce_replicated_expression( - expr, assignments, layouts_lib.join_layouts - ) - case LeastReplicated(): - return reduce_replicated_expression( - expr, assignments, layouts_lib.meet_layouts - ) case Reduce(expression=expr, axes=axes): reduced_expr = reduce_expression(expr, assignments) match reduced_expr: @@ -640,12 +558,6 @@ def extract_variables(expr: Expression) -> None: free_variables.append(expr) case Constant(): ... - case MostReplicated(expressions=expressions): - for e in expressions: - extract_variables(e) - case LeastReplicated(expressions=expressions): - for e in expressions: - extract_variables(e) case Reduce(expression=e): extract_variables(e) case BroadcastInDim(expression=e): diff --git a/jax/experimental/mosaic/gpu/layouts.py b/jax/experimental/mosaic/gpu/layouts.py index 82870696eb32..1b75e4d48b47 100644 --- a/jax/experimental/mosaic/gpu/layouts.py +++ b/jax/experimental/mosaic/gpu/layouts.py @@ -15,7 +15,6 @@ """Layout utilities.""" import re -from typing import assert_never from jax._src.lib import mosaic_gpu_dialect as mgpu from jax._src.lib.mlir import ir @@ -224,118 +223,6 @@ def splat_is_compatible_with_tiled( return all(d1 % d2 == 0 for d1, d2 in zip(s1, s2)) -def meet_layouts( - layout1: fa.FragmentedLayout, layout2: fa.FragmentedLayout -) -> fa.FragmentedLayout | None: - """Returns the "meet" of two layouts that are compatible up to replication. - - The "meet" of the two layouts is the most replicated layout that is still - less replicated than the arguments. - - This is the dual of `join_layouts`. - - Returns: - The "meet" of the two layouts if both layouts are compatible up to - replication. - - Raises: - ValueError: if the two layouts are not compatible up to replication. - """ - if layout1 == layout2: - return layout1 - - match (layout1, layout2): - case (fa.WGSplatFragLayout(), _): - if isinstance(layout2, fa.TiledLayout): - if splat_is_compatible_with_tiled(layout1, layout2): - return layout2 - elif layout1.shape == layout2.shape: - return layout2 - case (_, fa.WGSplatFragLayout()): - if isinstance(layout1, fa.TiledLayout): - if splat_is_compatible_with_tiled(layout2, layout1): - return layout1 - elif layout1.shape == layout2.shape: - return layout1 - case (fa.TiledLayout(), fa.TiledLayout()): - # TODO(bchetioui): handle `TiledLayout` replication. - raise NotImplementedError("TiledLayout replication not supported yet") - - # Layouts are not compatible up to replication. - return None - -# NOTE: We say that two layouts are compatible up to replication if the two -# layouts satisfy at least one of the following conditions together: -# -# - The two layouts are equal; -# - One of the layouts is a `WGSplatFragLayout`, and -# * The other layout is a `WGStridedFragLayout` with the same shape; -# * The other layout is a `TiledLayout` that can be used to tile the shape -# embedded in the `WGSplatFragLayout`. -# -# If any of these conditions hold, then we are always able to substitute one -# layout with the other without having to reorder any data in the underlying -# array---i.e. a relayout is free. -# -# Note that there are other combinations of layouts for which relayout is free, -# but we voluntarily narrowed down our definition to span a small, useful -# subset. - -def join_layouts( - layout1: fa.FragmentedLayout, layout2: fa.FragmentedLayout -) -> fa.FragmentedLayout | None: - """Returns the "join" of two layouts that are compatible up to replication. - - The "join" of the two layouts is the least replicated layout that is still - more replicated than the arguments. - - This is the dual of `meet_layouts`. - - Returns: - The "join" of the two layouts if both layouts are compatible up to - replication. - - Raises: - ValueError: if the two layouts are not compatible up to replication. - """ - if layout1 == layout2: - return layout1 - - match (layout1, layout2): - case (fa.WGSplatFragLayout(), _): - if isinstance(layout2, fa.TiledLayout): - if splat_is_compatible_with_tiled(layout1, layout2): - return layout1 - elif layout1.shape == layout2.shape: - return layout1 - case (_, fa.WGSplatFragLayout()): - if isinstance(layout1, fa.TiledLayout): - if splat_is_compatible_with_tiled(layout2, layout1): - return layout2 - elif layout1.shape == layout2.shape: - return layout2 - case (fa.TiledLayout(), fa.TiledLayout()): - # TODO(bchetioui): handle `TiledLayout` replication. - raise NotImplementedError("TiledLayout replication not supported yet") - - # Layouts are not compatible up to replication. - return None - - -def has_any_replication(layout: fa.FragmentedLayout) -> bool: - match layout: - case fa.WGSplatFragLayout(): - return True - case fa.WGStridedFragLayout(): - return False - case fa.TiledLayout(): - is_warp_replicated = any(isinstance(d, fa.Replicated) for d in layout.warp_dims) - is_lane_replicated = any(isinstance(d, fa.Replicated) for d in layout.lane_dims) - return is_warp_replicated or is_lane_replicated - case _ as unreachable: - return assert_never(unreachable) # pytype: disable=wrong-arg-types - - _tile_transform_attr_pattern = re.compile( r"^#mosaic_gpu.tile<[^>]+>$" ) diff --git a/tests/mosaic/gpu_constraints_test.py b/tests/mosaic/gpu_constraints_test.py index 6800789ab6ec..cab748000fb0 100644 --- a/tests/mosaic/gpu_constraints_test.py +++ b/tests/mosaic/gpu_constraints_test.py @@ -127,11 +127,9 @@ def test_constraint_system_unknowns_are_all_the_variables_without_assignment( ): v0, v1, v2, v3 = V(0), V(1), V(2), V(3) layout = RL(mgpu.WGSplatFragLayout((1, 1))) - least_replicated = cs.LeastReplicated((v2, v3)) - most_replicated = cs.MostReplicated((least_replicated,)) system = cs.ConstraintSystem( assignments={v0: layout}, - constraints=[Eq(v1, most_replicated)], + constraints=[Eq(v1, v2), cs.Relayout(v2, v3)], ) self.assertSequenceEqual(system.unknowns(), [v1, v2, v3]) @@ -164,112 +162,6 @@ def test_intersection_of_compatible_systems_is_union_of_fields(self): self.assertSequenceEqual(system1.unknowns(), [v1]) self.assertSequenceEqual(system_intersection.unknowns(), [v0, v1]) - def test_reduce_extracts_most_replicated_expression_correctly(self): - v0 = V(0) - shape = (1, 128) - layout0 = RL(mgpu.WGSplatFragLayout(shape)) - layout1 = RL(mgpu.WGStridedFragLayout(shape, vec_size=1)) - with self.subTest("most-replicated-expression-exists"): - system = cs.ConstraintSystem( - constraints=[Eq(v0, cs.MostReplicated((layout0, layout1)))], - ) - self.assertEqual( - cs.reduce(system), - cs.ConstraintSystem(assignments={v0: layout0}), - ) - - with self.subTest("most-replicated-expression-is-unique-expression"): - system = cs.ConstraintSystem( - constraints=[Eq(v0, cs.MostReplicated((layout0,)))], - ) - self.assertEqual( - cs.reduce(system), - cs.ConstraintSystem(assignments={v0: layout0}), - ) - - with self.subTest("most-replicated-expression-does-not-exist"): - system = cs.ConstraintSystem( - constraints=[Eq(v0, cs.MostReplicated((layout1, v0)))], - ) - self.assertEqual(cs.reduce(system), system) - - def test_reduce_extracts_least_replicated_expression_correctly(self): - v0 = V(0) - shape = (1, 128) - layout0 = RL(mgpu.WGSplatFragLayout(shape)) - layout1 = RL(mgpu.WGStridedFragLayout(shape, vec_size=1)) - with self.subTest("least-replicated-expression-exists"): - system = cs.ConstraintSystem( - constraints=[Eq(v0, cs.LeastReplicated([layout0, layout1]))], - ) - self.assertEqual( - cs.reduce(system), - cs.ConstraintSystem(assignments={v0: layout1}), - ) - - with self.subTest("least-replicated-expression-is-unique-expression"): - system = cs.ConstraintSystem( - constraints=[Eq(v0, cs.LeastReplicated((layout0,)))], - ) - self.assertEqual( - cs.reduce(system), - cs.ConstraintSystem(assignments={v0: layout0}), - ) - - with self.subTest("least-replicated-expression-does-not-exist"): - system = cs.ConstraintSystem( - constraints=[Eq(v0, cs.LeastReplicated((layout0, v0)))], - ) - self.assertEqual(cs.reduce(system), system) - - def test_reduce_most_replicated_expression_reduces_compatible_layouts(self): - splat_layout = RL(mgpu.WGSplatFragLayout((128, 64))) - tiled_layout = RL(mgpu.WGMMA_LAYOUT) - self.assertEqual( - cs.reduce_expression( - cs.MostReplicated((splat_layout, tiled_layout)), - {}, - ), - splat_layout, - ) - - def test_reduce_most_replicated_expression_is_unsatisfiable_for_incompatible_layouts( - self, - ): - splat_layout = RL(mgpu.WGSplatFragLayout((1, 2))) - tiled_layout = RL(mgpu.WGMMA_LAYOUT) - self.assertIsInstance( - cs.reduce_expression( - cs.MostReplicated((splat_layout, tiled_layout)), - {}, - ), - cs.Unsatisfiable, - ) - - def test_reduce_least_replicated_expression_reduces_compatible_layouts(self): - splat_layout = RL(mgpu.WGSplatFragLayout((128, 64))) - tiled_layout = RL(mgpu.WGMMA_LAYOUT) - self.assertEqual( - cs.reduce_expression( - cs.LeastReplicated((splat_layout, tiled_layout)), - {}, - ), - tiled_layout, - ) - - def test_reduce_least_replicated_expression_is_unsatisfiable_for_incompatible_layouts( - self, - ): - splat_layout = RL(mgpu.WGSplatFragLayout((1, 2))) - tiled_layout = RL(mgpu.WGMMA_LAYOUT) - self.assertIsInstance( - cs.reduce_expression( - cs.LeastReplicated((splat_layout, tiled_layout)), - {}, - ), - cs.Unsatisfiable, - ) - @parameterized.named_parameters( ("reduce_to_row_layout", (1,), mgpu.WGMMA_ROW_LAYOUT), ("reduce_to_col_layout", (0,), mgpu.WGMMA_COL_LAYOUT), diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 7f7e18f2fd8b..fd248ea182da 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1811,7 +1811,6 @@ def kernel(o_ref): ], ) def test_transposed_layout(self, layouts): - self.skip_if_wg_semantics() # TiledLayout replication not supported. layout, transposed_layout = layouts dtype = jnp.dtype(jnp.float16) shape = (256, 192) @@ -3417,7 +3416,6 @@ def kernel(x_ref, y_ref, aliased_ref, smem_ref, barrier_ref): plgpu.Layout.TCGEN05, plgpu.Layout.TCGEN05_TMEM_NATIVE ) def test_tmem_load_layout(self, layout): - self.skip_if_wg_semantics() # TiledLayout replication not supported yet. transforms = self.default_transforms(dtype=jnp.float32) @functools.partial( self.kernel, @@ -3452,7 +3450,7 @@ def kernel(x_ref, y_ref, tmem_ref, smem_ref, barrier_ref): plgpu.Layout.TCGEN05_M64_COLLECTIVE_NATIVE(160) ) def test_tmem_store_load_collective(self, layout): - self.skip_if_wg_semantics() # TiledLayout replication not supported yet. + self.skip_if_wg_semantics() # Failed to infer a possible set of layouts. @functools.partial( self.kernel, out_shape=jax.ShapeDtypeStruct((64, 160), jnp.float32), From e288c27aea118e0f20e5196cda4affc9b5cfd5db Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Tue, 9 Dec 2025 02:44:03 -0800 Subject: [PATCH 112/315] [Mosaic GPU] Support `FragmentedArray.broadcast_in_dim` for splat to other layouts. PiperOrigin-RevId: 842145211 --- .../mosaic/gpu/fragmented_array.py | 15 ++++-------- tests/mosaic/gpu_test.py | 24 +++++++++++++++++++ 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 223580bb9582..f3e7f5b477e8 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1301,9 +1301,8 @@ def to_layout(self, new_layout: FragmentedLayout) -> FragmentedArray: raise NotImplementedError( f"Cannot convert from {self.layout} to {new_layout}" ) - [reg] = self.registers.flat return type(self).splat( - reg, self.shape, new_layout, is_signed=self.is_signed + self.registers.item(), self.shape, new_layout, is_signed=self.is_signed ) def _pointwise( @@ -2502,15 +2501,9 @@ def broadcast_in_dim( f" {shape[target_dim]} in shape after broadcast" ) if isinstance(self.layout, WGSplatFragLayout): - if isinstance(layout, WGSplatFragLayout): - if layout.shape != shape: - raise ValueError( - f"Layout shape {layout.shape} does not match broadcast shape {shape}" - ) - return FragmentedArray( - _registers=self.registers, _layout=layout, _is_signed=self.is_signed, - ) - # TODO: Support splat to other layouts + return type(self).splat( + self.registers.item(), shape, layout, is_signed=self.is_signed + ) if not isinstance(self.layout, TiledLayout) or not isinstance(layout, TiledLayout): raise NotImplementedError(self.layout, layout) if any(d1 >= d2 for d1, d2 in zip(source_dimensions, source_dimensions[1:])): diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index fce498539f87..b3e918c674eb 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -3812,6 +3812,30 @@ def kernel(ctx, gmem_input, gmem_output, _): out_ref = jax.lax.broadcast_in_dim(inp, (m, n), (1,)) np.testing.assert_array_equal(result, out_ref) + @parameterized.parameters(*mtu.RegisterLayout) + def test_broadcast_splat(self, layout): + out_shape = (128, 128) + + def body(ctx, out_ref, scratch): + del ctx, scratch + c42 = arith.constant(ir.IntegerType.get_signless(32), 42) + arr = mgpu.FragmentedArray.splat(c42, (128,), is_signed=True) + out_layout = layout.to_mgpu(out_shape, jnp.int32) + result = arr.broadcast_in_dim(out_shape, (0,), out_layout) + result.store_untiled(out_ref, optimized=False) + + kernel = mgpu.as_gpu_kernel( + body, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(), + out_shape=jax.ShapeDtypeStruct(out_shape, jnp.int32), + smem_scratch_shape=[], + ) + np.testing.assert_array_equal( + kernel(), np.full(out_shape, 42, dtype=np.int32) + ) + def test_warp_tree_reduce(self): def kernel(ctx, out, *_): del ctx From 1bfd8dc167b82502eacfac0d839788815f942454 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 9 Dec 2025 04:21:03 -0800 Subject: [PATCH 113/315] [maint] clean up jnp.arange implementation --- jax/_src/numpy/lax_numpy.py | 82 ++++++++++++++++++------------------- 1 file changed, 39 insertions(+), 43 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 338e5b6ab164..79278915a09b 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -36,7 +36,6 @@ import numpy as np from jax._src import api -from jax._src import config from jax._src import core from jax._src import deprecations from jax._src import dtypes @@ -143,7 +142,8 @@ def iscomplexobj(x: Any) -> bool: >>> jnp.iscomplexobj(jnp.array([0, 1+2j])) True """ - if x is None: + # Check for int here to avoid potential overflow in jnp.array below. + if x is None or isinstance(x, int): return False try: typ = x.dtype.type @@ -5954,60 +5954,56 @@ def arange(start: ArrayLike | DimSize, stop: ArrayLike | DimSize | None = None, def _arange(start: ArrayLike | DimSize, stop: ArrayLike | DimSize | None = None, step: ArrayLike | None = None, dtype: DTypeLike | None = None, out_sharding: NamedSharding | None = None) -> Array: + # Validate inputs if dtype is not None: dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "arange") - if not config.dynamic_shapes.value: - util.check_arraylike("arange", start) - if stop is None and step is None: - start = core.concrete_or_error(None, start, "It arose in the jnp.arange argument 'stop'") - else: - start = core.concrete_or_error(None, start, "It arose in the jnp.arange argument 'start'") - util.check_arraylike_or_none("arange", None, stop, step) + util.check_arraylike_or_none("arange", start, stop, step) + + # Ensure start/stop/step are concrete + start_name = "stop" if stop is None and step is None else "start" + start = core.concrete_or_error(None, start, f"It arose in the jnp.arange argument '{start_name}'") stop = core.concrete_or_error(None, stop, "It arose in the jnp.arange argument 'stop'") step = core.concrete_or_error(None, step, "It arose in the jnp.arange argument 'step'") - start_name = "stop" if stop is None and step is None else "start" + + # Ensure start/stop/step are scalars for name, val in [(start_name, start), ("stop", stop), ("step", step)]: if val is not None and np.ndim(val) != 0: raise ValueError(f"jax.numpy.arange: arguments must be scalars; got {name}={val}") + + # Handle symbolic dimensions if any(core.is_symbolic_dim(v) for v in (start, stop, step)): - # Some dynamic shapes - if stop is None and step is None: - stop = start - start = 0 - step = 1 - elif stop is not None and step is None: + if stop is None: + start, stop = 0, start + if step is None: step = 1 return _arange_dynamic(start, stop, step, dtype or dtypes.default_int_dtype()) + if dtype is None: - dtype = result_type(start, *(x for x in [stop, step] if x is not None)) + dtype = dtypes.result_type(start, *(x for x in [stop, step] if x is not None)) dtype = dtypes.jax_dtype(dtype) - if stop is None and step is None: - start_dtype = _dtype(start) - if (not dtypes.issubdtype(start_dtype, np.integer) and - not dtypes.issubdtype(start_dtype, dtypes.extended)): - ceil_ = ufuncs.ceil if isinstance(start, core.Tracer) else np.ceil - start = ceil_(start).astype(int) - return lax.broadcasted_iota(dtype, (start,), 0, out_sharding=out_sharding) # type: ignore[arg-type] + + if iscomplexobj(start) or iscomplexobj(stop) or iscomplexobj(step): + # Complex arange is poorly defined; fall back to NumPy here. + # TODO(jakevdp): deprecate the complex case. + return array(np.arange(start, stop, step, dtype=dtype), device=out_sharding) + + if step is not None: + # arange(N, M, K): when step is specified, fall back to NumPy. + return array(np.arange(start, stop, step, dtype=dtype), device=out_sharding) + + if stop is None: + start, stop = 0, start + + if start == 0: + # arange(M) or arange(0, M) + size = max(0, int(np.ceil(stop))) + return lax.broadcasted_iota(dtype, (size,), 0, out_sharding=out_sharding) + else: - if step is None and stop is not None: - # Skip optimization if start or stop is complex (ceil doesn't support complex) - start_dtype = _dtype(start) - stop_dtype = _dtype(stop) - if (dtypes.issubdtype(start_dtype, np.complexfloating) or - dtypes.issubdtype(stop_dtype, np.complexfloating)): - return array(np.arange(start, stop=stop, step=step, dtype=dtype), - device=out_sharding) - # Use iota + offset instead of creating a constant array - size = int(np.ceil(stop - start)) - if size <= 0: - return array([], dtype=dtype, device=out_sharding) - result = lax.broadcasted_iota(dtype, (size,), 0, out_sharding=out_sharding) - if start != 0: - # Add offset if start is non-zero - result = lax.add(result, lax.convert_element_type(start, dtype)) - return result - return array(np.arange(start, stop=stop, step=step, dtype=dtype), - device=out_sharding) + # arange(N, M) + size = max(0, int(np.ceil(stop - start))) + return lax.add(lax.convert_element_type(start, dtype), + lax.broadcasted_iota(dtype, (size,), 0, out_sharding=out_sharding)) def _arange_dynamic( From ea5aee97bbfe3d2161ecbe39cb53deaf7ec88831 Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Tue, 9 Dec 2025 09:05:23 -0800 Subject: [PATCH 114/315] Update generated stubs. PiperOrigin-RevId: 842264374 --- jaxlib/_jax/__init__.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index 4caad8876f06..120acdb347e5 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -1029,7 +1029,7 @@ def array_result_handler( class ResultHandler: def __call__(self, arg: Array | Sequence[Array], /) -> Array: ... - def wrap(self, wrapper: Callable) -> Any: ... + def wrap(self, arg: Callable, /) -> ResultHandler: ... class DeviceList: def __init__(self, arg: tuple[Device, ...], /) -> None: ... From 2f62fb10d45c6b0a74b2f836f0b766af324cea87 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 9 Dec 2025 09:25:40 -0800 Subject: [PATCH 115/315] Fix back compat test to ignore warnings PiperOrigin-RevId: 842272730 --- jax/_src/test_util.py | 2 -- tests/export_back_compat_test.py | 8 +++++--- tests/fused_attention_stablehlo_test.py | 2 ++ 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 366bec41611b..30ed606b527f 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1248,8 +1248,6 @@ class JaxTestCase(parameterized.TestCase): 'jax_legacy_prng_key': 'error', } - - def setUp(self): super().setUp() self.enterContext(assert_global_configs_unchanged()) diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 43be67f393fe..a7f9fa2ea73f 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -784,6 +784,8 @@ def func(x): data = self.load_testdata(cuda_threefry2x32.data_2024_07_30) self.run_one_test(func, data) + @jtu.ignore_warning(category=DeprecationWarning, + message='`with mesh:` context manager') def test_tpu_sharding(self): # Tests "Sharding", "SPMDShardToFullShape", "SPMDFullToShardShape" on TPU if not jtu.test_device_matches(["tpu"]) or len(jax.devices()) < 2: @@ -1006,8 +1008,10 @@ def check_top_k_results(res_run, res_expected, *, rtol, atol): ) -@jtu.with_config(jax_use_shardy_partitioner=True) class ShardyCompatTest(bctu.CompatTestBase): + + @jtu.ignore_warning(category=DeprecationWarning, + message='`with mesh:` context manager') def test_shardy_sharding_ops_with_different_meshes(self): # Tests whether we can save and load a module with meshes that have the # same axis sizes (and same order) but different axis names. @@ -1046,7 +1050,5 @@ def shard_map_func(x): # b: f32[2, 4] expect_current_custom_calls=custom_call_targets_override) - - if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/fused_attention_stablehlo_test.py b/tests/fused_attention_stablehlo_test.py index e968107e6f6e..1228b2160ab1 100644 --- a/tests/fused_attention_stablehlo_test.py +++ b/tests/fused_attention_stablehlo_test.py @@ -264,6 +264,8 @@ def dot_product_attention_fp8(query, key, value, fp8_metas): return out[0], (query_grad, key_grad, value_grad) +@jtu.ignore_warning(category=DeprecationWarning, + message='`with mesh:` context manager') class DotProductAttentionTest(jtu.JaxTestCase): def setUp(self): super().setUp() From 50bdc72240dd75afa936b34326c1502b2b4729d4 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 9 Dec 2025 09:26:10 -0800 Subject: [PATCH 116/315] [pallas:mosaic_gpu] Slightly tweaked the error messages in a few places PiperOrigin-RevId: 842272894 --- jax/_src/pallas/mosaic_gpu/core.py | 3 ++- jax/experimental/mosaic/gpu/launch_context.py | 4 ++-- jax/experimental/mosaic/gpu/wgmma.py | 3 ++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index c8198d2cc655..097f6e47b170 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -670,7 +670,8 @@ def untransform_reshape( self, dtype: jnp.dtype, shape: tuple[int, ...] ) -> tuple[tuple[int, ...], state_types.Transform]: del dtype - raise NotImplementedError("Reshapes don't commute with transposes.") + # TODO(slebedev): Support this. + raise NotImplementedError("Reshapes don't commute with tiling.") def untransform_index( self, dtype: jnp.dtype | ir.Type, idxs: tuple[Index, ...] diff --git a/jax/experimental/mosaic/gpu/launch_context.py b/jax/experimental/mosaic/gpu/launch_context.py index d03410bbd3a2..88b44978897c 100644 --- a/jax/experimental/mosaic/gpu/launch_context.py +++ b/jax/experimental/mosaic/gpu/launch_context.py @@ -862,7 +862,7 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): if max(slice_shape) > 256: raise ValueError( "Async copies only support copying <=256 elements along each" - " dimension" + f" dimension, got {tuple(slice_shape)}" ) if (zeroth_bw := slice_shape[-1] * element_bitwidth) % 128 != 0: raise ValueError( @@ -1019,7 +1019,7 @@ def async_copy( raise ValueError( "Expected the SMEM reference to have the same shape as the" f" transformed slice: {tuple(smem_ref_ty.shape)} !=" - f" {slice_shape[len(squeezed_dims):]}" + f" {tuple(slice_shape[len(squeezed_dims):])}" ) if implementation == AsyncCopyImplementation.CP_ASYNC: diff --git a/jax/experimental/mosaic/gpu/wgmma.py b/jax/experimental/mosaic/gpu/wgmma.py index bdb1c5c200fe..9af9e8965ad4 100644 --- a/jax/experimental/mosaic/gpu/wgmma.py +++ b/jax/experimental/mosaic/gpu/wgmma.py @@ -65,7 +65,8 @@ def value(self) -> fa.FragmentedArray: @classmethod def zero(cls, m, n, dtype=None, *, is_signed: bool | None = None): if m % 64 or n % 8: - raise ValueError + raise ValueError("WGMMA requires m and n to be multiples of 64 and 8, " + f"got {m} and {n}") if is_signed is False: # pylint: disable=g-bool-id-comparison raise TypeError("PTX does not support unsigned WGMMA accumulators") f32 = ir.F32Type.get() From 6cd6cc79227bf2f821a96efcbf5d8836c4f60221 Mon Sep 17 00:00:00 2001 From: George Necula Date: Fri, 5 Dec 2025 13:00:09 +0100 Subject: [PATCH 117/315] [export] Add backwards compatibility tests for serialized exports Recently we have found the need to evolve the serialization of `jax.export.Exported`. E.g., in #33942 we have added a 32-bit representation for `nr_devices`. This introduced a compatiblity bug that was found by usersm and fixed in in #33685. Here we add backwards compatibility tests. See the description in the `export_serialization_back_compat_test.py` module docstring. Note that this is separate from our previous set of backwards compatibility tests for the lowering of custom calls (in `export_back_compat_test.py`). However, we reuse some of the same ideas, and we use the same directory for storing saved old serializations. --- jax/_src/export/serialization.py | 14 +- .../export_with_specified_sharding.py | 29 +++ .../export_with_unspecified_sharding.py | 29 +++ tests/BUILD | 12 ++ .../export_serialization_back_compat_test.py | 167 ++++++++++++++++++ 5 files changed, 244 insertions(+), 7 deletions(-) create mode 100644 jax/_src/internal_test_util/export_back_compat_test_data/export_with_specified_sharding.py create mode 100644 jax/_src/internal_test_util/export_back_compat_test_data/export_with_unspecified_sharding.py create mode 100644 tests/export_serialization_back_compat_test.py diff --git a/jax/_src/export/serialization.py b/jax/_src/export/serialization.py index f479a6222d7e..4c8a65b89856 100644 --- a/jax/_src/export/serialization.py +++ b/jax/_src/export/serialization.py @@ -53,12 +53,7 @@ # Version 5, November 23rd, 2025, adds serialization for aval memory_space, # upgrade num_devices to a 32 bit value. # This version is backwards compatible with Version 2 to 4. -# TODO(necula): we cannot really store the actual serialization_version -# in the flatbuffer because prior to 11/25/2025 deserializers checked -# if the version is 2 or 3. I have now removed that check, but for the -# sake of old deserializers we can only store version 3. Starting -# on January 2026 we can store the actual version. -_SERIALIZATION_VERSION = 3 +_SERIALIZATION_VERSION = 5 def serialize(exp: _export.Exported, vjp_order: int = 0) -> bytearray: """Serializes an Exported. @@ -125,7 +120,12 @@ def _serialize_exported( vjp = _serialize_exported(builder, exp.vjp(), vjp_order - 1) ser_flatbuf.ExportedStart(builder) - ser_flatbuf.ExportedAddSerializationVersion(builder, _SERIALIZATION_VERSION) + # TODO(necula): we cannot really store the actual serialization_version + # in the flatbuffer because prior to 11/25/2025 deserializers checked + # if the version is 2 or 3. I have now removed that check, but for the + # sake of old deserializers we can only store version 3. Starting + # on January 2026 we can store the actual version. + ser_flatbuf.ExportedAddSerializationVersion(builder, 3) ser_flatbuf.ExportedAddFunctionName(builder, fun_name) ser_flatbuf.ExportedAddInTree(builder, in_tree) ser_flatbuf.ExportedAddInAvals(builder, in_avals) diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/export_with_specified_sharding.py b/jax/_src/internal_test_util/export_back_compat_test_data/export_with_specified_sharding.py new file mode 100644 index 000000000000..4cb08b42f268 --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/export_with_specified_sharding.py @@ -0,0 +1,29 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa + +# Pasted from the test output (see export_serialization_back_compat_test.py module docstring) +serializations = [ + dict( + serialization_version=4, + exported_serialized=bytearray(b"(\x00\x00\x00$\x00D\x00B\x00<\x008\x004\x000\x00,\x00*\x00$\x00 \x00\x1c\x00\x18\x00\x14\x00\x10\x00\x0c\x00\n\x00\x04\x00$\x00\x00\x00@\x00\x00\x00\x00\x00\n\x00@\x00\x00\x00H\x06\x00\x00H\x06\x00\x00H\x06\x00\x00,\x06\x00\x00D\x06\x00\x00d\x06\x00\x00\x00\x00\x02\x00\x80\x06\x00\x00\xac\x06\x00\x00\xac\x06\x00\x00\xe4\x06\x00\x008\x07\x00\x00\x00\x00\x02\x00\x01\x00\x00\x00\x00\x00\x00\x00\xf6\x05\x00\x00ML\xefR\rStableHLO_v1.13.0\x00\x01%\x07\x01\x05\t\x0f\x01\x03\x0f\x03\x03\x13\x05\x0b\x17\x1b\x1f#\'\x03\xc7\x9f\x11\x01y\x07\x0b\x0b\x0b\x0b\x0f#\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1f\x0b\x0f\x0f\x0b\x1f\x0f\x0f\x0b\x1b\x0b\x0f\x0f\x0b\x1b\x0b\x0f\x0f\x0b\x1f\x0b\x0f\x0f\x0b\x1f\x0f\x0b\x1f\x03\x0f\x17\x13\x13\x0f\x1b\x0f\x1b\x05\x19\x0b\x0f\x13\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x1f\x0f\x01\x05\x0f\x0b\x05\r\x17\x07\x0f\x17\x13\x07\x02~\x04\x1f\x05\x15\x05\x17\x05\t\t\x07\x1d!#\x03\x07\x0f\x11\x13\x15\x17\x19\x05\x19\x11\x03\x00\x05\x1b\x11\x01\t\x05\x1d\x11\x01\x05\x05\x1f\x1d\x1f\x01\x05!\x05#\x15%+\x1d\')\x05%-\x03\x07\xb1\x1f+\x15-3\x1d/1\x05\'-\x03\x07w\x15]\x155;\x1d79\x05)-\x03\x07\xb7!_\x15=E\x1d?A\x05+-C\x07~\x05!K\x05-\x15GM\x1dIK\x05/-\x05\x07\xca\x02\x11-\x15OW\x1dQS\x051-U\x07\xf35g\x053\x15Ya\x1d[]\x055-_\x07\xf1\x1f\x99\x057\x15ck\x1deg\x059-i\x07\x02\x08\x1f\xab\x05;\x15ms\x1doq\x05=-\x05\x07\xda\x03!_\x1duw\x05?-\x05\x07b\x05KW\x0b\x03\x83\x01\x01\x0b\x01\x01\x01\x05\x03\x7f\x01\x03A\t\r\t\x05y{\x01\tA\x01\r\t\x05{y\x01\x1dC\x03\x03\x8b\r\x03\x87\x81#\x0b\x03\x03\x91\r\x05\x93\x95\x87\x85\x1dE\x1dG\x1dI\x1dK\x1f\t\t\x00\x00\x00@\x1f\r\x01\x01\x02\x02\x01\t)\x05A\x11\x07\t)\x01\x07\x11\x03\x05\x03\x05)\x03\x01\x0f\x1d\x04s\x05\x01Q\x01\r\x01\x07\x04a\x03\x01\t\x03@\x01\x03\x05P\x01\x05\x07\x04E\x03\t\x13\x03\x0b\x1d\x00\x07B\x01\x07\x03\t\tF\x0b\t\x03\x05\x03\x03\x0b\x06\x0b\x03\x05\x05\x01\x05\r\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00\xba\x0fM\x0f\x0b\x0f!\x1b\x05\'E\x9b)\x9f1\x9f\x17)\xa13QAg\x17\x05\r%)9\x9d\x91\x15\x19)\x19\x11\x0b\x0f\x0b\t\x11builtin\x00sdy\x00vhlo\x00module\x00mesh\x00func_v1\x00constant_v1\x00broadcast_in_dim_v1\x00multiply_v1\x00return_v1\x00/Users/necula/Source/jax/tests/export_serialization_back_compat_test.py\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/_pytest/runner.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_f\x00b\x00jit(f)/mul\x00CompatTest.test_with_specified_sharding..f\x00CompatTest.export_and_serialize\x00CompatTest.test_with_specified_sharding\x00TestCaseFunction.runtest\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/_pytest/unittest.py\x00pytest_runtest_call\x00_multicall\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_callers.py\x00PluginManager._hookexec\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_manager.py\x00HookCaller.__call__\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_hooks.py\x00call_and_report..\x00CallInfo.from_call\x00x\x00sdy.sharding\x00jax.result_info\x00result\x00main\x00public\x00\x08#\x0b\x057\x01\x05}\x07\x0b\x89\x8d\x8f\x97\x99\x03\x9b\x03\x9d\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x03\x00\x00\x00cpu\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x1c\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x01\x0c\x00\x00\x00\x08\x03\x1a\x02\x01\x02J\x01\x02R\x01\x00\x01\x00\x00\x00\x04\x00\x00\x00@\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x01\x0c\x00\x00\x00\x08\x03\x1a\x02\x02\x01J\x01\x02R\x01\x00\x01\x00\x00\x00\x04\x00\x00\x00\xca\xff\xff\xff\x00\x00\x00\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x004\x00\x00\x00\x02\x00\x00\x0016\x00\x00p\xff\xff\xff\x01\x00\x00\x00\x10\x00\x00\x00\x00\x00\n\x00\x0c\x00\x00\x00\x08\x00\x07\x00\n\x00\x00\x00\x00\x00\x00\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x004\x00\x00\x00\x02\x00\x00\x0016\x00\x00\xcc\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x02\x02\x00\x00\x00,\x00\x00\x00\x10\x00\x00\x00\x00\x00\n\x00\x0c\x00\x0b\x00\x00\x00\x04\x00\n\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x08\x00\x0c\x00\x0b\x00\x04\x00\x08\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x02\x01\x00\x00\x00\x08\x00\x00\x00\x04\x00\x04\x00\x04\x00\x00\x00\x01\x00\x00\x00f\x00\x00\x00"), + ), # End paste + + + dict( + serialization_version=5, + exported_serialized=bytearray(b"0\x00\x00\x00\x00\x00*\x00H\x00F\x00@\x00<\x008\x004\x000\x00*\x00$\x00 \x00\x1c\x00\x18\x00\x14\x00\x10\x00\x0c\x00\n\x00\x04\x00\x00\x00\x00\x00,\x00*\x00\x00\x00D\x00\x00\x00\x00\x00\n\x00D\x00\x00\x00L\x06\x00\x00L\x06\x00\x00L\x06\x00\x000\x06\x00\x00H\x06\x00\x00h\x06\x00\x00\x00\x00\x02\x00\x02\x00\x00\x00\x80\x06\x00\x00\xac\x06\x00\x00\xac\x06\x00\x00\xe4\x06\x00\x008\x07\x00\x00\x00\x00\x03\x00\x01\x00\x00\x00\x00\x00\x00\x00\xf6\x05\x00\x00ML\xefR\rStableHLO_v1.13.0\x00\x01%\x07\x01\x05\t\x0f\x01\x03\x0f\x03\x03\x13\x05\x0b\x17\x1b\x1f#\'\x03\xc7\x9f\x11\x01y\x07\x0b\x0b\x0b\x0b\x0f#\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1f\x0b\x0f\x0f\x0b\x1f\x0f\x0f\x0b\x1b\x0b\x0f\x0f\x0b\x1b\x0b\x0f\x0f\x0b\x1f\x0b\x0f\x0f\x0b\x1f\x0f\x0b\x1f\x03\x0f\x17\x13\x13\x0f\x1b\x0f\x1b\x05\x19\x0b\x0f\x13\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x1f\x0f\x01\x05\x0f\x0b\x05\r\x17\x07\x0f\x17\x13\x07\x02~\x04\x1f\x05\x15\x05\x17\x05\t\t\x07\x1d!#\x03\x07\x0f\x11\x13\x15\x17\x19\x05\x19\x11\x03\x00\x05\x1b\x11\x01\t\x05\x1d\x11\x01\x05\x05\x1f\x1d\x1f\x01\x05!\x05#\x15%+\x1d\')\x05%-\x03\x07\xa5\x1f+\x15-3\x1d/1\x05\'-\x03\x07k\x15]\x155;\x1d79\x05)-\x03\x07\xa9\x1d[\x15=E\x1d?A\x05+-C\x07~\x05!K\x05-\x15GM\x1dIK\x05/-\x05\x07\xca\x02\x11-\x15OW\x1dQS\x051-U\x07\xf35g\x053\x15Ya\x1d[]\x055-_\x07\xf1\x1f\x99\x057\x15ck\x1deg\x059-i\x07\x02\x08\x1f\xab\x05;\x15ms\x1doq\x05=-\x05\x07\xda\x03!_\x1duw\x05?-\x05\x07b\x05KW\x0b\x03\x83\x01\x01\x0b\x01\x01\x01\x05\x03\x7f\x01\x03A\t\r\t\x05y{\x01\tA\x01\r\t\x05{y\x01\x1dC\x03\x03\x8b\r\x03\x87\x81#\x0b\x03\x03\x91\r\x05\x93\x95\x87\x85\x1dE\x1dG\x1dI\x1dK\x1f\t\t\x00\x00\x00@\x1f\r\x01\x01\x02\x02\x01\t)\x05A\x11\x07\t)\x01\x07\x11\x03\x05\x03\x05)\x03\x01\x0f\x1d\x04s\x05\x01Q\x01\r\x01\x07\x04a\x03\x01\t\x03@\x01\x03\x05P\x01\x05\x07\x04E\x03\t\x13\x03\x0b\x1d\x00\x07B\x01\x07\x03\t\tF\x0b\t\x03\x05\x03\x03\x0b\x06\x0b\x03\x05\x05\x01\x05\r\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00\xba\x0fM\x0f\x0b\x0f!\x1b\x05\'E\x9b)\x9f1\x9f\x17)\xa13QAg\x17\x05\r%)9\x9d\x91\x15\x19)\x19\x11\x0b\x0f\x0b\t\x11builtin\x00sdy\x00vhlo\x00module\x00mesh\x00func_v1\x00constant_v1\x00broadcast_in_dim_v1\x00multiply_v1\x00return_v1\x00/Users/necula/Source/jax/tests/export_serialization_back_compat_test.py\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/_pytest/runner.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_f\x00b\x00jit(f)/mul\x00CompatTest.test_with_specified_sharding..f\x00CompatTest.export_and_serialize\x00CompatTest.test_with_specified_sharding\x00TestCaseFunction.runtest\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/_pytest/unittest.py\x00pytest_runtest_call\x00_multicall\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_callers.py\x00PluginManager._hookexec\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_manager.py\x00HookCaller.__call__\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_hooks.py\x00call_and_report..\x00CallInfo.from_call\x00x\x00sdy.sharding\x00jax.result_info\x00result\x00main\x00public\x00\x08#\x0b\x057\x01\x05}\x07\x0b\x89\x8d\x8f\x97\x99\x03\x9b\x03\x9d\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x03\x00\x00\x00cpu\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x1c\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x01\x0c\x00\x00\x00\x08\x03\x1a\x02\x01\x02J\x01\x02R\x01\x00\x01\x00\x00\x00\x04\x00\x00\x00@\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x01\x0c\x00\x00\x00\x08\x03\x1a\x02\x02\x01J\x01\x02R\x01\x00\x01\x00\x00\x00\x04\x00\x00\x00\xcc\xff\xff\xff\x00\x00\x01\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x004\x00\x00\x00\x02\x00\x00\x0016\x00\x00p\xff\xff\xff\x01\x00\x00\x00\x10\x00\x00\x00\x0c\x00\x0c\x00\x00\x00\x08\x00\x07\x00\x06\x00\x0c\x00\x00\x00\x00\x00\x01\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x004\x00\x00\x00\x02\x00\x00\x0016\x00\x00\xcc\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x02\x02\x00\x00\x00,\x00\x00\x00\x10\x00\x00\x00\x00\x00\n\x00\x0c\x00\x0b\x00\x00\x00\x04\x00\n\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x08\x00\x0c\x00\x0b\x00\x04\x00\x08\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x02\x01\x00\x00\x00\x08\x00\x00\x00\x04\x00\x04\x00\x04\x00\x00\x00\x01\x00\x00\x00f\x00\x00\x00"), + ), # End paste +] diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/export_with_unspecified_sharding.py b/jax/_src/internal_test_util/export_back_compat_test_data/export_with_unspecified_sharding.py new file mode 100644 index 000000000000..4a0da0b85ee9 --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/export_with_unspecified_sharding.py @@ -0,0 +1,29 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa + + +# Pasted from the test output (see export_serialization_back_compat_test.py module docstring) +serializations = [ + dict( + serialization_version=4, + exported_serialized=bytearray(b"(\x00\x00\x00$\x00D\x00B\x00<\x008\x004\x000\x00,\x00*\x00$\x00 \x00\x1c\x00\x18\x00\x14\x00\x10\x00\x0c\x00\n\x00\x04\x00$\x00\x00\x00@\x00\x00\x00\x00\x00\n\x00@\x00\x00\x00D\x06\x00\x00D\x06\x00\x00D\x06\x00\x00(\x06\x00\x00@\x06\x00\x00H\x06\x00\x00\x00\x00\x02\x00d\x06\x00\x00\x90\x06\x00\x00\x90\x06\x00\x00\xc8\x06\x00\x00\x1c\x07\x00\x00\x00\x00\x02\x00\x01\x00\x00\x00\x00\x00\x00\x00\xf1\x05\x00\x00ML\xefR\rStableHLO_v1.13.0\x00\x01%\x07\x01\x05\t\x0f\x01\x03\x0f\x03\x03\x13\x05\x0b\x17\x1b\x1f#\'\x03\xc5\x9d\x11\x01y\x07\x0b\x0b\x0b\x0f#\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x0b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1f\x0b\x0f\x0f\x0b\x1f\x0f\x0f\x0b\x1b\x0b\x0f\x0f\x0b\x1b\x0b\x0f\x0f\x0b\x1f\x0b\x0f\x0f\x0b\x1f\x0f\x0b\x1f\x03\r\x13\x0f\x1b\x17\x0f\x13\x05\x19\x0f\x13\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x1f\x0f\x01\x05\x0f\x0b\x05\r\x17\x07\x0f\x17\x13\x07\x02^\x04\x1f\x05\x15\x05\x17\x05\t\x1d!#\x03\x07\r\x0f\x11\x13\x15\x17\x05\x19\x11\x03\x00\x05\x1b\x11\x01\t\x05\x1d\x11\x01\x05\x05\x1f\t\x07\x1d\x1f\x01\x05!\x05#\x15%+\x1d\')\x05%-\x03\x07\xed\x1f+\x15-3\x1d/1\x05\'-\x03\x07w\x15]\x155;\x1d79\x05)-\x03\x07\xf5!_\x15=E\x1d?A\x05+-C\x07~\x05!K\x05-\x15GM\x1dIK\x05/-\x05\x07\xca\x02\x11-\x15OW\x1dQS\x051-U\x07\xf35g\x053\x15Ya\x1d[]\x055-_\x07\xf1\x1f\x99\x057\x15ck\x1deg\x059-i\x07\x02\x08\x1f\xab\x05;\x15ms\x1doq\x05=-\x05\x07\xda\x03!_\x1duw\x05?-\x05\x07b\x05KW\x05\x03{\x01\x03A\t\r\x1b\x05\x7f\x83\x01\x0b\x03\x81\x01\x01\tA\x01\x0b\x01\x01\x01\x03\x03\x87\r\x03\x89}\x1dC#\x0b\x03\x03\x8f\r\x03\x91\x93\x1dE\x1dG\x1dI\x1dK\x1f\t\t\x00\x00\x00@\x1f\r\x01\x01\x02\x02\x01\t)\x05A\x11\x07\t)\x01\x07\x11\x03\x05\x03\x05)\x03\x01\x0f\x1d\x04s\x05\x01Q\x01\x0b\x01\x07\x04a\x03\x01\t\x03@\x01\x03\x05P\x01\x05\x07\x04E\x03\t\x13\x03\x0b\x1d\x00\x07B\x01\x07\x03\t\tF\t\t\x03\x05\x03\x03\x0b\x06\t\x03\x05\x05\x01\x05\r\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00\xca\x0fM\x0f\x0b\x0f!\x1b\x05\'E\x9b)\x9f1\x9f\x17)\xa13UAk\x17\x05\r%)9\x9d\x91\x15\x19)\x19\x11\x0b\x0f\x0b\t\x11builtin\x00sdy\x00vhlo\x00module\x00mesh\x00func_v1\x00constant_v1\x00broadcast_in_dim_v1\x00multiply_v1\x00return_v1\x00/Users/necula/Source/jax/tests/export_serialization_back_compat_test.py\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/_pytest/runner.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_f\x00b\x00jit(f)/mul\x00CompatTest.test_with_unspecified_sharding..f\x00CompatTest.export_and_serialize\x00CompatTest.test_with_unspecified_sharding\x00TestCaseFunction.runtest\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/_pytest/unittest.py\x00pytest_runtest_call\x00_multicall\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_callers.py\x00PluginManager._hookexec\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_manager.py\x00HookCaller.__call__\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_hooks.py\x00call_and_report..\x00CallInfo.from_call\x00x\x00sdy.sharding\x00jax.result_info\x00result\x00main\x00public\x00\x08#\x0b\x053\x01\x05y\x07\x0b\x85\x8b\x8d\x95\x97\x03\x99\x03\x9b\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x03\x00\x00\x00cpu\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x18\xff\xff\xff\x01\x00\x00\x00\x04\x00\x00\x00@\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x01\x0c\x00\x00\x00\x08\x03\x1a\x02\x02\x01J\x01\x02R\x01\x00\x01\x00\x00\x00\x04\x00\x00\x00\xca\xff\xff\xff\x00\x00\x00\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x004\x00\x00\x00\x02\x00\x00\x0016\x00\x00p\xff\xff\xff\x01\x00\x00\x00\x10\x00\x00\x00\x00\x00\n\x00\x0c\x00\x00\x00\x08\x00\x07\x00\n\x00\x00\x00\x00\x00\x00\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x004\x00\x00\x00\x02\x00\x00\x0016\x00\x00\xcc\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x02\x02\x00\x00\x00,\x00\x00\x00\x10\x00\x00\x00\x00\x00\n\x00\x0c\x00\x0b\x00\x00\x00\x04\x00\n\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x08\x00\x0c\x00\x0b\x00\x04\x00\x08\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x02\x01\x00\x00\x00\x08\x00\x00\x00\x04\x00\x04\x00\x04\x00\x00\x00\x01\x00\x00\x00f\x00\x00\x00"), + ), + + dict( + serialization_version=5, + exported_serialized=bytearray(b"0\x00\x00\x00\x00\x00*\x00H\x00F\x00@\x00<\x008\x004\x000\x00*\x00$\x00 \x00\x1c\x00\x18\x00\x14\x00\x10\x00\x0c\x00\n\x00\x04\x00\x00\x00\x00\x00,\x00*\x00\x00\x00D\x00\x00\x00\x00\x00\n\x00D\x00\x00\x00H\x06\x00\x00H\x06\x00\x00H\x06\x00\x00,\x06\x00\x00D\x06\x00\x00L\x06\x00\x00\x00\x00\x02\x00\x02\x00\x00\x00d\x06\x00\x00\x90\x06\x00\x00\x90\x06\x00\x00\xc8\x06\x00\x00\x1c\x07\x00\x00\x00\x00\x03\x00\x01\x00\x00\x00\x00\x00\x00\x00\xf1\x05\x00\x00ML\xefR\rStableHLO_v1.13.0\x00\x01%\x07\x01\x05\t\x0f\x01\x03\x0f\x03\x03\x13\x05\x0b\x17\x1b\x1f#\'\x03\xc5\x9d\x11\x01y\x07\x0b\x0b\x0b\x0f#\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x0b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1b\x0f\x0f\x0b\x1f\x0b\x0f\x0f\x0b\x1f\x0f\x0f\x0b\x1b\x0b\x0f\x0f\x0b\x1b\x0b\x0f\x0f\x0b\x1f\x0b\x0f\x0f\x0b\x1f\x0f\x0b\x1f\x03\r\x13\x0f\x1b\x17\x0f\x13\x05\x19\x0f\x13\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x1f\x0f\x01\x05\x0f\x0b\x05\r\x17\x07\x0f\x17\x13\x07\x02^\x04\x1f\x05\x15\x05\x17\x05\t\x1d!#\x03\x07\r\x0f\x11\x13\x15\x17\x05\x19\x11\x03\x00\x05\x1b\x11\x01\t\x05\x1d\x11\x01\x05\x05\x1f\t\x07\x1d\x1f\x01\x05!\x05#\x15%+\x1d\')\x05%-\x03\x07\xdf\x1f+\x15-3\x1d/1\x05\'-\x03\x07w\x15]\x155;\x1d79\x05)-\x03\x07\xe7!_\x15=E\x1d?A\x05+-C\x07~\x05!K\x05-\x15GM\x1dIK\x05/-\x05\x07\xca\x02\x11-\x15OW\x1dQS\x051-U\x07\xf35g\x053\x15Ya\x1d[]\x055-_\x07\xf1\x1f\x99\x057\x15ck\x1deg\x059-i\x07\x02\x08\x1f\xab\x05;\x15ms\x1doq\x05=-\x05\x07\xda\x03!_\x1duw\x05?-\x05\x07b\x05KW\x05\x03{\x01\x03A\t\r\x1b\x05\x7f\x83\x01\x0b\x03\x81\x01\x01\tA\x01\x0b\x01\x01\x01\x03\x03\x87\r\x03\x89}\x1dC#\x0b\x03\x03\x8f\r\x03\x91\x93\x1dE\x1dG\x1dI\x1dK\x1f\t\t\x00\x00\x00@\x1f\r\x01\x01\x02\x02\x01\t)\x05A\x11\x07\t)\x01\x07\x11\x03\x05\x03\x05)\x03\x01\x0f\x1d\x04s\x05\x01Q\x01\x0b\x01\x07\x04a\x03\x01\t\x03@\x01\x03\x05P\x01\x05\x07\x04E\x03\t\x13\x03\x0b\x1d\x00\x07B\x01\x07\x03\t\tF\t\t\x03\x05\x03\x03\x0b\x06\t\x03\x05\x05\x01\x05\r\x04\x01\x03\x07\x06\x03\x01\x05\x01\x00\xca\x0fM\x0f\x0b\x0f!\x1b\x05\'E\x9b)\x9f1\x9f\x17)\xa13UAk\x17\x05\r%)9\x9d\x91\x15\x19)\x19\x11\x0b\x0f\x0b\t\x11builtin\x00sdy\x00vhlo\x00module\x00mesh\x00func_v1\x00constant_v1\x00broadcast_in_dim_v1\x00multiply_v1\x00return_v1\x00/Users/necula/Source/jax/tests/export_serialization_back_compat_test.py\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/_pytest/runner.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_f\x00b\x00jit(f)/mul\x00CompatTest.test_with_unspecified_sharding..f\x00CompatTest.export_and_serialize\x00CompatTest.test_with_unspecified_sharding\x00TestCaseFunction.runtest\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/_pytest/unittest.py\x00pytest_runtest_call\x00_multicall\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_callers.py\x00PluginManager._hookexec\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_manager.py\x00HookCaller.__call__\x00/Users/necula/Source/jax/.venv/lib/python3.12/site-packages/pluggy/_hooks.py\x00call_and_report..\x00CallInfo.from_call\x00x\x00sdy.sharding\x00jax.result_info\x00result\x00main\x00public\x00\x08#\x0b\x053\x01\x05y\x07\x0b\x85\x8b\x8d\x95\x97\x03\x99\x03\x9b\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x03\x00\x00\x00cpu\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x18\xff\xff\xff\x01\x00\x00\x00\x04\x00\x00\x00@\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x01\x0c\x00\x00\x00\x08\x03\x1a\x02\x02\x01J\x01\x02R\x01\x00\x01\x00\x00\x00\x04\x00\x00\x00\xcc\xff\xff\xff\x00\x00\x01\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x004\x00\x00\x00\x02\x00\x00\x0016\x00\x00p\xff\xff\xff\x01\x00\x00\x00\x10\x00\x00\x00\x0c\x00\x0c\x00\x00\x00\x08\x00\x07\x00\x06\x00\x0c\x00\x00\x00\x00\x00\x01\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x004\x00\x00\x00\x02\x00\x00\x0016\x00\x00\xcc\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x02\x02\x00\x00\x00,\x00\x00\x00\x10\x00\x00\x00\x00\x00\n\x00\x0c\x00\x0b\x00\x00\x00\x04\x00\n\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x08\x00\x0c\x00\x0b\x00\x04\x00\x08\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x02\x01\x00\x00\x00\x08\x00\x00\x00\x04\x00\x04\x00\x04\x00\x00\x00\x01\x00\x00\x00f\x00\x00\x00"), + ), +] diff --git a/tests/BUILD b/tests/BUILD index ad8a01ca70da..e27ed2957a28 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -2088,6 +2088,18 @@ jax_multiplatform_test( ]), ) +jax_multiplatform_test( + name = "export_serialization_back_compat_test", + srcs = ["export_serialization_back_compat_test.py"], + tags = [], + deps = [ + "//jax/_src:internal_export_back_compat_test_data", + ] + py_deps([ + "absl/testing", + "numpy", + ]), +) + jax_multiplatform_test( name = "fused_attention_stablehlo_test", srcs = ["fused_attention_stablehlo_test.py"], diff --git a/tests/export_serialization_back_compat_test.py b/tests/export_serialization_back_compat_test.py new file mode 100644 index 000000000000..02886cb237df --- /dev/null +++ b/tests/export_serialization_back_compat_test.py @@ -0,0 +1,167 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for backwards compatibility of serialization of JAX exports. + +Whenever we change the serialization format for jax.export.Exported +(see file jax.export.serialization), we should first save a serialization +of the current format and add a test that it can be deserialized and it has +the expected behavior. + +To add a new test: + + * Create a new test method, with a function to be serialized that exercises + the feature you want to test, and a call to self.export_and_serialize. + You can follow the model of the tests below, which are parameterized by + the test data. Use `None` for the test data to signal that you want to + use a fresh serialization. + * Run the test. This will save the serialized data in + TEST_UNDECLARED_OUTPUTS_DIR (or "/tmp/back_compat_testdata" if not set). + * Copy the test data defined in the output file, to the file + jax._src.internal_test_util.export_back_compat_test_data.export_{name}.py. + * Add a new import statement to this file to import that module + +This process will ensure that the saved serialized export can be read by +future code version (backward compatibility of the deserializer). To check +forward compatibility you'd have to check out an older version of the code +and cherry pick a new version of the directory +`jax._src.internal_test_util.export_back_compat_test_data`. +""" + +import logging +import os +import re +from typing import Any + +from absl.testing import absltest +import numpy as np + +# ruff: noqa: F401 +try: + import flatbuffers + CAN_SERIALIZE = True +except (ModuleNotFoundError, ImportError): + CAN_SERIALIZE = False + +import jax +from jax._src import config +from jax._src.export import _export +from jax._src.export.serialization import _SERIALIZATION_VERSION +from jax.sharding import PartitionSpec as P +from jax._src import test_util as jtu + +from jax._src.internal_test_util.export_back_compat_test_data import export_with_specified_sharding +from jax._src.internal_test_util.export_back_compat_test_data import export_with_unspecified_sharding + +config.parse_flags_with_absl() +jtu.request_cpu_devices(8) + + +class CompatTest(jtu.JaxTestCase): + + def setUp(self): + if not CAN_SERIALIZE: + self.skipTest("Serialization not available") + + def export_and_serialize(self, fun, *args, + vjp_order=0, + **kwargs) -> bytearray: + """Export and serialize a function. + + The test data is saved in TEST_UNDECLARED_OUTPUTS_DIR (or + "/tmp/back_compat_testdata" if not set) and should be copied as explained + in the module docstring. + """ + exp = _export.export(fun)(*args, **kwargs) + serialized = exp.serialize(vjp_order=vjp_order) + updated_testdata = f""" + # Paste to the test data file (see export_serialization_back_compat_test.py module docstring) + dict( + serialization_version={_SERIALIZATION_VERSION}, + exported_serialized={serialized!r}, + ), + +""" + # Replace the word that should not appear. + updated_testdata = re.sub(r"google.", "googlex", updated_testdata) + output_dir = os.getenv("TEST_UNDECLARED_OUTPUTS_DIR", + "/tmp/back_compat_testdata") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + output_file = os.path.join(output_dir, f"export_{self._testMethodName}.py") + logging.info("Writing the updated serialized Exported at %s", output_file) + with open(output_file, "w") as f: + f.write(updated_testdata) + return serialized + + @jtu.parameterized_filterable( + kwargs=[ + dict(testdata=testdata, + testcase_name=("current" if testdata is None + else f"v{testdata['serialization_version']}")) + for testdata in [None, *export_with_specified_sharding.serializations] + ] + ) + def test_with_specified_sharding(self, testdata: dict[str, Any] | None): + a = np.arange(16 * 4, dtype=np.float32).reshape((16, 4)) + mesh = jtu.create_mesh((2,), "x") + with jax.set_mesh(mesh): + @jax.jit(in_shardings=(jax.sharding.NamedSharding(mesh, P("x", None),),), + out_shardings=jax.sharding.NamedSharding(mesh, P(None, "x"))) + def f(b): + return b * 2. + + a = jax.device_put(a, jax.sharding.NamedSharding(mesh, P("x", None))) + if testdata is None: + serialized = self.export_and_serialize(f, a) + else: + serialized = testdata["exported_serialized"] + + out = _export.deserialize(serialized).call(a) + self.assertAllClose(out, a * 2.) + self.assertEqual(out.addressable_shards[0].index, (slice(None), slice(0, 2))) + self.assertEqual(out.addressable_shards[1].index, (slice(None), slice(2, 4))) + + + @jtu.parameterized_filterable( + kwargs=[ + dict(testdata=testdata, + testcase_name=("current" if testdata is None + else f"v{testdata['serialization_version']}")) + for testdata in [None, *export_with_unspecified_sharding.serializations] + ] + ) + def test_with_unspecified_sharding(self, testdata: dict[str, Any] | None): + a = np.arange(16 * 4, dtype=np.float32).reshape((16, 4)) + + # Output sharding is not specified + mesh = jtu.create_mesh((2,), "x") + with jax.set_mesh(mesh): + @jax.jit(in_shardings=(jax.sharding.NamedSharding(mesh, P("x", None),),)) + def f(b): + return b * 2. + + a = jax.device_put(a, jax.sharding.NamedSharding(mesh, P("x", None))) + if testdata is None: + serialized = self.export_and_serialize(f, a) + else: + serialized = testdata["exported_serialized"] + + out = _export.deserialize(serialized).call(a) + self.assertAllClose(out, a * 2.) + self.assertEqual(out.addressable_shards[0].index, (slice(0, 8), slice(None))) + self.assertEqual(out.addressable_shards[1].index, (slice(8, 16), slice(None))) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) From 14f072b764196d3abe5322860f1b523f4c2ec0f3 Mon Sep 17 00:00:00 2001 From: Krishna Haridasan Date: Tue, 9 Dec 2025 09:55:27 -0800 Subject: [PATCH 118/315] Use ifrt::AttributeMap::Get instead of directly accessing map Introduces a variant of Get in AttributeMap that returns the value variant as is. PiperOrigin-RevId: 842283537 --- jaxlib/BUILD | 1 + jaxlib/jax.cc | 10 ++++++---- jaxlib/py_client.cc | 10 ++++++---- jaxlib/py_device.cc | 14 ++++++++------ 4 files changed, 21 insertions(+), 14 deletions(-) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index bfcff630cc89..c50200aa6da0 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -407,6 +407,7 @@ nanobind_pywrap_extension( "@xla//xla/python:types", "@xla//xla/python:version", "@xla//xla/python/ifrt", + "@xla//xla/python/ifrt:attribute_map", "@xla//xla/python/ifrt:plugin_program", "@xla//xla/python/ifrt:plugin_program_serdes", "@xla//xla/python/pjrt_ifrt", diff --git a/jaxlib/jax.cc b/jaxlib/jax.cc index 86be4995b823..30959db5f1de 100644 --- a/jaxlib/jax.cc +++ b/jaxlib/jax.cc @@ -65,6 +65,7 @@ limitations under the License. #include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h" #include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/executable.h" @@ -911,11 +912,12 @@ NB_MODULE(_jax, m) { .def("__getattr__", [](xla::ifrt::Topology& topology, std::string_view name) -> nb::object { - const auto& attrs = topology.Attributes().map(); - auto it = attrs.find(name); - if (it != attrs.end()) { + auto value = + topology.Attributes().Get( + std::string(name)); + if (value.ok()) { return std::visit([](auto&& v) { return nb::cast(v.value); }, - it->second); + *value); } throw nb::attribute_error( absl::StrCat("Unknown attribute ", name).c_str()); diff --git a/jaxlib/py_client.cc b/jaxlib/py_client.cc index 71b89d626cf2..9be3e8598a36 100644 --- a/jaxlib/py_client.cc +++ b/jaxlib/py_client.cc @@ -74,6 +74,7 @@ limitations under the License. #include "xla/pjrt/pjrt_layout.h" #include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/compiler.h" #include "xla/python/ifrt/device.h" @@ -1031,11 +1032,12 @@ PyType_Slot PyClient::slots_[] = { nb::arg("dtype"), nb::arg("shard_shape"), nb::arg("device")) .def("__getattr__", [](PyClient& client, std::string_view name) -> nb::object { - const auto& attrs = client.Attributes().map(); - auto it = attrs.find(name); - if (it != attrs.end()) { + auto value = + client.Attributes().Get( + std::string(name)); + if (value.ok()) { return std::visit([](auto&& v) { return nb::cast(v.value); }, - it->second); + *value); } throw nb::attribute_error( absl::StrCat("Unknown attribute ", name).c_str()); diff --git a/jaxlib/py_device.cc b/jaxlib/py_device.cc index 7863ef21cdce..8d4bcb216dfb 100644 --- a/jaxlib/py_device.cc +++ b/jaxlib/py_device.cc @@ -40,6 +40,7 @@ limitations under the License. #include "jaxlib/py_memory_space.h" #include "jaxlib/python_ref_manager.h" #include "xla/pjrt/status_casters.h" +#include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/device.h" #include "xla/python/nb_helpers.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" @@ -278,12 +279,13 @@ PyType_Slot PyDevice::slots_[] = { } try { auto device = nb::cast(nb::handle(self)); - auto name = nb::cast(nb::handle(key)); - const auto& attrs = device->device_->Attributes().map(); - auto it = attrs.find(name); - if (it != attrs.end()) { - auto result = std::visit([](auto&& v) { return nb::cast(v.value); }, - it->second); + auto name = nb::cast(nb::handle(key)); + auto value = + device->device_->Attributes().Get( + name); + if (value.ok()) { + auto result = + std::visit([](auto&& v) { return nb::cast(v.value); }, *value); return result.release().ptr(); } PyErr_SetNone(PyExc_AttributeError); From f7e3bdb29efaac06676690e958eef607e9022f4b Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 9 Dec 2025 10:30:07 -0800 Subject: [PATCH 119/315] Mention pmap is in maintenance mode and point to shard_map and the migration guide Co-authored-by: Matthew Johnson PiperOrigin-RevId: 842298387 --- docs/jax.sharding.rst | 3 --- jax/_src/api.py | 28 +++++++--------------------- jax/_src/sharding_impls.py | 1 - 3 files changed, 7 insertions(+), 25 deletions(-) diff --git a/docs/jax.sharding.rst b/docs/jax.sharding.rst index 12760d62ddb3..0146398cb12d 100644 --- a/docs/jax.sharding.rst +++ b/docs/jax.sharding.rst @@ -16,9 +16,6 @@ Classes .. autoclass:: NamedSharding :members: :show-inheritance: -.. autoclass:: PmapSharding - :members: - :show-inheritance: .. autoclass:: PartitionSpec :members: .. autoclass:: Mesh diff --git a/jax/_src/api.py b/jax/_src/api.py index e5c57b3d99e4..39097971ceee 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -1345,7 +1345,13 @@ def pmap( donate_argnums: int | Iterable[int] = (), global_arg_shapes: tuple[tuple[int, ...], ...] | None = None, ) -> Any: - """Parallel map with support for collective operations. + """Old way of doing parallel map. Use :py:func:`jax.shard_map` instead. + + .. note:: + While :py:func:`jax.pmap` works, you should probably use + :py:func:`jax.shard_map` or ``jax.smap`` instead. shard_map supports more + efficient autodiff, and is more composable in the multi-controller setting. + See https://docs.jax.dev/en/latest/notebooks/shard_map.html for examples. .. note:: :py:func:`pmap` is now implemented in terms of :py:func:`jit` and @@ -1510,26 +1516,6 @@ def pmap( are important particularly in the case of nested :py:func:`pmap` functions, where collective operations can operate over distinct axes: - >>> from functools import partial - >>> import jax - >>> - >>> @partial(pmap, axis_name='rows') - ... @partial(pmap, axis_name='cols') - ... def normalize(x): - ... row_normed = x / jax.lax.psum(x, 'rows') - ... col_normed = x / jax.lax.psum(x, 'cols') - ... doubly_normed = x / jax.lax.psum(x, ('rows', 'cols')) - ... return row_normed, col_normed, doubly_normed - >>> - >>> x = jnp.arange(8.).reshape((4, 2)) - >>> row_normed, col_normed, doubly_normed = normalize(x) # doctest: +SKIP - >>> print(row_normed.sum(0)) # doctest: +SKIP - [ 1. 1.] - >>> print(col_normed.sum(1)) # doctest: +SKIP - [ 1. 1. 1. 1.] - >>> print(doubly_normed.sum((0, 1))) # doctest: +SKIP - 1.0 - On multi-process platforms, collective operations operate over all devices, including those on other processes. For example, assuming the following code runs on two processes with 4 XLA devices each: diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index ded173ffb653..b658a4d13966 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -193,7 +193,6 @@ def pmap_sharding_devices_indices_map( @use_cpp_class(xc.PmapSharding) class PmapSharding(jsharding.Sharding): - """Describes a sharding used by :func:`jax.pmap`.""" devices: np.ndarray sharding_spec: sharding_specs.ShardingSpec _internal_device_list: xc.DeviceList From 6cb91ede15882ef60111f89877e4c98623e7d7c6 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 9 Dec 2025 10:50:42 -0800 Subject: [PATCH 120/315] [mosaic] Added documentation and a few useful methods to tpu.tiled attribute PiperOrigin-RevId: 842307512 --- jaxlib/mosaic/dialect/tpu/tpu.td | 28 ++- jaxlib/mosaic/dialect/tpu/tpu_dialect.cc | 222 +++++++++++++++++++++-- 2 files changed, 229 insertions(+), 21 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 892032666299..f7383756e5d6 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -168,13 +168,39 @@ def TPU_VectorLayoutAttr : TPU_Attr<"VectorLayout", "vpad"> { def TPU_TiledLayoutAttr : TPU_Attr<"TiledLayout", "tiled", [DeclareAttrInterfaceMethods]> { - let description = [{TODO}]; + let description = [{ + This attribute represents tiled layouts in memrefs. + + Multiple levels of tiling are supported with the following restriction: + - Additional levels of tiling may not add any padding. + - Additional levels of tiling may not tile previously untiled dimensions, + that is, they cannot tile across first-level tiles. + + Tile strides encode the stride when moving along a given dimension. They + must have the same rank as the shape and must be decreasing with increasing + dimension number. For tiled dimensions, the stride applies only when moving + across first-level tiles. The strides are in units of the size of the first + tile, or 1 if there are no tiles. + }]; let parameters = (ins ArrayRefParameter<"::xla::Tile", "">:$tiles, ArrayRefParameter<"int64_t", "">:$tile_strides ); + let extraClassDeclaration = [{ + static ::llvm::SmallVector getDefaultTileStrides(::llvm::ArrayRef<::xla::Tile> tiles, ::llvm::ArrayRef shape); + bool tilesAreKnownContiguous(::llvm::ArrayRef shape) const; + + int64_t getRank() const { + return getTileStrides().size(); + } + int64_t getUntiledRank() const; + + ::llvm::SmallVector getExpandedShape(::llvm::ArrayRef shape) const; + ::llvm::SmallVector getExpandedStrides() const; + }]; let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; } def TPU_MemorySpace : I32EnumAttr<"MemorySpace", "Memory space", [ diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc index d49380877549..b0cb7d35c4d5 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc @@ -15,6 +15,8 @@ limitations under the License. #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" +#include +#include #include #include #include @@ -23,14 +25,16 @@ limitations under the License. #include "absl/log/log.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/Hashing.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep. +#include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep. @@ -215,32 +219,210 @@ Attribute TiledLayoutAttr::parse(AsmParser &parser, Type type) { } AffineMap TiledLayoutAttr::getAffineMap() const { - AffineMap map = - AffineMap::getMultiDimIdentityMap(getTileStrides().size(), getContext()); SmallVector exprs; - for (const xla::Tile &tile : getTiles()) { - exprs.clear(); + for (int64_t i = 0; i < getRank(); ++i) { + exprs.push_back(getAffineDimExpr(i, getContext())); + } + for (const xla::Tile& tile : getTiles()) { + SmallVector new_exprs; auto dimensions = tile.dimensions(); - int64_t untiled_dims = map.getNumResults() - dimensions.size(); - if (untiled_dims < 0) { - LOG(FATAL) << "Invalid TiledLayoutAttr: Number of dims must be larger " - "or equal to the rank of the tile"; + int64_t untiled_rank = exprs.size() - dimensions.size(); + assert(untiled_rank >= 0); + for (int64_t i = 0; i < untiled_rank; ++i) { + new_exprs.push_back(exprs[i]); + } + for (int64_t i = 0; i < dimensions.size(); ++i) { + new_exprs.push_back(exprs[untiled_rank + i].floorDiv(dimensions[i])); + } + for (int64_t i = 0; i < dimensions.size(); ++i) { + new_exprs.push_back(exprs[untiled_rank + i] % dimensions[i]); + } + exprs = std::move(new_exprs); + } + int64_t num_symbols = 0; + AffineExpr result = getAffineConstantExpr(0, getContext()); + SmallVector strides = getExpandedStrides(); + assert(strides.size() == exprs.size()); + for (int64_t i = 0; i < exprs.size(); ++i) { + AffineExpr stride_expr = + ShapedType::isDynamic(strides[i]) + ? getAffineSymbolExpr(num_symbols++, getContext()) + : getAffineConstantExpr(strides[i], getContext()); + result = result + exprs[i] * stride_expr; + } + return AffineMap::get(getRank(), num_symbols, result); +} + +namespace { +int64_t getUntiledRank(ArrayRef tiles, const int64_t rank) { + // Note: This implementation does not assume there is no nested tiling across + // the first level of tiling, though this is enforced by the verifier. + int64_t untiled_rank = rank; + int64_t tiled_rank = rank; + for (const xla::Tile& tile : tiles) { + const int64_t tile_ndims = tile.dimensions().size(); + untiled_rank = std::min(untiled_rank, tiled_rank - tile_ndims); + tiled_rank += tile_ndims; + } + return untiled_rank; +} +} // namespace + +int64_t TiledLayoutAttr::getUntiledRank() const { + return mlir::tpu::getUntiledRank(getTiles(), getRank()); +} + +namespace { +FailureOr> getExpandedShape( + const ArrayRef untiled_shape, const ArrayRef tiles, + const bool require_alignment) { + SmallVector shape(untiled_shape); + for (const xla::Tile& tile : tiles) { + const int64_t tile_ndims = tile.dimensions().size(); + const llvm::ArrayRef tiled_shape = + llvm::ArrayRef(shape).take_back(tile_ndims); + llvm::SmallVector new_tiled_shape(2 * tile_ndims); + for (int64_t i = 0; i < tile_ndims; ++i) { + if (require_alignment && (ShapedType::isDynamic(tiled_shape[i]) || + tiled_shape[i] % tile.dimension(i) != 0)) { + return failure(); + } + if (ShapedType::isDynamic(tiled_shape[i])) { + new_tiled_shape[i] = ShapedType::kDynamic; + } else { + new_tiled_shape[i] = + llvm::divideCeil(tiled_shape[i], tile.dimension(i)); + } + new_tiled_shape[tile_ndims + i] = tile.dimension(i); + } + shape.pop_back_n(tile_ndims); + shape.append(new_tiled_shape); + } + return shape; +} +} // namespace + +SmallVector TiledLayoutAttr::getDefaultTileStrides( + const ArrayRef tiles, const ArrayRef shape) { + SmallVector strides(shape.size()); + int64_t stride = 1; + const xla::Tile* const first_tile = tiles.empty() ? nullptr : &tiles.front(); + const int64_t first_tile_rank = + first_tile == nullptr ? 0 : first_tile->dimensions().size(); + for (int64_t d = shape.size() - 1; d >= 0; --d) { + assert(!ShapedType::isDynamic(shape[d])); + strides[d] = stride; + if (d >= shape.size() - first_tile_rank) { + assert(first_tile != nullptr); + const int64_t tile_d = d - (shape.size() - first_tile_rank); + stride *= llvm::divideCeil(shape[d], first_tile->dimension(tile_d)); + } else { + stride *= shape[d]; } - for (int64_t i = 0; i < untiled_dims; ++i) { - exprs.push_back(getAffineDimExpr(i, getContext())); + } + return strides; +} + +bool TiledLayoutAttr::tilesAreKnownContiguous( + const ArrayRef shape) const { + const ArrayRef tiles = getTiles(); + const ArrayRef tile_strides = getTileStrides(); + int64_t stride = 1; + const xla::Tile* const first_tile = tiles.empty() ? nullptr : &tiles.front(); + const int64_t first_tile_rank = + first_tile == nullptr ? 0 : first_tile->dimensions().size(); + for (int64_t d = shape.size() - 1; d >= 0; --d) { + int64_t size_tiles; + if (d >= shape.size() - first_tile_rank && + shape[d] != ShapedType::kDynamic) { + assert(first_tile != nullptr); + const int64_t tile_d = d - (shape.size() - first_tile_rank); + size_tiles = llvm::divideCeil(shape[d], first_tile->dimension(tile_d)); + } else { + size_tiles = shape[d]; } - for (int i = 0; i < dimensions.size(); ++i) { - exprs.push_back(getAffineDimExpr(untiled_dims + i, getContext()) - .floorDiv(dimensions[i])); + // Dimensions with only one element/tile can have any stride. + if (stride != tile_strides[d] && size_tiles != 1) { + return false; } - for (int i = 0; i < dimensions.size(); ++i) { - exprs.push_back(getAffineDimExpr(untiled_dims + i, getContext()) % - dimensions[i]); + if (d == 0) { + break; } - auto tile_map = AffineMap::get(map.getNumResults(), 0, exprs, getContext()); - map = tile_map.compose(map); + // When any dimension other than the leading one has a dynamic size, we + // cannot guarantee that there are no gaps. + if (size_tiles == ShapedType::kDynamic) { + return false; + } + stride *= size_tiles; + } + return true; +} + +SmallVector TiledLayoutAttr::getExpandedShape( + ArrayRef untiled_shape) const { + // getExpandedShape should never fail without require_alignment + return *mlir::tpu::getExpandedShape(untiled_shape, getTiles(), + /*require_alignment=*/false); +} + +SmallVector TiledLayoutAttr::getExpandedStrides() const { + if (getTiles().empty()) { + return SmallVector(getTileStrides()); + } + SmallVector strides(getTileStrides()); + // Expand front tile + const xla::Tile& first_tile = getTiles().front(); + const FailureOr> failure_or_expanded_tile = + mlir::tpu::getExpandedShape(first_tile.dimensions(), + getTiles().drop_front(), + /*require_alignment=*/true); + // Verification should ensure this: + assert(succeeded(failure_or_expanded_tile)); + const SmallVector& expanded_tile = *failure_or_expanded_tile; + strides.resize_for_overwrite(getRank() + expanded_tile.size()); + int64_t first_tile_size = llvm::product_of(first_tile.dimensions()); + int64_t tile_size = 1; + for (int64_t d = strides.size() - 1; d >= 0; --d) { + if (d >= getRank()) { + const int64_t new_stride = tile_size; + tile_size *= expanded_tile[d - getRank()]; + strides[d] = new_stride; + } else { + strides[d] *= first_tile_size; + } + } + return strides; +} + +LogicalResult TiledLayoutAttr::verify( + function_ref emitError, + const llvm::ArrayRef tiles, + const llvm::ArrayRef tile_strides) { + if (llvm::any_of(tile_strides, ShapedType::isDynamic)) { + return emitError() << "Not implemented: Dynamic tile strides"; + } + if (tiles.empty()) { + return success(); + } + const int64_t rank = tile_strides.size(); + const xla::Tile& first_tile = tiles.front(); + const int64_t first_tile_rank = first_tile.dimensions().size(); + // The interpretation of tile strides is unclear if there is nested tiling + // across first tiles (e.g. T(8, 128)(2, 4, 64)), and this has no applications + // anyway. + if (mlir::tpu::getUntiledRank(tiles, rank) != rank - first_tile_rank) { + return emitError() << "Not implemented: Nested tiling across first tiles"; + } + // Check that nested tiles evenly divide previous tiles (so they don't add any + // padding or change the tile size) + if (failed(mlir::tpu::getExpandedShape(first_tile.dimensions(), + tiles.drop_front(), + /*require_alignment=*/true))) { + return emitError() << "Not implemented: Nested tiles must evenly divide " + << "the first tile " << first_tile.ToString() + << " but they do not (would add padding)"; } - return map; + return success(); } MemRefType getMemRefType(Value value) { From 863e4e752da0e4081ccec90f6f9ad1d85c3f96ac Mon Sep 17 00:00:00 2001 From: Emily Fertig Date: Tue, 9 Dec 2025 11:03:52 -0800 Subject: [PATCH 121/315] Remove multiprocess tests from TPU presubmit due to latency. They still run in CI. PiperOrigin-RevId: 842313199 --- ci/run_bazel_test_tpu.sh | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ci/run_bazel_test_tpu.sh b/ci/run_bazel_test_tpu.sh index 5c8d53e4c23b..570b990e5484 100755 --- a/ci/run_bazel_test_tpu.sh +++ b/ci/run_bazel_test_tpu.sh @@ -227,9 +227,7 @@ else //tests:layout_test_tpu \ //tests:pjit_test_tpu \ //tests:python_callback_test_tpu \ - //tests:ragged_collective_test_tpu \ - //tests/multiprocess:tpu_tests \ - $IGNORE_TESTS_MULTIACCELERATOR + //tests:ragged_collective_test_tpu # Store the return value of the second bazel command. second_bazel_cmd_retval=$? From 3a366567bbae1235848dd07e785c84f60143f960 Mon Sep 17 00:00:00 2001 From: Michael Whittaker Date: Tue, 30 Sep 2025 18:23:14 +0000 Subject: [PATCH 122/315] Adds docs on multi-controller JAX fault tolerance. --- .../fault_tolerance/cancel_collectives.py | 60 + docs/_static/fault_tolerance/collectives.py | 50 + .../fault_tolerance/data_parallelism.py | 129 ++ .../data_parallelism_with_recovery.py | 182 ++ docs/_static/fault_tolerance/dont_fail.py | 45 + .../fault_tolerance/fault_tolerance.css | 283 +++ .../fault_tolerance/fault_tolerance.js | 377 ++++ docs/_static/fault_tolerance/live_devices.py | 64 + docs/_static/fault_tolerance/while_loop.py | 41 + docs/advanced_guides.rst | 1 + docs/fault_tolerance.rst | 1524 +++++++++++++++++ 11 files changed, 2756 insertions(+) create mode 100644 docs/_static/fault_tolerance/cancel_collectives.py create mode 100644 docs/_static/fault_tolerance/collectives.py create mode 100644 docs/_static/fault_tolerance/data_parallelism.py create mode 100644 docs/_static/fault_tolerance/data_parallelism_with_recovery.py create mode 100644 docs/_static/fault_tolerance/dont_fail.py create mode 100644 docs/_static/fault_tolerance/fault_tolerance.css create mode 100644 docs/_static/fault_tolerance/fault_tolerance.js create mode 100644 docs/_static/fault_tolerance/live_devices.py create mode 100644 docs/_static/fault_tolerance/while_loop.py create mode 100644 docs/fault_tolerance.rst diff --git a/docs/_static/fault_tolerance/cancel_collectives.py b/docs/_static/fault_tolerance/cancel_collectives.py new file mode 100644 index 000000000000..42c75c015112 --- /dev/null +++ b/docs/_static/fault_tolerance/cancel_collectives.py @@ -0,0 +1,60 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +os.environ['XLA_FLAGS'] = ' '.join([ + '--xla_gpu_nccl_terminate_on_error=false', + '--xla_gpu_nccl_async_execution=true', + '--xla_gpu_nccl_blocking_communicators=false', +]) +os.environ['XLA_PYTHON_CLIENT_ABORT_COLLECTIVES_ON_FAILURE'] = '1' +os.environ['XLA_PYTHON_CLIENT_USE_TFRT_GPU_CLIENT'] = '1' + +from absl import app +from absl import flags +from collections.abc import Sequence +import jax +import jax.numpy as jnp +import time + +_PROCESS_ID = flags.DEFINE_integer("i", -1, "Process id") +_NUM_PROCESSES = flags.DEFINE_integer("n", -1, "Number of processes") + + +def main(_: Sequence[str]) -> None: + jax.config.update("jax_enable_recoverability", True) + jax.distributed.initialize( + coordinator_address="localhost:9000", + num_processes=_NUM_PROCESSES.value, + process_id=_PROCESS_ID.value, + local_device_ids=[_PROCESS_ID.value], + heartbeat_timeout_seconds=10, + ) + print(f'{jax.devices()=}') + print(f'{jax.local_devices()=}') + + # Don't do this. Use live_devices instead. + from jax.experimental.multihost_utils import _live_devices + _live_devices(jax._src.distributed.global_state.client, jax.devices()) + + n = jax.device_count() + jax.set_mesh(jax.make_mesh((n,), ("i",))) + x = jax.device_put(jnp.arange(n), jax.P("i")) + while True: + print(jnp.sum(x)) + time.sleep(1) + + +if __name__ == "__main__": + app.run(main) diff --git a/docs/_static/fault_tolerance/collectives.py b/docs/_static/fault_tolerance/collectives.py new file mode 100644 index 000000000000..0f120f47271f --- /dev/null +++ b/docs/_static/fault_tolerance/collectives.py @@ -0,0 +1,50 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +os.environ['XLA_FLAGS'] = '--xla_gpu_nccl_terminate_on_error=false' + +from absl import app +from absl import flags +from collections.abc import Sequence +import jax +import jax.numpy as jnp +import time + +_PROCESS_ID = flags.DEFINE_integer("i", -1, "Process id") +_NUM_PROCESSES = flags.DEFINE_integer("n", -1, "Number of processes") + + +def main(_: Sequence[str]) -> None: + jax.config.update("jax_enable_recoverability", True) + jax.distributed.initialize( + coordinator_address="localhost:9000", + num_processes=_NUM_PROCESSES.value, + process_id=_PROCESS_ID.value, + local_device_ids=[_PROCESS_ID.value], + heartbeat_timeout_seconds=10, + ) + print(f'{jax.devices()=}') + print(f'{jax.local_devices()=}') + + n = jax.device_count() + jax.set_mesh(jax.make_mesh((n,), ("i",))) + x = jax.device_put(jnp.arange(n), jax.P("i")) + while True: + print(jnp.sum(x)) + time.sleep(1) + + +if __name__ == "__main__": + app.run(main) diff --git a/docs/_static/fault_tolerance/data_parallelism.py b/docs/_static/fault_tolerance/data_parallelism.py new file mode 100644 index 000000000000..c70d52751ecb --- /dev/null +++ b/docs/_static/fault_tolerance/data_parallelism.py @@ -0,0 +1,129 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +os.environ['XLA_FLAGS'] = ' '.join([ + '--xla_gpu_nccl_terminate_on_error=false', + '--xla_gpu_nccl_async_execution=true', + '--xla_gpu_nccl_blocking_communicators=false', +]) +os.environ['XLA_PYTHON_CLIENT_ABORT_COLLECTIVES_ON_FAILURE'] = '1' +os.environ['XLA_PYTHON_CLIENT_USE_TFRT_GPU_CLIENT'] = '1' + +from absl import app +from absl import flags +from collections.abc import Sequence +from jax.experimental.multihost_utils import live_devices +import jax +import jax.numpy as jnp +import time + +_PROCESS_ID = flags.DEFINE_integer("i", -1, "Process id") +_NUM_PROCESSES = flags.DEFINE_integer("n", -1, "Number of processes") + +def replicated(x: jax.Array, devices: list[jax.Device]): + """Return x replicated across the provided devices. + + Note that replicated(x) doesn't actually move any data. It simply creates a + logically replicated array with x as the local replica. + """ + n = len(devices) + mesh = jax.make_mesh((n, ), ("i", ), devices=devices) + spec = jax.sharding.PartitionSpec(None) + sharding = jax.sharding.NamedSharding(mesh, spec) + shards = [ + jax.device_put(x.addressable_shards[0].data, d) for d in devices + if d.process_index == jax.process_index() + ] + return jax.make_array_from_single_device_arrays(x.shape, sharding, shards) + + +def sharded(x: jax.Array, devices: list[jax.Device]): + """Return x sharded across the provided devices. + + Note that sharded(x) doesn't actually move any data. It simply creates a + logically sharded array. x should have the same shape as the global array. + """ + n = len(devices) + mesh = jax.make_mesh((n, ), ("i", ), devices=devices) + spec = jax.sharding.PartitionSpec("i") + sharding = jax.sharding.NamedSharding(mesh, spec) + m = sharding.addressable_devices_indices_map(x.shape) + shards = [jax.device_put(x[m[d]], d) for d in jax.local_devices()] + return jax.make_array_from_single_device_arrays(x.shape, sharding, shards) + + +def main(_: Sequence[str]) -> None: + # Parse command line arguments and initialize multi-controller JAX. + jax.config.update("jax_enable_recoverability", True) + jax.distributed.initialize(coordinator_address="localhost:8000", + process_id=_PROCESS_ID.value, + num_processes=_NUM_PROCESSES.value, + local_device_ids=[_PROCESS_ID.value], + heartbeat_timeout_seconds=10) + print(f'{jax.devices()=}') + print(f'{jax.local_devices()=}') + + # Initialize the model's weights. + keys = iter(jax.random.split(jax.random.key(seed=42), num=3)) + weights = jax.random.normal(next(keys), shape=(1, )) + + # We'll learn a trivial linear model: a*x. + def predict(weights, X): + return weights * X + + # We'll use mean squared error loss. + def loss(weights, X, Y): + return jnp.mean((predict(weights, X) - Y)**2) + + # Initialize the (noisy) training data with a=10. + X = jax.random.permutation(next(keys), jnp.arange(-300., 300.)) + Y = 10 * X + jax.random.normal(next(keys), X.shape) + + # Hyperparameters. + loss_and_grad = jax.jit(jax.value_and_grad(loss)) + learning_rate = 1e-6 + device_batch_size = 10 + + step = 0 + while True: + try: + with live_devices(jax.devices()) as devices: + print(f'=== Running step {step} with live devices = {devices} ===') + + # Replicate the model weights. + weights = replicated(weights, devices) + + # Shard the batch. + batch_size = device_batch_size * len(devices) + start = (step * batch_size) % len(X) + stop = start + batch_size + X_batch = sharded(X[start:stop], devices) + Y_batch = sharded(Y[start:stop], devices) + + # Compute gradients and update weights. + l, grad = loss_and_grad(weights, X_batch, Y_batch) + new_weights = jax.block_until_ready(weights - learning_rate * grad) + except Exception as e: + print(f'Step {step} failed: {e}') + else: + print(f'Step {step} succeeded: loss = {l}') + step += 1 + weights = new_weights + + time.sleep(1) + + +if __name__ == "__main__": + app.run(main) diff --git a/docs/_static/fault_tolerance/data_parallelism_with_recovery.py b/docs/_static/fault_tolerance/data_parallelism_with_recovery.py new file mode 100644 index 000000000000..b97461a20773 --- /dev/null +++ b/docs/_static/fault_tolerance/data_parallelism_with_recovery.py @@ -0,0 +1,182 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +os.environ['XLA_FLAGS'] = ' '.join([ + '--xla_gpu_nccl_terminate_on_error=false', + '--xla_gpu_nccl_async_execution=true', + '--xla_gpu_nccl_blocking_communicators=false', +]) +os.environ['XLA_PYTHON_CLIENT_ABORT_COLLECTIVES_ON_FAILURE'] = '1' +os.environ['XLA_PYTHON_CLIENT_USE_TFRT_GPU_CLIENT'] = '1' + +from absl import app +from absl import flags +from collections.abc import Sequence +from jax.experimental.multihost_utils import live_devices +from jax.experimental import shard_map +import jax +import jax.numpy as jnp +import numpy as np +import time + +_PROCESS_ID = flags.DEFINE_integer("i", -1, "Process id") +_NUM_PROCESSES = flags.DEFINE_integer("n", -1, "Number of processes") + +def replicated(x: jax.Array, devices: list[jax.Device]): + """Return x replicated across the provided devices. + + Note that replicated(x) doesn't actually move any data. It simply creates a + logically replicated array with x as the local replica. + """ + n = len(devices) + mesh = jax.make_mesh((n, ), ("i", ), devices=devices) + spec = jax.sharding.PartitionSpec(None) + sharding = jax.sharding.NamedSharding(mesh, spec) + shards = [ + jax.device_put(x.addressable_shards[0].data, d) for d in devices + if d.process_index == jax.process_index() + ] + return jax.make_array_from_single_device_arrays(x.shape, sharding, shards) + + +def sharded(x: jax.Array, devices: list[jax.Device]): + """Return x sharded across the provided devices. + + Note that sharded(x) doesn't actually move any data. It simply creates a + logically sharded array. x should have the same shape as the global array. + """ + n = len(devices) + mesh = jax.make_mesh((n, ), ("i", ), devices=devices) + spec = jax.sharding.PartitionSpec("i") + sharding = jax.sharding.NamedSharding(mesh, spec) + m = sharding.addressable_devices_indices_map(x.shape) + shards = [jax.device_put(x[m[d]], d) for d in jax.local_devices()] + return jax.make_array_from_single_device_arrays(x.shape, sharding, shards) + + +def send(x: jax.Array, from_device: jax.Device, to_device: jax.Device): + """Sends x from one device to another.""" + assert isinstance(x, jax.Array) + devices = [from_device, to_device] + psum = lambda x: jax.lax.psum(x, "i") + mesh = jax.make_mesh((2, ), ("i", ), devices=devices) + spec = jax.sharding.PartitionSpec(None) + x = replicated(x, [from_device, to_device]) + shard_map.shard_map(psum, mesh=mesh, in_specs=spec, out_specs=spec)(x) + + +def recv(x: jax.Array, from_device: jax.Device, to_device: jax.Device): + """Receives x from a matching send.""" + assert isinstance(x, jax.Array) + to_device = jax.local_devices()[0] + devices = [from_device, to_device] + psum = lambda x: jax.lax.psum(x, "i") + mesh = jax.make_mesh((2, ), ("i", ), devices=devices) + spec = jax.sharding.PartitionSpec(None) + x = jnp.zeros_like(x) + x = replicated(x, [from_device, to_device]) + return shard_map.shard_map(psum, mesh=mesh, in_specs=spec, out_specs=spec)(x) + + +def allgather(x: float, devices: list[jax.Device]) -> list[float]: + """Performs an AllGather across the provided devices.""" + n = len(devices) + mesh = jax.make_mesh((n, ), ("i", ), devices=devices) + spec = jax.sharding.PartitionSpec('i') + p = lambda x: jax.lax.all_gather(x, "i", tiled=True) + f = jax.shard_map(p, mesh=mesh, in_specs=spec, out_specs=spec) + return jax.block_until_ready(f(np.array([x] * len(devices)))).addressable_shards[0].data + + +def main(_: Sequence[str]) -> None: + # Parse command line arguments and initialize multi-controller JAX. + jax.config.update("jax_enable_recoverability", True) + jax.distributed.initialize(coordinator_address="localhost:8000", + process_id=_PROCESS_ID.value, + num_processes=_NUM_PROCESSES.value, + local_device_ids=[_PROCESS_ID.value], + heartbeat_timeout_seconds=10) + print(f'{jax.devices()=}') + print(f'{jax.local_devices()=}') + + # Initialize the model's weights. + keys = iter(jax.random.split(jax.random.key(seed=42), num=3)) + weights = jax.random.normal(next(keys), shape=(1, )) + + # We'll learn a trivial linear model: a*x. + def predict(weights, X): + return weights * X + + # We'll use mean squared error loss. + def loss(weights, X, Y): + return jnp.mean((predict(weights, X) - Y)**2) + + # Initialize the (noisy) training data with a=10. + X = jax.random.permutation(next(keys), jnp.arange(-300., 300.)) + Y = 10 * X + jax.random.normal(next(keys), X.shape) + + # Hyperparameters. + loss_and_grad = jax.jit(jax.value_and_grad(loss)) + learning_rate = 1e-6 + device_batch_size = 10 + + step = 0 + while True: + try: + with live_devices(jax.devices()) as devices: + print(f'=== Running step {step} with live devices = {devices} ===') + + # Handle recovering devices. A device is recovering if its step doesn't + # match process 0's step. We assume process 0 never fails. + print('all gathering steps...') + steps = allgather(step, devices) + print(f'{steps=}') + recovering = [d for d, s in zip(devices, steps) if s != steps[0]] + for d in recovering: + # Process 0 sends weights and step to the recovering devices. + if jax.process_index() == 0: + print('sending...') + send(weights, jax.devices()[0], d) + send(jnp.array([step]), jax.devices()[0], d) + elif d.process_index == jax.process_index(): + print('receiving...') + weights = recv(weights, jax.devices()[0], d) + step = recv(jnp.array([step]), jax.devices()[0], d)[0] + + # Replicate the model weights. + weights = replicated(weights, devices) + + # Shard the batch. + batch_size = device_batch_size * len(devices) + start = (step * batch_size) % len(X) + stop = start + batch_size + X_batch = sharded(X[start:stop], devices) + Y_batch = sharded(Y[start:stop], devices) + + # Compute gradients and update weights. + l, grad = loss_and_grad(weights, X_batch, Y_batch) + new_weights = jax.block_until_ready(weights - learning_rate * grad) + except Exception as e: + print(f'Step {step} failed: {e}') + else: + print(f'Step {step} succeeded: loss = {l}') + step += 1 + weights = new_weights + + time.sleep(1) + + +if __name__ == "__main__": + app.run(main) diff --git a/docs/_static/fault_tolerance/dont_fail.py b/docs/_static/fault_tolerance/dont_fail.py new file mode 100644 index 000000000000..a44514d65c71 --- /dev/null +++ b/docs/_static/fault_tolerance/dont_fail.py @@ -0,0 +1,45 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +os.environ['XLA_FLAGS'] = '--xla_gpu_nccl_terminate_on_error=false' + +from absl import app +from absl import flags +from collections.abc import Sequence +import jax +import time + +_PROCESS_ID = flags.DEFINE_integer("i", -1, "Process id") +_NUM_PROCESSES = flags.DEFINE_integer("n", -1, "Number of processes") + + +def main(_: Sequence[str]) -> None: + jax.config.update("jax_enable_recoverability", True) + jax.distributed.initialize( + coordinator_address="localhost:9000", + num_processes=_NUM_PROCESSES.value, + process_id=_PROCESS_ID.value, + local_device_ids=[_PROCESS_ID.value], + heartbeat_timeout_seconds=10, + ) + print(f'{jax.devices()=}') + print(f'{jax.local_devices()=}') + while True: + print(time.time()) + time.sleep(1) + + +if __name__ == "__main__": + app.run(main) diff --git a/docs/_static/fault_tolerance/fault_tolerance.css b/docs/_static/fault_tolerance/fault_tolerance.css new file mode 100644 index 000000000000..86f26e5e842a --- /dev/null +++ b/docs/_static/fault_tolerance/fault_tolerance.css @@ -0,0 +1,283 @@ +.cluster { + margin: 1em; + font-size: smaller; + position: relative; + height: 20em; +} + +.server-box { + position: absolute; + top: 0%; + left: 0%; + width: 100%; + display: flex; + align-items: flex-start; +} + +.proc-box { + display: flex; + flex-direction: column; + max-width: 33%; + transform: translate(0%, -100%); +} + +.p0-box { + position: absolute; + top: 100%; + left: 0%; +} + +.p1-box { + position: absolute; + top: 100%; + left: 35%; +} + +.p2-box { + position: absolute; + top: 100%; + left: 70%; +} + +.proc-box div { + margin: 1pt; +} + +.proc-box button { + margin: 1pt; +} + +.server { + padding: 0.5em; + border: 2pt solid black; + background-color: #FEEFC3; + display: flex; + justify-content: center; + align-items: center; + text-align: center; + font-weight: bold; + z-index: 99; +} + +.proc { + border: 2pt solid black; + border-radius: 50%; + width: 2rem; + height: 2rem; + display: flex; + align-items: center; + justify-content: center; + font-weight: bold; + z-index: 99; +} + +.p0 { + background-color: #D2E3FC; +} + +.p1 { + background-color: #FAD2CF; +} + +.p2 { + background-color: #CEEAD6; +} + +.alive { + color: green; +} + +.dead { + color: red; +} + +.failed { + background-color: gray; +} + +.msg { + position: absolute; + animation-timing-function: linear; + font-size: large; + z-index: -1; +} + +.p0_to_pserver { + animation-duration: 1s; + animation-name: p0_to_pserver_keyframes; +} + +.p1_to_pserver { + animation-duration: 1.5s; + animation-name: p1_to_pserver_keyframes; +} + +.p2_to_pserver { + animation-duration: 2.0s; + animation-name: p2_to_pserver_keyframes; +} + +.pserver_to_p0 { + animation-duration: 1s; + animation-name: pserver_to_p0_keyframes; +} + +.pserver_to_p1 { + animation-duration: 1.5s; + animation-name: pserver_to_p1_keyframes; +} + +.pserver_to_p2 { + animation-duration: 2.0s; + animation-name: pserver_to_p2_keyframes; +} + +@keyframes p0_to_pserver_keyframes { + from { top: 75%; left: 1%; } + to { top: 0%; left: 1%; } +} + +@keyframes p1_to_pserver_keyframes { + from { top: 75%; left: 36%; } + to { top: 0%; left: 1%; } +} + +@keyframes p2_to_pserver_keyframes { + from { top: 75%; left: 71%; } + to { top: 0%; left: 1%; } +} + +@keyframes pserver_to_p0_keyframes { + from { top: 0%; left: 1%; } + to { top: 75%; left: 1%; } +} + +@keyframes pserver_to_p1_keyframes { + from { top: 0%; left: 1%; } + to { top: 75%; left: 36%; } +} + +@keyframes pserver_to_p2_keyframes { + from { top: 0%; left: 1%; } + to { top: 75%; left: 71%; } +} + +.p0_to_pserver_tall { + animation-duration: 1s; + animation-name: p0_to_pserver_keyframes_tall; +} + +.p1_to_pserver_tall { + animation-duration: 1.5s; + animation-name: p1_to_pserver_keyframes_tall; +} + +.p2_to_pserver_tall { + animation-duration: 2.0s; + animation-name: p2_to_pserver_keyframes_tall; +} + +.pserver_to_p0_tall { + animation-duration: 1s; + animation-name: pserver_to_p0_keyframes_tall; +} + +.pserver_to_p1_tall { + animation-duration: 1.5s; + animation-name: pserver_to_p1_keyframes_tall; +} + +.pserver_to_p2_tall { + animation-duration: 2.0s; + animation-name: pserver_to_p2_keyframes_tall; +} + +@keyframes p0_to_pserver_keyframes_tall { + from { top: 55%; left: 1%; } + to { top: 0%; left: 1%; } +} + +@keyframes p1_to_pserver_keyframes_tall { + from { top: 55%; left: 36%; } + to { top: 0%; left: 1%; } +} + +@keyframes p2_to_pserver_keyframes_tall { + from { top: 55%; left: 71%; } + to { top: 0%; left: 1%; } +} + +@keyframes pserver_to_p0_keyframes_tall { + from { top: 0%; left: 1%; } + to { top: 55%; left: 1%; } +} + +@keyframes pserver_to_p1_keyframes_tall { + from { top: 0%; left: 1%; } + to { top: 55%; left: 36%; } +} + +@keyframes pserver_to_p2_keyframes_tall { + from { top: 0%; left: 1%; } + to { top: 55%; left: 71%; } +} + + +.svgbox { + margin-bottom: 1.15em; +} + +.svgbox svg { + margin-left: auto; + margin-right: auto; + display: block; + width: 100%; + height: 100%; +} + +.svgbox svg .proc { + font-family: monospace; + dominant-baseline: middle; + text-anchor: middle; +} + +.svgbox svg .proc-axis { + stroke: gray; + stroke-width: 0.5; + stroke-linecap: round; +} + +.svgbox svg .event { + dominant-baseline: middle; + text-anchor: middle; + stroke: black; +} + +.svgbox svg .p0-color { + stroke: #D2E3FC; +} + +.svgbox svg .p1-color { + stroke: #FAD2CF; +} + +.svgbox svg .p2-color { + stroke: #CEEAD7; +} + +.svgbox svg .rpc { + stroke-width: 12; + stroke-linecap: round; +} + +.svgbox svg .reply { + font-family: monospace; + font-size: smaller; + dominant-baseline: middle; + text-anchor: middle; +} + +.svgbox svg .snapshot { + stroke-width: 2; + stroke: red; +} diff --git a/docs/_static/fault_tolerance/fault_tolerance.js b/docs/_static/fault_tolerance/fault_tolerance.js new file mode 100644 index 000000000000..25a10b23c67e --- /dev/null +++ b/docs/_static/fault_tolerance/fault_tolerance.js @@ -0,0 +1,377 @@ +// Helpers ///////////////////////////////////////////////////////////////////// + +// Returns a random float between min and max. +function rand(min, max) { + return Math.random() * (max - min) + min; +} + +// Formats the provided time as hh:mm:ss. +function formatTime(date) { + // https://stackoverflow.com/a/25279399 + return date.toISOString().substring(11, 19); +} + +// Periodically runs f with a delay between min_delay and max_delay. +// setIntervalWithJitter returns a cancel function that, when called, cancels +// the interval. +function setIntervalWithJitter(f, min_delay, max_delay) { + let handle = null; + + f(); + const helper = () => { + const g = () => { + f(); + helper(); + }; + handle = setTimeout(g, rand(min_delay, max_delay)); + return () => { + clearTimeout(handle); + }; + }; + + return helper(); +} + +// Coordination Service //////////////////////////////////////////////////////// + +class CoordinationService { + constructor(network, options) { + const now = new Date(); + this.network = network; + this.options = options; + this.heartbeats = [now, now, now]; + this.alive = [true, true, true]; + this.in_barrier = []; + + // Periodically refresh state. + setInterval(() => this.refresh(), 100); + } + + receive(msg) { + const {src, dst, type, payload} = msg; + switch (type) { + case 'heartbeat': + this.heartbeats[src] = new Date(); + return []; + case 'live_devices': + if (this.options.barrier) { + if (!this.in_barrier.includes(src)) { + this.in_barrier.push(src); + this.refresh_live_devices(); + } + } else { + this.network.push({ + src: 'server', + dst: msg.src, + type: 'live_devices', + payload: this.live_devices(), + }) + } + break; + default: + console.log(`Unknown message type ${type}`) + } + } + + time_since_heartbeat(i) { + return (new Date() - this.heartbeats[i]) / 1000; + } + + detect_failures() { + let something_failed = false; + for (let i = 0; i < 3; ++i) { + if (this.time_since_heartbeat(i) > 6) { + if (this.alive[i]) { + something_failed = true; + } + this.alive[i] = false; + } else { + this.alive[i] = true; + } + } + + if (something_failed && this.options.share_fate) { + for (let i = 0; i < 3; ++i) { + if (this.alive[i]) { + this.network.push({ + src: 'server', + dst: i, + type: 'fail', + payload: '💀', + }) + } + } + } + } + + live_devices() { + let devices = []; + for (let i = 0; i < 3; ++i) { + if (this.alive[i]) { + devices.push(i); + } + } + return devices; + } + + refresh_live_devices() { + // Check dst see if the live_devices barrier is done. + for (let i = 0; i < 3; ++i) { + if (this.alive[i] && !this.in_barrier.includes(i)) { + // The barrier isn't done. + return; + } + } + + // The barrier is done! Send the set of live devices dst all live devices. + let live = this.live_devices(); + for (let i of live) { + this.network.push({ + src: 'server', + dst: i, + type: 'live_devices', + payload: live, + }) + } + this.in_barrier = []; + } + + refresh() { + this.detect_failures(); + this.refresh_live_devices(); + } + + update_html(container) { + for (let i = 0; i < 3; ++i) { + // Update time since last heartbeat. + const now = new Date(); + const time_since = + container.getElementsByClassName(`p${i}-time-since-heartbeat`)[0]; + time_since.textContent = + ((now - this.heartbeats[i]) / 1000).toFixed(1) + ' s'; + + // Update health. + const health = container.getElementsByClassName(`p${i}-health`)[0]; + if (this.alive[i]) { + health.textContent = 'alive'; + health.classList.add('alive'); + time_since.classList.add('alive'); + health.classList.remove('dead'); + time_since.classList.remove('dead'); + } else { + health.textContent = 'dead'; + health.classList.add('dead'); + time_since.classList.add('dead'); + health.classList.remove('alive'); + time_since.classList.remove('alive'); + } + + } + + // Update processes in barrier. + const in_barrier = container.getElementsByClassName('in-barrier')[0]; + if (in_barrier) { + in_barrier.textContent = `In barrier = [${this.in_barrier}]`; + } + } +} + +// Process + +class Process { + constructor(network, options, i) { + this.network = network; + this.options = options; + this.i = i; + this.alive = true; + this.live_devices = null; + this.heartbeat_cancel = + setIntervalWithJitter(() => this.send_heartbeat(), 3000, 4000); + } + + receive(msg) { + const {src, dst, type, payload} = msg; + switch (type) { + case 'live_devices': + if (this.alive) { + this.live_devices = payload; + } + break; + case 'fail': + this.fail(); + break; + default: + console.log(`Unknown message type ${type}`) + } + } + + send_heartbeat() { + this.network.push({ + src: this.i, + dst: 'server', + type: 'heartbeat', + payload: '❤️', + }) + } + + send_live_devices() { + this.network.push({ + src: this.i, + dst: 'server', + type: 'live_devices', + payload: '⚫', + }) + } + + fail() { + this.alive = false; + this.live_devices = null; + this.heartbeat_cancel(); + } + + update_html(container) { + const live_devices = + container.getElementsByClassName(`p${this.i}-live-devices`)[0]; + if (this.options.live_devices) { + if (this.live_devices == null) { + live_devices.textContent = 'live processes = 0,1,2'; + } else { + live_devices.textContent = `live processes = ${this.live_devices}`; + } + } + + if (!this.alive) { + const node = container.getElementsByClassName(`p${this.i}`)[0]; + node.classList.add('failed'); + + const ld_button = + container.getElementsByClassName(`p${this.i}-ld-button`)[0]; + if (ld_button) { + ld_button.disabled = true; + } + + const fail_button = + container.getElementsByClassName(`p${this.i}-fail-button`)[0]; + if (fail_button) { + fail_button.disabled = true; + } + } + } +} + + +// Network communication. + +function send(container, tall, text, src, dst, after) { + const msg = document.createElement('div'); + msg.textContent = text; + msg.classList.add('msg'); + if (tall) { + msg.classList.add(`${src}_to_${dst}_tall`); + } else { + msg.classList.add(`${src}_to_${dst}`); + } + msg.addEventListener('animationend', (_) => { + msg.remove(); + after(); + }); + container.appendChild(msg); +} + +// { +// share_fate: false, +// live_devices: false, +// barrier: false, +// } +function init_cluster(id, options) { + const container = document.getElementById(id); + container.innerHTML = ` +
+
Coordination Service
+
+
    +
  • Process 0: 0s (alive)
  • +
  • Process 1: 0s (alive)
  • +
  • Process 2: 0s (alive)
  • +
  • In barrier: []
  • +
+
+
+ +
+
0
+
live processes = 0,1,2
+ + +
+ +
+
1
+
live processes = 0,1,2
+ + +
+ +
+
2
+
live processes = 0,1,2
+ + +
+ `; + + // Create the cluster. + let network = []; + let server = new CoordinationService(network, options); + const processes = [ + new Process(network, options, 0), new Process(network, options, 1), + new Process(network, options, 2) + ]; + + // Set up the live_devices button. + for (let i = 0; i < 3; ++i) { + const button = container.getElementsByClassName(`p${i}-ld-button`)[0]; + if (options.live_devices) { + button.addEventListener('click', () => processes[i].send_live_devices()); + } else { + button.remove(); + } + } + + // Set up the fail button. + const button = container.querySelectorAll('.p2-fail-button')[0]; + button.addEventListener('click', () => processes[2].fail()); + + // Remove live_devices display if needed. + if (!options.live_devices) { + for (let i = 0; i < 3; ++i) { + container.getElementsByClassName(`p${i}-live-devices`)[0].remove(); + } + } + if (!options.barrier) { + container.getElementsByClassName('in-barrier')[0].remove(); + } + + // Periodically process network messages. + setInterval(() => { + while (network.length > 0) { + const msg = network.shift(); + const tall = options.live_devices; + send(container, tall, msg.payload, `p${msg.src}`, `p${msg.dst}`, () => { + if (msg.dst == 'server') { + server.receive(msg); + } else { + processes[msg.dst].receive(msg); + } + }); + } + }, 10) + + // Periodically update HTML. + setInterval(() => { + server.update_html(container); + for (let proc of processes) { + proc.update_html(container); + } + }, 50); +} diff --git a/docs/_static/fault_tolerance/live_devices.py b/docs/_static/fault_tolerance/live_devices.py new file mode 100644 index 000000000000..9f41a2bdac6a --- /dev/null +++ b/docs/_static/fault_tolerance/live_devices.py @@ -0,0 +1,64 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +os.environ['XLA_FLAGS'] = ' '.join([ + '--xla_gpu_nccl_terminate_on_error=false', + '--xla_gpu_nccl_async_execution=true', + '--xla_gpu_nccl_blocking_communicators=false', +]) +os.environ['XLA_PYTHON_CLIENT_ABORT_COLLECTIVES_ON_FAILURE'] = '1' +os.environ['XLA_PYTHON_CLIENT_USE_TFRT_GPU_CLIENT'] = '1' + +from absl import app +from absl import flags +from collections.abc import Sequence +from jax.experimental.multihost_utils import live_devices +import jax +import jax.numpy as jnp +import time + +_PROCESS_ID = flags.DEFINE_integer("i", -1, "Process id") +_NUM_PROCESSES = flags.DEFINE_integer("n", -1, "Number of processes") + + +def main(_: Sequence[str]) -> None: + jax.config.update("jax_enable_recoverability", True) + jax.distributed.initialize( + coordinator_address="localhost:9000", + num_processes=_NUM_PROCESSES.value, + process_id=_PROCESS_ID.value, + local_device_ids=[_PROCESS_ID.value], + heartbeat_timeout_seconds=10, + ) + print(f'{jax.devices()=}') + print(f'{jax.local_devices()=}') + + while True: + try: + with live_devices(jax.devices()) as devices: + print(f'{devices=}') + n = len(devices) + jax.set_mesh(jax.make_mesh((n,), ("i",), devices=devices)) + x = jax.device_put(jnp.arange(n), jax.P("i")) + print(jnp.sum(x)) + except Exception as e: + print('FAIL:', e) + else: + print('PASS') + time.sleep(1) + + +if __name__ == "__main__": + app.run(main) diff --git a/docs/_static/fault_tolerance/while_loop.py b/docs/_static/fault_tolerance/while_loop.py new file mode 100644 index 000000000000..0dbac58b528d --- /dev/null +++ b/docs/_static/fault_tolerance/while_loop.py @@ -0,0 +1,41 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl import app +from absl import flags +from collections.abc import Sequence +import jax +import time + +_PROCESS_ID = flags.DEFINE_integer("i", -1, "Process id") +_NUM_PROCESSES = flags.DEFINE_integer("n", -1, "Number of processes") + + +def main(_: Sequence[str]) -> None: + jax.distributed.initialize( + coordinator_address="localhost:9000", + num_processes=_NUM_PROCESSES.value, + process_id=_PROCESS_ID.value, + local_device_ids=[_PROCESS_ID.value], + heartbeat_timeout_seconds=10, + ) + print(f'{jax.devices()=}') + print(f'{jax.local_devices()=}') + while True: + print(time.time()) + time.sleep(1) + + +if __name__ == "__main__": + app.run(main) diff --git a/docs/advanced_guides.rst b/docs/advanced_guides.rst index 7a9d3d95a58d..e090efa67b29 100644 --- a/docs/advanced_guides.rst +++ b/docs/advanced_guides.rst @@ -17,6 +17,7 @@ operations. notebooks/layout notebooks/host-offloading multi_process + fault_tolerance distributed_data_loading notebooks/colocated-python diff --git a/docs/fault_tolerance.rst b/docs/fault_tolerance.rst new file mode 100644 index 000000000000..153b3c159399 --- /dev/null +++ b/docs/fault_tolerance.rst @@ -0,0 +1,1524 @@ +.. raw:: html + + + + + +Fault Tolerant Distributed JAX +============================== + +Recall that `multi-controller JAX`_ allows you to run a JAX program distributed +across multiple machines. By default, if *any* of these machines fail, then +*every* machine will fail. That is, multi-controller JAX is not +**fault-tolerant** by default. + +This article has three parts. In the first part, we'll explain the basics of +how to write fault tolerant multi-controller JAX programs. In the second part, +we'll show some example fault-tolerant multi-controller JAX programs. In the +third part, we'll take a look under the covers at how multi-controller JAX +implements fault tolerance. + +.. warning:: + + JAX's support for fault tolerance is still experimental. It currently only + works fully on GPUs. It has rough edges, is probably buggy, and is subject + to change. Use at your own risk. + + +.. _part1: + +Part 1: Fault Tolerance Basics +------------------------------ + +Fault Intolerant By Default +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +By default, multi-controller JAX programs are not fault tolerant. If *any* +process crashes, then *all* other processes will also intentionally crash. To +make this concrete, consider the following trivial script, ``example.py``, that +initializes multi-controller JAX by calling ``jax.distributed.initialize`` and +then enters an infinite loop. + +.. literalinclude:: _static/fault_tolerance/while_loop.py + :language: python + :emphasize-lines: 12-18 + :lines: 15- + :linenos: + :caption: ``example.py`` + +Run ``example.py`` across four processes on a VM with four GPUs by running +the following four commands, each in a different terminal. The +``local_device_ids`` argument to ``jax.distributed.initialize`` ensures each +process is assigned only one of the four GPUs. We'll explain the +``heartbeat_timeout_seconds`` argument in just a second. + +.. code-block:: shell + + python example.py --i=0 --n=4 # in terminal 1 + python example.py --i=1 --n=4 # in terminal 2 + python example.py --i=2 --n=4 # in terminal 3 + python example.py --i=3 --n=4 # in terminal 4 + +When you run these commands, you'll see the processes dutifully printing out +the current time every second. Next, fail the fourth process: ``pkill -9 -f +'python example.py --i=3 --n=4'``. After about ten seconds, the other +processes will also terminate and spit out error messages that look something +like this: + +.. code-block:: + + E0926 17:26:32.075402 157988 coordination_service_agent.cc:332] Polled an error from coordination service (this can be an error from this or another task). + F0926 17:26:32.075587 157988 client.h:77] Terminating process because the JAX distributed service detected fatal errors. This most likely indicates that another task died; see the other task logs for more details. Disable Python buffering, i.e. `python -u`, to be sure to see all the previous output. absl::Status: UNAVAILABLE: The following tasks are unhealthy (stopped sending heartbeats): + /job:jax_worker/replica:0/task:3 + The tasks have crashed. Check the task logs for an earlier error, or scheduler events (e.g. preemption, eviction) to debug further. + + RPC: /tensorflow.CoordinationService/PollForError [type.googleapis.com/tensorflow.CoordinationServiceError=''] + +When a process in a multi-controller JAX program notices that a peer process +has crashed, it decides to crash as well. The processes `share fate`_. The +``heartbeat_timeout_seconds`` argument to ``jax.distributed.initialize`` +determines how long a process waits before concluding a peer process has died. +The first three processes crash about ten seconds after you kill the fourth +because we passed ``heartbeat_timeout_seconds=10`` as an argument to +``jax.distributed.initialize``. + +Surviving Faults +^^^^^^^^^^^^^^^^ + +We can disable fate-sharing by adding the +``--xla_gpu_nccl_terminate_on_error=false`` flag and the +``jax_enable_recoverability`` configuration option to ``example.py``, as shown +below: + +.. literalinclude:: _static/fault_tolerance/dont_fail.py + :language: python + :emphasize-lines: 1-2,15 + :linenos: + :lines: 15- + +Again run the script across four processes and then kill the fourth. Notice +that now, the other three processes happily continue executing. + +Next try failing process 0. Notice that all four processes terminate with +error messages that look something like the following: + +.. code-block:: + + E0929 17:42:48.594192 1044529 coordination_service_agent.cc:332] Polled an error from coordination service (this can be an error from this or another task). + F0929 17:42:48.594200 1044529 client.h:77] Terminating process because the JAX distributed service detected fatal errors. This most likely indicates that another task died; see the other task logs for more details. Disable Python buffering, i.e. `python -u`, to be sure to see all the previous output. absl::Status: UNAVAILABLE: Failed to send RPC to coordination service. Either the leader task was preempted/died/restarted unexpectedly or this task is experiencing network issues. Check earlier logs from 1) this task, 2) the leader (usually slice 0 task 0), and 3) cluster scheduler to debug further. + Additional GRPC error information from remote target coordination_service while calling /tensorflow.CoordinationService/PollForError: + :UNKNOWN:Error received from peer {grpc_message:"Socket closed", grpc_status:14} + +Process 0 is special. If process 0 fails, every process will fail, even with +fate-sharing disabled. Why? Process 0 runs an RPC service called the +coordination service that all processes use to coordination with each other. If +the coordination service fails, all other processes have no choice but to fail. +See :ref:`part3` for more details. + +Getting Stuck in Collectives +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +``example.py`` is now able to survive faults, but the processes do not +communicate with each other at all. Any realistic multi-controller JAX program +would involve communication between the processes (otherwise, what's the point +of using multi-controller JAX?). Let's edit ``example.py`` so that the +processes perform a collective ``jnp.sum`` in every iteration of the loop. + +.. literalinclude:: _static/fault_tolerance/collectives.py + :language: python + :emphasize-lines: 27-32 + :linenos: + :lines: 15- + +In the highlighted code above, the processes create an array ``x`` sharded +across the four processes and then perform a distributed ``jnp.sum``. Again run +the program and fail the fourth process. You'll notice that the first three +process do not crash, but they do get *stuck*. By default, if a process fails +while participating in a distributed computation (like ``jnp.sum``), then the +rest of the processes participating in the computation will get stuck +*forever*. + +.. _`canceling_collectives`: + +Cancelling Collectives +^^^^^^^^^^^^^^^^^^^^^^ + +We can avoid getting stuck by cancelling collectives with a failed participant. +We can enable collective cancelling by providing a few more flags and +environment variables, highlighted below. + +.. literalinclude:: _static/fault_tolerance/cancel_collectives.py + :language: python + :emphasize-lines: 1-8,22,33-35 + :linenos: + :lines: 15- + +We also need to insert a call to +``jax.experimental.multihost_utils._live_devices`` to make the script work. You +should normally not do this. You should instead use the ``live_devices`` API +that we'll introduce momentarily. For now, ``_live_devices`` is a hack to get +the script working before we explain the proper API. + +Again run the script and fail the fourth process. The first three processes +will be stuck in their call to ``jnp.sum``, but after about ten seconds, the +call will be cancelled and ``jnp.sum`` will raise an exception that looks +something like this: + +.. code-block:: + + jaxlib._jax.XlaRuntimeError: FAILED_PRECONDITION: Task with incarnation id 3446767950926952685 is not connected + + +Knowing Who's Alive +^^^^^^^^^^^^^^^^^^^ + +After a process dies, the remaining *alive* procesess need to learn who is dead +and who is alive. For this, we can use the core JAX fault tolerance API: +``live_devices``. ``live_devices`` is a context manager that takes a list of +devices as an argument and returns the subset of these devices that are alive. +Below, we edit ``example.py`` to call ``live_devices``. + +.. literalinclude:: _static/fault_tolerance/live_devices.py + :language: python + :emphasize-lines: 34-46 + :linenos: + :lines: 15- + +In the highlighted code above, we call ``live_devices`` with all devices +(``jax.devices()``) to get the set ``devices`` of live devices. We then shard +array ``x`` over these devices and perform a ``jnp.sum``. If a process fails +while executing the ``jnp.sum``, then ``jnp.sum`` will be cancelled and raise +an exception on the remaining live devices. Technically, the collective is not +guaranteed to fail. We'll revisit this in :ref:`atomicity`. For now, assume it +will fail. + +.. note:: + + ``jax.devices()`` always returns the set of *all* devices, even if some of + these devices are on failed processes. Use + ``jax.experimental.multihost_utils.live_devices`` to learn which of these + devices are live. + +Again run the script and fail the fourth process. Notice that the remaining +three alive processes catch the exception raised by ``jnp.sum`` and continue to +the next iteration of the while loop. In this next iteration, ``devices`` does +not include the device on the failed fourth process. The three alive processes +continue to execute correctly even though the fourth process is dead. + +Next, restart the fourth process. Notice that after the fourth process +restarts, its device is again included in the set of alive devices returned by +``live_devices``. All four processes then continue executing normally. + +At first blush, ``live_devices`` seems trivial. You give it a list of devices, +and it returns the ones that are alive. How complicated can that be? +Unfortunately, as with `many things in distributed systems`_, there are a lot +subtleties to iron out. Next, we explain the **barrier** semantics and +**atomicity** properties of ``live_devices``. + +Barrier Semantics +^^^^^^^^^^^^^^^^^ + +Recall that every process in a `multi-controller JAX`_ program should run in +lockstep. The processes should execute the same instructions in the same order. +Failing to do so will *almost certainly* lead to deadlocks, crashes, or +anomalous behavior. + +In the context of ``live_devices``, we need to ensure that every process agrees +on which processes are currently alive. This is difficult to ensure because +every process is executing independently at potentially different speeds and +processes can fail at any time. Consider again the ``example.py`` script from +above running on four processes. Imagine process 1 and 2 call ``live_devices``, +then process 4 fails, and then process 3 calls ``live_devices``. Process 1 and +2 might think process 4 is alive while process 3 thinks it is dead. + +To avoid situations like these, ``live_devices`` guarantees that it returns the +same set of live devices to every process. It accomplishes this using a +barrier. A call to ``live_devicess(devices)`` blocks until every live process +hosting a device in ``devices`` has also called ``live_devices``. Once every +live process is in the ``live_devices`` barrier, ``live_devices`` returns the +same set of live devices to every process. + +.. important:: + + ``live_devices`` uses a barrier to ensure that it will *always* return the + same set of live devices to every live process. + +Because ``live_devices`` implements a barrier it is susceptible to deadlock if +used improperly. We recommend only having a single ``with live_devices`` block +in a program. Multiple calls to ``live_devices`` is hard to reason about and +can lead to deadlock. + +See :ref:`part3` for details on how the ``live_devices`` barrier is implemented +as well as a formal semantics based on `linearizability`_. + +.. _atomicity: + +Atomicity +^^^^^^^^^ + +A distributed computation is **atomic** if every participant in the computation +agrees on whether the operation succeeds or fails. In the ``example.py`` script +above, we saw that when a process failed during the execution of a ``jnp.sum``, +then ``jnp.sum`` would abort and raise an exception on the remaining live +processes. So ``jnp.sum`` is atomic? + +Unfortunately, it's not. + +When a process fails during the execution of a collective operation (like +``jnp.sum``), the remaining processes may cancel the operation and raise an +exception or they may complete the operation successfully. Collective +operations in JAX do not have any inherent atomicity properties. + +If collective operations are not atomic, however, then multi-controller JAX +processes might diverge. For example, if a process fails during a training step +of a machine learning model, some processes might detect the failure and roll +the model back to a checkpoint while other processes might think the step +succeeded and keep training. + +To avoid the complexities of non-atomic execution, ``live_devices`` provides +its own atomicity guarantees despite the fact that collectives are not atomic. +Specifically, the body of a ``with live_devices`` block is guaranteed to either +complete successfully on all processes or raise an exception on all processes. +More concretely, if we consider the code snippet below, either every process +executes branch A or every process executes branch B. It is impossible for some +processes to execute A while others execute B. + +.. code-block:: python + + try: + with live_devices(jax.live_devices()) as devices: + ... + except Exception as e: + ... # Branch A + else: + ... # Branch B + +.. warning:: + + A ``with live_devices`` block does not guarantee atomicity if the code + block non-deterministically raises exceptions for reasons other than + collectives that fail because of a crashed process. For example, if one + process raises an exception because it runs out of memory, this exception + will not be propagated to the other processes. + +Recall that JAX uses `asynchronous dispatch`_. Operations like ``jnp.sum`` do +not block until the operation is complete. Instead, they return ``jax.Arrays`` +that act as futures. This asynchrony can interact with ``live_devices`` in +unexpected ways. For example, consider the following code that performs a +``jnp.sum``, assigns the result to ``y``, and then prints ``y``: + +.. code-block:: python + + x = ... + y = ... + try: + with live_devices(jax.live_devices()) as devices: + y = jnp.sum(x) + except Exception as e: + ... # Branch A + else: + ... # Branch B + print(y) + +Imagine that the ``with live_devices`` block executes successfully on all +processes. That is, all processes execute branch B. This only guarantees that +every process successfully created a future and assigned it to ``y``. The +actual computation of the ``jnp.sum`` may be delayed until outside the block. +Thus, some processes might successfully complete the ``jnp.sum`` and print the +value of ``y`` while other processes fail to complete the ``jnp.sum`` and raise +an exception when trying to print ``y``. + +To avoid this, use ``jax.block_until_ready`` to ensure that computations are +performed within the ``with live_devices`` block. The code snippet below, which +now calls ``jax.block_until_ready`` when assigning to ``y``, guarantees that +every process will successfully execute the ``jnp.sum`` or every process will +raise an exception. + +.. code-block:: python + + x = ... + y = ... + try: + with live_devices(jax.live_devices()) as devices: + y = jax.block_until_ready(jnp.sum(x)) + except Exception as e: + ... # Branch A + else: + ... # Branch B + print(y) + +See :ref:`part3` for details on how atomicity is implemented. + +Part 2: Examples +---------------- + +``live_devices`` is not a panacea; it is a tool. It does not magically make +multi-controller JAX programs fault tolerant. Rather, it allows you to +implement fault tolerance yourself in the way that is best for your +application. + +The exact details of how you implement fault-tolerance will vary greatly based +on the nature of your application. In this section, we present some examples of +how to use ``live_devices``. The examples are meant to be illustrative but not +prescriptive. There are many other ways to implement fault tolerance. + +Example 1: Fault Tolerant Data Parallel Training +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In this example, we train a trivial single-parameter linear model (:math:`y = +\alpha x`) with data parallelism across four processes. The example is +contrived---you would never train a model with a single parameter across four +machines---but we intentionally keep the model simple to focus on fault +tolerance. + +Data parallelism makes implementing fault tolerance relatively straightforward. +Because every process has a full copy of the model weights, if a process fails, +we can simply ignore it and continue training. This example tolerates an +arbitrary number of process failures (excluding process 0), but once a process +fails, we assume it does not recover. The next example shows how to handle +process recovery. + +First, we set some flags to disable fate-sharing and enable collective +cancelling. We also make the necessary imports and define some flags. + +.. literalinclude:: _static/fault_tolerance/data_parallelism.py + :language: python + :lines: 15-33 + :lineno-start: 1 + +Next, we define a ``replicated`` function that returns an array replicated +across a set of devices. Note that ``replicated`` doesn't actually move any +data. It assumes the argument ``x`` already has equal value across all +processes. It simply returns a new view of that data, in a process-spanning +`jax.Array` with a replicated sharding. + +.. literalinclude:: _static/fault_tolerance/data_parallelism.py + :language: python + :lines: 35-49 + :lineno-start: 21 + +We define a similar ``sharded`` function that returns an array sharded across a +set of devices. Again, ``sharded`` is not actually moving any data between +processes. + +.. literalinclude:: _static/fault_tolerance/data_parallelism.py + :language: python + :lines: 52-64 + :lineno-start: 38 + +Now, we're ready to start writing our training loop. We begin by initializing +multi-controller JAX by calling ``jax.distributed.initialize``. + +.. literalinclude:: _static/fault_tolerance/data_parallelism.py + :language: python + :lines: 67-76 + :lineno-start: 53 + +Then, we define our simple linear model, generate some random training data, +and initialize some basic hyperparameters. + +.. literalinclude:: _static/fault_tolerance/data_parallelism.py + :language: python + :lines: 78-97 + :lineno-start: 64 + +Finally, we enter the main training loop. + +.. literalinclude:: _static/fault_tolerance/data_parallelism.py + :language: python + :lines: 99-125 + :lineno-start: 85 + +- Every iteration of the loop, we call ``live_devices`` to learn which devices + are currently alive. +- We then ensure that the model weights are replicated across these devices and + ensure that the training data is sharded across these devices. Note that this + doesn't actually move any data between the devices; it simply creates JAX + arrays with the appropriate replication and sharding metadata. +- We call ``loss_and_grad`` to compute the gradient of the weights with respect + to the current batch of data and then compute the new weights. Notice that we + assign the new weights to ``new_weights`` rather than assigning to + ``weights`` in case the training step fails. We also call + ``jax.block_until_ready`` to ensure that every process has computed the new + weights when we exit the ``live_devices`` block. +- If no processes failed during the execution of the training step, then the + ``else`` branch is taken. The step is incremented, and ``weights`` is + updated. Otherwise, an exception will be raised and the ``except`` branch is + taken. In this case, we do not update ``step`` or ``weights`` and retry the + step on the next iteration with the new set of live devices. + +Here is the full example: + +.. literalinclude:: _static/fault_tolerance/data_parallelism.py + :language: python + :linenos: + :lines: 15- + +Example 2: Fault Tolerant Data Parallel Training With Recovery +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Now, we modify the example above to allow failed processes to recover. When a +process recovers, it needs to receive the current step and model weights. +Because we assume process 0 never fails---recall that if process 0 fails, every +process will fail---we have process 0 send the current step and weights to +recovering processes. + +First, we define ``send`` and ``recv`` functions that use a ``shard_map`` to +send data from one device to another. The sender calls ``send``, and the +receiver calls ``recv``. + +.. literalinclude:: _static/fault_tolerance/data_parallelism_with_recovery.py + :language: python + :lines: 69-90 + :lineno-start: 55 + +``allgather`` performs an AllGather of a single float across a set of devices. + +.. literalinclude:: _static/fault_tolerance/data_parallelism_with_recovery.py + :language: python + :lines: 93-100 + :lineno-start: 79 + +Finally, we modify the training loop to handle recovering processes, as shown +in the highlighted code below. + +.. literalinclude:: _static/fault_tolerance/data_parallelism_with_recovery.py + :language: python + :lines: 135-178 + :lineno-start: 121 + :emphasize-lines: 7-22 + +Recovery is a two-step process. First, we need to detect which processes are +recovering. Second, we need process 0 to send the step and weights to the +recovering processes. + +1. To detect which processes are recovering, we perform an AllGather on all + live processes' steps. When a failed process recovers, its ``step`` will be + ``0``, while the ``step`` on process ``0`` will be some positive number, so + if a process' step is not equal to process 0's step, then it is recovering. +2. Then, we call the ``send`` and ``recv`` functions we defined above to + transfer the current step and model weights from process 0 to the recovering + processes. + +Here is the full example: + +.. literalinclude:: _static/fault_tolerance/data_parallelism_with_recovery.py + :language: python + :linenos: + :lines: 15- + +.. _part3: + + +Part 3: Implementation Details +------------------------------ + +We now take a deep dive into the architecture of multi-controller JAX and the +semantics and implementation of ``live_devices``. If you're only interested in +writing fault-tolerant multi-controller JAX programs, the first two parts of +this article suffice. + +The Coordination Service +^^^^^^^^^^^^^^^^^^^^^^^^ + +When you launch a multi-controller JAX program, the first process (i.e. process +0) runs a standalone RPC server called the **coordination service**. Moreover, +all processes (including process 0) create an RPC client to the coordination +service. Concretely, the ``coordinator_address`` argument of +:func:`jax.distributed.initialize` is the address of the coordination service. +This argument lets process 0 know on what address to run the server, and it +lets all processes know which address to connect to. + +The coordination service implements the multi-controller JAX **control plane**. +For example, it can perform a distributed barrier across all processes, and it +implements a key-value store that processes can use to exchange small amounts +of metadata. Note, however, that the **data plane** (e.g., all collective +operations on program data) is implemented directly between the processes and +does not involve the coordination service. + +One of the most important functionalities of the coordination service is health +checking. Every process periodically sends a heartbeat to the coordination +service. If a process fails, it stops sending heartbeats. If the coordination +service hasn't received a heartbeat from a process for a while, it assumes the +process has failed. + +This is shown in the interactive visualization below. The coordination service +is shown at the top and three multi-controller JAX processes are shown at the +bottom. Note how the processes periodically send heartbeats to the controller, +and the controller keeps track of the health of each process based on when it +last received a heartbeat. Try failing process 2 by clicking the "Fail" button. +Observe how the process stops sending heartbeats and the coordination service +eventually considers the process dead. + +.. raw:: html + +
+ + +By default, when the coordination service detects that a process has failed, it +sends a message to all other processes requesting that they self-terminate. In +other words, all processes in a multi-controller JAX program `share fate`_. +Again fail process 2 in the visualization below by clicking the "Fail" button +and observe how the coordination service notifies the other processes to fail. + +.. raw:: html + +
+ + +This fate sharing means that multi-controller JAX programs are not at all +fault-tolerant. They are fault-*intolerant*. To enable fault-tolerance, we +need to do two things: + +- First, we need to remove fate sharing and allow processes to continue + executing even when a peer process has died. This can be enabled using the + ``jax_enable_recoverability`` option, as described in :ref:`part1`. We'll + assume that this option is set. +- Second, we need to provide an API that processes can use to learn which + processes are alive and which have failed. This is the ``live_devices`` API + introduced in :ref:`part1`. + +There is a surprising amount of technical depth and subtlety in implementing +the ``live_devices`` API. We'll walk through the design and implementation of +the API step-by-step. We'll begin by introducing a simpler ``live_processes`` +API and slowly improve it until we arrive at the ``live_devices`` API. + +Live Processes +^^^^^^^^^^^^^^ + +Let's try to design a new hypothetical JAX API: ``jax.live_processes``. As the +name suggests, we want ``jax.live_processes()`` to return the set of all +currently alive processes. Here is a naive but (as we'll see momentarily) +incorrect implementation. When a process calls ``jax.live_processes()``, it +sends an RPC request to the coordination service. Remember that the +coordination service already uses heartbeats to keep track of which processes +are dead and which are alive, so when it receives a ``jax.live_processes`` +request, it responds with the set of processes it thinks are alive. + +This is illustrated below. Below each process is a "Call live_processes" +button. You can click this button to make the process call +``jax.live_processes``. Note how the coordination service replies to a +``live_processess`` request with the set of alive processes. Fail process 2 by +clicking the "Fail" button and see how it affects later calls to +``jax.live_processes``. + +.. raw:: html + +
+ + +This naive implementation is simple but incorrect. It is crucial that all +processes in a multi-controller JAX job execute the same instructions in the +same order. If the processes start to diverge, by executing different code +paths in the JAX program, the job will behave erratically. Most likely, it will +crash or hang or produce garbage values, and most certainly it will be very +hard to reason about. + +Our naive implementation of ``jax.live_processes`` can very easily lead to +divergence. For example, consider a multi-controller JAX job with three +processes. If process 0 and 1 both call ``jax.live_processes`` around the same +time that process 2 fails, the coordination service might report to process 0 +that all processes are alive but report to process 1 that only processes 0 and +1 are alive. Try to produce this scenario in the visualization below: + +.. raw:: html + +
+ + +If processes disagree on which processes are alive, they will almost certainly +diverge. Thankfully, we can avoid this divergence by augmenting +``jax.live_processes`` with barrier semantics. + +Barrier Semantics +^^^^^^^^^^^^^^^^^ + +Let's change the implementation of ``jax.live_processes`` so that when the +coordination service receives a ``jax.live_processes()`` request, it does not +reply right away. Instead, the coordination service only replies once *every* +live process has called ``jax.live_processes()``. Once every alive process has +entered the ``jax.live_processess()`` barrier, the coordination service returns +the set of live processes. Crucially, the coordination service returns the +*same* set of live processes to all processes, which prevents the processes +from diverging. + +This is illustrated below. Note that coordination server now keeps track of +which devices are in the ``live_processes`` barrier. Try calling +``live_processes`` from every process. Notice how the coordination service +doesn't respond until every process has entered the barrier. Then fail process +2 and call ``live_processes`` from process 0 and process 1. + +.. raw:: html + +
+ + +Formal Semantics +^^^^^^^^^^^^^^^^ + +Distributed systems are notoriously complex. Machines can fail at arbitrary +times, and network messages can be dropped, delayed, and reordered. In this +section, we introduce a formal semantics of the ``jax.live_processes`` API to +help tame this complexity. Thinking rigorously about the semantics of +``jax.live_processes`` will help us understand the behavior of the API even in +pathological executions. + +We'll base the formal semantics of ``jax.live_processes`` on +`linearizability`_: a popular formalism used to define the semantics of many +distributed APIs. Concretely, we model our distributed system as a number of +processes. Each process serially performs a number of events. There are four +types of events: + +1. A process can **start** (👶). We'll assume that when a process starts, it + connects to the coordination service, so the coordination service is aware + that is has started. +2. A process can **fail** (💀). Unlike starting, the coordination service may + not immediately be aware that a process has failed. +3. A process can **send** a ``jax.live_processes`` request to the coordination + service. +4. A process can **receive** a reply to a ``jax.live_processes`` request from + the coordination service. + +Below is a diagram of an execution of three processes: 0, 1, and 2. Time +progresses from left to right. First, all three processes start. This is shown +with the baby emojis. Then all three processes send ``jax.live_processes`` +requests to the coordination service. This is shown as the start of the thick +colored regions. Later, all three processes receive a reply from the +coordination service with ``0,1,2`` as the set of live devices. + +.. raw:: html + +
+ + + 0 + 1 + 2 + + + + + + + + 👶 + + 0,1,2 + + + 👶 + + 0,1,2 + + + 👶 + + 0,1,2 + +
+ +In this simple execution, it is clear that ``jax.live_processes`` is behaving +correctly. We can formalize this intuition with the following formal semantics. + +.. attention:: + + An execution is valid if whenever ``jax.live_processes`` returns a set ``P`` + of live processes, there exists an instantaneous moment in time at which + every process in ``P`` was in the ``live_processes`` barrier and every other + process was dead. An implementation of ``live_processes`` is correct if + it only allows for valid executions. + +Later, we will amend these formal semantics to cover some subtle corner cases, +but assume this simplified semantics for now. + +In the example above, ``live_processes`` returns ``0,1,2``. In the +visualization below, we show that there does exist an instantaneous moment of +time in which processes 0, 1, and 2 are all in the barrier and all other +processes (there are none) are dead. The moment in time is drawn as a vertical +red bar. + +.. raw:: html + +
+ + + 0 + 1 + 2 + + + + + + + + 👶 + + 0,1,2 + + + 👶 + + 0,1,2 + + + 👶 + + 0,1,2 + + + + +
+ +There is nothing special about the specific moment in time we chose in the +visualization above. All that's important is that *there exists some* moment in +time where all processes in `P` are in the barrier and all other processes are +dead. There are many moments in time that satisfy this property, as shown +below. + +.. raw:: html + +
+ + + 0 + 1 + 2 + + + + + + + + 👶 + + 0,1,2 + + + 👶 + + 0,1,2 + + + 👶 + + 0,1,2 + + + + + + + +
+ +In the next example, processes 0 and 1 start, call ``jax.live_devices``, and +receive ``0,1`` as a reply. Process 2 is dead throughout the execution. + +.. raw:: html + +
+ + + 0 + 1 + 2 + + + + + + + + 👶 + + 0,1 + + + 👶 + + 0,1 + + + 💀 + +
+ +This is a valid execution under our formal semantics because there exists a +moment a time in which processes 0 and 1 are in the barrier and process 2 is +dead. + +.. raw:: html + +
+ + + 0 + 1 + 2 + + + + + + + + 👶 + + 0,1 + + + 👶 + + 0,1 + + + 💀 + + + + +
+ +In the following execution, process 0 calls ``jax.live_processes`` and receives +a reply of ``0``. Process 1 calls ``jax.live_processes``, but dies before +receiving a reply. + +.. raw:: html + +
+ + + 0 + 1 + + + + + + + 👶 + + 0 + + + 👶 + + 💀 + +
+ +Is this a valid execution? Yes. There exists a moment in time at which process +0 is in the barrier and process 1 is dead, as shown below. Even though process +1 called ``jax.live_processes``, it is not guaranteed that process 1 will be +included in the coordination service's response. + +For example, process 1's ``jax.live_processes`` request may have been dropped +by the network and never received by the coordination service. So from the +coordination service's perspective, process 1 is thoroughly dead and never even +entered the ``live_processes`` barrier. + +.. raw:: html + +
+ + + 0 + 1 + + + + + + + 👶 + + 0 + + + 👶 + + 💀 + + + + +
+ +What about the same exact execution, except that process 0 now receives the +reply ``0,1`` from the coordination service? + +.. raw:: html + +
+ + + 0 + 1 + + + + + + + 👶 + + 0,1 + + + 👶 + + 💀 + +
+ +Again, this is a valid execution, as witnessed below. Intuitively, the +coordination service could have received ``jax.live_processes`` requests from +both processes 0 and 1 and sent the reply ``0,1`` to both. While this reply was +in the network, process 1 failed. Thus, even though process 1 is dead when +process 0 receives a reply, the execution is still valid. + +.. raw:: html + +
+ + + 0 + 1 + + + + + + + 👶 + + 0,1 + + + 👶 + + 💀 + + + + +
+ +This point bears repeating. If ``jax.live_processes`` returns a set ``P`` of +processes, it does not mean that all processes in ``P`` are *currently* alive +and all other processes are *currently* dead. It only means that *there existed +a point in time* when this was true. + +In the following execution, process 1 calls ``jax.live_processes`` and fails. +Later, process 0 starts, calls ``jax.live_processes``, and receives ``0,1`` as +a reply. + +.. raw:: html + +
+ + + 0 + 1 + + + + + + + 👶 + + 0,1 + + + 👶 + + 💀 + +
+ +Using the formal semantics described thus far, this is *not* a valid execution. +There is never a point in time where process 0 and 1 are both alive. However, +this *should* be a valid execution. + +The reason has to do with the unavoidable fact that in a distributed system, it +is impossible to detect failures with 100% accuracy. If the coordination +service hasn't received heartbeats from a process in a while, it considers the +process dead. But, the coordination service cannot determine with 100% +certainty when the process died or if the process is actually dead at all. +Maybe the process died a long time ago, or maybe it died very recently, or +maybe it is alive but on the other side of a network partition. + +Let's return to the execution above for a concrete example. Imagine the +coordination service successfully received process 1's ``live_processes`` +request. Then, process 1 failed but the coordination service didn't detect the +failure immediately. In the meantime, the coordination service received process +0's ``live_processes`` request. At this point, the coordination service thought +both processes were alive and saw that both processes were in the barrier, so +it naturally returned ``0,1`` to both processes (though only process 0 received +the reply because process 1 was dead). + +The coordination service thought process 1 was alive when it was dead. And +sometimes the coordination service might think a process is dead when it is +alive. Though not ideal, we need to accommodate executions like this because +they are unavoidable. + +We amend our formal semantics and allow ourselves to move a failure either +earlier or later in time, though we cannot move a failure past a different +event from the same process. Intuitively, we can move a failure from when it +actually happened to the point in time when the coordination service thought it +happened. Continuing the example above, we can delay the failure of process 1 +to create a moment in time in which both processes 0 and 1 are in the barrier, +witnessing the fact that the execution is valid. + +.. raw:: html + +
+ + + 0 + 1 + + + + + + + 👶 + + 0,1 + + + 👶 + + + + + + 💀 + + + + + + + +
+ +Consider a similar execution below. + +.. raw:: html + +
+ + + 0 + 1 + + + + + + + 👶 + + 0 + + + 👶 + + 💀 + +
+ +As is, there is no moment in time in which process 0 is alive and process 1 is +dead. However, if we move the failure of process 1 leftwards, there is. How +might such an execution arise? Imagine process 1 is partitioned from the +coordination service. The coordination service doesn't receive any messages +from process 1, including its heartbeats. This leads the coordination service +to conclude that process 1 is dead, even though it isn't. Then, the +coordination service receives process 0's ``live_processes`` request and +responds with ``0``. + +.. raw:: html + +
+ + + 0 + 1 + + + + + + + 👶 + + 0 + + + 👶 + + + + + + 💀 + + + + + + + +
+ +We cannot move a process failure past the process' other events, however. For +example, the following execution is *invalid* because no matter where we move +the failure of process 1, there is never a moment in time where both processes +are in the barrier. + +.. raw:: html + +
+ + + 0 + 1 + + + + + + + 👶 + + 0,1 + + + 👶 + 👶 + + + + + + 💀 + + +
+ +With these formal semantics, we can make sense of even complex executions. For +example, consider the following execution. + +.. raw:: html + +
+ + + 0 + 1 + 2 + + + + + + + + 👶 + + 0 + + 0,2 + + + 👶 + + 💀 + 👶 + + 💀 + + + 👶 + + 💀 + +
+ + +After moving some process failures, we see the execution is valid. + +.. raw:: html + +
+ + + 0 + 1 + 2 + + + + + + + + 👶 + + 0 + + 0,2 + + + 👶 + + 💀 + 👶 + + 💀 + + + 👶 + + 💀 + + + + + +
+ +The following execution, on the other hand, is invalid. + +.. raw:: html + +
+ + + 0 + 1 + 2 + + + + + + + + 👶 + + 0,2 + + + 👶 + + 1 + 💀 + + + 👶 + + 💀 + +
+ + +Atomicity +^^^^^^^^^ + +Equipped with ``jax.live_processes``, let's try to write some fault-tolerant +multi-controller JAX code. + +.. code-block:: python + + step = 0 + while True: + # Get the devices on all live processes. + procs = jax.live_processes() + devices = [d for d in jax.devices() if d.process_index in procs] + + # Shard array x over these devices. + mesh = jax.make_mesh((len(devices),), ("i",), devices=devices) + spec = jax.sharding.PartitionSpec("i") + sharding = jax.sharding.NamedSharding(mesh, spec) + x = jax.make_array_from_process_local_data(sharding, np.ones(1)) + + # Try to perform a jnp.sum. + try: + print(jnp.sum(x)) + except: + # jnp.sum failed. + pass + else: + # jnp.sum succeeded. + step += 1 + +The code repeatedly + +- calls ``jax.live_processes`` to learn which processes are alive, +- computes the set of devices on the healthy processes, +- shards an array across these healthy devices, +- performs a ``jnp.sum`` (i.e. AllReduce) on the array, and +- increments ``step`` if the ``jnp.sum`` succeeds. + +This code *looks* correct, but it has a very subtle bug. Assume the ``jnp.sum`` +is being performed across a set of processes ``P``. If one (or more) of the +processes in ``P`` fails during the execution of the ``jnp.sum``, then +``jnp.sum`` can behave differently on different processes. Some processes in +``P`` might see ``jnp.sum`` return the correct result. Other processes might +see ``jnp.sum`` raise an exception. Others might see ``jnp.sum`` return an +incorrect result. + +.. warning:: + + If a process fails during a collective operation, the operation may behave + differently on different processes. + +This means that the processes executing the code example above might diverge. +Some might increment ``step``, and some might not. In the trivial code example +above, this divergence is benign, but in a real program, the divergence would +likely lead to a crash, a deadlock, or garbage outputs. For example, if a +multi-controller JAX program is training a model with data parallelism and +starts to diverge, some processes might roll back their model weights to a +previous checkpoint while others continue training, leading to a +"franken-model" where nobody agrees on what the model weights are supposed to +be. + +To write fault-tolerant code that does not diverge, we want **atomicity**. When +executing a block of code (like the ``jnp.sum`` above), we either want *every* +process to run the code successfully, or *every* process to learn that the code +failed to execute successfully. We don't want some processes succeeding and +others failing. + +Thankfully, we can achieve atomicity with a very simple trick: call +``live_processes`` twice, once before a code block and once after. If all the +processes that were alive before the block are also alive after the block, then +the code block executed successfully on all live processes. On the other hand, +if any process died, then all remaining processes can agree the code block +failed to execute properly. Here's a sketch of what that might look like: + +.. code-block:: python + + # Get the set of live processes before the code block. + procs_before = jax.live_processes() + + # Execute the code block. + ... + + # Get the set of live processes after the code block + procs_after = jax.live_processes() + if procs_before == procs_after: + # The code block executed successfully on all processes in + # procs_before. + pass + else: + # The code block did not execute successfully. All processes will + # agree it failed. + pass + +The code above should give you a rough idea of how to use two calls to +``live_processes`` to achieve atomicity, but there are still a handful of small +issues we need to address before it is fully correct. For example, + +- What if the code block throws an exception? We need to catch the exception + and still call ``live_processess`` the second time and then re-raise the + exception. +- What if a process fails after the first call to ``live_processes`` and + recovers before the second call? Wouldn't the code block fail but the + processes before and after be the same? Every time a process starts, it + generates a random **incarnation id**. In addition to checking that the set + of processes hasn't changed, we also check that their incarnation ids haven't + changed. +- What if a process recovers and its first call to ``live_processes`` matches + up with a different process' second call to ``live_processes``? Couldn't this + lead to a deadlock? Yes. We can avoid the problem by only calling + ``live_processes`` at a single program point. We can be clever and use a + single call to ``live_processes`` for two purposes. It can be used to check + that the set of processes hasn't changed since the previous call to + ``live_processes``, and it can be used to generate the set of live processes + that should be used the next time the atomic code block is executed. + +All these details are handled and abstracted away by the ``jax.live_devices`` +API introduced in :ref:`part1`. ``jax.live_devices`` is a context manager that +guarantees the atomic execution of a block of code. In the code snippet below, +``devices`` is a list of the devices on all live processes. The code block +``A`` will execute atomically across these processes. That is, either every +process will see the code raise an exception (branch ``B``) or every process +will see the code succeed (branch ``C``). + +.. code-block:: python + + try: + with live_devices() as devices: + pass # A + except Exception as e: + pass # B + else: + pass # C + +Cancelling Collectives +^^^^^^^^^^^^^^^^^^^^^^ + +As mentioned in :ref:`canceling_collectives`, if a process participating in a +collective fails, then the other participating processes get stuck forever. We +need to explicitly cancel these collectives to allow the alive participants to +make progress. While the ``live_devices`` API is supported on all JAX backends +(i.e. CPU, GPU, TPU), cancelling collectives is only supported by the GPU +backend. Here, we briefly explain some of the implementation details behind +collective cancelling. + +The GPU backend implements collectives using `NCCL`_, NVIDIA's collective +communication library. When a set of processes wants to perform a collective, +they form a **NCCL communicator**. Processes can then repeatedly perform +collectives using this communicator. Creating a communicator is expensive---it +requires network communication---so the JAX backend caches communicators keyed +by the set of participating processes and their incarnation ids. + +Internally, a JAX client polls the coordination service for the current status +of every process. If a client ever detects that a process is dead or has +restarted with a new incarnation id, then the client aborts all communicators +with the failed incarnation id in its cache key. + +.. _asynchronous dispatch: https://docs.jax.dev/en/latest/async_dispatch.html +.. _linearizability: https://cs.brown.edu/~mph/HerlihyW90/p463-herlihy.pdf +.. _many things in distributed systems: https://en.wikipedia.org/wiki/Fallacies_of_distributed_computing +.. _multi-controller JAX: https://docs.jax.dev/en/latest/multi_process.html +.. _NCCL: https://developer.nvidia.com/nccl +.. _reference: https://docs.jax.dev/en/latest/config_options.html#jax_enable_recoverability +.. _share fate: https://en.wikipedia.org/wiki/Fate-sharing From f1bc1a5ca412e8555113420982f287762c993de1 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 9 Dec 2025 11:29:38 -0800 Subject: [PATCH 123/315] [test] fix signatures test for NumPy nightly --- tests/lax_numpy_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 6156f8996994..8f974208e4f6 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6328,6 +6328,7 @@ def testWrappedSignaturesMatch(self): 'frombuffer': ['like'], 'fromfunction': ['like'], 'frompyfunc': ['kwargs'], + 'fromstring': ['like'], 'load': ['mmap_mode', 'allow_pickle', 'fix_imports', 'encoding', 'max_header_size'], 'nanpercentile': ['weights'], 'nanquantile': ['weights'], From b8cab948c86ebbcf6a36f064d88976ec81612e07 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 9 Dec 2025 11:37:53 -0800 Subject: [PATCH 124/315] [test] remove pre-NumPy 2.0 API skips --- tests/lax_numpy_test.py | 37 ------------------------------------- 1 file changed, 37 deletions(-) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 6156f8996994..6bbbf41a2494 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6254,43 +6254,6 @@ def testWrappedSignaturesMatch(self): 'trapz', 'typename'} - # symbols removed in NumPy 2.0 - skip |= {'add_docstring', - 'add_newdoc', - 'add_newdoc_ufunc', - 'alltrue', - 'asfarray', - 'byte_bounds', - 'compare_chararrays', - 'cumproduct', - 'deprecate', - 'deprecate_with_doc', - 'disp', - 'fastCopyAndTranspose', - 'find_common_type', - 'get_array_wrap', - 'geterrobj', - 'issctype', - 'issubclass_', - 'issubsctype', - 'lookfor', - 'mat', - 'maximum_sctype', - 'msort', - 'obj2sctype', - 'product', - 'recfromcsv', - 'recfromtxt', - 'round_', - 'safe_eval', - 'sctype2char', - 'set_numeric_ops', - 'set_string_function', - 'seterrobj', - 'sometrue', - 'source', - 'who'} - self.assertEmpty(skip.intersection(dir(jnp))) names = (name for name in dir(np) if not (name.startswith('_') or name in skip)) From 75697ef8680b51b8badc039d6fa3fbd7523960e0 Mon Sep 17 00:00:00 2001 From: Dougal Date: Sun, 23 Nov 2025 22:47:11 +0000 Subject: [PATCH 125/315] Remove dynamic shapes. Dead weight at this point. --- ci/run_bazel_test_tpu.sh | 1 - jax/_src/api.py | 16 +- jax/_src/config.py | 10 - jax/_src/core.py | 310 +-- jax/_src/dispatch.py | 19 - jax/_src/interpreters/ad.py | 29 +- jax/_src/interpreters/batching.py | 444 +---- jax/_src/interpreters/mlir.py | 75 +- jax/_src/interpreters/partial_eval.py | 470 +---- jax/_src/interpreters/pxla.py | 8 +- jax/_src/lax/control_flow/conditionals.py | 1 - jax/_src/lax/control_flow/loops.py | 6 +- jax/_src/lax/lax.py | 559 +----- jax/_src/lax/slicing.py | 136 +- jax/_src/lax/utils.py | 5 - jax/_src/linear_util.py | 29 +- jax/_src/numpy/array_methods.py | 1 - jax/_src/numpy/einsum.py | 3 +- jax/_src/numpy/indexing.py | 14 - jax/_src/numpy/util.py | 25 +- jax/_src/pallas/core.py | 16 +- jax/_src/pallas/hlo_interpreter.py | 3 - .../pallas/mosaic/pallas_call_registration.py | 3 +- jax/_src/pallas/pallas_call.py | 370 +--- jax/_src/pallas/primitives.py | 2 - jax/_src/pjit.py | 184 +- jax/_src/shard_map.py | 2 - jax/_src/stages.py | 2 - jax/_src/state/primitives.py | 11 - jax/core.py | 1 - .../jax2tf/examples/saved_model_main_test.py | 6 +- jax/experimental/jax2tf/tests/jax2tf_test.py | 5 +- jax/interpreters/batching.py | 35 - jax/interpreters/partial_eval.py | 40 - tests/BUILD | 12 - tests/api_test.py | 31 - tests/core_test.py | 191 +- tests/dynamic_api_test.py | 1770 ----------------- tests/pallas/BUILD | 24 - tests/pallas/pallas_jumble_test.py | 373 ---- 40 files changed, 261 insertions(+), 4981 deletions(-) delete mode 100644 tests/dynamic_api_test.py delete mode 100644 tests/pallas/pallas_jumble_test.py diff --git a/ci/run_bazel_test_tpu.sh b/ci/run_bazel_test_tpu.sh index 570b990e5484..09dfebfbd526 100755 --- a/ci/run_bazel_test_tpu.sh +++ b/ci/run_bazel_test_tpu.sh @@ -188,7 +188,6 @@ else //tests/pallas:tpu_pallas_call_print_test_tpu \ //tests/pallas:indexing_test_tpu \ //tests/pallas:pallas_error_handling_test_tpu \ - //tests/pallas:pallas_jumble_test_tpu \ //tests/pallas:pallas_shape_poly_test_tpu \ //tests/pallas:tpu_all_gather_test_tpu \ //tests/pallas:tpu_fusible_matmul_test_tpu \ diff --git a/jax/_src/api.py b/jax/_src/api.py index 39097971ceee..1ad727b5cf2a 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -172,7 +172,6 @@ def jit( device: xc.Device | None = ..., backend: str | None = ..., inline: bool = ..., - abstracted_axes: Any | None = ..., compiler_options: dict[str, Any] | None = ..., ) -> pjit.JitWrapped: ... @@ -189,7 +188,6 @@ def jit( device: xc.Device | None = ..., backend: str | None = ..., inline: bool = ..., - abstracted_axes: Any | None = ..., compiler_options: dict[str, Any] | None = ..., ) -> Callable[[Callable], pjit.JitWrapped]: ... @@ -205,7 +203,6 @@ def jit( device: xc.Device | None = None, backend: str | None = None, inline: bool = False, - abstracted_axes: Any | None = None, compiler_options: dict[str, Any] | None = None, ) -> pjit.JitWrapped | Callable[[Callable], pjit.JitWrapped]: """Sets up ``fun`` for just-in-time compilation with XLA. @@ -350,8 +347,7 @@ def jit( static_argnums=static_argnums, static_argnames=static_argnames, donate_argnums=donate_argnums, donate_argnames=donate_argnames, keep_unused=keep_unused, device=device, backend=backend, inline=inline, - abstracted_axes=abstracted_axes, compiler_options=compiler_options, - use_resource_env=False) + compiler_options=compiler_options, use_resource_env=False) if isinstance(fun, NotSpecified): return lambda fun: pjit.make_jit(fun, **kwds) else: @@ -2563,13 +2559,13 @@ def transposed_fun(const, out_cotangent): return Partial(transposed_fun, const) -def _flat_axes_specs(abstracted_axes, *args, **kwargs +def _flat_axes_specs(*args, **kwargs ) -> list[pe.AbstractedAxesSpec]: if kwargs: raise NotImplementedError def ax_leaf(l): return (isinstance(l, dict) and all_leaves(l.values()) or isinstance(l, tuple) and all_leaves(l, lambda x: x is None)) - return broadcast_prefix(abstracted_axes, args, ax_leaf) + return broadcast_prefix(args, ax_leaf) @overload @@ -2578,7 +2574,6 @@ def make_jaxpr( static_argnums: int | Iterable[int] = (), axis_env: Sequence[tuple[AxisName, int]] | None = None, return_shape: Literal[False] = ..., - abstracted_axes: Any | None = None, ) -> Callable[..., core.ClosedJaxpr]: ... @@ -2588,7 +2583,6 @@ def make_jaxpr( static_argnums: int | Iterable[int] = (), axis_env: Sequence[tuple[AxisName, int]] | None = None, return_shape: Literal[True] = ..., - abstracted_axes: Any | None = None, ) -> Callable[..., tuple[core.ClosedJaxpr, Any]]: ... @@ -2598,7 +2592,6 @@ def make_jaxpr( static_argnums: int | Iterable[int] = (), axis_env: Sequence[tuple[AxisName, int]] | None = None, return_shape: bool = False, - abstracted_axes: Any | None = None, ) -> Callable[..., core.ClosedJaxpr | tuple[core.ClosedJaxpr, Any]]: """Create a function that returns the jaxpr of ``fun`` given example args. @@ -2666,8 +2659,7 @@ def make_jaxpr( @api_boundary def make_jaxpr_f(*args, **kwargs): with core.extend_axis_env_nd(axis_env or []): - traced = jit(fun, static_argnums=static_argnums, - abstracted_axes=abstracted_axes).trace(*args, **kwargs) + traced = jit(fun, static_argnums=static_argnums).trace(*args, **kwargs) # `jit` converts tracers in consts to args but `make_jaxpr` callers expect # consts not to be converted. num_consts = traced._num_consts diff --git a/jax/_src/config.py b/jax/_src/config.py index 04c0d30de5a5..e7cf1c772cf3 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1785,16 +1785,6 @@ def _validate_default_device(val): default=False, help=('Enables lowering BCOO ops to cuSparse.')) -# TODO(mattjj): remove this flag when we ensure we only succeed at trace-staging -# if the intended backend can handle lowering the result -dynamic_shapes = bool_state( - name='jax_dynamic_shapes', - default=False, - help=('Enables experimental features for staging out computations with ' - 'dynamic shapes.'), - include_in_jit_key=True, - include_in_trace_context=True) - # This is for stackless backward compat with e.g. equinox eager_constant_folding = bool_state( name='eager_constant_folding', diff --git a/jax/_src/core.py b/jax/_src/core.py index 9aa32fe6e7b5..48a8db1c5e5f 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -731,7 +731,7 @@ def read(v: Atom) -> Any: return v.val if isinstance(v, Literal) else env[v] def write(v: Var, val: Any) -> None: - if config.enable_checks.value and not config.dynamic_shapes.value: + if config.enable_checks.value: assert typecheck(v.aval, val), (v.aval, get_aval(val), val) env[v] = val @@ -1689,38 +1689,8 @@ def lo_ty_qdd(self, qdd): def str_short(self, short_dtypes=False, mesh_axis_types=False): return str(self) -# For type signatures involving dynamic shapes, we use lists of abstract values -# which may contain (reverse) de Bruijn indices in their shapes. -class DBIdx(NamedTuple): - val: int - -@dataclass(frozen=True) -class InDBIdx: - val: int - -@dataclass(frozen=True) -class OutDBIdx: - val: int - -# For annotating input types of callables (i.e. linear_util.WrappedFuns), we use -# a sequence of pairs where the first element of each pair is an AbstractValue -# (possibly containing DBIdx instances in its shape) and the second is a boolean -# indicating whether that argument is explicit (i.e. passed to the callable). -InputType = tuple[tuple[AbstractValue, bool], ...] # DBIdx in shapes - -# For annotating jaxpr output types, we use a sequence of pairs where the first -# element of each pair is an AbstractValue (possibly containing InDBIdx and/or -# OutDBIdx instances in its shape) and the second is a boolean indicating -# whether that argument is explicit (i.e. returned by the callable). -OutputType = tuple[tuple[AbstractValue, bool], ...] # InDBIdx / OutDBIdx shapes - - -def _jaxpr_type_to_callable_annotation(jaxpr: Jaxpr) -> InputType: - idxs = {v: DBIdx(i) for i, v in enumerate((*jaxpr.constvars, *jaxpr.invars))} - out = [(v.aval.update(shape=tuple(idxs.get(d, d) for d in v.aval.shape)) # type: ignore - if type(v.aval) is DShapedArray else v.aval, True) - for v in jaxpr.invars] - return tuple(out) +InputType = tuple[AbstractValue] +OutputType = tuple[AbstractValue] # For use in typing annotations to denote either a Tracer or a `valid_jaxtype`. Value = Any @@ -1957,21 +1927,17 @@ def cur_aval_qdd(x): @overload def physical_aval(aval: ShapedArray) -> ShapedArray: ... -@overload -def physical_aval(aval: DShapedArray) -> DShapedArray: ... @overload # TODO(frostig): remove this case def physical_aval(aval: AbstractValue) -> AbstractValue: ... def physical_aval(aval): - if (isinstance(aval, (ShapedArray, DShapedArray)) and + if (isinstance(aval, ShapedArray) and isinstance(aval.dtype, dtypes.ExtendedDType)): elt_aval = physical_element_aval(aval.dtype) - if isinstance(aval, ShapedArray): - from jax._src.sharding_impls import physical_sharding # type: ignore - return ShapedArray((*aval.shape, *elt_aval.shape), elt_aval.dtype, - sharding=physical_sharding(aval, aval.sharding), - vma=aval.vma) - return DShapedArray((*aval.shape, *elt_aval.shape), elt_aval.dtype) + from jax._src.sharding_impls import physical_sharding # type: ignore + return ShapedArray((*aval.shape, *elt_aval.shape), elt_aval.dtype, + sharding=physical_sharding(aval, aval.sharding), + vma=aval.vma) return aval def physical_shape(logical_shape, dtype): @@ -1993,15 +1959,7 @@ def _canonicalize_dimension(dim: DimSize) -> DimSize: return operator.index(dim) except TypeError as e: type_error = e - if isinstance(dim, Tracer) and config.dynamic_shapes.value: - if not (dim.ndim == 0 and (dtypes.issubdtype(dim.dtype, np.integer) - or isinstance(dim.dtype, bint))): - raise TypeError(f"Dimensions must be integer scalars; got {dim.ndim=} {dim.dtype=}") - return dim - elif (config.dynamic_shapes.value and isinstance(dim, DArray) and - type(dim._aval.dtype) is bint and not dim._aval.shape): - return dim - elif is_dim(dim): + if is_dim(dim): return dim else: raise type_error @@ -2035,16 +1993,11 @@ def canonicalize_dim(d: DimSize, context: str="") -> DimSize: return canonicalize_shape((d,), context)[0] def _invalid_shape_error(shape: Shape, context: str=""): - if config.dynamic_shapes.value: - msg = ("Shapes must be 1D sequences of integer scalars, " - f"got {shape}") - else: - msg = ("Shapes must be 1D sequences of concrete values of integer type, " - f"got {shape}.") + msg = ("Shapes must be 1D sequences of concrete values of integer type, " + f"got {shape}.") if context: msg += f" {context}." - if not config.dynamic_shapes.value and any( - isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray) + if any(isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray) and not is_concrete(x) for x in shape): msg += ("\nIf using `jit`, try using `static_argnums` or applying `jit` to " "smaller subfunctions.") @@ -2485,149 +2438,6 @@ def standard_vma_rule(prim_name, *avals, **kwargs) -> frozenset[AxisName]: 'workaround pass the check_vma=False argument to `jax.shard_map`') return vma -# Dynamic shape stuff below here! We keep the abstract values distinct just so -# as not to interfere with any static shape machinery. - -# We have a convention of reusing AbsractValues as types, even though we could -# make a distinction and use abstract values during tracing only. This reuse -# becomes a bit more extreme with DShapedArrays. A DShapedArray's shape -# attribute is a tuple which can contain several different types: int, DArray -# (scalar and with dtype of bint type), Tracer (while tracing), Var (when used -# as jaxpr type annotations), or DBIdx/InDBIdx/OutDBIdx (when used in InputType -# or OutputType). We could reduce this polymorphism if it seems cleaner, though -# it's kind of convenient! -class DShapedArray(AbstractValue): - __slots__ = ['shape', 'dtype', 'weak_type'] - shape: tuple[AxisSize, ...] # noqa: F821 - array_abstraction_level: int = 3 - - def __init__(self, shape, dtype, weak_type=False): - assert not any(isinstance(d, Literal) for d in shape) - self.shape = shape - self.dtype = dtype - self.weak_type = weak_type - - ndim = property(lambda self: len(self.shape)) - size = property(lambda self: - 0 if any(type(d) is int and d == 0 for d in self.shape) - else math.prod(self.shape)) - - def str_short(self, short_dtypes=False, mesh_axis_types=False) -> str: - del short_dtypes # ignored - shape = f'{",".join(str(d) for d in self.shape)}' if self.shape else '' - dtype = dtypes.short_dtype_name(self.dtype) - return f'{dtype}[{shape}]' - __str__ = __repr__ = str_short - - def update(self, shape=None, dtype=None, weak_type=None): - if shape is None: - shape = self.shape - if dtype is None: - dtype = self.dtype - if weak_type is None: - weak_type = self.weak_type - return DShapedArray(shape, dtype, weak_type) - - @property - def sharding(self): - return NamedSharding(mesh_lib.empty_abstract_mesh, P()) - - @property - def vma(self): - return frozenset() - - def _len(self, tracer): - return self.shape[0] - - def __eq__(self, other): - return (type(self) is type(other) - and self.dtype == other.dtype and self.shape == other.shape - and self.weak_type == other.weak_type) - - def __hash__(self): - # We don't hash the contents of the shape because it may contain tracers. - return hash((len(self.shape), self.dtype, self.weak_type)) - - def __ne__(self, other): - return not self == other - - def to_tangent_aval(self): - return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), - self.weak_type) - - def update_vma(self, vma): - return self - - def update_weak_type(self, weak_type): - return self.update(weak_type=weak_type) - - _bool = concretization_function_error(bool) - _int = concretization_function_error(int, True) - _float = concretization_function_error(float, True) - _complex = concretization_function_error(complex, True) - _hex = concretization_function_error(hex) - _oct = concretization_function_error(oct) - _index = concretization_function_error(operator.index) - - -class DArray: - _aval: DShapedArray - _data: Any # standard array type - def __init__(self, aval, data): - pad_shape = tuple(d.dtype.bound if type(d) is DArray and - type(d.dtype) is bint else d for d in aval.shape) - assert data.shape == pad_shape - self._aval = aval - self._data = data - - shape = property(lambda self: self._aval.shape) - dtype = property(lambda self: self._aval.dtype) - aval = property(lambda self: self._aval) - def __repr__(self) -> str: - if not self.shape and type(self.dtype) is bint: - # special-case scalar bints - return f'{int(self._data)}{{≤{self.dtype.bound}}}' - - dtypestr = dtypes.short_dtype_name(self._aval.dtype) - shapestr = ','.join(map(str, self.shape)) - data = self.data - return f'{dtypestr}[{shapestr}] with value: {data}' - - def __hash__(self) -> int: - if not self.shape: - return hash((self._aval, int(self._data))) - raise TypeError("unhashable type: DArray") - - def __eq__(self, other): - if isinstance(other, DArray) and self._aval == other._aval: - return self._data == other._data - return False - - def __len__(self): - return self.shape[0] - - @property - def data(self): - if not self.shape and type(self.dtype) is bint: - # special-case scalar bints - return self._data - - slices = tuple( - slice(int(d._data)) - if type(d) is DArray and type(d.dtype) is bint - else slice(None) - for d in self.shape - ) - data = self._data[slices] - return data - -def _darray_aval(x): - return DShapedArray(x._aval.shape, x._aval.dtype, x._aval.weak_type) - -pytype_aval_mappings[DArray] = _darray_aval -dtypes.canonicalize_value_handlers[DArray] = lambda x: x - - @dataclass(frozen=True) class bint(dtypes.ExtendedDType): bound: int @@ -2643,7 +2453,7 @@ def name(self) -> str: def __str__(self) -> str: return self.name -AxisSize = Union[int, DArray, Tracer, Var, DBIdx, InDBIdx, OutDBIdx] +AxisSize = Union[int, Tracer, Var] class RefMeta(type): @@ -3097,8 +2907,6 @@ def get_bind_params(self, params): jaxpr = new_params.pop('call_jaxpr') subfun = lu.hashable_partial( lu.wrap_init(eval_jaxpr, debug_info=jaxpr.debug_info), jaxpr, ()) - if config.dynamic_shapes.value: - subfun = lu.annotate(subfun, _jaxpr_type_to_callable_annotation(jaxpr)) return [subfun], new_params def call_impl(f: lu.WrappedFun, *args, **params): @@ -3194,25 +3002,8 @@ def _unmap_shaped_array( else: raise TypeError(axis) -def _map_dshaped_array( - size: AxisSize, axis: int | None, aval: DShapedArray) -> DShapedArray: - if axis is None: return aval - return DShapedArray(tuple_delete(aval.shape, axis), aval.dtype, - aval.weak_type) - -def _unmap_dshaped_array( - size: AxisSize, axis: int | None, explicit_mesh_axis, aval: DShapedArray - ) -> DShapedArray: - if axis is None: return aval - elif type(axis) is int: - return DShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype, - weak_type=aval.weak_type) - else: - raise TypeError(axis) - AvalMapHandlerPair = tuple[Callable, Callable] aval_mapping_handlers: dict[type, AvalMapHandlerPair] = { - DShapedArray: (_map_dshaped_array, _unmap_dshaped_array), ShapedArray: (_map_shaped_array, _unmap_shaped_array), AbstractToken: (lambda _, __, a: a, lambda _, __, ____, a: a) } @@ -3296,10 +3087,7 @@ def typematch(t1: AbstractValue, t2: AbstractValue) -> bool: from jax._src.state.types import AbstractRef # pytype: disable=import-error if t1 == t2: return True - elif (isinstance(t1, (ShapedArray, DShapedArray)) and - isinstance(t2, (ShapedArray, DShapedArray))): - # This case handles DShapedArray and shape polynomials. Alternatively we - # could try normalizing first and then doing simple equality. + elif isinstance(t1, ShapedArray) and isinstance(t2, ShapedArray): cmp = (t1.dtype == t2.dtype and definitely_equal_shape(t1.shape, t2.shape) and t1.vma == t2.vma and t1.memory_space == t2.memory_space) # type: ignore # TODO(yashkatariya): Expand this to Manual and Auto mode. @@ -3518,7 +3306,6 @@ def write(v: Var, a: AvalQDD) -> None: f"Jaxpr effects: {jaxpr.effects}") # Check out_type matches the let-binders' annotation (after substitution). - out_type = substitute_vars_in_output_ty(out_type, eqn.invars, eqn.outvars) out_type = [t if isinstance(t, AvalQDD) else AvalQDD(t, None) for t in out_type] foreach(write, eqn.outvars, out_type) @@ -3545,51 +3332,7 @@ def check_type( env: dict[Var, Atom | MutableTypecheckVal], ty: AbstractValue, ) -> None: - if isinstance(ty, DShapedArray): - # Check all elements in the shape tuple are well-typed. - for d in ty.shape: - if (isinstance(d, int) or - isinstance(d, DArray) and not d.shape and type(d.dtype) == bint): - continue - elif isinstance(d, Var): - if d not in env: - ctx, _ = ctx_factory() - raise JaxprTypeError(f"unbound axis size: '{pp_var(d, ctx)}'") - if not isinstance(d.aval, (ShapedArray, DShapedArray)): - raise JaxprTypeError(f"axis size with unexpected type annotation: " - f"{d.aval} of type {type(d.aval)}") - if isinstance(d.aval, ShapedArray): - shape, dtype = d.aval.shape, d.aval.dtype - if shape: raise JaxprTypeError(f"axis size nonscalar: {d.aval}") - if not dtypes.issubdtype(dtype, np.integer): - raise JaxprTypeError(f"axis size with non-integer dtype: {d.aval}") - else: - assert isinstance(d.aval, DShapedArray) - shape, dtype = d.aval.shape, d.aval.dtype - if shape: raise JaxprTypeError(f"axis size nonscalar: {d.aval}") - if type(dtype) is not bint: - raise JaxprTypeError( - f"DArray axis size with non-bint dtype: {d.aval}") - else: - raise JaxprTypeError(f"unexpected type in shape: {type(d)}") - else: - return # Except in above case(s), all syntactic forms are valid - -def substitute_vars_in_output_ty( - out_type: Sequence[AbstractValue], # shapes may contain InDBIdx / OutDBIdx - in_atoms: Sequence[Atom], - out_binders: Sequence[Var], - ) -> list[AbstractValue]: # shapes may contain Vars - in_atoms = [x.val if type(x) is Literal else x for x in in_atoms] - result = [] - for aval in out_type: - if type(aval) is DShapedArray: - shape = [in_atoms[d.val] if type(d) is InDBIdx else - out_binders[d.val] if type(d) is OutDBIdx else - d for d in aval.shape] - aval = aval.update(shape=tuple(shape)) - result.append(aval) - return result + return # Except in above case(s), all syntactic forms are valid def check_eqn(prim, in_avals, params): for jaxpr in jaxprs_in_params(params): @@ -3616,29 +3359,19 @@ def _check_call(ctx_factory, prim, in_atoms, params): # Check `call_jaxpr` can be applied to in_atoms. env: dict[Var, Atom | MutableTypecheckVal] = {} - def substitute(aval: AbstractValue): - if isinstance(aval, DShapedArray): - aval = aval.update(shape=tuple(env.get(d, d) for d in aval.shape)) # type: ignore - return aval for v, x in zip(call_jaxpr.invars, in_atoms): - if not typecompat(substitute(v.aval), x.aval): + if not typecompat(v.aval, x.aval): # TODO(mattjj): vars in error message are confusing b/c of Var.__repr__ raise JaxprTypeError(f"Call primitive {prim} passes operand {x} of type " f"{x.aval} to jaxpr expecting type " - f"{substitute(v.aval)}") + f"{v.aval}") env[v] = x.val if type(x) is Literal else x check_jaxpr(call_jaxpr) invars, outvars = call_jaxpr.invars, call_jaxpr.outvars - in_map : dict[Var, InDBIdx] = {v: InDBIdx(i) for i, v in enumerate( invars)} - out_map: dict[Var, OutDBIdx] = {x: OutDBIdx(i) for i, x in enumerate(outvars) - if type(x) is Var} out_avals = [x.aval for x in call_jaxpr.outvars] - out_type = [a.update(shape=tuple(in_map.get(d, out_map.get(d)) - if type(d) is Var else d for d in a.shape)) - if type(a) is DShapedArray else a for a in out_avals] - + out_type = out_avals # jaxpr input effects are indexed to include jaxpr.constvars, but the eqn # should have effects indexed only on its explicit arguments effs = {e.replace(input_index=e.input_index - len(call_jaxpr.constvars)) @@ -3952,12 +3685,7 @@ def pp_var(v: Var | Literal, context: JaxprPpContext, *, return v.pretty_print(context, print_dtype=print_literal_dtype) def pp_aval(a: AbstractValue, context: JaxprPpContext) -> str: - if isinstance(a, DShapedArray): - shape = [pp_var(d, context) if type(d) is Var else str(d) for d in a.shape] - dtype = dtypes.short_dtype_name(a.dtype) - return f'{dtype}[{",".join(shape)}]' - else: - return a.str_short(short_dtypes=True) + return a.str_short(short_dtypes=True) def pp_vars(vs: Sequence[Atom], context: JaxprPpContext, *, separator="", print_shapes: bool = False) -> pp.Doc: diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index cdd9d3462027..1c50cf5bd373 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -19,7 +19,6 @@ from collections.abc import Sequence import dataclasses from functools import partial -import itertools import logging import threading import time @@ -286,24 +285,6 @@ def get_intermediate_shardings( return out -def jaxpr_has_bints(jaxpr: core.Jaxpr) -> bool: - return (any(type(v.aval.dtype) is core.bint for v in jaxpr.invars - if isinstance(v.aval, (core.ShapedArray, core.DShapedArray))) or - any(_is_bint_axis_size(d) - for j in itertools.chain([jaxpr], core.subjaxprs(jaxpr)) - for e in j.eqns for v in e.outvars - if isinstance(v.aval, core.DShapedArray) for d in v.aval.shape)) - -def _is_bint_axis_size(d: core.AxisSize) -> bool: - if isinstance(d, core.DArray): - assert not d.shape - return type(d.dtype) is core.bint - elif isinstance(d, core.Var): - return (isinstance(d.aval, core.DShapedArray) and - type(d.aval.dtype) is core.bint) - return False - - def check_arg(arg: Any): if not (isinstance(arg, core.Tracer) or core.valid_jaxtype(arg)): raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid " diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 3ebe87212ae2..863839cdc080 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -51,18 +51,12 @@ def identity(x): return x def _update_annotation( f: lu.WrappedFun, - orig_type: tuple[tuple[core.AbstractValue, bool], ...] | None, - explicit_nonzeros: list[bool] + orig_type: tuple[core.AbstractValue, ...] | None, + nonzeros: list[bool] ) -> lu.WrappedFun: if orig_type is None: return f - # By convention, `explicit_nonzeros` only accounts for explicit arguments. - assert len(explicit_nonzeros) == sum(explicit for _, explicit in orig_type) - # Implicit arguments never have tangents, so generate the tangent part of the - # type annotation from explicit arguments only. - explicit_avals = [aval for aval, explicit in orig_type if explicit] - tan_types = [(aval.to_tangent_aval(), True) - for nz, aval in zip(explicit_nonzeros, explicit_avals) if nz] + tan_types = [aval.to_tangent_aval() for nz, aval in zip(nonzeros, orig_type) if nz] return lu.annotate(f, (*orig_type, *tan_types)) def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True, @@ -343,11 +337,6 @@ def write_cotangent(prim, v, ct): # assert v.aval == ct.aval, (prim, v.aval, ct.aval) return ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct - # TODO(mattjj): add back these checks for dynamic shapes - # if config.enable_checks.value: - # ct_aval = core.get_aval(ct_env[v]) - # joined_aval = core.lattice_join(v.aval, ct_aval).strip_weak_type() - # assert v.aval.strip_weak_type() == joined_aval, (prim, v.aval, ct_aval) def read_cotangent(v): return ct_env.pop(v, Zero(v.aval.to_tangent_aval())) @@ -357,9 +346,6 @@ def read_primal(v): return v.val else: a = v.aval - if type(a) is core.DShapedArray: - shape = [primal_env[d] if type(d) is core.Var else d for d in a.shape] - a = a.update(shape=tuple(shape)) return primal_env.get(v, UndefinedPrimal(a)) def write_primal(v, val): @@ -1341,15 +1327,6 @@ def call_transpose(primitive, params, call_jaxpr: core.Jaxpr, args, ct, _): if update_params: params = update_params(params, map(is_undefined_primal, args), [type(x) is not Zero for x in ct]) - if config.dynamic_shapes.value: - # TODO(mattjj,dougalm): handle consts, for now assume just args - which_lin = [is_undefined_primal(x) for x in args] - res_invars, _ = partition_list(which_lin, call_jaxpr.invars) - new_invars = [*res_invars, *call_jaxpr.outvars] - dbidx_map = {v: core.DBIdx(i) for i, v in enumerate(new_invars)} - in_type = [(v.aval.update(shape=tuple(dbidx_map.get(d, d) for d in v.aval.shape)) # type: ignore[arg-type] - if type(v.aval) is core.DShapedArray else v.aval, True) for v in new_invars] - fun = lu.annotate(fun, tuple(in_type)) out_flat = primitive.bind(fun, *all_args, **params) return tree_unflatten(out_tree(), out_flat) primitive_transposes[core.call_p] = partial(call_transpose, call_p) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 9eafef3a6396..abdcddfd41f5 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -13,7 +13,6 @@ # limitations under the License. from __future__ import annotations -import collections from collections.abc import Callable, Sequence import dataclasses from functools import partial @@ -32,8 +31,7 @@ from jax._src.ad_util import Zero, SymbolicZero, add_jaxvals, add_jaxvals_p from jax._src.core import Trace, Tracer, TraceTag, AxisName from jax._src.interpreters import partial_eval as pe -from jax._src.tree_util import (tree_unflatten, tree_flatten, - register_pytree_node, PyTreeDef) +from jax._src.tree_util import (tree_unflatten, tree_flatten, PyTreeDef) from jax._src.typing import Array from jax._src.util import (unzip2, safe_map, safe_zip, split_list, canonicalize_axis, moveaxis, as_hashable_function, @@ -44,198 +42,6 @@ zip, unsafe_zip = safe_zip, zip -# Jumbles - -# i:(Fin 3) => f32[[3, 1, 4].i] -@dataclasses.dataclass(frozen=True) -class JumbleTy: - binder: core.Var - length: int | Tracer | core.Var - elt_ty: core.DShapedArray - def __repr__(self) -> str: - return f'Var{id(self.binder)}:{self.length} => {self.elt_ty}' - replace = dataclasses.replace - -# [3, 1, 4].i -@dataclasses.dataclass(frozen=True) -class IndexedAxisSize: - idx: core.Var - lengths: Array | core.Var | Tracer - def __repr__(self) -> str: - return f'{self.lengths}.Var{id(self.idx)}' - replace = dataclasses.replace - -# Jumble(aval=a:3 => f32[[3 1 4].a], -# data=Array([0., 1., 2., 0., 0., 1., 2., 3.], dtype=float32)) -@dataclasses.dataclass(frozen=True) -class Jumble: - aval: JumbleTy - data: Array - -# To vmap over a jumble, one must specify the axis as JumbleAxis. -class JumbleAxis: pass -jumble_axis = JumbleAxis() - -# As a temporary measure before we have more general JITable / ADable interfaces -# (analogues to vmappable), to enable Jumbles to be used with other -# transformations and higher-order primitives (primarily jit, though also grad -# with allow_int=True) we register them as pytrees. -# TODO(mattjj): add JITable / ADable interfaces, remove this pytree registration -def _jumble_flatten(jumble): - lengths = [] - new_shape = [lengths.append(d.lengths) or d.replace(lengths=len(lengths)) - if type(d) is IndexedAxisSize else d - for d in jumble.aval.elt_ty.shape] - elt_ty = jumble.aval.elt_ty.update(shape=tuple(new_shape)) - aval = jumble.aval.replace(elt_ty=elt_ty) - return (lengths, jumble.data), aval - - -def _ragged_axis_parts(dim: RaggedAxis) -> tuple[int, int, int]: - stacked_axis = dim.stacked_axis - ragged_axes = dim.ragged_axes - if len(ragged_axes) != 1: - raise ValueError('Multiple ragged axes not yet implemented.') - ragged_axis_dim = ragged_axes[0][0] - ragged_axis_length = ragged_axes[0][1] - return stacked_axis, ragged_axis_dim, ragged_axis_length - - -def _jumble_unflatten(aval, x): - lengths, data = x - new_shape = [d.replace(lengths=lengths[d.lengths - 1]) - if type(d) is IndexedAxisSize else d - for d in aval.elt_ty.shape] - elt_ty = aval.elt_ty.update(shape=tuple(new_shape)) - aval = aval.replace(elt_ty=elt_ty) - return Jumble(aval, data) -register_pytree_node(Jumble, _jumble_flatten, _jumble_unflatten) - -def _jumble_result(axis_size, stacked_axis, ragged_axes, x): - binder = core.Var(core.ShapedArray((), np.dtype('int32'))) - if stacked_axis != 0: - raise NotImplementedError # TODO Transpose x so the stacked axis is axis 0 - shape = list(x.shape) - del shape[0] - for ragged_axis, segment_lens in ragged_axes: - shape[ragged_axis-1] = IndexedAxisSize(binder, segment_lens) - elt_ty = core.DShapedArray(tuple(shape), x.dtype, x.weak_type) - return Jumble(JumbleTy(binder, axis_size, elt_ty), x) - - -@dataclasses.dataclass(frozen=True) -class RaggedAxis: - stacked_axis: int - # For each axis, we store its index and the corresponding segment lengths. - # For example, the jumble i:(Fin 3) => f32[lens1.i, 7, lens2.i] - # would be represented with ragged_axes = [(1, lens1), (3, lens2)] - ragged_axes: tuple[tuple[int, Any], ...] - - @property - def size(self): - # TODO(mattjj, axch): All the segment lengths arrays better be the - # same length! - return len(self.ragged_axes[0][1]) - - def move_stacked_axis(self: RaggedAxis, dst: int) -> RaggedAxis: - # Assumes that all stored and incoming axes are already canonicalized - def move_axis(ax): - if self.stacked_axis > ax and ax >= dst: - return ax + 1 - if self.stacked_axis < ax and ax <= dst: - return ax - 1 - return ax - new_axes = tuple((move_axis(ax), sizes) for ax, sizes in self.ragged_axes) - return RaggedAxis(dst, new_axes) - - -def transpose_ragged_axes(dim: RaggedAxis, perm: tuple[int, ...]) -> RaggedAxis: - new_ragged_axes = [] - for idx, old_idx in enumerate(perm): - for ax, size in dim.ragged_axes: - if old_idx == ax: - new_ragged_axes.append((idx, size)) - break - return _sorted_ragged_axis(dim.stacked_axis, new_ragged_axes) - -def _sorted_ragged_axis(stacked_axis, ragged_axes): - return RaggedAxis(stacked_axis, tuple(sorted(ragged_axes, key=lambda p: p[0]))) - -def make_batch_axis( - ndim: int, - stacked_axis: int, - ragged_axes: list[tuple[int, Array | core.Var]], -) -> int | RaggedAxis: - if ragged_axes: - canonical = [(canonicalize_axis(ax, ndim), sz) for ax, sz in ragged_axes] - return _sorted_ragged_axis(canonicalize_axis(stacked_axis, ndim), canonical) - else: - return canonicalize_axis(stacked_axis, ndim) - -def bdim_as_shape( - bdim: int | RaggedAxis, data_shape: core.Shape) -> core.Shape: - if isinstance(bdim, RaggedAxis): - result = list(data_shape) - binder = core.Var(core.ShapedArray((), np.dtype('int32'))) - for ragged_axis, segment_lens in bdim.ragged_axes: - result[ragged_axis] = IndexedAxisSize(binder, segment_lens) - return tuple(result) - else: - return data_shape - -def shape_as_bdim( - stacked_axis: int, data_shape: core.Shape) -> int | RaggedAxis: - # This assumes that there is only one binder in the data_shape. - ragged_axes = [(i, size.lengths) for i, size in enumerate(data_shape) - if isinstance(size, IndexedAxisSize)] - return make_batch_axis(len(data_shape), stacked_axis, ragged_axes) - - -def _update_annotation( - f: lu.WrappedFun, orig_type: core.InputType | None, - axis_size: core.AxisSize, axis_name: AxisName, - explicit_in_dims: Sequence[int | RaggedAxis | None], - segment_lens: Sequence[Array], - ) -> lu.WrappedFun: - if orig_type is None: return f - # By convention, `explicit_in_dims` only accounts for explicit arguments. - assert len(explicit_in_dims) == sum(explicit for _, explicit in orig_type) - # We need to: - # * if `axis_size` is dynamic, add a new implicit binder (type) for it; - # * for each element of `segment_lengths`, add a new explicit binder for it; - # * drop other implicit binders, replacing DBIdx which refer to them with - # Name objects; - # * for each (aval, in_dim) pair: if int-valued in_dim, add batch axis (int - # size if `axis_size` is int, otherwise Name); if RaggedAxis-valued in_dim, - # add batch axis (int if corresponding segment_lengths is concrete, Name if - # not); - # * generate full in_type with implicit args too. - - class Name: - def __init__(self, a): self.a = a - names = [Name(a) for a, _ in orig_type] - avals = [a.update(shape=tuple(names[d.val] if type(d) is pe.DBIdx else d - for d in a.shape)) - if type(a) is core.DShapedArray else a for a, e in orig_type if e] - - new_avals = [core.get_aval(s) for s in segment_lens] - sz = Name(axis_size.aval) if isinstance(axis_size, Tracer) else axis_size - for a, d in zip(avals, explicit_in_dims): - if isinstance(d, RaggedAxis): - raise NotImplementedError - else: - new_avals.append(core.unmapped_aval(sz, d, a)) # type: ignore - - mentioned = {d for a in new_avals if type(a) is core.DShapedArray - for d in a.shape if type(d) is Name} - expl_names = set(map(Name, new_avals)) - impl_names = mentioned - expl_names # type: ignore - impl_part = [(n.a, False) for n in impl_names] # type: ignore - name_map = {n: pe.DBIdx(i) for i, n in enumerate((*impl_names, *expl_names))} - expl_part = [(a.update(shape=tuple(name_map.get(d, d) for d in a.shape)) - if type(a) is core.DShapedArray else a, True) for a in new_avals] - return lu.annotate(f, (*impl_part, *expl_part)) - ### vmappable typeclass Vmappable = Any @@ -252,26 +58,11 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt: handler = to_elt_handlers.get(type(x)) if handler: return handler(partial(to_elt, trace, get_idx), get_idx, x, spec) - elif type(x) is Jumble: - if spec is not jumble_axis: - raise TypeError("jumble input without using jumble_axis in_axes spec") - ias: IndexedAxisSize # Not present in the AxisSize union in core.py - (d, ias), = ((i, sz) # type: ignore - for i, sz in enumerate(x.aval.elt_ty.shape) - if type(sz) is IndexedAxisSize) - batch_axis = make_batch_axis(x.data.ndim, 0, [(d+1, ias.lengths)]) - return BatchTracer(trace, x.data, batch_axis) elif isinstance(spec, int) or spec is None: spec = spec and canonicalize_axis(spec, len(np.shape(x))) return (BatchTracer(trace, x, spec, source_info_util.current()) if spec is not None else x) else: - if isinstance(trace, BatchTrace) and isinstance(spec, JumbleAxis): - # TODO(mvoz): A vaguely questionable assumption that it is always - # sound to have a 0 axis here. This is true for the current use cases - # and comes from how we handle intermediary products of jumbles in - # vmap. - return BatchTracer(trace, x, 0, source_info_util.current()) # TODO(mvoz): This is a terrible place to fall into if you pass # a non jumble type in, make it clearer what went wrong. assert False, f'Unexpected type in ELT? {type(x)}' @@ -287,17 +78,11 @@ def _cont(axis_size, elt, axis): return from_elt(trace, axis_size, mesh_axis, i, elt, axis) return handler(_cont, axis_size, x, spec) val, bdim = trace.to_batch_info(x) - if type(bdim) is RaggedAxis: - if spec is not jumble_axis: - # TODO(mattjj): improve this error message - raise TypeError("ragged output without using jumble_axis out_axes spec") - return _jumble_result(axis_size, bdim.stacked_axis, bdim.ragged_axes, val) - else: - try: - return matchaxis(trace.axis_data.name, axis_size, mesh_axis, - bdim, spec, val) - except SpecMatchError: - raise SpecMatchError(i, x.batch_dim, spec) from None + try: + return matchaxis(trace.axis_data.name, axis_size, mesh_axis, + bdim, spec, val) + except SpecMatchError: + raise SpecMatchError(i, x.batch_dim, spec) from None from_elt_handlers: dict[type, FromEltHandler] = {} def make_iota(axis_size: AxisSize) -> Array: @@ -320,7 +105,7 @@ def register_vmappable(data_type: type, spec_type: type, axis_size_type: type, from_elt_handlers[data_type] = from_elt if make_iota: make_iota_handlers[axis_size_type] = make_iota vmappables: dict[type, tuple[type, type]] = {} -spec_types: set[type] = {JumbleAxis} +spec_types: set[type] = set() def unregister_vmappable(data_type: type) -> None: _, axis_size_type = vmappables.pop(data_type) @@ -330,11 +115,11 @@ def unregister_vmappable(data_type: type) -> None: del make_iota_handlers[axis_size_type] global spec_types spec_types = ( - {JumbleAxis} | {spec_type for spec_type, _ in vmappables.values()} + set() | {spec_type for spec_type, _ in vmappables.values()} ) def is_vmappable(x: Any) -> bool: - return type(x) is Jumble or type(x) in vmappables + return type(x) in vmappables @lu.transformation_with_aux2 def flatten_fun_for_vmap(f: Callable, @@ -345,44 +130,6 @@ def flatten_fun_for_vmap(f: Callable, store.store(out_tree) return ans -# Propagate ragged masking rules from invars to outvars -# rule([params], [raggedness_per_invar], outvars) -> -# [raggedness_per_invar, raggedness_per_outvar] -RaggedMaskingRule = Callable[ - [list[Any], list[Any], list[Any]], tuple[list[Any], list[Any]] -] - -ragged_prop_rules: dict[core.Primitive, RaggedMaskingRule] = {} - - -def ragged_mask_elementwise_rule(eqn_params, invar_raggedness, outvars): - # TODO(mvoz): A util for getting the ragged representations - first_invar_raggedness = invar_raggedness[0] - for other_invar_raggedness in invar_raggedness[1:]: - if other_invar_raggedness != first_invar_raggedness: - raise ValueError(f'{other_invar_raggedness} != {first_invar_raggedness}') - - outvar_raggedness = [first_invar_raggedness] * len(outvars) - return invar_raggedness, outvar_raggedness - - -def ragged_mask_assert_no_op_rule(eqn_params, invar_raggedness, outvars): - if any(invar_raggedness): - raise ValueError(f'unexpected invar_raggedness: {invar_raggedness}') - return invar_raggedness, [None] * len(outvars) - - -def ragged_mask_no_op_rule(eqn_params, invar_raggedness, outvars): - return invar_raggedness, [None] * len(outvars) - - -def ragged_mask_transfer_identity( - eqn_params, invar_raggedness, outvar_raggedness -): - assert len(invar_raggedness) == 1, invar_raggedness - outvar_raggedness = invar_raggedness - return invar_raggedness, outvar_raggedness - ### tracer @@ -394,10 +141,10 @@ def ragged_mask_transfer_identity( class BatchTracer(Tracer): __slots__ = ['val', 'batch_dim', 'source_info'] - def __init__(self, trace, val, batch_dim: NotMapped | int | RaggedAxis, + def __init__(self, trace, val, batch_dim: NotMapped | int, source_info: source_info_util.SourceInfo | None = None): if config.enable_checks.value: - assert type(batch_dim) in (NotMapped, int, RaggedAxis) + assert type(batch_dim) in (NotMapped, int) if type(batch_dim) is int: aval = core.get_aval(val) assert 0 <= batch_dim < len(aval.shape) @@ -420,17 +167,8 @@ def aval(self): return aval elif type(self.batch_dim) is int: return core.mapped_aval(aval.shape[self.batch_dim], self.batch_dim, aval) - elif type(self.batch_dim) is RaggedAxis: - new_aval = core.mapped_aval( - aval.shape[self.batch_dim.stacked_axis], self.batch_dim.stacked_axis, aval) - shape = list(new_aval.shape) # pytype: disable=attribute-error - for ragged_axis, segment_lengths in self.batch_dim.ragged_axes: - size_tracer = BatchTracer(self._trace, segment_lengths, 0) - if self.batch_dim.stacked_axis < ragged_axis: - ragged_axis -= 1 - shape[ragged_axis] = size_tracer - return core.DShapedArray(shape=tuple(shape), dtype=aval.dtype, - weak_type=aval.weak_type) + else: + raise Exception("batch dim should be int or `not_mapped`") def full_lower(self): if self.batch_dim is not_mapped: @@ -450,7 +188,7 @@ def _contents(self): def get_referent(self): if self.batch_dim is None or type(self.batch_dim) is int: return core.get_referent(self.val) - else: # TODO(mattjj): could handle the RaggedAxis case? + else: return self @dataclasses.dataclass(frozen=True) @@ -510,8 +248,6 @@ def to_batch_info(self, val): return val, not_mapped def process_primitive(self, p, tracers, params): - if config.dynamic_shapes.value: - p.abstract_eval(*(map(core.get_aval, tracers)), **params) vals_in, dims_in = unzip2(map(self.to_batch_info, tracers)) args_not_mapped = all(bdim is not_mapped for bdim in dims_in) if p in fancy_primitive_batchers: @@ -546,16 +282,12 @@ def process_call(self, call_primitive, f, tracers, params): assert call_primitive.multiple_results params = dict(params, name=params.get('name', f.__name__)) vals, dims = unzip2(map(self.to_batch_info, tracers)) - segment_lens, dims = indirectify_ragged_axes(dims) f_, dims_out = batch_subtrace(f, self.tag, self.axis_data, tuple(dims)) - f_ = _update_annotation( - f_, f.in_type, self.axis_data.size, self.axis_data.name, dims, segment_lens) with core.set_current_trace(self.parent_trace): - vals_out = call_primitive.bind(f_, *segment_lens, *vals, **params) - vals_out, dims_out = resolve_ragged_axes(vals_out, dims_out()) + vals_out = call_primitive.bind(f_, *vals, **params) src = source_info_util.current() - return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out)] + return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out())] def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): vals, dims = unzip2(map(self.to_batch_info, tracers)) @@ -708,43 +440,12 @@ def batch_subtrace(f, store, tag, axis_data, in_dims, *in_vals): trace = BatchTrace(parent_trace, tag, axis_data) with core.set_current_trace(trace): in_dims = in_dims() if callable(in_dims) else in_dims - in_vals, in_dims = resolve_ragged_axes(in_vals, in_dims) in_tracers = [BatchTracer(trace, x, dim, source_info_util.current()) if dim is not None else x for x, dim in zip(in_vals, in_dims)] outs = f(*in_tracers) out_vals, out_dims = unzip2(map(trace.to_batch_info, outs)) - segment_lens, out_dims = indirectify_ragged_axes(out_dims) store.store(out_dims) - return (*segment_lens, *out_vals) - -def indirectify_ragged_axes(dims): - if not any(type(d) is RaggedAxis for d in dims): - return [], dims - axis_map : dict[int, tuple[Array, pe.DBIdx]] = collections.OrderedDict() - def canonicalize_segment_lengths(d: RaggedAxis) -> RaggedAxis: - new_ragged_axes = [] - for ragged_axis, segment_lengths in d.ragged_axes: - _, dbidx = axis_map.setdefault( - id(core.get_referent(segment_lengths)), - (segment_lengths, pe.DBIdx(len(axis_map)))) - new_ragged_axes.append((ragged_axis, dbidx)) - return RaggedAxis(d.stacked_axis, tuple(new_ragged_axes)) - new_dims = [canonicalize_segment_lengths(d) - if isinstance(d, RaggedAxis) else d for d in dims] - segment_lens = [s for s, _ in axis_map.values()] - return segment_lens, new_dims - -def indirectify_ragged_axes_against_inputs_outputs(dims, in_vals, out_vals): - def canonicalize_segment_lengths(d: RaggedAxis) -> RaggedAxis: - new_ragged_axes = [] - for ragged_axis, segment_lengths in d.ragged_axes: - key = id(core.get_referent(segment_lengths)) - value = _locate_value(key, in_vals, out_vals) - new_ragged_axes.append((ragged_axis, value)) - return RaggedAxis(d.stacked_axis, tuple(new_ragged_axes)) - new_dims = [canonicalize_segment_lengths(d) - if isinstance(d, RaggedAxis) else d for d in dims] - return new_dims + return out_vals def _locate_value(key, in_vals, out_vals): for ix, candidate in enumerate(in_vals): @@ -755,58 +456,27 @@ def _locate_value(key, in_vals, out_vals): return pe.OutDBIdx(ix) assert False, "Could not find segment lengths" -def resolve_ragged_axes(vals, dims): - idxs = {lengths_idx.val for d in dims if isinstance(d, RaggedAxis) - for (_, lengths_idx) in d.ragged_axes} - dims = [RaggedAxis(d.stacked_axis, - tuple((ragged_axis, vals[lengths_idx.val]) - for ragged_axis, lengths_idx in d.ragged_axes)) - if isinstance(d, RaggedAxis) else d for d in dims] - vals = [x for i, x in enumerate(vals) if i not in idxs] - return vals, dims - -def resolve_ragged_axes_against_inputs_outputs(in_vals, out_vals, dims): - def fetch(idx): - if isinstance(idx, pe.InDBIdx): - return in_vals[idx.val] - else: - assert isinstance(idx, pe.OutDBIdx) - return out_vals[idx.val] - - dims = [RaggedAxis(d.stacked_axis, - tuple((ragged_axis, fetch(lengths_idx)) - for ragged_axis, lengths_idx in d.ragged_axes)) - if isinstance(d, RaggedAxis) else d for d in dims] - return dims - ### API for batching jaxprs -# TODO(axch): parameterize RaggedAxis annotations by a type parameter so as to -# indicate whether we're dealing with instances that contain Arrays or DBIdx. -# Can reuse same pattern for all dynamic shape stuff. def batch_jaxpr2( closed_jaxpr: core.ClosedJaxpr, axis_data, - in_axes: tuple[int | NotMapped | RaggedAxis, ...], - ) -> tuple[core.ClosedJaxpr, tuple[int | NotMapped | RaggedAxis, ...]]: + in_axes: tuple[int | NotMapped, ...], + ) -> tuple[core.ClosedJaxpr, tuple[int | NotMapped ]]: return _batch_jaxpr2(closed_jaxpr, axis_data, tuple(in_axes)) @weakref_lru_cache def _batch_jaxpr2( closed_jaxpr: core.ClosedJaxpr, axis_data, - in_axes: tuple[int | NotMapped | RaggedAxis, ...], + in_axes: tuple[int | NotMapped ], ) -> tuple[core.ClosedJaxpr, tuple[int | NotMapped, ...]]: f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr), debug_info=closed_jaxpr.jaxpr.debug_info) f, out_axes = _batch_jaxpr_inner(f, axis_data) f = _batch_jaxpr_outer(f, axis_data, in_axes) - in_axes2, avals_in = unzip2([ - handle_ragged(closed_jaxpr.in_avals, dim, aval) - if isinstance(dim, RaggedAxis) else (dim, aval) - for dim, aval in zip(in_axes, closed_jaxpr.in_avals)]) avals_in2 = [] - for aval, b in unsafe_zip(avals_in, in_axes2): + for aval, b in unsafe_zip(closed_jaxpr.in_avals, in_axes): if b is not_mapped: avals_in2.append(aval) else: @@ -819,14 +489,6 @@ def _batch_jaxpr2( jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in2) return core.ClosedJaxpr(jaxpr_out, consts), out_axes() -def handle_ragged(in_avals: list[core.AbstractValue], dim: RaggedAxis, - aval: core.ShapedArray) -> tuple[int, core.ShapedArray]: - new_shape = list(aval.shape) - for i, dbi in dim.ragged_axes: - new_shape[i - (dim.stacked_axis < i)] = in_avals[dbi.val].dtype.bound - new_aval = aval.update(shape=tuple(new_shape)) - return dim.stacked_axis, new_aval - def batch_jaxpr(closed_jaxpr, axis_data, in_batched, instantiate): inst = tuple(instantiate) if isinstance(instantiate, list) else instantiate return _batch_jaxpr(closed_jaxpr, axis_data, tuple(in_batched), inst) @@ -864,7 +526,6 @@ def _batch_jaxpr_axes(closed_jaxpr: core.ClosedJaxpr, def _batch_jaxpr_inner(f, store, axis_data, tag, in_axes, *in_vals): with core.take_current_trace() as parent_trace: trace = BatchTrace(parent_trace, tag, axis_data) - _, in_axes = resolve_ragged_axes(in_vals, in_axes) in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val for val, dim in zip(in_vals, in_axes)] # TODO(yashkatariya): Instead of `add_explicit_mesh_axis_names`, we should @@ -875,9 +536,7 @@ def _batch_jaxpr_inner(f, store, axis_data, tag, in_axes, *in_vals): core.add_explicit_mesh_axis_names(axis_data.explicit_mesh_axis)): outs = f(*in_tracers) out_vals, out_axes = unzip2(map(trace.to_batch_info, outs)) - new_out_axes = indirectify_ragged_axes_against_inputs_outputs( - out_axes, in_vals, out_vals) - store.store(new_out_axes) + store.store(out_axes) return out_vals @lu.transformation_with_aux2 @@ -1079,23 +738,6 @@ def out_axis(axes, axis): if 'input_shape' in params: params = dict(params, input_shape=operand.shape) return prim.bind(operand, axes=axes, **params), bdim_out - elif isinstance(bdim, RaggedAxis): - assert ident is not None, "TODO Ragged batching a reduction requires an identity" - axes = tuple(np.where(np.less(axes, bdim.stacked_axis), axes, np.add(axes, 1))) - bdim_out = out_axis(axes, bdim.stacked_axis) - # For each ragged_axis, we either mask the operand there or append - # it to the set of axes that will be ragged in the result. - axes_to_mask = [] - ragged_axes_out = [] - for ragged_axis, segment_lengths in bdim.ragged_axes: - if ragged_axis in axes: - axes_to_mask.append((ragged_axis, segment_lengths)) - else: - ragged_axes_out.append((out_axis(axes, ragged_axis), segment_lengths)) - operand = mask_ragged_axes( - operand, ident, RaggedAxis(bdim.stacked_axis, tuple(axes_to_mask))) - result = prim.bind(operand, axes=axes, **params) - return result, make_batch_axis(operand.ndim, bdim_out, ragged_axes_out) else: assert False @@ -1108,42 +750,6 @@ def expand_dims_batcher(prim, args, dims, **params): out = prim.bind(*args, **params) return (out, (0,) * len(out)) if prim.multiple_results else (out, 0) -def mask_ragged_axes(operand: Array, ident, axis_spec: RaggedAxis) -> Array: - # TODO(mattjj, axch) Can we mask multiple axes more efficiently at - # once, rather than one at a time? - for ragged_axis, segment_lengths in axis_spec.ragged_axes: - this_axis_spec = RaggedAxis( - axis_spec.stacked_axis, ((ragged_axis, segment_lengths),)) - operand = _mask_one_ragged_axis(operand, ident, this_axis_spec) - return operand - -def _mask_one_ragged_axis( - operand: Array, ident, axis_spec: RaggedAxis) -> Array: - # Callers of this utility, via reducer_batcher() or defreducer(), - # must be in a context where lax is importable. - from jax import lax # pytype: disable=import-error - assert len(axis_spec.ragged_axes) == 1, "Mask just one ragged axis at a time" - ragged_axis, segment_lengths = axis_spec.ragged_axes[0] - value = ident(operand.dtype) - positions = lax.broadcasted_iota('int32', operand.shape, ragged_axis) - # TODO(mattjj, axch) can't get ._data, need to convert it - # lengths = lax.convert_element_type(segment_lengths._data, 'int32') - lengths = lax.convert_element_type(segment_lengths, 'int32') - limits = lax.broadcast_in_dim( - lengths, operand.shape, [axis_spec.stacked_axis]) - mask = positions < limits - return lax.select(mask, operand, lax.broadcast(value, operand.shape)) - -def move_stacked_axis(operand, bdim, dst): - dst = canonicalize_axis(dst, operand.ndim) - if isinstance(bdim, int): - return moveaxis(operand, bdim, dst), dst - elif isinstance(bdim, RaggedAxis): - result = moveaxis(operand, bdim.stacked_axis, dst) - return result, bdim.move_stacked_axis(dst) - else: - raise TypeError(f"Unrecognized batch dimension type {bdim}") - ### general utilities for manipulating axes on jaxpr types (not vmappables) def broadcast(x, sz, axis, mesh_axis): @@ -1175,12 +781,6 @@ def matchaxis2(axis_data, src, dst, x, sum_match=False): src, dst, x, sum_match) def matchaxis(axis_name, sz, mesh_axis, src, dst, x, sum_match=False): - if dst == jumble_axis: - x = bdim_at_front(x, src, sz) - elt_ty = x.aval.update(shape=x.shape[1:]) - aval = JumbleTy(core.Var(core.ShapedArray((), np.dtype('int32'))), - x.shape[0], elt_ty) - return Jumble(aval, x) try: _ = core.get_aval(x) except TypeError as e: diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index f9fca20bfa65..2be7f5088609 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -184,7 +184,7 @@ def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type: f"No dtype_to_ir_type handler for dtype: {dtype}") from err return ir_type_factory() -def _array_ir_types(aval: core.ShapedArray | core.DShapedArray) -> ir.Type: +def _array_ir_types(aval: core.ShapedArray) -> ir.Type: aval = core.physical_aval(aval) # type: ignore if not core.is_constant_shape(aval.shape): return _dynamic_array_ir_types(aval) # type: ignore @@ -209,7 +209,6 @@ def aval_to_ir_type(aval: core.AbstractValue) -> IrTypes: ir_type_handlers[core.ShapedArray] = _array_ir_types ir_type_handlers[core.AbstractToken] = lambda _: hlo.TokenType.get() -ir_type_handlers[core.DShapedArray] = _dynamic_array_ir_types # This is a backwards compatibility shim for external users of jax.mlir apis. def aval_to_ir_types(aval: core.AbstractValue) -> tuple[ir.Type, ...]: @@ -974,27 +973,23 @@ def sharded_aval(aval: core.AbstractValue, return aval if isinstance(aval, core.AbstractToken): return aval - if not isinstance(aval, (core.ShapedArray, core.DShapedArray)): + if not isinstance(aval, core.ShapedArray): raise NotImplementedError return aval.update(sharding.shard_shape(aval.shape), sharding=None) # type: ignore def eval_dynamic_shape(ctx: LoweringRuleContext, shape: core.Shape) -> tuple[int | Value, ...]: - if config.dynamic_shapes.value: - assert ctx.axis_size_env is not None - return tuple(ctx.axis_size_env.get(d, d) for d in shape) # type: ignore - else: - ctx = ctx.replace( - primitive="eval_dynamic_shape", - avals_in=[core.dim_value_aval()] * len(ctx.module_context.shape_poly_state.dim_vars), - tokens_out=None) + ctx = ctx.replace( + primitive="eval_dynamic_shape", + avals_in=[core.dim_value_aval()] * len(ctx.module_context.shape_poly_state.dim_vars), + tokens_out=None) - res = lower_fun( - partial(core.evaluate_shape, shape, ctx.module_context.shape_poly_state.dim_vars), - multiple_results=True)(ctx, *ctx.dim_var_values) - return tuple(operator.index(d) if core.is_constant_dim(d) else d_ir - for d, d_ir in zip(shape, flatten_ir_values(res))) + res = lower_fun( + partial(core.evaluate_shape, shape, ctx.module_context.shape_poly_state.dim_vars), + multiple_results=True)(ctx, *ctx.dim_var_values) + return tuple(operator.index(d) if core.is_constant_dim(d) else d_ir + for d, d_ir in zip(shape, flatten_ir_values(res))) # TODO: replace usage of eval_dynamic_shape_as_vals with eval_dynamic_shape_as_ivals def eval_dynamic_shape_as_vals(ctx: LoweringRuleContext, @@ -1083,7 +1078,7 @@ def _to_physical_op_sharding( assert isinstance(sharding, JSharding) if isinstance(aval, AbstractRef): return _to_physical_op_sharding(ctx, aval.inner_aval, sharding) - assert isinstance(aval, (core.ShapedArray, core.DShapedArray)) + assert isinstance(aval, core.ShapedArray) if dtypes.issubdtype(aval.dtype, dtypes.extended): sharding = sharding_impls.physical_sharding(aval, sharding) aval = core.physical_aval(aval) @@ -1288,16 +1283,12 @@ def lower_jaxpr_to_module( # Create a keepalives list that will be mutated during the lowering. keepalives: list[Any] = [] host_callbacks: list[Any] = [] + # Find the dimension variables + all_dim_poly = [d for aval in sharded_in_avals if hasattr(aval, "shape") + for d in aval.shape if not core.is_constant_dim(d)] + dim_vars = tuple(sorted(functools.reduce(lambda acc, new: acc.union(new._get_vars()), + all_dim_poly, set()))) - dim_vars: Sequence[str] - if not config.dynamic_shapes.value: - # Find the dimension variables - all_dim_poly = [d for aval in sharded_in_avals if hasattr(aval, "shape") - for d in aval.shape if not core.is_constant_dim(d)] - dim_vars = tuple(sorted(functools.reduce(lambda acc, new: acc.union(new._get_vars()), - all_dim_poly, set()))) - else: - dim_vars = () ctx = ModuleContext(backend=backend, platforms=platforms, axis_context=axis_context, @@ -1974,7 +1965,7 @@ def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value: # For example: if the key.shape is (8, 2) and key_data(key).shape is (8, 2, 2), # then the sharding will be P(P.UNCONSTRAINED, P.UNCONSTRAINED, None). # The below custom call achieves the sharding like above example. - assert isinstance(aval, (core.ShapedArray, core.DShapedArray)) + assert isinstance(aval, core.ShapedArray) if config.use_shardy_partitioner.value: physical_ndim = core.physical_aval(aval).ndim s = SdyArray( @@ -2065,8 +2056,7 @@ def write(v: core.Var, node: IrValues): eqn.ctx.manager): # TODO(mattjj, phawkins): support caching for dynamic shapes. can_cache_lowering = ( - eqn.primitive not in _uncacheable_primitives and - not config.dynamic_shapes.value) + eqn.primitive not in _uncacheable_primitives) if can_cache_lowering: loc = source_info_to_location(ctx, None, eqn_name_stack, eqn.source_info.traceback) @@ -2077,10 +2067,6 @@ def write(v: core.Var, node: IrValues): else: # If we cannot cache the lowering, lower inline. axis_size_env = None - if config.dynamic_shapes.value: - axis_size_env = {d: read(d) - for a in avals_in if type(a) is core.DShapedArray - for d in a.shape if type(d) is core.Var} rule_ctx = LoweringRuleContext( module_context=ctx, primitive=eqn.primitive, name_stack=eqn_name_stack, @@ -2465,26 +2451,9 @@ def f_lowered(ctx: LoweringRuleContext, *args, **params): wrapped_fun = lu.wrap_init(f, params, debug_info=api_util.debug_info("lower_fun", fun, args, {})) - if config.dynamic_shapes.value: - # We might be applying this function to arguments with dynamic shapes, - # i.e. there might be Vars in the shape tuples of ctx.avals_in. In that - # case, we need to form a jaxpr with leading binders for those axis size - # arguments (by computing an InputType and using trace_to_jaxpr_dynamic2), - # and we need to call jaxpr_subcomp with these arguments made explicit. - assert ctx.axis_size_env is not None - args = (*ctx.axis_size_env.values(), *args) - idx = {d: core.DBIdx(i) for i, d in enumerate(ctx.axis_size_env)} - i32_aval = core.ShapedArray((), np.dtype('int32')) - implicit_args = [(i32_aval, False)] * len(ctx.axis_size_env) - explicit_args = [(a.update(shape=tuple(idx.get(d, d) for d in a.shape)) # type: ignore - if type(a) is core.DShapedArray else a, True) - for a in ctx.avals_in] - wrapped_fun = lu.annotate(wrapped_fun, (*implicit_args, *explicit_args)) - jaxpr, _, consts_for_constvars = pe.trace_to_jaxpr_dynamic2(wrapped_fun) - else: - jaxpr, _, consts_for_constvars = pe.trace_to_jaxpr_dynamic(wrapped_fun, - ctx.avals_in) - # TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out? + jaxpr, _, consts_for_constvars = pe.trace_to_jaxpr_dynamic(wrapped_fun, + ctx.avals_in) + # TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out? if ctx.platforms is not None: sub_context = ctx.module_context.replace(platforms=ctx.platforms) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index fc4ffda4396d..0dfffb0f3efa 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -16,7 +16,7 @@ from __future__ import annotations from collections import namedtuple -from collections.abc import Callable, Sequence, Hashable +from collections.abc import Callable, Sequence import contextlib from dataclasses import dataclass from functools import partial @@ -40,11 +40,10 @@ from jax._src.core import ( Trace, Tracer, TraceTag, Jaxpr, Literal, get_aval, AbstractValue, ClosedJaxpr, new_jaxpr_eqn, Var, DropVar, Atom, JaxprEqn, Primitive, - ShapedArray, DShapedArray, mapped_aval, unmapped_aval, DBIdx, InDBIdx, - OutDBIdx, InputType, OutputType, get_referent, JaxprEqnContext, typeof) + mapped_aval, unmapped_aval, get_referent, JaxprEqnContext, typeof) from jax._src.source_info_util import SourceInfo from jax._src.state.types import AbstractRef, ReadEffect -from jax._src.tree_util import (PyTreeDef, treedef_tuple, register_static, +from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_flatten, tree_unflatten) from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list, merge_lists, partition_list, OrderedSet, @@ -64,42 +63,6 @@ def identity(x): return x AttrKind = Any PyTree = Any -def _update_annotation_known( - f: lu.WrappedFun, - orig_type: InputType | None, - in_knowns: list[bool] - ) -> lu.WrappedFun: - if orig_type is None: return f - # orig_type might contain DBIdx, but we're tossing out some args so we have to - # re-index. moreover some of the implicit args may not be needed anymore. - # so we basically just re-infer the lambda input type - if (all(e for _, e in orig_type) and - not any(type(d) is DBIdx for a, _ in orig_type for d in a.shape - if type(a) is DShapedArray)): - new_type = [ty for ty, known in zip(orig_type, in_knowns) if known] - return lu.annotate(f, tuple(new_type)) - - # Replace DBIdx with names, prune down to explicit only. - class Name: - def __init__(self, a): self.a = a - names = [Name(a) for a, _ in orig_type] - avals = [a.update(shape=tuple(names[d.val] if type(d) is DBIdx else d - for d in a.shape)) - if type(a) is DShapedArray else a for a, e in orig_type if e] - avals = [a for a, known in zip(avals, in_knowns) if known] - # Figure out the implicit part: names which aren't explicit and known. - expl_names = [o for o, (_, e) in zip(names, orig_type) if e] - expl_names = [o for o, k in zip(expl_names, in_knowns) if k] - expl_names_ = set(expl_names) - impl_names = {d for a in avals if type(a) is DShapedArray for d in a.shape - if type(d) is Name and d not in expl_names_} - impl_part = [(n.a, False) for n in impl_names] # type: ignore - # Figure out the explicit part: known explicit avals, replacing names w/ dbidx - name_map = {n: DBIdx(i) for i, n in enumerate((*impl_names, *expl_names))} - expl_part = [(a.update(shape=tuple(name_map.get(d, d) for d in a.shape)) - if type(a) is DShapedArray else a, True) for a in avals] - return lu.annotate(f, (*impl_part, *expl_part)) - class PartialVal(tuple): """Partial value: either a known value or an unknown (abstract) value. @@ -187,14 +150,6 @@ def new_arg(self, pval: PartialVal) -> JaxprTracer: # known inputs (if it needs them, then they get passed through residuals). if const is None: aval = pval.get_aval() - if type(aval) is DShapedArray: - # TODO(dougalm): Fix the type error and remove the pytype pragmas. - # pytype: disable=attribute-error - shape = [self.new_instantiated_const(d) - if isinstance(d, Tracer) and d._trace.level < self.level else d - for d in aval.shape] - # pytype: enable=attribute-error - aval = aval.update(shape=tuple(shape)) return JaxprTracer(self, PartialVal.unknown(aval), LambdaBinding()) else: return self.new_const(const) @@ -282,27 +237,12 @@ def process_call(self, primitive, f: lu.WrappedFun, tracers, params): const_params = update_params(params, in_knowns, 0) # Run the call, getting known out vals and aux data used for staged-out call - fun_and_args = (_update_annotation_known(f_, f.in_type, in_knowns),) + tuple(in_consts) + fun_and_args = (f_,) + tuple(in_consts) out = primitive.bind_with_trace(self.parent_trace, fun_and_args, const_params) fwds, out_knowns, out_type, jaxpr, env = aux() # Split apart known outputs from the original call and non-fwded residuals. out_consts, non_fwd_res = split_list(out, [sum(out_knowns)]) - - # Form the complete list of residuals by forwarding some inputs. - if config.dynamic_shapes.value: - # With dynamic shapes, we may need to forward implicit arguments. - assert f.in_type is not None, "f must be annotated with lu.annotate()" - in_consts_, in_knowns_ = iter(in_consts), iter(in_knowns) - in_consts_full = [None] * len(f.in_type) - for idx, (aval, explicit) in enumerate(f.in_type): - if explicit and next(in_knowns_): - c = in_consts_full[idx] = next(in_consts_) - if aval.shape: - for d1, d2 in zip(aval.shape, c.shape): - if type(d1) is DBIdx: - in_consts_full[d1.val] = d2 - else: - in_consts_full = in_consts + in_consts_full = in_consts res = subs_list(fwds, in_consts_full, non_fwd_res) # Create the input tracers for the staged-out (unknown-value) call. @@ -317,19 +257,8 @@ def process_call(self, primitive, f: lu.WrappedFun, tracers, params): staged_params = dict(params, call_jaxpr=new_jaxpr) staged_params = update_params(staged_params, map(op.not_, in_knowns), num_new_args) - # The outputs of the staged-out call are Tracers with the new eqn as recipe. - if config.dynamic_shapes.value: - # With dynamic shapes, we may need to substitute Tracers into avals. - out_tracers = [] - for aval, _ in out_type: - if type(aval) is DShapedArray: - shape = [[*res_tracers, *env_tracers, *unknown_arg_tracers][d.val] - if type(d) is InDBIdx else d for d in aval.shape] - aval = aval.update(shape=tuple(shape)) - out_tracers.append(JaxprTracer(self, PartialVal.unknown(aval), None)) - else: - out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None) - for a in out_type] + out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None) + for a in out_type] name_stack = self._current_truncated_name_stack() source = source_info_util.current().replace(name_stack=name_stack) eqn = new_eqn_recipe(self, (*res_tracers, *env_tracers, *unknown_arg_tracers), @@ -568,8 +497,6 @@ def parents(self) -> Sequence[JaxprTracer]: if isinstance(self.recipe, JaxprEqnRecipe): # TODO broadcast_in_dim can create a new tracer... return self.recipe.in_tracers - elif isinstance(self.aval, DShapedArray): - return [d for d in self.aval.shape if isinstance(d, JaxprTracer)] else: return [] @@ -814,19 +741,11 @@ def get_atom(t: JaxprTracer) -> Atom: def newvar(t: JaxprTracer | None) -> Var: assert t is not None - var = gensym(type_substitute(t.aval)) + var = gensym(t.aval) var_ = t_to_var.setdefault(id(t), var) assert var is var_ return var - def type_substitute(aval: AbstractValue) -> AbstractValue: - if isinstance(aval, DShapedArray): - # Replace any Tracers in aval.shape with Vars or Literal values - shape = [get_atom(d) if type(d) is JaxprTracer else d for d in aval.shape] - shape = [d.val if type(d) is Literal else d for d in shape] - aval = aval.update(shape=tuple(shape)) - return aval - processed_eqn_ids = set() eqns: list[core.JaxprEqn] = [] @@ -843,7 +762,7 @@ def sort_key(t): # TODO broadcast_in_dim can create a new tracer, not present in parents if r.eqn_id not in processed_eqn_ids: in_atoms = map(get_atom, r.in_tracers) - outvars = [DropVar(type_substitute(a)) if rf() is None else newvar(rf()) + outvars = [DropVar(a) if rf() is None else newvar(rf()) for a, rf in zip(r.out_avals, r.out_tracer_refs)] eqns.append(new_jaxpr_eqn(in_atoms, outvars, r.primitive, r.params, r.effects, r.source_info, r.ctx)) @@ -1874,25 +1793,7 @@ def to_jaxpr( jaxpr = Jaxpr(constvars, self.invars, outvars, eqns, effs, debug_info, is_high) return jaxpr, list(constvals) - def to_jaxpr2(self, out_tracers: Sequence[core.Tracer], - debug_info: core.DebugInfo): - eqns = self.get_eqns() - outvars = [t.val for t in out_tracers] - constvars, constvals = unzip2(self.constvar_to_val.copy().items()) - constvals = [c.canonical for c in constvals] - constvars, constvals = _drop_unused_vars(constvars, constvals, eqns, outvars) - effs = make_jaxpr_effects(constvars, self.invars, outvars, eqns) - jaxpr = Jaxpr(constvars, self.invars, outvars, eqns, effs, debug_info) - jaxpr, out_type = _add_implicit_outputs(jaxpr) - config.enable_checks.value and core.check_jaxpr(jaxpr) - return jaxpr, out_type, constvals - def newvar(self, aval): - if isinstance(aval, DShapedArray): - # this aval may have tracers in it, so we replace those with variables - new_shape = [d.val if isinstance(d, Tracer) else d for d in aval.shape] - new_shape = [d.val if isinstance(d, Literal) else d for d in new_shape] - aval = aval.update(shape=tuple(new_shape)) if isinstance(aval, core.AvalQDD): return self.gensym(aval.aval, initial_qdd=aval.qdd) else: @@ -1937,8 +1838,6 @@ def vars(atom: Atom) -> list[Var]: if isinstance(atom, Literal): return [] aval = atom.aval - if isinstance(aval, DShapedArray): - return [atom] + [d for d in aval.shape if isinstance(d, Var)] return [atom] used: set[Var] = {v for atom in outvars for v in vars(atom)} for eqn in eqns[::-1]: @@ -2072,7 +1971,6 @@ def new_const(self, c, source_info: SourceInfo, if aval.has_qdd: with core.set_current_trace(self.parent_trace or core.eval_trace): aval = core.AvalQDD(aval, core.cur_qdd(c)) # type: ignore - aval = self._lift_tracers_in_aval(aval, source_info) tracer = self._new_const(aval, c, source_info) return tracer @@ -2109,14 +2007,6 @@ def get_const(self, tracer) -> Any: const = const.canonical return const - def _lift_tracers_in_aval(self, aval, source_info: SourceInfo): - if (not isinstance(aval, DShapedArray) or - not any(isinstance(d, Tracer) for d in aval.shape)): - return aval - shape = [self.to_jaxpr_tracer(d, source_info) if isinstance(d, Tracer) else d - for d in aval.shape] - return aval.update(shape=tuple(shape)) - def cur_qdd(self, x): source_info = source_info_util.current() return self.to_jaxpr_tracer(x, source_info=source_info).mutable_qdd.cur_val @@ -2144,7 +2034,7 @@ def default_process_primitive(self, primitive, tracers, params, # TODO(mattjj,dougalm): clean up how we check for new-style hi primitives if primitive is call_hi_primitive_p: out_avals, effs = params['prim'].out_avals_flat, set() # TODO effs - elif (primitive.name == "custom_lin" or config.dynamic_shapes.value or + elif (primitive.name == "custom_lin" or primitive.is_effectful and primitive.is_effectful(params)): out_avals, effs = primitive.abstract_eval(*aval_qdds, **params) else: @@ -2181,37 +2071,33 @@ def default_process_primitive(self, primitive, tracers, params, self.frame.add_eqn(eqn) return out_tracers if primitive.multiple_results else out_tracers.pop() - def process_call(self, call_primitive, f: lu.WrappedFun, explicit_tracers, + def process_call(self, call_primitive, f: lu.WrappedFun, in_tracers, params): source_info = source_info_util.current() to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) - in_type = (tuple((get_aval(t), True) for t in explicit_tracers) + in_type = (tuple(get_aval(t) for t in in_tracers) if f.in_type is None else f.in_type) f.in_type = None assert in_type is not None - implicit_tracers = _extract_implicit_args(self, in_type, explicit_tracers, - source_info) - in_tracers = map(to_jaxpr_tracer, [*implicit_tracers, *explicit_tracers]) + in_tracers = map(to_jaxpr_tracer, in_tracers) # TODO(mattjj): check in_tracers are consistent with f.in_type annotation - jaxpr, out_type, consts = _cached_trace_to_jaxpr(f, in_type) + jaxpr, out_avals, consts = _cached_trace_to_jaxpr(f, in_type) if params.get('inline', False): return core.eval_jaxpr(jaxpr, consts, *in_tracers, propagate_source_info=False) - out_avals = [aval for aval, _ in out_type] new_jaxpr = convert_constvars_jaxpr(jaxpr) if isinstance(call_primitive, core.ClosedCallPrimitive): new_jaxpr = close_jaxpr(new_jaxpr) # type: ignore new_params = dict(params, call_jaxpr=new_jaxpr) update_params = call_param_updaters.get(call_primitive) if update_params: - new_params = update_params(new_params, [True] * len(explicit_tracers), - len(consts) + len(implicit_tracers)) + new_params = update_params(new_params, [True] * len(in_tracers), + len(consts)) const_tracers = map(to_jaxpr_tracer, consts) - out_tracers = self.emit_eqn( + return self.emit_eqn( [*const_tracers, *in_tracers], out_avals, call_primitive, new_params, new_params['call_jaxpr'].effects, source_info=source_info) - return [t for t, (_, keep) in zip(out_tracers, out_type) if keep] def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): source_info = source_info_util.current() @@ -2365,10 +2251,9 @@ def to_jaxpr(self, out_tracers: Sequence[Tracer], return self.frame.to_jaxpr(self, out_tracers, debug_info, source_info) - @lu.cache def _cached_trace_to_jaxpr(f, in_type): - jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(lu.annotate(f, in_type)) + jaxpr, out_type, consts = trace_to_jaxpr_dynamic(lu.annotate(f, in_type), in_type) return jaxpr, out_type, consts @@ -2457,8 +2342,7 @@ def trace_to_jaxpr_dynamic( # rooted at the enclosing jaxpr. with core.ensure_no_leaks(trace), source_info_util.reset_name_stack(): source_info = source_info_util.current() - in_tracers = _input_type_to_tracers( - partial(trace.new_arg, source_info=source_info), in_avals) + in_tracers = map(partial(trace.new_arg, source_info=source_info), in_avals) in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep] with core.set_current_trace(trace): @@ -2516,176 +2400,6 @@ def _check_no_returned_refs( f"a mutable array reference of type {a.str_short()}{loc}, but " f"mutable array references cannot be returned.{origin_info}") -@profiler.annotate_function -def trace_to_jaxpr_dynamic2( - fun: lu.WrappedFun, - ) -> tuple[Jaxpr, OutputType, list[Any]]: - assert fun.in_type is not None, "fun must be annotated with lu.annotate()" - config.enable_checks.value and fun.debug_info.assert_arg_names(len(fun.in_type)) - - parent_trace = core.trace_ctx.trace - trace = DynamicJaxprTrace(fun.debug_info, parent_trace=parent_trace) - with core.ensure_no_leaks(trace), source_info_util.reset_name_stack(): - source_info = source_info_util.current() - in_avals, keep_inputs = unzip2(fun.in_type) - in_tracers = _input_type_to_tracers( - partial(trace.new_arg, source_info=source_info), in_avals) - in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep] - with core.set_current_trace(trace): - ans = fun.call_wrapped(*in_tracers) - out_tracers = map(partial(trace.to_jaxpr_tracer, source_info=source_info), ans) - jaxpr = trace.frame.to_jaxpr2(out_tracers, fun.debug_info) - del trace, in_tracers, out_tracers, ans - return jaxpr - -AbstractedAxisName = Hashable -AbstractedAxesSpec = Union[ - dict[int, AbstractedAxisName], - tuple[AbstractedAxisName, ...], -] - -@register_static -class DoesNotExist: ... -dne_sentinel = DoesNotExist() - - -def infer_lambda_input_type( - axes_specs: Sequence[AbstractedAxesSpec] | None, - args: Sequence[Any] - ) -> InputType: - ndims = [getattr(get_aval(x), 'ndim', 0) for x in args] - partial_specs = _canonicalize_specs(ndims, axes_specs) - specs = _complete_specs(args, partial_specs) - idxs, implicit_types = _collect_implicit(args, specs) - implicit_sig = [(ty, False) for ty in implicit_types] - explicit_sig = [(_arg_type(idxs, x, s), True) for x, s in zip(args, specs)] - input_type = (*implicit_sig, *explicit_sig) - lu._check_input_type(input_type) - return input_type - -def _spec_to_dict(spec: AbstractedAxesSpec) -> dict[int, AbstractedAxisName]: - if isinstance(spec, tuple): - return {i: d for i, d in enumerate(spec) if d is not None} - else: - return spec - -def _canonicalize_specs( - ndims: Sequence[int], specs: Sequence[AbstractedAxesSpec] | None - ) -> list[dict[int, AbstractedAxisName]]: - if specs is None: - return [{}] * len(ndims) - else: - return [_spec_to_dict(s) for n, s in zip(ndims, specs)] - -def _complete_specs( - args: Sequence[Any], partial_specs: list[dict[int, AbstractedAxisName]] - ) -> list[dict[int, AbstractedAxisName]]: - # The abstracted axes specification in `partial_specs` is partial in the sense - # that there could be additional axis abstraction represented in `args` due to - # Tracers existing in the shapes of elements of `args`. The purpose of this - # function is to produce a full specification, for each argument mapping any - # abstracted axis positions to a name, introducing new names as needed for - # Tracers in axis sizes which don't already correspond to abstracted axis - # names (with one new name per unique Tracer object id). - - # Identify each user-supplied name in partial_specs with a size. - sizes: dict[AbstractedAxisName, int | DynamicJaxprTracer] = {} - for x, spec in zip(args, partial_specs): - for i, name in spec.items(): - d = sizes.setdefault(name, x.shape[i]) - if d is not x.shape[i] and d != x.shape[i]: - raise TypeError(f"Provided size {d} for {name} does not match prior associated name for {name} : {x.shape[i]}") - - # Introduce new names as needed for Tracers in shapes. - named_tracers: dict[TracerId, AbstractedAxisName] = { - id(d): name for name, d in sizes.items() if isinstance(d, Tracer)} - specs: list[dict[int, AbstractedAxisName]] = [] - for x, spec in zip(args, partial_specs): - if isinstance(get_aval(x), DShapedArray): - spec = dict(spec) - for i, d in enumerate(x.shape): - if isinstance(d, Tracer): - spec[i] = named_tracers.get(id(d), TracerAsName(d)) - specs.append(spec) - - # Assert that `specs` is now complete in the sense that there are no Tracers - # which don't correspond to an AbstractedAxisName. - assert all(not spec or not any(isinstance(d, Tracer) and i not in spec - for i, d in enumerate(x.shape)) - for x, spec in zip(args, specs)) - return specs - - -def _collect_implicit( - args: Sequence[Any], specs: list[dict[int, AbstractedAxisName]] - ) -> tuple[dict[AbstractedAxisName, DBIdx], list[AbstractValue]]: - # Given an explicit argument list and a specification of abstracted axes, we - # want to produce an InputType by identifying AbstractedAxisNames with DBIdxs - # and figuring out which AbstractedAxisNames correspond to implicit arguments. - - idxs: dict[AbstractedAxisName, DBIdx] = {} - implicit_types: list[AbstractValue] = [] - explicit_tracers: dict[TracerId, int] = {} - counter = it.count() - - # Add implicit arguments to idxs. - for explicit_idx, (x, spec) in enumerate(zip(args, specs)): - for i, name in spec.items(): - if name not in idxs and id(x.shape[i]) not in explicit_tracers: - idxs[name] = DBIdx(next(counter)) - implicit_types.append(get_aval(x.shape[i])) - if isinstance(x, Tracer): - explicit_tracers.setdefault(id(x), explicit_idx) # use the first - - # Now that we know the implicit args, add explicit args to idxs. - offset = len(implicit_types) - for x, spec in zip(args, specs): - for i, name in spec.items(): - if id(x.shape[i]) in explicit_tracers: - idxs.setdefault(name, DBIdx(offset + explicit_tracers[id(x.shape[i])])) - - return idxs, implicit_types - -def _arg_type( - idxs: dict[AbstractedAxisName, DBIdx], x: Any, - spec: dict[int, AbstractedAxisName] - ) -> AbstractValue: - # Produce an AbstractValue by substituting DBIdxs for AbstractedAxisNames. - aval = get_aval(x) # aval.shape could contain Tracers - if not spec: return aval - shape: list[int | DBIdx] = [idxs[spec[i]] if i in spec else d - for i, d in enumerate(aval.shape)] - assert not any(isinstance(d, Tracer) for d in shape) - return DShapedArray(tuple(shape), aval.dtype, False) - -def _add_implicit_outputs(jaxpr: Jaxpr) -> tuple[Jaxpr, OutputType]: - invars = [*jaxpr.constvars, *jaxpr.invars] - expl_outvars = jaxpr.outvars - - # First do a pass to collect implicit outputs, meaning variables which occur - # in explicit_outvars types but not in invars or to the left in outvars. - seen: set[Var] = set(invars) - impl_outvars = [seen.add(d) or d for x in expl_outvars if type(x) is Var and # type: ignore - (seen.add(x) or type(x.aval) is DShapedArray) # type: ignore - for d in x.aval.shape if type(d) is Var and d not in seen] - outvars = [*impl_outvars, *expl_outvars] - - # Now assemble an OutputType by mapping vars in shapes to InDBIdx/OutDBIdx. - in_map : dict[Var, InDBIdx] = {v: InDBIdx(i) for i, v in enumerate( invars)} - out_map: dict[Var, OutDBIdx] = {x: OutDBIdx(i) for i, x in enumerate(outvars) - if type(x) is Var} - out_avals_ = (x.aval for x in outvars) - out_avals = [a.update(shape=tuple(in_map.get(d, out_map.get(d)) - if type(d) is Var else d for d in a.shape)) - if type(a) is DShapedArray else a for a in out_avals_] - kept_outs = [False] * len(impl_outvars) + [True] * len(expl_outvars) - out_type = tuple(zip(out_avals, kept_outs)) - - new_jaxpr = jaxpr.replace(outvars=outvars) - config.enable_checks.value and core.check_jaxpr(jaxpr) - return new_jaxpr, out_type - - class TracerAsName: ref: Any def __init__(self, tracer): @@ -2695,155 +2409,9 @@ def __eq__(self, other): def __hash__(self): return id(self.ref) -def _extract_implicit_args( - trace: DynamicJaxprTrace, in_type: Sequence[tuple[AbstractValue, bool]], - explicit_tracers: Sequence[DynamicJaxprTracer], source_info: SourceInfo, - ) -> Sequence[DynamicJaxprTracer]: - # First, construct a list to represent the full argument list, leaving the - # implicit arguments as Nones for now. - explicit_tracers_ = iter(explicit_tracers) - tracers = [next(explicit_tracers_) if expl else None for _, expl in in_type] - assert next(explicit_tracers_, None) is None - del explicit_tracers_ - - # Next, populate the implicit arguments using DBIdxs in in_type. - for i, (aval, explicit) in enumerate(in_type): - if not explicit or not isinstance(aval, DShapedArray): - continue # can't populate an implicit argument - tracer = tracers[i] - assert tracer is not None - for d1, d2 in zip(aval.shape, tracer.aval.shape): - if isinstance(d1, DBIdx): - if tracers[d1.val] is None: - tracers[d1.val] = trace.to_jaxpr_tracer(d2, source_info) - assert tracers[d1.val] is trace.to_jaxpr_tracer(d2, source_info) - assert all(t is not None for t in tracers) - return [t for t, (_, e) in zip(tracers, in_type) if not e] # type: ignore - -def _input_type_to_tracers( - new_arg: Callable[[AbstractValue | core.AvalQDD], Tracer], - in_avals: Sequence[AbstractValue | core.AvalQDD] - ) -> Sequence[Tracer]: - # Create input Tracers given input AbstractValues, each of which can contain - # DeBruijn indices which refer to positions in the input argument list. That - # is, each element `a` of `in_avals` can have DBIdx instances in its shape, - # which must refer to positions left of `a`'s. - in_tracers: list[Tracer] = [] - - def _substitute_tracers_in_aval(a): - if isinstance(a, DShapedArray) and any(type(d) is DBIdx for d in a.shape): - shape = [in_tracers[d.val] if type(d) is DBIdx else d for d in a.shape] - return a.update(shape=tuple(shape)) - return a - - for a in in_avals: - in_tracers.append(new_arg(_substitute_tracers_in_aval(a))) - return in_tracers - Const = Any Val = Any -def pad_jaxpr(jaxpr: Jaxpr, consts: Sequence[Const] - ) -> tuple[Jaxpr, list[Const]]: - bounds = {v: v.aval.dtype.bound for v in jaxpr.invars - if isinstance(v.aval, (core.ShapedArray, core.DShapedArray)) and - type(v.aval.dtype) is core.bint and not v.aval.shape} - idxs = {v: DBIdx(i) for i, v in enumerate(jaxpr.invars)} - - def substitute(aval: AbstractValue) -> AbstractValue: - if (isinstance(aval, (core.ShapedArray, core.DShapedArray)) - and type(aval.dtype) is core.bint and not aval.shape): - return ShapedArray((), dtypes.scalar_type_to_dtype(int)) - elif isinstance(aval, DShapedArray): - shape = [bounds.get(d, idxs.get(d, d)) for d in aval.shape] # type: ignore - typ = ShapedArray if all(type(d) is int for d in shape) else DShapedArray - return typ(tuple(shape), aval.dtype, aval.weak_type) - else: - return aval - - in_avals = [substitute(v.aval) for v in jaxpr.invars] - eval_padded = lu.wrap_init(partial(_eval_jaxpr_padded, jaxpr, consts), - debug_info=jaxpr.debug_info) - padded_jaxpr, _, padded_consts = trace_to_jaxpr_dynamic(eval_padded, in_avals) - return padded_jaxpr, padded_consts - -class BoundedAxisSize(NamedTuple): - val: int | DynamicJaxprTracer - bound: int - -def _eval_jaxpr_padded( - jaxpr: Jaxpr, consts: Sequence[Const], *args: DynamicJaxprTracer - ) -> list[Const | DynamicJaxprTracer]: - env: dict[Var, Val] = {} - - def read(x): - return x.val if type(x) is Literal else env[x] - - def write(v, val) -> None: - env[v] = val - - foreach(write, jaxpr.constvars, consts) - foreach(write, jaxpr.invars, args) - last_used = core.last_used(jaxpr) - for eqn in jaxpr.eqns: - in_avals = [_substitute_axis_sizes(env, v.aval) for v in eqn.invars] - out_avals = [_substitute_axis_sizes(env, v.aval) for v in eqn.outvars] - rule = padding_rules[eqn.primitive] - outs = rule(in_avals, out_avals, *map(read, eqn.invars), **eqn.params) - foreach(write, eqn.outvars, outs) - core.clean_up_dead_vars(eqn, env, last_used) - return map(read, jaxpr.outvars) - -def _substitute_axis_sizes(env: dict, aval: AbstractValue) -> AbstractValue: - if isinstance(aval, DShapedArray): - shp = [] - for d in aval.shape: - if isinstance(d, core.DArray): - assert not d.shape and type(d.dtype) is core.bint - shp.append(BoundedAxisSize(int(d._data), int(d.dtype.bound))) - elif (type(d) is core.Var and isinstance(d.aval, core.DShapedArray) and - type(d.aval.dtype) is core.bint): - assert not d.aval.shape - shp.append(BoundedAxisSize(env[d], d.aval.dtype.bound)) - else: - shp.append(env.get(d, d)) - return DShapedArray(tuple(shp), aval.dtype, aval.weak_type) - else: - return aval - -def _is_bint_axis_size(d: int | core.DArray | core.Var) -> bool: - if isinstance(d, core.DArray): - assert not d.shape # pytype: disable=attribute-error - return type(d.dtype) is core.bint # pytype: disable=attribute-error - elif isinstance(d, core.Var): - return (isinstance(d.aval, core.DShapedArray) and # pytype: disable=attribute-error - type(d.aval.dtype) is core.bint) # pytype: disable=attribute-error - return False - - -padding_rules: dict[Primitive, Callable] = {} - -def def_trivial_padding(prim: Primitive) -> None: - if prim.multiple_results: - padding_rules[prim] = partial(_trivial_padding_rule_multi, prim) - else: - padding_rules[prim] = partial(_trivial_padding_rule, prim) - -def _trivial_padding_rule(prim, _, __, *args, **params): - return [prim.bind(*args, **params)] - -def _trivial_padding_rule_multi(prim, _, __, *args, **params): - return prim.bind(*args, **params) - -def call_padding_rule(prim, in_avals, out_avals, *args, call_jaxpr, **params): - if call_jaxpr.constvars: raise NotImplementedError - padded_jaxpr, padded_consts = pad_jaxpr(call_jaxpr, ()) - if padded_consts: raise NotImplementedError - new_params = dict(params, call_jaxpr=padded_jaxpr) - subfuns, bind_params = prim.get_bind_params(new_params) - return prim.bind(*subfuns, *args, **bind_params) - - def instantiate_const_at(trace: JaxprTrace, instantiate: bool, tracer): if instantiate: return trace.instantiate_const(tracer) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 0570b083a77e..54ff65bd0123 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -51,7 +51,6 @@ from jax._src import util from jax._src import xla_bridge as xb from jax._src.abstract_arrays import array_types -from jax._src.core import DShapedArray from jax._src.core import ShapedArray from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -228,11 +227,6 @@ def _shard_typed_scalar(xs, shardings, layouts, copy_semantics): for _t in literals.typed_scalar_types: shard_arg_handlers[_t] = _shard_typed_scalar -def _shard_darray(xs, shardings, layouts, copy_semantics): - bufs = [x._data for x in xs] - return shard_args(shardings, layouts, copy_semantics, bufs) -shard_arg_handlers[core.DArray] = _shard_darray - def _shard_mutable_array(xs, shardings, layouts, copy_semantics): bufs = [x._refs._buf for x in xs] return shard_args(shardings, layouts, copy_semantics, bufs) @@ -2460,7 +2454,7 @@ def _to_logical_sharding( return None if isinstance(sharding, AUTO): return sharding - elif isinstance(aval, (ShapedArray, DShapedArray, AbstractRef)): + elif isinstance(aval, (ShapedArray, AbstractRef)): assert isinstance(sharding, JSharding) return sharding elif isinstance(aval, core.AbstractToken): diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 0a3b103e9e24..f904c08dd2db 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -999,7 +999,6 @@ def _cond_typecheck(bind_time, *in_atoms, branches, **params): core.custom_typechecks[cond_p] = partial(_cond_typecheck, False) pe.partial_eval_jaxpr_custom_rules[cond_p] = _cond_partial_eval_custom pe.dce_rules[cond_p] = _cond_dce_rule -batching.ragged_prop_rules[cond_p] = batching.ragged_mask_assert_no_op_rule def _cond_is_high(*_, branches, **__) -> bool: return any(j.jaxpr.is_high for j in branches) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 0ca5cf773743..a9a8f5a2b74e 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -1299,9 +1299,6 @@ def _scan_batching_rule(axis_data, args, def _cached_scan_pad_jaxpr(jaxpr): return ClosedJaxpr(*pe.pad_jaxpr(jaxpr.jaxpr, jaxpr.consts)) -def _scan_padding_rule(in_avals, out_avals, *args, jaxpr, **params): - return scan_p.bind(*args, jaxpr=_cached_scan_pad_jaxpr(jaxpr), **params) - def _scan_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn ) -> tuple[list[bool], core.JaxprEqn | None]: if not any(used_outputs) and not pe.has_effects(eqn): @@ -1598,7 +1595,6 @@ def rearrange(lst): batching.fancy_primitive_batchers[scan_p] = _scan_batching_rule core.custom_typechecks[scan_p] = partial(_scan_typecheck, False) pe.partial_eval_jaxpr_custom_rules[scan_p] = _scan_partial_eval_custom -pe.padding_rules[scan_p] = _scan_padding_rule pe.dce_rules[scan_p] = _scan_dce_rule state_discharge.register_partial_discharge_rule(scan_p)(_scan_state_partial_discharge_rule) @@ -2516,7 +2512,7 @@ def _pred_bcast_select_hlo(ctx, pred_aval.shape, x_y_aval) x_y_aval = core.physical_aval(x_y_aval) bcast_pred = mlir.broadcast_in_dim( - ctx, pred, core.DShapedArray(x_y_aval.shape, np.dtype(np.bool_)), + ctx, pred, core.ShapedArray(x_y_aval.shape, np.dtype(np.bool_)), broadcast_dimensions=list(range(len(pred_aval.shape)))) return hlo.SelectOp(bcast_pred, x, y).results diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 689b76531cda..b75c719ef098 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -56,14 +56,12 @@ from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import pxla -from jax._src.interpreters.batching import RaggedAxis from jax._src.lax import slicing from jax._src.lax import utils as lax_utils from jax._src.mesh import get_abstract_mesh, get_concrete_mesh from jax._src.lax.utils import ( input_dtype, dtype_to_string, standard_multi_result_abstract_eval, standard_primitive) -from jax._src.lib import jaxlib_extension_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo @@ -104,11 +102,7 @@ def _check_static_shape(shape: Shape): raise TypeError(msg) assert shapes - if config.dynamic_shapes.value: - # pass dynamic shapes through unchecked - return - else: - foreach(_check_static_shape, shapes) + foreach(_check_static_shape, shapes) def _try_broadcast_shapes(*shapes: tuple[int, ...], name: str) -> tuple[int, ...]: """ @@ -246,39 +240,6 @@ def broadcast_shardings(*avals): def _identity(x, **_): return x -def _extract_tracers_dyn_shape( - shape: Sequence[int | core.Tracer] - ) -> tuple[list[core.Tracer], list[int | None]]: - # Given a sequence representing a shape, pull out Tracers, replacing with None - if config.dynamic_shapes.value: - # We must gate this behavior under a flag because otherwise the errors - # raised are different (and have worse source provenance information). - dyn_shape = [d for d in shape if isinstance(d, core.Tracer)] - static_shape = [None if isinstance(d, core.Tracer) else d for d in shape] - return dyn_shape, static_shape - else: - return [], list(shape) # type: ignore - -def _merge_dyn_shape( - static_shape: Sequence[int | None], - dyn_shape: Sequence[Any], - ) -> tuple[int | mlir.Value | core.Tracer, ...]: - # Replace Nones in static_shape with elements of dyn_shape, in order - dyn_shape_it = iter(dyn_shape) - shape = tuple(next(dyn_shape_it) if d is None else d for d in static_shape) - assert next(dyn_shape_it, None) is None - return shape - -def _dyn_shape_staging_rule(trace, source_info, prim, out_aval, *args, - **params): - var = trace.frame.newvar(out_aval) - eqn = pe.new_jaxpr_eqn([x.val for x in args], - [var], - prim, params, core.no_effects, source_info) - out_tracer = pe.DynamicJaxprTracer(trace, out_aval, var, source_info) - trace.frame.add_eqn(eqn) - return out_tracer - ### traceables @@ -2749,14 +2710,8 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape, if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and isinstance(operand, Array) and out_sharding is None): return operand - if config.dynamic_shapes.value: - # We must gate this behavior under a flag because otherwise the errors - # raised are different (and have worse source provenance information). - dyn_shape, static_shape = _extract_tracers_dyn_shape(shape) - else: - dyn_shape, static_shape = [], shape # type: ignore return broadcast_in_dim_p.bind( - operand, *dyn_shape, shape=tuple(static_shape), + operand, shape=tuple(shape), broadcast_dimensions=tuple(broadcast_dimensions), sharding=out_sharding) @@ -2824,9 +2779,8 @@ def reshape(operand: ArrayLike, new_sizes: Shape, isinstance(operand, Array)): return operand else: - dyn_shape, static_new_sizes = _extract_tracers_dyn_shape(new_sizes) return reshape_p.bind( - operand, *dyn_shape, new_sizes=tuple(static_new_sizes), + operand, new_sizes=tuple(new_sizes), dimensions=None if dims is None or same_dims else dims, sharding=out_sharding) @@ -3452,12 +3406,10 @@ def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int, """Convenience wrapper around ``iota``.""" dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "broadcasted_iota") shape = canonicalize_shape(shape) - dynamic_shape = [d for d in shape if isinstance(d, core.Tracer)] - static_shape = [None if isinstance(d, core.Tracer) else d for d in shape] dimension = core.concrete_or_error( int, dimension, "dimension argument of lax.broadcasted_iota") out_sharding = canonicalize_sharding(out_sharding, 'broadcasted_iota') - return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape), + return iota_p.bind(dtype=dtype, shape=shape, dimension=dimension, sharding=out_sharding) def _eye(dtype: DTypeLike, shape: Shape, offset: DimSize = 0) -> Array: @@ -3915,7 +3867,6 @@ def _iter(tracer): else: return (slicing.index_in_dim(tracer, i, keepdims=False) for i in range(n)) ShapedArray._iter = staticmethod(_iter) -core.DShapedArray._iter = staticmethod(_iter) def zeros_like_array(x: ArrayLike) -> Array: return full_like(x, 0) @@ -3977,7 +3928,6 @@ def unop(result_dtype, accepted_dtypes, name, supports_narrow_ints=True): vma_rule=_attrgetter('vma'), reduced_rule=unop_reduced_rule) batching.defvectorized(prim) - pe.def_trivial_padding(prim) return prim standard_unop = partial(unop, _identity) @@ -4102,7 +4052,6 @@ def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False, vma_rule=partial(core.standard_vma_rule, name), unreduced_rule=unreduced_rule, reduced_rule=nary_reduced_rule) batching.defbroadcasting(prim) - pe.def_trivial_padding(prim) return prim standard_naryop = partial(naryop, input_dtype) @@ -4110,7 +4059,7 @@ def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False, # Like autograd.numpy.numpy_vjps.unbroadcast, this utility handles transposition # involving linear primitives with implicit broadcasting. def _unbroadcast(aval, x): - if not isinstance(aval, (core.DShapedArray, ShapedArray)): + if not isinstance(aval, ShapedArray): raise TypeError("transpose with implicit broadcasting of unshaped values") x_shape = np.shape(x) if (core.definitely_equal_shape(aval.shape, x_shape) and @@ -4243,7 +4192,6 @@ def _round_lower(ctx, x, *, rounding_method): exp_p = standard_unop(_float | _complex, 'exp') ad.defjvp2(exp_p, lambda g, ans, x, **kwargs: mul(g, ans)) mlir.register_lowering(exp_p, partial(_nary_lower_hlo, hlo.exponential)) -batching.ragged_prop_rules[exp_p] = batching.ragged_mask_elementwise_rule core.pp_eqn_rules[exp_p] = _unary_with_accuracy_pp_rule exp2_p = standard_unop(_float | _complex, 'exp2') @@ -4354,7 +4302,6 @@ def _sin_lin(nzs, x, accuracy): ad.primitive_linearizations[sin_p] = _sin_lin mlir.register_lowering(sin_p, _sin_lowering) core.pp_eqn_rules[sin_p] = _unary_with_accuracy_pp_rule -batching.ragged_prop_rules[sin_p] = batching.ragged_mask_elementwise_rule def _cos_complex(x): # cos(x) = complex(cos(real(x)) * cosh(imag(x)), -sin(real(x)) * sinh(imag(x))) @@ -4598,7 +4545,6 @@ def _integer_pow_jvp(g, x, *, y): sharding_rule=_attrgetter('sharding'), vma_rule=_attrgetter('vma')) batching.defvectorized(integer_pow_p) ad.defjvp(integer_pow_p, _integer_pow_jvp) -pe.def_trivial_padding(integer_pow_p) def _integer_pow(x, *, y): # This should be kept in sync with the jax2tf translation rule. @@ -4713,7 +4659,6 @@ def _add_unreduced_rule(out_sharding, x, y): ad.primitive_jvps[add_p] = _add_jvp ad.primitive_transposes[add_p] = _add_transpose mlir.register_lowering(add_p, partial(_nary_lower_hlo, hlo.add)) -batching.ragged_prop_rules[add_p] = batching.ragged_mask_elementwise_rule def _sub_jvp(primals, tangents): x, y = primals @@ -4743,7 +4688,6 @@ def _sub_transpose(t, x, y): ad.primitive_jvps[sub_p] = _sub_jvp ad.primitive_transposes[sub_p] = _sub_transpose mlir.register_lowering(sub_p, partial(_nary_lower_hlo, hlo.subtract)) -batching.ragged_prop_rules[sub_p] = batching.ragged_mask_elementwise_rule def _mul_unreduced_rule(out_sharding, x, y): x_ur, y_ur = x.sharding.spec.unreduced, y.sharding.spec.unreduced @@ -4781,7 +4725,6 @@ def _mul_unreduced_rule(out_sharding, x, y): ad.defbilinear(mul_p, lambda ct, x, y: _unbroadcast(x.aval, mul(ct, y)), lambda ct, x, y: _unbroadcast(y.aval, mul(x, ct))) mlir.register_lowering(mul_p, partial(_nary_lower_hlo, hlo.multiply)) -batching.ragged_prop_rules[mul_p] = batching.ragged_mask_elementwise_rule def _div_transpose_rule(cotangent, x, y): assert ad.is_undefined_primal(x) and not ad.is_undefined_primal(y) @@ -4795,7 +4738,6 @@ def _div_transpose_rule(cotangent, x, y): lambda g, x, y: mul(mul(neg(g), x), integer_pow(y, -2))) ad.primitive_transposes[div_p] = _div_transpose_rule mlir.register_lowering(div_p, partial(_nary_lower_hlo, hlo.divide)) -batching.ragged_prop_rules[div_p] = batching.ragged_mask_elementwise_rule rem_p = standard_naryop([_int | _float, _int | _float], 'rem') ad.defjvp( @@ -4819,14 +4761,12 @@ def _minmax_complex_lowering(x, y, *, lax_cmp_pick_x): lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)), lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x))) mlir.register_lowering(max_p, partial(_nary_lower_hlo, mlir.max_hlo)) -batching.ragged_prop_rules[max_p] = batching.ragged_mask_elementwise_rule min_p: core.Primitive = standard_naryop([_any, _any], 'min') ad.defjvp2(min_p, lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)), lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x))) mlir.register_lowering(min_p, partial(_nary_lower_hlo, mlir.min_hlo)) -batching.ragged_prop_rules[min_p] = batching.ragged_mask_elementwise_rule shift_left_p = standard_naryop([_int, _int], 'shift_left') ad.defjvp_zero(shift_left_p) @@ -4893,7 +4833,6 @@ def _compare_lower_hlo(direction: str, total_order: bool, ctx, x, y): eq_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq', allow_extended_dtype=True) ad.defjvp_zero(eq_p) mlir.register_lowering(eq_p, partial(_compare_lower_hlo, "EQ", False)) -batching.ragged_prop_rules[eq_p] = batching.ragged_mask_elementwise_rule ne_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'ne', allow_extended_dtype=True) ad.defjvp_zero(ne_p) @@ -4914,7 +4853,6 @@ def _compare_lower_hlo(direction: str, total_order: bool, ctx, x, y): lt_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'lt') ad.defjvp_zero(lt_p) mlir.register_lowering(lt_p, partial(_compare_lower_hlo, "LT", False)) -batching.ragged_prop_rules[lt_p] = batching.ragged_mask_elementwise_rule eq_to_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq_to') ad.defjvp_zero(eq_to_p) @@ -5067,11 +5005,7 @@ def _convert_element_type_batching_rule( pe.const_fold_rules[convert_element_type_p] = _convert_elt_type_folding_rule pe.forwarding_rules[convert_element_type_p] = _convert_elt_type_fwd_rule -pe.def_trivial_padding(convert_element_type_p) core.pp_eqn_rules[convert_element_type_p] = _convert_elt_type_pp_rule -batching.ragged_prop_rules[convert_element_type_p] = ( - batching.ragged_mask_elementwise_rule -) def _real_dtype(dtype): return np.finfo(dtype).dtype @@ -5094,7 +5028,7 @@ def _to_edtype_abstract_eval(x, *, edtype): not isinstance(x.dtype, dtypes.ExtendedDType)) # For backward compatibility, if the edtype rules have a `convert_to` method, # use that rather than looking for an `allow_conversion: bool` attribute. - if not isinstance(x, (ShapedArray, core.DShapedArray)): + if not isinstance(x, ShapedArray): raise TypeError("can only convert to an extended dtype on an array type," f"but got {type(x)}") if convert_to := getattr(edtype._rules, 'convert_to', None): @@ -5139,8 +5073,6 @@ def _to_edtype_abstract_eval(x, *, edtype): f"shape {rep_aval.shape}") return x.update(shape=shape_prefix, dtype=edtype, sharding=x.sharding.update(spec=spec_prefix)) - elif isinstance(x, core.DShapedArray): - return x.update(shape=shape_prefix, dtype=edtype) else: assert False # unreachable, see isinstance check above @@ -5159,7 +5091,7 @@ def _to_edtype_abstract_eval(x, *, edtype): def _from_edtype_abstract_eval(x, *, dtype): assert (isinstance(x.dtype, dtypes.ExtendedDType) and not isinstance(dtype, dtypes.ExtendedDType)) - if not isinstance(x, (ShapedArray, core.DShapedArray)): + if not isinstance(x, ShapedArray): raise TypeError("can only convert from an extended dtype on an array type," f"but got {type(x)}") if convert_from := getattr(x.dtype._rules, 'convert_from', None): @@ -5180,11 +5112,6 @@ def _from_edtype_abstract_eval(x, *, dtype): f"{dtype_to_string(rep_aval.dtype)}.") if isinstance(x, ShapedArray): return x.update(shape=(*x.shape, *rep_aval.shape), dtype=dtype) - elif isinstance(x, core.DShapedArray): - if all(isinstance(d, int) for d in x.shape): - return core.ShapedArray(shape=(*x.shape, *rep_aval.shape), dtype=dtype) - else: - raise NotImplementedError else: assert False # unreachable, see isinstance check above @@ -5562,30 +5489,14 @@ def _dot_batch_rule( lhs, rhs = unpack_args(batched_args) lbd, rbd = unpack_dims(batch_dims) - left_stack_dim = lbd.stacked_axis if type(lbd) is RaggedAxis else lbd - right_stack_dim = rbd.stacked_axis if type(rbd) is RaggedAxis else rbd new_dimension_numbers, result_stack_dim = _dot_general_batch_dim_nums( - (np.ndim(lhs), np.ndim(rhs)), (left_stack_dim, right_stack_dim), + (np.ndim(lhs), np.ndim(rhs)), (lbd, rbd), dimension_numbers) - # TODO Should probably check that any ragged dimensions have corresponding - # sizes, because otherwise the dot product is technically undefined. - # - # This masking is not strictly necessary for non-contraction dimensions; - # we could micro-optimize here by avoiding computing that mask. - if type(lbd) is RaggedAxis: - lhs = batching.mask_ragged_axes(lhs, _get_sum_identity, lbd) - lhs_shape = batching.bdim_as_shape(lbd, lhs.shape) - else: - lhs_shape = np.shape(lhs) - if type(rbd) is RaggedAxis: - rhs = batching.mask_ragged_axes(rhs, _get_sum_identity, rbd) - rhs_shape = batching.bdim_as_shape(rbd, rhs.shape) - else: - rhs_shape = np.shape(rhs) - result_batch_dim = batching.shape_as_bdim( - result_stack_dim, - _dot_general_shape_computation(lhs_shape, rhs_shape, new_dimension_numbers)) + lhs_shape = np.shape(lhs) + rhs_shape = np.shape(rhs) + result_shape = _dot_general_shape_computation(lhs_shape, rhs_shape, new_dimension_numbers) + result_batch_dim = canonicalize_axis(result_stack_dim, len(result_shape)) if out_sharding is not None: out_sharding = batching.get_sharding_for_vmap( @@ -5673,15 +5584,6 @@ def bump_dims(dims, b): ) return new_dimension_numbers, result_batch_dim -def _dot_general_padding_rule(in_avals, out_avals, lhs, rhs, *, - dimension_numbers, **params): - lhs_aval, _ = in_avals - (lhs_contract, _), _ = dimension_numbers - padded_axes = [(i, lhs_aval.shape[i].val) for i in lhs_contract - if isinstance(lhs_aval.shape[i], pe.BoundedAxisSize)] - lhs_ = _replace_masked_values(lhs, 0, padded_axes) - return [dot_general(lhs_, rhs, dimension_numbers=dimension_numbers, **params)] - def _dot_general_pp_rule(eqn, context, settings) -> pp.Doc: # * suppress printing precision or preferred_element_type when None. # * print dimension_numbers as list-of-lists to be shorter. @@ -5692,59 +5594,6 @@ def _dot_general_pp_rule(eqn, context, settings) -> pp.Doc: return core._pp_eqn(eqn.replace(params=printed_params), context, settings) -def _dot_general_ragged_prop_rule(eqn_params, invar_raggedness, outvars): - assert len(invar_raggedness) == 2 - assert len(outvars) == 1 - invar_raggedness_lhs = invar_raggedness[0] - invar_raggedness_rhs = invar_raggedness[1] - - dimension_numbers = eqn_params['dimension_numbers'] - (lhs_contracting, rhs_contracting), (_, _) = dimension_numbers - - if not invar_raggedness_lhs and not invar_raggedness_rhs: - # Both are dense - it is valid to reach here, because dense operations - # are legal in code running under ragged prop. - return invar_raggedness, [None] - - if not invar_raggedness_lhs or not invar_raggedness_rhs: - # One ragged, one dense - if not invar_raggedness_lhs: - # left is dense, right is ragged - _, ragged_axis_dim_rhs, _, _ = invar_raggedness_rhs - if rhs_contracting != ragged_axis_dim_rhs: - # Contraction is on a dense dimension, this is valid! - return invar_raggedness, [None] - if not invar_raggedness_rhs: - # left is ragged, right is dense - _, ragged_axis_dim_lhs, _, _ = invar_raggedness_lhs - if lhs_contracting != ragged_axis_dim_lhs: - # Contraction is on a dense dimension, this is valid! - return invar_raggedness, [None] - - raise NotImplementedError('NYI - dense and ragged dim contraction') - - stacked_axis_lhs, ragged_axis_dim_lhs, _, _ = invar_raggedness_lhs - stacked_axis_rhs, ragged_axis_dim_rhs, _, _ = invar_raggedness_rhs - - if stacked_axis_rhs != 0 or stacked_axis_lhs != 0: - raise NotImplementedError( - 'Dot general ragged prop for non 0 stacked axis, NYI' - ) - - # We only support ragged k atm, that is, lhs is (m, ragged_k) and rhs is - # (ragged_k, n), meaning the output is dense. - if ragged_axis_dim_lhs != 2 or ragged_axis_dim_rhs != 1: - raise NotImplementedError( - 'Dot general ragged prop for non contraction raggedness, NYI' - ) - - assert len(outvars) == 1 - - # TODO(mvoz): A constant on batching.* ? - # Dense (m, n) - no jumble only atm - return invar_raggedness, [None] - - dot_general_p = standard_primitive( _dot_general_shape_rule, _dot_general_dtype_rule, @@ -5776,9 +5625,7 @@ def _dot_general_batch_unpack_dims(batch_dims): ) batching.fancy_primitive_batchers[dot_general_p] = _dot_general_batch_rule batching.skippable_batchers[dot_general_p] = lambda _: () -pe.padding_rules[dot_general_p] = _dot_general_padding_rule core.pp_eqn_rules[dot_general_p] = _dot_general_pp_rule -batching.ragged_prop_rules[dot_general_p] = _dot_general_ragged_prop_rule def _full_precision(precision: Precision) -> tuple[Precision, Precision]: @@ -6582,92 +6429,46 @@ def _broadcast_in_dim_sharding_rule(operand, *, shape, broadcast_dimensions, spec=operand.sharding.spec.update(partitions=new_spec)) def _broadcast_in_dim_typecheck_rule( - _, operand, *dyn_shape, shape, broadcast_dimensions, sharding): - if not dyn_shape: - out_aval, effects = broadcast_in_dim_p.abstract_eval( - operand.aval, shape=shape, broadcast_dimensions=broadcast_dimensions, - sharding=sharding) - return [out_aval], effects - else: - # TODO(mattjj): perform more checks like _broadcast_in_dim_shape_rule - out_shape = _merge_dyn_shape(shape, dyn_shape) - out_shape = [x.val if type(x) is core.Literal else x for x in out_shape] # pytype: disable=attribute-error - out_aval = core.DShapedArray(tuple(out_shape), operand.aval.dtype, - operand.aval.weak_type) - return [out_aval], core.no_effects - -def _broadcast_in_dim_transpose_rule(ct, operand, *dyn_shape, + _, operand, shape, broadcast_dimensions, sharding): + out_aval, effects = broadcast_in_dim_p.abstract_eval( + operand.aval, shape=shape, broadcast_dimensions=broadcast_dimensions, + sharding=sharding) + return [out_aval], effects + +def _broadcast_in_dim_transpose_rule(ct, operand, shape, broadcast_dimensions, sharding): if type(ct) is ad_util.Zero: return [ad_util.Zero(operand.aval)] if not isinstance(operand, ad.UndefinedPrimal): - return [None] * (1 + len(dyn_shape)) # transpose wrt literal + return [None] # transpose wrt literal unit_dims = [i for i, s in enumerate(operand.aval.shape) if core.definitely_equal(s, 1)] bdims = tuple(np.delete(broadcast_dimensions, unit_dims)) axes = tuple(np.delete(range(len(shape)), bdims)) - return ([expand_dims(reduce_sum(ct, axes), unit_dims)] + - [None] * len(dyn_shape)) + return [expand_dims(reduce_sum(ct, axes), unit_dims)] def _broadcast_in_dim_batch_rule(axis_data, batched_args, batch_dims, shape, broadcast_dimensions, sharding): - # `dyn_shape` is the dynamic portion of the target shape. `shape` - # is the target shape, with `None` for dynamic sections. - # broadcast_dimensions gives indices where dimensions of the input - # have to go: dimension i of the input becomes dimension - # broadcast_dimensions[i] of the output. - operand, *dyn_shape = batched_args - operand_bdim, *dyn_shape_bdims = batch_dims - - stacked_size = None - if operand_bdim is not None: - if isinstance(operand_bdim, RaggedAxis): - stacked_axis = operand_bdim.stacked_axis - stacked_size = operand_bdim.size - else: - stacked_axis = operand_bdim - stacked_size = operand.shape[stacked_axis] - new_operand = batching.moveaxis(operand, stacked_axis, 0) - new_broadcast_dimensions = (0,) + tuple(np.add(1, broadcast_dimensions)) - else: - new_operand = operand - new_broadcast_dimensions = tuple(np.add(1, broadcast_dimensions)) - - # TODO(mattjj,axch) This section assumes that the shape of the operand is - # broadcast-compatible with the requested shape. We should tweak vmap to run - # the abstract_eval rule so this can be checked while the raggedness - # information is available. - dyn_limits = [] - out_ragged_sizes = [] - for sizes, bdim in zip(dyn_shape, dyn_shape_bdims): - if bdim is None: - # TODO(mattjj,axch) Is this what bdim == None means? - assert isinstance(sizes, int) - bound = sizes - else: - bound = sizes.dtype.bound - out_ragged_sizes.append(sizes) - if stacked_size is None: - stacked_size = len(sizes) - else: - msg = "All segments lengths arrays must be the same length" - assert len(sizes) == stacked_size, msg - dyn_limits.append(bound) - new_shape = (stacked_size,) + _merge_dyn_shape(shape, dyn_limits) + # `shape` is the target shape. broadcast_dimensions gives indices where + # dimensions of the input have to go: dimension i of the input becomes + # dimension broadcast_dimensions[i] of the output. + operand, = batched_args + operand_bdim, = batch_dims + assert operand_bdim is not None + new_operand = batching.moveaxis(operand, operand_bdim, 0) + new_broadcast_dimensions = (0,) + tuple(np.add(1, broadcast_dimensions)) + new_shape = (operand.shape[operand_bdim],) + shape if sharding is not None: sharding = batching.get_sharding_for_vmap(axis_data, sharding, 0) result = broadcast_in_dim(new_operand, new_shape, new_broadcast_dimensions, out_sharding=sharding) - out_ragged_axes = [idx+1 for idx, s in enumerate(shape) if s is None] - out_bdim = batching.make_batch_axis( - result.ndim, 0, zip(out_ragged_axes, out_ragged_sizes)) - return result, out_bdim + return result, 0 def _broadcast_in_dim_fwd_rule(eqn): - v, *dyn = eqn.invars - if (not dyn and core.definitely_equal_shape(eqn.params['shape'], v.aval.shape) + v, = eqn.invars + if (core.definitely_equal_shape(eqn.params['shape'], v.aval.shape) and (eqn.params['sharding'] is None or eqn.params['sharding'] == v.aval.sharding)): return [0], None @@ -6675,103 +6476,51 @@ def _broadcast_in_dim_fwd_rule(eqn): return [None], eqn def _broadcast_in_dim_staging_rule( - trace, source_info, x, *dyn, shape, broadcast_dimensions, sharding): + trace, source_info, x, shape, broadcast_dimensions, sharding): params = dict(shape=shape, broadcast_dimensions=broadcast_dimensions, sharding=sharding) - if not dyn: - return trace.default_process_primitive(broadcast_in_dim_p, (x,), params, - source_info=source_info) - aval = core.DShapedArray(_merge_dyn_shape(shape, dyn), x.dtype, x.weak_type) - return _dyn_shape_staging_rule(trace, source_info, broadcast_in_dim_p, aval, - x, *dyn, **params) - -def _broadcast_in_dim_padding_rule(in_avals, out_avals, x, *dyn_shape, - shape, broadcast_dimensions): - del in_avals, dyn_shape - out_aval, = out_avals - new_shape = [] - new_dyn_shape = [] - for d in out_aval.shape: - if type(d) is pe.BoundedAxisSize: - new_shape.append(d.bound) - elif type(d) is int: - new_shape.append(d) - else: - assert isinstance(d, core.Tracer) - new_shape.append(None) - new_dyn_shape.append(d) - return [broadcast_in_dim_p.bind(x, *new_dyn_shape, shape=tuple(new_shape), - broadcast_dimensions=broadcast_dimensions)] + return trace.default_process_primitive(broadcast_in_dim_p, (x,), params, + source_info=source_info) def _broadcast_in_dim_jvp_rule(primals, tangents, *, shape, broadcast_dimensions, sharding): - operand, *dyn_shape = primals + operand, = primals operand_dot, *_ = tangents - y = broadcast_in_dim_p.bind(operand, *dyn_shape, shape=shape, + y = broadcast_in_dim_p.bind(operand, shape=shape, broadcast_dimensions=broadcast_dimensions, sharding=sharding) if type(operand_dot) is ad_util.Zero: y_dot = ad_util.Zero.from_primal_value(y) else: - y_dot = broadcast_in_dim_p.bind(operand_dot, *dyn_shape, shape=shape, + y_dot = broadcast_in_dim_p.bind(operand_dot, shape=shape, broadcast_dimensions=broadcast_dimensions, sharding=sharding) return y, y_dot def _broadcast_in_dim_partial_eval( - trace, operand, *dyn_shape, shape, broadcast_dimensions, sharding): - if not dyn_shape: - return trace.default_process_primitive( - broadcast_in_dim_p, (operand, *dyn_shape), - dict(shape=shape, broadcast_dimensions=broadcast_dimensions, - sharding=sharding)) - assert all(t.pval.is_known() for t in dyn_shape) - operand_tracer = trace.instantiate_const(operand) - dyn_shape_tracers = map(trace.instantiate_const, dyn_shape) - dyn_shape_tracers_ = iter(dyn_shape_tracers) - shape_ = [next(dyn_shape_tracers_) if d is None else d for d in shape] - out_aval = core.DShapedArray(tuple(shape_), operand.dtype, operand.weak_type) - out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None) - eqn = pe.new_eqn_recipe( - trace, [operand_tracer, *dyn_shape_tracers], [out_tracer], broadcast_in_dim_p, + trace, operand, shape, broadcast_dimensions, sharding): + return trace.default_process_primitive( + broadcast_in_dim_p, (operand,), dict(shape=shape, broadcast_dimensions=broadcast_dimensions, - sharding=None), - core.no_effects, source_info_util.current()) - out_tracer.recipe = eqn - return out_tracer + sharding=sharding)) -def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions, +def _broadcast_in_dim_lower(ctx, x, shape, broadcast_dimensions, sharding) -> Sequence[ir.Value]: aval_out, = ctx.avals_out - if dyn_shape: - aval_out = aval_out.update(shape=_merge_dyn_shape(shape, dyn_shape)) out = mlir.broadcast_in_dim(ctx, x, aval_out, broadcast_dimensions=broadcast_dimensions) return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)] -def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions, +def _broadcast_in_dim_abstract_eval(x, shape, broadcast_dimensions, sharding): - if (not dyn_shape and - not any(isinstance(d, core.DArray) and - type(core.get_aval(d).dtype) is core.bint for d in shape)): - shape = _broadcast_in_dim_shape_rule( # error checking - x, shape=shape, broadcast_dimensions=broadcast_dimensions, sharding=None) - new_sharding = _broadcast_in_dim_sharding_rule( - x, shape=shape, broadcast_dimensions=broadcast_dimensions, - sharding=sharding) - new_vma = core.standard_vma_rule('broadcast_in_dim', x) - return core.ShapedArray(shape, x.dtype, x.weak_type, sharding=new_sharding, - vma=new_vma, memory_space=x.memory_space) - # If any BInts in shape, or Tracers in dyn_shape, produce a DShapedArray - # (even if x is a ShapedArray) - # TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code - return core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), x.dtype, x.weak_type) - - -def _broadcast_in_dim_ragged_prop_rule(eqn_params, invar_raggedness, outvars): - assert len(invar_raggedness) == 1 - assert not isinstance(invar_raggedness[0], core.Var) - return invar_raggedness, [None] * len(outvars) + shape = _broadcast_in_dim_shape_rule( # error checking + x, shape=shape, broadcast_dimensions=broadcast_dimensions, sharding=None) + new_sharding = _broadcast_in_dim_sharding_rule( + x, shape=shape, broadcast_dimensions=broadcast_dimensions, + sharding=sharding) + new_vma = core.standard_vma_rule('broadcast_in_dim', x) + return core.ShapedArray(shape, x.dtype, x.weak_type, sharding=new_sharding, + vma=new_vma, memory_space=x.memory_space) broadcast_in_dim_p = core.Primitive('broadcast_in_dim') @@ -6784,12 +6533,8 @@ def _broadcast_in_dim_ragged_prop_rule(eqn_params, invar_raggedness, outvars): pe.forwarding_rules[broadcast_in_dim_p] = _broadcast_in_dim_fwd_rule pe.custom_partial_eval_rules[broadcast_in_dim_p] = _broadcast_in_dim_partial_eval pe.custom_staging_rules[broadcast_in_dim_p] = _broadcast_in_dim_staging_rule -pe.padding_rules[broadcast_in_dim_p] = _broadcast_in_dim_padding_rule core.custom_typechecks[broadcast_in_dim_p] = _broadcast_in_dim_typecheck_rule mlir.register_lowering(broadcast_in_dim_p, _broadcast_in_dim_lower) -batching.ragged_prop_rules[broadcast_in_dim_p] = ( - _broadcast_in_dim_ragged_prop_rule -) def _clamp_shape_rule(min, operand, max): @@ -6858,7 +6603,6 @@ def _clamp_batch_rule(batched_args, batch_dims, **params): select(lt(max, operand), g, _zeros(operand))) batching.primitive_batchers[clamp_p] = _clamp_batch_rule mlir.register_lowering(clamp_p, partial(_nary_lower_hlo, hlo.clamp)) -pe.def_trivial_padding(clamp_p) def _concatenate_shape_rule(*operands, **kwargs): dimension = kwargs.pop('dimension') @@ -6949,7 +6693,6 @@ def _concatenate_pad_rule(in_avals, out_avals, *operands, dimension): ad.deflinear2(concatenate_p, _concatenate_transpose_rule) ad.primitive_transposes[concatenate_p] = _concatenate_transpose_rule batching.primitive_batchers[concatenate_p] = _concatenate_batch_rule -pe.padding_rules[concatenate_p] = _concatenate_pad_rule def _concatenate_lower(ctx, *xs, dimension): aval_out, = ctx.avals_out @@ -7169,12 +6912,11 @@ def _squeeze_transpose_rule(t, operand, *, dimensions): def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions): operand, = batched_args bdim, = batch_dims - operand, bdim = batching.move_stacked_axis(operand, bdim, 0) + operand = batching.moveaxis(operand, bdim, 0) dimensions = tuple(np.add(1, dimensions)) - out_stack_dim = bdim.stacked_axis if isinstance(bdim, RaggedAxis) else bdim - bdim_out = batching.shape_as_bdim( - out_stack_dim, - _compute_squeeze_shape(batching.bdim_as_shape(bdim, operand.shape), dimensions)) + + result_shape = _compute_squeeze_shape(operand.shape, dimensions) + bdim_out = canonicalize_axis(0, len(result_shape)) return squeeze(operand, dimensions=dimensions), bdim_out squeeze_p = standard_primitive( @@ -7184,8 +6926,6 @@ def _squeeze_batch_rule(batched_args, batch_dims, *, dimensions): reduced_rule=_squeeze_reduced_rule) ad.deflinear2(squeeze_p, _squeeze_transpose_rule) batching.primitive_batchers[squeeze_p] = _squeeze_batch_rule -pe.def_trivial_padding(squeeze_p) -batching.ragged_prop_rules[squeeze_p] = batching.ragged_mask_no_op_rule def _squeeze_lower(ctx, operand, *, dimensions): del dimensions # Implied by the output aval. @@ -7217,12 +6957,6 @@ def _reshape_shape_rule(operand, *, new_sizes, dimensions, sharding): # TODO(necula): re-enable this check operand_size = math.prod(np.shape(operand)) new_size = math.prod(new_sizes) - if (not config.dynamic_shapes.value and - not operand_size == new_size): - msg = (f"reshape total size must be unchanged, got new_sizes {new_sizes} " - f"(of total size {new_size}) for shape {np.shape(operand)} " - f"(of total size {operand_size}).") - raise TypeError(msg) if dimensions is not None: if set(dimensions) != set(range(np.ndim(operand))): msg = ('reshape dimensions must be a permutation of operand dimensions, ' @@ -7366,20 +7100,12 @@ def _merge_an_axis_sharding_rule(operand, operand_merge, new_sizes, dimensions): return operand.sharding.update(spec=new_spec) -def _reshape_typecheck_rule(_, operand, *dyn_shape, new_sizes, dimensions, +def _reshape_typecheck_rule(_, operand, new_sizes, dimensions, sharding): - if not dyn_shape: - out_aval, effects = reshape_p.abstract_eval( - operand.aval, new_sizes=new_sizes, dimensions=dimensions, - sharding=sharding) - return [out_aval], effects - else: - # TODO(mattjj, necula): perform more checks like _reshape_shape_rule - out_shape = _merge_dyn_shape(new_sizes, dyn_shape) - out_shape = [x.val if type(x) is core.Literal else x for x in out_shape] # pytype: disable=attribute-error - out_aval = core.DShapedArray(tuple(out_shape), operand.aval.dtype, - operand.aval.weak_type) - return [out_aval], core.no_effects + out_aval, effects = reshape_p.abstract_eval( + operand.aval, new_sizes=new_sizes, dimensions=dimensions, + sharding=sharding) + return [out_aval], effects def _reshape_dtype_rule(operand, *, new_sizes, dimensions, sharding): @@ -7413,24 +7139,18 @@ def _reshape_batch_rule(axis_data, batched_args, batch_dims, *, new_sizes, return out, 0 -def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions, sharding): +def _reshape_lower(ctx, x, new_sizes, dimensions, sharding): aval_out, = ctx.avals_out if dimensions is not None: x = hlo.transpose(x, mlir.dense_int_array(dimensions)) - if dyn_shape: - aval_out = aval_out.update(shape=_merge_dyn_shape(new_sizes, dyn_shape)) out = mlir.reshape(ctx, x, aval_out) return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)] def _reshape_staging_rule( - trace, source_info, x, *dyn, new_sizes, dimensions, sharding): + trace, source_info, x, new_sizes, dimensions, sharding): params = dict(new_sizes=new_sizes, dimensions=dimensions, sharding=sharding) - if not dyn: - return trace.default_process_primitive(reshape_p, (x,), params, - source_info=source_info) - av = core.DShapedArray(_merge_dyn_shape(new_sizes, dyn), x.dtype, x.weak_type) - return _dyn_shape_staging_rule(trace, source_info, reshape_p, av, x, *dyn, - **params) + return trace.default_process_primitive(reshape_p, (x,), params, + source_info=source_info) reshape_p = standard_primitive(_reshape_shape_rule, _reshape_dtype_rule, 'reshape', sharding_rule=_reshape_sharding_rule, @@ -7504,12 +7224,8 @@ def _transpose_reduced_rule(out_s, operand, *, permutation): def _transpose_batch_rule(batched_args, batch_dims, *, permutation): operand, = batched_args bdim, = batch_dims - stack_dim = bdim.stacked_axis if isinstance(bdim, RaggedAxis) else bdim - perm = (stack_dim,) + tuple(i if i < stack_dim else i+1 for i in permutation) - if isinstance(bdim, RaggedAxis): - res_bdim = batching.transpose_ragged_axes(bdim.move_stacked_axis(0), perm) - else: - res_bdim = 0 + perm = (bdim,) + tuple(i if i < bdim else i+1 for i in permutation) + res_bdim = 0 return transpose(operand, perm), res_bdim def _transpose_lower(ctx, x, *, permutation): @@ -7531,7 +7247,6 @@ def _transpose_lower(ctx, x, *, permutation): lambda t, _, permutation: [transpose(t, np.argsort(permutation))]) batching.primitive_batchers[transpose_p] = _transpose_batch_rule mlir.register_lowering(transpose_p, _transpose_lower) -pe.def_trivial_padding(transpose_p) def _select_shape_rule(which, *cases): @@ -7716,7 +7431,6 @@ def _select(offset, cases): batching.fancy_primitive_batchers[select_n_p] = _select_batch_rule batching.skippable_batchers[select_n_p] = lambda _: () mlir.register_lowering(select_n_p, _select_hlo_lowering) -pe.def_trivial_padding(select_n_p) def _reduce_shape_rule(*avals, computation, jaxpr, dimensions): @@ -7930,9 +7644,6 @@ def _reduce_sum_reduced_rule(out_s, operand, *, axes, **kwargs): reduced_rule=_reduce_sum_reduced_rule) ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule) batching.defreducer(reduce_sum_p, _get_sum_identity) -pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, reduce_sum, - _get_sum_identity) -batching.ragged_prop_rules[reduce_sum_p] = batching.ragged_mask_elementwise_rule def _reduce_prod_jvp_rule(primals, tangents, *, axes): reducer = lambda x, y: [mul(x, y)] @@ -7952,9 +7663,6 @@ def _reduce_op_sharding_rule(operand, *, axes): vma_rule=partial(core.standard_vma_rule, 'reduce_prod')) ad.primitive_jvps[reduce_prod_p] = _reduce_prod_jvp_rule batching.defreducer(reduce_prod_p, _get_prod_identity) -pe.padding_rules[reduce_prod_p] = partial(_reducer_padding, reduce_prod, - _get_prod_identity) - def _reduce_chooser_jvp_rule(g, ans, operand, *, axes): # TODO(mattjj): an alternative is to use variadic reduce to compute the chosen @@ -7973,9 +7681,6 @@ def _reduce_chooser_jvp_rule(g, ans, operand, *, axes): vma_rule=partial(core.standard_vma_rule, 'reduce_max')) ad.defjvp2(reduce_max_p, _reduce_chooser_jvp_rule) batching.defreducer(reduce_max_p, _get_max_identity) -pe.padding_rules[reduce_max_p] = partial(_reducer_padding, reduce_max, - _get_max_identity) -batching.ragged_prop_rules[reduce_max_p] = batching.ragged_mask_elementwise_rule reduce_min_p = standard_primitive( @@ -7984,9 +7689,6 @@ def _reduce_chooser_jvp_rule(g, ans, operand, *, axes): vma_rule=partial(core.standard_vma_rule, 'reduce_min')) ad.defjvp2(reduce_min_p, _reduce_chooser_jvp_rule) batching.defreducer(reduce_min_p, _get_min_identity) -pe.padding_rules[reduce_min_p] = partial(_reducer_padding, reduce_min, - _get_min_identity) - def _argminmax_shape_rule(operand, *, axes, index_dtype): axis, = axes @@ -8105,7 +7807,6 @@ def _reduce_or_lin(nzs, x, *, axes): weak_type_rule=_strip_weak_type, sharding_rule=_reduce_logical_sharding_rule, vma_rule=partial(core.standard_vma_rule, 'reduce_and')) batching.defreducer(reduce_and_p, _get_bitwise_and_identity) -batching.ragged_prop_rules[reduce_and_p] = batching.ragged_mask_elementwise_rule reduce_xor_p = standard_primitive( @@ -8446,7 +8147,6 @@ def _stop_gradient_batch_rule(batched_args, batch_dims): ad.primitive_jvps[ad_util.stop_gradient_p] = _stop_gradient_jvp_rule batching.primitive_batchers[ad_util.stop_gradient_p] = _stop_gradient_batch_rule -pe.def_trivial_padding(ad_util.stop_gradient_p) def create_token(_=None): @@ -8699,7 +8399,6 @@ def _copy_impl(prim, *args, **kwargs): copy_p.def_abstract_eval(lambda x: x) mlir.register_lowering(copy_p, lambda ctx, x: [x]) ad.deflinear(copy_p, lambda t: [copy_p.bind(t)]) -pe.def_trivial_padding(copy_p) batching.defvectorized(copy_p) # The dce_sink_p primitive marks a value as "used" from the perspective of DCE @@ -8719,7 +8418,6 @@ class NoDCEEffect(effects.Effect): dce_sink_p.def_effectful_abstract_eval(lambda _: ([], {no_dce_effect})) mlir.register_lowering(dce_sink_p, lambda ctx, _: []) ad.deflinear(dce_sink_p, lambda _: []) -pe.def_trivial_padding(dce_sink_p) batching.primitive_batchers[dce_sink_p] = lambda x, bd: (x, bd) def rng_bit_generator(key, shape, dtype=np.uint32, @@ -8750,10 +8448,9 @@ def rng_bit_generator(key, shape, dtype=np.uint32, out_sharding=out_sharding)) -def _iota_abstract_eval(*dyn_shape, dtype, shape, dimension, sharding): - if not dyn_shape: - # TODO(mattjj) Generalize shape_like checking to permit dynamic shapes - _check_shapelike("iota", "shape", shape) +def _iota_abstract_eval(dtype, shape, dimension, sharding): + # TODO(mattjj) Generalize shape_like checking to permit dynamic shapes + _check_shapelike("iota", "shape", shape) if not any(dtypes.issubdtype(dtype, t) for t in _num): msg = 'iota does not accept dtype {}. Accepted dtypes are subtypes of {}.' typename = dtype_to_string(dtype) @@ -8762,88 +8459,35 @@ def _iota_abstract_eval(*dyn_shape, dtype, shape, dimension, sharding): if not 0 <= dimension < len(shape): raise ValueError("iota dimension must be between 0 and len(shape), got " f"{dimension=} for {shape=}") - if (not dyn_shape and - not any(isinstance(d, core.DArray) and - type(core.get_aval(d).dtype) is core.bint for d in shape)): - if sharding is None: - sharding = core.get_cur_mesh_sharding(spec=core.P(*[None] * len(shape))) - return ShapedArray(shape, dtype, sharding=sharding) - # TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code - return core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), dtype, False) - + if sharding is None: + sharding = core.get_cur_mesh_sharding(spec=core.P(*[None] * len(shape))) + return ShapedArray(shape, dtype, sharding=sharding) iota_p = Primitive('iota') iota_p.def_impl(partial(dispatch.apply_primitive, iota_p)) iota_p.def_abstract_eval(_iota_abstract_eval) -batching.ragged_prop_rules[iota_p] = batching.ragged_mask_no_op_rule -def _iota_staging_rule(trace, source_info, *dyn_shape, dtype, shape, dimension, +def _iota_staging_rule(trace, source_info, dtype, shape, dimension, sharding): params = dict(dtype=dtype, shape=shape, dimension=dimension, sharding=sharding) - if not dyn_shape: - return trace.default_process_primitive(iota_p, (), params, + return trace.default_process_primitive(iota_p, (), params, source_info=source_info) - aval = core.DShapedArray(_merge_dyn_shape(shape, dyn_shape), dtype, False) - return _dyn_shape_staging_rule(trace, source_info, iota_p, aval, *dyn_shape, - **params) pe.custom_staging_rules[iota_p] = _iota_staging_rule -def _iota_typecheck_rule(_, *dyn_shape, dtype, shape, dimension, sharding): - if not dyn_shape: - out_aval, effects = iota_p.abstract_eval( - dtype=dtype, shape=shape, dimension=dimension, sharding=sharding) - return [out_aval], effects - else: - out_shape = _merge_dyn_shape(shape, dyn_shape) - out_shape = [x.val if type(x) is core.Literal else x for x in out_shape] # pytype: disable=attribute-error - out_aval = core.DShapedArray(tuple(out_shape), dtype, False) - return [out_aval], core.no_effects +def _iota_typecheck_rule(_, dtype, shape, dimension, sharding): + out_aval, effects = iota_p.abstract_eval( + dtype=dtype, shape=shape, dimension=dimension, sharding=sharding) + return [out_aval], effects core.custom_typechecks[iota_p] = _iota_typecheck_rule -def _iota_lower(ctx, *dyn_shape, dtype, shape, dimension, sharding): +def _iota_lower(ctx, dtype, shape, dimension, sharding): del dtype aval_out, = ctx.avals_out - if dyn_shape: - aval_out = aval_out.update(shape=_merge_dyn_shape(shape, dyn_shape)) out = mlir.iota(ctx, aval_out, dimension=dimension) return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)] mlir.register_lowering(iota_p, _iota_lower) -def _iota_batching_rule(in_vals, in_dims, *, dtype, shape, dimension, - sharding): - (segment_lengths,), (ax,) = in_vals, in_dims - assert ax == 0 - bound = segment_lengths.dtype.bound - ragged_axis, = (i for i, dim in enumerate(shape) if dim is None) - shape = (len(segment_lengths),) + _merge_dyn_shape(shape, (bound,)) - if sharding is not None: - raise NotImplementedError('Please file an issue if you want this support') - iota = broadcasted_iota(dtype, shape, dimension+1) - return iota, batching.RaggedAxis(ax, ((ragged_axis+1, segment_lengths),)) -batching.primitive_batchers[iota_p] = _iota_batching_rule - -def _iota_padding_rule(in_avals, out_avals, *dyn_shape, dtype, shape, dimension, - sharding): - out_aval, = out_avals - new_shape = [] - new_dyn_shape = [] - for d in out_aval.shape: - if type(d) is pe.BoundedAxisSize: - new_shape.append(d.bound) - elif type(d) is int: - new_shape.append(d) - else: - assert isinstance(d, core.Tracer) - new_shape.append(None) - new_dyn_shape.append(d) - if sharding is not None: - raise NotImplementedError('Please file an issue if you want this support') - return [iota_p.bind(*new_dyn_shape, shape=tuple(new_shape), - dtype=dtype, dimension=dimension, sharding=sharding)] -pe.padding_rules[iota_p] = _iota_padding_rule - - ### util _ndim = np.ndim @@ -8986,9 +8630,6 @@ def _check_shapelike(fun_name, arg_name, obj, non_zero_shape=False): # bool(obj) for an ndarray raises an error, so we check len if not len(obj): # pylint: disable=g-explicit-length-test return - if (config.dynamic_shapes.value and isinstance(obj, (tuple, list)) and - any(isinstance(d, (core.Tracer, core.DArray)) for d in obj)): - return # TODO(mattjj): handle more checks in the dynamic shape case obj_arr = np.array(obj) if obj_arr.ndim != 1: msg = "{} {} must be 1-dimensional, got {}." @@ -9219,47 +8860,9 @@ def _empty2_lower(ctx, *, dtype, memory_space): ad.primitive_jvps[tie_p] = \ lambda primals, tangents: (tie_p.bind(*primals), tangents[-1]) ad.primitive_transposes[tie_p] = lambda ct, x, _: [None, ct] -pe.def_trivial_padding(tie_p) batching.defvectorized(tie_p) -class BIntRules: - allow_conversion: bool = True - - @staticmethod - def physical_element_aval(dtype) -> core.ShapedArray: - return core.ShapedArray((), np.dtype('int32')) - - @staticmethod - def result_handler(sticky_device, aval): - def handler(_, buf): - buf.aval = core.ShapedArray(buf.shape, buf.dtype) - return core.DArray(aval, buf) - return handler - - @staticmethod - def global_sharded_result_handler(aval, out_sharding, committed): - phys_aval = core.physical_aval(aval) - phys_handler_maker = pxla.global_result_handlers[core.ShapedArray] - - if not dispatch.is_single_device_sharding(out_sharding): - raise NotImplementedError # TODO(mattjj) - else: - phys_sharding = out_sharding - phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed) - if jaxlib_extension_version >= 390: - def handler(arr): - return core.DArray(aval, arr) - return phys_handler.wrap(handler) - else: - def handler(bufs): - return core.DArray(aval, phys_handler(bufs)) - return handler - - -core.bint._rules = BIntRules - - def optimization_barrier(operand, /): """Prevents the compiler from moving operations across the barrier. diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index efcb930012bb..e102951cbf58 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -26,7 +26,6 @@ from jax._src import ad_util from jax._src import api -from jax._src import config from jax._src import core from jax._src import dispatch from jax._src import dtypes @@ -174,15 +173,9 @@ def dynamic_slice( """ start_indices = _dynamic_slice_indices( operand, start_indices, allow_negative_indices) - if config.dynamic_shapes.value: - dynamic_sizes, static_sizes = lax._extract_tracers_dyn_shape(slice_sizes) - else: - dynamic_sizes = [] - static_sizes = core.canonicalize_shape(slice_sizes) # type: ignore - operand, *start_indices = core.standard_insert_pvary( - operand, *start_indices) - return dynamic_slice_p.bind(operand, *start_indices, *dynamic_sizes, - slice_sizes=tuple(static_sizes)) + sizes = core.canonicalize_shape(slice_sizes) # type: ignore + operand, *start_indices = core.standard_insert_pvary(operand, *start_indices) + return dynamic_slice_p.bind(operand, *start_indices, slice_sizes=tuple(sizes)) def dynamic_update_slice( @@ -1369,11 +1362,10 @@ def _slice_shape_rule(operand, *, start_indices, limit_indices, strides): msg = ("slice start_indices must be greater than or equal to zero, " "got start_indices of {}.") raise TypeError(msg.format(start_indices)) - if not config.dynamic_shapes.value: - if not all(map(operator.ge, limit_indices, start_indices)): - msg = ("slice limit_indices must be greater than or equal to start_indices," - " got start_indices {} and limit_indices {}.") - raise TypeError(msg.format(start_indices, limit_indices)) + if not all(map(operator.ge, limit_indices, start_indices)): + msg = ("slice limit_indices must be greater than or equal to start_indices," + " got start_indices {} and limit_indices {}.") + raise TypeError(msg.format(start_indices, limit_indices)) diff = tuple(map(operator.sub, limit_indices, start_indices)) if strides is None or tuple(strides) == (1,) * len(operand.shape): return diff @@ -1485,9 +1477,6 @@ def _slice_batching_rule(batched_args, batch_dims, *, start_indices, ad.deflinear2(slice_p, _slice_transpose_rule) ad.fancy_transposes[slice_p] = _slice_transpose_fancy batching.primitive_batchers[slice_p] = _slice_batching_rule -# TODO(mvoz): A better slice rule for ragged prop, enforcing boundaries -# or supporting nested jumbles. NYI. -batching.ragged_prop_rules[slice_p] = batching.ragged_mask_no_op_rule # Override the standard impl to defer to dynamic_slice whenever possible. # This lets us reuse the same program for many applications of slicing for as @@ -1514,28 +1503,19 @@ def _slice_lower(ctx, x, *, start_indices, limit_indices, strides): mlir.register_lowering(slice_p, _slice_lower) -def _dynamic_slice_shape_rule(operand, *starts_and_dyn_sizes, slice_sizes): - start_indices, dyn = util.split_list(starts_and_dyn_sizes, [operand.ndim]) - if operand.ndim != len(start_indices): - msg = ("dynamic_slice start_indices must have length equal to the number " - "of dimensions of the operand, got indices {} for operand shape {}.") - raise TypeError(msg.format(start_indices, operand.shape)) - if len(start_indices) != len(slice_sizes): - msg = ("dynamic_slice slice_sizes must have the same length as " - "start_indices, got start_indices length {} and slice_sizes {}.") - raise TypeError(msg.format(len(start_indices), slice_sizes)) - if not dyn and not all(map(operator.ge, operand.shape, slice_sizes)): +def _dynamic_slice_shape_rule(operand, *start_indices, slice_sizes): + if not all(map(operator.ge, operand.shape, slice_sizes)): msg = ("slice slice_sizes must be less than or equal to operand shape, " "got slice_sizes {} for operand shape {}.") raise TypeError(msg.format(slice_sizes, operand.shape)) - if not dyn and not all(ssz >= 0 for ssz in slice_sizes): + if not all(ssz >= 0 for ssz in slice_sizes): msg = ("slice slice_sizes must be greater than or equal to zero, " "got slice_sizes of {}.") raise TypeError(msg.format(slice_sizes)) if any(idx.ndim != 0 for idx in start_indices): raise TypeError("start_indices arguments to dynamic_slice must be scalars, " f" got indices {start_indices}") - return tuple(lax._merge_dyn_shape(slice_sizes, dyn)) + return tuple(slice_sizes) def _dynamic_slice_sharding_rule(operand, *starts_and_dyn_sizes, slice_sizes): out_shape = _dynamic_slice_shape_rule( @@ -1626,42 +1606,16 @@ def _dynamic_slice_batching_rule(batched_args, batch_dims, *, slice_sizes): slice_sizes=slice_sizes, unique_indices=True, indices_are_sorted=True, mode=GatherScatterMode.PROMISE_IN_BOUNDS, fill_value=None) -def _dynamic_slice_staging_rule(trace, source_info, x, *starts_and_dyn_sizes, +def _dynamic_slice_staging_rule(trace, source_info, x, *start_indices, slice_sizes): - start_indices, dyn = util.split_list(starts_and_dyn_sizes, [x.ndim]) - if not dyn: - return trace.default_process_primitive( - dynamic_slice_p, (x, *start_indices), dict(slice_sizes=slice_sizes), - source_info=source_info) - shape = lax._merge_dyn_shape(slice_sizes, dyn) - aval = core.DShapedArray(shape, x.dtype, False) - return lax._dyn_shape_staging_rule(trace, source_info, dynamic_slice_p, aval, - x, *starts_and_dyn_sizes, - slice_sizes=slice_sizes) - -def _dynamic_slice_typecheck_rule(_, x, *starts_and_dyn_sizes, slice_sizes): - start_indices, dyn = util.split_list(starts_and_dyn_sizes, [x.aval.ndim]) - if not dyn: - out_aval, effects = dynamic_slice_p.abstract_eval( - x.aval, *(d.aval for d in start_indices), slice_sizes=slice_sizes) - return [out_aval], effects - else: - # TODO(mattjj): perform more checks - out_shape = lax._merge_dyn_shape(slice_sizes, dyn) - out_shape = [d.val if type(d) is core.Literal else d for d in out_shape] - out_aval = core.DShapedArray(tuple(out_shape), x.aval.dtype, - x.aval.weak_type) - return [out_aval], core.no_effects - -def _dynamic_slice_padding_rule(in_avals, out_avals, x, *starts_and_dyn, - slice_sizes): - x_aval, start_indices_avals, dyn_avals = util.split_list(in_avals, [1, x.ndim]) - start_indices, dyn = util.split_list(starts_and_dyn, [x.ndim]) - dyn_ = [a.dtype.bound if type(a.dtype) is core.bint else d - for a, d in zip(dyn_avals, dyn)] - slice_sizes_ = lax._merge_dyn_shape(slice_sizes, dyn_) - start_idx = [d.val if type(d) is core.DArray else d for d in start_indices] - return [dynamic_slice(x, start_idx, slice_sizes_)] + return trace.default_process_primitive( + dynamic_slice_p, (x, *start_indices), dict(slice_sizes=slice_sizes), + source_info=source_info) + +def _dynamic_slice_typecheck_rule(_, x, *start_indices, slice_sizes): + out_aval, effects = dynamic_slice_p.abstract_eval( + x.aval, *(d.aval for d in start_indices), slice_sizes=slice_sizes) + return [out_aval], effects dynamic_slice_p = standard_primitive( _dynamic_slice_shape_rule, _dynamic_slice_dtype_rule, 'dynamic_slice', @@ -1675,14 +1629,10 @@ def _dynamic_slice_padding_rule(in_avals, out_avals, x, *starts_and_dyn, batching.primitive_batchers[dynamic_slice_p] = _dynamic_slice_batching_rule pe.custom_staging_rules[dynamic_slice_p] = _dynamic_slice_staging_rule core.custom_typechecks[dynamic_slice_p] = _dynamic_slice_typecheck_rule -pe.padding_rules[dynamic_slice_p] = _dynamic_slice_padding_rule -def _dynamic_slice_lower(ctx, x, *starts_and_dyn_sizes, slice_sizes): +def _dynamic_slice_lower(ctx, x, *start_indices, slice_sizes): x_aval, *_ = ctx.avals_in - start_indices, dyn = util.split_list(starts_and_dyn_sizes, [x_aval.ndim]) aval_out, = ctx.avals_out - if dyn: - aval_out = aval_out.update(shape=lax._merge_dyn_shape(slice_sizes, dyn)) out = mlir.dynamic_slice(ctx, aval_out, x, start_indices=start_indices) return [mlir.lower_with_sharding_in_types(ctx, out, aval_out)] @@ -2249,12 +2199,12 @@ def _gather_transpose_rule(t, operand, indices, *, dimension_numbers, def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): - operand, indices, *dyn_slice_sizes = batched_args - operand_bdim, indices_bdim, *dyn_slice_size_bds = batch_dims - dyn_slice_size_bounds = [b.dtype.bound for b in dyn_slice_sizes] + operand, indices = batched_args + operand_bdim, indices_bdim = batch_dims if operand_bdim is not None and indices_bdim is None: - operand, operand_bdim = batching.move_stacked_axis(operand, operand_bdim, 0) + operand = batching.moveaxis(operand, operand_bdim, 0) + operand_bdim = 0 slice_sizes = (operand.shape[0],) + slice_sizes offset_dims = (0,) + tuple(np.add(1, dimension_numbers.offset_dims)) collapsed_slice_dims = tuple(np.add(1, dimension_numbers.collapsed_slice_dims)) @@ -2269,29 +2219,10 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers, operand_batching_dims=operand_batching_dims, start_indices_batching_dims=dimension_numbers.start_indices_batching_dims, ) - if isinstance(operand_bdim, batching.RaggedAxis): - ragged_slice_sizes = batching.bdim_as_shape(operand_bdim, slice_sizes) - for orig, fabricated in zip( - lax._merge_dyn_shape(slice_sizes, dyn_slice_sizes), - ragged_slice_sizes): - if isinstance(fabricated, batching.IndexedAxisSize): - if not core.same_referent(orig, fabricated.lengths): - # Don't know what to do when slicing a ragged dimension with a - # different size. To wit, if the client tries to index outside the - # ragged size, the resulting element should be determined by the - # out of bounds `mode`, but the underlying gather will only do that - # if the client tries to index outside the _padded_ array. I guess - # we should read the mode and apply a mask that writes the correct - # fill element into all out-of-bounds locations? - raise NotImplementedError - bdim_out = batching.shape_as_bdim( - operand_bdim.stacked_axis, - _gather_shape_computation(indices, dnums, ragged_slice_sizes)) - else: - bdim_out = operand_bdim + bdim_out = operand_bdim return gather( operand, indices, dimension_numbers=dnums, - slice_sizes=lax._merge_dyn_shape(slice_sizes, dyn_slice_size_bounds), + slice_sizes=slice_sizes, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value), bdim_out @@ -2348,18 +2279,6 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers, indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value), 0 -def _gather_pad_rule(in_avals, out_avals, operand, indices, *, - dimension_numbers, slice_sizes, unique_indices, - indices_are_sorted, mode, fill_value): - operand_aval, indices_aval = in_avals - if any(isinstance(d, pe.BoundedAxisSize) for d in operand_aval.shape): - raise NotImplementedError - if mode != GatherScatterMode.PROMISE_IN_BOUNDS: - # with fill, jnp.where on operand; with clip, jnp.where on indices - raise NotImplementedError - return [gather(operand, indices, dimension_numbers=dimension_numbers, - slice_sizes=slice_sizes, mode=mode, fill_value=fill_value)] - gather_p = standard_primitive( _gather_shape_rule, _gather_dtype_rule, 'gather', weak_type_rule=_argnum_weak_type(0), sharding_rule=_gather_sharding_rule, @@ -2367,7 +2286,6 @@ def _gather_pad_rule(in_avals, out_avals, operand, indices, *, ad.defjvp(gather_p, _gather_jvp_rule, None) ad.primitive_transposes[gather_p] = _gather_transpose_rule batching.primitive_batchers[gather_p] = _gather_batching_rule -pe.padding_rules[gather_p] = _gather_pad_rule def _gather_lower_opaque(ctx, operand, indices, *, diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index e24bea8fb6ff..669ffc510ae0 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -199,11 +199,6 @@ def standard_abstract_eval( vma=out_vma, memory_space=out_mem_space) core.check_avals_context_mesh([out_aval], prim.name) return out_aval - elif least_specialized is core.DShapedArray: - shape = shape_rule(*avals, **kwargs) - ty = (core.ShapedArray if all(type(d) is int for d in shape) - else core.DShapedArray) - return ty(shape, dtype_rule(*avals, **kwargs), weak_type) else: raise TypeError(avals, least_specialized) diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index 5433e9769698..bf45604c2f5e 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -426,33 +426,8 @@ def annotate(f: WrappedFun, in_type: core.InputType | None) -> WrappedFun: def _check_input_type(in_type: core.InputType) -> None: # Check that in_type is syntactically well-formed - assert type(in_type) is tuple and all(type(e) is tuple for e in in_type) - assert all(isinstance(a, core.AbstractValue) and type(b) is bool - for a, b in in_type) - - def valid_size(d) -> bool: - if isinstance(d, core.DBIdx) and type(d.val) is int and d.val >= 0: - return True - return (isinstance(d, (int, core.DBIdx, core.DArray)) and - (not isinstance(d, core.DArray) or type(d) is core.bint and not d.shape)) - assert all(valid_size(d) for a, _ in in_type if type(a) is core.DShapedArray - for d in a.shape) - - # Check that all DBIdx point to positions to the left of the input on which - # they appear. - assert all(d.val < i for i, (aval, _) in enumerate(in_type) - if isinstance(aval, core.DShapedArray) for d in aval.shape - if isinstance(d, core.DBIdx)) - - # Check that all implicit arguments have at least one DBIdx pointing to them. - provided = [e for _, e in in_type] - for aval, _ in in_type: - if type(aval) is core.DShapedArray: - for d in aval.shape: - if isinstance(d, core.DBIdx): - provided[d.val] = True - assert all(provided) - + assert type(in_type) is tuple + assert all(isinstance(a, core.AbstractValue) for a in in_type) def cache(call: Callable, *, explain: Callable[[WrappedFun, bool, dict, tuple, float], None] | None = None): diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index e713ee8c53ed..e0496d157b67 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -1212,7 +1212,6 @@ def _set_array_abstract_methods(basearray): def register_jax_array_methods(): """Call this function once to register methods of JAX arrays""" _set_shaped_array_attributes(core.ShapedArray) - _set_shaped_array_attributes(core.DShapedArray) _set_array_base_attributes(ArrayImpl, exclude={'__getitem__'}) _set_tracer_aval_forwarding(core.Tracer, exclude={*_impl_only_array_methods, "at"}) diff --git a/jax/_src/numpy/einsum.py b/jax/_src/numpy/einsum.py index 761756f56780..b8f72081df62 100644 --- a/jax/_src/numpy/einsum.py +++ b/jax/_src/numpy/einsum.py @@ -20,7 +20,6 @@ import opt_einsum from jax._src import api -from jax._src import config from jax._src import core from jax._src import dtypes from jax._src.export import shape_poly @@ -537,7 +536,7 @@ def filter_singleton_dims(operand, names, other_shape, other_names): # NOTE(mattjj): this can fail non-deterministically in python3, maybe # due to opt_einsum - assert config.dynamic_shapes.value or all( + assert all( name in lhs_names and name in rhs_names and lhs.shape[lhs_names.index(name)] == rhs.shape[rhs_names.index(name)] for name in contracted_names), ( diff --git a/jax/_src/numpy/indexing.py b/jax/_src/numpy/indexing.py index 1e65d9c37e16..cf58764c6293 100644 --- a/jax/_src/numpy/indexing.py +++ b/jax/_src/numpy/indexing.py @@ -765,20 +765,6 @@ def rewriting_take( if result is not None: return result - # otherwise, strategy is GATHER or SCATTER - - # TODO(mattjj,dougalm): expand dynamic shape indexing support - if config.dynamic_shapes.value and arr.ndim > 0: - try: aval = core.get_aval(idx) - except: pass - else: - if (isinstance(aval, core.DShapedArray) and aval.shape == () and - dtypes.issubdtype(aval.dtype, np.integer) and - not dtypes.issubdtype(aval.dtype, dtypes.bool_) and - isinstance(arr.shape[0], int)): - assert isinstance(idx, (int, Array)) - return slicing.dynamic_index_in_dim(arr, idx, keepdims=False) - treedef, static_idx, dynamic_idx = split_index_for_jit(idx, arr.shape) internal_gather = partial( _gather, treedef=treedef, static_idx=static_idx, diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 4d7d992545bc..53a78cf6c8ac 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -49,23 +49,16 @@ def promote_shapes(fun_name: str, *args: ArrayLike) -> list[Array]: return [lax.asarray(arg) for arg in args] else: shapes = [np.shape(arg) for arg in args] - if config.dynamic_shapes.value: - # With dynamic shapes we don't support singleton-dimension broadcasting; - # we instead broadcast out to the full shape as a temporary workaround. - # TODO(mattjj): revise this workaround - res_shape = lax.broadcast_shapes(*shapes) # Can raise an error! - return [_broadcast_to(arg, res_shape) for arg, shp in zip(args, shapes)] + if all(len(shapes[0]) == len(s) for s in shapes[1:]): + return [lax.asarray(arg) for arg in args] # no need for rank promotion, so rely on lax promotion + nonscalar_ranks = {len(shp) for shp in shapes if shp} + if len(nonscalar_ranks) < 2: + return [lax.asarray(arg) for arg in args] # rely on lax scalar promotion else: - if all(len(shapes[0]) == len(s) for s in shapes[1:]): - return [lax.asarray(arg) for arg in args] # no need for rank promotion, so rely on lax promotion - nonscalar_ranks = {len(shp) for shp in shapes if shp} - if len(nonscalar_ranks) < 2: - return [lax.asarray(arg) for arg in args] # rely on lax scalar promotion - else: - if config.numpy_rank_promotion.value != "allow": - _rank_promotion_warning_or_error(fun_name, shapes) - result_rank = len(lax.broadcast_shapes(*shapes)) - return [lax.broadcast_to_rank(arg, result_rank) for arg in args] + if config.numpy_rank_promotion.value != "allow": + _rank_promotion_warning_or_error(fun_name, shapes) + result_rank = len(lax.broadcast_shapes(*shapes)) + return [lax.broadcast_to_rank(arg, result_rank) for arg in args] def _rank_promotion_warning_or_error(fun_name: str, shapes: Sequence[Shape]): diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 8a16559a08ae..d6747b1d8ed9 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -529,15 +529,7 @@ def to_block_mapping( ) ref_block_shape = _get_ref_block_shape(block_shape) - if isinstance(array_aval, jax_core.DShapedArray): - # Get the "max" shape for the ragged array. - block_array_aval = array_aval.update(shape=ref_block_shape) - block_array_aval = jax_core.ShapedArray( - block_array_aval.shape, - block_array_aval.dtype, - block_array_aval.weak_type, - ) - elif isinstance(array_aval, ShapedArrayWithMemorySpace): + if isinstance(array_aval, ShapedArrayWithMemorySpace): block_array_aval = jax_core.ShapedArray( ref_block_shape, array_aval.dtype, array_aval.weak_type ) @@ -618,10 +610,6 @@ def to_block_mapping( f"{origin} must not capture constants: {consts}" ) - if isinstance(array_aval, (jax_core.ShapedArray, jax_core.DShapedArray)): - array_aval_shape = _max_shape_from_aval(array_aval) - array_aval = array_aval.update(shape=array_aval_shape) - mapping = BlockMapping( block_shape=block_shape, transformed_block_aval=block_aval, # There are no transforms by default @@ -1064,8 +1052,6 @@ def _max_shape_from_aval(array_aval: jax_core.ShapedArray): for i, s in enumerate(array_aval.shape): try: aval = jax_core.get_aval(s) - if isinstance(aval, jax_core.DShapedArray): - array_aval_shape[i] = aval.dtype.bound except OverflowError as e: # Note - there are annoying cases where on 32 bit hardware, # a flattened index space may overflow - for these cases, diff --git a/jax/_src/pallas/hlo_interpreter.py b/jax/_src/pallas/hlo_interpreter.py index ec100e1d796e..f3faddbc13dd 100644 --- a/jax/_src/pallas/hlo_interpreter.py +++ b/jax/_src/pallas/hlo_interpreter.py @@ -411,9 +411,6 @@ def pallas_call_hlo_interpret( # to catch OOB accesses. for carry_element in carry: aval = carry_element.aval - if isinstance(aval, jax_core.DShapedArray): - aval = jax_core.ShapedArray(aval.shape, aval.dtype) - carry_element.aval = aval carry = map(_pad_to_block_dimension, carry, block_shapes) carry.extend(scratch_values) diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 4811aa9f12d6..85d3d4cf88a9 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -46,8 +46,7 @@ def _maybe_cast_to_int(x: jax.Array | jax_core.AbstractValue): after loading from a memref inside of the kernel. """ assert isinstance( - x, (jax.Array, jax_core.ShapedArray, jax_core.DShapedArray, - state_types.AbstractLinVal) + x, (jax.Array, jax_core.ShapedArray, state_types.AbstractLinVal) ), type(x) if isinstance(x, jax.Array): if dtypes.issubdtype(x.dtype, jax.numpy.bool_): diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 1621ba411369..57884b5c67e2 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -43,7 +43,6 @@ from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.pallas import core as pallas_core -from jax._src.pallas import helpers as pallas_helpers from jax._src.pallas import hlo_interpreter from jax._src.pallas import primitives from jax._src.state import discharge as state_discharge @@ -329,17 +328,12 @@ def _pallas_call_jvp_rule( def _batch_block_mapping( grid_mapping: GridMapping, axis_size: int, - for_ragged: bool, aval: jax_core.ShapedArray, dim: int | batching.NotMapped, block_mapping: BlockMapping, - ragged_axis_values, ) -> BlockMapping: def _block_map_function(new_idx, *args): - if for_ragged: - drop_last_args = args[:-1] - else: - drop_last_args = args + drop_last_args = args indices = jax_core.eval_jaxpr( block_mapping.index_map_jaxpr.jaxpr, @@ -352,27 +346,10 @@ def _block_map_function(new_idx, *args): unflat_indices = (unflat_indices,) unflat_indices = list(unflat_indices) if dim is not batching.not_mapped: - if isinstance(dim, batching.RaggedAxis): - assert for_ragged, "Ragged axis not supported for non-ragged batching." - stacked_axis = dim.stacked_axis - unflat_indices.insert(stacked_axis, new_idx) - else: - unflat_indices.insert(dim, new_idx) + unflat_indices.insert(dim, new_idx) return tuple(unflat_indices) idx_avals = [pallas_core.index_map_grid_aval, *block_mapping.index_map_jaxpr.in_avals] - if for_ragged: - if isinstance(dim, batching.RaggedAxis): - assert for_ragged, "Ragged axis not supported for non-ragged batching." - _, _, _, lengths_aval = ragged_axis_values - idx_avals = [*idx_avals, lengths_aval] - else: - i32_aval_memref = state.AbstractRef( - jax_core.ShapedArray(([axis_size]), jnp.int32), - pallas_core.MemorySpace.INDEX, - ) - idx_avals = [*idx_avals, i32_aval_memref] - block_mapping_flat_fn, out_tree_thunk = api_util.flatten_fun_nokwargs( lu.wrap_init(_block_map_function, debug_info=block_mapping.index_map_jaxpr.jaxpr.debug_info.with_unknown_names()), @@ -387,23 +364,10 @@ def _block_map_function(new_idx, *args): new_block_shape = shape new_array_aval = block_mapping.array_aval else: - if isinstance(dim, batching.RaggedAxis): - assert for_ragged, "Ragged axis not supported for non-ragged batching." - new_block_shape = shape - stacked_axis = dim.stacked_axis - new_block_shape = tuple_insert( - new_block_shape, stacked_axis, pallas_core.squeezed - ) - else: - new_block_shape = tuple_insert(shape, dim, pallas_core.squeezed) + new_block_shape = tuple_insert(shape, dim, pallas_core.squeezed) array_shape = block_mapping.array_aval.shape - if isinstance(dim, batching.RaggedAxis): - assert for_ragged, "Ragged axis not supported for non-ragged batching." - stacked_axis = dim.stacked_axis - array_shape = tuple_insert(array_shape, stacked_axis, axis_size) - else: - array_shape = tuple_insert(array_shape, dim, axis_size) + array_shape = tuple_insert(array_shape, dim, axis_size) new_array_aval = jax_core.ShapedArray( array_shape, block_mapping.array_aval.dtype @@ -437,12 +401,6 @@ def _broadcast_input_output_aliases( for input_index, _ in input_output_aliases: dim = dims_[input_index] dims_[input_index] = 0 - if isinstance(dim, batching.RaggedAxis): - stacked_axis = dim.stacked_axis - if stacked_axis != 0: - raise NotImplementedError("Ragged aliasing on non 0 dim NYI") - return tuple(args_), tuple(dims_) - if dim is batching.not_mapped: args_[input_index] = batching.broadcast( args_[input_index], axis_size, 0, None) @@ -585,16 +543,8 @@ def _maybe_squeeze_out_bdim( return x return jnp.squeeze(x, axis=bdim) - def get_size(i, x, d): - if not isinstance(d, batching.RaggedAxis): - return x.shape[d] - return x.aval.shape[d.stacked_axis] - - (axis_size,) = { - get_size(i=i, x=x, d=d) - for i, (x, d) in enumerate(zip(args, dims)) - if d is not batching.not_mapped - } + axis_size, = {x.shape[d] for i, (x, d) in enumerate(zip(args, dims)) + if d is not batching.not_mapped} if axis_size == 1: # Why are we even vmapping? args = map(_maybe_squeeze_out_bdim, args, dims) @@ -702,30 +652,7 @@ def get_size(i, x, d): args, dims, input_output_aliases=input_output_aliases, axis_size=axis_size ) - # Each dim either has data about its ragged axis, or None - ragged_axis_values = [] - for d in dims: - if isinstance(d, batching.RaggedAxis): - stacked_axis, ragged_axis_dim, ragged_axis_length = ( - batching._ragged_axis_parts(d) - ) - aval = jax_core.get_aval(ragged_axis_length).update(dtype=jnp.int32) - if isinstance(aval, jax_core.DShapedArray): - aval = jax_core.ShapedArray(aval.shape, aval.dtype, aval.weak_type) - lengths_aval = state.AbstractRef( - aval, - pallas_core.MemorySpace.INDEX, - ) - # TODO(mvoz): Give this its own type - ragged_axis_values.append( - (stacked_axis, ragged_axis_dim, ragged_axis_length, lengths_aval) - ) - else: - ragged_axis_values.append(None) # type: ignore[arg-type] - all_dims = list(dims) + [0] * grid_mapping.num_outputs - ragged_axis_values = ragged_axis_values + [None] * grid_mapping.num_outputs - num_index_operands = grid_mapping.num_index_operands num_scratch_operands = grid_mapping.num_scratch_operands @@ -739,34 +666,16 @@ def get_size(i, x, d): _batch_block_mapping, grid_mapping, axis_size, - any(ragged_axis_values), ), avals_to_batch, all_dims[num_index_operands:], block_mappings, - ragged_axis_values[num_index_operands:], ) index_map_tree_args, index_map_tree_kwargs = grid_mapping.index_map_tree.unflatten( grid_mapping.index_map_avals) assert not index_map_tree_kwargs batched_index_map_args = (pallas_core.index_map_grid_aval,) + index_map_tree_args - - lengths_aval = None # type: ignore[assignment] - - # Check all the ragged axis values, ensure their raggedness pattern - # is identical (consider moving this check up!) - for rav in ragged_axis_values: - if rav is not None: - if lengths_aval is None: - lengths_aval = rav[3] - else: - assert lengths_aval == rav[3], "NYI - different lengths in ragged batch" - - if lengths_aval: - batched_index_map_args = batched_index_map_args + (lengths_aval,) - num_index_operands += 1 - batched_index_map_avals, batched_index_map_tree = tree_util.tree_flatten( (batched_index_map_args, {})) @@ -791,261 +700,6 @@ def get_size(i, x, d): else: batched_cost_estimate = None - # Start the ragged handling code - # Here, we: - # - Rewrite the indexer to save memory (skip indices outside the ragged bounds) - # - Rewrite the kernel to save compute (skip elements outside the ragged bounds) - # - Update various internal structures/metadata to account for the new - # block spec. - # - Set the hacky flag of ragged_originating on the mapping, to signal to - # the lowering code to treat mapped dimensions as part of the user grid. - if lengths_aval: - batched_grid_mapping = batched_grid_mapping.replace( - get_grid_indices=lambda indices, maybe_include_mapped_dims: indices, - local_grid_env=lambda loop_idx, grid: tuple( - pallas_core.GridAxis(idx, b) for (idx, b) in zip(loop_idx, grid) - ), - ) - - # Note - on zero filling counterfactuals - # A debug util to produce a counterfactual version of the when - # gating, where for all values that don't pass the @when check, - # we write 0s. This is useful for debugging, as certain lowering paths - # like mosaic will write the last data as passthrough, leading to - # potentially confusing results. - block_mapped_dim_idxs = [] - for block_mapping in batched_grid_mapping.block_mappings: - mapped_dim_idxs = [] - for i, d in enumerate(block_mapping.block_shape): - if isinstance(d, pallas_core.Squeezed): - mapped_dim_idxs.append(i) - else: - mapped_dim_idxs.append(None) # type: ignore[arg-type] - block_mapped_dim_idxs.append(mapped_dim_idxs) - - mapped_dim_idx = None - for rav, mapped_dim_idxs in zip(ragged_axis_values, block_mapped_dim_idxs): - if rav is not None: - stacked_axis = rav[0] - if mapped_dim_idx is None: - mapped_dim_idx = mapped_dim_idxs[stacked_axis] - if mapped_dim_idxs[stacked_axis] is None: - raise ValueError( - f"Expected mapped dim to be {stacked_axis}, but got" - f" {mapped_dim_idxs[stacked_axis]}" - ) - else: - assert mapped_dim_idx == mapped_dim_idxs[stacked_axis], ( - f"Different mapped dims - expected {mapped_dim_idx}, but got" - f" {mapped_dim_idxs[stacked_axis]}" - ) - - # This is the blockspec size of the dimension - block_shapes = [b.block_shape for b in batched_grid_mapping.block_mappings] - - # Parse out the operations from the jaxpr to determine how to mask the output - # NOTE! while this *could* be a default dict of None, and None is sound, as - # it denotes that there is no raggedness for the given var, we explicitly - # do not do this, so as to get better signal on implementation of rules - # A misimplemented rule that does not account for new vars being introduced - # will result in an error on the next op using the new var. The benefit of - # of forcing implementers to account for all outputs and intermediaries is - # a very nice one. - - var_to_raggedness = {} - for invar, rav in zip(jaxpr.invars, ragged_axis_values): - var_to_raggedness[invar] = rav - - for eqn in jaxpr.eqns: - prim = eqn.primitive - if prim not in batching.ragged_prop_rules: - raise NotImplementedError(f"Not implemented - ragged prop for {prim}") - rule = batching.ragged_prop_rules[prim] - - invar_raggedness = [ - ( - var_to_raggedness.get(invar, None) - if isinstance(invar, jax_core.Var) - else None - ) - for invar in eqn.invars - ] - try: - invar_raggedness, outvar_raggedness = rule( - eqn.params, invar_raggedness, eqn.outvars # type: ignore[arg-type] - ) - except Exception as e: - raise RuntimeError( - f"Failed to run rule for {prim}. invars: {eqn.invars}, outvars:" - f" {eqn.outvars}. Underlying reason: {e}" - ) from e - - for invar, rav in zip(eqn.invars, invar_raggedness): # type: ignore[assignment] - if isinstance(invar, jax_core.Var): - var_to_raggedness[invar] = rav - for outvar, rav in zip(eqn.outvars, outvar_raggedness): - if isinstance(outvar, jax_core.Var): - var_to_raggedness[outvar] = rav - - for pos, invar in enumerate(jaxpr.invars): - ragged_axis_values[pos] = var_to_raggedness[invar] - - per_input_ragged_axis_dim: list[int | None] = [] - for rav in ragged_axis_values: - if rav is not None: - per_input_ragged_axis_dim.append(rav[1]) - else: - per_input_ragged_axis_dim.append(None) - - def when_wrapped_kernel(lengths_ref, *args, **kwargs): - b_idx = primitives.program_id(mapped_dim_idx) - - b_len = lengths_ref[b_idx] - run_kernel = jnp.array(True) - for i, _ in enumerate(args): - ragged_axis_dim = per_input_ragged_axis_dim[i] - if ragged_axis_dim is None: - continue - arg_i_idx = ( - primitives.program_id(ragged_axis_dim) - * pallas_core.get_block_dim_size(block_shapes[i][ragged_axis_dim]) - ) - run_kernel = jnp.logical_and(run_kernel, arg_i_idx < b_len) - - # TODO(mvoz): Unimplemented primitive in pallas - # b_len_mod = jnp.equal(jnp.mod(b_len, val_at_ragged_dim), 0) - # checkify.check(b_len_mod, "b_len % val_at_ragged_dim != 0") - - @pallas_helpers.when(run_kernel) - def f(): - # Important! This allows us to trace the inner kernel with the correct - # grid to preserve user program_id semantics. Ex: program_id(0) will - # always be analogous to program_id(1) in the outer kernel. - with pallas_core.tracing_grid_env(grid_mapping.grid, ()): - jax_core.eval_jaxpr(jaxpr, (), *args, **kwargs) - - kernel_avals = [lengths_aval] + [v.aval for v in jaxpr.invars] - flat_kernel_avals, kernel_in_tree = tree_util.tree_flatten( - list(kernel_avals) - ) - - def _rewrite_index_jaxpr(enumerate_batched_block_mapping): - arg_pos, batched_block_mapping = enumerate_batched_block_mapping - indexer_avals = [ - v.aval for v in batched_block_mapping.index_map_jaxpr.jaxpr.invars - ] - flat_indexer_avals, indexer_in_tree = tree_util.tree_flatten( - list(indexer_avals) - ) - - def index_rewrite_kernel(*indexer_args): - ragged_axis_dim = per_input_ragged_axis_dim[arg_pos] - - # the problem here seems to be that we are rnning this for all inputs, per input, because they each have an indexer - which means - # that the indexer for output isn't getting written - before, it always was - - lengths_ref = indexer_args[-1] - rest_indexer_args = indexer_args[:-1] - # Lengths are always the last argument of the indexer. - # lengths_ref = args[-1] - # Invariant: Stacked axis is enforced to be the mapped axis above. - b_idx = indexer_args[mapped_dim_idx] - - nargs = list(rest_indexer_args) - - if ragged_axis_dim is not None: - val_at_ragged_dim = pallas_core.get_block_dim_size( - batched_block_mapping.block_shape[ragged_axis_dim]) - - # The current index into the ragged dimension. - # Invariant: There is only one ragged dimension, enforced above. - i_idx = indexer_args[ragged_axis_dim] - - # grid space -> element space - i_len = i_idx * val_at_ragged_dim - - # The length of the current batch. - b_len = lengths_ref[b_idx] - - # Have we reached the end of the current batch? - not_done = i_len < b_len - - am_last_batch = b_idx == axis_size - 1 - last_good_block = lax.div(b_len, val_at_ragged_dim) - 1 - - # The logic below can be thought of as: - # if index_oob_ragged: - # if not last_batch: - # batch_idx += 1 - # ragged_idx = 0 - # else: - # ragged_idx = last_good_block - # - # wherein we find the next good block by incrementing the batch index - # and setting the ragged index to 0 if we are not in the last batch. - # Otherwise, we set the ragged index to the last good block. - b_next = jnp.where( - not_done, b_idx, jnp.where(am_last_batch, b_idx, b_idx + 1) - ) - i_next = jnp.where( - not_done, i_idx, jnp.where(am_last_batch, last_good_block, 0) - ) - nargs[ragged_axis_dim] = i_next - nargs[mapped_dim_idx] = b_next - - nargs = nargs + [lengths_ref] - return jax_core.eval_jaxpr( - batched_block_mapping.index_map_jaxpr.jaxpr, - batched_block_mapping.index_map_jaxpr.consts, - *nargs, - ) - index_jaxpr, _ = _trace_kernel_to_jaxpr( - index_rewrite_kernel, - batched_block_mapping.index_map_jaxpr.jaxpr.debug_info, - batched_grid_mapping, - tuple(flat_indexer_avals), - indexer_in_tree, - tuple(() for _ in flat_indexer_avals), - indexer=True, - ) - - batched_block_mapping = batched_block_mapping.replace( - index_map_jaxpr=pe.close_jaxpr(index_jaxpr) - ) - return batched_block_mapping - - # Important! This allows us to trace the outer kernel with the correct grid - # to enable accessing the batch program_id. - with pallas_core.tracing_grid_env(batched_grid_mapping.grid, ()): - batched_block_mappings = map( - _rewrite_index_jaxpr, enumerate(batched_block_mappings) - ) - - batched_grid_mapping = batched_grid_mapping.replace( - block_mappings=tuple(batched_block_mappings), - ) - - jaxpr, consts = _trace_kernel_to_jaxpr( - when_wrapped_kernel, - jaxpr.debug_info, - batched_grid_mapping, - tuple(flat_kernel_avals), - kernel_in_tree, - tuple(() for _ in flat_kernel_avals), - ) - if consts: - raise NotImplementedError("consts not supported in pallas_call") - - # We need to rewrite the input_output_aliases here, the initial call - # to broadcast is done, and we have inserted a new input (lengths), so - # there's an off-by-one here now. - new_input_output_aliases = [] - for k, v in input_output_aliases: - new_input_output_aliases.append((k + 1, v)) - input_output_aliases = tuple(new_input_output_aliases) - - # assert ragged_axis_length is not None - args = (ragged_axis_length, *args) assert all(isinstance(aval, jax_core.ShapedArray) for aval in out_avals) batched_out_avals = [] @@ -1887,18 +1541,6 @@ def wrapped(*args): f"[0, {len(flat_out_avals)})") in_aval = flat_in_avals[i_idx] out_aval = flat_out_avals[o_idx] - if isinstance(in_aval, jax_core.DShapedArray): - new_shape = [] - for d in in_aval.shape: - if isinstance(d, int): - new_shape.append(d) - else: - new_shape.append(d.dtype.bound) - - in_aval = jax_core.ShapedArray( - tuple(new_shape), in_aval.dtype, in_aval.weak_type - ) - if in_aval.shape != out_aval.shape or in_aval.dtype != out_aval.dtype: raise ValueError( f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' " diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 2fdfe6b31513..b612c139acda 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -39,7 +39,6 @@ from jax._src import state from jax._src import util from jax._src.interpreters import ad -from jax._src.interpreters import batching from jax._src.interpreters import partial_eval as pe from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith @@ -59,7 +58,6 @@ zip, unsafe_zip = util.safe_zip, zip program_id_p = jax_core.Primitive("program_id") -batching.ragged_prop_rules[program_id_p] = batching.ragged_mask_no_op_rule def program_id(axis: int) -> jax_typing.Array: """Returns the kernel execution position along the given axis of the grid. diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index d9d3aa16d35f..43671aef48d7 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -21,7 +21,7 @@ import inspect import logging import weakref -from typing import NamedTuple, Any, Union, cast +from typing import NamedTuple, Any, Union import warnings import numpy as np @@ -71,8 +71,8 @@ from jax._src.traceback_util import api_boundary from jax._src.tree_util import ( tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure, - treedef_children, broadcast_prefix, all_leaves, prefix_errors, keystr, - PyTreeDef, none_leaf_registry as none_lr, tree_map) + treedef_children, prefix_errors, keystr, PyTreeDef, + none_leaf_registry as none_lr, tree_map) from jax._src.typing import ArrayLike from jax._src.util import ( HashableFunction, safe_map, safe_zip, wraps, distributed_debug_log, @@ -119,7 +119,6 @@ class PjitInfo(NamedTuple): backend: str | None keep_unused: bool inline: bool - abstracted_axes: Any | None use_resource_env: bool # False for jit, True for pjit compiler_options_kvs: tuple[tuple[str, Any], ...] @@ -187,7 +186,7 @@ def _need_to_rebuild_with_fdo(pgle_profiler): def _get_fastpath_data( executable, out_tree, args_flat, out_flat, effects, consts_for_constvars, - abstracted_axes, pgle_profiler, const_args: Sequence[ArrayLike] + pgle_profiler, const_args: Sequence[ArrayLike] ) -> pxla.MeshExecutableFastpathData | None: if ( executable is None @@ -196,7 +195,6 @@ def _get_fastpath_data( # No effects in computation or executable.unsafe_call.ordered_effects or executable.unsafe_call.has_unordered_effects - or abstracted_axes is not None # no ref state effects or any(isinstance(e, RefEffect) for e in effects) # no prng reuse checking @@ -265,7 +263,7 @@ def cache_miss(*args, **kwargs): maybe_fastpath_data = _get_fastpath_data( executable, out_tree, args_flat, out_flat, jaxpr.effects, jaxpr.consts, - jit_info.abstracted_axes, pgle_profiler, + pgle_profiler, const_args) return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler) @@ -349,7 +347,6 @@ def _parse_jit_arguments(fun: Callable, *, in_shardings: Any, donate_argnames: str | Iterable[str] | None, keep_unused: bool, device: xc.Device | None, backend: str | None, inline: bool, - abstracted_axes: Any | None, compiler_options: dict[str, Any] | None, use_resource_env: bool) -> PjitInfo: """Parses the arguments to jit/pjit. @@ -357,9 +354,6 @@ def _parse_jit_arguments(fun: Callable, *, in_shardings: Any, Performs any preprocessing and validation of the arguments that we can do ahead of time before the jit()-ed function is invoked. """ - if abstracted_axes and not config.dynamic_shapes.value: - raise ValueError("abstracted_axes must be used with --jax_dynamic_shapes") - check_callable(fun) if backend is not None or device is not None: @@ -427,7 +421,6 @@ def _parse_jit_arguments(fun: Callable, *, in_shardings: Any, static_argnames=static_argnames, donate_argnums=donate_argnums, donate_argnames=donate_argnames, device=device, backend=backend, keep_unused=keep_unused, inline=inline, - abstracted_axes=abstracted_axes, use_resource_env=use_resource_env, compiler_options_kvs=compiler_options_kvs) @@ -443,7 +436,6 @@ def make_jit(fun: Callable, device: xc.Device | None, backend: str | None, inline: bool, - abstracted_axes: Any | None, compiler_options: dict[str, Any] | None, use_resource_env: bool) -> Any: """jit() and pjit() are thin wrappers around this function.""" @@ -452,7 +444,7 @@ def make_jit(fun: Callable, static_argnums=static_argnums, static_argnames=static_argnames, donate_argnums=donate_argnums, donate_argnames=donate_argnames, keep_unused=keep_unused, device=device, backend=backend, inline=inline, - abstracted_axes=abstracted_axes, compiler_options=compiler_options, + compiler_options=compiler_options, use_resource_env=use_resource_env) return _cpp_pjit(fun, jit_info) @@ -490,8 +482,6 @@ def _infer_params_impl( "Mesh context manager should not be used with jit when backend or " "device is also specified as an argument to jit.") - axes_specs = _flat_axes_specs(ji.abstracted_axes, *args, **kwargs) - f = lu.wrap_init(fun, debug_info=dbg) f, dyn_args = argnums_partial_except(f, ji.static_argnums, args, allow_invalid=True) del args @@ -530,14 +520,9 @@ def _infer_params_impl( assert None not in out_shardings_leaves in_type: core.InputType | tuple[core.AbstractValue, ...] - if config.dynamic_shapes.value: - assert in_avals is None - in_type = pe.infer_lambda_input_type(axes_specs, explicit_args) - in_avals = tuple(a for a, e in in_type if e) - else: - in_type = in_avals # type: ignore - in_type = tuple(core.AvalQDD(a, cur_qdd(x)) if a.has_qdd # type: ignore - else a for a, x in zip(in_type, explicit_args)) + in_type = in_avals # type: ignore + in_type = tuple(core.AvalQDD(a, cur_qdd(x)) if a.has_qdd # type: ignore + else a for a, x in zip(in_type, explicit_args)) assert in_avals is not None in_shardings_flat, in_layouts_flat = _process_in_axis_resources( @@ -562,14 +547,9 @@ def _infer_params_impl( assert len(explicit_args) == len(in_shardings_flat) == len(in_layouts_flat) - if config.dynamic_shapes.value: - implicit_args = _extract_implicit_args( - cast(core.InputType, in_type), explicit_args) - else: - implicit_args = [] - args_flat = [*implicit_args, *explicit_args] + args_flat = explicit_args - num_extra_args = len(implicit_args) + len(consts) + num_extra_args = len(consts) in_shardings_flat = (UNSPECIFIED,) * num_extra_args + in_shardings_flat in_layouts_flat = (None,) * num_extra_args + in_layouts_flat donated_invars = (False,) * num_extra_args + donated_invars @@ -645,11 +625,6 @@ def _infer_params_internal( static_argnames=ji.static_argnames, sourceinfo=ji.fun_sourceinfo, signature=ji.fun_signature) - if config.dynamic_shapes.value: # don't use the cache - p, args_flat = _infer_params_impl(fun, ji, ctx_mesh, dbg_fn(), - args, kwargs, in_avals=None) - return p, p.consts + args_flat - signature, dynargs = jax_jit.parse_arguments( args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums, ji.static_argnames, tree_util.default_registry) @@ -693,45 +668,6 @@ def _infer_input_type(fun: Callable, dbg_fn: Callable[[], core.DebugInfo], check_no_aliased_ref_args(dbg_fn, avals, explicit_args) return tuple(avals) -def _extract_implicit_args( - in_type: Sequence[tuple[core.AbstractValue, bool]], - explicit_args: Sequence[Any] -) -> Sequence[core.Tracer]: - """ - Given an input type and explicitly-passed arguments (per the user-facing API - calling convention), extract implicit axis size arguments from shapes of - explicit arguments (for the trace-time / jaxpr-level calling convention). - """ - # First, using `in_type` construct a list to represent the full argument list, - # leaving the implicit arguments as None placeholders for now. - explicit_args_ = iter(explicit_args) - args = [next(explicit_args_) if expl else None for _, expl in in_type] - assert next(explicit_args_, None) is None - del explicit_args, explicit_args_ - - # Next, populate the implicit arguments using the DBIdxs in `in_type`. - for i, (aval, explicit) in enumerate(in_type): - if not explicit or not isinstance(aval, core.DShapedArray): - continue # can't populate an implicit argument - arg = args[i] - assert arg is not None - for d1, d2 in zip(aval.shape, arg.aval.shape): - if isinstance(d1, core.DBIdx): - if args[d1.val] is None: - args[d1.val] = d2 - assert core.same_referent(args[d1.val], d2) - assert all(x is not None for x in args) - return [x for x, (_, e) in zip(args, in_type) if not e] # type: ignore - -def _flat_axes_specs(abstracted_axes, *args, **kwargs - ) -> list[pe.AbstractedAxesSpec] | None: - if abstracted_axes is None: return None - if kwargs: raise NotImplementedError - def ax_leaf(l): - return (isinstance(l, dict) and all_leaves(l.values()) or - isinstance(l, tuple) and all_leaves(l, lambda x: x is None)) - return broadcast_prefix(abstracted_axes, args, ax_leaf) - class JitWrapped(stages.Wrapped): @@ -758,7 +694,6 @@ def pjit( device: xc.Device | None = None, backend: str | None = None, inline: bool = False, - abstracted_axes: Any | None = None, compiler_options: dict[str, Any] | None = None, ) -> JitWrapped: """`jax.experimental.pjit.pjit` has been deprecated. Please use `jax.jit`.""" @@ -767,8 +702,7 @@ def pjit( static_argnums=static_argnums, static_argnames=static_argnames, donate_argnums=donate_argnums, donate_argnames=donate_argnames, keep_unused=keep_unused, device=device, backend=backend, inline=inline, - abstracted_axes=abstracted_axes, compiler_options=compiler_options, - use_resource_env=True) + compiler_options=compiler_options, use_resource_env=True) def hashable_pytree(pytree): @@ -890,13 +824,12 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves, in_layouts_flat = flatten_axis_resources( "pjit in_layouts", in_tree, in_layouts, tupled_args=True) - if not config.dynamic_shapes.value: - pjit_check_aval_sharding(in_shardings_flat, in_avals, - debug_info.safe_arg_names(len(in_avals)), - "pjit arguments", allow_uneven_sharding=False) - check_aval_layout_compatibility( - in_layouts_flat, in_avals, - debug_info.safe_arg_names(len(in_avals)), "jit arguments") # type: ignore[arg-type] + pjit_check_aval_sharding(in_shardings_flat, in_avals, + debug_info.safe_arg_names(len(in_avals)), + "pjit arguments", allow_uneven_sharding=False) + check_aval_layout_compatibility( + in_layouts_flat, in_avals, + debug_info.safe_arg_names(len(in_avals)), "jit arguments") # type: ignore[arg-type] return in_shardings_flat, in_layouts_flat callsites_with_tracing_cache_miss: set[str] = set() @@ -1182,12 +1115,7 @@ def _create_pjit_jaxpr( with dispatch.log_elapsed_time( "Finished tracing + transforming {fun_name} for pjit in {elapsed_time:.9f} sec", fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT): - if config.dynamic_shapes.value: - assert isinstance(in_type, core.InputType) - jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2( - lu.annotate(fun, in_type)) - else: - jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, in_type) + jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, in_type) if config.debug_key_reuse.value: # Import here to avoid circular imports @@ -1226,15 +1154,14 @@ def _check_and_canonicalize_out_shardings( out_layouts_flat = flatten_axis_resources( "pjit out_layouts", out_tree(), out_layouts, tupled_args=False) - if not config.dynamic_shapes.value: - pjit_check_aval_sharding( - out_shardings_flat, out_avals, - debug_info.safe_result_paths(len(out_avals)), - "pjit outputs", allow_uneven_sharding=False) - check_aval_layout_compatibility( - out_layouts_flat, out_avals, - debug_info.safe_result_paths(len(out_avals)), - "jit outputs") + pjit_check_aval_sharding( + out_shardings_flat, out_avals, + debug_info.safe_result_paths(len(out_avals)), + "pjit outputs", allow_uneven_sharding=False) + check_aval_layout_compatibility( + out_layouts_flat, out_avals, + debug_info.safe_result_paths(len(out_avals)), + "jit outputs") return out_shardings_flat, out_layouts_flat _seen_qdds = weakref.WeakKeyDictionary() # type: ignore @@ -1678,7 +1605,7 @@ def call_impl_cache_miss(*args_, **kwargs_): inline=inline, compiler_options_kvs=compiler_options_kvs) fastpath_data = _get_fastpath_data( compiled, tree_structure(out_flat), args, out_flat, - jaxpr.effects, jaxpr.consts, None, pgle_profiler, + jaxpr.effects, jaxpr.consts, pgle_profiler, const_args) return out_flat, fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler) @@ -1744,36 +1671,12 @@ def pjit_staging_rule(trace, source_info, *args, **params): all(i is None for i in params["in_layouts"]) and all(o is None for o in params["out_layouts"])): jaxpr = params["jaxpr"] - if config.dynamic_shapes.value: - # Inline jaxpr doesn't handle dynamic shapes when inlining. If dynamic - # shapes are enabled, use eval_jaxpr, which uses the tracing machinery, - # but redundantly performs abstract evaluation again. - with core.set_current_trace(trace): - out = core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args, - propagate_source_info=False) - else: - out = pe.inline_jaxpr_into_trace( - trace, source_info, jaxpr.jaxpr, jaxpr.consts, *args) + out = pe.inline_jaxpr_into_trace( + trace, source_info, jaxpr.jaxpr, jaxpr.consts, *args) return [trace.to_jaxpr_tracer(x, source_info) for x in out] jaxpr = params['jaxpr'] - if config.dynamic_shapes.value: - jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding( - jaxpr, params['out_shardings'], params['out_layouts']) - params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings, - out_layouts=out_layouts) - outvars = map(trace.frame.newvar, _out_type(jaxpr)) - eqn = core.new_jaxpr_eqn( - [arg.var for arg in args], outvars, jit_p, params, - jaxpr.effects, source_info) - trace.frame.add_eqn(eqn) - out_tracers = [pe.DynamicJaxprTracer(trace, v.aval, v, source_info) - for v in outvars] - out_tracers_ = iter(out_tracers) - out_tracers = [args[f] if type(f) is int else next(out_tracers_) - for f in in_fwd] - assert next(out_tracers_, None) is None - elif any(isinstance(c, core.Ref) for c in jaxpr.consts): + if any(isinstance(c, core.Ref) for c in jaxpr.consts): jaxpr, consts = pxla._move_mutable_consts(jaxpr) consts = [trace.new_const(c, source_info) for c in consts] in_shardings = (*params['in_shardings'],) + (UNSPECIFIED,) * len(consts) @@ -1801,15 +1704,7 @@ def _pjit_forwarding(jaxpr, out_shardings, out_layouts): return jaxpr, in_fwd, out_shardings, out_layouts def pjit_forwarding_rule(eqn): - if not config.dynamic_shapes.value: - return [None] * len(eqn.outvars), eqn - jaxpr, in_fwd, out_shardings, out_layouts = _pjit_forwarding( - eqn.params['jaxpr'], eqn.params['out_shardings'], eqn.params['out_layouts']) - new_outvars = [v for v, f in zip(eqn.outvars, in_fwd) if f is None] - new_params = dict(eqn.params, jaxpr=jaxpr, out_shardings=out_shardings, - out_layouts=out_layouts) - new_eqn = eqn.replace(params=new_params, outvars=new_outvars) - return in_fwd, new_eqn + return [None] * len(eqn.outvars), eqn # TODO(mattjj): Remove pjit_forwarding_rule and also in staging rule. pe.forwarding_rules[jit_p] = pjit_forwarding_rule @@ -1823,11 +1718,6 @@ def _out_type(jaxpr: core.ClosedJaxpr) -> list[core.AbstractValue]: if type(x) is core.Var} for x in jaxpr.jaxpr.outvars: aval = x.aval - if type(aval) is core.DShapedArray: - shape = [core.InDBIdx(in_idx[d]) if d in in_idx else - core.OutDBIdx(out_idx[d]) if d in out_idx else - d for d in x.aval.shape] - aval = aval.update(shape=tuple(shape)) out.append(aval) return out @@ -1955,10 +1845,7 @@ def _pjit_batcher(axis_data, vals_in, in_shardings, out_shardings, in_layouts, out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs): - segment_lens, dims_in = batching.indirectify_ragged_axes(dims_in) new_jaxpr, axes_out = batching.batch_jaxpr2(jaxpr, axis_data, dims_in) - - # TODO(axch): prepend with Nones (?) to account for new segment_lens inputs in_shardings = tuple( _pjit_batcher_for_sharding(i, axis_in, axis_data.spmd_name, ctx_mesh, aval.ndim) @@ -1989,16 +1876,13 @@ def _pjit_batcher(axis_data, vals_in, inline=inline, compiler_options_kvs=compiler_options_kvs) - resolved_axes_out = batching.resolve_ragged_axes_against_inputs_outputs( - vals_in, vals_out, axes_out) - return vals_out, resolved_axes_out + return vals_out, axes_out batching.fancy_primitive_batchers[jit_p] = _pjit_batcher -batching.ragged_prop_rules[jit_p] = batching.ragged_mask_no_op_rule def _pjit_batcher_for_sharding( - s, dim: int | batching.RaggedAxis, spmd_axis_name: tuple[str, ...] | None, + s, dim: int, spmd_axis_name: tuple[str, ...] | None, mesh, ndim: int): if isinstance(s, UnspecifiedValue): return s diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index dbcfadb3fc95..d2682253b6be 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -1381,8 +1381,6 @@ def _shard_map_batch( in_specs, out_specs_thunk, check_vma: bool, manual_axes: frozenset ) -> Sequence[batching.BatchTracer]: in_vals, in_dims = unzip2(map(trace.to_batch_info, in_tracers)) - if any(isinstance(d, batching.RaggedAxis) for d in in_dims): - raise NotImplementedError spmd_axis_name = trace.axis_data.spmd_name explicit_mesh_axis = trace.axis_data.explicit_mesh_axis if spmd_axis_name is not None: diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 810f76cd8de8..af87fc736210 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -818,8 +818,6 @@ def call(*args, **kwargs): # which might conflict here. params = args[0] args = args[1:] # Not including const_args - if config.dynamic_shapes.value: - raise NotImplementedError if params.no_kwargs and kwargs: kws = ', '.join(kwargs.keys()) raise NotImplementedError( diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 3b94cd5e705e..d3b7e6fae71b 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -70,7 +70,6 @@ get_p = core.Primitive("get") get_p.is_effectful = lambda params: True # type: ignore get_p.def_impl(partial(dispatch.apply_primitive, get_p)) -batching.ragged_prop_rules[get_p] = batching.ragged_mask_transfer_identity get_p.is_high = lambda ref_aval, *_, tree: ref_aval.is_high # type: ignore def _get_to_lojax(ref, *idx, tree): @@ -192,16 +191,6 @@ def _swap_to_lojax(ref, val, *idx, tree): swap_p.to_lojax = _swap_to_lojax # type: ignore -def swap_ragged_prop_rule(eqn_params, invar_raggedness, outvars): - assert len(invar_raggedness) == 2 - invar_raggedness_lhs = invar_raggedness[0] - invar_raggedness_rhs = invar_raggedness[1] - - return [invar_raggedness_rhs, invar_raggedness_lhs], [None] - - -batching.ragged_prop_rules[swap_p] = swap_ragged_prop_rule - @partial(traceback_util.api_boundary, repro_api_name="jax.ref.swap") def ref_swap( ref: core.Ref | TransformedRef, diff --git a/jax/core.py b/jax/core.py index 60d031083bb9..7fca3e07dd99 100644 --- a/jax/core.py +++ b/jax/core.py @@ -21,7 +21,6 @@ Atom as Atom, CallPrimitive as CallPrimitive, DebugInfo as DebugInfo, - DShapedArray as DShapedArray, DropVar as DropVar, Effect as Effect, Effects as Effects, diff --git a/jax/experimental/jax2tf/examples/saved_model_main_test.py b/jax/experimental/jax2tf/examples/saved_model_main_test.py index 28bce3d014e6..b515ced90ca5 100644 --- a/jax/experimental/jax2tf/examples/saved_model_main_test.py +++ b/jax/experimental/jax2tf/examples/saved_model_main_test.py @@ -48,9 +48,11 @@ def setUp(self): def test_train_and_save_full(self, model="mnist_flax", serving_batch_size=-1): + self.skipTest("no more dynamic shapes") if (serving_batch_size == -1 and - config.jax2tf_default_native_serialization.value and - not config.dynamic_shapes.value): + config.jax2tf_default_native_serialization.value + # and not config.dynamic_shapes.value + ): self.skipTest("shape polymorphism but --jax_dynamic_shapes is not set.") FLAGS.model = model FLAGS.model_classifier_layer = True diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index fe1ee305eb3a..e3c80e865b12 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -1229,8 +1229,9 @@ def assertAllOperationStartWith(self, g: "tf.Graph", scope_name: str): self.fail(f"{op.name} does not start with {scope_name}.") def test_name_scope_polymorphic(self): - if not config.dynamic_shapes.value: - self.skipTest("shape polymorphism but --jax_dynamic_shapes is not set.") + self.skipTest("no more dynamic shapes") + # if not config.dynamic_shapes.value: + # self.skipTest("shape polymorphism but --jax_dynamic_shapes is not set.") def func_jax(x, y): return jnp.sin(x) + jnp.cos(y) diff --git a/jax/interpreters/batching.py b/jax/interpreters/batching.py index 3eda5ab9bbbe..4db86da4d806 100644 --- a/jax/interpreters/batching.py +++ b/jax/interpreters/batching.py @@ -54,18 +54,6 @@ "jax.interpreters.batching.BatchingRule is deprecated.", _src_batching.BatchingRule, ), - "Jumble": ( - "jax.interpreters.batching.Jumble is deprecated.", - _src_batching.Jumble, - ), - "JumbleAxis": ( - "jax.interpreters.batching.JumbleAxis is deprecated.", - _src_batching.JumbleAxis, - ), - "JumbleTy": ( - "jax.interpreters.batching.JumbleTy is deprecated.", - _src_batching.JumbleTy, - ), "Elt": ( "jax.interpreters.batching.Elt is deprecated.", _src_batching.Elt, @@ -78,10 +66,6 @@ "jax.interpreters.batching.GetIdx is deprecated.", _src_batching.GetIdx, ), - "IndexedAxisSize": ( - "jax.interpreters.batching.IndexedAxisSize is deprecated.", - _src_batching.IndexedAxisSize, - ), "MakeIotaHandler": ( "jax.interpreters.batching.MakeIotaHandler is deprecated.", _src_batching.MakeIotaHandler, @@ -94,10 +78,6 @@ "jax.interpreters.batching.NotMapped is deprecated.", _src_batching.NotMapped, ), - "RaggedAxis": ( - "jax.interpreters.batching.RaggedAxis is deprecated.", - _src_batching.RaggedAxis, - ), "ToEltHandler": ( "jax.interpreters.batching.ToEltHandler is deprecated.", _src_batching.ToEltHandler, @@ -130,10 +110,6 @@ "jax.interpreters.batching.batch_jaxpr is deprecated. It is an internal API.", _src_batching.batch_jaxpr, ), - "batch_jaxpr2": ( - "jax.interpreters.batching.batch_jaxpr2 is deprecated. It is an internal API.", - _src_batching.batch_jaxpr2, - ), "batch_jaxpr_axes": ( "jax.interpreters.batching.batch_jaxpr_axes is deprecated. It is an internal API.", _src_batching.batch_jaxpr_axes, @@ -162,10 +138,6 @@ "jax.interpreters.batching.is_vmappable is deprecated. It is an internal API.", _src_batching.is_vmappable, ), - "jumble_axis": ( - "jax.interpreters.batching.jumble_axis is deprecated. It is an internal API.", - _src_batching.jumble_axis, - ), "make_iota": ( "jax.interpreters.batching.make_iota is deprecated. It is an internal API.", _src_batching.make_iota, @@ -224,17 +196,12 @@ BatchTrace = _src_batching.BatchTrace BatchTracer = _src_batching.BatchTracer BatchingRule = _src_batching.BatchingRule - Jumble = _src_batching.Jumble - JumbleAxis = _src_batching.JumbleAxis - JumbleTy = _src_batching.JumbleTy Elt = _src_batching.Elt FromEltHandler = _src_batching.FromEltHandler GetIdx = _src_batching.GetIdx - IndexedAxisSize = _src_batching.IndexedAxisSize MakeIotaHandler = _src_batching.MakeIotaHandler MapSpec = _src_batching.MapSpec NotMapped = _src_batching.NotMapped - RaggedAxis = _src_batching.RaggedAxis ToEltHandler = _src_batching.ToEltHandler Vmappable = _src_batching.Vmappable Zero = _src_batching.Zero @@ -243,7 +210,6 @@ batch_custom_jvp_subtrace = _src_batching.batch_custom_jvp_subtrace batch_custom_vjp_bwd = _src_batching.batch_custom_vjp_bwd batch_jaxpr = _src_batching.batch_jaxpr - batch_jaxpr2 = _src_batching.batch_jaxpr2 batch_jaxpr_axes = _src_batching.batch_jaxpr_axes batch_subtrace = _src_batching.batch_subtrace broadcast_batcher = _src_batching.broadcast_batcher @@ -251,7 +217,6 @@ from_elt = _src_batching.from_elt from_elt_handlers = _src_batching.from_elt_handlers is_vmappable = _src_batching.is_vmappable - jumble_axis = _src_batching.jumble_axis make_iota = _src_batching.make_iota make_iota_handlers = _src_batching.make_iota_handlers matchaxis = _src_batching.matchaxis diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 8b1022f73baa..67a246eba226 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -34,18 +34,6 @@ _deprecations = { # Deprecated for JAX v0.7.1; finalize in JAX v0.9.0. - "AbstractedAxesSpec": ( - "jax.interpreters.partial_eval.AbstractedAxesSpec is deprecated.", - _pe_src.AbstractedAxesSpec, - ), - "AbstractedAxisName": ( - "jax.interpreters.partial_eval.AbstractedAxisName is deprecated.", - _pe_src.AbstractedAxisName, - ), - "BoundedAxisSize": ( - "jax.interpreters.partial_eval.BoundedAxisSize is deprecated.", - _pe_src.BoundedAxisSize, - ), "Const": ( "jax.interpreters.partial_eval.Const is deprecated.", _pe_src.Const, @@ -130,10 +118,6 @@ "jax.interpreters.partial_eval.abstract_eval_fun is deprecated.", _pe_src.abstract_eval_fun, ), - "call_padding_rule": ( - "jax.interpreters.partial_eval.call_padding_rule is deprecated.", - _pe_src.call_padding_rule, - ), "call_param_updaters": ( "jax.interpreters.partial_eval.call_param_updaters is deprecated.", _pe_src.call_param_updaters, @@ -178,10 +162,6 @@ "jax.interpreters.partial_eval.custom_staging_rules is deprecated.", _pe_src.custom_staging_rules, ), - "def_trivial_padding": ( - "jax.interpreters.partial_eval.def_trivial_padding is deprecated.", - _pe_src.def_trivial_padding, - ), "forwarding_rules": ( "jax.interpreters.partial_eval.forwarding_rules is deprecated.", _pe_src.forwarding_rules, @@ -190,10 +170,6 @@ "jax.interpreters.partial_eval.has_effects is deprecated.", _pe_src.has_effects, ), - "infer_lambda_input_type": ( - "jax.interpreters.partial_eval.infer_lambda_input_type is deprecated.", - _pe_src.infer_lambda_input_type, - ), "instantiate_const_at": ( "jax.interpreters.partial_eval.instantiate_const_at is deprecated.", _pe_src.instantiate_const_at, @@ -214,14 +190,6 @@ "jax.interpreters.partial_eval.new_eqn_recipe is deprecated.", _pe_src.new_eqn_recipe, ), - "pad_jaxpr": ( - "jax.interpreters.partial_eval.pad_jaxpr is deprecated.", - _pe_src.pad_jaxpr, - ), - "padding_rules": ( - "jax.interpreters.partial_eval.padding_rules is deprecated.", - _pe_src.padding_rules, - ), "partial_eval_jaxpr_custom": ( "jax.interpreters.partial_eval.partial_eval_jaxpr_custom is deprecated.", _pe_src.partial_eval_jaxpr_custom, @@ -246,10 +214,6 @@ "jax.interpreters.partial_eval.recipe_to_eqn is deprecated.", _pe_src.recipe_to_eqn, ), - "trace_to_jaxpr_dynamic2": ( - "jax.interpreters.partial_eval.trace_to_jaxpr_dynamic2 is deprecated.", - _pe_src.trace_to_jaxpr_dynamic2, - ), "trace_to_subjaxpr_nounits": ( "jax.interpreters.partial_eval.trace_to_subjaxpr_nounits is deprecated.", _pe_src.trace_to_subjaxpr_nounits, @@ -289,7 +253,6 @@ TracerAsName = _pe_src.TracerAsName TracerId = _pe_src.TracerId abstract_eval_fun = _pe_src.abstract_eval_fun - call_padding_rule = _pe_src.call_padding_rule call_param_updaters = _pe_src.call_param_updaters call_partial_eval_custom_rule = _pe_src.call_partial_eval_custom_rule call_partial_eval_rules = _pe_src.call_partial_eval_rules @@ -301,7 +264,6 @@ convert_envvars_to_constvars = _pe_src.convert_envvars_to_constvars convert_invars_to_constvars = _pe_src.convert_invars_to_constvars custom_staging_rules = _pe_src.custom_staging_rules - def_trivial_padding = _pe_src.def_trivial_padding forwarding_rules = _pe_src.forwarding_rules has_effects = _pe_src.has_effects infer_lambda_input_type = _pe_src.infer_lambda_input_type @@ -310,8 +272,6 @@ move_binders_to_back = _pe_src.move_binders_to_back move_binders_to_front = _pe_src.move_binders_to_front new_eqn_recipe = _pe_src.new_eqn_recipe - pad_jaxpr = _pe_src.pad_jaxpr - padding_rules = _pe_src.padding_rules partial_eval_jaxpr_custom = _pe_src.partial_eval_jaxpr_custom partial_eval_jaxpr_custom_rule_not_implemented = _pe_src.partial_eval_jaxpr_custom_rule_not_implemented partial_eval_jaxpr_nounits = _pe_src.partial_eval_jaxpr_nounits diff --git a/tests/BUILD b/tests/BUILD index e27ed2957a28..3736f44b8b5d 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -81,18 +81,6 @@ jax_multiplatform_test( deps = py_deps("absl/testing"), ) -jax_py_test( - name = "dynamic_api_test", - srcs = ["dynamic_api_test.py"], - deps = [ - "//jax", - "//jax/_src:test_util", - ] + py_deps([ - "absl/testing", - "numpy", - ]), -) - jax_multiplatform_test( name = "api_util_test", srcs = ["api_util_test.py"], diff --git a/tests/api_test.py b/tests/api_test.py index b55f844f033f..4f0c85cdc36e 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -5134,37 +5134,6 @@ def f(inputs): jtu.check_grads(f, (list(jnp.arange(float(num_args))),), order=1, modes=['rev'], atol=1e-3, rtol=1e-3) - @jtu.run_on_devices("cpu") - def test_inner_jit_forwarding_happens(self): - if not config.dynamic_shapes.value: - self.skipTest("Only works for dynamic shapes") - jaxpr = jax.make_jaxpr(lambda: jax.jit(lambda x: x)(3))() - self.assertLen(jaxpr.jaxpr.outvars, 1) - self.assertIsInstance(jaxpr.jaxpr.outvars[0], core.Literal) - self.assertEqual(jaxpr.jaxpr.outvars[0].val, 3) - - @parameterized.parameters(range(8)) - @jtu.run_on_devices("cpu") - def test_inner_jit_forwarding_correctness(self, num_input_fwd): - if not config.dynamic_shapes.value: - self.skipTest("Only works for dynamic shapes") - num_args = 8 - rng = np.random.RandomState(0) - - @jax.jit - def f(inputs): - inputs = [inputs[i] for i in rng.permutation(num_args)] - outputs = (inputs[:num_input_fwd] + - [jnp.sin(inputs[i]) for i in range(num_args - num_input_fwd)]) - return [outputs[i] for i in rng.permutation(num_args)] - - f2 = jax.jit(f) - inputs = list(jnp.arange(float(num_args))) - expected = f(inputs) - ans = f2(inputs) - for a, b in zip(ans, expected): - self.assertAllClose(a, b) - @unittest.skip # TODO(dougalm): figure out with Matt what to do with this feature def test_inner_jit_forwarded_consts_stay_const(self): out = jax.jit(lambda: int(jax.jit(lambda x: x)(3)))() # don't crash diff --git a/tests/core_test.py b/tests/core_test.py index e0c6b8436e2e..86430593c50a 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest from collections import namedtuple from functools import partial import gc @@ -32,7 +31,7 @@ from jax._src import linear_util as lu from jax._src import util from jax._src import test_util as jtu -from jax._src.core import ShapedArray, DBIdx +from jax._src.core import ShapedArray from jax._src.interpreters import partial_eval as pe from jax._src.lax import control_flow as lax_control_flow @@ -598,193 +597,5 @@ def f(x): core.check_jaxpr(jaxpr) -@unittest.skip("currently unmaintained") -@jtu.with_config(jax_dynamic_shapes=True) -class DynamicShapesTest(jtu.JaxTestCase): - - def test_staging_basic(self): - n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False) - a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False) - b = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False) - - def f(x, y): - return x, y - - jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(f, - debug_info=debug_info("test", f, (1, 2), {})), - [n, a, b], keep_inputs=[False, True, True]) - - self.assertLen(jaxpr.invars, 3) - self.assertEqual((jaxpr.invars[0],), jaxpr.invars[1].aval.shape) - self.assertEqual((jaxpr.invars[0],), jaxpr.invars[2].aval.shape) - - self.assertLen(jaxpr.outvars, 2) - self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[0].aval.shape) - self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[1].aval.shape) - - @unittest.skip('This test does not work with nested pjit and DShapedArray') - def test_staging_nested(self): - n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False) - a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False) - b = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False) - - def f(x, y): - @jax.jit - def g(x, y, z, w): - return (x, w) - return g(x, y, x, y) - - jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(f, - debug_info=debug_info("test", f, (0, 1), {})), - [n, a, b], keep_inputs=[False, True, True]) - - self.assertLen(jaxpr.invars, 1 + 2) # one axis size var, two other inputs - self.assertEqual((jaxpr.invars[0],), jaxpr.invars[1].aval.shape) - self.assertEqual((jaxpr.invars[0],), jaxpr.invars[2].aval.shape) - - self.assertLen(jaxpr.outvars, 2) - self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[0].aval.shape) - self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[1].aval.shape) - - self.assertLen(jaxpr.eqns, 1) - eqn = jaxpr.eqns[0] - self.assertIsInstance(eqn.primitive, core.CallPrimitive) - inner_jaxpr = eqn.params['call_jaxpr'] - self.assertIsInstance(inner_jaxpr, core.Jaxpr) - - self.assertLen(inner_jaxpr.invars, 1 + 4) # one axis size var - self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[1].aval.shape) - self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[2].aval.shape) - self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[3].aval.shape) - self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[4].aval.shape) - - @unittest.skip('This test does not work with nested pjit and DShapedArray') - def test_staging_nested_including_shape_arg(self): - n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False) - a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False) - b = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False) - - def f(x, y): - @jax.jit - def g(_, x, y, z, w): - return (x, w) - return g(x.shape[0], x, y, x, y) - - jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(f, - debug_info=debug_info("test", f, (1, 2), {})), - [n, a, b], keep_inputs=[False, True, True]) - - # { lambda ; a:i32[] b:f32[a] c:f32[a]. let - # d:f32[a] e:f32[a] = xla_call[ - # call_jaxpr={ lambda ; f:i32[] g:i32[] h:f32[f] i:f32[f] j:f32[f] k:f32[f]. let - # - # in (h, k) } - # name=g - # ] a a b c b c - # in (d, e) } - - self.assertLen(jaxpr.eqns, 1) - eqn = jaxpr.eqns[0] - self.assertIsInstance(eqn.primitive, core.CallPrimitive) - inner_jaxpr = eqn.params['call_jaxpr'] - self.assertIsInstance(inner_jaxpr, core.Jaxpr) - - self.assertLen(inner_jaxpr.invars, 1 + 4) # one axis size var - self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[1].aval.shape) - self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[2].aval.shape) - self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[3].aval.shape) - self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[4].aval.shape) - - def test_staging_primitive_applications(self): - n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False) - a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False) - b = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False) - - def f(x, y): - z = lax.mul(x, y) - w = lax.sin(z) - u = lax.reduce_sum(w, [0]) - return (u,) - - jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(f, - debug_info=debug_info("test", f, (1, 2), {})), - [n, a, b], keep_inputs=[False, True, True]) - - self.assertLen(jaxpr.invars, 1 + 2) # one axis size var, two other inputs - self.assertLen(jaxpr.eqns, 3) - self.assertLen(jaxpr.eqns[0].outvars, 1) - self.assertEqual(jaxpr.eqns[0].outvars[0].aval.shape, - jaxpr.invars[1].aval.shape) - - self.assertLen(jaxpr.outvars, 1) - self.assertEqual(jaxpr.outvars[0].aval.shape, ()) - - @unittest.skip('This test does not work with nested pjit and DShapedArray') - def test_typecheck_staging_nested(self): - n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False) - m = core.ShapedArray((), jnp.dtype('int32'), weak_type=False) - a = core.DShapedArray((DBIdx(0),), jnp.dtype('float32'), weak_type=False) - b = core.DShapedArray((DBIdx(1),), jnp.dtype('float32'), weak_type=False) - - def f(a, b): - @jax.jit - def g(x): return x - return g(a), - - jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(f, - debug_info=debug_info("test", f, (1, 2), {})), - [n, m, a, b], keep_inputs=[False, False, True, True]) - # { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let - # e:f32[a] = xla_call[ - # call_jaxpr={ lambda ; f:i32[] g:f32[f]. let in (g,) } - # name=g - # ] a c - # in (e,) } - core.check_jaxpr(jaxpr) # no problems here... - - # Let's introduce a type error by applying the called jaxpr to arguments - # with types which aren't consistent with its input binders: - _, _, c, d = jaxpr.invars - jaxpr.eqns[0].invars[1] = d - # { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let - # e:f32[a] = xla_call[ - # call_jaxpr={ lambda ; f:i32[] g:f32[f]. let in (g,) } - # name=g - # ] a d !!! type error here !!! - # in (e,) } - with self.assertRaisesRegex(TypeError, "passes operand"): - core.check_jaxpr(jaxpr) - - # Restore the original jaxpr: - jaxpr.eqns[0].invars[1] = c - core.check_jaxpr(jaxpr) # no problems here... - - # Let's introduce another type error by setting the call result let binders - # to have the wrong type: - jaxpr.eqns[0].outvars[0] = core.Var('', d.aval) - # { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let - # e:f32[b] = xla_call[ !!! type error here !!! - # call_jaxpr={ lambda ; f:i32[] g:f32[f]. let in (g,) } - # name=g - # ] a c - # in (h,) } - with self.assertRaisesRegex(TypeError, "inconsistently typed as"): - core.check_jaxpr(jaxpr) - - def test_check_jaxpr_key_reuse(self): - with config.debug_key_reuse(True): - def f(seed): - key = jax.random.key(seed) - return jax.random.uniform(key) + jax.random.normal(key) - with jax.enable_checks(True): - with self.assertRaises(jax.errors.KeyReuseError): - jax.jit(f)(0) - - if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/dynamic_api_test.py b/tests/dynamic_api_test.py deleted file mode 100644 index 1c3dc17b79de..000000000000 --- a/tests/dynamic_api_test.py +++ /dev/null @@ -1,1770 +0,0 @@ -# Copyright 2018 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partial -import re -import unittest -import numpy as np - -from absl.testing import absltest -from absl.testing import parameterized - -import jax -import jax.numpy as jnp -from jax import lax -from jax.interpreters import batching - -from jax._src import core -from jax._src import test_util as jtu - -jax.config.parse_flags_with_absl() - - -@unittest.skip("currently unmaintained") -@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow") -class DynamicShapeStagingTest(jtu.JaxTestCase): - def test_basic_staging(self): - def f(x, _): - return x - - x = jnp.arange(3) - y = jnp.ones((3, 4)) - jaxpr = jax.make_jaxpr(f, abstracted_axes={0: 'n'})(x, y) - - # { lambda ; a:i32[] b:i32[a] c:f32[a,4]. let in (b,) } - self.assertLen(jaxpr.in_avals, 3) - self.assertLen(jaxpr.in_avals[0].shape, 0) - self.assertLen(jaxpr.in_avals[1].shape, 1) - self.assertLen(jaxpr.in_avals[2].shape, 2) - self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.in_avals[1].shape[0]) - self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.in_avals[2].shape[0]) - self.assertLen(jaxpr.out_avals, 1) - self.assertLen(jaxpr.out_avals[0].shape, 1) - self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.out_avals[0].shape[0]) - - def test_basic_staging_repeated(self): - def f(x, _): - return x - - x = jnp.arange(3) - y = jnp.ones((3, 3)) - jaxpr = jax.make_jaxpr(f, abstracted_axes=(('n',), ('n', 'n')))(x, y) - - # { lambda ; a:i32[] b:i32[a] c:f32[a,a]. let in (b,) } - self.assertLen(jaxpr.in_avals, 3) - self.assertLen(jaxpr.in_avals[0].shape, 0) - self.assertLen(jaxpr.in_avals[1].shape, 1) - self.assertLen(jaxpr.in_avals[2].shape, 2) - self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.in_avals[1].shape[0]) - self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.in_avals[2].shape[0]) - self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.in_avals[2].shape[1]) - self.assertLen(jaxpr.out_avals, 1) - self.assertLen(jaxpr.out_avals[0].shape, 1) - self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.out_avals[0].shape[0]) - - def test_basic_staging_multiple_shape_vars(self): - def f(x, _): - return x - - x = jnp.arange(3) - y = jnp.ones((4, 3)) - jaxpr = jax.make_jaxpr(f, abstracted_axes=(('n',), ('m', 'n')))(x, y) - - # { lambda ; a:i32[] b: i32[] c:i32[a] d:f32[b,a]. let in (c,) } - self.assertLen(jaxpr.in_avals, 4) - self.assertLen(jaxpr.in_avals[0].shape, 0) - self.assertLen(jaxpr.in_avals[1].shape, 0) - self.assertLen(jaxpr.in_avals[2].shape, 1) - self.assertLen(jaxpr.in_avals[3].shape, 2) - self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.in_avals[2].shape[0]) - self.assertIs(jaxpr.jaxpr.invars[1], jaxpr.in_avals[3].shape[0]) - self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.in_avals[3].shape[1]) - self.assertLen(jaxpr.out_avals, 1) - self.assertLen(jaxpr.out_avals[0].shape, 1) - self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.out_avals[0].shape[0]) - - def test_basic_add(self): - def f(x, y): - return x + y - - x = jnp.arange(3) - y = jnp.arange(1, 4) - jaxpr = jax.make_jaxpr(f, abstracted_axes={0: 'n'})(x, y) - - # { lambda ; a:i32[] b:i32[a] c:i32[a]. let d:i32[a] = add b c in (d,) } - self.assertLen(jaxpr.eqns, 1) - self.assertLen(jaxpr.out_avals, 1) - self.assertLen(jaxpr.out_avals[0].shape, 1) - self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.out_avals[0].shape[0]) - - def test_basic_jnp(self): - def f(x): - y = x + jnp.sin(x) - return y.sum() - - x = jnp.ones((3, 4)) - jaxpr = jax.make_jaxpr(f, abstracted_axes={0: 'n'})(x) - - # { lambda ; a:i32[] b:f32[a,4]. let - # c:f32[a,4] = sin b - # d:f32[a,4] = add b c - # e:f32[] = reduce_sum[axes=(0, 1)] d - # in (e,) } - self.assertLen(jaxpr.in_avals, 2) - self.assertLen(jaxpr.eqns, 3) # sin, add, and reduce_sum - self.assertLen(jaxpr.out_avals, 1) - self.assertLen(jaxpr.out_avals[0].shape, 0) - - def test_shape_errors_var_and_lit(self): - def f(x, y): - return jnp.sin(x) + y - - x = np.ones(3) - y = np.ones(3) - with self.assertRaisesRegex( - Exception, '[Ii]ncompatible shapes for broadcasting'): - _ = jax.make_jaxpr(f, abstracted_axes=({0: 'n'}, {}))(x, y) - - def test_shape_errors_distinct_vars(self): - def f(x, y): - return jnp.sin(x) + y - - x = np.ones(3) - y = np.ones(3) - with self.assertRaisesRegex( - Exception, '[Ii]ncompatible shapes for broadcasting'): - _ = jax.make_jaxpr(f, abstracted_axes=({0: 'n'}, {0: 'm'}))(x, y) - - def test_basic_dot(self): - A = jnp.ones((3, 4)) - x = jnp.ones(4) - jaxpr = jax.make_jaxpr(jnp.dot, abstracted_axes=(('m', 'n'), ('n',)))(A, x) - - # { lambda ; a:i32[] b:i32[] c:f32[a,b] d:f32[b]. let - # e:f32[a] = dot_general[dimension_numbers=(((1,), (0,)), ((), ()))] c d - # in (e,) } - self.assertLen(jaxpr.in_avals, 4) - self.assertLen(jaxpr.in_avals[0].shape, 0) # two shape vars - self.assertLen(jaxpr.in_avals[1].shape, 0) - self.assertLen(jaxpr.in_avals[2].shape, 2) # one matrix - self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.in_avals[2].shape[0]) - self.assertIs(jaxpr.jaxpr.invars[1], jaxpr.in_avals[2].shape[1]) - self.assertLen(jaxpr.in_avals[3].shape, 1) # one vector - self.assertIs(jaxpr.jaxpr.invars[1], jaxpr.in_avals[3].shape[0]) - self.assertLen(jaxpr.eqns, 1) - self.assertLen(jaxpr.out_avals, 1) - self.assertLen(jaxpr.out_avals[0].shape, 1) # output vector - self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.out_avals[0].shape[0]) - - def test_basic_broadcast(self): - def f(x, n): - return lax.broadcast(x, (n,)) - - jaxpr = jax.make_jaxpr(f)(jnp.ones(4), 3) - - # { lambda ; a:f32[4] b:i32[]. let - # c:f32[b,4] = broadcast_in_dim[bcast_dims=(1,) shape=(None, 4)] a b - # in (c,) } - self.assertLen(jaxpr.in_avals, 2) - self.assertLen(jaxpr.eqns, 1) - self.assertLen(jaxpr.out_avals, 1) - self.assertLen(jaxpr.out_avals[0].shape, 2) - self.assertIs(jaxpr.jaxpr.invars[1], jaxpr.out_avals[0].shape[0]) - self.assertEqual(4, jaxpr.out_avals[0].shape[1]) - - def test_basic_batchpoly_neuralnet(self): - def predict(params, inputs): - for W, b in params: - outputs = jnp.dot(inputs, W) + b - inputs = jnp.tanh(outputs) - return outputs - - def loss(params, batch): - inputs, targets = batch - preds = predict(params, inputs) - return jnp.sum((preds - targets) ** 2) - - sizes = [784, 128, 128, 10] - params = [(jnp.ones((input_dim, output_dim)), jnp.ones(output_dim)) - for input_dim, output_dim in zip(sizes[:-1], sizes[1:])] - batch = (jnp.ones((32, 784)), jnp.ones((32, 10))) - - # Mainly we want to test that make_jaxpr doesn't crash here. - jaxpr = jax.make_jaxpr(loss, abstracted_axes=({}, {0: 'n'}))(params, batch) - self.assertLen(jaxpr.in_avals, 9) - self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.in_avals[-2].shape[0]) - self.assertIs(jaxpr.jaxpr.invars[0], jaxpr.in_avals[-1].shape[0]) - self.assertLen(jaxpr.out_avals, 1) - self.assertLen(jaxpr.out_avals[0].shape, 0) - - def test_closing_over_polymorphic_shape(self): - def f(n): - x = jnp.zeros(n) - return jax.jit(lambda: x)() - - jaxpr = jax.make_jaxpr(f)(3) - - # { lambda ; a:i32[]. let - # b:f32[a] = bcast[dims=() shape=(None,)] 0.0 a - # c:f32[a] = pjit[ - # jaxpr={ lambda ; d:i32[] e:f32[d]. let in (e,) } - # name= - # ] a b - # in (c,) } - a, = jaxpr.jaxpr.invars - c, = jaxpr.jaxpr.outvars - self.assertLen(c.aval.shape, 1) - self.assertIs(a, c.aval.shape[0]) - - def test_closing_over_dynamic_shape(self): - def f(n): - m = 2 * n - x = jnp.zeros(m) - return jax.jit(jnp.sin)(x) - - # { lambda ; a:i32[]. let - # b:i32[] = mul a 2 - # c:f32[b] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 0.0 b - # d:f32[b] = pjit[ - # jaxpr={ lambda ; e:i32[] f:f32[e]. let in (f,) } - # name= - # ] b c - # in (d,) } - jaxpr = jax.make_jaxpr(f)(3) - b, = jaxpr.jaxpr.eqns[0].outvars - c, = jaxpr.jaxpr.eqns[1].outvars - d, = jaxpr.jaxpr.eqns[2].outvars - self.assertLen(c.aval.shape, 1) - self.assertIs(b, c.aval.shape[0]) - self.assertLen(d.aval.shape, 1) - self.assertIs(b, d.aval.shape[0]) - - def test_closing_over_polymorphic_shape_and_adding(self): - def f(n): - x = jnp.zeros(n) - y = jnp.zeros(n) - - @jax.jit - def g(): - return x + y - return g() - - # { lambda ; a:i32[]. let - # b:f32[a] = bcast[broadcast_dimensions=() shape=(None,)] 0.0 a - # c:f32[a] = bcast[broadcast_dimensions=() shape=(None,)] 0.0 a - # d:f32[a] = pjit[ - # jaxpr={ lambda ; e:i32[] f:f32[e] g:f32[e]. let - # h:f32[e] = add f g - # in (h,) } - # name=g - # ] a b c - # in (d,) } - jaxpr = jax.make_jaxpr(f)(3) # doesn't fail on the addition! - a, = jaxpr.jaxpr.invars - b, = jaxpr.jaxpr.eqns[0].outvars - c, = jaxpr.jaxpr.eqns[1].outvars - d, = jaxpr.jaxpr.eqns[2].outvars - self.assertIs(a, b.aval.shape[0]) - self.assertIs(a, c.aval.shape[0]) - self.assertIs(a, d.aval.shape[0]) - - def test_passing_in_equal_polymorphic_shapes_and_adding(self): - def f(n): - x = jnp.zeros(n) - - @jax.jit - def g(x, y): - return x + y - return g(x, x) - - # { lambda ; a:i32[]. let - # b:f32[a] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 0.0 a - # c:f32[a] = pjit[ - # jaxpr={ lambda ; d:i32[] e:f32[d] f:f32[d]. let - # g:f32[d] = add e f - # in (g,) } - # name=g - # ] a b b - # in (c,) } - jaxpr = jax.make_jaxpr(f)(3) - a, = jaxpr.jaxpr.invars - c, = jaxpr.jaxpr.outvars - self.assertLen(c.aval.shape, 1) - self.assertIs(a, c.aval.shape[0]) - - @unittest.skip("doesn't work yet: shape error b/c we don't notice x and y same") - def test_closing_over_and_passing_arg_addition(self): - # TODO(mattjj,dougalm): currently fails to notice equal shapes, fix! - def f(n): - x = jnp.zeros(n) - - @jax.jit - def g(y): - return x + y - return g(x) - - _ = jax.make_jaxpr(f)(3) - - @unittest.skip("doesn't work yet: shape error b/c we don't notice x and jnp.zeros(m) same") - def test_closing_over_and_passing_size_addition(self): - # TODO(mattjj,dougalm): currently fails to notice equal shapes, fix! - def f(n): - x = jnp.zeros(n) - - @jax.jit - def g(m): - return jnp.zeros(m) + x - return g(n) - - _ = jax.make_jaxpr(f)(3) - - def test_closing_over_and_broadcasting_polymorphic_shape(self): - def f(n): - x = jnp.zeros(n) - @jax.jit - def g(): - return jnp.zeros(n) + x - return g() - - # { lambda ; a:i32[]. let - # b:f32[a] = bcast[broadcast_dimensions=() shape=(None,)] 0.0 a - # c:f32[a] = pjit[ - # jaxpr={ lambda ; d:i32[] e:f32[d]. let - # f:f32[d] = bcast[broadcast_dimensions=() shape=(None,)] 0.0 d - # g:f32[d] = add f e - # in (g,) } - # name=g - # ] a b - # in (c,) } - jaxpr = jax.make_jaxpr(f)(3) - - a, = jaxpr.jaxpr.invars - c, = jaxpr.jaxpr.outvars - self.assertLen(c.aval.shape, 1) - self.assertIs(a, c.aval.shape[0]) - - def test_closing_over_repeated_shapes(self): - def zeros(shape): - if not isinstance(shape, (tuple, list)): - shape = shape, - return lax.broadcast(0., shape) - - def f(n): - m = 2 * n - x = zeros((m, m)) - return jax.jit(lambda: x.sum(0))() - - # { lambda ; a:i32[]. let - # b:i32[] = mul a 2 - # c:f32[b,b] = broadcast_in_dim[broadcast_dimensions=() shape=(None, None)] 0.0 - # b b - # d:f32[b] = pjit[ - # jaxpr={ lambda ; e:i32[] f:f32[e,e]. let - # g:f32[e] = reduce_sum[axes=(0,)] f - # in (g,) } - # name= - # ] b c - # in (d,) } - jaxpr = jax.make_jaxpr(f)(3) - a, = jaxpr.jaxpr.invars - b, = jaxpr.jaxpr.eqns[0].outvars - c, = jaxpr.jaxpr.eqns[1].outvars - d, = jaxpr.jaxpr.eqns[2].outvars - b_, c_ = jaxpr.jaxpr.eqns[2].invars - self.assertLen(c.aval.shape, 2) - self.assertIs(c.aval.shape[0], b) - self.assertIs(c.aval.shape[1], b) - self.assertIs(b, b_) - self.assertIs(c, c_) - self.assertLen(d.aval.shape, 1) - self.assertIs(d.aval.shape[0], b) - - def test_staging_repeated_nested(self): - def zeros(shape): - if not isinstance(shape, (tuple, list)): - shape = shape, - return lax.broadcast(jnp.float32(0.), shape) - - def f(n): - m = 2 * n - x = zeros((m, n)) - y = zeros(m) - return jax.jit(lambda x, y: x.sum(1) + y)(x, y) - - # { lambda ; a:i32[]. let - # b:i32[] = mul a 2 - # c:f32[b,a] = broadcast_in_dim[broadcast_dimensions=() shape=(None, None)] 0.0 - # b a - # d:f32[b] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0 b - # e:f32[b] = pjit[ - # jaxpr={ lambda ; f:i32[] g:i32[] h:f32[f,g] i:f32[f]. let - # j:f32[f] = reduce_sum[axes=(1,)] h - # k:f32[f] = add j i - # in (k,) } - # name= - # ] b a c d - # in (e,) } - jaxpr = jax.make_jaxpr(f)(jnp.int32(3)) - a, = jaxpr.jaxpr.invars - b, = jaxpr.jaxpr.eqns[0].outvars - c, = jaxpr.jaxpr.eqns[1].outvars - d, = jaxpr.jaxpr.eqns[2].outvars - e, = jaxpr.jaxpr.eqns[3].outvars - b_, a_, c_, d_ = jaxpr.jaxpr.eqns[3].invars - self.assertLen(c.aval.shape, 2) - self.assertIs(c.aval.shape[0], b) - self.assertIs(c.aval.shape[1], a) - self.assertLen(e.aval.shape, 1) - self.assertIs(e.aval.shape[0], b) - self.assertIs(a, a_) - self.assertIs(b, b_) - self.assertIs(c, c_) - self.assertIs(d, d_) - - def test_jit_abstracted_axes_staging(self): - # We just test make_jaxpr-of-jit because dynamic shape compilation/execution - # may not be supported. - @jax.jit(abstracted_axes=('n',)) - def f(x): - return jnp.sum(x) - jaxpr = jax.make_jaxpr(f)(jnp.ones(3, jnp.dtype('float32'))) - # { lambda ; a:f32[3]. let - # b:f32[] = pjit[ - # jaxpr={ lambda ; c:i32[] d:f32[c]. let - # e:f32[] = reduce_sum[axes=(0,)] d - # in (e,) } - # name=f - # ] 3 a - # in (b,) } - a, = jaxpr.jaxpr.invars - e, = jaxpr.jaxpr.eqns - self.assertLen(e.invars, 2) - self.assertIsInstance(e.invars[0], core.Literal) - self.assertIs(e.invars[1], a) - b, = e.outvars - self.assertLen(b.aval.shape, 0) - - subjaxpr = e.params['jaxpr'] - c, d = subjaxpr.jaxpr.invars - self.assertLen(c.aval.shape, 0) - self.assertLen(d.aval.shape, 1) - self.assertIs(d.aval.shape[0], c) - - def test_jit_abstracted_axes_staging2(self): - @jax.jit(abstracted_axes=('n',)) - def fun(x): - return jnp.sum(x) - jaxpr = jax.make_jaxpr(lambda n: fun(jnp.ones(n + n, jnp.dtype('float32'))) - )(3) - # { lambda ; a:i32[]. let - # b:i32[] = add a a - # c:f32[b] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0 b - # d:f32[] = pjit[ - # jaxpr={ lambda ; e:i32[] f:f32[e]. let - # g:f32[] = reduce_sum[axes=(0,)] f - # in (g,) } - # name=f - # ] b c - # in (d,) } - a, = jaxpr.jaxpr.invars - e1, e2, e3 = jaxpr.jaxpr.eqns - b, = e1.outvars - c, = e2.outvars - b_, c_ = e3.invars - self.assertIs(b, b_) - self.assertIs(c, c_) - - subjaxpr = e3.params['jaxpr'] - e, f = subjaxpr.jaxpr.invars - self.assertLen(e.aval.shape, 0) - self.assertLen(f.aval.shape, 1) - self.assertIs(f.aval.shape[0], e) - - def test_jit_abstracted_axes_staging3(self): - f = jax.jit(jnp.sum, abstracted_axes=('n',)) - jaxpr = jax.make_jaxpr(f, abstracted_axes=('n',))(jnp.arange(3.)) - # { lambda ; a:i32[] b:f32[a]. let - # c:f32[] = pjit[ - # jaxpr={ lambda ; d:i32[] e:f32[d]. let - # f:f32[] = reduce_sum[axes=(0,)] e - # in (f,) } - # name=sum - # ] a b - # in (c,) } - a, b = jaxpr.jaxpr.invars - e, = jaxpr.jaxpr.eqns - self.assertIs(e.invars[0], a) - self.assertIs(e.invars[1], b) - c, = e.outvars - self.assertLen(c.aval.shape, 0) - - subjaxpr = e.params['jaxpr'] - d, e = subjaxpr.jaxpr.invars - self.assertLen(d.aval.shape, 0) - self.assertLen(e.aval.shape, 1) - self.assertIs(e.aval.shape[0], d) - - def test_jit_abstracted_axes_return_polymorphic_shape(self): - f = jax.jit(lambda x: jnp.sin(x), abstracted_axes=('n',)) - jaxpr = jax.make_jaxpr(f)(jnp.arange(3)) # doesn't crash - # { lambda ; a:i32[3]. let - # b:i32[3] = pjit[ - # jaxpr={ lambda ; c:i32[] d:i32[c]. let in (d,) } - # name= - # ] 3 a - # in (b,) } - a, = jaxpr.jaxpr.invars - e, = jaxpr.jaxpr.eqns - three, a_ = e.invars - b, = e.outvars - self.assertIsInstance(three, core.Literal) - self.assertEqual(three.val, 3) - self.assertIs(a_, a) - self.assertLen(b.aval.shape, 1) - self.assertEqual(b.aval.shape[0], 3) - - def test_jit_abstracted_axes_return_polymorphic_shape2(self): - f = jax.jit(lambda n: jnp.ones(n)) - # TODO(mattjj,dougalm): support dynamic shapes in type checker - with jax.enable_checks(False): - jaxpr = jax.make_jaxpr(f)(3) - # { lambda ; a:i32[]. let - # b:f32[a] = pjit[ - # jaxpr={ lambda ; c:i32[]. let - # d:f32[c] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0 - # c - # in (d,) } - # name= - # ] a - # in (b,) } - a, = jaxpr.jaxpr.invars - e, = jaxpr.jaxpr.eqns - a_, = e.invars - self.assertIs(a, a_) - b, = e.outvars - a__, = b.aval.shape - self.assertIs(a, a__) - - with jax.enable_checks(False): - jaxpr = jax.make_jaxpr(lambda: f(3))() - # { lambda ; . let - # a:f32[3] = pjit[ - # jaxpr={ lambda ; b:i32[]. let - # c:f32[b] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] 1.0 - # b - # in (c,) } - # name= - # ] 3 - # in (a,) } - () = jaxpr.jaxpr.invars - e, = jaxpr.jaxpr.eqns - three, = e.invars - self.assertIsInstance(three, core.Literal) - self.assertEqual(three.val, 3) - b, = e.outvars - three_, = b.aval.shape - self.assertIsInstance(three_, int) - self.assertEqual(three_, 3) - - def test_zero_size_checking(self): - def f(x): - if core.definitely_equal(x.size, 0): - return x - else: - return -x - - x = jnp.zeros(1) - jaxpr = jax.make_jaxpr(f, abstracted_axes={0: 'n'})(x) # doesn't crash - self.assertGreaterEqual(len(jaxpr.jaxpr.eqns), 1) - - y = jnp.zeros((2, 0)) - jaxpr = jax.make_jaxpr(f, abstracted_axes={0: 'n'})(y) # doesn't crash - self.assertLen(jaxpr.jaxpr.eqns, 0) - - def test_flattening_basic(self): - x = jnp.zeros((2, 3, 4, 5)) - - # don't need to divide or multiply any dynamic axis sizes - jaxpr = jax.make_jaxpr(lambda x: x.reshape(x.shape[0], -1), - abstracted_axes={0: 'n'})(x) - self.assertLen(jaxpr.jaxpr.eqns, 1) - jaxpr = jax.make_jaxpr(lambda x: x.reshape(3, x.shape[0], -1), - abstracted_axes={0: 'n'})(x) - self.assertLen(jaxpr.jaxpr.eqns, 1) - jaxpr = jax.make_jaxpr(lambda x: x.reshape(-1, x.shape[0]), - abstracted_axes={0: 'n'})(x) - self.assertLen(jaxpr.jaxpr.eqns, 1) - - # don't need to divide but do need a dynamic axis size in multiplication - # (so to typecheck we'd need nontrivial reductions) - jaxpr = jax.make_jaxpr(lambda x: x.reshape(-1), - abstracted_axes={0: 'n'})(x) - self.assertLessEqual(len(jaxpr.jaxpr.eqns), 3) # may have mul with 1 - self.assertEqual(str(jaxpr.jaxpr.eqns[-2].primitive), 'mul') - self.assertEqual(str(jaxpr.jaxpr.eqns[-1].primitive), 'reshape') - jaxpr = jax.make_jaxpr(lambda x: x.reshape(2, -1), - abstracted_axes={0: 'n'})(x) - self.assertLessEqual(len(jaxpr.jaxpr.eqns), 3) - jaxpr = jax.make_jaxpr(lambda x: x.reshape(-1, 12), abstracted_axes={0: 'n'})(x) - self.assertLessEqual(len(jaxpr.jaxpr.eqns), 3) - - def test_shape_validation(self): - # Regression test for https://github.com/jax-ml/jax/issues/18937 - msg = r"Shapes must be 1D sequences of integer scalars, got .+" - with self.assertRaisesRegex(TypeError, msg): - jax.make_jaxpr(jnp.ones)(5.0) - with self.assertRaisesRegex(TypeError, msg): - jax.make_jaxpr(jnp.ones)(jnp.ones((2, 2))) - - def test_matmul_two_arg(self): - def f(x, y): - return jnp.matmul(x, y) - - jaxpr = jax.make_jaxpr(f, abstracted_axes=({0: 'a_0', 1: 'a_1'}, {0: 'a_1', 1: 'a_2'}),)(jnp.ones((8, 4)), jnp.ones((4, 8))) - - def test_matmul_two_arg_size_mismatch_name_validation(self): - def f(x, y): - return jnp.matmul(x, y) - - with self.assertRaisesRegex(TypeError, - re.escape("Provided size 4 for a_1 does not match prior associated name for a_1 : 8")): - jaxpr = jax.make_jaxpr(f, abstracted_axes=({0: 'a_0', 1: 'a_1'}, {0: 'a_1', 1: 'a_2'}),)(jnp.ones((8, 4)), jnp.ones((8, 4))) - -@unittest.skip("Test does not work with jax.Array") -@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow") -class DynamicShapeAutodiffTest(jtu.JaxTestCase): - def test_jvp_broadcast(self): - @jax.jit - def fn(n, x): - return lax.broadcast_in_dim(x, (n,), ()) - - outer_jaxpr = jax.make_jaxpr( - lambda x, t: jax.jvp(lambda y: fn(3, y), (x,), (t,)) - )(3., 4.) - # { lambda ; a:f32[] b:f32[]. let - # c:f32[3] d:f32[3] = pjit[ - # jaxpr={ lambda ; e:i32[] f:f32[] g:f32[]. let - # h:f32[e] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] f e - # i:f32[e] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] g e - # in (h, i) } - # name=f - # ] 3 a b - # in (c, d) } - self.assertLen(outer_jaxpr.jaxpr.eqns, 1) - eqn, = outer_jaxpr.jaxpr.eqns - self.assertIn('jaxpr', eqn.params) - jaxpr = eqn.params['jaxpr'].jaxpr - self.assertLen(jaxpr.invars, 3) - e, f, g = jaxpr.invars - self.assertEqual(e.aval.shape, ()) - self.assertEqual(f.aval.shape, ()) - self.assertEqual(g.aval.shape, ()) - self.assertLen(jaxpr.outvars, 2) - h, i = jaxpr.outvars - self.assertEqual(h.aval.shape, (e,)) - self.assertEqual(i.aval.shape, (e,)) - self.assertLen(eqn.outvars, 2) - c, d = eqn.outvars - self.assertEqual(c.aval.shape, (3,)) - self.assertEqual(d.aval.shape, (3,)) - - def test_jvp_basic(self): - @jax.jit(abstracted_axes=('n',)) - def foo(x): - return jnp.sin(x) - - x = t = jnp.arange(3.) - outer_jaxpr = jax.make_jaxpr(lambda x, t: jax.jvp(foo, (x,), (t,)))(x, t) - # { lambda ; a:f32[3] b:f32[3]. let - # c:f32[3] d:f32[3] = pjit[ - # jaxpr={ lambda ; e:i32[] f:f32[e] g:f32[e]. let - # h:f32[e] = sin f - # i:f32[e] = cos f - # j:f32[e] = mul g i - # in (h, j) } - # name=f - # ] 3 a b - # in (c, d) } - self.assertLen(outer_jaxpr.jaxpr.eqns, 1) - eqn, = outer_jaxpr.eqns - self.assertIn('jaxpr', eqn.params) - jaxpr = eqn.params['jaxpr'].jaxpr - self.assertLen(jaxpr.invars, 3) - e, f, g = jaxpr.invars - self.assertEqual(e.aval.shape, ()) - self.assertEqual(f.aval.shape, (e,)) - self.assertEqual(g.aval.shape, (e,)) - self.assertLen(jaxpr.outvars, 2) - self.assertLen(eqn.outvars, 2) - c, d = eqn.outvars - self.assertEqual(c.aval.shape, (3,)) - self.assertEqual(d.aval.shape, (3,)) - - def test_linearize_basic(self): - @jax.jit(abstracted_axes=('n',)) - def foo(x): - return jax.lax.sin(x) - - x = jnp.arange(3.) - - # primal computation - outer_jaxpr = jax.make_jaxpr(lambda x: jax.linearize(foo, x))(x) - # { lambda ; a:f32[3]. let - # b:f32[3] c:f32[3] = pjit[ - # jaxpr={ lambda ; d:i32[] e:f32[d]. let - # f:f32[d] = sin e - # g:f32[d] = cos e - # in (f, g) } - # name=foo - # ] 3 a - # in (b, c) } - self.assertLen(outer_jaxpr.jaxpr.eqns, 1) - eqn, = outer_jaxpr.jaxpr.eqns - self.assertIn('jaxpr', eqn.params) - jaxpr = eqn.params['jaxpr'].jaxpr - self.assertLen(jaxpr.invars, 2) - d, e = jaxpr.invars - self.assertEqual(d.aval.shape, ()) - self.assertEqual(e.aval.shape, (d,)) - self.assertLen(jaxpr.eqns, 2) - self.assertLen(jaxpr.outvars, 2) - f, g = jaxpr.outvars - self.assertEqual(jaxpr.eqns[0].outvars, [f]) - self.assertEqual(jaxpr.eqns[1].outvars, [g]) - self.assertLen(eqn.outvars, 2) - b, c = eqn.outvars - self.assertEqual(b.aval.shape, (3,)) - self.assertEqual(c.aval.shape, (3,)) - - # primal and tangent computation - outer_jaxpr = jax.make_jaxpr( - lambda x, xdot: jax.linearize(foo, x)[1](xdot))(x, x) - # { lambda ; a:f32[3] b:f32[3]. let - # _:f32[3] c:f32[3] = pjit[ - # jaxpr={ lambda ; d:i32[] e:f32[d]. let - # f:f32[d] = sin e - # g:f32[d] = cos e - # in (f, g) } - # name=foo - # ] 3 a - # h:f32[3] = pjit[ - # jaxpr={ lambda ; i:i32[] j:f32[i] k:f32[i]. let - # l:f32[i] = mul k j - # in (l,) } - # name=foo - # ] 3 c b - # in (h,) } - self.assertLen(outer_jaxpr.jaxpr.eqns, 2) - _, eqn = outer_jaxpr.jaxpr.eqns - self.assertIn('jaxpr', eqn.params) - jaxpr = eqn.params['jaxpr'].jaxpr - self.assertLen(jaxpr.invars, 3) - i, j, k = jaxpr.invars - self.assertEqual(i.aval.shape, ()) - self.assertEqual(j.aval.shape, (i,)) - self.assertEqual(k.aval.shape, (i,)) - self.assertLen(eqn.outvars, 1) - h, = eqn.outvars - self.assertEqual(h.aval.shape, (3,)) - - def test_linearize_basic2(self): - @jax.jit(abstracted_axes=('n',)) - def foo(x): - return jax.jit(jax.lax.sin)(x) - - x = jnp.arange(3.) - outer_jaxpr = jax.make_jaxpr(lambda x: jax.linearize(foo, x))(x) - # { lambda ; a:f32[3]. let - # b:f32[3] c:f32[3] = pjit[ - # jaxpr={ lambda ; d:i32[] e:f32[d]. let - # f:f32[d] g:f32[d] = pjit[ - # jaxpr={ lambda ; h:i32[] i:f32[h]. let - # j:f32[h] = sin i - # k:f32[h] = cos i - # in (j, k) } - # name=sin - # ] d e - # in (f, g) } - # name=foo - # ] 3 a - # in (b, c) } - self.assertLen(outer_jaxpr.jaxpr.eqns, 1) - eqn, = outer_jaxpr.jaxpr.eqns - self.assertLen(eqn.outvars, 2) - b, c = eqn.outvars - self.assertEqual(b.aval.shape, (3,)) - self.assertEqual(c.aval.shape, (3,)) - - def test_grad_basic(self): - @jax.jit(abstracted_axes=('n',)) - def foo(x): - y = jax.lax.sin(x) - return y.sum() - - x = jnp.arange(3.) - outer_jaxpr = jax.make_jaxpr(jax.grad(foo))(x) - # { lambda ; a:f32[3]. let - # _:f32[] b:f32[3] = pjit[ - # jaxpr={ lambda ; c:i32[] d:f32[c]. let - # e:f32[c] = sin d - # f:f32[c] = cos d - # g:f32[] = reduce_sum[axes=(0,)] e - # in (g, f) } - # name=foo - # ] 3 a - # h:f32[3] = pjit[ - # jaxpr={ lambda ; i:i32[] j:f32[i] k:f32[]. let - # l:f32[i] = broadcast_in_dim[broadcast_dimensions=() shape=(None,)] k i - # m:f32[i] = mul l j - # in (m,) } - # name=foo - # ] 3 b 1.0 - # in (h,) } - self.assertLen(outer_jaxpr.jaxpr.eqns, 2) - fwd_eqn, bwd_eqn = outer_jaxpr.jaxpr.eqns - self.assertIn('jaxpr', fwd_eqn.params) - fwd_jaxpr = fwd_eqn.params['jaxpr'].jaxpr - self.assertLen(fwd_jaxpr.invars, 2) - c, d = fwd_jaxpr.invars - self.assertEqual(c.aval.shape, ()) - self.assertEqual(d.aval.shape, (c,)) - self.assertLen(fwd_jaxpr.outvars, 2) - g, f = fwd_jaxpr.outvars - self.assertEqual(g.aval.shape, ()) - self.assertEqual(f.aval.shape, (c,)) - self.assertLen(fwd_eqn.outvars, 2) - _, b = fwd_eqn.outvars - self.assertEqual(b.aval.shape, (3,)) - self.assertIn('jaxpr', bwd_eqn.params) - bwd_jaxpr = bwd_eqn.params['jaxpr'].jaxpr - self.assertLen(bwd_jaxpr.invars, 3) - i, j, k = bwd_jaxpr.invars - self.assertEqual(i.aval.shape, ()) - self.assertEqual(j.aval.shape, (i,)) - self.assertEqual(k.aval.shape, ()) - self.assertLen(bwd_jaxpr.outvars, 1) - m, = bwd_jaxpr.outvars - self.assertEqual(m.aval.shape, (i,)) - self.assertLen(bwd_eqn.outvars, 1) - h, = bwd_eqn.outvars - self.assertEqual(h.aval.shape, (3,)) - - def test_mlp_autodiff_dynamic_batch_toplevel(self): - def predict(params, inputs): - for W, b in params: - outputs = jnp.dot(inputs, W) + b - inputs = jnp.maximum(0, outputs) - return outputs - - def loss(params, batch): - inputs, targets = batch - predictions = predict(params, inputs) - return jnp.sum((predictions - targets) ** 2) - - batch = (inputs, targets) = (jnp.ones((128, 784)), jnp.ones((128, 10))) - params = [(jnp.ones((784, 256)), jnp.ones(256)), - (jnp.ones((256, 256)), jnp.ones(256)), - (jnp.ones((256, 10)), jnp.ones( 10))] - - # jvp - def loss_jvp(params, batch): - return jax.jvp(loss, (params, batch), (params, batch)) - jaxpr = jax.make_jaxpr(loss_jvp, abstracted_axes=({}, {0: 'n'}))(params, batch) - core.check_jaxpr(jaxpr.jaxpr) - - # linearize - def loss_lin(params, batch): - y, f_lin = jax.linearize(loss, params, batch) - y_dot = f_lin(params, batch) - return y, y_dot - jaxpr = jax.make_jaxpr(loss_lin, abstracted_axes=({}, {0: 'n'}))(params, batch) - core.check_jaxpr(jaxpr.jaxpr) - - # grad - jaxpr = jax.make_jaxpr(jax.grad(loss), abstracted_axes=({}, {0: 'n'}))(params, batch) - core.check_jaxpr(jaxpr.jaxpr) - - def test_mlp_autodiff_dynamic_batch_inner(self): - # This is like the above 'toplevel' test, but instead of introducing - # abstracted axes on the make_jaxpr call, we do it on a jit. - - @jax.jit(abstracted_axes=({}, {0: 'n'})) - def predict(params, inputs): - for W, b in params: - outputs = jnp.dot(inputs, W) + b - inputs = jnp.maximum(0, outputs) - return outputs - - def loss(params, batch): - inputs, targets = batch - predictions = predict(params, inputs) - return jnp.sum((predictions - targets) ** 2) - - batch = (inputs, targets) = (jnp.ones((128, 784)), jnp.ones((128, 10))) - params = [(jnp.ones((784, 256)), jnp.ones(256)), - (jnp.ones((256, 256)), jnp.ones(256)), - (jnp.ones((256, 10)), jnp.ones( 10))] - - # jvp - def loss_jvp(params, batch): - return jax.jvp(loss, (params, batch), (params, batch)) - jaxpr = jax.make_jaxpr(loss_jvp)(params, batch) - core.check_jaxpr(jaxpr.jaxpr) - - # linearize - def loss_lin(params, batch): - y, f_lin = jax.linearize(loss, params, batch) - y_dot = f_lin(params, batch) - return y, y_dot - jaxpr = jax.make_jaxpr(loss_lin)(params, batch) - core.check_jaxpr(jaxpr.jaxpr) - - # grad - jaxpr = jax.make_jaxpr(jax.grad(loss))(params, batch) - core.check_jaxpr(jaxpr.jaxpr) - - def test_bint_broadcast(self): - d = lax.convert_element_type(3, core.bint(5)) - bint = lambda x, b: lax.convert_element_type(x, core.bint(b)) - - x = lax.broadcast_in_dim(0, (d,), ()) # doesn't crash - self.assertIsInstance(x, core.DArray) - self.assertAllClose(x._data, np.zeros(5, dtype='int32'), check_dtypes=False) - self.assertEqual( - x._aval, core.DShapedArray((bint(3, 5),), x._data.dtype, True)) - - def f(n): - return jnp.zeros(n) - x = jax.jit(f)(d) - self.assertIsInstance(x, core.DArray) - self.assertAllClose(x._data, np.zeros(5, dtype='int32'), check_dtypes=False) - self.assertEqual( - x._aval, core.DShapedArray((bint(3, 5),), x._data.dtype, False)) - - jaxpr = jax.make_jaxpr(f)(d).jaxpr - # { lambda ; a:bint{≤5}[]. let - # b:f32[a] = broadcast_in_dim[...] 0.0 a - # in (b,) } - self.assertLen(jaxpr.invars, 1) - a, = jaxpr.invars - self.assertEqual(a.aval, core.DShapedArray((), core.bint(5))) - self.assertLen(jaxpr.eqns, 1) - eqn, = jaxpr.eqns - self.assertLen(eqn.outvars, 1) - b, = eqn.outvars - self.assertEqual(b.aval.shape, (a,)) - - def test_vmap_abstracted_axis(self): - def foo(x, y): - z = jax.vmap(jnp.sin)(x) * y - return jax.vmap(jnp.add)(x, z) - - x = jnp.arange(3.) - jaxpr = jax.make_jaxpr(foo, abstracted_axes=('n',))(x, x).jaxpr - self.assertLen(jaxpr.invars, 3) - a, b, c = jaxpr.invars - self.assertEqual(a.aval.shape, ()) - self.assertEqual(b.aval.shape, (a,)) - self.assertEqual(c.aval.shape, (a,)) - self.assertLen(jaxpr.eqns, 3) - self.assertLen(jaxpr.outvars, 1) - f, = jaxpr.outvars - self.assertEqual(f.aval.shape, (a,)) - - def test_vmap_abstracted_axes_2d(self): - def foo(x, y): - z = jax.vmap(jax.vmap(jnp.sin))(x) * y - return jax.vmap(jax.vmap(jnp.add))(x, z) - - x = jnp.arange(12.).reshape(3, 4) - jaxpr = jax.make_jaxpr(foo, abstracted_axes=('n', 'm'))(x, x).jaxpr - self.assertLen(jaxpr.invars, 4) - a, b, c, d = jaxpr.invars - self.assertEqual(a.aval.shape, ()) - self.assertEqual(b.aval.shape, ()) - self.assertEqual(c.aval.shape, (a, b)) - self.assertEqual(c.aval.shape, (a, b)) - self.assertLen(jaxpr.eqns, 3) - self.assertLen(jaxpr.outvars, 1) - f, = jaxpr.outvars - self.assertEqual(f.aval.shape, (a, b)) - - def test_vmap_of_indexing_basic(self): - x = jnp.arange(3.) - - def f(idxs): - return jax.vmap(lambda i: x[i])(idxs) - - idxs = jnp.arange(3) - jaxpr = jax.make_jaxpr(f, abstracted_axes=('n',))(idxs).jaxpr - # { lambda a:f32[3]; b:i32[] c:i32[b]. let - # d:bool[b] = lt c 0 - # e:i32[b] = add c 3 - # f:i32[b] = select_n d c e - # g:i32[b,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(None, 1)] f b - # h:f32[b,1] = gather[ - # dimension_numbers=GatherDimensionNumbers(offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)) - # fill_value=None - # indices_are_sorted=False - # mode=GatherScatterMode.PROMISE_IN_BOUNDS - # slice_sizes=(1,) - # unique_indices=False - # ] a g - # i:f32[b] = squeeze[dimensions=(1,)] h - # in (i,) } - b, _ = jaxpr.invars - e, = (e for e in jaxpr.eqns if str(e.primitive) == 'gather') - h, = e.outvars - self.assertEqual(h.aval.shape, (b, 1)) - - def test_einsum_basic(self): - x = jnp.arange(20.).reshape(4, 5) - - def f(x): - return jnp.einsum('ij,kj->ik', x, x) - - jaxpr = jax.make_jaxpr(f, abstracted_axes=('n', 'm'))(x).jaxpr - # { lambda ; a:i32[] b:i32[] c:f32[a,b]. let - # d:f32[a,a] = pjit[ - # jaxpr={ lambda ; e:i32[] f:i32[] g:f32[e,f] h:f32[e,f]. let - # i:f32[e,e] = dot_general[ - # dimension_numbers=(((1,), (1,)), ((), ())) - # precision=None - # preferred_element_type=None - # ] g h - # in (i,) } - # name=_einsum - # ] a b c c - # in (d,) } - self.assertLen(jaxpr.invars, 3) - a, b, c = jaxpr.invars - self.assertEqual(c.aval.shape[0], a) - self.assertLen(jaxpr.eqns, 1) - self.assertLen(jaxpr.eqns[0].outvars, 1) - d, = jaxpr.eqns[0].outvars - self.assertEqual(d.aval.shape, (a, a)) - - def test_inferring_valid_subjaxpr_type_add(self): - def f(x): - return x + x.shape[0] - - jax.make_jaxpr(f, abstracted_axes=('n',))(jnp.arange(3)) # doesn't crash - - def test_slicing_basic_jaxpr(self): - def f(x): - return x[0] - - jaxpr = jax.make_jaxpr(f, abstracted_axes=(None, 'n'))(jnp.zeros((3, 4))) - # { lambda ; a:i32[] b:f32[3,a]. let - # c:f32[1,a] = dynamic_slice[slice_sizes=(1, None)] b 0 0 a - # d:f32[a] = squeeze[dimensions=(0,)] c - # in (d,) } - self.assertLen(jaxpr.jaxpr.invars, 2) - a, _ = jaxpr.jaxpr.invars - self.assertLen(jaxpr.jaxpr.outvars, 1) - d, = jaxpr.jaxpr.outvars - self.assertLen(d.aval.shape, 1) - self.assertEqual(d.aval.shape, (a,)) - - def test_shape_tuple_argument_to_zeros(self): - @jax.jit(abstracted_axes=(('n',), ('n',))) - def f(x, y): - zero = jnp.zeros(jnp.shape(x)) - return zero * y - - x = jnp.arange(3.0) - y = jnp.arange(3.0) + 1 - jax.make_jaxpr(f)(x, y) # doesn't crash - -@unittest.skip("Test does not work with jax.Array") -@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow") -class DynamicShapeExecutionTest(jtu.JaxTestCase): - def test_jit_basic(self): - @jax.jit - def f(i): - return jnp.sum(jnp.ones(i, dtype='float32')) - self.assertAllClose(f(3), jnp.array(3., dtype='float32'), check_dtypes=True) - - def test_jit_basic_2(self): - count = 0 - - @jax.jit(abstracted_axes=('n',)) - def f(x): - nonlocal count - count += 1 - return jnp.sum(x) - - x = f(np.arange(3)) - y = f(np.arange(4)) - self.assertAllClose(x, 3., check_dtypes=False) - self.assertAllClose(y, 6., check_dtypes=False) - self.assertEqual(count, 1) - - def test_jit_polymorphic_output(self): - # like test_jit_basic, but without the jnp.sum! - count = 0 - - @jax.jit - def f(i): - nonlocal count - count += 1 - return jnp.ones(i, dtype='float32') - - self.assertAllClose(f(3), np.ones(3, dtype='float32'), check_dtypes=True) - self.assertAllClose(f(4), np.ones(4, dtype='float32'), check_dtypes=True) - self.assertEqual(count, 1) - - @unittest.skip('TODO: need typechecking rule for concatenate') - def test_concatenate(self): - @jax.jit(abstracted_axes=({0: 'n'},)) - def f(x): # x: f32[n, 4] - return jnp.concatenate([x, x, x], axis=0) - - f(np.ones((5, 4), dtype=np.float32)) - # TODO: add assertions - - def test_reshape(self): - @jax.jit(abstracted_axes=({0: 'n'},)) - def f(x): # x: f32[n, 4] - return jnp.reshape(x, (2, -1)) - - f(np.ones((5, 4), dtype=np.float32)) - # TODO: add assertions - - def test_nested(self): - @jax.jit - def nested_f(x): # f32[h, v] -> f32[h, v] - # A nested call that needs shape variables - return jnp.sin(x) - - @jax.jit(abstracted_axes=({0: 'h', 1: 'v'},)) - def f(x): # f32[h, w] -> f32[h, w] - return jnp.sin(x) + jax.jit(nested_f)(x) - f(np.ones((3, 5), dtype=np.float32)) - # TODO: add assertions - - def test_nested_arange(self): - def nested_f(x): # f32[h, v] -> f32[h, v] - # A nested call that needs to compute with shapes - return jnp.arange(x.shape[0] * x.shape[1], dtype=x.dtype).reshape(x.shape) - - @jax.jit(abstracted_axes=({0: 'h', 1: 'w'},)) - def f(x): # f32[h, w] -> f32[h, w] - return x + jax.jit(nested_f)(x) - f(np.ones((3, 5), dtype=np.float32)) - # TODO: add assertions - - def test_transpose(self): - # see also https://github.com/iree-org/iree-jax/issues/57 - @jax.jit(abstracted_axes=({0: 'h', 1: 'w'},)) - def f(x): # f32[h, w] -> f32[w, h] - return x.T - - f(np.ones((3, 5), dtype=np.float32)) # doesn't crash - # TODO: add assertions - - def test_matmul(self): - @jax.jit(abstracted_axes=({0: 'w', 1: 'w'},)) - def f(x): # f32[w, w] -> f32[w, w] - return jnp.matmul(x, x) - - f(np.ones((5, 5), dtype=np.float32)) - # TODO: add assertions - - def test_matmul_shape_error(self): - @jax.jit(abstracted_axes=({0: 'h', 1: 'w'},)) - def f(x): # f32[h, w] -> error - return jnp.matmul(x, x) - - # TODO(necula): improve error message, print actual shapes - with self.assertRaisesRegex(TypeError, - re.escape("dot_general requires contracting dimensions to have the same shape, got")): - f(np.ones((5, 5), dtype=np.float32)) - - @unittest.skip("TODO: investigate failure") - def test_cond(self): - @jax.jit(abstracted_axes=({0: 'w', 1: 'w'},)) - def f(x): # f32[w, w] -> f32[w, w] - return lax.cond(True, - lambda x: jnp.sin(x), - lambda x: jnp.matmul(x, x), x) - f(np.ones((5, 5), dtype=np.float32)) - # TODO: add assertions - - def test_arange(self): - @jax.jit(abstracted_axes=({0: 'w'},)) - def f(x): # f32[w] -> f32[w] - return jnp.arange(x.shape[0], dtype=x.dtype) + x - f(np.ones((5,), dtype=np.float32)) - # TODO: add assertions - - def test_broadcast(self): - @jax.jit(abstracted_axes=({0: 'w'},)) - def f(x): # f32[w] -> f32[w, w] - return jnp.broadcast_to(x, (x.shape[0], x.shape[0])) - f(np.ones((5,), dtype=np.float32)) - # TODO: add assertions - - def test_zeros(self): - @jax.jit(abstracted_axes=({0: 'w'},)) - def f(x): # f32[w] -> f32[w] - return jnp.zeros(x.shape[0], dtype=x.dtype) + x - f(np.ones((5,), dtype=np.float32)) - # TODO: add assertions - - def test_stack(self): - @jax.jit(abstracted_axes=({0: 'w'},)) - def f(x): - return jnp.stack([jnp.sin(x), jnp.cos(x)]) - - f(np.ones((5,), dtype=np.float32)) - # TODO: add assertions - - def test_jit_dependent_pair_output(self): - # Like the above 'polymorhpic output' test, but now with a `2 * n`! - count = 0 - - @jax.jit - def f(n): - nonlocal count - count += 1 - return jnp.arange(2 * n) - - x = f(3) - y = f(4) - self.assertAllClose(x, jnp.arange(2 * 3), check_dtypes=False) - self.assertAllClose(y, jnp.arange(2 * 4), check_dtypes=False) - self.assertEqual(count, 1) - - @unittest.skip("revising slicing logic") - def test_slicing_basic(self): - f = jax.jit(lambda x, n: jnp.sum(x[:n])) - # TODO(mattjj): revise getslice, add typecheck rule for it, enable checks - with jax.enable_checks(False): - ans = f(jnp.arange(10), 3) - expected = jnp.sum(jnp.arange(10)[:3]) - self.assertAllClose(ans, expected, check_dtypes=True) - - # TODO(mattjj,dougalm,phawkins): debug iree failure, "failed to legalize - # operation 'while' that was explicitly marked illegal" - @unittest.skip("revising slicing logic") - def test_scan_basic(self): - def cumsum(x): - def body(i, _): - return i + 1, jnp.sum(x[:i+1]) - _, ans = lax.scan(body, 0, None, length=len(x)) - return ans - x = jnp.array([3, 1, 4, 1, 5, 9]) - with jax.enable_checks(False): - ans = cumsum(x) - expected = jnp.cumsum(x) - self.assertAllClose(ans, expected, check_dtypes=False) - - def test_jit_of_broadcast(self): - x = jax.jit(jnp.ones)(3) - self.assertAllClose(x, jnp.ones(3)) - - def test_jit_of_broadcast2(self): - x = jax.jit(lambda n: jnp.ones(2 * n))(3) - self.assertAllClose(x, jnp.ones(2 * 3)) - - def test_mlp_autodiff_dynamic_batch(self): - count = 0 - - def predict(params, inputs): - for W, b in params: - outputs = jnp.dot(inputs, W) + b - inputs = jnp.maximum(0, outputs) - return outputs - - def loss_ref(params, batch): - nonlocal count - count += 1 # count retraces - inputs, targets = batch - predictions = predict(params, inputs) - return jnp.sum((predictions - targets) ** 2) - - loss = jax.jit(loss_ref, abstracted_axes=({}, {0: 'n'})) - - params = [(jnp.ones((784, 256)), jnp.ones(256)), - (jnp.ones((256, 10)), jnp.ones( 10))] - - # two different size batches - batch1 = (inputs, targets) = (jnp.ones((128, 784)), jnp.ones((128, 10))) - batch2 = (inputs, targets) = (jnp.ones((32, 784)), jnp.ones((32, 10))) - - _ = loss(params, batch1) - _ = loss(params, batch2) - self.assertEqual(count, 1) - - _ = jax.grad(loss)(params, batch1) - _ = jax.grad(loss)(params, batch2) - self.assertEqual(count, 2) - - ans = loss( params, batch1) - expected = loss_ref(params, batch1) - self.assertAllClose(ans, expected) - - ans = jax.grad(loss )(params, batch1) - expected = jax.grad(loss_ref)(params, batch1) - self.assertAllClose(ans, expected) - - @jax.enable_checks(False) # TODO(mattjj): upgrade typecompat to handle bints - def test_mlp_autodiff_dynamic_batch_bint(self): - count = 0 - - def predict(params, inputs): - for W, b in params: - outputs = jnp.dot(inputs, W) + b - inputs = jnp.maximum(0, outputs) - return outputs - - def loss_ref(params, batch): - nonlocal count - count += 1 # count traces - inputs, targets = batch - predictions = predict(params, inputs) - return jnp.sum((predictions - targets) ** 2) - - loss = jax.jit(loss_ref, abstracted_axes=({}, {0: 'n'})) - - params = [(jnp.ones((784, 256)), jnp.ones(256)), - (jnp.ones((256, 10)), jnp.ones( 10))] - - # two different batch sizes *with bints* - bs1 = jax.lax.convert_element_type(128, core.bint(128)) - batch1 = (jnp.ones((bs1, 784)), jnp.ones((bs1, 10))) - - bs2 = jax.lax.convert_element_type(32, core.bint(128)) - batch2 = (jnp.ones((bs2, 784)), jnp.ones((bs2, 10))) - - # count retraces (and don't crash) - self.assertEqual(count, 0) - _ = jax.grad(loss)(params, batch1) - self.assertEqual(count, 1) - g2 = jax.grad(loss)(params, batch2) - self.assertEqual(count, 1) # cache hit! - - # check the numbers make sense - batch = (jnp.ones((32, 784)), jnp.ones((32, 10))) - g2_expected = jax.grad(loss_ref)(params, batch) - self.assertAllClose(g2, g2_expected, check_dtypes=False, - atol=1e-3, rtol=1e-3) - - def test_bint_basic(self): - d = lax.convert_element_type(3, core.bint(5)) - self.assertEqual(str(d), '3{≤5}') - - @jax.jit - def f(d): - jnp.sin(3.) # don't have an empty jaxpr - return d - f(d) # doesn't crash - - def test_bint_iota(self): - def f(d): - return jnp.arange(d, dtype='int32') - - y = f(lax.convert_element_type(3, core.bint(5))) - self.assertIsInstance(y, core.DArray) - self.assertAllClose(y._data, np.arange(5), check_dtypes=False) - - d = lax.convert_element_type(3, core.bint(5)) - y = jax.jit(f)(d) - self.assertIsInstance(y, core.DArray) - self.assertAllClose(y._data, np.arange(5), check_dtypes=False) - - def test_bint_compilation_cache(self): - count = 0 - - @jax.jit - def f(n): - nonlocal count - count += 1 - return jnp.zeros(n) - f(lax.convert_element_type(3, core.bint(5))) - f(lax.convert_element_type(4, core.bint(5))) - self.assertEqual(count, 1) - - def test_bint_compilation_cache2(self): - count = 0 - - @jax.jit(abstracted_axes=('n',)) - def f(x): - nonlocal count - count += 1 - return x.sum() - - d = lax.convert_element_type(3, core.bint(5)) - x = jnp.arange(d) - y = f(x) - self.assertEqual(y, 3) - self.assertEqual(count, 1) - - d = lax.convert_element_type(4, core.bint(5)) - x = jnp.arange(d) - y = f(x) - self.assertEqual(y, 6) - self.assertEqual(count, 1) - - d = lax.convert_element_type(4, core.bint(6)) - x = jnp.arange(d) - y = f(x) - self.assertEqual(y, 6) - self.assertEqual(count, 2) - - @unittest.skip('do we want to support this?') - def test_bint_add(self): - d = lax.convert_element_type(4, core.bint(6)) - x = jnp.arange(d) - - @jax.jit - def f(x): - return x + x - - f(x) # doesn't crash - - def test_lower_abstracted_axes(self): - @jax.jit(abstracted_axes=('n',)) - def f(x): - return x.sum() - - f_lowered = f.lower(np.arange(3, dtype='int32')) - mlir_str = f_lowered.compiler_ir() - self.assertIn('tensor', str(mlir_str)) - - def test_lower_abstracted_axes_shapedtypestruct(self): - @jax.jit(abstracted_axes=('n',)) - def f(x): - return x.sum() - - f_lowered = f.lower(jax.ShapeDtypeStruct((3,), np.int32)) - mlir_str = f_lowered.compiler_ir() - self.assertIn('tensor', str(mlir_str)) - - def test_slicing_basic_lower(self): - @jax.jit(abstracted_axes=(None, 'n')) - def f(x): - return x[0] - f.lower(jnp.zeros((3, 4))).compiler_ir() # doesn't crash - - def test_slicing_basic_execute(self): - @jax.jit(abstracted_axes=(None, 'n')) - def f(x): - return x[0] - - y = f(jnp.arange(3 * 4).reshape(3, 4)) - self.assertAllClose(y, jnp.array([0, 1, 2, 3])) - - def test_gather_basic_bounded(self): - x = jnp.arange(3. * 4.).reshape(3, 4) - - def f(i): - return x[i] - - sz = jax.lax.convert_element_type(2, core.bint(3)) - idx = jnp.arange(sz) - y = jax.jit(jax.vmap(f), abstracted_axes=('n',))(idx) - - self.assertIsInstance(y, core.DArray) - self.assertEqual(y.shape, (sz, 4)) - self.assertAllClose(y._data, x) - -@unittest.skip("currently unmaintained") -@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow", - jax_traceback_filtering='off') -class JumbleTest(jtu.JaxTestCase): - - def setUp(self): - super().setUp() - if jax.config.x64_enabled: raise unittest.SkipTest() - - @parameterized.parameters((True,), (False,)) - def test_internal_jumble(self, disable_jit): - with jax.disable_jit(disable_jit): - ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - xs = jax.vmap(lambda n: jax.lax.iota('int32', n).sum())(ins) - self.assertAllClose(xs, jnp.array([3, 0, 6]), check_dtypes=False) - - def test_jumble_escapes(self): - ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - xs = jax.vmap(jax.jit(lambda n: jax.lax.iota('int32', n)), - out_axes=batching.jumble_axis)(ins) - self.assertIsInstance(xs, batching.Jumble) - data = jax.lax.broadcasted_iota('int32', (3, 5), 1) - self.assertAllClose(xs.data, data, check_dtypes=False) - - def test_make_jumble_from_dynamic_shape(self): - # We may not want to support returning jumbles from vmapped functions - # (instead preferring to have a separate API which allows jumbles). But for - # now it makes for a convenient way to construct jumbles for the other - # tests! - ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - p = jax.vmap(partial(jnp.arange, dtype='int32'), - out_axes=batching.jumble_axis)(ins) - self.assertIsInstance(p, batching.Jumble) - self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+\]') - data = jax.lax.broadcasted_iota('int32', (3, 5), 1) - self.assertAllClose(p.data, data, check_dtypes=False) - - def test_jumble_map_eltwise(self): - ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - p = jax.vmap(partial(jnp.arange, dtype='int32'), - out_axes=batching.jumble_axis)(ins) - p = jumble_map(jax.jit(lambda x: x * 3))(p) - self.assertIsInstance(p, batching.Jumble) - self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+\]') - data = jax.lax.broadcasted_iota('int32', (3, 5), 1) * 3 - self.assertAllClose(p.data, data, check_dtypes=False) - - def test_jumble_map_vector_dot(self): - ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - p = jax.vmap(partial(jnp.arange, dtype='int32'), - out_axes=batching.jumble_axis)(ins) - y = jumble_map(jnp.dot)(p, p) - self.assertIsInstance(y, batching.Jumble) - self.assertAllClose(y.data, jnp.array([5, 0, 14], dtype='int32')) - - @parameterized.parameters((True,), (False,)) - def test_jumble_map_matrix_dot_ragged_contract(self, disable_jit): - with jax.disable_jit(disable_jit): - sizes = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - p1 = jax.vmap(lambda n: jnp.ones((7, n)), out_axes=batching.jumble_axis - )(sizes) - p2 = jax.vmap(lambda n: jnp.ones((n, 7)), out_axes=batching.jumble_axis - )(sizes) - y = jax.vmap(jnp.dot, in_axes=batching.jumble_axis, out_axes=0, - axis_size=3)(p1, p2) - self.assertAllClose(y, np.tile(np.array([3, 1, 4])[:, None, None], (7, 7)), - check_dtypes=False) - - @parameterized.parameters((True,), (False,)) - def test_jumble_map_matrix_dot_ragged_tensor(self, disable_jit): - with jax.disable_jit(disable_jit): - sizes = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - def func(size): - lhs_one_d = jnp.arange(size, dtype='int32') + 1 - lhs_two_d = jax.lax.broadcast_in_dim(lhs_one_d, (size, 2), (0,)) - rhs = jax.lax.broadcasted_iota('int32', (2, 4), 0) + 1 - return jnp.dot(lhs_two_d, rhs) - p = jax.vmap(func, out_axes=batching.jumble_axis)(sizes) - self.assertIsInstance(p, batching.Jumble) - self.assertEqual(p.data.shape, (3, 5, 4)) - - def test_broadcast_in_dim_while_ragged(self): - ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - def func(size): - one_d = jnp.arange(size, dtype='int32') - two_d = jax.lax.broadcast_in_dim(one_d, (size, 7), (0,)) - return two_d - p = jax.vmap(func, out_axes=batching.jumble_axis)(ins) - self.assertIsInstance(p, batching.Jumble) - data = jax.lax.broadcasted_iota('int32', (3, 5, 7), 1) - self.assertAllClose(p.data, data) - - def test_broadcast_in_dim_to_ragged(self): - ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - def func(size): - one_d = jnp.arange(12, dtype='int32') - two_d = jax.lax.broadcast_in_dim(one_d, (size, 12), (1,)) - return two_d - p = jax.vmap(func, out_axes=batching.jumble_axis)(ins) - self.assertIsInstance(p, batching.Jumble) - data = jax.lax.broadcasted_iota('int32', (3, 5, 12), 2) - self.assertAllClose(p.data, data) - - def test_broadcast_in_dim_ragged_to_static_error(self): - ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - def func(size): - one_d = jnp.arange(size, dtype='int32') - # Broadcast should error even if the target shape is the same as the - # underlying data shape, because the semantic size doesn't match. - two_d = jax.lax.broadcast_in_dim(one_d, (4, 5), (1,)) - return two_d - msg = r"got operand of shape \(\[dynamic\],\), target broadcast shape \(4, 5\)" - with self.assertRaisesRegex(TypeError, msg): - jax.vmap(func, out_axes=batching.jumble_axis)(ins) - - def test_broadcast_in_dim_to_doubly_ragged(self): - ins1 = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - ins2 = lax.convert_element_type(jnp.array([2, 5, 1]), core.bint(6)) - def func(size1, size2): - one_d = jnp.arange(size1, dtype='int32') - two_d = jax.lax.broadcast_in_dim(one_d, (size1, size2), (0,)) - return two_d - p = jax.vmap(func, out_axes=batching.jumble_axis)(ins1, ins2) - self.assertIsInstance(p, batching.Jumble) - data = jax.lax.broadcasted_iota('int32', (3, 5, 6), 1) - self.assertAllClose(p.data, data) - - def test_squeeze_ragged(self): - ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - def func(size): - one_d = jnp.arange(size, dtype='int32') - two_d = jax.lax.broadcast_in_dim(one_d, (size, 1), (0,)) - one_again = jax.lax.squeeze(two_d, dimensions=[1]) - return one_again - p = jax.vmap(func, out_axes=batching.jumble_axis)(ins) - self.assertIsInstance(p, batching.Jumble) - data = jax.lax.broadcasted_iota('int32', (3, 5), 1) - self.assertAllClose(p.data, data) - - def test_broadcast_to_while_ragged(self): - ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - def func(size): - one_d = jnp.arange(size, dtype='int32') - two_d = jnp.broadcast_to(one_d, (4, size)) - return two_d - p = jax.vmap(func, out_axes=batching.jumble_axis)(ins) - self.assertIsInstance(p, batching.Jumble) - data = jax.lax.broadcasted_iota('int32', (3, 4, 5), 2) - self.assertAllClose(p.data, data) - - def test_broadcast_to_doubly_ragged(self): - ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - def func(size): - one_d = jnp.arange(size, dtype='int32') - two_d = jnp.broadcast_to(one_d, (size, size)) - return two_d - p = jax.vmap(func, out_axes=batching.jumble_axis)(ins) - self.assertIsInstance(p, batching.Jumble) - data = jax.lax.broadcasted_iota('int32', (3, 5, 5), 2) - self.assertAllClose(p.data, data) - - def test_transpose_ragged(self): - ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - def func(size): - one_d = jnp.arange(size, dtype='int32') - two_d = jnp.broadcast_to(one_d, (7, size)) - return jnp.transpose(two_d, [1, 0]) - p = jax.vmap(func, out_axes=batching.jumble_axis)(ins) - self.assertIsInstance(p, batching.Jumble) - data = jax.lax.broadcasted_iota('int32', (3, 5, 7), 1) - self.assertAllClose(p.data, data) - - def test_einsum_with_ragged_tensor_dimension(self): - x_sizes = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - def fprop_layer(x_size): - one_d = jnp.arange(x_size, dtype='int32') - x = jax.lax.broadcast_in_dim(one_d, (x_size, 11), [0]) - wqkv = jax.lax.broadcasted_iota('int32', (3, 2, 7, 11), 1) - qkv = jnp.einsum('te,ihqe->ithq', x, wqkv) - return qkv - p = jax.vmap(fprop_layer, out_axes=batching.jumble_axis)(x_sizes) - self.assertIsInstance(p, batching.Jumble) - self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[3,bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+,2,7\]') - self.assertEqual(p.data.shape, (3, 3, 5, 2, 7)) - - @parameterized.parameters((True,), (False,)) - def test_einsum_with_ragged_tensor_and_contract_dimensions(self, disable_jit): - with jax.disable_jit(disable_jit): - ragged_sizes = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - def fprop_layer(ragged_size): - one_d = jnp.arange(ragged_size, dtype='int32') - alpha = jax.lax.broadcast_in_dim(one_d, (ragged_size, ragged_size, 2), [1]) - v = jax.lax.broadcast_in_dim(one_d, (ragged_size, 2, 7), [0]) - inner = jnp.einsum('tsh,shq->thq', alpha, v) - return inner - p = jax.vmap(fprop_layer, out_axes=batching.jumble_axis)(ragged_sizes) - self.assertIsInstance(p, batching.Jumble) - self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+,2,7\]') - self.assertEqual(p.data.shape, (3, 5, 2, 7)) - - def test_split_while_ragged(self): - ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5)) - def func(size): - one_d = jnp.arange(size, dtype='int32') - two_d = jnp.broadcast_to(one_d, (2, size)) - part_1, part_2 = two_d - return part_1 - p = jax.vmap(func, out_axes=batching.jumble_axis)(ins) - self.assertIsInstance(p, batching.Jumble) - self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+\]') - data = jax.lax.broadcasted_iota('int32', (3, 5), 1) - self.assertAllClose(p.data, data) - - @parameterized.parameters((True,), (False,)) - @unittest.skip("test fails at head") - def test_jumble_map_end_to_end_fprop_layer(self, disable_jit): - - def fprop_layer(params, x): - ((xnorm_scale, xnorm_bias), (wqkv, wqkv_bias), (wo, wo_bias), - (ynorm_scale, ynorm_bias), (w_i, w_i_bias), (w_o, w_o_bias)) = params - xnorm = jax.nn.standardize(x) * xnorm_scale + xnorm_bias - qkv = jnp.einsum('te,ihqe->ithq', xnorm, wqkv) + wqkv_bias[:, None] - q, k, v = qkv - outer = jnp.einsum('thq,shq->tsh', q, k) / jnp.asarray( - jnp.sqrt(v.shape[-1]), dtype=x.dtype) - - alpha = jax.nn.softmax(outer, 2) - inner = jnp.einsum('tsh,shq->thq', alpha, v) - y = jnp.einsum('thq,hqe->te', inner, wo) + wo_bias + x - ynorm = jax.nn.standardize(y) * ynorm_scale + ynorm_bias - act = jax.nn.gelu(jnp.einsum('te,ef->tf', ynorm, w_i) + w_i_bias) - z = jnp.einsum('tf,fe->te', act, w_o) + w_o_bias + y - return z - - params = [ - (jnp.ones(128), jnp.zeros(128)), # xnorm_scale, xnorm_bias - (jnp.ones((3, 16, 64, 128)), jnp.zeros((3, 16, 64))), # wqkv, wqkv_bias - (jnp.ones((16, 64, 128)), jnp.zeros(128)), # wo, wo_bias - (jnp.ones(128), jnp.zeros(128)), # ynorm_scale, ynorm_bias - (jnp.ones((128, 4096)), jnp.zeros(4096)), # w_i, w_i_bias - (jnp.ones((4096, 128)), jnp.zeros(128)), # w_o, w_o_bias - ] - - xs = [ - jnp.zeros((512, 128)), - jnp.zeros((386, 128)), - jnp.zeros((420, 128)), - ] - - def jumble_stack(xs: list[jax.Array]) -> batching.Jumble: - max_length = max(len(x) for x in xs) - lengths = jnp.array([len(x) for x in xs]) - lengths = jax.lax.convert_element_type(lengths, core.bint(max_length)) - xs_padded = jnp.stack([jnp.zeros((max_length, 128), dtype=x.dtype - ).at[:x.shape[0]].set(x) for x in xs]) - - # binder = i - binder = core.Var('', core.ShapedArray((), np.dtype('int32'))) - # elt_ty = f32[[3, 1, 4].i, 128] - elt_ty = core.DShapedArray((batching.IndexedAxisSize(binder, lengths), 128), - xs_padded.dtype) - # aval = i:(Fin 3) => f32[[3, 1, 4].i, 128] - aval = batching.JumbleTy(binder, len(xs), elt_ty) - xs_jumble = batching.Jumble(aval, xs_padded) - return xs_jumble - - with jax.disable_jit(disable_jit): - xs_jumble = jumble_stack(xs) - - fprop_batched = jax.vmap(fprop_layer, - in_axes=(None, batching.jumble_axis), - out_axes=batching.jumble_axis, - axis_size=3) - result_jumble = fprop_batched(params, xs_jumble) - self.assertIsInstance(result_jumble, batching.Jumble) - regex = r'Var[0-9]+:3 => (f32|f64)\[bint\{≤512\}\[3\] with value: \[512 386 420\]\.Var[0-9]+,128\]' - self.assertRegex(str(result_jumble.aval), regex) - self.assertAllClose(result_jumble.data.shape, (3, 512, 128)) - -def jumble_map(f): - def mapped(*jumbles): - return jax.vmap(f, in_axes=batching.jumble_axis, out_axes=batching.jumble_axis, - axis_size=jumbles[0].aval.length)(*jumbles) - return mapped - -if __name__ == '__main__': - absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 7f1182236615..e160f38133ca 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -83,30 +83,6 @@ jax_py_test( ]), ) -jax_multiplatform_test( - name = "pallas_jumble_test", - srcs = [ - "pallas_jumble_test.py", - ], - disable_configs = [ - "gpu_v100", - "gpu_v100_x32", - "gpu_a100", - "gpu_p100", - "gpu_p100_x32", - "gpu_h100", - "gpu_b200", - ], - deps = [ - "//jax/experimental:pallas", - "//jax/experimental:pallas_tpu", - "//jax/experimental:pallas_tpu_ops", - ] + py_deps([ - "absl/testing", - "numpy", - ]), -) - jax_multiplatform_test( name = "ops_test", srcs = [ diff --git a/tests/pallas/pallas_jumble_test.py b/tests/pallas/pallas_jumble_test.py deleted file mode 100644 index f570554083f0..000000000000 --- a/tests/pallas/pallas_jumble_test.py +++ /dev/null @@ -1,373 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import sys -import unittest - -os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5" - -from absl.testing import absltest -import jax -from jax import lax -from jax._src import config -from jax._src import core -from jax._src import dtypes -from jax._src import test_util as jtu -from jax._src.interpreters import batching -from jax.experimental import pallas as pl -import jax.numpy as jnp -import numpy as np - - -# TODO(mvoz): Update signatures of pallas_call to correct inputs/outputs. -# pylint: disable=no-value-for-parameter - -config.parse_flags_with_absl() - - -intx = dtypes.default_int_dtype() -floatx = dtypes.default_float_dtype() - - -def _assert_ragged_equal_with_elementwise_mask( - row_count, col_grid_size, ragged_shape, res, ref -): - total_columns = col_grid_size * 128 - mask = jnp.zeros((len(ragged_shape), row_count, total_columns), dtype=bool) - - for i, r in enumerate(ragged_shape): - mask = mask.at[i, :, : r * 128].set(True) - - res_valid = jnp.where(mask, res, -1) - ref_valid = jnp.where(mask, ref, -1) - - np.testing.assert_allclose(res_valid, ref_valid) - - -@unittest.skip("broken by https://github.com/jax-ml/jax/pull/29937") # TODO(mattjj): revive -@jtu.with_config(jax_traceback_filtering="off") -class PallasBaseTest(jtu.JaxTestCase): - INTERPRET = False - - def setUp(self): - if jtu.test_device_matches(["cpu"]) and not self.INTERPRET: - self.skipTest("On CPU the test works only in interpret mode") - if jtu.test_device_matches( - ["cuda"] - ) and not jtu.is_cuda_compute_capability_at_least("8.0"): - self.skipTest("Only works on GPU with capability >= sm80") - if sys.platform == "win32" and not self.INTERPRET: - self.skipTest("Only works on non-Windows platforms") - - super().setUp() - - def pallas_call(self, *args, **kwargs): - return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET) - - -@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_dtype_promotion="standard") -class PallasCallRaggedVmapTest(PallasBaseTest): - - def test_vmap_jumble_over_sin_kernel(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Only tested on TPU") - - row_count = 8 - col_grid_size = 5 - ragged_shape = [3, 1, 4] - sizes = lax.convert_element_type( - jnp.array([128 * x for x in ragged_shape]), - core.bint(col_grid_size * 128), - ) - x = jax.vmap( - lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis - )(sizes) - - def kernel(x_ref, o_ref): - o_ref[...] = jnp.sin(x_ref[...]) - - def invoke_kernel(x): - return pl.pallas_call( - kernel, - in_specs=[pl.BlockSpec((8, 128), lambda j, k: (j, k))], - out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)), - out_shape=jax.ShapeDtypeStruct( - (8, col_grid_size * 128), dtype=jnp.float32 - ), - grid=(1, col_grid_size), - interpret=self.INTERPRET, - # See note - on zero filling counterfactuals - debug=True, - )(x) - - res = jax.vmap( - invoke_kernel, - out_axes=batching.jumble_axis, - in_axes=batching.jumble_axis, - axis_size=3, - )(x) - - ref = jax.vmap( - jnp.sin, - out_axes=batching.jumble_axis, - in_axes=batching.jumble_axis, - axis_size=3, - )(x) - - _assert_ragged_equal_with_elementwise_mask( - row_count, col_grid_size, ragged_shape, res.data, ref.data - ) - - def test_vmap_jumble_over_add_kernel(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Only tested on TPU") - - row_count = 8 - col_grid_size = 5 - ragged_shape = [3, 1, 4] - sizes = lax.convert_element_type( - jnp.array([128 * x for x in ragged_shape]), - core.bint(col_grid_size * 128), - ) - x = jax.vmap( - lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis - )(sizes) - y = jax.vmap( - lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis - )(sizes) - - def kernel(x_ref, y_ref, o_ref): - o_ref[...] = x_ref[...] + y_ref[...] - - def invoke_kernel(x, y): - return pl.pallas_call( - kernel, - in_specs=[ - pl.BlockSpec((8, 128), lambda j, k: (j, k)), - pl.BlockSpec((8, 128), lambda j, k: (j, k)), - ], - out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)), - out_shape=jax.ShapeDtypeStruct( - (8, col_grid_size * 128), dtype=jnp.float32 - ), - grid=(1, col_grid_size), - interpret=self.INTERPRET, - )(x, y) - - # We've had this test fail with data corruption due to multiple - # invocations, so we run it k times to make sure it's not setting up - # memory incorrectly for subsequent invocations. - for _ in range(4): - res = jax.vmap( - invoke_kernel, - out_axes=batching.jumble_axis, - in_axes=batching.jumble_axis, - axis_size=3, - )(x, y) - - res = res.data - total = len(ragged_shape) * row_count * col_grid_size * 128 - res_total = np.prod(res.shape) - self.assertEqual(res_total, total) - - ref = jax.vmap( - lambda x, y: x + y, - out_axes=batching.jumble_axis, - in_axes=batching.jumble_axis, - axis_size=3, - )(x, y) - _assert_ragged_equal_with_elementwise_mask( - row_count, col_grid_size, ragged_shape, res, ref.data - ) - - def test_vmap_jumble_over_sin_kernel_grid_remapping(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Only tested on TPU") - - row_count = 8 - col_grid_size = 5 - ragged_shape = [3, 1, 4] - sizes = lax.convert_element_type( - jnp.array([128 * x for x in ragged_shape]), - core.bint(col_grid_size * 128), - ) - x = jax.vmap( - lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis - )(sizes) - - def kernel(x_ref, o_ref): - o_ref[...] = jnp.sin(x_ref[...]) * pl.program_id(2) - - def invoke_kernel(x): - return pl.pallas_call( - kernel, - in_specs=[pl.BlockSpec((8, 128), lambda j, k: (j, k))], - out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)), - out_shape=jax.ShapeDtypeStruct((8, 640), dtype=jnp.float32), - grid=(1, 5), - interpret=self.INTERPRET, - )(x) - - with self.assertRaisesRegex(ValueError, "Axis 2 is out of bounds for grid"): - jax.vmap( - invoke_kernel, - out_axes=batching.jumble_axis, - in_axes=batching.jumble_axis, - axis_size=3, - )(x) - - def test_vmap_jumble_over_matmul_kernel(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Only tested on TPU") - - if jtu.is_device_tpu(version=4): - self.skipTest("Flaky 15% of the time on tpuv4?") - - m = 128 - k = 640 - n = 640 - - def matmul_kernel(x_ref, y_ref, x_sentinel, z_ref): - # weird little once-only reset - @pl.when(x_sentinel[...][0][0] == 1.0) - def _(): - z_ref[...] = jnp.zeros_like(z_ref) - x_sentinel[...] = jnp.zeros_like(x_sentinel) - - z_ref[...] += x_ref[...] @ y_ref[...] - - def matmul( - x: jax.Array, - y: jax.Array, - x_sentinel: jax.Array, - *, - bm: int = 128, - bk: int = 128, - bn: int = 640, - ): - # m, k = x.shape - # _, n = y.shape - # a (1, 5) grid - # TODO(mvoz): parameterize this grid? - grid = (n // bn, k // bk) - return pl.pallas_call( - matmul_kernel, - out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), - in_specs=[ - pl.BlockSpec( - (bm, bk), - lambda j, k: (0, k), - ), - pl.BlockSpec( - (bk, bn), - lambda j, k: (k, j), - ), - pl.BlockSpec( - (bm, bn), - lambda j, k: (0, j), - ), - ], - out_specs=pl.BlockSpec( - (bm, bn), - lambda j, k: (0, j), - ), - grid=grid, - input_output_aliases={2: 0}, - interpret=self.INTERPRET, - )(x, y, x_sentinel) - - # TODO(mvoz): parameterize this shape? - ragged_shape = [3, 1, 4] - sizes = lax.convert_element_type( - jnp.array([128 * x for x in ragged_shape]), - core.bint(k), - ) - x = jax.vmap(lambda k_: jnp.ones((m, k_)), out_axes=batching.jumble_axis)( - sizes - ) - x_sentinel = jax.vmap( - lambda k_: jnp.ones((m, k_)), out_axes=batching.jumble_axis - )(sizes) - y = jax.vmap(lambda k_: jnp.ones((k_, n)), out_axes=batching.jumble_axis)( - sizes - ) - - res = jax.vmap( - matmul, - out_axes=batching.jumble_axis, - in_axes=batching.jumble_axis, - axis_size=3, - )(x, y, x_sentinel) - - ref = jax.vmap( - jnp.dot, - out_axes=batching.jumble_axis, - in_axes=batching.jumble_axis, - axis_size=3, - )(x, y) - - ref = ref.data - res = res.data - np.testing.assert_allclose(ref, res) - - def test_vmap_jumble_ragged_boundary_unaligned_with_grid(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Only tested on TPU") - - self.skipTest("Checkify NYI") - - row_count = 8 - col_grid_size = 5 - ragged_shape = [3, 1, 4] - sizes = lax.convert_element_type( - jnp.array([(128 * x) - 1 for x in ragged_shape]), - core.bint(col_grid_size * 128), - ) - x = jax.vmap( - lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis - )(sizes) - - def kernel(x_ref, o_ref): - o_ref[...] = jnp.sin(x_ref[...]) - - def invoke_kernel(x): - return pl.pallas_call( - kernel, - in_specs=[pl.BlockSpec((8, 128), lambda j, k: (j, k))], - out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)), - out_shape=jax.ShapeDtypeStruct((8, 640), dtype=jnp.float32), - grid=(1, 5), - interpret=False, - )(x) - - with self.assertRaisesRegex( - ValueError, - "Ragged input shape must be evenly divisible by the grid" # noqa: W605 - " size at the ragged dimension 2", - ): - jax.vmap( - invoke_kernel, - out_axes=batching.jumble_axis, - in_axes=batching.jumble_axis, - axis_size=3, - )(x) - - -class PallasCallNamedGridInterpretTest(PallasCallRaggedVmapTest): - INTERPRET = True - - -if __name__ == "__main__": - absltest.main() From c006c304ba639fe29aab1b898d3350fb4f67aa74 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 9 Dec 2025 12:01:07 -0800 Subject: [PATCH 126/315] jnp.arange: deprecate passing complex arguments --- CHANGELOG.md | 2 ++ jax/_src/deprecations.py | 3 ++- jax/_src/numpy/lax_numpy.py | 8 ++++++++ tests/lax_numpy_test.py | 20 +++++++++++++------- 4 files changed, 25 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 882ea245ed5a..a4896846a144 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. Please use `jax.lax.pcast(..., to='varying')` as the replacement. * `with mesh:` context manager has been deprecated. Please use `with jax.set_mesh(mesh):` instead. + * Complex arguments passed to {func}`jax.numpy.arange` now result in a + deprecation warning, because the output is poorly-defined. * Changes: * jax's `Tracer` no longer inherits from `jax.Array` at runtime. However, diff --git a/jax/_src/deprecations.py b/jax/_src/deprecations.py index 3ae3e19cb8c1..7cdbcb64791f 100644 --- a/jax/_src/deprecations.py +++ b/jax/_src/deprecations.py @@ -128,7 +128,8 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None: register('jax-lax-dot-positional-args') register('jax-lib-module') register('jax-nn-one-hot-float-input') -register("jax-numpy-astype-complex-to-real") +register('jax-numpy-arange-complex') +register('jax-numpy-astype-complex-to-real') register('jax-numpy-clip-args') register('jax-scipy-special-sph-harm') register('safer-randint-config') diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 79278915a09b..4bdc2e31d37b 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -5983,6 +5983,14 @@ def _arange(start: ArrayLike | DimSize, stop: ArrayLike | DimSize | None = None, dtype = dtypes.jax_dtype(dtype) if iscomplexobj(start) or iscomplexobj(stop) or iscomplexobj(step): + deprecations.warn( + "jax-numpy-arange-complex", + ( + "Passing complex start/stop/step to jnp.arange is deprecated;" + " in the future this will result in a ValueError." + ), + stacklevel=3 + ) # Complex arange is poorly defined; fall back to NumPy here. # TODO(jakevdp): deprecate the complex case. return array(np.arange(start, stop, step, dtype=dtype), device=out_sharding) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 6156f8996994..f3c828927314 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -46,6 +46,7 @@ from jax._src import array from jax._src import config from jax._src import core +from jax._src import deprecations from jax._src import dtypes from jax._src import test_util as jtu from jax._src.lax import lax as lax_internal @@ -4892,18 +4893,23 @@ def testArangeRandomValues(self, dtype, iteration): np_result = np.arange(start, stop, dtype=dtype) self.assertAllClose(jax_result, np_result) - def testArangeComplex(self): - test_cases = [ + @parameterized.parameters( (1+2j, 5+3j), (0+0j, 5+0j), (1.0+0j, 5.0+0j), (0, 5, 1+1j), - ] - for args in test_cases: - with self.subTest(args=args): + ) + def testArangeComplex(self, *args): + dep_id = "jax-numpy-arange-complex" + msg = "Passing complex start/stop/step to jnp.arange is deprecated" + if deprecations.is_accelerated(dep_id): + with self.assertRaisesRegex(ValueError, msg): + jax_result = jnp.arange(*args) + else: + with self.assertWarnsRegex(DeprecationWarning, msg): jax_result = jnp.arange(*args) - np_result = np.arange(*args) - self.assertArraysEqual(jax_result, np_result) + np_result = np.arange(*args) + self.assertArraysEqual(jax_result, np_result) def testIssue830(self): a = jnp.arange(4, dtype=jnp.complex64) From 548eaa5b53afeba91518d4d9274f7198b55cc308 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Tue, 9 Dec 2025 15:53:13 -0800 Subject: [PATCH 127/315] [Pallas TPU] Allow closed over scalars in core_map code This allows doing things like dynamic indexing of Refs using just regular scalars from outside the kernel. PiperOrigin-RevId: 842429684 --- jax/_src/pallas/core.py | 26 ++++++++++++---- jax/_src/pallas/mosaic/core.py | 45 +++++++++++++++++++++++++++ jax/_src/pallas/mosaic/sc_core.py | 5 +++ jax/_src/pallas/mosaic_gpu/core.py | 5 +++ tests/pallas/tpu_pallas_state_test.py | 31 +++++++++++++++++- 5 files changed, 105 insertions(+), 7 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index d6747b1d8ed9..32ceb1c21dc2 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -1535,8 +1535,13 @@ def default_mesh_discharge_rule( scratch_shapes, ): """Discharges a ``core_map`` over a mesh to a ``pallas_call``.""" - del out_avals # Unused. default_memory_space = memory_space + if not all( + isinstance(aval, state.AbstractRef) for aval in (in_avals + out_avals) + ): + raise ValueError( + "default_mesh_discharge_rule only supports Ref inputs/outputs." + ) def body(*args): # Due to aliasing, ``args`` contains aliased inputs and outputs so we @@ -1605,15 +1610,24 @@ def _core_map_discharge_rule(in_avals, out_avals, *args_flat, jaxpr, debug_info, for var in jaxpr.constvars if not isinstance(aval := var.aval, state.AbstractRef) ] - if consts_avals: + is_scalar_const_aval = [ + isinstance(aval, jax_core.ShapedArray) and not aval.shape + for aval in consts_avals + ] + if not all(is_scalar_const_aval): ctx = jax_core.JaxprPpContext() - pp_const_avals = ", ".join( - jax_core.pp_aval(aval, ctx) for aval in consts_avals + non_scalar_const_avals = [ + aval + for aval, is_scalar in zip(consts_avals, is_scalar_const_aval) + if not is_scalar + ] + non_scalar_const_pp_avals = ", ".join( + jax_core.pp_aval(aval, ctx) for aval in non_scalar_const_avals ) raise ValueError( "The kernel function in core_map" - f" {debug_info.func_src_info} captures constants" - f" [{pp_const_avals}]. You should pass them as inputs." + f" {debug_info.func_src_info} captures non-scalar constants" + f" [{non_scalar_const_pp_avals}]. You should pass them as inputs." ) return _core_map_mesh_rules[type(mesh)]( in_avals, out_avals, *args_flat, jaxpr=jaxpr, mesh=mesh, **kwargs diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 58643515b239..4547af632d4f 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -26,9 +26,11 @@ import jax.numpy as jnp from jax.extend import backend as jex_backend from jax._src import core as jax_core +from jax._src import linear_util as lu from jax._src import state from jax._src import util from jax._src.frozen_dict import FrozenDict +from jax._src.interpreters import partial_eval as pe from jax._src.pallas import core as pallas_core import numpy as np @@ -336,6 +338,49 @@ def _tensorcore_mesh_discharge_rule( "TensorCoreMesh does not support VMEM inputs/outputs when there are" " >1 cores. Use HBM or ANY instead." ) + def allowed_aval(aval): + if isinstance(aval, state.AbstractRef): + return True + if isinstance(aval, jax_core.ShapedArray): + # Only scalars are allowed. + return not aval.shape + return False + assert all(allowed_aval(v.aval) for v in jaxpr.constvars + jaxpr.invars) + + is_scalar_const = [ + isinstance(v.aval, jax_core.ShapedArray) and not v.aval.shape + for v in jaxpr.constvars + ] + if any(is_scalar_const): + # Rewrite body jaxpr to take in scalar values as Refs. + def new_body(*args): + args = [ + a[0] if is_scalar else a + for a, is_scalar in zip(args, is_scalar_const) + ] + return jax_core.eval_jaxpr(jaxpr, args) + # TODO(sharadmv): Remove this once Mosaic support passing scalars as values. + new_trace_avals = [ + state.AbstractRef( # pylint: disable=g-long-ternary + jax_core.ShapedArray((1,), v.aval.dtype), + memory_space=MemorySpace.SMEM, + ) + if is_scalar + else v.aval + for v, is_scalar in zip(jaxpr.constvars, is_scalar_const) + ] + new_jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( + lu.wrap_init( + new_body, debug_info=jaxpr.debug_info.with_unknown_names() + ), + new_trace_avals, + ) + jaxpr = new_jaxpr.replace(invars=[], constvars=new_jaxpr.invars) + args = tuple( + a[None] if is_scalar else a + for a, is_scalar in zip(args, is_scalar_const) + ) + in_avals, out_avals = util.split_list(new_trace_avals, [len(in_avals)]) return pallas_core.default_mesh_discharge_rule( in_avals, out_avals, diff --git a/jax/_src/pallas/mosaic/sc_core.py b/jax/_src/pallas/mosaic/sc_core.py index 2eaab2546e8a..2d5670e8461a 100644 --- a/jax/_src/pallas/mosaic/sc_core.py +++ b/jax/_src/pallas/mosaic/sc_core.py @@ -219,6 +219,11 @@ def _scalar_subcore_mesh_discharge_rule( compiler_params = tpu_core.CompilerParams() if compiler_params.dimension_semantics is not None: raise ValueError("ScalarSubcoreMesh does not support dimension_semantics=") + sa_avals = [a for a in in_avals if isinstance(a, jax_core.ShapedArray)] + if sa_avals: + raise NotImplementedError( + f"Cannot close over values in core_map: {sa_avals}" + ) return pallas_core.default_mesh_discharge_rule( in_avals, out_avals, diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 097f6e47b170..532c86d98af7 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -1368,6 +1368,11 @@ def _gpu_mesh_discharge_rule( ) if not compiler_params: compiler_params = CompilerParams() + sa_avals = [a for a in in_avals if isinstance(a, jax_core.ShapedArray)] + if sa_avals: + raise NotImplementedError( + f"Cannot close over values in core_map: {sa_avals}" + ) return pallas_core.default_mesh_discharge_rule( in_avals, out_avals, diff --git a/tests/pallas/tpu_pallas_state_test.py b/tests/pallas/tpu_pallas_state_test.py index f127bb1d49d8..f7821d922249 100644 --- a/tests/pallas/tpu_pallas_state_test.py +++ b/tests/pallas/tpu_pallas_state_test.py @@ -324,9 +324,38 @@ def kernel(x_ref, out_ref, tmp_ref): return kernel(x) x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128)) - with self.assertRaisesRegex(Exception, "core_map .* captures constants"): + with self.assertRaisesRegex( + Exception, "core_map .* captures non-scalar constants" + ): f(x) + def test_capture_scalar(self): + @jax.jit + def f(x, i): + @pl.kernel(out_shape=jax.ShapeDtypeStruct(x.shape[1:], jnp.int32), + mesh=pltpu.create_tensorcore_mesh("x", num_cores=1)) + def kernel(x_ref, out_ref): + pltpu.sync_copy(x_ref.at[i], out_ref) + return kernel(x) + + x = jnp.arange(4 * 8 * 128, dtype=jnp.int32).reshape((4, 8, 128)) + for i in range(x.shape[0]): + out = f(x, i) + np.testing.assert_array_equal(out, x[i]) + + @jax.jit + def g(x, i): + @pl.kernel(out_shape=jax.ShapeDtypeStruct((2, *x.shape[1:]), jnp.int32), + mesh=pltpu.create_tensorcore_mesh("x", num_cores=1)) + def kernel(x_ref, out_ref): + pltpu.sync_copy(x_ref.at[pl.ds(i, 2)], out_ref) + return kernel(x) + + x = jnp.arange(4 * 8 * 128, dtype=jnp.int32).reshape((4, 8, 128)) + for i in range(3): + out = g(x, i) + np.testing.assert_array_equal(out, x[i:i+2]) + def test_kernel_helper_with_scratch(self): mesh = pltpu.create_tensorcore_mesh("x") def body(x_ref, o_ref, scratch_ref): From 4257c62b3cb35886fa1af08e1a56da08f2b46f6f Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 9 Dec 2025 17:49:14 -0800 Subject: [PATCH 128/315] Remove one sized mesh axis from spmd_axis_name during comparison with explicit axes if remove_size_one_mesh_axis_from_type is turned on. PiperOrigin-RevId: 842467979 --- jax/_src/api.py | 5 ++++- tests/pjit_test.py | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 1ad727b5cf2a..08a9dd8387d2 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -68,7 +68,7 @@ from jax._src.lib import xla_client as xc from jax._src.lib import pmap_lib from jax._src.sharding import Sharding -from jax._src.mesh import get_concrete_mesh +from jax._src.mesh import get_concrete_mesh, get_abstract_mesh from jax._src.sharding_impls import (PmapSharding, PartitionSpec as P, NamedSharding) from jax._src.layout import Format @@ -1191,6 +1191,9 @@ def vmap_f(*args, **kwargs): _mapped_axis_size(fun, in_tree, args_flat, in_axes_flat, "vmap")) explicit_mesh_axis = _mapped_axis_spec(args_flat, in_axes_flat) if spmd_axis_name is not None and explicit_mesh_axis is not None: + spmd_axis_name = ( + tuple(core.remove_size_one_mesh_axis(P(spmd_axis_name), get_abstract_mesh())) + if config.remove_size_one_mesh_axis_from_type.value else spmd_axis_name) if spmd_axis_name == explicit_mesh_axis: spmd_axis_name = None else: diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 2b8ff9198420..70e2e1d3503b 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7268,6 +7268,25 @@ def f(x): "Only one of spmd_axis_name or arrays sharded on.*spmd_axis_name"): f(arr) + @config.remove_size_one_mesh_axis_from_type(True) + @jtu.with_explicit_mesh((2, 1), ('x', 'y')) + def test_spmd_axis_name_explicit_mode_assert_remove_one_size(self, mesh): + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, NamedSharding(mesh, P(('x', 'y'), None))) + + @jax.jit + @partial(jax.vmap, spmd_axis_name=('x', 'y')) + def f(x): + # breakpoint() + self.assertEqual(x.aval.sharding.spec, P(None)) + out = x * 2 + self.assertEqual(out.aval.sharding.spec, P(None)) + return out + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + self.assertArraysEqual(out, np_inp * 2) + @jtu.with_explicit_mesh((2,), ('x',)) def test_unmapped_last_vmap(self, mesh): np_inp = np.arange(8) From 8186c19d305d2553bea21b70aa0197af4172319b Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Wed, 10 Dec 2025 01:18:54 +0000 Subject: [PATCH 129/315] [mutable-arrays] allow internal ref effects in mlir.lower_fun Co-authored-by: Sharad Vikram --- jax/_src/interpreters/mlir.py | 11 +++++++++-- tests/mutable_array_test.py | 23 +++++++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 2be7f5088609..4be9077d4e34 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -2451,8 +2451,15 @@ def f_lowered(ctx: LoweringRuleContext, *args, **params): wrapped_fun = lu.wrap_init(f, params, debug_info=api_util.debug_info("lower_fun", fun, args, {})) - jaxpr, _, consts_for_constvars = pe.trace_to_jaxpr_dynamic(wrapped_fun, - ctx.avals_in) + jaxpr, _, consts_for_constvars = pe.trace_to_jaxpr_dynamic( + wrapped_fun, ctx.avals_in) + + if any(isinstance(e, core.InternalMutableArrayEffect) for e in jaxpr.effects): + from jax._src.interpreters import pxla # type: ignore + closed_jaxpr = core.ClosedJaxpr(jaxpr, consts_for_constvars) + closed_jaxpr = pxla._discharge_internal_refs(closed_jaxpr) + jaxpr, consts_for_constvars = closed_jaxpr.jaxpr, closed_jaxpr.consts + # TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out? if ctx.platforms is not None: diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index 2d03c1866bbf..ed23bbe8604e 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -27,6 +27,7 @@ from jax._src import test_util as jtu from jax._src.api import vjp3 from jax._src.util import safe_map, safe_zip +from jax._src.interpreters import mlir from jax.sharding import NamedSharding, PartitionSpec as P, AxisType import jax.numpy as jnp @@ -1031,6 +1032,28 @@ def test_none_index(self): y = ref[None] self.assertEqual(y.shape, (1, 3)) + def test_what_if_you_lower_fun_something_with_internal_effects(self): + bjp_p = core.Primitive('bjp') + + @bjp_p.def_abstract_eval + def _(aval): + return aval + + def lowering(x): + x_ref = jax.new_ref(x) + x_ref[...] += 1 + x_ref[...] += -1 + return jax.freeze(x_ref) + + mlir.register_lowering(bjp_p, mlir.lower_fun(lowering, multiple_results=False)) + + @jax.jit + def f(x): + return bjp_p.bind(x) + + f(3.) # don't crash + + @jtu.with_config(jax_mutable_array_checks=True) class MutableArrayErrorsTest(jtu.JaxTestCase): def test_return_from_jit(self): From 4952b21de2d27fd8a0264cd25ea3ae0f902009f0 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 10 Dec 2025 00:05:06 -0800 Subject: [PATCH 130/315] Update XLA dependency to use revision http://github.com/openxla/xla/commit/63413fe5a5ce541ba0e076f25264d11f0311fac5 PiperOrigin-RevId: 842575869 --- third_party/xla/revision.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/revision.bzl b/third_party/xla/revision.bzl index 8810fa765f03..b7788d301fd9 100644 --- a/third_party/xla/revision.bzl +++ b/third_party/xla/revision.bzl @@ -21,5 +21,5 @@ # and update XLA_SHA256 with the result. # buildifier: disable=module-docstring -XLA_COMMIT = "91b3f740b75d1d932a12fb0886338f84f856a453" -XLA_SHA256 = "68d6d2f66b10e826512fa6e262143425606dda30d6ab95daa83e3dfb5cf298a0" +XLA_COMMIT = "63413fe5a5ce541ba0e076f25264d11f0311fac5" +XLA_SHA256 = "4381718d7981a6d866171fe206d07ce7cf58295c7aeaef9fdc5c44741f22b585" From d9a7388943120ba9037e53cd4afe2df4fc0d15b6 Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Wed, 10 Dec 2025 02:27:28 -0800 Subject: [PATCH 131/315] [Mosaic GPU][NFC] Refactor `IsTransferable.holds` to use pattern matching. PiperOrigin-RevId: 842626789 --- jax/experimental/mosaic/gpu/constraints.py | 38 ++++++++++++---------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/jax/experimental/mosaic/gpu/constraints.py b/jax/experimental/mosaic/gpu/constraints.py index 9a66cde1432b..aeed84694eb1 100644 --- a/jax/experimental/mosaic/gpu/constraints.py +++ b/jax/experimental/mosaic/gpu/constraints.py @@ -389,25 +389,29 @@ def holds(self) -> bool | None: Returns `None` if the constraint can't be checked. """ - source = self.source - target = self.target - if isinstance(source, TMEMLayout) and isinstance(target, RegisterLayout): - return self._is_valid_tmem_transfer(source.value, target.value) - if isinstance(target, TMEMLayout) and isinstance(source, RegisterLayout): - return self._is_valid_tmem_transfer(target.value, source.value) - if isinstance(source, TMEMLayout) and isinstance(target, TMEMLayout): - return source == target - if isinstance(source, SMEMTiling) and isinstance(target, RegisterLayout): - return self._is_valid_smem_transfer(source.value, target.value) - if isinstance(target, SMEMTiling) and isinstance(source, RegisterLayout): - return self._is_valid_smem_transfer(target.value, source.value) - if isinstance(target, Constant) and isinstance(source, Constant): - source_type = type(source).__name__ - target_type = type(target).__name__ - raise NotImplementedError(f"Unsupported transfer: {source_type} -> {target_type}") + assert self.source != self.target, ( + "IsTransferable constraints within the same memory space are not" + " supported." + ) - return None + match self.source, self.target: + case TMEMLayout(value=src), RegisterLayout(value=dst): + return self._is_valid_tmem_transfer(src, dst) + case RegisterLayout(value=src), TMEMLayout(value=dst): + return self._is_valid_tmem_transfer(dst, src) + case SMEMTiling(value=src), RegisterLayout(value=dst): + return self._is_valid_smem_transfer(src, dst) + case RegisterLayout(value=src), SMEMTiling(value=dst): + return self._is_valid_smem_transfer(dst, src) + case Constant(), Constant(): + source_type = type(self.source).__name__ + target_type = type(self.target).__name__ + raise NotImplementedError( + f"Unsupported transfer: {source_type} -> {target_type}" + ) + case _: + return None def __str__(self): return f"IsTransferable({self.source} ⟶ {self.target})" From 3aeefe5ee2a7e4244d651481c18b7082b6aaf761 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 10 Dec 2025 02:36:53 -0800 Subject: [PATCH 132/315] jnp.arange: avoid device transfer when possible --- jax/_src/numpy/lax_numpy.py | 7 +++++-- tests/lax_numpy_test.py | 15 +++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 4bdc2e31d37b..29ddd583cc02 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -142,9 +142,12 @@ def iscomplexobj(x: Any) -> bool: >>> jnp.iscomplexobj(jnp.array([0, 1+2j])) True """ - # Check for int here to avoid potential overflow in jnp.array below. - if x is None or isinstance(x, int): + # Fast path for common types. + if isinstance(x, (complex, np.complexfloating)): + return True + if x is None or isinstance(x, (bool, int, float, str, np.generic)): return False + # Fall back to dtype attribute lookup. try: typ = x.dtype.type except AttributeError: diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index fbdd46de4cd0..677f0555878d 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -3613,6 +3613,15 @@ def testIsComplexObj(self, val): self._CheckAgainstNumpy(np.iscomplexobj, jnp.iscomplexobj, args_maker) self._CompileAndCheck(jnp.iscomplexobj, args_maker) + @parameterized.parameters( + None, bool(1), int(1), float(1), complex(1), + np.int32(0), np.float32(1), np.complex64(1), + (np.arange(5),) + ) + def testIsComplexObjTransferGuard(self, val): + with jax.transfer_guard("disallow"): + jnp.iscomplexobj(val) + def testIsClose(self): c_isclose = jax.jit(jnp.isclose) c_isclose_nan = jax.jit(partial(jnp.isclose, equal_nan=True)) @@ -4911,6 +4920,12 @@ def testArangeComplex(self, *args): np_result = np.arange(*args) self.assertArraysEqual(jax_result, np_result) + @parameterized.parameters(int, float, np.int32, np.float32) + def testArangeTransferGuard(self, typ): + # Ensure that simple arange calls avoid host-to-device transfer. + with jax.transfer_guard("disallow"): + jnp.arange(typ(5)) + def testIssue830(self): a = jnp.arange(4, dtype=jnp.complex64) self.assertEqual(a.dtype, jnp.complex64) From 723b90e04fb2665d8f657350dddb205c15fad599 Mon Sep 17 00:00:00 2001 From: George Necula Date: Wed, 10 Dec 2025 10:42:10 +0000 Subject: [PATCH 133/315] [export] Add backwards compatibility test for memory_space This is in preparation for the upcoming serialization version 6 (#33597). --- .../export_with_memory_space.py | 23 ++++++++++ .../export_serialization_back_compat_test.py | 45 +++++++++++++++++-- 2 files changed, 64 insertions(+), 4 deletions(-) create mode 100644 jax/_src/internal_test_util/export_back_compat_test_data/export_with_memory_space.py diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/export_with_memory_space.py b/jax/_src/internal_test_util/export_back_compat_test_data/export_with_memory_space.py new file mode 100644 index 000000000000..3d89168e0004 --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/export_with_memory_space.py @@ -0,0 +1,23 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa + +# Pasted from the test output (see export_serialization_back_compat_test.py module docstring) +serializations = [ + dict( + serialization_version=5, + exported_serialized=bytearray(b"0\x00\x00\x00\x00\x00*\x00L\x00J\x00D\x00@\x00<\x008\x004\x00.\x00(\x00$\x00 \x00\x1c\x00\x18\x00\x14\x00\x10\x00\x0e\x00\x08\x00\x07\x00\x00\x000\x00*\x00\x00\x00\x00\x00\x00\x01D\x00\x00\x00\x00\x00\n\x00D\x00\x00\x00\x84\x02\x00\x00\x84\x02\x00\x00\x84\x02\x00\x00X\x02\x00\x00\x80\x02\x00\x00\x88\x02\x00\x00\x00\x00\x02\x00\x02\x00\x00\x00\xa0\x02\x00\x00\xcc\x02\x00\x00\xcc\x02\x00\x00\x04\x03\x00\x00X\x03\x00\x00\x00\x00\x03\x00\x01\x00\x00\x00\x00\x00\x00\x00 \x02\x00\x00ML\xefR\rStableHLO_v1.13.0\x00\x01\x1f\x07\x01\x05\t\t\x01\x03\x0f\x03\x03\x13\x05\x05\x17\x1b\x03kE\x0f\x01\x1b\x07\x0b#\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0b\x0f\x0b\x03\r\x13\x0f\x1b\x17\x0f\x13\x05\x1f\x0b\x0b\x13\x13\x0b\x0b\x1b\x0b\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x01\x05\x0f\x0b\x05\x0b\x17\x0f\x1b\x07\x07\x02\xf9\x1f\x05\t\x03\x07\x07\t\x0b\r\x0f\x11\x05\x0f\x11\x03\x01\x05\x11\x11\x01\t\x05\x13\x11\x01\x05\x05\x15\t\x03\x1d\x19\x01\x05\x17\x05\x03\x1d\x01\x03\x17\t\r\x15\x05!%\x01\x0b\x03#\x01\x01\t\x17\x01\x0b\x01\x01\x01\x1d\x19\x1d\x1b\x03\x05-3\r\x03/1\x1d\x1d\x1d\x1f\r\x05\')5\x1f\x1d!#\t\x03\x03;\r\x05=?\')\x1d#\x1d%\x1d\'\x1d)\x01\x02\x02\x01\t)\x05\t\r\r)\x01\x0b\x11\x05\x07\x05\x03\x05\x1b\t\x04I\x05\x01Q\x01\x05\x01\x07\x047\x03\x01\t\x03@\x01\x03\x05P\x01\x05\x07\x04\x1b\x03\x05\x07\x05\r\x0b\x17\x00\x07\x04\x01\x03\x03\x06\x03\x01\x05\x01\x00\x1a\x04+\x0f\x0b\x0f!\x1b!)\x19#\x05\x19%)9\x15\x11\x0b\x0f\x0b\t\x11builtin\x00sdy\x00vhlo\x00module\x00mesh\x00func_v1\x00return_v1\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda\x00x\x00mhlo.memory_kind\x00pinned_host\x00jax.global_constant\x00_platform_index\x00sdy.sharding\x00jax.result_info\x00result\x00main\x00public\x00\x08\x1b\x07\x05\'\x01\x05\x1b\x03\x0b+79AC\x02\x00\x00\x00\x14\x00\x00\x00\x04\x00\x00\x00\x04\x00\x00\x00cuda\x00\x00\x00\x00\x03\x00\x00\x00tpu\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x18\xff\xff\xff\x01\x00\x00\x00\x04\x00\x00\x00@\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x01\x0c\x00\x00\x00\x08\x03\x1a\x02\x02\x01J\x01\x02R\x01\x00\x01\x00\x00\x00\x04\x00\x00\x00\xcc\xff\xff\xff\x00\x00\x02\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x003\x00\x00\x00\x01\x00\x00\x002\x00\x00\x00p\xff\xff\xff\x01\x00\x00\x00\x10\x00\x00\x00\x0c\x00\x0c\x00\x00\x00\x08\x00\x07\x00\x06\x00\x0c\x00\x00\x00\x00\x00\x02\n\x04\x00\x00\x00\x02\x00\x00\x00\x10\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x003\x00\x00\x00\x01\x00\x00\x002\x00\x00\x00\xcc\xff\xff\xff\x08\x00\x00\x00\x00\x00\x00\x02\x02\x00\x00\x00,\x00\x00\x00\x10\x00\x00\x00\x00\x00\n\x00\x0c\x00\x0b\x00\x00\x00\x04\x00\n\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x08\x00\x0c\x00\x0b\x00\x04\x00\x08\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x02\x01\x00\x00\x00\x08\x00\x00\x00\x04\x00\x04\x00\x04\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x00"), + ), +] diff --git a/tests/export_serialization_back_compat_test.py b/tests/export_serialization_back_compat_test.py index 02886cb237df..b11858db3c2b 100644 --- a/tests/export_serialization_back_compat_test.py +++ b/tests/export_serialization_back_compat_test.py @@ -23,8 +23,8 @@ * Create a new test method, with a function to be serialized that exercises the feature you want to test, and a call to self.export_and_serialize. You can follow the model of the tests below, which are parameterized by - the test data. Use `None` for the test data to signal that you want to - use a fresh serialization. + the testdata. Use only `None` for the testdata parameter to signal that + you want to use a current serialization and not a saved one. * Run the test. This will save the serialized data in TEST_UNDECLARED_OUTPUTS_DIR (or "/tmp/back_compat_testdata" if not set). * Copy the test data defined in the output file, to the file @@ -55,6 +55,7 @@ import jax from jax._src import config +from jax._src import core from jax._src.export import _export from jax._src.export.serialization import _SERIALIZATION_VERSION from jax.sharding import PartitionSpec as P @@ -62,6 +63,7 @@ from jax._src.internal_test_util.export_back_compat_test_data import export_with_specified_sharding from jax._src.internal_test_util.export_back_compat_test_data import export_with_unspecified_sharding +from jax._src.internal_test_util.export_back_compat_test_data import export_with_memory_space config.parse_flags_with_absl() jtu.request_cpu_devices(8) @@ -75,6 +77,7 @@ def setUp(self): def export_and_serialize(self, fun, *args, vjp_order=0, + platforms=None, **kwargs) -> bytearray: """Export and serialize a function. @@ -82,7 +85,7 @@ def export_and_serialize(self, fun, *args, "/tmp/back_compat_testdata" if not set) and should be copied as explained in the module docstring. """ - exp = _export.export(fun)(*args, **kwargs) + exp = _export.export(fun, platforms=platforms)(*args, **kwargs) serialized = exp.serialize(vjp_order=vjp_order) updated_testdata = f""" # Paste to the test data file (see export_serialization_back_compat_test.py module docstring) @@ -98,7 +101,8 @@ def export_and_serialize(self, fun, *args, "/tmp/back_compat_testdata") if not os.path.exists(output_dir): os.makedirs(output_dir) - output_file = os.path.join(output_dir, f"export_{self._testMethodName}.py") + output_file_basename = f"export_{self._testMethodName.replace('test_', '')}.py" + output_file = os.path.join(output_dir, output_file_basename) logging.info("Writing the updated serialized Exported at %s", output_file) with open(output_file, "w") as f: f.write(updated_testdata) @@ -163,5 +167,38 @@ def f(b): self.assertEqual(out.addressable_shards[1].index, (slice(8, 16), slice(None))) + @jtu.parameterized_filterable( + kwargs=[ + dict(testdata=testdata, + testcase_name=("current" if testdata is None + else f"v{testdata['serialization_version']}")) + for testdata in [None, *export_with_memory_space.serializations] + ] + ) + def test_with_memory_space(self, testdata: dict[str, Any] | None): + # This test is based on export_test.py::test_memory_space_from_arg + mesh = jtu.create_mesh((2,), "x") + with jax.set_mesh(mesh): + shd = jax.sharding.NamedSharding(mesh, P("x", None), + memory_kind="pinned_host") + a = jax.device_put(np.ones((2, 3), dtype=np.float32), shd) + f = jax.jit(lambda x: x) + + if testdata is None: + serialized = self.export_and_serialize( + f, a, platforms=("tpu", "cuda")) + else: + serialized = testdata["exported_serialized"] + + exported = _export.deserialize(serialized) + self.assertEqual(exported.in_avals[0].memory_space, core.MemorySpace.Host) + self.assertEqual(exported.out_avals[0].memory_space, core.MemorySpace.Host) + + if jtu.device_under_test() in ("tpu", "gpu"): + b = exported.call(a) + self.assertEqual(b.aval.memory_space, core.MemorySpace.Host) + self.assertEqual(b.sharding, a.sharding) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From 3a28e93e6a4b33fa5f3b4a194c213d85bc67d286 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 10 Dec 2025 03:24:07 -0800 Subject: [PATCH 134/315] [Pallas:MGPU] Add a lowering rule for lax.clamp PiperOrigin-RevId: 842643534 --- jax/_src/pallas/mosaic_gpu/lowering.py | 8 ++++++++ tests/pallas/ops_test.py | 23 +++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 4cfd3924c8b2..d853cf24e7e1 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -2217,6 +2217,14 @@ def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y): return res +@register_lowering_rule(lax.clamp_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule(lax.clamp_p, mgpu.LoweringSemantics.Warpgroup) +def _clamp_lowering_rule(ctx: LoweringRuleContext, l, x, u): + return _lower_fun( + lambda l, x, u: lax.min(lax.max(x, l), u), multiple_results=False + )(ctx, l, x, u) + + @register_lowering_rule(lax.square_p, mgpu.LoweringSemantics.Lane) @register_lowering_rule(lax.square_p, mgpu.LoweringSemantics.Warpgroup) def _square_lowering_rule(ctx: LoweringRuleContext, x): diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 11d295ae9bdd..2d87545e86ef 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1017,6 +1017,29 @@ def kernel(x_ref, o_ref): expected = lax.is_finite(x) self.assertArraysEqual(out, expected) + @parameterized.parameters(jnp.float32, jnp.bfloat16, jnp.int32, jnp.int16) + def test_clamp(self, dtype): + if dtype == jnp.int16 and jtu.test_device_matches(["tpu"]): + self.skipTest("int16 is not supported on TPU") + + k1, k2, k3 = random.split(jax.random.key(0), num=3) + if jnp.issubdtype(dtype, jnp.floating): + lo_ = random.normal(k1, (8, 128), dtype=dtype) + hi_ = random.normal(k2, (8, 128), dtype=dtype) + x = random.normal(k3, (8, 128), dtype=dtype) + else: + lo_ = random.randint(k1, (8, 128), -100, 100, dtype=dtype) + hi_ = random.randint(k2, (8, 128), -100, 100, dtype=dtype) + x = random.randint(k3, (8, 128), -100, 100, dtype=dtype) + lo = jnp.minimum(lo_, hi_) + hi = jnp.maximum(lo_, hi_) + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((8, 128), dtype), + ) + def kernel(lo_ref, x_ref, hi_ref, o_ref): + o_ref[...] = lax.clamp(lo_ref[...], x_ref[...], hi_ref[...]) + np.testing.assert_array_equal(kernel(lo, x, hi), lax.clamp(lo, x, hi)) + @parameterized.named_parameters( (dtype.__name__, dtype) for dtype in (jnp.float32, jnp.float16, jnp.bfloat16) From 15ba1b72c98345d5d9b106f736f58bf15ad882b9 Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Wed, 10 Dec 2025 03:53:53 -0800 Subject: [PATCH 135/315] [Mosaic GPU][NFC] Remove `Tautological` from constraint reduction. PiperOrigin-RevId: 842652535 --- jax/experimental/mosaic/gpu/constraints.py | 35 ++++++++++------------ 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/jax/experimental/mosaic/gpu/constraints.py b/jax/experimental/mosaic/gpu/constraints.py index aeed84694eb1..409dd1d7982b 100644 --- a/jax/experimental/mosaic/gpu/constraints.py +++ b/jax/experimental/mosaic/gpu/constraints.py @@ -490,10 +490,9 @@ def __str__(self): def reduce_constraint( constraint: Constraint, assignments: dict[Variable, Constant] -) -> Constraint | Tautological | Unsatisfiable: +) -> Constraint | Unsatisfiable: """Reduces a constraint.""" - new_constraint: Constraint match constraint: case Equals(lhs=lhs, rhs=rhs): lhs_red = reduce_expression(lhs, assignments) @@ -502,7 +501,7 @@ def reduce_constraint( rhs_red = reduce_expression(rhs, assignments) if isinstance(rhs_red, Unsatisfiable): return Unsatisfiable() - new_constraint = Equals(lhs_red, rhs_red) + return Equals(lhs_red, rhs_red) case Relayout(source=source, target=target): source_red = reduce_expression(source, assignments) target_red = reduce_expression(target, assignments) @@ -510,31 +509,26 @@ def reduce_constraint( target_red, Unsatisfiable ): return Unsatisfiable() - new_constraint = Relayout(source_red, target_red) + return Relayout(source_red, target_red) case NotOfType(expr=expr, type=type): expr_red = reduce_expression(expr, assignments) if isinstance(expr_red, Unsatisfiable): return Unsatisfiable() - new_constraint = NotOfType(expr_red, type) + return NotOfType(expr_red, type) case IsTransferable(source=source, target=target, shape=shape): source_red = reduce_expression(source, assignments) target_red = reduce_expression(target, assignments) if isinstance(source_red, Unsatisfiable) or isinstance(target_red, Unsatisfiable): return Unsatisfiable() - new_constraint = IsTransferable(source_red, target_red, shape) + return IsTransferable(source_red, target_red, shape) case Divides(expr=expr, tiling_multiple=tiling_multiple): expr_red = reduce_expression(expr, assignments) if isinstance(expr_red, Unsatisfiable): return Unsatisfiable() - new_constraint = Divides(expr_red, tiling_multiple) + return Divides(expr_red, tiling_multiple) case _ as never: assert_never(never) - constraint_holds = new_constraint.holds() - if constraint_holds is None: - return new_constraint - return Tautological() if constraint_holds else Unsatisfiable() - @dataclasses.dataclass class ConstraintSystem: @@ -620,10 +614,6 @@ def __and__(self, other: ConstraintSystem | Unsatisfiable) -> Unsatisfiable: return self -class Tautological: - ... - - def non_splat_variables( constraints: Sequence[Constraint], ) -> set[Variable]: @@ -832,11 +822,16 @@ def try_assign(var: Variable, cst: Constant) -> bool: if not try_assign(var, cst): return Unsatisfiable() changed = True - case Tautological(): - changed = True case _ as new_constraint: - changed |= new_constraint != constraint - constraints.append(new_constraint) + assert isinstance(new_constraint, Constraint) # make pytype happy + match new_constraint.holds(): + case None: + constraints.append(new_constraint) + changed |= new_constraint != constraint + case False: + return Unsatisfiable() + case True: + changed = True new_constraints = merge_divides_constraints(constraints) changed |= len(new_constraints) != len(constraints) From 9ec840fa83e00999f09e475137d86d9f30d69756 Mon Sep 17 00:00:00 2001 From: Brian Patton Date: Wed, 10 Dec 2025 04:18:03 -0800 Subject: [PATCH 136/315] [Pallas:SC] Adds a test that verifies pl.kernel outputs are placed in the proper memory spaces. PiperOrigin-RevId: 842659708 --- tests/pallas/tpu_pallas_state_test.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/pallas/tpu_pallas_state_test.py b/tests/pallas/tpu_pallas_state_test.py index f7821d922249..fa5ce0778ecc 100644 --- a/tests/pallas/tpu_pallas_state_test.py +++ b/tests/pallas/tpu_pallas_state_test.py @@ -14,6 +14,7 @@ import functools from absl.testing import absltest +from absl.testing import parameterized import jax from jax._src import test_util as jtu from jax._src.state.primitives import pin, unpin @@ -382,6 +383,28 @@ def body(x_ref, o1_ref, o2_ref, scratch_ref): np.testing.assert_array_equal(result1, x) np.testing.assert_array_equal(result2, x + 1) + @parameterized.named_parameters( + ("HBM", pltpu.HBM, 0), + ("VMEM", pltpu.VMEM, 1), + ("SMEM", pltpu.SMEM, 4), + ("SEMAPHORE", pltpu.SEMAPHORE, 2), + ) + def test_kernel_with_output_memory_space(self, memory_space, color): + if not jtu.is_device_tpu_at_least(5): + self.skipTest("Only supported on TPU v5+") + mesh = pltpu.create_tensorcore_mesh("x", num_cores=1) + def body(x_ref, o_ref): + pltpu.sync_copy(x_ref, o_ref) + x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128)) + text = pl.kernel( + body, out_shape=memory_space(x.shape, x.dtype), mesh=mesh, + ).lower(x).as_text() + custom_call = [l for l in text.split("\n") if "@tpu_custom_call" in l] + self.assertLen(custom_call, 1) + custom_call = custom_call[0] + self.assertRegex(custom_call, + r".*output_memory_colors\\22: \[" + str(color) + r"\].*") + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From e43f4cb6af49d0dbb50f64c65ea9e75f6b722329 Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Wed, 10 Dec 2025 05:09:18 -0800 Subject: [PATCH 137/315] [Pallas:MGPU] Expose `fragmented_array.Replicated` as part of the public API. PiperOrigin-RevId: 842675493 --- jax/experimental/pallas/mosaic_gpu.py | 1 + tests/pallas/mosaic_gpu_test.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/jax/experimental/pallas/mosaic_gpu.py b/jax/experimental/pallas/mosaic_gpu.py index 8adcd2da2521..eccc06881936 100644 --- a/jax/experimental/pallas/mosaic_gpu.py +++ b/jax/experimental/pallas/mosaic_gpu.py @@ -87,6 +87,7 @@ from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait as wgmma_wait from jax._src.pallas.mosaic_gpu.torch import as_torch_kernel as as_torch_kernel from jax.experimental.mosaic.gpu.core import LoweringSemantics as LoweringSemantics +from jax.experimental.mosaic.gpu.fragmented_array import Replicated as Replicated from jax.experimental.mosaic.gpu.fragmented_array import Tiling as Tiling diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index fd248ea182da..35496cf5f8c9 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2563,6 +2563,23 @@ def kernel(dst, collective_barrier): )() np.testing.assert_array_equal(y, np.ones((), dtype=np.int32)) + def test_replicated_layout(self): + shape = (32,) + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), + ) + def kernel(src_ref, dst_ref): + layout = plgpu.Layout.TILED( + plgpu.Tiling(((32,), (1,))), + warp_dims=(plgpu.Replicated(4),), + lane_dims=(-2,), + vector_dim=-1, + ) + dst_ref[...] = plgpu.load(src_ref, (), layout=layout) + src = jnp.arange(shape[0], dtype=jnp.float32) + np.testing.assert_array_equal(kernel(src), src) + class PallasCallWarpPrimitiveSemanticsTest(PallasTest): def setUp(self): From 637982ed33a9099daa6009e319367558e7692b35 Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Wed, 10 Dec 2025 05:28:09 -0800 Subject: [PATCH 138/315] [MGPU] Add support for broadcast on major dim in WGStridedFragLayout. PiperOrigin-RevId: 842680640 --- jax/_src/pallas/mosaic_gpu/lowering.py | 8 ++++ .../mosaic/gpu/fragmented_array.py | 14 ++++++ tests/mosaic/gpu_test.py | 47 ++++++++++++++++--- tests/pallas/mosaic_gpu_test.py | 16 +++++++ 4 files changed, 78 insertions(+), 7 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index d853cf24e7e1..3c484df435c9 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1860,6 +1860,14 @@ def _broadcast_in_dim_lowering_rule( if (isinstance(x.layout, mgpu.WGSplatFragLayout) and broadcast_dimensions == tuple(range(rank_diff, rank_diff + x_aval.ndim))): return x.broadcast(shape) + if ( + isinstance(x.layout, mgpu.WGStridedFragLayout) + and broadcast_dimensions == tuple(range(rank_diff, y_aval.ndim)) + ): + new_layout = mgpu.WGStridedFragLayout( + shape=y_aval.shape, vec_size=x.layout.vec_size + ) + return x.broadcast_in_dim(y_aval.shape, broadcast_dimensions, new_layout) if not isinstance(layout := x.layout, mgpu.TiledLayout): raise NotImplementedError(f"Unsupported layout: {x.layout}") if any(d1 >= d2 for d1, d2 in zip(broadcast_dimensions[:-1], broadcast_dimensions[1:])): diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index f3e7f5b477e8..f2c58a9b354f 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -2504,6 +2504,20 @@ def broadcast_in_dim( return type(self).splat( self.registers.item(), shape, layout, is_signed=self.is_signed ) + if isinstance(self.layout, WGStridedFragLayout) and isinstance(layout, WGStridedFragLayout): + new_dims = set(range(len(shape))) - set(source_dimensions) + vec_match = self.layout.vec_size == layout.vec_size + broadcast_dim_match = new_dims == set(range(len(new_dims))) + assert layout.shape == shape, (layout.shape, shape) + if vec_match and broadcast_dim_match: + return FragmentedArray( + _registers=np.tile( + self.registers, + np.prod(shape[:len(new_dims)]), + ), + _layout=layout, + _is_signed=self.is_signed, + ) if not isinstance(self.layout, TiledLayout) or not isinstance(layout, TiledLayout): raise NotImplementedError(self.layout, layout) if any(d1 >= d2 for d1, d2 in zip(source_dimensions, source_dimensions[1:])): diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index b3e918c674eb..5218160a6a77 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -3796,16 +3796,49 @@ def kernel(ctx, *args): )(inp) np.testing.assert_array_equal(result, inp) - @parameterized.parameters((128, 128), (128, 64), (64, 128)) - def test_broadcast_major(self, m, n): + @parameterized.product( + mns=((128, 128), (128, 64), (64, 128)), + layout=(mtu.RegisterLayout.WG_STRIDED, mtu.RegisterLayout.WGMMA), + ) + def test_broadcast_major(self, mns, layout): + m, n = mns + + if n < 128 and layout == mtu.RegisterLayout.WG_STRIDED: + self.skipTest(f"{n=} < 128 not supported for {layout=}") + + dtype = jnp.float16 + load_layout = ( + layout.to_mgpu((n,), dtype) + if layout == mtu.RegisterLayout.WG_STRIDED + else mgpu.WGMMA_COL_LAYOUT + ) + broadcast_layout = ( + mgpu.WGStridedFragLayout((m, n), load_layout.vec_size) + if layout == mtu.RegisterLayout.WG_STRIDED + else layout.to_mgpu((m, n), dtype) + ) + + def load(gmem_input): + match layout: + case mtu.RegisterLayout.WG_STRIDED: + return mgpu.FragmentedArray.load_strided( + gmem_input, vec_size=load_layout.vec_size + ) + case mtu.RegisterLayout.WGMMA: + return mgpu.FragmentedArray.load_untiled( + gmem_input, layout=mgpu.WGMMA_COL_LAYOUT, optimized=False + ) + case _: + raise NotImplementedError(f"Unsupported layout: {layout}") + def kernel(ctx, gmem_input, gmem_output, _): - t = mgpu.FragmentedArray.load_untiled( - gmem_input, layout=mgpu.WGMMA_COL_LAYOUT, optimized=False + t = load(gmem_input) + t.broadcast_in_dim((m, n), (1,), broadcast_layout).store_untiled( + gmem_output, optimized=False ) - t.broadcast_in_dim((m, n), (1,), mgpu.WGMMA_LAYOUT).store_untiled(gmem_output, optimized=False) - inp = self.prng.uniform(-1, 1, (n,)).astype(jnp.float16) - out_shape = jax.ShapeDtypeStruct((m, n), jnp.float16) + inp = self.prng.uniform(-1, 1, (n,)).astype(dtype) + out_shape = jax.ShapeDtypeStruct((m, n), dtype) result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (inp,), out_shape, inp )(inp) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 35496cf5f8c9..e35b6a4d5816 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2407,6 +2407,22 @@ def test_broadcast_in_dim_does_not_crash_on_small_shape(self): shape, plgpu.Layout.TCGEN05_TMEM_NATIVE, axis=1, hint=False ) + def test_broadcast_in_dim_wg_strided_majormost_dim(self): + self.skip_if_wg_semantics() + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((256, 128), jnp.float32), + ) + def kernel(x_ref, y_ref): + to_be_broadcasted = plgpu.load( + x_ref, (), layout=plgpu.Layout.WG_STRIDED((128,), 1) + ) + broadcasted = lax.broadcast_in_dim(to_be_broadcasted, (256, 128), (1,)) + y_ref[...] = broadcasted + + result = jax.random.uniform(jax.random.key(0), shape=(128,), dtype=jnp.float32) + np.testing.assert_array_equal(kernel(result), jnp.broadcast_to(result[None,:], (256, 128))) + def test_broadcast_in_dim_tcgen05_native_layout(self): @functools.partial( self.kernel, From c7bc9b2866c8c3a2c9c8b90dbecd5dfe37ba3812 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 10 Dec 2025 06:03:46 -0800 Subject: [PATCH 139/315] [Pallas:MGPU] Properly restore the pytree structure when unflattening unions Previously we returned all the leaves in a flat list, which is unhelpful. PiperOrigin-RevId: 842691111 --- jax/_src/pallas/mosaic_gpu/core.py | 22 ++++++++------ tests/pallas/mosaic_gpu_test.py | 49 ++++++++++++++++-------------- 2 files changed, 39 insertions(+), 32 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 532c86d98af7..0167c96d04ce 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -431,7 +431,8 @@ def flatten_ref_union(ref_union: AbstractRefUnion) -> tuple[_Ref, ...]: union_bytes = 0 for ref_group in ref_union.refs: byte_offset = 0 - for ref in jax.tree.leaves(ref_group): + def unflatten(ref): + nonlocal byte_offset byte_offset = align_to(byte_offset, SMEM_ALIGNMENT) assert isinstance(ref, state.AbstractRef) or isinstance( ref, pallas_core.TransformedRef @@ -439,10 +440,8 @@ def flatten_ref_union(ref_union: AbstractRefUnion) -> tuple[_Ref, ...]: if not isinstance(ref, pallas_core.TransformedRef): ref = pallas_core.TransformedRef(ref, transforms=()) transform = ExtractAliasedRef.from_transformed_ref(ref, byte_offset) - flat_refs.append( - pallas_core.TransformedRef( - ref_union, transforms=(transform, *ref.transforms) - ) + result = pallas_core.TransformedRef( + ref_union, transforms=(transform, *ref.transforms) ) if jnp.issubdtype(ref.dtype, jnp.integer): nbits = jnp.iinfo(ref.dtype).bits @@ -457,13 +456,16 @@ def flatten_ref_union(ref_union: AbstractRefUnion) -> tuple[_Ref, ...]: f" {ref.dtype}{ref.shape}" ) byte_offset += ref_bits // 8 + return result + flat_refs.append(jax.tree.map(unflatten, ref_group)) union_bytes = max(union_bytes, byte_offset) assert union_bytes == ref_union.shape[0] elif ref_union.memory_space == TMEM: union_cols = 0 for ref_group in ref_union.refs: col_offset = 0 - for ref in jax.tree.leaves(ref_group): + def unflatten(ref): + nonlocal col_offset col_offset = align_to(col_offset, TMEM_COL_ALIGNMENT) if not isinstance(ref, pallas_core.TransformedRef): ref = pallas_core.TransformedRef(ref, transforms=()) @@ -471,12 +473,12 @@ def flatten_ref_union(ref_union: AbstractRefUnion) -> tuple[_Ref, ...]: dtypes.itemsize_bits(ref.dtype)) transform = ExtractAliasedRef.from_transformed_ref( ref, col_offset, layout=ref.layout) - flat_refs.append( - pallas_core.TransformedRef( - ref_union, transforms=(transform, *ref.transforms) - ) + result = pallas_core.TransformedRef( + ref_union, transforms=(transform, *ref.transforms) ) col_offset += ncols + return result + flat_refs.append(jax.tree.map(unflatten, ref_group)) union_cols = max(union_cols, col_offset) assert union_cols == ref_union.shape[1], (union_cols, ref_union.shape[1]) else: diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index e35b6a4d5816..951bb3af2031 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1907,7 +1907,7 @@ def kernel(x_ref, y_ref, o_ref): y = jax.lax.iota(jnp.float32, 128) * 3 np.testing.assert_array_equal(kernel(x, y), x + y) - def test_smem_aliasing_works(self): + def test_smem_aliasing_works_basic(self): self.skip_if_wg_semantics() in_shape = (2, 256) @@ -1938,17 +1938,16 @@ def test_smem_aliasing_works(self): plgpu.SMEM( (128,), jnp.float32, - transforms=(plgpu.TilingTransform((64,)),), - ), + transforms=(plgpu.TilingTransform((64,)),)), ] ], ) ], ) def kernel(x_ref, o_ref128, aliased_ref): - smem_ref256, _, smem_ref128 = aliased_ref + smem_ref256, [_, [smem_ref128]] = aliased_ref # Ensure that extraction via index works the same as unfolding. - smem_ref128_2 = aliased_ref[2] + smem_ref128_2 = aliased_ref[1][1][0] self.assertIsInstance(smem_ref128, state_types.TransformedRef) self.assertIsInstance(smem_ref128_2, state_types.TransformedRef) self.assertIs(smem_ref128.ref, smem_ref128_2.ref) @@ -2005,7 +2004,7 @@ def test_smem_aliasing_works_with_subbyte_dtypes(self): ], ) def kernel(x_ref, o_refi4, aliased_ref): - _, smem_refi8, _, smem_refi4 = aliased_ref + [_, smem_refi8], [_, smem_refi4] = aliased_ref smem_refi8[...] = x_ref[...] plgpu.commit_smem() plgpu.copy_smem_to_gmem(smem_refi4, o_refi4) @@ -3415,7 +3414,7 @@ def test_tmem_ref_aliasing(self): thread_name="x", ) def kernel(x_ref, y_ref, aliased_ref, smem_ref, barrier_ref): - tmem_128x32a, tmem_128x32b, tmem_128x64 = aliased_ref + [tmem_128x32a, tmem_128x32b], tmem_128x64 = aliased_ref plgpu.copy_gmem_to_smem(x_ref, smem_ref, barrier_ref) plgpu.barrier_wait(barrier_ref) # Test tmem_128x32 a and b @@ -4268,7 +4267,7 @@ def kernel(a_gmem, b_gmem, out_gmem128, out_gmem64, plgpu.barrier_wait(tma_barrier) plgpu.copy_gmem_to_smem(b_gmem, b_smem, tma_barrier) plgpu.barrier_wait(tma_barrier) - acc_128, lhs_128, lhs_64, acc_64, _ = aliased_refs + [acc_128, lhs_128], [lhs_64, acc_64], _ = aliased_refs # Do 128x128 @ 128x128 matmul plgpu.async_store_tmem(lhs_128, plgpu.load(a_smem, (), layout=plgpu.Layout.TCGEN05)) @@ -4305,21 +4304,27 @@ def kernel(a_gmem, b_gmem, out_gmem128, out_gmem64, f = self.kernel( kernel, - out_shape=[jax.ShapeDtypeStruct(shape, dtype), - jax.ShapeDtypeStruct(shape, dtype)], + out_shape=[ + jax.ShapeDtypeStruct(shape, dtype), + jax.ShapeDtypeStruct(shape, dtype), + ], scratch_shapes=[ - plgpu.SMEM(shape, dtype, transforms=transforms), # a_smem - plgpu.SMEM(shape, dtype, transforms=transforms), # b_smem - plgpu.SMEM(shape, dtype, transforms=transforms), # out_smem - plgpu.Barrier(), # tma_barrier - plgpu.Barrier(orders_tensor_core=True), # mma_barrier - plgpu.RefUnion( # aliased_refs - [plgpu.TMEM((128, 128), jnp.float32), # acc - plgpu.TMEM((128, 128), dtype, packed=True)], # lhs - [plgpu.TMEM((128, 64), dtype, packed=True), # lhs - plgpu.TMEM((128, 128), jnp.float32)], # acc - plgpu.TMEM((128, 128), jnp.float32) # unused - ), + plgpu.SMEM(shape, dtype, transforms=transforms), # a_smem + plgpu.SMEM(shape, dtype, transforms=transforms), # b_smem + plgpu.SMEM(shape, dtype, transforms=transforms), # out_smem + plgpu.Barrier(), # tma_barrier + plgpu.Barrier(orders_tensor_core=True), # mma_barrier + plgpu.RefUnion( # aliased_refs + [ + plgpu.TMEM((128, 128), jnp.float32), # acc + plgpu.TMEM((128, 128), dtype, packed=True), # lhs + ], + [ + plgpu.TMEM((128, 64), dtype, packed=True), # lhs + plgpu.TMEM((128, 128), jnp.float32), # acc + ], + plgpu.TMEM((128, 128), jnp.float32), # unused + ), ], ) x = jax.random.uniform(jax.random.key(0), shape=shape, dtype=dtype) From 11bb9812588ba1e26ed68965d8e6cf7817accf60 Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Wed, 10 Dec 2025 08:35:43 -0800 Subject: [PATCH 140/315] [Pallas:MGPU] Support not tiled transposed loads in `swap` LANE lowering rule. PiperOrigin-RevId: 842740700 --- jax/_src/pallas/mosaic_gpu/lowering.py | 7 ++++++- tests/pallas/mosaic_gpu_test.py | 6 ------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 3c484df435c9..af4aebd172bd 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1717,9 +1717,12 @@ def _swap_lowering_rule( layout=value.layout, ) value.store_tiled(x_smem, swizzle=swizzle) - case (): + case () | (gpu_core.TransposeRef((1, 0)),): + transposed = bool(transforms) match value.layout: case mgpu.TiledLayout(): + if transposed: + x_smem = mgpu.memref_transpose(x_smem, (1, 0)) old_value = mgpu.FragmentedArray.load_untiled( x_smem, layout=value.layout, @@ -1728,6 +1731,8 @@ def _swap_lowering_rule( ) value.store_untiled(x_smem, optimized=False) case _: + if transposed: + raise NotImplementedError(f"Unsupported transforms: {transforms}") old_value = mgpu.FragmentedArray.load_strided( x_smem, is_signed=mgpu_utils.is_signed(v_aval.dtype) ) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 951bb3af2031..9efd7e7ad438 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1056,12 +1056,6 @@ def test_transposed_load_store(self, src_layout, dst_layout): def is_transposed(layout): return layout == plgpu.Layout.WGMMA_TRANSPOSED - if ( - self.LOWERING_SEMANTICS == mgpu.LoweringSemantics.Lane - and is_transposed(dst_layout) - ): - self.skipTest("Not implemented: transposed, not tiled") - shape, dtype = (128, 128), jnp.float32 @functools.partial( From 1a78757ec4d49c67f629d44c73abe33d7193d637 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 10 Dec 2025 08:48:31 -0800 Subject: [PATCH 141/315] Add a missing libtpu version skip in SC Pallas tests PiperOrigin-RevId: 842744528 --- tests/pallas/tpu_sparsecore_pallas_test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/pallas/tpu_sparsecore_pallas_test.py b/tests/pallas/tpu_sparsecore_pallas_test.py index 17d1e3acf098..7157018ff8ba 100644 --- a/tests/pallas/tpu_sparsecore_pallas_test.py +++ b/tests/pallas/tpu_sparsecore_pallas_test.py @@ -1925,6 +1925,11 @@ class PipelineTestWithTCTiling(TCTilingMixin, PipelineTest): class PallasSparsecoreAsyncTest(PallasSCTest): + def setUp(self): + super().setUp() + if not jtu.is_cloud_tpu_at_least(2025, 12, 14): + self.skipTest("Needs a newer libtpu") + @parameterized.product( shape=[ (8, 128), From 3a37e92770906b853dbeeb16e9e320462ad53299 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Wed, 10 Dec 2025 09:29:55 -0800 Subject: [PATCH 142/315] [pmap] Add `default_pmap_sharding` to migrate users away from `PmapSharding.default`. We leave this function undocumented because it's not intended for new users, but is public to migrate existing users. PiperOrigin-RevId: 842759555 --- jax/_src/sharding_impls.py | 46 ++++++++++++++++++++++++++++ jax/sharding.py | 1 + tests/documentation_coverage_test.py | 2 +- 3 files changed, 48 insertions(+), 1 deletion(-) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index b658a4d13966..2ddf717eae39 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -351,6 +351,52 @@ def shard_shape(self, global_shape: Shape) -> Shape: PmapSharding.__module__ = 'jax.sharding' +def default_pmap_sharding( + shape: Shape, + sharded_dim: int | None = 0, + devices: Sequence[xc.Device] | None = None, +) -> NamedSharding | PmapSharding: + """Creates a NamedSharding equivalent to PmapSharding.default. + + This function provides the same sharding semantics as PmapSharding.default + but returns a NamedSharding when jax_pmap_shmap_merge is enabled, which is + compatible with the shard_map-based pmap implementation. + + Args: + shape: The shape of the input array. + sharded_dim: Dimension the input array is sharded on. Defaults to 0. + If None, the array is fully replicated. + devices: Optional sequence of devices to use. If omitted, uses + jax.local_devices(). + + Returns: + A NamedSharding if jax_pmap_shmap_merge is enabled, otherwise a + PmapSharding. + """ + if not config.pmap_shmap_merge.value: + return PmapSharding.default(shape, sharded_dim=sharded_dim, devices=devices) + + if sharded_dim is None: + if devices is None: + raise ValueError("One of sharded_dim or devices must be set.") + mesh = mesh_lib.Mesh(np.array(devices), ('_default_pmap_sharding',)) + return NamedSharding(mesh, PartitionSpec()) + + if len(shape) == 0: + raise ValueError("shape must be non-empty for sharded_dim != None") + + num_ways_sharded = shape[sharded_dim] + + if devices is None: + pmap_devices = np.array(xb.local_devices()[:num_ways_sharded]) + else: + pmap_devices = np.array(devices) + + mesh = mesh_lib.Mesh(pmap_devices, ('_default_pmap_sharding',)) + spec_list: list[str | None] = [None] * len(shape) + spec_list[sharded_dim] = '_default_pmap_sharding' + return NamedSharding(mesh, PartitionSpec(*spec_list)) + def _unpickle_gspmd_sharding(devices, op_sharding, memory_kind): return GSPMDSharding(devices, op_sharding, memory_kind=memory_kind) diff --git a/jax/sharding.py b/jax/sharding.py index c592abec393f..e98ad72be036 100644 --- a/jax/sharding.py +++ b/jax/sharding.py @@ -20,6 +20,7 @@ NamedSharding as NamedSharding, SingleDeviceSharding as SingleDeviceSharding, PmapSharding as _deprecated_PmapSharding, + default_pmap_sharding as default_pmap_sharding, set_mesh as set_mesh, get_mesh as get_mesh, ) diff --git a/tests/documentation_coverage_test.py b/tests/documentation_coverage_test.py index 83ae55a7423c..316f925a6ec8 100644 --- a/tests/documentation_coverage_test.py +++ b/tests/documentation_coverage_test.py @@ -68,7 +68,7 @@ def jax_docs_dir() -> str: 'jax.profiler': ['ProfileData', 'ProfileEvent', 'ProfileOptions', 'ProfilePlane', 'stop_server'], 'jax.random': ['key_impl', 'random_gamma_p'], 'jax.scipy.special': ['bessel_jn', 'sph_harm_y'], - 'jax.sharding': ['AbstractDevice', 'AbstractMesh', 'AxisType', 'auto_axes', 'explicit_axes', 'get_abstract_mesh', 'reshard', 'set_mesh', 'use_abstract_mesh', 'get_mesh'], + 'jax.sharding': ['AbstractDevice', 'AbstractMesh', 'AxisType', 'auto_axes', 'default_pmap_sharding', 'explicit_axes', 'get_abstract_mesh', 'reshard', 'set_mesh', 'use_abstract_mesh', 'get_mesh'], 'jax.stages': ['ArgInfo', 'CompilerOptions'], 'jax.tree_util': ['DictKey', 'FlattenedIndexKey', 'GetAttrKey', 'PyTreeDef', 'SequenceKey', 'default_registry'], } From d774b64cdc412ce7e0099e90fab3118af0893868 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 10 Dec 2025 09:45:44 -0800 Subject: [PATCH 143/315] Remove some deprecated BUILD aliases, close visibility of some others. PiperOrigin-RevId: 842764995 --- BUILD.bazel | 50 ++++++++---- jax/BUILD | 163 +++------------------------------------ tests/BUILD | 53 +++++-------- tests/mosaic/BUILD | 1 - tests/multiprocess/BUILD | 3 +- tests/pallas/BUILD | 7 +- 6 files changed, 69 insertions(+), 208 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index c9bcd69227da..f7dee0cb7bb0 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -30,21 +30,43 @@ wheel_sources( data_srcs = ["//jax"], py_srcs = [ "//jax", - "//jax:compilation_cache", - "//jax:experimental", "//jax/example_libraries:example_libraries", - "//jax:experimental_colocated_python", - "//jax:experimental_sparse", - "//jax:experimental_buffer_callback", - "//jax:experimental_serialize_executable", - "//jax:pallas_experimental_gpu_ops", - "//jax:pallas_fuser", - "//jax:pallas_gpu_ops", - "//jax:pallas_mosaic_gpu", - "//jax:pallas_tpu_ops", - "//jax:pallas_triton", - "//jax:source_mapper", - "//jax:sparse_test_util", + "//jax/example_libraries:optimizers", + "//jax/example_libraries:stax", + "//jax/experimental:buffer_callback", + "//jax/experimental:checkify", + "//jax/experimental:colocated_python", + "//jax/experimental:compilation_cache", + "//jax/experimental:compute_on", + "//jax/experimental:custom_dce", + "//jax/experimental:custom_partitioning", + "//jax/experimental:fused", + "//jax/experimental:hijax", + "//jax/experimental:jet", + "//jax/experimental:layout", + "//jax/experimental:mesh_utils", + "//jax/experimental:multihost_utils", + "//jax/experimental:ode", + "//jax/experimental:pallas_experimental_gpu_ops", + "//jax/experimental:pallas_fuser", + "//jax/experimental:pallas_gpu_ops", + "//jax/experimental:pallas_mosaic_gpu", + "//jax/experimental:pallas_tpu_ops", + "//jax/experimental:pallas_triton", + "//jax/experimental:pjit", + "//jax/experimental:profiler", + "//jax/experimental:rnn", + "//jax/experimental:scheduling_groups", + "//jax/experimental:serialize_executable", + "//jax/experimental:shard_alike", + "//jax/experimental:shard_map", + "//jax/experimental:source_mapper", + "//jax/experimental:sparse_test_util", + "//jax/experimental:sparse", + "//jax/experimental:topologies", + "//jax/experimental:transfer", + "//jax/experimental:xla_metadata", + "//jax/experimental", "//jax/_src:lax_reference", "//jax/_src:internal_export_back_compat_test_util", "//jax/_src:internal_export_back_compat_test_data", diff --git a/jax/BUILD b/jax/BUILD index 245d15d9de16..928c67fadf85 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -20,6 +20,7 @@ load( "jax_extend_internal_users", "jax_extra_deps", "jax_internal_packages", + "jax_visibility", "py_deps", "py_library_providing_imports_info", "pytype_library", @@ -259,7 +260,7 @@ pytype_strict_library( # TODO(dsuo): remove these aliases/targets. pytype_strict_library( name = "experimental", - visibility = ["//visibility:public"], + visibility = jax_visibility("experimental_deprecated_alias"), deps = [ ":jax", "//jax/example_libraries:optimizers", @@ -288,190 +289,44 @@ pytype_strict_library( ], ) -alias( - name = "experimental_buffer_callback", - actual = "//jax/experimental:buffer_callback", - visibility = ["//jax/experimental:buffer_callback_users"], -) - -alias( - name = "experimental_colocated_python", - actual = "//jax/experimental:colocated_python", - visibility = ["//visibility:public"], -) - -alias( - name = "experimental_compute_on", - actual = "//jax/experimental:compute_on", - visibility = ["//visibility:public"], -) - -alias( - name = "compilation_cache", - actual = "//jax/experimental:compilation_cache", - visibility = ["//visibility:public"], -) - -alias( - name = "jet", - actual = "//jax/experimental:jet", - visibility = ["//visibility:public"], -) - alias( name = "mesh_utils", actual = "//jax/experimental:mesh_utils", - visibility = ["//visibility:public"], -) - -alias( - name = "experimental_mesh_utils", - actual = "//jax/experimental:mesh_utils", - visibility = ["//visibility:public"], -) - -alias( - name = "mosaic", - actual = "//jax/experimental:mosaic", - visibility = ["//jax/experimental:mosaic_users"], -) - -alias( - name = "mosaic_gpu", - actual = "//jax/experimental:mosaic_gpu", - visibility = ["//jax/experimental:mosaic_gpu_users"], -) - -alias( - name = "experimental_multihost_utils", - actual = "//jax/experimental:multihost_utils", - visibility = ["//visibility:public"], -) - -alias( - name = "ode", - actual = "//jax/experimental:ode", - visibility = ["//visibility:public"], + visibility = jax_visibility("mesh_utils_deprecated_alias"), ) alias( name = "pallas", actual = "//jax/experimental:pallas", - visibility = ["//visibility:public"], + visibility = jax_visibility("pallas_deprecated_alias"), ) alias( name = "pallas_fuser", actual = "//jax/experimental:pallas_fuser", - visibility = ["//jax/experimental:pallas_fuser_users"], -) - -alias( - name = "pallas_gpu", - actual = "//jax/experimental:pallas_gpu", - visibility = ["//jax/experimental:pallas_gpu_users"], -) - -alias( - name = "pallas_gpu_ops", - actual = "//jax/experimental:pallas_gpu_ops", - visibility = ["//jax/experimental:pallas_gpu_users"], + visibility = jax_visibility("pallas_fuser_deprecated_alias"), ) alias( name = "pallas_mosaic_gpu", actual = "//jax/experimental:pallas_mosaic_gpu", - visibility = ["//jax/experimental:mosaic_gpu_users"], + visibility = jax_visibility("pallas_mosaic_gpu_deprecated_alias"), ) alias( name = "pallas_tpu", actual = "//jax/experimental:pallas_tpu", - visibility = ["//visibility:public"], -) - -alias( - name = "pallas_tpu_ops", - actual = "//jax/experimental:pallas_tpu_ops", - visibility = ["//visibility:public"], -) - -alias( - name = "pallas_triton", - actual = "//jax/experimental:pallas_triton", - visibility = ["//jax/experimental:pallas_gpu_users"], -) - -alias( - name = "pallas_experimental_gpu_ops", - actual = "//jax/experimental:pallas_experimental_gpu_ops", - visibility = ["//jax/experimental:mosaic_gpu_users"], -) - -alias( - name = "experimental_profiler", - actual = "//jax/experimental:profiler", - visibility = ["//visibility:public"], -) - -alias( - name = "experimental_pjit", - actual = "//jax/experimental:pjit", - visibility = ["//visibility:public"], -) - -alias( - name = "rnn", - actual = "//jax/experimental:rnn", - visibility = ["//visibility:public"], -) - -alias( - name = "experimental_serialize_executable", - actual = "//jax/experimental:serialize_executable", - visibility = ["//jax/experimental:serialize_executable_users"], -) - -alias( - name = "source_mapper", - actual = "//jax/experimental:source_mapper", - visibility = ["//visibility:public"], -) - -alias( - name = "experimental_sparse", - actual = "//jax/experimental:sparse", - visibility = ["//visibility:public"], -) - -alias( - name = "sparse_test_util", - actual = "//jax/experimental:sparse_test_util", - visibility = [":internal"], -) - -alias( - name = "experimental_topologies", - actual = "//jax/experimental:topologies", - visibility = ["//visibility:public"], + visibility = jax_visibility("pallas_tpu_deprecated_alias"), ) alias( name = "experimental_transfer", actual = "//jax/experimental:transfer", - visibility = ["//jax/experimental:experimental_transfer_users"], + visibility = jax_visibility("experimental_transfer_deprecated_alias"), ) -# Aliases of example_library targets. -# TODO(dsuo): remove these aliases. alias( name = "optimizers", actual = "//jax/example_libraries:optimizers", - visibility = ["//visibility:public"], -) - -alias( - name = "stax", - actual = "//jax/example_libraries:stax", - visibility = ["//visibility:public"], + visibility = jax_visibility("optimizers_deprecated_alias"), ) diff --git a/tests/BUILD b/tests/BUILD index 3736f44b8b5d..b723dc69120c 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -36,9 +36,7 @@ jax_multiplatform_test( srcs = ["api_test.py"], enable_configs = ["tpu_v3_x4"], shard_count = 5, - deps = [ - "//jax:experimental", - ] + py_deps([ + deps = py_deps([ "absl/testing", "numpy", ]), @@ -48,8 +46,8 @@ jax_multiplatform_test( name = "custom_api_test", srcs = ["custom_api_test.py"], deps = [ - "//jax:experimental", "//jax/_src:custom_derivatives", + "//jax/experimental:custom_dce", ] + py_deps([ "absl/testing", "numpy", @@ -61,9 +59,10 @@ jax_multiplatform_test( srcs = ["debug_info_test.py"], enable_configs = ["tpu_v3_x4"], deps = [ - "//jax:experimental", "//jax/_src:custom_transpose", "//jax/_src:shard_map", + "//jax/experimental:checkify", + "//jax/experimental:custom_dce", "//jax/experimental:pallas", "//jax/experimental:pallas_gpu", "//jax/experimental:pallas_gpu_ops", @@ -373,7 +372,7 @@ jax_multiplatform_test( srcs = ["memories_test.py"], tags = ["multiaccelerator"], deps = [ - "//jax:experimental", + "//jax/experimental:compute_on", ] + py_deps([ "absl/testing", "numpy", @@ -393,7 +392,7 @@ jax_multiplatform_test( ], tags = ["multiaccelerator"], deps = [ - "//jax:experimental", + "//jax/experimental:custom_partitioning", ] + py_deps([ "absl/testing", "numpy", @@ -421,7 +420,8 @@ jax_multiplatform_test( }, tags = ["multiaccelerator"], deps = [ - "//jax:experimental", + "//jax/experimental", + "//jax/experimental:multihost_utils", ] + py_deps([ "absl/testing", "numpy", @@ -439,7 +439,7 @@ jax_multiplatform_test( ], tags = ["multiaccelerator"], deps = [ - "//jax:experimental", + "//jax/experimental:layout", ] + py_deps([ "absl/testing", "numpy", @@ -451,7 +451,7 @@ jax_multiplatform_test( srcs = ["shard_alike_test.py"], tags = ["multiaccelerator"], deps = [ - "//jax:experimental", + "//jax/experimental:shard_alike", ] + py_deps([ "absl/testing", "numpy", @@ -488,9 +488,7 @@ jax_multiplatform_test( tags = [ "config-cuda-only", ], - deps = [ - "//jax:experimental", - ] + py_deps([ + deps = py_deps([ "absl/testing", "numpy", ]), @@ -506,9 +504,7 @@ jax_multiplatform_test( tags = [ "config-cuda-only", ], - deps = [ - "//jax:experimental", - ] + py_deps("absl/testing"), + deps = py_deps("absl/testing"), ) jax_multiplatform_test( @@ -522,8 +518,8 @@ jax_multiplatform_test( ], tags = ["multiaccelerator"], deps = [ - "//jax:experimental", "//jax/_src:internal_test_util", + "//jax/experimental:multihost_utils", ] + py_deps([ "numpy", "absl/testing", @@ -1083,9 +1079,7 @@ jax_multiplatform_test( name = "pickle_test", srcs = ["pickle_test.py"], enable_backends = ["cpu"], - deps = [ - "//jax:experimental", - ] + py_deps([ + deps = py_deps([ "cloudpickle", "numpy", "absl/testing", @@ -1736,7 +1730,7 @@ jax_multiplatform_test( srcs = ["x64_context_test.py"], enable_backends = ["cpu"], deps = [ - "//jax:experimental", + "//jax/experimental", ] + py_deps([ "absl/testing", "numpy", @@ -1824,7 +1818,7 @@ jax_multiplatform_test( }, tags = ["multiaccelerator"], deps = [ - "//jax:experimental", + "//jax/experimental", ] + py_deps([ "absl/testing", "numpy", @@ -1881,9 +1875,7 @@ jax_multiplatform_test( tags = [ "multiaccelerator", ], - deps = [ - "//jax:experimental", - ] + py_deps("absl/testing"), + deps = py_deps("absl/testing"), ) jax_multiplatform_test( @@ -1907,8 +1899,8 @@ jax_multiplatform_test( "notsan", ], # Times out under *SAN. deps = [ - "//jax:experimental", "//jax/_src:tree_util", + "//jax/experimental:custom_partitioning", ] + py_deps([ "absl/testing", "numpy", @@ -1925,7 +1917,7 @@ jax_multiplatform_test( name = "hijax_test", srcs = ["hijax_test.py"], deps = [ - "//jax:experimental", + "//jax/experimental:hijax", ] + py_deps([ "numpy", "absl/testing", @@ -2103,7 +2095,7 @@ jax_multiplatform_test( jax_multiplatform_test( name = "xla_metadata_test", srcs = ["xla_metadata_test.py"], - deps = ["//jax:experimental"] + py_deps("absl/testing"), + deps = ["//jax/experimental:xla_metadata"] + py_deps("absl/testing"), ) jax_multiplatform_test( @@ -2112,9 +2104,7 @@ jax_multiplatform_test( enable_backends = [ "tpu", ], - deps = [ - "//jax:experimental", - ] + py_deps([ + deps = py_deps([ "absl/testing", "numpy", ]), @@ -2194,7 +2184,6 @@ jax_py_test( srcs = ["custom_partitioning_sharding_rule_test.py"], deps = [ "//jax", - "//jax:experimental", "//jax/_src:custom_partitioning_sharding_rule", "//jax/_src:test_util", ] + py_deps("absl/testing"), diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD index af9aab89da16..1bcc601f2b45 100644 --- a/tests/mosaic/BUILD +++ b/tests/mosaic/BUILD @@ -115,7 +115,6 @@ jax_multiplatform_test( env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0 --xla_gpu_experimental_enable_nvshmem=true"}, tags = ["multiaccelerator"], deps = [ - "//jax:experimental", "//jax/_src:test_multiprocess", "//jax/experimental:mosaic_gpu", ] + py_deps([ diff --git a/tests/multiprocess/BUILD b/tests/multiprocess/BUILD index 5d5d90a5f321..73fd9bc5c7a9 100644 --- a/tests/multiprocess/BUILD +++ b/tests/multiprocess/BUILD @@ -107,8 +107,9 @@ jax_multiprocess_test( srcs = ["host_callback_test.py"], main = "host_callback_test.py", deps = [ - "//jax:experimental", "//jax/_src:test_multiprocess", + "//jax/experimental", + "//jax/experimental:multihost_utils", ], ) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index e160f38133ca..4128e6dffda2 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -574,7 +574,6 @@ jax_multiplatform_test( "tpu_v5p", ], deps = [ - "//jax:experimental", "//jax/experimental:pallas_tpu", ] + py_deps([ "absl/testing", @@ -640,7 +639,6 @@ jax_multiplatform_test( ], enable_backends = ["cpu"], deps = [ - "//jax:experimental", "//jax/experimental:pallas", "//jax/experimental:pallas_tpu", ] + py_deps([ @@ -656,7 +654,7 @@ jax_multiplatform_test( ], enable_backends = ["cpu"], deps = [ - "//jax:experimental", + "//jax/experimental", "//jax/experimental:pallas", "//jax/experimental:pallas_tpu", ] + py_deps([ @@ -1139,7 +1137,6 @@ jax_multiplatform_test( "notap", ], deps = [ - "//jax:experimental", "//jax/_src:test_multiprocess", "//jax/experimental:pallas", "//jax/experimental:pallas_experimental_gpu_ops", @@ -1186,7 +1183,6 @@ jax_multiplatform_test( "notap", ], deps = [ - "//jax:experimental", "//jax/_src:test_multiprocess", "//jax/experimental:pallas", "//jax/experimental:pallas_experimental_gpu_ops", @@ -1215,7 +1211,6 @@ jax_multiplatform_test( "notap", ], deps = [ - "//jax:experimental", "//jax/_src:test_multiprocess", "//jax/experimental:pallas", "//jax/experimental:pallas_experimental_gpu_ops", From 6d55ebdad529f5778b53a8b0eb04767a8cd40d5b Mon Sep 17 00:00:00 2001 From: Krishna Haridasan Date: Wed, 10 Dec 2025 09:53:59 -0800 Subject: [PATCH 144/315] Remove another instance of access to AttributeMap::Map PiperOrigin-RevId: 842768120 --- jaxlib/call_location.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/jaxlib/call_location.cc b/jaxlib/call_location.cc index df335af1c46e..b556f5e6ee62 100644 --- a/jaxlib/call_location.cc +++ b/jaxlib/call_location.cc @@ -124,9 +124,14 @@ void PopulateCallLocation(xla::ifrt::ExecuteOptions& options, } if (!call_location_str.empty()) { + // Simplify this to use AttributeMap::Set(). xla::ifrt::AttributeMap::Map attrs_map; if (options.custom_options.has_value()) { - attrs_map = options.custom_options->map(); + options.custom_options->ForEach( + [&](const std::string& key, + const xla::ifrt::AttributeMap::Value& value) { + attrs_map.insert({key, value}); + }); } attrs_map.insert( {std::string(xla::ifrt::PjRtCompatibleLoadedExecutable::kCallLocation), From cd4c077f3193f535a766ccc923204f1a93f9a251 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 10 Dec 2025 11:06:26 -0800 Subject: [PATCH 145/315] Widen visibility of //jax/experimental:transfer. Change in preparation for removing the alias //jax:experimental_transfer, which currently has this wider visibility. PiperOrigin-RevId: 842799409 --- jax/experimental/BUILD | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/jax/experimental/BUILD b/jax/experimental/BUILD index 7215cecc8651..bb18041e83dd 100644 --- a/jax/experimental/BUILD +++ b/jax/experimental/BUILD @@ -696,7 +696,10 @@ pytype_strict_library( pytype_strict_library( name = "transfer", srcs = ["transfer.py"], - visibility = ["//jax:internal"], + visibility = [ + ":experimental_transfer_users", + "//jax:internal", + ], deps = [ "//jax", "//jax/_src:util", From 76cb1828c78517f1e39f3df5652298a484b77e33 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Wed, 10 Dec 2025 20:59:35 +0000 Subject: [PATCH 146/315] [hijax] handle symbolic zeros generically in hijax lin rule Soon we'll upgrade the user-facing rule api to support symbolic zeros. That's in the custom-vjp3 branch. But to unblock things now, and as a convenience wrapper later, we can support not differentiating with respect to some inputs without the user rule being aware: we just use the black-hole /dev/null NullAccum in the right places, where the right places are indicated by the in_nz on the linearize pass. This is less good than telling the user rule about in_nz because it doesn't let the user rule save any work or memory. (At first we tried just instantiating zeros, but that was worse because we lose information as to whether those zeros are linear or nonlinear, complicating things downstream.) Co-authored-by: Robert Dyro --- jax/_src/hijax.py | 22 +++++++++++++--------- jax/_src/interpreters/ad.py | 9 ++++----- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/jax/_src/hijax.py b/jax/_src/hijax.py index 80729e9b53d2..d35da8f88752 100644 --- a/jax/_src/hijax.py +++ b/jax/_src/hijax.py @@ -414,8 +414,7 @@ def _call_hi_primitive_batcher(axis_data, args_flat, dims_flat, prim): return ans_flat, dims_flat batching.fancy_primitive_batchers[call_hi_primitive_p] = _call_hi_primitive_batcher -def _call_hi_primitive_linearize(nz_in, *args_flat, prim): - assert all(nz_in) +def _call_hi_primitive_linearize(nz_in_flat, *args_flat, prim): args = tree_unflatten(prim.in_tree, args_flat) ans, residuals = prim.vjp_fwd(*args) # TODO(dougalm): does the fwd/bwd API force us to assume the nzs_out are all False @@ -423,25 +422,30 @@ def _call_hi_primitive_linearize(nz_in, *args_flat, prim): # LinearizeTrace.ProcessPrimitive)? ans_flat = tree_leaves_checked(prim.out_tree, ans) nzs_out = [True for _ in ans_flat] - return (ans_flat, nzs_out, residuals, partial(fake_linear_op, prim)) + return (ans_flat, nzs_out, residuals, partial(fake_linear_op, prim, nz_in_flat)) -def fake_linear_op(prim, rs, *tangents): +def fake_linear_op(prim, nz_in_flat, rs, *tangents): residuals_flat, residuals_tree = tree_flatten(rs) - return call_hi_primitive_linearized_p.bind(*residuals_flat, *tangents, - residuals_tree=residuals_tree, prim=prim) + tangents_flat, _ = tree_flatten(tangents) # prune symbolic zeros + return call_hi_primitive_linearized_p.bind( + *residuals_flat, *tangents_flat, + residuals_tree=residuals_tree, nz_in_flat=tuple(nz_in_flat), prim=prim) ad.primitive_linearizations[call_hi_primitive_p] = _call_hi_primitive_linearize call_hi_primitive_linearized_p = core.Primitive("call_hi_primitive_linearized") call_hi_primitive_linearized_p.multiple_results = True -call_hi_primitive_linearized_p.is_high = lambda *args, prim, residuals_tree: True # type: ignore +call_hi_primitive_linearized_p.is_high = lambda *args, prim, **_: True # type: ignore @call_hi_primitive_linearized_p.def_abstract_eval -def _call_hi_primitive_linearized_abstract_eval(*_args, prim, residuals_tree): +def _call_hi_primitive_linearized_abstract_eval(*_args, prim, residuals_tree, nz_in_flat): return [t.to_tangent_aval() for t in prim.out_avals_flat] # TODO(dougalm): handle nonzeros -def _call_hi_primitive_linearized_transpose(cts_flat, *args, prim, residuals_tree): +def _call_hi_primitive_linearized_transpose(cts_flat, *args, prim, residuals_tree, nz_in_flat): residuals_flat, accums_flat = split_list(args, [residuals_tree.num_leaves]) residuals = tree_unflatten(residuals_tree, residuals_flat) + accums_flat_ = iter(accums_flat) + accums_flat = [next(accums_flat_) if nz else ad.NullAccum() for nz in nz_in_flat] + assert next(accums_flat_, None) is None accums = tree_unflatten(prim.in_tree, accums_flat) cts = tree_unflatten(prim.out_tree, cts_flat) none = prim.vjp_bwd(residuals, cts, *accums) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 863839cdc080..75d1862b9809 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -598,11 +598,10 @@ def accum(self, x): def freeze(self): return self.val -# class NullAccum(GradAccum): -# aval: core.AbstractValue -# def __init__(self, aval): self.aval = aval -# def accum(self, x): return -# def freeze(self): assert False +class NullAccum(GradAccum): + def __init__(self): pass + def accum(self, x): return + def freeze(self): assert False fancy_transposes: dict[core.Primitive, Callable] = {} From 64a8e0de42681bdc26b89dba30fc2c979869ea82 Mon Sep 17 00:00:00 2001 From: Yulia Baturina Date: Wed, 10 Dec 2025 13:48:46 -0800 Subject: [PATCH 147/315] Temporarily pin nightly libtpu version to `0.0.31.dev20251209` to unblock TPU presubmit. PiperOrigin-RevId: 842864247 --- .github/workflows/bazel_test_tpu.yml | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/.github/workflows/bazel_test_tpu.yml b/.github/workflows/bazel_test_tpu.yml index 6459c45475e0..15895f995fd3 100644 --- a/.github/workflows/bazel_test_tpu.yml +++ b/.github/workflows/bazel_test_tpu.yml @@ -122,8 +122,21 @@ jobs: mkdir -p $(pwd)/dist $JAXCI_PYTHON -m pip install --upgrade pip echo "Download the wheel into a local directory" + # TODO(ybaturina): Remove this once the libtpu wheel is updated. if [[ "${{ inputs.libtpu-version-type }}" == "nightly" ]]; then - $JAXCI_PYTHON -m pip download -d $(pwd)/dist --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + version="" + suffix="" + full_python_version="${{ inputs.python }}" + + if [[ "$full_python_version" == *-* ]]; then + version="${full_python_version%%-*}" + suffix="t" + else + version="$full_python_version" + suffix="" + fi + version_no_dots="${version//./}" + wget -P $(pwd)/dist https://storage.googleapis.com/libtpu-nightly-releases/wheels/libtpu/libtpu-0.0.31.dev20251209+nightly-cp${version_no_dots}-cp${version_no_dots}${suffix}-manylinux_2_31_x86_64.whl elif [[ "${{ inputs.libtpu-version-type }}" == "pypi_latest" ]]; then echo "Using latest libtpu from PyPI" $JAXCI_PYTHON -m pip download -d $(pwd)/dist libtpu From cdf4a66e8c38e213d6cdf6c32c5e97dfa7268da7 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Wed, 10 Dec 2025 14:46:01 -0800 Subject: [PATCH 148/315] Remove of disassemble_into_single_device_arrays in favor of a new pre_wrap function for result handlers. Similarly remove disassemble_prefix_into_single_device_arrays for strict=True/False on consume_with_handlers. PiperOrigin-RevId: 842888651 --- jax/_src/dispatch.py | 9 ++++++++ jax/_src/interpreters/pxla.py | 20 +++++++++++++----- jaxlib/_jax/__init__.pyi | 3 ++- jaxlib/py_array.cc | 23 ++++++++++++++------- jaxlib/py_executable.cc | 39 ++++++++++++++++++++++++++--------- jaxlib/py_executable.h | 3 ++- jaxlib/xla_client.py | 2 +- 7 files changed, 74 insertions(+), 25 deletions(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 1c50cf5bd373..ff40a53aec48 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -299,6 +299,15 @@ def check_special(name: str, bufs: Sequence[basearray.Array]) -> None: for buf in bufs: _check_special(name, buf.dtype, buf) + +def check_special_array(name: str, arr: array.ArrayImpl) -> array.ArrayImpl: + if needs_check_special(): + if dtypes.issubdtype(arr.dtype, np.inexact): + for buf in arr._arrays: + _check_special(name, buf.dtype, buf) + return arr + + def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None: if dtypes.issubdtype(dtype, np.inexact): if config.debug_nans.value and np.any(np.isnan(np.asarray(buf))): diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 54ff65bd0123..e2d8fce03b6b 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -58,6 +58,7 @@ from jax._src.interpreters import mlir from jax._src.layout import Layout, AutoLayout, Format from jax._src.lib import _jax +from jax._src.lib import jaxlib_extension_version from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo @@ -1363,20 +1364,29 @@ def __call__(self, *args): input_bufs = self._add_tokens_to_inputs(input_bufs) results = self.xla_executable.execute_sharded(input_bufs, with_tokens=True) - result_token_bufs = results.disassemble_prefix_into_single_device_arrays( - len(self.ordered_effects)) + if jaxlib_extension_version >= 391: + result_token_bufs = results.consume_with_handlers( + [lambda xs: xs] * len(self.ordered_effects), strict=False) + else: + result_token_bufs = results.disassemble_prefix_into_single_device_arrays( + len(self.ordered_effects)) sharded_runtime_token = results.consume_token() self._handle_token_bufs(result_token_bufs, sharded_runtime_token) else: results = self.xla_executable.execute_sharded(input_bufs) - if dispatch.needs_check_special(): + if jaxlib_extension_version >= 391 or not dispatch.needs_check_special(): + handlers = self.out_handler.handlers + if dispatch.needs_check_special(): + special_check = functools.partial( + dispatch.check_special_array, self.name) + handlers = [h.pre_wrap(special_check) for h in handlers] + out = results.consume_with_handlers(handlers) + else: out_arrays = results.disassemble_into_single_device_arrays() for arrays in out_arrays: dispatch.check_special(self.name, arrays) out = self.out_handler(out_arrays) - else: - out = results.consume_with_handlers(self.out_handler.handlers) if (self.pgle_profiler is not None and self.pgle_profiler.is_running() and len(out) > 0): diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index 120acdb347e5..405c485e1f76 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -1030,6 +1030,7 @@ def array_result_handler( class ResultHandler: def __call__(self, arg: Array | Sequence[Array], /) -> Array: ... def wrap(self, arg: Callable, /) -> ResultHandler: ... + def pre_wrap(self, arg: Callable, /) -> ResultHandler: ... class DeviceList: def __init__(self, arg: tuple[Device, ...], /) -> None: ... @@ -1228,7 +1229,7 @@ class ExecuteResults: self, arg: int, / ) -> list[list[Array]]: ... def consume_with_handlers( - self, arg: Sequence[ResultHandler | object], / + self, out_handlers: Sequence[ResultHandler | object], strict: bool = ... ) -> list[object]: ... def consume_token(self) -> ShardedToken: ... diff --git a/jaxlib/py_array.cc b/jaxlib/py_array.cc index 37baa58c6cc9..c8984d7581a8 100644 --- a/jaxlib/py_array.cc +++ b/jaxlib/py_array.cc @@ -2373,13 +2373,22 @@ absl::Status PyArray::Register(nb::module_& m) { }, nb::sig( "def __call__(self, arg: Array | Sequence[Array], /) -> Array")) - .def("wrap", [](const PyArrayResultHandler& self, nb::callable wrapper) { - auto wrappers = self.wrappers(); - wrappers.push_back(std::move(wrapper)); - return make_nb_class( - self.aval(), self.sharding(), self.committed(), self.skip_checks(), - std::move(wrappers)); - }); + .def("wrap", + [](const PyArrayResultHandler& self, nb::callable wrapper) { + auto wrappers = self.wrappers(); + wrappers.push_back(std::move(wrapper)); + return make_nb_class( + self.aval(), self.sharding(), self.committed(), + self.skip_checks(), std::move(wrappers)); + }) + .def("pre_wrap", + [](const PyArrayResultHandler& self, nb::callable wrapper) { + auto wrappers = self.wrappers(); + wrappers.insert(wrappers.begin(), std::move(wrapper)); + return make_nb_class( + self.aval(), self.sharding(), self.committed(), + self.skip_checks(), std::move(wrappers)); + }); return absl::OkStatus(); } diff --git a/jaxlib/py_executable.cc b/jaxlib/py_executable.cc index d6ba17fd8a09..18f09389497f 100644 --- a/jaxlib/py_executable.cc +++ b/jaxlib/py_executable.cc @@ -256,16 +256,34 @@ PyExecuteResults::DisassemblePrefixIntoSingleDeviceArrays(size_t n) { std::vector PyExecuteResults::ConsumeWithHandlers( std::vector> - out_handlers) { + out_handlers, + bool strict) { std::vector outputs; - auto ifrt_arrays = Consume(); - int num_output_buffers = ifrt_arrays.size(); - outputs.reserve(num_output_buffers); - if (out_handlers.size() != num_output_buffers) { - throw nb::value_error( - absl::StrCat("Mismatch between out_handlers and num_results: ", - out_handlers.size(), " vs ", num_output_buffers) - .c_str()); + int num_output_buffers = out_handlers.size(); + std::vector ifrt_arrays; + if (strict) { + if (out_handlers.size() != ifrt_arrays_.size()) { + throw nb::value_error( + absl::StrCat("Mismatch between out_handlers and num_results: ", + out_handlers.size(), " vs ", ifrt_arrays_.size()) + .c_str()); + } + ifrt_arrays = Consume(); + } else { + if (out_handlers.size() > ifrt_arrays_.size()) { + throw nb::value_error( + absl::StrCat("Mismatch between out_handlers and num_results: ", + out_handlers.size(), " > ", ifrt_arrays_.size()) + .c_str()); + } + CheckNotDisassembled(); + ifrt_arrays.reserve(ifrt_arrays_.size() - num_output_buffers); + for (size_t i = num_output_buffers; i < ifrt_arrays_.size(); ++i) { + ifrt_arrays.push_back(std::move(ifrt_arrays_[i])); + } + ifrt_arrays_.erase(ifrt_arrays_.begin() + ifrt_arrays_.size(), + ifrt_arrays_.end()); + std::swap(ifrt_arrays_, ifrt_arrays); } for (int buffer_id = 0; buffer_id < num_output_buffers; ++buffer_id) { auto& handler = out_handlers[buffer_id]; @@ -305,7 +323,8 @@ void PyExecuteResults::Register(nb::module_& m) { &PyExecuteResults::DisassembleIntoSingleDeviceArrays) .def("disassemble_prefix_into_single_device_arrays", &PyExecuteResults::DisassemblePrefixIntoSingleDeviceArrays) - .def("consume_with_handlers", &PyExecuteResults::ConsumeWithHandlers) + .def("consume_with_handlers", &PyExecuteResults::ConsumeWithHandlers, + nb::arg("out_handlers"), nb::arg("strict") = true) .def("consume_token", &PyExecuteResults::ConsumeToken); } diff --git a/jaxlib/py_executable.h b/jaxlib/py_executable.h index ffce042355cd..2e3f33a1222c 100644 --- a/jaxlib/py_executable.h +++ b/jaxlib/py_executable.h @@ -103,7 +103,8 @@ class PyExecuteResults { std::vector ConsumeWithHandlers( std::vector> - out_handlers); + out_handlers, + bool strict); std::vector Consume(); diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index 8ae8102af30a..c26c41ccf4fb 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -47,7 +47,7 @@ # Please suffix the version number with a brief description of your change # in a comment. The goal here is to force a merge conflict if two changes # attempt to grab the same version number. -_version = 390 # ResultHandler.wrap +_version = 391 # ResultHandler.pre_wrap # An internal increasing version number for protecting jaxlib code against # ifrt changes. From 93b138c35450e0e4caba6bc5d74e955209516724 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Wed, 10 Dec 2025 14:59:20 -0800 Subject: [PATCH 149/315] [pmap] Remove `default_pmap_sharding` Reverts 3a37e92770906b853dbeeb16e9e320462ad53299 PiperOrigin-RevId: 842893715 --- jax/_src/sharding_impls.py | 46 ---------------------------- jax/sharding.py | 1 - tests/documentation_coverage_test.py | 2 +- 3 files changed, 1 insertion(+), 48 deletions(-) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 2ddf717eae39..b658a4d13966 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -351,52 +351,6 @@ def shard_shape(self, global_shape: Shape) -> Shape: PmapSharding.__module__ = 'jax.sharding' -def default_pmap_sharding( - shape: Shape, - sharded_dim: int | None = 0, - devices: Sequence[xc.Device] | None = None, -) -> NamedSharding | PmapSharding: - """Creates a NamedSharding equivalent to PmapSharding.default. - - This function provides the same sharding semantics as PmapSharding.default - but returns a NamedSharding when jax_pmap_shmap_merge is enabled, which is - compatible with the shard_map-based pmap implementation. - - Args: - shape: The shape of the input array. - sharded_dim: Dimension the input array is sharded on. Defaults to 0. - If None, the array is fully replicated. - devices: Optional sequence of devices to use. If omitted, uses - jax.local_devices(). - - Returns: - A NamedSharding if jax_pmap_shmap_merge is enabled, otherwise a - PmapSharding. - """ - if not config.pmap_shmap_merge.value: - return PmapSharding.default(shape, sharded_dim=sharded_dim, devices=devices) - - if sharded_dim is None: - if devices is None: - raise ValueError("One of sharded_dim or devices must be set.") - mesh = mesh_lib.Mesh(np.array(devices), ('_default_pmap_sharding',)) - return NamedSharding(mesh, PartitionSpec()) - - if len(shape) == 0: - raise ValueError("shape must be non-empty for sharded_dim != None") - - num_ways_sharded = shape[sharded_dim] - - if devices is None: - pmap_devices = np.array(xb.local_devices()[:num_ways_sharded]) - else: - pmap_devices = np.array(devices) - - mesh = mesh_lib.Mesh(pmap_devices, ('_default_pmap_sharding',)) - spec_list: list[str | None] = [None] * len(shape) - spec_list[sharded_dim] = '_default_pmap_sharding' - return NamedSharding(mesh, PartitionSpec(*spec_list)) - def _unpickle_gspmd_sharding(devices, op_sharding, memory_kind): return GSPMDSharding(devices, op_sharding, memory_kind=memory_kind) diff --git a/jax/sharding.py b/jax/sharding.py index e98ad72be036..c592abec393f 100644 --- a/jax/sharding.py +++ b/jax/sharding.py @@ -20,7 +20,6 @@ NamedSharding as NamedSharding, SingleDeviceSharding as SingleDeviceSharding, PmapSharding as _deprecated_PmapSharding, - default_pmap_sharding as default_pmap_sharding, set_mesh as set_mesh, get_mesh as get_mesh, ) diff --git a/tests/documentation_coverage_test.py b/tests/documentation_coverage_test.py index 316f925a6ec8..83ae55a7423c 100644 --- a/tests/documentation_coverage_test.py +++ b/tests/documentation_coverage_test.py @@ -68,7 +68,7 @@ def jax_docs_dir() -> str: 'jax.profiler': ['ProfileData', 'ProfileEvent', 'ProfileOptions', 'ProfilePlane', 'stop_server'], 'jax.random': ['key_impl', 'random_gamma_p'], 'jax.scipy.special': ['bessel_jn', 'sph_harm_y'], - 'jax.sharding': ['AbstractDevice', 'AbstractMesh', 'AxisType', 'auto_axes', 'default_pmap_sharding', 'explicit_axes', 'get_abstract_mesh', 'reshard', 'set_mesh', 'use_abstract_mesh', 'get_mesh'], + 'jax.sharding': ['AbstractDevice', 'AbstractMesh', 'AxisType', 'auto_axes', 'explicit_axes', 'get_abstract_mesh', 'reshard', 'set_mesh', 'use_abstract_mesh', 'get_mesh'], 'jax.stages': ['ArgInfo', 'CompilerOptions'], 'jax.tree_util': ['DictKey', 'FlattenedIndexKey', 'GetAttrKey', 'PyTreeDef', 'SequenceKey', 'default_registry'], } From b12fd6bd440360da6cf4ab02c29374e1e3ef2474 Mon Sep 17 00:00:00 2001 From: Ashish Rao Date: Mon, 8 Dec 2025 20:20:31 +0000 Subject: [PATCH 150/315] Modify _batched_device_put_impl to batch cross-host transfers + enable test_cross_host_transfer_batched on GPU platform --- jax/_src/dispatch.py | 47 ++++++++++++++++++++++++++++---- tests/multiprocess/array_test.py | 13 +-------- 2 files changed, 43 insertions(+), 17 deletions(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index ff40a53aec48..d3fca5107b01 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -396,6 +396,25 @@ def result_handler(self, shard_arg_result): return pxla.global_aval_to_result_handler( self.aval, self.s, self.committed)(shard_arg_result) +@dataclasses.dataclass(frozen=True) +class _DeferredCrossHostTransferArg: + """Deferred call to `xc.batched_copy_array_to_devices_with_sharding` for + cross-host data transfers. + + Per-array impls return this object instead of a result array to indicate a + deferred `batched_copy_array_to_devices_with_sharding` call for a cross-host + data transfer. `_batched_device_put_impl` then batches all + `_DeferredCrossHostTransferArg` objects into a single + `_batched_device_put_impl` call. + + For any _DeferredCrossHostTransferArg, _is_supported_cross_host_transfer( + x.ndim, x.sharding, dst_sharding) == True. + """ + + x: array.ArrayImpl + dst_sharding: Sharding + copy_semantics: ArrayCopySemantics + def _device_put_sharding_impl( x: Any, @@ -434,9 +453,7 @@ def _device_put_sharding_impl( if (x_is_jax_array and x._committed and xla_bridge.process_count() > 1 and _is_supported_cross_host_transfer(x.ndim, x_sharding, s)): - return xc.batched_copy_array_to_devices_with_sharding( - [x], [s._internal_device_list], [s], # pytype: disable=attribute-error - [copy])[0] + return _DeferredCrossHostTransferArg(x, s, copy) if not s_is_fully_addressable: # If both the source and target shardings are not fully addressable and @@ -552,7 +569,14 @@ def _batched_device_put_impl( copy_semantics: Sequence[ArrayCopySemantics], dst_avals: Sequence[core.ShapedArray | None]): ys = [] + + # Used to batch transfers when _device_put_impl returns a _DeferredShardArg. dsa_indices, dsa_xs, dsa_shardings, dsa_copy_semantics = [], [], [], [] + # Used to batch transfers when _device_put_impl returns a + # _DeferredCrossHostTransferArg. + dca_indices, dca_xs, dca_shardings, dca_device_lists, dca_copy_semantics = \ + [], [], [], [], [] + for i, (x, device, src, cp, aval) in enumerate( zip(xs, devices, srcs, copy_semantics, dst_avals)): y = _device_put_impl(x, device=device, src=src, copy=cp, aval=aval) @@ -561,11 +585,17 @@ def _batched_device_put_impl( dsa_xs.append(y.x) dsa_shardings.append(y.s) dsa_copy_semantics.append(y.copy_semantics) + elif isinstance(y, _DeferredCrossHostTransferArg): + dca_indices.append(i) + dca_xs.append(y.x) + dca_shardings.append(y.dst_sharding) + dca_device_lists.append(y.dst_sharding._internal_device_list) # pytype: disable=attribute-error + dca_copy_semantics.append(y.copy_semantics) ys.append(y) + # Batch shard_arg / batched_copy_array_to_devices_with_sharding calls. Helps + # improve efficiency for backends that support efficient batch transfer. if dsa_xs: - # Batch shard_arg calls. Helps improve efficiency for backends that support - # efficient batch transfer. # device_put handles `Format` via a different path, so just pass `None` as # the layout here. shard_arg_results = pxla.shard_args(dsa_shardings, [None] * len(dsa_xs), @@ -573,6 +603,13 @@ def _batched_device_put_impl( for i, shard_arg_result in zip(dsa_indices, shard_arg_results): assert isinstance(ys[i], _DeferredShardArg) ys[i] = ys[i].result_handler(shard_arg_result) + if dca_xs: + copy_array_results = xc.batched_copy_array_to_devices_with_sharding( + dca_xs, dca_device_lists, dca_shardings, dca_copy_semantics) + for i, copy_array_result in zip(dca_indices, copy_array_results): + assert isinstance(ys[i], _DeferredCrossHostTransferArg) + ys[i] = copy_array_result + return ys def batched_device_put_impl( diff --git a/tests/multiprocess/array_test.py b/tests/multiprocess/array_test.py index 420b838cdf1e..83cdf7913708 100644 --- a/tests/multiprocess/array_test.py +++ b/tests/multiprocess/array_test.py @@ -23,8 +23,6 @@ from jax._src import sharding_impls from jax._src import test_multiprocess as jt_multiprocess from jax._src import test_util as jtu -from jax._src import xla_bridge as xb -from jax._src.lib import xla_client as xc import jax.numpy as jnp from jax.sharding import PartitionSpec as P import numpy as np @@ -979,12 +977,6 @@ def test_cross_host_transfer_named_sharding_replicated(self): np.testing.assert_array_equal(shard.data, x[shard.index]) def test_cross_host_transfer_batched(self): - backend = xb.get_backend() - if "cuda" in backend.platform_version: - self.skipTest( - "The CUDA plugin does not support batched cross-host transfers." - ) - num_arrays = 3 xs = [] for i in range(1, num_arrays + 1): @@ -1010,10 +1002,7 @@ def test_cross_host_transfer_batched(self): P("x")) ys = jax.device_put(xs, src_sharding) - copy_semantics = xc.ArrayCopySemantics.ALWAYS_COPY - zs = xc.batched_copy_array_to_devices_with_sharding( - ys, [dst_sharding._internal_device_list] * num_arrays, - [dst_sharding] * num_arrays, [copy_semantics] * num_arrays) + zs = jax.device_put(ys, dst_sharding) for (x, z) in zip(xs, zs): if jax.process_index() == dst_pid: self.assertLen(z.addressable_shards, n_local) From 86fe2aba9161764ef8b0d1bda4114f5c17a40d64 Mon Sep 17 00:00:00 2001 From: Quoc Truong Date: Wed, 10 Dec 2025 15:48:54 -0800 Subject: [PATCH 151/315] Add TPU v7 runners to nightly and continuous job Use the new TPU v7 runners (4 chips, 8 cores) in nightly and continuous jobs. PiperOrigin-RevId: 842912906 --- .github/actionlint.yaml | 1 + .github/workflows/cloud-tpu-ci-nightly.yml | 3 ++- .github/workflows/wheel_tests_continuous.yml | 3 ++- .github/workflows/wheel_tests_nightly_release.yml | 3 ++- 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/.github/actionlint.yaml b/.github/actionlint.yaml index 0fb4d3579612..6307be466509 100644 --- a/.github/actionlint.yaml +++ b/.github/actionlint.yaml @@ -16,4 +16,5 @@ self-hosted-runner: - "linux-x86-ct6e-180-8tpu" # Linux X86 TPU runner using ct6e-hightpu-8t machine with 2x4 topology. - "linux-x86-ct6e-180-4tpu" # Linux X86 TPU runner using ct6e-hightpu-4t machine with 2x2 topology. - "linux-x86-ct4p-240-4tpu" # Linux X86 TPU runner using ct4p-hightpu-4t machine with 2x2x1 topology. + - "linux-x86-tpu7x-224-4tpu" # Linux X86 TPU runner using tpu7x-224 machine with 4 TPU chips (8 cores) and 2x2x1 topology. - "linux-x86_64-cirrascale-64-8gpu-amd-mi250" # AMD runner diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index 7feb2d2ad4aa..73ccb3d7b5b4 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -35,7 +35,8 @@ jobs: tpu: [ {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}, - {type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"} + {type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"}, + {type: "v7x-8", cores: "8", runner: "linux-x86-tpu7x-224-4tpu"} ] python-version: ["3.11"] # Exclude v6e-8 tests for pypi_latest for resource constraints. diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index 7aa14307d96c..07e2eb4fd978 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -241,7 +241,8 @@ jobs: tpu-specs: [ # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}, - {type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"} + {type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"}, + {type: "v7x-8", cores: "8", runner: "linux-x86-tpu7x-224-4tpu"} ] libtpu-version-type: ["nightly"] name: "Pytest TPU (JAX artifacts version = ${{ format('{0}', 'head') }})" diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index b53be566a378..8a5703412936 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -177,7 +177,8 @@ jobs: tpu-specs: [ # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}, - {type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"} + {type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"}, + {type: "v7x-8", cores: "8", runner: "linux-x86-tpu7x-224-4tpu"} ] libtpu-version-type: ["pypi_latest", "nightly"] exclude: From 1e14a1503b8b9ef631f539436992b26a0d4ebc89 Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Wed, 10 Dec 2025 19:22:44 -0500 Subject: [PATCH 152/315] Add missing description of numpy.concatenate behavior for axis=None. --- jax/_src/numpy/lax_numpy.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 29ddd583cc02..96d100e5534b 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -4557,7 +4557,8 @@ def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike], except along the specified axis. If a single array is given it will be treated equivalently to `arrays = unstack(arrays)`, but the implementation will avoid explicit unstacking. - axis: specify the axis along which to concatenate. + axis: specify the axis along which to concatenate. If None, the arrays are + flattened before concatenation. dtype: optional dtype of the resulting array. If not specified, the dtype will be determined via type promotion rules described in :ref:`type-promotion`. From a80c856dc1a51f589dfa87730cfbbf75e16c1178 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 10 Dec 2025 18:48:16 -0800 Subject: [PATCH 153/315] Update EnzymeJaX visibility PiperOrigin-RevId: 842981186 --- jaxlib/gpu/BUILD | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index 9d6111fcadaa..6169c882aa05 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -82,6 +82,10 @@ proto_library( cc_proto_library( name = "triton_cc_proto", compatible_with = None, + visibility = [ + "//jax:internal", + "//third_party/py/enzyme_ad:__subpackages__", + ], deps = [":triton_proto"], ) From 66b796c35901dccf4c74001b54b7774edaffd9dc Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Wed, 10 Dec 2025 19:55:15 -0800 Subject: [PATCH 154/315] Check that axis_name input to pcast is either a tuple or a str. PiperOrigin-RevId: 843002874 --- jax/_src/lax/parallel.py | 2 ++ tests/shard_map_test.py | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index f1c916e39923..b06e543ae481 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -2830,6 +2830,8 @@ def _get_from(aval, axes: tuple[AxisName, ...], name) -> str: _allowed_pcast_to = {'unreduced', 'reduced', 'varying'} def pcast(x, axis_name, *, to: str): + if isinstance(axis_name, (set, frozenset)): + raise TypeError(f"{axis_name=} must be a tuple or a str. Got {axis_name}") axes = (axis_name,) if not isinstance(axis_name, tuple) else axis_name if not axis_name: return x diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index bbf95102f057..6e0e56818c08 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -2275,6 +2275,24 @@ def f(x): f(jnp.arange(8.)) jax.grad(lambda x: f(x).sum())(jnp.arange(8.)) + @jtu.with_explicit_mesh((2,), 'x') + def test_pcast_axis_name_is_not_set(self, mesh): + def f(axis_name_type, x): + with self.assertRaisesRegex(TypeError, 'must be a tuple or a str'): + if axis_name_type == 'str': + jax.lax.pcast(x, {'x'}, to='varying') + elif axis_name_type == 'aval.vma': + jax.lax.pcast(x, x.aval.vma, to='varying') + + jax.shard_map(partial(f, 'str'), mesh=mesh, in_specs=P(), + out_specs=None)(np.arange(8.)) + jax.shard_map(partial(f, 'aval.vma'), mesh=mesh, in_specs=P(), + out_specs=None)(np.arange(8.)) + jax.jit(jax.shard_map(partial(f, 'str'), mesh=mesh, in_specs=P(), + out_specs=None))(np.arange(8.)) + jax.jit(jax.shard_map(partial(f, 'aval.vma'), mesh=mesh, in_specs=P(), + out_specs=None))(np.arange(8.)) + def test_rewrite_binops(self): mesh = jtu.create_mesh((4,), ('x',)) From e45f9d97c35ddb7ff9d443a76c307f124353ad27 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 10 Dec 2025 21:04:23 -0800 Subject: [PATCH 155/315] If there is a mesh in ctx and operand.mesh is empty, then make sure that output aval of broadcast_in_dim is on the mesh in ctx. PiperOrigin-RevId: 843027342 --- jax/_src/lax/lax.py | 4 +++- jax/_src/lax/slicing.py | 11 +++++++---- tests/pjit_test.py | 11 +++++++++++ 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index b75c719ef098..c598b6d4e61f 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -6425,8 +6425,10 @@ def _broadcast_in_dim_sharding_rule(operand, *, shape, broadcast_dimensions, orig_spec = iter(operand.sharding.spec) new_spec = [next(orig_spec) if i in bds else None for i in range(len(shape))] assert next(orig_spec, None) is None + mesh = (get_abstract_mesh() if operand.sharding.mesh.empty else + operand.sharding.mesh) return operand.sharding.update( - spec=operand.sharding.spec.update(partitions=new_spec)) + mesh=mesh, spec=operand.sharding.spec.update(partitions=new_spec)) def _broadcast_in_dim_typecheck_rule( _, operand, shape, broadcast_dimensions, sharding): diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index e102951cbf58..9ff745e922ce 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -1572,13 +1572,16 @@ def _batch_dynamic_slice_indices(indices, bdims): empty_marker = object() size = next((x.shape[i] for x, i in zip(indices, bdims) if i is not None), empty_marker) + out = next(((core.typeof(x).sharding.mesh, core.typeof(x).sharding.spec[i]) + for x, i in zip(indices, bdims) if i is not None), None) if size is empty_marker: return lax.concatenate([lax.broadcast(i, (1,)) for i in indices], 0), None + out_s = None if out is None else NamedSharding(out[0], P(out[1], None)) indices = lax.concatenate( - [lax.broadcast_in_dim(x, (size, 1), - broadcast_dimensions=((0,) if i is not None else ())) - for x, i in zip(indices, bdims)], - dimension=1) + [lax.broadcast_in_dim( + x, (size, 1), broadcast_dimensions=((0,) if i is not None else ()), + out_sharding=out_s) + for x, i in zip(indices, bdims)], dimension=1) return indices, 0 def _dynamic_slice_batching_rule(batched_args, batch_dims, *, slice_sizes): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 70e2e1d3503b..0f7e9f866371 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -9843,6 +9843,17 @@ def test_c64_to_f32_view_rountrip(self, mesh): y = jax.jit(lambda _x: _x.view(jnp.complex64))(x) self.assertEqual(y.sharding, NamedSharding(mesh, P('x', None))) + @jtu.with_explicit_mesh((2,), 'x') + def test_jnp_ones_mesh_ctx_aval(self, mesh): + @jax.jit + def f(): + out = jnp.ones((2,)) + self.assertEqual(out.aval.sharding.mesh, mesh.abstract_mesh) + self.assertEqual(out.aval.sharding.spec, P(None)) + return out + + self.assertEqual(f().sharding, NamedSharding(mesh, P(None))) + @jtu.pytest_mark_if_available('multiaccelerator') @jtu.ignore_warning(category=DeprecationWarning, From aad732503e80b2b8ec01b5e22aee1ec677615154 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 10 Dec 2025 22:31:31 -0800 Subject: [PATCH 156/315] Fix spmd_axis_name == explicit_mesh_axes assert when there are multiple mesh axes PiperOrigin-RevId: 843053541 --- jax/_src/api.py | 2 +- tests/pjit_test.py | 22 +++++++++++++--------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 08a9dd8387d2..5dde45b9e004 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -1192,7 +1192,7 @@ def vmap_f(*args, **kwargs): explicit_mesh_axis = _mapped_axis_spec(args_flat, in_axes_flat) if spmd_axis_name is not None and explicit_mesh_axis is not None: spmd_axis_name = ( - tuple(core.remove_size_one_mesh_axis(P(spmd_axis_name), get_abstract_mesh())) + tuple(*core.remove_size_one_mesh_axis(P(spmd_axis_name), get_abstract_mesh())) if config.remove_size_one_mesh_axis_from_type.value else spmd_axis_name) if spmd_axis_name == explicit_mesh_axis: spmd_axis_name = None diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 0f7e9f866371..3517ea41ba14 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7268,23 +7268,27 @@ def f(x): "Only one of spmd_axis_name or arrays sharded on.*spmd_axis_name"): f(arr) + @parameterized.parameters( + (('x', 'y', 'z'), ('x', 'y')), + (('x', 'z'), 'x') + ) @config.remove_size_one_mesh_axis_from_type(True) - @jtu.with_explicit_mesh((2, 1), ('x', 'y')) - def test_spmd_axis_name_explicit_mode_assert_remove_one_size(self, mesh): - np_inp = np.arange(16).reshape(8, 2) - arr = jax.device_put(np_inp, NamedSharding(mesh, P(('x', 'y'), None))) + @jtu.with_explicit_mesh((2, 2, 1), ('x', 'y', 'z')) + def test_spmd_axis_name_explicit_mode_assert_remove_one_size( + self, in_spec, out_spec, mesh): + np_inp = np.arange(16).reshape(4, 2, 2) + arr = jax.device_put(np_inp, NamedSharding(mesh, P(in_spec, None))) @jax.jit - @partial(jax.vmap, spmd_axis_name=('x', 'y')) + @partial(jax.vmap, spmd_axis_name=in_spec) def f(x): - # breakpoint() - self.assertEqual(x.aval.sharding.spec, P(None)) + self.assertEqual(x.aval.sharding.spec, P(None, None)) out = x * 2 - self.assertEqual(out.aval.sharding.spec, P(None)) + self.assertEqual(out.aval.sharding.spec, P(None, None)) return out out = f(arr) - self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + self.assertEqual(out.sharding, NamedSharding(mesh, P(out_spec, None, None))) self.assertArraysEqual(out, np_inp * 2) @jtu.with_explicit_mesh((2,), ('x',)) From f030b733bce7f372e58d788b7bb260b4e0ebfbd8 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 11 Dec 2025 00:06:17 -0800 Subject: [PATCH 157/315] Update XLA dependency to use revision http://github.com/openxla/xla/commit/5d63679b85e9808398a1bb725365dda4b23594e4 PiperOrigin-RevId: 843081363 --- third_party/xla/revision.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/revision.bzl b/third_party/xla/revision.bzl index b7788d301fd9..adfbfd05131d 100644 --- a/third_party/xla/revision.bzl +++ b/third_party/xla/revision.bzl @@ -21,5 +21,5 @@ # and update XLA_SHA256 with the result. # buildifier: disable=module-docstring -XLA_COMMIT = "63413fe5a5ce541ba0e076f25264d11f0311fac5" -XLA_SHA256 = "4381718d7981a6d866171fe206d07ce7cf58295c7aeaef9fdc5c44741f22b585" +XLA_COMMIT = "5d63679b85e9808398a1bb725365dda4b23594e4" +XLA_SHA256 = "e92ba838cd10126e7580435ff089f22ed6d2a7afc53ae6e5308e39ef9184e847" From db65e542202bf022c3e9dc713eabea061948c5a7 Mon Sep 17 00:00:00 2001 From: George Necula Date: Thu, 11 Dec 2025 07:58:00 +0000 Subject: [PATCH 158/315] [export] Fix export back compat serialization test. Make sure that the tests run only on the platform for which they were serialized. --- jax/_src/test_util.py | 2 +- tests/BUILD | 5 +++++ tests/export_serialization_back_compat_test.py | 6 +++++- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 30ed606b527f..246ae14e6023 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1554,7 +1554,7 @@ def mesh_fn(*args, **kwargs): def create_mesh(mesh_shape, axis_names, iota_order=False, axis_types=None): size = math.prod(mesh_shape) if len(xla_bridge.devices()) < size: - raise unittest.SkipTest(f"Test requires {size} global devices.") + raise unittest.SkipTest(f"Test requires {size} global devices and found {len(xla_bridge.devices())}.") if iota_order: devices = sorted(xla_bridge.devices(), key=lambda d: d.id) mesh_devices = np.array(devices[:size]).reshape(mesh_shape) diff --git a/tests/BUILD b/tests/BUILD index b723dc69120c..51975ea6c48f 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -2071,6 +2071,11 @@ jax_multiplatform_test( jax_multiplatform_test( name = "export_serialization_back_compat_test", srcs = ["export_serialization_back_compat_test.py"], + enable_backends = ["cpu", "gpu", "tpu"], + enable_configs = [ + "tpu_v3_x4", + "gpu_h100x2", + ], tags = [], deps = [ "//jax/_src:internal_export_back_compat_test_data", diff --git a/tests/export_serialization_back_compat_test.py b/tests/export_serialization_back_compat_test.py index b11858db3c2b..d23b51a1d34c 100644 --- a/tests/export_serialization_back_compat_test.py +++ b/tests/export_serialization_back_compat_test.py @@ -117,6 +117,8 @@ def export_and_serialize(self, fun, *args, ] ) def test_with_specified_sharding(self, testdata: dict[str, Any] | None): + if jtu.device_under_test() != "cpu": + self.skipTest("Testing only the CPU serialization") a = np.arange(16 * 4, dtype=np.float32).reshape((16, 4)) mesh = jtu.create_mesh((2,), "x") with jax.set_mesh(mesh): @@ -146,6 +148,8 @@ def f(b): ] ) def test_with_unspecified_sharding(self, testdata: dict[str, Any] | None): + if jtu.device_under_test() != "cpu": + self.skipTest("Testing only the CPU serialization") a = np.arange(16 * 4, dtype=np.float32).reshape((16, 4)) # Output sharding is not specified @@ -197,7 +201,7 @@ def test_with_memory_space(self, testdata: dict[str, Any] | None): if jtu.device_under_test() in ("tpu", "gpu"): b = exported.call(a) self.assertEqual(b.aval.memory_space, core.MemorySpace.Host) - self.assertEqual(b.sharding, a.sharding) + self.assertEqual(b.sharding.memory_kind, a.sharding.memory_kind) if __name__ == "__main__": From b6fb016741f61d0eec2a36806ee9a47c5c5b6b71 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 11 Dec 2025 06:05:04 -0800 Subject: [PATCH 159/315] [mosaic] Add a canonicalization pattern pushing memref.dim through tpu.memref_squeeze PiperOrigin-RevId: 843187331 --- jaxlib/mosaic/dialect/tpu/tpu_dialect.cc | 45 +++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc index b0cb7d35c4d5..3bfde20840a3 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc @@ -31,6 +31,7 @@ limitations under the License. #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" @@ -46,6 +47,7 @@ limitations under the License. #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.cc.inc" #include "jaxlib/mosaic/dialect/tpu/tpu_enums.cc.inc" +#include "jaxlib/mosaic/dialect/tpu/util.h" #include "xla/layout.h" // This is a bit unclean, but we need to squat the xla namespace to make sure @@ -122,9 +124,50 @@ struct MemRefCastEraseLayout : public OpRewritePattern { } }; +// Rewrites memref.dim(tpu.memref_squeeze(x)) to memref.dim(x) with the +// dimension index adjusted to account for squeezed dimensions. +struct MemRefDimOfSqueeze : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::DimOp dim_op, + PatternRewriter& rewriter) const override { + auto squeeze_op = dim_op.getSource().getDefiningOp(); + if (!squeeze_op) { + return failure(); + } + const std::optional maybe_dim = + getConstantIntValue(dim_op.getDimension()); + if (!maybe_dim) { + return failure(); + } + const int64_t dim = *maybe_dim; + MemRefType result_type = squeeze_op.getType(); + if (dim < 0 || result_type.getRank() <= dim) { + return dim_op.emitWarning("Dimension index is out of bounds"); + } + if (result_type.getDimSize(dim) != ShapedType::kDynamic) { + return failure(); + } + MemRefType source_type = getMemRefType(squeeze_op.getInput()); + FAILUREOR_ASSIGN_OR_RETURN( + SmallVector squeezed, + computeSqueezedDimsChecked(squeeze_op, source_type.getShape(), + result_type.getShape())); + int64_t source_dim = dim; + for (int squeezed_dim : squeezed) { + if (squeezed_dim <= source_dim) { + ++source_dim; + } + } + rewriter.replaceOpWithNewOp(dim_op, squeeze_op.getInput(), + source_dim); + return success(); + } +}; + void TPUDialect::getCanonicalizationPatterns(RewritePatternSet& results) const /*override*/ { - results.add(getContext()); + results.add(getContext()); } FailureOr GetCoreTypeOfParentFunc(Operation &op) { From b30c9de54296dcb6b878bff10a7e3909da787d02 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 11 Dec 2025 06:53:35 -0800 Subject: [PATCH 160/315] Remove some deprecated BUILD aliases. PiperOrigin-RevId: 843203193 --- jax/BUILD | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 928c67fadf85..16b88eef390f 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -301,30 +301,12 @@ alias( visibility = jax_visibility("pallas_deprecated_alias"), ) -alias( - name = "pallas_fuser", - actual = "//jax/experimental:pallas_fuser", - visibility = jax_visibility("pallas_fuser_deprecated_alias"), -) - -alias( - name = "pallas_mosaic_gpu", - actual = "//jax/experimental:pallas_mosaic_gpu", - visibility = jax_visibility("pallas_mosaic_gpu_deprecated_alias"), -) - alias( name = "pallas_tpu", actual = "//jax/experimental:pallas_tpu", visibility = jax_visibility("pallas_tpu_deprecated_alias"), ) -alias( - name = "experimental_transfer", - actual = "//jax/experimental:transfer", - visibility = jax_visibility("experimental_transfer_deprecated_alias"), -) - alias( name = "optimizers", actual = "//jax/example_libraries:optimizers", From 81fe5dd8b4413dfb59ec5d1fcd9006377a9293ab Mon Sep 17 00:00:00 2001 From: Ilya Tikhonovskiy Date: Thu, 11 Dec 2025 07:03:23 -0800 Subject: [PATCH 161/315] Integrate Triton up to 8d445186 https://github.com/openxla/triton/tree/triton_integrate_branch-1.15 PiperOrigin-RevId: 843206657 --- jaxlib/gpu/triton.cc | 4 ++-- jaxlib/gpu/triton.proto | 1 + jaxlib/gpu/triton_kernels.cc | 37 +++++++++++++++++------------------- jaxlib/gpu/triton_kernels.h | 12 +++++------- 4 files changed, 25 insertions(+), 29 deletions(-) diff --git a/jaxlib/gpu/triton.cc b/jaxlib/gpu/triton.cc index 42c58eb613a2..a1bb10ed510f 100644 --- a/jaxlib/gpu/triton.cc +++ b/jaxlib/gpu/triton.cc @@ -45,8 +45,8 @@ namespace jax::JAX_GPU_NAMESPACE { NB_MODULE(_triton, m) { nb::class_(m, "TritonKernel") - .def(nb::init()); + .def(nb::init()); nb::class_(m, "TritonParameter"); diff --git a/jaxlib/gpu/triton.proto b/jaxlib/gpu/triton.proto index 786b07afbdbe..553f95dd5f17 100644 --- a/jaxlib/gpu/triton.proto +++ b/jaxlib/gpu/triton.proto @@ -5,6 +5,7 @@ package jax_triton; message TritonKernel { string kernel_name = 1; // Kernel function name within module. uint32 num_warps = 2; + optional uint32 num_ctas = 10; uint32 shared_mem_bytes = 3; string ptx = 4; string ttir = 5; diff --git a/jaxlib/gpu/triton_kernels.cc b/jaxlib/gpu/triton_kernels.cc index 1961ada1bf76..0ad86f522d9d 100644 --- a/jaxlib/gpu/triton_kernels.cc +++ b/jaxlib/gpu/triton_kernels.cc @@ -315,17 +315,16 @@ class ModuleImage { ABSL_GUARDED_BY(mutex_); }; -Kernel::Kernel(std::string kernel_name, uint32_t num_warps, +Kernel::Kernel(std::string kernel_name, uint32_t num_warps, uint32_t num_ctas, uint32_t shared_mem_bytes, std::string ptx, std::string ttir, - int compute_capability, uint32_t cluster_dim_0, - uint32_t cluster_dim_1, uint32_t cluster_dim_2) + int compute_capability) : kernel_name_(std::move(kernel_name)), block_dim_x_(num_warps * kNumThreadsPerWarp), + num_ctas_(num_ctas), shared_mem_bytes_(shared_mem_bytes), ptx_(std::move(ptx)), ttir_(std::move(ttir)), - compute_capability_(compute_capability), - cluster_dims_{cluster_dim_0, cluster_dim_1, cluster_dim_2} {} + compute_capability_(compute_capability) {} absl::Status Kernel::Launch(gpuStream_t stream, uint32_t grid[3], void** params) { @@ -362,9 +361,7 @@ absl::Status Kernel::Launch(gpuStream_t stream, uint32_t grid[3], JAX_ASSIGN_OR_RETURN(gpuFunction_t kernel, module_image_->GetFunctionForContext(context)); - const uint32_t cluster_size = - cluster_dims_[0] * cluster_dims_[1] * cluster_dims_[2]; - if (cluster_size <= 1) { + if (num_ctas_ == 1) { return JAX_AS_STATUS(gpuLaunchKernel( kernel, grid[0], grid[1], grid[2], block_dim_x_, /*blockDimY=*/1, /*blockDimZ=*/1, shared_mem_bytes_, stream, params, @@ -372,16 +369,16 @@ absl::Status Kernel::Launch(gpuStream_t stream, uint32_t grid[3], } CUlaunchAttribute launch_attrs[2]; launch_attrs[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; - launch_attrs[0].value.clusterDim.x = cluster_dims_[0]; - launch_attrs[0].value.clusterDim.y = cluster_dims_[1]; - launch_attrs[0].value.clusterDim.z = cluster_dims_[2]; + launch_attrs[0].value.clusterDim.x = num_ctas_; + launch_attrs[0].value.clusterDim.y = 1; + launch_attrs[0].value.clusterDim.z = 1; launch_attrs[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; launch_attrs[1].value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD; CUlaunchConfig launch_config = { - /*gridDimX=*/grid[0] * cluster_dims_[0], - /*gridDimY=*/grid[1] * cluster_dims_[1], - /*gridDimZ=*/grid[2] * cluster_dims_[2], + /*gridDimX=*/grid[0] * num_ctas_, + /*gridDimY=*/grid[1], + /*gridDimZ=*/grid[2], /*blockDimX=*/block_dim_x_, /*blockDimY=*/1, /*blockDimZ=*/1, @@ -396,23 +393,23 @@ absl::Status Kernel::Launch(gpuStream_t stream, uint32_t grid[3], } /*static*/ Kernel Kernel::FromProto(const jax_triton::TritonKernel& proto) { - return Kernel(proto.kernel_name(), proto.num_warps(), + // Use 1 as default value if not specified in already serialized kernels. + int num_ctas = proto.has_num_ctas() ? proto.num_ctas() : 1; + + return Kernel(proto.kernel_name(), proto.num_warps(), num_ctas, proto.shared_mem_bytes(), proto.ptx(), proto.ttir(), - proto.compute_capability(), proto.cluster_dim_0(), - proto.cluster_dim_1(), proto.cluster_dim_2()); + proto.compute_capability()); } jax_triton::TritonKernel Kernel::ToProto() const { jax_triton::TritonKernel proto; proto.set_kernel_name(kernel_name_); proto.set_num_warps(block_dim_x_ / kNumThreadsPerWarp); + proto.set_num_ctas(num_ctas_); proto.set_shared_mem_bytes(shared_mem_bytes_); proto.set_ptx(ptx_); proto.set_ttir(ttir_); proto.set_compute_capability(compute_capability_); - proto.set_cluster_dim_0(cluster_dims_[0]); - proto.set_cluster_dim_1(cluster_dims_[1]); - proto.set_cluster_dim_2(cluster_dims_[2]); return proto; } diff --git a/jaxlib/gpu/triton_kernels.h b/jaxlib/gpu/triton_kernels.h index 3ab3e9143fb8..08320a104183 100644 --- a/jaxlib/gpu/triton_kernels.h +++ b/jaxlib/gpu/triton_kernels.h @@ -38,10 +38,9 @@ class ModuleImage; class Kernel { public: - Kernel(std::string kernel_name, uint32_t num_warps, uint32_t shared_mem_bytes, - std::string ptx, std::string ttir, int compute_capability, - uint32_t cluster_dim_0, uint32_t cluster_dim_1, - uint32_t cluster_dim_2); + Kernel(std::string kernel_name, uint32_t num_warps, uint32_t num_ctas, + uint32_t shared_mem_bytes, std::string ptx, std::string ttir, + int compute_capability); absl::Status Launch(gpuStream_t stream, uint32_t grid[3], void** params); @@ -54,11 +53,11 @@ class Kernel { private: std::string kernel_name_; uint32_t block_dim_x_; + uint32_t num_ctas_; uint32_t shared_mem_bytes_; std::string ptx_; std::string ttir_; int compute_capability_; - uint32_t cluster_dims_[3]; ModuleImage* module_image_ = nullptr; }; @@ -107,8 +106,7 @@ class AutotunedKernelCall { AutotunedKernelCall( std::string name, std::vector configs, - std::vector> input_output_aliases); + std::vector> input_output_aliases); static absl::StatusOr Autotune(AutotunedKernelCall kernel_call, gpuStream_t stream, From 03256b3a4d30ba1521f9377584d556cce871f7ca Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Thu, 11 Dec 2025 07:24:09 -0800 Subject: [PATCH 162/315] [Mosaic GPU] Support loading transposed refs to WGMMA_TRANSPOSED layout. PiperOrigin-RevId: 843212997 --- jax/_src/pallas/mosaic_gpu/lowering.py | 48 ++++++++++++++++++++++---- tests/pallas/mosaic_gpu_test.py | 30 ++++++++++++++++ 2 files changed, 72 insertions(+), 6 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index af4aebd172bd..272a9d249eeb 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1557,9 +1557,16 @@ def _get_lowering_rule( dtype = ctx.avals_out[0].dtype transforms = jax.tree.unflatten(tree, leaves) + transposed = ctx.out_layout_hint and ctx.out_layout_hint in ( + mgpu.WGMMA_TRANSPOSED_LAYOUT, + mgpu.TCGEN05_TRANSPOSED_LAYOUT, + ) + transposed = bool(transposed) x_smem, transforms = _handle_transforms( - ctx, x_ref, transforms, allow_peer_refs=True + ctx, x_ref, transforms, handle_transposes=not transposed, + allow_peer_refs=True ) + x_smem = cast(ir.Value, x_smem) del x_ref # Don't use x_ref anymore. Use x_smem instead! is_signed = mgpu_utils.is_signed(dtype) @@ -1569,20 +1576,49 @@ def _get_lowering_rule( return mgpu.FragmentedArray.splat(val, shape=(), is_signed=is_signed) match transforms: - case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)): + case ( + gpu_core.UnswizzleRef(swizzle), + gpu_core.UntileRef(tiling), + *maybe_transpose, + ): if len(tiling) != 2: raise NotImplementedError(f"Only 2D tiling is supported, got: {tiling}") - expected_minor_tiling = swizzle * 8 // dtypes.itemsize_bits(dtype) + bw = dtypes.itemsize_bits(ctx.avals_out[0].dtype) + expected_minor_tiling = swizzle * 8 // bw if tiling[-1] != expected_minor_tiling: raise NotImplementedError( "Minor tiling dimension does not fit swizzle: " f" expected {expected_minor_tiling}, got {tiling[-1]}" ) - layout = ctx.out_layout_hint or mgpu.WGMMA_LAYOUT + + if transposed != bool(maybe_transpose): + raise ValueError( + "Either both the ref and the value are transposed or neither is." + ) + + if maybe_transpose: + if maybe_transpose != [gpu_core.TransposeRef((1, 0))]: + raise NotImplementedError( + f"Unsupported transforms: {transforms} ({maybe_transpose})" + ) + + x_smem = mgpu.memref_transpose(x_smem, (1, 0, 3, 2)) return mgpu.FragmentedArray.load_tiled( - x_smem, is_signed=is_signed, swizzle=swizzle, layout=layout, optimized=optimized + x_smem, + is_signed=is_signed, + swizzle=swizzle, + layout=ctx.out_layout_hint or mgpu.WGMMA_LAYOUT, + optimized=optimized, ) - case (): + case (*maybe_transpose,): + if maybe_transpose: + if len(maybe_transpose) != 1 or not isinstance( + maybe_transpose[0], gpu_core.TransposeRef + ): + raise NotImplementedError( + f"Unsupported transforms: {transforms} ({maybe_transpose})" + ) + x_smem = mgpu.memref_transpose(x_smem, maybe_transpose[0].permutation) match ctx.out_layout_hint: case mgpu.WGStridedFragLayout(shape=shape, vec_size=vec_size): ref_ty = ir.MemRefType(x_smem.type) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 9efd7e7ad438..6e906c75839a 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -3251,6 +3251,36 @@ def compute(acc_ref): ) np.testing.assert_allclose(f(a, b), out_ref, rtol=1e-3) + def test_load_store_wgmma_transposed(self): + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Warpgroup: + self.skipTest("Doesn't work in WG semantics") + transforms = (plgpu.TilingTransform((8, 16)), + plgpu.SwizzleTransform(64)) + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct([8, 64], jnp.float32), + in_specs=[ + pl.BlockSpec(memory_space=plgpu.GMEM), + ], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + scratch_shapes=[ + plgpu.SMEM((8, 64), jnp.float32, transforms=transforms), + plgpu.Barrier(), + ], + ) + def kernel(x_gmem, o_ref, x_smem, barrier): + plgpu.copy_gmem_to_smem(x_gmem, x_smem, barrier) + plgpu.barrier_wait(barrier) + x = plgpu.load(x_smem.T, (), layout=plgpu.Layout.WGMMA_TRANSPOSED) + x_smem.T[...] = x + 1 + plgpu.commit_smem() + plgpu.copy_smem_to_gmem(x_smem, o_ref) + plgpu.wait_smem_to_gmem(0) + + x = jax.random.uniform(jax.random.key(42), shape=(8, 64), dtype=jnp.float32) + result = kernel(x) + np.testing.assert_array_equal(result, x + 1) + class PallasCallSm90AWGTest( PallasCallSm90ATest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup From 35e2b33853353d7e9157e2545477d13d2f73c3db Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 11 Dec 2025 08:03:58 -0800 Subject: [PATCH 163/315] Add a vlog for the launch_id_key JAX chooses to aid debugging. PiperOrigin-RevId: 843226195 --- jaxlib/py_executable.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/jaxlib/py_executable.cc b/jaxlib/py_executable.cc index 18f09389497f..c7a1e19e94e5 100644 --- a/jaxlib/py_executable.cc +++ b/jaxlib/py_executable.cc @@ -100,6 +100,10 @@ uint64_t GetBaseLaunchId(std::optional fingerprint, ret += executable->devices()->fingerprint(); } #endif + VLOG(1) << "Get base launch id: " << ret << " from fingerprint: " + << (fingerprint.has_value() + ? absl::StrCat(tsl::Fingerprint64(*fingerprint)) + : ""); return ret; } @@ -540,7 +544,7 @@ int32_t PyLoadedExecutable::GetNextLaunchId() { launch_id = absl::bit_cast(it->second++); } VLOG(1) << "Launching executable " << ifrt_loaded_executable_->name() - << " with launch ID: " << launch_id; + << " with launch ID: " << launch_id << " key: " << launch_id_key_; #if JAX_IFRT_VERSION_NUMBER >= 37 VLOG(2) << "Executable devices for launch ID " << launch_id << ": " << (ifrt_loaded_executable_->devices().has_value() From 1d4be40bc98a85dcf8ff9a63699ebb7824ce6d81 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 11 Dec 2025 08:47:55 -0800 Subject: [PATCH 164/315] [test] assertAllClose: improve error message on failure --- jax/_src/test_util.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 246ae14e6023..503f100efdd4 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1339,14 +1339,16 @@ def assertAllClose(self, actual, desired, *, check_dtypes=True, atol=None, rtol= rtol=rtol, canonicalize_dtypes=canonicalize_dtypes, err_msg=err_msg) elif is_sequence(actual) and not hasattr(actual, '__array__'): - self.assertTrue(is_sequence(desired) and not hasattr(desired, '__array__')) + self.assertTrue(is_sequence(desired) and not hasattr(desired, '__array__'), + msg=f"Expected sequence, got {desired}") self.assertEqual(len(actual), len(desired)) for actual_elt, desired_elt in zip(actual, desired): self.assertAllClose(actual_elt, desired_elt, check_dtypes=check_dtypes, atol=atol, rtol=rtol, canonicalize_dtypes=canonicalize_dtypes, err_msg=err_msg) elif hasattr(actual, '__array__') or np.isscalar(actual): - self.assertTrue(hasattr(desired, '__array__') or np.isscalar(desired)) + self.assertTrue(hasattr(desired, '__array__') or np.isscalar(desired), + msg=f"Expected array-like, got {desired}") if check_dtypes: self.assertDtypesMatch(actual, desired, canonicalize_dtypes=canonicalize_dtypes) actual = np.asarray(actual) From 4efd7828b041f4d1a9cdd8b5c61a31cda378414a Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 11 Dec 2025 10:48:36 -0800 Subject: [PATCH 165/315] Fix _split_transpose_rule to correctly instantiate zeros PiperOrigin-RevId: 843289556 --- jax/_src/lax/lax.py | 6 ++---- tests/shard_map_test.py | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index c598b6d4e61f..ca0e2b1f0a8f 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -6728,10 +6728,8 @@ def _split_transpose_rule(cotangents, operand, *, sizes, axis): assert ad.is_undefined_primal(operand) if all(type(t) is ad_util.Zero for t in cotangents): return ad_util.Zero(operand.aval), - cotangents = [ - _zeros(t.aval) if type(t) is ad_util.Zero else t - for t in cotangents - ] + cotangents = [t.instantiate() if type(t) is ad_util.Zero else t + for t in cotangents] return concatenate(cotangents, dimension=axis), def _split_batch_rule(batched_args, batch_dims, *, sizes, axis): diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 6e0e56818c08..742c976f3ca6 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -4679,6 +4679,22 @@ def g(x, y): self.assertEqual(out2.sharding, NamedSharding(mesh, P(None, unreduced={'x'}))) + @jtu.with_explicit_mesh((2,), 'x') + def test_split_with_unused_result_in_shardmap(self, mesh): + arr = jax.device_put(jnp.ones(8), P('x')) + + @jax.shard_map(in_specs=P('x'), out_specs=P('x')) + def f(x): + a, _ = jnp.split(x, 2, axis=0) # Important that one result is unused. + return a + + def g(x): + a = f(x) + b = 0.1 * a.mean(keepdims=True) + return b.squeeze(0) + + jax.jit(jax.grad(g))(arr) # doesn't crash + class FunSpec(NamedTuple): name: str From 1c0bb945e8b3872a38d390e897bea0b8a76675cc Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 11 Dec 2025 11:15:28 -0800 Subject: [PATCH 166/315] Colocated python perf optimization. PiperOrigin-RevId: 843301183 --- jax/experimental/colocated_python/func.py | 306 ++++++++++++++++------ 1 file changed, 221 insertions(+), 85 deletions(-) diff --git a/jax/experimental/colocated_python/func.py b/jax/experimental/colocated_python/func.py index a7fe5a52ba39..220f0cfdf540 100644 --- a/jax/experimental/colocated_python/func.py +++ b/jax/experimental/colocated_python/func.py @@ -17,11 +17,12 @@ from collections.abc import Callable, Sequence import dataclasses -import functools import inspect import random import threading from typing import Any +import uuid +import weakref import jax from jax._src import api @@ -375,41 +376,154 @@ def specialized_func(*args, **kwargs): return specialized_func -class _CachedGetSpecializedFunction: - """Manages cached versions of `_uncached_get_specialized_func`. +class _SpecializedCollection: + """Collection of specialized functions for a single unspecialized function. - This class holds a collection of caches, each identified by a unique ID, and - presents itself as a single cache to JAX's `register_backend_cache`. One can - clear individual caches identified by the UID, using the `cache_remove(uid)` - method. JAX's `clear_backend_cache()` will clear all caches. + The `get()` method retrieves the specialized function for the provided input + spec, either by looking up a cache or by compiling the specialized function. + + Looking up a cache with an input spec as a key can be slow, because + `Sharding`'s equivalence comparison is slow. Instead, we maintain two caches + for the same value: we use the ID of the sharding object (via `WeakSpec`) as + the key in one cache, and the corresponding strong references to the sharding + object (via `StrongSpec`) as the key in another cache. Looking up the + `WeakSpec`-keyed cache is fast. Note that the ID integer in the `WeakSpec` + cache will remain valid as long as a strong-ref exists in the `StrongSpec` + cache. + + The `StrongSpec`-keyed cache is unbounded, while the `WeakSpec`-keyed cache + is LRU(1): if there is a miss in the `WeakSpec` cache but a hit in the + `StrongSpec` cache, the strong-ref is the `StrongSpec` cache and the ID + integer in the `WeakSpec` cache are both updated. """ + @dataclasses.dataclass(slots=True, unsafe_hash=True) + class WeakSpec: + """WeakSpec stores just the `id()` of the input spec sharding.""" + + dtypes: tuple[jax.numpy.dtype, ...] + shapes: tuple[tuple[int, ...], ...] + sharding_ids: tuple[int, ...] + treedef: tree_util.PyTreeDef + + def __init__( + self, args_leaves: Sequence[jax.Array], treedef: tree_util.PyTreeDef + ): + self.dtypes = tuple(x.dtype for x in args_leaves) + self.shapes = tuple(x.shape for x in args_leaves) + self.sharding_ids = tuple(id(x.sharding) for x in args_leaves) + self.treedef = treedef + + @dataclasses.dataclass(slots=True, unsafe_hash=True) + class StrongSpec: + """StrongSpec stores the full input spec sharding.""" + + in_specs_treedef: tree_util.PyTreeDef | None = None + in_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None + + def __init__( + self, args_leaves: Sequence[jax.Array], pytreedef: tree_util.PyTreeDef + ): + self.in_specs_leaves = tuple(_get_spec(x) for x in args_leaves) + self.in_specs_treedef = pytreedef + def __init__(self): + CompiledId = int + + self._weak_to_id: dict[_SpecializedCollection.WeakSpec, CompiledId] = {} + self._id_to_weak: dict[CompiledId, _SpecializedCollection.WeakSpec] = {} + self._strong_to_id: dict[_SpecializedCollection.StrongSpec, CompiledId] = {} + self._id_to_compiled: dict[CompiledId, Callable[..., Any]] = {} + + self._counter = 0 + self._mu = threading.Lock() + + def get( + self, + args_leaves: Sequence[jax.Array], + pytreedef: tree_util.PyTreeDef, + func_info: FunctionInfo, + specialization: Specialization, + ) -> Callable[..., Any]: + # TODO(hyeontaek): Allow Python values in args_leaves, similar to the todo + # in _get_spec(). + + # Attempt fast-path cache hit. + weak_spec = _SpecializedCollection.WeakSpec(args_leaves, pytreedef) + compiled_id = self._weak_to_id.get(weak_spec) + if compiled_id is not None: + return self._id_to_compiled[compiled_id] + + with self._mu: + # Attempt slow-path cache hit. + strong_spec = _SpecializedCollection.StrongSpec(args_leaves, pytreedef) + compiled_id = self._strong_to_id.pop(strong_spec, None) + if compiled_id is not None: + # Update the caches so that the fast-path cache stores the `id()` of the + # shardings presented by the current invocation. + old_weak = self._id_to_weak.pop(compiled_id) + del self._weak_to_id[old_weak] + + self._strong_to_id[strong_spec] = compiled_id + self._weak_to_id[weak_spec] = compiled_id + self._id_to_weak[compiled_id] = weak_spec + + return self._id_to_compiled[compiled_id] + + # Cache-miss: compile. + if specialization.devices is None: + result = _uncached_get_specialized_func( + func_info, + specialization.update( + in_specs_treedef=strong_spec.in_specs_treedef, + in_specs_leaves=strong_spec.in_specs_leaves, + devices=_infer_devices_from_args(args_leaves), + ), + ) + else: + result = _uncached_get_specialized_func( + func_info, + specialization.update( + in_specs_treedef=strong_spec.in_specs_treedef, + in_specs_leaves=strong_spec.in_specs_leaves, + ), + ) + + compiled_id = self._counter + self._counter += 1 + + self._weak_to_id[weak_spec] = compiled_id + self._strong_to_id[strong_spec] = compiled_id + self._id_to_weak[compiled_id] = weak_spec + self._id_to_compiled[compiled_id] = result + return result + + +class _JaxSecondLevelCaches: + """Manages second-level caches registered as a single cache with JAX.""" + + def __init__(self, name: str): self._lock = threading.Lock() - self._caches: dict[int, Any] = {} - jax_register_backend_cache(self, "colocated_python_specialized_func_cache") + self._callbacks: dict[int, Callable[..., Any]] = {} + jax_register_backend_cache(self, name) def cache_clear(self): - self._caches.clear() + """Meant to be invoked by JAX internals.""" + for callback in self._callbacks.values(): + callback() + self._callbacks.clear() - def cache_remove(self, held_by: int): + def register_second_level( + self, uid: int, cache_clear_callback: Callable[..., Any] + ): + self._callbacks[uid] = cache_clear_callback + + def remove_second_level(self, uid: int): try: - self._caches.pop(held_by) + self._callbacks.pop(uid) except KeyError: pass - def get(self, held_by: int) -> Callable[..., Any]: - with self._lock: - try: - return self._caches[held_by] - except KeyError: - cache = functools.cache(_uncached_get_specialized_func) - self._caches[held_by] = cache - return cache - - -_SINGLETON_CACHED_GET_SPECIALIZED_FUNCTION = _CachedGetSpecializedFunction() - class _CachedColocatedFunctionMaker: """Function maker for colocated Python functions. @@ -418,20 +532,32 @@ class _CachedColocatedFunctionMaker: reused, until the cache is dropped. """ + JAX_CACHE = _JaxSecondLevelCaches("colocated_python_specialized_func_cache") + def __init__(self, held_by: int | None): - self.held_by = held_by - if held_by is None: - self._get_specialized_func = jax._src.util.cache( - max_size=None, trace_context_in_key=False - )(_uncached_get_specialized_func) - else: - self._get_specialized_func = ( - _SINGLETON_CACHED_GET_SPECIALIZED_FUNCTION.get(held_by) - ) + self.held_by = held_by if held_by is not None else uuid.uuid4().int + specialized_collections: list[_SpecializedCollection] = [] + specialized_functions: list[Callable[..., Any]] = [] + + def clear_caches(): + specialized_collections.clear() + specialized_functions.clear() + + _CachedColocatedFunctionMaker.JAX_CACHE.register_second_level( + self.held_by, + clear_caches, + ) + self.specialized_collections = specialized_collections + self.specialized_functions = specialized_functions def __del__(self): - if self.held_by is not None: - _SINGLETON_CACHED_GET_SPECIALIZED_FUNCTION.cache_remove(self.held_by) + self.specialized_collections.clear() + self.specialized_functions.clear() + try: + _CachedColocatedFunctionMaker.JAX_CACHE.remove_second_level(self.held_by) + except AttributeError: + # Ignore error during python finalization. + pass def _make_callable( self, @@ -482,6 +608,14 @@ def specialize( ), ) + # Caches for a collection of specialized functions or a specialized function + # itself. The latter is used as a performance optimization when the input + # spec is explicitly specified and can skip a collection lookup. The caches + # use weakrefs so that we avoid creating cyclic references. + specialized_collections_wref = lambda: None + specialized_functions_wref = lambda: None + wref_mu = threading.Lock() + @api_boundary def __call__(*args, **kwargs): """Executes the given Python function on the same devices as the arguments or as specialized. @@ -489,63 +623,65 @@ def __call__(*args, **kwargs): If the callable has not been specialized with output shapes and shardings (see `specialize` above), the very first call will run synchronously to discover output shapes and shardings, and will run asynchronously after. - If - specialized with output shapes and shardings, every execution of the + If specialized with output shapes and shardings, every execution of the callable will be asynchronous. """ args_leaves, in_specs_treedef = tree_util.tree_flatten((args, kwargs)) - in_specs_leaves = tuple(_get_spec(x) for x in args_leaves) - if specialization.in_specs_treedef is None: - # Allow input polymorphism by applying input_specs specialization - # temporarily for this call. - return self._make_callable( - info, - specialization.update( - in_specs_treedef=in_specs_treedef, - in_specs_leaves=in_specs_leaves, - ), - )(*args, **kwargs) - - if specialization.devices is None: - devices = _infer_devices_from_args(args_leaves) - if devices is None: - raise ValueError( - "No devices found. colocated_python function without input" - " arguments must be first specialized with devices." - ) - # Allow device polymorphism by applying devices specialization temporarily - # for this call. - return self._make_callable( - info, - specialization.update(devices=devices), - )(*args, **kwargs) - - # Assertion is added to silence mypy error: Unsupported operand types for != - # ("PyTreeDef" and "None") [operator] - assert isinstance(specialization.in_specs_treedef, tree_util.PyTreeDef) - - # If input_specs is known, verify that it matches actual inputs. - if ( - specialization.in_specs_treedef != in_specs_treedef - or specialization.in_specs_leaves != in_specs_leaves - ): + no_input = len(args_leaves) == 0 + if no_input and specialization.devices is None: raise ValueError( - "Input specs in specialization and input specs of arguments must" - " have the same pytree structure, but they have the following" - " structural differences:\n" - + ( - "\n".join( - f" - {tree_util.keystr(path)} is a {thing1} in value 1" - f" and a {thing2} in value 2, so {explanation}.\n" - for path, thing1, thing2, explanation in tree_util.equality_errors_pytreedef( - specialization.in_specs_treedef, in_specs_treedef - ) - ) - ) + "No devices found. colocated_python function without input" + " arguments must be first specialized with devices." ) - return self._get_specialized_func(info, specialization)(*args, **kwargs) + fully_specified_in_spec = ( + specialization.in_specs_treedef is not None + and specialization.in_specs_leaves is not None + ) + + if not fully_specified_in_spec and not no_input: + # We need to handle input polymorphism + nonlocal specialized_collections_wref + with wref_mu: + collection: _SpecializedCollection = specialized_collections_wref() + if collection is None: + collection = _SpecializedCollection() + self.specialized_collections.append(collection) + specialized_collections_wref = weakref.ref(collection) + result = collection.get( + args_leaves, in_specs_treedef, info, specialization + )(*args, **kwargs) + del collection + return result + + # No input polymorphism -- exactly one compiled function is possible. + with wref_mu: + nonlocal specialized_functions_wref + func: Callable[..., Any] = specialized_functions_wref() + if func is None: + if fully_specified_in_spec and specialization.devices is not None: + func = _uncached_get_specialized_func(info, specialization) + elif fully_specified_in_spec: + func = _uncached_get_specialized_func( + info, + specialization.update( + devices=_infer_devices_from_args(args_leaves) + ), + ) + elif no_input: + func = _uncached_get_specialized_func( + info, + specialization.update( + in_specs_leaves=tuple(), + in_specs_treedef=in_specs_treedef, + ), + ) + self.specialized_functions.append(func) + specialized_functions_wref = weakref.ref(func) + result = func(*args, **kwargs) + del func + return result __call__ = wraps(info.fun)(__call__) __call__.specialize = specialize From d7e026f992e361b43e50005f278e7cee8ba92308 Mon Sep 17 00:00:00 2001 From: Aditya Jha Date: Wed, 26 Nov 2025 06:18:19 +0000 Subject: [PATCH 167/315] Add entropy function to jax.scipy.stats.poisson MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements Shannon entropy calculation for Poisson distribution with three-regime computational strategy for numerical stability: - Small μ < 10: Direct PMF summation with adaptive bounds - Medium 10 ≤ μ < 100: Adaptive bounds based on standard deviation - Large μ ≥ 100: Asymptotic Stirling approximation Features: - Matches SciPy behavior with <1e-6 relative error - Handles loc parameter for API compatibility - Supports broadcasting for array inputs - Returns NaN for invalid μ values (μ ≤ 0) - JIT-compatible with static bounds - Compatible with both JAX_ENABLE_X64=0 and JAX_ENABLE_X64=1 Closes #29596 --- docs/jax.scipy.rst | 1 + jax/_src/scipy/stats/poisson.py | 127 +++++++++++++++++++++++++++++++- jax/scipy/stats/poisson.py | 1 + tests/scipy_stats_test.py | 55 ++++++++++++++ 4 files changed, 182 insertions(+), 2 deletions(-) diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index aad146eb71c5..6cf14389adcd 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -463,6 +463,7 @@ jax.scipy.stats.poisson logpmf pmf cdf + entropy jax.scipy.stats.t ~~~~~~~~~~~~~~~~~ diff --git a/jax/_src/scipy/stats/poisson.py b/jax/_src/scipy/stats/poisson.py index bf314842d3ff..bb2f9399dfdc 100644 --- a/jax/_src/scipy/stats/poisson.py +++ b/jax/_src/scipy/stats/poisson.py @@ -17,8 +17,8 @@ from jax._src import lax from jax._src import numpy as jnp from jax._src.lax.lax import _const as _lax_const -from jax._src.numpy.util import promote_args_inexact -from jax._src.scipy.special import xlogy, gammaln, gammaincc +from jax._src.numpy.util import promote_args_inexact, promote_dtypes_inexact, ensure_arraylike +from jax._src.scipy.special import xlogy, entr, gammaln, gammaincc from jax._src.typing import Array, ArrayLike @@ -114,3 +114,126 @@ def cdf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array: x = lax.sub(k, loc) p = gammaincc(jnp.floor(1 + x), mu) return jnp.where(lax.lt(x, zero), zero, p) + +def entropy(mu: ArrayLike, loc: ArrayLike = 0) -> Array: + r"""Shannon entropy of the Poisson distribution. + + JAX implementation of :obj:`scipy.stats.poisson` ``entropy``. + + The entropy :math:`H(X)` of a Poisson random variable + :math:`X \sim \text{Poisson}(\mu)` is defined as: + + .. math:: + + H(X) = -\sum_{k=0}^\infty p(k) \log p(k) + + where :math:`p(k) = e^{-\mu} \mu^k / k!` for + :math:`k \geq \max(0, \lfloor \text{loc} \rfloor)`. + + This implementation uses **regime switching** for numerical stability + and performance: + + - **Small** :math:`\mu < 10`: Direct summation over PMF with adaptive + upper bound :math:`k \leq \mu + 20` + - **Medium** :math:`10 \leq \mu < 100`: Summation with bound + :math:`k \leq \mu + 10\sqrt{\mu} + 20` + - **Large** :math:`\mu \geq 100`: Asymptotic Stirling approximation: + :math:`H(\mu) \approx \frac{1}{2} \log(2\pi e \mu) - \frac{1}{12\mu}` + + Matches SciPy to relative error :math:`< 10^{-5}` across all regimes. + + Args: + mu: arraylike, mean parameter of the Poisson distribution. + Must be ``> 0``. + loc: arraylike, optional location parameter (default: 0). + Accepted for API compatibility with scipy but does not + affect the entropy + + Returns: + Array of entropy values with shape broadcast from ``mu`` and ``loc``. + Returns ``NaN`` for ``mu <= 0``. + + Examples: + >>> from jax.scipy.stats import poisson + >>> poisson.entropy(5.0) + Array(2.204394, dtype=float32) + >>> poisson.entropy(jax.numpy.array([1, 10, 100])) + Array([1.3048419, 2.5614073, 3.7206903], dtype=float32) + + See Also: + - :func:`jax.scipy.stats.poisson.pmf` + - :func:`jax.scipy.stats.poisson.logpmf` + - :obj:`scipy.stats.poisson` + """ + mu, loc = ensure_arraylike("poisson.entropy", mu, loc) + promoted_mu, promoted_loc = promote_dtypes_inexact(mu, loc) + + #Note: loc does not affect the entropy - translation invariant + #it has only been taken to maintain compatibility with scipy api + result_shape = jnp.broadcast_shapes( + promoted_mu.shape, + promoted_loc.shape + ) + + mu_flat = jnp.ravel(promoted_mu) + zero_result = jnp.zeros_like(mu_flat) + + + # Choose the computation regime based on mu value + result = jnp.where( + mu_flat == 0, + zero_result, + jnp.where( + mu_flat < 10, + _entropy_small_mu(mu_flat), + jnp.where( + mu_flat < 100, + _entropy_medium_mu(mu_flat), + _entropy_large_mu(mu_flat) + ) + ) + ) + + result_mu_shape = jnp.reshape(result, promoted_mu.shape) + + # Restore original shape + return jnp.broadcast_to(result_mu_shape, result_shape) + +def _entropy_small_mu(mu: Array) -> Array: + """Entropy via direct PMF summation for small μ (< 10). + Uses adaptive upper bound k ≤ μ + 20 to capture >99.999% of mass. + """ + max_k = 35 + + k = jnp.arange(max_k, dtype=mu.dtype)[:, None] + probs = pmf(k, mu, 0) + + # Mask: only compute up to mu + 20 for each value + upper_bounds = jnp.ceil(mu + 20).astype(k.dtype) + mask = k < upper_bounds[None, :] + probs_masked = jnp.where(mask, probs, 0.0) + + return jnp.sum(entr(probs_masked), axis=0) + +def _entropy_medium_mu(mu: Array) -> Array: + """Entropy for medium mu (10-100): Adaptive bounds based on std dev. + + Bounds: k ≤ μ + 10√μ + 20. Caps at k=250 for JIT compatibility. + """ + max_k = 250 # Static bound for JIT. For mu<100, upper bound < 220 + + k = jnp.arange(max_k, dtype=mu.dtype)[:, None] + probs = pmf(k, mu, 0) + + upper_bounds = jnp.ceil(mu + 10 * jnp.sqrt(mu) + 20).astype(k.dtype) + mask = k < upper_bounds[None, :] + probs_masked = jnp.where(mask, probs, 0.0) + + return jnp.sum(entr(probs_masked), axis=0) + +def _entropy_large_mu(mu: Array) -> Array: + """Entropy for large mu (>= 100): Asymptotic approximation. + + Formula: H(λ) ≈ 0.5*log(2πeλ) - 1/(12λ) + O(λ^-2) + """ + return 0.5 * jnp.log(2 * np.pi * np.e * mu) - 1.0 / (12 * mu) diff --git a/jax/scipy/stats/poisson.py b/jax/scipy/stats/poisson.py index 5fcde905f89b..ac7cfa141063 100644 --- a/jax/scipy/stats/poisson.py +++ b/jax/scipy/stats/poisson.py @@ -19,4 +19,5 @@ logpmf as logpmf, pmf as pmf, cdf as cdf, + entropy as entropy ) diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index d38260a7522f..e9d1b77b7cf8 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -2070,5 +2070,60 @@ def testSEM(self, shape, dtype, axis, ddof, nan_policy, keepdims): atol=tol) self._CompileAndCheck(lax_fun, args_maker, atol=tol) + @jtu.sample_product( + shape=[(), (5,), (3, 4)], + dtype=jtu.dtypes.floating, + ) + def testPoissonEntropy(self, shape, dtype): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.poisson.entropy + lax_fun = lsp_stats.poisson.entropy + + args_maker = lambda: [rng(shape, dtype)] + + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker,check_dtypes=False, rtol=1e-4) + self._CompileAndCheck(lax_fun, args_maker, rtol=1e-4) + + @genNamedParametersNArgs(2) + def testPoissonEntropyWithLoc(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = lambda mu, loc: osp_stats.poisson.entropy(mu, loc=loc) + lax_fun = lambda mu, loc: lsp_stats.poisson.entropy(mu, loc) + + args_maker = lambda: [rng(shapes[0], dtypes[0]), rng(shapes[1], dtypes[1])] + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, rtol=1e-4) + self._CompileAndCheck(lax_fun, args_maker, rtol=1e-4) + + @jtu.sample_product( + dtype=jtu.dtypes.floating, + ) + def testPoissonEntropyEdgeCases(self, dtype): + """Test edge cases: invalid mu and very small mu""" + # Invalid mu (should return NaN) + invalid_mu = jnp.array([-1.0, 0.0, -5.0], dtype=dtype) + jax_result = lsp_stats.poisson.entropy(invalid_mu) + scipy_result = osp_stats.poisson.entropy(np.array(invalid_mu)) + self.assertAllClose(jax_result, scipy_result, check_dtypes=False, rtol=1e-4) + + # Very small mu + small_mu = jnp.array([0.01, 0.1, 0.5], dtype=dtype) + jax_result = lsp_stats.poisson.entropy(small_mu) + scipy_result = osp_stats.poisson.entropy(np.array(small_mu)) + self.assertAllClose(jax_result, scipy_result,check_dtypes=False, rtol=1e-4) + + @jtu.sample_product( + dtype=jtu.dtypes.floating, + ) + def testPoissonEntropyRegimes(self, dtype): + """Test all three computational regimes""" + mu = jnp.array([2.0, 5.0, 9.0, 15.0, 50.0, 99.0, 100.0, 200.0, 500.0], dtype=dtype) + scipy_fun = lambda m: osp_stats.poisson.entropy(m) + lax_fun = lambda m: lsp_stats.poisson.entropy(m) + args_maker = lambda: [mu] + + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker,check_dtypes=False, rtol=1e-4) + self._CompileAndCheck(lax_fun, args_maker, rtol=1e-4) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From f3d83cd55ecc87dc08cb3a39a42a499dacc84d3d Mon Sep 17 00:00:00 2001 From: Davis Yoshida Date: Thu, 11 Dec 2025 12:51:01 -0800 Subject: [PATCH 168/315] Support Hijax types in emit_pipeline. PiperOrigin-RevId: 843337588 --- jax/_src/pallas/core.py | 21 ++ jax/_src/pallas/mosaic/pipeline.py | 138 ++++++++---- jax/_src/pallas/mosaic/primitives.py | 86 ++++++++ jax/_src/pallas/primitives.py | 11 + jax/_src/state/types.py | 5 +- jax/experimental/hijax.py | 1 + jaxlib/mosaic/dialect/tpu/tpu_dialect.cc | 22 +- jaxlib/mosaic/dialect/tpu/tpu_dialect.h | 2 +- tests/pallas/BUILD | 1 + tests/pallas/tpu_pallas_pipeline_test.py | 255 ++++++++++++++++++++++- 10 files changed, 494 insertions(+), 48 deletions(-) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 32ceb1c21dc2..6c995002976b 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -1365,6 +1365,27 @@ def _get_sds(aval: jax_core.AbstractValue): core_map_p = jax_core.Primitive("core_map") core_map_p.multiple_results = True +def _core_map_is_high(*avals, jaxpr, **params): + del avals, params + return jaxpr.is_high +core_map_p.is_high = _core_map_is_high # type: ignore[method-assign] + +def _core_map_to_lojax(*consts, jaxpr, mesh, **params): + closed_hi_jaxpr = jax_core.ClosedJaxpr(jaxpr, consts) + with ( + tracing_grid_env(tuple(mesh.shape.values()), mapped_dims=()), + jax_core.extend_axis_env_nd(mesh.shape.items()), + ): + closed_lo_jaxpr = pe.lower_jaxpr(closed_hi_jaxpr) + assert not closed_lo_jaxpr.is_high + return core_map_p.bind( + *closed_lo_jaxpr.consts, + jaxpr=closed_lo_jaxpr.jaxpr, + mesh=mesh, + **params, + ) +core_map_p.to_lojax = _core_map_to_lojax + def core_map( mesh, diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 9bd15e18f17d..e519a323ca20 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -23,6 +23,7 @@ from typing import Any, Union import jax +from jax import core as jax_core from jax import lax from jax import tree_util from jax._src import util as jax_util @@ -30,11 +31,11 @@ from jax._src.pallas import primitives as primitives from jax._src.pallas.mosaic import core as tpu_core from jax._src.pallas.mosaic import helpers as tpu_helpers -from jax._src.pallas.mosaic import tpu_info from jax._src.pallas.mosaic import primitives as tpu_primitives +from jax._src.pallas.mosaic import tpu_info +from jax._src.state import types as state_types from jax.experimental import pallas as pl import jax.numpy as jnp -import numpy as np SMEM = tpu_core.MemorySpace.SMEM @@ -79,17 +80,32 @@ def add_leaves(i, x): def _get_tpu_generation() -> int: return tpu_info.get_tpu_info().generation -def _make_tiling(shape: tuple[int, ...], dtype: np.dtype) -> tuple[int, ...]: - # For a n-dimensional shape, returns (8, 128) for the last 2 dimensions - # and 1 for the leading n - 2. For example, (256, 256) -> (8, 128) and - # (2, 3, 128, 128) -> (1, 1, 8, 128). + +def _make_tiling( + shape: tuple[int, ...], ty: jax_core.AbstractValue +) -> tuple[int | None, ...]: + """Compute a tiling for the given shape and type. + + For a n-dimensional shape, returns (8, 128) for the last 2 dimensions + and 1 for the leading n - 2. For example, (256, 256) -> (8, 128) and + (2, 3, 128, 128) -> (1, 1, 8, 128). + + Types are not required to have a dtype, so for such types we return None for + all dimensions because their tiling is unknown. + """ + if len(shape) < 2: raise ValueError(f"Shape must have at least 2 dimensions: {shape=}") + + if not hasattr(ty, 'dtype'): + return (None,) * len(shape) + leading_dims, final_dims = shape[:-2], shape[-2:] # We want to find the minimum power of 2 that fits the second-minor dimension # of shape, with maximum value 8. second_minor, _ = final_dims - packing = 4 // dtype.itemsize + + packing = 4 // ty.dtype.itemsize max_tiling = _TILING[0] second_minor_tiling = (1 + int(_get_tpu_generation() < 4)) * packing while second_minor_tiling < min(second_minor, max_tiling): @@ -114,13 +130,18 @@ def _make_block_ds( assert isinstance(out, pl.Slice) return out -def _create_blocked_slice(block_index: jax.Array | int, - block_size: int, - dim_size: int, - tiling: int): + +def _create_blocked_slice( + block_index: jax.Array | int, + block_size: int, + dim_size: int, + tiling: int | None, +): block_start = block_size * block_index if (dim_rem := dim_size % block_size) == 0: return pl.ds(block_start, block_size) + if tiling is None: + raise ValueError("If tiling is None, block_size must divide dim_size.") if block_size % tiling != 0: raise ValueError(f"Block size must divide tiling: {block_size=}, {tiling=}") num_blocks = pl.cdiv(dim_size, block_size) @@ -137,12 +158,15 @@ def _create_bounded_slice(slice_start: jax.Array | int, slice_size: jax.Array | int, block_size: int, dim_size: int, - tiling: int): - if block_size % tiling != 0: + tiling: int | None): + if tiling is not None and block_size % tiling != 0: raise ValueError(f"Block size must divide tiling: {block_size=}, {tiling=}") # We assume by construction that slice_size <= block_size. We also assume # that the slice_start is already aligned to the tiling. + if tiling is None: + return pl.ds(slice_start, slice_size) + # If we are out of bound, we need to round the slice size down to the nearest # multiple of the tiling. is_oob = slice_start + slice_size > dim_size @@ -157,7 +181,7 @@ def _create_bounded_slice(slice_start: jax.Array | int, def _make_block_slice( block_index: jax.Array, block_size: pl.BlockDim | int | None, size: int, - tiling: int + tiling: int | None ) -> pl.Slice | slice | int | jax.Array: # Computes a slice given a block index and block size. In the default case, # we return slice(block_index * block_size, (block_index + 1) * block_size). @@ -332,7 +356,7 @@ def block_shape(self) -> Sequence[pl.BlockDim | int | None] | None: def compute_index(self): return self.spec.index_map - def get_dma_slice(self, src_shape, src_dtype, grid_indices): + def get_dma_slice(self, src_ty, grid_indices): # We need to handle blocks that might go OOB in the src array. An in bounds # block looks like this (for array shape (600, 600) and block shape # (256, 256)): @@ -379,10 +403,14 @@ def get_dma_slice(self, src_shape, src_dtype, grid_indices): # Suppose A is now (601, 600), instead of picking a (88, 256)-sized block # for the last iteration on that dimension, we will pick the next highest # tile multiple, i.e. (96, 256). + + if (src_shape := getattr(src_ty, "shape", None)) is None: + raise ValueError(f'Type {src_ty} does not have a type.') + if len(src_shape) < 2: raise NotImplementedError("Must use >1D values.") - tiling = _make_tiling(src_shape, src_dtype) + tiling = _make_tiling(src_shape, src_ty) block_indices = self.compute_index(*grid_indices) return tuple( _make_block_slice(bi, bs, ss, t) @@ -403,6 +431,14 @@ def with_spec(self, spec: pl.BlockSpec) -> BufferedRefBase: """Returns a new BufferedRefBase with the given block spec.""" raise NotImplementedError() +def _ref_to_value_aval(ref): + """Return the inner of a ref, or a ShapedArray for TransformedRefs.""" + return ( + jax_core.ShapedArray(shape=ref.shape, dtype=ref.dtype) + if isinstance(ref, state_types.TransformedRef) + else jax.typeof(ref).inner_aval + ) + # TODO(justinfu): Refactor and rename slot fields to reflect cumulative values # instead of slot index. @@ -413,7 +449,6 @@ class BufferedRef(BufferedRefBase): Attributes: spec: pallas blockspec. - dtype: dtype for buffers. buffer_type: enum indicating whether this is an input, output, or in/out accumulator buffered reference. window_ref: a multiple-buffer to hold the working and dirty buffers used @@ -444,7 +479,6 @@ class BufferedRef(BufferedRefBase): copy. """ _spec: pl.BlockSpec = dataclasses.field(metadata=dict(static=True)) - dtype: Any = dataclasses.field(metadata=dict(static=True)) _buffer_type: BufferType = dataclasses.field(metadata=dict(static=True)) window_ref: ArrayRef | None accum_ref: ArrayRef | None @@ -507,7 +541,7 @@ def buffer_types() -> type[BufferType]: return BufferType @classmethod - def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count, + def create(cls, spec: pl.BlockSpec, dtype_or_type, buffer_type, buffer_count, needs_swap_ref=True, grid_rank=None, use_lookahead=False, @@ -516,7 +550,8 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count, Args: spec: pallas blockspec. - dtype: dtype for buffers. + dtype_or_type: dtype or aval for buffers. If an aval, the shape is + ignored. buffer_type: enum indicating whether this is an input, output, or in/out accumulator buffered reference. needs_swap_ref: whether a swap slots tracker needs to be allocated. @@ -527,9 +562,18 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count, Returns: Initialized BufferedRef """ + + # (123, 456) is a dummy shape since we never use ty without + # calling .update(shape=...) first. + ty = ( + dtype_or_type + if isinstance(dtype_or_type, jax_core.AbstractValue) + else jax_core.ShapedArray((123, 456), dtype_or_type) + ) + block_shape = _get_block_shape(spec) if buffer_type is BufferType.ACCUMULATOR: - accum_ref = VMEM(block_shape, dtype) + accum_ref = VMEM.from_type(ty.update(shape=block_shape)) else: accum_ref = None if source_memory_space == VMEM: @@ -541,7 +585,6 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count, f"Cannot hold a non-buffered ref in {spec.memory_space=}") return cls( _spec=spec, - dtype=dtype, _buffer_type=buffer_type, window_ref=None, # to be bound to existing ref by the pipeline routine accum_ref=accum_ref, @@ -570,11 +613,12 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count, raise ValueError( "grid_rank must be specified when use_lookahead is True." ) + + buffer_ty = ty.update(shape=(buffer_count, *block_shape)) return cls( _spec=spec, - dtype=dtype, _buffer_type=buffer_type, - window_ref=buffer_memory_space((buffer_count,) + block_shape, dtype), + window_ref=buffer_memory_space.from_type(buffer_ty), accum_ref=accum_ref, copy_in_slot=SMEM((1,), jnp.uint32) if buffer_type.is_input else None, wait_in_slot=SMEM((1,), jnp.uint32) if buffer_type.is_input else None, @@ -601,22 +645,28 @@ def create(cls, spec: pl.BlockSpec, dtype, buffer_type, buffer_count, ) @classmethod - def input(cls, spec, dtype, buffer_count=2, **kwargs): - return cls.create(spec, dtype, BufferType.INPUT, buffer_count, **kwargs) + def input(cls, spec, dtype_or_type, buffer_count=2, **kwargs): + return cls.create( + spec, dtype_or_type, BufferType.INPUT, buffer_count, **kwargs + ) @classmethod - def output(cls, spec, dtype, buffer_count=2, **kwargs): - return cls.create(spec, dtype, BufferType.OUTPUT, buffer_count, **kwargs) + def output(cls, spec, dtype_or_type, buffer_count=2, **kwargs): + return cls.create( + spec, dtype_or_type, BufferType.OUTPUT, buffer_count, **kwargs + ) @classmethod - def accumulator(cls, spec, dtype, buffer_count=2, **kwargs): - return cls.create(spec, dtype, BufferType.ACCUMULATOR, buffer_count, - **kwargs) + def accumulator(cls, spec, dtype_or_type, buffer_count=2, **kwargs): + return cls.create( + spec, dtype_or_type, BufferType.ACCUMULATOR, buffer_count, **kwargs + ) @classmethod - def input_output(cls, spec, dtype, buffer_count=2, **kwargs): - return cls.create(spec, dtype, BufferType.INPUT_OUTPUT, buffer_count, - **kwargs) + def input_output(cls, spec, dtype_or_type, buffer_count=2, **kwargs): + return cls.create( + spec, dtype_or_type, BufferType.INPUT_OUTPUT, buffer_count, **kwargs + ) @property def block_shape(self): @@ -923,7 +973,7 @@ def copy_in(self, src_ref, grid_indices): if self.swap is not None: self.swap[0] = True slot = self.current_copy_in_slot - src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices) + src_slice = self.get_dma_slice(_ref_to_value_aval(src_ref), grid_indices) dst_slice = tuple( pl.ds(0, s.size) for s, bd in zip(src_slice, self.block_shape) @@ -944,7 +994,7 @@ def copy_out(self, dst_ref, grid_indices): if self.swap is not None: self.swap[0] = True slot = self.current_copy_out_slot - dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices) + dst_slice = self.get_dma_slice(_ref_to_value_aval(dst_ref), grid_indices) src_slice = tuple( pl.ds(0, s.size) for s, bd in zip(dst_slice, self.block_shape) @@ -962,7 +1012,7 @@ def wait_in(self, src_ref, grid_indices): if not self.is_buffered: return assert not (self.window_ref is None or isinstance(self.window_ref, REF)) assert self.sem_recvs is not None - src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices) + src_slice = self.get_dma_slice(_ref_to_value_aval(src_ref), grid_indices) dst_slice = tuple( pl.ds(0, s.size) for s, bd in zip(src_slice, self.block_shape) @@ -984,7 +1034,7 @@ def wait_out(self, dst_ref, grid_indices): assert not (self.window_ref is None or isinstance(self.window_ref, REF)) assert self.sem_sends is not None wait_slot = self.current_wait_out_slot - dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices) + dst_slice = self.get_dma_slice(_ref_to_value_aval(dst_ref), grid_indices) src_slice = tuple( pl.ds(0, s.size) for s, bd in zip(dst_slice, self.block_shape) @@ -1682,7 +1732,9 @@ def make_input_bref(in_spec, in_ref): use_lookahead = in_spec.pipeline_mode.use_lookahead if use_lookahead and grid is None: raise ValueError("Grid must be specified when using lookahead.") - return BufferedRef.input(in_spec, in_ref.dtype, buffer_count, + + in_aval = _ref_to_value_aval(in_ref) + return BufferedRef.input(in_spec, in_aval, buffer_count, needs_swap_ref=needs_swap_ref, grid_rank=len(grid), use_lookahead=use_lookahead, @@ -1695,11 +1747,13 @@ def make_output_bref(out_spec, out_ref, accumulate): if out_spec.pipeline_mode.use_lookahead: raise ValueError("Output buffering does not support lookahead.") + out_aval = _ref_to_value_aval(out_ref) + if accumulate: - return BufferedRef.accumulator(out_spec, out_ref.dtype, buffer_count, + return BufferedRef.accumulator(out_spec, out_aval, buffer_count, needs_swap_ref=needs_swap_ref, source_memory_space=out_ref.memory_space) - return BufferedRef.output(out_spec, out_ref.dtype, buffer_count, + return BufferedRef.output(out_spec, out_aval, buffer_count, needs_swap_ref=needs_swap_ref, source_memory_space=out_ref.memory_space) out_brefs = jax.tree.map( @@ -1817,7 +1871,7 @@ def sync_copy(src: REF | BufferedRef, dst: REF | BufferedRef, indices): bref = dst hbm_ref = src copy_in = True - hbm_slice = bref.get_dma_slice(hbm_ref.shape, hbm_ref.dtype, indices) + hbm_slice = bref.get_dma_slice(_ref_to_value_aval(hbm_ref), indices) bref_slice = tuple( pl.ds(0, s.size) for s, bd in zip(hbm_slice, bref.block_shape) diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index e472253fb231..47a107368f96 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -383,6 +383,52 @@ def _get_dma_effects( dma_start_p = jax_core.Primitive('dma_start') dma_start_p.multiple_results = True +def _dma_is_high(*avals, **params): + return any(aval.is_high for aval in avals) + +dma_start_p.is_high = _dma_is_high # type: ignore[method-assign] + +def _dma_start_to_lojax(*args, tree, device_id_type, priority, add): + ( + src_ref, + src_transforms, + dst_ref, + dst_transforms, + dst_sem, + dst_sem_transforms, + src_sem, + src_sem_transforms, + device_id, + ) = tree_util.tree_unflatten(tree, args) + src_ref_aval = jax_core.get_aval(src_ref) + dst_ref_aval = jax_core.get_aval(dst_ref) + if not (src_ref_aval.is_high and dst_ref_aval.is_high): + raise NotImplementedError("dma_start not implemented in LoJAX yet.") + dst_sem_aval = jax_core.get_aval(dst_sem) + if dst_sem_aval.is_high: + raise NotImplementedError("dma_start not implemented in LoJAX yet.") + if src_sem is not None: + if jax_core.get_aval(src_sem).is_high: + raise NotImplementedError("dma_start not implemented in LoJAX yet.") + src_transformed_ref = state.TransformedRef(src_ref, src_transforms) + dst_transformed_ref = state.TransformedRef(dst_ref, dst_transforms) + if src_sem is not None: + src_sem = state.TransformedRef(src_sem, src_sem_transforms) + dst_sem = state.TransformedRef(dst_sem, dst_sem_transforms) + + src_ref_aval.inner_aval.dma_start( + src_transformed_ref, + dst_transformed_ref, + src_sem, + dst_sem, + device_id=device_id, + priority=priority, + device_id_type=device_id_type, + add=add + ) + return [] +dma_start_p.to_lojax = _dma_start_to_lojax + @dma_start_p.def_effectful_abstract_eval def _dma_start_abstract_eval(*args, tree, device_id_type, priority, add): if priority < 0: @@ -646,6 +692,46 @@ def do_discharge_src_sem(src_sem=src_sem): dma_wait_p = jax_core.Primitive('dma_wait') dma_wait_p.multiple_results = True +dma_wait_p.is_high = _dma_is_high # type: ignore[method-assign] + +def _dma_wait_to_lojax(*args, tree, device_id_type): + ( + src_ref, + src_transforms, + dst_ref, + dst_transforms, + dst_sem, + dst_sem_transforms, + src_sem, + src_sem_transforms, + device_id, + ) = tree_util.tree_unflatten(tree, args) + src_ref_aval = jax_core.get_aval(src_ref) + dst_ref_aval = jax_core.get_aval(dst_ref) + if not (src_ref_aval.is_high and dst_ref_aval.is_high): + raise NotImplementedError("dma_wait not implemented in LoJAX yet.") + dst_sem_aval = jax_core.get_aval(dst_sem) + if dst_sem_aval.is_high: + raise NotImplementedError("dma_wait not implemented in LoJAX yet.") + if src_sem is not None: + if jax_core.get_aval(src_sem).is_high: + raise NotImplementedError("dma_wait not implemented in LoJAX yet.") + src_transformed_ref = state.TransformedRef(src_ref, src_transforms) + dst_transformed_ref = state.TransformedRef(dst_ref, dst_transforms) + if src_sem is not None: + src_sem = state.TransformedRef(src_sem, src_sem_transforms) + dst_sem = state.TransformedRef(dst_sem, dst_sem_transforms) + src_ref_aval.inner_aval.dma_wait( + src_transformed_ref, + dst_transformed_ref, + src_sem, + dst_sem, + device_id=device_id, + device_id_type=device_id_type, + ) + return [] +dma_wait_p.to_lojax = _dma_wait_to_lojax + @dma_wait_p.def_effectful_abstract_eval def _dma_wait_abstract_eval(*args, tree, device_id_type): del device_id_type diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index b612c139acda..4ae79d0769e6 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -854,6 +854,17 @@ def wrap_with_transforms(f, transforms, *args): run_scoped_p = jax_core.Primitive("run_scoped") run_scoped_p.multiple_results = True +def _run_scoped_is_high(*avals, jaxpr, **params): + del avals, params + return jaxpr.is_high +run_scoped_p.is_high = _run_scoped_is_high # type: ignore[method-assign] + +def _run_scoped_to_lojax(*args, jaxpr, **params): + closed_hi_jaxpr = jax_core.ClosedJaxpr(jaxpr, args) + closed_lo_jaxpr = pe.lower_jaxpr(closed_hi_jaxpr) + consts = closed_lo_jaxpr.consts + return run_scoped_p.bind(*consts, jaxpr=closed_lo_jaxpr.jaxpr, **params) +run_scoped_p.to_lojax = _run_scoped_to_lojax def run_scoped( f: Callable[..., Any], diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index 2644f8392416..e9f589a10f27 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -409,7 +409,10 @@ def is_high(self): return self.inner_aval.is_high def lo_ty(self): - return map(AbstractRef, self.inner_aval.lo_ty()) + return [ + AbstractRef(x, memory_space=self.memory_space) + for x in self.inner_aval.lo_ty() + ] def lower_val(self, ref): if not self.is_high: diff --git a/jax/experimental/hijax.py b/jax/experimental/hijax.py index 5e5bb0512c79..087569ae9234 100644 --- a/jax/experimental/hijax.py +++ b/jax/experimental/hijax.py @@ -36,4 +36,5 @@ ) from jax._src.state import ( AbstractRef as AbstractRef, + TransformedRef as TransformedRef ) diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc index 3bfde20840a3..26780d42252c 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc @@ -475,6 +475,15 @@ MemRefType getMemRefType(Value value) { return cast(value.getType()); } +template +bool checkBothOperandsDivisible(Value value, int64_t divisor, int64_t fuel) { + if (auto op = value.getDefiningOp()) { + return isGuaranteedDivisible(op.getLhs(), divisor, fuel / 2) && + isGuaranteedDivisible(op.getRhs(), divisor, (fuel + 1) / 2); + } + return false; +} + bool isGuaranteedDivisible(Value value, int64_t divisor, int64_t fuel) { if (fuel <= 0) { return false; @@ -497,9 +506,16 @@ bool isGuaranteedDivisible(Value value, int64_t divisor, int64_t fuel) { if (auto cast_op = value.getDefiningOp()) { return isGuaranteedDivisible(cast_op.getOperand(), divisor, fuel - 1); } - if (auto add_op = value.getDefiningOp()) { - return isGuaranteedDivisible(add_op.getRhs(), divisor, fuel / 2) && - isGuaranteedDivisible(add_op.getLhs(), divisor, (fuel + 1) / 2); + if (checkBothOperandsDivisible(value, divisor, fuel) || + checkBothOperandsDivisible(value, divisor, fuel) || + checkBothOperandsDivisible(value, divisor, fuel) || + checkBothOperandsDivisible(value, divisor, fuel)) { + return true; + } + if (auto select_op = value.getDefiningOp()) { + return isGuaranteedDivisible(select_op.getTrueValue(), divisor, fuel / 2) && + isGuaranteedDivisible(select_op.getFalseValue(), divisor, + (fuel + 1) / 2); } return false; } diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h index 16f710836f0b..9d9fcf624a40 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h @@ -77,7 +77,7 @@ LogicalResult specializeMemorySpace(TypedValue value, // vector ops. This functions inverts the layout erasure applied to the value. MemRefType getMemRefType(Value value); -bool isGuaranteedDivisible(Value value, int64_t divisor, int64_t fuel = 8); +bool isGuaranteedDivisible(Value value, int64_t divisor, int64_t fuel = 128); DotDimensionNumbersAttr defaultDimensionNumbers(Builder &builder, bool transpose_lhs, diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 4128e6dffda2..0d03cc9b5444 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -538,6 +538,7 @@ jax_multiplatform_test( "notsan", # Times out. ], deps = [ + "//jax/experimental:hijax", "//jax/experimental:mesh_utils", "//jax/experimental:pallas_tpu", "//jax/experimental:pallas_tpu_ops", diff --git a/tests/pallas/tpu_pallas_pipeline_test.py b/tests/pallas/tpu_pallas_pipeline_test.py index 42f7674f7adc..257b1a474cb6 100644 --- a/tests/pallas/tpu_pallas_pipeline_test.py +++ b/tests/pallas/tpu_pallas_pipeline_test.py @@ -14,6 +14,7 @@ """Test TPU-specific extensions to pallas_call.""" +import dataclasses import functools from absl.testing import absltest from absl.testing import parameterized @@ -21,10 +22,14 @@ import hypothesis.strategies as hps import jax from jax import lax +from jax._src import hijax +from jax._src import shard_map +from jax._src import state from jax._src import test_util as jtu +from jax._src.state import indexing +from jax._src.state import primitives as state_primitives from jax.experimental import mesh_utils from jax.experimental import pallas as pl -from jax._src import shard_map from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp import numpy as np @@ -2189,5 +2194,253 @@ def f(x, slices): np.testing.assert_allclose(out, x) +class PipelineHijaxTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + if not jtu.is_device_tpu_at_least(4): + self.skipTest('Only works on TPU v4+.') + + def test_emit_pipeline_hijax(self): + @dataclasses.dataclass(frozen=True) + class ArrayTuple: + x0: jax.Array + x1: jax.Array + + @property + def shape(self): + assert self.x0.shape == self.x1.shape + return self.x0.shape + + @property + def dtype(self): + assert self.x0.dtype == self.x1.dtype + return self.x0.dtype + + @dataclasses.dataclass(frozen=True) + class ShapedArrayTuple(hijax.HiType): + shape: tuple[int, ...] + dtype: jnp.dtype + + update = dataclasses.replace + + def lo_ty(self) -> list[hijax.ShapedArray]: + return [hijax.ShapedArray(self.shape, self.dtype)] * 2 + + def lower_val(self, hi_val: ArrayTuple) -> list[jax.Array]: + return [hi_val.x0, hi_val.x1] + + def raise_val(self, x0, x1) -> ArrayTuple: + return ArrayTuple(x0, x1) + + def ref_get_abstract_eval(self, ref_aval, *args, tree): + arr_aval = hijax.ShapedArray(self.shape, self.dtype) + updated_ref = ref_aval.update(inner_aval=arr_aval) + out, effects = state_primitives.get_p.abstract_eval( + updated_ref, *args, tree=tree + ) + assert isinstance(out, hijax.ShapedArray) + return ShapedArrayTuple(out.shape, out.dtype), effects + + def ref_get_to_lojax( + self, ref: state.TransformedRef | jax.Ref, idx: indexing.NDIndexer + ): + tup_ref, transforms = ref._refs, ref.transforms # pylint: disable=protected-access + assert isinstance(transforms, tuple) + transforms += (idx,) + + flat_transforms, tree = jax.tree.flatten(transforms) + x0_out = state_primitives.get_p.bind( + tup_ref.x0, *flat_transforms, tree=tree + ) + x1_out = state_primitives.get_p.bind( + tup_ref.x1, *flat_transforms, tree=tree + ) + return ShapedArrayTuple(x0_out, x1_out).raise_val(x0_out, x1_out) + + def ref_swap_abstract_eval(self, ref_aval, val_aval, *args, tree): + arr_aval = hijax.ShapedArray(self.shape, self.dtype) + val_arr_aval = hijax.ShapedArray(val_aval.shape, val_aval.dtype) + updated_ref = ref_aval.update(inner_aval=arr_aval) + out_aval, effects = state_primitives.swap_p.abstract_eval( + updated_ref, val_arr_aval, *args, tree=tree + ) + assert isinstance(out_aval, hijax.ShapedArray) + return ShapedArrayTuple(out_aval.shape, out_aval.dtype), effects + + def ref_swap_to_lojax( + self, + ref: state.TransformedRef | jax.Ref, + val: ArrayTuple, + idx: indexing.NDIndexer, + ): + tup_ref, transforms = ref._refs, ref.transforms # pylint: disable=protected-access + assert isinstance(transforms, tuple) + transforms += (idx,) + + flat_transforms, tree = jax.tree.flatten(transforms) + x0_out = state_primitives.swap_p.bind( + tup_ref.x0, val.x0, *flat_transforms, tree=tree + ) + x1_out = state_primitives.swap_p.bind( + tup_ref.x1, val.x1, *flat_transforms, tree=tree + ) + return self.raise_val(x0_out, x1_out) + + def lower_block_spec( + self, block_spec: pl.BlockSpec + ) -> list[pl.BlockSpec]: + return [block_spec, block_spec] + + def dma_start( + self, + src_ref: state.TransformedRef, + dst_ref: state.TransformedRef, + src_sem: state.TransformedRef, + dst_sem: state.TransformedRef, + device_id: jax.Array | int | None, + device_id_type: pl.DeviceIdType, + priority: int, + add: bool, + ) -> None: + del add + src_aval = jax.typeof(src_ref.ref).inner_aval + assert isinstance(src_aval, ShapedArrayTuple) + dst_aval = jax.typeof(dst_ref.ref).inner_aval + assert isinstance(dst_aval, ShapedArrayTuple) + + src_ref, src_transforms = src_ref.ref._refs, src_ref.transforms # pylint: disable=protected-access + dst_ref, dst_transforms = dst_ref.ref._refs, dst_ref.transforms # pylint: disable=protected-access + + def _run_dma( + src_ref, + dst_ref, + src_sem, + dst_sem, + device_id, + device_id_type, + priority, + ): + if src_sem is not None: + desc = pltpu.make_async_remote_copy( + src_ref, + dst_ref, + src_sem, + dst_sem, + device_id=device_id, + device_id_type=device_id_type, + ) + else: + assert device_id is None + desc = pltpu.make_async_copy(src_ref, dst_ref, dst_sem) + desc.start(priority=priority) + + src_x0_ref, src_x1_ref = src_ref.x0, src_ref.x1 + dst_x0_ref, dst_x1_ref = dst_ref.x0, dst_ref.x1 + + _run_dma( + state.TransformedRef(src_x0_ref, src_transforms), + state.TransformedRef(dst_x0_ref, dst_transforms), + src_sem, + dst_sem, + device_id, + device_id_type, + priority, + ) + _run_dma( + state.TransformedRef(src_x1_ref, src_transforms), + state.TransformedRef(dst_x1_ref, dst_transforms), + src_sem, + dst_sem, + device_id, + device_id_type, + priority, + ) + + def dma_wait( + self, src_ref, dst_ref, src_sem, dst_sem, device_id, device_id_type + ): + assert isinstance(jax.typeof(src_ref.ref).inner_aval, ShapedArrayTuple) + assert isinstance(jax.typeof(dst_ref.ref).inner_aval, ShapedArrayTuple) + + src_ref, src_transforms = src_ref.ref._refs, src_ref.transforms # pylint: disable=protected-access + dst_ref, dst_transforms = dst_ref.ref._refs, dst_ref.transforms # pylint: disable=protected-access + + def _run_dma( + src_ref, dst_ref, src_sem, dst_sem, device_id, device_id_type + ): + if src_sem is not None: + desc = pltpu.make_async_remote_copy( + src_ref, + dst_ref, + src_sem, + dst_sem, + device_id=device_id, + device_id_type=device_id_type, + ) + else: + assert device_id is None + desc = pltpu.make_async_copy(src_ref, dst_ref, dst_sem) + desc.wait() + + src_x0_ref, src_x1_ref = src_ref.x0, src_ref.x1 + dst_x0_ref, dst_x1_ref = dst_ref.x0, dst_ref.x1 + + _run_dma( + state.TransformedRef(src_x0_ref, src_transforms), + state.TransformedRef(dst_x0_ref, dst_transforms), + src_sem, + dst_sem, + device_id, + device_id_type, + ) + _run_dma( + state.TransformedRef(src_x1_ref, src_transforms), + state.TransformedRef(dst_x1_ref, dst_transforms), + src_sem, + dst_sem, + device_id, + device_id_type, + ) + + hijax.register_hitype( + ArrayTuple, lambda q: ShapedArrayTuple(q.shape, q.dtype) + ) + + def kernel(x_hbm_ref, o_hbm_ref): + def body(x_ref, o_ref): + o_ref[...] = x_ref[...] + + num_steps = 4 + block_shape = (x_hbm_ref.shape[0] // num_steps, x_hbm_ref.shape[1]) + + pltpu.emit_pipeline( + body, + grid=(num_steps,), + in_specs=(pl.BlockSpec(block_shape, lambda i: (i, 0)),), + out_specs=pl.BlockSpec(block_shape, lambda i: (i, 0)), + )(x_hbm_ref, o_hbm_ref) + + inp = ArrayTuple( + jnp.arange(32 * 128, dtype=jnp.int32).reshape((32, 128)), + jnp.arange(32 * 128, dtype=jnp.int32).reshape((32, 128)), + ) + + out_ty = ShapedArrayTuple( + inp.shape, + inp.dtype, + ) + + out = pl.pallas_call( + kernel, + in_specs=(pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),), + out_shape=out_ty, + out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), + )(inp) + + np.testing.assert_allclose(out.x0, inp.x0) + np.testing.assert_allclose(out.x1, inp.x1) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 952e0d6dbd316dd898788618d0fc47e4507ba8a7 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Thu, 11 Dec 2025 13:53:40 -0800 Subject: [PATCH 169/315] [Pallas TPU] Add dma_granule_size_bytes to SC info PiperOrigin-RevId: 843361419 --- jax/_src/pallas/mosaic/sc_core.py | 2 +- jax/_src/pallas/mosaic/tpu_info.py | 64 +++++++++++++++++++----------- 2 files changed, 41 insertions(+), 25 deletions(-) diff --git a/jax/_src/pallas/mosaic/sc_core.py b/jax/_src/pallas/mosaic/sc_core.py index 2d5670e8461a..8f3001f25730 100644 --- a/jax/_src/pallas/mosaic/sc_core.py +++ b/jax/_src/pallas/mosaic/sc_core.py @@ -152,7 +152,7 @@ class BlockMapping(pallas_core.BlockMapping): def get_sparse_core_info() -> tpu_info.SparseCoreInfo: """Returns the SparseCore information for the current device.""" return tpu_info.get_tpu_info().sparse_core or tpu_info.SparseCoreInfo( - num_cores=0, num_subcores=0, num_lanes=0 + num_cores=0, num_subcores=0, num_lanes=0, dma_granule_size_bytes=0, ) diff --git a/jax/_src/pallas/mosaic/tpu_info.py b/jax/_src/pallas/mosaic/tpu_info.py index 767b8d1c7fb8..3159fe0d1a98 100644 --- a/jax/_src/pallas/mosaic/tpu_info.py +++ b/jax/_src/pallas/mosaic/tpu_info.py @@ -20,8 +20,8 @@ from jax import numpy as jnp from jax._src import dtypes -from jax._src.pallas.mosaic import core from jax._src import util as jax_util +from jax._src.pallas.mosaic import core class ChipVersionBase: @@ -41,12 +41,15 @@ class ChipVersion(ChipVersionBase, enum.Enum): def __str__(self) -> str: return self.value + @dataclasses.dataclass(frozen=True, kw_only=True) class SparseCoreInfo: """SparseCore-specific information.""" + num_cores: int num_subcores: int num_lanes: int + dma_granule_size_bytes: int @dataclasses.dataclass(frozen=True, kw_only=True) @@ -122,10 +125,7 @@ def is_matmul_supported( or (lhs_dt in {U4, S4} and rhs_dt in {U4, S4}) ) case 7: - return ( - lhs_dt in {F32, BF16} - and rhs_dt in {F32, BF16} - ) or ( + return (lhs_dt in {F32, BF16} and rhs_dt in {F32, BF16}) or ( lhs_dt in {F32, BF16, F8E5M2, F8E4M3FN} and rhs_dt in {F8E5M2, F8E4M3FN} ) @@ -172,6 +172,7 @@ def is_tpu_device() -> bool: registry: dict[str, Callable[[], TpuInfo]] = {} + @jax_util.cache(trace_context_in_key=True) def get_tpu_info() -> TpuInfo: """Returns the TPU hardware information for the current device. @@ -302,7 +303,12 @@ def get_tpu_info() -> TpuInfo: int8_ops_per_second=int(9.18e14 // num_chip_cores), fp8_ops_per_second=0, # Not Available int4_ops_per_second=int(1.84e15 // num_chip_cores), - sparse_core=SparseCoreInfo(num_cores=4, num_subcores=16, num_lanes=8), + sparse_core=SparseCoreInfo( + num_cores=4, + num_subcores=16, + num_lanes=8, + dma_granule_size_bytes=32, + ), ) case "TPU v6 lite" | "TPU v6e": # 1 TensorCore per chip return TpuInfo( @@ -321,29 +327,39 @@ def get_tpu_info() -> TpuInfo: int8_ops_per_second=int(1.84e15), fp8_ops_per_second=int(9.20e14), int4_ops_per_second=int(3.68e15), - sparse_core=SparseCoreInfo(num_cores=2, num_subcores=16, num_lanes=8), + sparse_core=SparseCoreInfo( + num_cores=2, + num_subcores=16, + num_lanes=8, + dma_granule_size_bytes=32, + ), ) case "TPU7x": num_cores = core.get_num_device_cores() num_chip_cores = 2 return TpuInfo( - chip_version=ChipVersion.TPU_7X, - generation=7, - num_cores=num_cores, - num_lanes=128, - num_sublanes=8, - mxu_column_size=256, - vmem_capacity_bytes=64 * 1024 * 1024, # 64 MiB per core - cmem_capacity_bytes=0, - smem_capacity_bytes=1024 * 1024, # 1 MiB per core - hbm_capacity_bytes=206_000_000_000 // num_chip_cores, - mem_bw_bytes_per_second=int(7.40e12 // num_chip_cores), - bf16_ops_per_second=int(2.31e15 // num_chip_cores), - int8_ops_per_second=0, # Not Available - fp8_ops_per_second=int(4.60e15 // num_chip_cores), - int4_ops_per_second=0, # Not Available - sparse_core=SparseCoreInfo(num_cores=4, num_subcores=16, num_lanes=16), - ) + chip_version=ChipVersion.TPU_7X, + generation=7, + num_cores=num_cores, + num_lanes=128, + num_sublanes=8, + mxu_column_size=256, + vmem_capacity_bytes=64 * 1024 * 1024, # 64 MiB per core + cmem_capacity_bytes=0, + smem_capacity_bytes=1024 * 1024, # 1 MiB per core + hbm_capacity_bytes=206_000_000_000 // num_chip_cores, + mem_bw_bytes_per_second=int(7.40e12 // num_chip_cores), + bf16_ops_per_second=int(2.31e15 // num_chip_cores), + int8_ops_per_second=0, # Not Available + fp8_ops_per_second=int(4.60e15 // num_chip_cores), + int4_ops_per_second=0, # Not Available + sparse_core=SparseCoreInfo( + num_cores=4, + num_subcores=16, + num_lanes=16, + dma_granule_size_bytes=64, + ), + ) case _ as d: if d in registry: return registry[d]() From 75e7243359fc71f8540b664f7e43d6f9ed446da7 Mon Sep 17 00:00:00 2001 From: Dougal Date: Wed, 10 Dec 2025 21:21:31 +0000 Subject: [PATCH 170/315] [no-thunks] Implement FlatTree, an internal representation of pytrees. A FlatTree stores a treedef and a flat list of values. It's meant to be isomorphic to the corresponding pytree but we can map/zip over it more easily. Compared to `tree_map`, FlatTree.map has these benefits: 1. It doesn't touch user flatten/unflatten code (which shouldn't have side effects but sometimes does in practice). 2. It can be faster, because it skips the recursive traversal. 3. It actually obeys the functor rules, which lets us write things like `flat_tree.map(lambda x: (f(x), g(x))).unzip2()` whereas an ordinary `tree_map` would change the tree structure due to the returned tuple. --- jax/_src/checkify.py | 27 +-- jax/_src/interpreters/partial_eval.py | 36 ++- jax/_src/lax/control_flow/conditionals.py | 69 +++--- jax/_src/lax/control_flow/loops.py | 266 +++++++++++----------- jax/_src/lax/control_flow/solves.py | 78 +++---- jax/_src/tree_util.py | 116 ++++++++++ 6 files changed, 347 insertions(+), 245 deletions(-) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 3b76ea8724de..62fa882dd006 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -47,7 +47,7 @@ from jax._src.interpreters import partial_eval as pe from jax._src.partition_spec import PartitionSpec as P from jax._src.tree_util import tree_flatten -from jax._src.tree_util import tree_map +from jax._src.tree_util import tree_map, FlatTree from jax._src.tree_util import tree_unflatten from jax._src.typing import Array from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip, @@ -770,11 +770,11 @@ def fun_wrapped(*invals): return (error, out), error_effects debug_info = jaxpr.jaxpr.debug_info.with_unknown_names() - checked_jaxpr, full_out_tree = pe.trace_to_jaxpr( - fun_wrapped, None, flat_err_and_in_vals, debug_info) - out_tree, error_effects_treedef = full_out_tree.children() - error_effects = error_effects_treedef.unflatten(()).val - return checked_jaxpr, out_tree, error_effects + args_avals = FlatTree.flatten((flat_err_and_in_vals, {})) + checked_jaxpr, full_out_avals = pe.trace_to_jaxpr(fun_wrapped, args_avals, debug_info) + out_avals, error_effects = full_out_avals.unpack() + error_effects = error_effects.unflatten().val + return checked_jaxpr, out_avals.tree, error_effects def cond_error_check(error: Error, enabled_errors, index, *ops, branches, **params): @@ -856,9 +856,10 @@ def new_body_f(*c_consts_and_vals): lax.dce_sink(cond_f(*c_consts, *out)) return out c_consts_avals = cond_jaxpr.in_avals[:c_consts_num] + jaxpr, _ = pe.trace_to_jaxpr( - new_body_f, None, - (*c_consts_avals, *body_jaxpr.in_avals), + new_body_f, + FlatTree.flatten(((*c_consts_avals, *body_jaxpr.in_avals), {})), debug_info=body_jaxpr.jaxpr.debug_info.with_unknown_names()) err_vals, err_tree = jtu.tree_flatten(error) err_vals = map(core.get_aval, err_vals) @@ -1010,8 +1011,8 @@ def expand_errors_leading_dim(*xs): with core.extend_axis_env_nd(mesh.shape.items()), config._check_vma(check_vma): checked_jaxpr, _ = pe.trace_to_jaxpr( - expand_errors_leading_dim, None, - tuple(checked_jaxpr.in_avals), + expand_errors_leading_dim, + FlatTree.flatten((tuple(checked_jaxpr.in_avals), {})), debug_info=checked_jaxpr.jaxpr.debug_info) # Update shard_map params to account for extra error values. @@ -1238,15 +1239,15 @@ def checkify(f: Callable[..., Out], @traceback_util.api_boundary def checked_fun(*args, **kwargs): # close over all arguments so they're not turned into abstract values. - in_tree = jtu.tree_structure(()) + in_avals = FlatTree.flatten(((), {})) closed_f = lambda: f(*args, **kwargs) # stage: debug_info = api_util.debug_info("checkify", f, args, kwargs).with_unknown_names() - jaxpr_, out_tree = pe.trace_to_jaxpr(closed_f, in_tree, (), debug_info) + jaxpr_, out_avals = pe.trace_to_jaxpr(closed_f, in_avals, debug_info) jaxpr, consts = pe.separate_consts(jaxpr_) # checkify: error, out_flat = checkify_jaxpr(jaxpr, errors, init_error, *consts) - return error, jtu.tree_unflatten(out_tree, out_flat) + return error, out_avals.update_from_list(out_flat).unflatten() return checked_fun def check(pred: Bool, msg: str, diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 0dfffb0f3efa..c92010022c27 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -43,8 +43,7 @@ mapped_aval, unmapped_aval, get_referent, JaxprEqnContext, typeof) from jax._src.source_info_util import SourceInfo from jax._src.state.types import AbstractRef, ReadEffect -from jax._src.tree_util import (PyTreeDef, treedef_tuple, - tree_flatten, tree_unflatten) +from jax._src.tree_util import PyTreeDef, treedef_tuple, FlatTree from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list, merge_lists, partition_list, OrderedSet, as_hashable_function, weakref_lru_cache, @@ -2293,35 +2292,34 @@ def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals): @weakref_lru_cache def trace_to_jaxpr( fun: Callable, - in_tree: PyTreeDef | None, - in_avals_flat: tuple[AbstractValue | core.AvalQDD, ...], + in_avals: FlatTree[AbstractValue | core.AvalQDD], # (args, kwargs) pair debug_info: core.DebugInfo ) -> tuple[ClosedJaxpr, PyTreeDef]: - config.enable_checks.value and debug_info.assert_arg_names(len(in_avals_flat)) + config.enable_checks.value and debug_info.assert_arg_names(len(in_avals)) parent_trace = core.trace_ctx.trace trace = DynamicJaxprTrace(debug_info, parent_trace=parent_trace) # Name stacks are reset because the name stacks on jaxpr equations should be # rooted at the enclosing jaxpr. with core.ensure_no_leaks(trace), source_info_util.reset_name_stack(): source_info = source_info_util.current() - in_tracers = map(partial(trace.new_arg, source_info=source_info), - in_avals_flat) + in_tracers = in_avals.map(partial(trace.new_arg, source_info=source_info)) with core.set_current_trace(trace): - if in_tree is not None: - in_tracers = tree_unflatten(in_tree, in_tracers) - ans = fun(*in_tracers) - debug_info = debug_info.set_result_paths(ans) - ans_flat, out_tree = tree_flatten(ans) - - _check_returned_jaxtypes(debug_info, ans_flat) - out_tracers = map(partial(trace.to_jaxpr_tracer, source_info=source_info), ans_flat) - _check_no_returned_refs(debug_info, out_tracers) - jaxpr, consts = trace.frame.to_jaxpr(trace, out_tracers, debug_info, + args, kwargs = in_tracers.unflatten() + ans_pytree = fun(*args, **kwargs) + debug_info = debug_info.set_result_paths(ans_pytree) + ans = FlatTree.flatten(ans_pytree) + del ans_pytree, args, kwargs + + _check_returned_jaxtypes(debug_info, list(ans)) + out_tracers = ans.map(partial(trace.to_jaxpr_tracer, source_info=source_info)) + out_avals = out_tracers.map(lambda t: t.aval) + _check_no_returned_refs(debug_info, list(out_tracers)) + jaxpr, consts = trace.frame.to_jaxpr(trace, list(out_tracers), debug_info, source_info) - del trace, fun, in_tracers, out_tracers, ans, ans_flat + del trace, fun, in_tracers, out_tracers, ans config.enable_checks.value and core.check_jaxpr(jaxpr) - return ClosedJaxpr(jaxpr, consts), out_tree + return ClosedJaxpr(jaxpr, consts), out_avals # TODO(dougalm): remove in favor of `trace_to_jaxpr` @profiler.annotate_function diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index f904c08dd2db..1f30876ac742 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -24,7 +24,7 @@ from jax._src.tree_util import ( tree_flatten, tree_unflatten, tree_flatten_with_path, keystr, - equality_errors_pytreedef) + equality_errors_pytreedef, FlatTree) from jax._src import ad_util from jax._src import api_util from jax._src import config @@ -142,23 +142,23 @@ def _switch_internal( dbgs = [api_util.debug_info("switch", branch, operands, {}) for branch in branches] - ops, ops_tree = tree_flatten(operands) - ops_avals = tuple(map(core.get_aval, ops)) + args = FlatTree.flatten((operands, {})) + avals = args.map(core.get_aval) if config.mutable_array_checks.value: - api_util.check_no_aliased_ref_args(lambda: dbgs[0], ops_avals, ops) + api_util.check_no_aliased_ref_args(lambda: dbgs[0], list(avals), list(args)) - jaxprs_, out_trees = zip(*[pe.trace_to_jaxpr( - branch, ops_tree, ops_avals, dbg) for branch, dbg in zip(branches, dbgs)]) + jaxprs_, out_avalss = zip(*[pe.trace_to_jaxpr(branch, avals, dbg) + for branch, dbg in zip(branches, dbgs)]) jaxprs_, all_consts = zip(*[pe.separate_consts(j) for j in jaxprs_]) jaxprs, consts = _merge_common_consts(jaxprs_, all_consts) if config.mutable_array_checks.value: - api_util._check_no_aliased_closed_over_refs(dbgs[0], (*jaxprs[0].consts, *consts), ops) - for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])): + api_util._check_no_aliased_closed_over_refs(dbgs[0], (*jaxprs[0].consts, *consts), list(args)) + for i, (out_avals, jaxpr) in enumerate(zip(out_avalss[1:], jaxprs[1:])): _check_branch_outputs( "switch", "branch 0", f"branch{i+1}", branches[0], branches[i+1], - out_trees[0], out_tree, jaxprs[0].out_avals, jaxpr.out_avals) + out_avalss[0], out_avals) # prune passthrough outputs fwds = [pe._jaxpr_forwarding(jaxpr.jaxpr) for jaxpr in jaxprs] in_fwd = [xs[0] if len(set(xs)) == 1 else None for xs in zip(*fwds)] @@ -174,16 +174,16 @@ def _switch_internal( params = dict(branches=tuple(jaxprs)) if branches_platforms is not None: params["branches_platforms"] = branches_platforms - out = cond_p.bind(index, *consts, *ops, **params) + out = cond_p.bind(index, *consts, *args, **params) out_ = iter(out) - all_inputs = [*consts, *ops] + all_inputs = [*consts, *args] out = [ next(out_) if fwd is None else lax.asarray(all_inputs[fwd]) for fwd in in_fwd ] assert next(out_, None) is None - return tree_unflatten(out_trees[0], out) + return out_avalss[0].update_from_list(out).unflatten() @partial(api_boundary, repro_api_name="jax_cond") def cond(pred, true_fun: Callable, false_fun: Callable, *operands, @@ -267,35 +267,33 @@ def cond(pred, true_fun, false_fun, *operands): else: return false_fun(*operands) - ops, ops_tree = tree_flatten(operands) - ops_avals = tuple(map(core.get_aval, ops)) - ops_avals = tuple(core.AvalQDD(a, cur_qdd(x)) if a.has_qdd # type: ignore - else a for a, x in zip(ops_avals, ops)) - - - dbg_true_fun = api_util.debug_info("cond", true_fun, operands, {}) + args = FlatTree.flatten((operands, {})) + avals = args.map(core.get_aval) + avals = avals.map2( + lambda a, x: core.AvalQDD(a, cur_qdd(x)) if a.has_qdd else a, + args) + dbg_true = api_util.debug_info("cond", true_fun, operands, {}) if config.mutable_array_checks.value: - api_util.check_no_aliased_ref_args(lambda: dbg_true_fun, ops_avals, ops) - dbg_false_fun = api_util.debug_info("cond", false_fun, operands, {}) + api_util.check_no_aliased_ref_args(lambda: dbg_true, list(avals), list(args)) + dbg_false = api_util.debug_info("cond", false_fun, operands, {}) - true_jaxpr_, out_tree = pe.trace_to_jaxpr( - true_fun, ops_tree, ops_avals, dbg_true_fun) + true_jaxpr_, out_avals = pe.trace_to_jaxpr(true_fun, avals, dbg_true) true_jaxpr_, true_consts = pe.separate_consts(true_jaxpr_) - false_jaxpr_, false_out_tree = pe.trace_to_jaxpr( - false_fun, ops_tree, ops_avals, dbg_false_fun) + false_jaxpr_, false_out_avals = pe.trace_to_jaxpr(false_fun, avals, dbg_false) false_jaxpr_, false_consts = pe.separate_consts(false_jaxpr_) (true_jaxpr, false_jaxpr), consts = _merge_common_consts( (true_jaxpr_, false_jaxpr_), (true_consts, false_consts)) if config.mutable_array_checks.value: - api_util._check_no_aliased_closed_over_refs(dbg_true_fun, (*true_jaxpr.consts, *consts), ops) + api_util._check_no_aliased_closed_over_refs( + dbg_true, (*true_jaxpr.consts, *consts), list(args)) if any(isinstance(out_aval, AbstractRef) for out_aval in true_jaxpr.out_avals + false_jaxpr.out_avals): raise ValueError("Cannot return `Ref`s from `cond`.") _check_branch_outputs( - 'cond', 'true_fun', 'false_fun', true_fun, false_fun, out_tree, - false_out_tree, true_jaxpr.out_avals, false_jaxpr.out_avals) + 'cond', 'true_fun', 'false_fun', + true_fun, false_fun, out_avals, false_out_avals) # prune passthrough outputs true_fwds = pe._jaxpr_forwarding(true_jaxpr.jaxpr) @@ -315,24 +313,23 @@ def cond(pred, true_fun, false_fun, *operands): false_jaxpr = replace_jaxpr_effects(false_jaxpr, joined_effects) true_jaxpr = replace_jaxpr_effects(true_jaxpr, joined_effects) - out = cond_p.bind(index, *consts, *ops, branches=(false_jaxpr, true_jaxpr)) + out = cond_p.bind(index, *consts, *args, branches=(false_jaxpr, true_jaxpr)) out_ = iter(out) - all_inputs = [*consts, *ops] + all_inputs = [*consts, *args] out = [ next(out_) if fwd is None else lax.asarray(all_inputs[fwd]) for fwd in in_fwd ] assert next(out_, None) is None - return tree_unflatten(out_tree, out) + return out_avals.update_from_list(out).unflatten() def _check_branch_outputs( - api_name, name1, name2, f1, f2, out_tree1, out_tree2, out_avals1, - out_avals2) -> None: + api_name, name1, name2, f1, f2, out_avals1, out_avals2) -> None: info1 = api_util.fun_sourceinfo(f1) info2 = api_util.fun_sourceinfo(f2) try: - outs1 = tree_unflatten(out_tree1, out_avals1) + outs1 = out_avals1.unflatten() except: paths = [None] * len(out_avals1) component = lambda _: '' @@ -341,11 +338,11 @@ def _check_branch_outputs( paths, _ = unzip2(leaves_and_paths) # type: ignore component = lambda p: f' at path {keystr(p)}' if p else '' - if out_tree1 != out_tree2: + if out_avals1.tree != out_avals2.tree: diffs = [f'{name1} output{component(p)} is a {thing1} but ' f'{name2} output{component(p)} is a {thing2}, so {expl}' for p, thing1, thing2, expl - in equality_errors_pytreedef(out_tree1, out_tree2)] + in equality_errors_pytreedef(out_avals1.tree, out_avals2.tree)] if len(diffs) == 0: return # the trees may have different aux data, but structures are same diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index a9a8f5a2b74e..929a0ab679f5 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -38,7 +38,8 @@ from jax._src import util from jax._src.api_util import ( check_no_aliased_ref_args, _check_no_aliased_closed_over_refs) -from jax._src.core import ShapedArray, typeof, cur_qdd, ClosedJaxpr +from jax._src.core import ( + ShapedArray, typeof, cur_qdd, ClosedJaxpr, AbstractValue) from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -66,8 +67,8 @@ split_list_checked, unzip2, weakref_lru_cache, subs_list) from jax._src import xla_bridge as xb from jax._src.tree_util import ( - keystr, tree_flatten, tree_flatten_with_path, tree_map, tree_unflatten, - treedef_is_leaf) + keystr, tree_flatten, tree_map, tree_unflatten, + treedef_is_leaf, FlatTree, tree_leaves_with_path) import numpy as np _map = safe_map @@ -81,29 +82,14 @@ def _stack(arrs: Sequence[Array], axis: int=0) -> Array: return lax.concatenate([lax.expand_dims(arr, (axis,)) for arr in arrs], dimension=axis) -def _promote_weak_typed_inputs(in_vals, in_avals, out_avals): - """Promote weakly-typed in_vals to be compatible with out_avals. - - Args: - in_vals : flattened list of input values. - in_avals : corresponding list of avals. - out_avals : list of target output avals. - Returns: - in_vals_new : flattened list of modified in_vals with no weak types. - changed : bool; true if in_vals required modification. - """ - if len(in_vals) != len(in_avals) or len(in_avals) != len(out_avals): - # Calling function is responsible for catching this. - return in_vals, False - weak_mismatches = [i for i, (a1, a2) in enumerate(zip(in_avals, out_avals)) - if getattr(a1, 'weak_type', False) and not core.typematch(a1, a2)] - if not weak_mismatches: - return in_vals, False - for i in weak_mismatches: - new_dtype = dtypes.result_type(in_vals[i], out_avals[i]) - in_vals[i] = lax.convert_element_type(in_vals[i], new_dtype) - return in_vals, True - +def _promote_weak_typed_input( + in_val:Any, in_aval:AbstractValue, out_aval:AbstractValue + ) -> tuple[AbstractValue, bool]: + if getattr(in_aval, 'weak_type', False) and not core.typematch(in_aval, out_aval): + new_dtype = dtypes.result_type(in_val, out_aval) + return lax.convert_element_type(in_val, new_dtype), True + else: + return in_val, False ### scan @@ -215,85 +201,44 @@ def scan(f, init, xs, length=None): """ if not callable(f): raise TypeError("lax.scan: f argument should be a callable.") - xs_flat, xs_tree = tree_flatten(xs) - try: - lengths = [x.shape[0] for x in xs_flat] - except AttributeError as err: - msg = "scan got value with no leading axis to scan over: {}." - raise ValueError( - msg.format(', '.join(str(x) for x in xs_flat - if not hasattr(x, 'shape')))) from err + dbg_body = api_util.debug_info("scan", f, (init, xs), {}) + init = FlatTree.flatten(init) + xs = FlatTree.flatten(xs) + args = FlatTree.pack((init, xs)) - xs_avals = [core.get_aval(x) for x in xs_flat] + args_avals = args.map(core.get_aval) + init_avals, xs_avals = args_avals.unpack() - if not all(a.sharding.spec[0] is None for a in xs_avals): - raise ValueError('0th dimension of all xs should be replicated. Got ' - f'{", ".join(str(a.sharding.spec) for a in xs_avals)}') - - if length is not None: - try: - length = int(length) - except core.ConcretizationTypeError as err: - msg = ('The `length` argument to `scan` expects a concrete `int` value.' - ' For scan-like iteration with a dynamic length, use `while_loop`' - ' or `fori_loop`.') - raise core.ConcretizationTypeError(length, msg) from None # type: ignore[arg-type] - if not all(length == l for l in lengths): - msg = ("scan got `length` argument of {} which disagrees with " - "leading axis sizes {}.") - raise ValueError(msg.format(length, [x.shape[0] for x in xs_flat])) - else: - unique_lengths = set(lengths) - if len(unique_lengths) > 1: - msg = "scan got values with different leading axis sizes: {}." - raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat))) - elif len(unique_lengths) == 0: - msg = "scan got no values to scan over and `length` not provided." - raise ValueError(msg) - else: - length, = unique_lengths + length = _infer_scan_length(list(xs), list(xs_avals), length) if config.disable_jit.value: if length == 0: raise ValueError("zero-length scan is not supported in disable_jit() " "mode because the output type is unknown.") - carry = init + carry = init.unflatten() ys = [] maybe_reversed = reversed if reverse else lambda x: x for i in maybe_reversed(range(length)): - xs_slice = [slicing.index_in_dim(x, i, keepdims=False) for x in xs_flat] - carry, y = f(carry, tree_unflatten(xs_tree, xs_slice)) + xs_slice = xs.map(lambda x: slicing.index_in_dim(x, i, keepdims=False)) + carry, y = f(carry, xs_slice.unflatten()) ys.append(y) stack = lambda *ys: _stack(ys) stacked_y = tree_map(stack, *maybe_reversed(ys)) return carry, stacked_y - x_avals = [core.mapped_aval(length, 0, aval) for aval in xs_avals] - dbg_body = api_util.debug_info("scan", f, (init, xs), {}) - if config.mutable_array_checks.value: - in_flat, in_tree = tree_flatten((init, xs)) - in_avals = tuple(_map(core.get_aval, in_flat)) - check_no_aliased_ref_args(lambda: dbg_body, in_avals, in_flat) - - def _create_jaxpr(init): - init_flat, init_tree = tree_flatten(init) - in_flat, in_tree = tree_flatten((init, xs)) - carry_avals = tuple(_map(core.get_aval, init_flat)) - jaxpr, out_tree = pe.trace_to_jaxpr( - f, in_tree, (*carry_avals, *x_avals), debug_info=dbg_body) + check_no_aliased_ref_args(lambda: dbg_body, list(args), list(args_avals)) + + x_avals = xs_avals.map(lambda aval: core.mapped_aval(length, 0, aval)) + def _create_jaxpr(carry_avals): + new_arg_avals = FlatTree.pack(((carry_avals, x_avals), {})) + jaxpr, out_avals = pe.trace_to_jaxpr(f, new_arg_avals, dbg_body) jaxpr, consts = pe.separate_consts(jaxpr) - if config.mutable_array_checks.value: - _check_no_aliased_closed_over_refs(dbg_body, (*jaxpr.consts, *consts), in_flat) - out_tree_children = out_tree.children() - if len(out_tree_children) != 2: + if len(out_avals.unpack()) != 2: msg = "scan body output must be a pair, got {}." - raise TypeError(msg.format(tree_unflatten(out_tree, jaxpr.out_avals))) - - carry_avals_out, _ = split_list(jaxpr.out_avals, [out_tree_children[0].num_leaves]) - return (init_flat, carry_avals, carry_avals_out, init_tree, in_flat, jaxpr, - consts, out_tree, out_tree_children) + raise TypeError(msg.format(out_avals.unflatten())) + return jaxpr, out_avals, consts # The carry input and output avals must match exactly. However, we want to account for # the case when init contains weakly-typed values (e.g. Python scalars), with avals that @@ -303,18 +248,21 @@ def _create_jaxpr(init): # TODO(dougalm): this two-pass stuff is expensive (exponential in scan nesting # depth) and incomplete (because in the general case it takes more than two passes). # Let's get rid of it, perhaps after getting rid of weak types altogether. - init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init) - new_init_flat, changed = _promote_weak_typed_inputs(init_flat, carry_avals, carry_avals_out) - if changed: - init = tree_unflatten(init_tree, new_init_flat) - init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init) - in_flat, jaxpr, consts, out_tree, out_tree_children = rest - num_carry = len(init_flat) - num_xs = len(x_avals) - num_ys = len(jaxpr.out_avals) - num_carry - del init_flat + jaxpr, out_avals, consts = _create_jaxpr(init_avals) + if config.mutable_array_checks.value: + _check_no_aliased_closed_over_refs(dbg_body, consts, list(args)) + carry_out_avals, ys_avals = out_avals.unpack() + init, changed = init.map3( + _promote_weak_typed_input, + init_avals, carry_out_avals).unzip2() + num_carry, num_xs, num_ys = len(init), len(xs), len(ys_avals) + if any(changed): + init_avals = init.map(core.get_aval) + jaxpr, out_avals, consts = _create_jaxpr(init_avals) + carry_out_avals, ys_avals = out_avals.unpack() + + _check_carry_type('scan body', f, init_avals, carry_out_avals) - _check_carry_type('scan body', f, init, out_tree_children[0], carry_avals_out) disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(jaxpr.effects) if disallowed_effects: raise NotImplementedError( @@ -329,9 +277,12 @@ def _create_jaxpr(init): if unroll < 0: raise ValueError("`unroll` must be a `bool` or a non-negative `int`.") + args_flat = (*init.vals, *xs.vals) + # If the body forwards an input carry to an output carry, that input is # read-only and can be moved to be a const. Doing so can lead to efficiency # wins, e.g. if the scan is inside a cond with a batched predicate. + num_ys = len(jaxpr.out_avals) - num_carry carry_fwd, ext_fwd = split_list(pe._jaxpr_forwarding(jaxpr.jaxpr), [num_carry]) move_to_const = [len(consts) + i == f for i, f in enumerate(carry_fwd)] if any(move_to_const): @@ -339,7 +290,7 @@ def _create_jaxpr(init): jaxpr, [not m for m in move_to_const] + [True] * num_ys) jaxpr = pe.move_binders_to_front( jaxpr, [False] * len(consts) + move_to_const + [False] * num_xs) - in_flat, new_consts = partition_list(move_to_const + [False] * num_xs, in_flat) + args_flat, new_consts = partition_list(move_to_const + [False] * num_xs, args_flat) consts = [*new_consts, *consts] num_carry -= len(new_consts) @@ -356,30 +307,67 @@ def _create_jaxpr(init): jaxpr = pe.prune_closed_jaxpr_outputs( jaxpr, [True] * num_carry + [i is None for i in ext_to_ext_fwd]) - out = scan_p.bind(*consts, *in_flat, + out = scan_p.bind(*consts, *args_flat, reverse=reverse, length=length, jaxpr=jaxpr, num_consts=len(consts), num_carry=num_carry, - linear=(False,) * (len(consts) + len(in_flat)), + linear=(False,) * (len(consts) + len(args_flat)), unroll=unroll, _split_transpose=_split_transpose) # Apply input to output forwarding that was computed above. carry_out, out = split_list(out, [num_carry]) out_ = iter(out) - out = [next(out_) if f is None else _maybe_put(in_flat[f]) for f in ext_to_ext_fwd] + out = [next(out_) if f is None else _maybe_put(args_flat[f]) for f in ext_to_ext_fwd] assert next(out_, None) is None out = [*carry_out, *out] if any(move_to_const): out = pe.merge_lists(move_to_const + [False] * num_ys, out, new_consts) - return tree_unflatten(out_tree, out) + return out_avals.update_from_list(out).unflatten() +def _infer_scan_length( + xs_flat: list[Any], xs_avals: list[AbstractValue], + length: int | None) -> int: + try: + lengths = [x.shape[0] for x in xs_flat] + except AttributeError as err: + msg = "scan got value with no leading axis to scan over: {}." + raise ValueError( + msg.format(', '.join(str(x) for x in xs_flat + if not hasattr(x, 'shape')))) from err + + if not all(a.sharding.spec[0] is None for a in xs_avals): + raise ValueError('0th dimension of all xs should be replicated. Got ' + f'{", ".join(str(a.sharding.spec) for a in xs_avals)}') + + if length is not None: + try: + return int(length) + except core.ConcretizationTypeError as err: + msg = ('The `length` argument to `scan` expects a concrete `int` value.' + ' For scan-like iteration with a dynamic length, use `while_loop`' + ' or `fori_loop`.') + raise core.ConcretizationTypeError(length, msg) from None # type: ignore[arg-type] + if not all(length == l for l in lengths): + msg = ("scan got `length` argument of {} which disagrees with " + "leading axis sizes {}.") + raise ValueError(msg.format(length, [x.shape[0] for x in xs_flat])) + else: + unique_lengths = set(lengths) + if len(unique_lengths) > 1: + msg = "scan got values with different leading axis sizes: {}." + raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat))) + elif len(unique_lengths) == 0: + msg = "scan got no values to scan over and `length` not provided." + raise ValueError(msg) + else: + return list(unique_lengths)[0] def _capitalize(s): # s.capitalize() converts s[1:] to lowercase which we don't want. return s[0].capitalize() + s[1:] -def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals): +def _check_carry_type(name, body_fun, in_carry, out_carry): try: sig = inspect.signature(body_fun) except (ValueError, TypeError): @@ -391,18 +379,15 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals): else: component = lambda p: (f'the input carry at path {keystr(p)}' if p else 'the input carry') - leaves_and_paths, in_carry_tree = tree_flatten_with_path(in_carry) - paths, in_carry_flat = unzip2(leaves_and_paths) - in_avals = _map(core.get_aval, in_carry_flat) - if in_carry_tree != out_carry_tree: + if in_carry.tree != out_carry.tree: try: - out_carry = tree_unflatten(out_carry_tree, out_avals) + out_carry = out_carry.unflatten() except: out_carry = None if out_carry is None: - differences = (f'the input tree structure is:\n{in_carry_tree}\n' + - f'the output tree structure is:\n{out_carry_tree}\n') + differences = (f'the input tree structure is:\n{in_carry.tree}\n' + + f'the output tree structure is:\n{out_carry.tree}\n') else: diffs = [f'{component(path)} is a {thing1} but the corresponding component ' f'of the carry output is a {thing2}, so {explanation}' @@ -421,12 +406,14 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals): f"{differences}\n" "Revise the function so that the carry output has the same pytree " "structure as the carry input.") - if not all(_map(core.typematch, in_avals, out_avals)): + if not all(_map(core.typematch, in_carry, out_carry)): + # TODO(dougalm): add a way to get paths paths without roundtripping + paths, _ = unzip2(tree_leaves_with_path(in_carry.unflatten())) diffs = [f'{component(path)} has type {in_aval.str_short()}' ' but the corresponding output carry component has type ' f'{out_aval.str_short()}' f'{core.aval_mismatch_extra(in_aval, out_aval)}' - for path, in_aval, out_aval in zip(paths, in_avals, out_avals) + for path, in_aval, out_aval in zip(paths, in_carry, out_carry) if not core.typematch(in_aval, out_aval)] if len(diffs) == 0: @@ -441,7 +428,7 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals): f'applying `jax.lax.pcast(..., {tuple(out_aval.vma - in_aval.vma)},' " to='varying')` to the initial carry value corresponding to" f' {component(path)}' - for path, in_aval, out_aval in zip(paths, in_avals, out_avals) + for path, in_aval, out_aval in zip(paths, in_carry, out_carry) if not core.typematch(in_aval, out_aval) and isinstance(in_aval, ShapedArray) and isinstance(out_aval, ShapedArray) and in_aval.vma != out_aval.vma and out_aval.vma - in_aval.vma] @@ -1704,42 +1691,44 @@ def while_loop(cond_fun, body_fun, init_val): # transformation on it), so we fall back to the primitive version. pass - def _create_jaxpr(init_val): - init_vals, in_tree = tree_flatten((init_val,)) - init_avals = tuple(_map(core.get_aval, init_vals)) - cond_dbg = api_util.debug_info("while_cond", cond_fun, (init_val,), {}) - cond_jaxpr, cond_tree = pe.trace_to_jaxpr(cond_fun, in_tree, init_avals, cond_dbg) - cond_jaxpr, cond_consts = pe.separate_consts(cond_jaxpr) - body_dbg = api_util.debug_info("while_body", body_fun, (init_val,), {}) - body_jaxpr, body_tree = pe.trace_to_jaxpr(body_fun, in_tree, init_avals, body_dbg) - body_jaxpr, body_consts = pe.separate_consts(body_jaxpr) - if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1: + def _create_jaxpr(init_avals): + args_avals = FlatTree.pack(((init_avals,), {})) + cond_jaxpr, cond_out_avals = pe.trace_to_jaxpr(cond_fun, args_avals, cond_dbg) + body_jaxpr, body_out_avals = pe.trace_to_jaxpr(body_fun, args_avals, body_dbg) + if not treedef_is_leaf(cond_out_avals.tree) or len(cond_jaxpr.out_avals) != 1: msg = "cond_fun must return a boolean scalar, but got pytree {}." - raise TypeError(msg.format(cond_tree)) + raise TypeError(msg.format(cond_out_avals.tree)) + pred_aval = cond_jaxpr.out_avals[0] if (not isinstance(pred_aval, ShapedArray) or ShapedArray(pred_aval.shape, pred_aval.dtype) != ShapedArray((), np.bool_)): msg = "cond_fun must return a boolean scalar, but got output type(s) {}." raise TypeError(msg.format(cond_jaxpr.out_avals)) - return init_vals, init_avals, body_jaxpr, in_tree, cond_jaxpr, cond_consts, body_consts, body_tree + + return cond_jaxpr, body_jaxpr, body_out_avals + + cond_dbg = api_util.debug_info("while_cond", cond_fun, (init_val,), {}) + body_dbg = api_util.debug_info("while_body", body_fun, (init_val,), {}) + init_val = FlatTree.flatten(init_val) + init_aval = init_val.map(core.get_aval) # The body input and output avals must match exactly. However, we want to account for # the case when init contains weakly-typed values (e.g. Python scalars), with avals that # may not match the output despite being compatible by virtue of their weak type. # To do this, we compute the jaxpr in two passes: first with the raw inputs, and if # necessary, a second time with modified init values. - init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(init_val) - new_init_vals, changed = _promote_weak_typed_inputs( - init_vals, init_avals, body_jaxpr.out_avals) - new_init_val, = tree_unflatten(in_tree, new_init_vals) - if changed: - init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(new_init_val) - cond_jaxpr, cond_consts, body_consts, body_tree = rest - - in_tree_children = in_tree.children() - assert len(in_tree_children) == 1 - _check_carry_type('while_loop body', body_fun, new_init_val, body_tree, - body_jaxpr.out_avals) + cond_jaxpr, body_jaxpr, body_out_avals = _create_jaxpr(init_aval) + init_val, changed = init_val.map3( + _promote_weak_typed_input, + init_aval, body_out_avals).unzip2() + if any(changed): + init_aval = init_val.map(core.get_aval) + cond_jaxpr, body_jaxpr, body_out_avals = _create_jaxpr(init_aval) + + cond_jaxpr, cond_consts = pe.separate_consts(cond_jaxpr) + body_jaxpr, body_consts = pe.separate_consts(body_jaxpr) + _check_carry_type('while_loop body', body_fun, init_aval, body_out_avals) + if not all(not v.aval.has_qdd or v.initial_qdd == v.final_qdd for v in body_jaxpr.jaxpr.invars): raise TypeError("type-changing mutations not allowed in while_loop body") @@ -1760,6 +1749,7 @@ def _create_jaxpr(init_val): _, keep_cond_carry = split_list(keep_cond, [len(cond_consts)]) move_to_const = _map(operator.not_, keep_cond_carry) + init_vals = list(init_val) if any(move_to_const): cond_jaxpr = pe.close_jaxpr(cond_jaxpr_) body_jaxpr = pe.prune_closed_jaxpr_outputs( @@ -1776,7 +1766,7 @@ def _create_jaxpr(init_val): if any(move_to_const): outs = pe.merge_lists(move_to_const, outs, new_body_consts) - return tree_unflatten(body_tree, outs) + return body_out_avals.update_from_list(outs).unflatten() def _join_while_effects(body_jaxpr, cond_jaxpr, body_nconsts, cond_nconsts diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py index 5a8600f8dcc1..e38f51447c8c 100644 --- a/jax/_src/lax/control_flow/solves.py +++ b/jax/_src/lax/control_flow/solves.py @@ -30,8 +30,7 @@ from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import pxla from jax._src.traceback_util import api_boundary -from jax._src.tree_util import (tree_flatten, treedef_children, tree_leaves, - tree_unflatten, treedef_tuple) +from jax._src.tree_util import tree_leaves, FlatTree from jax._src.util import split_list, safe_map import numpy as np @@ -92,23 +91,22 @@ def custom_root(f: Callable, The result of calling solve(f, initial_guess) with gradients defined via implicit differentiation assuming ``f(solve(f, initial_guess)) == 0``. """ - guess_flat, in_args_tree = tree_flatten((initial_guess,)) - guess_avals = tuple(_map(core.get_aval, guess_flat)) + guess_flat = FlatTree.flatten(initial_guess) + guess_avals = guess_flat.map(core.get_aval) f_debug = api_util.debug_info("custom_root", f, (initial_guess,), {}) - f_jaxpr, out_tree = pe.trace_to_jaxpr( - f, in_args_tree, guess_avals, f_debug) + args_avals = FlatTree.pack(((guess_avals,),{})) + f_jaxpr, out_avals = pe.trace_to_jaxpr(f, args_avals, f_debug) f_jaxpr, f_consts = pe.separate_consts(f_jaxpr) - in_tree, = treedef_children(in_args_tree) - _check_tree("f", "initial_guess", out_tree, in_tree, False) + _check_tree("f", "initial_guess", out_avals.tree, guess_avals.tree, False) solve_debug = api_util.debug_info("custom_root solve", solve, (f, initial_guess), {}, static_argnums=(0,)) - solve_jaxpr, solution_tree = pe.trace_to_jaxpr( - partial(solve, f), in_args_tree, guess_avals, solve_debug) + solve_jaxpr, solution_avals = pe.trace_to_jaxpr( + partial(solve, f), args_avals, solve_debug) solve_jaxpr, solve_consts = pe.separate_consts(solve_jaxpr) - _check_tree("solve", "initial_guess", solution_tree, in_tree, has_aux) + _check_tree("solve", "initial_guess", solution_avals.tree, guess_flat.tree, has_aux) def linearize_and_solve(x, b): unchecked_zeros, f_jvp = api.linearize(f, x) @@ -116,19 +114,21 @@ def linearize_and_solve(x, b): linearize_and_solve_dbg = api_util.debug_info("custom_root tangent_solve", tangent_solve, (initial_guess, initial_guess), {}) - l_and_s_jaxpr, out_tree = pe.trace_to_jaxpr( - linearize_and_solve, treedef_tuple((in_tree,) * 2), guess_avals * 2, - linearize_and_solve_dbg) + + + linearize_and_solve_avals = FlatTree.pack(((guess_avals, guess_avals), {})) + l_and_s_jaxpr, out_avals = pe.trace_to_jaxpr( + linearize_and_solve, linearize_and_solve_avals, linearize_and_solve_dbg) l_and_s_jaxpr, l_and_s_consts = pe.separate_consts(l_and_s_jaxpr) - _check_tree("tangent_solve", "x", out_tree, in_tree, False) + _check_tree("tangent_solve", "x", out_avals.tree, guess_flat.tree, False) all_consts = [f_consts, solve_consts, l_and_s_consts] const_lengths = _RootTuple(*_map(len, all_consts)) jaxprs = _RootTuple(f_jaxpr, solve_jaxpr, l_and_s_jaxpr) solution_flat = _custom_root( - const_lengths, jaxprs, *(_flatten(all_consts) + guess_flat)) - return tree_unflatten(solution_tree, solution_flat) + const_lengths, jaxprs, *_flatten(all_consts), *guess_flat) + return solution_avals.update_from_list(solution_flat).unflatten() @partial(custom_derivatives.custom_jvp, nondiff_argnums=(0, 1)) @@ -198,8 +198,8 @@ def _flatten(args): def _check_shapes(func_name, expected_name, actual, expected): - actual_shapes = _map(np.shape, tree_leaves(actual)) - expected_shapes = _map(np.shape, tree_leaves(expected)) + actual_shapes = _map(np.shape, actual) + expected_shapes = _map(np.shape, expected) if actual_shapes != expected_shapes: raise ValueError( f"{func_name}() output shapes must match {expected_name}, " @@ -250,20 +250,19 @@ def custom_linear_solve( if transpose_solve is None and symmetric: transpose_solve = solve - b_flat, in_args_tree = tree_flatten((b,)) - b_avals = tuple(_map(core.get_aval, b_flat)) - - tree, = treedef_children(in_args_tree) + b_flat = FlatTree.flatten(b) + b_avals = b_flat.map(core.get_aval) + tree = b_flat.tree def _shape_checked(fun, name, has_aux): def f(x): y = fun(x) - _check_shapes(name, "b", y, b_flat) + _check_shapes(name, "b", tree_leaves(y), b_flat) return y def f_aux(x): y, aux = fun(x) - _check_shapes(name, "b", y, b_flat) + _check_shapes(name, "b", tree_leaves(y), b_flat) return y, aux return f_aux if has_aux else f @@ -271,20 +270,21 @@ def f_aux(x): matvec_debug = api_util.debug_info("custom_linear_solve", matvec, (b,), {}) # no auxiliary data assumed for matvec - matvec_jaxpr, out_tree = pe.trace_to_jaxpr( - _shape_checked(matvec, "matvec", False), in_args_tree, b_avals, + args_avals = FlatTree.pack(((b_avals,),{})) + matvec_jaxpr, out_avals = pe.trace_to_jaxpr( + _shape_checked(matvec, "matvec", False), args_avals, matvec_debug) matvec_jaxpr, matvec_consts = pe.separate_consts(matvec_jaxpr) - _check_tree("matvec", "b", out_tree, tree, False) + _check_tree("matvec", "b", out_avals.tree, tree, False) solve_debug = api_util.debug_info("custom_linear_solve solve", solve, (matvec, b), {}, static_argnums=(0,)) - solve_jaxpr, out_tree = pe.trace_to_jaxpr( - _shape_checked(partial(solve, matvec), "solve", has_aux), in_args_tree, b_avals, + solve_jaxpr, out_avals = pe.trace_to_jaxpr( + _shape_checked(partial(solve, matvec), "solve", has_aux), args_avals, solve_debug) solve_jaxpr, solve_consts = pe.separate_consts(solve_jaxpr) - _check_tree("solve", "b", out_tree, tree, has_aux) + _check_tree("solve", "b", out_avals.tree, tree, has_aux) if transpose_solve is None: vecmat_jaxpr = tr_solve_jaxpr = None @@ -299,27 +299,27 @@ def f_aux(x): vecmat_consts = matvec_consts else: vecmat = _transpose_one_output(matvec, b) - vecmat_jaxpr, out_tree = pe.trace_to_jaxpr( - vecmat, in_args_tree, b_avals, transpose_solve_debug) + vecmat_jaxpr, out_avals = pe.trace_to_jaxpr( + vecmat, args_avals, transpose_solve_debug) vecmat_jaxpr, vecmat_consts = pe.separate_consts(vecmat_jaxpr) - assert out_tree == tree + assert out_avals.tree == tree - tr_solve_jaxpr, out_tree = pe.trace_to_jaxpr( + tr_solve_jaxpr, out_avals = pe.trace_to_jaxpr( _shape_checked(partial(transpose_solve, vecmat), "transpose_solve", has_aux), - in_args_tree, b_avals, transpose_solve_debug) + args_avals, transpose_solve_debug) tr_solve_jaxpr, tr_solve_consts = pe.separate_consts(tr_solve_jaxpr) - _check_tree("transpose_solve", "b", out_tree, tree, has_aux) + _check_tree("transpose_solve", "b", out_avals.tree, tree, has_aux) all_consts = [matvec_consts, vecmat_consts, solve_consts, tr_solve_consts] const_lengths = _LinearSolveTuple(*_map(len, all_consts)) jaxprs = _LinearSolveTuple( matvec_jaxpr, vecmat_jaxpr, solve_jaxpr, tr_solve_jaxpr) - args = _flatten(all_consts) + b_flat + args = _flatten(all_consts) + list(b_flat) args = core.standard_insert_pvary(*args) out_flat = linear_solve_p.bind(*args, const_lengths=const_lengths, jaxprs=jaxprs) - return tree_unflatten(out_tree, out_flat) + return out_avals.update_from_list(out_flat).unflatten() def _linear_solve_abstract_eval(*args, const_lengths, jaxprs): diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 7ddbb3cf55ef..f6501fe0c4a0 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -34,10 +34,15 @@ traceback_util.register_exclusion(__file__) T = TypeVar("T") +T1 = TypeVar("T1") +T2 = TypeVar("T2") +T3 = TypeVar("T3") +T4 = TypeVar("T4") Typ = TypeVar("Typ", bound=type[Any]) H = TypeVar("H", bound=Hashable) Leaf = Any +PyTree = Any PyTreeDef = pytree.PyTreeDef default_registry = pytree.default_registry() @@ -1336,3 +1341,114 @@ def _prefix_error( f"{prefix_tree_keys} and {full_tree_keys}") for k, t1, t2 in zip(prefix_tree_keys, prefix_tree_children, full_tree_children): yield from _prefix_error((*key_path, k), t1, t2) + +# === flat tree === + +class FlatTree: + """A FlatTree stores a treedef and a flat list of values. It's meant to be + isomorphic to the corresponding pytree but we can map over it more easily. + Compared to `tree_map`, FlatTree.map has these benefits: + 1. It doesn't touch user flatten/unflatten code (which shouldn't have side + effects but sometimes does in practice). + 2. It can be faster, because it skips the recursive traversal. + 3. It actually obeys the functor rules. For example, + `flat_tree.map(lambda x: (f(x), g(x))).unzip2()[0]` will give + the same result as `flat_tree.map(f)`, whereas in the `tree_map` version + the tuple-returning function would change the tree structure and `unzip` + wouldn't be able to recover it. + """ + def __init__(self, vals:Sequence[T], treedef:PyTreeDef): + assert isinstance(treedef, pytree.PyTreeDef) + self.tree = treedef + self.vals = list(vals) + + def map(self, f:Callable[[T1], T2]) -> FlatTree[T2]: + ans_vals = [] + for x in self.vals: + ans_vals.append(f(x)) + return FlatTree(ans_vals, self.tree) + + def map2( + self:FlatTree[T1], f:Callable[[T1, T2], T3], + t2:FlatTree[T2]) -> FlatTree[T3]: + + n = len(self) + assert len(t2) == n + ans_vals = [] + for x1, x2 in zip(self.vals, t2.vals): + ans_vals.append(f(x1, x2)) + return FlatTree(ans_vals, self.tree) + + def map3( + self:FlatTree[T1], f:Callable[[T1, T2, T3], T4], + t2:FlatTree[T2], t3:FlatTree[T3]) -> FlatTree[T4]: + n = len(self) + assert len(t2) == n and len(t3) == n + ans_vals = [] + for x1, x2, x3 in zip(self.vals, t2.vals, t3.vals): + ans_vals.append(f(x1, x2, x3)) + return FlatTree(ans_vals, self.tree) + + def zip(self, t2:FlatTree[T2]) -> FlatTree[tuple[T1, T2]]: + assert False + + def unzip2(self:FlatTree[tuple[T1, T2]]) -> tuple[FlatTree[T1], FlatTree[T2]]: + ys = [] + zs = [] + for y, z in self.vals: + ys.append(y) + zs.append(z) + return FlatTree(ys, self.tree), FlatTree(zs, self.tree) + + # TODO: add map3, zip3, unzip3 etc. as needed + + @staticmethod + def pack(tree): + # We could generalize this to arbitrary pytrees of FlatTree but tuples/dicts + # are sufficient for now. + if isinstance(tree, FlatTree): + return tree + elif isinstance(tree, tuple): + vals = [] + trees = [] + for child_tree in tree: + child = FlatTree.pack(child_tree) + vals.extend(child.vals) + trees.append(child.tree) + return FlatTree(vals, treedef_tuple(trees)) + elif isinstance(tree, dict): + # only empty case handled for now + if tree == {}: + return FlatTree.flatten({}) + else: + assert False + else: + assert False + + def unpack(self:FlatTree[tuple]) -> tuple[FlatTree]: + # TODO: this is O(N) not O(1) (with N as the number of leaves). If it + # becomes a problem we can fix it with a fancier data tree. + trees = treedef_children(self.tree) + children = [] + offset = 0 + for tree in trees: + new_offset = offset + tree.num_leaves + children.append(FlatTree(self.vals[offset:new_offset], tree)) + offset = new_offset + return tuple(children) + + @staticmethod + def flatten(tree: PyTree) -> FlatTree: + return FlatTree(*tree_flatten(tree)) + + def unflatten(self) -> PyTree: + return tree_unflatten(self.tree, self.vals) + + def update_from_list(self, new_vals:list[T1]) -> FlatTree[T1]: + return FlatTree(new_vals, self.tree) + + def __len__(self): + return self.tree.num_leaves + + def __iter__(self): + return self.vals.__iter__() From a8389508d8cc6d5f13a09158d7a3aae7745644ba Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Thu, 11 Dec 2025 14:50:31 -0800 Subject: [PATCH 171/315] Add device dict support to Pallas TPU interpret mode Reverts changelist 843160736 PiperOrigin-RevId: 843384117 --- .../mosaic/interpret/interpret_pallas_call.py | 63 +++++++++++++++++-- 1 file changed, 57 insertions(+), 6 deletions(-) diff --git a/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py b/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py index e1094a43c19d..8921da01d971 100644 --- a/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py +++ b/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py @@ -1184,7 +1184,55 @@ def _compute_transformed_shape_and_dtype(shape, dtype, transforms): dtype = transform.transform_dtype(dtype) return shape, dtype -def _device_coords_to_logical_id(device_coords, axis_sizes): +# TODO(sharadmv): De-dup this w/ the impl in primitives.py. +def _device_id_dict_to_mesh(device_id_dict, axis_sizes, axis_indices): + physical_axis_dict = {} + axis_names = axis_sizes.keys() + for axis, idx in device_id_dict.items(): + if isinstance(axis, tuple) and any(a in axis_names for a in axis): + if not all(a in axis_names for a in axis): + raise NotImplementedError( + f"{axis} mixes JAX mesh and Pallas mesh grid axes" + ) + axes_dimensions = [axis_sizes[name] for name in axis] + for axis_index, axis_name in enumerate(axis): + axis_size = axis_sizes[axis_name] + inner_mesh_size = math.prod(axes_dimensions[axis_index + 1 :]) + minor_divisor = inner_mesh_size + + # Fast path for power of 2s + if inner_mesh_size & (inner_mesh_size - 1) == 0: + shift_len = (inner_mesh_size & -inner_mesh_size).bit_length() - 1 + partial_device_idx = idx >> shift_len + else: + partial_device_idx = idx // minor_divisor + + if axis_size & (axis_size - 1) == 0: + device_idx = partial_device_idx & (axis_size - 1) + else: + device_idx = partial_device_idx % axis_size + physical_axis_dict[axis_name] = device_idx + else: + physical_axis_dict[axis] = idx + device_id = [] + for axis in axis_names: + if axis in physical_axis_dict: + device_id.append(physical_axis_dict[axis]) + else: + device_id.append(axis_indices[axis]) + non_mesh_axes = { + k: v + for k, v in physical_axis_dict.items() + if k not in axis_names + } + return tuple(device_id), non_mesh_axes + +def _device_coords_to_logical_id(device_coords, axis_sizes, axis_indices): + if isinstance(device_coords, dict): + device_coords, non_mesh_axes = _device_id_dict_to_mesh( + device_coords, axis_sizes, axis_indices) + if non_mesh_axes: + raise NotImplementedError(non_mesh_axes) if not isinstance(device_coords, tuple): device_coords = (device_coords,) assert len(device_coords) == len(axis_sizes) @@ -1194,11 +1242,12 @@ def _device_coords_to_logical_id(device_coords, axis_sizes): ret += device_coords[i] * math.prod(sizes[i+1:]) return ret -def _device_id_to_logical(device_id, device_id_type, axis_sizes): +def _device_id_to_logical(device_id, device_id_type, axis_sizes, + axis_indices): if device_id is None: return None if device_id_type == primitives.DeviceIdType.MESH: - return _device_coords_to_logical_id(device_id, axis_sizes) + return _device_coords_to_logical_id(device_id, axis_sizes, axis_indices) elif device_id_type == primitives.DeviceIdType.LOGICAL: return device_id else: @@ -1515,7 +1564,8 @@ def f(*args, jaxpr): target_device_id, ) = jax.tree.unflatten(eqn.params['tree'], deferred_invals()) target_device_id = _device_id_to_logical( - target_device_id, eqn.params['device_id_type'], axis_sizes) + target_device_id, eqn.params['device_id_type'], axis_sizes, + axis_indices) (orig_src_ref, _, orig_dst_ref, *_ ) = jax.tree.unflatten(eqn.params['tree'], eqn.invars) src_memory_space = getattr(orig_src_ref.aval, 'memory_space', None) @@ -1580,7 +1630,8 @@ def f(*args, jaxpr): sem, sem_transforms, inc, target_device_id, core_index = ( jax.tree.unflatten(eqn.params['args_tree'], deferred_invals())) target_device_id = _device_id_to_logical( - target_device_id, eqn.params['device_id_type'], axis_sizes) + target_device_id, eqn.params['device_id_type'], axis_sizes, + axis_indices) callback.io_callback( semaphore_signal, (), @@ -1984,7 +2035,7 @@ def interpret_pallas_call( jnp.multiply, axis_sizes.values(), jnp.int32(1)) axis_indices = {k: lax.axis_index(k) for k in axis_sizes.keys()} device_id = _device_coords_to_logical_id( - tuple(axis_indices.values()), axis_sizes) + tuple(axis_indices.values()), axis_sizes, axis_indices) callback.io_callback( functools.partial( _initialize_shared_memory, interpret_params=interpret_params From cc17df39c680feddef7f12d5ac7e0cb18d39e8f1 Mon Sep 17 00:00:00 2001 From: Krishna Haridasan Date: Thu, 11 Dec 2025 17:28:05 -0800 Subject: [PATCH 172/315] Simplify adding call location to custom options. Use xla::ifrt::AttributeMap::Set() to add the call location string to options.custom_options, instead of rebuilding the entire map. PiperOrigin-RevId: 843437966 --- jaxlib/BUILD | 1 + jaxlib/call_location.cc | 21 ++++++--------------- 2 files changed, 7 insertions(+), 15 deletions(-) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index c50200aa6da0..ebcd42f05304 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -516,6 +516,7 @@ cc_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@nanobind", diff --git a/jaxlib/call_location.cc b/jaxlib/call_location.cc index b556f5e6ee62..8a5558d361ff 100644 --- a/jaxlib/call_location.cc +++ b/jaxlib/call_location.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/base/no_destructor.h" #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" @@ -31,10 +32,10 @@ limitations under the License. #include "nanobind/stl/optional.h" // IWYU pragma: keep #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep -#include "jaxlib/traceback.h" #include "jaxlib/py_user_context.h" -#include "xla/python/ifrt/executable.h" +#include "jaxlib/traceback.h" #include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/executable.h" #include "xla/python/ifrt/user_context.h" #include "xla/python/pjrt_ifrt/pjrt_executable.h" @@ -124,19 +125,9 @@ void PopulateCallLocation(xla::ifrt::ExecuteOptions& options, } if (!call_location_str.empty()) { - // Simplify this to use AttributeMap::Set(). - xla::ifrt::AttributeMap::Map attrs_map; - if (options.custom_options.has_value()) { - options.custom_options->ForEach( - [&](const std::string& key, - const xla::ifrt::AttributeMap::Value& value) { - attrs_map.insert({key, value}); - }); - } - attrs_map.insert( - {std::string(xla::ifrt::PjRtCompatibleLoadedExecutable::kCallLocation), - xla::ifrt::AttributeMap::StringValue(std::move(call_location_str))}); - options.custom_options.emplace(std::move(attrs_map)); + CHECK_OK(options.custom_options->Set( + std::string(xla::ifrt::PjRtCompatibleLoadedExecutable::kCallLocation), + std::move(call_location_str))); } } From 6312a470f4d77829c44b4e5ec4a0305b4c2e8df2 Mon Sep 17 00:00:00 2001 From: partev Date: Mon, 8 Dec 2025 07:29:49 -0500 Subject: [PATCH 173/315] Fix links in contributing.md to use HTTPS Updated links in contributing guidelines to use HTTPS. add .git to github repo --- docs/contributing.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/contributing.md b/docs/contributing.md index 0c85f83d8b80..40334bb9599a 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -6,8 +6,8 @@ Everyone can contribute to JAX, and we value everyone's contributions. There are ways to contribute, including: - Answering questions on JAX's [discussions page](https://github.com/jax-ml/jax/discussions) -- Improving or expanding JAX's [documentation](http://docs.jax.dev/) -- Contributing to JAX's [code-base](http://github.com/jax-ml/jax/) +- Improving or expanding JAX's [documentation](https://docs.jax.dev) +- Contributing to JAX's [code-base](https://github.com/jax-ml/jax) - Contributing in any of the above ways to the broader ecosystem of [libraries built on JAX](https://github.com/jax-ml/jax#neural-network-libraries) The JAX project follows [Google's Open Source Community Guidelines](https://opensource.google/conduct/). @@ -49,7 +49,7 @@ Follow these steps to contribute code: For more information, see the {ref}`pr-checklist` below. 2. Fork the JAX repository by clicking the **Fork** button on the - [repository page](http://www.github.com/jax-ml/jax). This creates + [repository page](https://github.com/jax-ml/jax). This creates a copy of the JAX repository in your own account. 3. Install Python >= 3.11 locally in order to run tests. @@ -68,7 +68,7 @@ Follow these steps to contribute code: changes. ```bash - git remote add upstream https://www.github.com/jax-ml/jax + git remote add upstream https://github.com/jax-ml/jax.git ``` 6. Create a branch where you will develop from: From 356ad2bba9b8cbe937006b5ce5cce31fbd3e08d4 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 11 Dec 2025 18:58:57 -0800 Subject: [PATCH 174/315] [test] add test of advanced indexing with empty lists --- tests/lax_numpy_indexing_test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index 5a1ba1ac870f..742f5b90c5d3 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -237,6 +237,11 @@ def check_grads(f, args, order, atol=None, rtol=None, eps=None): IndexSpec(shape=(3,), indexer=np.array([0, 1, 0]), out_shape=(3,)), IndexSpec(shape=(3, 4, 5), indexer=np.array([ 0, -1]), out_shape=(2, 4, 5)), ]), + ("TupleOfEmptyList", [ + IndexSpec(shape=(3, 4), indexer=([],), out_shape=(0, 4)), + IndexSpec(shape=(3, 4), indexer=([], 0), out_shape=(0,)), + IndexSpec(shape=(3, 4), indexer=([], []), out_shape=(0,)), + ]), ("TupleOfListsOfPythonInts", [ IndexSpec(shape=(3, 4, 5), indexer=([0, 1],), out_shape=(2, 4, 5)), IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], [[2, 3, 0, 3]]), From ae9a27e4e96578caeac9f550cc7a5f748e62223a Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 12 Dec 2025 00:04:45 -0800 Subject: [PATCH 175/315] Update XLA dependency to use revision http://github.com/openxla/xla/commit/6872ce865c8b5880bce6a7f4d4b2e6fbce0704d2 PiperOrigin-RevId: 843551536 --- third_party/xla/revision.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/revision.bzl b/third_party/xla/revision.bzl index adfbfd05131d..22ed0fce0d6b 100644 --- a/third_party/xla/revision.bzl +++ b/third_party/xla/revision.bzl @@ -21,5 +21,5 @@ # and update XLA_SHA256 with the result. # buildifier: disable=module-docstring -XLA_COMMIT = "5d63679b85e9808398a1bb725365dda4b23594e4" -XLA_SHA256 = "e92ba838cd10126e7580435ff089f22ed6d2a7afc53ae6e5308e39ef9184e847" +XLA_COMMIT = "6872ce865c8b5880bce6a7f4d4b2e6fbce0704d2" +XLA_SHA256 = "d4c4dd44aed887092306f4a76eb85cd35f59d307d8e9f6d737901a9ac3e805d2" From 6cd2bc67b47530f77a2725b75b2c7aac240ce152 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 12 Dec 2025 03:24:46 -0800 Subject: [PATCH 176/315] [Mosaic GPU] Add limited support for reshapes of tiled layouts Reshapes are no-ops for as long as they don't affect any tiled dimension, or if the majormost tiled dimension is divisible by the tiling in both source and target shape. PiperOrigin-RevId: 843612858 --- jax/_src/pallas/mosaic_gpu/lowering.py | 13 ++++++++ .../mosaic/gpu/fragmented_array.py | 30 ++++++++++++++++--- tests/pallas/mosaic_gpu_test.py | 16 ++++++++++ 3 files changed, 55 insertions(+), 4 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 272a9d249eeb..6339cef1b25f 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -2399,6 +2399,19 @@ def _log_lowering_rule(ctx: LoweringRuleContext, x, accuracy): return math_dialect.log(_ensure_ir_value(x, x_aval.dtype), fastmath=fastmath) +@register_lowering_rule(lax.reshape_p, mgpu.LoweringSemantics.Lane) +def _reshape_lowering_rule( + ctx: LoweringRuleContext, x, new_sizes, dimensions, sharding +): + if dimensions is not None: + raise NotImplementedError("Not implemented: dimensions") + if sharding is not None: + raise NotImplementedError("Not implemented: sharding") + [x_aval] = ctx.avals_in + x = _ensure_fa(x, x_aval.dtype) + return x.reshape(new_sizes) + + def _reduce_lowering_rule(op, ctx: LoweringRuleContext, x, *, axes, **kwargs): [x_aval] = ctx.avals_in match x.layout: diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index f2c58a9b354f..bb8a5576558e 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -2471,13 +2471,35 @@ def reshape(self, shape) -> FragmentedArray: match self.layout: case WGSplatFragLayout() | WGStridedFragLayout(): new_layout = dataclasses.replace(self.layout, shape=shape) + return FragmentedArray( + _registers=self.registers, + _layout=new_layout, + _is_signed=self.is_signed, + ) + case TiledLayout(): + base_tile_shape = self.layout.base_tile_shape + assert base_tile_shape + old_shape_suffix = self.shape[-len(base_tile_shape):] + new_shape_suffix = shape[-len(base_tile_shape):] + # We already know that old_shape_suffix[0] is divisible by + # base_tile_shape[0]. + if ( + old_shape_suffix[1:] != new_shape_suffix[1:] + or new_shape_suffix[0] % base_tile_shape[0] + ): + raise ValueError( + f"Can't reshape {self.shape} to {shape} with a tiled layout with" + f" base tile of {base_tile_shape}" + ) + new_registers_shape = self.layout.registers_shape(shape) + return FragmentedArray( + _registers=self.registers.reshape(new_registers_shape), + _layout=self.layout, + _is_signed=self.is_signed, + ) case _: raise NotImplementedError(self.layout) - return FragmentedArray( - _registers=self.registers, _layout=new_layout, _is_signed=self.is_signed - ) - def broadcast_minor(self, n) -> FragmentedArray: if len(self.shape) != 1: raise ValueError("Broadcast minor is only supported for 1D arrays") diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 6e906c75839a..49e9bea41837 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -322,6 +322,21 @@ def kernel(x_ref, out_ref): x = jnp.arange(math.prod(shape1)).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x.reshape(shape2)) + def test_reshape_tiled(self): + self.skip_if_wg_semantics() + shape1, shape2 = (6 * 64, 8), (2, 3, 64, 8) + + @functools.partial( + self.kernel, + out_shape=jax.ShapeDtypeStruct(shape2, jnp.float32), + ) + def kernel(x_ref, out_ref): + y = plgpu.load(x_ref, (), layout=plgpu.Layout.WGMMA, optimized=False).reshape(shape2) + out_ref[...] = y + + x = jnp.arange(math.prod(shape1)).reshape(shape1).astype(jnp.float32) + np.testing.assert_array_equal(kernel(x), x.reshape(shape2)) + def test_add_xy_indexed(self): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.float32) @@ -2819,6 +2834,7 @@ def test_missing_primitive_lowerings_are_tracked(self): pallas_primitives.semaphore_read_p, pallas_primitives.delay_p, checkify.check_p, + lax.reshape_p, } self.assertSetEqual(actual_missing_primitives, expected_missing_primitives) From a5a333fa6f0c2b9243154d8d01b29df9e2e0c903 Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Fri, 12 Dec 2025 03:35:14 -0800 Subject: [PATCH 177/315] [MGPU] Split test_broadcast_major_dim into two tests PiperOrigin-RevId: 843615845 --- tests/mosaic/gpu_test.py | 60 ++++++++++++++++++---------------------- 1 file changed, 27 insertions(+), 33 deletions(-) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 5218160a6a77..33f4fed610b8 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -3796,46 +3796,40 @@ def kernel(ctx, *args): )(inp) np.testing.assert_array_equal(result, inp) - @parameterized.product( - mns=((128, 128), (128, 64), (64, 128)), - layout=(mtu.RegisterLayout.WG_STRIDED, mtu.RegisterLayout.WGMMA), + @parameterized.parameters( + (128, 128), (64, 128), (64, 256) ) - def test_broadcast_major(self, mns, layout): - m, n = mns + def test_broadcast_in_dim_major_strided(self, m, n): + dtype = jnp.float16 + def kernel(ctx, gmem_input, gmem_output, _): + t = mgpu.FragmentedArray.load_strided( + gmem_input, vec_size=1 + ) + t.broadcast_in_dim((m, n), (1,), + mgpu.WGStridedFragLayout(shape=(m, n), vec_size=1), + ).store_untiled(gmem_output, optimized=False) - if n < 128 and layout == mtu.RegisterLayout.WG_STRIDED: - self.skipTest(f"{n=} < 128 not supported for {layout=}") + inp = self.prng.uniform(-1, 1, (n,)).astype(dtype) + out_shape = jax.ShapeDtypeStruct((m, n), dtype) + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (inp,), out_shape, inp + )(inp) + out_ref = jax.lax.broadcast_in_dim(inp, (m, n), (1,)) + np.testing.assert_array_equal(result, out_ref) + @parameterized.parameters( + (128, 128), (128, 64), (64, 128) + ) + def test_broadcast_in_dim_major_wgmma(self, m, n): dtype = jnp.float16 - load_layout = ( - layout.to_mgpu((n,), dtype) - if layout == mtu.RegisterLayout.WG_STRIDED - else mgpu.WGMMA_COL_LAYOUT - ) - broadcast_layout = ( - mgpu.WGStridedFragLayout((m, n), load_layout.vec_size) - if layout == mtu.RegisterLayout.WG_STRIDED - else layout.to_mgpu((m, n), dtype) - ) - - def load(gmem_input): - match layout: - case mtu.RegisterLayout.WG_STRIDED: - return mgpu.FragmentedArray.load_strided( - gmem_input, vec_size=load_layout.vec_size - ) - case mtu.RegisterLayout.WGMMA: - return mgpu.FragmentedArray.load_untiled( - gmem_input, layout=mgpu.WGMMA_COL_LAYOUT, optimized=False - ) - case _: - raise NotImplementedError(f"Unsupported layout: {layout}") def kernel(ctx, gmem_input, gmem_output, _): - t = load(gmem_input) - t.broadcast_in_dim((m, n), (1,), broadcast_layout).store_untiled( - gmem_output, optimized=False + t = mgpu.FragmentedArray.load_untiled( + gmem_input, layout=mgpu.WGMMA_COL_LAYOUT, optimized=False ) + t.broadcast_in_dim( + (m, n), (1,), mgpu.WGMMA_LAYOUT + ).store_untiled(gmem_output, optimized=False) inp = self.prng.uniform(-1, 1, (n,)).astype(dtype) out_shape = jax.ShapeDtypeStruct((m, n), dtype) From 8572231e45768c6b5e19f2afa28a6aa0f6f52de3 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 12 Dec 2025 04:06:45 -0800 Subject: [PATCH 178/315] [Mosaic GPU] Add support for indexing untiled dims PiperOrigin-RevId: 843624037 --- jax/_src/pallas/mosaic_gpu/lowering.py | 10 ++++++++-- jax/experimental/mosaic/gpu/fragmented_array.py | 4 ++-- tests/pallas/mosaic_gpu_test.py | 16 ++++++++++++++++ 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 6339cef1b25f..ef20ee735035 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -2408,8 +2408,14 @@ def _reshape_lowering_rule( if sharding is not None: raise NotImplementedError("Not implemented: sharding") [x_aval] = ctx.avals_in - x = _ensure_fa(x, x_aval.dtype) - return x.reshape(new_sizes) + return _ensure_fa(x, x_aval.dtype).reshape(new_sizes) + + +@register_lowering_rule(lax.squeeze_p, mgpu.LoweringSemantics.Lane) +def _squeeze_lowering_rule(ctx: LoweringRuleContext, x, dimensions): + [x_aval] = ctx.avals_in + [y_aval] = ctx.avals_out + return _ensure_fa(x, x_aval.dtype).reshape(y_aval.shape) def _reduce_lowering_rule(op, ctx: LoweringRuleContext, x, *, axes, **kwargs): diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index bb8a5576558e..a81784f1e6d5 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1729,8 +1729,8 @@ def __getitem__(self, idx) -> FragmentedArray: if any(is_squeezed): raise NotImplementedError("Integer indexing not implemented (only slicing allowed)") base_tile_shape = self.layout.base_tile_shape - if len(base_tile_shape) != len(self.shape): - raise NotImplementedError("Tiling has different rank than array") + if untiled_rank := len(self.shape) - len(base_tile_shape): + base_tile_shape = (1,) * untiled_rank + base_tile_shape if any(b % t for b, t in zip(base_idx, base_tile_shape, strict=True)): raise ValueError( "Base indices of array slices must be aligned to the beginning of a" diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 49e9bea41837..158754fc697a 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -337,6 +337,21 @@ def kernel(x_ref, out_ref): x = jnp.arange(math.prod(shape1)).reshape(shape1).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x.reshape(shape2)) + def test_slice_untiled_dim(self): + self.skip_if_wg_semantics() + shape = (2, 3, 64, 8) + + @functools.partial( + self.kernel, + out_shape=jax.ShapeDtypeStruct(shape[2:], jnp.float32), + ) + def kernel(x_ref, out_ref): + y = plgpu.load(x_ref, (), layout=plgpu.Layout.WGMMA, optimized=False)[1, 1] + out_ref[...] = y + + x = jnp.arange(math.prod(shape)).reshape(shape).astype(jnp.float32) + np.testing.assert_array_equal(kernel(x), x[1, 1]) + def test_add_xy_indexed(self): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.float32) @@ -2835,6 +2850,7 @@ def test_missing_primitive_lowerings_are_tracked(self): pallas_primitives.delay_p, checkify.check_p, lax.reshape_p, + lax.squeeze_p, } self.assertSetEqual(actual_missing_primitives, expected_missing_primitives) From 8ab5e07551f7f84ef7c7d11c3d939830ddebcad9 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 12 Dec 2025 07:44:51 -0800 Subject: [PATCH 179/315] [test] fix old TODO related to scipy v1.17 --- tests/linalg_test.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index e20242136418..8c5f7b200108 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -85,9 +85,7 @@ def _random_invertible(rng, shape, dtype): def osp_linalg_toeplitz(c: np.ndarray, r: np.ndarray | None = None) -> np.ndarray: """scipy.linalg.toeplitz with v1.17+ batching semantics.""" - # TODO(dfm,jakevdp): Remove dev check after upstream PR is merged: - # https://github.com/scipy/scipy/issues/21466. - if scipy_version >= (1, 17, 0) and "dev0" not in scipy.version.version: + if scipy_version >= (1, 17, 0): return scipy.linalg.toeplitz(c, r) elif r is None: c = np.atleast_1d(c) From dfe1de0443050c0d3d9182c42ef8f0c24ed5a897 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Fri, 12 Dec 2025 08:04:18 -0800 Subject: [PATCH 180/315] [Mosaic] NFC: Move ops in tpu.td to tpu_ops.td. so that they can be reused. Note that we couldn't just include original tpu.td in downstream td files because of redefinitions of enums. PiperOrigin-RevId: 843693198 --- jaxlib/mosaic/BUILD | 70 +- jaxlib/mosaic/dialect/tpu/tpu.td | 1504 +------------------------ jaxlib/mosaic/dialect/tpu/tpu_ops.td | 1516 ++++++++++++++++++++++++++ jaxlib/mosaic/python/BUILD | 2 +- jaxlib/mosaic/python/tpu_python.td | 2 +- 5 files changed, 1582 insertions(+), 1512 deletions(-) create mode 100644 jaxlib/mosaic/dialect/tpu/tpu_ops.td diff --git a/jaxlib/mosaic/BUILD b/jaxlib/mosaic/BUILD index 5c1c8b58f1da..cced087e215b 100644 --- a/jaxlib/mosaic/BUILD +++ b/jaxlib/mosaic/BUILD @@ -148,16 +148,50 @@ gentbl_cc_library( name = "tpu_inc_gen", # compatible with libtpu tbl_outs = { - "dialect/tpu/tpu_ops.h.inc": ["-gen-op-decls"], - "dialect/tpu/tpu_ops.cc.inc": ["-gen-op-defs"], - "dialect/tpu/tpu_dialect.h.inc": ["-gen-dialect-decls"], - "dialect/tpu/tpu_dialect.cc.inc": ["-gen-dialect-defs"], - "dialect/tpu/tpu_enums.h.inc": ["-gen-enum-decls"], - "dialect/tpu/tpu_enums.cc.inc": ["-gen-enum-defs"], - "dialect/tpu/tpu_attr_defs.h.inc": ["-gen-attrdef-decls"], - "dialect/tpu/tpu_attr_defs.cc.inc": ["-gen-attrdef-defs"], - "dialect/tpu/tpu_type_defs.h.inc": ["-gen-typedef-decls"], - "dialect/tpu/tpu_type_defs.cc.inc": ["-gen-typedef-defs"], + "dialect/tpu/tpu_ops.h.inc": [ + "-gen-op-decls", + "-dialect=tpu", + ], + "dialect/tpu/tpu_ops.cc.inc": [ + "-gen-op-defs", + "-dialect=tpu", + ], + "dialect/tpu/tpu_dialect.h.inc": [ + "-gen-dialect-decls", + "-dialect=tpu", + ], + "dialect/tpu/tpu_dialect.cc.inc": [ + "-gen-dialect-defs", + "-dialect=tpu", + ], + "dialect/tpu/tpu_enums.h.inc": [ + "-gen-enum-decls", + "-dialect=tpu", + ], + "dialect/tpu/tpu_enums.cc.inc": [ + "-gen-enum-defs", + "-dialect=tpu", + ], + "dialect/tpu/tpu_attr_defs.h.inc": [ + "-gen-attrdef-decls", + "-dialect=tpu", + "--attrdefs-dialect=tpu", + ], + "dialect/tpu/tpu_attr_defs.cc.inc": [ + "-gen-attrdef-defs", + "-dialect=tpu", + "--attrdefs-dialect=tpu", + ], + "dialect/tpu/tpu_type_defs.h.inc": [ + "-gen-typedef-decls", + "-dialect=tpu", + "--typedefs-dialect=tpu", + ], + "dialect/tpu/tpu_type_defs.cc.inc": [ + "-gen-typedef-defs", + "-dialect=tpu", + "--typedefs-dialect=tpu", + ], "dialect/tpu/tpu_passes.h.inc": [ "-gen-pass-decls", "-name=TPU", @@ -172,8 +206,8 @@ gentbl_cc_library( ], }, tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "dialect/tpu/tpu.td", - deps = [":tpu_td_files"], + td_file = "dialect/tpu/tpu_ops.td", + deps = [":tpu_ops_td_files"], ) td_library( @@ -184,6 +218,18 @@ td_library( # compatible with libtpu deps = [ "@llvm-project//mlir:BuiltinDialectTdFiles", + ], +) + +td_library( + name = "tpu_ops_td_files", + srcs = [ + "dialect/tpu/tpu_ops.td", + ], + # compatible with libtpu + deps = [ + ":tpu_td_files", + "@llvm-project//mlir:BuiltinDialectTdFiles", "@llvm-project//mlir:ControlFlowInterfacesTdFiles", "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", "@llvm-project//mlir:OpBaseTdFiles", diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index f7383756e5d6..b19614bdb3bd 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -1,4 +1,4 @@ -/* Copyright 2023 The JAX Authors. +/* Copyright 2025 The JAX Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,18 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TPU_ATTRS -#define TPU_ATTRS +#ifndef TPU_BASE +#define TPU_BASE -include "mlir/IR/OpBase.td" -include "mlir/IR/AttrTypeBase.td" -include "mlir/IR/BuiltinAttributeInterfaces.td" +include "mlir/IR/BuiltinAttributes.td" include "mlir/IR/BuiltinTypeInterfaces.td" -include "mlir/IR/EnumAttr.td" -include "mlir/Pass/PassBase.td" -include "mlir/Interfaces/ControlFlowInterfaces.td" -include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/Interfaces/InferTypeOpInterface.td" def TPU_Dialect : Dialect { let name = "tpu"; @@ -45,870 +38,6 @@ class TPU_Attr traits = []> let mnemonic = mnemonic_; } -// TODO(b/369418606): Find out the way to verify vreg size. -def TPU_Vreg : Type; - -class TPU_Type traits = [], - string baseCppType = "::mlir::Type"> - : TypeDef { - let mnemonic = mnemonic_; -} - -def TPU_CoreType : I32EnumAttr<"CoreType", "Core type", [ - I32EnumAttrCase<"kTc", 0, "tc">, - I32EnumAttrCase<"kScScalarSubcore", 1, "sc_scalar_subcore">, - I32EnumAttrCase<"kScVectorSubcore", 2, "sc_vector_subcore"> -]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::tpu"; -} - -def TPU_CoreTypeEnum : EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} - -def TPU_PipelineMode : I32EnumAttr<"PipelineMode", "Pipeline mode", [ - I32EnumAttrCase<"kSynchronous", 1, "synchronous">, - I32EnumAttrCase<"kDoubleBuffered", 2, "double_buffered"> - ]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::tpu"; -} - -def TPU_PipelineModeEnum : EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} - -def TPU_SemaphoreType : TPU_Type<"Semaphore", "semaphore", [MemRefElementTypeInterface]>; -def TPU_DMASemaphoreType : TPU_Type<"DMASemaphore", "dma_semaphore", [MemRefElementTypeInterface]>; -def TPU_SomeSemaphoreType : AnyTypeOf<[TPU_SemaphoreType, TPU_DMASemaphoreType]>; - -def TPU_Float8EXMYType : TPU_Type<"Float8EXMY", "float8_exmy", - [DeclareTypeInterfaceMethods]> { - let summary = "EXMY type in a 8 bit container"; - let description = [{ - EXMY type in a 8 bit container. Meaningful bits are aligned to LSB, and - bits higher than the underlying exmy type in the container are considered - as ignored. See https://arxiv.org/abs/2405.13938 for more details. - }]; - - let parameters = (ins - TypeParameter<"::mlir::FloatType", "Underlying EXMY type">:$underlying_type - ); - - let assemblyFormat = [{ - `<` $underlying_type `>` - }]; -} - -def TPU_DimensionSemantics : I32EnumAttr<"DimensionSemantics", "Dimension semantics", [ - I32EnumAttrCase<"parallel", 0>, - I32EnumAttrCase<"arbitrary", 1>, - I32EnumAttrCase<"core_parallel", 2>, - I32EnumAttrCase<"subcore_parallel", 3> -]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::tpu"; -} - -def TPU_DimensionSemanticsEnum - : EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} - -// All indices/sizes are in element-space. -// Note that the implementation will require statically provable tile alignment. -def TPU_ElementWindowAttr : TPU_Attr<"ElementWindow", "element_window"> { - // Including low padding, to avoid backwards-incompatible changes once we add it. - let parameters = (ins - ArrayRefParameter<"int64_t", "">:$pad_low, - ArrayRefParameter<"int64_t", "">:$pad_high - ); - let assemblyFormat = "`<` `[` $pad_low `]` `,` `[` $pad_high `]` `>`"; -} - -def TPU_ContractPrecision : I32EnumAttr<"ContractPrecision", "Contraction precision", [ - I32EnumAttrCase<"kBF16", 0, "bf16">, - I32EnumAttrCase<"kFP32", 1, "fp32"> -]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::tpu"; -} - -def TPU_ContractPrecisionEnum - : EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} - -def TPU_PackFormat : I32EnumAttr<"PackFormat", "Pack format", [ - I32EnumAttrCase<"kCompressed", 0, "compressed">, - I32EnumAttrCase<"kInterleaved", 1, "interleaved"> -]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::tpu"; -} - -def TPU_PackFormatEnum : EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} - -def TPU_TiledCase : I32EnumAttrCase<"tiled", 0>; -def TPU_LaneCase : I32EnumAttrCase<"lanes", 1>; -def TPU_SublaneCase : I32EnumAttrCase<"sublanes", 2>; -def TPU_VectorLayoutDim : I32EnumAttr< - "VectorLayoutDim", "", [TPU_TiledCase, TPU_LaneCase, TPU_SublaneCase]>; - -def TPU_VectorLayoutAttr : TPU_Attr<"VectorLayout", "vpad"> { - let description = [{TODO}]; - - let parameters = (ins "Layout":$layout); - let hasCustomAssemblyFormat = 1; -} - -def TPU_TiledLayoutAttr - : TPU_Attr<"TiledLayout", "tiled", - [DeclareAttrInterfaceMethods]> { - let description = [{ - This attribute represents tiled layouts in memrefs. - - Multiple levels of tiling are supported with the following restriction: - - Additional levels of tiling may not add any padding. - - Additional levels of tiling may not tile previously untiled dimensions, - that is, they cannot tile across first-level tiles. - - Tile strides encode the stride when moving along a given dimension. They - must have the same rank as the shape and must be decreasing with increasing - dimension number. For tiled dimensions, the stride applies only when moving - across first-level tiles. The strides are in units of the size of the first - tile, or 1 if there are no tiles. - }]; - let parameters = (ins - ArrayRefParameter<"::xla::Tile", "">:$tiles, - ArrayRefParameter<"int64_t", "">:$tile_strides - ); - let extraClassDeclaration = [{ - static ::llvm::SmallVector getDefaultTileStrides(::llvm::ArrayRef<::xla::Tile> tiles, ::llvm::ArrayRef shape); - bool tilesAreKnownContiguous(::llvm::ArrayRef shape) const; - - int64_t getRank() const { - return getTileStrides().size(); - } - int64_t getUntiledRank() const; - - ::llvm::SmallVector getExpandedShape(::llvm::ArrayRef shape) const; - ::llvm::SmallVector getExpandedStrides() const; - }]; - - let hasCustomAssemblyFormat = 1; - let genVerifyDecl = 1; -} - -def TPU_MemorySpace : I32EnumAttr<"MemorySpace", "Memory space", [ - I32EnumAttrCase<"kAny", 4294967295, "any">, - I32EnumAttrCase<"kVmem", 0, "vmem">, - I32EnumAttrCase<"kSmem", 1, "smem">, - I32EnumAttrCase<"kHbm", 2, "hbm">, - I32EnumAttrCase<"kCmem", 3, "cmem">, - I32EnumAttrCase<"kSemaphoreMem", 4, "semaphore_mem">, - I32EnumAttrCase<"kVmemShared", 5, "vmem_shared">, - I32EnumAttrCase<"kHost", 6, "host"> -]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::tpu"; -} - -def TPU_MemorySpaceEnum - : EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} - -class TPU_Op traits = []> : - Op { -} - -def DefaultMemWrite : MemoryEffects<[MemWrite]>; -def DefaultMemRead : MemoryEffects<[MemRead]>; - -def TPU_ReductionKind : I32EnumAttr<"ReductionKind", "Reduction kind", [ - I32EnumAttrCase<"kSum", 0, "sum">, - I32EnumAttrCase<"kMax", 1, "max">, - I32EnumAttrCase<"kMin", 2, "min">, - I32EnumAttrCase<"kArgMax", 3, "arg_max">, - I32EnumAttrCase<"kArgMin", 4, "arg_min">, - I32EnumAttrCase<"kFindFirstSet", 5, "find_first_set"> -]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::tpu"; -} - -def TPU_ReductionKindAttr - : EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} - -def TPU_AllReduceOp : TPU_Op<"all_reduce", [Pure]> { - let arguments = (ins AnyVectorOfNonZeroRank:$input, I64Attr:$dim, TPU_ReductionKindAttr:$kind); - let results = (outs AnyVectorOfNonZeroRank:$output); - let assemblyFormat = [{ - $input attr-dict `:` type($input) `->` type($output) - }]; - let hasVerifier = 1; -} - -def TPU_ReduceIndexOp : TPU_Op<"reduce_index", [Pure]> { - let arguments = (ins - AnyVectorOfNonZeroRank:$input, - I32Attr:$axis, - TPU_ReductionKindAttr:$kind - ); - let results = (outs VectorOfNonZeroRankOf<[I32]>:$output); - let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; - let hasVerifier = 1; -} - -// tpu.scan performs a scan across a vector. -// -// If a mask is provided, all output elements before the first unmasked input -// element is undefined. Subsequent masked elements will hold the result -// of the last unmasked element. -// -// For example, a "kSum" reduction over a input vector [1, 2, 3, 4] -// with mask [0, 1, 0, 1] will produce the output vector [X, 2, 2, 6]. -// where X is some undefined value. -// -// output : Result vector. Must have the same shape as source. -// input : Vector to scan. -// kind : Reduction operator. Must be one of "kSum", "kMax", or "kMin". -// Must be "kSum" if input is an I1 vector. -// mask : Elementwise vector mask. The scan operation starts from the -// lowest-indexed non-masked vector element (all previous elements -// have undefined values). Not taken for I1 input vectors. -def TPU_ScanOp : TPU_Op<"scan"> { - let arguments = (ins - VectorOfNonZeroRankOf<[I1, I16, I32, BF16, F32]>:$input, - TPU_ReductionKindAttr:$kind, - Optional>:$mask - ); - let results = (outs VectorOfNonZeroRankOf<[I16, I32, BF16, F32]>:$output); - let assemblyFormat = [{ - $kind `,` $input (`masked` $mask^)? attr-dict `:` type($input) `,` type($mask) `->` type($output) - }]; - let hasVerifier = 1; -} - -def TPU_SortOp : TPU_Op<"sort", [Pure]> { - let summary = "Sorts key/value pairs based on keys."; - let description = [{ - tpu.sort performs a stable sort of key/value pairs in ascending or - descending order based on keys. Masked-out keys and values are placed at the - end of the output vectors. An output mask indicates which outputs - correspond to the valid inputs. - }]; - let arguments = (ins - VectorOfNonZeroRankOf<[I32,F32]>:$keys, - VectorOfNonZeroRankOf<[I32,F32]>:$values, - Optional>:$mask, - DefaultValuedAttr:$descending - ); - let results = (outs - VectorOfNonZeroRankOf<[I1]>:$output_mask, - VectorOfNonZeroRankOf<[I32,F32]>:$sorted_keys, - VectorOfNonZeroRankOf<[I32,F32]>:$sorted_values - ); - let assemblyFormat = [{ - $keys `,` $values (`masked` $mask^)? attr-dict `:` functional-type(operands, results) - }]; - let hasVerifier = 1; -} - -def TPU_StoreOp : TPU_Op<"store", [DefaultMemWrite, AttrSizedOperandSegments]> { - let arguments = (ins - TPU_Vreg:$valueToStore, - AnyType:$base, - Variadic:$indices, - DenseBoolArrayAttr:$sublane_mask, - Optional:$mask, - OptionalAttr:$sublane_stride // In sublane-sized units - ); - let results = (outs); - let assemblyFormat = [{ - $base `[` $indices `]` `,` $valueToStore (`masked` $mask^)? `sublanes` $sublane_mask (`sublane_stride` $sublane_stride^)? attr-dict `:` type($base) `,` type($valueToStore) `,` type($mask) - }]; -} - -def TPU_LoadOp : TPU_Op<"load", [DefaultMemRead]> { - let arguments = (ins - AnyType:$base, - Variadic:$indices, - DenseBoolArrayAttr:$sublane_mask, - OptionalAttr:$sublane_stride // In sublane-sized units - ); - let results = (outs TPU_Vreg:$result); - let assemblyFormat = [{ - $base `[` $indices `]` `sublanes` $sublane_mask (`sublane_stride` $sublane_stride^)? attr-dict `:` type($base) `,` type($result) - }]; - let description = [{ - Similar to `vector::LoadOp` but with `sublane_mask` and `sublane_stride`. - When `indices` are negative, it means loading from negative offset - of `base` address. - }]; -} - -// TODO(jevinjiang): migrate tpu.strided_store to general vector store op. -def TPU_VectorStoreOp :TPU_Op<"vector_store", [DefaultMemWrite, AttrSizedOperandSegments]> { - let arguments = (ins - AnyVectorOfNonZeroRank:$valueToStore, - AnyMemRef:$base, - Variadic:$indices, - DenseI32ArrayAttr:$strides, - Optional:$mask, // Elementwise mask. - DefaultValuedAttr:$add - ); - let results = (outs); - let assemblyFormat = [{ - $base `[` $indices `]` `,` $valueToStore (`masked` $mask^)? attr-dict `:` type($base) `,` type($valueToStore) `,` type($mask) - }]; - let hasVerifier = 1; - let hasCanonicalizeMethod = 1; -} - -// tpu.vector_load loads a vector from memory into a register. -// -// base : Memref to load from. -// indices: Scalar indices into base. indices must be of the same rank as the -// base memref shape. -// strides: The stride to use for calculating the address of subsequent -// elements. If left unspecified, the stride is implicitly 1 along -// each dimension. Otherwise the stride must match the rank of the -// memref shape. -// mask : Elementwise vector mask. Must be broadcastable to the shape of the -// result vector. Depending on the core type, this may be a dynamic -// (lane) mask consumed from a register or a static (sublane) mask -// that must be the result of arith.constant. -def TPU_VectorLoadOp :TPU_Op<"vector_load", [DefaultMemRead, AttrSizedOperandSegments]> { - let arguments = (ins - AnyMemRef:$base, - Variadic:$indices, - DenseI32ArrayAttr:$strides, - Optional:$mask // Elementwise mask. - ); - let results = (outs AnyVectorOfNonZeroRank:$result); - let assemblyFormat = [{ - $base `[` $indices `]` (`masked` $mask^)? attr-dict `:` type($base) `,` type($result) `,` type($mask) - }]; - let hasVerifier = 1; - let hasCanonicalizeMethod = 1; -} - -def TPU_StridedLoadOp : TPU_Op<"strided_load", [DefaultMemRead]> { - let arguments = (ins - AnyMemRef:$base, - Variadic:$indices, - DenseI32ArrayAttr:$strides - ); - let results = (outs AnyVectorOfNonZeroRank:$result); - let assemblyFormat = [{ - $base `[` $indices `]` attr-dict `:` type($base) `,` type($result) - }]; - let hasVerifier = 1; -} - -def TPU_StridedStoreOp : TPU_Op<"strided_store", [DefaultMemWrite]> { - let arguments = (ins - AnyVectorOfNonZeroRank:$valueToStore, - AnyMemRef:$base, - Variadic:$indices, - DenseI32ArrayAttr:$strides - ); - let results = (outs); - let assemblyFormat = [{ - $base `[` $indices `]` `,` $valueToStore attr-dict `:` type($base) `,` type($valueToStore) - }]; - let hasVerifier = 1; -} - -// TODO: b/435258666 - Merge with tpu.vector_load_idx. -def TPU_ShuffledLoadOp : TPU_Op<"shuffled_load", [DefaultMemRead]> { - let arguments = (ins - AnyMemRef:$base, - Variadic:$indices, - DenseBoolArrayAttr:$sublane_mask, - DenseI32ArrayAttr:$sublane_offsets - ); - let results = (outs TPU_Vreg:$result); - let assemblyFormat = [{ - $base `[` $indices `]` attr-dict `:` type($base) `,` type($result) - }]; - let hasVerifier = 1; - let hasCanonicalizeMethod = 1; -} - -// TODO: b/435258666 - Merge with tpu.vector_store_idx. -def TPU_ShuffledStoreOp : TPU_Op<"shuffled_store", [DefaultMemWrite]> { - let arguments = (ins - TPU_Vreg:$valueToStore, - AnyMemRef:$base, - Variadic:$indices, - DenseBoolArrayAttr:$sublane_mask, - DenseI32ArrayAttr:$sublane_offsets - ); - let results = (outs); - let assemblyFormat = [{ - $base `[` $indices `]` `,` $valueToStore attr-dict `:` type($base) `,` type($valueToStore) - }]; - let hasVerifier = 1; - let hasCanonicalizeMethod = 1; -} - -// tpu.vector_load_idx loads values from arbitrary locations in memory. -// -// Each element in the output vector is loaded from an index in the base memref -// specified by the corresponding elements in the 'indices' vectors. The shape -// of each index vector must match the shape of the output vector. The number -// of index vectors must equal the rank of the base memref. -// -// For example, for a vector of length n with rank 2, the indices will look like: -// indices = [[idx0, idx1, ...], [idxn, idxn+1, ...]] -// where [idx0, idxn] is the offset of the first vector element. -// -// base : Memref specifying the base address. -// indices : Vectors of indices for each dimension of the base memref. -// mask : Optional elementwise vector mask. -def TPU_VectorLoadIdxOp :TPU_Op<"vector_load_idx", [DefaultMemRead, AttrSizedOperandSegments]> { - let arguments = (ins - MemRefOf<[I32, F32]>:$base, - Variadic>:$indices, - Optional>:$mask - ); - let results = (outs VectorOfNonZeroRankOf<[I32, F32]>:$value); - let assemblyFormat = [{ - $base `[` $indices `]` (`masked` $mask^)? attr-dict `:` type($base) `[` type($indices) `]` `,` type($value) `,` type($mask) - }]; - let hasVerifier = 1; - let hasCanonicalizeMethod = 1; -} - -// tpu.vector_store_idx stores values to arbitrary locations in memory. -// -// Each element in the input vector is stored to an index in the base memref -// specified by the corresponding elements in the 'indices' vectors. The shape -// of each index vector must match the shape of the input vector. The number -// of index vectors must equal the rank of the base memref. -// -// For example, for a vector of length n with rank 2, the indices will look like: -// indices = [[idx0, idx1, ...], [idxn, idxn+1, ...]] -// where [idx0, idxn] is the offset of the first vector element. -// -// When multiple vector elements have the same index to store to, the data from -// the highest lane will be the one stored. If add is true, then the data will -// be added from the lowest lane to the highest lane. -// -// valueToStore: Vector to be stored. -// base : Memref specifying the base address. -// indices : Vectors of indices for each dimension of the base memref. -// mask : Optional elementwise vector mask. -// add : If true, add source values to target values. Otherwise, overwrite. -def TPU_VectorStoreIdxOp :TPU_Op<"vector_store_idx", [DefaultMemWrite, AttrSizedOperandSegments]> { - let arguments = (ins - VectorOfNonZeroRankOf<[I32, F32]>:$valueToStore, - MemRefOf<[I32, F32]>:$base, - Variadic>:$indices, - Optional>:$mask, - DefaultValuedAttr:$add - ); - let results = (outs); - let assemblyFormat = [{ - $base `[` $indices `]` `,` $valueToStore (`masked` $mask^)? attr-dict `:` type($base) `[` type($indices) `]` `,` type($valueToStore) `,` type($mask) - }]; - let hasVerifier = 1; - let hasCanonicalizeMethod = 1; -} - -// TODO(jevinjiang): deprecate to use dynamic_rotate. -def TPU_RotateOp : TPU_Op<"rotate", [Pure, SameOperandsAndResultType]> { - let description = [{ - Rotates the given vector by the given amount in the given dimension, i.e., - for a 2D vector of shape (m, n), rotating dim 0 by `amount` will shift a row - at index `i` to index `(i + amount) % m` - }]; - let arguments = (ins - AnyVectorOfNonZeroRank:$value, - SI32Attr:$amount, - SI32Attr:$dimension, - // When the stride is specified, the rotation amount for each index on the - // stride dimension will be (amount + stride * index). - OptionalAttr:$stride, - OptionalAttr:$stride_dimension - ); - let results = (outs AnyVectorOfNonZeroRank:$result); - let assemblyFormat = [{ - $value `by` $amount `dim` $dimension (`stride` $stride `stride_dim` $stride_dimension^)? attr-dict `:` type($value) - }]; - let hasVerifier = 1; -} - -def TPU_DynamicRotateOp : TPU_Op<"dynamic_rotate", [Pure]> { - let arguments = (ins - AnyVectorOfNonZeroRank:$value, - I32:$amount, - SI32Attr:$dimension, - // When the stride is specified, the rotation amount for each index on the - // stride dimension will be (amount + stride * index). - OptionalAttr:$stride, - OptionalAttr:$stride_dimension - ); - let results = (outs AnyVectorOfNonZeroRank:$result); - let assemblyFormat = [{ - $value `by` $amount `dim` $dimension attr-dict `:` type($value) `,` type($amount) `->` type($result) - }]; - let hasVerifier = 1; -} - -def TPU_ScanCountOp : TPU_Op<"scan_count", [Pure, InferTypeOpAdaptor, SameOperandsAndResultShape]> { -let summary = [{ - ScanCountOp calculates the running duplicate occurrence count of the elements - in the input vector. Elements eligible for counting are specified by the - input mask vector. The output mask vector indicates one unique occurrence - per duplicate that was counted. - }]; - - let description = [{ - ScanCountOp calculates the running duplicate occurrence count of the elements - in the input vector, %values. The output vector, %counts, contains the running - duplicate occurrence count for the corresponding element in - the input vector, where the count is performed in ascending order of element - indices. For example, if the elements of %values at indices 0, 5, and 7 had - duplicate values, then the elements of %counts at indices 0, 5, and 7 would - be 1, 2, and 3, respectively. - - A mask vector, %in_mask, specifies which of the elements in the input vector - are eligible for counting. An element in %values that has its mask set to 0 - will always have a count of 1 in %counts, regardless of the position in the - vector, or whether there were duplicates or not. - }]; - - let arguments = (ins - VectorOfNonZeroRankOf<[I1]>:$in_mask, - AnyVectorOfNonZeroRank:$values - ); - let results = (outs - VectorOfNonZeroRankOf<[I1]>:$out_mask, - VectorOfNonZeroRankOf<[I32]>:$counts - ); - - let assemblyFormat = [{ - `mask` `(` $in_mask `:` type($in_mask) `)` - `value` `(` $values `:` type($values) `)` - attr-dict `:` type(results) - }]; - -} - -def TPU_IotaOp : TPU_Op<"iota", [Pure]> { - let description = [{ - Creates a vector that with values that start at 0 and increase along a - dimension resulting from collapsing the given `dimensions` together in - row-major order. - - Example: - ``` - tpu.iota {dimensions = array} : vector<4x3x2xi16> - ``` - This produces a vector with the following values: - ``` - [[[0, 4], [0, 4], [0, 4]] - [[1, 5], [1, 5], [1, 5]] - [[2, 6], [2, 6], [2, 6]] - [[3, 7], [3, 7], [3, 7]]] - ``` - }]; - let arguments = (ins DenseI32ArrayAttr:$dimensions); - let results = (outs VectorOfNonZeroRankOf<[AnyInteger, Index]>:$output); - let assemblyFormat = [{ attr-dict `:` type($output) }]; - let hasVerifier = 1; -} - -def TPU_ReshapeOp : TPU_Op<"reshape", [Pure]> { - let arguments = (ins AnyVectorOfNonZeroRank:$source); - let results = (outs AnyVectorOfNonZeroRank:$result); - let assemblyFormat = [{ $source attr-dict `:` type($source) `->` type($result) }]; - let hasVerifier = 1; - let hasFolder = 1; -} - -// TODO(mvoz): deprecated - use concat. Canonicalization will do so automatically. -// b/376295711 -def TPU_RepeatOp : TPU_Op<"repeat", [Pure]> { - let arguments = (ins - AnyVectorOfNonZeroRank:$source, - I32Attr:$dimension, - I32Attr:$times - ); - let results = (outs AnyVectorOfNonZeroRank:$output); - let assemblyFormat = [{ $source `,` $dimension `x` $times attr-dict `:` type($source) `->` type($output) }]; -} - -def TPU_BroadcastInSublanesOp : TPU_Op<"broadcast_in_sublanes", [Pure]> { - let description = [{ - For each sublane `i`, broadcasts the value in lane `lane + i` along the - entire sublane. For packed type, imagine the data is compressed unpacked - along sublane dimension, and the sublane count is multiplied by the packing - factor. - For example, for i16 with sublane count 8, `i` above is in [0, 8 * 2). - If `lane + i` is not in [0, lane_count), then the value in sublane `i` is - not defined (can be anything). - }]; - let arguments = (ins - TPU_Vreg:$source, // All sublanes should be equal. - I32Attr:$lane // Coordinates of the first element to take. - ); - // Output shape should be the same, except for position dim which contains - // the newly inserted dimension. - let results = (outs AnyVectorOfNonZeroRank:$output); - let assemblyFormat = [{ - $source `,` $lane attr-dict `:` type($source) `->` type($output) - }]; -} - -// Integer unpacks are always signed at the moment. -// -// When unpacking integers to integers, setting `sign_extended` to false will -// leave bits higher than source bitwidth as undefined. -// -// Take int4 to int16 interleaved unpacking and `index = 1` as an example: -// -// Source: -// -// Bits 28 24 20 16 12 8 4 0 -// --------abcd------------efgh---- -// -// where "a" and "e" are the sign bits of the values to be unpacked, and "-" are -// bits to be ignored. -// -// Unpacked, sign_extend = true: -// -// Bits 28 24 20 16 12 8 4 0 -// aaaaaaaaaaaaabcdeeeeeeeeeeeeefgh -// -// Unpacked, sign_extend = false: -// -// Bits 28 24 20 16 12 8 4 0 -// ------------abcd------------efgh -def TPU_UnpackSubelementsOp : TPU_Op<"unpack_subelements", [Pure]> { - let arguments = (ins - AnyVectorOfNonZeroRank:$source, - I32Attr:$index, - TPU_PackFormatEnum:$pack_format, - DefaultValuedAttr:$sign_extended - ); - let results = (outs AnyVectorOfNonZeroRank:$output); - let assemblyFormat = [{ $source `,` $index attr-dict `:` type($source) `->` type($output) }]; - let hasVerifier = 1; - let hasCanonicalizeMethod = 1; -} - -// Integer packs are always signed at the moment. -// Float to integer packing rounds to nearest even. -// WARNING: pack(pack(a, b), pack(c, d)) == pack(a, b, c, d) only holds for -// compressed packing! -// Below, we use [ ... ] to denote the bounds of the vreg and use regular parens -// ( ... ) to denote packing of multiple subelements into a single 32-bit word. -// -// Interleaved packing -// -// Interleaved packing downcasts to a narrower dtype, and packs multiple elements -// into the same word coordinate from which they originated. If a and b are packed -// values, then interleaved packing first iterates over the operand list and only -// then over the subelements within each word. -// Take 16-bit vregs A, B, C and D: -/// -// [ (A000 A001) (A010 A011) ... ] -// [ (A100 A101) (A110 A111) ... ] -// ... -// -// An interleaved pack(a, b) from 16-bit values produces: -// -// [ (A000 B000 A001 B001) (A010 B010 A011 B011) ...] -// ... -// -// While an interleaved pack(a, b, c, d) produces the following subelements in -// each vreg word: -// -// [ (A000 B000 C000 D000 A001 B001 C001 D001) ... ] -// ... -// -// Compressed packing -// -// Compressed packing downcasts each value and then packs multiple rows together. -// A compressed pack(a, b) from 16-bit values produces: -// -// [ (A000 A001 A100 A101) (A010 A011 A110 A111) ... ] -// [ (A200 A201 A300 A301) (A210 A211 A310 A311) ... ] -// ... # 2 more sublanes -// [ (B000 B001 B100 B101) (B010 B011 B110 B111) ... ] -// [ (B200 B201 B300 B301) (B210 B211 B310 B311) ... ] -// ... -def TPU_PackSubelementsOp : TPU_Op<"pack_subelements", [Pure, SameTypeOperands]> { - let arguments = (ins - Variadic:$sources, - DenseI32ArrayAttr:$positions, - TPU_PackFormatEnum:$pack_format - ); - let results = (outs TPU_Vreg:$output); - let assemblyFormat = [{ $sources attr-dict `:` type($sources) `->` type($output) }]; - let builders = [ - OpBuilder<(ins "::mlir::VectorType":$output_type, "::mlir::ArrayRef<::mlir::Value>":$padded_sources, "::mlir::tpu::PackFormat":$pack_format)>, - ]; - let extraClassDeclaration = [{ - static ::mlir::SmallVector<::mlir::Value> getPaddedSources(::mlir::ValueRange sources, ::mlir::ArrayRef positions, int packing_factor); - }]; - let hasVerifier = 1; -} - -def TPU_PackElementwiseOp : TPU_Op<"pack_elementwise", [Pure, SameTypeOperands, ElementwiseMappable]> { - let description = [{ - Packs multiple `sources` elementwise into a single vector of a narrower `target_type`. - - The number of `sources` must equal the packing factor, which is the ratio of - the element bitwidth of the `sources` to the element bitwidth of the - `target_type`. Elements from the `sources` are interleaved and packed into - each word of the `output`, ordered from lowest to highest bits, - corresponding to their order in the `sources`. - }]; - let arguments = (ins - Variadic>:$sources, - TypeAttr:$target_type - ); - let results = (outs VectorOfNonZeroRankOf<[I32]>:$output); - let assemblyFormat = [{ $sources attr-dict `:` type($sources) `->` type($output) }]; - let hasVerifier = 1; -} - -def TPU_UnpackElementwiseOp : TPU_Op<"unpack_elementwise", [Pure, ElementwiseMappable]> { - let description = [{ - Unpacks a single vector from `source`, which contains multiple `source_type` - vectors packed elementwise. - - The `index` selects which packed value to extract from each word of `source`. - An `index` of 0 corresponds to the lowest bits. The extracted values are - cast to the output element type. - }]; - let arguments = (ins - VectorOfNonZeroRankOf<[I32]>:$source, - TypeAttr:$source_type, - I32Attr:$index - ); - let results = (outs VectorOfNonZeroRankOf<[F32, I32]>:$output); - let assemblyFormat = [{ $source `,` $index attr-dict `:` type($source) `->` type($output) }]; - let hasVerifier = 1; -} - -def TPU_RelayoutOp : TPU_Op<"relayout", [Pure, SameOperandsAndResultType]> { - let arguments = (ins AnyVectorOfAnyRank:$input); - let results = (outs AnyVectorOfAnyRank:$output); - let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; - let hasVerifier = 1; -} - -def TPU_PackMaskOp : TPU_Op<"pack_vmsk", [Pure, SameTypeOperands]> { - let arguments = (ins - VectorOfNonZeroRankOf<[I1]>: $low, - VectorOfNonZeroRankOf<[I1]>: $high - ); - let results = (outs VectorOfNonZeroRankOf<[I1]>:$output); - let assemblyFormat = [{ $low `,` $high `,` attr-dict `:` type($low) `,` type($high) `->` type($output) }]; -} - -def TPU_GatherOp : TPU_Op<"gather", [Pure]> { - let arguments = (ins - AnyVectorOfNonZeroRank:$source, - DenseI32ArrayAttr:$indices, - I32Attr:$dimension - ); - let results = (outs AnyVectorOfNonZeroRank:$output); - let assemblyFormat = [{ - $source `[` $indices `]` `in` $dimension attr-dict - `:` type($source) `->` type($output) - }]; -} - -def TPU_DynamicGatherOp : TPU_Op<"dynamic_gather", [Pure, DeclareOpInterfaceMethods, AllShapesMatch<["indices", "output"]>, AllElementTypesMatch<["source", "output"]>]> { - let description = [{ - Gathers elements from `source` using `indices`. - - The specified `dimensions` of `source` are collapsed together and indexed by - `indices`. - - Given a shape `N0 x N1 x ...`, the `output[i0, i1, ...]` is given by - `collapsed_source[j0, j1, ..., indices[i0, i1, ...] mod M]` where - - `collapsed_source` is the result of collapsing `dimensions` of `source` - into a new trailing dimension of size `M`. - - `jk` is the subsequence of `in` for `n` not in `dimensions`. - - When a single dimension is specified, this is similar to - `np.take_along_axis`. - }]; - let arguments = (ins - AnyVectorOfNonZeroRank:$source, - VectorOfNonZeroRankOf<[AnyInteger]>:$indices, - DenseI32ArrayAttr:$dimensions - ); - let results = (outs AnyVectorOfNonZeroRank:$output); - let assemblyFormat = [{ - $source `[` $indices `]` `in` $dimensions attr-dict - `:` type($source) `,` type($indices) `->` type($output) - }]; - let hasVerifier = 1; -} - -def TPU_RoundingMode : I32EnumAttr<"RoundingMode", "Rounding mode", [ - I32EnumAttrCase<"kTowardsZero", 0, "towards_zero">, - I32EnumAttrCase<"kToNearestEven", 1, "to_nearest_even">, -]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::tpu"; -} - -def TPU_RoundingModeEnum : EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} - -// Internal operation. All arith.fptosi operations that change the bitwidth -// must be canonicalized to this operation. -def TPU_FPToSIOp : TPU_Op<"fptosi", [Pure, ElementwiseMappable]> { - let arguments = (ins AnyVectorOfAnyRank:$input, TPU_RoundingModeEnum:$rounding_mode); - let results = (outs AnyVectorOfAnyRank:$output); - let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; - let hasCanonicalizeMethod = 1; -} - -// Internal operation. All arith.sitofp operations that change the bitwidth -// must be canonicalized to this operation. -def TPU_SIToFPOp : TPU_Op<"sitofp", [Pure, ElementwiseMappable]> { - let arguments = (ins AnyType:$in, TPU_RoundingModeEnum:$rounding_mode); - let results = (outs AnyType:$output); - let assemblyFormat = [{ $in attr-dict `:` type($in) `->` type($output) }]; -} - -// Internal operation. -def TPU_ExtFOp : TPU_Op<"extf", [Pure, ElementwiseMappable]> { - let arguments = (ins AnyType:$in); - let results = (outs AnyType:$out); - let assemblyFormat = [{ $in attr-dict `:` type($in) `->` type($out) }]; - let hasFolder = 1; -} - -// Internal operation. -def TPU_TruncFOp : TPU_Op<"truncf", [Pure, ElementwiseMappable]> { - let arguments = ( - ins AnyType:$in, - TPU_RoundingModeEnum:$rounding_mode - ); - let results = (outs AnyType:$out); - let assemblyFormat = [{ $in attr-dict `:` type($in) `->` type($out) }]; - let hasFolder = 1; -} - def TPU_DotDimensionNumbersAttr : TPU_Attr<"DotDimensionNumbers", "dot_dimension_numbers"> { let parameters = (ins ArrayRefParameter<"int64_t", "">:$lhs_contracting_dims, @@ -928,628 +57,7 @@ def TPU_DotDimensionNumbersAttr : TPU_Attr<"DotDimensionNumbers", "dot_dimension "`[` $output_dim_order `]` `,` " "`[` (`]`):($lhs_batch_dims^ `]`)? `,` " "`[` (`]`):($rhs_batch_dims^ `]`)? `>`"; + let constBuilderCall = "::mlir::tpu::DotDimensionNumbersAttr::get($_builder.getContext(), $0)"; } -// TODO(apaszke): Think hard about precision -def TPU_MatmulOp : TPU_Op<"matmul", [Pure]> { - let arguments = (ins - AnyVectorOfNonZeroRank:$lhs, - AnyVectorOfNonZeroRank:$rhs, - AnyVectorOfNonZeroRank:$acc, - // These flags are deprecated - if dimension_numbers are defined, - // these flags are ignored. They will always be false after canonicalize. - DefaultValuedAttr:$transpose_lhs, - DefaultValuedAttr:$transpose_rhs, - OptionalAttr:$precision, - // NOTE: User-level optional, once canonicalized, always present. - OptionalAttr:$dimension_numbers - ); - let results = (outs AnyVectorOfNonZeroRank:$result); - let assemblyFormat = [{ - $lhs `,` $rhs `,` $acc attr-dict `:` type($lhs) `,` type($rhs) `,` type($acc) `->` type($result) - }]; - let hasCanonicalizer = 1; - let hasVerifier = 1; -} - -def TPU_ConcatenateOp : TPU_Op<"concatenate", [Pure, DeclareOpInterfaceMethods]> { - let arguments = (ins - Variadic:$sources, - I32Attr:$dimension - ); - let results = (outs AnyVectorOfNonZeroRank:$output); - let assemblyFormat = [{ - $sources `in` $dimension attr-dict `:` type($sources) `->` type($output) - }]; - let hasVerifier = 1; -} - -def TPU_BitcastOp : TPU_Op<"bitcast", [Pure]> { - let arguments = (ins AnyVectorOfNonZeroRank:$input); - let results = (outs AnyVectorOfNonZeroRank:$output); - let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; - let hasVerifier = 1; -} - -def TPU_BitcastVregOp : TPU_Op<"bitcast_vreg", [Pure]> { - let arguments = (ins TPU_Vreg:$input); - let results = (outs TPU_Vreg:$output); - let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; - let hasFolder = 1; -} - -def TPU_WeirdOp : TPU_Op<"weird", [Pure, ElementwiseMappable]> { - let arguments = (ins AnyType:$input); // F32 vector or scalar - let results = (outs AnyType:$output); // I1 vector or scalar - let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; - let hasVerifier = 1; -} - -def TPU_ReciprocalOp : TPU_Op<"reciprocal", [Pure, SameOperandsAndResultType, ElementwiseMappable]> { - let arguments = (ins - AnyVectorOfNonZeroRank:$input, - DefaultValuedAttr:$approx - ); - let results = (outs AnyVectorOfNonZeroRank:$output); - let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; - let hasVerifier = 1; -} - -def TPU_StochasticConvertOp : TPU_Op<"stochastic_convert", [Pure, SameOperandsAndResultShape]> { - let arguments = (ins - VectorOfNonZeroRankOf<[F32]>:$input, - VectorOfNonZeroRankOf<[I32]>:$random - ); - let results = (outs AnyVectorOfNonZeroRank:$output); - let assemblyFormat = [{ $input `,` $random attr-dict `:` type($input) `,` type($random) `->` type($output) }]; -} - -def TPU_StochasticConvertElementwiseOp : TPU_Op<"stochastic_convert_elementwise", [Pure, ElementwiseMappable]> { - // Stochastically converts the input to the target dtype based on the mode. - // When the target dtype is less than 32 bits, the result occupies the lowest {bitwidth} bits in the I32 output. - let arguments = (ins - VectorOfNonZeroRankOf<[F32]>:$input, - VectorOfNonZeroRankOf<[I32]>:$random, - TypeAttr:$dst_type - ); - let results = (outs VectorOfNonZeroRankOf<[I32]>:$output); - let assemblyFormat = [{ $input `,` $random attr-dict `:` type($input) `,` type($random) `->` type($output) }]; - let hasVerifier = 1; -} - -def TPU_RollVectorsOp : TPU_Op<"roll_vectors", [Pure]> { - let arguments = (ins Variadic:$input); - let results = (outs AnyVectorOfAnyRank:$output); - let assemblyFormat = [{ - $input attr-dict `:` type($input) `->` type($output) - }]; -} - -def TPU_UnrollVectorsOp : TPU_Op<"unroll_vectors", [Pure]> { - let arguments = (ins AnyVectorOfAnyRank:$input); - let results = (outs Variadic:$output); - let hasCanonicalizeMethod = 1; - let assemblyFormat = [{ - $input attr-dict `:` type($input) `->` type($output) - }]; -} - -def TPU_CreateMaskOp : TPU_Op<"create_mask", [Pure, SameVariadicOperandSize]> { - // high is exclusive - let arguments = (ins Variadic:$low, Variadic:$high); - let results = (outs AnyType:$output); - let assemblyFormat = [{ - `[` $low `]``[` $high `]` attr-dict `:` type($output) - }]; -} - -def TPU_CreateSubelementMaskOp : TPU_Op<"create_subelement_mask", [Pure]> { - let summary = "Create a mask masking contiguous rows of subelements."; - let description = [{ - The "half-sublanes", "quarter-sublanes", etc. (unit is determined by - the type of `output`) of the mask are masked in the range specified by - `from` and `to`. - - - If `from <= to`, the range `[from, to)` is set and the rest is unset. - - If `to <= from`, the range `[to, from)` is unset and the rest is set. - - All lanes are set identically. - - Example: - - ```mlir - %msk = tpu.create_subelement_mask 3, 9 : vector<8x128x2xi1> - ``` - - This creates a mask `%msk` where, for all `lane`s, `%msk[*][lane][*]` is: - - ``` - [[0, 0], [0, 1], [1, 1], [1, 1], [1, 0], [0, 0], [0, 0], [0, 0]] - ``` - - It is currently only supported: - - In TPU v4, for `num_subelems` of 1 and 2. - - In TPU v5, for `num_subelems` of 1, 2, and 4. - }]; - let arguments = (ins - I32Attr:$from, // inclusive - I32Attr:$to // exclusive - ); - let results = (outs AnyType:$output); // Verify this is a vmsk with num_subelems - let assemblyFormat = [{ - $from `,` $to attr-dict `:` type($output) - }]; -} - -def TPU_AssumeMultipleOp : TPU_Op<"assume_multiple", [Pure, SameOperandsAndResultType]> { - let summary = "Assumes that a value is a multiple of a given integer."; - let description = [{ - This operation is a hint to the compiler that the input `value` is guaranteed - to be a multiple of `multiple`. This can be used to satisfy divisibility checks - in some compiler passes. - - The result is the same as the input `value`. - - Example: - - ```mlir - %val = tpu.assume_multiple %arg0, 16 : index - ``` - }]; - let arguments = (ins - AnyTypeOf<[Index, AnyInteger]>:$value, - I32Attr:$multiple - ); - let results = (outs AnyTypeOf<[Index, AnyInteger]>:$result); - let assemblyFormat = [{$value `,` $multiple attr-dict `:` type($result)}]; - let hasVerifier = 1; -} - -def TPU_MemRefSliceOp : TPU_Op<"memref_slice", [Pure, AttrSizedOperandSegments]> { - let arguments = (ins - AnyMemRef:$mem_ref, - Variadic:$base_idx, - Variadic:$dynamic_sizes - ); - let results = (outs AnyMemRef:$result); - let assemblyFormat = [{ - $mem_ref `[` $base_idx `]` (`<` $dynamic_sizes^ `>`)? - attr-dict `:` type($mem_ref) `->` type($result) - }]; - let hasVerifier = 1; - let hasCanonicalizer = 1; -} - -def TPU_MemRefSqueezeOp : TPU_Op<"memref_squeeze", [Pure]> { - let arguments = (ins AnyMemRef:$input); - let results = (outs AnyMemRef:$result); - let assemblyFormat = [{ - $input attr-dict `:` type($input) `->` type($result) - }]; - let hasVerifier = 1; - let hasCanonicalizeMethod = 1; -} - -def TPU_MemRefReshapeOp : TPU_Op<"memref_reshape", [Pure]> { - let arguments = (ins AnyMemRef:$input); - let results = (outs AnyMemRef:$result); - let assemblyFormat = [{ - $input attr-dict `:` type($input) `->` type($result) - }]; - let hasVerifier = 1; - let hasCanonicalizeMethod = 1; -} - -def TPU_MemRefBitcastOp : TPU_Op<"memref_bitcast", [Pure]> { - let arguments = (ins AnyMemRef:$input); - let results = (outs AnyMemRef:$result); - let assemblyFormat = [{ - $input attr-dict `:` type($input) `->` type($result) - }]; - let hasVerifier = 1; - let hasCanonicalizeMethod = 1; -} - -def TPU_ReinterpretCastOp : TPU_Op<"reinterpret_cast", [Pure]> { - let arguments = (ins AnyMemRef:$input); - let results = (outs AnyMemRef:$result); - let assemblyFormat = [{ - $input attr-dict `:` type($input) `->` type($result) - }]; - let hasVerifier = 1; - let hasCanonicalizeMethod = 1; -} - -def TPU_AssumeLayoutOp : TPU_Op<"assume_layout", [Pure]> { - let arguments = (ins AnyType:$input); - let results = (outs AnyType:$result); - let assemblyFormat = [{ - $input attr-dict `:` type($input) `->` type($result) - }]; -} - -// Erases the layout attribute from the memref. -// -// The resulting memref is identical to the input, except that it has an -// identity layout. -def TPU_EraseLayoutOp : TPU_Op<"erase_memref_layout", [Pure, InferTypeOpAdaptor]> { - let arguments = (ins AnyMemRef:$operand); - let results = (outs AnyMemRef:$result); - let assemblyFormat = [{ - $operand attr-dict `:` type($operand) `->` type($result) - }]; - let hasFolder = 1; -} - -// Returns the ID of the current device. -// -// On the input to the compiler the return value is a logical ID in the XLA -// device assignment. It changes to a physical ID after the -// logical-to-physical-device-id pass. -def TPU_DeviceIdOp : TPU_Op<"device_id", [Pure]> { - let arguments = (ins); - let results = (outs I32:$result); - let assemblyFormat = [{ attr-dict `:` type($result) }]; -} - -def TPU_SemaphoreReadOp : TPU_Op<"sem_read"> { - let arguments = (ins MemRefOf<[TPU_SemaphoreType, TPU_DMASemaphoreType]>:$semaphore); - let results = (outs I32:$result); - let assemblyFormat = [{ $semaphore attr-dict `:` type($semaphore) `->` type($result)}]; -} - -def TPU_SemaphoreWaitOp : TPU_Op<"sem_wait"> { - let arguments = (ins - MemRefOf<[TPU_SemaphoreType]>:$semaphore, - I32:$amount - ); - let results = (outs); - let assemblyFormat = [{ $semaphore `,` $amount attr-dict `:` type($semaphore)}]; - let hasVerifier = 1; -} - -def TPU_AllocaSemaphoreOp : TPU_Op<"sem_alloc"> { - let arguments = (ins); - let results = (outs MemRefOf<[TPU_SomeSemaphoreType]>:$result); - let assemblyFormat = [{ attr-dict `:` type($result) }]; -} - -def TPU_GetBarrierSemaphoreOp : TPU_Op<"sem_barrier"> { - let arguments = (ins); - let results = (outs MemRefOf<[TPU_SemaphoreType]>:$semaphore); - let assemblyFormat = [{ attr-dict `:` type($semaphore) }]; - let hasVerifier = 1; -} - -def TPU_SemaphoreSignalOp : TPU_Op<"sem_signal", [AttrSizedOperandSegments]> { - let arguments = (ins - MemRefOf<[TPU_SemaphoreType]>:$semaphore, - I32:$amount, - Optional:$device_id, // For remote DMAs - Optional:$core_id, // For megacore - OptionalAttr:$core_type - ); -let assemblyFormat = [{ - $semaphore `,` $amount (`device_id` $device_id^)? (`core_id` $core_id^)? (`core_type` $core_type^)? attr-dict `:` type($semaphore) - }]; - let hasVerifier = 1; - let builders = [ - // A backward-compatible builder that sets `core_type` to nullptr. - OpBuilder<(ins "Value":$semaphore, "Value":$amount, - "Value":$device_id, "Value":$core_id)>, - ]; -} - -def TPU_BarrierOp : TPU_Op<"barrier"> { - let summary = [{Barrier synchronization across SC vector subcores.}]; - let description = [{ - Performs barrier synchronization across all SC vector subcores at the - specified barrier id. - }]; - let arguments = (ins Index:$barrier_id); - let results = (outs); - let assemblyFormat = [{ `barrier_id` `(` $barrier_id `)` attr-dict }]; -} - -// tpu.enqueue_dma enqueues a DMA operation. -// -// source : Memref to copy from. -// source_semaphore : Semaphore to signal after the DMA completes. -// target : Memref to copy to. -// target_semaphore : Semaphore to wait on before the DMA completes. -// device_id : The id of the device to copy to for remote DMAs. -// core_id : The id of the core to copy to for remote and cross-core -// DMAs. -// priority : The priority of the DMA. -// strict_ordering : True if the DMA requires strict ordering. If false, the -// ordering is either strict or relaxed depending on the -// source and destination. -def TPU_EnqueueDMAOp : TPU_Op<"enqueue_dma", [AttrSizedOperandSegments]> { - let arguments = (ins - AnyMemRef:$source, - Optional>:$source_semaphore, // For remote DMAs - AnyMemRef:$target, - MemRefOf<[TPU_DMASemaphoreType]>:$target_semaphore, - Optional:$device_id, // For remote DMAs - Optional:$core_id, // For megacore - // Smaller number means higher priority. 0 is the highest and the default. - DefaultValuedAttr:$priority, - DefaultValuedAttr:$strict_ordering - ); - let assemblyFormat = [{ - `source` `(` $source `:` type($source) `)` - `target` `(` $target `:` type($target) `)` - (`source_semaphore` `(` $source_semaphore^ `:` type($source_semaphore) `)`)? - `target_semaphore` `(` $target_semaphore `:` type($target_semaphore) `)` - (`device_id` `(` $device_id^ `)`)? - (`core_id` `(` $core_id^ `)`)? - attr-dict - }]; - let hasVerifier = 1; - let hasCanonicalizeMethod = 1; -} - -// A base class for all ops that need to differentiate between gather and -// scatter. -class IndirectDMAOp { - code extraBaseClassDeclaration = [{ - // Return true if this op performs a gather. Returns false if it performs a - // scatter. - FailureOr isGather(); - }]; -} - -// tpu.enqueue_indirect_dma copies data between HBM and VMEM, or between -// VMEM_SHARED and VMEM using indirect HBM offsets. -// -// If the source is in HBM or VMEM_SHARED and the target is in VMEM, performs a -// gather from the source (operand) at the offsets to the target (gather -// result). -// If the source is in VMEM and the target is in HBM or VMEM_SHARED, performs a -// scatter of the source (updates) to the target (operand) at the offsets. -// -// source : Memref to copy from. -// target : Memref to copy to. -// offsets : Gather or scatter offsets. -// semaphore : Semaphore to wait on; receive semaphore for scatter, send semaphore for gather. -// add : If true, add source values to target values. Otherwise, overwrite. -// offset_filter : If set, don't write values at offsets whose value is equal to -// the filter value. -def TPU_EnqueueIndirectDMAOp : TPU_Op<"enqueue_indirect_dma">, IndirectDMAOp { - let arguments = (ins - AnyMemRef:$source, - AnyMemRef:$target, - AnyTypeOf<[MemRefOf<[I32]>, VectorOfRankAndType<[1], [I32]>]>:$offsets, - MemRefOf<[TPU_DMASemaphoreType]>:$semaphore, - Optional:$offset_filter, - DefaultValuedAttr:$add - ); - let assemblyFormat = [{ - `source` `(` $source `:` type($source) `)` - `target` `(` $target `:` type($target) `)` - `offsets` `(` $offsets `:` type($offsets) `)` - (`offset_filter` `(` $offset_filter^ `)`)? - `semaphore` `(` $semaphore `:` type($semaphore) `)` - attr-dict - }]; - let hasVerifier = 1; - let extraClassDeclaration = extraBaseClassDeclaration # [{ - LogicalResult verifyGather(MemRefType operand_ty, - ArrayRef offsets_shape, - MemRefType result_ty); - LogicalResult verifyScatter(MemRefType updates_ty, - ArrayRef offsets_shape, - MemRefType operand_ty); - }]; - let hasCanonicalizeMethod = 1; -} - -// tpu.wait_dma2 waits for a DMA to complete. -// -// The number of bytes to wait for is determined based on the size of the -// destination memref. -def TPU_WaitDMA2Op : TPU_Op<"wait_dma2", [AttrSizedOperandSegments]> { - let arguments = (ins - MemRefOf<[TPU_DMASemaphoreType]>:$semaphore, - AnyMemRef:$src, - AnyMemRef:$dst, - Optional:$device_id, // For remote DMAs - Optional:$core_id, // For megacore - DefaultValuedAttr:$strict_ordering - ); - let assemblyFormat = [{ - `semaphore` `(` $semaphore `:` type($semaphore) `)` - `src` `(` $src `:` type($src) `)` - `dst` `(` $dst `:` type($dst) `)` - (`device_id` `(` $device_id^ `)`)? - (`core_id` `(` $core_id^ `)`)? - attr-dict - }]; - let hasVerifier = 1; - // A backward-compatible builder that sets `device_id` and `core_id` to nullptr. - let builders = [ - OpBuilder<(ins "Value":$semaphore, "Value":$src, "Value":$dst)> - ]; - let hasCanonicalizeMethod = 1; -} - -// TODO(b/395630795): Remove after 2025-08-10. -def TPU_WaitDMAOp : TPU_Op<"wait_dma"> { - let arguments = (ins - MemRefOf<[TPU_DMASemaphoreType]>:$semaphore, - AnyMemRef:$ref - ); - let hasVerifier = 1; -} - -// Like tpu.wait_dma2, but for indirect DMAs. -// -// The number of bytes to wait for is determined based on the size of the -// destination memref in a gather, and the size of the source memref in a -// scatter. The op differentiates between gather and scatter based on the memory -// spaces of the source and destination memrefs. -def TPU_WaitIndirectDMAOp : TPU_Op<"wait_indirect_dma">, IndirectDMAOp { - let arguments = (ins - MemRefOf<[TPU_DMASemaphoreType]>:$semaphore, - AnyMemRef:$src, - AnyMemRef:$dst - ); - let assemblyFormat = [{ - `semaphore` `(` $semaphore `:` type($semaphore) `)` - `src` `(` $src `:` type($src) `)` - `dst` `(` $dst `:` type($dst) `)` - attr-dict - }]; - let hasVerifier = 1; - let hasCanonicalizeMethod = 1; - let extraClassDeclaration = extraBaseClassDeclaration; -} - -def TPU_RegionOp : TPU_Op<"region", [RecursiveMemoryEffects, SingleBlockImplicitTerminator<"tpu::YieldOp">]> { - let arguments = (ins); - let results = (outs Variadic:$results); - let regions = (region AnyRegion:$region); - let hasVerifier = 1; -} - -def TPU_TraceOp : TPU_Op<"trace", [RecursiveMemoryEffects, SingleBlockImplicitTerminator<"tpu::YieldOp">]> { - let arguments = (ins StrAttr:$message, I32Attr:$level); - let results = (outs Variadic:$results); - let regions = (region AnyRegion:$region); -} - -def TPU_TraceStartOp : TPU_Op<"trace_start", []> { - let arguments = (ins StrAttr:$message, I32Attr:$level); - let results = (outs); -} - -def TPU_TraceStopOp : TPU_Op<"trace_stop", []> { - let arguments = (ins); - let results = (outs); -} - -def TPU_YieldOp : TPU_Op<"yield", [Pure, ReturnLike, Terminator]> { - let arguments = (ins Variadic:$results); - let assemblyFormat = [{ attr-dict ($results^ `:` type($results))? }]; -} - -def TPU_DelayOp : TPU_Op<"delay"> { - let arguments = (ins I32:$nanos); - let results = (outs); -} - -// Expands the granularity of mask to subelements. -def TPU_MaskCastOp : TPU_Op<"mask_cast", [Pure]> { - let description = [{ - Cast a mask register into a different packing. - - If casting to a type with smaller packing, then values being packed together - must be identical. For example, for 8x128x4xi1 -> 8x128x2xi1, - input[i, j, 0] == input[i, j, 1] and input[i, j, 2] == input[i, j, 3] must - hold for all i, j. Otherwise, the result is undefined. - }]; - let arguments = (ins VectorOfNonZeroRankOf<[I1]>:$input); - let results = (outs VectorOfNonZeroRankOf<[I1]>:$result); - let assemblyFormat = [{ - $input attr-dict `:` type($input) `->` type($result) - }]; - let hasVerifier = 1; -} - -def TPU_GetIterationBoundOp : TPU_Op<"iteration_bound"> { - let arguments = (ins I32Attr:$dim); - let results = (outs I32:$result); - let assemblyFormat = [{ $dim attr-dict `:` type($result) }]; -} - -def TPU_GetInternalScratchOp : TPU_Op<"internal_scratch"> { - let arguments = (ins); - let results = (outs AnyMemRef:$result); - let assemblyFormat = [{ attr-dict `:` type($result) }]; -} - -def TPU_PRNGSeed32Op : TPU_Op<"prng_set_seed_32"> { - let arguments = (ins Variadic:$seeds); - let results = (outs); -} - -def TPU_PRNGRandomBitsOp : TPU_Op<"prng_random_bits"> { - let arguments = (ins); - let results = (outs AnyVectorOfNonZeroRank:$output); -} - -def TPU_SublaneShuffleOp : TPU_Op<"sublane_shuffle", [SameOperandsAndResultType]> { - // This op takes 2 physical vregs and a pattern, applies the pattern, - // and returns the result as 1 vreg. - // - // The pattern is a list of integers, where the integer value is the - // index of the sublane in the *combined input* [lhs, rhs], and the - // position of the integer in the list is the index of the sublane - // in the *output* vreg. - // - // The pattern size must match the operand/result sublane count. - // - // Example: - // %0 = tpu.single_output_sublane_shuffle %a, %b, - // [0, 1, 2, 3, 4, 5, 6, 7] // Result is %a - // %1 = tpu.single_output_sublane_shuffle %a, %b, - // [8, 9, 10, 11, 12, 13, 14, 15] // Result is %b - // %2 = tpu.single_output_sublane_shuffle %a, %b, - // [7, 6, 5, 4, 11, 10, 9, 8] // Result uses high half of a - // // and low half of b, reversed. - let arguments = (ins - TPU_Vreg:$lhs, - TPU_Vreg:$rhs, - DenseI32ArrayAttr:$pattern - ); - let results = (outs TPU_Vreg:$result); - let assemblyFormat = [{ - $lhs `,` $rhs `,` $pattern attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) - }]; - - let hasVerifier = 1; -} - -def TPU_TransposeOp : TPU_Op<"transpose", [Pure]> { - let summary = "tpu transpose operation"; - let arguments = (ins AnyVectorOfAnyRank:$vector, - DenseI64ArrayAttr:$permutation); - let results = (outs AnyVectorOfAnyRank:$result); - - let assemblyFormat = [{ - $vector `,` $permutation attr-dict `:` type($vector) `->` type($result) - }]; - let extraClassDeclaration = [{ - VectorType getSourceVectorType() { - return ::llvm::cast(getVector().getType()); - } - VectorType getResultVectorType() { - return ::llvm::cast(getResult().getType()); - } - }]; - let hasVerifier = 1; -} - -def TPU_LogOp : TPU_Op<"log"> { - let arguments = (ins - Variadic:$inputs, - StrAttr:$tag, - DefaultValuedAttr:$formatted - ); - let results = (outs); - let assemblyFormat = [{ $tag attr-dict (`:` `[` $inputs^ `]` `:` type($inputs))? }]; - let hasVerifier = 1; -} - -def TPU_LogBufferOp : TPU_Op<"log_buffer"> { - let arguments = (ins - AnyMemRef:$input, - DenseI64ArrayAttr:$shape, - StrAttr:$tag - ); - let results = (outs); - let assemblyFormat = [{ $tag attr-dict `:` $input `:` type($input) }]; - let hasVerifier = 1; -} - -#endif // TPU_ATTRS +#endif // TPU_BASE diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.td b/jaxlib/mosaic/dialect/tpu/tpu_ops.td new file mode 100644 index 000000000000..3329220cdd91 --- /dev/null +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.td @@ -0,0 +1,1516 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TPU_OPS +#define TPU_OPS + +include "mlir/IR/OpBase.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinAttributeInterfaces.td" +include "mlir/IR/BuiltinTypeInterfaces.td" +include "mlir/IR/EnumAttr.td" +include "mlir/Pass/PassBase.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "jaxlib/mosaic/dialect/tpu/tpu.td" + +// TODO(b/369418606): Find out the way to verify vreg size. +def TPU_Vreg : Type; + +class TPU_Type traits = [], + string baseCppType = "::mlir::Type"> + : TypeDef { + let mnemonic = mnemonic_; +} + +def TPU_CoreType : I32EnumAttr<"CoreType", "Core type", [ + I32EnumAttrCase<"kTc", 0, "tc">, + I32EnumAttrCase<"kScScalarSubcore", 1, "sc_scalar_subcore">, + I32EnumAttrCase<"kScVectorSubcore", 2, "sc_vector_subcore"> +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tpu"; +} + +def TPU_CoreTypeEnum : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def TPU_PipelineMode : I32EnumAttr<"PipelineMode", "Pipeline mode", [ + I32EnumAttrCase<"kSynchronous", 1, "synchronous">, + I32EnumAttrCase<"kDoubleBuffered", 2, "double_buffered"> + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tpu"; +} + +def TPU_PipelineModeEnum : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def TPU_SemaphoreType : TPU_Type<"Semaphore", "semaphore", [MemRefElementTypeInterface]>; +def TPU_DMASemaphoreType : TPU_Type<"DMASemaphore", "dma_semaphore", [MemRefElementTypeInterface]>; +def TPU_SomeSemaphoreType : AnyTypeOf<[TPU_SemaphoreType, TPU_DMASemaphoreType]>; + +def TPU_Float8EXMYType : TPU_Type<"Float8EXMY", "float8_exmy", + [DeclareTypeInterfaceMethods]> { + let summary = "EXMY type in a 8 bit container"; + let description = [{ + EXMY type in a 8 bit container. Meaningful bits are aligned to LSB, and + bits higher than the underlying exmy type in the container are considered + as ignored. See https://arxiv.org/abs/2405.13938 for more details. + }]; + + let parameters = (ins + TypeParameter<"::mlir::FloatType", "Underlying EXMY type">:$underlying_type + ); + + let assemblyFormat = [{ + `<` $underlying_type `>` + }]; +} + +def TPU_DimensionSemantics : I32EnumAttr<"DimensionSemantics", "Dimension semantics", [ + I32EnumAttrCase<"parallel", 0>, + I32EnumAttrCase<"arbitrary", 1>, + I32EnumAttrCase<"core_parallel", 2>, + I32EnumAttrCase<"subcore_parallel", 3> +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tpu"; +} + +def TPU_DimensionSemanticsEnum + : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +// All indices/sizes are in element-space. +// Note that the implementation will require statically provable tile alignment. +def TPU_ElementWindowAttr : TPU_Attr<"ElementWindow", "element_window"> { + // Including low padding, to avoid backwards-incompatible changes once we add it. + let parameters = (ins + ArrayRefParameter<"int64_t", "">:$pad_low, + ArrayRefParameter<"int64_t", "">:$pad_high + ); + let assemblyFormat = "`<` `[` $pad_low `]` `,` `[` $pad_high `]` `>`"; +} + +def TPU_ContractPrecision : I32EnumAttr<"ContractPrecision", "Contraction precision", [ + I32EnumAttrCase<"kBF16", 0, "bf16">, + I32EnumAttrCase<"kFP32", 1, "fp32"> +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tpu"; +} + +def TPU_ContractPrecisionEnum + : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def TPU_PackFormat : I32EnumAttr<"PackFormat", "Pack format", [ + I32EnumAttrCase<"kCompressed", 0, "compressed">, + I32EnumAttrCase<"kInterleaved", 1, "interleaved"> +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tpu"; +} + +def TPU_PackFormatEnum : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def TPU_TiledCase : I32EnumAttrCase<"tiled", 0>; +def TPU_LaneCase : I32EnumAttrCase<"lanes", 1>; +def TPU_SublaneCase : I32EnumAttrCase<"sublanes", 2>; +def TPU_VectorLayoutDim : I32EnumAttr< + "VectorLayoutDim", "", [TPU_TiledCase, TPU_LaneCase, TPU_SublaneCase]>; + +def TPU_VectorLayoutAttr : TPU_Attr<"VectorLayout", "vpad"> { + let description = [{TODO}]; + + let parameters = (ins "Layout":$layout); + let hasCustomAssemblyFormat = 1; +} + +def TPU_TiledLayoutAttr + : TPU_Attr<"TiledLayout", "tiled", + [DeclareAttrInterfaceMethods]> { + let description = [{ + This attribute represents tiled layouts in memrefs. + + Multiple levels of tiling are supported with the following restriction: + - Additional levels of tiling may not add any padding. + - Additional levels of tiling may not tile previously untiled dimensions, + that is, they cannot tile across first-level tiles. + + Tile strides encode the stride when moving along a given dimension. They + must have the same rank as the shape and must be decreasing with increasing + dimension number. For tiled dimensions, the stride applies only when moving + across first-level tiles. The strides are in units of the size of the first + tile, or 1 if there are no tiles. + }]; + let parameters = (ins + ArrayRefParameter<"::xla::Tile", "">:$tiles, + ArrayRefParameter<"int64_t", "">:$tile_strides + ); + let extraClassDeclaration = [{ + static ::llvm::SmallVector getDefaultTileStrides(::llvm::ArrayRef<::xla::Tile> tiles, ::llvm::ArrayRef shape); + bool tilesAreKnownContiguous(::llvm::ArrayRef shape) const; + + int64_t getRank() const { + return getTileStrides().size(); + } + int64_t getUntiledRank() const; + + ::llvm::SmallVector getExpandedShape(::llvm::ArrayRef shape) const; + ::llvm::SmallVector getExpandedStrides() const; + }]; + + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; +} + +def TPU_MemorySpace : I32EnumAttr<"MemorySpace", "Memory space", [ + I32EnumAttrCase<"kAny", 4294967295, "any">, + I32EnumAttrCase<"kVmem", 0, "vmem">, + I32EnumAttrCase<"kSmem", 1, "smem">, + I32EnumAttrCase<"kHbm", 2, "hbm">, + I32EnumAttrCase<"kCmem", 3, "cmem">, + I32EnumAttrCase<"kSemaphoreMem", 4, "semaphore_mem">, + I32EnumAttrCase<"kVmemShared", 5, "vmem_shared">, + I32EnumAttrCase<"kHost", 6, "host"> +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tpu"; +} + +def TPU_MemorySpaceEnum + : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +class TPU_Op traits = []> : + Op { +} + +def DefaultMemWrite : MemoryEffects<[MemWrite]>; +def DefaultMemRead : MemoryEffects<[MemRead]>; + +def TPU_ReductionKind : I32EnumAttr<"ReductionKind", "Reduction kind", [ + I32EnumAttrCase<"kSum", 0, "sum">, + I32EnumAttrCase<"kMax", 1, "max">, + I32EnumAttrCase<"kMin", 2, "min">, + I32EnumAttrCase<"kArgMax", 3, "arg_max">, + I32EnumAttrCase<"kArgMin", 4, "arg_min">, + I32EnumAttrCase<"kFindFirstSet", 5, "find_first_set"> +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tpu"; +} + +def TPU_ReductionKindAttr + : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def TPU_AllReduceOp : TPU_Op<"all_reduce", [Pure]> { + let arguments = (ins AnyVectorOfNonZeroRank:$input, I64Attr:$dim, TPU_ReductionKindAttr:$kind); + let results = (outs AnyVectorOfNonZeroRank:$output); + let assemblyFormat = [{ + $input attr-dict `:` type($input) `->` type($output) + }]; + let hasVerifier = 1; +} + +def TPU_ReduceIndexOp : TPU_Op<"reduce_index", [Pure]> { + let arguments = (ins + AnyVectorOfNonZeroRank:$input, + I32Attr:$axis, + TPU_ReductionKindAttr:$kind + ); + let results = (outs VectorOfNonZeroRankOf<[I32]>:$output); + let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; + let hasVerifier = 1; +} + +// tpu.scan performs a scan across a vector. +// +// If a mask is provided, all output elements before the first unmasked input +// element is undefined. Subsequent masked elements will hold the result +// of the last unmasked element. +// +// For example, a "kSum" reduction over a input vector [1, 2, 3, 4] +// with mask [0, 1, 0, 1] will produce the output vector [X, 2, 2, 6]. +// where X is some undefined value. +// +// output : Result vector. Must have the same shape as source. +// input : Vector to scan. +// kind : Reduction operator. Must be one of "kSum", "kMax", or "kMin". +// Must be "kSum" if input is an I1 vector. +// mask : Elementwise vector mask. The scan operation starts from the +// lowest-indexed non-masked vector element (all previous elements +// have undefined values). Not taken for I1 input vectors. +def TPU_ScanOp : TPU_Op<"scan"> { + let arguments = (ins + VectorOfNonZeroRankOf<[I1, I16, I32, BF16, F32]>:$input, + TPU_ReductionKindAttr:$kind, + Optional>:$mask + ); + let results = (outs VectorOfNonZeroRankOf<[I16, I32, BF16, F32]>:$output); + let assemblyFormat = [{ + $kind `,` $input (`masked` $mask^)? attr-dict `:` type($input) `,` type($mask) `->` type($output) + }]; + let hasVerifier = 1; +} + +def TPU_SortOp : TPU_Op<"sort", [Pure]> { + let summary = "Sorts key/value pairs based on keys."; + let description = [{ + tpu.sort performs a stable sort of key/value pairs in ascending or + descending order based on keys. Masked-out keys and values are placed at the + end of the output vectors. An output mask indicates which outputs + correspond to the valid inputs. + }]; + let arguments = (ins + VectorOfNonZeroRankOf<[I32,F32]>:$keys, + VectorOfNonZeroRankOf<[I32,F32]>:$values, + Optional>:$mask, + DefaultValuedAttr:$descending + ); + let results = (outs + VectorOfNonZeroRankOf<[I1]>:$output_mask, + VectorOfNonZeroRankOf<[I32,F32]>:$sorted_keys, + VectorOfNonZeroRankOf<[I32,F32]>:$sorted_values + ); + let assemblyFormat = [{ + $keys `,` $values (`masked` $mask^)? attr-dict `:` functional-type(operands, results) + }]; + let hasVerifier = 1; +} + +def TPU_StoreOp : TPU_Op<"store", [DefaultMemWrite, AttrSizedOperandSegments]> { + let arguments = (ins + TPU_Vreg:$valueToStore, + AnyType:$base, + Variadic:$indices, + DenseBoolArrayAttr:$sublane_mask, + Optional:$mask, + OptionalAttr:$sublane_stride // In sublane-sized units + ); + let results = (outs); + let assemblyFormat = [{ + $base `[` $indices `]` `,` $valueToStore (`masked` $mask^)? `sublanes` $sublane_mask (`sublane_stride` $sublane_stride^)? attr-dict `:` type($base) `,` type($valueToStore) `,` type($mask) + }]; +} + +def TPU_LoadOp : TPU_Op<"load", [DefaultMemRead]> { + let arguments = (ins + AnyType:$base, + Variadic:$indices, + DenseBoolArrayAttr:$sublane_mask, + OptionalAttr:$sublane_stride // In sublane-sized units + ); + let results = (outs TPU_Vreg:$result); + let assemblyFormat = [{ + $base `[` $indices `]` `sublanes` $sublane_mask (`sublane_stride` $sublane_stride^)? attr-dict `:` type($base) `,` type($result) + }]; + let description = [{ + Similar to `vector::LoadOp` but with `sublane_mask` and `sublane_stride`. + When `indices` are negative, it means loading from negative offset + of `base` address. + }]; +} + +// TODO(jevinjiang): migrate tpu.strided_store to general vector store op. +def TPU_VectorStoreOp :TPU_Op<"vector_store", [DefaultMemWrite, AttrSizedOperandSegments]> { + let arguments = (ins + AnyVectorOfNonZeroRank:$valueToStore, + AnyMemRef:$base, + Variadic:$indices, + DenseI32ArrayAttr:$strides, + Optional:$mask, // Elementwise mask. + DefaultValuedAttr:$add + ); + let results = (outs); + let assemblyFormat = [{ + $base `[` $indices `]` `,` $valueToStore (`masked` $mask^)? attr-dict `:` type($base) `,` type($valueToStore) `,` type($mask) + }]; + let hasVerifier = 1; + let hasCanonicalizeMethod = 1; +} + +// tpu.vector_load loads a vector from memory into a register. +// +// base : Memref to load from. +// indices: Scalar indices into base. indices must be of the same rank as the +// base memref shape. +// strides: The stride to use for calculating the address of subsequent +// elements. If left unspecified, the stride is implicitly 1 along +// each dimension. Otherwise the stride must match the rank of the +// memref shape. +// mask : Elementwise vector mask. Must be broadcastable to the shape of the +// result vector. Depending on the core type, this may be a dynamic +// (lane) mask consumed from a register or a static (sublane) mask +// that must be the result of arith.constant. +def TPU_VectorLoadOp :TPU_Op<"vector_load", [DefaultMemRead, AttrSizedOperandSegments]> { + let arguments = (ins + AnyMemRef:$base, + Variadic:$indices, + DenseI32ArrayAttr:$strides, + Optional:$mask // Elementwise mask. + ); + let results = (outs AnyVectorOfNonZeroRank:$result); + let assemblyFormat = [{ + $base `[` $indices `]` (`masked` $mask^)? attr-dict `:` type($base) `,` type($result) `,` type($mask) + }]; + let hasVerifier = 1; + let hasCanonicalizeMethod = 1; +} + +def TPU_StridedLoadOp : TPU_Op<"strided_load", [DefaultMemRead]> { + let arguments = (ins + AnyMemRef:$base, + Variadic:$indices, + DenseI32ArrayAttr:$strides + ); + let results = (outs AnyVectorOfNonZeroRank:$result); + let assemblyFormat = [{ + $base `[` $indices `]` attr-dict `:` type($base) `,` type($result) + }]; + let hasVerifier = 1; +} + +def TPU_StridedStoreOp : TPU_Op<"strided_store", [DefaultMemWrite]> { + let arguments = (ins + AnyVectorOfNonZeroRank:$valueToStore, + AnyMemRef:$base, + Variadic:$indices, + DenseI32ArrayAttr:$strides + ); + let results = (outs); + let assemblyFormat = [{ + $base `[` $indices `]` `,` $valueToStore attr-dict `:` type($base) `,` type($valueToStore) + }]; + let hasVerifier = 1; +} + +// TODO: b/435258666 - Merge with tpu.vector_load_idx. +def TPU_ShuffledLoadOp : TPU_Op<"shuffled_load", [DefaultMemRead]> { + let arguments = (ins + AnyMemRef:$base, + Variadic:$indices, + DenseBoolArrayAttr:$sublane_mask, + DenseI32ArrayAttr:$sublane_offsets + ); + let results = (outs TPU_Vreg:$result); + let assemblyFormat = [{ + $base `[` $indices `]` attr-dict `:` type($base) `,` type($result) + }]; + let hasVerifier = 1; + let hasCanonicalizeMethod = 1; +} + +// TODO: b/435258666 - Merge with tpu.vector_store_idx. +def TPU_ShuffledStoreOp : TPU_Op<"shuffled_store", [DefaultMemWrite]> { + let arguments = (ins + TPU_Vreg:$valueToStore, + AnyMemRef:$base, + Variadic:$indices, + DenseBoolArrayAttr:$sublane_mask, + DenseI32ArrayAttr:$sublane_offsets + ); + let results = (outs); + let assemblyFormat = [{ + $base `[` $indices `]` `,` $valueToStore attr-dict `:` type($base) `,` type($valueToStore) + }]; + let hasVerifier = 1; + let hasCanonicalizeMethod = 1; +} + +// tpu.vector_load_idx loads values from arbitrary locations in memory. +// +// Each element in the output vector is loaded from an index in the base memref +// specified by the corresponding elements in the 'indices' vectors. The shape +// of each index vector must match the shape of the output vector. The number +// of index vectors must equal the rank of the base memref. +// +// For example, for a vector of length n with rank 2, the indices will look like: +// indices = [[idx0, idx1, ...], [idxn, idxn+1, ...]] +// where [idx0, idxn] is the offset of the first vector element. +// +// base : Memref specifying the base address. +// indices : Vectors of indices for each dimension of the base memref. +// mask : Optional elementwise vector mask. +def TPU_VectorLoadIdxOp :TPU_Op<"vector_load_idx", [DefaultMemRead, AttrSizedOperandSegments]> { + let arguments = (ins + MemRefOf<[I32, F32]>:$base, + Variadic>:$indices, + Optional>:$mask + ); + let results = (outs VectorOfNonZeroRankOf<[I32, F32]>:$value); + let assemblyFormat = [{ + $base `[` $indices `]` (`masked` $mask^)? attr-dict `:` type($base) `[` type($indices) `]` `,` type($value) `,` type($mask) + }]; + let hasVerifier = 1; + let hasCanonicalizeMethod = 1; +} + +// tpu.vector_store_idx stores values to arbitrary locations in memory. +// +// Each element in the input vector is stored to an index in the base memref +// specified by the corresponding elements in the 'indices' vectors. The shape +// of each index vector must match the shape of the input vector. The number +// of index vectors must equal the rank of the base memref. +// +// For example, for a vector of length n with rank 2, the indices will look like: +// indices = [[idx0, idx1, ...], [idxn, idxn+1, ...]] +// where [idx0, idxn] is the offset of the first vector element. +// +// When multiple vector elements have the same index to store to, the data from +// the highest lane will be the one stored. If add is true, then the data will +// be added from the lowest lane to the highest lane. +// +// valueToStore: Vector to be stored. +// base : Memref specifying the base address. +// indices : Vectors of indices for each dimension of the base memref. +// mask : Optional elementwise vector mask. +// add : If true, add source values to target values. Otherwise, overwrite. +def TPU_VectorStoreIdxOp :TPU_Op<"vector_store_idx", [DefaultMemWrite, AttrSizedOperandSegments]> { + let arguments = (ins + VectorOfNonZeroRankOf<[I32, F32]>:$valueToStore, + MemRefOf<[I32, F32]>:$base, + Variadic>:$indices, + Optional>:$mask, + DefaultValuedAttr:$add + ); + let results = (outs); + let assemblyFormat = [{ + $base `[` $indices `]` `,` $valueToStore (`masked` $mask^)? attr-dict `:` type($base) `[` type($indices) `]` `,` type($valueToStore) `,` type($mask) + }]; + let hasVerifier = 1; + let hasCanonicalizeMethod = 1; +} + +// TODO(jevinjiang): deprecate to use dynamic_rotate. +def TPU_RotateOp : TPU_Op<"rotate", [Pure, SameOperandsAndResultType]> { + let description = [{ + Rotates the given vector by the given amount in the given dimension, i.e., + for a 2D vector of shape (m, n), rotating dim 0 by `amount` will shift a row + at index `i` to index `(i + amount) % m` + }]; + let arguments = (ins + AnyVectorOfNonZeroRank:$value, + SI32Attr:$amount, + SI32Attr:$dimension, + // When the stride is specified, the rotation amount for each index on the + // stride dimension will be (amount + stride * index). + OptionalAttr:$stride, + OptionalAttr:$stride_dimension + ); + let results = (outs AnyVectorOfNonZeroRank:$result); + let assemblyFormat = [{ + $value `by` $amount `dim` $dimension (`stride` $stride `stride_dim` $stride_dimension^)? attr-dict `:` type($value) + }]; + let hasVerifier = 1; +} + +def TPU_DynamicRotateOp : TPU_Op<"dynamic_rotate", [Pure]> { + let arguments = (ins + AnyVectorOfNonZeroRank:$value, + I32:$amount, + SI32Attr:$dimension, + // When the stride is specified, the rotation amount for each index on the + // stride dimension will be (amount + stride * index). + OptionalAttr:$stride, + OptionalAttr:$stride_dimension + ); + let results = (outs AnyVectorOfNonZeroRank:$result); + let assemblyFormat = [{ + $value `by` $amount `dim` $dimension attr-dict `:` type($value) `,` type($amount) `->` type($result) + }]; + let hasVerifier = 1; +} + +def TPU_ScanCountOp : TPU_Op<"scan_count", [Pure, InferTypeOpAdaptor, SameOperandsAndResultShape]> { +let summary = [{ + ScanCountOp calculates the running duplicate occurrence count of the elements + in the input vector. Elements eligible for counting are specified by the + input mask vector. The output mask vector indicates one unique occurrence + per duplicate that was counted. + }]; + + let description = [{ + ScanCountOp calculates the running duplicate occurrence count of the elements + in the input vector, %values. The output vector, %counts, contains the running + duplicate occurrence count for the corresponding element in + the input vector, where the count is performed in ascending order of element + indices. For example, if the elements of %values at indices 0, 5, and 7 had + duplicate values, then the elements of %counts at indices 0, 5, and 7 would + be 1, 2, and 3, respectively. + + A mask vector, %in_mask, specifies which of the elements in the input vector + are eligible for counting. An element in %values that has its mask set to 0 + will always have a count of 1 in %counts, regardless of the position in the + vector, or whether there were duplicates or not. + }]; + + let arguments = (ins + VectorOfNonZeroRankOf<[I1]>:$in_mask, + AnyVectorOfNonZeroRank:$values + ); + let results = (outs + VectorOfNonZeroRankOf<[I1]>:$out_mask, + VectorOfNonZeroRankOf<[I32]>:$counts + ); + + let assemblyFormat = [{ + `mask` `(` $in_mask `:` type($in_mask) `)` + `value` `(` $values `:` type($values) `)` + attr-dict `:` type(results) + }]; + +} + +def TPU_IotaOp : TPU_Op<"iota", [Pure]> { + let description = [{ + Creates a vector that with values that start at 0 and increase along a + dimension resulting from collapsing the given `dimensions` together in + row-major order. + + Example: + ``` + tpu.iota {dimensions = array} : vector<4x3x2xi16> + ``` + This produces a vector with the following values: + ``` + [[[0, 4], [0, 4], [0, 4]] + [[1, 5], [1, 5], [1, 5]] + [[2, 6], [2, 6], [2, 6]] + [[3, 7], [3, 7], [3, 7]]] + ``` + }]; + let arguments = (ins DenseI32ArrayAttr:$dimensions); + let results = (outs VectorOfNonZeroRankOf<[AnyInteger, Index]>:$output); + let assemblyFormat = [{ attr-dict `:` type($output) }]; + let hasVerifier = 1; +} + +def TPU_ReshapeOp : TPU_Op<"reshape", [Pure]> { + let arguments = (ins AnyVectorOfNonZeroRank:$source); + let results = (outs AnyVectorOfNonZeroRank:$result); + let assemblyFormat = [{ $source attr-dict `:` type($source) `->` type($result) }]; + let hasVerifier = 1; + let hasFolder = 1; +} + +// TODO(mvoz): deprecated - use concat. Canonicalization will do so automatically. +// b/376295711 +def TPU_RepeatOp : TPU_Op<"repeat", [Pure]> { + let arguments = (ins + AnyVectorOfNonZeroRank:$source, + I32Attr:$dimension, + I32Attr:$times + ); + let results = (outs AnyVectorOfNonZeroRank:$output); + let assemblyFormat = [{ $source `,` $dimension `x` $times attr-dict `:` type($source) `->` type($output) }]; +} + +def TPU_BroadcastInSublanesOp : TPU_Op<"broadcast_in_sublanes", [Pure]> { + let description = [{ + For each sublane `i`, broadcasts the value in lane `lane + i` along the + entire sublane. For packed type, imagine the data is compressed unpacked + along sublane dimension, and the sublane count is multiplied by the packing + factor. + For example, for i16 with sublane count 8, `i` above is in [0, 8 * 2). + If `lane + i` is not in [0, lane_count), then the value in sublane `i` is + not defined (can be anything). + }]; + let arguments = (ins + TPU_Vreg:$source, // All sublanes should be equal. + I32Attr:$lane // Coordinates of the first element to take. + ); + // Output shape should be the same, except for position dim which contains + // the newly inserted dimension. + let results = (outs AnyVectorOfNonZeroRank:$output); + let assemblyFormat = [{ + $source `,` $lane attr-dict `:` type($source) `->` type($output) + }]; +} + +// Integer unpacks are always signed at the moment. +// +// When unpacking integers to integers, setting `sign_extended` to false will +// leave bits higher than source bitwidth as undefined. +// +// Take int4 to int16 interleaved unpacking and `index = 1` as an example: +// +// Source: +// +// Bits 28 24 20 16 12 8 4 0 +// --------abcd------------efgh---- +// +// where "a" and "e" are the sign bits of the values to be unpacked, and "-" are +// bits to be ignored. +// +// Unpacked, sign_extend = true: +// +// Bits 28 24 20 16 12 8 4 0 +// aaaaaaaaaaaaabcdeeeeeeeeeeeeefgh +// +// Unpacked, sign_extend = false: +// +// Bits 28 24 20 16 12 8 4 0 +// ------------abcd------------efgh +def TPU_UnpackSubelementsOp : TPU_Op<"unpack_subelements", [Pure]> { + let arguments = (ins + AnyVectorOfNonZeroRank:$source, + I32Attr:$index, + TPU_PackFormatEnum:$pack_format, + DefaultValuedAttr:$sign_extended + ); + let results = (outs AnyVectorOfNonZeroRank:$output); + let assemblyFormat = [{ $source `,` $index attr-dict `:` type($source) `->` type($output) }]; + let hasVerifier = 1; + let hasCanonicalizeMethod = 1; +} + +// Integer packs are always signed at the moment. +// Float to integer packing rounds to nearest even. +// WARNING: pack(pack(a, b), pack(c, d)) == pack(a, b, c, d) only holds for +// compressed packing! +// Below, we use [ ... ] to denote the bounds of the vreg and use regular parens +// ( ... ) to denote packing of multiple subelements into a single 32-bit word. +// +// Interleaved packing +// +// Interleaved packing downcasts to a narrower dtype, and packs multiple elements +// into the same word coordinate from which they originated. If a and b are packed +// values, then interleaved packing first iterates over the operand list and only +// then over the subelements within each word. +// Take 16-bit vregs A, B, C and D: +/// +// [ (A000 A001) (A010 A011) ... ] +// [ (A100 A101) (A110 A111) ... ] +// ... +// +// An interleaved pack(a, b) from 16-bit values produces: +// +// [ (A000 B000 A001 B001) (A010 B010 A011 B011) ...] +// ... +// +// While an interleaved pack(a, b, c, d) produces the following subelements in +// each vreg word: +// +// [ (A000 B000 C000 D000 A001 B001 C001 D001) ... ] +// ... +// +// Compressed packing +// +// Compressed packing downcasts each value and then packs multiple rows together. +// A compressed pack(a, b) from 16-bit values produces: +// +// [ (A000 A001 A100 A101) (A010 A011 A110 A111) ... ] +// [ (A200 A201 A300 A301) (A210 A211 A310 A311) ... ] +// ... # 2 more sublanes +// [ (B000 B001 B100 B101) (B010 B011 B110 B111) ... ] +// [ (B200 B201 B300 B301) (B210 B211 B310 B311) ... ] +// ... +def TPU_PackSubelementsOp : TPU_Op<"pack_subelements", [Pure, SameTypeOperands]> { + let arguments = (ins + Variadic:$sources, + DenseI32ArrayAttr:$positions, + TPU_PackFormatEnum:$pack_format + ); + let results = (outs TPU_Vreg:$output); + let assemblyFormat = [{ $sources attr-dict `:` type($sources) `->` type($output) }]; + let builders = [ + OpBuilder<(ins "::mlir::VectorType":$output_type, "::mlir::ArrayRef<::mlir::Value>":$padded_sources, "::mlir::tpu::PackFormat":$pack_format)>, + ]; + let extraClassDeclaration = [{ + static ::mlir::SmallVector<::mlir::Value> getPaddedSources(::mlir::ValueRange sources, ::mlir::ArrayRef positions, int packing_factor); + }]; + let hasVerifier = 1; +} + +def TPU_PackElementwiseOp : TPU_Op<"pack_elementwise", [Pure, SameTypeOperands, ElementwiseMappable]> { + let description = [{ + Packs multiple `sources` elementwise into a single vector of a narrower `target_type`. + + The number of `sources` must equal the packing factor, which is the ratio of + the element bitwidth of the `sources` to the element bitwidth of the + `target_type`. Elements from the `sources` are interleaved and packed into + each word of the `output`, ordered from lowest to highest bits, + corresponding to their order in the `sources`. + }]; + let arguments = (ins + Variadic>:$sources, + TypeAttr:$target_type + ); + let results = (outs VectorOfNonZeroRankOf<[I32]>:$output); + let assemblyFormat = [{ $sources attr-dict `:` type($sources) `->` type($output) }]; + let hasVerifier = 1; +} + +def TPU_UnpackElementwiseOp : TPU_Op<"unpack_elementwise", [Pure, ElementwiseMappable]> { + let description = [{ + Unpacks a single vector from `source`, which contains multiple `source_type` + vectors packed elementwise. + + The `index` selects which packed value to extract from each word of `source`. + An `index` of 0 corresponds to the lowest bits. The extracted values are + cast to the output element type. + }]; + let arguments = (ins + VectorOfNonZeroRankOf<[I32]>:$source, + TypeAttr:$source_type, + I32Attr:$index + ); + let results = (outs VectorOfNonZeroRankOf<[F32, I32]>:$output); + let assemblyFormat = [{ $source `,` $index attr-dict `:` type($source) `->` type($output) }]; + let hasVerifier = 1; +} + +def TPU_RelayoutOp : TPU_Op<"relayout", [Pure, SameOperandsAndResultType]> { + let arguments = (ins AnyVectorOfAnyRank:$input); + let results = (outs AnyVectorOfAnyRank:$output); + let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; + let hasVerifier = 1; +} + +def TPU_PackMaskOp : TPU_Op<"pack_vmsk", [Pure, SameTypeOperands]> { + let arguments = (ins + VectorOfNonZeroRankOf<[I1]>: $low, + VectorOfNonZeroRankOf<[I1]>: $high + ); + let results = (outs VectorOfNonZeroRankOf<[I1]>:$output); + let assemblyFormat = [{ $low `,` $high `,` attr-dict `:` type($low) `,` type($high) `->` type($output) }]; +} + +def TPU_GatherOp : TPU_Op<"gather", [Pure]> { + let arguments = (ins + AnyVectorOfNonZeroRank:$source, + DenseI32ArrayAttr:$indices, + I32Attr:$dimension + ); + let results = (outs AnyVectorOfNonZeroRank:$output); + let assemblyFormat = [{ + $source `[` $indices `]` `in` $dimension attr-dict + `:` type($source) `->` type($output) + }]; +} + +def TPU_DynamicGatherOp : TPU_Op<"dynamic_gather", [Pure, DeclareOpInterfaceMethods, AllShapesMatch<["indices", "output"]>, AllElementTypesMatch<["source", "output"]>]> { + let description = [{ + Gathers elements from `source` using `indices`. + + The specified `dimensions` of `source` are collapsed together and indexed by + `indices`. + + Given a shape `N0 x N1 x ...`, the `output[i0, i1, ...]` is given by + `collapsed_source[j0, j1, ..., indices[i0, i1, ...] mod M]` where + - `collapsed_source` is the result of collapsing `dimensions` of `source` + into a new trailing dimension of size `M`. + - `jk` is the subsequence of `in` for `n` not in `dimensions`. + + When a single dimension is specified, this is similar to + `np.take_along_axis`. + }]; + let arguments = (ins + AnyVectorOfNonZeroRank:$source, + VectorOfNonZeroRankOf<[AnyInteger]>:$indices, + DenseI32ArrayAttr:$dimensions + ); + let results = (outs AnyVectorOfNonZeroRank:$output); + let assemblyFormat = [{ + $source `[` $indices `]` `in` $dimensions attr-dict + `:` type($source) `,` type($indices) `->` type($output) + }]; + let hasVerifier = 1; +} + +def TPU_RoundingMode : I32EnumAttr<"RoundingMode", "Rounding mode", [ + I32EnumAttrCase<"kTowardsZero", 0, "towards_zero">, + I32EnumAttrCase<"kToNearestEven", 1, "to_nearest_even">, +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tpu"; +} + +def TPU_RoundingModeEnum : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +// Internal operation. All arith.fptosi operations that change the bitwidth +// must be canonicalized to this operation. +def TPU_FPToSIOp : TPU_Op<"fptosi", [Pure, ElementwiseMappable]> { + let arguments = (ins AnyVectorOfAnyRank:$input, TPU_RoundingModeEnum:$rounding_mode); + let results = (outs AnyVectorOfAnyRank:$output); + let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; + let hasCanonicalizeMethod = 1; +} + +// Internal operation. All arith.sitofp operations that change the bitwidth +// must be canonicalized to this operation. +def TPU_SIToFPOp : TPU_Op<"sitofp", [Pure, ElementwiseMappable]> { + let arguments = (ins AnyType:$in, TPU_RoundingModeEnum:$rounding_mode); + let results = (outs AnyType:$output); + let assemblyFormat = [{ $in attr-dict `:` type($in) `->` type($output) }]; +} + +// Internal operation. +def TPU_ExtFOp : TPU_Op<"extf", [Pure, ElementwiseMappable]> { + let arguments = (ins AnyType:$in); + let results = (outs AnyType:$out); + let assemblyFormat = [{ $in attr-dict `:` type($in) `->` type($out) }]; + let hasFolder = 1; +} + +// Internal operation. +def TPU_TruncFOp : TPU_Op<"truncf", [Pure, ElementwiseMappable]> { + let arguments = ( + ins AnyType:$in, + TPU_RoundingModeEnum:$rounding_mode + ); + let results = (outs AnyType:$out); + let assemblyFormat = [{ $in attr-dict `:` type($in) `->` type($out) }]; + let hasFolder = 1; +} + +// TODO(apaszke): Think hard about precision +def TPU_MatmulOp : TPU_Op<"matmul", [Pure]> { + let arguments = (ins + AnyVectorOfNonZeroRank:$lhs, + AnyVectorOfNonZeroRank:$rhs, + AnyVectorOfNonZeroRank:$acc, + // These flags are deprecated - if dimension_numbers are defined, + // these flags are ignored. They will always be false after canonicalize. + DefaultValuedAttr:$transpose_lhs, + DefaultValuedAttr:$transpose_rhs, + OptionalAttr:$precision, + // NOTE: User-level optional, once canonicalized, always present. + OptionalAttr:$dimension_numbers + ); + let results = (outs AnyVectorOfNonZeroRank:$result); + let assemblyFormat = [{ + $lhs `,` $rhs `,` $acc attr-dict `:` type($lhs) `,` type($rhs) `,` type($acc) `->` type($result) + }]; + let hasCanonicalizer = 1; + let hasVerifier = 1; +} + +def TPU_ConcatenateOp : TPU_Op<"concatenate", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins + Variadic:$sources, + I32Attr:$dimension + ); + let results = (outs AnyVectorOfNonZeroRank:$output); + let assemblyFormat = [{ + $sources `in` $dimension attr-dict `:` type($sources) `->` type($output) + }]; + let hasVerifier = 1; +} + +def TPU_BitcastOp : TPU_Op<"bitcast", [Pure]> { + let arguments = (ins AnyVectorOfNonZeroRank:$input); + let results = (outs AnyVectorOfNonZeroRank:$output); + let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; + let hasVerifier = 1; +} + +def TPU_BitcastVregOp : TPU_Op<"bitcast_vreg", [Pure]> { + let arguments = (ins TPU_Vreg:$input); + let results = (outs TPU_Vreg:$output); + let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; + let hasFolder = 1; +} + +def TPU_WeirdOp : TPU_Op<"weird", [Pure, ElementwiseMappable]> { + let arguments = (ins AnyType:$input); // F32 vector or scalar + let results = (outs AnyType:$output); // I1 vector or scalar + let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; + let hasVerifier = 1; +} + +def TPU_ReciprocalOp : TPU_Op<"reciprocal", [Pure, SameOperandsAndResultType, ElementwiseMappable]> { + let arguments = (ins + AnyVectorOfNonZeroRank:$input, + DefaultValuedAttr:$approx + ); + let results = (outs AnyVectorOfNonZeroRank:$output); + let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; + let hasVerifier = 1; +} + +def TPU_StochasticConvertOp : TPU_Op<"stochastic_convert", [Pure, SameOperandsAndResultShape]> { + let arguments = (ins + VectorOfNonZeroRankOf<[F32]>:$input, + VectorOfNonZeroRankOf<[I32]>:$random + ); + let results = (outs AnyVectorOfNonZeroRank:$output); + let assemblyFormat = [{ $input `,` $random attr-dict `:` type($input) `,` type($random) `->` type($output) }]; +} + +def TPU_StochasticConvertElementwiseOp : TPU_Op<"stochastic_convert_elementwise", [Pure, ElementwiseMappable]> { + // Stochastically converts the input to the target dtype based on the mode. + // When the target dtype is less than 32 bits, the result occupies the lowest {bitwidth} bits in the I32 output. + let arguments = (ins + VectorOfNonZeroRankOf<[F32]>:$input, + VectorOfNonZeroRankOf<[I32]>:$random, + TypeAttr:$dst_type + ); + let results = (outs VectorOfNonZeroRankOf<[I32]>:$output); + let assemblyFormat = [{ $input `,` $random attr-dict `:` type($input) `,` type($random) `->` type($output) }]; + let hasVerifier = 1; +} + +def TPU_RollVectorsOp : TPU_Op<"roll_vectors", [Pure]> { + let arguments = (ins Variadic:$input); + let results = (outs AnyVectorOfAnyRank:$output); + let assemblyFormat = [{ + $input attr-dict `:` type($input) `->` type($output) + }]; +} + +def TPU_UnrollVectorsOp : TPU_Op<"unroll_vectors", [Pure]> { + let arguments = (ins AnyVectorOfAnyRank:$input); + let results = (outs Variadic:$output); + let hasCanonicalizeMethod = 1; + let assemblyFormat = [{ + $input attr-dict `:` type($input) `->` type($output) + }]; +} + +def TPU_CreateMaskOp : TPU_Op<"create_mask", [Pure, SameVariadicOperandSize]> { + // high is exclusive + let arguments = (ins Variadic:$low, Variadic:$high); + let results = (outs AnyType:$output); + let assemblyFormat = [{ + `[` $low `]``[` $high `]` attr-dict `:` type($output) + }]; +} + +def TPU_CreateSubelementMaskOp : TPU_Op<"create_subelement_mask", [Pure]> { + let summary = "Create a mask masking contiguous rows of subelements."; + let description = [{ + The "half-sublanes", "quarter-sublanes", etc. (unit is determined by + the type of `output`) of the mask are masked in the range specified by + `from` and `to`. + + - If `from <= to`, the range `[from, to)` is set and the rest is unset. + - If `to <= from`, the range `[to, from)` is unset and the rest is set. + + All lanes are set identically. + + Example: + + ```mlir + %msk = tpu.create_subelement_mask 3, 9 : vector<8x128x2xi1> + ``` + + This creates a mask `%msk` where, for all `lane`s, `%msk[*][lane][*]` is: + + ``` + [[0, 0], [0, 1], [1, 1], [1, 1], [1, 0], [0, 0], [0, 0], [0, 0]] + ``` + + It is currently only supported: + - In TPU v4, for `num_subelems` of 1 and 2. + - In TPU v5, for `num_subelems` of 1, 2, and 4. + }]; + let arguments = (ins + I32Attr:$from, // inclusive + I32Attr:$to // exclusive + ); + let results = (outs AnyType:$output); // Verify this is a vmsk with num_subelems + let assemblyFormat = [{ + $from `,` $to attr-dict `:` type($output) + }]; +} + +def TPU_AssumeMultipleOp : TPU_Op<"assume_multiple", [Pure, SameOperandsAndResultType]> { + let summary = "Assumes that a value is a multiple of a given integer."; + let description = [{ + This operation is a hint to the compiler that the input `value` is guaranteed + to be a multiple of `multiple`. This can be used to satisfy divisibility checks + in some compiler passes. + + The result is the same as the input `value`. + + Example: + + ```mlir + %val = tpu.assume_multiple %arg0, 16 : index + ``` + }]; + let arguments = (ins + AnyTypeOf<[Index, AnyInteger]>:$value, + I32Attr:$multiple + ); + let results = (outs AnyTypeOf<[Index, AnyInteger]>:$result); + let assemblyFormat = [{$value `,` $multiple attr-dict `:` type($result)}]; + let hasVerifier = 1; +} + +def TPU_MemRefSliceOp : TPU_Op<"memref_slice", [Pure, AttrSizedOperandSegments]> { + let arguments = (ins + AnyMemRef:$mem_ref, + Variadic:$base_idx, + Variadic:$dynamic_sizes + ); + let results = (outs AnyMemRef:$result); + let assemblyFormat = [{ + $mem_ref `[` $base_idx `]` (`<` $dynamic_sizes^ `>`)? + attr-dict `:` type($mem_ref) `->` type($result) + }]; + let hasVerifier = 1; + let hasCanonicalizer = 1; +} + +def TPU_MemRefSqueezeOp : TPU_Op<"memref_squeeze", [Pure]> { + let arguments = (ins AnyMemRef:$input); + let results = (outs AnyMemRef:$result); + let assemblyFormat = [{ + $input attr-dict `:` type($input) `->` type($result) + }]; + let hasVerifier = 1; + let hasCanonicalizeMethod = 1; +} + +def TPU_MemRefReshapeOp : TPU_Op<"memref_reshape", [Pure]> { + let arguments = (ins AnyMemRef:$input); + let results = (outs AnyMemRef:$result); + let assemblyFormat = [{ + $input attr-dict `:` type($input) `->` type($result) + }]; + let hasVerifier = 1; + let hasCanonicalizeMethod = 1; +} + +def TPU_MemRefBitcastOp : TPU_Op<"memref_bitcast", [Pure]> { + let arguments = (ins AnyMemRef:$input); + let results = (outs AnyMemRef:$result); + let assemblyFormat = [{ + $input attr-dict `:` type($input) `->` type($result) + }]; + let hasVerifier = 1; + let hasCanonicalizeMethod = 1; +} + +def TPU_ReinterpretCastOp : TPU_Op<"reinterpret_cast", [Pure]> { + let arguments = (ins AnyMemRef:$input); + let results = (outs AnyMemRef:$result); + let assemblyFormat = [{ + $input attr-dict `:` type($input) `->` type($result) + }]; + let hasVerifier = 1; + let hasCanonicalizeMethod = 1; +} + +def TPU_AssumeLayoutOp : TPU_Op<"assume_layout", [Pure]> { + let arguments = (ins AnyType:$input); + let results = (outs AnyType:$result); + let assemblyFormat = [{ + $input attr-dict `:` type($input) `->` type($result) + }]; +} + +// Erases the layout attribute from the memref. +// +// The resulting memref is identical to the input, except that it has an +// identity layout. +def TPU_EraseLayoutOp : TPU_Op<"erase_memref_layout", [Pure, InferTypeOpAdaptor]> { + let arguments = (ins AnyMemRef:$operand); + let results = (outs AnyMemRef:$result); + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) `->` type($result) + }]; + let hasFolder = 1; +} + +// Returns the ID of the current device. +// +// On the input to the compiler the return value is a logical ID in the XLA +// device assignment. It changes to a physical ID after the +// logical-to-physical-device-id pass. +def TPU_DeviceIdOp : TPU_Op<"device_id", [Pure]> { + let arguments = (ins); + let results = (outs I32:$result); + let assemblyFormat = [{ attr-dict `:` type($result) }]; +} + +def TPU_SemaphoreReadOp : TPU_Op<"sem_read"> { + let arguments = (ins MemRefOf<[TPU_SemaphoreType, TPU_DMASemaphoreType]>:$semaphore); + let results = (outs I32:$result); + let assemblyFormat = [{ $semaphore attr-dict `:` type($semaphore) `->` type($result)}]; +} + +def TPU_SemaphoreWaitOp : TPU_Op<"sem_wait"> { + let arguments = (ins + MemRefOf<[TPU_SemaphoreType]>:$semaphore, + I32:$amount + ); + let results = (outs); + let assemblyFormat = [{ $semaphore `,` $amount attr-dict `:` type($semaphore)}]; + let hasVerifier = 1; +} + +def TPU_AllocaSemaphoreOp : TPU_Op<"sem_alloc"> { + let arguments = (ins); + let results = (outs MemRefOf<[TPU_SomeSemaphoreType]>:$result); + let assemblyFormat = [{ attr-dict `:` type($result) }]; +} + +def TPU_GetBarrierSemaphoreOp : TPU_Op<"sem_barrier"> { + let arguments = (ins); + let results = (outs MemRefOf<[TPU_SemaphoreType]>:$semaphore); + let assemblyFormat = [{ attr-dict `:` type($semaphore) }]; + let hasVerifier = 1; +} + +def TPU_SemaphoreSignalOp : TPU_Op<"sem_signal", [AttrSizedOperandSegments]> { + let arguments = (ins + MemRefOf<[TPU_SemaphoreType]>:$semaphore, + I32:$amount, + Optional:$device_id, // For remote DMAs + Optional:$core_id, // For megacore + OptionalAttr:$core_type + ); +let assemblyFormat = [{ + $semaphore `,` $amount (`device_id` $device_id^)? (`core_id` $core_id^)? (`core_type` $core_type^)? attr-dict `:` type($semaphore) + }]; + let hasVerifier = 1; + let builders = [ + // A backward-compatible builder that sets `core_type` to nullptr. + OpBuilder<(ins "Value":$semaphore, "Value":$amount, + "Value":$device_id, "Value":$core_id)>, + ]; +} + +def TPU_BarrierOp : TPU_Op<"barrier"> { + let summary = [{Barrier synchronization across SC vector subcores.}]; + let description = [{ + Performs barrier synchronization across all SC vector subcores at the + specified barrier id. + }]; + let arguments = (ins Index:$barrier_id); + let results = (outs); + let assemblyFormat = [{ `barrier_id` `(` $barrier_id `)` attr-dict }]; +} + +// tpu.enqueue_dma enqueues a DMA operation. +// +// source : Memref to copy from. +// source_semaphore : Semaphore to signal after the DMA completes. +// target : Memref to copy to. +// target_semaphore : Semaphore to wait on before the DMA completes. +// device_id : The id of the device to copy to for remote DMAs. +// core_id : The id of the core to copy to for remote and cross-core +// DMAs. +// priority : The priority of the DMA. +// strict_ordering : True if the DMA requires strict ordering. If false, the +// ordering is either strict or relaxed depending on the +// source and destination. +def TPU_EnqueueDMAOp : TPU_Op<"enqueue_dma", [AttrSizedOperandSegments]> { + let arguments = (ins + AnyMemRef:$source, + Optional>:$source_semaphore, // For remote DMAs + AnyMemRef:$target, + MemRefOf<[TPU_DMASemaphoreType]>:$target_semaphore, + Optional:$device_id, // For remote DMAs + Optional:$core_id, // For megacore + // Smaller number means higher priority. 0 is the highest and the default. + DefaultValuedAttr:$priority, + DefaultValuedAttr:$strict_ordering + ); + let assemblyFormat = [{ + `source` `(` $source `:` type($source) `)` + `target` `(` $target `:` type($target) `)` + (`source_semaphore` `(` $source_semaphore^ `:` type($source_semaphore) `)`)? + `target_semaphore` `(` $target_semaphore `:` type($target_semaphore) `)` + (`device_id` `(` $device_id^ `)`)? + (`core_id` `(` $core_id^ `)`)? + attr-dict + }]; + let hasVerifier = 1; + let hasCanonicalizeMethod = 1; +} + +// A base class for all ops that need to differentiate between gather and +// scatter. +class IndirectDMAOp { + code extraBaseClassDeclaration = [{ + // Return true if this op performs a gather. Returns false if it performs a + // scatter. + FailureOr isGather(); + }]; +} + +// tpu.enqueue_indirect_dma copies data between HBM and VMEM, or between +// VMEM_SHARED and VMEM using indirect HBM offsets. +// +// If the source is in HBM or VMEM_SHARED and the target is in VMEM, performs a +// gather from the source (operand) at the offsets to the target (gather +// result). +// If the source is in VMEM and the target is in HBM or VMEM_SHARED, performs a +// scatter of the source (updates) to the target (operand) at the offsets. +// +// source : Memref to copy from. +// target : Memref to copy to. +// offsets : Gather or scatter offsets. +// semaphore : Semaphore to wait on; receive semaphore for scatter, send semaphore for gather. +// add : If true, add source values to target values. Otherwise, overwrite. +// offset_filter : If set, don't write values at offsets whose value is equal to +// the filter value. +def TPU_EnqueueIndirectDMAOp : TPU_Op<"enqueue_indirect_dma">, IndirectDMAOp { + let arguments = (ins + AnyMemRef:$source, + AnyMemRef:$target, + AnyTypeOf<[MemRefOf<[I32]>, VectorOfRankAndType<[1], [I32]>]>:$offsets, + MemRefOf<[TPU_DMASemaphoreType]>:$semaphore, + Optional:$offset_filter, + DefaultValuedAttr:$add + ); + let assemblyFormat = [{ + `source` `(` $source `:` type($source) `)` + `target` `(` $target `:` type($target) `)` + `offsets` `(` $offsets `:` type($offsets) `)` + (`offset_filter` `(` $offset_filter^ `)`)? + `semaphore` `(` $semaphore `:` type($semaphore) `)` + attr-dict + }]; + let hasVerifier = 1; + let extraClassDeclaration = extraBaseClassDeclaration # [{ + LogicalResult verifyGather(MemRefType operand_ty, + ArrayRef offsets_shape, + MemRefType result_ty); + LogicalResult verifyScatter(MemRefType updates_ty, + ArrayRef offsets_shape, + MemRefType operand_ty); + }]; + let hasCanonicalizeMethod = 1; +} + +// tpu.wait_dma2 waits for a DMA to complete. +// +// The number of bytes to wait for is determined based on the size of the +// destination memref. +def TPU_WaitDMA2Op : TPU_Op<"wait_dma2", [AttrSizedOperandSegments]> { + let arguments = (ins + MemRefOf<[TPU_DMASemaphoreType]>:$semaphore, + AnyMemRef:$src, + AnyMemRef:$dst, + Optional:$device_id, // For remote DMAs + Optional:$core_id, // For megacore + DefaultValuedAttr:$strict_ordering + ); + let assemblyFormat = [{ + `semaphore` `(` $semaphore `:` type($semaphore) `)` + `src` `(` $src `:` type($src) `)` + `dst` `(` $dst `:` type($dst) `)` + (`device_id` `(` $device_id^ `)`)? + (`core_id` `(` $core_id^ `)`)? + attr-dict + }]; + let hasVerifier = 1; + // A backward-compatible builder that sets `device_id` and `core_id` to nullptr. + let builders = [ + OpBuilder<(ins "Value":$semaphore, "Value":$src, "Value":$dst)> + ]; + let hasCanonicalizeMethod = 1; +} + +// TODO(b/395630795): Remove after 2025-08-10. +def TPU_WaitDMAOp : TPU_Op<"wait_dma"> { + let arguments = (ins + MemRefOf<[TPU_DMASemaphoreType]>:$semaphore, + AnyMemRef:$ref + ); + let hasVerifier = 1; +} + +// Like tpu.wait_dma2, but for indirect DMAs. +// +// The number of bytes to wait for is determined based on the size of the +// destination memref in a gather, and the size of the source memref in a +// scatter. The op differentiates between gather and scatter based on the memory +// spaces of the source and destination memrefs. +def TPU_WaitIndirectDMAOp : TPU_Op<"wait_indirect_dma">, IndirectDMAOp { + let arguments = (ins + MemRefOf<[TPU_DMASemaphoreType]>:$semaphore, + AnyMemRef:$src, + AnyMemRef:$dst + ); + let assemblyFormat = [{ + `semaphore` `(` $semaphore `:` type($semaphore) `)` + `src` `(` $src `:` type($src) `)` + `dst` `(` $dst `:` type($dst) `)` + attr-dict + }]; + let hasVerifier = 1; + let hasCanonicalizeMethod = 1; + let extraClassDeclaration = extraBaseClassDeclaration; +} + +def TPU_RegionOp : TPU_Op<"region", [RecursiveMemoryEffects, SingleBlockImplicitTerminator<"tpu::YieldOp">]> { + let arguments = (ins); + let results = (outs Variadic:$results); + let regions = (region AnyRegion:$region); + let hasVerifier = 1; +} + +def TPU_TraceOp : TPU_Op<"trace", [RecursiveMemoryEffects, SingleBlockImplicitTerminator<"tpu::YieldOp">]> { + let arguments = (ins StrAttr:$message, I32Attr:$level); + let results = (outs Variadic:$results); + let regions = (region AnyRegion:$region); +} + +def TPU_TraceStartOp : TPU_Op<"trace_start", []> { + let arguments = (ins StrAttr:$message, I32Attr:$level); + let results = (outs); +} + +def TPU_TraceStopOp : TPU_Op<"trace_stop", []> { + let arguments = (ins); + let results = (outs); +} + +def TPU_YieldOp : TPU_Op<"yield", [Pure, ReturnLike, Terminator]> { + let arguments = (ins Variadic:$results); + let assemblyFormat = [{ attr-dict ($results^ `:` type($results))? }]; +} + +def TPU_DelayOp : TPU_Op<"delay"> { + let arguments = (ins I32:$nanos); + let results = (outs); +} + +// Expands the granularity of mask to subelements. +def TPU_MaskCastOp : TPU_Op<"mask_cast", [Pure]> { + let description = [{ + Cast a mask register into a different packing. + + If casting to a type with smaller packing, then values being packed together + must be identical. For example, for 8x128x4xi1 -> 8x128x2xi1, + input[i, j, 0] == input[i, j, 1] and input[i, j, 2] == input[i, j, 3] must + hold for all i, j. Otherwise, the result is undefined. + }]; + let arguments = (ins VectorOfNonZeroRankOf<[I1]>:$input); + let results = (outs VectorOfNonZeroRankOf<[I1]>:$result); + let assemblyFormat = [{ + $input attr-dict `:` type($input) `->` type($result) + }]; + let hasVerifier = 1; +} + +def TPU_GetIterationBoundOp : TPU_Op<"iteration_bound"> { + let arguments = (ins I32Attr:$dim); + let results = (outs I32:$result); + let assemblyFormat = [{ $dim attr-dict `:` type($result) }]; +} + +def TPU_GetInternalScratchOp : TPU_Op<"internal_scratch"> { + let arguments = (ins); + let results = (outs AnyMemRef:$result); + let assemblyFormat = [{ attr-dict `:` type($result) }]; +} + +def TPU_PRNGSeed32Op : TPU_Op<"prng_set_seed_32"> { + let arguments = (ins Variadic:$seeds); + let results = (outs); +} + +def TPU_PRNGRandomBitsOp : TPU_Op<"prng_random_bits"> { + let arguments = (ins); + let results = (outs AnyVectorOfNonZeroRank:$output); +} + +def TPU_SublaneShuffleOp : TPU_Op<"sublane_shuffle", [SameOperandsAndResultType]> { + // This op takes 2 physical vregs and a pattern, applies the pattern, + // and returns the result as 1 vreg. + // + // The pattern is a list of integers, where the integer value is the + // index of the sublane in the *combined input* [lhs, rhs], and the + // position of the integer in the list is the index of the sublane + // in the *output* vreg. + // + // The pattern size must match the operand/result sublane count. + // + // Example: + // %0 = tpu.single_output_sublane_shuffle %a, %b, + // [0, 1, 2, 3, 4, 5, 6, 7] // Result is %a + // %1 = tpu.single_output_sublane_shuffle %a, %b, + // [8, 9, 10, 11, 12, 13, 14, 15] // Result is %b + // %2 = tpu.single_output_sublane_shuffle %a, %b, + // [7, 6, 5, 4, 11, 10, 9, 8] // Result uses high half of a + // // and low half of b, reversed. + let arguments = (ins + TPU_Vreg:$lhs, + TPU_Vreg:$rhs, + DenseI32ArrayAttr:$pattern + ); + let results = (outs TPU_Vreg:$result); + let assemblyFormat = [{ + $lhs `,` $rhs `,` $pattern attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + }]; + + let hasVerifier = 1; +} + +def TPU_TransposeOp : TPU_Op<"transpose", [Pure]> { + let summary = "tpu transpose operation"; + let arguments = (ins AnyVectorOfAnyRank:$vector, + DenseI64ArrayAttr:$permutation); + let results = (outs AnyVectorOfAnyRank:$result); + + let assemblyFormat = [{ + $vector `,` $permutation attr-dict `:` type($vector) `->` type($result) + }]; + let extraClassDeclaration = [{ + VectorType getSourceVectorType() { + return ::llvm::cast(getVector().getType()); + } + VectorType getResultVectorType() { + return ::llvm::cast(getResult().getType()); + } + }]; + let hasVerifier = 1; +} + +def TPU_LogOp : TPU_Op<"log"> { + let arguments = (ins + Variadic:$inputs, + StrAttr:$tag, + DefaultValuedAttr:$formatted + ); + let results = (outs); + let assemblyFormat = [{ $tag attr-dict (`:` `[` $inputs^ `]` `:` type($inputs))? }]; + let hasVerifier = 1; +} + +def TPU_LogBufferOp : TPU_Op<"log_buffer"> { + let arguments = (ins + AnyMemRef:$input, + DenseI64ArrayAttr:$shape, + StrAttr:$tag + ); + let results = (outs); + let assemblyFormat = [{ $tag attr-dict `:` $input `:` type($input) }]; + let hasVerifier = 1; +} + +#endif // TPU_OPS diff --git a/jaxlib/mosaic/python/BUILD b/jaxlib/mosaic/python/BUILD index 6e575fb3092a..d13b5efe5fbe 100644 --- a/jaxlib/mosaic/python/BUILD +++ b/jaxlib/mosaic/python/BUILD @@ -45,7 +45,7 @@ gentbl_filegroup( tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = ":tpu_python.td", deps = [ - "//jaxlib/mosaic:tpu_td_files", + "//jaxlib/mosaic:tpu_ops_td_files", "@llvm-project//mlir:OpBaseTdFiles", ], ) diff --git a/jaxlib/mosaic/python/tpu_python.td b/jaxlib/mosaic/python/tpu_python.td index 56abaadd7f36..a6abf92116b3 100644 --- a/jaxlib/mosaic/python/tpu_python.td +++ b/jaxlib/mosaic/python/tpu_python.td @@ -13,4 +13,4 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -include "jaxlib/mosaic/dialect/tpu/tpu.td" +include "jaxlib/mosaic/dialect/tpu/tpu_ops.td" From 32eb215034ad9d01992d14d5e467470bf6ef4eab Mon Sep 17 00:00:00 2001 From: Krishna Haridasan Date: Fri, 12 Dec 2025 09:35:48 -0800 Subject: [PATCH 181/315] Reverts cc17df39c680feddef7f12d5ac7e0cb18d39e8f1 PiperOrigin-RevId: 843723663 --- jaxlib/BUILD | 1 - jaxlib/call_location.cc | 21 +++++++++++++++------ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index ebcd42f05304..c50200aa6da0 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -516,7 +516,6 @@ cc_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@nanobind", diff --git a/jaxlib/call_location.cc b/jaxlib/call_location.cc index 8a5558d361ff..b556f5e6ee62 100644 --- a/jaxlib/call_location.cc +++ b/jaxlib/call_location.cc @@ -24,7 +24,6 @@ limitations under the License. #include "absl/base/no_destructor.h" #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" -#include "absl/log/check.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" @@ -32,10 +31,10 @@ limitations under the License. #include "nanobind/stl/optional.h" // IWYU pragma: keep #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep -#include "jaxlib/py_user_context.h" #include "jaxlib/traceback.h" -#include "xla/python/ifrt/attribute_map.h" +#include "jaxlib/py_user_context.h" #include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/user_context.h" #include "xla/python/pjrt_ifrt/pjrt_executable.h" @@ -125,9 +124,19 @@ void PopulateCallLocation(xla::ifrt::ExecuteOptions& options, } if (!call_location_str.empty()) { - CHECK_OK(options.custom_options->Set( - std::string(xla::ifrt::PjRtCompatibleLoadedExecutable::kCallLocation), - std::move(call_location_str))); + // Simplify this to use AttributeMap::Set(). + xla::ifrt::AttributeMap::Map attrs_map; + if (options.custom_options.has_value()) { + options.custom_options->ForEach( + [&](const std::string& key, + const xla::ifrt::AttributeMap::Value& value) { + attrs_map.insert({key, value}); + }); + } + attrs_map.insert( + {std::string(xla::ifrt::PjRtCompatibleLoadedExecutable::kCallLocation), + xla::ifrt::AttributeMap::StringValue(std::move(call_location_str))}); + options.custom_options.emplace(std::move(attrs_map)); } } From a33d8378966e2101f8dd393214f642eee12ae5c0 Mon Sep 17 00:00:00 2001 From: Dougal Date: Thu, 11 Dec 2025 16:11:05 -0500 Subject: [PATCH 182/315] Remove type parameters. Type checker doesn't like them --- jax/_src/interpreters/partial_eval.py | 4 ++-- jax/_src/lax/control_flow/loops.py | 24 +++++++++++-------- jax/_src/tree_util.py | 33 ++++++++++++++------------- 3 files changed, 34 insertions(+), 27 deletions(-) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index c92010022c27..8c87ec7e7191 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -2292,9 +2292,9 @@ def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals): @weakref_lru_cache def trace_to_jaxpr( fun: Callable, - in_avals: FlatTree[AbstractValue | core.AvalQDD], # (args, kwargs) pair + in_avals: FlatTree, # (args, kwargs) pair debug_info: core.DebugInfo -) -> tuple[ClosedJaxpr, PyTreeDef]: +) -> tuple[ClosedJaxpr, FlatTree]: config.enable_checks.value and debug_info.assert_arg_names(len(in_avals)) parent_trace = core.trace_ctx.trace trace = DynamicJaxprTrace(debug_info, parent_trace=parent_trace) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 929a0ab679f5..d46d100da0b2 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -84,7 +84,7 @@ def _stack(arrs: Sequence[Array], axis: int=0) -> Array: def _promote_weak_typed_input( in_val:Any, in_aval:AbstractValue, out_aval:AbstractValue - ) -> tuple[AbstractValue, bool]: + ) -> tuple[Any, bool]: if getattr(in_aval, 'weak_type', False) and not core.typematch(in_aval, out_aval): new_dtype = dtypes.result_type(in_val, out_aval) return lax.convert_element_type(in_val, new_dtype), True @@ -228,7 +228,7 @@ def scan(f, init, xs, length=None): return carry, stacked_y if config.mutable_array_checks.value: - check_no_aliased_ref_args(lambda: dbg_body, list(args), list(args_avals)) + check_no_aliased_ref_args(lambda: dbg_body, list(args_avals), list(args)) x_avals = xs_avals.map(lambda aval: core.mapped_aval(length, 0, aval)) def _create_jaxpr(carry_avals): @@ -252,6 +252,8 @@ def _create_jaxpr(carry_avals): if config.mutable_array_checks.value: _check_no_aliased_closed_over_refs(dbg_body, consts, list(args)) carry_out_avals, ys_avals = out_avals.unpack() + if len(carry_out_avals) != len(init_avals): + _check_carry_type('scan body', f, init_avals, carry_out_avals) init, changed = init.map3( _promote_weak_typed_input, init_avals, carry_out_avals).unzip2() @@ -277,7 +279,7 @@ def _create_jaxpr(carry_avals): if unroll < 0: raise ValueError("`unroll` must be a `bool` or a non-negative `int`.") - args_flat = (*init.vals, *xs.vals) + args_flat = [*init.vals, *xs.vals] # If the body forwards an input carry to an output carry, that input is # read-only and can be moved to be a const. Doing so can lead to efficiency @@ -381,18 +383,18 @@ def _check_carry_type(name, body_fun, in_carry, out_carry): if p else 'the input carry') if in_carry.tree != out_carry.tree: try: - out_carry = out_carry.unflatten() + out_carry_unflat = out_carry.unflatten() except: - out_carry = None + out_carry_unflat = None - if out_carry is None: + if out_carry_unflat is None: differences = (f'the input tree structure is:\n{in_carry.tree}\n' + f'the output tree structure is:\n{out_carry.tree}\n') else: diffs = [f'{component(path)} is a {thing1} but the corresponding component ' f'of the carry output is a {thing2}, so {explanation}' for path, thing1, thing2, explanation - in equality_errors(in_carry, out_carry)] + in equality_errors(in_carry.unflatten(), out_carry.unflatten())] if len(diffs) == 0: return # the trees may have different aux data, but structures are same elif len(diffs) == 1: @@ -1709,7 +1711,7 @@ def _create_jaxpr(init_avals): cond_dbg = api_util.debug_info("while_cond", cond_fun, (init_val,), {}) body_dbg = api_util.debug_info("while_body", body_fun, (init_val,), {}) - init_val = FlatTree.flatten(init_val) + init_val = FlatTree.flatten(init_val) # type: ignore init_aval = init_val.map(core.get_aval) # The body input and output avals must match exactly. However, we want to account for @@ -1718,6 +1720,10 @@ def _create_jaxpr(init_avals): # To do this, we compute the jaxpr in two passes: first with the raw inputs, and if # necessary, a second time with modified init values. cond_jaxpr, body_jaxpr, body_out_avals = _create_jaxpr(init_aval) + if len(body_out_avals) != len(init_aval): + _check_carry_type('while_loop body', body_fun, init_aval, body_out_avals) + assert False, "shouldn't get here" + init_val, changed = init_val.map3( _promote_weak_typed_input, init_aval, body_out_avals).unzip2() @@ -1749,7 +1755,7 @@ def _create_jaxpr(init_avals): _, keep_cond_carry = split_list(keep_cond, [len(cond_consts)]) move_to_const = _map(operator.not_, keep_cond_carry) - init_vals = list(init_val) + init_vals = list(init_val) # type: ignore if any(move_to_const): cond_jaxpr = pe.close_jaxpr(cond_jaxpr_) body_jaxpr = pe.prune_closed_jaxpr_outputs( diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index f6501fe0c4a0..5fc4c4018b9f 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -34,10 +34,6 @@ traceback_util.register_exclusion(__file__) T = TypeVar("T") -T1 = TypeVar("T1") -T2 = TypeVar("T2") -T3 = TypeVar("T3") -T4 = TypeVar("T4") Typ = TypeVar("Typ", bound=type[Any]) H = TypeVar("H", bound=Hashable) @@ -1357,20 +1353,18 @@ class FlatTree: the tuple-returning function would change the tree structure and `unzip` wouldn't be able to recover it. """ - def __init__(self, vals:Sequence[T], treedef:PyTreeDef): + def __init__(self, vals:Sequence, treedef:PyTreeDef): assert isinstance(treedef, pytree.PyTreeDef) self.tree = treedef - self.vals = list(vals) + self.vals = tuple(vals) - def map(self, f:Callable[[T1], T2]) -> FlatTree[T2]: + def map(self, f:Callable) -> FlatTree: ans_vals = [] for x in self.vals: ans_vals.append(f(x)) return FlatTree(ans_vals, self.tree) - def map2( - self:FlatTree[T1], f:Callable[[T1, T2], T3], - t2:FlatTree[T2]) -> FlatTree[T3]: + def map2(self:FlatTree, f:Callable, t2:FlatTree) -> FlatTree: n = len(self) assert len(t2) == n @@ -1380,8 +1374,7 @@ def map2( return FlatTree(ans_vals, self.tree) def map3( - self:FlatTree[T1], f:Callable[[T1, T2, T3], T4], - t2:FlatTree[T2], t3:FlatTree[T3]) -> FlatTree[T4]: + self:FlatTree, f:Callable, t2:FlatTree, t3:FlatTree) -> FlatTree: n = len(self) assert len(t2) == n and len(t3) == n ans_vals = [] @@ -1389,10 +1382,10 @@ def map3( ans_vals.append(f(x1, x2, x3)) return FlatTree(ans_vals, self.tree) - def zip(self, t2:FlatTree[T2]) -> FlatTree[tuple[T1, T2]]: + def zip(self, t2:FlatTree) -> FlatTree: assert False - def unzip2(self:FlatTree[tuple[T1, T2]]) -> tuple[FlatTree[T1], FlatTree[T2]]: + def unzip2(self:FlatTree) -> tuple[FlatTree, FlatTree]: ys = [] zs = [] for y, z in self.vals: @@ -1425,7 +1418,7 @@ def pack(tree): else: assert False - def unpack(self:FlatTree[tuple]) -> tuple[FlatTree]: + def unpack(self:FlatTree) -> tuple[FlatTree, ...]: # TODO: this is O(N) not O(1) (with N as the number of leaves). If it # becomes a problem we can fix it with a fancier data tree. trees = treedef_children(self.tree) @@ -1444,7 +1437,7 @@ def flatten(tree: PyTree) -> FlatTree: def unflatten(self) -> PyTree: return tree_unflatten(self.tree, self.vals) - def update_from_list(self, new_vals:list[T1]) -> FlatTree[T1]: + def update_from_list(self, new_vals:list) -> FlatTree: return FlatTree(new_vals, self.tree) def __len__(self): @@ -1452,3 +1445,11 @@ def __len__(self): def __iter__(self): return self.vals.__iter__() + + def __eq__(self, other): + return (isinstance(other, FlatTree) + and self.vals == other.vals + and self.tree == other.tree) + + def __hash__(self): + return hash((self.vals, self.tree)) From 60a0f0f499f4ae4d301da2ba1588d7144060789a Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 12 Dec 2025 10:12:01 -0800 Subject: [PATCH 183/315] [sparse] make caveat at top more prominent --- docs/jax.experimental.sparse.rst | 6 ------ jax/experimental/sparse/__init__.py | 13 ++++++++++--- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/docs/jax.experimental.sparse.rst b/docs/jax.experimental.sparse.rst index 37ef8ae43d67..f021cce3452f 100644 --- a/docs/jax.experimental.sparse.rst +++ b/docs/jax.experimental.sparse.rst @@ -1,12 +1,6 @@ ``jax.experimental.sparse`` module ================================== -.. note:: - - The methods in ``jax.experimental.sparse`` are experimental reference - implementations, and not recommended for use in performance-critical - applications. - .. automodule:: jax.experimental.sparse .. currentmodule:: jax.experimental.sparse diff --git a/jax/experimental/sparse/__init__.py b/jax/experimental/sparse/__init__.py index f388cd527cf9..dbd21e343bb7 100644 --- a/jax/experimental/sparse/__init__.py +++ b/jax/experimental/sparse/__init__.py @@ -15,10 +15,17 @@ """ .. currentmodule:: jax.experimental.sparse +.. note:: + + The methods in ``jax.experimental.sparse`` are experimental reference implementations, + and not recommended for use in performance-critical applications. The submodule is no + longer being actively developed, but the team will continue supporting existing features + as best we can. + The :mod:`jax.experimental.sparse` module includes experimental support for sparse matrix -operations in JAX. It is under active development, and the API is subject to change. The -primary interfaces made available are the :class:`BCOO` sparse array type, and the -:func:`sparsify` transform. +operations in JAX. The primary interfaces made available are the :class:`BCOO` sparse array +type, and the :func:`sparsify` transform. + Batched-coordinate (BCOO) sparse matrices ----------------------------------------- From 3d9edef6f663847aeb354449f36978861afdee34 Mon Sep 17 00:00:00 2001 From: Krishna Haridasan Date: Fri, 12 Dec 2025 11:08:04 -0800 Subject: [PATCH 184/315] Simplify adding call location to custom options. Use xla::ifrt::AttributeMap::Set() to add the call location string to options.custom_options, instead of rebuilding the entire map. Reverts 32eb215034ad9d01992d14d5e467470bf6ef4eab PiperOrigin-RevId: 843761265 --- jaxlib/BUILD | 1 + jaxlib/call_location.cc | 22 ++++++++-------------- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index c50200aa6da0..ebcd42f05304 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -516,6 +516,7 @@ cc_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@nanobind", diff --git a/jaxlib/call_location.cc b/jaxlib/call_location.cc index b556f5e6ee62..96855b114fb2 100644 --- a/jaxlib/call_location.cc +++ b/jaxlib/call_location.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/base/no_destructor.h" #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" @@ -31,10 +32,10 @@ limitations under the License. #include "nanobind/stl/optional.h" // IWYU pragma: keep #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep -#include "jaxlib/traceback.h" #include "jaxlib/py_user_context.h" -#include "xla/python/ifrt/executable.h" +#include "jaxlib/traceback.h" #include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/executable.h" #include "xla/python/ifrt/user_context.h" #include "xla/python/pjrt_ifrt/pjrt_executable.h" @@ -124,19 +125,12 @@ void PopulateCallLocation(xla::ifrt::ExecuteOptions& options, } if (!call_location_str.empty()) { - // Simplify this to use AttributeMap::Set(). - xla::ifrt::AttributeMap::Map attrs_map; - if (options.custom_options.has_value()) { - options.custom_options->ForEach( - [&](const std::string& key, - const xla::ifrt::AttributeMap::Value& value) { - attrs_map.insert({key, value}); - }); + if (!options.custom_options.has_value()) { + options.custom_options.emplace(xla::ifrt::AttributeMap({})); } - attrs_map.insert( - {std::string(xla::ifrt::PjRtCompatibleLoadedExecutable::kCallLocation), - xla::ifrt::AttributeMap::StringValue(std::move(call_location_str))}); - options.custom_options.emplace(std::move(attrs_map)); + CHECK_OK(options.custom_options->Set( + std::string(xla::ifrt::PjRtCompatibleLoadedExecutable::kCallLocation), + std::move(call_location_str))); } } From 3ef63cec873456f047e25d3ebe5387d288ba723b Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 12 Dec 2025 19:13:24 +0000 Subject: [PATCH 185/315] [mutable-arrays] ignroe InternalMutableArrayEffect in partial_eval has_effects It's not really an effect from the point of view of partial eval; like NamedAxisEffect, it's just a bit of info we're sneaking into the effect system as a handy way of tracking it. But it should be ignored for the purpose of deciding "is this code effectful". (We're hoping that InternalMutableArrayEffect, and the effect system more generally, and partial eval, are not long for this world...) This is a follow-up on #33906 Co-authored-by: Sharad Vikram --- jax/_src/ad_checkpoint.py | 3 ++- jax/_src/interpreters/partial_eval.py | 3 ++- tests/mutable_array_test.py | 28 +++++++++++++++++++++++++++ 3 files changed, 32 insertions(+), 2 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index a74bb7b7828e..88bdabc558bc 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -845,7 +845,8 @@ def remat_dce(used_outputs: list[bool], eqn: core.JaxprEqn pe.dce_rules[remat_p] = remat_dce def _has_effects(effects) -> bool: - return bool({e for e in effects if not isinstance(e, core.NamedAxisEffect)}) + not_really_effects = (core.NamedAxisEffect, core.InternalMutableArrayEffect) + return any(not isinstance(e, not_really_effects) for e in effects) def remat_expansion( diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 0dfffb0f3efa..5afe2fb33c7d 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1051,7 +1051,8 @@ def ensure_instantiated(inst: bool, x: Atom) -> Atom: return x def has_effects(effects) -> bool: - return bool({e for e in effects if not isinstance(e, core.NamedAxisEffect)}) + not_really_effects = (core.NamedAxisEffect, core.InternalMutableArrayEffect) + return any(not isinstance(e, not_really_effects) for e in effects) known_eqns, staged_eqns = [], [] foreach(write, in_unknowns, in_inst, jaxpr.invars) diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index ed23bbe8604e..332466321603 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -1053,6 +1053,34 @@ def f(x): f(3.) # don't crash + def test_remat_while_loop_residuals(self): + @jax.custom_vjp + def ra2a(x): + return jax.freeze(jax.new_ref(x)) + + def ra2a_fwd(x): + o = ra2a(x) + return o, () + + def ra2a_bwd(res, g): + return (ra2a(g),) + + ra2a.defvjp(ra2a_fwd, ra2a_bwd) + + @jax.jit + @jax.remat + def f(x): + + def g(x): + def body(carry): + i, x = carry + x = ra2a(x) + return i + 1, x + return jax.lax.while_loop(lambda x: x[0] < 5, body, (0, x))[1] + return g(x) + + jax.linearize(f, 5.) # don't crash + @jtu.with_config(jax_mutable_array_checks=True) class MutableArrayErrorsTest(jtu.JaxTestCase): From 94ae97fbb0acc983eb448636179c0b366b306543 Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Fri, 12 Dec 2025 12:51:34 -0800 Subject: [PATCH 186/315] Remove use of deprecated XLA CPU flags. These flags now have no effect. PiperOrigin-RevId: 843799361 --- tests/BUILD | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/BUILD b/tests/BUILD index e2487d2deaca..6c1fd8915878 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1238,9 +1238,6 @@ jax_multiplatform_test( "notsan", # Times out ], }, - env = { - "XLA_FLAGS": "--xla_cpu_use_xnnpack=false", # TODO(b/454581761): Reenable once we switch to YNNPACK. - }, shard_count = 4, deps = py_deps([ "absl/testing", From 6fccabe4c5e2c6535d0c2e92cf90a8c1502187e6 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 12 Dec 2025 14:15:40 -0800 Subject: [PATCH 187/315] Reverts 890ccd23c3728a152ad551e187cea3777af9e435 PiperOrigin-RevId: 843829527 --- CHANGELOG.md | 2 -- jax/_src/mesh.py | 5 ----- jax/_src/pjit.py | 2 +- jax/experimental/jax2tf/tests/tf_test_util.py | 2 -- tests/cache_key_test.py | 2 -- tests/export_back_compat_test.py | 4 ---- tests/fused_attention_stablehlo_test.py | 4 ---- tests/multiprocess/multihost_utils_test.py | 2 -- tests/multiprocess/pjit_test.py | 2 -- tests/pjit_test.py | 7 +------ tests/python_callback_test.py | 4 +--- 11 files changed, 3 insertions(+), 33 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a4896846a144..74ed57bb879c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,8 +19,6 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * Deprecations * `jax.lax.pvary` has been deprecated. Please use `jax.lax.pcast(..., to='varying')` as the replacement. - * `with mesh:` context manager has been deprecated. - Please use `with jax.set_mesh(mesh):` instead. * Complex arguments passed to {func}`jax.numpy.arange` now result in a deprecation warning, because the output is poorly-defined. diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 6ab9f7703324..f0210d5d9b82 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -24,7 +24,6 @@ import math import threading from typing import Any, NamedTuple -import warnings import numpy as np @@ -323,10 +322,6 @@ def __setattr__(self, name, value): def __enter__(self): if jax_config.disallow_mesh_context_manager.value: raise RuntimeError("Mesh context manager is disabled.") - warnings.warn( - "`with mesh:` context manager has been deprecated. Please use `with" - " jax.set_mesh(mesh):` instead.", - category=DeprecationWarning, stacklevel=2) new_env = thread_resources.stack[-1].with_mesh(self) thread_resources.stack.append(new_env) thread_resources.env = new_env diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 43671aef48d7..da0e7ccbbb9e 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -361,7 +361,7 @@ def _parse_jit_arguments(fun: Callable, *, in_shardings: Any, 'backend and device argument on jit is deprecated. You can use' ' `jax.device_put(..., jax.local_devices(backend="cpu")[0])` on the' ' inputs to the jitted function to get the same behavior.', - category=DeprecationWarning, stacklevel=2 + DeprecationWarning, ) if device is not None and backend is not None: raise ValueError("can't specify both a device and a backend for jit, " diff --git a/jax/experimental/jax2tf/tests/tf_test_util.py b/jax/experimental/jax2tf/tests/tf_test_util.py index 0514ae0530c9..df7e59a0d8ce 100644 --- a/jax/experimental/jax2tf/tests/tf_test_util.py +++ b/jax/experimental/jax2tf/tests/tf_test_util.py @@ -188,8 +188,6 @@ def setUp(self): self.assertGreaterEqual(version, export.minimum_supported_calling_convention_version) self.enter_context(config.jax_export_calling_convention_version(version)) - self.enter_context(jtu.ignore_warning( - category=DeprecationWarning, message='`with mesh:` context manager')) logging.info( "Using JAX serialization version %s (export.max_version %s, tf.XlaCallModule max version %s)", version, diff --git a/tests/cache_key_test.py b/tests/cache_key_test.py index 96faa47be7e2..35ac03011a97 100644 --- a/tests/cache_key_test.py +++ b/tests/cache_key_test.py @@ -165,8 +165,6 @@ def test_different_computations(self): # TODO(phawkins): this test flakes if test concurrency is enabled. @jtu.thread_unsafe_test() - @jtu.ignore_warning(category=DeprecationWarning, - message='`with mesh:` context manager') def test_custom_partitioning_ptr_removal(self): def _partition(mesh, arg_shapes, result_shape): arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index a7f9fa2ea73f..41a4b99ed944 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -784,8 +784,6 @@ def func(x): data = self.load_testdata(cuda_threefry2x32.data_2024_07_30) self.run_one_test(func, data) - @jtu.ignore_warning(category=DeprecationWarning, - message='`with mesh:` context manager') def test_tpu_sharding(self): # Tests "Sharding", "SPMDShardToFullShape", "SPMDFullToShardShape" on TPU if not jtu.test_device_matches(["tpu"]) or len(jax.devices()) < 2: @@ -1010,8 +1008,6 @@ def check_top_k_results(res_run, res_expected, *, rtol, atol): class ShardyCompatTest(bctu.CompatTestBase): - @jtu.ignore_warning(category=DeprecationWarning, - message='`with mesh:` context manager') def test_shardy_sharding_ops_with_different_meshes(self): # Tests whether we can save and load a module with meshes that have the # same axis sizes (and same order) but different axis names. diff --git a/tests/fused_attention_stablehlo_test.py b/tests/fused_attention_stablehlo_test.py index 1228b2160ab1..8df402e6e4ff 100644 --- a/tests/fused_attention_stablehlo_test.py +++ b/tests/fused_attention_stablehlo_test.py @@ -264,8 +264,6 @@ def dot_product_attention_fp8(query, key, value, fp8_metas): return out[0], (query_grad, key_grad, value_grad) -@jtu.ignore_warning(category=DeprecationWarning, - message='`with mesh:` context manager') class DotProductAttentionTest(jtu.JaxTestCase): def setUp(self): super().setUp() @@ -755,8 +753,6 @@ def generate_segment_mask(segment_ids, dtype): self.assertArraysAllClose(value_grad_ref, value_grad, rtol=1e-2, atol=1e-2) @jtu.run_on_devices("cuda") - @jtu.ignore_warning(category=DeprecationWarning, - message='`with mesh:` context manager') def test_sdpa_residual(self): k1, k2, k3, k4, k5 = jax.random.split(jax.random.key(0), 5) query = jax.random.normal( diff --git a/tests/multiprocess/multihost_utils_test.py b/tests/multiprocess/multihost_utils_test.py index fe7aec5c630d..d3ce2d5d6393 100644 --- a/tests/multiprocess/multihost_utils_test.py +++ b/tests/multiprocess/multihost_utils_test.py @@ -171,8 +171,6 @@ def test_sync_global_devices_error(self): else: multihost_utils.sync_global_devices('test message2') - @jtu.ignore_warning(category=DeprecationWarning, - message='`with mesh:` context manager') def test_sync_global_devices_mesh_context_manager(self): global_mesh = jtu.create_mesh((2, 2), ('x', 'y'), iota_order=True) with global_mesh: diff --git a/tests/multiprocess/pjit_test.py b/tests/multiprocess/pjit_test.py index 4ef9e1c4c2cf..79c0721ab66b 100644 --- a/tests/multiprocess/pjit_test.py +++ b/tests/multiprocess/pjit_test.py @@ -381,8 +381,6 @@ def _lower_compile(inp): for out in list(result): np.testing.assert_array_equal(out(x), expected_out) - @jtu.ignore_warning(category=DeprecationWarning, - message='`with mesh:` context manager') def test_fully_sharded_on_all_devices(self): if jax.local_device_count() > 1: self.skipTest("This test only works with 1 process per device.") diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 3517ea41ba14..6d4b3b31ab51 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -95,9 +95,8 @@ def check_1d_2d_mesh(f, set_mesh): ))(jtu.with_mesh_from_kwargs(f) if set_mesh else f) +# TODO(skye): make the buffer donation utils part of JaxTestCase @jtu.pytest_mark_if_available('multiaccelerator') -@jtu.ignore_warning(category=DeprecationWarning, - message='`with mesh:` context manager') class PJitTest(jtu.BufferDonationTestCase): @jtu.with_mesh([('x', 1)]) @@ -1492,8 +1491,6 @@ def test_pjit_array_error(self): @jtu.pytest_mark_if_available('multiaccelerator') -@jtu.ignore_warning(category=DeprecationWarning, - message='`with mesh:` context manager') class ArrayPjitTest(jtu.JaxTestCase): @parameterized.named_parameters( @@ -9860,8 +9857,6 @@ def f(): @jtu.pytest_mark_if_available('multiaccelerator') -@jtu.ignore_warning(category=DeprecationWarning, - message='`with mesh:` context manager') class PJitErrorTest(jtu.JaxTestCase): @check_1d_2d_mesh(set_mesh=True) diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 1423a8bcf268..9633bd49d444 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -1,4 +1,4 @@ -# Copyright 2022 The JAX Authors. +"# Copyright 2022 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -1138,8 +1138,6 @@ def f(x): self.assertEqual(count(), 1) -@jtu.ignore_warning(category=DeprecationWarning, - message='`with mesh:` context manager') class IOCallbackTest(jtu.JaxTestCase): def setUp(self): From b5cb67e3be98d5e1d3776cbf129d10c53ef13f98 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 12 Dec 2025 14:50:59 -0800 Subject: [PATCH 188/315] Update Shardy to e8435cb5c0b852b0e249b3fbf5f42dd51988afc9. Fix jax typo PiperOrigin-RevId: 843841248 --- tests/python_callback_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 9633bd49d444..3a70b08ea912 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -1,4 +1,4 @@ -"# Copyright 2022 The JAX Authors. +# Copyright 2022 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From bc639963703516d9e754c2f43fa486712d90ee8f Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 12 Dec 2025 15:46:58 -0800 Subject: [PATCH 189/315] Add support and tests for sharded -> unreduced operation. **But why do we need such an operation?** You might want to use it directly: it's a kind of a lazy (i.e. no-comms) reduce_sum over shards, without changing the logical shape: ```python reshard(f32[4@x], P(Unreduced={x})) : f32[4]{U:x} ``` Physically, we would zero-pad: ``` # f32[4@x] i.e. per-device shape is (2,) Device-0 Device-1 [0, 1] [2, 3] # f32[4]{U:x} i.e. per-device shape is (4,) Device-0 Device-1 [0, 1, 0, 0] [0, 0, 2, 3] ``` (There are other valid physical possibilities because Unreduced is flexible. For example, `Device-0: [0/2, 1/2, 2/2, 3/2]` and `Device-1: [0/2, 1/2, 2/2, 3/2]` is valid, but would require comms and would have weird numeric effects. Terrible.) The inverse operation (not the transpose, since those change the types) is `reshard(f32[4]{U:x}, P('x'))` and physically is a reduce-scatter, which naturally has the right effect on the physical buffers. **But as another motivation**, this operation naturally arises from autodiff, if we allow other reasonable expressions. For example, if we want to allow elementwise multiplication of sharded and Reduced values at the user level (because everything that works with Replicated should work with Reduced): ```python a: f32[4@x] b: f32[4]{R: x} c: f32[4@x] = a * b ``` we would desugar that as ```python b_: f32[4@x] = reshard(b, P('x')) # Reduced -> Sharded c: f32[4@x] = mul(a, b_) ``` Then the backward pass would require a Sharded -> Unreduced operation: ```python db_: f32[4@x] = mul(a, dc) db: f32[4]{U:x} = reshard(db_, P(Unreduced={x})) # Sharded -> Unreduced ``` **Before this change**, we actually had buggy behavior in that autodiff example where we multiply Reduced with Sharded. We would get incorrect gradients because our lowering of the backward pass's Sharded->Unreduced operation used to all-gather instead of zero-pad. One very very interesting thing is comparing to varying -> unreduced support inside shard_map, which works via shape-changing rather than zero-padding! How? The varying -> unreduced pcast is shape-preserving operation inside shmap, but when returning shard_map naturally concats so as to increase shapes. If we want exactly the same to be expressible outside shard_map, we might additionally need shape-changing operations like `f32[4@x] -> f32[2]{U:x}` and its transpose. But we'll leave that to future work. Co-authored-by: Matthew Johnson PiperOrigin-RevId: 843859866 --- tests/pjit_test.py | 93 +++++++++++++++++++++++++++++++++++++++++ tests/shard_map_test.py | 69 +++++++++++++++++++++++++++++- 2 files changed, 161 insertions(+), 1 deletion(-) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 6d4b3b31ab51..17e4ac8901bf 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -63,6 +63,7 @@ from jax._src.mesh import AxisType from jax._src.interpreters import pxla from jax._src.lib import xla_client as xc +from jax._src.lib import ifrt_version from jax._src.util import curry, unzip2 config.parse_flags_with_absl() @@ -9718,6 +9719,33 @@ def f(x, y): "reduced on"): jax.jit(jax.shard_map(f, out_specs=P()))(arr1, arr2) + @parameterized.parameters( + ((8,), P('x'), P(None, unreduced={'x'})), + ((4, 2), P('x', 'y'), P(None, None, unreduced={'x', 'y'})), + ((4, 2), P('x', 'y'), P('x', None, unreduced={'y'})), + ((4, 2), P('x', 'y'), P(None, 'y', unreduced={'x'})), + ((4, 2), P('x', None), P(None, None, unreduced={'x'})), + ((4, 2), P('y', None), P(None, None, unreduced={'y'})), + ((4, 2), P(('x', 'y'), None), P(None, None, unreduced={'x', 'y'})), + ((4, 4), P(None, ('x', 'y')), P(None, None, unreduced={'x', 'y'})), + # TODO(yashkatariya): Enable this after collectives + S->U cast is enabled. + # ((4, 2), P('x', 'y'), P(None, 'x', unreduced={'y'})), + # ((4, 2), P('x', 'y'), P('y', None, unreduced={'x'})), + ) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_sharded_unreduced_roundtrip(self, shape, orig_spec, un_spec, mesh): + if ifrt_version < 40: + self.skipTest('Requires ifrt_version >= 40') + np1 = np.arange(math.prod(shape)).reshape(shape) + arr = jax.device_put(np1, orig_spec) + + arr2 = reshard(arr, un_spec) + self.assertEqual(arr2.sharding, NamedSharding(mesh, un_spec)) + + arr3 = reshard(arr2, orig_spec) + self.assertArraysEqual(arr, arr3) + self.assertEqual(arr.sharding, arr3.sharding) + @parameterized.named_parameters( ('mul', jax.lax.mul), ('add', jax.lax.add), @@ -9747,6 +9775,71 @@ def g(x, y): self.assertEqual(out2.sharding, NamedSharding(mesh, P(None, unreduced={'x'}))) + if ifrt_version >= 40: + arr3 = jax.device_put(np1, P(None)) + ex_out1, ex_out2 = jax.jit(jax.grad(g, argnums=(0, 1)))(arr1, arr3) + self.assertArraysEqual(reshard(out1, P()), ex_out1) + self.assertArraysEqual(reshard(out2, P()), ex_out2) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_reduced_reshard_unreduced_bwd(self, mesh): + if ifrt_version < 40: + self.skipTest('Requires ifrt_version >= 40') + np1 = np.arange(4.) + arr = jax.device_put(np1, P(None, reduced={'x'})) + + @jax.jit + def f(x): + return jax.reshard(x, P('x')) + + out = f(arr) + self.assertArraysEqual(out, np1) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + + @jax.jit + def g(x): + return f(x).sum() + + out = jax.jit(jax.grad(g))(arr) + ex_data = [np.array([1., 1., 0., 0.]), np.array([1., 1., 0., 0.]), + np.array([0., 0., 1., 1.]), np.array([0., 0., 1., 1.])] + for s, d in zip(out.addressable_shards, ex_data): + self.assertArraysEqual(s.data, d) + + arr2 = jax.device_put(np1, P(None)) + expected_out = jax.jit(jax.grad(g))(arr2) + self.assertArraysEqual(reshard(out, P()), expected_out) + + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_reduced_reshard_unreduced_bwd_sharded(self, mesh): + if ifrt_version < 40: + self.skipTest('Requires ifrt_version >= 40') + np1 = np.arange(8.).reshape(4, 2) + arr = jax.device_put(np1, P('x', None, reduced={'y'})) + + @jax.jit + def f(x): + return jax.reshard(x, P('x', 'y')) + + out = f(arr) + self.assertArraysEqual(out, np1) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + + @jax.jit + def g(x): + return f(x).sum() + + out = jax.jit(jax.grad(g))(arr) + self.assertEqual(out.sharding, + NamedSharding(mesh, P('x', None, unreduced={'y'}))) + + arr2 = jax.device_put(np1, P('x', None)) + expected_out = jax.jit(jax.grad(g))(arr2) + self.assertEqual(expected_out.sharding, NamedSharding(mesh, P('x', None))) + + self.assertArraysEqual(reshard(out, P('x', None)), expected_out) + self.assertArraysEqual(reshard(out, P()), reshard(expected_out, P())) + @jtu.with_explicit_mesh((2,), 'x') def test_reduced_at_get_out_sharding(self, mesh): np1 = np.ones((2048, 64), dtype=jnp.float32) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 742c976f3ca6..7419efb3394e 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -4664,7 +4664,6 @@ def f(x, y): out = f(arr1, arr2) self.assertEqual(out.sharding, NamedSharding(mesh, P(None))) - with jax.set_mesh(empty_concrete_mesh): ex_out = np.sum([func(s.data, np2) for s in arr1.addressable_shards], axis=0) @@ -4679,6 +4678,74 @@ def g(x, y): self.assertEqual(out2.sharding, NamedSharding(mesh, P(None, unreduced={'x'}))) + arr3 = jax.device_put(np2, P(None)) + ex_out1, ex_out2 = jax.jit(jax.grad(g, argnums=(0, 1)))(arr1, arr3) + self.assertArraysEqual(out1, ex_out1) + self.assertArraysEqual(jax.reshard(out2, P()), ex_out2) + + @jtu.with_explicit_mesh((2,), 'x') + def test_reduced_pcast_fwd_unreduced_bwd(self, mesh): + np1 = np.arange(8.) + arr = jax.device_put(np1, P(None, reduced={'x'})) + + @jax.jit + @jax.shard_map(out_specs=P('x')) + def f(x): + return jax.lax.pcast(x, 'x', to='varying') + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + self.assertArraysEqual(out, np.concat([np1, np1], axis=0)) + + @jax.jit + def g(x): + return f(x).sum() + + out = jax.jit(jax.grad(g))(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P(None, unreduced={'x'}))) + + arr2 = jax.device_put(np1, P(None)) + ex_out = jax.jit(jax.grad(g))(arr2) + self.assertArraysEqual(jax.reshard(out, P()), ex_out) + + @parameterized.named_parameters( + ('mul', jax.lax.mul), + ('add', jax.lax.add), + ) + @jtu.with_explicit_mesh((2,), 'x') + def test_one_input_sharded_another_reduced_shmap_no_psum(self, func, mesh): + np1 = np.arange(16.) + np2 = np.arange(8.) + arr1 = jax.device_put(np1, P('x')) + arr2 = jax.device_put(np2, P(None, reduced={'x'})) + + @jax.jit + @jax.shard_map(out_specs=P('x')) + def f(x, y): + z = func(x, y) + return z + + out = f(arr1, arr2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + with jax.set_mesh(empty_concrete_mesh): + ex_out = [func(s.data, np2) for s in arr1.addressable_shards] + for s, e in zip(out.addressable_shards, ex_out): + self.assertArraysEqual(s.data, e) + + @jax.jit + def g(x, y): + return f(x, y).sum() + + out1, out2 = jax.jit(jax.grad(g, argnums=(0, 1)))(arr1, arr2) + self.assertEqual(out1.sharding, NamedSharding(mesh, P('x'))) + self.assertEqual(out2.sharding, + NamedSharding(mesh, P(None, unreduced={'x'}))) + + arr3 = jax.device_put(np2, P(None)) + ex_out1, ex_out2 = jax.jit(jax.grad(g, argnums=(0, 1)))(arr1, arr3) + self.assertArraysEqual(out1, ex_out1) + self.assertArraysEqual(jax.reshard(out2, P()), ex_out2) + @jtu.with_explicit_mesh((2,), 'x') def test_split_with_unused_result_in_shardmap(self, mesh): arr = jax.device_put(jnp.ones(8), P('x')) From a99f76f6a6556223d71a2a63cd03019634085cc5 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 12 Dec 2025 16:05:31 -0800 Subject: [PATCH 190/315] Error out if ShapeDtypeStruct gets a sharding that's not compatible with the shape. PiperOrigin-RevId: 843866725 --- jax/_src/core.py | 9 +++++++++ jax/_src/partition_spec.py | 9 +++++++++ jax/_src/pjit.py | 35 +++++++++++++++-------------------- jax/_src/sharding_impls.py | 6 ++++++ tests/pjit_test.py | 15 +++++++++++++++ 5 files changed, 54 insertions(+), 20 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 48a8db1c5e5f..88c3c402250d 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -3428,6 +3428,14 @@ def eqn_effects(jaxpr): # ------------------- ShapeDtypeStruct ------------------- +def _check_sharding(sharding, shape): + if sharding is None: + return + if isinstance(sharding, P): + sharding._check_compatible_wrt_shape(shape) + else: + sharding.check_compatible_aval(shape) + @set_module("jax") class ShapeDtypeStruct: """A container for the shape, dtype, and other static attributes of an array. @@ -3461,6 +3469,7 @@ def __init__(self, shape, dtype, *, sharding=None, weak_type=False, f" layout in a `ShapeDtypeStruct`. Got {sharding}") self._sharding = (sharding.sharding if isinstance(sharding, Format) else sharding) + _check_sharding(self._sharding, self.shape) self._dll = sharding.layout if isinstance(sharding, Format) else None self.weak_type = weak_type if vma is not None and not isinstance(vma, (set, frozenset)): diff --git a/jax/_src/partition_spec.py b/jax/_src/partition_spec.py index 7720adc2d072..d1ee8d6a40fc 100644 --- a/jax/_src/partition_spec.py +++ b/jax/_src/partition_spec.py @@ -163,6 +163,15 @@ def _normalized_spec_for_aval(self, ndim: int) -> PartitionSpec: out.extend([None] * (ndim - len(out))) return self.update(partitions=out) + def _check_compatible_wrt_shape(self, shape): + if len(shape) < len(self._partitions): + extra_msg = (' For scalars the PartitionSpec should be P()' + if len(shape) == 0 else '') + raise ValueError( + f"PartitionSpec {self} is only valid for values of rank at least " + f"{len(self._partitions)}, but was applied to a value of rank " + f"{len(shape)}.{extra_msg}") + PartitionSpec.__module__ = 'jax.sharding' P = PartitionSpec diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index da0e7ccbbb9e..00738dc31194 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1206,30 +1206,25 @@ def pjit_check_aval_sharding( name_str = f' with pytree key path {name}' if name else '' shape = aval.shape try: - # Sharding interfaces can implement `check_compatible_aval` as an optional - # method to raise a more meaningful error. - if hasattr(s, 'check_compatible_aval'): - s.check_compatible_aval(shape) - else: - s._to_xla_hlo_sharding(len(shape)) + s.check_compatible_aval(shape) except ValueError as e: raise ValueError( f'One of {what_aval}{name_str} is incompatible with its sharding ' f'annotation {s}: {e}') - # Use the `OpSharding` proto to find out how many ways each dimension of - # the aval is sharded. This approach will work across all - # Sharding. - hlo_sharding = s._to_xla_hlo_sharding(len(shape)) - assert hlo_sharding is not None - num_ways_dim_sharded, _ = op_shardings.get_num_ways_dim_sharded( - hlo_sharding, allow_partial_manual) - for i, size in enumerate(num_ways_dim_sharded): - if not allow_uneven_sharding and shape[i] % size != 0: - raise ValueError(f"One of {what_aval}{name_str} was given the sharding " - f"of {s}, which implies that " - f"the global size of its dimension {i} should be " - f"divisible by {size}, but it is equal to {shape[i]} " - f"(full shape: {shape})") + + if not allow_uneven_sharding: + hlo_sharding = s._to_xla_hlo_sharding(len(shape)) + assert hlo_sharding is not None + num_ways_dim_sharded, _ = op_shardings.get_num_ways_dim_sharded( + hlo_sharding, allow_partial_manual) + for i, size in enumerate(num_ways_dim_sharded): + if shape[i] % size != 0: + raise ValueError( + f'One of {what_aval}{name_str} was given the sharding ' + f'of {s}, which implies that ' + f'the global size of its dimension {i} should be ' + f'divisible by {size}, but it is equal to {shape[i]} ' + f'(full shape: {shape})') def check_aval_layout_compatibility( diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index b658a4d13966..2cdb0ec2fe56 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -181,6 +181,9 @@ def is_fully_replicated(self) -> bool: def is_fully_addressable(self) -> bool: return xb.process_index(self._device.client) == self._device.process_index + def check_compatible_aval(self, aval_shape: Shape) -> None: + return + SingleDeviceSharding.__module__ = 'jax.sharding' @util.cache(max_size=4096, trace_context_in_key=False) @@ -323,6 +326,9 @@ def is_fully_replicated(self) -> bool: def is_fully_addressable(self) -> bool: return self._internal_device_list.is_fully_addressable + def check_compatible_aval(self, aval_shape: Shape) -> None: + return + def shard_shape(self, global_shape: Shape) -> Shape: sharded_dim = None sharded_dim_size = None diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 17e4ac8901bf..db67152647cd 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4967,6 +4967,21 @@ def test_sds_input_to_zeros_like_propagates_sharding(self, mesh): out = jnp.zeros_like(val) self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + def test_sds_incompatible_sharding(self): + mesh = jtu.create_mesh((2,), 'x') + with self.assertRaisesRegex( + ValueError, + "only valid for values of rank at least 3, but was applied to a value " + "of rank 2"): + jax.ShapeDtypeStruct((128, 128), jnp.float32, + sharding=NamedSharding(mesh, P(None, 'x', None))) + + with self.assertRaisesRegex( + ValueError, + "only valid for values of rank at least 3, but was applied to a value " + "of rank 2"): + jax.ShapeDtypeStruct((128, 128), jnp.float32, sharding=P(None, 'x', None)) + class ShardingInTypesTest(jtu.JaxTestCase): From 481a569a2abebb23836c59b7aaa4dee99c5f8f4a Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 12 Dec 2025 18:55:18 -0800 Subject: [PATCH 191/315] Add sharding/shape/dtype checks in si_vjp Co-authored-by: Dougal Maclaurin PiperOrigin-RevId: 843911692 --- jax/_src/api.py | 22 +++++++++++++++++++--- tests/api_test.py | 15 +++++++++++++++ tests/pjit_test.py | 19 ++++++++++++++----- 3 files changed, 48 insertions(+), 8 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 5dde45b9e004..ce62293f6518 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2257,8 +2257,10 @@ def saved_input_vjp(f: Callable, which: Sequence[bool], *primals, res_spec = [RSpec(id_map[id(r)], True) if id(r) in id_map else RSpec(opaque_residuals.append(r) or (len(opaque_residuals) - 1), False) # type: ignore for r in residuals] + out_primal_avals = map(shaped_abstractify, out_primals_flat) f_vjp = Partial(partial(_saved_input_vjpfun, res_spec, filt_tree, in_tree, - out_tree(), out_known, jaxpr), opaque_residuals) + out_tree(), out_known, jaxpr, out_primal_avals), + opaque_residuals) if not allow_unused and not set(id_map).issubset(res_ids := {id(r) for r in residuals}): unused = [(i, core.get_aval(x)) for i, (x, w) in enumerate(zip(primals, which)) @@ -2284,7 +2286,8 @@ def saved_input_vjp(f: Callable, which: Sequence[bool], *primals, return out_primals, f_vjp def _saved_input_vjpfun(res_spec, filtered_tree, in_tree, out_tree, out_known, - jaxpr, opaque_residuals, ct, *saved_primals): + jaxpr, out_primal_avals, opaque_residuals, ct, + *saved_primals): primals_filtered, filtered_tree_ = tree_flatten(saved_primals) if filtered_tree != filtered_tree_: raise ValueError( @@ -2303,7 +2306,20 @@ def _saved_input_vjpfun(res_spec, filtered_tree, in_tree, out_tree, out_known, for i in res_spec] dummy_args = [ad.UndefinedPrimal(v.aval) for v in jaxpr.invars] cts_flat, out_tree_ = tree_flatten(ct) - assert out_tree_ == out_tree + if out_tree_ != out_tree: + raise ValueError(f"unexpected tree structure of argument to vjp function: " + f"got {out_tree}, but expected to match {out_tree_}") + for arg, aval in zip(cts_flat, out_primal_avals): + ct_aval = shaped_abstractify(arg) + ct_aval_expected = aval.to_cotangent_aval() + if (not core.typecompat(ct_aval, ct_aval_expected) and + not _temporary_dtype_exception(ct_aval, ct_aval_expected)): + raise ValueError( + "unexpected JAX type (e.g. shape/dtype) for argument to vjp function: " + f"got {ct_aval.str_short()}, but expected {ct_aval_expected.str_short()} " + f"because the corresponding output of the function had JAX type " + f"{aval.str_short()}") + cts_flat = [ct for ct, k in zip(cts_flat, out_known) if not k] arg_cts = ad.backward_pass(jaxpr, True, residuals, dummy_args, cts_flat) return tree_unflatten(in_tree, map(ad.instantiate_zeros, arg_cts)) diff --git a/tests/api_test.py b/tests/api_test.py index 4f0c85cdc36e..3f7f1b417794 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -7865,6 +7865,21 @@ def f2(x, w): x_grad, w_grad = f2_sivjp(y_grad, w) self.assertAllClose(x_grad, 2. * y_grad @ w.T) + def test_fsdp_error(self): + # see https://github.com/jax-ml/jax/pull/27017 for why this is called "fsdp" + def f2(x, w): + x = 1. * x + x = x @ w + x = 2. * x + return x + + x = jnp.ones((3, 4)) + w = jnp.ones((4, 4)) + y, f2_sivjp = api.si_vjp(f2, [False, True], x, w) + y_grad = jnp.ones((2, 4)) + with self.assertRaisesRegex(ValueError, "unexpected JAX type"): + f2_sivjp(y_grad, w) + def test_fsdp_vjp3(self): # see https://github.com/jax-ml/jax/pull/27017 for why this is called "fsdp" def f2(x, w): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index db67152647cd..a941a9c098b2 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -9751,6 +9751,8 @@ def f(x, y): def test_sharded_unreduced_roundtrip(self, shape, orig_spec, un_spec, mesh): if ifrt_version < 40: self.skipTest('Requires ifrt_version >= 40') + if not jtu.is_cloud_tpu_at_least(2025, 12, 15): + self.skipTest('Requires libtpu built after 2025-12-15') np1 = np.arange(math.prod(shape)).reshape(shape) arr = jax.device_put(np1, orig_spec) @@ -9767,6 +9769,10 @@ def test_sharded_unreduced_roundtrip(self, shape, orig_spec, un_spec, mesh): ) @jtu.with_explicit_mesh((2,), 'x') def test_one_input_sharded_another_reduced(self, func, mesh): + if ifrt_version < 40: + self.skipTest('Requires ifrt_version >= 40') + if not jtu.is_cloud_tpu_at_least(2025, 12, 15): + self.skipTest('Requires libtpu built after 2025-12-15') np1 = np.arange(8.) arr1 = jax.device_put(np1, P('x')) arr2 = jax.device_put(np1, P(None, reduced={'x'})) @@ -9790,16 +9796,17 @@ def g(x, y): self.assertEqual(out2.sharding, NamedSharding(mesh, P(None, unreduced={'x'}))) - if ifrt_version >= 40: - arr3 = jax.device_put(np1, P(None)) - ex_out1, ex_out2 = jax.jit(jax.grad(g, argnums=(0, 1)))(arr1, arr3) - self.assertArraysEqual(reshard(out1, P()), ex_out1) - self.assertArraysEqual(reshard(out2, P()), ex_out2) + arr3 = jax.device_put(np1, P(None)) + ex_out1, ex_out2 = jax.jit(jax.grad(g, argnums=(0, 1)))(arr1, arr3) + self.assertArraysEqual(reshard(out1, P()), ex_out1) + self.assertArraysEqual(reshard(out2, P()), ex_out2) @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_reduced_reshard_unreduced_bwd(self, mesh): if ifrt_version < 40: self.skipTest('Requires ifrt_version >= 40') + if not jtu.is_cloud_tpu_at_least(2025, 12, 15): + self.skipTest('Requires libtpu built after 2025-12-15') np1 = np.arange(4.) arr = jax.device_put(np1, P(None, reduced={'x'})) @@ -9829,6 +9836,8 @@ def g(x): def test_reduced_reshard_unreduced_bwd_sharded(self, mesh): if ifrt_version < 40: self.skipTest('Requires ifrt_version >= 40') + if not jtu.is_cloud_tpu_at_least(2025, 12, 15): + self.skipTest('Requires libtpu built after 2025-12-15') np1 = np.arange(8.).reshape(4, 2) arr = jax.device_put(np1, P('x', None, reduced={'y'})) From 0d58415ad3c3f11dadbec52ec09803b0938273b0 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 12 Dec 2025 20:05:06 -0800 Subject: [PATCH 192/315] Make sure correct mesh is maintained on the out_avals of pallas_call PiperOrigin-RevId: 843928732 --- jax/_src/pallas/pallas_call.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 57884b5c67e2..8f00b492057d 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -115,7 +115,8 @@ def _pallas_call_abstract_eval( raise ValueError(f"input pinned buffers without input_output_aliases:" f"{missing}") outin_aliases = {out_idx: in_idx for in_idx, out_idx in inout_aliases.items()} - out_avals = [jax_core.ShapedArray(a.shape, a.dtype, a.weak_type) + out_avals = [jax_core.ShapedArray(a.shape, a.dtype, a.weak_type, + sharding=a.sharding) if isinstance(a, pallas_core.ShapedArrayWithMemorySpace) else avals[outin_aliases[out_idx]] if out_idx in outin_aliases else a for out_idx, a in enumerate(out_avals)] @@ -1155,7 +1156,8 @@ def _convert_out_shape_to_aval(out_shape: Any) -> jax_core.AbstractValue: return jax_core.ShapedArray( shape=out_shape.shape, dtype=out_shape.dtype, sharding=jax_core.get_cur_mesh_sharding(), vma=out_shape.vma) - return jax_core.ShapedArray(shape=out_shape.shape, dtype=out_shape.dtype) + return jax_core.ShapedArray(shape=out_shape.shape, dtype=out_shape.dtype, + sharding=jax_core.get_cur_mesh_sharding()) case pallas_core.MemoryRef(): return out_shape.get_array_aval() case hijax.HiType(): From 8ce351244abae9cd6c88c673b4fa1fa92e45377c Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 12 Dec 2025 21:26:27 -0800 Subject: [PATCH 193/315] Fix tree equality error message in si_vjp PiperOrigin-RevId: 843950969 --- jax/_src/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index ce62293f6518..431740824606 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2308,7 +2308,7 @@ def _saved_input_vjpfun(res_spec, filtered_tree, in_tree, out_tree, out_known, cts_flat, out_tree_ = tree_flatten(ct) if out_tree_ != out_tree: raise ValueError(f"unexpected tree structure of argument to vjp function: " - f"got {out_tree}, but expected to match {out_tree_}") + f"got {out_tree_}, but expected to match {out_tree}") for arg, aval in zip(cts_flat, out_primal_avals): ct_aval = shaped_abstractify(arg) ct_aval_expected = aval.to_cotangent_aval() From 404d644f1b1a6e55d239e1c4614d5aa4f14978b4 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 12 Dec 2025 22:56:26 -0800 Subject: [PATCH 194/315] Remove config.vjp3 and vjp3 API since it's now replaced by jax.vjp PiperOrigin-RevId: 843971635 --- jax/_src/api.py | 32 +------------------------------- jax/_src/config.py | 6 ------ jax/_src/interpreters/ad.py | 19 +------------------ jax/interpreters/ad.py | 1 - tests/api_test.py | 26 ++++++++++++++------------ tests/custom_api_test.py | 2 +- tests/lax_control_flow_test.py | 15 ++++----------- tests/mutable_array_test.py | 17 ++++++++--------- tests/pjit_test.py | 7 ------- 9 files changed, 29 insertions(+), 96 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 431740824606..e91e3e8a3e18 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2212,31 +2212,6 @@ def vjp( fun, debug_info=debug_info("vjp", fun, primals, {})) return _vjp(wrapped_fun, *primals, has_aux=has_aux) -def _vjp(fun: lu.WrappedFun, *primals, has_aux=False): - """Variant of vjp() that takes an lu.WrappedFun.""" - if config.vjp3.value: - return _vjp3(fun, *primals, has_aux=has_aux) - primals_flat, in_tree = tree_flatten(primals) - primals_flat = [canonicalize_value(v) if not isinstance(v, core.Tracer) else v - for v in primals_flat] - for arg in primals_flat: dispatch.check_arg(arg) - if not has_aux: - flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree) - out_primals, vjp = ad.vjp(flat_fun, primals_flat) - out_tree = out_tree() - else: - flat_fun, out_aux_trees = flatten_fun_nokwargs2(fun, in_tree) - out_primals, vjp, aux = ad.vjp(flat_fun, primals_flat, has_aux=True) - out_tree, aux_tree = out_aux_trees() - out_primal_avals = map(shaped_abstractify, out_primals) - out_primal_py = tree_unflatten(out_tree, out_primals) - vjp_py = Partial(partial(_vjp_pullback_wrapper, fun.__name__, - out_primal_avals, (out_tree, in_tree)), vjp) - if not has_aux: - return out_primal_py, vjp_py - else: - return out_primal_py, vjp_py, tree_unflatten(aux_tree, aux) - @partial(api_boundary, repro_api_name="jax.experimental.saved_input_vjp") def saved_input_vjp(f: Callable, which: Sequence[bool], *primals, allow_unused: bool = True, allow_opaque: bool = True): @@ -2332,12 +2307,7 @@ class RSpec: si_vjp = saved_input_vjp -def vjp3(f, *primals, has_aux=False): - dbg = debug_info("vjp", f, primals, {}) - fun = lu.wrap_init(f, debug_info=dbg) - return _vjp3(fun, *primals, has_aux=has_aux) - -def _vjp3(fun, *primals, has_aux=False): +def _vjp(fun, *primals, has_aux=False): canon = lambda x: x if isinstance(x, core.Tracer) else canonicalize_value(x) primals = tree_map(canon, primals) primals_flat, in_tree = tree_flatten(primals) diff --git a/jax/_src/config.py b/jax/_src/config.py index e7cf1c772cf3..469eeaf7a873 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1833,12 +1833,6 @@ def _validate_default_device(val): help='Enable error checks for mutable arrays that rule out aliasing.', include_in_trace_context=True) -vjp3 = bool_state( - name='jax_vjp3', - default=True, - upgrade=True, - help='Use new backward-pass code in jax.vjp') - refs_to_pins = bool_state( name='jax_refs_to_pins', default=False, diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 75d1862b9809..459bf5397f92 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -26,7 +26,7 @@ from jax._src import linear_util as lu from jax._src.interpreters import partial_eval as pe from jax._src.tree_util import (tree_flatten, tree_unflatten, - register_pytree_node, Partial, PyTreeDef) + register_pytree_node, PyTreeDef) from jax._src import mesh as mesh_lib from jax._src import core from jax._src import source_info_util @@ -302,23 +302,6 @@ def linearize(traceable: lu.WrappedFun, *primals, **kwargs): else: return out_primals_consts, out_tangents_pvals, jaxpr, consts, aux() -def vjp(traceable: lu.WrappedFun, primals, has_aux=False): - if not has_aux: - out_primals, pvals, jaxpr, consts = linearize(traceable, *primals) - else: - out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True) - - def unbound_vjp(pvals, jaxpr, consts, *cts): - cts = tuple(ct for ct, pval in zip(cts, pvals) if not pval.is_known()) - dummy_args = [UndefinedPrimal(v.aval) for v in jaxpr.invars] - arg_cts = backward_pass(jaxpr, True, consts, dummy_args, cts) - return map(instantiate_zeros, arg_cts) - - vjp_ = Partial(partial(unbound_vjp, pvals, jaxpr), consts) - if not has_aux: - return out_primals, vjp_ - else: - return out_primals, vjp_, aux # NOTE: The FIXMEs below are caused by primal/tangent mixups (type # errors if you will) diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py index 6af63b3e7310..64945034a7a9 100644 --- a/jax/interpreters/ad.py +++ b/jax/interpreters/ad.py @@ -41,7 +41,6 @@ primitive_jvps as primitive_jvps, primitive_transposes as primitive_transposes, reducing_transposes as reducing_transposes, - vjp as vjp, zeros_like_aval as zeros_like_aval, ) diff --git a/tests/api_test.py b/tests/api_test.py index 3f7f1b417794..547ef9c22493 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -3101,7 +3101,7 @@ def cond(pred): def test_grad_of_bool_vjp3(self): def cond(pred): return lax.cond(pred, lambda _: 1., lambda _: 2., 1.) - value, f_vjp = api.vjp3(cond, True) + value, f_vjp = api.vjp(cond, True) grd, = f_vjp(1.) self.assertEqual(value, 1.) self.assertEqual(grd, np.zeros(shape=(), dtype=float0)) @@ -3213,6 +3213,12 @@ def f(): self.assertNotRegex(str(j_module), f"stablehlo.constant dense.*tensor<{const_size}x") + def test_basic_vjp3(self): + f = jax.jit(lambda x: jnp.sin(jnp.sin(x))) + _, f_vjp = jax.vjp(f, 1.) + g, = f_vjp(1.0) + self.assertAllClose(g, jnp.cos(jnp.sin(1.)) * jnp.cos(1.), check_dtypes=False) + def test_constants_not_in_lowering_scan(self): if not config.use_simplified_jaxpr_constants.value: self.skipTest("Works only with simplified Jaxpr consts") @@ -6737,11 +6743,7 @@ def f(x): return lax.cond(x.sum() > 0, f, lambda x: x, x) _, f_vjp = api.vjp(f, jnp.ones((5, 5))) - if config.vjp3.value: - jaxpr_text = str(f_vjp.jaxpr) - else: - jaxpr_text = str(f_vjp.jaxpr) - + jaxpr_text = str(f_vjp.jaxpr) self.assertEqual(jaxpr_text.count(' sin '), 2) self.assertEqual(jaxpr_text.count(' cos '), 3) # Five calls to dot_general in the backward pass because we have two for @@ -7806,7 +7808,7 @@ def test_basic_unused(self): def test_basic_unused_vjp3(self): f = jnp.sin primals = 3., - y, f_vjp = api.vjp3(f, *primals) + y, f_vjp = api.vjp(f, *primals) x_ct, = f_vjp(1.) self.assertAllClose(y, jnp.sin(3.)) self.assertAllClose(x_ct, jnp.cos(3.)) @@ -7821,7 +7823,7 @@ def test_basic_opaque(self): def test_basic_opaque_vjp3(self): f = jnp.sin primals = 3., - _, f_vjp = api.vjp3(f, *primals) + _, f_vjp = api.vjp(f, *primals) assert f_vjp.opaque_residuals # can detect if opaque res are used def test_basic_pytree_error(self): @@ -7841,7 +7843,7 @@ def f(x): # def f(x): # return [x['hi'] * x['bye']] - # y, f_vjp = api.vjp3(f, {'hi': 2., 'bye': 3.}) + # y, f_vjp = api.vjp(f, {'hi': 2., 'bye': 3.}) # arg_ct, = f_vjp([1.], {'hi': 2., 'bye': 3.}) # self.assertAllClose(y, [6.]) # self.assertAllClose(arg_ct, {'hi': 3., 'bye': 2.}) @@ -7890,7 +7892,7 @@ def f2(x, w): x = jnp.ones((3, 4)) w = jnp.ones((4, 4)) - y, f2_vjp = api.vjp3(f2, x, w) + y, f2_vjp = api.vjp(f2, x, w) f2_vjp.args_res[1] = None y_grad = jnp.ones_like(y) f2_vjp.args_res[1] = w @@ -7964,7 +7966,7 @@ def foo(x): def test_grad_traceback(self): # TODO(dougalm): improve this - expected_depth = 12 + expected_depth = 11 init_depth = self.cur_depth() def foo(x): @@ -7987,7 +7989,7 @@ def foo(x): def test_custom_vjp_traceback(self): # TODO(dougalm): improve this expected_depth_f = 10 - expected_depth_f_fwd = 20 + expected_depth_f_fwd = 19 expected_depth_f_rev = 12 init_depth = self.cur_depth() @jax.custom_vjp diff --git a/tests/custom_api_test.py b/tests/custom_api_test.py index 786409de5a6e..f219cefc2fa4 100644 --- a/tests/custom_api_test.py +++ b/tests/custom_api_test.py @@ -1472,7 +1472,7 @@ def sin_jvp(primals, tangents): (x, y), (x_dot, y_dot) = primals, tangents del y_dot # ignore lol return div(x, y), div(x_dot, y) - _, f_vjp = api.vjp3(lambda x: div(x, 2.), 1.) + _, f_vjp = api.vjp(lambda x: div(x, 2.), 1.) ans, = f_vjp(1.) self.assertAllClose(ans, 1./2, check_dtypes=False) diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 7a3a45a485d9..be799f64a2ad 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -41,7 +41,6 @@ import jax.numpy as jnp # scan tests use numpy import jax.scipy as jsp from jax._src import dispatch -from jax._src.api import vjp3 from jax._src.lax import control_flow as lax_control_flow from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -2777,10 +2776,7 @@ def cumprod(x): # ==> Yes, we don't want to change autodiff const behavior. We must make # these tessts pass under use_simplified_jaxpr_constants. if not config.use_simplified_jaxpr_constants.value: - if config.vjp3.value: - ext_res, = vjp_fun.args_res - else: - *_, ext_res = vjp_fun.args[0].args[0] + ext_res, = vjp_fun.args_res self.assertIs(ext_res, x) if remat is not None: @@ -2790,10 +2786,7 @@ def cumprod(x): x = rng.randn(32, 2, 32).astype('float32') # numpy.ndarray, not Array _, vjp_fun = jax.vjp(cumprod, x) if not config.use_simplified_jaxpr_constants.value: - if config.vjp3.value: - ext_res, *_ = vjp_fun.opaque_residuals - else: - *_, ext_res = vjp_fun.args[0].args[0] + ext_res, *_ = vjp_fun.opaque_residuals self.assertIsInstance(ext_res, jax.Array) def test_scan_vmap_collectives(self): @@ -3498,14 +3491,14 @@ def test_cond_basic_vjp3(self): def f(x): return jax.lax.cond(True, jnp.sin, lambda x: x, x) - _, f_vjp = vjp3(f, 1.) + _, f_vjp = jax.vjp(f, 1.) g, = f_vjp(1.0) self.assertAllClose(g, jnp.cos(1.), check_dtypes=False) def h(x): return jax.lax.cond(True, jnp.sin, lambda x: 1., x) - _, h_vjp = vjp3(h, 1.) + _, h_vjp = jax.vjp(h, 1.) g, = h_vjp(1.0) self.assertAllClose(g, jnp.cos(1.), check_dtypes=False) diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index 332466321603..9234bf57f28f 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -25,7 +25,6 @@ from jax._src import core from jax._src import config from jax._src import test_util as jtu -from jax._src.api import vjp3 from jax._src.util import safe_map, safe_zip from jax._src.interpreters import mlir from jax.sharding import NamedSharding, PartitionSpec as P, AxisType @@ -524,8 +523,8 @@ def stash_grads_bwd(grads_ref, g): grads_ref = core.new_ref(jnp.float32(0.)) x = jnp.float32(1.) - _, f_vjp, *maybe_aux = vjp3(lambda x: primal(grads_ref, x), x, - has_aux=has_aux) + _, f_vjp, *maybe_aux = jax.vjp( + lambda x: primal(grads_ref, x), x, has_aux=has_aux) _ = f_vjp(jnp.float32(1.)) self.assertAllClose(grads_ref[...], jnp.cos(jnp.sin(1.)), check_dtypes=False) if has_aux: @@ -553,7 +552,7 @@ def stash_grads_bwd(stash_ref, g): stash_grads.defvjp(stash_grads_fwd, stash_grads_bwd) stash_ref = core.new_ref(jnp.float32(0.)) - _, f_vjp = vjp3(lambda x: primal(stash_ref, x), jnp.float32(1.)) + _, f_vjp = jax.vjp(lambda x: primal(stash_ref, x), jnp.float32(1.)) grads_val, = f_vjp(jnp.float32(1.)) self.assertAllClose(stash_ref[...], jnp.cos(jnp.sin(1.)), check_dtypes=False) self.assertAllClose(grads_val, jnp.cos(jnp.sin(1.)) * jnp.cos(1.), @@ -561,7 +560,7 @@ def stash_grads_bwd(stash_ref, g): stash_ref = core.new_ref(jnp.float32(0.)) grads_ref = core.new_ref(jnp.float32(0.)) - _, f_vjp = vjp3(lambda x: primal(stash_ref, x), jnp.float32(1.)) + _, f_vjp = jax.vjp(lambda x: primal(stash_ref, x), jnp.float32(1.)) _ = f_vjp.with_refs(grads_ref)(jnp.float32(1.)) self.assertAllClose(stash_ref[...], jnp.cos(jnp.sin(1.)), check_dtypes=False) self.assertAllClose(grads_ref[...], jnp.cos(jnp.sin(1.)) * jnp.cos(1.), @@ -849,7 +848,7 @@ def process_batch(Ws, xs_batch): grad_acc = jax.new_ref(jnp.zeros_like(Ws)) # CHANGED def process_mubatch(_, xs): - loss, f_vjp = vjp3(lambda Ws: mubatch_loss(Ws, xs), Ws) # CHANGED + loss, f_vjp = jax.vjp(lambda Ws: mubatch_loss(Ws, xs), Ws) # CHANGED f_vjp.with_refs(grad_acc)(jnp.ones_like(loss)) # CHANGED return (), loss @@ -924,7 +923,7 @@ def f_bwd(_, g): self.assertAllClose(y, 3.14, check_dtypes=False) # this exercises the fallback path, not a fancy transpose - _, f_vjp = vjp3(lambda x: f(jax.new_ref(x)), 3.14) + _, f_vjp = jax.vjp(lambda x: f(jax.new_ref(x)), 3.14) g, = f_vjp(1.) self.assertAllClose(g, 1., check_dtypes=False) @@ -969,7 +968,7 @@ def body(_, xy): return z.sum() grad_accum = jax.new_ref(jnp.zeros(5)) - _, f_vjp = vjp3(f, jnp.ones(5)) + _, f_vjp = jax.vjp(f, jnp.ones(5)) _, = f_vjp.with_refs(grad_accum)(1.) self.assertAllClose(grad_accum[...], jnp.arange(5.)) @@ -978,7 +977,7 @@ def test_vmap_with_vjp3(self): def grad_via_ref(f): def wrapper(*args): grad_accum = jax.tree.map(lambda x: jax.new_ref(jnp.zeros_like(x)), args) - out, f_vjp = vjp3(f, *args) + out, f_vjp = jax.vjp(f, *args) f_vjp.with_refs(*grad_accum)(jnp.ones_like(out)) return jax.tree.map(lambda x: jax.freeze(x), grad_accum) return wrapper diff --git a/tests/pjit_test.py b/tests/pjit_test.py index a941a9c098b2..cfdfce42a511 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -38,7 +38,6 @@ from jax._src import dtypes from jax import stages from jax import lax -from jax._src.api import vjp3 from jax._src.lax import lax as lax_internal from jax.lax import with_sharding_constraint from jax._src import prng @@ -1309,12 +1308,6 @@ def test_device_put_copy_donate(self): self.assertNotDeleted(z) self.assertArraysEqual(a, x * 2) - def test_basic_vjp3(self): - f = jax.jit(lambda x: jnp.sin(jnp.sin(x))) - _, f_vjp = vjp3(f, 1.) - g, = f_vjp(1.0) - self.assertAllClose(g, jnp.cos(jnp.sin(1.)) * jnp.cos(1.), check_dtypes=False) - @jtu.pytest_mark_if_available('multiaccelerator') class AutoShardingPjitTest(jtu.JaxTestCase): From 6ea96b54c08ec66a3206fbbfa5b9945cec3f45c0 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 11 Dec 2025 20:48:41 +0000 Subject: [PATCH 195/315] [vjp3] remove ad.backward_pass and non-fancy HOP transposes --- jax/_src/ad_checkpoint.py | 22 ++- jax/_src/custom_derivatives.py | 8 - jax/_src/interpreters/ad.py | 190 +++++-------------- jax/_src/interpreters/partial_eval.py | 4 +- jax/_src/interpreters/pxla.py | 1 + jax/_src/lax/control_flow/conditionals.py | 44 ----- jax/_src/lax/control_flow/loops.py | 220 ---------------------- jax/_src/lax/lax.py | 2 +- jax/_src/pjit.py | 81 +------- jax/interpreters/ad.py | 19 -- 10 files changed, 61 insertions(+), 530 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 88bdabc558bc..5976aa192e9c 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -731,27 +731,29 @@ def remat_partial_eval_custom_params_updater(*args): partial(pe.call_partial_eval_custom_rule, 'jaxpr', remat_partial_eval_custom_params_updater) -def remat_transpose(out_cts, *in_primals, jaxpr, prevent_cse, **params): +def remat_transpose(out_cts, *args, jaxpr, prevent_cse, **params): + # TODO(mattjj): avoid round-tripping into UndefinedPrimals + args_ = [ad.UndefinedPrimal(x.aval) if isinstance(x, ad.ValAccum) else x + for x in args] + assert not jaxpr.constvars - in_linear = [ad.is_undefined_primal(x) for x in in_primals] + in_linear = [ad.is_undefined_primal(x) for x in args_] out_zeros = [type(ct) is ad_util.Zero for ct in out_cts] transposed_jaxpr_, in_zeros = transpose_jaxpr( pe.close_jaxpr(jaxpr), in_linear, out_zeros) transposed_jaxpr, consts = transposed_jaxpr_.jaxpr, transposed_jaxpr_.consts transposed_jaxpr = pe.convert_constvars_jaxpr(transposed_jaxpr) - args, _ = tree_flatten((in_primals, out_cts)) + flat_args, _ = tree_flatten((args_, out_cts)) if isinstance(prevent_cse, tuple): prevent_cse_, _ = partition_list(in_linear, prevent_cse) prevent_cse = tuple(prevent_cse_) + (True,) * (len(out_zeros) - sum(out_zeros)) - in_cts_nz = remat_p.bind(*consts, *args, jaxpr=transposed_jaxpr, + in_cts_nz = remat_p.bind(*consts, *flat_args, jaxpr=transposed_jaxpr, prevent_cse=prevent_cse, **params) in_cts_nz_, in_zeros_ = iter(in_cts_nz), iter(in_zeros) - in_cts = [None if not ad.is_undefined_primal(x) else - ad_util.Zero(x.aval) if next(in_zeros_) else next(in_cts_nz_) - for x in in_primals] - assert next(in_cts_nz_, None) is next(in_zeros_, None) is None - return in_cts -ad.primitive_transposes[remat_p] = remat_transpose + for x in args: + if isinstance(x, ad.ValAccum) and not next(in_zeros_): + x.accum(next(in_cts_nz_)) +ad.fancy_transposes[remat_p] = remat_transpose # TODO(mattjj): move this to ad.py def transpose_jaxpr(jaxpr: core.ClosedJaxpr, in_linear: bool | Sequence[bool], diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index ed9b051d5777..e3d8be38a3c7 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -449,14 +449,6 @@ def _custom_jvp_vjp_call_lowering(ctx: mlir.LoweringRuleContext, *args, return out mlir.register_lowering(custom_jvp_call_p, _custom_jvp_vjp_call_lowering) -# If a (multi)linear function is defined with a custom jvp, then -# custom_jvp_call_ can appear in jaxprs to be transposed. Since it's already -# been linearized, we can drop the jvp rule. -def _custom_jvp_call_transpose(params, jaxpr, args, ct, _): - del params - return ad.backward_pass(jaxpr.jaxpr, False, jaxpr.consts, args, ct) -ad.primitive_transposes[custom_jvp_call_p] = _custom_jvp_call_transpose - def _custom_jvp_call_transpose_fancy(params, jaxpr, args, ct, _): del params return ad.backward_pass3(jaxpr.jaxpr, False, jaxpr.consts, args, ct) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 75d1862b9809..406b3a97e614 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -320,132 +320,6 @@ def unbound_vjp(pvals, jaxpr, consts, *cts): else: return out_primals, vjp_, aux -# NOTE: The FIXMEs below are caused by primal/tangent mixups (type -# errors if you will) -def backward_pass(jaxpr: core.Jaxpr, transform_stack, - consts, primals_in, cotangents_in): - if all(type(ct) is Zero for ct in cotangents_in) and not jaxpr.effects: - return map(lambda v: Zero(v.aval), jaxpr.invars) - - def write_cotangent(prim, v, ct): - # assert v not in primal_env - assert ct is not Zero, (prim, v.aval) # check for an old harmless type error - if ct is None or type(v) is Literal: - return - if type(ct) is Zero: - # FIXME: This triggers a lot of failures! - # assert v.aval == ct.aval, (prim, v.aval, ct.aval) - return - ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct - - def read_cotangent(v): - return ct_env.pop(v, Zero(v.aval.to_tangent_aval())) - - def read_primal(v): - if type(v) is Literal: - return v.val - else: - a = v.aval - return primal_env.get(v, UndefinedPrimal(a)) - - def write_primal(v, val): - if not is_undefined_primal(val): - primal_env[v] = val - - primal_env: dict[Any, Any] = {} - foreach(write_primal, jaxpr.constvars, consts) - foreach(write_primal, jaxpr.invars, primals_in) - - # Start with a forward pass to evaluate any side-effect-free JaxprEqns that - # only operate on primals. This is required to support primitives with - # linearization rules that include computations on the residuals. - lin_eqns = [] - dangling_refs = set() - for eqn in jaxpr.eqns: - if eqn.primitive is core.ref_p: - dangling_refs.add(eqn.outvars[0]) - if eqn.primitive is core.freeze_p: - dangling_refs.remove(eqn.invars[0]) # type: ignore - # TODO (dfm): The effects check is probably stricter than necessary. - # Consider adding an allowlist of effects here. - if jaxpr.effects or any( - type(x) is not Literal and x not in primal_env for x in eqn.invars): - lin_eqns.append(eqn) - continue - subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) - name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack - with source_info_util.user_context( - eqn.source_info.traceback, name_stack=name_stack), eqn.ctx.manager: - ans = eqn.primitive.bind(*subfuns, *map(read_primal, eqn.invars), **bind_params) - if eqn.primitive.multiple_results: - foreach(write_primal, eqn.outvars, ans) - else: - write_primal(eqn.outvars[0], ans) - - for v in dangling_refs: - write_primal(v, core.new_ref(zeros_like_aval(v.aval.inner_aval))) # type: ignore - - ct_env: dict[Any, Any] = {} - ctx = (source_info_util.transform_name_stack('transpose') if transform_stack - else contextlib.nullcontext()) - with ctx: - foreach(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in) - for eqn in lin_eqns[::-1]: - if eqn.primitive.ref_primitive: - if eqn.primitive is core.ref_p: - val_var, = eqn.invars - ref_var, = eqn.outvars - ref = read_primal(ref_var) - ct_out = core.freeze(ref) - write_cotangent(eqn.primitive, val_var, ct_out) - elif eqn.primitive is core.freeze_p: - val_var, = eqn.outvars - ref_var, = eqn.invars # type: ignore - ct_in = instantiate_zeros(read_cotangent(val_var)) - write_primal(ref_var, core.new_ref(ct_in)) - continue - - invals = map(read_primal, eqn.invars) - if eqn.primitive.multiple_results: - cts_in = map(read_cotangent, eqn.outvars) - else: - cts_in, = map(read_cotangent, eqn.outvars) - name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack - with source_info_util.user_context( - eqn.source_info.traceback, name_stack=name_stack), eqn.ctx.manager: - if eqn.primitive.call_primitive or eqn.primitive.map_primitive: - cts_in_avals = [v.aval for v in eqn.outvars] - params = dict(eqn.params) - call_jaxpr = params.pop('call_jaxpr') - cts_out = get_primitive_transpose(eqn.primitive)( - params, call_jaxpr, invals, cts_in, cts_in_avals) - else: - try: - cts_out = get_primitive_transpose(eqn.primitive)( - cts_in, *invals, **eqn.params) - except core.ShardingTypeError as e: - extra_msg = ("This is a potential JAX bug. Please file an issue at" - " https://github.com/jax-ml/jax/issues") - if extra_msg in str(e): - raise - raise core.ShardingTypeError(f"{str(e)}\n{extra_msg}") from e - except (FloatingPointError, ZeroDivisionError) as e: - msg = "When differentiating the code at the top of the callstack:" - if msg not in e.args[0]: - e.args = e.args[0] + f'\n{msg}', - e.args = e.args[0] + f'\n{source_info_util.summarize(eqn.source_info)}', - raise e from None - cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out - # FIXME: Some invars correspond to primals! - foreach(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out) - - cotangents_out = map(read_cotangent, jaxpr.invars) - return cotangents_out - -def closed_backward_pass(jaxpr: core.ClosedJaxpr, transform_stack, - primals_in, cotangents_in): - return backward_pass(jaxpr.jaxpr, transform_stack, jaxpr.consts, - primals_in, cotangents_in) class UndefinedPrimal: __slots__ = ['aval'] @@ -638,6 +512,14 @@ def accum_typeof(x): else: return core.typeof(x) +# TOOD(mattjj): this is for for backward (get it?) compatibility. Remove, maybe. +def backward_pass(jaxpr, transform_stack: bool, consts, primals_in, cts_in): + primals_in = [ValAccum(x.aval) if isinstance(x, UndefinedPrimal) else x + for x in primals_in] + backward_pass3(jaxpr, transform_stack, consts, primals_in, cts_in) + return [x.freeze() if isinstance(x, ValAccum) else None + for x in primals_in] + @lu.transformation_with_aux2 def nonzero_tangent_outputs(f, store, *args, **kwargs): @@ -1312,41 +1194,53 @@ def traceable(f, store, in_tree, *primals_and_tangents): return out_flat -def call_transpose(primitive, params, call_jaxpr: core.Jaxpr, args, ct, _): - if isinstance(call_jaxpr, core.ClosedJaxpr): - call_jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts - else: - consts = () - all_args, in_treedef = tree_flatten((consts, args, ct)) - fun = lu.hashable_partial( - lu.wrap_init(backward_pass, debug_info=call_jaxpr.debug_info), - call_jaxpr, False) - fun, out_tree = flatten_fun_nokwargs(fun, in_treedef) +def call_transpose(primitive, cts, *args, call_jaxpr, **params): + if call_jaxpr.constvars: raise NotImplementedError + primals_ctrefs, specs = project_accums(args) + flat_args, treedef = tree_flatten((primals_ctrefs, cts)) + cell = lambda: None + + @partial(lu.wrap_init, debug_info=call_jaxpr.debug_info.with_unknown_names()) + def transposed(*flat_args): + primals_ctrefs, cts = tree_unflatten(treedef, flat_args) + args = unproject_accums(specs, primals_ctrefs) + backward_pass3(call_jaxpr, False, (), args, cts) + cts_out = [x.freeze() if isinstance(x, ValAccum) else None for x in args] + cts_out, cell.out_tree = tree_flatten(cts_out) # type: ignore + return cts_out + update_params = call_transpose_param_updaters.get(primitive) if update_params: - params = update_params(params, map(is_undefined_primal, args), - [type(x) is not Zero for x in ct]) - out_flat = primitive.bind(fun, *all_args, **params) - return tree_unflatten(out_tree(), out_flat) -primitive_transposes[core.call_p] = partial(call_transpose, call_p) + params = update_params(params, [isinstance(x, GradAccum) for x in args], + [type(x) is not Zero for x in cts]) + out_flat = primitive.bind(transposed, *flat_args, **params) + for x, ct in zip(args, tree_unflatten(cell.out_tree, out_flat)): # type: ignore + if isinstance(x, ValAccum): x.accum(ct) +fancy_transposes[core.call_p] = partial(call_transpose, call_p) -def _closed_call_transpose(params, jaxpr, args, ct, cts_in_avals): - jaxpr_, consts = jaxpr.jaxpr, jaxpr.consts +def _closed_call_transpose(ct, *args, call_jaxpr, **params): + jaxpr_, consts = call_jaxpr.jaxpr, call_jaxpr.consts jaxpr_ = pe.convert_constvars_jaxpr(jaxpr_) - return call_transpose(core.closed_call_p, params, jaxpr_, (*consts, *args), - ct, cts_in_avals) -primitive_transposes[core.closed_call_p] = _closed_call_transpose + call_transpose(core.closed_call_p, ct, *consts, *args, call_jaxpr=jaxpr_, + **params) +fancy_transposes[core.closed_call_p] = _closed_call_transpose @lu.transformation_with_aux2 def nonzero_outputs(f, store, *args, **kwargs): results = f(*args, **kwargs) - store.store([type(r) is not Zero for r in results]) + store.store([not isinstance(r, (Zero, type(None))) for r in results]) return results +# TODO(mattjj): delete this when the original pmap implementation is removed def map_transpose(primitive: core.Primitive, params, call_jaxpr: core.Jaxpr, args, ct, _): + # TODO(mattjj): we should unmap any Zeros in ct according to out_axes, but + # this code path is not long for this world... + args = [x if type(x) is not UndefinedPrimal else + UndefinedPrimal(core.mapped_aval(params['axis_size'], ax, x.aval)) + for x, ax in zip(args, params['in_axes'])] all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts # TODO(necula): use the right debug_info for the backwards pass fun = lu.hashable_partial(lu.wrap_init( @@ -1383,7 +1277,7 @@ def out_axes_thunk(): print("Invalid nan value encountered in the backward pass of a jax.jit " "function. Calling the de-optimized backward pass.") try: - _ = backward_pass(call_jaxpr, False, {}, args, ct) + _ = backward_pass(call_jaxpr, False, (), args, ct) except (FloatingPointError, ZeroDivisionError) as e2: raise e2 from None else: diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 843a1730b94f..ef6b803de368 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -2075,8 +2075,8 @@ def process_call(self, call_primitive, f: lu.WrappedFun, in_tracers, params): source_info = source_info_util.current() to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) - in_type = (tuple(get_aval(t) for t in in_tracers) - if f.in_type is None else f.in_type) + in_type = (tuple(get_aval(t) for t in in_tracers) if f.in_type is None + else f.in_type) f.in_type = None assert in_type is not None in_tracers = map(to_jaxpr_tracer, in_tracers) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index e2d8fce03b6b..06b49956fc22 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -413,6 +413,7 @@ def _emap_impl(fun: lu.WrappedFun, *args, platform = xb.get_backend(backend).platform donate_argnums = (1,) if platform in {"cuda", "rocm", "tpu"} else () new_outvals = [] + assert len(out_axes_src) == len(out_axes), breakpoint() for out_axis_src, out_axis, outval in zip(out_axes_src, out_axes, outvals): with api.disable_jit(False): donate_argnums_ = donate_argnums diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 1f30876ac742..a5fc5e0e6232 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -836,49 +836,6 @@ def _cond_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn, return [True, *used_inputs], new_eqn -def _transpose_cond_jaxpr(jaxpr: core.ClosedJaxpr, - num_res: int): - res_avals, primal_avals = split_list(jaxpr.in_avals, [num_res]) - - def transposed(*args): - res, cts_out = split_list(args, [num_res]) - primals = res + [ad.UndefinedPrimal(aval) for aval in primal_avals] - cts_in = ad.backward_pass( - jaxpr.jaxpr, False, jaxpr.consts, primals, cts_out) - _, cts_in = split_list(cts_in, [num_res]) - return map(ad.instantiate_zeros, cts_in) - - return _make_closed_jaxpr(lu.wrap_init(transposed, - debug_info=jaxpr.jaxpr.debug_info), - res_avals + jaxpr.out_avals) - -def _cond_transpose(cts, *args, branches, **params): - index, *ops = args - assert type(index) is not ad.UndefinedPrimal - linear = [type(x) is ad.UndefinedPrimal for x in ops] - in_avals = branches[0].in_avals - num_res = len(ops) - sum(linear) - if any(isinstance(eff, RefEffect) for branch in branches for eff in - branch.jaxpr.effects): - raise NotImplementedError("State effect not supported in cond transpose.") - - branches_trans = [_transpose_cond_jaxpr(jaxpr, num_res) for jaxpr in branches] - lin_in_avals = [a.strip_weak_type() for a, l in zip(in_avals, linear) if l] - assert all(core.typematch(out_aval, lin_in_aval) - for jaxpr in branches_trans - for out_aval, lin_in_aval in zip(jaxpr.out_avals, lin_in_avals)) - - res = ops[:num_res] - cts = map(ad.instantiate_zeros, cts) - - out = cond_p.bind(index, *res, *cts, branches=tuple(branches_trans), **params) - assert all(map(core.typecheck, lin_in_avals, out)) - - out_iter = iter(out) - out = [next(out_iter) if l else None for l in linear] - assert next(out_iter, None) is None - return [None] + out - def _cond_transpose_fancy(cts_in, index, *args, branches, **params): assert not isinstance(index, ad.GradAccum) primals_ctrefs, specs = ad.project_accums(args) @@ -987,7 +944,6 @@ def _cond_typecheck(bind_time, *in_atoms, branches, **params): cond_p.def_impl(partial(dispatch.apply_primitive, cond_p)) cond_p.def_effectful_abstract_eval(_cond_abstract_eval) ad.primitive_jvps[cond_p] = _cond_jvp -ad.primitive_transposes[cond_p] = _cond_transpose ad.primitive_linearizations[cond_p] = _cond_linearize ad.fancy_transposes[cond_p] = _cond_transpose_fancy pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index d46d100da0b2..aee073d83c46 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -923,225 +923,6 @@ def _rearrange_mutable_binders( if config.enable_checks.value: core.check_jaxpr(new_jaxpr) return ClosedJaxpr(new_jaxpr, jaxpr.consts) -def _scan_transpose(cts, *args, reverse, length, num_consts, - num_carry, jaxpr, linear, unroll, _split_transpose): - # we've only implemented transposing scans with specific lin/nonlin patterns - consts_lin, init_lin, xs_lin = split_list(linear, [num_consts, num_carry]) - num_ires = len(consts_lin) - sum(consts_lin) - num_eres = len(xs_lin) - sum(xs_lin) - if consts_lin != [False] * num_ires + [True] * (len(consts_lin) - num_ires): - raise NotImplementedError - if xs_lin != [True] * (len(xs_lin) - num_eres) + [False] * num_eres: - raise NotImplementedError - if not all(init_lin): - pass # TODO(mattjj): error check https://github.com/jax-ml/jax/issues/1963 - - # We follow a funny convention of passing cotangent refs like primals, so they - # appear in `args` mixed in with the UndefinedPrimals of `T d` and `T a`. - # Rearrange jaxpr binders and arguments to put cotangent mutable arrays first: - # Before: [ires, T d, T c, T a, eres] -> [T c, T b] - # After: [ires, T d_mut, T d_pure, T c, T a_mut, T a_pure, eres] -> [T c, T b] - # where - # * `ires` means intensive (not scanned over / const) residuals - # * `T d` means the intensive tangents - # * `T c` means the tangent carry - # * `T a` means the extensive (scanned over) tangent inputs - # * `eres` means the extensive residuals - # * `T b` means the extensive tangent outputs - ires, consts_dot, carry_dot, xs_dot, eres = split_list( - args, [num_ires, num_consts - num_ires, num_carry, sum(xs_lin)]) - _, const_avals, _, xs_avals, _ = split_list( - jaxpr.in_avals, [num_ires, num_consts - num_ires, num_carry, sum(xs_lin)]) - is_mutable = [isinstance(a, AbstractRef) for a in const_avals] - immut_consts_dot, mut_consts_bar = partition_list(is_mutable, consts_dot) - jaxpr = _rearrange_mutable_binders(jaxpr, num_ires, num_consts - num_ires) - del const_avals, consts_dot - is_mutable_ = [isinstance(a, AbstractRef) for a in xs_avals] - immut_xs_dot, mut_xs_bar = partition_list(is_mutable_, xs_dot) - jaxpr = _rearrange_mutable_binders(jaxpr, num_consts + num_carry, sum(xs_lin)) - del xs_avals, xs_dot - # Check that pure tangent values are all UndefinedPrimals, and mutable - # 'tangent values' are not (since we actually put cotangent refs there). - assert not any(ad.is_undefined_primal(r) for r in ires) - assert not any(ad.is_undefined_primal(x) for x in mut_consts_bar) - # TODO(mattjj): re-enable these asserts - # assert all(ad.is_undefined_primal(x) for x in immut_consts_dot) - # assert all(ad.is_undefined_primal(x) for x in carry_dot) - # assert all(ad.is_undefined_primal(x) for x in immut_xs_dot) - assert not any(ad.is_undefined_primal(r) for r in eres) - del args - - # Take apart passed-in cotangents to identify which are sym zeros. - ct_carry, ct_ys = split_list(cts, [num_carry]) - ct_carry = _map(ad.instantiate_zeros, ct_carry) - ct_ys_is_zeros = [type(ct_y) is ad.Zero for ct_y in ct_ys] - ct_ys_nz = [x for x in ct_ys if type(x) is not ad.Zero] - ct_immut_consts = _map(ad_util.zeros_like_aval, - jaxpr.in_avals[num_ires+len(mut_consts_bar):num_consts]) - - jaxpr_trans = _transpose_scan_jaxpr( - jaxpr, num_ires, len(mut_consts_bar), len(immut_consts_dot), - len(mut_xs_bar), len(immut_xs_dot), num_eres, tuple(ct_ys_is_zeros)) - - linear_trans = ([False] * num_ires + - [True] * (len(mut_consts_bar) + len(immut_consts_dot) + - len(carry_dot) + len(mut_xs_bar) + len(ct_ys_nz)) + - [False] * num_eres) - transpose_inputs = [*ires, *mut_consts_bar, *ct_immut_consts, *ct_carry, - *mut_xs_bar, *ct_ys_nz, *eres] - - if not _split_transpose: - outs = scan_p.bind( - *transpose_inputs, - reverse=not reverse, length=length, jaxpr=jaxpr_trans, - num_consts=num_ires + len(mut_consts_bar), - num_carry=len(immut_consts_dot) + len(carry_dot), - linear=tuple(linear_trans), unroll=unroll, - _split_transpose=False) - else: - if len(mut_consts_bar): raise NotImplementedError - transpose_num_out_carry = num_consts-num_ires+num_carry - inst_mask = [False] * transpose_num_out_carry + [True] * ( - len(jaxpr_trans.out_avals) - transpose_num_out_carry) - - unknowns_mask = [False] * (len(transpose_inputs) - len(eres)) + [ - True - ] * len(eres) - - # The residuals may contain original parameters (e.g. forwarded extensive - # array arguments) and residuals from the primal. Hence we iterate and - # update all values of the mask that we've set to True (i.e. 'unknown') to - # see if we should actually push them to the known computation in order to - # perform the scan (known) - map (unknown) split. The test effectively is - # done by comparing the output masks. - # - # TODO(dvytin): improve performance by doing backwards abstract eval. - # - # For example, a mask arising from a relu() is an extensive residual, yet - # only really used in the backpropagation scan, not in the unknown map. But - # an intermediate activation of a matmul will be used only in the map part. - # If we were to erroneously push the relu mask to the unknown part, then, - # in the output, the partial evaluator will also pull the loop-carried state - # to the unknown, and that is something we can test by comparing the output - # mask of pe against our intended inst mask. - for index in range(len(jaxpr_trans.in_avals)): - if unknowns_mask[index]: - mask_for_dependence = [False]*len(jaxpr_trans.in_avals) - mask_for_dependence[index] = True # try moving this to unknown - _, _, outs_for_dependence, _ = pe.partial_eval_jaxpr_nounits( - jaxpr_trans, mask_for_dependence, inst_mask) - if inst_mask != outs_for_dependence: - unknowns_mask[index] = False - - jaxpr_known_body, jaxpr_unknown_body, outs_mask, res_avals = ( - pe.partial_eval_jaxpr_nounits(jaxpr_trans, unknowns_mask, inst_mask) - ) - - num_knowns = len(outs_mask) - sum(outs_mask) - - linear_list = list(linear_trans) - known_linear = [ - l for mask, l in zip(unknowns_mask, linear_list) if not mask - ] - unknown_linear = [l for mask, l in zip(unknowns_mask, linear_list) if mask] - unknown_linear = [False] * len(res_avals) + unknown_linear - - known_args = [ - arg for mask, arg in zip(unknowns_mask, transpose_inputs) if not mask - ] - unknown_args = [ - arg for mask, arg in zip(unknowns_mask, transpose_inputs) if mask - ] - # 1. Apply the known scan. - knowns_and_residual = scan_p.bind( - *known_args, - reverse=not reverse, - length=length, - num_consts=num_ires, - num_carry=transpose_num_out_carry, - jaxpr=jaxpr_known_body, - linear=tuple(known_linear), - unroll=unroll, - _split_transpose=False, # Just generate the loop now. - ) - known_results, residuals = split_list(knowns_and_residual, [num_knowns]) - - # 2. Apply the unknown map to residuals and unknown arguments. - unknown_results = scan_p.bind( - *residuals, *unknown_args, - reverse=reverse, # Keep reverse as is for better scheduling. - length=length, - num_consts=0, - num_carry=0, - jaxpr=jaxpr_unknown_body, - linear=tuple(unknown_linear), - unroll=unroll, - _split_transpose=False, # Just generate the loop now. - ) - known_results_iter = iter(known_results) - unknown_results_iter = iter(unknown_results) - outs = [ - next(known_results_iter) if not mask else next(unknown_results_iter) - for mask in outs_mask - ] - - ct_immut_consts, ct_init, ct_immut_xs = split_list(outs, [len(immut_consts_dot), len(carry_dot)]) - ct_consts = merge_lists(is_mutable, ct_immut_consts, [None] * len(mut_consts_bar)) - ct_xs = merge_lists(is_mutable_, ct_immut_xs, [None] * len(mut_xs_bar)) - return [None] * num_ires + ct_consts + ct_init + ct_xs + [None] * num_eres - -# transpose_scan_jaxpr converts the jaxpr signature: -# Before: [(ires, T d_mut T d_pure), T c, (CT a_mut, T a, eres)] -> [T c, T b] -# ---------- consts ----------- --------- ext ------- -# -# After: [(ires, CT d_mut), (CT d_pure, CT c), (CT a_mut, CT b, eres)] -> [(CT d_pure, CT c), CT a] -# --- consts ---- ----- carry ------ --------- ext -------- -@weakref_lru_cache -def _transpose_scan_jaxpr( - jaxpr: ClosedJaxpr, - num_ires: int, - num_d_mut: int, - num_d_pure: int, - num_a_mut: int, - num_a_pure: int, - num_eres: int, - ct_b_is_zeros: Sequence[bool]): - num_d = num_d_mut + num_d_pure - num_a = num_a_mut + num_a_pure - num_b_nz = len(ct_b_is_zeros) - sum(ct_b_is_zeros) - num_c = len(jaxpr.out_avals) - len(ct_b_is_zeros) - assert num_a == len(jaxpr.in_avals) - num_ires - num_d - num_c - num_eres - - ires_avals, d_mut_avals, d_pure_avals, c_avals, a_mut_avals, a_pure_avals, eres_avals = split_list( - jaxpr.in_avals, [num_ires, num_d_mut, num_d_pure, num_c, num_a_mut, num_a_pure]) - _, b_avals = split_list(jaxpr.out_avals, [num_c]) - b_avals_nz = [a for a, z in zip(b_avals, ct_b_is_zeros) if not z] - - # TODO(mattjj,dougalm): map to cotangent types... - def transposed(*ct_args): - ires, d_mut_bar, d_pure, c_bar, a_mut_bar, b_bar, eres = split_list( - ct_args, [num_ires, num_d_mut, num_d_pure, num_c, num_a_mut, num_b_nz]) - b_bar_ = iter(b_bar) - b_bar = [ad.Zero(a) if z else next(b_bar_) for a, z in zip(b_avals, ct_b_is_zeros)] - assert next(b_bar_, None) is None - primals = ( - ires + d_mut_bar + - [ad.UndefinedPrimal(aval) for aval in [*d_pure_avals, *c_avals]] + - a_mut_bar + [ad.UndefinedPrimal(aval) for aval in a_pure_avals] + eres) - cts_out = ad.backward_pass( - jaxpr.jaxpr, False, jaxpr.consts, primals, c_bar + b_bar) - _, new_d_pure, new_c_bar, _, a_bar, _ = split_list( - cts_out, [num_ires + num_d_mut, num_d_pure, num_c, num_a_mut, num_a_pure]) - d_pure = _map(ad.instantiate_zeros, _map(ad.add_tangents, d_pure, new_d_pure)) - new_c_bar = _map(ad.instantiate_zeros, new_c_bar) - a_bar = _map(ad.instantiate_zeros, a_bar) - return [*d_pure, *new_c_bar, *a_bar] - - transposed_wrapped = lu.wrap_init(transposed, debug_info=jaxpr.jaxpr.debug_info) - trans_avals = *ires_avals, *d_mut_avals, *d_pure_avals, *c_avals, *a_mut_avals, *b_avals_nz, *eres_avals - trans_jaxpr = _make_closed_jaxpr(transposed_wrapped, trans_avals) - return trans_jaxpr - def _scan_transpose_fancy(cts, *args, reverse, length, num_consts, num_carry, jaxpr, linear, unroll, _split_transpose): consts_lin, init_lin, xs_lin = split_list(linear, [num_consts, num_carry]) @@ -1574,7 +1355,6 @@ def rearrange(lst): scan_p.def_impl(partial(dispatch.apply_primitive, scan_p)) scan_p.def_effectful_abstract_eval(_scan_abstract_eval) ad.primitive_jvps[scan_p] = _scan_jvp -ad.primitive_transposes[scan_p] = _scan_transpose ad.fancy_transposes[scan_p] = _scan_transpose_fancy ad.primitive_linearizations[scan_p] = _scan_linearize pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index ca0e2b1f0a8f..283a774667e3 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4065,7 +4065,7 @@ def _unbroadcast(aval, x): if (core.definitely_equal_shape(aval.shape, x_shape) and aval.sharding == core.typeof(x).sharding): return x - assert not aval.shape or len(x_shape) == len(aval.shape) + assert not aval.shape or len(x_shape) == len(aval.shape), breakpoint() if not aval.shape: return reduce_sum(x, list(range(len(x_shape)))) else: diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index da0e7ccbbb9e..86ffef7a4599 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -46,9 +46,9 @@ from jax._src import xla_bridge as xb from jax._src.core import typeof, cur_qdd from jax._src.api_util import ( - argnums_partial_except, flatten_axes, flatten_fun3, flatten_fun_nokwargs, - donation_vector, check_callable, resolve_argnums, argnames_partial_except, - debug_info, check_no_aliased_ref_args, _check_no_aliased_closed_over_refs) + argnums_partial_except, flatten_axes, flatten_fun3, donation_vector, + check_callable, resolve_argnums, argnames_partial_except, debug_info, + check_no_aliased_ref_args, _check_no_aliased_closed_over_refs) from jax._src.interpreters import partial_eval as pe from jax._src.partition_spec import PartitionSpec from jax._src.interpreters import ad @@ -2228,81 +2228,6 @@ def _pjit_partial_eval_custom_params_updater( _pjit_partial_eval_custom_params_updater) -@lu.cache -def _pjit_transpose_trace(fun: lu.WrappedFun, - in_avals: Sequence[core.AbstractValue]): - transpose_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals) - transpose_jaxpr = core.ClosedJaxpr(transpose_jaxpr, consts) - return transpose_jaxpr - - -def _pjit_transpose(cts_in, *primals_in, - jaxpr: core.ClosedJaxpr, - in_shardings, out_shardings, in_layouts, out_layouts, - donated_invars, ctx_mesh, name, keep_unused, inline, - compiler_options_kvs): - def prune_type(ty, xs, maybe_zeros): - return tuple(x for x, mz in zip(xs, maybe_zeros) if type(mz) is not ty) - - dbg = jaxpr.jaxpr.debug_info.with_unknown_names() - body = lu.wrap_init(ad.closed_backward_pass, debug_info=dbg) - body = lu.hashable_partial(body, jaxpr, False) - primals_and_nz_cts_in, in_treedef = tree_flatten((primals_in, cts_in)) - body, cts_out_treedef_thunk = flatten_fun_nokwargs(body, in_treedef) - - transpose_in_shardings = ( - *prune_type(ad.UndefinedPrimal, in_shardings, primals_in), - *prune_type(ad.Zero, out_shardings, cts_in) - ) - transpose_in_layouts = ( - *prune_type(ad.UndefinedPrimal, in_layouts, primals_in), - *prune_type(ad.Zero, out_layouts, cts_in) - ) - global_cts_in_avals = tuple( - core.AvalQDD(a, cur_qdd(x)) if (a := typeof(x)).has_qdd else a - for x in primals_and_nz_cts_in) - - transpose_jaxpr = _pjit_transpose_trace(body, global_cts_in_avals) - cts_out_treedef = cts_out_treedef_thunk() - transpose_out_shardings = prune_type( - ad.Zero, - in_shardings, - tree_unflatten(cts_out_treedef, [object()] * cts_out_treedef.num_leaves)) - transpose_out_layouts = prune_type( - ad.Zero, - in_layouts, - tree_unflatten(cts_out_treedef, [object()] * cts_out_treedef.num_leaves)) - - try: - nz_cts_out = jit_p.bind( - *primals_and_nz_cts_in, - jaxpr=transpose_jaxpr, - in_shardings=transpose_in_shardings, - out_shardings=transpose_out_shardings, - in_layouts=transpose_in_layouts, - out_layouts=transpose_out_layouts, - donated_invars=(False,) * len(primals_and_nz_cts_in), - ctx_mesh=ctx_mesh, - name=name, - keep_unused=keep_unused, - inline=inline, - compiler_options_kvs=compiler_options_kvs) - except api_util.InternalFloatingPointError as e: - print("Invalid nan value encountered in the backward pass of a jax.jit " - "function. Calling the de-optimized backward pass.") - try: - _ = ad.closed_backward_pass(jaxpr, None, primals_in, cts_in) - except (FloatingPointError, ZeroDivisionError) as e2: - raise e2 from None # great - else: - # If control reaches this line, we got a NaN on the output of `compiled` - # but not `fun.call_wrapped` on the same arguments. Let's tell the user. - api_util._raise_no_nan_in_deoptimized(e) - - return tree_unflatten(cts_out_treedef, nz_cts_out) -ad.primitive_transposes[jit_p] = _pjit_transpose - - def _pjit_transpose_fancy( cts_in, *args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline, diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py index 6af63b3e7310..b461b06fe37b 100644 --- a/jax/interpreters/ad.py +++ b/jax/interpreters/ad.py @@ -46,25 +46,12 @@ ) -def _deprecated_backward_pass(jaxpr, reduce_axes, transform_stack, - consts, primals_in, cotangents_in): - if reduce_axes: - raise NotImplementedError("reduce_axes on ad.backward_pass is deprecated") - del reduce_axes - return _src_ad.backward_pass( - jaxpr, transform_stack, consts, primals_in, cotangents_in) - - _deprecations = { # Deprecated for JAX v0.7.1; finalize in JAX v0.9.0. "zeros_like_p": ( "jax.interpreters.ad.zeros_like_p is deprecated in JAX v0.7.1. It has been unused since v0.4.24.", _src_ad_util.zeros_like_p, ), - "backward_pass": ( - "jax.interpreters.ad.backward_pass is deprecated.", - _deprecated_backward_pass - ), "bilinear_transpose": ( "jax.interpreters.ad.bilinear_transpose is deprecated.", _src_ad.bilinear_transpose, @@ -81,10 +68,6 @@ def _deprecated_backward_pass(jaxpr, reduce_axes, transform_stack, "jax.interpreters.ad.call_transpose_param_updaters is deprecated.", _src_ad.call_transpose_param_updaters, ), - "closed_backward_pass": ( - "jax.interpreters.ad.closed_backward_pass is deprecated.", - _src_ad.closed_backward_pass, - ), "custom_lin_p": ( "jax.interpreters.ad.custom_lin_p is deprecated.", _src_ad.custom_lin_p, @@ -161,12 +144,10 @@ def _deprecated_backward_pass(jaxpr, reduce_axes, transform_stack, import typing if typing.TYPE_CHECKING: - backward_pass = _deprecated_backward_pass bilinear_transpose = _src_ad.bilinear_transpose call_param_updaters = _src_ad.call_param_updaters call_transpose = _src_ad.call_transpose call_transpose_param_updaters = _src_ad.call_transpose_param_updaters - closed_backward_pass = _src_ad.closed_backward_pass custom_lin_p = _src_ad.custom_lin_p defjvp_zero = _src_ad.defjvp_zero f_jvp_traceable = _src_ad.f_jvp_traceable From cd50850a7613ffc2770f31aea66ff23a73065645 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 13 Dec 2025 00:04:26 -0800 Subject: [PATCH 196/315] Update XLA dependency to use revision http://github.com/openxla/xla/commit/ba2ef9892875a41eb9f30efb2582d8728dc6b9d8 PiperOrigin-RevId: 843987755 --- third_party/xla/revision.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/revision.bzl b/third_party/xla/revision.bzl index 22ed0fce0d6b..882e9d1f6c38 100644 --- a/third_party/xla/revision.bzl +++ b/third_party/xla/revision.bzl @@ -21,5 +21,5 @@ # and update XLA_SHA256 with the result. # buildifier: disable=module-docstring -XLA_COMMIT = "6872ce865c8b5880bce6a7f4d4b2e6fbce0704d2" -XLA_SHA256 = "d4c4dd44aed887092306f4a76eb85cd35f59d307d8e9f6d737901a9ac3e805d2" +XLA_COMMIT = "ba2ef9892875a41eb9f30efb2582d8728dc6b9d8" +XLA_SHA256 = "3e663ebc9af0aa4b27146a76720cb748cae4a5301a8473645a194c08e55ec227" From 925512ebaa5aca0d0d78b2aca4405b66295e7b11 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 14 Dec 2025 00:05:38 -0800 Subject: [PATCH 197/315] Update XLA dependency to use revision http://github.com/openxla/xla/commit/a49f6ba83fc1f912df4126f4ce8d31b2f66d6ca4 PiperOrigin-RevId: 844293227 --- third_party/xla/revision.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/revision.bzl b/third_party/xla/revision.bzl index 882e9d1f6c38..b6d783b44c45 100644 --- a/third_party/xla/revision.bzl +++ b/third_party/xla/revision.bzl @@ -21,5 +21,5 @@ # and update XLA_SHA256 with the result. # buildifier: disable=module-docstring -XLA_COMMIT = "ba2ef9892875a41eb9f30efb2582d8728dc6b9d8" -XLA_SHA256 = "3e663ebc9af0aa4b27146a76720cb748cae4a5301a8473645a194c08e55ec227" +XLA_COMMIT = "a49f6ba83fc1f912df4126f4ce8d31b2f66d6ca4" +XLA_SHA256 = "2cec23a3d8bef63ddba9a52b46bade108e1c13c7974829c32f0699168f2c7ba1" From eb535dde3f8ac07741b576e56c7b20ae4595ca17 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 14 Dec 2025 11:30:25 -0800 Subject: [PATCH 198/315] Automated Code Change PiperOrigin-RevId: 844430584 --- jaxlib/mosaic/dialect/tpu/tpu_dialect.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc index 26780d42252c..1fbc59746921 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc @@ -22,7 +22,6 @@ limitations under the License. #include #include "absl/hash/hash.h" -#include "absl/log/log.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/STLExtras.h" From 164d8bb3f2e43153e5841609eff32cf4c4d92afb Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 15 Dec 2025 00:05:32 -0800 Subject: [PATCH 199/315] Update XLA dependency to use revision http://github.com/openxla/xla/commit/b9f8bd1a637329cfe8eecdf8ac42b5b96445b563 PiperOrigin-RevId: 844626310 --- third_party/xla/revision.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/revision.bzl b/third_party/xla/revision.bzl index b6d783b44c45..79bdbc60e232 100644 --- a/third_party/xla/revision.bzl +++ b/third_party/xla/revision.bzl @@ -21,5 +21,5 @@ # and update XLA_SHA256 with the result. # buildifier: disable=module-docstring -XLA_COMMIT = "a49f6ba83fc1f912df4126f4ce8d31b2f66d6ca4" -XLA_SHA256 = "2cec23a3d8bef63ddba9a52b46bade108e1c13c7974829c32f0699168f2c7ba1" +XLA_COMMIT = "b9f8bd1a637329cfe8eecdf8ac42b5b96445b563" +XLA_SHA256 = "a6efa4a48f737155e41498d5cea6ce1a30828418bdda1c36504ab0f385ac36a5" From 238f05d130aef9588bd33cecf93581fa350065ec Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Mon, 15 Dec 2025 04:36:39 -0800 Subject: [PATCH 200/315] [Pallas:MGPU] Support not tiled 3D+ transposed in `swap` LANE lowering rule. PiperOrigin-RevId: 844713064 --- jax/_src/pallas/mosaic_gpu/lowering.py | 8 ++++++-- tests/pallas/mosaic_gpu_test.py | 28 +++++++++++++------------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index ef20ee735035..e279afaa7904 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1753,12 +1753,16 @@ def _swap_lowering_rule( layout=value.layout, ) value.store_tiled(x_smem, swizzle=swizzle) - case () | (gpu_core.TransposeRef((1, 0)),): + case () | (gpu_core.TransposeRef(),): transposed = bool(transforms) match value.layout: case mgpu.TiledLayout(): if transposed: - x_smem = mgpu.memref_transpose(x_smem, (1, 0)) + assert isinstance( + transforms[0], gpu_core.TransposeRef + ) # silence pytype + permutation = transforms[0].permutation + x_smem = mgpu.memref_transpose(x_smem, permutation) old_value = mgpu.FragmentedArray.load_untiled( x_smem, layout=value.layout, diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 158754fc697a..e0c241b27408 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1078,31 +1078,31 @@ def kernel(x_ref_gmem, idx_ref, o_ref, barrier_ref): idx = jax.random.permutation(jax.random.key(1234), out_shape[0]).astype(jnp.uint32) np.testing.assert_array_equal(kernel(x, idx), x[idx, 64:]) - @parameterized.parameters( - (plgpu.Layout.WGMMA, plgpu.Layout.WGMMA_TRANSPOSED), - (plgpu.Layout.WGMMA_TRANSPOSED, plgpu.Layout.WGMMA), + @parameterized.product( + src_transposed=(False, True), shape=((128, 128), (1, 128, 128)) ) - def test_transposed_load_store(self, src_layout, dst_layout): - def is_transposed(layout): - return layout == plgpu.Layout.WGMMA_TRANSPOSED - - shape, dtype = (128, 128), jnp.float32 - + def test_transposed_load_store(self, src_transposed, shape): + dtype = jnp.float32 + permutation = (0, 2, 1) if len(shape) == 3 else (1, 0) @functools.partial( self.kernel, out_shape=jax.ShapeDtypeStruct(shape, dtype), ) def kernel(src_ref, dst_ref): - if is_transposed(src_layout): - src_ref = src_ref.T - if is_transposed(dst_layout): - dst_ref = dst_ref.T + if src_transposed: + src_ref = plgpu.transpose_ref(src_ref, permutation) + src_layout = plgpu.Layout.WGMMA_TRANSPOSED + dst_layout = plgpu.Layout.WGMMA + else: + dst_ref = plgpu.transpose_ref(dst_ref, permutation) + src_layout = plgpu.Layout.WGMMA + dst_layout = plgpu.Layout.WGMMA_TRANSPOSED src = plgpu.load(src_ref, (), layout=src_layout, optimized=False) dst = plgpu.layout_cast(src, dst_layout) dst_ref[...] = dst x = jnp.arange(math.prod(shape), dtype=dtype).reshape(shape) - np.testing.assert_array_equal(kernel(x), x.T) + np.testing.assert_array_equal(kernel(x), jnp.transpose(x, permutation)) @parameterized.product( src_memory_space=[plgpu.SMEM, plgpu.GMEM], From 8daf02a18e035b4742249ab78d33b5156fa3c119 Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Mon, 15 Dec 2025 08:19:21 -0800 Subject: [PATCH 201/315] [Pallas:MGPU] Add Pallas lowering for `lax.reshape_p` under WG semantic. We lower `lax.reshape_p` to `vector.shape_cast` and `vector.broadcast` for scalars. PiperOrigin-RevId: 844779809 --- jax/_src/pallas/mosaic_gpu/lowering.py | 18 +++++++++++++ .../mosaic/gpu/dialect_lowering.py | 4 ++- .../mosaic/gpu/fragmented_array.py | 2 +- tests/pallas/mosaic_gpu_test.py | 25 ++++++++++++++++--- 4 files changed, 43 insertions(+), 6 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index e279afaa7904..24e3639c5d13 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -2415,6 +2415,24 @@ def _reshape_lowering_rule( return _ensure_fa(x, x_aval.dtype).reshape(new_sizes) +@register_lowering_rule(lax.reshape_p, mgpu.LoweringSemantics.Warpgroup) +def _reshape_lowering_rule_wg( + ctx: LoweringRuleContext, x, new_sizes, dimensions, sharding +): + if dimensions is not None: + raise NotImplementedError("Not implemented: dimensions") + if sharding is not None: + raise NotImplementedError("Not implemented: sharding") + [x_aval] = ctx.avals_in + x = _ensure_ir_value(x, x_aval.dtype) + if x_aval.ndim == 0: # scalar + res_ty = ir.VectorType.get(new_sizes, x.type) + return vector_dialect.broadcast(res_ty, x) + else: + res_ty = ir.VectorType.get(new_sizes, ir.VectorType(x.type).element_type) + return vector_dialect.shape_cast(res_ty, x) + + @register_lowering_rule(lax.squeeze_p, mgpu.LoweringSemantics.Lane) def _squeeze_lowering_rule(ctx: LoweringRuleContext, x, dimensions): [x_aval] = ctx.avals_in diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index d7a9141fa318..f442b9064376 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -646,7 +646,9 @@ def _vector_shape_cast_op_lowering_rule( out_vec_ty = ir.VectorType(op.result.type) assert out_vec_ty.has_static_shape a = _fragmented_array_from_ir(op.source, layout) - return [fragmented_array_to_ir(a.reshape(out_vec_ty.shape), out_vec_ty)] + return [ + fragmented_array_to_ir(a.reshape(tuple(out_vec_ty.shape)), out_vec_ty) + ] @_register_lowering(vector.ExtractStridedSliceOp) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index a81784f1e6d5..ac0a10199943 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -2462,7 +2462,7 @@ def broadcast(self, shape) -> FragmentedArray: _is_signed=self.is_signed, ) - def reshape(self, shape) -> FragmentedArray: + def reshape(self, shape: tuple[int, ...]) -> FragmentedArray: if self.shape == shape: return self if math.prod(shape) != math.prod(self.shape): diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index e0c241b27408..2da4c1327c60 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -323,7 +323,6 @@ def kernel(x_ref, out_ref): np.testing.assert_array_equal(kernel(x), x.reshape(shape2)) def test_reshape_tiled(self): - self.skip_if_wg_semantics() shape1, shape2 = (6 * 64, 8), (2, 3, 64, 8) @functools.partial( @@ -331,12 +330,31 @@ def test_reshape_tiled(self): out_shape=jax.ShapeDtypeStruct(shape2, jnp.float32), ) def kernel(x_ref, out_ref): - y = plgpu.load(x_ref, (), layout=plgpu.Layout.WGMMA, optimized=False).reshape(shape2) - out_ref[...] = y + x = plgpu.load(x_ref, (), layout=plgpu.Layout.WGMMA, optimized=False) + out_ref[...] = x.reshape(shape2) x = jnp.arange(math.prod(shape1)).reshape(shape1).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x.reshape(shape2)) + def test_reshape_splat(self): + shape = (1, 1, 1) + + # TODO(allanrenucci): Fix swap_p lowering for scalars under Lane semantics. + if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: + self.skipTest("Not supported under Lane semantics") + + @functools.partial( + self.kernel, + out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), + ) + def kernel(out_ref): + x = jnp.array(42, dtype=jnp.float32) + out_ref[...] = x.reshape(shape) + + np.testing.assert_array_equal( + kernel(), jnp.array(42, dtype=jnp.float32).reshape(shape) + ) + def test_slice_untiled_dim(self): self.skip_if_wg_semantics() shape = (2, 3, 64, 8) @@ -2849,7 +2867,6 @@ def test_missing_primitive_lowerings_are_tracked(self): pallas_primitives.semaphore_read_p, pallas_primitives.delay_p, checkify.check_p, - lax.reshape_p, lax.squeeze_p, } From 253ff54b0db539332f273016b740df76cc8b8fc7 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 12 Dec 2025 16:45:03 -0800 Subject: [PATCH 202/315] Deprecate a number of rarely-used jax.core APIs --- CHANGELOG.md | 4 +++ jax/core.py | 72 +++++++++++++++++++++++++++++++++++++++------------- 2 files changed, 58 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 74ed57bb879c..6d3c417b81ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. Please use `jax.lax.pcast(..., to='varying')` as the replacement. * Complex arguments passed to {func}`jax.numpy.arange` now result in a deprecation warning, because the output is poorly-defined. + * From {mod}`jax.core` a number of symbols are newly deprecated including: + `call_impl`, `get_aval`, `mapped_aval`, `subjaxprs`, `set_current_trace`, + `take_current_trace`, `traverse_jaxpr_params`, `unmapped_aval`, + `AbstractToken`, and `TraceTag`. * Changes: * jax's `Tracer` no longer inherits from `jax.Array` at runtime. However, diff --git a/jax/core.py b/jax/core.py index 7fca3e07dd99..ff8a29779596 100644 --- a/jax/core.py +++ b/jax/core.py @@ -15,8 +15,8 @@ # Note: import as is required for names to be exported. # See PEP 484 & https://github.com/jax-ml/jax/issues/7570 +import jax._src.core as _src_core from jax._src.core import ( - AbstractToken as AbstractToken, AbstractValue as AbstractValue, Atom as Atom, CallPrimitive as CallPrimitive, @@ -24,25 +24,19 @@ DropVar as DropVar, Effect as Effect, Effects as Effects, - get_opaque_trace_state as get_opaque_trace_state, InconclusiveDimensionOperation as InconclusiveDimensionOperation, JaxprPpContext as JaxprPpContext, JaxprPpSettings as JaxprPpSettings, JaxprTypeError as JaxprTypeError, - nonempty_axis_env as nonempty_axis_env_DO_NOT_USE, # noqa: F401 OutputType as OutputType, ParamDict as ParamDict, ShapedArray as ShapedArray, Trace as Trace, Tracer as Tracer, - unsafe_am_i_under_a_jit as unsafe_am_i_under_a_jit_DO_NOT_USE, # noqa: F401 - unsafe_am_i_under_a_vmap as unsafe_am_i_under_a_vmap_DO_NOT_USE, # noqa: F401 - unsafe_get_axis_names as unsafe_get_axis_names_DO_NOT_USE, # noqa: F401 Value as Value, abstract_token as abstract_token, aval_mapping_handlers as aval_mapping_handlers, call as call, - call_impl as call_impl, check_jaxpr as check_jaxpr, concrete_or_error as concrete_or_error, concretization_function_error as concretization_function_error, @@ -52,44 +46,86 @@ eval_jaxpr as eval_jaxpr, find_top_trace as find_top_trace, gensym as gensym, - get_aval as _deprecated_get_aval, + get_opaque_trace_state as get_opaque_trace_state, is_concrete as is_concrete, is_constant_dim as is_constant_dim, is_constant_shape as is_constant_shape, jaxprs_in_params as jaxprs_in_params, literalable_types as literalable_types, - mapped_aval as mapped_aval, max_dim as max_dim, min_dim as min_dim, new_jaxpr_eqn as new_jaxpr_eqn, no_axis_name as no_axis_name, no_effects as no_effects, + nonempty_axis_env as nonempty_axis_env_DO_NOT_USE, # noqa: F401 primal_dtype_to_tangent_dtype as primal_dtype_to_tangent_dtype, pytype_aval_mappings as pytype_aval_mappings, - set_current_trace as set_current_trace, - subjaxprs as subjaxprs, - take_current_trace as take_current_trace, trace_ctx as trace_ctx, - TraceTag as TraceTag, - traverse_jaxpr_params as traverse_jaxpr_params, - unmapped_aval as unmapped_aval, + unsafe_am_i_under_a_jit as unsafe_am_i_under_a_jit_DO_NOT_USE, # noqa: F401 + unsafe_am_i_under_a_vmap as unsafe_am_i_under_a_vmap_DO_NOT_USE, # noqa: F401 + unsafe_get_axis_names as unsafe_get_axis_names_DO_NOT_USE, # noqa: F401 valid_jaxtype as valid_jaxtype, ) _deprecations = { # Added for v0.8.2 + "call_impl": ( + "jax.core.call_impl is deprecated.", + _src_core.call_impl, + ), "get_aval": ( "jax.core.get_aval is deprecated; use jax.typeof instead.", - _deprecated_get_aval + _src_core.get_aval, + ), + "mapped_aval": ( + "jax.core.mapped_aval is deprecated.", + _src_core.mapped_aval, + ), + "set_current_trace": ( + "jax.core.set_current_trace is deprecated.", + _src_core.set_current_trace, + ), + "subjaxprs": ( + "jax.core.subjaxprs is deprecated.", + _src_core.subjaxprs, + ), + "take_current_trace": ( + "jax.core.take_current_trace is deprecated.", + _src_core.take_current_trace, + ), + "traverse_jaxpr_params": ( + "jax.core.traverse_jaxpr_params is deprecated.", + _src_core.traverse_jaxpr_params, + ), + "unmapped_aval": ( + "jax.core.unmapped_aval is deprecated.", + _src_core.unmapped_aval, + ), + "AbstractToken": ( + "jax.core.AbstractToken is deprecated.", + _src_core.AbstractToken, + ), + "TraceTag": ( + "jax.core.TraceTag is deprecated.", + _src_core.TraceTag, ), } import typing as _typing if _typing.TYPE_CHECKING: - get_aval = _deprecated_get_aval + call_impl = _src_core.call_impl + get_aval = _src_core.get_aval + mapped_aval = _src_core.mapped_aval + subjaxprs = _src_core.subjaxprs + set_current_trace = _src_core.set_current_trace + take_current_trace = _src_core.take_current_trace + traverse_jaxpr_params = _src_core.traverse_jaxpr_params + unmapped_aval = _src_core.unmapped_aval + AbstractToken = _src_core.AbstractToken + TraceTag = _src_core.TraceTag else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations) del _deprecation_getattr del _typing -del _deprecated_get_aval +del _src_core From 9cfdc4601087bd44f1a50cee48088753c20153bf Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 15 Dec 2025 08:37:30 -0800 Subject: [PATCH 203/315] [test] fix toeplitz test under scipy 1.17 --- tests/linalg_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 8c5f7b200108..ad02379dffad 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -85,7 +85,9 @@ def _random_invertible(rng, shape, dtype): def osp_linalg_toeplitz(c: np.ndarray, r: np.ndarray | None = None) -> np.ndarray: """scipy.linalg.toeplitz with v1.17+ batching semantics.""" - if scipy_version >= (1, 17, 0): + # scipy 1.17 doesn't support zero batch size: https://github.com/scipy/scipy/pull/24151 + zero_batch = (0 in c.shape[:-1]) or (r is not None and 0 in r.shape[:-1]) + if scipy_version >= (1, 17, 0) and not zero_batch: return scipy.linalg.toeplitz(c, r) elif r is None: c = np.atleast_1d(c) From a2e14e41a5554e875434f7b7f0562317aee11804 Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Mon, 15 Dec 2025 08:43:16 -0800 Subject: [PATCH 204/315] [Mosaic GPU] Implement GPU module unloading for Mosaic GPU custom calls. We introduce reference counting in the global kernel cache. A kernel handle is created in the global cache the first time it is compiled. Subsequent calls to "compile" will return a shared pointer to the kernel handle. We keep references to the compiled kernels in the custom call state. We unload the module in the kernel handle destructor. I.e. when there is no more reference to a kernel handle, the module is unloaded. We retain the old caching mechanism for the legacy custom call. GPU module are never unloaded when using the legacy custom call. PiperOrigin-RevId: 844787620 --- jaxlib/mosaic/gpu/BUILD | 4 + jaxlib/mosaic/gpu/custom_call.cc | 283 +++++++++++++++++++------- jaxlib/mosaic/gpu/custom_call_test.cc | 155 +++++++++++++- jaxlib/mosaic/gpu/nvshmem.h | 7 + 4 files changed, 366 insertions(+), 83 deletions(-) diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index b7359495f184..b0e93a39b3a4 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -235,6 +235,7 @@ cc_library( "-Wl,--export-dynamic-symbol='nvshmemx_mc_ptr'", "-Wl,--export-dynamic-symbol='nvshmemx_barrier_all_on_stream'", "-Wl,--export-dynamic-symbol='nvshmemx_cumodule_init'", + "-Wl,--export-dynamic-symbol='nvshmemx_cumodule_finalize'", "-Wl,--export-dynamic-symbol='nvshmemx_init_status'", ], deps = [ @@ -351,6 +352,9 @@ cc_test( deps = [ ":mosaic_gpu_support", "//testing/base/public:gunit_main", + "@com_google_absl//absl/base:log_severity", + "@com_google_absl//absl/log:globals", + "@com_google_absl//absl/log:scoped_mock_log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings", diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 29931d56e139..f1d7e1199ee6 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -50,8 +50,10 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/driver_types.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/CodeGen.h" #include "llvm/Support/Debug.h" @@ -130,6 +132,8 @@ limitations under the License. namespace { +using ::mosaic::gpu::NvshmemApi; + namespace ffi = xla::ffi; namespace se = stream_executor; @@ -507,19 +511,25 @@ absl::StatusOr, bool>> Compile( class CompiledKernel { public: CompiledKernel(std::unique_ptr engine, void* ctx, - MosaicHostFunc* host_launch, bool is_comm_used) + CUmodule module, MosaicHostFunc* host_launch, + bool is_comm_used) : engine_(std::move(engine)), ctx_(ctx), + module_(module), host_launch_(host_launch), is_comm_used_(is_comm_used) {} - std::tuple GetHostLaunch() { + std::tuple GetHostLaunch() const { return std::make_tuple(ctx_, host_launch_, is_comm_used_); } + CUmodule module() const { return module_; } + bool is_comm_used() const { return is_comm_used_; } + private: std::unique_ptr engine_; void* ctx_; // TODO(apaszke): Destroy this properly + CUmodule module_; MosaicHostFunc* host_launch_; bool is_comm_used_; }; @@ -559,7 +569,7 @@ absl::StatusOr> GetHostAndInitFuncNames( return std::make_pair(host_func_name, init_func_name); } -absl::StatusOr CompileAndInit(llvm::StringRef module) { +absl::StatusOr CompileAndInit(absl::string_view module) { mlir::MLIRContext context(mlir::MLIRContext::Threading::DISABLED); context.allowUnregisteredDialects(true); InitContext(&context); @@ -600,13 +610,30 @@ absl::StatusOr CompileAndInit(llvm::StringRef module) { void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr}; reinterpret_cast(*init)(init_args); return CompiledKernel(std::move(maybe_engine.value().first), kernel_ptr, + reinterpret_cast(module_ptr), reinterpret_cast(*host), is_comm_used); } +absl::Status Unload(const CompiledKernel& kernel, CUcontext ctx) { + CUDA_RETURN_IF_ERROR(cuCtxPushCurrent(ctx)); + if (kernel.is_comm_used()) { + if (NvshmemApi::Default().cumodule_finalize(kernel.module()) != + NVSHMEM_SUCCESS) { + return absl::InternalError("nvshmemx_cumodule_finalize failed"); + } + } + CUDA_RETURN_IF_ERROR(cuModuleUnload(kernel.module())); + CUcontext unused; + CUDA_RETURN_IF_ERROR(cuCtxPopCurrent(&unused)); + return absl::OkStatus(); +} + using KernelHash = std::array; -using CacheKey = std::pair; -struct KernelCache { +// A reference counted cache of compiled and loaded kernels. +class KernelCache { + public: + // A global cache of compiled and loaded kernels. static KernelCache& Global() { static absl::NoDestructor cache; return *cache; @@ -617,80 +644,89 @@ struct KernelCache { KernelCache(const KernelCache&) = delete; KernelCache(KernelCache&&) = delete; - absl::Mutex mutex; - absl::flat_hash_map kernels ABSL_GUARDED_BY(mutex); -}; - -// Each compiled kernel has a unique init func, and each kernel is used from -// a single HLO module. So it should be safe to not include the CUDA context -// in the key. -absl::StatusOr CachedCompileAndInit(CacheKey key, - llvm::StringRef module) { - KernelCache& cache = KernelCache::Global(); + // Holds a reference to a compiled and loaded kernel. + // Unload the kernel when the handle is destroyed. + class KernelHandle { + public: + KernelHandle(CompiledKernel kernel, CUcontext ctx) + : kernel_(std::move(kernel)), ctx_(ctx) {} + ~KernelHandle() { + CHECK_OK(Unload(kernel_, ctx_)); + VLOG(5) << "Successfully unloaded GPU module"; + } + const CompiledKernel* kernel() const { return &kernel_; } - { - // Fast path uses reader lock (as hash map look-up is relatively slow). - absl::ReaderMutexLock lock(cache.mutex); - auto it = cache.kernels.find(key); - if (ABSL_PREDICT_TRUE(it != cache.kernels.end())) return &it->second; - } + private: + CompiledKernel kernel_; + CUcontext ctx_; // The CUDA context in which the kernel was loaded. + }; - absl::MutexLock lock(cache.mutex); - // We released the reader lock, another thread might have initialized it. - if (cache.kernels.find(key) == cache.kernels.end()) { - tsl::profiler::TraceMe trace("Compilation cache miss"); - auto compiled = CompileAndInit(module); - if (!compiled.ok()) { - return compiled.status(); + // Compile and load the given module in the current CUDA context. + absl::StatusOr> CompileAndInit( + const KernelHash& kernel_hash, absl::string_view module) { + CUcontext ctx; + CUDA_RETURN_IF_ERROR(cuCtxGetCurrent(&ctx)); + CacheKey key(kernel_hash, reinterpret_cast(ctx)); + absl::MutexLock lock(mutex_); + if (auto it = kernels_.find(key); it != kernels_.end()) { + std::shared_ptr handle = it->second.lock(); + if (handle) { + return handle; + } } - cache.kernels.insert_or_assign(key, std::move(*compiled)); + // Kernel not found or has expired, create a new value. + tsl::profiler::TraceMe trace("Compilation cache miss"); + TF_ASSIGN_OR_RETURN(CompiledKernel compiled, ::CompileAndInit(module)); + VLOG(5) << "Successfully compiled and initialized Mosaic GPU kernel"; + auto handle = std::make_shared(std::move(compiled), ctx); + kernels_[key] = handle; + return handle; } - return &cache.kernels.at(key); -} -// TODO(b/464203195): Backward-compatible version using the legacy FFI -// API. Remove once backward compatibility window has passed. -void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - if (reinterpret_cast(opaque) % alignof(KernelHash)) { - fprintf(stderr, "Misaligned opaque pointer\n"); - abort(); - } - auto hash = *reinterpret_cast(opaque); - CUcontext ctx; - if (cuCtxGetCurrent(&ctx) != CUDA_SUCCESS) { - fprintf(stderr, "Failed to get current CUDA context\n"); - abort(); - } - CacheKey key(hash, reinterpret_cast(ctx)); - auto compiled_kernel = CachedCompileAndInit(key, opaque + sizeof(KernelHash)); - if (!compiled_kernel.ok()) { - XlaCustomCallStatusSetFailure(status, - compiled_kernel.status().message().data(), - compiled_kernel.status().message().size()); - return; + private: + using CacheKey = std::pair; + absl::Mutex mutex_; + absl::flat_hash_map> kernels_ + ABSL_GUARDED_BY(mutex_); +}; + +// Tracks the compiled and loaded kernels for a given custom call. +// There is a single global cache in the process and a process can have +// multiple devices, each of which must load/unload the module. We expect each +// device/module pair to have a unique cache key. +class CustomCallResources { + public: + CustomCallResources() = default; + + const CompiledKernel* KernelForDevice(int32_t device_ordinal) const { + absl::MutexLock lock(mutex_); + return kernels_.at(device_ordinal)->kernel(); } - auto ctx_kernel_comm = (*compiled_kernel)->GetHostLaunch(); - bool is_comm_used = std::get<2>(ctx_kernel_comm); - void* args[4] = {&std::get<0>(ctx_kernel_comm), &stream, &buffers}; - if (is_comm_used) { - mosaic::gpu::NvshmemApi::Default().barrier_all_on_stream( - reinterpret_cast(stream)); + + void AddKernel(int32_t device_ordinal, + std::shared_ptr kernel) { + absl::MutexLock lock(mutex_); + kernels_[device_ordinal] = std::move(kernel); } - std::get<1>(ctx_kernel_comm)(args); -} -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("mosaic_gpu", &MosaicGPUCustomCall, - "CUDA"); + private: + mutable absl::Mutex mutex_; + absl::flat_hash_map> + kernels_ ABSL_GUARDED_BY(mutex_); +}; -absl::Status MosaicGpuExecute(cudaStream_t stream, ffi::RemainingArgs inputs, - ffi::RemainingRets results, - std::string_view kernel_hash, - std::string_view module, - bool use_custom_barrier) { - if (use_custom_barrier) { - return absl::UnimplementedError("Custom barrier is not supported on GPUs."); - } +absl::StatusOr> InstantiateResources() { + // TODO(b/466097203): Ideally we would compile the module here. + // Sadly we need to acquire a lock on LLVM command line options which is + // already held by XLA causing a deadlock. + // See `GpuCompiler::CompileToBackendResult`. + return std::make_unique(); +} + +absl::Status InitializeResources(int32_t device_ordinal, + CustomCallResources* resources, + std::string_view kernel_hash, + std::string_view module, bool) { if (kernel_hash.size() != sizeof(KernelHash)) { return absl::InvalidArgumentError( absl::StrFormat("Kernel hash size is %d bytes, expected %d bytes", @@ -698,11 +734,23 @@ absl::Status MosaicGpuExecute(cudaStream_t stream, ffi::RemainingArgs inputs, } KernelHash hash; std::memcpy(hash.data(), kernel_hash.data(), sizeof(KernelHash)); - CUcontext ctx; - CUDA_RETURN_IF_ERROR(cuCtxGetCurrent(&ctx)); - CacheKey key(hash, reinterpret_cast(ctx)); - TF_ASSIGN_OR_RETURN(auto compiled_kernel, - CachedCompileAndInit(key, module)); + TF_ASSIGN_OR_RETURN( + std::shared_ptr handle, + KernelCache::Global().CompileAndInit(hash, module)); + resources->AddKernel(device_ordinal, std::move(handle)); + return absl::OkStatus(); +} + +absl::Status MosaicGpuExecute(cudaStream_t stream, int32_t device_ordinal, + ffi::RemainingArgs inputs, + ffi::RemainingRets results, + CustomCallResources* resources, std::string_view, + std::string_view, bool use_custom_barrier) { + if (use_custom_barrier) { + return absl::UnimplementedError("Custom barrier is not supported on GPUs."); + } + const CompiledKernel* compiled_kernel = + resources->KernelForDevice(device_ordinal); auto ctx_kernel_comm = compiled_kernel->GetHostLaunch(); bool is_comm_used = std::get<2>(ctx_kernel_comm); @@ -730,17 +778,30 @@ absl::Status MosaicGpuExecute(cudaStream_t stream, ffi::RemainingArgs inputs, void* args[4] = {&std::get<0>(ctx_kernel_comm), &stream, &buffers_ptr}; if (is_comm_used) { - mosaic::gpu::NvshmemApi::Default().barrier_all_on_stream(stream); + NvshmemApi::Default().barrier_all_on_stream(stream); } std::get<1>(ctx_kernel_comm)(args); return absl::OkStatus(); } +XLA_FFI_DEFINE_HANDLER(kInstantiateResources, InstantiateResources, + ffi::Ffi::BindInstantiate()); + +XLA_FFI_DEFINE_HANDLER(kInitializeResources, InitializeResources, + ffi::Ffi::BindInitialize() + .Ctx() + .Ctx>() + .Attr("kernel_hash") + .Attr("module") + .Attr("use_custom_barrier")); + XLA_FFI_DEFINE_HANDLER(kMosaicGpuExecute, MosaicGpuExecute, ffi::Ffi::Bind() .Ctx>() + .Ctx() .RemainingArgs() .RemainingRets() + .Ctx>() .Attr("kernel_hash") .Attr("module") .Attr("use_custom_barrier"), @@ -748,12 +809,78 @@ XLA_FFI_DEFINE_HANDLER(kMosaicGpuExecute, MosaicGpuExecute, XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "mosaic_gpu_v2", "CUDA", { - /*instantiate=*/nullptr, + /*instantiate=*/kInstantiateResources, /*prepare=*/nullptr, - /*initialize=*/nullptr, + /*initialize=*/kInitializeResources, /*execute=*/kMosaicGpuExecute, }); +// Cache compiled and loaded kernels in the current CUDA context. +// Loaded kernels are never unloaded. +absl::StatusOr LegacyCachedCompileAndInit( + const KernelHash& kernel_hash, absl::string_view module) { + using CacheKey = std::pair; + struct LegacyCache { + absl::Mutex mutex; + absl::flat_hash_map kernels + ABSL_GUARDED_BY(mutex); + }; + static absl::NoDestructor cache; + + CUcontext ctx; + CUDA_RETURN_IF_ERROR(cuCtxGetCurrent(&ctx)); + + CacheKey key(kernel_hash, reinterpret_cast(ctx)); + { + // Fast path uses reader lock (as hash map look-up is relatively slow). + absl::ReaderMutexLock lock(cache->mutex); + auto it = cache->kernels.find(key); + if (ABSL_PREDICT_TRUE(it != cache->kernels.end())) return &it->second; + } + + absl::MutexLock lock(cache->mutex); + // We released the reader lock, another thread might have initialized it. + if (cache->kernels.find(key) == cache->kernels.end()) { + tsl::profiler::TraceMe trace("Compilation cache miss"); + auto compiled = CompileAndInit(module); + if (!compiled.ok()) { + return compiled.status(); + } + cache->kernels.insert_or_assign(key, std::move(*compiled)); + } + return &cache->kernels.at(key); +} + +// TODO(b/464203195): Backward-compatible version using the legacy FFI +// API. Remove once backward compatibility window has passed. +void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque, + size_t opaque_len, XlaCustomCallStatus* status) { + if (reinterpret_cast(opaque) % alignof(KernelHash)) { + fprintf(stderr, "Misaligned opaque pointer\n"); + abort(); + } + auto hash = *reinterpret_cast(opaque); + auto compiled_kernel = + LegacyCachedCompileAndInit(hash, opaque + sizeof(KernelHash)); + if (!compiled_kernel.ok()) { + XlaCustomCallStatusSetFailure(status, + compiled_kernel.status().message().data(), + compiled_kernel.status().message().size()); + return; + } + auto ctx_kernel_comm = (*compiled_kernel)->GetHostLaunch(); + bool is_comm_used = std::get<2>(ctx_kernel_comm); + void* args[4] = {&std::get<0>(ctx_kernel_comm), &stream, &buffers}; + if (is_comm_used) { + mosaic::gpu::NvshmemApi::Default().barrier_all_on_stream( + reinterpret_cast(stream)); + } + std::get<1>(ctx_kernel_comm)(args); +} + +XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("mosaic_gpu", &MosaicGPUCustomCall, + "CUDA"); + } // namespace extern "C" { diff --git a/jaxlib/mosaic/gpu/custom_call_test.cc b/jaxlib/mosaic/gpu/custom_call_test.cc index e4756a394325..d3426c0fd71a 100644 --- a/jaxlib/mosaic/gpu/custom_call_test.cc +++ b/jaxlib/mosaic/gpu/custom_call_test.cc @@ -19,6 +19,9 @@ limitations under the License. #include #include +#include "absl/base/log_severity.h" +#include "absl/log/globals.h" +#include "absl/log/scoped_mock_log.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/strings/str_cat.h" @@ -36,6 +39,7 @@ limitations under the License. namespace { using ::absl_testing::IsOk; +using ::testing::_; absl::Status ExecuteSync(xla::PjRtLoadedExecutable* executable) { std::vector no_buffers; @@ -65,16 +69,16 @@ ENTRY main { custom_call_target="mosaic_gpu_v2", api_version=API_VERSION_TYPED_FFI })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - xla::ParseAndReturnUnverifiedModule(kHloModule)); + ASSERT_OK_AND_ASSIGN(auto module, + xla::ParseAndReturnUnverifiedModule(kHloModule)); std::string tmp_path = testing::TempDir(); tsl::setenv("XLA_FLAGS", absl::StrCat("--xla_dump_to=", tmp_path).c_str(), /*overwrite=*/true); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, - xla::GetXlaPjrtGpuClient(/*options=*/{})); - TF_ASSERT_OK_AND_ASSIGN( + ASSERT_OK_AND_ASSIGN(std::unique_ptr client, + xla::GetXlaPjrtGpuClient(/*options=*/{})); + ASSERT_OK_AND_ASSIGN( std::unique_ptr executable, client->CompileAndLoad(xla::XlaComputation(module->ToProto()), /*options=*/{})); @@ -134,4 +138,145 @@ TEST(CustomCallTest, LegacyCustomCall) { EXPECT_THAT(ExecuteSync(executable.get()), IsOk()); } +absl::string_view TestMGPUHloModule() { + // Dumped from the following JAX program: + // + // ``` + // @functools.partial( + // plgpu.pallas_call, + // out_shape=jax.ShapeDtypeStruct((), jnp.int32), + // out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + // ) + // def kernel(o_ref): + // o_ref[...] = jnp.array(42) + // ``` + return R"hlo( + HloModule test + + ENTRY main { + ROOT result = s32[] custom-call(), custom_call_target="mosaic_gpu_v2", api_version=API_VERSION_TYPED_FFI, backend_config={kernel_hash = "\90\C7\1F$\92=c\9D\E4\A8\15\B1Y\9B.\02\B4\B0\0B\16\C5Ol\D4\ED\CDdA-\C9\D77", module = "ML\EFR\01MLIR\00\01O\0D\01\03\05\07\09\0B\01\03\0D\037\0F\11\13\15\17\19\1B\1D\1F!#%')+-/13579;=?AC\03\12\02\C9\1D\01\BB\0F\13\0B\0B\0F\13\13\13\13\0B\07\0B\0B\13\13\0B\0F\13\13\13e\1B\0B\0F\0B\0B#\0B\0B\0B\0B;\0B\0B\0B\0B\0B\0B\0B#\0B\0B\07\0B\13\0F\0F\13\13\13\0F\13\13\0B\133\133\133U\1B\0B\C3\0B\13\13\13\13\13\13\13\13\13\17\17\17\0B\0F\1F\0F\0B\0B\13\13\0B\0B\0F\0B\0F\0B\17\0B\05\03a\07\09y111\09\03Y\0B\03U\01\15\0F\07\0F\0B\0B\1B/\17\13;\05\07)yQ\07\03E\02\AE\0A\1D3\15\03\03\9B\C5\05E\05G\11\05\01\03\03\07]\03\03\19\BF\03\03\19\C1\03\03\19\C3\05I\1F\05K\05M\03\03\07\9D\03\03\A5\09\05O\11\01\11\03\03\07\9F\03\03\07\A1\03\03\A3\C7affine_map<(d0) -> (d0)>\00\03\05-/\131\05Q\11\05\19\05S\05U\03\07\1F7\139;=\0D\0D\05W\05Y\05[\03\0DA!CEG\BB\13IK\09M\09\05]\05_\0D\19\05a\05c\05e\05g\03\07\1FQSU\13W\0D\0F\05i\0F\05k\03\03\07[\11\01\A9\11\01\01\03\03\07a\11\03\02\04\03\03\07e\11\03\05\03\03\07\09\03\03k\09\05m\03\03\17o#\05\03\11\00\00\00\00\00\00\00\00\03\03\17s#\05\03\11\01\00\00\00\00\00\00\00\03\03\17w#\05\03\11\02\00\00\00\00\00\00\00affine_map<() -> ()>\00\03\05}\7F\81\09\05o#\01\17Y\01\00\00\00\01\00\00\00\01\00\00\00\01\00\00\00\01\00\00\00\01\00\00\00\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\05q\17\05%O\17\05%]\17\05%k\17\05%\E1\17\05%\EF\17\05%\FD\17\05%\81\17\05%\9B\17\05%\B5\17\05%&\02\17\05%f\02\17\05%\9E\02\05s\11\01\15\11\01\D0\FF\FF\FF?\11\01}\05u\05w\03\03\07!\03\03\AB\AD\05y\01\01\1D\B1\B3\05{\1D\B5\B7\05}\17\B9\06\03\0D\05\7F#llvm.linkage\00#gpu.address_space\00#gpu\00#gpu\00#gpu\00#arith.overflow\00#nvvm\00\01\02\02\03\01\02\04\01\09\01A\17\BD\03\01\09)\05\11\15\15\05\05\15\15\05\15\01\05\05\15\15\01\15\01\01y\17\BD\03\00\FF\FF\FF\FF\FF\FF\FF\FF\09)!llvm.ptr\00!llvm.struct<(ptr, ptr, i64)>\00!llvm.array<0 x i8>\00!gpu.async.token\00\04Z\0C\05\01\11\01+\07\03\01\0D\17\11\015\07\01\1F\11\01?\07\01\17\11\01O\07\03\1F;\05\15\01\15\01\05\03\15Y\03\01\05\03\15\0B\03\01\05\03\01_\03\03\05\03\15c\03\03!\03\01g\03\05#\02\01\03\17\0F\06\01\03\1B\03\01%\07\01i\03\15\03\03\11\07\01m\03\17\05\0F\13\11\07\01q\03\17\05\15\13\11\07\01u\03\17\05\17\0D\0F\06\01\03\11\03\19'\17\01{\03\1B\11\11\0B\0B\0B\09\0B\0B\07\05\03\C1\C6\02\19\03\83\03\85\03\87\03\89\03\8B\03\8D\03\8F\03\91\03\93\03\95\03\97\03\99\19\02\01\03\07\09\03\01\0D\03\03\03\06\01\03\01\039\0B\03\01\0D\03\03\03\06\01\03\01\03=\09\03\01\0F\03\03\03\06\01\03\01\03A\07\07\01\03\03\01\05C?\0D\07\01\03\03\01\05;E\0B\03\01\0F\03\03\03\06\01\03\01\03I\07\07\01\03\03\01\05?K\09\03\01\11\03\03\03\06\01\03\01\03O\07\07\01\03\03\01\05QM\0D\07\01\03\03\01\05GS\0B\03\01\11\03\03\03\06\01\03\01\03W\07\07\01\03\03\01\05MY\05\03\01\1B\03\01\13\06\01\03\01\05U]\05\03\01#\03\01\05\03\01\0B\03\01\05\03\01%\03\01\1B\07\01'\03\01\09a_ce\05\03\01\0B\03\01\15\07\01\1D\03\07\05gi\1D\06\01\03\07\05k7\19\02\01\03\07\09\03\01\0D\03\03\03\06\01\03\01\03q\0B\03\01\0D\03\03\03\06\01\03\01\03u\09\03\01\0F\03\03\03\06\01\03\01\03y\07\07\01\03\03\01\05{w\0D\07\01\03\03\01\05s}\0B\03\01\0F\03\03\03\06\01\03\01\03\81\07\07\01\03\03\01\05w\83\09\03\01\11\03\03\03\06\01\03\01\03\87\07\07\01\03\03\01\05\89\85\0D\07\01\03\03\01\05\7F\8B\0B\03\01\11\03\03\03\06\01\03\01\03\8F\07\07\01\03\03\01\05\85\91\05\03\01\1B\03\01\13\06\01\03\01\05\8D\95\05\03\01#\03\01\05\03\01\0B\03\01\05\03\01%\03\01\1B\07\01'\03\01\09\99\97\9B\9D\05\03\01\A7\03\01+\06\01\03\01\05\9F\A1\05\03\01\0B\03\01\15\07\01\1D\03\07\05\A3\A5\1D\06\01\03\07\05\A7o\09\03\01\0D\03\03\03\06\01\03\01\03\AB\0B\03\01\0D\03\03\03\06\01\03\01\03\AF\09\03\01\0F\03\03\03\06\01\03\01\03\B3\07\07\01\03\03\01\05\B5\B1\0D\07\01\03\03\01\05\AD\B7\0B\03\01\0F\03\03\03\06\01\03\01\03\BB\07\07\01\03\03\01\05\B1\BD\09\03\01\11\03\03\03\06\01\03\01\03\C1\07\07\01\03\03\01\05\C3\BF\0D\07\01\03\03\01\05\B9\C5\0B\03\01\11\03\03\03\06\01\03\01\03\C9\07\07\01\03\03\01\05\BF\CB\05\03\01\1B\03\01\13\06\01\03\01\05\C7\CF\05\03\01\0B\03\01\15\07\01\1D\03\07\05\D1\D3-\02\01\03\13\03\06\01\03\03\03\07/\06\01\03\0B\05\D7\D9\0F\07\01\A9\03\0B\03\DB1\00\013\00\015\04\AF\05\05\1B7\00\01)\00\01\06\03\01\05\01\00\9E\0E\81g\0B\0D\17\15\0B\1D/)\13%-\19\1B\1F\11\19\17\11\1F3\19\0F5\1D\15\13\13\0D\05\1F\1B\193\195\19\19\17\15!'#\17\1F!\15\17\19#G\17\1D\1D\17\1F#\0F\0B\0D\09\0B%\11builtin\00stable_mosaic_gpu\00llvm\00gpu\00arith\00nvvm\00module\00arith.index_cast\00arith.constant\00arith.muli\00gpu.thread_id\00gpu.block_dim\00arith.addi\00builtin.unrealized_conversion_cast\00llvm.insertvalue\00arith.shrui\00arith.cmpi\00func.func\00nvvm.elect.sync\00nvvm.shfl.sync\00arith.andi\00llvm.mlir.global\00llvm.mlir.constant\00llvm.mlir.undef\00llvm.load\00gpu.launch\00func.return\00arith.remui\00gpu.dynamic_shared_memory\00memref.view\00nvvm.fence.mbarrier.init\00gpu.barrier\00memref.store\00gpu.terminator\00-\00value\00sym_name\00position\00dimension\00function_type\00stable_mosaic_gpu.version\00kernel\00pallas_call\00mosaic_gpu_init_tma_desc\00sym_visibility\00private\00addr_space\00global_type\00linkage\00global_scratch\00unnamed_addr\00visibility_\00llvm.emit_c_interface\00kernel_mosaic_gpu\00ordering\00operandSegmentSizes\00workgroup_attributions\00overflowFlags\00kind\00predicate\00transforms\00swap:\00swap\00third_party/py/jax/tests/pallas/mosaic_gpu_test.py\00", use_custom_barrier = false} + } + )hlo"; +} + +TEST(CustomCallTest, UnloadGPUModule) { + ASSERT_OK_AND_ASSIGN( + auto module, xla::ParseAndReturnUnverifiedModule(TestMGPUHloModule())); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr client, + xla::GetXlaPjrtGpuClient(/*options=*/{})); + ASSERT_OK_AND_ASSIGN( + std::unique_ptr executable, + client->CompileAndLoad(xla::XlaComputation(module->ToProto()), + /*options=*/{})); + + absl::SetVLogLevel("custom_call", 5); + { + absl::ScopedMockLog log; + EXPECT_CALL(log, + Log(absl::LogSeverity::kInfo, _, + "Successfully compiled and initialized Mosaic GPU kernel")) + .Times(1); + log.StartCapturingLogs(); + EXPECT_THAT(ExecuteSync(executable.get()), IsOk()); + } + + { + // The second execution the compilation should be cached. + absl::ScopedMockLog log; + EXPECT_CALL(log, + Log(absl::LogSeverity::kInfo, _, + "Successfully compiled and initialized Mosaic GPU kernel")) + .Times(0); + log.StartCapturingLogs(); + EXPECT_THAT(ExecuteSync(executable.get()), IsOk()); + } + + { + // GPU module should be unloaded when the executable is destroyed. + absl::ScopedMockLog log; + EXPECT_CALL(log, Log(absl::LogSeverity::kInfo, _, + "Successfully unloaded GPU module")) + .Times(1); + log.StartCapturingLogs(); + executable.reset(); + } +} + +TEST(CustomCallTest, GPUModuleIsOnlyUnloadedWhenAllExecutablesAreDestroyed) { + ASSERT_OK_AND_ASSIGN( + auto module, xla::ParseAndReturnUnverifiedModule(TestMGPUHloModule())); + ASSERT_OK_AND_ASSIGN(std::unique_ptr client, + xla::GetXlaPjrtGpuClient(/*options=*/{})); + ASSERT_OK_AND_ASSIGN( + std::unique_ptr executable1, + client->CompileAndLoad(xla::XlaComputation(module->ToProto()), + /*options=*/{})); + ASSERT_OK_AND_ASSIGN( + std::unique_ptr executable2, + client->CompileAndLoad(xla::XlaComputation(module->ToProto()), + /*options=*/{})); + + EXPECT_THAT(ExecuteSync(executable1.get()), IsOk()); + EXPECT_THAT(ExecuteSync(executable2.get()), IsOk()); + + absl::SetVLogLevel("custom_call", 5); + { + // executable2 still holds a reference to the GPU module. + absl::ScopedMockLog log; + EXPECT_CALL(log, Log(absl::LogSeverity::kInfo, _, + "Successfully unloaded GPU module")) + .Times(0); + log.StartCapturingLogs(); + executable1.reset(); + } + EXPECT_THAT(ExecuteSync(executable2.get()), IsOk()); + { + absl::ScopedMockLog log; + EXPECT_CALL(log, Log(absl::LogSeverity::kInfo, _, + "Successfully unloaded GPU module")) + .Times(1); + log.StartCapturingLogs(); + executable2.reset(); + } +} + +TEST(CustomCallTest, GPUModuleIsRecompiledAfterExpiration) { + ASSERT_OK_AND_ASSIGN( + auto module, xla::ParseAndReturnUnverifiedModule(TestMGPUHloModule())); + ASSERT_OK_AND_ASSIGN(std::unique_ptr client, + xla::GetXlaPjrtGpuClient(/*options=*/{})); + ASSERT_OK_AND_ASSIGN( + std::unique_ptr executable, + client->CompileAndLoad(xla::XlaComputation(module->ToProto()), + /*options=*/{})); + + EXPECT_THAT(ExecuteSync(executable.get()), IsOk()); + + { + absl::ScopedMockLog log; + EXPECT_CALL(log, Log(absl::LogSeverity::kInfo, _, + "Successfully unloaded GPU module")) + .Times(1); + log.StartCapturingLogs(); + executable.reset(); + } + + ASSERT_OK_AND_ASSIGN( + executable, client->CompileAndLoad(xla::XlaComputation(module->ToProto()), + /*options=*/{})); + + { + // executable was destroyed and the module was unloaded. We re-compile the + // kernel. + absl::ScopedMockLog log; + EXPECT_CALL(log, + Log(absl::LogSeverity::kInfo, _, + "Successfully compiled and initialized Mosaic GPU kernel")) + .Times(1); + log.StartCapturingLogs(); + EXPECT_THAT(ExecuteSync(executable.get()), IsOk()); + } +} + } // namespace diff --git a/jaxlib/mosaic/gpu/nvshmem.h b/jaxlib/mosaic/gpu/nvshmem.h index dbd11aa1d373..267f17de8324 100644 --- a/jaxlib/mosaic/gpu/nvshmem.h +++ b/jaxlib/mosaic/gpu/nvshmem.h @@ -54,6 +54,11 @@ class NvshmemApi { return nvshmemx_cumodule_init(module); } + int cumodule_finalize(CUmodule module) { + std::lock_guard lock(mutex_); + return nvshmemx_cumodule_finalize(module); + } + void barrier_all_on_stream(cudaStream_t stream) { nvshmemx_barrier_all_on_stream(stream); } @@ -78,11 +83,13 @@ class NvshmemApi { NVSHMEM_SET_FN(nvshmemx_barrier_all_on_stream) NVSHMEM_SET_FN(nvshmemx_cumodule_init) + NVSHMEM_SET_FN(nvshmemx_cumodule_finalize) NVSHMEM_SET_FN(nvshmemx_init_status) } int (*nvshmemx_barrier_all_on_stream)(cudaStream_t); int (*nvshmemx_cumodule_init)(CUmodule); + int (*nvshmemx_cumodule_finalize)(CUmodule); int (*nvshmemx_init_status)(); std::mutex mutex_; From 3863a14db0ec758f756c27c96c45cc6c9603b782 Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Mon, 15 Dec 2025 08:50:57 -0800 Subject: [PATCH 205/315] [Pallas:MGPU] Fix `swap_p` lowering rule for splat under LANE semantics. PiperOrigin-RevId: 844790136 --- jax/_src/pallas/mosaic_gpu/lowering.py | 11 ++++++----- tests/pallas/mosaic_gpu_test.py | 4 ---- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 24e3639c5d13..481d038178d5 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -1708,16 +1708,17 @@ def _swap_lowering_rule( if ctx.module_ctx.auto_barriers: barrier() # Make sure reads have completed before we write. + match transforms: - case _ if not ctx.avals_out[0].shape: # Scalar case. + case _ if math.prod(ctx.avals_out[0].shape) == 1: # Scalar case. + zero_idx = _ir_constant(0, ir.IndexType.get()) + indices = [zero_idx] * len(ctx.avals_out[0].shape) old_value = mgpu.FragmentedArray.splat( - memref_dialect.load(x_smem, []), + memref_dialect.load(x_smem, indices), shape=(), is_signed=mgpu_utils.is_signed(v_aval.dtype), ) - memref_dialect.store( - _ensure_ir_value(value, ctx.avals_out[0].dtype), x_smem, [] - ) + value.store_untiled(x_smem) case ( gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling), diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 2da4c1327c60..508b9473fe29 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -339,10 +339,6 @@ def kernel(x_ref, out_ref): def test_reshape_splat(self): shape = (1, 1, 1) - # TODO(allanrenucci): Fix swap_p lowering for scalars under Lane semantics. - if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane: - self.skipTest("Not supported under Lane semantics") - @functools.partial( self.kernel, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), From d337b27c64a6418ea51cbd7c81dbb02d087d262c Mon Sep 17 00:00:00 2001 From: Michael Whittaker Date: Mon, 15 Dec 2025 10:09:34 -0800 Subject: [PATCH 206/315] Add a JAX config to disable the preemption service. PiperOrigin-RevId: 844820622 --- jax/_src/distributed.py | 15 +++++++++++++++ jax/experimental/multihost_utils.py | 5 ++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index 715d831e4e9c..b6f1ab7443d6 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -44,6 +44,15 @@ ), ) +_ENABLE_PREEMPTION_SERVICE = config.bool_state( + name='jax_enable_preemption_service', + default=True, + help=( + "Enables the preemption service. See" + " multihost_utils.reached_preemption_sync_point for details." + ), +) + class State: process_id: int = 0 num_processes: int = 1 @@ -188,6 +197,12 @@ def shutdown(self): self.service = None def initialize_preemption_sync_manager(self): + if not _ENABLE_PREEMPTION_SERVICE.value: + logger.info( + 'The JAX preemption service is disabled. You can enable it using the' + ' jax_enable_preemption_service configuration option.' + ) + return if self.preemption_sync_manager is not None: raise RuntimeError( 'Preemption sync manager should only be initialized once.') diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py index ee7e9509ea3d..f3026502abc6 100644 --- a/jax/experimental/multihost_utils.py +++ b/jax/experimental/multihost_utils.py @@ -226,7 +226,10 @@ def should_save(step_id: int) -> bool: return False sync_manager = distributed.global_state.preemption_sync_manager if sync_manager is None: - raise RuntimeError("Preemption sync manager has not been initialized.") + raise RuntimeError( + "Preemption sync manager has not been initialized. Make sure the" + " 'jax_enable_preemption_service' config is enabled." + ) return sync_manager.reached_sync_point(step_id) From 90007e00786b925a79f8511a0e3ea6c09d55ef82 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 15 Dec 2025 10:15:23 -0800 Subject: [PATCH 207/315] [pallas:mosaic] Allowed specifying tiling for SC_*_SUBCORE kernels We need to know the tiling to use `pltpu.emit_pipeline` in the SC lowering. Previously tiling was controlled via an XLA flag, which could not be easily inspected at tracing/lowering time. PiperOrigin-RevId: 844823041 --- jax/_src/pallas/mosaic/core.py | 5 + .../pallas/mosaic/pallas_call_registration.py | 18 ++- jax/_src/tpu_custom_call.py | 33 ++++- tests/pallas/tpu_pallas_test.py | 11 +- tests/pallas/tpu_sparsecore_pallas_test.py | 128 ++++++++---------- 5 files changed, 108 insertions(+), 87 deletions(-) diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 4547af632d4f..e05c69f537fc 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -102,6 +102,8 @@ class CompilerParams(pallas_core.CompilerParams): skip_device_barrier: Skip the default device barrier for the kernel. allow_collective_id_without_custom_barrier: Allow the use of collective_id without a custom barrier. + use_tc_tiling_on_sc: Use TensorCore tiling for SparseCore. This flag is + only used for ``SC_*_SUBCORE`` kernels. """ BACKEND: ClassVar[pallas_core.Backend] = "mosaic_tpu" dimension_semantics: tuple[DimensionSemantics, ...] | None = None @@ -117,6 +119,7 @@ class CompilerParams(pallas_core.CompilerParams): skip_device_barrier: bool = False allow_collective_id_without_custom_barrier: bool = False shape_invariant_numerics: bool = True + use_tc_tiling_on_sc: bool | None = None def __init__( self, @@ -133,6 +136,7 @@ def __init__( skip_device_barrier: bool = False, allow_collective_id_without_custom_barrier: bool = False, shape_invariant_numerics: bool = True, + use_tc_tiling_on_sc: bool | None = None, ): object.__setattr__( self, @@ -165,6 +169,7 @@ def __init__( object.__setattr__( self, "shape_invariant_numerics", shape_invariant_numerics ) + object.__setattr__(self, "use_tc_tiling_on_sc", use_tc_tiling_on_sc) # Replace is a method, not a field. replace = dataclasses.replace diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 85d3d4cf88a9..1ab2c9f40459 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -164,7 +164,7 @@ def pallas_call_tpu_lowering_rule( grid_mapping, jaxpr, dimension_semantics=mosaic_params.dimension_semantics, - kernel_type=mosaic_params.kernel_type, + kernel_type=kernel_type, mesh=jax_mesh, dynamic_shape_replacement_enabled=pallas_core.dynamic_shapes_export_enabled(), ) @@ -258,6 +258,21 @@ def _maybe_cast_inputs(*args): has_side_effects = tpu_custom_call.TpuSideEffectType.SIDE_EFFECTING case _: raise ValueError(f"Invalid side effect type: {mosaic_params.has_side_effects}") + tiling: tpu_custom_call.Tiling | None = None + if mosaic_params.use_tc_tiling_on_sc is not None: + if kernel_type not in ( + tpu_core.KernelType.SC_SCALAR_SUBCORE, + tpu_core.KernelType.SC_VECTOR_SUBCORE, + ): + raise ValueError( + "use_tc_tiling_on_sc= is only supported for SC_*_SUBCORE kernels" + ) + + tiling = ( + tpu_custom_call.Tiling.COMPACT + if mosaic_params.use_tc_tiling_on_sc + else tpu_custom_call.Tiling.SPARSE_CORE + ) out_nodes = mosaic.lower_module_to_custom_call( kernel_ctx, *dynamic_grid_args, @@ -281,6 +296,7 @@ def _maybe_cast_inputs(*args): skip_device_barrier=mosaic_params.skip_device_barrier, allow_collective_id_without_custom_barrier=mosaic_params.allow_collective_id_without_custom_barrier, shape_invariant_numerics=mosaic_params.shape_invariant_numerics, + tiling=tiling, ) _maybe_cast_to_bool = ( lambda x, aval: x.astype(jax.numpy.bool_) diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index 81c63f94cce4..2a6773f28c80 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -145,6 +145,11 @@ class TpuSideEffectType(enum.Enum): SIDE_EFFECTING = "side_effecting" +class Tiling(enum.Enum): + COMPACT = "TILING_COMPACT" + SPARSE_CORE = "TILING_SPARSE_CORE" + + @dataclasses.dataclass(frozen=True) class CustomCallBackendConfig: """Represents an unserialized backend config for custom calls.""" @@ -166,6 +171,7 @@ class CustomCallBackendConfig: input_memory_spaces: tuple[MemorySpace | None, ...] | None skip_device_barrier: bool shape_invariant_numerics: bool + tiling: Tiling | None = None # Only used for SparseCore. def __post_init__(self): if self.allow_input_fusion is not None: @@ -195,9 +201,7 @@ def to_json(self) -> bytes: config.write(str(self.collective_id).encode("ascii")) if self.cost_estimate is not None: config.write(b', "cost_estimate": ') - config.write( - json.dumps(dict(self.cost_estimate), sort_keys=True).encode("ascii") - ) + config.write(_compact_json_object(**self.cost_estimate)) if self.needs_hlo_passes: config.write(b', "needs_hlo_passes": ') config.write(str(self.needs_hlo_passes).lower().encode("ascii")) @@ -214,7 +218,6 @@ def to_json(self) -> bytes: config.write(b', "allow_input_fusion": [') for i, value in enumerate(self.allow_input_fusion): config.write(b"true" if value else b"false") - # config.write(str(value).lower().encode("ascii")) if i + 1 != len(self.allow_input_fusion): config.write(b",") config.write(b"]") @@ -264,6 +267,9 @@ def to_json(self) -> bytes: config.write(b', "skip_device_barrier": ') config.write(str(self.skip_device_barrier).lower().encode("ascii")) config.write(b"}") # End of custom_call_config. + if self.tiling is not None: + config.write(b', "sparse_core_config": ') + config.write(_compact_json_object(tiling=self.tiling.value)) if self.device_type is not None: config.write(b', "device_type": ') config.write( @@ -303,6 +309,12 @@ def to_json(self) -> bytes: return config.getvalue() +def _compact_json_object(**kwargs: Any) -> bytes: + return json.dumps( + kwargs, sort_keys=True, indent=0, separators=(",", ":") + ).encode("ascii") + + @tpu_custom_call_p.def_abstract_eval def _tpu_custom_call_abstract_eval(*_, out_avals, **__): return out_avals @@ -368,7 +380,9 @@ def _tpu_custom_call_lowering( ) metadata_dict = {} if metadata is not None: - metadata_dict["kernel_metadata"] = ir.StringAttr.get(json.dumps(metadata)) + metadata_dict["kernel_metadata"] = ir.StringAttr.get( + _compact_json_object(**metadata) + ) assert isinstance(has_side_effects, TpuSideEffectType) if has_side_effects == TpuSideEffectType.DATAFLOW_SIDE_EFFECTING: metadata_dict["xla_allow_dce_side_effecting_op"] = ir.StringAttr.get("true") @@ -544,6 +558,7 @@ def _lower_to_custom_call_config( allow_collective_id_without_custom_barrier: bool = False, shape_invariant_numerics: bool = False, needs_layout_passes: bool | None = None, + tiling: Tiling | None = None, ) -> CustomCallBackendConfig: device_type = _get_device_type(module) needs_hlo_passes = _MOSAIC_ALLOW_HLO.value @@ -578,6 +593,7 @@ def _lower_to_custom_call_config( skip_device_barrier=skip_device_barrier, allow_collective_id_without_custom_barrier=allow_collective_id_without_custom_barrier, shape_invariant_numerics=shape_invariant_numerics, + tiling=tiling, ) @@ -603,6 +619,7 @@ def _lowered_to_custom_call_config( skip_device_barrier: bool = False, allow_collective_id_without_custom_barrier: bool = False, shape_invariant_numerics: bool = False, + tiling: Tiling | None = None, ): if has_custom_barrier: if collective_id is None: @@ -619,7 +636,7 @@ def _lowered_to_custom_call_config( "vmem_limit_bytes must be an int: provided with a" f" {type(vmem_limit_bytes)}." ) - config = CustomCallBackendConfig( + return CustomCallBackendConfig( lowered_module_asm, has_communication, collective_id, @@ -638,8 +655,8 @@ def _lowered_to_custom_call_config( input_memory_spaces=input_memory_spaces, skip_device_barrier=skip_device_barrier, shape_invariant_numerics=shape_invariant_numerics, + tiling=tiling, ) - return config def lower_module_to_custom_call( @@ -665,6 +682,7 @@ def lower_module_to_custom_call( allow_collective_id_without_custom_barrier: bool = False, shape_invariant_numerics: bool = False, needs_layout_passes: bool | None = None, + tiling: Tiling | None = None, ) -> Sequence[ir.Value]: if isinstance(has_side_effects, bool): has_side_effects = ( @@ -689,6 +707,7 @@ def lower_module_to_custom_call( allow_collective_id_without_custom_barrier=allow_collective_id_without_custom_barrier, shape_invariant_numerics=shape_invariant_numerics, needs_layout_passes=needs_layout_passes, + tiling=tiling, ) return _tpu_custom_call_lowering( ctx, diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 551ee97db9d5..cddfb9905ea9 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -2097,13 +2097,13 @@ def kernel(x, y): self.assertIn('tpu_custom_call', str(exported_module)) self.assertIn('cost_estimate', str(exported_module)) # The exported module string encodes " as \22. - self.assertIn(f'flops\\22: {batch_size * flops}', str(exported_module)) + self.assertIn(f'flops\\22:{batch_size * flops}', str(exported_module)) self.assertIn( - f'transcendentals\\22: {batch_size * transcendentals}', + f'transcendentals\\22:{batch_size * transcendentals}', str(exported_module), ) self.assertIn( - f'bytes_accessed\\22: {batch_size * bytes_accessed}', + f'bytes_accessed\\22:{batch_size * bytes_accessed}', str(exported_module), ) @@ -4201,7 +4201,10 @@ def f(x, y): )(x, y) hlo = f.lower(x, y).compile().as_text() - self.assertIn(json.dumps(metadata), hlo) + self.assertIn( + json.dumps(metadata, sort_keys=True, indent=0, separators=(',', ':')), + hlo, + ) if __name__ == '__main__': diff --git a/tests/pallas/tpu_sparsecore_pallas_test.py b/tests/pallas/tpu_sparsecore_pallas_test.py index 7157018ff8ba..5f5c1726ffba 100644 --- a/tests/pallas/tpu_sparsecore_pallas_test.py +++ b/tests/pallas/tpu_sparsecore_pallas_test.py @@ -39,12 +39,16 @@ class PallasSCTest(jtu.JaxTestCase): - COMPILER_OPTIONS = {"xla_tpu_use_tc_device_shape_on_sc": "false"} + USE_TC_TILING = False def setUp(self): if not jtu.is_device_tpu(5, "p") and not jtu.is_device_tpu_at_least(6): self.skipTest("SparseCore only supported on TPU v5p+") + if self.USE_TC_TILING and jtu.is_cloud_tpu(): + # TODO(apaszke,slebedev): Fix those. + self.skipTest("Many tests are failing on Cloud TPUs") + super().setUp() @property @@ -53,53 +57,31 @@ def sc_info(self): def vector_subcore_kernel(self, **kwargs): assert "compiler_params" not in kwargs - def wrapper(f): - f = pl.pallas_call( - f, - compiler_params=pltpu.CompilerParams( - kernel_type=pltpu.KernelType.SC_VECTOR_SUBCORE - ), - **kwargs, - ) - return jax.jit(f, compiler_options=self.COMPILER_OPTIONS) - return wrapper - - def kernel(self, *args, jax_compiler_options=None, **kwargs): - if jax_compiler_options is None: - jax_compiler_options = self.COMPILER_OPTIONS - # We only implement the decorator version of pl.kernel for now. - def wrapper(f): - f = pl.kernel(f, *args, **kwargs) - return jax.jit(f, compiler_options=jax_compiler_options) - return wrapper + return functools.partial( + pl.pallas_call, + **kwargs, + compiler_params=pltpu.CompilerParams( + kernel_type=pltpu.KernelType.SC_VECTOR_SUBCORE, + use_tc_tiling_on_sc=self.USE_TC_TILING, + ), + ) - @property - def uses_tc_tiling(self): - return self.COMPILER_OPTIONS.get( - "xla_tpu_use_tc_device_shape_on_sc", "false" - ) == "true" + def kernel(self, **kwargs): + assert "compiler_params" not in kwargs + return functools.partial( + pl.kernel, + compiler_params=pltpu.CompilerParams( + use_tc_tiling_on_sc=self.USE_TC_TILING + ), + **kwargs, + ) def skip_if_tc_tiling(self, reason: str = ""): - use_tc_tiling = self.COMPILER_OPTIONS.get( - "xla_tpu_use_tc_device_shape_on_sc", "false" - ) - if use_tc_tiling == "true": + if self.USE_TC_TILING: self.skipTest(f"TC tiling is not supported. {reason}") -class TCTilingMixin(): - COMPILER_OPTIONS = {"xla_tpu_use_tc_device_shape_on_sc": "true"} - - def setUp(self): - super().setUp() - if jtu.is_cloud_tpu(): - # TODO(apaszke,slebedev): Fix those. - self.skipTest("Many tests are failing on Cloud TPUs") - - class DebugPrintTest(PallasSCTest): - # We are passing compiler options from jax.jit explicitly. - COMPILER_OPTIONS = {} def setUp(self): if jtu.is_cloud_tpu(): @@ -313,14 +295,14 @@ def kernel(x_ref, o_ref): @jtu.thread_unsafe_test(condition=not jtu.hypothesis_is_thread_safe()) @hp.given(hps.data()) def test_block_spec_untiled_slicing(self, data): - if not self.uses_tc_tiling: + if not self.USE_TC_TILING: self.skipTest("Test uncoveres a bug: @reproduce_failure('6.80.0', b'AAEBAQAAAAA=')") slice_shape = data.draw( hps.lists( - hps.integers(1, 3), min_size=(1 + self.uses_tc_tiling), max_size=4 + hps.integers(1, 3), min_size=(1 + self.USE_TC_TILING), max_size=4 ) ) - if self.uses_tc_tiling: + if self.USE_TC_TILING: slice_shape[-2] *= 8 slice_shape[-1] *= 128 else: @@ -1429,6 +1411,8 @@ def test_barrier_via_pallas_call(self): if not jtu.is_cloud_tpu_at_least(2025, 11, 22): self.skipTest("Test requires a newer libtpu") + self.skip_if_tc_tiling() + mesh = plsc.VectorSubcoreMesh( core_axis_name="core", subcore_axis_name="subcore", num_cores=1 ) @@ -1439,15 +1423,14 @@ def test_barrier_via_pallas_call(self): compiler_params=pltpu.CompilerParams( kernel_type=pltpu.KernelType.SC_VECTOR_SUBCORE, dimension_semantics=["subcore_parallel"], + use_tc_tiling_on_sc=self.USE_TC_TILING, ), out_shape=jax.ShapeDtypeStruct( shape=(mesh.num_subcores, vec_dim), dtype=jnp.uint32 ), out_specs=pl.BlockSpec((1, vec_dim), lambda i: (i, 0)), scratch_shapes=( - pltpu.VMEM_SHARED( - (mesh.num_subcores, vec_dim), jnp.uint32 - ), + pltpu.VMEM_SHARED((mesh.num_subcores, vec_dim), jnp.uint32), pltpu.VMEM((vec_dim,), jnp.uint32), ), ) @@ -1517,13 +1500,16 @@ def test_scatter_add(self, dtype): compiler_params=pltpu.CompilerParams( kernel_type=pltpu.KernelType.SC_VECTOR_SUBCORE, dimension_semantics=["subcore_parallel"], + use_tc_tiling_on_sc=self.USE_TC_TILING, ), out_shape=jax.ShapeDtypeStruct(shape[1:], dtype), - out_specs=pl.BlockSpec(shape[1:], lambda i: (0,), - memory_space=pltpu.HBM), - in_specs=[pl.BlockSpec(shape, lambda *_: (0, 0), - memory_space=pltpu.HBM), - pl.BlockSpec(shape[1:], lambda _: (0,))], + out_specs=pl.BlockSpec( + shape[1:], lambda i: (0,), memory_space=pltpu.HBM + ), + in_specs=[ + pl.BlockSpec(shape, lambda *_: (0, 0), memory_space=pltpu.HBM), + pl.BlockSpec(shape[1:], lambda _: (0,)), + ], scratch_shapes=[ pltpu.VMEM_SHARED(shape[1:], dtype), pltpu.VMEM(shape[1:], dtype), @@ -1606,8 +1592,6 @@ def f(x): core_axis_name="core", subcore_axis_name="subcore", num_cores=1 ), scratch_shapes=(pltpu.VMEM(x.shape, x.dtype),), - # compiler_options don't compose well with shard_map... - jax_compiler_options={}, ) def kernel(in_ref, o_ref, scratch_ref): pltpu.sync_copy(in_ref, scratch_ref) @@ -1618,24 +1602,19 @@ def kernel(in_ref, o_ref, scratch_ref): np.testing.assert_array_equal(f(x), x) @parameterized.named_parameters( - ("exp", jnp.exp), ("neg", lambda x: -x), ("abs", jnp.abs)) + ("exp", jnp.exp), ("neg", lambda x: -x), ("abs", jnp.abs) + ) def test_unary_ops(self, op): if not jtu.is_cloud_tpu_at_least(2025, 11, 30): self.skipTest("Test requires a newer libtpu") x = jnp.arange(8, dtype=jnp.float32) - def sc_exp_kernel(x_hbm_ref, out_ref): - out_ref[...] = op(x_hbm_ref[...]) + @self.vector_subcore_kernel(out_shape=x) + def kernel(x_ref, o_ref): + o_ref[...] = op(x_ref[...]) - result = pl.pallas_call( - sc_exp_kernel, - compiler_params=pltpu.CompilerParams( - kernel_type=pltpu.KernelType.SC_VECTOR_SUBCORE - ), - out_shape=x, - )(x) - np.testing.assert_array_equal(result, op(x)) + np.testing.assert_array_equal(kernel(x), op(x)) @parameterized.product(dtype=[np.int32, np.float32]) def test_vector_gather(self, dtype): @@ -1757,8 +1736,8 @@ def kernel(*args): np.testing.assert_array_equal(values_result, values_in[perm]) -class VectorSubcoreTestWithTCTiling(TCTilingMixin, VectorSubcoreTest): - pass +class VectorSubcoreTestWithTCTiling(VectorSubcoreTest): + USE_TC_TILING = True class ScalarSubcoreTest(PallasSCTest): @@ -1888,8 +1867,8 @@ def _(j): np.testing.assert_array_equal(kernel(x), x + jnp.arange(1, 9)[:, None]) -class ScalarSubcoreTestWithTCTiling(TCTilingMixin, ScalarSubcoreTest): - pass +class ScalarSubcoreTestWithTCTiling(ScalarSubcoreTest): + USE_TC_TILING = True class PipelineTest(PallasSCTest): @@ -1919,8 +1898,8 @@ def pipeline(x_ref, o_ref): np.testing.assert_array_equal(kernel(x), x + 1) -class PipelineTestWithTCTiling(TCTilingMixin, PipelineTest): - pass +class PipelineTestWithTCTiling(PipelineTest): + USE_TC_TILING = True class PallasSparsecoreAsyncTest(PallasSCTest): @@ -1961,6 +1940,7 @@ def foo(x): compiler_params=pltpu.CompilerParams( dimension_semantics=["core_parallel"], kernel_type=pltpu.KernelType.SC_SCALAR_SUBCORE, + use_tc_tiling_on_sc=self.USE_TC_TILING, ), )() @@ -1984,10 +1964,8 @@ def _(): np.testing.assert_array_equal(o, x) -class PallasSparsecoreAsyncTestWithTCTiling( - TCTilingMixin, PallasSparsecoreAsyncTest -): - pass +class PallasSparsecoreAsyncTestWithTCTiling(PallasSparsecoreAsyncTest): + USE_TC_TILING = True if __name__ == "__main__": From f88f6ecae77d1da1cac4dd71ab543d9bf6623419 Mon Sep 17 00:00:00 2001 From: Yulia Baturina Date: Mon, 15 Dec 2025 10:30:28 -0800 Subject: [PATCH 208/315] Reverts 64a8e0de42681bdc26b89dba30fc2c979869ea82 PiperOrigin-RevId: 844828995 --- .github/workflows/bazel_test_tpu.yml | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/.github/workflows/bazel_test_tpu.yml b/.github/workflows/bazel_test_tpu.yml index 15895f995fd3..6459c45475e0 100644 --- a/.github/workflows/bazel_test_tpu.yml +++ b/.github/workflows/bazel_test_tpu.yml @@ -122,21 +122,8 @@ jobs: mkdir -p $(pwd)/dist $JAXCI_PYTHON -m pip install --upgrade pip echo "Download the wheel into a local directory" - # TODO(ybaturina): Remove this once the libtpu wheel is updated. if [[ "${{ inputs.libtpu-version-type }}" == "nightly" ]]; then - version="" - suffix="" - full_python_version="${{ inputs.python }}" - - if [[ "$full_python_version" == *-* ]]; then - version="${full_python_version%%-*}" - suffix="t" - else - version="$full_python_version" - suffix="" - fi - version_no_dots="${version//./}" - wget -P $(pwd)/dist https://storage.googleapis.com/libtpu-nightly-releases/wheels/libtpu/libtpu-0.0.31.dev20251209+nightly-cp${version_no_dots}-cp${version_no_dots}${suffix}-manylinux_2_31_x86_64.whl + $JAXCI_PYTHON -m pip download -d $(pwd)/dist --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html elif [[ "${{ inputs.libtpu-version-type }}" == "pypi_latest" ]]; then echo "Using latest libtpu from PyPI" $JAXCI_PYTHON -m pip download -d $(pwd)/dist libtpu From 01c6db463bf73a1592f82cfb2a55039e336f9d89 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Mon, 15 Dec 2025 10:53:26 -0800 Subject: [PATCH 209/315] [pxla] Deprecate `jax.interpreters.pxla` symbols. PiperOrigin-RevId: 844838950 --- CHANGELOG.md | 2 + jax/interpreters/pxla.py | 183 +++++++++++++++++++++++------ tests/debugging_primitives_test.py | 5 +- tests/pickle_test.py | 5 +- 4 files changed, 155 insertions(+), 40 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6d3c417b81ce..d91111b68482 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. `call_impl`, `get_aval`, `mapped_aval`, `subjaxprs`, `set_current_trace`, `take_current_trace`, `traverse_jaxpr_params`, `unmapped_aval`, `AbstractToken`, and `TraceTag`. + * All symbols in {mod}`jax.interpreters.pxla` are deprecated. These are + primarily JAX internal APIs, and users should not rely on them. * Changes: * jax's `Tracer` no longer inherits from `jax.Array` at runtime. However, diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 06c771169443..af179b1905a5 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -12,42 +12,153 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax._src.interpreters.pxla import ( - Index as Index, - MapTracer as MapTracer, - MeshAxisName as MeshAxisName, - MeshComputation as MeshComputation, - MeshExecutable as MeshExecutable, - PmapExecutable as PmapExecutable, - global_aval_to_result_handler as global_aval_to_result_handler, - global_avals_to_results_handler as global_avals_to_results_handler, - global_result_handlers as global_result_handlers, - parallel_callable as parallel_callable, - shard_args as shard_args, - xla_pmap_p as xla_pmap_p, -) -from jax._src.mesh import ( - thread_resources as thread_resources, -) +# Note: import as is required for names to be exported. +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 -from jax._src.op_shardings import ( - are_hlo_shardings_equal as are_hlo_shardings_equal, - is_hlo_sharding_replicated as is_hlo_sharding_replicated, - op_sharding_to_indices as op_sharding_to_indices, -) +from jax._src.interpreters import pxla as _deprecated_pxla +from jax._src import mesh as _deprecated_mesh +from jax._src import op_shardings as _deprecated_op_shardings +from jax._src import sharding_impls as _deprecated_sharding_impls +from jax._src import sharding_specs as _deprecated_sharding_specs -from jax._src.sharding_impls import ( - ArrayMapping as ArrayMapping, - UNSPECIFIED as _UNSPECIFIED, # noqa: F401 - array_mapping_to_axis_resources as array_mapping_to_axis_resources, -) +_deprecations = { + # deprecated as of JAX v0.8.2 (Dec 2025) + "Index": ( + "jax.interpreters.pxla.Index is deprecated as of JAX v0.8.2.", + _deprecated_pxla.Index, + ), + "MapTracer": ( + "jax.interpreters.pxla.MapTracer is deprecated as of JAX v0.8.2.", + _deprecated_pxla.MapTracer, + ), + "MeshAxisName": ( + "jax.interpreters.pxla.MeshAxisName is deprecated as of JAX v0.8.2. Use jax.sharding.Mesh axis names directly.", + _deprecated_pxla.MeshAxisName, + ), + "MeshComputation": ( + "jax.interpreters.pxla.MeshComputation is deprecated as of JAX v0.8.2.", + _deprecated_pxla.MeshComputation, + ), + "MeshExecutable": ( + "jax.interpreters.pxla.MeshExecutable is deprecated as of JAX v0.8.2.", + _deprecated_pxla.MeshExecutable, + ), + "PmapExecutable": ( + "jax.interpreters.pxla.PmapExecutable is deprecated as of JAX v0.8.2.", + _deprecated_pxla.PmapExecutable, + ), + "global_aval_to_result_handler": ( + "jax.interpreters.pxla.global_aval_to_result_handler is deprecated as of JAX v0.8.2.", + _deprecated_pxla.global_aval_to_result_handler, + ), + "global_avals_to_results_handler": ( + "jax.interpreters.pxla.global_avals_to_results_handler is deprecated as of JAX v0.8.2.", + _deprecated_pxla.global_avals_to_results_handler, + ), + "global_result_handlers": ( + "jax.interpreters.pxla.global_result_handlers is deprecated as of JAX v0.8.2.", + _deprecated_pxla.global_result_handlers, + ), + "parallel_callable": ( + "jax.interpreters.pxla.parallel_callable is deprecated as of JAX v0.8.2.", + _deprecated_pxla.parallel_callable, + ), + "shard_args": ( + "jax.interpreters.pxla.shard_args is deprecated as of JAX v0.8.2.", + _deprecated_pxla.shard_args, + ), + "xla_pmap_p": ( + "jax.interpreters.pxla.xla_pmap_p is deprecated as of JAX v0.8.2.", + _deprecated_pxla.xla_pmap_p, + ), + "thread_resources": ( + "jax.interpreters.pxla.thread_resources is deprecated as of JAX v0.8.2.", + _deprecated_mesh.thread_resources, + ), + "are_hlo_shardings_equal": ( + "jax.interpreters.pxla.are_hlo_shardings_equal is deprecated as of JAX v0.8.2.", + _deprecated_op_shardings.are_hlo_shardings_equal, + ), + "is_hlo_sharding_replicated": ( + "jax.interpreters.pxla.is_hlo_sharding_replicated is deprecated as of JAX v0.8.2.", + _deprecated_op_shardings.is_hlo_sharding_replicated, + ), + "op_sharding_to_indices": ( + "jax.interpreters.pxla.op_sharding_to_indices is deprecated as of JAX v0.8.2.", + _deprecated_op_shardings.op_sharding_to_indices, + ), + "ArrayMapping": ( + "jax.interpreters.pxla.ArrayMapping is deprecated as of JAX v0.8.2.", + _deprecated_sharding_impls.ArrayMapping, + ), + "_UNSPECIFIED": ( + "jax.interpreters.pxla._UNSPECIFIED is deprecated as of JAX v0.8.2.", + _deprecated_sharding_impls.UNSPECIFIED, + ), + "array_mapping_to_axis_resources": ( + "jax.interpreters.pxla.array_mapping_to_axis_resources is deprecated as of JAX v0.8.2.", + _deprecated_sharding_impls.array_mapping_to_axis_resources, + ), + "Chunked": ( + "jax.interpreters.pxla.Chunked is deprecated as of JAX v0.8.2.", + _deprecated_sharding_specs.Chunked, + ), + "NoSharding": ( + "jax.interpreters.pxla.NoSharding is deprecated as of JAX v0.8.2.", + _deprecated_sharding_specs.NoSharding, + ), + "Replicated": ( + "jax.interpreters.pxla.Replicated is deprecated as of JAX v0.8.2.", + _deprecated_sharding_specs.Replicated, + ), + "ShardedAxis": ( + "jax.interpreters.pxla.ShardedAxis is deprecated as of JAX v0.8.2.", + _deprecated_sharding_specs.ShardedAxis, + ), + "ShardingSpec": ( + "jax.interpreters.pxla.ShardingSpec is deprecated as of JAX v0.8.2.", + _deprecated_sharding_specs.ShardingSpec, + ), + "Unstacked": ( + "jax.interpreters.pxla.Unstacked is deprecated as of JAX v0.8.2.", + _deprecated_sharding_specs.Unstacked, + ), + "spec_to_indices": ( + "jax.interpreters.pxla.spec_to_indices is deprecated as of JAX v0.8.2.", + _deprecated_sharding_specs.spec_to_indices, + ), +} -from jax._src.sharding_specs import ( - Chunked as Chunked, - NoSharding as NoSharding, - Replicated as Replicated, - ShardedAxis as ShardedAxis, - ShardingSpec as ShardingSpec, - Unstacked as Unstacked, - spec_to_indices as spec_to_indices, -) +import typing as _typing +if _typing.TYPE_CHECKING: + Index = _deprecated_pxla.Index + MapTracer = _deprecated_pxla.MapTracer + MeshAxisName = _deprecated_pxla.MeshAxisName + MeshComputation = _deprecated_pxla.MeshComputation + MeshExecutable = _deprecated_pxla.MeshExecutable + PmapExecutable = _deprecated_pxla.PmapExecutable + global_aval_to_result_handler = _deprecated_pxla.global_aval_to_result_handler + global_avals_to_results_handler = _deprecated_pxla.global_avals_to_results_handler + global_result_handlers = _deprecated_pxla.global_result_handlers + parallel_callable = _deprecated_pxla.parallel_callable + shard_args = _deprecated_pxla.shard_args + xla_pmap_p = _deprecated_pxla.xla_pmap_p + thread_resources = _deprecated_mesh.thread_resources + are_hlo_shardings_equal = _deprecated_op_shardings.are_hlo_shardings_equal + is_hlo_sharding_replicated = _deprecated_op_shardings.is_hlo_sharding_replicated + op_sharding_to_indices = _deprecated_op_shardings.op_sharding_to_indices + ArrayMapping = _deprecated_sharding_impls.ArrayMapping + _UNSPECIFIED = _deprecated_sharding_impls.UNSPECIFIED + array_mapping_to_axis_resources = _deprecated_sharding_impls.array_mapping_to_axis_resources + Chunked = _deprecated_sharding_specs.Chunked + NoSharding = _deprecated_sharding_specs.NoSharding + Replicated = _deprecated_sharding_specs.Replicated + ShardedAxis = _deprecated_sharding_specs.ShardedAxis + ShardingSpec = _deprecated_sharding_specs.ShardingSpec + Unstacked = _deprecated_sharding_specs.Unstacked + spec_to_indices = _deprecated_sharding_specs.spec_to_indices +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del _typing diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py index e2cb37e4cb66..c05e2d42a9b7 100644 --- a/tests/debugging_primitives_test.py +++ b/tests/debugging_primitives_test.py @@ -20,12 +20,12 @@ from absl.testing import absltest, parameterized import jax from jax import lax -from jax.interpreters import pxla from jax._src import ad_checkpoint from jax._src import config from jax._src import debugging from jax._src import dispatch from jax._src import test_util as jtu +from jax._src.interpreters import pxla from jax.sharding import PartitionSpec as P import jax.numpy as jnp import numpy as np @@ -1106,7 +1106,8 @@ def test_visualize_wide_array(self): """) self.assertEqual(output(), expected) - @jtu.ignore_warning(category=DeprecationWarning) + @jtu.ignore_warning(category=DeprecationWarning, + message='jax.sharding.PmapSharding is deprecated') def test_visualize_pmap_sharding(self): ss = pxla.ShardingSpec( sharding=(pxla.Unstacked(8),), diff --git a/tests/pickle_test.py b/tests/pickle_test.py index 2f4d677ee0b8..c1466c09058f 100644 --- a/tests/pickle_test.py +++ b/tests/pickle_test.py @@ -26,10 +26,10 @@ import jax from jax import numpy as jnp -from jax.interpreters import pxla from jax._src import config from jax._src import literals from jax._src import test_util as jtu +from jax._src.interpreters import pxla from jax._src.lib import xla_client as xc from jax._src.sharding_impls import GSPMDSharding @@ -196,7 +196,8 @@ def test_pickle_single_device_sharding_with_memory_kind(self): ) self.assertEqual(s, pickle.loads(pickle.dumps(s))) - @jtu.ignore_warning(category=DeprecationWarning) + @jtu.ignore_warning(category=DeprecationWarning, + message='jax.sharding.PmapSharding is deprecated') def test_pickle_pmap_sharding(self): ss = pxla.ShardingSpec( sharding=(pxla.Unstacked(8),), From a8983f89dc9dcb873da67bafafd352039d5af6ba Mon Sep 17 00:00:00 2001 From: Luke Baumann Date: Mon, 15 Dec 2025 11:57:30 -0800 Subject: [PATCH 210/315] Expose profiler advanced configuration as a Python dict. In profiler.cc, the advanced_configuration property of tensorflow::ProfileOptions is now exposed as a Python dictionary. The getter converts the proto map to a nb::dict, handling different value types (bool, int64, string). Example error: ``` ProfileOptions().advanced_configuration ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ TypeError: Unable to convert function return value to a Python type! The signature was (self) -> proto2::Map, std::__u::allocator>, tensorflow::ProfileOptions_AdvancedConfigValue> ``` PiperOrigin-RevId: 844865140 --- tests/profiler_test.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/profiler_test.py b/tests/profiler_test.py index 3088ced9872a..0db0e2dffab7 100644 --- a/tests/profiler_test.py +++ b/tests/profiler_test.py @@ -32,6 +32,7 @@ import jax._src.test_util as jtu from jax._src import profiler +from jax._src.lib import ifrt_version from jax import jit @@ -508,5 +509,20 @@ def on_profile(): unittest.mock.ANY, ) + def test_advanced_configuration_getter(self): + if ifrt_version < 41: + self.skipTest("advanced_configuration getter is newly added") + + options = jax.profiler.ProfileOptions() + advanced_config = { + "tpu_trace_mode": "TRACE_COMPUTE", + "tpu_num_sparse_cores_to_trace": 1, + "enableFwThrottleEvent": True, + } + options.advanced_configuration = advanced_config + returned_config = options.advanced_configuration + self.assertDictEqual(returned_config, advanced_config) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From b4a3e160febd21050b6921724b99a4d4e7212b13 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 15 Dec 2025 12:00:02 -0800 Subject: [PATCH 211/315] Delete si_vjp from JAX. Dead weight at this point. It is now replaced with jax.vjp. PiperOrigin-RevId: 844865984 --- CHANGELOG.md | 2 + jax/_src/api.py | 133 +++++++---------------------------- jax/experimental/__init__.py | 4 -- tests/api_test.py | 85 +++++++--------------- 4 files changed, 54 insertions(+), 170 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d91111b68482..097a98bc1dcb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. For the moment, during Python type checking, we continue to declare `Tracer` as a subclass of `Array`, however we expect to remove this in a future release. + * `jax.experimental.si_vjp` has been deleted. + `jax.vjp` subsumes it's functionality. ## JAX 0.8.1 (November 18, 2025) diff --git a/jax/_src/api.py b/jax/_src/api.py index e91e3e8a3e18..635e1d09195b 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2212,106 +2212,12 @@ def vjp( fun, debug_info=debug_info("vjp", fun, primals, {})) return _vjp(wrapped_fun, *primals, has_aux=has_aux) -@partial(api_boundary, repro_api_name="jax.experimental.saved_input_vjp") -def saved_input_vjp(f: Callable, which: Sequence[bool], *primals, - allow_unused: bool = True, allow_opaque: bool = True): - if len(which) != len(primals): - raise ValueError( - "length of 'which' argument must equal the number of primal input values, " - f"but got {len(which)=} and {len(primals)=}") - - dbg = debug_info("saved_input_vjp", f, primals, {}) - fun = lu.wrap_init(f, debug_info=dbg) - primals_flat, in_tree = tree_flatten(primals) - fun, out_tree = flatten_fun_nokwargs(fun, in_tree) - out_primals_flat, out_pvals, jaxpr, residuals = ad.linearize(fun, *primals_flat) - out_known = [pval.is_known() for pval in out_pvals] - primals_filt, filt_tree = tree_flatten(tuple(p for w, p in zip(which, primals) if w)) - id_map = {id(x): i for i, x in enumerate(primals_filt)} - opaque_residuals = [] - res_spec = [RSpec(id_map[id(r)], True) if id(r) in id_map else - RSpec(opaque_residuals.append(r) or (len(opaque_residuals) - 1), False) # type: ignore - for r in residuals] - out_primal_avals = map(shaped_abstractify, out_primals_flat) - f_vjp = Partial(partial(_saved_input_vjpfun, res_spec, filt_tree, in_tree, - out_tree(), out_known, jaxpr, out_primal_avals), - opaque_residuals) - - if not allow_unused and not set(id_map).issubset(res_ids := {id(r) for r in residuals}): - unused = [(i, core.get_aval(x)) for i, (x, w) in enumerate(zip(primals, which)) - if w and id(x) not in res_ids] - assert unused - if len(unused) == 1: - (i, a), = unused - start, was = "an input value", "was" - msg = f" {dbg.arg_names[i] if dbg.arg_names is not None else 'unknown'} of type {a.str_short()}" - else: - start, was = "multiple input values", "were" - msg = "\n" + "\n".join(f" * {dbg.arg_names[i] if dbg.arg_names is not None else 'unknown'} of type {a.str_short()}" - for i, a in unused) - raise Exception(f"with {allow_unused=}, {start} marked to be saved {was} " - f"not used by the backward pass:{msg}") - - if not allow_opaque and opaque_residuals: - msg = ", ".join(core.get_aval(x).str_short() for x in opaque_residuals) - raise Exception(f"with {allow_opaque=}, the backward pass requires opaque " - f"(non-input) residuals: {msg}") - - out_primals = tree_unflatten(out_tree(), out_primals_flat) - return out_primals, f_vjp - -def _saved_input_vjpfun(res_spec, filtered_tree, in_tree, out_tree, out_known, - jaxpr, out_primal_avals, opaque_residuals, ct, - *saved_primals): - primals_filtered, filtered_tree_ = tree_flatten(saved_primals) - if filtered_tree != filtered_tree_: - raise ValueError( - "inputs passed to f_vjp must be a tuple of (pytrees of) " - "arrays with the same structure as\n" - " tuple(x for x, w in zip(inputs, which) if w)\n" - "given the original call\n" - " _, f_vjp = saved_input_vjp(f, which, *inputs, ...)\n" - "but the structures differ:\n" + - "\n".join(f" * inputs{keystr(path)} was a {thing1} in the original " - f"call, but a {thing2} here, so {explanation}" - for path, thing1, thing2, explanation - in equality_errors_pytreedef(filtered_tree, filtered_tree_))) - - residuals = [primals_filtered[i.idx] if i.primal else opaque_residuals[i.idx] - for i in res_spec] - dummy_args = [ad.UndefinedPrimal(v.aval) for v in jaxpr.invars] - cts_flat, out_tree_ = tree_flatten(ct) - if out_tree_ != out_tree: - raise ValueError(f"unexpected tree structure of argument to vjp function: " - f"got {out_tree_}, but expected to match {out_tree}") - for arg, aval in zip(cts_flat, out_primal_avals): - ct_aval = shaped_abstractify(arg) - ct_aval_expected = aval.to_cotangent_aval() - if (not core.typecompat(ct_aval, ct_aval_expected) and - not _temporary_dtype_exception(ct_aval, ct_aval_expected)): - raise ValueError( - "unexpected JAX type (e.g. shape/dtype) for argument to vjp function: " - f"got {ct_aval.str_short()}, but expected {ct_aval_expected.str_short()} " - f"because the corresponding output of the function had JAX type " - f"{aval.str_short()}") - - cts_flat = [ct for ct, k in zip(cts_flat, out_known) if not k] - arg_cts = ad.backward_pass(jaxpr, True, residuals, dummy_args, cts_flat) - return tree_unflatten(in_tree, map(ad.instantiate_zeros, arg_cts)) - -@dataclasses.dataclass(frozen=True) -class RSpec: - idx: int - primal: bool - -si_vjp = saved_input_vjp - - def _vjp(fun, *primals, has_aux=False): canon = lambda x: x if isinstance(x, core.Tracer) else canonicalize_value(x) primals = tree_map(canon, primals) primals_flat, in_tree = tree_flatten(primals) - for arg in primals_flat: dispatch.check_arg(arg) + for arg in primals_flat: + dispatch.check_arg(arg) if not has_aux: flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree) out_primals_flat, out_pvals, jaxpr, residuals = ad.linearize( @@ -2340,22 +2246,14 @@ def _vjp(fun, *primals, has_aux=False): else: return out_primals, f_vjp, tree_unflatten(aux_tree, aux) -def tuptree_map(f, treedef, x): - return treedef.walk(lambda xs, _: tuple(xs), f, x) - - -def _is_ref(x): - from jax._src.state.types import AbstractRef - try: return isinstance(typeof(x), AbstractRef) - except: return False - def _vjp3_callable(spec, out_known, jaxpr, out_primal_avals, in_tree, out_tree, args_res, opaque_res, *maybe_ct_refs): if not maybe_ct_refs: maybe_ct_refs_flat = [GradValue()] * in_tree.num_leaves else: maybe_ct_refs_flat, in_tree_ = tree_flatten(maybe_ct_refs) - if in_tree != in_tree_: raise Exception # TODO accept isomorph tuple tree + if in_tree != in_tree_: + raise Exception # TODO accept isomorph tuple tree args_res_ = tree_leaves(args_res, is_leaf=lambda x: isinstance(x, NotNeeded)) residuals = [args_res_[i.idx] if i.primal else opaque_res[i.idx] for i in spec] maybe_refs = [ad.RefAccum(v.aval, x) if _is_ref(x) else ad.ValAccum(v.aval) @@ -2366,7 +2264,8 @@ def _vjp3_callable(spec, out_known, jaxpr, out_primal_avals, in_tree, out_tree, def _vjp3_bwd(in_tree, out_tree, out_known, jaxpr, out_primal_avals, residuals, maybe_refs, out_ct): cts_flat, out_tree_ = tree_flatten(out_ct) - if out_tree != out_tree_: _vjp_ct_tree_error(jaxpr, out_tree, out_tree_) + if out_tree != out_tree_: + _vjp_ct_tree_error(jaxpr, out_tree, out_tree_) _vjp_check_ct_avals(cts_flat, out_primal_avals) cts_flat = [ct for ct, k in zip(cts_flat, out_known) if not k] ad.backward_pass3(jaxpr, True, residuals, maybe_refs, cts_flat) @@ -2375,6 +2274,23 @@ def _vjp3_bwd(in_tree, out_tree, out_known, jaxpr, out_primal_avals, residuals, arg_cts = map(ad.instantiate_zeros, arg_cts) return tree_unflatten(in_tree, arg_cts) + +@dataclasses.dataclass(frozen=True) +class RSpec: + idx: int + primal: bool + +def tuptree_map(f, treedef, x): + return treedef.walk(lambda xs, _: tuple(xs), f, x) + +def _is_ref(x): + from jax._src.state.types import AbstractRef + try: + return isinstance(typeof(x), AbstractRef) + except: + return False + + _vjp_too_many_args = """ The function returned by `jax.vjp` applied to {} was called with {} arguments, but functions returned by `jax.vjp` must be called with a single argument @@ -2396,6 +2312,7 @@ def f(x): arguments rather than in a tuple, this error can arise. """.format + def _vjp_ct_tree_error(jaxpr, out_tree, ct_tree): msg = f"""unexpected tree structure. @@ -2410,6 +2327,7 @@ def _vjp_ct_tree_error(jaxpr, out_tree, ct_tree): in equality_errors_pytreedef(out_tree, ct_tree)) raise ValueError(msg) + def _vjp_check_ct_avals(cts, primal_avals): # TODO(mattjj): improve this error by flattening with keys in the first place for ct, aval in zip(cts, primal_avals): @@ -2425,6 +2343,7 @@ def _vjp_check_ct_avals(cts, primal_avals): "because the corresponding output of the differentiated function had JAX type " f"{aval.str_short()}") + @register_dataclass @dataclasses.dataclass(frozen=True) class NotNeeded: diff --git a/jax/experimental/__init__.py b/jax/experimental/__init__.py index 474f9ce5f675..606033c43ac7 100644 --- a/jax/experimental/__init__.py +++ b/jax/experimental/__init__.py @@ -21,10 +21,6 @@ # experimental features and as a result, more flexibility to manage their status # and lifetimes. -from jax._src.api import ( - saved_input_vjp as saved_input_vjp, - si_vjp as si_vjp -) from jax._src.callback import ( io_callback as io_callback ) diff --git a/tests/api_test.py b/tests/api_test.py index 547ef9c22493..020b32ec192c 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -7769,9 +7769,12 @@ def test_basic(self): def f(x, y): return x * y - primals = 2., 3. - y, f_vjp = api.si_vjp(f, [True, True], *primals) - arg_cts = f_vjp(1., *primals) + primals = [2., 3.] + y, f_vjp = jax.vjp(f, *primals) + f_vjp.args_res = [None, None] + y_grad = 1. + f_vjp.args_res = primals + arg_cts = f_vjp(1.) self.assertAllClose(y, 6.) self.assertAllClose(arg_cts, (3., 2.)) @@ -7782,29 +7785,20 @@ def f(x, y): @jax.jit def g(): primals = 2., 3. - y, f_vjp = api.si_vjp(f, [True, True], *primals) + y, f_vjp = jax.vjp(f, *primals) + f_vjp.args_res = [None, None] return y, f_vjp @jax.jit def h(f_vjp): - return f_vjp(1., 2., 3.) + f_vjp.args_res = [2., 3.] + return f_vjp(1.) y, f_vjp = g() arg_cts = h(f_vjp) self.assertAllClose(y, 6.) self.assertAllClose(arg_cts, (3., 2.)) - def test_basic_unused(self): - f = jnp.sin - primals = 3., - y, f_vjp = api.si_vjp(f, [True], *primals) - x_ct, = f_vjp(1., *primals) - self.assertAllClose(y, jnp.sin(3.)) - self.assertAllClose(x_ct, jnp.cos(3.)) - - with self.assertRaisesRegex(Exception, "not used by the backward pass: x"): - _ = api.si_vjp(f, [True], *primals, allow_unused=False) - def test_basic_unused_vjp3(self): f = jnp.sin primals = 3., @@ -7814,58 +7808,28 @@ def test_basic_unused_vjp3(self): self.assertAllClose(x_ct, jnp.cos(3.)) self.assertIsInstance(f_vjp.args_res[0], api.NotNeeded) # can check if unused - def test_basic_opaque(self): - f = jnp.sin - primals = 3., - with self.assertRaisesRegex(Exception, "the backward pass requires opaque"): - _ = api.si_vjp(f, [True], *primals, allow_opaque=False) - def test_basic_opaque_vjp3(self): f = jnp.sin primals = 3., _, f_vjp = api.vjp(f, *primals) - assert f_vjp.opaque_residuals # can detect if opaque res are used + self.assertTrue(f_vjp.opaque_residuals) # can detect if opaque res are used def test_basic_pytree_error(self): def f(x): return [x['hi'] * x['bye']] - y, f_vjp = api.si_vjp(f, [True], {'hi': 2., 'bye': 3.}) - arg_ct, = f_vjp([1.], {'hi': 2., 'bye': 3.}) + y, f_vjp = jax.vjp(f, {'hi': 2., 'bye': 3.}) + f_vjp.args_res = [None] + y_grad = [1.] + f_vjp.args_res = [{'hi': 2., 'bye': 3.}] + arg_ct, = f_vjp(y_grad) self.assertAllClose(y, [6.]) self.assertAllClose(arg_ct, {'hi': 3., 'bye': 2.}) - with self.assertRaisesRegex(ValueError, "but the structures differ"): - f_vjp(1., {'hi': 2.}) - - # TODO(mattjj): improve this vjp3 error message - # def test_basic_pytree_error_vjp3(self): - # def f(x): - # return [x['hi'] * x['bye']] - - # y, f_vjp = api.vjp(f, {'hi': 2., 'bye': 3.}) - # arg_ct, = f_vjp([1.], {'hi': 2., 'bye': 3.}) - # self.assertAllClose(y, [6.]) - # self.assertAllClose(arg_ct, {'hi': 3., 'bye': 2.}) - - # f_vjp.args_res[0] = {'hi': 2.} - # with self.assertRaisesRegex(ValueError, "but the structures differ"): - # f_vjp(1.) - - def test_fsdp(self): - # see https://github.com/jax-ml/jax/pull/27017 for why this is called "fsdp" - def f2(x, w): - x = 1. * x - x = x @ w - x = 2. * x - return x - - x = jnp.ones((3, 4)) - w = jnp.ones((4, 4)) - y, f2_sivjp = api.si_vjp(f2, [False, True], x, w) - y_grad = jnp.ones_like(y) - x_grad, w_grad = f2_sivjp(y_grad, w) - self.assertAllClose(x_grad, 2. * y_grad @ w.T) + # TODO(mattjj): Raise an error message. + # with self.assertRaisesRegex(ValueError, "but the structures differ"): + # f_vjp.args_res = [{'hi': 2.}] + # f_vjp([1.]) def test_fsdp_error(self): # see https://github.com/jax-ml/jax/pull/27017 for why this is called "fsdp" @@ -7877,10 +7841,12 @@ def f2(x, w): x = jnp.ones((3, 4)) w = jnp.ones((4, 4)) - y, f2_sivjp = api.si_vjp(f2, [False, True], x, w) + y, f2_vjp = jax.vjp(f2, x, w) + f2_vjp.args_res[1] = None y_grad = jnp.ones((2, 4)) + f2_vjp.args_res[1] = w with self.assertRaisesRegex(ValueError, "unexpected JAX type"): - f2_sivjp(y_grad, w) + f2_vjp(y_grad) def test_fsdp_vjp3(self): # see https://github.com/jax-ml/jax/pull/27017 for why this is called "fsdp" @@ -7902,10 +7868,11 @@ def f2(x, w): self.assertAllClose(w_grad, 2. * x.T @ y_grad) def test_doesnt_leak_symbolic_zeros(self): - _, vjp = api.si_vjp(lambda x: 1., [False], 3.14) + _, vjp = jax.vjp(lambda x: 1., 3.14) ans, = vjp(1.0) self.assertIsInstance(ans, jax.Array) + class TracebackTest(jtu.JaxTestCase): # These tests are to catch regressions in Python traceback sizes. Our # second-order APIs can be nested arbitrarily and if each one adds a dozen From a4bd50b0cc51af301868191548db0f88fcefba7e Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Mon, 15 Dec 2025 17:27:43 -0500 Subject: [PATCH 212/315] Add batch_size=0 support to jax.lax.map. --- jax/_src/lax/control_flow/loops.py | 7 ++++++- tests/custom_api_test.py | 2 ++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index d46d100da0b2..ed996db7e43e 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -2708,7 +2708,10 @@ def _batch_and_remainder(x, batch_size: int): leaves, treedef = tree_flatten(x) if not leaves: return x, None - num_batches, remainder = divmod(leaves[0].shape[0], batch_size) + if batch_size == 0: + num_batches, remainder = 0, leaves[0].shape[0] + else: + num_batches, remainder = divmod(leaves[0].shape[0], batch_size) batch_elems = num_batches * batch_size if num_batches == 0: remainder_leaves = [_remainder_leaf(leaf, batch_elems) for leaf in leaves] @@ -2749,6 +2752,8 @@ def map(f, xs): divisible by the batch size, the remainder is processed in a separate ``vmap`` and concatenated to the result. + ``batch_size=0`` is equivalent to applying a ``vmap``. That is, it uses a full batch. + >>> x = jnp.ones((10, 3, 4)) >>> def f(x): ... print('inner shape:', x.shape) diff --git a/tests/custom_api_test.py b/tests/custom_api_test.py index f219cefc2fa4..40b7113a55a9 100644 --- a/tests/custom_api_test.py +++ b/tests/custom_api_test.py @@ -4591,6 +4591,7 @@ def vmap_ref(xs, y): self.assertEqual(str(jaxpr), str(jaxpr_ref)) @parameterized.named_parameters( + ("0", 0), ("1", 1), ("8", 4), ("12", 8), @@ -4607,6 +4608,7 @@ def f(x): np.testing.assert_array_equal(y, x**2) @parameterized.named_parameters( + ("0", 0), ("1", 1), ("8", 4), ("12", 8), From 52ee821e3342eba21058a0a52230e78c703e30f5 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Mon, 15 Dec 2025 16:40:35 -0800 Subject: [PATCH 213/315] [Pallas TPU] Add mesh axis info to pallas call metadata Note that this only works if you pass in device_id as a dict. PiperOrigin-RevId: 844969870 --- .../pallas/mosaic/pallas_call_registration.py | 21 +++++- jax/_src/pallas/mosaic/primitives.py | 16 ++++- jax/_src/pallas/primitives.py | 18 +++-- tests/pallas/tpu_pallas_distributed_test.py | 70 +++++++++++++++++++ 4 files changed, 117 insertions(+), 8 deletions(-) diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 1ab2c9f40459..125eafa4eaf1 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -18,6 +18,7 @@ from collections.abc import Sequence import dataclasses +import json from typing import cast import jax @@ -273,6 +274,24 @@ def _maybe_cast_inputs(*args): if mosaic_params.use_tc_tiling_on_sc else tpu_custom_call.Tiling.SPARSE_CORE ) + dict_metadata = dict(metadata) if metadata is not None else {} + del metadata + if jax_mesh is not None: + mesh_axes = { + e.name + for e in jaxpr.effects + if isinstance(e, jax_core.NamedAxisEffect) + # Filter for only device mesh axis name effects + and e.name in jax_mesh.axis_names + } + # Only put mesh axes in metadata if there are any. + if mesh_axes: + if "mesh_axes" in dict_metadata: + raise ValueError("Metadata already contains mesh axes.") + mesh_axes_list = list(mesh_axes) + if all(isinstance(a, str) for a in mesh_axes): + mesh_axes_list = sorted(mesh_axes) # type: ignore + dict_metadata["mesh_axes"] = json.dumps(mesh_axes_list) out_nodes = mosaic.lower_module_to_custom_call( kernel_ctx, *dynamic_grid_args, @@ -292,7 +311,7 @@ def _maybe_cast_inputs(*args): output_memory_spaces=output_memory_spaces, disable_bounds_checks=mosaic_params.disable_bounds_checks, input_memory_spaces=input_memory_spaces, - metadata=dict(metadata) if metadata is not None else None, + metadata=dict_metadata, skip_device_barrier=mosaic_params.skip_device_barrier, allow_collective_id_without_custom_barrier=mosaic_params.allow_collective_id_without_custom_barrier, shape_invariant_numerics=mosaic_params.shape_invariant_numerics, diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index 47a107368f96..5be4028441db 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -362,6 +362,8 @@ def _get_dma_effects( dst_transforms_avals, dst_sem_transforms_avals, src_sem_aval, + device_id_aval, + device_id_type, ): n_src_transforms = len(tree_util.tree_leaves(src_transforms_avals)) n_dst_transforms = len(tree_util.tree_leaves(dst_transforms_avals)) @@ -377,6 +379,15 @@ def _get_dma_effects( 1 + n_src_transforms + 1 + n_dst_transforms + 1 + n_dst_sem_transforms ) effs.add(state.WriteEffect(src_sem_index)) + if device_id_aval is not None: + if device_id_type is primitives.DeviceIdType.MESH and isinstance( + device_id_aval, dict + ): + for k in device_id_aval: + if not isinstance(k, tuple): + k = (k,) + for k_ in k: + effs.add(jax_core.NamedAxisEffect(k_)) return effs @@ -471,6 +482,8 @@ def _dma_start_abstract_eval(*args, tree, device_id_type, priority, add): dst_transforms_avals, dst_sem_transforms_avals, src_sem_aval, + device_id_aval, + device_id_type, ) def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn, @@ -734,7 +747,6 @@ def _dma_wait_to_lojax(*args, tree, device_id_type): @dma_wait_p.def_effectful_abstract_eval def _dma_wait_abstract_eval(*args, tree, device_id_type): - del device_id_type ( src_ref_aval, src_transforms_avals, @@ -751,6 +763,8 @@ def _dma_wait_abstract_eval(*args, tree, device_id_type): dst_transforms_avals, dst_sem_transforms_avals, src_sem_aval, + device_id_aval, + device_id_type, ) def _dma_wait_pp_eqn(eqn: jax_core.JaxprEqn, diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 4ae79d0769e6..8cf362d0bc3b 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -1084,7 +1084,7 @@ def check_sem_avals( ): raise ValueError( f"Must {name} semaphores of the following types:" - f" {allowed_semaphore_types}." + f" {allowed_semaphore_types}. Got {sem_dtype}." ) @@ -1191,26 +1191,32 @@ def _semaphore_signal_abstract_eval( args_tree, device_id_type: DeviceIdType, ): - del device_id_type ( sem_aval, sem_transforms_avals, value_aval, - device_id_avals, + device_id_aval, core_index_aval, ) = tree_util.tree_unflatten(args_tree, avals) check_sem_avals(sem_aval, sem_transforms_avals, "signal") if value_aval.dtype != jnp.dtype("int32"): raise ValueError(f"Must signal an int32 value, but got {value_aval.dtype}") effs : set[effects.Effect] = set() - if device_id_avals is not None: - device_id_flat_avals = tree_util.tree_leaves(device_id_avals) + if device_id_aval is not None: + device_id_flat_avals = tree_util.tree_leaves(device_id_aval) for aval in device_id_flat_avals: if aval.dtype != jnp.dtype("int32"): raise ValueError( f"`device_id`s must be an int32 value, but got {aval.dtype}" ) - effs.add(pallas_core.comms_effect) + if device_id_type is DeviceIdType.MESH and isinstance(device_id_aval, dict): + for k in device_id_aval: + if not isinstance(k, tuple): + k = (k,) + for k_ in k: + effs.add(jax_core.NamedAxisEffect(k_)) + else: + effs.add(pallas_core.comms_effect) return [], effs def _semaphore_signal_pp_eqn(eqn: jax_core.JaxprEqn, diff --git a/tests/pallas/tpu_pallas_distributed_test.py b/tests/pallas/tpu_pallas_distributed_test.py index bfc4e31f08cc..029ef28de555 100644 --- a/tests/pallas/tpu_pallas_distributed_test.py +++ b/tests/pallas/tpu_pallas_distributed_test.py @@ -13,6 +13,7 @@ # limitations under the License. import functools +import json import os import tempfile from absl.testing import absltest @@ -810,5 +811,74 @@ def _(i, _): self.assertNotEmpty(os.listdir(tmpdir)) +class PallasKernelMetadataDistributedTest(parameterized.TestCase): + + @parameterized.product( + axis_names=[['x', 'y'], [('x', 'y')], ['x'], ['y']], + op=['copy', 'signal'], + ) + def test_mesh_axes_metadata_is_preserved(self, axis_names, op): + if not jtu.is_device_tpu_at_least(4): + self.skipTest('Remote async copy only supported on TPU v4+') + if len(jax.devices()) < 4: + self.skipTest('Not enough devices') + devices = np.array(jax.devices()[:4]).reshape((2, 2)) + mesh = jax.sharding.Mesh(devices, ('x', 'y')) + + def kernel(x_ref, out_ref): + def body(send_sem, recv_sem, sem): + if len(jax.tree.leaves(axis_names)) > 0: + device_id = {a: 0 for a in axis_names} + if op == 'copy': + pltpu.async_remote_copy( + x_ref, + out_ref, + send_sem, + recv_sem, + device_id=device_id, + ).wait() + else: + pl.semaphore_signal(sem, device_id=device_id) + else: + out_ref[...] = x_ref[...] + pl.run_scoped( + body, + send_sem=pltpu.SemaphoreType.DMA, + recv_sem=pltpu.SemaphoreType.DMA, + sem=pltpu.SemaphoreType.REGULAR, + ) + + @functools.partial( + jax.jit, + out_shardings=jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec('x', 'y') + ), + ) + @functools.partial( + jax.shard_map, + mesh=mesh, + in_specs=jax.sharding.PartitionSpec('x', 'y'), + out_specs=jax.sharding.PartitionSpec('x', 'y'), + check_vma=False, + ) + def f(x): + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((1, 1, 1, 128), jnp.float32), + in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), + )(x) + + x = jnp.zeros((2, 2, 1, 128), dtype=jnp.float32) + hlo = f.lower(x).compile().as_text() + axis_names_text = json.dumps( + json.dumps(sorted(jax.tree.leaves(axis_names))) + ) + self.assertIn( + f'"mesh_axes":{axis_names_text}', + hlo, + ) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From e914ced652ebc17012eb29ad2ee5c4ca5524f2e1 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Mon, 15 Dec 2025 17:13:17 -0800 Subject: [PATCH 214/315] [PjRt-IFRT] Create `ifrt::PjRtExecutable` only from `ifrt::PjRtCompiler` and `CompileOnlyIfrtCompiler` This change migrates direct calls to `ifrt::PjRtExecutable::Create()` outside to use a public IFRT API `ifrt::PjRtCompiler::Compile()` instead. This change should be no-op in practice. For PjRt-IFRT, it now performs IFRT device ID -> PjRt device ID conversion in `xla::CompileOptions::executable_build_options` (which was missing before) and thus can handle a client using a different device ID mapping. PiperOrigin-RevId: 844980890 --- jaxlib/py_client.cc | 18 ++++++++++++------ jaxlib/py_compile_only_client.cc | 9 ++++++++- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/jaxlib/py_client.cc b/jaxlib/py_client.cc index 9be3e8598a36..6e7bbfddbb5e 100644 --- a/jaxlib/py_client.cc +++ b/jaxlib/py_client.cc @@ -499,27 +499,33 @@ PyClient::CompileAndLoadIfrtProgram( ifrt::DeviceListRef executable_devices, xla::CompileOptions options) { mlir::OwningOpRef clone(module.clone()); module = *clone; - ifrt::ExecutableRef executable_ref; + ifrt::ExecutableRef ifrt_executable; { TF_ASSIGN_OR_RETURN( auto topology, client->ifrt_client()->GetTopologyForDevices(executable_devices)); auto xla_options = std::make_unique( options, std::move(executable_devices)); -#if JAX_IFRT_VERSION_NUMBER >= 38 +#if JAX_IFRT_VERSION_NUMBER >= 42 TF_ASSIGN_OR_RETURN( - executable_ref, + ifrt_executable, + client->ifrt_client()->GetDefaultCompiler()->Compile( + std::make_unique(std::move(module)), + *topology, std::move(xla_options))); +#elif JAX_IFRT_VERSION_NUMBER >= 38 + TF_ASSIGN_OR_RETURN( + ifrt_executable, ifrt::PjRtExecutable::Create(std::move(module), std::move(options), *topology->description())); #else TF_ASSIGN_OR_RETURN( auto pjrt_executable, PjRtCompile(std::move(options), module, *topology->description())); - TF_ASSIGN_OR_RETURN(executable_ref, ifrt::PjRtExecutable::Create( - std::move(pjrt_executable))); + TF_ASSIGN_OR_RETURN(ifrt_executable, ifrt::PjRtExecutable::Create( + std::move(pjrt_executable))); #endif } - return make_nb_class(executable_ref); + return make_nb_class(ifrt_executable); } /* static */ absl::StatusOr> diff --git a/jaxlib/py_compile_only_client.cc b/jaxlib/py_compile_only_client.cc index d38f19e35c83..ca7029c7857b 100644 --- a/jaxlib/py_compile_only_client.cc +++ b/jaxlib/py_compile_only_client.cc @@ -40,6 +40,7 @@ limitations under the License. #include "xla/python/compile_only_ifrt/client.h" #include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/hlo/hlo_program.h" #include "xla/python/pjrt_ifrt/pjrt_executable.h" #include "xla/python/pjrt_ifrt/pjrt_topology.h" #include "xla/python/pjrt_ifrt/xla_compiler.h" @@ -79,7 +80,13 @@ absl::StatusOr> CompileOnlyPyClient::CompileUnloaded( auto xla_options = std::make_unique( options, std::move(executable_devices)); -#if JAX_IFRT_VERSION_NUMBER >= 38 +#if JAX_IFRT_VERSION_NUMBER >= 42 + TF_ASSIGN_OR_RETURN( + ifrt_executable, + ifrt_client->GetDefaultCompiler()->Compile( + std::make_unique(std::move(module)), + ifrt_client->topology(), std::move(xla_options))); +#elif JAX_IFRT_VERSION_NUMBER >= 38 TF_ASSIGN_OR_RETURN( ifrt_executable, ifrt::PjRtExecutable::Create(std::move(module), std::move(options), From 6395b5f859b229bb31e146bc3a5a67f78ae5bb2e Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 15 Dec 2025 18:29:15 -0800 Subject: [PATCH 215/315] Remove jax_custom_vjp_disable_shape_check config option PiperOrigin-RevId: 845005784 --- jax/_src/config.py | 7 ------- jax/_src/custom_derivatives.py | 6 +----- tests/custom_api_test.py | 17 ----------------- 3 files changed, 1 insertion(+), 29 deletions(-) diff --git a/jax/_src/config.py b/jax/_src/config.py index 469eeaf7a873..85d563157939 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1819,13 +1819,6 @@ def _validate_default_device(val): upgrade=False, help='Temporary workaround to disable an error check in vmap-of-shmap.') -# TODO(mattjj): remove once we land mutable array plumbing, or face great shame -custom_vjp_disable_shape_check = bool_state( - name='jax_custom_vjp_disable_shape_check', - default=False, - upgrade=True, - help='Disable the check from #19009 to enable some custom_vjp hacks.') - mutable_array_checks = bool_state( name='jax_mutable_array_checks', default=True, diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index ed9b051d5777..4f5ba126ea75 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -966,8 +966,7 @@ def append(x, d): else: if (not core.typecompat(a.to_tangent_aval(), a_ := core.get_aval(ct)) and not _ref_typecompat(a.to_tangent_aval(), a_) and - not (_temporary_dtype_exception(a, a_) or - _temporary_shape_exception(a, a_))): + not _temporary_dtype_exception(a, a_)): msg = ("Custom VJP bwd rule must produce an output with the same " "shape/dtypes as the args tuple of the primal function, but at " f"output{keystr(kp)} the bwd rule produced an output of " @@ -990,9 +989,6 @@ def _temporary_dtype_exception(a, a_) -> bool: dtypes.issubdtype(a.dtype, dtypes.np.inexact))) return False -# TODO(mattjj): remove both these exceptions to cotangent compatibility check -def _temporary_shape_exception(a, a_) -> bool: - return config.custom_vjp_disable_shape_check.value class CustomVJPCallPrimitive(core.Primitive): multiple_results = True diff --git a/tests/custom_api_test.py b/tests/custom_api_test.py index 40b7113a55a9..9a5b762cefe7 100644 --- a/tests/custom_api_test.py +++ b/tests/custom_api_test.py @@ -2971,23 +2971,6 @@ def foo_bwd(_, g): r'output\[1\] the bwd rule produced an output of shape/dtype float..\[3\]'): jax.grad(lambda x, y: foo(x, y * y).sum(), 1)(jnp.ones(3), jnp.ones(4)) - def test_bwd_rule_shape_mismatch_disable(self): - # TODO(mattjj): remove this test when the config option is removed - @jax.custom_vjp - def foo(x, y): - return x - - def foo_fwd(x, y): - return x, None - - def foo_bwd(_, g): - return jnp.zeros(3), jnp.zeros(3) - - foo.defvjp(foo_fwd, foo_bwd) - - with config.custom_vjp_disable_shape_check(True): - jax.grad(lambda x, y: foo(x, y).sum(), 1)(jnp.ones(3), jnp.ones(4)) - def test_bwd_rule_can_produce_list_or_tuple(self): @jax.custom_vjp def f(x, y): From d5df1927053ce427ea26560a727a8729e684584e Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Mon, 15 Dec 2025 22:10:37 -0800 Subject: [PATCH 216/315] [pmap] Add more detailed documentation about `int` array indexing in JAX. PiperOrigin-RevId: 845080355 --- docs/migrate_pmap.md | 167 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 150 insertions(+), 17 deletions(-) diff --git a/docs/migrate_pmap.md b/docs/migrate_pmap.md index d48aa8fb28cc..be0080577c7e 100644 --- a/docs/migrate_pmap.md +++ b/docs/migrate_pmap.md @@ -92,6 +92,49 @@ Mesh('y': 4, axis_types=(Auto,)) ## Performance implications +### `int` indexing into sharded arrays + +The new implementation of `jax.pmap` uses `NamedSharding` instead of the legacy +`PmapSharding`. We've observe a common pattern with the old `jax.pmap` where +users shard stacked copies of an array to replicate (e.g., via +`jax.device_put_replicated`). These "sharded-but-really-replicated" arrays +suffer unnecessary communication overhead when `int` indexing (e.g., `x[0]`) +because JAX does not know the arrays are actually replicated. For a more +thorough discussion, please see [Appendix A](#appendix-a). + +#### Option 1: Prevent unintended sharding (recommended) +Avoid creating the leading sharded dimension entirely. + +- Use `jax.pmap`'s `out_axes=None` for arguments that should remain replicated. +The output will be fully replicated (e.g., `P(None, None)`), making access +cheap. +- For inputs: When using `jax.device_put`, specify `jax.P()` (fully replicated) +in the partition spec rather than relying on utilities that stack and shard. +(Note: `jax.device_put_replicated` and `jax.device_put_sharded` are deprecated +because they confusingly produce sharded arrays rather than replicated ones). + +#### Option 2: Access local data directly +If you must work with a sharded array (or want potentially fewer changes to +code), you can access the local data shard directly without triggering JAX's +distributed consistency checks. Note that this is only recommended when bringing +data back to host (e.g., for logging, checkpointing). Instead of `x[0]`, use +`addressable_shards`: + +```python +# Old slow way: +# result = x[0] + +# New fast way: +# x.addressable_shards is a list of shards on the current process. +# We grab the first one, extract the data, and remove the leading dimension. +result = x.addressable_shards[0].data.squeeze(0) +``` + +In the example of `x` with shape `(8, 3, 4)`, `x.addressable_shards[0].data` +returns the local chunk of shape `(1, 3, 4)`. Calling `.squeeze(0)` results in +the desired `(3, 4)` shape without any cross-device communication. Both +solutions will eliminate the `_gather` operations seen in profiling. + ### Host local array to global array round-trip conversion In multi-process JAX programs (i.e., `jax.process_count() > 1`), arrays might be @@ -104,23 +147,6 @@ host-local array when returning to user code. This round-trip conversion cannot be avoided, so if the performance penalty is too great, we recommend migrating your code to `jax.shard_map`. -### `int` array indexing - -Indexing into a sharded array with an int (e.g., `arr[0]`) may now execute a -rank reduction computation. Depending on your use case, there may be -workarounds: - -1. In a typical training loop, we might use a `jax.pmap`ed update function to - operate on / carry training state and grab resulting metrics from the first - `jax.pmap`'ed device for logging. In this case, it may be possible to - use `None` for the relevant `in_axes` and `out_axes` passed to `jax.pmap`. - This lets `jax.pmap` handle replication and will return an - appropriately-shaped result that looks like it's from a single device for, - say, logging metrics. -2. More generally, you can get the first shard of data without a reshape via - `arr[0:1]` or `arr.addressable_shards[0].data`. Note that this will have a - leading `(1,)` dimension that your code will need to handle. - ## Migrating to `jax.shard_map` In many cases, users can migrate from `jax.pmap` to `jax.jit(jax.shard_map)` by @@ -132,4 +158,111 @@ dispatch path as in the `jax.shard_map` implementation of `jax.pmap` and can often be overlapped with compute or be called infrequently (i.e., before a train loop and for occasionally grabbing metrics). +(appendix-a)= +## Appendix A: More details about `int` indexing into sharded arrays. + +### What should `x[0]` return? + +In **NumPy**, `x[0]` returns a rank-reduced array representing the first slice +along the first dimension. For example, if `x = np.ones((8, 3, 4))`, then `x[0]` +returns an array of shape `(3, 4)`. + +In **JAX** (`jax.numpy`), `x[0]` semantically works the same way: it returns the +rank-reduced slice of the logical array `x`. However, performance depends on how +`x` is sharded or replicated across devices. Consider an array `x` with shape +`(8, 3, 4)` distributed across 8 devices (using `jax.P` as the short name for +`jax.sharding.PartitionSpec`P): + +1. **Fully Replicated:** `jax.P(None, None, None)` + If `x` is fully replicated, every device holds a complete copy of the `(8, + 3, 4)` array. `x[0]` will have the shape `(3, 4)` and a partition spec + `jax.P(None, None)`. Since every device already has `x`, this operation will + slice on each device independently and requires **no communication**. + +2. **Sharded on Non-Leading Dimension:** `jax.P(None, 'x', None)` + If `x` is sharded along the second dimension, `x[0]` results in shape `(3, + 4)` with partition spec `jax.P('x', None)`. Since the first dimension (the + one being sliced) is unsharded, this operation also requires **no + communication**. + +3. **Sharded on Leading Dimension:** `jax.P('x', None, None)` + If `x` is sharded along the first dimension, `x[0]` results in shape `(3, + 4)` with partition spec `jax.P(None, None)`. + * **The Issue:** Because the first dimension is sharded, the data for + `x[0]` physically resides *only* on the first device. To satisfy the + output sharding `jax.P(None, None)` (which implies replication), JAX + must broadcast the data from the first device to all other devices. This + requires **communication**; JAX will gather the *entire* array of shape + `(8, 3, 4)` to each device and then take a slice. + +### The Common Performance Pitfall + +A common pattern among `jax.pmap` users involves arrays that are **semantically +replicated** (the user intends for them to be identical everywhere) but are +**physically sharded** (stacked along the leading dimension). + +This happens implicitly (e.g., via `jax.pmap(..., out_axes=0)`) or explicitly +(e.g., via `jax.device_put_replicated`). Users often try to retrieve metrics or +checkpoints by calling `unreplicate` or `x[0]`, assuming it is a cheap +operation. + +#### Example: The "Unreplicate" Anti-Pattern + +```python +from flax import jax_utils +import jax.numpy as jnp +import jax + +# jax_utils.replicate calls jax.device_put_replicated. +# This stacks num_devices copies and SHARDS them over the stacked dimension. +# Logical Shape: (8, 3, 4) | Sharding: P('x', None, None) +train_state = jax_utils.replicate({'params': jnp.zeros((3, 4))}) + +# out_axes=0 by default, so the output remains sharded along dim 0. +train_step_pmapped = jax.pmap(lambda x: x) + +# jax_utils.unreplicate performs a jax.tree_map(lambda x: x[0], tree). +# Users do this to grab metrics, log param statistics, checkpoint, etc. +train_state = jax_utils.unreplicate(train_step_pmapped(train_state)) +``` + +#### The Consequence +Even though the user knows `train_state` contains identical data on every +device, JAX sees an array with `shape (8, 3, 4)` and spec `jax.P('x', None, +None)` i.e., an array that is sharded along its leading dimension. JAX cannot +safely assume the data is identical on each device. Therefore, `x[0]` triggers a +gather of the entire array to all devices before slicing to ensure correctness. +This unnecessary communication causes performance degradation (visible as +_gather operations in a stack trace). + +``` +train + └─ jax_utils.py:48 unreplicate + └─ tree_util.py:354 tree_map + └─ jax_utils.py:50 (performing x[0]) + └─ array.py:335 __getitem__ + └─ indexing.py:734 rewriting_take + │ + ▼ + └─ indexing.py:784 _gather + └─ slicing.py:324 gather + └─ PjitFunction(gather) +``` + +### Why was "Old Pmap" Fast? +Historically, `pmap` used `PmapSharding`, which had a fast-path optimization in +`jax.Array`'s `__getitem__` allowing it to return an array with a +`SingleDeviceSharding` (data residing on only one device). + +However, current JAX uses `NamedSharding`. We do not strictly replicate the +legacy behavior because it breaks the semantics of array indexing. If we allowed +`x[0]` to return a `SingleDeviceSharding` array in a general context (e.g., in +the middle of a train step instead of when trying to bring data back to host for +reporting), only one device would have data while others would have nothing. +This is computationally problematic for subsequent operations. + +The slowdown users experience now is JAX enforcing correct semantics: if you ask +for `x[0]` from an array sharded along its leading dimension, you get a fully +replicated result available on all devices, which requires communication. + From dd62dd76642220562eb73ca5e14391d713513d56 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 16 Dec 2025 00:06:44 -0800 Subject: [PATCH 217/315] Update XLA dependency to use revision http://github.com/openxla/xla/commit/d20fe8e99b411a6b61e91ad3aeeadd6e26f9fc7d PiperOrigin-RevId: 845116836 --- third_party/xla/revision.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/revision.bzl b/third_party/xla/revision.bzl index 79bdbc60e232..3c027d1816cd 100644 --- a/third_party/xla/revision.bzl +++ b/third_party/xla/revision.bzl @@ -21,5 +21,5 @@ # and update XLA_SHA256 with the result. # buildifier: disable=module-docstring -XLA_COMMIT = "b9f8bd1a637329cfe8eecdf8ac42b5b96445b563" -XLA_SHA256 = "a6efa4a48f737155e41498d5cea6ce1a30828418bdda1c36504ab0f385ac36a5" +XLA_COMMIT = "d20fe8e99b411a6b61e91ad3aeeadd6e26f9fc7d" +XLA_SHA256 = "0bab08c8933e282e4437532418fcfc3b1b2648749a8546526542d35127f1d982" From dda9095ad0192a031e2ab78b05fae1cbfe9b3fff Mon Sep 17 00:00:00 2001 From: Vladimir Belitskiy Date: Tue, 16 Dec 2025 06:55:24 -0800 Subject: [PATCH 218/315] Suppress "healthcheck too slow" for OpsTest.test_select_n under ASAN. PiperOrigin-RevId: 845251115 --- tests/pallas/ops_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 2d87545e86ef..fdd1c9900788 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -550,6 +550,8 @@ def kernel(x_ref, o_ref): # TODO(sharadmv): test rank < 2, size < 2 @hp.given(select_n_strategy(max_cases=2, min_rank=2, max_rank=4, min_size_exp=1)) + @hp.settings(suppress_health_check=([hp.HealthCheck.too_slow] + if jtu.is_asan() else [])) def test_select_n(self, args): if jtu.test_device_matches(["gpu"]): self.skipTest("TODO: error on GPU, lowering bug for select_n") From b8747f94672322b578e699a5842e35a13e63976e Mon Sep 17 00:00:00 2001 From: Vladimir Belitskiy Date: Tue, 16 Dec 2025 06:59:18 -0800 Subject: [PATCH 219/315] Disable the linalg test target under TSAN on CPU due to timeouts. PiperOrigin-RevId: 845252480 --- tests/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/BUILD b/tests/BUILD index 6c1fd8915878..3f12aea68415 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -955,6 +955,7 @@ jax_multiplatform_test( "cpu": [ "noasan", "nomsan", + "notsan", # Times out. ], # TODO(phawkins): Latest SciPy leaks memory. }, shard_count = { From 6860386637a966234242a0c4fd5c6550d36a1071 Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Tue, 16 Dec 2025 07:03:02 -0800 Subject: [PATCH 220/315] [Mosaic GPU][NFC] Remove duplicate lowering rule for `vector.BroadcastOp`. PiperOrigin-RevId: 845253825 --- .../mosaic/gpu/dialect_lowering.py | 24 ++++--------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index f442b9064376..7e3cc39783da 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -606,32 +606,16 @@ def _broadcasted_iota_op_lowering_rule( return [fragmented_array_to_ir(a, result_type)] -@_register_lowering(vector.BroadcastOp) -def _vector_splat_op_lowering_rule( - _: LoweringContext, vector_splat_op: vector.BroadcastOp -) -> Sequence[ir.Value]: - - out_vec_ty = ir.VectorType(vector_splat_op.aggregate.type) - fragmented_array = fa.FragmentedArray.splat( - vector_splat_op.input, - tuple(out_vec_ty.shape), - layouts.from_layout_attr(vector_splat_op.attributes["out_layouts"][0]), - is_signed=_default_is_signed(out_vec_ty.element_type), - ) - return [fragmented_array_to_ir(fragmented_array, out_vec_ty)] - - @_register_lowering(vector.BroadcastOp) def _vector_broadcast_op_lowering_rule( - _: LoweringContext, vector_broadcast_op: vector.BroadcastOp + _: LoweringContext, op: vector.BroadcastOp ) -> Sequence[ir.Value]: - - out_vec_ty = ir.VectorType(vector_broadcast_op.vector.type) + out_vec_ty = ir.VectorType(op.vector.type) fragmented_array = fa.FragmentedArray.splat( - vector_broadcast_op.source, + op.source, tuple(out_vec_ty.shape), layouts.from_layout_attr( - vector_broadcast_op.attributes["out_layouts"][0] + op.attributes["out_layouts"][0] ), is_signed=_default_is_signed(out_vec_ty.element_type), ) From 3cbe1377a972999f285696268d345109dcbd654e Mon Sep 17 00:00:00 2001 From: George Necula Date: Thu, 11 Dec 2025 11:42:19 +0000 Subject: [PATCH 221/315] [export] Fix the "with mesh" deprecation warning --- tests/export_back_compat_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 41a4b99ed944..e41d29a53cf8 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -813,7 +813,7 @@ def func(x): # b: f32[2, 4] # the expected custom call targets for old test data that was serialized # with custom calls. for data, custom_call_targets_override in data: - with mesh: + with jax.set_mesh(mesh): if jax.config.jax_use_shardy_partitioner: self.run_one_test( func, self.load_testdata(data["shardy"]), @@ -1040,7 +1040,7 @@ def shard_map_func(x): # b: f32[2, 4] # the expected custom call targets for old test data that was serialized # with custom calls. for data, custom_call_targets_override in data: - with Mesh(devices, axis_names=('x')): + with jax.set_mesh(Mesh(devices, axis_names=('x'))): self.run_one_test( func, self.load_testdata(data), expect_current_custom_calls=custom_call_targets_override) From f5a09b8ff76b3ca2e19175aba3a8f3ff65b0919d Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Tue, 16 Dec 2025 07:43:34 -0800 Subject: [PATCH 222/315] [Mosaic GPU] Add support for all kinds of TMA reductions. I had to change the tma descriptor cache key, since there are cases where we currently need two different descriptors based on the reduction op. We could in principle go back to a single TMA descriptor in those cases if we pass sign information to async_copy. PiperOrigin-RevId: 845269465 --- jax/experimental/mosaic/gpu/launch_context.py | 150 ++++++++++++------ jaxlib/mosaic/dialect/gpu/mosaic_gpu.td | 6 +- jaxlib/mosaic/gpu/runtime.cc | 10 +- tests/mosaic/gpu_test.py | 85 ++++++++-- 4 files changed, 184 insertions(+), 67 deletions(-) diff --git a/jax/experimental/mosaic/gpu/launch_context.py b/jax/experimental/mosaic/gpu/launch_context.py index 88b44978897c..ff8d94814c22 100644 --- a/jax/experimental/mosaic/gpu/launch_context.py +++ b/jax/experimental/mosaic/gpu/launch_context.py @@ -39,7 +39,24 @@ TMA_DESCRIPTOR_BYTES = 128 TMA_DESCRIPTOR_ALIGNMENT = 64 -TMAReductionOp = Literal["add", "min", "max", "inc", "dec", "and", "or", "xor"] +TMAReductionOp = Literal[ + "add", + "min", + "max", + "inc", + "dec", + "and", + "or", + "xor", + "umin", + "umax", + "smin", + "smax", +] + +def _reduction_op_to_ptx(reduction_op: TMAReductionOp) -> str: + # convert [s|u]min|max to min|max + return reduction_op[-3:] c = utils.c # This is too common to fully qualify. @@ -426,6 +443,81 @@ def _find_kernel_argument_for_gmem_ref( return gmem_ref +def _is_tma_reduction_op_supported( + reduction_op: TMAReductionOp | None, dtype: ir.Type, +) -> bool: + """Returns whether the given TMA reduction op supports the given dtype. + + This function essentially implements the table at: + https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-reduce-async-bulk-tensor + with the following differences: + - For `add` reductions, we also support int64, treating it as uint64. + - For `and`, `or`, and `xor` reductions, we support signed integer types. + - For `inc` and `dec` reductions, we support both signed and unsigned i32 + treating both as unsigned. + """ + i32 = ir.IntegerType.get_signless(32) + i64 = ir.IntegerType.get_signless(64) + f16 = ir.F16Type.get() + f32 = ir.F32Type.get() + bf16 = ir.BF16Type.get() + + match reduction_op: + case None: + return True + case "add": + return dtype in (f16, f32, bf16, i32, i64) + case "max" | "min": + return dtype in (f16, bf16) + case "umax" | "umin" | "smax" | "smin": + return dtype in (i32, i64) + case "inc" | "dec": + return dtype == i32 + case "and" | "or" | "xor": + return dtype in (i32, i64) + + +def _tma_dma_type( + element_type: ir.Type, + reduction_op: TMAReductionOp | None, +) -> int: + """Returns the TMA DMA type for the given element type and signedness.""" + if ir.IntegerType.isinstance(element_type): + bitwidth = utils.bitwidth_impl(element_type) + if bitwidth == 2: + tma_dtype = 8 + elif bitwidth == 4: + tma_dtype = 0 + elif bitwidth == 8: + tma_dtype = 1 + elif bitwidth == 16: + tma_dtype = 2 + elif bitwidth == 32: + tma_dtype = 9 if reduction_op in ("smin", "smax") else 3 + elif bitwidth == 64: + tma_dtype = 10 if reduction_op in ("smin", "smax") else 4 + else: + raise ValueError(f"Unsupported integer bitwidth: {bitwidth}") + elif ir.F16Type.isinstance(element_type): + tma_dtype = 5 + elif ir.F32Type.isinstance(element_type): + tma_dtype = 6 + elif ir.BF16Type.isinstance(element_type): + tma_dtype = 7 + # We treat narrow floats as integers + elif ir.Float8E5M2Type.isinstance(element_type): + tma_dtype = 1 + elif ir.Float8E4M3FNType.isinstance(element_type): + tma_dtype = 1 + elif ir.Float8E8M0FNUType.isinstance(element_type): + tma_dtype = 1 + elif ir.Float4E2M1FNType.isinstance(element_type): + tma_dtype = 0 + else: + raise ValueError(f"unsupported TMA dtype {element_type}") + return tma_dtype + + class AsyncCopyImplementation(enum.Enum): TMA = enum.auto() CP_ASYNC = enum.auto() @@ -438,7 +530,7 @@ class LaunchContext: cluster_size: tuple[int, int, int] profiler: OnDeviceProfiler | None = None tma_descriptors: dict[ - tuple[ir.Value, tuple[int, ...], int | None, tuple[MemRefTransform, ...], Any], + tuple[ir.Value, tuple[int, ...], int | None, tuple[MemRefTransform, ...], Any, int], ir.Value, ] = dataclasses.field(default_factory=dict, init=False) is_device_collective: bool = False @@ -512,10 +604,11 @@ def _get_tma_desc( reduction_op: TMAReductionOp | None, ): gmem_ref = _find_kernel_argument_for_gmem_ref(gmem_ref) + tma_dtype = _tma_dma_type(ir.MemRefType(gmem_ref.type).element_type, reduction_op) # Using ir.Values in cache keys is a little sketchy, but I think it should # be fine. Having it in the key will keep it alive, and if comparison and # hashing is by identity then it should work out. - tma_desc_key = (gmem_ref, transformed_slice_shape, swizzle, gmem_transform, gmem_peer_id) + tma_desc_key = (gmem_ref, transformed_slice_shape, swizzle, gmem_transform, gmem_peer_id, tma_dtype) if (tma_desc := self.tma_descriptors.get(tma_desc_key, None)) is None: i32 = ir.IntegerType.get_signless(32) i64 = ir.IntegerType.get_signless(64) @@ -580,43 +673,6 @@ def init_tma_desc(host_ptr): ) # TODO(apaszke): Better verification (e.g. slice is non-zero) # TODO(apaszke): We always know strides statically. - if isinstance(ref_ty.element_type, ir.IntegerType): - if reduction_op is not None: - raise ValueError( - f"TMA with reduction_op={reduction_op} is not supported with Integers" - ) - bitwidth = utils.bitwidth_impl(ref_ty.element_type) - if bitwidth == 2: - tma_dtype = 8 - elif bitwidth == 4: - tma_dtype = 0 - elif bitwidth == 8: - tma_dtype = 1 - elif bitwidth == 16: - tma_dtype = 2 - elif bitwidth == 32: - tma_dtype = 3 - elif bitwidth == 64: - tma_dtype = 4 - else: - raise ValueError(f"Unsupported integer bitwidth: {bitwidth}") - elif ir.F16Type.isinstance(ref_ty.element_type): - tma_dtype = 5 - elif ir.F32Type.isinstance(ref_ty.element_type): - tma_dtype = 6 - elif ir.BF16Type.isinstance(ref_ty.element_type): - tma_dtype = 7 - # We treat narrow floats as integers - elif ir.Float8E5M2Type.isinstance(ref_ty.element_type): - tma_dtype = 1 - elif ir.Float8E4M3FNType.isinstance(ref_ty.element_type): - tma_dtype = 1 - elif ir.Float8E8M0FNUType.isinstance(ref_ty.element_type): - tma_dtype = 1 - elif ir.Float4E2M1FNType.isinstance(ref_ty.element_type): - tma_dtype = 0 - else: - raise ValueError(f"unsupported TMA dtype {ref_ty.element_type}") dtype_or_bitwidth = c(tma_dtype, i64) args = [ host_ptr, @@ -953,16 +1009,10 @@ def async_copy( if reduction_op is not None: if implementation != AsyncCopyImplementation.TMA: raise ValueError("Only the TMA implementation supports reductions") - if not any( - t.isinstance(element_type) - for t in (ir.F32Type, ir.BF16Type, ir.F16Type) - ): - raise ValueError( - "TMA with reduction is only supported with f32, f16 and bf16" - ) - if reduction_op != "add": + if not _is_tma_reduction_op_supported(reduction_op, element_type): raise ValueError( - "TMA with reduction is only supported with add operation" + f"Reduction op {reduction_op} not supported by the TMA" + f" implementation for element type {element_type}" ) if src_ref_ty.memory_space is None and utils.is_smem_ref(dst_ref_ty): @@ -1329,7 +1379,7 @@ def async_copy( llvm.inline_asm( ir.Type.parse("!llvm.void"), [predicate,smem_ptr,tma_desc,*rev_dyn_base_indices], - f"@$0 cp.reduce.async.bulk.tensor.{rank}d.global.shared::cta.{reduction_op}.tile.bulk_group [$2,{{{idx_operands}}}], [$1];", + f"@$0 cp.reduce.async.bulk.tensor.{rank}d.global.shared::cta.{_reduction_op_to_ptx(reduction_op)}.tile.bulk_group [$2,{{{idx_operands}}}], [$1];", "b,r,l" + ",r" * rank, has_side_effects=True, ) diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index 9fc855486ad4..cc733d28fbc8 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -205,7 +205,11 @@ def MosaicGPU_TMAReduction : I32EnumAttr<"TMAReduction", I32EnumAttrCase<"Dec", 4, "dec">, I32EnumAttrCase<"And", 5, "and">, I32EnumAttrCase<"Or", 6, "or">, - I32EnumAttrCase<"Xor", 7, "xor"> + I32EnumAttrCase<"Xor", 7, "xor">, + I32EnumAttrCase<"Umin", 8, "umin">, + I32EnumAttrCase<"Umax", 9, "umax">, + I32EnumAttrCase<"Smin", 10, "smin">, + I32EnumAttrCase<"Smax", 11, "smax"> ]>{ let cppNamespace = "::mosaic_gpu"; } diff --git a/jaxlib/mosaic/gpu/runtime.cc b/jaxlib/mosaic/gpu/runtime.cc index 7c12c8e2748e..4d94120aa8c0 100644 --- a/jaxlib/mosaic/gpu/runtime.cc +++ b/jaxlib/mosaic/gpu/runtime.cc @@ -47,7 +47,7 @@ void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr, CUtensorMapDataType data_type; int64_t elem_bitwidth; - // types are defined in: LaunchContext._get_tma_desc() + // types are defined in: launch_context._tma_dma_type() if (elem_type == 8){ // this is for int2s data_type = CU_TENSOR_MAP_DATA_TYPE_UINT8; @@ -77,7 +77,13 @@ void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr, } else if (elem_type == 7){ data_type = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; elem_bitwidth = 16; - } else{ + } else if (elem_type == 9){ + data_type = CU_TENSOR_MAP_DATA_TYPE_INT32; + elem_bitwidth = 32; + } else if (elem_type == 10){ + data_type = CU_TENSOR_MAP_DATA_TYPE_INT64; + elem_bitwidth = 64; + } else{ fprintf(stderr, "Unsupported element type: %ld \n", elem_type); abort(); } diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 33f4fed610b8..423d0d1e7c3b 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -5397,29 +5397,57 @@ def body( x = self.prng.uniform(0, 10, input_shape).astype(el_type) self.assertArraysEqual(kernel(x), x.reshape(output_shape)) - @parameterized.parameters(jnp.float32, jnp.bfloat16, jnp.float16) - def test_async_store_add_reduction(self, dtype): - # TODO(b/415721295):Remove after the minimal jaxlib version is 0.8.2. + @parameterized.product( + dtype=(jnp.int32, jnp.int64, jnp.uint32, jnp.uint64, jnp.float32, jnp.float16, jnp.bfloat16), + reduction_op=("add", "min", "max", "inc", "dec", "and", "or", "xor"), + ) + def test_async_store_reduction(self, dtype, reduction_op): + # TODO(b/415721295):Clean up after the minimal jaxlib version is 0.8.2. if not hasattr(mgpu_dialect, "TMAReduction"): - self.skipTest("TMAReduction op is required.") + self.skipTest("The mgpu_dialect.TMAReduction attribute is required.") + + if reduction_op in ("min", "max"): + if dtype in (jnp.int32, jnp.int64): + reduction_op = "s" + reduction_op + elif dtype in (jnp.uint32, jnp.uint64): + reduction_op = "u" + reduction_op + + if reduction_op in ("smin", "smax", "umin", "umax") and not hasattr(mgpu_dialect.TMAReduction, "Smin"): + self.skipTest("The Smin/Smax/Umin/Umax reduction types are required.") + + if ( + not launch_context._is_tma_reduction_op_supported( + reduction_op, + utils.dtype_to_ir_type(dtype), + ) + or ( + dtype in (jnp.uint32, jnp.uint64) + and reduction_op in ("smin", "smax") + ) + or ( + dtype in (jnp.int32, jnp.int64) and reduction_op in ("umin", "umax") + ) + or dtype == jnp.int32 and reduction_op in ("inc", "dec") + ): + self.skipTest("TMA does not support this reduction op for this dtype") shape = (8, 128) def body(ctx, src, dst, smem): del ctx - smem_ref, tma_barrier = smem + src_smem_ref, tma_barrier = smem i32 = ir.IntegerType.get_signless(32) zero = arith.constant(i32, 0) indices = [zero, zero] - slice_lengths = smem_ref.type.shape + slice_lengths = src_smem_ref.type.shape tma_barrier.arrive_expect_tx( - utils.bitwidth(smem_ref.type.element_type) * math.prod(shape) // 8 + utils.bitwidth(src_smem_ref.type.element_type) * math.prod(shape) // 8 ) mgpu_dialect.async_load( source=src, - destination=smem_ref, + destination=src_smem_ref, barrier=tma_barrier.as_barrier_memref(), indices=indices, slice_lengths=slice_lengths, @@ -5428,31 +5456,60 @@ def body(ctx, src, dst, smem): tma_barrier.wait() + reduction_attr = getattr( + mgpu_dialect.TMAReduction, reduction_op.capitalize() + ) + mgpu_dialect.async_store( - source=smem_ref, + source=src_smem_ref, destination=dst, indices=indices, slice_lengths=slice_lengths, - reduction_op=mgpu_dialect.TMAReduction.Add, + reduction_op=reduction_attr, ) nvvm.cp_async_bulk_wait_group(0) - src = jnp.ones(shape, dtype=dtype) - dst = jnp.ones(shape, dtype=dtype) + prng_key = jax.random.key(1234) + k0, k1 = jax.random.split(prng_key, 2) + if dtype in (jnp.bfloat16, jnp.float16, jnp.float32): + src = jax.random.uniform(k0, shape, dtype, -10, 10) + dst = jax.random.uniform(k1, shape, dtype, -10, 10) + else: + src = jax.random.randint(k0, shape, -10, 10).astype(dtype) + dst = jax.random.randint(k1, shape, -10, 10).astype(dtype) + + if reduction_op == "add": + expected = src + dst + elif reduction_op in ("min", "smin", "umin"): + expected = jnp.minimum(src, dst) + elif reduction_op in ("max", "smax", "umax"): + expected = jnp.maximum(src, dst) + elif reduction_op == "and": + expected = src & dst + elif reduction_op == "or": + expected = src | dst + elif reduction_op == "xor": + expected = src ^ dst + elif reduction_op == "inc": + expected = jnp.where(dst >= src, 0, dst + 1) + elif reduction_op == "dec": + expected = jnp.where((dst == 0) | (dst > src), src, dst - 1) + else: + raise ValueError(f"Unsupported reduction op: {reduction_op}") jax_shape = jax.ShapeDtypeStruct(shape, dtype) kernel = mgpu.as_gpu_kernel( body, grid=(1, 1, 1), block=(128, 1, 1), - in_shape=(jax_shape,), + in_shape=(jax_shape), out_shape=(), inout_shape=(jax_shape,), smem_scratch_shape=[jax_shape, core.TMABarrier(1)], thread_semantics=mgpu.LoweringSemantics.Warpgroup, ) - np.testing.assert_array_equal(kernel(src, dst)[0], src + dst) + np.testing.assert_array_equal(kernel(src, dst)[0], expected) class MosaicGpuDialectSm90ATest(Sm90ATestCase, jtu.JaxTestCase): From 21b8652685482c2ae95afe1a97e4d3a20eacbd43 Mon Sep 17 00:00:00 2001 From: Yulia Baturina Date: Tue, 16 Dec 2025 07:43:57 -0800 Subject: [PATCH 223/315] Update `rules_ml_toolchain` version. PiperOrigin-RevId: 845269589 --- WORKSPACE | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index 87d1e6830a9f..67474c08749d 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -1,6 +1,7 @@ +load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") + # The XLA commit is determined by third_party/xla/revision.bzl. load("//third_party/xla:workspace.bzl", jax_xla_workspace = "repo") -load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") jax_xla_workspace() @@ -12,15 +13,15 @@ load("@xla//:workspace3.bzl", "xla_workspace3") xla_workspace3() -load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") - # Initialize Hermetic toolchains # Details: https://github.com/google-ml-infra/rules_ml_toolchain tf_http_archive( name = "rules_ml_toolchain", - sha256 = "7f00b3e94bbca1a4737ded6b9ed5358f6d1c86430c2ec97c90081343c0482f18", - strip_prefix = "rules_ml_toolchain-29d54c875da37e74b8548924ed30e78cb28126b9", - urls = tf_mirror_urls("https://github.com/google-ml-infra/rules_ml_toolchain/archive/29d54c875da37e74b8548924ed30e78cb28126b9.tar.gz"), + sha256 = "e9842de3fefb5a120d3b1647d3a09e6e7071e8df8d1cd2dfe6f66ee31fd2595e", + strip_prefix = "rules_ml_toolchain-cb79a8fc8dcf3f75743dcd9b3418a70c884a7269", + urls = tf_mirror_urls( + "https://github.com/google-ml-infra/rules_ml_toolchain/archive/cb79a8fc8dcf3f75743dcd9b3418a70c884a7269.tar.gz", + ), ) load( From 06d576be8d0a126422dbbc45abceb44984e017a1 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 16 Dec 2025 09:33:10 -0800 Subject: [PATCH 224/315] Remove more deprecated BUILD aliases. PiperOrigin-RevId: 845311597 --- jax/BUILD | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 16b88eef390f..c5ab5c1189cd 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -288,27 +288,3 @@ pytype_strict_library( "//jax/experimental:xla_metadata", ], ) - -alias( - name = "mesh_utils", - actual = "//jax/experimental:mesh_utils", - visibility = jax_visibility("mesh_utils_deprecated_alias"), -) - -alias( - name = "pallas", - actual = "//jax/experimental:pallas", - visibility = jax_visibility("pallas_deprecated_alias"), -) - -alias( - name = "pallas_tpu", - actual = "//jax/experimental:pallas_tpu", - visibility = jax_visibility("pallas_tpu_deprecated_alias"), -) - -alias( - name = "optimizers", - actual = "//jax/example_libraries:optimizers", - visibility = jax_visibility("optimizers_deprecated_alias"), -) From bb339675b3835fbe10c0e2762c2139711b521dc7 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 16 Dec 2025 09:57:08 -0800 Subject: [PATCH 225/315] Add Guidance for GCS Fuse for Compilation Cache in JAX PiperOrigin-RevId: 845321019 --- docs/persistent_compilation_cache.md | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/docs/persistent_compilation_cache.md b/docs/persistent_compilation_cache.md index ae3d6ddfcad0..6e82d995b782 100644 --- a/docs/persistent_compilation_cache.md +++ b/docs/persistent_compilation_cache.md @@ -132,6 +132,34 @@ Cloud Storage (GCS) bucket. We recommend the following configuration: * All encryption policies are supported. +It is **recommended** to use +[Google Cloud Storage Fuse](https://cloud.google.com/storage/docs/cloud-storage-fuse) +to mount the GCS bucket as a local directory. This is because when running JAX +in a multi-node setup, multiple nodes might try to write to the cache +simultaneously, leading to GCS rate-limit errors. GCSFuse handles this by +ensuring that only one process can write to a file at a time, preventing these +errors. + +To set up GCSFuse, follow instructions for +[GCE](https://cloud.google.com/storage/docs/cloud-storage-fuse/mount-bucket) or +[GKE](https://cloud.google.com/kubernetes-engine/docs/how-to/cloud-storage-fuse-csi-driver-setup). +For better performance, enable file caching +([GCE](https://cloud.google.com/storage/docs/cloud-storage-fuse/file-caching) and +[GKE](https://cloud.google.com/kubernetes-engine/docs/how-to/cloud-storage-fuse-csi-driver-perf#enable-and-use-file-caching)). + +Once GCSFuse is configured, set the JAX cache directory to the GCSFuse mount +point: + +```python +# Example assuming the GCS bucket is mounted at /gcs/my-bucket +jax.config.update("jax_compilation_cache_dir", "/gcs/my-bucket/jax-cache") +``` + +**Direct GCS access :** + +If you choose not to use GCSFuse, you can point the cache directly to a GCS +bucket. + Assuming that `gs://jax-cache` is the GCS bucket, set cache location as follows: From b841f5b4a7167c8a6e31b101f027175f1a29cec2 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 16 Dec 2025 10:43:33 -0800 Subject: [PATCH 226/315] Refactor indexing code --- jax/_src/numpy/indexing.py | 747 ++++++++++++++++++------------- jax/_src/ops/scatter.py | 14 +- tests/lax_numpy_indexing_test.py | 16 +- 3 files changed, 456 insertions(+), 321 deletions(-) diff --git a/jax/_src/numpy/indexing.py b/jax/_src/numpy/indexing.py index cf58764c6293..9f808540d85f 100644 --- a/jax/_src/numpy/indexing.py +++ b/jax/_src/numpy/indexing.py @@ -15,12 +15,14 @@ # pytype: skip-file """Indexing code for jax.numpy.""" +from __future__ import annotations + +import dataclasses import enum from functools import partial import operator import string -from typing import Any, NamedTuple, cast -from types import EllipsisType +from typing import Any, NamedTuple from collections.abc import Sequence import numpy as np @@ -36,6 +38,7 @@ from jax._src.lax import lax from jax._src.lax import slicing from jax._src.lax import utils as lax_utils +from jax._src.numpy import array_constructors from jax._src.numpy import einsum from jax._src.numpy import error as jnp_error from jax._src.numpy import lax_numpy @@ -44,13 +47,384 @@ from jax._src.partition_spec import PartitionSpec from jax._src.pjit import auto_axes from jax._src.sharding_impls import canonicalize_sharding, NamedSharding -from jax._src.tree_util import tree_flatten -from jax._src.typing import Array, ArrayLike, Index, StaticIndex, StaticScalar -from jax._src.util import canonicalize_axis, safe_zip, set_module, tuple_update +from jax._src.tree_util import tree_flatten, tree_unflatten, register_pytree_node_class +from jax._src.typing import Array, ArrayLike, Index, StaticScalar +from jax._src.util import canonicalize_axis, safe_zip, set_module, tuple_update, unzip3 export = set_module('jax.numpy') +# Internal utilities for parsing and validating NumPy-style indices. + +class IndexType(enum.Enum): + """Enum for tracking the type of an index.""" + NONE = "none" + SLICE = "slice" + ELLIPSIS = "ellipsis" + INTEGER = "integer" + BOOLEAN = "boolean" + ARRAY = "array" + + @classmethod + def from_index(cls, idx: Index) -> IndexType: + """Create an IndexType enum from a supported JAX array index.""" + if idx is None: + return cls.NONE + elif idx is Ellipsis: + return cls.ELLIPSIS + elif isinstance(idx, slice): + return cls.SLICE + elif _is_integer_index(idx): + return cls.INTEGER + elif _is_boolean_index(idx): + return cls.BOOLEAN + elif isinstance(idx, (Array, np.ndarray, literals.TypedNdArray)): + if dtypes.issubdtype(idx.dtype, np.integer): + return cls.ARRAY + else: + raise TypeError( + f"Indexer must have integer or boolean type, got indexer with type {idx.dtype}") + elif isinstance(idx, str): + # TODO(jakevdp): this TypeError is for backward compatibility. + # We should switch to IndexError for consistency. + raise TypeError(f"JAX does not support string indexing; got {idx=}") + elif isinstance(idx, Sequence): + if not idx: # empty indices default to float, so special-case this. + return cls.ARRAY + idx_aval = api.eval_shape(array_constructors.asarray, idx) + if idx_aval.dtype == bool: + return cls.BOOLEAN + elif dtypes.issubdtype(idx_aval.dtype, np.integer): + return cls.ARRAY + else: + raise TypeError( + f"Indexer must have integer or boolean type, got indexer with type {idx_aval.dtype}") + elif isinstance(idx, (float, complex, np.generic)): + raise TypeError( + f"Indexer must have integer or boolean type, got indexer with type {np.dtype(type(idx))}") + else: + raise IndexError("only integers, slices (`:`), ellipsis (`...`), newaxis (`None`)" + f" and integer or boolean arrays are valid indices. Got {idx}") + + +class ParsedIndex(NamedTuple): + """Structure for tracking an indexer parsed within the context of an array shape.""" + index: Index # type: ignore[assignment] # seems to be a strange misfire by mypy. + typ: IndexType + consumed_axes: tuple[int, ...] + + +def _parse_indices( + indices: tuple[Index, ...], + shape: tuple[int, ...], +) -> list[ParsedIndex]: + """Parse indices in the context of an array shape. + + Args: + indices: a tuple of user-supplied indices to be parsed. + shape: the shape of the array being indexed. + + Returns: + The list of parsed indices stored in :class:`ParsedIndex` objects. + This list will have the same length as ``indices``. + + Raises: + IndexError: if any unrecognized index types are present or if there + are too many indices, or too many ellipses. + """ + # 1. go through indices to count the number of consumed dimensions. + # This is required to determine the effect of any ellipses. + dimensions_consumed: list[int] = [] + ellipses_indices: list[int] = [] + index_types: list[IndexType] = [] + for i, idx in enumerate(indices): + typ = IndexType.from_index(idx) + index_types.append(typ) + + if typ == IndexType.NONE: + dimensions_consumed.append(0) + elif typ == IndexType.ELLIPSIS: + # We don't yet know how many dimensions are consumed, so set to zero + # for now and update later. + dimensions_consumed.append(0) + ellipses_indices.append(i) + elif typ == IndexType.BOOLEAN: + dimensions_consumed.append(np.ndim(idx)) # type: ignore[arg-type] + elif typ in [IndexType.INTEGER, IndexType.ARRAY, IndexType.SLICE]: + dimensions_consumed.append(1) + else: + raise IndexError(f"Unrecognized index type: {typ}") + + # 2. Validate the consumed dimensions and ellipses. + if len(ellipses_indices) > 1: + raise IndexError("an index can only have a single ellipsis ('...')") + total_consumed = sum(dimensions_consumed) + if total_consumed > len(shape): + raise IndexError(f"Too many indices: array is {len(shape)}-dimensional," + f" but {total_consumed} were indexed") + if ellipses_indices: + dimensions_consumed[ellipses_indices[0]] = len(shape) - total_consumed + + # 3. Generate the final sequence of parsed indices. + result: list[ParsedIndex] = [] + current_dim = 0 + for index, typ, n_consumed in safe_zip(indices, index_types, dimensions_consumed): + consumed_axes = tuple(range(current_dim, current_dim + n_consumed)) + current_dim += len(consumed_axes) + result.append(ParsedIndex(index=index, typ=typ, consumed_axes=consumed_axes)) + return result + + +@register_pytree_node_class +@dataclasses.dataclass(frozen=True, kw_only=True) +class NDIndexer: + """Object that implements NumPy-style indexing operations on top of JAX. + + Generally this will be constructed via the :meth:`NDIndexer.from_raw_indices` + method. + + Attributes: + shape: the shape of the array being indexed. + indices: a list of :class:`ParsedIndex` objects. + """ + shape: tuple[int, ...] + indices: list[ParsedIndex] + + @classmethod + def from_raw_indices(cls, indices: Index | tuple[Index, ...], shape: tuple[int, ...]) -> NDIndexer: + """Create an NDIndexer object from raw user-supplied indices.""" + indices = eliminate_deprecated_list_indexing(indices) + indices = _parse_indices(indices, shape) + return cls(shape=shape, indices=indices) + + def validate_static_indices(self, normalize_indices: bool = True) -> None: + """Check that all static integer indices are in-bounds. + + Raises an IndexError in case of out-of-bound indices + """ + for position, idx in enumerate(self.indices): + if idx.typ == IndexType.INTEGER: + assert isinstance(idx.index, (int, np.integer)) + i = operator.index(idx.index) + axis, = idx.consumed_axes + size = self.shape[axis] + normed_idx = i + size if normalize_indices and i < 0 else i + if not 0 <= normed_idx < size: + raise IndexError(f"index {i} out of bounds for axis {axis} with size {size}" + f" ({normalize_indices=})") + + def validate_slices(self) -> None: + """Check that all slices have static start/stop/step values. + + Raises an IndexError in case of non-static entries. + """ + for position, idx in enumerate(self.indices): + if idx.typ == IndexType.SLICE: + assert isinstance(idx.index, slice) + if not all(_is_slice_element_none_or_constant_or_symbolic(val) + for val in [idx.index.start, idx.index.stop, idx.index.step]): + raise IndexError("Slice entries must be static integers." + f" Got {idx.index} at position {position}") + + def expand_bool_indices(self) -> NDIndexer: + """Returns a new NDIndexer with boolean indices replaced by array indices. + + The only exception are scalar boolean indices, which are left in-place. + """ + expanded_indices: list[ParsedIndex] = [] + + for position, idx in enumerate(self.indices): + if idx.typ != IndexType.BOOLEAN: + expanded_indices.append(idx) + continue + if not core.is_concrete(idx.index): + # TODO(mattjj): improve this error by tracking _why_ the indices are not concrete + raise errors.NonConcreteBooleanIndexError(core.get_aval(idx.index)) + assert isinstance(idx.index, (bool, np.ndarray, Array, literals.TypedNdArray, list)) + if np.ndim(idx.index) == 0: + # Scalar booleans + assert idx.consumed_axes == () + expanded_indices.append(ParsedIndex(index=bool(idx.index), typ=idx.typ, consumed_axes=())) + continue + idx_shape = np.shape(idx.index) + expected_shape = [self.shape[i] for i in idx.consumed_axes] + if not all(s1 in (0, s2) for s1, s2 in zip(idx_shape, expected_shape)): + raise IndexError("boolean index did not match shape of indexed array in index" + f" {position}: got {idx_shape}, expected {expected_shape}") + expanded_indices_raw = np.where(np.asarray(idx.index)) + expanded_indices.extend(ParsedIndex(index=i, typ=IndexType.ARRAY, consumed_axes=(axis,)) + for i, axis in safe_zip(expanded_indices_raw, idx.consumed_axes)) + return NDIndexer(shape=self.shape, indices=expanded_indices) + + def expand_scalar_bool_indices(self, sharding_spec: Any = None) -> tuple[NDIndexer, Any]: + new_shape = list(self.shape) + new_sharding_spec = list((None for _ in self.shape) if sharding_spec is None else sharding_spec) + new_indices = list(self.indices) + current_dim = 0 + for i, idx in enumerate(self.indices): + if idx.typ == IndexType.BOOLEAN and np.ndim(idx.index) == 0: # type: ignore[arg-type] + new_shape.insert(i, 1) + new_sharding_spec.insert(i, None) + new_indices[i] = ParsedIndex( + np.arange(int(idx.index)), typ=IndexType.ARRAY, consumed_axes=(current_dim,)) # type: ignore[arg-type] + current_dim += 1 + else: + n_consumed = len(idx.consumed_axes) + new_indices[i] = ParsedIndex( + index=idx.index, + typ=idx.typ, + consumed_axes = tuple(range(current_dim, current_dim + n_consumed)) + ) + current_dim += n_consumed + new_sharding_spec = None if sharding_spec is None else tuple(new_sharding_spec) + return NDIndexer(indices=new_indices, shape=tuple(new_shape)), new_sharding_spec + + def convert_sequences_to_arrays(self) -> NDIndexer: + new_indices = [ParsedIndex(lax_numpy.asarray(idx.index), typ=idx.typ, consumed_axes=idx.consumed_axes) + if isinstance(idx.index, Sequence) else idx for idx in self.indices] + return NDIndexer(indices=new_indices, shape=self.shape) + + def expand_ellipses(self) -> NDIndexer: + """ + Returns a new indexer with ellipsis and implicit trailing slices + replaced by explicit empty slices. + """ + expanded: list[ParsedIndex] = [] + consumed = 0 + for idx in self.indices: + consumed += len(idx.consumed_axes) + if idx.typ == IndexType.ELLIPSIS: + for axis in idx.consumed_axes: + expanded.append(ParsedIndex(index=slice(None), typ=IndexType.SLICE, consumed_axes=(axis,))) + else: + expanded.append(idx) + for axis in range(consumed, len(self.shape)): + expanded.append(ParsedIndex(index=slice(None), typ=IndexType.SLICE, consumed_axes=(axis,))) + return NDIndexer(shape=self.shape, indices=expanded) + + def normalize_indices(self) -> NDIndexer: + new_indices: list[ParsedIndex] = [] + for idx in self.indices: + if idx.typ == IndexType.INTEGER: + axis, = idx.consumed_axes + size: ArrayLike = self.shape[axis] + if isinstance(idx.index, np.unsignedinteger): + normed_index: Index = idx.index + else: + normed_index = idx.index + size if idx.index < 0 else idx.index # type: ignore[assignment,operator] + new_indices.append(ParsedIndex(normed_index, typ=idx.typ, consumed_axes=idx.consumed_axes)) + elif idx.typ == IndexType.ARRAY: + assert isinstance(idx.index, (Array, np.ndarray, literals.TypedNdArray)) + axis, = idx.consumed_axes + if dtypes.issubdtype(idx.index.dtype, np.unsignedinteger): + normed_index = idx.index + else: + size = self.shape[axis] + if core.is_constant_dim(size): + size = lax._const(idx.index, size) + else: + size = lax.convert_element_type(core.dimension_as_value(size), + idx.index.dtype) + normed_index = lax.select(idx.index < 0, lax.add(idx.index, size), idx.index) + new_indices.append(ParsedIndex(normed_index, typ=idx.typ, consumed_axes=idx.consumed_axes)) + else: + new_indices.append(idx) + return NDIndexer(indices=new_indices, shape=self.shape) + + def compute_via_static_slice(self, arr: Array) -> Array: + """Equivalent of arr[idx] implemented in terms of static :func:`lax.slice` operations. + + This supports only INTEGER, ELLIPSIS, and SLICE indices, and will raise a TypeError + if other indices are present. + """ + # Validation of the unmodified user indices. + self.validate_static_indices(normalize_indices=True) + self.validate_slices() + + for position, pidx in enumerate(self.indices): + if pidx.typ in [IndexType.INTEGER, IndexType.ELLIPSIS, IndexType.SLICE]: + pass + elif pidx.typ == IndexType.NONE: + raise TypeError(f"static_slice: got {pidx.index} at position {position}") + elif pidx.typ in [IndexType.ARRAY, IndexType.BOOLEAN]: + raise TypeError("static_slice: indices must be static scalars or slices." + f" Got {pidx.index} at position {position}") + else: + raise TypeError(f"static_slice: unrecognized index {pidx.index} at position {position}.") + + # Now re-iterate to generate static slices. + start_indices: list[int] = [] + limit_indices: list[int] = [] + strides: list[int] = [] + rev_axes: list[int] = [] + squeeze_axes: list[int] = [] + + expanded = self.expand_ellipses() + for pidx in expanded.indices: + if pidx.typ in [IndexType.ARRAY, IndexType.BOOLEAN, IndexType.NONE, IndexType.ELLIPSIS]: + raise RuntimeError(f"Internal: unexpected index encountered: {pidx}") + elif pidx.typ == IndexType.INTEGER: + assert isinstance(pidx.index, (int, np.integer)) + axis, = pidx.consumed_axes + start_index = int(pidx.index + arr.shape[axis] if pidx.index < 0 else pidx.index) + start_indices.append(start_index) + limit_indices.append(start_index + 1) + strides.append(1) + squeeze_axes.append(axis) + elif pidx.typ == IndexType.SLICE: + assert isinstance(pidx.index, slice) + axis, = pidx.consumed_axes + size = arr.shape[axis] + start, stop, stride = pidx.index.indices(size) + if stride < 0: + new_start = stop + 1 + abs(start - stop - 1) % abs(stride) + start_indices.append(new_start) + limit_indices.append(max(new_start, start + 1)) + strides.append(abs(stride)) + rev_axes.append(axis) + else: + start_indices.append(start) + limit_indices.append(stop) + strides.append(stride) + else: + raise TypeError(f"static_slice: unrecognized index {pidx.index}") + result = arr + if start_indices: + result = slicing.slice(result, start_indices, limit_indices, strides) + if rev_axes: + result = lax.rev(result, rev_axes) + if squeeze_axes: + result = lax.squeeze(result, squeeze_axes) + return result + + def is_advanced_int_indexer(self): + """Returns True if idx should trigger int array indexing, False otherwise.""" + # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing + return any(idx.typ in [IndexType.ARRAY, IndexType.BOOLEAN] and np.ndim(idx.index) > 0 + for idx in self.indices) + + def to_gather(self, x_sharding: NamedSharding | Any, + normalize_indices: bool = True) -> _GatherIndexer: + return _index_to_gather(self, x_sharding=x_sharding, normalize_indices=normalize_indices) + + def tree_flatten(self): + # split dynamic and static indices + def is_dynamic(i: ParsedIndex): + return i.typ in [IndexType.INTEGER, IndexType.ARRAY, IndexType.BOOLEAN] + raw_dynamic_indices = [i.index if is_dynamic(i) else None for i in self.indices] + static_metadata = [ + ParsedIndex(index=None, typ=i.typ, consumed_axes=i.consumed_axes) if is_dynamic(i) else i + for i in self.indices] + return raw_dynamic_indices, (self.shape, static_metadata) + + @classmethod + def tree_unflatten(cls, aux_data, children): + shape, static_metadata = aux_data + indices = [idx if dyn_index is None else ParsedIndex(dyn_index, idx.typ, idx.consumed_axes) + for dyn_index, idx in safe_zip(children, static_metadata)] + return cls(indices=indices, shape=shape) + + @export def take( a: ArrayLike, @@ -532,12 +906,14 @@ def _is_contiguous_slice(idx): (idx.step is None or (_is_integer_index(idx.step) and idx.step == 1))) def _attempt_rewriting_take_via_slice( - arr: Array, - idx: Index | tuple[Index, ...], *, + arr: Array, indexer: NDIndexer, *, mode: str | slicing.GatherScatterMode | None, out_sharding: NamedSharding | PartitionSpec | None = None) -> Array | None: # attempt to compute _rewriting_take via lax.slice(); return None if not possible. - idx = idx if isinstance(idx, tuple) else (idx,) + + # TODO(jakevdp): update implementation to use indexer directly, and to reuse code + # from compute_via_static_slice + idx = tuple(i.index for i in indexer.indices) if not all(isinstance(i, int) for i in arr.shape): return None @@ -598,7 +974,7 @@ def _attempt_rewriting_take_via_slice( allow_negative_indices.append(start < 0 or stop < 0) else: assert np.issubdtype(dtypes.dtype(ind), np.integer) # checked above - assert np.shape(ind) == () # checked above + assert np.shape(ind) == () # type: ignore[arg-type] # checked above start_indices.append(ind) slice_sizes.append(1) allow_negative_indices.append( @@ -635,95 +1011,6 @@ def _attempt_rewriting_take_via_slice( return arr -def static_slice(arr: Array, idx: StaticIndex | tuple[StaticIndex, ...]): - """Compute NumPy-style indexing for static slices only.""" - idx = idx if isinstance(idx, tuple) else (idx,) - - # First validate the types of entries before expanding ellipses: this allows - # error messages to point to particular positions supplied by the user. - # Valid index types here are integers, ellipses, and slices. - for position, ind in enumerate(idx): - if isinstance(ind, (int, np.integer, EllipsisType)): - pass - elif isinstance(ind, slice): - if not all(val is None or isinstance(val, (int, np.integer)) - for val in [ind.start, ind.stop, ind.step]): - raise ValueError("Slice entries must be static integers." - f" Got {ind} at position {position}") - elif ind is None: - raise TypeError(f"static_slice: got {ind} at position {position}") - elif isinstance(ind, (np.ndarray, Array, tuple, list, Sequence)): - raise TypeError("static_slice: indices must be static scalars or slices." - f" Got {ind} at position {position}") - else: - raise TypeError("static_slice: unrecognized index {ind} at position {position}.") - - # Now expand ellipses and validate the index values. This allows error messages - # to point to relevant array dimensions. - idx = _canonicalize_tuple_index(arr.ndim, idx) - start_indices: list[int] = [] - limit_indices: list[int] = [] - strides: list[int] = [] - rev_axes: list[int] = [] - squeeze_axes: list[int] = [] - - for axis, (ind, size) in enumerate(safe_zip(idx, arr.shape)): - if isinstance(ind, (int, np.integer)): - if not (-size <= ind < size): - raise IndexError(f"index {ind} out of bounds for axis {axis} with size {size}") - if ind < 0: - ind += size - start_indices.append(ind) - limit_indices.append(ind + 1) - strides.append(1) - squeeze_axes.append(axis) - elif isinstance(ind, slice): - start, stop, stride = ind.indices(size) - if stride < 0: - new_start = stop + 1 + abs(start - stop - 1) % abs(stride) - start_indices.append(new_start) - limit_indices.append(max(new_start, start + 1)) - strides.append(abs(stride)) - rev_axes.append(axis) - else: - start_indices.append(start) - limit_indices.append(stop) - strides.append(stride) - else: - raise ValueError(f"Unexpected index: {ind} at axis {axis}") - - if start_indices: - result = slicing.slice(arr, start_indices, limit_indices, strides) - if rev_axes: - result = lax.rev(result, rev_axes) - if squeeze_axes: - result = lax.squeeze(result, squeeze_axes) - return result - - -def validate_static_indices( - arr: Array, - idx: Index | tuple[Index, ...], *, - normalize_indices: bool) -> None: - """Perform bounds-checks for static indices. - - Raises an IndexError if any static indices are out-of-bounds. - """ - # TODO(jakevdp): expand_bool_indices is expensive; do this more efficiently. - idx = idx if isinstance(idx, tuple) else (idx,) - idx = _expand_bool_indices(idx, arr.shape) - idx_tup = tuple(i for i in _canonicalize_tuple_index(arr.ndim, idx) - if i is not None and not isinstance(i, bool)) - def norm_index(i, size): - return i + size if normalize_indices and i < 0 else i - if len(idx_tup) != arr.ndim: - raise RuntimeError(f"Error for {idx=} and {arr.shape=}: processed {idx_tup=}") - for axis, (i, size) in enumerate(safe_zip(idx_tup, arr.shape)): - if isinstance(i, (int, np.integer)) and (norm_index(i, size) < 0 or i >= size): - raise IndexError(f"index {i} out of bounds for axis {axis} with size {size}" - f" ({normalize_indices=})") - - class IndexingStrategy(enum.Enum): AUTO = 'auto' GATHER = 'gather' @@ -745,29 +1032,31 @@ def rewriting_take( # Computes arr[idx]. # All supported cases of indexing can be implemented as an XLA gather, # followed by an optional reverse and broadcast_in_dim. + indexer = NDIndexer.from_raw_indices(idx, arr.shape) if not isinstance(strategy, IndexingStrategy): raise TypeError(f"Expected strategy to be IndexingStrategy; got {strategy}") if config.check_static_indices.value and (mode is None or slicing.GatherScatterMode.from_any(mode) == slicing.GatherScatterMode.PROMISE_IN_BOUNDS): - validate_static_indices(arr, idx, normalize_indices=normalize_indices) + indexer.validate_static_indices(normalize_indices=normalize_indices) if strategy == IndexingStrategy.STATIC_SLICE: if not normalize_indices: raise ValueError("strategy=STATIC_SLICE is only supported when normalize_indices=True.") - return static_slice(arr, cast(StaticIndex | tuple[StaticIndex, ...], idx)) + return indexer.compute_via_static_slice(arr) # For simplicity of generated primitives, we call lax.slice or lax.dynamic_slice # in the simplest cases: i.e. non-dynamic arrays indexed with integers and slices. # TODO(jakevdp): lower to slice even when normalize_indices is False if strategy == IndexingStrategy.AUTO and normalize_indices: - result = _attempt_rewriting_take_via_slice(arr, idx, mode=mode, out_sharding=out_sharding) + result = _attempt_rewriting_take_via_slice(arr, indexer, mode=mode, out_sharding=out_sharding) if result is not None: return result - treedef, static_idx, dynamic_idx = split_index_for_jit(idx, arr.shape) + indexer = indexer.expand_bool_indices() + dynamic_idx, treedef = tree_flatten(indexer) internal_gather = partial( - _gather, treedef=treedef, static_idx=static_idx, + _gather, treedef=treedef, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode, fill_value=fill_value, normalize_indices=normalize_indices) if out_sharding is not None: @@ -781,12 +1070,11 @@ def rewriting_take( # TODO(phawkins): re-enable jit after fixing excessive recompilation for # slice indexes (e.g., slice(0, 5, None), slice(10, 15, None), etc.). # @api.jit(static_argnums=(1, 2)) -def _gather(arr, dynamic_idx, *, treedef, static_idx, indices_are_sorted, +def _gather(arr, dynamic_idx, *, treedef, indices_are_sorted, unique_indices, mode, fill_value, normalize_indices): - idx = merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx) - indexer = index_to_gather( - np.shape(arr), idx, core.typeof(arr).sharding, - normalize_indices=normalize_indices) # shared with _scatter_update + parsed_idx = tree_unflatten(treedef, dynamic_idx) + indexer = parsed_idx.to_gather(core.typeof(arr).sharding, + normalize_indices=normalize_indices) jnp_error._check_precondition_oob_gather(arr.shape, indexer.gather_indices) y = arr @@ -821,7 +1109,7 @@ def _gather(arr, dynamic_idx, *, treedef, static_idx, indices_are_sorted, return lax.expand_dims(y, indexer.newaxis_dims) -class _Indexer(NamedTuple): +class _GatherIndexer(NamedTuple): # The expected shape of the slice output. slice_shape: Sequence[int] # The slice shape to pass to lax.gather(). @@ -853,123 +1141,43 @@ class _Indexer(NamedTuple): slice_sharding: NamedSharding | None = None -def split_index_for_jit(idx, shape): - """Splits indices into necessarily-static and dynamic parts. - - Used to pass indices into `jit`-ted function. - """ - # Convert list indices to tuples in cases (deprecated by NumPy.) - idx = eliminate_deprecated_list_indexing(idx) - if any(isinstance(i, str) for i in idx): - raise TypeError(f"JAX does not support string indexing; got {idx=}") - - # Expand any (concrete) boolean indices. We can then use advanced integer - # indexing logic to handle them. - idx = _expand_bool_indices(idx, shape) - - leaves, treedef = tree_flatten(idx) - dynamic = [None] * len(leaves) - static = [None] * len(leaves) - for i, x in enumerate(leaves): - if x is Ellipsis: - static[i] = x - elif isinstance(x, slice): - # slice objects aren't hashable. - static[i] = (x.start, x.stop, x.step) - else: - dynamic[i] = x - return treedef, tuple(static), dynamic - -def merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx): - """Recombines indices that were split by split_index_for_jit.""" - idx = [] - for s, d in zip(static_idx, dynamic_idx): - if d is not None: - idx.append(d) - elif isinstance(s, tuple): - idx.append(slice(s[0], s[1], s[2])) - else: - idx.append(s) - return treedef.unflatten(idx) +def _index_to_gather(indexer: NDIndexer, *, x_sharding: NamedSharding | Any, + normalize_indices: bool = True) -> _GatherIndexer: + indexer.validate_slices() + indexer = indexer.convert_sequences_to_arrays() -def _int(aval): - return not aval.shape and dtypes.issubdtype(aval.dtype, np.integer) - -def _aval_or_none(x): - try: - return core.get_aval(x) - except: - return None - -def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], - x_sharding, normalize_indices: bool = True) -> _Indexer: - # Convert sequences to arrays - idx = tuple(lax_numpy.asarray(i, dtype=None if i else int) - if isinstance(i, Sequence) else i for i in idx) - abstract_idx = [_aval_or_none(i) for i in idx] - float_indices = [(i, val, aval) for i, (val, aval) in enumerate(zip(idx, abstract_idx)) - if aval is not None and dtypes.issubdtype(aval, np.inexact)] - - # Check for float or complex indices: - if float_indices: - i, val, aval = float_indices[0] - msg = ("Indexer must have integer or boolean type, got indexer " - "with type {} at position {}, indexer value {}") - raise TypeError(msg.format(aval.dtype.name, i, val)) - - # Check whether advanced indices are contiguous. We must do this before - # removing ellipses (https://github.com/jax-ml/jax/issues/25109) - # If advanced idexing axes do not appear contiguously, NumPy semantics - # move the advanced axes to the front. - (is_advanced,) = np.nonzero([ - isinstance(e, (int, np.integer, Array, np.ndarray, - literals.TypedNdArray)) - or lax_numpy.isscalar(e) - for e in idx - ]) + is_advanced = np.nonzero([idx.typ in {IndexType.ARRAY, IndexType.INTEGER} for idx in indexer.indices]) advanced_axes_are_contiguous = np.all(np.diff(is_advanced) == 1) - # Remove ellipses and add trailing slice(None)s. - idx = _canonicalize_tuple_index(len(x_shape), idx) + indexer = indexer.expand_ellipses() - x_spec = x_sharding.spec + scalar_bool_dims: Sequence[int] = [n for n, i in enumerate(indexer.indices) if i.typ == IndexType.BOOLEAN] + indexer, x_spec = indexer.expand_scalar_bool_indices(x_sharding.spec) - # Check for scalar boolean indexing: this requires inserting extra dimensions - # before performing the rest of the logic. - scalar_bool_dims: Sequence[int] = [n for n, i in enumerate(idx) if isinstance(i, bool)] - if scalar_bool_dims: - idx = tuple(np.arange(int(i)) if isinstance(i, bool) else i for i in idx) - x_shape = list(x_shape) - x_spec = list(x_spec) - for i in sorted(scalar_bool_dims): - x_shape.insert(i, 1) - x_spec.insert(i, None) - x_shape = tuple(x_shape) - x_spec = tuple(x_spec) + if normalize_indices: + indexer = indexer.normalize_indices() # Check for advanced indexing: # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing - advanced_indexes: Sequence[Array | np.ndarray] | None = None + # The advanced indices. + advanced_indexes: Sequence[Array] = [] # The positions of the advanced indexing axes in `idx`. idx_advanced_axes: Sequence[int] = [] # The positions of the advanced indexes in x's shape. # collapsed, after None axes have been removed. See below. - x_advanced_axes: Sequence[int] | None = None + x_advanced_axes: Sequence[int] = [] - if _is_advanced_int_indexer(idx): - idx_no_nones = [(i, d) for i, d in enumerate(idx) if d is not None] + if indexer.is_advanced_int_indexer(): + idx_without_none = [(i, d) for i, d in enumerate(indexer.indices) if d.typ != IndexType.NONE] advanced_pairs = ( - (lax_numpy.asarray(e), i, j) for j, (i, e) in enumerate(idx_no_nones) - if lax_numpy.isscalar(e) - or isinstance(e, (Sequence, Array, np.ndarray, - literals.TypedNdArray))) - if normalize_indices: - advanced_pairs = ((_normalize_index(e, x_shape[j]), i, j) - for e, i, j in advanced_pairs) - advanced_indexes, idx_advanced_axes, x_advanced_axes = zip(*advanced_pairs) + (lax_numpy.asarray(e.index), i, j) + for j, (i, e) in enumerate(idx_without_none) + if e.typ in [IndexType.ARRAY, IndexType.INTEGER] + ) + advanced_indexes, idx_advanced_axes, x_advanced_axes = unzip3(advanced_pairs) x_axis = 0 # Current axis in x. y_axis = 0 # Current axis in y, before collapsing. See below. @@ -980,7 +1188,7 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], collapsed_slice_dims: list[int] = [] start_index_map: list[int] = [] - index_dtype = lax_utils.int_dtype_for_shape(x_shape, signed=True) + index_dtype = lax_utils.int_dtype_for_shape(indexer.shape, signed=True) # Gather indices. # Pairs of (array, start_dim) values. These will be broadcast into @@ -1002,11 +1210,11 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], gather_slice_shape: list[int] = [] slice_spec = [] - for idx_pos, i in enumerate(idx): + for idx_pos, index in enumerate(indexer.indices): # Handle the advanced indices here if: # * the advanced indices were not contiguous and we are the start. # * we are at the position of the first advanced index. - if (advanced_indexes is not None and + if (advanced_indexes and (advanced_axes_are_contiguous and idx_pos == idx_advanced_axes[0] or not advanced_axes_are_contiguous and idx_pos == 0)): advanced_index_arrs = util._broadcast_arrays(*advanced_indexes) @@ -1035,46 +1243,35 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], gather_slice_shape.append(1) continue - # Handle basic int indexes. - abstract_i = _aval_or_none(i) - if isinstance(abstract_i, core.ShapedArray) and _int(abstract_i): - if core.definitely_equal(x_shape[x_axis], 0): + if index.typ in [IndexType.INTEGER, IndexType.ARRAY] and np.ndim(index.index) == 0: # type: ignore[arg-type] + # Basic scalar int indices + if core.definitely_equal(indexer.shape[x_axis], 0): # XLA gives error when indexing into an axis of size 0 raise IndexError(f"index is out of bounds for axis {x_axis} with size 0") - i = _normalize_index(i, x_shape[x_axis]) if normalize_indices else i - i_converted = lax.convert_element_type(i, index_dtype) + i_converted = lax.convert_element_type(index.index, index_dtype) # type: ignore[arg-type] gather_indices.append((i_converted, len(gather_indices_shape))) collapsed_slice_dims.append(x_axis) gather_slice_shape.append(1) start_index_map.append(x_axis) x_axis += 1 - # Handle np.newaxis (None) - elif i is None: + + elif index.typ == IndexType.NONE: + # None indexing: add a dimension. slice_shape.append(1) slice_spec.append(None) newaxis_dims.append(y_axis) y_axis += 1 - elif isinstance(i, slice): - # Handle slice index (only static, otherwise an error is raised) - if not all(_is_slice_element_none_or_constant_or_symbolic(elt) - for elt in (i.start, i.stop, i.step)): - msg = ("Array slice indices must have static start/stop/step to be used " - "with NumPy indexing syntax. " - f"Found slice({i.start}, {i.stop}, {i.step}). " - "To index a statically sized " - "array at a dynamic position, try lax.dynamic_slice/" - "dynamic_update_slice (JAX does not support dynamically sized " - "arrays within JIT compiled functions).") - raise IndexError(msg) - - start, step, slice_size = core.canonicalize_slice(i, x_shape[x_axis]) + elif index.typ == IndexType.SLICE: + # Handle static slice index. + assert isinstance(index.index, slice) + start, step, slice_size = core.canonicalize_slice(index.index, indexer.shape[x_axis]) slice_shape.append(slice_size) slice_spec.append(x_spec[x_axis]) if core.definitely_equal(step, 1): - # Avoid generating trivial gather (an optimization) - if not core.definitely_equal(slice_size, x_shape[x_axis]): + # Optimization: avoid generating trivial gather. + if not core.definitely_equal(slice_size, indexer.shape[x_axis]): gather_indices.append((lax.convert_element_type(start, index_dtype), len(gather_indices_shape))) start_index_map.append(x_axis) @@ -1097,14 +1294,7 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], y_axis += 1 x_axis += 1 else: - if (abstract_i is not None and - not (dtypes.issubdtype(abstract_i.dtype, np.integer) or dtypes.issubdtype(abstract_i.dtype, np.bool_))): - msg = ("Indexer must have integer or boolean type, got indexer " - "with type {} at position {}, indexer value {}") - raise TypeError(msg.format(abstract_i.dtype.name, idx_pos, i)) - - raise IndexError("Indexing mode not yet supported. Got unsupported indexer " - f"at position {idx_pos}: {i!r}") + raise IndexError(f"Got unsupported indexer at position {idx_pos}: {index!r}") if len(gather_indices) == 0: gather_indices_array: ArrayLike = np.zeros((0,), dtype=index_dtype) @@ -1125,15 +1315,15 @@ def index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], start_index_map = tuple(start_index_map) ) slice_sharding = x_sharding.update(spec=slice_spec) - return _Indexer( + return _GatherIndexer( slice_shape=slice_shape, newaxis_dims=tuple(newaxis_dims), gather_slice_shape=gather_slice_shape, reversed_y_dims=reversed_y_dims, dnums=dnums, gather_indices=gather_indices_array, - unique_indices=advanced_indexes is None, - indices_are_sorted=advanced_indexes is None, + unique_indices=not advanced_indexes, + indices_are_sorted=not advanced_indexes, scalar_bool_dims=scalar_bool_dims, slice_sharding=slice_sharding) @@ -1178,52 +1368,6 @@ def _is_boolean_index(i): or isinstance(i, list) and i and all(_is_scalar(e) and dtypes.issubdtype(dtypes.dtype(e), np.bool_) for e in i)) -def _expand_bool_indices(idx, shape): - """Converts concrete bool indexes into advanced integer indexes.""" - out = [] - total_dims = len(shape) - num_ellipsis = sum(e is Ellipsis for e in idx) - if num_ellipsis > 1: - raise IndexError("an index can only have a single ellipsis ('...')") - elif num_ellipsis == 1: - total_dims = sum(np.ndim(e) if _is_boolean_index(e) else 1 for e in idx - if e is not None and e is not Ellipsis) - ellipsis_offset = 0 - newaxis_offset = 0 - for dim_number, i in enumerate(idx): - try: - abstract_i = core.get_aval(i) - except TypeError: - abstract_i = None - if _is_boolean_index(i): - if isinstance(i, list): - i = lax_numpy.array(i) - abstract_i = core.get_aval(i) - - if not core.is_concrete(i): - # TODO(mattjj): improve this error by tracking _why_ the indices are not concrete - raise errors.NonConcreteBooleanIndexError(abstract_i) - elif np.ndim(i) == 0: - out.append(bool(i)) - else: - i_shape = np.shape(i) - start = len(out) + ellipsis_offset - newaxis_offset - expected_shape = shape[start: start + np.ndim(i)] - if len(i_shape) != len(expected_shape): - raise IndexError(f"too many boolean indices at index {dim_number}: got mask of shape " - f"{i_shape}, but only {len(expected_shape)} dimensions remain.") - if not all(s1 in (0, s2) for s1, s2 in zip(i_shape, expected_shape)): - raise IndexError("boolean index did not match shape of indexed array in index " - f"{dim_number}: got {i_shape}, expected {expected_shape}") - out.extend(np.where(i)) - else: - out.append(i) - if i is Ellipsis: - ellipsis_offset = len(shape) - total_dims - 1 - if i is None: - newaxis_offset += 1 - return tuple(out) - def _is_slice_element_none_or_constant_or_symbolic(elt): """Return True if elt is a constant or None.""" @@ -1234,23 +1378,6 @@ def _is_slice_element_none_or_constant_or_symbolic(elt): except TypeError: return False -# TODO(mattjj): clean up this logic -def _is_advanced_int_indexer(idx): - """Returns True if idx should trigger int array indexing, False otherwise.""" - # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing - assert isinstance(idx, tuple) - if all(e is None or e is Ellipsis or isinstance(e, slice) - or _is_scalar(e) and dtypes.issubdtype(dtypes.dtype(e), np.integer) for e in idx): - return False - return all(e is None or e is Ellipsis or isinstance(e, slice) - or _is_int_arraylike(e) for e in idx) - -def _is_int_arraylike(x): - """Returns True if x is array-like with integer dtype, False otherwise.""" - return (isinstance(x, int) and not isinstance(x, bool) - or dtypes.issubdtype(getattr(x, "dtype", None), np.integer) - or isinstance(x, (list, tuple)) and all(_is_int_arraylike(e) for e in x)) - def _is_scalar(x): """Checks if a Python or NumPy scalar.""" return np.isscalar(x) or ( diff --git a/jax/_src/ops/scatter.py b/jax/_src/ops/scatter.py index 0cb35e310e25..f120d386b5fd 100644 --- a/jax/_src/ops/scatter.py +++ b/jax/_src/ops/scatter.py @@ -73,18 +73,19 @@ def _scatter_update(x: ArrayLike, idx: Index | tuple[Index, ...], # XLA gathers and scatters are very similar in structure; the scatter logic # is more or less a transpose of the gather equivalent. - treedef, static_idx, dynamic_idx = indexing.split_index_for_jit(idx, x.shape) + indexer = indexing.NDIndexer.from_raw_indices(idx, x.shape).expand_bool_indices() + dynamic_idx, treedef = tree_util.tree_flatten(indexer) internal_scatter = partial( _scatter_impl, scatter_op=scatter_op, treedef=treedef, - static_idx=static_idx, indices_are_sorted=indices_are_sorted, + indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode, normalize_indices=normalize_indices) if out_sharding is not None: return auto_axes(internal_scatter, out_sharding=out_sharding, axes=out_sharding.mesh.explicit_axes # type: ignore )(x, y, dynamic_idx) - return internal_scatter(x, y, dynamic_idx) + return internal_scatter(x, y, tuple(dynamic_idx)) # TODO(phawkins): re-enable jit after fixing excessive recompilation for @@ -92,7 +93,7 @@ def _scatter_update(x: ArrayLike, idx: Index | tuple[Index, ...], # @jit(static_argnums=(2, 3, 4)) def _scatter_impl(x: ArrayLike, y: ArrayLike, dynamic_idx: tuple[Any, ...], *, scatter_op: Callable[..., Array], - treedef: tree_util.PyTreeDef, static_idx: tuple[Any, ...], + treedef: tree_util.PyTreeDef, indices_are_sorted: bool, unique_indices: bool, mode: slicing.GatherScatterMode | str | None, normalize_indices: bool): dtype = lax.dtype(x) @@ -107,9 +108,8 @@ def _scatter_impl(x: ArrayLike, y: ArrayLike, dynamic_idx: tuple[Any, ...], *, "In future JAX releases this will result in an error.", FutureWarning) - idx = indexing.merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx) - indexer = indexing.index_to_gather(np.shape(x), idx, core.typeof(x).sharding, - normalize_indices=normalize_indices) + general_indexer = tree_util.tree_unflatten(treedef, dynamic_idx) + indexer = general_indexer.to_gather(core.typeof(x).sharding, normalize_indices=normalize_indices) # Avoid calling scatter if the slice shape is empty, both as a fast path and # to handle cases like zeros(0)[array([], int32)]. diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index 742f5b90c5d3..c8f1a2824731 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -473,7 +473,7 @@ def test_simple_indexing(self, name, shape, dtype, indexer, strategy): ((2, 3), ([1, 2], 0), TypeError, "static_slice: indices must be static scalars or slices."), ((2, 3), (np.arange(2), 0), TypeError, "static_slice: indices must be static scalars or slices."), ((2, 3), (None, 0), TypeError, "static_slice: got None at position 0"), - ((2, 3), (1, 2, 3), IndexError, "Too many indices: 2-dimensional array indexed with 3 regular indices"), + ((2, 3), (1, 2, 3), IndexError, "Too many indices: array is 2-dimensional, but 3 were indexed"), ) def test_slice_oob_indexing_fails(self, shape, idx, err, msg): arr = jnp.zeros(shape) @@ -1335,11 +1335,11 @@ def _check_raises(x_type, y_type, msg): def testWrongNumberOfIndices(self): with self.assertRaisesRegex( IndexError, - "Too many indices: 0-dimensional array indexed with 1 regular index."): + "Too many indices: array is 0-dimensional, but 1 were indexed"): jnp.array(1)[0] with self.assertRaisesRegex( IndexError, - "Too many indices: 1-dimensional array indexed with 2 regular indices."): + "Too many indices: array is 1-dimensional, but 2 were indexed"): jnp.zeros(3)[:, 5] @jtu.sample_product(shape=[(), (1,)]) @@ -1350,6 +1350,13 @@ def testIndexDtypePromotion(self, shape): expected = np.array(999).reshape(shape) self.assertArraysEqual(numbers[999, idx], expected) + def testIndexingTypedNdArray(self): + x = jnp.arange(4) + i = dtypes.canonicalize_value(np.array([2, 0, 1])) + result = x[i] + expected = x[jnp.asarray(i)] + self.assertArraysEqual(result, expected) + def _broadcastable_shapes(shape): """Returns all shapes that broadcast to `shape`.""" @@ -1863,8 +1870,9 @@ class ValidateIndicesTest(jtu.JaxTestCase): ((2, 3), np.index_exp[..., -4], IndexError, "index -4 out of bounds for axis 1 with size 3"), ((2, 3, 5), np.index_exp[3, :, 0], IndexError, "index 3 out of bounds for axis 0 with size 2"), ((2, 3, 5), np.index_exp[:5, :, 6], IndexError, "index 6 out of bounds for axis 2 with size 5"), + ((2, 3, 5), np.index_exp[:, [1, 2], 6], IndexError, "index 6 out of bounds for axis 2 with size 5"), ((2, 3, 5), np.index_exp[np.arange(3), 6, None], IndexError, "index 6 out of bounds for axis 1 with size 3"), - ((2, 3), (1, 2, 3), IndexError, "Too many indices: 2-dimensional array indexed with 3 regular indices"), + ((2, 3), (1, 2, 3), IndexError, "Too many indices: array is 2-dimensional, but 3 were indexed"), ) def test_out_of_bound_indices(self, shape, idx, err, msg): """Test that out-of-bound indexing """ From ed4d8252b64cd6e1c614721e915e6670bcb48928 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 16 Dec 2025 10:59:51 -0800 Subject: [PATCH 227/315] Fix spmd_axis_name assert with explicit_mesh_axis in presence of multi-character mesh axis name PiperOrigin-RevId: 845351158 --- jax/_src/api.py | 6 +++--- tests/pjit_test.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 635e1d09195b..5c000437adc4 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -1191,9 +1191,9 @@ def vmap_f(*args, **kwargs): _mapped_axis_size(fun, in_tree, args_flat, in_axes_flat, "vmap")) explicit_mesh_axis = _mapped_axis_spec(args_flat, in_axes_flat) if spmd_axis_name is not None and explicit_mesh_axis is not None: - spmd_axis_name = ( - tuple(*core.remove_size_one_mesh_axis(P(spmd_axis_name), get_abstract_mesh())) - if config.remove_size_one_mesh_axis_from_type.value else spmd_axis_name) + if config.remove_size_one_mesh_axis_from_type.value: + mesh = get_abstract_mesh() + spmd_axis_name = tuple(i for i in spmd_axis_name if mesh.shape[i] != 1) if spmd_axis_name == explicit_mesh_axis: spmd_axis_name = None else: diff --git a/tests/pjit_test.py b/tests/pjit_test.py index cfdfce42a511..210cb940c5b3 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -7275,11 +7275,11 @@ def f(x): f(arr) @parameterized.parameters( - (('x', 'y', 'z'), ('x', 'y')), - (('x', 'z'), 'x') + (('data', 'model', 'stage'), ('data', 'model')), + (('data', 'stage'), 'data') ) @config.remove_size_one_mesh_axis_from_type(True) - @jtu.with_explicit_mesh((2, 2, 1), ('x', 'y', 'z')) + @jtu.with_explicit_mesh((2, 2, 1), ('data', 'model', 'stage')) def test_spmd_axis_name_explicit_mode_assert_remove_one_size( self, in_spec, out_spec, mesh): np_inp = np.arange(16).reshape(4, 2, 2) From ba024e35222f9bebadfe7f67395246e11e1eefc5 Mon Sep 17 00:00:00 2001 From: Subhankar Shah Date: Tue, 16 Dec 2025 11:43:52 -0800 Subject: [PATCH 228/315] Remove dynamic grid bounds restriction in Pallas Mosaic with memory spaces. PiperOrigin-RevId: 845370892 --- .../pallas/mosaic/pallas_call_registration.py | 5 ---- tests/pallas/tpu_pallas_test.py | 27 +++++++++++++++++++ 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 125eafa4eaf1..8cebd5965547 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -208,11 +208,6 @@ def _maybe_cast_inputs(*args): isinstance(aval, pallas_core.ShapedArrayWithMemorySpace) for aval in ctx.avals_in ): - # TODO(sharadmv): Support dynamic grid bounds. - if num_dyn_bounds != 0: - raise NotImplementedError( - "Dynamic grid bounds are not supported when specifying memory spaces for inputs." - ) input_memory_spaces = _get_memory_spaces_from_avals( ctx.avals_in, kernel_type=kernel_type ) diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index cddfb9905ea9..36a21127a798 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -684,6 +684,33 @@ def dynamic_kernel(steps): dynamic_kernel(jnp.int32(4)), np.full(shape, 42.0, np.float32) ) + def test_dynamic_grid_scalar_input_with_input_memory_space(self): + if not jtu.is_device_tpu_at_least(5): + self.skipTest('Needs a newer TPU') + shape = (8, 128) + result_ty = jax.ShapeDtypeStruct(shape, jnp.float32) + + def kernel(scalar_input_ref, output_ref): + output_ref[...] = jnp.full_like(output_ref, scalar_input_ref[0, 0]) + + @jax.jit + def dynamic_kernel(steps): + scalar_input = jnp.array([[42]], dtype=jnp.int32) + scalar_input = pltpu.with_memory_space_constraint( + scalar_input, pltpu.VMEM + ) + return self.pallas_call( + kernel, + out_shape=result_ty, + in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)], + out_specs=pl.BlockSpec(shape, lambda i: (0, 0)), + grid=(steps * 2,), + )(scalar_input) + + np.testing.assert_array_equal( + dynamic_kernel(jnp.int32(4)), np.full(shape, 42.0, np.float32) + ) + def test_vmap_trivial_dynamic_grid(self): shape = (8, 128) result_ty = jax.ShapeDtypeStruct(shape, jnp.float32) From ad2b91459d365e4b024c7ac2eb5a9815a223a601 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Tue, 16 Dec 2025 14:29:31 -0800 Subject: [PATCH 229/315] [Pallas] Allow no mesh context if just signaling on core axis PiperOrigin-RevId: 845435562 --- jax/_src/pallas/primitives.py | 44 ++++++++++++++++----------- tests/pallas/tpu_pallas_state_test.py | 18 +++++++++++ 2 files changed, 44 insertions(+), 18 deletions(-) diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 8cf362d0bc3b..2706f2ecc9b2 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -1351,20 +1351,26 @@ def _semaphore_wait_discharge_rule(in_avals, ) -def _device_id_dict_to_mesh(mesh_context: pallas_utils.MeshInfo, device_id_dict, get_axis_index): +def _device_id_dict_to_mesh(mesh_context: pallas_utils.MeshInfo | None, device_id_dict, get_axis_index): i32 = ir.IntegerType.get_signless(32) - assert mesh_context is not None - mesh_axis_sizes = dict(zip(mesh_context.axis_names, mesh_context.mesh_shape)) + if mesh_context is None: + mesh_axis_sizes = {} + else: + mesh_axis_sizes = dict( + zip(mesh_context.axis_names, mesh_context.mesh_shape) + ) physical_axis_dict = {} # Handle joint axes (i.e., one logical axis over >1 physical axes) - for axis, idx in device_id_dict.items(): - if isinstance(axis, tuple) and any(a in mesh_context.axis_names for a in axis): - if not all(a in mesh_context.axis_names for a in axis): + for axis_name, idx in device_id_dict.items(): + if isinstance(axis_name, tuple) and any( + a in mesh_axis_sizes for a in axis_name + ): + if not all(a in mesh_axis_sizes for a in axis_name): raise NotImplementedError( - f"{axis} mixes JAX mesh and Pallas mesh grid axes" + f"{axis_name} mixes JAX mesh and Pallas mesh grid axes" ) - axes_dimensions = [mesh_axis_sizes[name] for name in axis] - for axis_index, axis_name in enumerate(axis): + axes_dimensions = [mesh_axis_sizes[name] for name in axis_name] + for axis_index, axis_name in enumerate(axis_name): axis_size = mesh_axis_sizes[axis_name] inner_mesh_size = math.prod(axes_dimensions[axis_index + 1 :]) minor_divisor = arith.constant(i32, inner_mesh_size) @@ -1387,17 +1393,17 @@ def _device_id_dict_to_mesh(mesh_context: pallas_utils.MeshInfo, device_id_dict, ) physical_axis_dict[axis_name] = device_idx else: - physical_axis_dict[axis] = idx + physical_axis_dict[axis_name] = idx device_id = [] - for axis in mesh_context.axis_names: - if axis in physical_axis_dict: - device_id.append(physical_axis_dict[axis]) + for axis_name in mesh_axis_sizes: + if axis_name in physical_axis_dict: + device_id.append(physical_axis_dict[axis_name]) else: - device_id.append(get_axis_index(axis)) + device_id.append(get_axis_index(axis_name)) non_mesh_axes = { k: v for k, v in physical_axis_dict.items() - if k not in mesh_context.axis_names + if k not in mesh_axis_sizes } return tuple(device_id), non_mesh_axes @@ -1419,13 +1425,15 @@ def device_id_to_logical( "`device_id_type` must be MESH if `device_id` is a dict," f" got: {device_id_type = }." ) - assert mesh_context is not None device_id, non_mesh_axes = _device_id_dict_to_mesh(mesh_context, device_id, get_axis_index) if device_id_type is DeviceIdType.MESH: - assert mesh_context is not None # Mesh means we are passed the mesh coordinates for the device device_ids = tree_util.tree_leaves(device_id) - mesh_strides = mesh_context.mesh_strides + mesh_strides: tuple[int, ...] + if mesh_context is None: + mesh_strides = () + else: + mesh_strides = mesh_context.mesh_strides if len(device_ids) != len(mesh_strides): raise ValueError( "Number of device ids must match the number of mesh axes, but got" diff --git a/tests/pallas/tpu_pallas_state_test.py b/tests/pallas/tpu_pallas_state_test.py index fa5ce0778ecc..f0707280ff55 100644 --- a/tests/pallas/tpu_pallas_state_test.py +++ b/tests/pallas/tpu_pallas_state_test.py @@ -273,6 +273,24 @@ def _(): "Attempted to lower core_map without discharging."): f(x) + def test_can_signal_cores(self): + @jax.jit + def f(x): + x_ref = jax.new_ref(x) + y_ref = jax.new_ref(jnp.empty_like(x)) + @pl.core_map(pltpu.create_tensorcore_mesh("x")) + def _(): + @functools.partial(pl.run_scoped, sem=pltpu.SemaphoreType.REGULAR) + def inner(sem): + s = jax.lax.axis_size("x") + for i in range(s): + pl.semaphore_signal(sem, device_id={"x": i}) + pl.semaphore_wait(sem, s) + pltpu.sync_copy(x_ref, y_ref) + return jax.freeze(y_ref) + x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128)) + np.testing.assert_array_equal(f(x), x) + def test_can_query_core_index(self): mesh = pltpu.create_tensorcore_mesh("x") slc_size = 16 // mesh.shape["x"] From 9a391d74408f969057f6051d65727831a2eb0610 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 16 Dec 2025 16:31:11 -0800 Subject: [PATCH 230/315] Add `to_cotangent_aval` to HiType and use it by default in bwd pass PiperOrigin-RevId: 845479579 --- jax/_src/api.py | 4 +--- jax/_src/core.py | 4 ++++ jax/_src/hijax.py | 12 ++++++++++++ 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 5c000437adc4..94dc483fd805 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2332,9 +2332,7 @@ def _vjp_check_ct_avals(cts, primal_avals): # TODO(mattjj): improve this error by flattening with keys in the first place for ct, aval in zip(cts, primal_avals): ct_aval = typeof(ct) - ct_aval_expected = ( - aval.to_cotangent_aval() if hasattr(aval, 'to_cotangent_aval') else - aval.to_tangent_aval()) + ct_aval_expected = aval.to_cotangent_aval() if (not core.typecompat(ct_aval, ct_aval_expected) and not _temporary_dtype_exception(ct_aval, ct_aval_expected)): raise ValueError( diff --git a/jax/_src/core.py b/jax/_src/core.py index 88c3c402250d..2d909156ad11 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1654,6 +1654,9 @@ class AbstractValue: def to_tangent_aval(self): raise NotImplementedError("must override") + def to_cotangent_aval(self): + raise NotImplementedError("must override") + # TODO(dougalm): deprecate this alias def at_least_vspace(self): return self.to_tangent_aval() @@ -2619,6 +2622,7 @@ def accum_grad_in_ref(x): class AbstractToken(AbstractValue): def str_short(self, short_dtypes=False, mesh_axis_types=False): return 'Tok' def to_tangent_aval(self): return self + def to_cotangent_aval(self): return self abstract_token: AbstractToken = AbstractToken() # Singleton shaped array used by all abstract tokens when shape/dtype is needed. diff --git a/jax/_src/hijax.py b/jax/_src/hijax.py index d35da8f88752..fcb9b9b8f9f3 100644 --- a/jax/_src/hijax.py +++ b/jax/_src/hijax.py @@ -92,9 +92,16 @@ def raise_val(self, *lo_vals: LoVal) -> HiVal: # autodiff interface def to_tangent_aval(self) -> HiType: assert False, "must override" + + # Subclasses should override if the cotangent type is a function of primal + # type. For example, CT unreduced = reduced and vice-versa. + def to_cotangent_aval(self) -> HiType: + return self.to_tangent_aval() + # the next two are required if this type is itself a tangent type def vspace_zero(self) -> HiVal: assert False, "must override" + def vspace_add(self, x: HiVal, y: HiVal) -> HiVal: assert False, "must override" @@ -127,6 +134,11 @@ def update_from_loval(self, state: QDD, val: HiVal, *lo_vals: LoVal) -> None: def to_tangent_aval(self) -> HiType: assert False, "must override" + # Subclasses should override if the cotangent type is a function of primal + # type. For example, CT unreduced = reduced and vice-versa. + def to_cotangent_aval(self) -> HiType: + return self.to_tangent_aval() + def register_hitype(val_cls, typeof_fn) -> None: core.pytype_aval_mappings[val_cls] = typeof_fn dtypes.canonicalize_value_handlers[val_cls] = lambda x: x From 691bf927f80ce103aee816a1537fc055522578f4 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Tue, 16 Dec 2025 19:12:33 -0800 Subject: [PATCH 231/315] [pmap] Created `NamedSharding` arrays when `jax_pmap_shmap_merge=True` in `jax.device_put_replicated` and `jax.device_put_sharded`. PiperOrigin-RevId: 845530346 --- jax/_src/api.py | 18 +++++++++++++----- tests/pmap_test.py | 5 ++++- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 94dc483fd805..01926d919860 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -68,7 +68,7 @@ from jax._src.lib import xla_client as xc from jax._src.lib import pmap_lib from jax._src.sharding import Sharding -from jax._src.mesh import get_concrete_mesh, get_abstract_mesh +from jax._src.mesh import get_concrete_mesh, get_abstract_mesh, Mesh from jax._src.sharding_impls import (PmapSharding, PartitionSpec as P, NamedSharding) from jax._src.layout import Format @@ -2800,8 +2800,12 @@ def _device_put_sharded(*xs): raise ValueError("the shards passed to device_put_sharded must have " f"consistent shape and dtype, but got {a1} and {a2}.") stacked_aval = avals[0].update(shape=(len(devices),) + avals[0].shape) - sharding_spec = sharding_specs.create_pmap_sharding_spec(stacked_aval.shape) - sharding = PmapSharding(np.array(devices), sharding_spec) + if config.pmap_shmap_merge.value: + mesh = Mesh(np.array(devices), ('_device_put_sharded',)) + sharding = NamedSharding(mesh, P('_device_put_sharded')) + else: + sharding_spec = sharding_specs.create_pmap_sharding_spec(stacked_aval.shape) + sharding = PmapSharding(np.array(devices), sharding_spec) if dtypes.issubdtype(stacked_aval.dtype, dtypes.extended): return stacked_aval.dtype._rules.device_put_sharded(xs, stacked_aval, sharding, devices) if config.pmap_no_rank_reduction.value: @@ -2856,7 +2860,6 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811 def _device_put_replicated(x): aval = core.unmapped_aval(len(devices), 0, core.get_aval(x)) assert isinstance(aval, ShapedArray) - sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape) if config.pmap_no_rank_reduction.value: if isinstance(x, (np.ndarray, basearray.Array)): buf = device_put(x[None], devices[0]) @@ -2864,7 +2867,12 @@ def _device_put_replicated(x): buf = device_put(x, devices[0])[None] else: buf = device_put(x, devices[0]) - sharding = PmapSharding(np.array(devices), sharding_spec) + if config.pmap_shmap_merge.value: + mesh = Mesh(np.array(devices), ('_device_put_replicated',)) + sharding = NamedSharding(mesh, P('_device_put_replicated')) + else: + sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape) + sharding = PmapSharding(np.array(devices), sharding_spec) if dtypes.issubdtype(aval.dtype, dtypes.extended): return aval.dtype._rules.device_put_replicated(buf, aval, sharding, devices) return pxla.batched_device_put(aval, sharding, [buf] * len(devices), devices) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index e396f7df0151..6508e2a458d9 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -2951,7 +2951,10 @@ def test_device_put_sharded(self): x = [np.arange(i, i + 4) for i in range(n_devices)] y = jax.device_put_sharded(x, devices) self.assertIsInstance(y, array.ArrayImpl) - self.assertIsInstance(y.sharding, jax.sharding.PmapSharding) + if config.pmap_shmap_merge.value: + self.assertIsInstance(y.sharding, jax.NamedSharding) + else: + self.assertIsInstance(y.sharding, jax.sharding.PmapSharding) for s in y.addressable_shards: self.assertArraysEqual(s.data, y[s.index]) self.assertEqual(s.replica_id, 0) From 4826ca7dfb1d3cce5ec471b971f1c891096e1717 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Tue, 16 Dec 2025 19:50:38 -0800 Subject: [PATCH 232/315] Make c-api topology and PjRtClient versions produce identical platform_verison strings to improve cache reuse between aot and actual runtime uses. Remove runtime_type as a cache key (compilation shouldn't depend on the runtime). PiperOrigin-RevId: 845541086 --- jax/_src/cache_key.py | 1 - 1 file changed, 1 deletion(-) diff --git a/jax/_src/cache_key.py b/jax/_src/cache_key.py index 493e49a2e086..296f65b5ed3f 100644 --- a/jax/_src/cache_key.py +++ b/jax/_src/cache_key.py @@ -334,7 +334,6 @@ def _hash_serialized_compile_options(hash_obj, compile_options_obj, def _hash_platform(hash_obj, backend): _hash_string(hash_obj, backend.platform) _hash_string(hash_obj, backend.platform_version) - _hash_string(hash_obj, backend.runtime_type) def _hash_xla_flags(hash_obj, extra_flag_prefixes: list[str]): From 27d29b9d205a5bbea5d38d877956091fdade5a30 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 16 Dec 2025 20:29:15 -0800 Subject: [PATCH 233/315] Add `reshard` to shard_map under full explicit mode if the input aval's sharding differs from `in_specs` passed to shard_map. Since `in_specs` of shard_map signal a reshard historically, we follow the same semantics here. Without this, the primal and cotangent types don't match. Also add checks in bwd pass to make sure input primal type matches the cotangent type. Some TODOs: * Figure out `normalize` i.e. weak_types with HiType because typematch strips away weak_types. * Support partial-manual in shard_map's reshard code PiperOrigin-RevId: 845555625 --- jax/_src/core.py | 31 ++++++++++++++++++------------- jax/_src/custom_derivatives.py | 15 ++++++++------- jax/_src/interpreters/ad.py | 33 +++++++++++++++++++++++++++------ jax/_src/pjit.py | 34 +++++++++++++++++++++++----------- jax/_src/shard_map.py | 10 ++++++++++ jax/_src/stages.py | 2 +- jax/_src/state/types.py | 7 ++++++- tests/custom_api_test.py | 2 +- tests/hijax_test.py | 10 ++++++++++ tests/pjit_test.py | 32 ++++++++++++++++---------------- tests/shard_map_test.py | 27 +++++++++++++++++++++------ 11 files changed, 141 insertions(+), 62 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 2d909156ad11..1b64ae57c79e 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -3084,7 +3084,8 @@ def typecompat(aval_ref: AbstractValue, aval: AbstractValue) -> bool: except TypeError: return False -def typematch(t1: AbstractValue, t2: AbstractValue) -> bool: +def typematch(t1: AbstractValue, t2: AbstractValue, + only_sharding_check: bool = False) -> bool: """Determine whether `t1` and `t2` are equivalent. Ignores weak_type.""" t1 = t1.normalize() t2 = t2.normalize() @@ -3092,25 +3093,29 @@ def typematch(t1: AbstractValue, t2: AbstractValue) -> bool: if t1 == t2: return True elif isinstance(t1, ShapedArray) and isinstance(t2, ShapedArray): - cmp = (t1.dtype == t2.dtype and definitely_equal_shape(t1.shape, t2.shape) - and t1.vma == t2.vma and t1.memory_space == t2.memory_space) # type: ignore - # TODO(yashkatariya): Expand this to Manual and Auto mode. - # See https://github.com/jax-ml/jax/issues/26474 - if (not t1.sharding.mesh.empty and not t2.sharding.mesh.empty and - (t1.sharding.mesh._any_axis_explicit or - t2.sharding.mesh._any_axis_explicit)): - sh_eq = t1.sharding == t2.sharding - else: - sh_eq = True - return cmp and sh_eq + if only_sharding_check: + return cmp_sharding_vma(t1, t2) + return (t1.dtype == t2.dtype and definitely_equal_shape(t1.shape, t2.shape) + and cmp_sharding_vma(t1, t2) and t1.memory_space == t2.memory_space) elif isinstance(t1, AbstractRef) and isinstance(t2, AbstractRef): # We want to use the regular typecheck for ShapedArray here. - return (typematch(t1.inner_aval, t2.inner_aval) and # type: ignore + return (typematch(t1.inner_aval, t2.inner_aval, only_sharding_check) and # type: ignore (t1.memory_space is None or t2.memory_space is None or # type: ignore t1.memory_space == t2.memory_space)) # type: ignore else: return False +def cmp_sharding_vma(t1, t2): + # TODO(yashkatariya): Expand this to Manual and Auto mode. + # See https://github.com/jax-ml/jax/issues/26474 + if (not t1.sharding.mesh.empty and not t2.sharding.mesh.empty and + (t1.sharding.mesh._any_axis_explicit or + t2.sharding.mesh._any_axis_explicit)): + sh_eq = t1.sharding == t2.sharding + else: + sh_eq = True + return sh_eq and t1.vma == t2.vma + def aval_mismatch_extra(a1: AbstractValue, a2: AbstractValue) -> str: assert not typematch(a1, a2) if isinstance(a1, ShapedArray) and isinstance(a2, ShapedArray): diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 4f5ba126ea75..4c7c87c39817 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -964,14 +964,14 @@ def append(x, d): raise ValueError(msg) results.append(Zero(ct.aval)) else: - if (not core.typecompat(a.to_tangent_aval(), a_ := core.get_aval(ct)) and - not _ref_typecompat(a.to_tangent_aval(), a_) and - not _temporary_dtype_exception(a, a_)): + if (not core.typecompat(a.to_cotangent_aval(), a_ := core.get_aval(ct)) + and not _ref_typecompat(a.to_cotangent_aval(), a_) + and not _temporary_dtype_exception(a.to_cotangent_aval(), a_)): msg = ("Custom VJP bwd rule must produce an output with the same " - "shape/dtypes as the args tuple of the primal function, but at " + "type as the args tuple of the primal function, but at " f"output{keystr(kp)} the bwd rule produced an output of " - f"shape/dtype {a_.str_short()} corresponding " - f"to an input of shape/dtype {a.str_short()}" + f"type {a_.str_short()} corresponding " + f"to an input of type {a.str_short()}" f"{core.aval_mismatch_extra(a, a_)}") raise ValueError(msg) results.append(ct) @@ -979,12 +979,13 @@ def append(x, d): def _ref_typecompat(a, a_): return (isinstance(a, AbstractRef) and - core.typecompat(a.to_tangent_aval().inner_aval, a_)) + core.typecompat(a.to_cotangent_aval().inner_aval, a_)) # TODO(mattjj): remove both these exceptions to cotangent compatibility check def _temporary_dtype_exception(a, a_) -> bool: if isinstance(a, core.ShapedArray) and isinstance(a_, core.ShapedArray): return (a.shape == a_.shape and + core.typematch(a, a_, only_sharding_check=True) and (dtypes.issubdtype(a_.dtype, dtypes.extended) or dtypes.issubdtype(a.dtype, dtypes.np.inexact))) return False diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 459bf5397f92..5696fc3ac11d 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -71,7 +71,7 @@ def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True, def jvpfun(f: Callable, instantiate, transform_stack, primals, tangents): tag = core.TraceTag() tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero) - and isinstance(core.typeof(t), core.ShapedArray) + and isinstance(typeof(t), core.ShapedArray) and dtype(t) == float0 else t for t in tangents] ctx = (source_info_util.transform_name_stack('jvp') if transform_stack else contextlib.nullcontext()) @@ -237,7 +237,7 @@ def direct_linearize(traceable: lu.WrappedFun, primals, kwargs, *, tangent_trace = pe.DynamicJaxprTrace(dbg, auto_dce=True) tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval(), source_info) for p in primals] tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero) - and isinstance(core.typeof(t), core.ShapedArray) + and isinstance(typeof(t), core.ShapedArray) and dtype(t) == float0 else t for t in tangents] linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=tag) tangent_trace.tag = linearize_trace.tag @@ -319,6 +319,13 @@ def write_cotangent(prim, v, ct): # FIXME: This triggers a lot of failures! # assert v.aval == ct.aval, (prim, v.aval, ct.aval) return + ct_aval = typeof(ct) + ct_aval_expected = v.aval.to_cotangent_aval() # type: ignore + if not core.typematch(ct_aval, ct_aval_expected, only_sharding_check=True): + raise ValueError( + f"Input primal JAX type to {prim.name} is {v.aval.str_short()}. Hence" + f" the expected cotangent type is {ct_aval_expected.str_short()} but" + f" got {ct_aval.str_short()}") ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct def read_cotangent(v): @@ -548,6 +555,7 @@ def __init__(self, aval, ref=None): def accum(self, x): assert x is not Zero + ct_check(self, x) if isinstance(x, Zero) or x is None: return elif self.ref is None: @@ -575,12 +583,25 @@ def __init__(self, aval, val=None): self.val = Zero(aval) if val is None else val def accum(self, x): + ct_check(self, x) if x is not None: self.val = add_tangents(self.val, x) def freeze(self): return self.val +def ct_check(primal, ct): + ct_aval = ct.aval if type(ct) is Zero else typeof(ct) + ct_aval_expected = primal.aval.to_cotangent_aval() # type: ignore + if not core.typematch(ct_aval, ct_aval_expected, only_sharding_check=True): + # TODO(yashkatariya, mattjj): Add primitive name here for + # better error message? + raise ValueError( + f"Input primal JAX type to VJP function is" + f" {primal.aval.str_short()}. Hence the expected" + f" cotangent type is {ct_aval_expected.str_short()} but" + f" got {ct_aval.str_short()}") + class NullAccum(GradAccum): def __init__(self): pass def accum(self, x): return @@ -619,7 +640,7 @@ def accum_typeof(x): if isinstance(x, GradAccum): return x.aval else: - return core.typeof(x) + return typeof(x) @lu.transformation_with_aux2 @@ -647,7 +668,7 @@ def process_primitive(self, primitive, tracers, params): primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers)) if (all(type(t) is Zero for t in tangents_in) and primitive is not core.ref_p and - not any(isinstance(core.typeof(x), AbstractRef) for x in primals_in)): + not any(isinstance(typeof(x), AbstractRef) for x in primals_in)): return primitive.bind_with_trace(self.parent_trace, primals_in, params) jvp = primitive_jvps.get(primitive) if not jvp: @@ -779,7 +800,7 @@ def process_custom_transpose(self, prim, call, tracers, **params): def maybe_jvp_tracer(trace, primal, tangent): if (type(tangent) is Zero or - isinstance(core.typeof(tangent), core.ShapedArray) + isinstance(typeof(tangent), core.ShapedArray) and dtype(tangent) == float0): return primal else: @@ -871,7 +892,7 @@ def process_primitive(self, primitive, args, params): tangent_nzs = [type(t) is not Zero for t in tangents_in] if (all(type(t) is Zero for t in tangents_in) and primitive is not core.ref_p and - not any(isinstance(core.typeof(x), AbstractRef) for x in primals_in)): + not any(isinstance(typeof(x), AbstractRef) for x in primals_in)): return primitive.bind_with_trace(self.parent_trace, primals_in, params) fallback = partial(fallback_linearize_rule, primitive) lin = primitive_linearizations.get(primitive, fallback) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 00738dc31194..26efa7b85bd9 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2694,40 +2694,51 @@ def reshard(xs, out_shardings): f'and have a nonempty mesh. Got sharding {s}.' ) ds = ds.update(spec=ds.spec._normalized_spec_for_aval(x_aval.ndim)) # pytype: disable=attribute-error - out_flat.append(reshard_p.bind(x, dst_sharding=ds)) + cmesh = (s.mesh if (isinstance(s, NamedSharding) and + isinstance(s.mesh, mesh_lib.Mesh)) + else None) + out_flat.append(reshard_p.bind(x, dst_sharding=ds, concrete_mesh=cmesh)) return tree_unflatten(treedef, out_flat) reshard_p = core.Primitive('reshard') reshard_p.skip_canonicalization = True -def _reshard_abstract_eval(aval, dst_sharding): +def _reshard_abstract_eval(aval, *, dst_sharding, concrete_mesh): assert isinstance(aval, core.ShapedArray) if aval.sharding == dst_sharding: return aval return aval.update(sharding=dst_sharding) reshard_p.def_abstract_eval(_reshard_abstract_eval) -def _reshard_impl(x, dst_sharding): - return dispatch.apply_primitive(reshard_p, x, dst_sharding=dst_sharding) +def _reshard_impl(x, *, dst_sharding, concrete_mesh): + thunk = lambda: dispatch.apply_primitive( + reshard_p, x, dst_sharding=dst_sharding, concrete_mesh=concrete_mesh) + if concrete_mesh is None: + return thunk() + else: + with sharding_impls.set_mesh(concrete_mesh): + return thunk() reshard_p.def_impl(_reshard_impl) -def _reshard_transpose_rule(ct, x, dst_sharding): +def _reshard_transpose_rule(ct, x, *, dst_sharding, concrete_mesh): assert ad.is_undefined_primal(x) out_sharding = x.aval.to_cotangent_aval().sharding with mesh_lib.use_abstract_mesh(out_sharding.mesh): - x_bar = reshard_p.bind(ct, dst_sharding=out_sharding) + x_bar = reshard_p.bind(ct, dst_sharding=out_sharding, + concrete_mesh=concrete_mesh) return [x_bar] ad.deflinear2(reshard_p, _reshard_transpose_rule) -def _reshard_transpose_fancy(ct, x, dst_sharding): +def _reshard_transpose_fancy(ct, x, *, dst_sharding, concrete_mesh): assert isinstance(x, ad.GradAccum) out_sharding = x.aval.to_cotangent_aval().sharding with mesh_lib.use_abstract_mesh(out_sharding.mesh): - x_bar = reshard_p.bind(ct, dst_sharding=out_sharding) + x_bar = reshard_p.bind(ct, dst_sharding=out_sharding, + concrete_mesh=concrete_mesh) x.accum(x_bar) ad.fancy_transposes[reshard_p] = _reshard_transpose_fancy -def _reshard_hlo_lowering(ctx, x_node, *, dst_sharding): +def _reshard_hlo_lowering(ctx, x_node, *, dst_sharding, concrete_mesh): aval_in, = ctx.avals_in aval_out, = ctx.avals_out if dtypes.issubdtype(aval_in.dtype, dtypes.extended): @@ -2738,12 +2749,13 @@ def _reshard_hlo_lowering(ctx, x_node, *, dst_sharding): return [mlir.lower_with_sharding_in_types(ctx, x_node, aval_out, proto)] mlir.register_lowering(reshard_p, _reshard_hlo_lowering) -def _reshard_batcher(axis_data, vals_in, dims_in, dst_sharding): +def _reshard_batcher(axis_data, vals_in, dims_in, dst_sharding, concrete_mesh): x, = vals_in d, = dims_in vmapped_dst_sharding = batching.get_sharding_for_vmap( axis_data, dst_sharding, d) - y = reshard_p.bind(x, dst_sharding=vmapped_dst_sharding) + y = reshard_p.bind(x, dst_sharding=vmapped_dst_sharding, + concrete_mesh=concrete_mesh) return y, d batching.fancy_primitive_batchers[reshard_p] = _reshard_batcher batching.skippable_batchers[reshard_p] = lambda _: () diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index d2682253b6be..d77c1299fcf0 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -40,6 +40,7 @@ from jax._src.mesh import (AbstractMesh, Mesh, BaseMesh, AxisType, use_abstract_mesh, get_abstract_mesh, get_concrete_mesh) +from jax._src.pjit import reshard from jax._src.lax import lax, parallel as lax_parallel from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo, sdy @@ -248,6 +249,15 @@ def wrapped(*args): fun = _implicit_pvary_on_output(fun, out_specs_thunk) fun = _implicit_unreduced_on_output(fun, out_specs_thunk) + # TODO(yashkatariya): Add support for partial manual + mesh_axis_names_wo_vmap = ( + frozenset(mesh.axis_names) - core.get_axis_env().explicit_mesh_axis_names) + if (mesh_axis_names_wo_vmap == axis_names and + all(mesh._name_to_type[a] == AxisType.Explicit for a in axis_names)): + args_flat = [a if typeof(a).sharding.spec == s + else reshard(a, NamedSharding(mesh, s)) + for a, s in zip(args_flat, in_specs_flat)] + try: out_flat = shard_map_p.bind( fun, *args_flat, mesh=mesh, in_specs=in_specs_flat, diff --git a/jax/_src/stages.py b/jax/_src/stages.py index af87fc736210..387fca3ee460 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -1038,7 +1038,7 @@ def _device_assignment_mismatch_error(fun_name, fails, args_flat, api_name, first, second = mismatched_args_msg # pytype: disable=bad-unpacking extra_msg = f" Got {first} and {second}" elif len(mismatched_args_msg) == 1: - first, second = fails + first, second = fails # Choose the failure left which is not already covered by ARG_SHARDING. left = second if first.m_type == MismatchType.ARG_SHARDING else first extra_msg = f" Got {mismatched_args_msg[0]} and{left._str(api_name)}" diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index e9f589a10f27..3c7c09684f99 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -555,7 +555,12 @@ def __repr__(self) -> str: __str__ = __repr__ def to_tangent_aval(self): - return AbstractRef(self.inner_aval.to_tangent_aval(), self.memory_space, kind=self.kind) + return AbstractRef(self.inner_aval.to_tangent_aval(), self.memory_space, + kind=self.kind) + + def to_cotangent_aval(self): + return AbstractRef(self.inner_aval.to_cotangent_aval(), self.memory_space, + kind=self.kind) def __eq__(self, other): return (type(self) is type(other) and self.inner_aval == other.inner_aval diff --git a/tests/custom_api_test.py b/tests/custom_api_test.py index 9a5b762cefe7..143fe6ce01a1 100644 --- a/tests/custom_api_test.py +++ b/tests/custom_api_test.py @@ -2968,7 +2968,7 @@ def foo_bwd(_, g): with self.assertRaisesRegex( ValueError, - r'output\[1\] the bwd rule produced an output of shape/dtype float..\[3\]'): + r'output\[1\] the bwd rule produced an output of type float..\[3\]'): jax.grad(lambda x, y: foo(x, y * y).sum(), 1)(jnp.ones(3), jnp.ones(4)) def test_bwd_rule_can_produce_list_or_tuple(self): diff --git a/tests/hijax_test.py b/tests/hijax_test.py index 2f665ce76aed..bd1e199864d2 100644 --- a/tests/hijax_test.py +++ b/tests/hijax_test.py @@ -171,9 +171,16 @@ def __repr__(self): @dataclass(frozen=True) class TupTy(HiType): tys: tuple[Ty] + def __repr__(self): return 'Tup{' + ','.join(a.str_short() for a in self.tys) + '}' + def __hash__(self): + return hash(self.tys) + + def __eq__(self, other): + return self.tys == other.tys + def lo_ty(self): return list(self.tys) @@ -189,6 +196,9 @@ def raise_val(self, *elts_flat): def to_tangent_aval(self): return TupTy(tuple(ty.to_tangent_aval() for ty in self.tys)) + def normalize(self): + return TupTy(tuple(ty.normalize() for ty in self.tys)) + register_hitype(HiTup, lambda t: TupTy(tuple(map(typeof, t.elts)))) class MakeTup(HiPrimitive): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 210cb940c5b3..243772588c91 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -9439,13 +9439,8 @@ def mubatch_loop_body(grad_acc, xs_mubatch): ws = reshard(ws, P()) return jax.tree.map(lambda W, g: W - g * 0.01, ws, grad_acc) - if use_custom_vjp: - ws = tuple(jax.device_put(jnp.ones((4, 4)), P()) for _ in range(4)) - else: - # Mark `w` with `reduced={'x'}` so that on the bwd pass we will induce - # an `unreduced={'x'}`. - ws = tuple(jax.device_put(jnp.ones((4, 4)), P(reduced={'x'})) - for _ in range(4)) + ws = tuple(jax.device_put(jnp.ones((4, 4)), P(reduced={'x'})) + for _ in range(4)) xs = jax.device_put(jnp.ones((2, 2, 4)), P(None, 'x', None)) step(ws, xs) # doesn't crash @@ -9529,15 +9524,9 @@ def mubatch_loop_body(stacked_grad_acc, xs_mubatch): return jax.tree.map( lambda W, g: W - g * 0.01, stacked_ws, stacked_grad_acc) - if use_custom_vjp: - ws = tuple(jax.device_put(jnp.ones((4, 4), dtype=jnp.float32), P()) - for _ in range(4)) - else: - # Mark `w` with `reduced={'x'}` so that on the bwd pass we will induce - # an `unreduced={'x'}`. - ws = tuple(jax.device_put(jnp.ones((4, 4), dtype=jnp.float32), - P(reduced={'x'})) - for _ in range(4)) + ws = tuple(jax.device_put(jnp.ones((4, 4), dtype=jnp.float32), + P(reduced={'x'})) + for _ in range(4)) xs = jax.device_put(jnp.ones((2, 4, 4), dtype=jnp.bfloat16), P(None, 'x', None)) @@ -9965,6 +9954,17 @@ def f(): self.assertEqual(f().sharding, NamedSharding(mesh, P(None))) + def test_reshard_no_mesh_ctx(self): + mesh = jtu.create_mesh((2,), 'x') + with self.assertRaisesRegex( + ValueError, "cannot contain axis names that are of type Auto"): + reshard(np.arange(8), NamedSharding(mesh, P('x'))) + + mesh = jtu.create_mesh((2,), 'x', axis_types=(AxisType.Explicit,)) + out = reshard(np.arange(8), NamedSharding(mesh, P('x'))) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + self.assertArraysEqual(out, np.arange(8)) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 7419efb3394e..249898952e85 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -2492,11 +2492,10 @@ def f_bwd(_, g): def g(x): return f(f(x)) - y, grad = jax.value_and_grad(lambda x: g(x).sum())(jnp.ones(4)) - # first psum sums, second psum multiplies by 4 - self.assertAllClose(y, (jnp.ones(4) * 4).sum(), check_dtypes=False) - # two psums on the backward pass, each one multiplies by 4 - self.assertAllClose(grad, jnp.ones(4) * 4 * 4, check_dtypes=False) + with self.assertRaisesRegex( + ValueError, + "Custom VJP bwd rule must produce an output with the same type"): + jax.value_and_grad(lambda x: g(x).sum())(jnp.ones(4)) def test_repeated_psum_allowed(self): # https://github.com/jax-ml/jax/issues/19175 @@ -3809,6 +3808,7 @@ def test_explicit_vmap_grad_shmap(self, use_axis_name, mesh): def g(x): self.assertEqual(x.aval.vma, frozenset()) + self.assertEqual(x.aval.sharding.spec, P(None)) if use_axis_name: out = jax.shard_map(jnp.cos, in_specs=P('y'), out_specs=P('y'), axis_names={'y'})(x) @@ -3818,7 +3818,7 @@ def g(x): return out.sum() out = jax.jit(jax.vmap(jax.grad(g)))(arr) - self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) def test_get_check_rep(self): mesh = jtu.create_mesh((2, 2), ('x', 'y')) @@ -4762,6 +4762,21 @@ def g(x): jax.jit(jax.grad(g))(arr) # doesn't crash + @jtu.with_explicit_mesh((2,), 'x') + def test_shmap_primal_type_match_ct_type(self, mesh): + arr = jax.device_put(np.arange(8.), P('x')) + + @jax.jit + @jax.shard_map(in_specs=P(), out_specs=P('x')) + def f(x): + return x * 2 + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + + out_g = jax.jit(jax.grad(lambda x: f(x).sum()))(arr) + self.assertEqual(out_g.sharding, NamedSharding(mesh, P('x'))) + class FunSpec(NamedTuple): name: str From 871245d1f53be5e437b83600e11229aa138c1434 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 16 Dec 2025 21:08:09 -0800 Subject: [PATCH 234/315] Add shape checks to `ct_check` function too PiperOrigin-RevId: 845568297 --- jax/_src/core.py | 21 +++++++++++---------- jax/_src/custom_derivatives.py | 2 +- jax/_src/interpreters/ad.py | 4 ++-- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 1b64ae57c79e..4d3e3435e583 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -3085,7 +3085,7 @@ def typecompat(aval_ref: AbstractValue, aval: AbstractValue) -> bool: return False def typematch(t1: AbstractValue, t2: AbstractValue, - only_sharding_check: bool = False) -> bool: + only_shape_shd_check: bool = False) -> bool: """Determine whether `t1` and `t2` are equivalent. Ignores weak_type.""" t1 = t1.normalize() t2 = t2.normalize() @@ -3093,28 +3093,29 @@ def typematch(t1: AbstractValue, t2: AbstractValue, if t1 == t2: return True elif isinstance(t1, ShapedArray) and isinstance(t2, ShapedArray): - if only_sharding_check: - return cmp_sharding_vma(t1, t2) - return (t1.dtype == t2.dtype and definitely_equal_shape(t1.shape, t2.shape) - and cmp_sharding_vma(t1, t2) and t1.memory_space == t2.memory_space) + if only_shape_shd_check: + return cmp_shape_sharding_vma(t1, t2) + return (t1.dtype == t2.dtype and cmp_shape_sharding_vma(t1, t2) and + t1.memory_space == t2.memory_space) elif isinstance(t1, AbstractRef) and isinstance(t2, AbstractRef): # We want to use the regular typecheck for ShapedArray here. - return (typematch(t1.inner_aval, t2.inner_aval, only_sharding_check) and # type: ignore + return (typematch(t1.inner_aval, t2.inner_aval, only_shape_shd_check) and # type: ignore (t1.memory_space is None or t2.memory_space is None or # type: ignore t1.memory_space == t2.memory_space)) # type: ignore else: return False -def cmp_sharding_vma(t1, t2): +def cmp_shape_sharding_vma(t1, t2): # TODO(yashkatariya): Expand this to Manual and Auto mode. # See https://github.com/jax-ml/jax/issues/26474 if (not t1.sharding.mesh.empty and not t2.sharding.mesh.empty and (t1.sharding.mesh._any_axis_explicit or t2.sharding.mesh._any_axis_explicit)): - sh_eq = t1.sharding == t2.sharding + shd_eq = t1.sharding == t2.sharding else: - sh_eq = True - return sh_eq and t1.vma == t2.vma + shd_eq = True + return (shd_eq and definitely_equal_shape(t1.shape, t2.shape) and + t1.vma == t2.vma) def aval_mismatch_extra(a1: AbstractValue, a2: AbstractValue) -> str: assert not typematch(a1, a2) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 4c7c87c39817..0726ec67c32d 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -985,7 +985,7 @@ def _ref_typecompat(a, a_): def _temporary_dtype_exception(a, a_) -> bool: if isinstance(a, core.ShapedArray) and isinstance(a_, core.ShapedArray): return (a.shape == a_.shape and - core.typematch(a, a_, only_sharding_check=True) and + core.typematch(a, a_, only_shape_shd_check=True) and (dtypes.issubdtype(a_.dtype, dtypes.extended) or dtypes.issubdtype(a.dtype, dtypes.np.inexact))) return False diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 5696fc3ac11d..fcc8175cdc7b 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -321,7 +321,7 @@ def write_cotangent(prim, v, ct): return ct_aval = typeof(ct) ct_aval_expected = v.aval.to_cotangent_aval() # type: ignore - if not core.typematch(ct_aval, ct_aval_expected, only_sharding_check=True): + if not core.typematch(ct_aval, ct_aval_expected, only_shape_shd_check=True): raise ValueError( f"Input primal JAX type to {prim.name} is {v.aval.str_short()}. Hence" f" the expected cotangent type is {ct_aval_expected.str_short()} but" @@ -593,7 +593,7 @@ def freeze(self): def ct_check(primal, ct): ct_aval = ct.aval if type(ct) is Zero else typeof(ct) ct_aval_expected = primal.aval.to_cotangent_aval() # type: ignore - if not core.typematch(ct_aval, ct_aval_expected, only_sharding_check=True): + if not core.typematch(ct_aval, ct_aval_expected, only_shape_shd_check=True): # TODO(yashkatariya, mattjj): Add primitive name here for # better error message? raise ValueError( From 6940903c006b9ea57af6e0043b4f86d4c9487045 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 16 Dec 2025 22:18:55 -0800 Subject: [PATCH 235/315] Dispatch to `reshard` in `broadcast_in_dim` if only sharding is changing and everything else is the same as operand. PiperOrigin-RevId: 845590697 --- jax/_src/lax/lax.py | 11 +++++++---- tests/pjit_test.py | 17 +++++++++++++++++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index ca0e2b1f0a8f..71f063135001 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2696,7 +2696,7 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape, operand: an array shape: the shape of the target array broadcast_dimensions: to which dimension in the target shape each dimension - of the operand shape corresponds to. That is, dimension i of the operand + of the operand shape corresponds to. That is, dimension i of the operand becomes dimension broadcast_dimensions[i] of the result. Returns: @@ -2705,15 +2705,18 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape, See Also: jax.lax.broadcast : simpler interface to add new leading dimensions. """ - # TODO(dfm): Re-write this as a "reshard" when only the sharding changes. out_sharding = canonicalize_sharding(out_sharding, 'broadcast_in_dim') if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and isinstance(operand, Array) and out_sharding is None): return operand + operand_aval = core.typeof(operand) + if (operand_aval.shape == shape and + list(broadcast_dimensions) == list(range(operand_aval.ndim)) and + out_sharding is not None and operand_aval.sharding != out_sharding): + return pjit.reshard(operand, out_sharding) return broadcast_in_dim_p.bind( operand, shape=tuple(shape), - broadcast_dimensions=tuple(broadcast_dimensions), - sharding=out_sharding) + broadcast_dimensions=tuple(broadcast_dimensions), sharding=out_sharding) def broadcast_to_rank(x: ArrayLike, rank: int) -> Array: """Adds leading dimensions of ``1`` to give ``x`` rank ``rank``.""" diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 243772588c91..8c71d5132667 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5399,6 +5399,23 @@ def f(x, y): ValueError, "For primitive.*context mesh.*aval mesh"): f(arr1, arr2) + @jtu.with_explicit_mesh((2,), 'x') + def test_no_op_broadcast_except_for_sharding_change(self, mesh): + arr = jnp.arange(8.).reshape(4, 2) + + @jax.jit + def f(x): + out = jax.lax.broadcast_in_dim(x, (4, 2), [0, 1], out_sharding=P('x')) + self.assertEqual(out.aval.sharding.spec, P('x', None)) + return out + + out = f(arr) + self.assertArraysEqual(out, arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + + out_g = jax.jit(jax.grad(lambda x: f(x).sum()))(arr) + self.assertEqual(out_g.sharding, NamedSharding(mesh, P(None, None))) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_sin_unop(self, mesh): np_inp = np.arange(16.).reshape(8, 2) From 9630a2b0af1fa82bdc76a0493710e909f31d5b3e Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 16 Dec 2025 22:57:57 -0800 Subject: [PATCH 236/315] Use `to_cotangent_aval()` in `SymbolicZero` check in _flatten_bwd PiperOrigin-RevId: 845603376 --- jax/_src/custom_derivatives.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 3d7167f4de0e..c7e1ae016d07 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -945,7 +945,7 @@ def append(x, d): if ct is zero or getattr(a.to_tangent_aval(), 'dtype') == dtypes.float0: results.append(Zero(a.to_tangent_aval())) elif type(ct) is SymbolicZero: - if not core.typecompat(a.to_tangent_aval(), a_ := ct.aval): + if not core.typecompat(a.to_cotangent_aval(), a_ := ct.aval): msg = ("Custom VJP bwd rule produced a SymbolicZero with a shape/dtype " "that does not match the corresponding input tangent shape/dtype: " f"at output{keystr(kp)} the SymbolicZero had shape/dtype " From e63d2a499f2f0e8cc21daa7d09a254ea3174508a Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 17 Dec 2025 00:06:16 -0800 Subject: [PATCH 237/315] Update XLA dependency to use revision http://github.com/openxla/xla/commit/8bc1190b0e51ec5f01143d255826143ad7975ee9 PiperOrigin-RevId: 845625223 --- third_party/xla/revision.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/revision.bzl b/third_party/xla/revision.bzl index 3c027d1816cd..677ff16cfdf0 100644 --- a/third_party/xla/revision.bzl +++ b/third_party/xla/revision.bzl @@ -21,5 +21,5 @@ # and update XLA_SHA256 with the result. # buildifier: disable=module-docstring -XLA_COMMIT = "d20fe8e99b411a6b61e91ad3aeeadd6e26f9fc7d" -XLA_SHA256 = "0bab08c8933e282e4437532418fcfc3b1b2648749a8546526542d35127f1d982" +XLA_COMMIT = "8bc1190b0e51ec5f01143d255826143ad7975ee9" +XLA_SHA256 = "2011c87aa50750037e78ebb771d1bc76f93f68b174f4354f64eff8eaa70c413d" From b8d37576ce90659c5e97b1b91c4a2bf1590163d2 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 17 Dec 2025 05:54:46 -0800 Subject: [PATCH 238/315] [Pallas/interpreter] Refactor re-useable functionality out of the main source file for the TPU kernel interpreter. PiperOrigin-RevId: 845731033 --- jax/_src/pallas/mosaic/interpret/BUILD | 22 + .../mosaic/interpret/interpret_pallas_call.py | 417 +++--------------- .../mosaic/interpret/race_detection_state.py | 12 +- .../pallas/mosaic/interpret/shared_memory.py | 17 + .../pallas/mosaic/interpret/thread_map.py | 80 ++++ jax/_src/pallas/mosaic/interpret/utils.py | 336 ++++++++++++++ tests/pallas/tpu_pallas_interpret_test.py | 10 +- .../tpu_pallas_interpret_thread_map_test.py | 7 +- 8 files changed, 527 insertions(+), 374 deletions(-) create mode 100644 jax/_src/pallas/mosaic/interpret/thread_map.py create mode 100644 jax/_src/pallas/mosaic/interpret/utils.py diff --git a/jax/_src/pallas/mosaic/interpret/BUILD b/jax/_src/pallas/mosaic/interpret/BUILD index 97300e4d0a3d..2a86f2258032 100644 --- a/jax/_src/pallas/mosaic/interpret/BUILD +++ b/jax/_src/pallas/mosaic/interpret/BUILD @@ -33,6 +33,8 @@ py_library( deps = [ ":race_detection_state", ":shared_memory", + ":thread_map", + ":utils", ":vector_clock", "//jax", "//jax/_src:api", @@ -79,3 +81,23 @@ pytype_strict_library( "//jax/_src:source_info_util", ], ) + +pytype_strict_library( + name = "thread_map", + srcs = ["thread_map.py"], + deps = [ + "//jax", + "//jax/_src:callback", + ], +) + +pytype_strict_library( + name = "utils", + srcs = ["utils.py"], + deps = [ + "//jax", + "//jax/_src:core", + "//jax/_src:util", + "//jax/_src/pallas", + ] + py_deps("numpy"), +) diff --git a/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py b/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py index 8921da01d971..577df33a9f58 100644 --- a/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py +++ b/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py @@ -39,6 +39,8 @@ from jax._src.pallas.mosaic.interpret import shared_memory as memory from jax._src.pallas.mosaic.interpret import vector_clock as vc from jax._src.pallas.mosaic.interpret.race_detection_state import RaceDetectionState +from jax._src.pallas.mosaic.interpret.thread_map import thread_map +import jax._src.pallas.mosaic.interpret.utils as interpret_utils from jax._src.state import discharge as state_discharge from jax._src.state import indexing from jax._src.state import primitives as state_primitives @@ -58,7 +60,7 @@ @dataclasses.dataclass(frozen=True, kw_only=True) -class InterpretParams: +class InterpretParams(interpret_utils.InterpretParams): """Parameters for TPU interpret mode. TPU interpret mode is a way run Pallas TPU kernels on CPU, while simulating @@ -71,35 +73,16 @@ class InterpretParams: :func:`jax.experimental.pallas.pallas_call` or :func:`jax.experimental.pallas.core_map`. + NOTE: If an exception is raised while interpreting a kernel, you must call + :func:`reset_tpu_interpret_mode_state` before using TPU interpret mode + again in the same process. + Attributes: dma_execution_mode: If "eager", DMAs are executed as soon as they are issued. If "on_wait", DMA reads or writes are only executed when a device is waiting on a DMA semaphore that will be signaled when the read or write is complete. Default: "on_wait". - detect_races: If True, a dynamic, happens-before race detector will be used - to detect data races during kernel interpretation. If any races are - detected, a message will be printed and `races.races_found` will be set to - True. - Default: False. - out_of_bounds_reads: If "raise", an exception will be raised on any - out-of-bounds read of a buffer. If "uninitialized_value", any parts of - the read that are out-of-bounds will return the value used to fill - uninitialized memory, which can be configured via the - "uninitialized_memory". NOTE: If an exception is raised while - interpreting a kernel, you must call - :func:`reset_tpu_interpret_mode_state` before using TPU interpret mode - again in the same process. - Default: "raise". - skip_floating_point_ops: If True, operations that produce only floating - point values will not be interpreted; instead, their results will be - replaced with arrays all of `jnp.inf`. Additionally any floating point - operands to any operation will be replaced with (arrays of) `jnp.inf`. - Default: False. - uninitialized_memory: If "nan", allocated buffers are initialized to contain - all NaNs (or to their maximum possible value for integers). If "zero", - allocated buffers are initialized to all zeros. - Default: "nan". random_seed: Seed for random number generator used during interpretation. Currently random numbers are used to randomize the grid coordinates along dimensions with 'parallel' semantics. @@ -112,33 +95,25 @@ class InterpretParams: along grid dimensions with 'parallel' semantics and - the mapping of grid points to local (i.e. per-device) cores. Default: None. - num_cores_per_device: The number of cores per device. - Default: 1. allow_hbm_allocation_in_run_scoped: If `True`, allows the allocation of HBM buffers (which are then shared across the cores in a device) in `run_scoped`. While this behavior can be enabled in the interpreter, allocating HBM buffers with `run_scoped` is not supported when executing Pallas kernels on a real TPU. Default: `False`. - vector_clock_size: The number of entries in the vector clocks. This should - be an integer bigger then the total number of cores, i.e. bigger than - `number of devices * num_cores_per_device`. If `None`, the vector clock - size that is used in the interpreter will default to twice the total - number of cores. - Default: None. """ + dma_execution_mode: Literal["eager", "on_wait"] = "on_wait" - detect_races: bool = False - out_of_bounds_reads: Literal["raise", "uninitialized"] = "raise" - skip_floating_point_ops: bool = False - uninitialized_memory: Literal["nan", "zero"] = "nan" random_seed: int | None = None grid_point_recorder: ( Callable[[tuple[np.int32, ...], np.int32], None] | None ) = None - num_cores_per_device: int = 1 allow_hbm_allocation_in_run_scoped: bool = False - vector_clock_size: int | None = None + + @property + def num_cores_per_device(self) -> int: + return self.num_cores_or_threads_per_device + @contextlib.contextmanager def force_tpu_interpret_mode(params: InterpretParams = InterpretParams()): @@ -167,26 +142,12 @@ def set_tpu_interpret_mode(params: InterpretParams = InterpretParams()): config.pallas_tpu_interpret_mode_context_manager.set_global(params) # type: ignore[arg-type] -class Counter: - """A simple counter that is thread-safe.""" - - def __init__(self, initial_value: int): - self.value = initial_value - self.lock = threading.Lock() - - def get_next(self): - with self.lock: - result = self.value - self.value += 1 - return result - - # TODO(jburnim): Do we want to support multiple instances of SharedMemory? # Maybe for running multiple distinct interpreted computations in parallel? _shared_memory: memory.SharedMemory | None = None _shared_memory_init_lock = threading.Lock() races: RaceDetectionState | None = None -dma_id_counter: Counter | None = None +dma_id_counter: interpret_utils.Counter | None = None def reset_tpu_interpret_mode_state(): """Resets all global, shared state used by TPU interpret mode. @@ -218,23 +179,6 @@ def _clear_shared_memory(): _shared_memory = None -def _get_vector_clock_size( - num_devices, num_cores_per_device, *, interpret_params -) -> int: - """Returns the number of vector clocks to use.`""" - num_cores = num_devices * num_cores_per_device - if interpret_params.vector_clock_size is not None: - if num_cores >= interpret_params.vector_clock_size: - raise ValueError( - f'Vector clock size ({interpret_params.vector_clock_size}) must be ' - f'greater than the total number of cores ({num_cores}).' - ) - return interpret_params.vector_clock_size - else: - # Default the vector clock size to twice the total number of cores. - return 2 * num_cores - - def _initialize_shared_memory( device_id, num_devices, num_cores_per_device, *, interpret_params ): @@ -247,11 +191,9 @@ def _initialize_shared_memory( with _shared_memory_init_lock: if _shared_memory is None: - vector_clock_size = _get_vector_clock_size( - num_devices, num_cores_per_device, interpret_params=interpret_params - ) + vector_clock_size = interpret_params.get_vector_clock_size(num_devices) races = RaceDetectionState(num_cores=num_cores) - dma_id_counter = Counter(100) + dma_id_counter = interpret_utils.Counter(100) _shared_memory = memory.SharedMemory( num_devices=num_devices, num_cores_per_device=num_cores_per_device, @@ -273,31 +215,16 @@ def _initialize_shared_memory( assert _shared_memory.num_cores == num_cores -def _update_clocks(low_global_core_id, high_global_core_id): - """Synchronizes the vector clocks for the cores with ids in the range between the two arguments.""" - shared_memory = _get_shared_memory() - # Despite only updating the vector clocks for some cores, we still need to - # hold the global lock to ensure that no other devices are concurrently - # accessing the same vector clocks. - with shared_memory.lock: - for c in shared_memory.clocks[low_global_core_id + 1 : high_global_core_id]: - vc.update_vector_clock(shared_memory.clocks[low_global_core_id], c) - for c in shared_memory.clocks[low_global_core_id + 1 : high_global_core_id]: - vc.update_vector_clock(c, shared_memory.clocks[low_global_core_id]) - - def _update_clocks_for_device_barrier(device_id): """Synchronizes the vector clocks for the cores on the given device.""" shared_memory = _get_shared_memory() - low_core_id = device_id * shared_memory.num_cores_per_device - high_core_id = (device_id + 1) * shared_memory.num_cores_per_device - _update_clocks(low_core_id, high_core_id) + shared_memory.update_clocks_for_device_barrier(device_id) def _update_clocks_for_global_barrier(): """Synchronizes all vector clocks.""" shared_memory = _get_shared_memory() - _update_clocks(0, shared_memory.num_cores) + shared_memory.update_clocks(0, shared_memory.num_cores) def _barrier(device_id): @@ -322,7 +249,8 @@ def _check_for_revisiting(device_id, local_core_id, loop_idx, output_blocks): except: raise ValueError('Advanced indexers are not supported on TPU') output_ranges = [ - _to_range(b) if b is not None else None for b in output_blocks + interpret_utils.to_range(b) if b is not None else None + for b in output_blocks ] shared_memory = _get_shared_memory() @@ -548,60 +476,6 @@ def get_barrier_semaphore(device_id, collective_id): return np.int16(collective_id) -def _transform_slice_or_index(slice_or_idx): - if isinstance(slice_or_idx, int): - return slice_or_idx - else: - start = int(slice_or_idx.start) - size = int(slice_or_idx.size) - stride = int(slice_or_idx.stride) - return slice(start, start + size * stride, stride) - - -def _compose_slice_or_index(slice_or_idx1, slice_or_idx2): - ret = [] - i = 0 - j = 0 - while True: - if i == len(slice_or_idx1): - ret.extend(slice_or_idx2[j:]) - return tuple(ret) - elif j == len(slice_or_idx2): - ret.extend(slice_or_idx1[i:]) - return tuple(ret) - elif isinstance(slice_or_idx1[i], int): - ret.append(slice_or_idx1[i]) - i += 1 - elif isinstance(slice_or_idx2[j], int): - ret.append( - slice_or_idx1[i].start + slice_or_idx2[j] * slice_or_idx1[i].step - ) - i += 1 - j += 1 - else: - ret.append( - slice( - slice_or_idx1[i].start - + slice_or_idx2[j].start * slice_or_idx1[i].step, - slice_or_idx1[i].start - + slice_or_idx2[j].stop * slice_or_idx1[i].step, - slice_or_idx1[i].step * slice_or_idx2[j].step, - ) - ) - i += 1 - j += 1 - - -def _to_range(transforms) -> tuple[slice | int, ...]: - ret = () - for transform in transforms: - # For now, assume only NDIndexer transforms. - ret = _compose_slice_or_index( - ret, tuple(_transform_slice_or_index(i) for i in transform.indices) - ) - return ret - - def _to_int(x: int | Array | None) -> int | None: """Converts a value to an integer, or returns None if the value is None.""" if x is None: @@ -649,7 +523,7 @@ def get( global_core_id = shared_memory.get_global_core_id(device_id, local_core_id) key = (memory_space, buffer_id, device_id, local_core_id_for_buffer) - read_range = _to_range(transforms) + read_range = interpret_utils.to_range(transforms) ret, (shape, dtype), clock_ = shared_memory.get_buffer_content( key, read_range, global_core_id ) @@ -702,7 +576,9 @@ def get( # out_of_bounds_reads == "uninitialized" uninit_array = np.full( full_read_shape, - _uninitialized_value(dtype, shared_memory.uninitialized_memory), + interpret_utils.get_uninitialized_value( + dtype, shared_memory.uninitialized_memory + ), dtype=dtype, ) if ret is None: @@ -771,7 +647,7 @@ def store( global_core_id = shared_memory.get_global_core_id(device_id, local_core_id) key = (memory_space, buffer_id, device_id, local_core_id_for_buffer) - write_range = _to_range(transforms) + write_range = interpret_utils.to_range(transforms) in_bounds, (shape, _), clock_ = shared_memory.store_buffer_content( key, write_range, val, global_core_id ) @@ -842,7 +718,7 @@ def swap( global_core_id = shared_memory.get_global_core_id(device_id, local_core_id) key = (memory_space, buffer_id, device_id, local_core_id_for_buffer) - read_write_range = _to_range(transforms) + read_write_range = interpret_utils.to_range(transforms) ret, (shape, _), clock = shared_memory.swap_buffer_content( key, read_write_range, val, mask, global_core_id ) @@ -1184,74 +1060,6 @@ def _compute_transformed_shape_and_dtype(shape, dtype, transforms): dtype = transform.transform_dtype(dtype) return shape, dtype -# TODO(sharadmv): De-dup this w/ the impl in primitives.py. -def _device_id_dict_to_mesh(device_id_dict, axis_sizes, axis_indices): - physical_axis_dict = {} - axis_names = axis_sizes.keys() - for axis, idx in device_id_dict.items(): - if isinstance(axis, tuple) and any(a in axis_names for a in axis): - if not all(a in axis_names for a in axis): - raise NotImplementedError( - f"{axis} mixes JAX mesh and Pallas mesh grid axes" - ) - axes_dimensions = [axis_sizes[name] for name in axis] - for axis_index, axis_name in enumerate(axis): - axis_size = axis_sizes[axis_name] - inner_mesh_size = math.prod(axes_dimensions[axis_index + 1 :]) - minor_divisor = inner_mesh_size - - # Fast path for power of 2s - if inner_mesh_size & (inner_mesh_size - 1) == 0: - shift_len = (inner_mesh_size & -inner_mesh_size).bit_length() - 1 - partial_device_idx = idx >> shift_len - else: - partial_device_idx = idx // minor_divisor - - if axis_size & (axis_size - 1) == 0: - device_idx = partial_device_idx & (axis_size - 1) - else: - device_idx = partial_device_idx % axis_size - physical_axis_dict[axis_name] = device_idx - else: - physical_axis_dict[axis] = idx - device_id = [] - for axis in axis_names: - if axis in physical_axis_dict: - device_id.append(physical_axis_dict[axis]) - else: - device_id.append(axis_indices[axis]) - non_mesh_axes = { - k: v - for k, v in physical_axis_dict.items() - if k not in axis_names - } - return tuple(device_id), non_mesh_axes - -def _device_coords_to_logical_id(device_coords, axis_sizes, axis_indices): - if isinstance(device_coords, dict): - device_coords, non_mesh_axes = _device_id_dict_to_mesh( - device_coords, axis_sizes, axis_indices) - if non_mesh_axes: - raise NotImplementedError(non_mesh_axes) - if not isinstance(device_coords, tuple): - device_coords = (device_coords,) - assert len(device_coords) == len(axis_sizes) - sizes = list(axis_sizes.values()) - ret = 0 - for i in range(len(device_coords)): - ret += device_coords[i] * math.prod(sizes[i+1:]) - return ret - -def _device_id_to_logical(device_id, device_id_type, axis_sizes, - axis_indices): - if device_id is None: - return None - if device_id_type == primitives.DeviceIdType.MESH: - return _device_coords_to_logical_id(device_id, axis_sizes, axis_indices) - elif device_id_type == primitives.DeviceIdType.LOGICAL: - return device_id - else: - raise ValueError(f'Unsupported device ID type: {device_id_type}') @lu.cache def _to_jaxpr(flat_fun, in_avals): @@ -1263,17 +1071,9 @@ def _is_any(memory_space): return ((memory_space == mosaic_core.MemorySpace.ANY) or (memory_space == pallas_core.MemorySpace.ANY)) -def _is_float(dtype): - return jnp.issubdtype(dtype, jnp.floating) _SENTINEL = jnp.inf -@dataclasses.dataclass(frozen=True) -class Placeholder: - """Placeholder for use in `_interpret_jaxpr` below instead of putting a concrete value into `env`.""" - shape: tuple[int, ...] - dtype: jnp.dtype - def _get_memory_space_and_raise_if_hbm(aval, primitive_name, message=None): memory_space = aval.memory_space @@ -1299,23 +1099,14 @@ def _interpret_jaxpr( compiler_params, interpret_params ): - env = {} - - def read(var): - if isinstance(var, jax_core.Literal): - result = var.val - else: - result = env[var] - if isinstance(result, Placeholder): - result = jax.lax.full(result.shape, _SENTINEL, result.dtype) - return result - - def write(var, value): - if interpret_params.skip_floating_point_ops and _is_float(value.dtype): - value = Placeholder(value.shape, value.dtype) - env[var] = value - - jax._src.util.safe_map(write, jaxpr.constvars + jaxpr.invars, args) + sentinel_for_floating_point_values = ( + _SENTINEL if interpret_params.skip_floating_point_ops else None + ) + env = interpret_utils.JaxprEnv( + vars=jaxpr.constvars + jaxpr.invars, + values=args, + sentinel_for_floating_point_values=sentinel_for_floating_point_values, + ) # TODO(jburnim): Clean up and finish this evaluation loop. For example: # - Replace the big if-statement with a dictionary of rules. @@ -1339,9 +1130,7 @@ def write(var, value): # not need to do any reads if `interpret_params.skip_floating_point_ops` # is True. If this is the case, we want to avoid materializing the read # array into the jaxpr when this function is traced. - deferred_invals = functools.partial( - jax._src.util.safe_map, read, eqn.invars - ) + deferred_invals = functools.partial(env.read_many, eqn.invars) if prim is primitives.load_p: (ref, transforms, mask, _) = jax.tree.unflatten( @@ -1487,8 +1276,8 @@ def f(*args, jaxpr): device_id, local_core_id, TPU_MEMORY_SPACE_IDXS[memory_space], - _uninitialized_array( - v.aval.shape, v.aval.dtype, interpret_params + interpret_params.get_uninitialized_array( + v.aval.shape, v.aval.dtype ), ordered=True, ) @@ -1563,7 +1352,7 @@ def f(*args, jaxpr): src_sem_transforms, target_device_id, ) = jax.tree.unflatten(eqn.params['tree'], deferred_invals()) - target_device_id = _device_id_to_logical( + target_device_id = interpret_utils._device_id_to_logical( target_device_id, eqn.params['device_id_type'], axis_sizes, axis_indices) (orig_src_ref, _, orig_dst_ref, *_ @@ -1629,7 +1418,7 @@ def f(*args, jaxpr): elif prim is primitives.semaphore_signal_p: sem, sem_transforms, inc, target_device_id, core_index = ( jax.tree.unflatten(eqn.params['args_tree'], deferred_invals())) - target_device_id = _device_id_to_logical( + target_device_id = interpret_utils._device_id_to_logical( target_device_id, eqn.params['device_id_type'], axis_sizes, axis_indices) callback.io_callback( @@ -1669,7 +1458,7 @@ def f(*args, jaxpr): else: if interpret_params.skip_floating_point_ops and all( - _is_float(ovar.aval.dtype) for ovar in eqn.outvars + interpret_utils.is_float(ovar.aval.dtype) for ovar in eqn.outvars ): # Skip `prim.bind` since `prim` only produces floating-point values. # It is safe to populate `out` with avals since mapping `write` over @@ -1683,9 +1472,9 @@ def f(*args, jaxpr): out = prim.bind(*subfuns, *deferred_invals(), **bind_params) out = out if prim.multiple_results else [out] - jax._src.util.safe_map(write, eqn.outvars, out) + env.write_many(eqn.outvars, out) - return jax._src.util.safe_map(read, jaxpr.outvars) + return env.read_many(jaxpr.outvars) def _compute_start_indices( block_mapping, loop_idx, *args, @@ -1890,106 +1679,10 @@ def _get_grid_point( grid_point.append(li if jnp.size(coords) == 0 else coords[li]) return jnp.array(grid_point, dtype=np.int32) -def _uninitialized_value(dtype, uninitialized_memory: Literal['nan', 'zero']): - if uninitialized_memory == 'nan': - if jnp.issubdtype(dtype, jnp.floating): - return np.nan - elif jnp.issubdtype(dtype, jnp.integer): - return jnp.iinfo(dtype).max - elif jnp.issubdtype(dtype, jnp.bool): - return True - if uninitialized_memory == 'zero': - return 0 - raise NotImplementedError( - uninitialized_memory + ' + ' + str(dtype)) - -def _uninitialized_array(shape, dtype, interpret_params): - return jnp.full( - shape, - _uninitialized_value(dtype, interpret_params.uninitialized_memory), - dtype, - ) - -def _pad_to_block_dimension(value, block_shape, interpret_params): - """Pads values so the shape evenly divides into block dimensions. - - For example, if values has a shape of (33, 2, 5) with a block_shape of - (32, 2, 4), this function will pad the value of shape to (64, 2, 8). - - Args: - value: Array to be padded. - block_shape: Block shapes to use for padding. If None, no padding will - be performed. - - Returns: - A padded array. - """ - padded_shape = tuple( - ((v - 1) // b + 1) * b for v, b in zip(value.shape, block_shape) - ) - if padded_shape != value.shape: - pad_width = tuple((0, a-b) for a, b in zip(padded_shape, value.shape)) - pad_value = _uninitialized_array((), value.dtype, interpret_params) - value = jnp.pad(value, pad_width, constant_values=pad_value) - return value def get_interpret_effects(): return {callback._OrderedIOEffect} -def _thread_map(f, num_threads): - if num_threads == 1: - f(jnp.int32(0)) - return - - def _f(core_index): - f(core_index) - return () - jaxpr = jax.make_jaxpr(_f)(jnp.int32(0)) - - _call_threadmap_callback(jaxpr.jaxpr, num_threads, *jaxpr.consts) - -def _run_jaxpr(jaxpr, consts, *args): - def _run(jaxpr, consts, *args): - jax_core.eval_jaxpr(jaxpr, consts, *args) - traced = jax.jit(_run, static_argnums=(0,)).trace(jaxpr, consts, *args) - traced.lower().compile()(consts, *args) - return - -import concurrent.futures - -def _thread_map_callback(jaxpr, num_threads, consts): - num_threads = int(num_threads) - threads = [] - with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: - for i in range(num_threads): - threads.append( - executor.submit(_run_jaxpr, jaxpr, consts, jnp.int32(i))) - exceptions = [] - for i in range(num_threads): - try: - threads[i].result() - except Exception as e: - exceptions.append(e) - if exceptions: - # TODO(jburnim): Use ExceptionGroup once JAX requires Python 3.11. - # raise ExceptionGroup('Exceptions raised during _thread_map', exceptions) - raise exceptions[0] - -def _call_threadmap_callback(jaxpr, num_threads, *consts): - # NOTE: At runtime, _thread_map_callback will lower and compile the - # given jaxpr. (JAX's caches should ensure the jaxpr is only lowered and - # compiled once.) - # - # TODO(jburnim): Would it be worth trying to lower/compile the jaxpr at - # lowering/compilation time? E.g., by using a custom primitive here, could - # we lower/compile jaxpr at lowering time, and then pass the compiled - # function to the callback? - return callback.io_callback( - functools.partial(_thread_map_callback, jaxpr), - (), - num_threads, - consts, - ordered=True) def interpret_pallas_call( *args, @@ -2014,7 +1707,8 @@ def interpret_pallas_call( # that users don't have to specify it in the InterpretParams. assert len(mesh.shape) == 1 interpret_params = dataclasses.replace( - interpret_params, num_cores_per_device=mesh.devices.shape[0]) + interpret_params, num_cores_or_threads_per_device=mesh.devices.shape[0] + ) args = [remove_memory_space_p.bind(a) for a in args] # args contains: *dynamic_grid_sizes, *index, *inputs. (No consts?) @@ -2034,8 +1728,9 @@ def interpret_pallas_call( num_devices = functools.reduce( jnp.multiply, axis_sizes.values(), jnp.int32(1)) axis_indices = {k: lax.axis_index(k) for k in axis_sizes.keys()} - device_id = _device_coords_to_logical_id( - tuple(axis_indices.values()), axis_sizes, axis_indices) + device_id = interpret_utils.device_coords_to_logical_id( + tuple(axis_indices.values()), axis_sizes, axis_indices + ) callback.io_callback( functools.partial( _initialize_shared_memory, interpret_params=interpret_params @@ -2058,7 +1753,7 @@ def interpret_pallas_call( ] num_inputs = grid_mapping.num_inputs input_args = [ - _pad_to_block_dimension(a, bs, interpret_params) + interpret_params.pad_to_block_dimension(a, bs) for a, bs in zip(input_args, block_shapes[:num_inputs]) ] @@ -2099,11 +1794,11 @@ def interpret_pallas_call( output_buffer_shapes.append(input_args[oi_alias_map[i]].shape) output_vals.append(input_args[oi_alias_map[i]]) else: - out_val = _uninitialized_array(bm.array_aval.shape, - bm.array_aval.dtype, - interpret_params) - padded_val = _pad_to_block_dimension( - out_val, output_block_shapes[i], interpret_params + out_val = interpret_params.get_uninitialized_array( + bm.array_aval.shape, bm.array_aval.dtype + ) + padded_val = interpret_params.pad_to_block_dimension( + out_val, output_block_shapes[i] ) output_buffer_ids.append( callback.io_callback( @@ -2172,8 +1867,8 @@ def interpret_pallas_call( device_id, None, # local_core_id, TPU_MEMORY_SPACE_IDXS[var.aval.memory_space], - _uninitialized_array( - var.aval.shape, var.aval.dtype, interpret_params + interpret_params.get_uninitialized_array( + var.aval.shape, var.aval.dtype ), ordered=True, ) @@ -2522,7 +2217,7 @@ def _store_to_output_buffer(index, output_var, transform): _update_clocks_for_device_barrier, (), device_id, ordered=True ) - _thread_map(_execute_grid_for_core, interpret_params.num_cores_per_device) + thread_map(_execute_grid_for_core, interpret_params.num_cores_per_device) # TODO(jburnim): Should we only create happens-before here from the other # # cores to core 0? diff --git a/jax/_src/pallas/mosaic/interpret/race_detection_state.py b/jax/_src/pallas/mosaic/interpret/race_detection_state.py index 64f6568c3754..ff76778119d3 100644 --- a/jax/_src/pallas/mosaic/interpret/race_detection_state.py +++ b/jax/_src/pallas/mosaic/interpret/race_detection_state.py @@ -116,9 +116,9 @@ def check_read( # between real device IDs vs. DMA IDs. print( f'RACE DETECTED\n read of {buffer_key}[{rnge}] from {device_id},' - f' {local_core_id}, {user_frame}\n write of' + f' {local_core_id}, {user_frame}\n clock: {clock}\n write of' f' {buffer_key}[{write_range}] from {write_device_id},' - f' {write_local_core_id} {write_frame}' + f' {write_local_core_id} {write_frame}\n clock: {write_clock}\n' ) with self.lock: self.races_found = True @@ -158,9 +158,9 @@ def check_write( # between real device IDs vs. DMA IDs. print( f'RACE DETECTED\n write of {buffer_key}[{rnge}] from {device_id},' - f' {local_core_id}, {user_frame}\n write of' + f' {local_core_id}, {user_frame}\n clock: {clock}\n write of' f' {buffer_key}[{write_range}] from {write_device_id},' - f' {write_local_core_id}, {write_frame}' + f' {write_local_core_id}, {write_frame}\n clock: {write_clock}\n' ) with self.lock: self.races_found = True @@ -178,9 +178,9 @@ def check_write( # between real device IDs vs. DMA IDs. print( f'RACE DETECTED\n write of {buffer_key}[{rnge}] from {device_id},' - f' {local_core_id}, {user_frame}\n read of' + f' {local_core_id}, {user_frame}\n clock: {clock}\n read of' f' {buffer_key}[{read_range}] from {read_device_id},' - f' {read_local_core_id}, {read_frame}' + f' {read_local_core_id}, {read_frame}\n clock: {read_clock}\n' ) with self.lock: self.races_found = True diff --git a/jax/_src/pallas/mosaic/interpret/shared_memory.py b/jax/_src/pallas/mosaic/interpret/shared_memory.py index 74a078bd9fa2..21fd6600928a 100644 --- a/jax/_src/pallas/mosaic/interpret/shared_memory.py +++ b/jax/_src/pallas/mosaic/interpret/shared_memory.py @@ -572,3 +572,20 @@ def swap_buffer_content( mask[in_bounds_idx], value[in_bounds_idx], raw_result ) return result.copy(), shape_and_dtype, clock + + def update_clocks(self, low_global_core_id, high_global_core_id): + """Synchronizes the vector clocks for the cores with ids in the range between the two arguments.""" + # Despite only updating the vector clocks for some cores, we still need to + # hold the global lock to ensure that no other devices are concurrently + # accessing the same vector clocks. + with self.lock: + for c in self.clocks[low_global_core_id + 1 : high_global_core_id]: + vc.update_vector_clock(self.clocks[low_global_core_id], c) + for c in self.clocks[low_global_core_id + 1 : high_global_core_id]: + vc.update_vector_clock(c, self.clocks[low_global_core_id]) + + def update_clocks_for_device_barrier(self, device_id): + """Synchronizes the vector clocks for the cores on the given device.""" + low_core_id = device_id * self.num_cores_per_device + high_core_id = (device_id + 1) * self.num_cores_per_device + self.update_clocks(low_core_id, high_core_id) diff --git a/jax/_src/pallas/mosaic/interpret/thread_map.py b/jax/_src/pallas/mosaic/interpret/thread_map.py new file mode 100644 index 000000000000..3b162a70daeb --- /dev/null +++ b/jax/_src/pallas/mosaic/interpret/thread_map.py @@ -0,0 +1,80 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from concurrent import futures +import functools + +import jax +from jax._src import callback +import jax.core as jax_core +import jax.numpy as jnp + + +def _run_jaxpr(jaxpr, consts, *args): + def _run(jaxpr, consts, *args): + jax_core.eval_jaxpr(jaxpr, consts, *args) + + traced = jax.jit(_run, static_argnums=(0,)).trace(jaxpr, consts, *args) + traced.lower().compile()(consts, *args) + return + + +def _thread_map_callback(jaxpr, num_threads, consts): + num_threads = int(num_threads) + threads = [] + with futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + for i in range(num_threads): + threads.append(executor.submit(_run_jaxpr, jaxpr, consts, jnp.int32(i))) + exceptions = [] + for i in range(num_threads): + try: + threads[i].result() + except Exception as e: + exceptions.append(e) + if exceptions: + # TODO(jburnim): Use ExceptionGroup once JAX requires Python 3.11. + # raise ExceptionGroup('Exceptions raised during _thread_map', exceptions) + raise exceptions[0] + + +def _call_threadmap_callback(jaxpr, num_threads, *consts): + # NOTE: At runtime, _thread_map_callback will lower and compile the + # given jaxpr. (JAX's caches should ensure the jaxpr is only lowered and + # compiled once.) + # + # TODO(jburnim): Would it be worth trying to lower/compile the jaxpr at + # lowering/compilation time? E.g., by using a custom primitive here, could + # we lower/compile jaxpr at lowering time, and then pass the compiled + # function to the callback? + return callback.io_callback( + functools.partial(_thread_map_callback, jaxpr), + (), + num_threads, + consts, + ordered=True, + ) + + +def thread_map(f, num_threads): + if num_threads == 1: + f(jnp.int32(0)) + return + + def _f(core_index): + f(core_index) + return () + + jaxpr = jax.make_jaxpr(_f)(jnp.int32(0)) + + _call_threadmap_callback(jaxpr.jaxpr, num_threads, *jaxpr.consts) diff --git a/jax/_src/pallas/mosaic/interpret/utils.py b/jax/_src/pallas/mosaic/interpret/utils.py new file mode 100644 index 000000000000..460e5031e065 --- /dev/null +++ b/jax/_src/pallas/mosaic/interpret/utils.py @@ -0,0 +1,336 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Sequence +import dataclasses +import math +import threading +from typing import Any, Literal + +from jax import lax +from jax._src import core as jax_core +from jax._src.pallas import primitives +from jax._src.util import safe_map +import jax.numpy as jnp +import numpy as np + + +def get_uninitialized_value( + dtype, uninitialized_memory: Literal["nan", "zero"] +): + if uninitialized_memory == "nan": + if jnp.issubdtype(dtype, jnp.floating): + return np.nan + elif jnp.issubdtype(dtype, jnp.integer): + return jnp.iinfo(dtype).max + elif jnp.issubdtype(dtype, jnp.bool): + return True + if uninitialized_memory == "zero": + return 0 + raise NotImplementedError(uninitialized_memory + " + " + str(dtype)) + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class InterpretParams: + """Parameters for kernel interpret mode. + + Interpret mode is a way to run Pallas kernels on CPU, while simulating TPU/GPU + shared memory, communication, and synchronization operations. + + Attributes: + detect_races: If True, a dynamic, happens-before race detector will be used + to detect data races during kernel interpretation. If any races are + detected, a message will be printed and `races.races_found` will be set to + True. + Default: False. + out_of_bounds_reads: If "raise", an exception will be raised on any + out-of-bounds read of a buffer. If "uninitialized_value", any parts of + the read that are out-of-bounds will return the value used to fill + uninitialized memory, which can be configured via the + "uninitialized_memory". + Default: "raise". + skip_floating_point_ops: If True, operations that produce only floating + point values will not be interpreted; instead, their results will be + replaced with arrays all of `jnp.inf`. Additionally any floating point + operands to any operation will be replaced with (arrays of) `jnp.inf`. + Default: False. + uninitialized_memory: If "nan", allocated buffers are initialized to contain + all NaNs (or to their maximum possible value for integers). If "zero", + allocated buffers are initialized to all zeros. + Default: "nan". + num_cores_or_threads_per_device: The number of cores (TPU) or threads (GPU) + per device. + Default: 1. + vector_clock_size: The number of entries in the vector clocks. This should + be an integer bigger then the total number of cores, i.e. bigger than + `number of devices * num_cores_per_device`. If `None`, the vector clock + size that is used in the interpreter will default to twice the total + number of cores. + Default: None. + """ + + detect_races: bool = False + out_of_bounds_reads: Literal["raise", "uninitialized"] = "raise" + skip_floating_point_ops: bool = False + uninitialized_memory: Literal["nan", "zero"] = "nan" + num_cores_or_threads_per_device: int = 1 + vector_clock_size: int | None = None + + def get_vector_clock_size(self, num_devices) -> int: + """Returns the number of vector clocks to use.`""" + num_cores_or_threads = num_devices * self.num_cores_or_threads_per_device + if self.vector_clock_size is not None: + if num_cores_or_threads >= self.vector_clock_size: + raise ValueError( + f"Vector clock size ({self.vector_clock_size}) must be greater than" + f" the total number of cores/threads ({num_cores_or_threads})." + ) + return self.vector_clock_size + else: + # Default to twice the total number of cores/threads. + return 2 * num_cores_or_threads + + def get_uninitialized_array(self, shape, dtype): + return jnp.full( + shape, + get_uninitialized_value(dtype, self.uninitialized_memory), + dtype, + ) + + def pad_to_block_dimension(self, value, block_shape): + """Pads values so the shape evenly divides into block dimensions. + + For example, if values has a shape of (33, 2, 5) with a block_shape of + (32, 2, 4), this function will pad the value of shape to (64, 2, 8). + + Args: + value: Array to be padded. + block_shape: Block shapes to use for padding. If None, no padding will be + performed. + + Returns: + A padded array. + """ + padded_shape = tuple( + ((v - 1) // b + 1) * b for v, b in zip(value.shape, block_shape) + ) + if padded_shape != value.shape: + pad_width = tuple((0, a - b) for a, b in zip(padded_shape, value.shape)) + pad_value = self.get_uninitialized_array((), value.dtype) + value = jnp.pad(value, pad_width, constant_values=pad_value) + return value + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class InterpretGPUParams(InterpretParams): + ... + + +class Counter: + """A simple counter that is thread-safe.""" + + def __init__(self, initial_value: int): + self.value = initial_value + self.lock = threading.Lock() + + def get_next(self): + with self.lock: + result = self.value + self.value += 1 + return result + + +# TODO(sharadmv): De-dup this w/ the impl in primitives.py. +def _device_id_dict_to_mesh(device_id_dict, axis_sizes, axis_indices): + physical_axis_dict = {} + axis_names = axis_sizes.keys() + for axis, idx in device_id_dict.items(): + if isinstance(axis, tuple) and any(a in axis_names for a in axis): + if not all(a in axis_names for a in axis): + raise NotImplementedError( + f"{axis} mixes JAX mesh and Pallas mesh grid axes" + ) + axes_dimensions = [axis_sizes[name] for name in axis] + for axis_index, axis_name in enumerate(axis): + axis_size = axis_sizes[axis_name] + inner_mesh_size = math.prod(axes_dimensions[axis_index + 1 :]) + minor_divisor = inner_mesh_size + + # Fast path for power of 2s + if inner_mesh_size & (inner_mesh_size - 1) == 0: + shift_len = (inner_mesh_size & -inner_mesh_size).bit_length() - 1 + partial_device_idx = idx >> shift_len + else: + partial_device_idx = idx // minor_divisor + + if axis_size & (axis_size - 1) == 0: + device_idx = partial_device_idx & (axis_size - 1) + else: + device_idx = partial_device_idx % axis_size + physical_axis_dict[axis_name] = device_idx + else: + physical_axis_dict[axis] = idx + device_id = [] + for axis in axis_names: + if axis in physical_axis_dict: + device_id.append(physical_axis_dict[axis]) + else: + device_id.append(axis_indices[axis]) + non_mesh_axes = { + k: v for k, v in physical_axis_dict.items() if k not in axis_names + } + return tuple(device_id), non_mesh_axes + + +def device_coords_to_logical_id(device_coords, axis_sizes, axis_indices): + if isinstance(device_coords, dict): + device_coords, non_mesh_axes = _device_id_dict_to_mesh( + device_coords, axis_sizes, axis_indices + ) + if non_mesh_axes: + raise NotImplementedError(non_mesh_axes) + if not isinstance(device_coords, tuple): + device_coords = (device_coords,) + assert len(device_coords) == len(axis_sizes) + sizes = list(axis_sizes.values()) + ret = 0 + for i in range(len(device_coords)): + ret += device_coords[i] * math.prod(sizes[i + 1 :]) + return ret + + +def _device_id_to_logical(device_id, device_id_type, axis_sizes, axis_indices): + if device_id is None: + return None + if device_id_type == primitives.DeviceIdType.MESH: + return device_coords_to_logical_id(device_id, axis_sizes, axis_indices) + elif device_id_type == primitives.DeviceIdType.LOGICAL: + return device_id + else: + raise ValueError(f"Unsupported device ID type: {device_id_type}") + + +def is_int(dtype): + return jnp.issubdtype(dtype, jnp.integer) + + +def is_float(dtype): + return jnp.issubdtype(dtype, jnp.floating) + + +@dataclasses.dataclass(frozen=True) +class Placeholder: + """Placeholder for use in `JaxprEnv` below instead of storing a concrete value.""" + + shape: tuple[int, ...] + dtype: jnp.dtype + + +class JaxprEnv: + """An environment for interpreting jaxprs, mapping variables to values.""" + + def __init__( + self, + *, + vars: Sequence[jax_core.Var] | None = None, + values: Sequence[Any] | None = None, + sentinel_for_floating_point_values: Any = None, + ): + self._sentinel_for_floating_point_values = ( + sentinel_for_floating_point_values + ) + self._env: dict[jax_core.Var, Any] = {} + + if vars is None and values is None: + return + + vars = vars or [] + values = values or [] + self.write_many(vars, values) + + def read(self, var): + if isinstance(var, jax_core.Literal): + result = var.val + else: + result = self._env[var] + if isinstance(result, Placeholder): + result = lax.full( + result.shape, self._sentinel_for_floating_point_values, result.dtype + ) + return result + + def read_many(self, vars): + return safe_map(self.read, vars) + + def write(self, var, value): + if self._sentinel_for_floating_point_values and is_float(value.dtype): + value = Placeholder(value.shape, value.dtype) + self._env[var] = value + + def write_many(self, vars, values): + safe_map(self.write, vars, values) + + +def _transform_slice_or_index(slice_or_idx): + if isinstance(slice_or_idx, int): + return slice_or_idx + else: + start = int(slice_or_idx.start) + size = int(slice_or_idx.size) + stride = int(slice_or_idx.stride) + return slice(start, start + size * stride, stride) + + +def _compose_slice_or_index(slice_or_idx1, slice_or_idx2): + ret = [] + i = 0 + j = 0 + while True: + if i == len(slice_or_idx1): + ret.extend(slice_or_idx2[j:]) + return tuple(ret) + elif j == len(slice_or_idx2): + ret.extend(slice_or_idx1[i:]) + return tuple(ret) + elif isinstance(slice_or_idx1[i], int): + ret.append(slice_or_idx1[i]) + i += 1 + elif isinstance(slice_or_idx2[j], int): + ret.append( + slice_or_idx1[i].start + slice_or_idx2[j] * slice_or_idx1[i].step + ) + i += 1 + j += 1 + else: + ret.append( + slice( + slice_or_idx1[i].start + + slice_or_idx2[j].start * slice_or_idx1[i].step, + slice_or_idx1[i].start + + slice_or_idx2[j].stop * slice_or_idx1[i].step, + slice_or_idx1[i].step * slice_or_idx2[j].step, + ) + ) + i += 1 + j += 1 + + +def to_range(transforms) -> tuple[slice | int, ...]: + ret = () + for transform in transforms: + # For now, assume only NDIndexer transforms. + ret = _compose_slice_or_index( + ret, tuple(_transform_slice_or_index(i) for i in transform.indices) + ) + return ret diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py index 01e19e8c9b75..b6826677d874 100644 --- a/tests/pallas/tpu_pallas_interpret_test.py +++ b/tests/pallas/tpu_pallas_interpret_test.py @@ -862,7 +862,7 @@ def f(x): dimension_semantics=('parallel',), ), interpret=pltpu.InterpretParams( - num_cores_per_device=2, + num_cores_or_threads_per_device=2, detect_races=False, ), )(x) @@ -872,7 +872,7 @@ def f(x): np.testing.assert_allclose(y, 2.0 * x) with pltpu.force_tpu_interpret_mode(pltpu.InterpretParams( - num_cores_per_device=1, + num_cores_or_threads_per_device=1, detect_races=True, )): y = f(x).block_until_ready() @@ -881,7 +881,7 @@ def f(x): self.assertEqual(trace_count[0], 2) with pltpu.force_tpu_interpret_mode(pltpu.InterpretParams( - num_cores_per_device=2, + num_cores_or_threads_per_device=2, detect_races=True, )): y = f(x).block_until_ready() @@ -913,7 +913,7 @@ def kernel(x_ref, o_ref, vmem_ref): pltpu.VMEM((8, 128), x.dtype), ], interpret=pltpu.InterpretParams( - num_cores_per_device=2, + num_cores_or_threads_per_device=2, detect_races=True, ), compiler_params=pltpu.CompilerParams( @@ -948,7 +948,7 @@ def kernel_call(s, num_cores_per_device, grid_point_recorder): out_specs=pl.BlockSpec((8, 128), lambda i, j: (i, j)), interpret=pltpu.InterpretParams( random_seed=12345, - num_cores_per_device=num_cores_per_device, + num_cores_or_threads_per_device=num_cores_per_device, grid_point_recorder=grid_point_recorder, detect_races=True, ), diff --git a/tests/pallas/tpu_pallas_interpret_thread_map_test.py b/tests/pallas/tpu_pallas_interpret_thread_map_test.py index 6ca34248c1dc..c0a6b05f3574 100644 --- a/tests/pallas/tpu_pallas_interpret_thread_map_test.py +++ b/tests/pallas/tpu_pallas_interpret_thread_map_test.py @@ -19,7 +19,7 @@ from absl.testing import absltest import jax from jax._src import test_util as jtu -from jax._src.pallas.mosaic.interpret import interpret_pallas_call as mosaic_interpret +from jax._src.pallas.mosaic.interpret.thread_map import thread_map jax.config.parse_flags_with_absl() @@ -61,8 +61,11 @@ def f(core_index): del core_index jax.experimental.io_callback(_barrier, (), ordered=True) - mosaic_interpret._thread_map(f, 8) + thread_map(f, 8) self.assertEqual(max_concurrent_calls[0], 8) + # `thread_map` returns only after all threads have completed, so the final + # value of `concurrent_calls` should be zero. + self.assertEqual(concurrent_calls[0], 0) if __name__ == '__main__': From b44a511b8f1da685531341dfdb40954de8bec780 Mon Sep 17 00:00:00 2001 From: George Necula Date: Wed, 17 Dec 2025 06:05:31 -0800 Subject: [PATCH 239/315] Reverts 58fbb8a3b7d4ecb860c4cf678cae1dd6f1079846 PiperOrigin-RevId: 845733949 --- tests/export_back_compat_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index e41d29a53cf8..41a4b99ed944 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -813,7 +813,7 @@ def func(x): # b: f32[2, 4] # the expected custom call targets for old test data that was serialized # with custom calls. for data, custom_call_targets_override in data: - with jax.set_mesh(mesh): + with mesh: if jax.config.jax_use_shardy_partitioner: self.run_one_test( func, self.load_testdata(data["shardy"]), @@ -1040,7 +1040,7 @@ def shard_map_func(x): # b: f32[2, 4] # the expected custom call targets for old test data that was serialized # with custom calls. for data, custom_call_targets_override in data: - with jax.set_mesh(Mesh(devices, axis_names=('x'))): + with Mesh(devices, axis_names=('x')): self.run_one_test( func, self.load_testdata(data), expect_current_custom_calls=custom_call_targets_override) From 7f7b35e59debc5a784e0f7058e6d93fbc557fb7f Mon Sep 17 00:00:00 2001 From: Vladimir Belitskiy Date: Wed, 17 Dec 2025 06:49:27 -0800 Subject: [PATCH 240/315] Skip test_scalar_debug_check in tpu_sparsecore_pallas_debug_check_test. Causes pytest testsuite to hang in its entirety, and fails on Bazel, while running on v6e-8. PiperOrigin-RevId: 845747262 --- tests/pallas/tpu_sparsecore_pallas_debug_check_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/pallas/tpu_sparsecore_pallas_debug_check_test.py b/tests/pallas/tpu_sparsecore_pallas_debug_check_test.py index fb95c997c636..55cb1ce37c01 100644 --- a/tests/pallas/tpu_sparsecore_pallas_debug_check_test.py +++ b/tests/pallas/tpu_sparsecore_pallas_debug_check_test.py @@ -68,9 +68,9 @@ def setUp(self): super().setUp() def test_scalar_debug_check(self): - if not jtu.is_device_tpu_at_least(6): - # TODO: b/436509694 - Figure out why the test gets stuck on v5p. - self.skipTest("") + if not jtu.is_device_tpu_at_least(7): + # TODO: b/469486032 - Figure out why the test gets stuck on v5p, v6e. + self.skipTest("Fails on v5p and v6e.") x = jnp.arange(8) From 2040612567367bff6fb6e509a87781337efae021 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 17 Dec 2025 06:57:32 -0800 Subject: [PATCH 241/315] Propagate effects in the abstract eval rule for custom_vmap_p. PiperOrigin-RevId: 845749688 --- jax/_src/custom_batching.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index 8c66d3da942c..e394e8273260 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -260,7 +260,8 @@ def custom_vmap_batching(args_flat, dims, *, call, rule, in_tree, out_tree): def custom_vmap_abstract_eval(*in_avals, call, **_): - return call.out_avals + del in_avals + return call.out_avals, call.effects def custom_vmap_jvp(primals, tangents, *, @@ -347,7 +348,7 @@ def to_vmap_over_extra_batched_dims(primals, tangents): custom_vmap_p = core.Primitive('custom_vmap_call') custom_vmap_p.multiple_results = True custom_vmap_p.def_impl(custom_vmap_impl) -custom_vmap_p.def_abstract_eval(custom_vmap_abstract_eval) +custom_vmap_p.def_effectful_abstract_eval(custom_vmap_abstract_eval) batching.primitive_batchers[custom_vmap_p] = custom_vmap_batching ad.primitive_jvps[custom_vmap_p] = custom_vmap_jvp pxla.register_initial_style_primitive(custom_vmap_p) From 39e6ee1e47da972b272427a1712a77514ad2e753 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 17 Dec 2025 06:58:21 -0800 Subject: [PATCH 242/315] Export pyproject.toml in the BUILD file PiperOrigin-RevId: 845749924 --- jax/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/jax/BUILD b/jax/BUILD index c5ab5c1189cd..0aaed7b61d2c 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -112,6 +112,7 @@ exports_files([ "LICENSE", "version.py", "py.typed", + "oss/pyproject.toml", ]) # Packages that have access to JAX-internal implementation details. From 86116f8f2d47d5e7ad6ce21cf9f9ce831f7919ef Mon Sep 17 00:00:00 2001 From: Vladimir Belitskiy Date: Wed, 17 Dec 2025 07:02:58 -0800 Subject: [PATCH 243/315] Causing segfaults in python_callback_test Reverts a2e14e41a5554e875434f7b7f0562317aee11804 PiperOrigin-RevId: 845751395 --- jaxlib/mosaic/gpu/BUILD | 4 - jaxlib/mosaic/gpu/custom_call.cc | 283 +++++++------------------- jaxlib/mosaic/gpu/custom_call_test.cc | 155 +------------- jaxlib/mosaic/gpu/nvshmem.h | 7 - 4 files changed, 83 insertions(+), 366 deletions(-) diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index b0e93a39b3a4..b7359495f184 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -235,7 +235,6 @@ cc_library( "-Wl,--export-dynamic-symbol='nvshmemx_mc_ptr'", "-Wl,--export-dynamic-symbol='nvshmemx_barrier_all_on_stream'", "-Wl,--export-dynamic-symbol='nvshmemx_cumodule_init'", - "-Wl,--export-dynamic-symbol='nvshmemx_cumodule_finalize'", "-Wl,--export-dynamic-symbol='nvshmemx_init_status'", ], deps = [ @@ -352,9 +351,6 @@ cc_test( deps = [ ":mosaic_gpu_support", "//testing/base/public:gunit_main", - "@com_google_absl//absl/base:log_severity", - "@com_google_absl//absl/log:globals", - "@com_google_absl//absl/log:scoped_mock_log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings", diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index f1d7e1199ee6..29931d56e139 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -50,10 +50,8 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "third_party/gpus/cuda/include/cuda.h" -#include "third_party/gpus/cuda/include/driver_types.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/CodeGen.h" #include "llvm/Support/Debug.h" @@ -132,8 +130,6 @@ limitations under the License. namespace { -using ::mosaic::gpu::NvshmemApi; - namespace ffi = xla::ffi; namespace se = stream_executor; @@ -511,25 +507,19 @@ absl::StatusOr, bool>> Compile( class CompiledKernel { public: CompiledKernel(std::unique_ptr engine, void* ctx, - CUmodule module, MosaicHostFunc* host_launch, - bool is_comm_used) + MosaicHostFunc* host_launch, bool is_comm_used) : engine_(std::move(engine)), ctx_(ctx), - module_(module), host_launch_(host_launch), is_comm_used_(is_comm_used) {} - std::tuple GetHostLaunch() const { + std::tuple GetHostLaunch() { return std::make_tuple(ctx_, host_launch_, is_comm_used_); } - CUmodule module() const { return module_; } - bool is_comm_used() const { return is_comm_used_; } - private: std::unique_ptr engine_; void* ctx_; // TODO(apaszke): Destroy this properly - CUmodule module_; MosaicHostFunc* host_launch_; bool is_comm_used_; }; @@ -569,7 +559,7 @@ absl::StatusOr> GetHostAndInitFuncNames( return std::make_pair(host_func_name, init_func_name); } -absl::StatusOr CompileAndInit(absl::string_view module) { +absl::StatusOr CompileAndInit(llvm::StringRef module) { mlir::MLIRContext context(mlir::MLIRContext::Threading::DISABLED); context.allowUnregisteredDialects(true); InitContext(&context); @@ -610,30 +600,13 @@ absl::StatusOr CompileAndInit(absl::string_view module) { void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr}; reinterpret_cast(*init)(init_args); return CompiledKernel(std::move(maybe_engine.value().first), kernel_ptr, - reinterpret_cast(module_ptr), reinterpret_cast(*host), is_comm_used); } -absl::Status Unload(const CompiledKernel& kernel, CUcontext ctx) { - CUDA_RETURN_IF_ERROR(cuCtxPushCurrent(ctx)); - if (kernel.is_comm_used()) { - if (NvshmemApi::Default().cumodule_finalize(kernel.module()) != - NVSHMEM_SUCCESS) { - return absl::InternalError("nvshmemx_cumodule_finalize failed"); - } - } - CUDA_RETURN_IF_ERROR(cuModuleUnload(kernel.module())); - CUcontext unused; - CUDA_RETURN_IF_ERROR(cuCtxPopCurrent(&unused)); - return absl::OkStatus(); -} - using KernelHash = std::array; +using CacheKey = std::pair; -// A reference counted cache of compiled and loaded kernels. -class KernelCache { - public: - // A global cache of compiled and loaded kernels. +struct KernelCache { static KernelCache& Global() { static absl::NoDestructor cache; return *cache; @@ -644,89 +617,80 @@ class KernelCache { KernelCache(const KernelCache&) = delete; KernelCache(KernelCache&&) = delete; - // Holds a reference to a compiled and loaded kernel. - // Unload the kernel when the handle is destroyed. - class KernelHandle { - public: - KernelHandle(CompiledKernel kernel, CUcontext ctx) - : kernel_(std::move(kernel)), ctx_(ctx) {} - ~KernelHandle() { - CHECK_OK(Unload(kernel_, ctx_)); - VLOG(5) << "Successfully unloaded GPU module"; - } - const CompiledKernel* kernel() const { return &kernel_; } - - private: - CompiledKernel kernel_; - CUcontext ctx_; // The CUDA context in which the kernel was loaded. - }; - - // Compile and load the given module in the current CUDA context. - absl::StatusOr> CompileAndInit( - const KernelHash& kernel_hash, absl::string_view module) { - CUcontext ctx; - CUDA_RETURN_IF_ERROR(cuCtxGetCurrent(&ctx)); - CacheKey key(kernel_hash, reinterpret_cast(ctx)); - absl::MutexLock lock(mutex_); - if (auto it = kernels_.find(key); it != kernels_.end()) { - std::shared_ptr handle = it->second.lock(); - if (handle) { - return handle; - } - } - // Kernel not found or has expired, create a new value. - tsl::profiler::TraceMe trace("Compilation cache miss"); - TF_ASSIGN_OR_RETURN(CompiledKernel compiled, ::CompileAndInit(module)); - VLOG(5) << "Successfully compiled and initialized Mosaic GPU kernel"; - auto handle = std::make_shared(std::move(compiled), ctx); - kernels_[key] = handle; - return handle; - } - - private: - using CacheKey = std::pair; - absl::Mutex mutex_; - absl::flat_hash_map> kernels_ - ABSL_GUARDED_BY(mutex_); + absl::Mutex mutex; + absl::flat_hash_map kernels ABSL_GUARDED_BY(mutex); }; -// Tracks the compiled and loaded kernels for a given custom call. -// There is a single global cache in the process and a process can have -// multiple devices, each of which must load/unload the module. We expect each -// device/module pair to have a unique cache key. -class CustomCallResources { - public: - CustomCallResources() = default; +// Each compiled kernel has a unique init func, and each kernel is used from +// a single HLO module. So it should be safe to not include the CUDA context +// in the key. +absl::StatusOr CachedCompileAndInit(CacheKey key, + llvm::StringRef module) { + KernelCache& cache = KernelCache::Global(); - const CompiledKernel* KernelForDevice(int32_t device_ordinal) const { - absl::MutexLock lock(mutex_); - return kernels_.at(device_ordinal)->kernel(); + { + // Fast path uses reader lock (as hash map look-up is relatively slow). + absl::ReaderMutexLock lock(cache.mutex); + auto it = cache.kernels.find(key); + if (ABSL_PREDICT_TRUE(it != cache.kernels.end())) return &it->second; } - void AddKernel(int32_t device_ordinal, - std::shared_ptr kernel) { - absl::MutexLock lock(mutex_); - kernels_[device_ordinal] = std::move(kernel); + absl::MutexLock lock(cache.mutex); + // We released the reader lock, another thread might have initialized it. + if (cache.kernels.find(key) == cache.kernels.end()) { + tsl::profiler::TraceMe trace("Compilation cache miss"); + auto compiled = CompileAndInit(module); + if (!compiled.ok()) { + return compiled.status(); + } + cache.kernels.insert_or_assign(key, std::move(*compiled)); } + return &cache.kernels.at(key); +} - private: - mutable absl::Mutex mutex_; - absl::flat_hash_map> - kernels_ ABSL_GUARDED_BY(mutex_); -}; - -absl::StatusOr> InstantiateResources() { - // TODO(b/466097203): Ideally we would compile the module here. - // Sadly we need to acquire a lock on LLVM command line options which is - // already held by XLA causing a deadlock. - // See `GpuCompiler::CompileToBackendResult`. - return std::make_unique(); +// TODO(b/464203195): Backward-compatible version using the legacy FFI +// API. Remove once backward compatibility window has passed. +void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque, + size_t opaque_len, XlaCustomCallStatus* status) { + if (reinterpret_cast(opaque) % alignof(KernelHash)) { + fprintf(stderr, "Misaligned opaque pointer\n"); + abort(); + } + auto hash = *reinterpret_cast(opaque); + CUcontext ctx; + if (cuCtxGetCurrent(&ctx) != CUDA_SUCCESS) { + fprintf(stderr, "Failed to get current CUDA context\n"); + abort(); + } + CacheKey key(hash, reinterpret_cast(ctx)); + auto compiled_kernel = CachedCompileAndInit(key, opaque + sizeof(KernelHash)); + if (!compiled_kernel.ok()) { + XlaCustomCallStatusSetFailure(status, + compiled_kernel.status().message().data(), + compiled_kernel.status().message().size()); + return; + } + auto ctx_kernel_comm = (*compiled_kernel)->GetHostLaunch(); + bool is_comm_used = std::get<2>(ctx_kernel_comm); + void* args[4] = {&std::get<0>(ctx_kernel_comm), &stream, &buffers}; + if (is_comm_used) { + mosaic::gpu::NvshmemApi::Default().barrier_all_on_stream( + reinterpret_cast(stream)); + } + std::get<1>(ctx_kernel_comm)(args); } -absl::Status InitializeResources(int32_t device_ordinal, - CustomCallResources* resources, - std::string_view kernel_hash, - std::string_view module, bool) { +XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("mosaic_gpu", &MosaicGPUCustomCall, + "CUDA"); + +absl::Status MosaicGpuExecute(cudaStream_t stream, ffi::RemainingArgs inputs, + ffi::RemainingRets results, + std::string_view kernel_hash, + std::string_view module, + bool use_custom_barrier) { + if (use_custom_barrier) { + return absl::UnimplementedError("Custom barrier is not supported on GPUs."); + } if (kernel_hash.size() != sizeof(KernelHash)) { return absl::InvalidArgumentError( absl::StrFormat("Kernel hash size is %d bytes, expected %d bytes", @@ -734,23 +698,11 @@ absl::Status InitializeResources(int32_t device_ordinal, } KernelHash hash; std::memcpy(hash.data(), kernel_hash.data(), sizeof(KernelHash)); - TF_ASSIGN_OR_RETURN( - std::shared_ptr handle, - KernelCache::Global().CompileAndInit(hash, module)); - resources->AddKernel(device_ordinal, std::move(handle)); - return absl::OkStatus(); -} - -absl::Status MosaicGpuExecute(cudaStream_t stream, int32_t device_ordinal, - ffi::RemainingArgs inputs, - ffi::RemainingRets results, - CustomCallResources* resources, std::string_view, - std::string_view, bool use_custom_barrier) { - if (use_custom_barrier) { - return absl::UnimplementedError("Custom barrier is not supported on GPUs."); - } - const CompiledKernel* compiled_kernel = - resources->KernelForDevice(device_ordinal); + CUcontext ctx; + CUDA_RETURN_IF_ERROR(cuCtxGetCurrent(&ctx)); + CacheKey key(hash, reinterpret_cast(ctx)); + TF_ASSIGN_OR_RETURN(auto compiled_kernel, + CachedCompileAndInit(key, module)); auto ctx_kernel_comm = compiled_kernel->GetHostLaunch(); bool is_comm_used = std::get<2>(ctx_kernel_comm); @@ -778,30 +730,17 @@ absl::Status MosaicGpuExecute(cudaStream_t stream, int32_t device_ordinal, void* args[4] = {&std::get<0>(ctx_kernel_comm), &stream, &buffers_ptr}; if (is_comm_used) { - NvshmemApi::Default().barrier_all_on_stream(stream); + mosaic::gpu::NvshmemApi::Default().barrier_all_on_stream(stream); } std::get<1>(ctx_kernel_comm)(args); return absl::OkStatus(); } -XLA_FFI_DEFINE_HANDLER(kInstantiateResources, InstantiateResources, - ffi::Ffi::BindInstantiate()); - -XLA_FFI_DEFINE_HANDLER(kInitializeResources, InitializeResources, - ffi::Ffi::BindInitialize() - .Ctx() - .Ctx>() - .Attr("kernel_hash") - .Attr("module") - .Attr("use_custom_barrier")); - XLA_FFI_DEFINE_HANDLER(kMosaicGpuExecute, MosaicGpuExecute, ffi::Ffi::Bind() .Ctx>() - .Ctx() .RemainingArgs() .RemainingRets() - .Ctx>() .Attr("kernel_hash") .Attr("module") .Attr("use_custom_barrier"), @@ -809,78 +748,12 @@ XLA_FFI_DEFINE_HANDLER(kMosaicGpuExecute, MosaicGpuExecute, XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "mosaic_gpu_v2", "CUDA", { - /*instantiate=*/kInstantiateResources, + /*instantiate=*/nullptr, /*prepare=*/nullptr, - /*initialize=*/kInitializeResources, + /*initialize=*/nullptr, /*execute=*/kMosaicGpuExecute, }); -// Cache compiled and loaded kernels in the current CUDA context. -// Loaded kernels are never unloaded. -absl::StatusOr LegacyCachedCompileAndInit( - const KernelHash& kernel_hash, absl::string_view module) { - using CacheKey = std::pair; - struct LegacyCache { - absl::Mutex mutex; - absl::flat_hash_map kernels - ABSL_GUARDED_BY(mutex); - }; - static absl::NoDestructor cache; - - CUcontext ctx; - CUDA_RETURN_IF_ERROR(cuCtxGetCurrent(&ctx)); - - CacheKey key(kernel_hash, reinterpret_cast(ctx)); - { - // Fast path uses reader lock (as hash map look-up is relatively slow). - absl::ReaderMutexLock lock(cache->mutex); - auto it = cache->kernels.find(key); - if (ABSL_PREDICT_TRUE(it != cache->kernels.end())) return &it->second; - } - - absl::MutexLock lock(cache->mutex); - // We released the reader lock, another thread might have initialized it. - if (cache->kernels.find(key) == cache->kernels.end()) { - tsl::profiler::TraceMe trace("Compilation cache miss"); - auto compiled = CompileAndInit(module); - if (!compiled.ok()) { - return compiled.status(); - } - cache->kernels.insert_or_assign(key, std::move(*compiled)); - } - return &cache->kernels.at(key); -} - -// TODO(b/464203195): Backward-compatible version using the legacy FFI -// API. Remove once backward compatibility window has passed. -void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - if (reinterpret_cast(opaque) % alignof(KernelHash)) { - fprintf(stderr, "Misaligned opaque pointer\n"); - abort(); - } - auto hash = *reinterpret_cast(opaque); - auto compiled_kernel = - LegacyCachedCompileAndInit(hash, opaque + sizeof(KernelHash)); - if (!compiled_kernel.ok()) { - XlaCustomCallStatusSetFailure(status, - compiled_kernel.status().message().data(), - compiled_kernel.status().message().size()); - return; - } - auto ctx_kernel_comm = (*compiled_kernel)->GetHostLaunch(); - bool is_comm_used = std::get<2>(ctx_kernel_comm); - void* args[4] = {&std::get<0>(ctx_kernel_comm), &stream, &buffers}; - if (is_comm_used) { - mosaic::gpu::NvshmemApi::Default().barrier_all_on_stream( - reinterpret_cast(stream)); - } - std::get<1>(ctx_kernel_comm)(args); -} - -XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("mosaic_gpu", &MosaicGPUCustomCall, - "CUDA"); - } // namespace extern "C" { diff --git a/jaxlib/mosaic/gpu/custom_call_test.cc b/jaxlib/mosaic/gpu/custom_call_test.cc index d3426c0fd71a..e4756a394325 100644 --- a/jaxlib/mosaic/gpu/custom_call_test.cc +++ b/jaxlib/mosaic/gpu/custom_call_test.cc @@ -19,9 +19,6 @@ limitations under the License. #include #include -#include "absl/base/log_severity.h" -#include "absl/log/globals.h" -#include "absl/log/scoped_mock_log.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/strings/str_cat.h" @@ -39,7 +36,6 @@ limitations under the License. namespace { using ::absl_testing::IsOk; -using ::testing::_; absl::Status ExecuteSync(xla::PjRtLoadedExecutable* executable) { std::vector no_buffers; @@ -69,16 +65,16 @@ ENTRY main { custom_call_target="mosaic_gpu_v2", api_version=API_VERSION_TYPED_FFI })"; - ASSERT_OK_AND_ASSIGN(auto module, - xla::ParseAndReturnUnverifiedModule(kHloModule)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + xla::ParseAndReturnUnverifiedModule(kHloModule)); std::string tmp_path = testing::TempDir(); tsl::setenv("XLA_FLAGS", absl::StrCat("--xla_dump_to=", tmp_path).c_str(), /*overwrite=*/true); - ASSERT_OK_AND_ASSIGN(std::unique_ptr client, - xla::GetXlaPjrtGpuClient(/*options=*/{})); - ASSERT_OK_AND_ASSIGN( + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, + xla::GetXlaPjrtGpuClient(/*options=*/{})); + TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr executable, client->CompileAndLoad(xla::XlaComputation(module->ToProto()), /*options=*/{})); @@ -138,145 +134,4 @@ TEST(CustomCallTest, LegacyCustomCall) { EXPECT_THAT(ExecuteSync(executable.get()), IsOk()); } -absl::string_view TestMGPUHloModule() { - // Dumped from the following JAX program: - // - // ``` - // @functools.partial( - // plgpu.pallas_call, - // out_shape=jax.ShapeDtypeStruct((), jnp.int32), - // out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), - // ) - // def kernel(o_ref): - // o_ref[...] = jnp.array(42) - // ``` - return R"hlo( - HloModule test - - ENTRY main { - ROOT result = s32[] custom-call(), custom_call_target="mosaic_gpu_v2", api_version=API_VERSION_TYPED_FFI, backend_config={kernel_hash = "\90\C7\1F$\92=c\9D\E4\A8\15\B1Y\9B.\02\B4\B0\0B\16\C5Ol\D4\ED\CDdA-\C9\D77", module = "ML\EFR\01MLIR\00\01O\0D\01\03\05\07\09\0B\01\03\0D\037\0F\11\13\15\17\19\1B\1D\1F!#%')+-/13579;=?AC\03\12\02\C9\1D\01\BB\0F\13\0B\0B\0F\13\13\13\13\0B\07\0B\0B\13\13\0B\0F\13\13\13e\1B\0B\0F\0B\0B#\0B\0B\0B\0B;\0B\0B\0B\0B\0B\0B\0B#\0B\0B\07\0B\13\0F\0F\13\13\13\0F\13\13\0B\133\133\133U\1B\0B\C3\0B\13\13\13\13\13\13\13\13\13\17\17\17\0B\0F\1F\0F\0B\0B\13\13\0B\0B\0F\0B\0F\0B\17\0B\05\03a\07\09y111\09\03Y\0B\03U\01\15\0F\07\0F\0B\0B\1B/\17\13;\05\07)yQ\07\03E\02\AE\0A\1D3\15\03\03\9B\C5\05E\05G\11\05\01\03\03\07]\03\03\19\BF\03\03\19\C1\03\03\19\C3\05I\1F\05K\05M\03\03\07\9D\03\03\A5\09\05O\11\01\11\03\03\07\9F\03\03\07\A1\03\03\A3\C7affine_map<(d0) -> (d0)>\00\03\05-/\131\05Q\11\05\19\05S\05U\03\07\1F7\139;=\0D\0D\05W\05Y\05[\03\0DA!CEG\BB\13IK\09M\09\05]\05_\0D\19\05a\05c\05e\05g\03\07\1FQSU\13W\0D\0F\05i\0F\05k\03\03\07[\11\01\A9\11\01\01\03\03\07a\11\03\02\04\03\03\07e\11\03\05\03\03\07\09\03\03k\09\05m\03\03\17o#\05\03\11\00\00\00\00\00\00\00\00\03\03\17s#\05\03\11\01\00\00\00\00\00\00\00\03\03\17w#\05\03\11\02\00\00\00\00\00\00\00affine_map<() -> ()>\00\03\05}\7F\81\09\05o#\01\17Y\01\00\00\00\01\00\00\00\01\00\00\00\01\00\00\00\01\00\00\00\01\00\00\00\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\05q\17\05%O\17\05%]\17\05%k\17\05%\E1\17\05%\EF\17\05%\FD\17\05%\81\17\05%\9B\17\05%\B5\17\05%&\02\17\05%f\02\17\05%\9E\02\05s\11\01\15\11\01\D0\FF\FF\FF?\11\01}\05u\05w\03\03\07!\03\03\AB\AD\05y\01\01\1D\B1\B3\05{\1D\B5\B7\05}\17\B9\06\03\0D\05\7F#llvm.linkage\00#gpu.address_space\00#gpu\00#gpu\00#gpu\00#arith.overflow\00#nvvm\00\01\02\02\03\01\02\04\01\09\01A\17\BD\03\01\09)\05\11\15\15\05\05\15\15\05\15\01\05\05\15\15\01\15\01\01y\17\BD\03\00\FF\FF\FF\FF\FF\FF\FF\FF\09)!llvm.ptr\00!llvm.struct<(ptr, ptr, i64)>\00!llvm.array<0 x i8>\00!gpu.async.token\00\04Z\0C\05\01\11\01+\07\03\01\0D\17\11\015\07\01\1F\11\01?\07\01\17\11\01O\07\03\1F;\05\15\01\15\01\05\03\15Y\03\01\05\03\15\0B\03\01\05\03\01_\03\03\05\03\15c\03\03!\03\01g\03\05#\02\01\03\17\0F\06\01\03\1B\03\01%\07\01i\03\15\03\03\11\07\01m\03\17\05\0F\13\11\07\01q\03\17\05\15\13\11\07\01u\03\17\05\17\0D\0F\06\01\03\11\03\19'\17\01{\03\1B\11\11\0B\0B\0B\09\0B\0B\07\05\03\C1\C6\02\19\03\83\03\85\03\87\03\89\03\8B\03\8D\03\8F\03\91\03\93\03\95\03\97\03\99\19\02\01\03\07\09\03\01\0D\03\03\03\06\01\03\01\039\0B\03\01\0D\03\03\03\06\01\03\01\03=\09\03\01\0F\03\03\03\06\01\03\01\03A\07\07\01\03\03\01\05C?\0D\07\01\03\03\01\05;E\0B\03\01\0F\03\03\03\06\01\03\01\03I\07\07\01\03\03\01\05?K\09\03\01\11\03\03\03\06\01\03\01\03O\07\07\01\03\03\01\05QM\0D\07\01\03\03\01\05GS\0B\03\01\11\03\03\03\06\01\03\01\03W\07\07\01\03\03\01\05MY\05\03\01\1B\03\01\13\06\01\03\01\05U]\05\03\01#\03\01\05\03\01\0B\03\01\05\03\01%\03\01\1B\07\01'\03\01\09a_ce\05\03\01\0B\03\01\15\07\01\1D\03\07\05gi\1D\06\01\03\07\05k7\19\02\01\03\07\09\03\01\0D\03\03\03\06\01\03\01\03q\0B\03\01\0D\03\03\03\06\01\03\01\03u\09\03\01\0F\03\03\03\06\01\03\01\03y\07\07\01\03\03\01\05{w\0D\07\01\03\03\01\05s}\0B\03\01\0F\03\03\03\06\01\03\01\03\81\07\07\01\03\03\01\05w\83\09\03\01\11\03\03\03\06\01\03\01\03\87\07\07\01\03\03\01\05\89\85\0D\07\01\03\03\01\05\7F\8B\0B\03\01\11\03\03\03\06\01\03\01\03\8F\07\07\01\03\03\01\05\85\91\05\03\01\1B\03\01\13\06\01\03\01\05\8D\95\05\03\01#\03\01\05\03\01\0B\03\01\05\03\01%\03\01\1B\07\01'\03\01\09\99\97\9B\9D\05\03\01\A7\03\01+\06\01\03\01\05\9F\A1\05\03\01\0B\03\01\15\07\01\1D\03\07\05\A3\A5\1D\06\01\03\07\05\A7o\09\03\01\0D\03\03\03\06\01\03\01\03\AB\0B\03\01\0D\03\03\03\06\01\03\01\03\AF\09\03\01\0F\03\03\03\06\01\03\01\03\B3\07\07\01\03\03\01\05\B5\B1\0D\07\01\03\03\01\05\AD\B7\0B\03\01\0F\03\03\03\06\01\03\01\03\BB\07\07\01\03\03\01\05\B1\BD\09\03\01\11\03\03\03\06\01\03\01\03\C1\07\07\01\03\03\01\05\C3\BF\0D\07\01\03\03\01\05\B9\C5\0B\03\01\11\03\03\03\06\01\03\01\03\C9\07\07\01\03\03\01\05\BF\CB\05\03\01\1B\03\01\13\06\01\03\01\05\C7\CF\05\03\01\0B\03\01\15\07\01\1D\03\07\05\D1\D3-\02\01\03\13\03\06\01\03\03\03\07/\06\01\03\0B\05\D7\D9\0F\07\01\A9\03\0B\03\DB1\00\013\00\015\04\AF\05\05\1B7\00\01)\00\01\06\03\01\05\01\00\9E\0E\81g\0B\0D\17\15\0B\1D/)\13%-\19\1B\1F\11\19\17\11\1F3\19\0F5\1D\15\13\13\0D\05\1F\1B\193\195\19\19\17\15!'#\17\1F!\15\17\19#G\17\1D\1D\17\1F#\0F\0B\0D\09\0B%\11builtin\00stable_mosaic_gpu\00llvm\00gpu\00arith\00nvvm\00module\00arith.index_cast\00arith.constant\00arith.muli\00gpu.thread_id\00gpu.block_dim\00arith.addi\00builtin.unrealized_conversion_cast\00llvm.insertvalue\00arith.shrui\00arith.cmpi\00func.func\00nvvm.elect.sync\00nvvm.shfl.sync\00arith.andi\00llvm.mlir.global\00llvm.mlir.constant\00llvm.mlir.undef\00llvm.load\00gpu.launch\00func.return\00arith.remui\00gpu.dynamic_shared_memory\00memref.view\00nvvm.fence.mbarrier.init\00gpu.barrier\00memref.store\00gpu.terminator\00-\00value\00sym_name\00position\00dimension\00function_type\00stable_mosaic_gpu.version\00kernel\00pallas_call\00mosaic_gpu_init_tma_desc\00sym_visibility\00private\00addr_space\00global_type\00linkage\00global_scratch\00unnamed_addr\00visibility_\00llvm.emit_c_interface\00kernel_mosaic_gpu\00ordering\00operandSegmentSizes\00workgroup_attributions\00overflowFlags\00kind\00predicate\00transforms\00swap:\00swap\00third_party/py/jax/tests/pallas/mosaic_gpu_test.py\00", use_custom_barrier = false} - } - )hlo"; -} - -TEST(CustomCallTest, UnloadGPUModule) { - ASSERT_OK_AND_ASSIGN( - auto module, xla::ParseAndReturnUnverifiedModule(TestMGPUHloModule())); - - ASSERT_OK_AND_ASSIGN(std::unique_ptr client, - xla::GetXlaPjrtGpuClient(/*options=*/{})); - ASSERT_OK_AND_ASSIGN( - std::unique_ptr executable, - client->CompileAndLoad(xla::XlaComputation(module->ToProto()), - /*options=*/{})); - - absl::SetVLogLevel("custom_call", 5); - { - absl::ScopedMockLog log; - EXPECT_CALL(log, - Log(absl::LogSeverity::kInfo, _, - "Successfully compiled and initialized Mosaic GPU kernel")) - .Times(1); - log.StartCapturingLogs(); - EXPECT_THAT(ExecuteSync(executable.get()), IsOk()); - } - - { - // The second execution the compilation should be cached. - absl::ScopedMockLog log; - EXPECT_CALL(log, - Log(absl::LogSeverity::kInfo, _, - "Successfully compiled and initialized Mosaic GPU kernel")) - .Times(0); - log.StartCapturingLogs(); - EXPECT_THAT(ExecuteSync(executable.get()), IsOk()); - } - - { - // GPU module should be unloaded when the executable is destroyed. - absl::ScopedMockLog log; - EXPECT_CALL(log, Log(absl::LogSeverity::kInfo, _, - "Successfully unloaded GPU module")) - .Times(1); - log.StartCapturingLogs(); - executable.reset(); - } -} - -TEST(CustomCallTest, GPUModuleIsOnlyUnloadedWhenAllExecutablesAreDestroyed) { - ASSERT_OK_AND_ASSIGN( - auto module, xla::ParseAndReturnUnverifiedModule(TestMGPUHloModule())); - ASSERT_OK_AND_ASSIGN(std::unique_ptr client, - xla::GetXlaPjrtGpuClient(/*options=*/{})); - ASSERT_OK_AND_ASSIGN( - std::unique_ptr executable1, - client->CompileAndLoad(xla::XlaComputation(module->ToProto()), - /*options=*/{})); - ASSERT_OK_AND_ASSIGN( - std::unique_ptr executable2, - client->CompileAndLoad(xla::XlaComputation(module->ToProto()), - /*options=*/{})); - - EXPECT_THAT(ExecuteSync(executable1.get()), IsOk()); - EXPECT_THAT(ExecuteSync(executable2.get()), IsOk()); - - absl::SetVLogLevel("custom_call", 5); - { - // executable2 still holds a reference to the GPU module. - absl::ScopedMockLog log; - EXPECT_CALL(log, Log(absl::LogSeverity::kInfo, _, - "Successfully unloaded GPU module")) - .Times(0); - log.StartCapturingLogs(); - executable1.reset(); - } - EXPECT_THAT(ExecuteSync(executable2.get()), IsOk()); - { - absl::ScopedMockLog log; - EXPECT_CALL(log, Log(absl::LogSeverity::kInfo, _, - "Successfully unloaded GPU module")) - .Times(1); - log.StartCapturingLogs(); - executable2.reset(); - } -} - -TEST(CustomCallTest, GPUModuleIsRecompiledAfterExpiration) { - ASSERT_OK_AND_ASSIGN( - auto module, xla::ParseAndReturnUnverifiedModule(TestMGPUHloModule())); - ASSERT_OK_AND_ASSIGN(std::unique_ptr client, - xla::GetXlaPjrtGpuClient(/*options=*/{})); - ASSERT_OK_AND_ASSIGN( - std::unique_ptr executable, - client->CompileAndLoad(xla::XlaComputation(module->ToProto()), - /*options=*/{})); - - EXPECT_THAT(ExecuteSync(executable.get()), IsOk()); - - { - absl::ScopedMockLog log; - EXPECT_CALL(log, Log(absl::LogSeverity::kInfo, _, - "Successfully unloaded GPU module")) - .Times(1); - log.StartCapturingLogs(); - executable.reset(); - } - - ASSERT_OK_AND_ASSIGN( - executable, client->CompileAndLoad(xla::XlaComputation(module->ToProto()), - /*options=*/{})); - - { - // executable was destroyed and the module was unloaded. We re-compile the - // kernel. - absl::ScopedMockLog log; - EXPECT_CALL(log, - Log(absl::LogSeverity::kInfo, _, - "Successfully compiled and initialized Mosaic GPU kernel")) - .Times(1); - log.StartCapturingLogs(); - EXPECT_THAT(ExecuteSync(executable.get()), IsOk()); - } -} - } // namespace diff --git a/jaxlib/mosaic/gpu/nvshmem.h b/jaxlib/mosaic/gpu/nvshmem.h index 267f17de8324..dbd11aa1d373 100644 --- a/jaxlib/mosaic/gpu/nvshmem.h +++ b/jaxlib/mosaic/gpu/nvshmem.h @@ -54,11 +54,6 @@ class NvshmemApi { return nvshmemx_cumodule_init(module); } - int cumodule_finalize(CUmodule module) { - std::lock_guard lock(mutex_); - return nvshmemx_cumodule_finalize(module); - } - void barrier_all_on_stream(cudaStream_t stream) { nvshmemx_barrier_all_on_stream(stream); } @@ -83,13 +78,11 @@ class NvshmemApi { NVSHMEM_SET_FN(nvshmemx_barrier_all_on_stream) NVSHMEM_SET_FN(nvshmemx_cumodule_init) - NVSHMEM_SET_FN(nvshmemx_cumodule_finalize) NVSHMEM_SET_FN(nvshmemx_init_status) } int (*nvshmemx_barrier_all_on_stream)(cudaStream_t); int (*nvshmemx_cumodule_init)(CUmodule); - int (*nvshmemx_cumodule_finalize)(CUmodule); int (*nvshmemx_init_status)(); std::mutex mutex_; From d3a80521e61857deaa338d4f2378d79b5163c772 Mon Sep 17 00:00:00 2001 From: Vladimir Belitskiy Date: Wed, 17 Dec 2025 07:20:55 -0800 Subject: [PATCH 244/315] Skip async store reduction tests for int64/uint64 when x64 is disabled. PiperOrigin-RevId: 845757351 --- tests/mosaic/gpu_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 423d0d1e7c3b..c9c919daedf6 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -5402,6 +5402,10 @@ def body( reduction_op=("add", "min", "max", "inc", "dec", "and", "or", "xor"), ) def test_async_store_reduction(self, dtype, reduction_op): + + if not config.enable_x64.value and dtype in (jnp.int64, jnp.uint64): + self.skipTest("x64 support is disabled") + # TODO(b/415721295):Clean up after the minimal jaxlib version is 0.8.2. if not hasattr(mgpu_dialect, "TMAReduction"): self.skipTest("The mgpu_dialect.TMAReduction attribute is required.") From 44f67b03efabf48de14ffb1f23c5ea79294707be Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 17 Dec 2025 07:49:36 -0800 Subject: [PATCH 245/315] Add `concrete_mesh` to reshard_p in pallas lowering rules which was missing PiperOrigin-RevId: 845766226 --- jax/_src/pallas/mosaic/lowering.py | 3 ++- jax/_src/pallas/triton/lowering.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index d6c30be8946f..d17b1aeb0c8a 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -3393,7 +3393,8 @@ def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_): @register_lowering_rule(pjit.reshard_p) -def _reshard_lowering_rule(ctx: LoweringRuleContext, x, dst_sharding): +def _reshard_lowering_rule(ctx: LoweringRuleContext, x, *, dst_sharding, + concrete_mesh): return x diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 2caa9f860bcb..62411e7f8659 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -2537,7 +2537,7 @@ def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_): @register_lowering(pjit.reshard_p) -def _reshard_lowering_rule(ctx, x, dst_sharding): +def _reshard_lowering_rule(ctx, x, *, dst_sharding, concrete_mesh): return x From 4c671ca77c95719fd401c42bf7c69d7e718ed685 Mon Sep 17 00:00:00 2001 From: Gleb Pobudzey Date: Wed, 17 Dec 2025 08:37:05 -0800 Subject: [PATCH 246/315] [Pallas MGPU] Adding support for squeezed block dims in the pipeline BlockSpecs. They can be identified with a `None` or `pl.Squeezed`. PiperOrigin-RevId: 845782837 --- jax/_src/pallas/mosaic_gpu/pipeline.py | 55 +++++++++++++++++------ tests/pallas/mosaic_gpu_test.py | 61 ++++++++++++++++++++++++++ 2 files changed, 103 insertions(+), 13 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index 61d75d5297b7..b2286e3dfc18 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -59,8 +59,15 @@ def _get_block_size( raise NotImplementedError(f"Unsupported block size type: {type(bd)}") def _get_block_shape(spec: pallas_core.BlockSpec): - assert spec.block_shape is not None - return tuple(_get_block_size(bd) for bd in spec.block_shape) + if spec.block_shape is None: + raise ValueError("Block shape must be specified.") + + block_shape = tuple( + _get_block_size(bd) + for bd in spec.block_shape + if not (bd is None or isinstance(bd, pl.Squeezed)) + ) + return block_shape map_brefs = functools.partial( @@ -84,18 +91,27 @@ def get_ref_for_slot( return self.gmem_ref return self.smem_ref.at[slot] - def compute_gmem_slice(self, grid_indices) -> tuple[pl.Slice, ...]: + def compute_gmem_slice(self, grid_indices) -> tuple[pl.Slice | jax.Array, ...]: index_map = self.spec.index_map assert index_map is not None + assert self.spec.block_shape is not None # We don't allow Python scalars here, because they are interpreted # differently depending on the x32/x64 mode. assert all(i.dtype == jnp.dtype(jnp.int32) for i in grid_indices) - sizes = _get_block_shape(self.spec) + + def _make_block_slice(block_index: jax.Array, bd: pl.BlockDim | int | None): + match bd: + case int(): + return pl.Slice(block_index * bd, bd) + case pl.Blocked(block_size): + return pl.Slice(block_index * block_size, block_size) + case None | pl.Squeezed(): + return block_index + case _: + raise ValueError(f"Unsupported block dimension type: {bd}") + return tuple( - pl.Slice(idx * size, size) # type: ignore[arg-type] - for idx, size in zip( - index_map(*grid_indices), sizes # type: ignore[arg-type] - ) + map(_make_block_slice, index_map(*grid_indices), self.spec.block_shape) ) def copy_in(self, slot, grid_indices, barrier_ref, barrier_slot=None): @@ -372,7 +388,8 @@ def loop_body(step, carry): continue assert last_store_slices[idx] is not None new_store_slices[idx] = tuple( - _Slice(s.start, s.size) for s in bref.compute_gmem_slice(indices) + _Slice(s.start, s.size) if isinstance(s, pl.Slice) else s + for s in bref.compute_gmem_slice(indices) ) are_same_slices = map( lambda old, new: old == new, @@ -430,11 +447,16 @@ def do_fetch(): fetch_indices = _inc_grid_by_1(fetch_indices, grid) fetch_index_levels.append(fetch_indices) + def _init_store_slice(bd): + if bd is None or isinstance(bd, pl.Squeezed): + return jnp.array(-1, dtype=jnp.int32) + return _Slice(-1, -1) + # TODO(justinfu): Only store base pointer instead of all indices. last_store_slices = [ None if bref.is_index_invariant - else (_Slice(-1, -1),) * len(bref.spec.block_shape) + else tuple(map(_init_store_slice, bref.spec.block_shape)) for bref in out_brefs ] last_indices, _, _, final_carry = lax.fori_loop( @@ -690,7 +712,7 @@ def _get_scoped_allocs(*gmem_refs: AbstractRefPytree): slots = max_concurrent_steps if has_seq_dim else 1 smem_allocs.append( gpu_core.SMEM( - (slots, *spec.block_shape), # type: ignore + (slots, *_get_block_shape(spec)), # type: ignore gmem_ref.dtype, transforms=getattr(spec, "transforms", ()), ) @@ -880,7 +902,8 @@ def compute_loop_body(step, carry): continue assert last_store_slices[idx] is not None new_store_slices[idx] = tuple( - _Slice(s.start, s.size) for s in bref.compute_gmem_slice(indices) + _Slice(s.start, s.size) if isinstance(s, pl.Slice) else s + for s in bref.compute_gmem_slice(indices) ) are_same_slices = map( lambda old, new: old == new, @@ -895,11 +918,17 @@ def compute_loop_body(step, carry): next_indices = _inc_grid_by_1(indices, grid) return (next_indices, new_store_slices, next_body_carry) init_indices = (jnp.asarray(0, dtype=jnp.int32),) * len(grid) + + def _init_store_slice(bd): + if bd is None or isinstance(bd, pl.Squeezed): + return jnp.array(-1, dtype=jnp.int32) + return _Slice(-1, -1) + # TODO(justinfu): Only store base pointer instead of all indices. last_store_slices = [ None if bref.is_index_invariant - else (_Slice(-1, -1),) * len(bref.spec.block_shape) + else tuple(map(_init_store_slice, bref.spec.block_shape)) for bref in flat_out_brefs ] diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 508b9473fe29..5b65671f91c7 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -4972,6 +4972,35 @@ def kernel_body(_, o_smem, carry): kernel_fn(), jnp.tile(jnp.repeat(jnp.arange(num_steps), 64), (64, 1)) ) + @parameterized.parameters((pl.Squeezed(),), (None,)) + def test_emit_with_squeezed_dim(self, squeezed_dim): + + shape = (16, 256) + num_steps = shape[0] + + def kernel(x_gmem, o_gmem): + plgpu.emit_pipeline( + kernel_body, + in_specs=[pl.BlockSpec((squeezed_dim, shape[1]), lambda i: (i, 0))], + out_specs=[pl.BlockSpec((squeezed_dim, shape[1]), lambda i: (i, 0))], + grid=(num_steps,), + max_concurrent_steps=2, + )(x_gmem, o_gmem) + + def kernel_body(_, in_smem, o_smem): + assert in_smem.shape == (shape[1],) + assert o_smem.shape == (shape[1],) + o_smem[...] = in_smem[...] + 1 + + kernel_fn = self.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct((16, 256), jnp.int32), + ) + x = jnp.arange(16 * 256, dtype=jnp.int32).reshape(16, 256) + np.testing.assert_array_equal(kernel_fn(x), x + 1) + class PipelineWGTest( PipelineTest, lowering_semantics=plgpu.LoweringSemantics.Warpgroup @@ -5661,6 +5690,38 @@ def pipeline_body(_, x_smem, o_smem): ) np.testing.assert_array_equal(y, np.stack([x + 1.0, x + 1.0])) + @parameterized.parameters((pl.Squeezed(),), (None,)) + def test_emit_with_squeezed_dim(self, squeezed_dim): + self.skip_if_wg_semantics() + + shape = (16, 256) + num_steps = shape[0] + + def kernel(x_gmem, o_gmem): + plgpu.emit_pipeline_warp_specialized( + kernel_body, + in_specs=[pl.BlockSpec((squeezed_dim, shape[1]), lambda i: (i, 0))], + out_specs=[pl.BlockSpec((squeezed_dim, shape[1]), lambda i: (i, 0))], + grid=(num_steps,), + max_concurrent_steps=2, + num_compute_wgs=1, + memory_registers=40, + wg_axis="wg", + )(x_gmem, o_gmem) + + def kernel_body(_, in_smem, o_smem): + o_smem[...] = in_smem[...] + 1 + + kernel_fn = self.kernel( + kernel, + out_shape=jax.ShapeDtypeStruct((16, 256), jnp.int32), + num_threads=2, + thread_name="wg", + ) + x = jnp.arange(16 * 256, dtype=jnp.int32).reshape(16, 256) + np.testing.assert_array_equal(kernel_fn(x), x + 1) + + class WarpSpecializedPipelineWGTest( WarpSpecializedPipelineTest, From 44085743e710edb36f39ada5de7bccd906549573 Mon Sep 17 00:00:00 2001 From: Vladimir Belitskiy Date: Wed, 17 Dec 2025 09:13:51 -0800 Subject: [PATCH 247/315] Add --dist=loadfile to pytest command in run_pytest_tpu.sh. Tests from a single file are sent to the same worker. This should reduce resource-related issues that cause hangs/crashes. PiperOrigin-RevId: 845796254 --- ci/run_pytest_tpu.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/run_pytest_tpu.sh b/ci/run_pytest_tpu.sh index abb45cbe10e8..95140d7e5133 100755 --- a/ci/run_pytest_tpu.sh +++ b/ci/run_pytest_tpu.sh @@ -73,7 +73,7 @@ if [[ "$JAXCI_RUN_FULL_TPU_TEST_SUITE" == "1" ]]; then --deselect=tests/pallas/tpu_pallas_call_print_test.py::PallasCallPrintTest \ --deselect=tests/pallas/tpu_sparsecore_pallas_test.py::DebugPrintTest \ --deselect=tests/pallas/tpu_pallas_interpret_thread_map_test.py::InterpretThreadMapTest::test_thread_map \ - --maxfail=20 -m "not multiaccelerator" $IGNORE_FLAGS tests examples + --dist=loadfile --maxfail=20 -m "not multiaccelerator" $IGNORE_FLAGS tests examples # Store the return value of the first command. first_cmd_retval=$? From 579009c41403e7caa2f40fecd9156e1b389c65ed Mon Sep 17 00:00:00 2001 From: Michael Hudgins Date: Wed, 17 Dec 2025 09:35:21 -0800 Subject: [PATCH 248/315] [CI] Modify TPU v7x jobs to include bazel, and exclude some python versions PiperOrigin-RevId: 845803404 --- .github/workflows/wheel_tests_continuous.yml | 1 + .github/workflows/wheel_tests_nightly_release.yml | 14 ++++++++++---- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index 07e2eb4fd978..aee1f3311d98 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -268,6 +268,7 @@ jobs: tpu-specs: [ {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}, + {type: "v7x-8", cores: "8", runner: "linux-x86-tpu7x-224-4tpu"}, ] libtpu-version-type: ["nightly"] name: "Bazel tests TPU (JAX artifacts version = ${{ format('{0}', 'head') }})" diff --git a/.github/workflows/wheel_tests_nightly_release.yml b/.github/workflows/wheel_tests_nightly_release.yml index 8a5703412936..07592cd0e70c 100644 --- a/.github/workflows/wheel_tests_nightly_release.yml +++ b/.github/workflows/wheel_tests_nightly_release.yml @@ -193,13 +193,17 @@ jobs: - tpu-specs: type: "v6e-8" python: "3.13-nogil" - # Run min and max Python versions for v5e-8 + # Run max Python versions for v5e-8 - tpu-specs: type: "v5e-8" python: "3.11" - tpu-specs: type: "v5e-8" python: "3.12" + # Run min and max Python versions for v7x-8 + - tpu-specs: + type: "v7x-8" + python: "3.12" name: "Pytest TPU (JAX artifacts version = ${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})" with: @@ -223,6 +227,7 @@ jobs: # {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}, + {type: "v7x-8", cores: "8", runner: "linux-x86-tpu7x-224-4tpu"}, ] libtpu-version-type: ["pypi_latest", "nightly"] exclude: @@ -240,12 +245,13 @@ jobs: - tpu-specs: type: "v4-8" python: "3.13-nogil" - # Run min and max Python versions for v5e-8 + # Run max Python versions for v5e-8 - tpu-specs: type: "v5e-8" - python: "3.11" + python: "3.12" + # Run min and max Python versions for v7x-8 - tpu-specs: - type: "v5e-8" + type: "v7x-8" python: "3.12" name: "Bazel tests TPU (JAX artifacts version = ${{ startsWith(github.ref_name, 'release/') && 'latest release' || 'nightly' }})" From 254918c89f98a23eb0065ff9d8e1169112befd35 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 17 Dec 2025 10:27:57 -0800 Subject: [PATCH 249/315] Fix ct_check to account for `None` cotangents too PiperOrigin-RevId: 845824336 --- jax/_src/interpreters/ad.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 35a7608bb145..d0f625ca5784 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -421,12 +421,12 @@ def __init__(self, aval, ref=None): def accum(self, x): assert x is not Zero - ct_check(self, x) if isinstance(x, Zero) or x is None: return - elif self.ref is None: + if self.ref is None: self.ref = core.new_ref(x) else: + ct_check(self, x) self.ref.addupdate(x) def freeze(self): @@ -449,8 +449,8 @@ def __init__(self, aval, val=None): self.val = Zero(aval) if val is None else val def accum(self, x): - ct_check(self, x) if x is not None: + ct_check(self, x) self.val = add_tangents(self.val, x) def freeze(self): From 7b63d9d3f893cbd009a2a83a294eed6cc79961e3 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 17 Dec 2025 11:47:31 -0800 Subject: [PATCH 250/315] [pallas] Deprecated `pltpu.ANY` in favor of `pl.ANY` PiperOrigin-RevId: 845857167 --- docs/pallas/design/async_note.md | 40 +++---- docs/pallas/tpu/distributed.ipynb | 18 +-- docs/pallas/tpu/distributed.md | 18 +-- docs/pallas/tpu/pipelining.ipynb | 6 +- docs/pallas/tpu/pipelining.md | 6 +- jax/_src/deprecations.py | 1 + jax/_src/pallas/mosaic/core.py | 24 +++- .../mosaic/interpret/interpret_pallas_call.py | 29 ++--- jax/_src/pallas/mosaic/lowering.py | 26 ++--- .../pallas/mosaic/pallas_call_registration.py | 4 - jax/_src/pallas/mosaic/pipeline.py | 20 ++-- jax/_src/pallas/mosaic/primitives.py | 2 +- .../paged_attention/paged_attention_kernel.py | 12 +- .../ops/tpu/ragged_paged_attention/kernel.py | 2 +- jax/experimental/pallas/tpu.py | 11 +- tests/pallas/tpu_all_gather_test.py | 5 +- tests/pallas/tpu_pallas_async_test.py | 66 +++++------ tests/pallas/tpu_pallas_distributed_test.py | 8 +- .../tpu_pallas_interpret_distributed_test.py | 24 ++-- tests/pallas/tpu_pallas_interpret_test.py | 10 +- tests/pallas/tpu_pallas_memory_space_test.py | 10 +- tests/pallas/tpu_pallas_pipeline_test.py | 104 +++++++++--------- tests/pallas/tpu_pallas_state_test.py | 4 +- tests/pallas/tpu_pallas_test.py | 18 +-- 24 files changed, 244 insertions(+), 224 deletions(-) diff --git a/docs/pallas/design/async_note.md b/docs/pallas/design/async_note.md index b255a91d3ec8..b21725f7a29e 100644 --- a/docs/pallas/design/async_note.md +++ b/docs/pallas/design/async_note.md @@ -18,7 +18,7 @@ def f(x): In this function, we could perform the `ppermute` at the same time as the `x + 1`. This is an optimization XLA does automatically by: -1. decomposing `ppermute` into a `ppermute_start` and `ppermute_done` op, which are connected via a future. +1. decomposing `ppermute` into a `ppermute_start` and `ppermute_done` op, which are connected via a future. 2. scheduling the `x + 1` between the `ppermute_start` and `ppermute_done`, resulting in the following program: @@ -107,12 +107,12 @@ def ppermute_start(x, *, axis_name) -> tuple[Semaphore, Semaphore, Array]: ), ), in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], out_specs=( pl.BlockSpec(memory_space=pltpu.SEMAPHORE), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), ), )(x) return send_sem, recv_sem, out @@ -139,11 +139,11 @@ def ppermute_done(send_sem, recv_sem, out) ->Array: ), ), in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), ], - out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), input_output_aliases={0:0} )(out, send_sem, recv_sem) return out @@ -167,9 +167,9 @@ def f(x): There are three remaining issues with this, each of which exists outside of Pallas to some degree. Here they are at a high level. -1. Scheduling \- just because we write `ppermute_start`, then `x + 1`, then `ppermute_done` doesn’t guarantee that they will happen in that order. XLA is responsible for scheduling, so when we write JAX programs, we are setting up data dependencies that XLA will respect but XLA will not respect the specific order of operations written in JAX. -2. Lifetimes \- XLA assumes that once a value is out of scope in the dependency graph, its memory can be freed for use by other values. If we have an op that asynchronously copies x \-\> y, we need to ensure that x is alive until the copy is complete, otherwise we will be copying from garbage memory. -3. Defensive copies \- XLA reserves the right to create copies of values. We need to make sure we don’t introduce unnecessary copies to a) avoid unnecessary runtime overhead and b) ensure correctness. +1. Scheduling \- just because we write `ppermute_start`, then `x + 1`, then `ppermute_done` doesn’t guarantee that they will happen in that order. XLA is responsible for scheduling, so when we write JAX programs, we are setting up data dependencies that XLA will respect but XLA will not respect the specific order of operations written in JAX. +2. Lifetimes \- XLA assumes that once a value is out of scope in the dependency graph, its memory can be freed for use by other values. If we have an op that asynchronously copies x \-\> y, we need to ensure that x is alive until the copy is complete, otherwise we will be copying from garbage memory. +3. Defensive copies \- XLA reserves the right to create copies of values. We need to make sure we don’t introduce unnecessary copies to a) avoid unnecessary runtime overhead and b) ensure correctness. We will go over these issues one by one and suggest fixes. @@ -292,13 +292,13 @@ def ppermute_start(x, *, axis_name) -> tuple[Semaphore, Semaphore, Array, Array] ), ), in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], out_specs=( pl.BlockSpec(memory_space=pltpu.SEMAPHORE), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), - pl.BlockSpec(memory_space=pltpu.ANY), - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), ), input_output_aliases={0:2} )(x) @@ -322,12 +322,12 @@ def ppermute_done(send_sem, recv_sem, x, out) ->Array: ), ), in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), ], - out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), input_output_aliases={1:0} )(x, out, send_sem, recv_sem) return out @@ -485,7 +485,7 @@ def f(x): def body(i, x): *sems, x, x2 = ppermute_start(x) x2 = ppermute_done((*sems, x, x2)) - + *sems, x2, y = ppermute_start(x2) y = ppermute_done((*sems, x2, y)) return y @@ -574,10 +574,10 @@ our program should now be correct. So we’ve come up with some rules of thumb: -1. If we have operations dependent on the input value to the `ppermute`, unpack the future to use the aliased value instead of the original value. +1. If we have operations dependent on the input value to the `ppermute`, unpack the future to use the aliased value instead of the original value. 2. Use `unroll >= 2` when doing `ppermute`s in a loop body. -Let’s combine everything into one function that does `ppermute`s in a loop and accumulates the result. +Let’s combine everything into one function that does `ppermute`s in a loop and accumulates the result. ```py def f(x): @@ -641,7 +641,7 @@ def f(x): return y_ref[...] ``` -Before, we ran into scheduling ambiguity, where XLA could re-order the add w.r.t. the `ppermute`. With stateful semantics, we actually add in an ordering constraint\! `x_ref[...] += 1` mutates `x_ref` so it can’t be moved wrt to `ppermute_done_stateful`. JAX can inject these scheduling constraints as part of the lowering to HLO. +Before, we ran into scheduling ambiguity, where XLA could re-order the add w.r.t. the `ppermute`. With stateful semantics, we actually add in an ordering constraint\! `x_ref[...] += 1` mutates `x_ref` so it can’t be moved wrt to `ppermute_done_stateful`. JAX can inject these scheduling constraints as part of the lowering to HLO. The final key difference is evident when we try our loop examples. @@ -665,8 +665,8 @@ To handle this without the manual unrolling, we’d create a scratch buffer with The realization here is that being stateful forces us to deal with a lot of the issues that pop up with value semantics earlier on. We define them away\! -1. Scheduling \- stateful ops that have `Ref`s as inputs force an ordering of our program. Note that this will schedule operations on the same `Ref` wrt to each other. We might also need an `opt_barrier_stateful` to enforce more ordering constraints. -2. Lifetimes \- `Ref` lifetimes can be scoped via `run_state` or could be inputs to stateful ops. +1. Scheduling \- stateful ops that have `Ref`s as inputs force an ordering of our program. Note that this will schedule operations on the same `Ref` wrt to each other. We might also need an `opt_barrier_stateful` to enforce more ordering constraints. +2. Lifetimes \- `Ref` lifetimes can be scoped via `run_state` or could be inputs to stateful ops. 3. Defensive copies \- Using `Ref`s forces us to handle buffer assignment “manually” and the lowering can ensure the aliasing works out to avoid any copies. Another important fundamental limitation is that we eventually stage out an HLO program where the live buffers and semaphores are represented as array value types. XLA does not provide guarantees about buffer lifetimes or which memory spaces they live in for these intermediate values. *Therefore, it is possible XLA can copy array values even if they are actively being copied into by Pallas kernels.* This is easy to verify in HLO but it is a sharp edge of using custom calls to represent asynchronous operations in HLO. diff --git a/docs/pallas/tpu/distributed.ipynb b/docs/pallas/tpu/distributed.ipynb index 434f610a0a79..ec13e7c7b5f1 100644 --- a/docs/pallas/tpu/distributed.ipynb +++ b/docs/pallas/tpu/distributed.ipynb @@ -273,9 +273,9 @@ " num_scalar_prefetch=0,\n", " # MemorySpace.ANY will (usually) place the tensor in HBM.\n", " in_specs=[\n", - " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", + " pl.BlockSpec(memory_space=pl.ANY),\n", " ],\n", - " out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", + " out_specs=pl.BlockSpec(memory_space=pl.ANY),\n", " scratch_shapes=(\n", " # We allocate DMA semaphores in scratch memory.\n", " [pltpu.SemaphoreType.DMA] * 2\n", @@ -421,9 +421,9 @@ " num_scalar_prefetch=0,\n", " in_specs=[\n", " # MemorySpace.ANY will (usually) place the tensor in HBM.\n", - " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", + " pl.BlockSpec(memory_space=pl.ANY),\n", " ],\n", - " out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", + " out_specs=pl.BlockSpec(memory_space=pl.ANY),\n", " scratch_shapes=(\n", " # DMA semaphores are allocated in scratch memory.\n", " # We allocated one semaphore for a local HBM-VMEM copy,\n", @@ -815,7 +815,7 @@ " # Our output lives in VMEM\n", " pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),\n", " # Our double-buffer lives in HBM\n", - " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", + " pl.BlockSpec(memory_space=pl.ANY),\n", " ],\n", " grid=(num_devices,),\n", " scratch_shapes=(\n", @@ -1150,7 +1150,7 @@ " ],\n", " out_specs=[\n", " pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),\n", - " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", + " pl.BlockSpec(memory_space=pl.ANY),\n", " ],\n", " grid=(num_devices, 2),\n", " scratch_shapes=(\n", @@ -1576,11 +1576,11 @@ "grid_spec = pltpu.PrefetchScalarGridSpec(\n", " num_scalar_prefetch=0,\n", " in_specs=[\n", - " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", + " pl.BlockSpec(memory_space=pl.ANY),\n", " ],\n", " out_specs=[\n", - " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", - " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", + " pl.BlockSpec(memory_space=pl.ANY),\n", + " pl.BlockSpec(memory_space=pl.ANY),\n", " ],\n", " grid=(num_devices, 2),\n", " scratch_shapes=(\n", diff --git a/docs/pallas/tpu/distributed.md b/docs/pallas/tpu/distributed.md index 678bd98f4470..da6e36bf672d 100644 --- a/docs/pallas/tpu/distributed.md +++ b/docs/pallas/tpu/distributed.md @@ -235,9 +235,9 @@ grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, # MemorySpace.ANY will (usually) place the tensor in HBM. in_specs=[ - pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), scratch_shapes=( # We allocate DMA semaphores in scratch memory. [pltpu.SemaphoreType.DMA] * 2 @@ -357,9 +357,9 @@ grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ # MemorySpace.ANY will (usually) place the tensor in HBM. - pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), scratch_shapes=( # DMA semaphores are allocated in scratch memory. # We allocated one semaphore for a local HBM-VMEM copy, @@ -709,7 +709,7 @@ grid_spec = pltpu.PrefetchScalarGridSpec( # Our output lives in VMEM pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM), # Our double-buffer lives in HBM - pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], grid=(num_devices,), scratch_shapes=( @@ -1023,7 +1023,7 @@ grid_spec = pltpu.PrefetchScalarGridSpec( ], out_specs=[ pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM), - pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], grid=(num_devices, 2), scratch_shapes=( @@ -1410,11 +1410,11 @@ out_shape = ( grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], out_specs=[ - pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), - pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], grid=(num_devices, 2), scratch_shapes=( diff --git a/docs/pallas/tpu/pipelining.ipynb b/docs/pallas/tpu/pipelining.ipynb index 64f095a4cb89..a78d794140ad 100644 --- a/docs/pallas/tpu/pipelining.ipynb +++ b/docs/pallas/tpu/pipelining.ipynb @@ -123,7 +123,7 @@ "\n", "| Pallas Enum | TPU Memory Space | Type (DRAM/SRAM) |\n", "| --- | --- | --- |\n", - "| `pltpu.MemorySpace.ANY` | HBM (usually) or VMEM | DRAM |\n", + "| `pl.ANY` | HBM (usually) or VMEM | DRAM |\n", "| `pltpu.MemorySpace.VMEM` | VMEM | SRAM |\n", "| `pltpu.MemorySpace.SMEM` | SMEM | SRAM |\n", "| `pltpu.MemorySpace.SEMAPHORE` | Semaphore | SRAM |\n", @@ -164,7 +164,7 @@ "\n", "x = jax.random.uniform(jax.random.key(0), (8, 128), jnp.float32)\n", "out = pl.pallas_call(hbm_vmem_kernel,\n", - " in_specs=[pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY)],\n", + " in_specs=[pl.BlockSpec(memory_space=pl.ANY)],\n", " out_shape=jax.ShapeDtypeStruct((1, 128), jnp.float32),\n", " scratch_shapes=(pltpu.MemorySpace.VMEM(shape=(1, 128), dtype=jnp.float32),)\n", ")(x)\n", @@ -283,7 +283,7 @@ "x = jax.random.uniform(jax.random.key(0), (8, 128), jnp.float32)\n", "slices = jnp.array([[0, 2], [2, 3], [3, 5], [5, 8]], dtype=jnp.int32)\n", "\n", - "hbm_block_spec = pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY)\n", + "hbm_block_spec = pl.BlockSpec(memory_space=pl.ANY)\n", "out = pl.pallas_call(dynamic_block_example_kernel,\n", " in_specs=[hbm_block_spec, hbm_block_spec],\n", " out_specs=hbm_block_spec,\n", diff --git a/docs/pallas/tpu/pipelining.md b/docs/pallas/tpu/pipelining.md index d91ebdd63f65..38747d17915a 100644 --- a/docs/pallas/tpu/pipelining.md +++ b/docs/pallas/tpu/pipelining.md @@ -95,7 +95,7 @@ Pallas exposes all levels of the TPU memory hierarchy to users. The following ta | Pallas Enum | TPU Memory Space | Type (DRAM/SRAM) | | --- | --- | --- | -| `pltpu.MemorySpace.ANY` | HBM (usually) or VMEM | DRAM | +| `pl.ANY` | HBM (usually) or VMEM | DRAM | | `pltpu.MemorySpace.VMEM` | VMEM | SRAM | | `pltpu.MemorySpace.SMEM` | SMEM | SRAM | | `pltpu.MemorySpace.SEMAPHORE` | Semaphore | SRAM | @@ -129,7 +129,7 @@ def hbm_vmem_kernel(x_hbm_ref, out_vmem_ref, scratch_vmem_ref): x = jax.random.uniform(jax.random.key(0), (8, 128), jnp.float32) out = pl.pallas_call(hbm_vmem_kernel, - in_specs=[pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY)], + in_specs=[pl.BlockSpec(memory_space=pl.ANY)], out_shape=jax.ShapeDtypeStruct((1, 128), jnp.float32), scratch_shapes=(pltpu.MemorySpace.VMEM(shape=(1, 128), dtype=jnp.float32),) )(x) @@ -229,7 +229,7 @@ def dynamic_block_example_kernel(x_hbm, slices_hbm, o_hbm, slices_smem): x = jax.random.uniform(jax.random.key(0), (8, 128), jnp.float32) slices = jnp.array([[0, 2], [2, 3], [3, 5], [5, 8]], dtype=jnp.int32) -hbm_block_spec = pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY) +hbm_block_spec = pl.BlockSpec(memory_space=pl.ANY) out = pl.pallas_call(dynamic_block_example_kernel, in_specs=[hbm_block_spec, hbm_block_spec], out_specs=hbm_block_spec, diff --git a/jax/_src/deprecations.py b/jax/_src/deprecations.py index 7cdbcb64791f..b650e0fd5a2f 100644 --- a/jax/_src/deprecations.py +++ b/jax/_src/deprecations.py @@ -135,3 +135,4 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None: register('safer-randint-config') register('jax-pmap-no-rank-reduction') register('jax-make-mesh-default-explicit') +register('pltpu-memory-space-any') diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index e05c69f537fc..f9fe91fe14b1 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -16,22 +16,23 @@ from __future__ import annotations import collections +from collections.abc import Mapping from collections.abc import Sequence import dataclasses import enum from typing import Any, ClassVar, Literal -from collections.abc import Mapping import jax -import jax.numpy as jnp -from jax.extend import backend as jex_backend from jax._src import core as jax_core +from jax._src import deprecations from jax._src import linear_util as lu from jax._src import state from jax._src import util from jax._src.frozen_dict import FrozenDict from jax._src.interpreters import partial_eval as pe from jax._src.pallas import core as pallas_core +from jax.extend import backend as jex_backend +import jax.numpy as jnp import numpy as np @@ -174,8 +175,8 @@ def __init__( # Replace is a method, not a field. replace = dataclasses.replace + class MemorySpace(enum.Enum): - ANY = "any" # TODO(b/368401328): Remove this and just use pl.ANY. VMEM = "vmem" VMEM_SHARED = "vmem_shared" SMEM = "smem" @@ -194,6 +195,21 @@ def __call__(self, shape: Sequence[int], dtype: jnp.dtype): # A convenience function for constructing MemoryRef types of ShapedArrays. return self.from_type(jax_core.ShapedArray(tuple(shape), dtype)) + def __getattr__(self, name): + if name == "ANY": + # Deprecated on Dec 10, 2025. + deprecations.warn( + "pltpu-memory-space-any", + "pltpu.MemorySpace.ANY is deprecated. Use pl.ANY instead.", + stacklevel=2, + ) + return pallas_core.MemorySpace.ANY + return super().__getattr__(name) # type: ignore + + +# TODO(slebedev): Remove this after +MemorySpace.ANY = pallas_core.MemorySpace.ANY + class dma_semaphore(pallas_core.semaphore_dtype): pass class DMASemaphore(pallas_core.AbstractSemaphoreTy): diff --git a/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py b/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py index 577df33a9f58..9b815860ac47 100644 --- a/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py +++ b/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py @@ -455,13 +455,16 @@ def _allocate_semaphores( TPU_MEMORY_SPACE_IDXS: dict[ mosaic_core.MemorySpace | pallas_core.MemorySpace | None, int ] = {v: i for i, v in enumerate(mosaic_core.MemorySpace)} -TPU_MEMORY_SPACE_IDXS[pallas_core.MemorySpace.ANY] = TPU_MEMORY_SPACE_IDXS[ - mosaic_core.MemorySpace.ANY -] TPU_MEMORY_SPACE_NAMES = { i: v.value for i, v in enumerate(mosaic_core.MemorySpace) } +# Inject ANY as the last memory space. +TPU_MEMORY_SPACE_NAMES[len(TPU_MEMORY_SPACE_IDXS)] = ( + pallas_core.MemorySpace.ANY.value +) +TPU_MEMORY_SPACE_IDXS[pallas_core.MemorySpace.ANY] = len(TPU_MEMORY_SPACE_IDXS) + # Default to VMEM when no memory space is specified. TPU_MEMORY_SPACE_IDXS[None] = TPU_MEMORY_SPACE_IDXS[ mosaic_core.MemorySpace.VMEM @@ -1068,8 +1071,7 @@ def _to_jaxpr(flat_fun, in_avals): return new_jaxpr def _is_any(memory_space): - return ((memory_space == mosaic_core.MemorySpace.ANY) or - (memory_space == pallas_core.MemorySpace.ANY)) + return memory_space is pallas_core.MemorySpace.ANY _SENTINEL = jnp.inf @@ -1077,7 +1079,7 @@ def _is_any(memory_space): def _get_memory_space_and_raise_if_hbm(aval, primitive_name, message=None): memory_space = aval.memory_space - if memory_space in [mosaic_core.MemorySpace.HBM, mosaic_core.MemorySpace.ANY]: + if memory_space in [mosaic_core.MemorySpace.HBM, pallas_core.MemorySpace.ANY]: if message is None: message = ( f'{primitive_name}: Buffers with a memory space of HBM or ANY cannot' @@ -1359,10 +1361,10 @@ def f(*args, jaxpr): ) = jax.tree.unflatten(eqn.params['tree'], eqn.invars) src_memory_space = getattr(orig_src_ref.aval, 'memory_space', None) if src_memory_space is None: - src_memory_space = mosaic_core.MemorySpace.ANY + src_memory_space = pallas_core.MemorySpace.ANY dst_memory_space = getattr(orig_dst_ref.aval, 'memory_space', None) if dst_memory_space is None: - dst_memory_space = mosaic_core.MemorySpace.ANY + dst_memory_space = pallas_core.MemorySpace.ANY callback.io_callback( functools.partial(dma_start, source_info=eqn.source_info), (), @@ -1636,7 +1638,6 @@ def _remove_memory_space_abstract_eval(x): if ( x.memory_space is None or x.memory_space is pallas_core.MemorySpace.ANY - or x.memory_space is mosaic_core.MemorySpace.ANY or x.memory_space is mosaic_core.MemorySpace.HBM ): return jax_core.ShapedArray(x.shape, x.dtype) @@ -1771,7 +1772,7 @@ def interpret_pallas_call( jax.ShapeDtypeStruct((), jnp.int16), device_id, None, # local_core_id - TPU_MEMORY_SPACE_IDXS[mosaic_core.MemorySpace.ANY], + TPU_MEMORY_SPACE_IDXS[pallas_core.MemorySpace.ANY], input_args[i], ordered=True, ) @@ -1806,7 +1807,7 @@ def interpret_pallas_call( jax.ShapeDtypeStruct((), jnp.int16), device_id, None, # local_core_id - TPU_MEMORY_SPACE_IDXS[mosaic_core.MemorySpace.ANY], + TPU_MEMORY_SPACE_IDXS[pallas_core.MemorySpace.ANY], padded_val, ordered=True, ) @@ -2046,7 +2047,7 @@ def _store_slice_to_kernel_input(index, input_var): jax.ShapeDtypeStruct(input_var.aval.shape, input_var.aval.dtype), device_id, core_index, - TPU_MEMORY_SPACE_IDXS[mosaic_core.MemorySpace.ANY], + TPU_MEMORY_SPACE_IDXS[pallas_core.MemorySpace.ANY], input_buffer_ids[index], (transform,), cur_block_indices[index], @@ -2115,7 +2116,7 @@ def _store_to_output_buffer(index, output_var, transform): (), device_id, core_index, - TPU_MEMORY_SPACE_IDXS[mosaic_core.MemorySpace.ANY], + TPU_MEMORY_SPACE_IDXS[pallas_core.MemorySpace.ANY], output_buffer_ids[index], (transform,), kernel_output_val, @@ -2234,7 +2235,7 @@ def _store_to_output_buffer(index, output_var, transform): val, device_id, 0, # local_core_id - TPU_MEMORY_SPACE_IDXS[mosaic_core.MemorySpace.ANY], + TPU_MEMORY_SPACE_IDXS[pallas_core.MemorySpace.ANY], output_buffer_id, ( indexing.NDIndexer.from_indices_shape( diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index d17b1aeb0c8a..70ae27890779 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -21,7 +21,7 @@ import functools import operator import string -from typing import Any, Protocol, Self, TypeVar, cast +from typing import Any, Literal, Protocol, Self, TypeVar, cast import jax from jax import api_util @@ -34,8 +34,8 @@ from jax._src import custom_derivatives from jax._src import debugging from jax._src import dtypes -from jax._src import literals from jax._src import linear_util as lu +from jax._src import literals from jax._src import mesh as mesh_lib from jax._src import pjit from jax._src import prng @@ -51,7 +51,6 @@ from jax._src.lax import control_flow from jax._src.lax import lax as lax_internal from jax._src.lax.control_flow import BranchesPlatforms - from jax._src.lib import xla_client from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith @@ -89,6 +88,7 @@ AnyMemorySpace = pallas_core.MemorySpace | TPUMemorySpace VMEM = TPUMemorySpace.VMEM SMEM = TPUMemorySpace.SMEM +ANY = pallas_core.MemorySpace.ANY # Booleans are stored as the following type in memrefs. BOOL_MEMREF_TYPE = np.dtype('int32') @@ -249,10 +249,11 @@ def is_cloud_tpu_older_than(self, year: int, month: int, day: int): return is_cloud_tpu_older_than(year, month, day, backend) -def _memory_space_to_tpu_memory_space(memory_space: AnyMemorySpace | None - ) -> TPUMemorySpace: +def _memory_space_to_tpu_memory_space( + memory_space: AnyMemorySpace | None, +) -> TPUMemorySpace | Literal[ANY]: if memory_space == jax_core.MemorySpace.Device: - return TPUMemorySpace.ANY + return ANY match memory_space: case None: @@ -261,7 +262,7 @@ def _memory_space_to_tpu_memory_space(memory_space: AnyMemorySpace | None return TPUMemorySpace.VMEM case pallas_core.MemorySpace.ANY: # Map the general ANY memory space to TPU ANY memory space - return TPUMemorySpace.ANY + return ANY case pallas_core.MemorySpace.HOST: return TPUMemorySpace.HOST case ( @@ -415,9 +416,6 @@ def _get_arg_type( memory_space = None if isinstance(aval, state.AbstractRef): memory_space = _memory_space_to_tpu_memory_space(aval.memory_space) - # We assume unannotated memory refs are in VMEM - if memory_space is None: - memory_space = TPUMemorySpace.VMEM return aval_to_ir_type( dynamic_shape_replacement_fn, aval, shape=shape, memory_space=memory_space ) @@ -663,10 +661,8 @@ def err_details(): "rank >= 1. " + err_details()) if ( - (memory_space == tpu_core.MemorySpace.ANY - or memory_space == tpu_core.MemorySpace.HBM) - and not bm.has_trivial_window() - ): + memory_space is ANY or memory_space == tpu_core.MemorySpace.HBM + ) and not bm.has_trivial_window(): raise ValueError( "The Pallas TPU lowering currently supports in memory space ANY " "only blocks having the same block shape as the array shape " @@ -804,7 +800,7 @@ def dynamic_shape_replacement_fn( tpu_memory_space = _memory_space_to_tpu_memory_space( bm.block_aval.memory_space) if ( - tpu_memory_space == tpu_core.MemorySpace.ANY + tpu_memory_space is ANY or tpu_memory_space == tpu_core.MemorySpace.HBM or tpu_memory_space == tpu_core.MemorySpace.SEMAPHORE ): diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 8cebd5965547..dd8bb00dcc78 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -73,10 +73,6 @@ def _get_memory_space_from_aval( # If we are passed an aval with an explicit memory space tag, we use it # to constrain the memory space. match out_aval.memory_space: - case None: - return None - case tpu_core.MemorySpace.ANY: - return None case tpu_core.MemorySpace.HBM: return tpu_custom_call.MemorySpace.HBM case tpu_core.MemorySpace.VMEM: diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index e519a323ca20..11e67f747be9 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -20,7 +20,7 @@ import dataclasses import enum import functools -from typing import Any, Union +from typing import Any, Literal, Union import jax from jax import core as jax_core @@ -40,7 +40,7 @@ SMEM = tpu_core.MemorySpace.SMEM VMEM = tpu_core.MemorySpace.VMEM -ANY = tpu_core.MemorySpace.ANY +ANY = pallas_core.MemorySpace.ANY REF = pallas_core.MemoryRef GridDimensionSemantics = tpu_core.GridDimensionSemantics PARALLEL = tpu_core.PARALLEL @@ -541,11 +541,17 @@ def buffer_types() -> type[BufferType]: return BufferType @classmethod - def create(cls, spec: pl.BlockSpec, dtype_or_type, buffer_type, buffer_count, - needs_swap_ref=True, - grid_rank=None, - use_lookahead=False, - source_memory_space: tpu_core.MemorySpace = ANY) -> BufferedRef: + def create( + cls, + spec: pl.BlockSpec, + dtype_or_type, + buffer_type, + buffer_count, + needs_swap_ref=True, + grid_rank=None, + use_lookahead=False, + source_memory_space: tpu_core.MemorySpace | Literal[ANY] = ANY, # type: ignore[valid-type] + ) -> BufferedRef: """Create a BufferedRef. Args: diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index 5be4028441db..1b94e5e75512 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -1154,7 +1154,7 @@ def with_memory_space_constraint( Returns: The array ``x`` with the memory space constraint. """ - if memory_space in {tpu_core.MemorySpace.ANY, pl_core.MemorySpace.ANY}: + if memory_space is pl_core.MemorySpace.ANY: return x if memory_space not in { tpu_core.MemorySpace.HBM, diff --git a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py index 5e3dcb271a4b..b35d6d6dc8d5 100644 --- a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py @@ -547,10 +547,10 @@ def paged_attention( if k_scales_pages is not None and v_scales_pages is not None: in_specs = [ q_block_spec, - pl.BlockSpec(memory_space=pltpu.ANY), - pl.BlockSpec(memory_space=pltpu.ANY), - pl.BlockSpec(memory_space=pltpu.ANY), - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), ] scratch_shapes = ( pltpu.VMEM( @@ -595,9 +595,9 @@ def paged_attention( else: in_specs = [ q_block_spec, - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), None, # type: ignore[list-item] - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), None, # type: ignore[list-item] ] scratch_shapes = ( diff --git a/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py index eea2f60c26f1..5ddeab270657 100644 --- a/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py +++ b/jax/experimental/pallas/ops/tpu/ragged_paged_attention/kernel.py @@ -831,7 +831,7 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): ) in_specs = [ q_block_spec, - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), ] out_specs = q_block_spec lm_scratch = pltpu.VMEM( diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index 0fbece6f5e42..43be2d80840b 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -89,10 +89,6 @@ HBM = MemorySpace.HBM HOST = MemorySpace.HOST SEMAPHORE = MemorySpace.SEMAPHORE -# Expose ANY for backward compatibility. -ANY = GeneralMemorySpace.ANY -del GeneralMemorySpace - _deprecations = { # Added Oct 31, 2025 @@ -100,13 +96,20 @@ "pltpu.delay is deprecated, use pl.delay instead.", pl_primitives.delay ), + # Added Dec 10, 2025 + "ANY": ( + "pltpu.ANY is deprecated, use pl.ANY instead.", + GeneralMemorySpace.ANY + ), } if typing.TYPE_CHECKING: delay = pl_primitives.delay + ANY = GeneralMemorySpace.ANY else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations) del _deprecation_getattr del typing del pl_primitives +del GeneralMemorySpace diff --git a/tests/pallas/tpu_all_gather_test.py b/tests/pallas/tpu_all_gather_test.py index 47168e1c35b4..72477e7230ce 100644 --- a/tests/pallas/tpu_all_gather_test.py +++ b/tests/pallas/tpu_all_gather_test.py @@ -20,6 +20,7 @@ from jax import random from jax._src import test_util as jtu from jax.experimental import mesh_utils +from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu from jax.experimental.pallas.ops.tpu import all_gather import jax.numpy as jnp @@ -91,7 +92,7 @@ def setUp(self): def test_all_gather_1d_mesh(self, is_vmem, shape, dtype): if jax.device_count() < 2: self.skipTest("Need more devices") - memory_space = pltpu.VMEM if is_vmem else pltpu.ANY + memory_space = pltpu.VMEM if is_vmem else pl.ANY mesh_shape = (jax.device_count(),) mesh = jax.sharding.Mesh( mesh_utils.create_device_mesh(mesh_shape, jax.devices()), ["x"] @@ -112,7 +113,7 @@ def test_all_gather_2d_mesh(self, is_vmem, shape, dtype, self.skipTest("Need more devices") if jax.device_count() % 2: self.skipTest("Need an even number of devices") - memory_space = pltpu.VMEM if is_vmem else pltpu.ANY + memory_space = pltpu.VMEM if is_vmem else pl.ANY mesh_shape = (2, jax.device_count() // 2) mesh = jax.sharding.Mesh( mesh_utils.create_device_mesh(mesh_shape, jax.devices()), ["x", "y"] diff --git a/tests/pallas/tpu_pallas_async_test.py b/tests/pallas/tpu_pallas_async_test.py index 3227aa218128..e9f2cd45e5ad 100644 --- a/tests/pallas/tpu_pallas_async_test.py +++ b/tests/pallas/tpu_pallas_async_test.py @@ -37,7 +37,7 @@ def make_async_copy(target_memory_space=None): if target_memory_space is None: - target_memory_space = pltpu.ANY + target_memory_space = pl.ANY @jax.named_call def copy_start(x: jax.Array) -> tuple[jax.Array, Future]: @@ -53,10 +53,10 @@ def copy_start_kernel(x_ref, aliased_x_ref, o_ref, sem): pltpu.SemaphoreType.DMA(()), ), in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], out_specs=( - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), pl.BlockSpec(memory_space=target_memory_space), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), ), @@ -76,7 +76,7 @@ def copy_done_kernel(x_ref, o_ref, sem, aliased_o_ref): copy_done_kernel, out_shape=target_memory_space(x.shape, x.dtype), # out in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), pl.BlockSpec(memory_space=target_memory_space), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), ], @@ -109,11 +109,11 @@ def async_slice_start(x: jax.Array) -> tuple[jax.Array, Future]: pltpu.SemaphoreType.DMA(()), ), in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], out_specs=( - pl.BlockSpec(memory_space=pltpu.ANY), - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), ), input_output_aliases={0: 0}, @@ -129,11 +129,11 @@ def async_slice_done( async_slice_done_kernel, out_shape=(jax.ShapeDtypeStruct(x.shape[1:], x.dtype)), # out in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), ], - out_specs=(pl.BlockSpec(memory_space=pltpu.ANY)), + out_specs=(pl.BlockSpec(memory_space=pl.ANY)), input_output_aliases={1: 0}, )(x, out, sem) return out @@ -164,11 +164,11 @@ def async_dslice_start(x: jax.Array) -> tuple[jax.Array, Future]: grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], out_specs=( - pl.BlockSpec(memory_space=pltpu.ANY), - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), ), ), @@ -185,11 +185,11 @@ def async_dslice_done( async_dslice_done_kernel, out_shape=(jax.ShapeDtypeStruct(x.shape[1:], x.dtype)), # out in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), ], - out_specs=(pl.BlockSpec(memory_space=pltpu.ANY)), + out_specs=(pl.BlockSpec(memory_space=pl.ANY)), input_output_aliases={1: 0}, )(x, out, sem) return out @@ -441,8 +441,8 @@ def _(): copy = pl.pallas_call( copy_kernel, out_shape=jax.ShapeDtypeStruct((xlocal, ylocal), jnp.float32), - in_specs=[pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),], - out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), + in_specs=[pl.BlockSpec(memory_space=pl.ANY),], + out_specs=pl.BlockSpec(memory_space=pl.ANY), scratch_shapes=[pltpu.SemaphoreType.DMA] * 3, ) @@ -544,7 +544,7 @@ def run_core_kernel(input): def make_async_remote_copy(axis_name: str, direction: str = 'right', target_memory_space=None): if target_memory_space is None: - target_memory_space = pltpu.ANY + target_memory_space = pl.ANY @jax.named_call def copy_start(x: jax.Array) -> tuple[jax.Array, Future]: @@ -579,10 +579,10 @@ def copy_start_kernel(x_ref, aliased_x_ref, o_ref, send_sem, recv_sem): pltpu.SemaphoreType.DMA(()), # recv_sem ), in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], out_specs=( - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), pl.BlockSpec(memory_space=target_memory_space), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), @@ -606,10 +606,10 @@ def send_done_kernel(x_ref, send_sem, aliased_o_ref): send_done_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), # out in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), ], - out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), input_output_aliases={0: 0}, )(x, send_sem) return x @@ -626,7 +626,7 @@ def send_done_kernel(x_ref, o_ref, send_sem, aliased_o_ref): send_done_kernel, out_shape=target_memory_space(x.shape, x.dtype), # out in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), pl.BlockSpec(memory_space=target_memory_space), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), ], @@ -675,16 +675,16 @@ def copy_start_kernel(x_ref, aliased_x_ref, o_ref, left_sems, right_sems): copy_start_kernel, out_shape=( jax.ShapeDtypeStruct(x.shape, x.dtype), # aliased x - pltpu.ANY(x.shape, x.dtype), # out + pl.ANY(x.shape, x.dtype), # out (pltpu.SemaphoreType.DMA(()),) * 2, # left_sems (pltpu.SemaphoreType.DMA(()),) * 2, # right_sems ), in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], out_specs=( - pl.BlockSpec(memory_space=pltpu.ANY), - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), (pl.BlockSpec(memory_space=pltpu.SEMAPHORE),) * 2, (pl.BlockSpec(memory_space=pltpu.SEMAPHORE),) * 2, ), @@ -716,11 +716,11 @@ def send_done_kernel(x_ref, send_left_sem, send_right_sem, aliased_o_ref): send_done_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), # out in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), ], - out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), input_output_aliases={0: 0}, )(x, send_left_sem, send_right_sem) return x @@ -747,12 +747,12 @@ def recv_done_kernel(o_ref, x_ref, recv_left_sem, recv_right_sem, recv_done_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), # out in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), pl.BlockSpec(memory_space=pltpu.SEMAPHORE), ], - out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), input_output_aliases={0: 0}, )(out, x, recv_left_sem, recv_right_sem) return out diff --git a/tests/pallas/tpu_pallas_distributed_test.py b/tests/pallas/tpu_pallas_distributed_test.py index 029ef28de555..670c2d7a1764 100644 --- a/tests/pallas/tpu_pallas_distributed_test.py +++ b/tests/pallas/tpu_pallas_distributed_test.py @@ -46,7 +46,7 @@ def setUp(self): @parameterized.named_parameters( ('vmem', pltpu.VMEM), - ('hbm', pltpu.ANY), + ('hbm', pl.ANY), ) def test_basic_remote_vmem_dma(self, mem): # Implements very simple collective permute @@ -119,8 +119,8 @@ def body(ready_sem, send_sem, recv_sem): def body(x): return pl.pallas_call( kernel, - in_specs=[pl.BlockSpec(memory_space=pltpu.ANY)], - out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + in_specs=[pl.BlockSpec(memory_space=pl.ANY)], + out_specs=pl.BlockSpec(memory_space=pl.ANY), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) @@ -513,7 +513,7 @@ def test_kernel(x_ref, grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], scratch_shapes=( [pltpu.SemaphoreType.DMA] * 2 diff --git a/tests/pallas/tpu_pallas_interpret_distributed_test.py b/tests/pallas/tpu_pallas_interpret_distributed_test.py index c44d56230762..d7bccfd6475d 100644 --- a/tests/pallas/tpu_pallas_interpret_distributed_test.py +++ b/tests/pallas/tpu_pallas_interpret_distributed_test.py @@ -96,9 +96,9 @@ def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem): num_scalar_prefetch=0, # MemorySpace.ANY will (usually) place the tensor in HBM. in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), scratch_shapes=( # We allocate DMA semaphores in scratch memory. [pltpu.SemaphoreType.DMA] * 2 @@ -209,9 +209,9 @@ def _(): num_scalar_prefetch=0, in_specs=[ # MemorySpace.ANY will (usually) place the tensor in HBM. - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), scratch_shapes=( # DMA semaphores are allocated in scratch memory. # We allocated one semaphore for a local HBM-VMEM copy, @@ -376,7 +376,7 @@ def _(): # Our output lives in VMEM pl.BlockSpec(memory_space=pltpu.VMEM), # Our double-buffer lives in HBM - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], grid=(num_devices,), scratch_shapes=( @@ -656,7 +656,7 @@ def _(): ], out_specs=[ pl.BlockSpec(memory_space=pltpu.VMEM), - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], grid=(num_devices, 2), scratch_shapes=( @@ -745,7 +745,7 @@ def test_reduce_scatter_sum_with_emit_pipeline_example( inner_block_spec = pl.BlockSpec( index_map=lambda i, j: (i, j), block_shape=inner_block_size, - memory_space=pltpu.ANY, + memory_space=pl.ANY, ) LEFT = 0 @@ -957,11 +957,11 @@ def _(): grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], out_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], grid=(num_devices, 2), scratch_shapes=( @@ -1062,9 +1062,9 @@ def run(src_dst_ids): out_shape=jax.ShapeDtypeStruct((8, 128), input_arr.dtype), in_specs=[ pl.BlockSpec(memory_space=pltpu.SMEM), - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), scratch_shapes=[pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.DMA], interpret=pltpu.InterpretParams( dma_execution_mode='eager', diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py index b6826677d874..be1fe689c40d 100644 --- a/tests/pallas/tpu_pallas_interpret_test.py +++ b/tests/pallas/tpu_pallas_interpret_test.py @@ -123,7 +123,7 @@ def run(): jax.ShapeDtypeStruct((16, 256), jnp.float32), ], grid=(4, 4), - in_specs=[pl.BlockSpec(memory_space=pltpu.ANY)], + in_specs=[pl.BlockSpec(memory_space=pl.ANY)], out_specs=[ pl.BlockSpec((4, 128), lambda i, j: (i, j // 2)), pl.BlockSpec((4, 128), lambda i, j: (j // 2, i % 2)), @@ -253,7 +253,7 @@ def run(): kernel, out_shape=jax.ShapeDtypeStruct((8, 128,), jnp.float32), out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), - in_specs=[pl.BlockSpec(memory_space=pltpu.ANY)], + in_specs=[pl.BlockSpec(memory_space=pl.ANY)], scratch_shapes=[pltpu.SemaphoreType.DMA], interpret=pltpu.InterpretParams( out_of_bounds_reads=out_of_bounds_reads), @@ -398,7 +398,7 @@ def kernel_with_race(x_ref, o_ref, t_ref, sem): y = pl.pallas_call( kernel_without_race, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), - in_specs=[pl.BlockSpec(memory_space=pltpu.ANY)], + in_specs=[pl.BlockSpec(memory_space=pl.ANY)], scratch_shapes=[ pltpu.VMEM(x.shape, x.dtype), pltpu.SemaphoreType.DMA, @@ -413,7 +413,7 @@ def kernel_with_race(x_ref, o_ref, t_ref, sem): pl.pallas_call( kernel_with_race, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), - in_specs=[pl.BlockSpec(memory_space=pltpu.ANY)], + in_specs=[pl.BlockSpec(memory_space=pl.ANY)], scratch_shapes=[ pltpu.VMEM(x.shape, x.dtype), pltpu.SemaphoreType.DMA, @@ -1191,7 +1191,7 @@ def kernel_call(s, num_cores_per_device, grid_point_recorder): ], ) - @parameterized.parameters(pltpu.MemorySpace.HBM, pltpu.MemorySpace.ANY) + @parameterized.parameters(pltpu.MemorySpace.HBM, pl.ANY) def test_referencing_hbm_raises(self, disallowed_memory_space): def jax_load_and_store(in_ref, o_ref): o_ref[...] = in_ref[...] diff --git a/tests/pallas/tpu_pallas_memory_space_test.py b/tests/pallas/tpu_pallas_memory_space_test.py index ec9195a7f3f1..2e9b7a1c1d9f 100644 --- a/tests/pallas/tpu_pallas_memory_space_test.py +++ b/tests/pallas/tpu_pallas_memory_space_test.py @@ -41,7 +41,7 @@ def setUp(self): (pltpu.VMEM, 1), (pltpu.SMEM, 4), (pltpu.HBM, 0), - (pltpu.ANY, None), + (pl.ANY, None), ) def test_basic_input_memory_space_constraint(self, memory_space, color): def kernel(x_ref, y_ref): @@ -52,7 +52,7 @@ def g(x): kernel, out_shape=x, in_specs=[pl.BlockSpec(memory_space=memory_space)], - out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), )(x) @jax.jit @@ -79,7 +79,7 @@ def f(x): (pltpu.VMEM, 1), (pltpu.SMEM, 4), (pltpu.HBM, 0), - (pltpu.ANY, None), + (pl.ANY, None), (pltpu.HOST, 5), ) def test_basic_output_memory_space_constraint(self, memory_space, color): @@ -94,7 +94,7 @@ def g(x): return pl.pallas_call( kernel, out_shape=out_shape_ctor(x.shape, x.dtype), - in_specs=[pl.BlockSpec(memory_space=pltpu.ANY)], + in_specs=[pl.BlockSpec(memory_space=pl.ANY)], out_specs=pl.BlockSpec(memory_space=memory_space), )(x) @@ -138,7 +138,7 @@ def setUp(self): (pltpu.VMEM, 1), (pltpu.SMEM, 4), (pltpu.HBM, 0), - (pltpu.ANY, None), + (pl.ANY, None), ) def test_basic_ref_memory_space_constraint(self, memory_space, color): @jax.jit diff --git a/tests/pallas/tpu_pallas_pipeline_test.py b/tests/pallas/tpu_pallas_pipeline_test.py index 257b1a474cb6..df3e56e0ab97 100644 --- a/tests/pallas/tpu_pallas_pipeline_test.py +++ b/tests/pallas/tpu_pallas_pipeline_test.py @@ -144,7 +144,7 @@ def body(o_ref): out = pl.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 512), jnp.int32), - out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), )() np.testing.assert_allclose(out, jnp.full_like(out, 42)) @@ -179,10 +179,10 @@ def matmul_kernel(x_ref, y_ref, z_ref): matmul_kernel, out_shape=jax.ShapeDtypeStruct((512, 512), jnp.float32), in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), ) jax.block_until_ready(z(x, y)) @@ -195,7 +195,7 @@ def matmul_kernel(x_ref, y_ref, z_ref): @parameterized.named_parameters( ('vmem', pltpu.VMEM), - ('hbm', pltpu.ANY), + ('hbm', pl.ANY), ) def test_double_pipeline_matmul(self, memory_space): # TODO(b/358121809): Re-enable this test once the bug is fixed. @@ -273,9 +273,9 @@ def inner_kernel(x_ref, o_ref): copy_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, jnp.int32), in_specs=[ - pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), ) result = fn(x) np.testing.assert_allclose(result, x) @@ -324,10 +324,10 @@ def _(): matmul_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, jnp.float32), in_specs=[ - pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), - pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), ) result = fn(x, y) np.testing.assert_allclose(result, x @ y, atol=5e-5) @@ -367,10 +367,10 @@ def _(): grid=(2,), out_shape=jax.ShapeDtypeStruct(x.shape, jnp.float32), in_specs=[ - pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), - pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), compiler_params=pltpu.CompilerParams( dimension_semantics=(pltpu.PARALLEL,) ), @@ -421,10 +421,10 @@ def inner_kernel(x_ref, o_ref): copy_kernel, out_shape=jax.ShapeDtypeStruct((len(in_block_indices) * 128, 128), jnp.int32), in_specs=[ - pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), pl.BlockSpec(memory_space=pltpu.MemorySpace.SMEM), ], - out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), ) result = fn(x, jnp.array(in_block_indices)) @@ -474,10 +474,10 @@ def inner_kernel(x_ref, o_ref): copy_kernel, out_shape=jax.ShapeDtypeStruct((1024, 128), jnp.int32), in_specs=[ - pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), pl.BlockSpec(memory_space=pltpu.MemorySpace.SMEM), ], - out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), ) result = fn(x, jnp.array(out_block_indices)) @@ -535,9 +535,9 @@ def _(allocations): copy_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, jnp.int32), in_specs=[ - pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), ) result = fn(x) np.testing.assert_allclose(result, x) @@ -613,10 +613,10 @@ def _(allocations): copy_kernel, out_shape=jax.ShapeDtypeStruct((blk_len * 2 * 128, 128), jnp.int32), in_specs=[ - pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), pl.BlockSpec(memory_space=pltpu.MemorySpace.SMEM), ], - out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), scratch_shapes = [pltpu.SMEM((1,), dtype=jnp.int32)] ) result = jax.block_until_ready(fn(x, jnp.array(in_block_indices))) @@ -685,9 +685,9 @@ def _(allocations): copy_kernel, out_shape=jax.ShapeDtypeStruct((128, 128), jnp.int32), in_specs=[ - pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), ) result = fn(x) expected = 0 @@ -730,10 +730,10 @@ def _(): matmul_kernel, out_shape=jax.ShapeDtypeStruct((M, N), jnp.float32), in_specs=[ - pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), - pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), scratch_shapes=inner_allocs, ) result = fn(x, y) @@ -756,10 +756,10 @@ def setUp(self): @parameterized.named_parameters( ('vmem', pltpu.VMEM, jnp.bfloat16, 2, 2, 2), - ('hbm', pltpu.ANY, jnp.bfloat16, 2, 2, 2), - ('hbm_float32', pltpu.ANY, jnp.float32, 2, 2, 2), - ('hbm_float32_112', pltpu.ANY, jnp.float32, 1, 1, 2), - ('hbm_float32_111', pltpu.ANY, jnp.float32, 1, 1, 1), + ('hbm', pl.ANY, jnp.bfloat16, 2, 2, 2), + ('hbm_float32', pl.ANY, jnp.float32, 2, 2, 2), + ('hbm_float32_112', pl.ANY, jnp.float32, 1, 1, 2), + ('hbm_float32_111', pl.ANY, jnp.float32, 1, 1, 1), ) def test_pipeline_latency_optimized_allgather_matmul( self, memory_space, out_dtype, n_tiles, m_tiles, k_tiles): @@ -1046,10 +1046,10 @@ def reference(x, y): @parameterized.named_parameters( ('vmem', pltpu.VMEM, jnp.bfloat16, 2, 2, 2), - ('hbm', pltpu.ANY, jnp.bfloat16, 2, 2, 2), - ('hbm_float32', pltpu.ANY, jnp.float32, 2, 2, 2), - ('hbm_float32_122', pltpu.ANY, jnp.float32, 1, 2, 2), - ('hbm_float32_121', pltpu.ANY, jnp.float32, 1, 2, 1), + ('hbm', pl.ANY, jnp.bfloat16, 2, 2, 2), + ('hbm_float32', pl.ANY, jnp.float32, 2, 2, 2), + ('hbm_float32_122', pl.ANY, jnp.float32, 1, 2, 2), + ('hbm_float32_121', pl.ANY, jnp.float32, 1, 2, 1), ) def test_pipeline_throughput_optimized_allgather_matmul( self, memory_space, out_dtype, n_tiles, m_tiles, k_tiles): @@ -1289,10 +1289,10 @@ def reference(x, y): @parameterized.named_parameters( ('vmem', pltpu.VMEM, jnp.bfloat16, 2, 2, 2), - ('hbm', pltpu.ANY, jnp.bfloat16, 2, 2, 2), - ('hbm_float32', pltpu.ANY, jnp.float32, 2, 4, 2), - ('hbm_float32_112', pltpu.ANY, jnp.float32, 1, 1, 2), - ('hbm_float32_111', pltpu.ANY, jnp.float32, 1, 1, 1), + ('hbm', pl.ANY, jnp.bfloat16, 2, 2, 2), + ('hbm_float32', pl.ANY, jnp.float32, 2, 4, 2), + ('hbm_float32_112', pl.ANY, jnp.float32, 1, 1, 2), + ('hbm_float32_111', pl.ANY, jnp.float32, 1, 1, 1), ) def test_pipeline_latency_optimized_matmul_reducescatter( self, memory_space, out_dtype, n_tiles, m_tiles, k_tiles): @@ -1576,10 +1576,10 @@ def reference(x, y): @parameterized.named_parameters( ('vmem', pltpu.VMEM, jnp.bfloat16, 2, 2, 2), - ('hbm', pltpu.ANY, jnp.bfloat16, 2, 2, 2), - ('hbm_float32', pltpu.ANY, jnp.float32, 2, 4, 2), - ('hbm_float32_112', pltpu.ANY, jnp.float32, 1, 2, 2), - ('hbm_float32_111', pltpu.ANY, jnp.float32, 1, 2, 1), + ('hbm', pl.ANY, jnp.bfloat16, 2, 2, 2), + ('hbm_float32', pl.ANY, jnp.float32, 2, 4, 2), + ('hbm_float32_112', pl.ANY, jnp.float32, 1, 2, 2), + ('hbm_float32_111', pl.ANY, jnp.float32, 1, 2, 1), ) def test_pipeline_throughput_optimized_matmul_reducescatter( self, memory_space, out_dtype, n_tiles, m_tiles, k_tiles): @@ -1868,9 +1868,9 @@ def mul_kernel(iters_ref, x_ref, y_ref): grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), grid=(num_cores,), ), compiler_params=pltpu.CompilerParams( @@ -1905,9 +1905,9 @@ def matmul_kernel(x_ref, y_ref): matmul_kernel, out_shape=jax.ShapeDtypeStruct((512, 512), jnp.float32), in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), grid=(num_cores,), compiler_params=pltpu.CompilerParams( dimension_semantics=('parallel',) @@ -1955,10 +1955,10 @@ def matmul_kernel(x_ref, y_ref, z_ref, *, bm, bk, bn): functools.partial(matmul_kernel, bm=bm, bk=bk, bn=bn), out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), grid=(num_cores,), compiler_params=pltpu.CompilerParams( dimension_semantics=('parallel',) @@ -2004,10 +2004,10 @@ def run(acc_scratch_ref): kernel, out_shape=jax.ShapeDtypeStruct((m, n), x.dtype), in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), grid=(num_cores,), )(x, y) diff --git a/tests/pallas/tpu_pallas_state_test.py b/tests/pallas/tpu_pallas_state_test.py index f0707280ff55..013b2e31fc7c 100644 --- a/tests/pallas/tpu_pallas_state_test.py +++ b/tests/pallas/tpu_pallas_state_test.py @@ -119,8 +119,8 @@ def f_stateful(refs): x = pl.pallas_call( functools.partial(copy_kernel, x_ref, y_ref), - in_specs=[pl.BlockSpec(memory_space=pltpu.ANY)], - out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + in_specs=[pl.BlockSpec(memory_space=pl.ANY)], + out_specs=pl.BlockSpec(memory_space=pl.ANY), scratch_shapes=[pltpu.SemaphoreType.DMA], out_shape=jax.ShapeDtypeStruct(x_ref.shape, x_ref.dtype), input_output_aliases={0: 0}, diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 36a21127a798..1aa8fc290a9d 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -1753,7 +1753,7 @@ def run(array, data, index, size): kernel, out_shape=array, in_specs=[ - pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), pl.BlockSpec(memory_space=pltpu.VMEM), pl.BlockSpec(memory_space=pltpu.SMEM), pl.BlockSpec(memory_space=pltpu.SMEM), @@ -1761,7 +1761,7 @@ def run(array, data, index, size): scratch_shapes=[ pltpu.SemaphoreType.DMA, ], - out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), input_output_aliases={0: 0}, )(array, data, index, size) @@ -1960,7 +1960,7 @@ def run(src): return pl.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct(src.shape, jnp.float32), - in_specs=[pl.BlockSpec(memory_space=pltpu.ANY)], + in_specs=[pl.BlockSpec(memory_space=pl.ANY)], scratch_shapes=[pltpu.SemaphoreType.DMA], out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), )(src) @@ -2954,9 +2954,9 @@ def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem): grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM), - pl.BlockSpec(memory_space=pltpu.ANY), - pl.BlockSpec(memory_space=pltpu.ANY)], - out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY)], + out_specs=pl.BlockSpec(memory_space=pl.ANY), scratch_shapes=[pltpu.SemaphoreType.DMA] ), out_shape=o, @@ -2982,9 +2982,9 @@ def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem): grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM), - pl.BlockSpec(memory_space=pltpu.ANY), - pl.BlockSpec(memory_space=pltpu.ANY)], - out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pl.ANY), + pl.BlockSpec(memory_space=pl.ANY)], + out_specs=pl.BlockSpec(memory_space=pl.ANY), scratch_shapes=[pltpu.SemaphoreType.DMA] ), out_shape=o, From c593739e7aaae70dfbb2d931c73dafb0c07d594f Mon Sep 17 00:00:00 2001 From: Yulia Baturina Date: Wed, 17 Dec 2025 12:31:58 -0800 Subject: [PATCH 251/315] Update `rules_ml_toolchain` version to remove redundant `fake_nvshmem_bootstrap_uid` library from hermetic CUDA deps. PiperOrigin-RevId: 845873778 --- WORKSPACE | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index 67474c08749d..6b3d0e2aa010 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -17,10 +17,10 @@ xla_workspace3() # Details: https://github.com/google-ml-infra/rules_ml_toolchain tf_http_archive( name = "rules_ml_toolchain", - sha256 = "e9842de3fefb5a120d3b1647d3a09e6e7071e8df8d1cd2dfe6f66ee31fd2595e", - strip_prefix = "rules_ml_toolchain-cb79a8fc8dcf3f75743dcd9b3418a70c884a7269", + sha256 = "53905ede50e3eebc782266e20e9b9ac1d7166ef68b877bea593d3600dcfe03e6", + strip_prefix = "rules_ml_toolchain-a1ff84835e407b41eef5fd1a865a23748c294db6", urls = tf_mirror_urls( - "https://github.com/google-ml-infra/rules_ml_toolchain/archive/cb79a8fc8dcf3f75743dcd9b3418a70c884a7269.tar.gz", + "https://github.com/google-ml-infra/rules_ml_toolchain/archive/a1ff84835e407b41eef5fd1a865a23748c294db6.tar.gz", ), ) From 019335652dc52e69a912b7d8fcbb7b965c8e6c58 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Wed, 17 Dec 2025 21:31:35 +0000 Subject: [PATCH 252/315] remove batching.primitive_batchers dead weight at this point Co-authored-by: Yash Katariya --- jax/_src/interpreters/batching.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index abdcddfd41f5..2f729c5774ed 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -255,14 +255,12 @@ def process_primitive(self, p, tracers, params): and p in skippable_batchers and not any(self.axis_data.name == axis_name for axis_name in skippable_batchers[p](params))): - # no-op shortcut return p.bind_with_trace(self.parent_trace, vals_in, params) else: with core.set_current_trace(self.parent_trace): val_out, dim_out = fancy_primitive_batchers[p]( self.axis_data, vals_in, dims_in, **params) elif args_not_mapped: - # no-op shortcut return p.bind_with_trace(self.parent_trace, vals_in, params) elif p in primitive_batchers: with core.set_current_trace(self.parent_trace): @@ -657,8 +655,6 @@ def _matchaxis_symzeros(axis_name, sz, mesh_axis, src, dst, x, sum_match=False): ..., tuple[Any, Union[int, None, tuple[Union[int, None], ...]]] ] -primitive_batchers : dict[core.Primitive, BatchingRule] = {} -# "fancy" primitive batchers just take a extra leading `AxisData` and "trace type" args fancy_primitive_batchers: dict[core.Primitive, Callable] = {} # backwards compat shim. TODO: delete @@ -667,9 +663,21 @@ def __setitem__(self, prim, batcher): def wrapped(axis_data, vals, dims, **params): return batcher(axis_data.size, axis_data.name, None, vals, dims, **params) fancy_primitive_batchers[prim] = wrapped - axis_primitive_batchers = AxisPrimitiveBatchersProxy() +class PrimitiveBatchersProxy: + def __setitem__(self, prim, batcher): + def wrapped(axis_data, vals, dims, **params): + if all(d is None for d in dims): + o = prim.bind(*vals, **params) + return (o, [None] * len(o)) if prim.multiple_results else (o, None) + return batcher(vals, dims, **params) + fancy_primitive_batchers[prim] = wrapped + + def __delitem__(self, prim): + del fancy_primitive_batchers[prim] +primitive_batchers = PrimitiveBatchersProxy() + # Presence in this table allows fancy batchers to be skipped by batch traces for # irrelevant axes. The Callable takes the params and returns a list of relevant From fdef5952e8dc9349ac95b4a9dd03b7f505c42283 Mon Sep 17 00:00:00 2001 From: Ionel Gog Date: Wed, 17 Dec 2025 13:44:21 -0800 Subject: [PATCH 253/315] Add platform name to xla::ifrt::Device PiperOrigin-RevId: 845901966 --- jaxlib/py_array.cc | 17 +++++++++++++---- jaxlib/py_device.cc | 12 +++++++++--- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/jaxlib/py_array.cc b/jaxlib/py_array.cc index c8984d7581a8..0ca463058bf2 100644 --- a/jaxlib/py_array.cc +++ b/jaxlib/py_array.cc @@ -47,6 +47,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/Support/Casting.h" #include "nanobind/nanobind.h" @@ -2304,13 +2305,21 @@ absl::Status PyArray::Register(nb::module_& m) { nb::is_method()); type.attr("platform") = nb::cpp_function( [](PyArray self) { - if (self.ifrt_array()->client()->platform_name() == "cuda" || - self.ifrt_array()->client()->platform_name() == "rocm") { +#if JAX_IFRT_VERSION_NUMBER >= 44 + const xla::ifrt::DeviceListRef& devices = + self.ifrt_array()->sharding().devices(); + absl::string_view platform_name = + devices->devices().front()->PlatformName(); +#else + absl::string_view platform_name = + self.ifrt_array()->client()->platform_name(); +#endif + if (platform_name == "cuda" || platform_name == "rocm") { return std::string_view("gpu"); } else { - return self.ifrt_array()->client()->platform_name(); + return platform_name; } - }, + }, nb::is_method()); type.attr("is_ready") = nb::cpp_function( [](PyArray self) { return xla::ValueOrThrow(self.IsReady()); }, diff --git a/jaxlib/py_device.cc b/jaxlib/py_device.cc index 8d4bcb216dfb..1410112e277c 100644 --- a/jaxlib/py_device.cc +++ b/jaxlib/py_device.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "llvm/Support/Casting.h" #include "nanobind/nanobind.h" #include "nanobind/stl/optional.h" // IWYU pragma: keep @@ -45,6 +46,7 @@ limitations under the License. #include "xla/python/nb_helpers.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/version.h" #include "xla/tsl/framework/allocator.h" #include "xla/tsl/platform/statusor.h" #include "xla/util.h" @@ -68,11 +70,15 @@ std::string_view PyDevice::platform() const { // but we haven't yet updated JAX clients that // expect "gpu". Migrate users and remove this // code. - if (client_->platform_name() == "cuda" || - client_->platform_name() == "rocm") { +#if JAX_IFRT_VERSION_NUMBER >= 44 + absl::string_view platform_name = device_->PlatformName(); +#else + absl::string_view platform_name = client_->platform_name(); +#endif + if (platform_name == "cuda" || platform_name == "rocm") { return std::string_view("gpu"); } else { - return client_->platform_name(); + return platform_name; } } From b16514bd607f4d810c75d7b4dd71b1ca52e1c1ef Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Wed, 17 Dec 2025 21:45:58 +0000 Subject: [PATCH 254/315] dont use primitive_batchers in batching.py --- jax/_src/interpreters/batching.py | 39 +++++++++++++++---------------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 2f729c5774ed..69f26c14eeac 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -262,9 +262,6 @@ def process_primitive(self, p, tracers, params): self.axis_data, vals_in, dims_in, **params) elif args_not_mapped: return p.bind_with_trace(self.parent_trace, vals_in, params) - elif p in primitive_batchers: - with core.set_current_trace(self.parent_trace): - val_out, dim_out = primitive_batchers[p](vals_in, dims_in, **params) else: raise NotImplementedError(f"Batching rule for '{p}' not implemented") src = source_info_util.current() @@ -665,9 +662,11 @@ def wrapped(axis_data, vals, dims, **params): fancy_primitive_batchers[prim] = wrapped axis_primitive_batchers = AxisPrimitiveBatchersProxy() +# backwards compat shim. TODO: delete class PrimitiveBatchersProxy: def __setitem__(self, prim, batcher): def wrapped(axis_data, vals, dims, **params): + del axis_data if all(d is None for d in dims): o = prim.bind(*vals, **params) return (o, [None] * len(o)) if prim.multiple_results else (o, None) @@ -680,31 +679,28 @@ def __delitem__(self, prim): # Presence in this table allows fancy batchers to be skipped by batch traces for -# irrelevant axes. The Callable takes the params and returns a list of relevant -# axes. +# irrelevant axes. The Callable takes params and returns a list of relevant axes +# TODO(yashkatariya): remove this skippable_batchers : dict[core.Primitive, Callable] = {} def defvectorized(prim): - primitive_batchers[prim] = partial(vectorized_batcher, prim) + fancy_primitive_batchers[prim] = partial(vectorized_batcher, prim) -def vectorized_batcher(prim, batched_args, batch_dims, **params): +def vectorized_batcher(prim, axis_data, batched_args, batch_dims, **params): + assert not prim.multiple_results + if all(d is None for d in batch_dims): + return prim.bind(*batched_args, **params), None assert all(batch_dims[0] == bd for bd in batch_dims[1:]), batch_dims return prim.bind(*batched_args, **params), batch_dims[0] def defbroadcasting(prim): - primitive_batchers[prim] = partial(broadcast_batcher, prim) - -def broadcast_batcher(prim, args, dims, **params): - """Process a primitive with built-in broadcasting. + fancy_primitive_batchers[prim] = partial(broadcast_batcher, prim) - Args: - args: the possibly-batched arguments - dims: list or tuple of the same length as `args`, where each - entry indicates the batching state of the corresponding entry to `args`: - either an int indicating the batch dimension, or else `not_mapped` - indicating no batching. - """ +def broadcast_batcher(prim, axis_data, args, dims, **params): assert len(args) > 1 + if all(d is None for d in dims): + o = prim.bind(*args, **params) + return (o, [None] * len(o)) if prim.multiple_results else (o, None) shape, dim = next((x.shape, d) for x, d in zip(args, dims) if d is not not_mapped) if all(core.definitely_equal_shape(shape, x.shape) and d == dim @@ -733,9 +729,12 @@ def _handle_scalar_broadcasting(nd, x, d): return lax.expand_dims(x, tuple(range(np.ndim(x), nd))) def defreducer(prim, ident): - primitive_batchers[prim] = partial(reducer_batcher, prim, ident) + fancy_primitive_batchers[prim] = partial(reducer_batcher, prim, ident) -def reducer_batcher(prim, ident, batched_args, batch_dims, axes, **params): +def reducer_batcher(prim, ident, axis_data, batched_args, batch_dims, axes, + **params): + if all(d is None for d in batch_dims): + return prim.bind(*batched_args, axes=axes, **params), None def out_axis(axes, axis): return int(list(np.delete(np.arange(operand.ndim), axes)).index(axis)) operand, = batched_args From 3c4dd57f0f8bcb6913905a9b549d58722e80a598 Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Wed, 17 Dec 2025 15:40:25 -0800 Subject: [PATCH 255/315] [JAX] Track the layout defaultness more precisely for arrays created using `from_dlpack` This change checks if a JAX array was created using an input buffer copy in `from_dlpack`. The copying would make the newly created `PjRtBuffer` to use a default layout. This defaultness would be tracked correctly. This change is expected to be a no-op today (and does not cause a problem before this change, either). Previously, the array with the new `PjRtBuffer` with copying would take the concrete default layout as if it were a custom layout, and now the same concrete layout will be available through `array.layout`. Internally in the runtime, it would now know that the defaultness, and knowing this aspect will be required in subsequent IFRT layout support changes. PiperOrigin-RevId: 845946765 --- jaxlib/dlpack.cc | 36 ++++++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/jaxlib/dlpack.cc b/jaxlib/dlpack.cc index 03388de494b4..14ccb51a85c1 100644 --- a/jaxlib/dlpack.cc +++ b/jaxlib/dlpack.cc @@ -178,11 +178,14 @@ absl::StatusOr> GetByteStrides(const DLTensor& dl_tensor) { return strides; } -absl::StatusOr> MakePjrtBuffer( - xla::PjRtDevice& device, ::DLManagedTensor* dlmt, const xla::Shape& shape, - xla::PrimitiveType element_type, absl::Span dimensions, - std::optional copy = std::nullopt, - std::optional stream = std::nullopt) { +// Makes a PjRtBuffer from a DLPack tensor. Returns a pair where the second +// element is true if a copy actually happened. +absl::StatusOr, bool>> +MakePjrtBuffer(xla::PjRtDevice& device, ::DLManagedTensor* dlmt, + const xla::Shape& shape, xla::PrimitiveType element_type, + absl::Span dimensions, + std::optional copy = std::nullopt, + std::optional stream = std::nullopt) { std::function on_delete_callback; if (dlmt->deleter) { on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); }; @@ -204,7 +207,8 @@ absl::StatusOr> MakePjrtBuffer( stream); if (!(result.status().code() == absl::StatusCode::kInvalidArgument && fallback_to_copy)) { - return result; + TF_RETURN_IF_ERROR(result.status()); + return std::make_pair(*std::move(result), false); } } @@ -217,10 +221,13 @@ absl::StatusOr> MakePjrtBuffer( TF_ASSIGN_OR_RETURN(auto* memory_space, device.default_memory_space()); // Create a copy. - return device.client()->BufferFromHostBuffer( - data, element_type, dimensions, byte_strides, - xla::PjRtClient::HostBufferSemantics::kMutableZeroCopy, - on_delete_callback, memory_space, /*device_layout=*/nullptr); + TF_ASSIGN_OR_RETURN( + auto buffer, + device.client()->BufferFromHostBuffer( + data, element_type, dimensions, byte_strides, + xla::PjRtClient::HostBufferSemantics::kMutableZeroCopy, + on_delete_callback, memory_space, /*device_layout=*/nullptr)); + return std::make_pair(std::move(buffer), true); } } // namespace @@ -365,9 +372,13 @@ absl::StatusOr DLPackManagedTensorToBuffer( xla::Shape shape = xla::ShapeUtil::MakeShapeWithDenseLayout( element_type, dimensions, minor_to_major); - TF_ASSIGN_OR_RETURN(auto pjrt_buffer, + TF_ASSIGN_OR_RETURN(auto pjrt_buffer_and_copied, MakePjrtBuffer(*device->pjrt_device(), dlmt, shape, element_type, dimensions, copy, stream)); + if (pjrt_buffer_and_copied.second) { + // A PjRtBuffer uses a default layout if it has been created using copy. + has_custom_layout = false; + } // We have taken ownership of the array inside the capsule; make sure the // capsule it cannot be used again. @@ -383,7 +394,8 @@ absl::StatusOr DLPackManagedTensorToBuffer( PyUserContextScope user_context_scope; TF_ASSIGN_OR_RETURN( auto ifrt_array, - ifrt_client->CreatePjRtArray(std::move(pjrt_buffer), has_custom_layout)); + ifrt_client->CreatePjRtArray(std::move(pjrt_buffer_and_copied.first), + has_custom_layout)); return PyArray::MakeFromSingleDeviceArray(std::move(client), std::move(ifrt_array), false, true); } From 4751be5c101b8c0812d29023a67aca7c241d48a7 Mon Sep 17 00:00:00 2001 From: Yue Sheng Date: Wed, 17 Dec 2025 16:33:54 -0800 Subject: [PATCH 256/315] [Mosaic TPU] Support reshape which folds last two dims when the last dim is not divisible by 128. PiperOrigin-RevId: 845964739 --- tests/pallas/tpu_pallas_test.py | 73 +++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 1aa8fc290a9d..229938289c65 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -3853,6 +3853,79 @@ def kernel(x_ref, y_ref): )(x) np.testing.assert_array_equal(out, x.reshape([q, m, n * k])) + # (q, m, n) -> (q, m * n) where n % 128 != 0 + @parameterized.parameters( + (q, m, n, dtype) + for (q, m, n), dtype in itertools.product( + [ + (32, 16, 500), + (20, 19, 500), + (5, 3, 200), + (9, 15, 200), + (3, 2, 200), + (5, 1, 300), + ], + [jnp.float32, jnp.uint32, jnp.bfloat16, jnp.int8], + ) + ) + def test_reshape_two_minor_dims_to_R2_padded_last_dim(self, q, m, n, dtype): + if not jtu.is_cloud_tpu_at_least(2025, 12, 22): + self.skipTest('Needs a newer libTPU') + if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( + dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) + ): + self.skipTest('Operation not supported on this TPU version.') + + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...].reshape( + x_ref.shape[0], x_ref.shape[1] * x_ref.shape[2] + ) + + x = np.arange(q * m * n, dtype=dtype).reshape(q, m, n) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((q, m * n), dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape([q, m * n])) + + # (q, m, n, k) -> (q, m, n * k) where k % 128 != 0 + @parameterized.parameters( + (q, m, n, k, dtype) + for (q, m, n, k), dtype in itertools.product( + [ + (3, 8, 17, 500), + (1, 8, 9, 200), + (1, 8, 3, 200), + (10, 1, 4, 200), + (1, 2, 2, 200), + (1, 9, 3, 200), + (4, 7, 1, 300), + ], + [jnp.float32, jnp.uint32, jnp.bfloat16, jnp.int8], + ) + ) + def test_reshape_two_minor_dims_to_R3_padded_last_dim( + self, q, m, n, k, dtype + ): + if not jtu.is_cloud_tpu_at_least(2025, 12, 22): + self.skipTest('Needs a newer libTPU') + if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( + dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) + ): + self.skipTest('Operation not supported on this TPU version.') + + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...].reshape( + x_ref.shape[0], x_ref.shape[1], x_ref.shape[2] * x_ref.shape[3] + ) + + x = np.arange(q * m * n * k, dtype=dtype).reshape(q, m, n, k) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((q, m, n * k), dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape([q, m, n * k])) + # (p, q, m, n, k) -> (p, q * m * n * k) where k % 128 == 0 @parameterized.parameters( (p, q, m, n, k, dtype) From f81c20150c39a103c93d3a9e69ac17cca5a18dd9 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 17 Dec 2025 16:55:12 -0800 Subject: [PATCH 257/315] [pallas:mosaic] Access memory spaces via `pltpu` directly The longer form, e.g. `pltpu.MemorySpace.HBM`, is unnecessary. PiperOrigin-RevId: 845971593 --- docs/pallas/quickstart.ipynb | 2 +- docs/pallas/quickstart.md | 2 +- docs/pallas/tpu/distributed.ipynb | 8 ++++---- docs/pallas/tpu/distributed.md | 8 ++++---- docs/pallas/tpu/pipelining.ipynb | 10 +++++----- docs/pallas/tpu/pipelining.md | 10 +++++----- tests/pallas/tpu_pallas_interpret_test.py | 14 +++++++------- tests/pallas/tpu_pallas_pipeline_test.py | 10 +++++----- tests/pallas/tpu_pallas_test.py | 4 ++-- 9 files changed, 34 insertions(+), 34 deletions(-) diff --git a/docs/pallas/quickstart.ipynb b/docs/pallas/quickstart.ipynb index 8ed5cac076d3..3fefc2cbc157 100644 --- a/docs/pallas/quickstart.ipynb +++ b/docs/pallas/quickstart.ipynb @@ -343,7 +343,7 @@ "\n", "def iota(size: int):\n", " return pl.pallas_call(iota_kernel,\n", - " out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.SMEM),\n", + " out_specs=pl.BlockSpec(memory_space=pltpu.SMEM),\n", " out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),\n", " grid=(size,))()\n", "iota(8)" diff --git a/docs/pallas/quickstart.md b/docs/pallas/quickstart.md index 3ff0801db965..f18225a589d5 100644 --- a/docs/pallas/quickstart.md +++ b/docs/pallas/quickstart.md @@ -230,7 +230,7 @@ from jax.experimental.pallas import tpu as pltpu def iota(size: int): return pl.pallas_call(iota_kernel, - out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.SMEM), + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), out_shape=jax.ShapeDtypeStruct((size,), jnp.int32), grid=(size,))() iota(8) diff --git a/docs/pallas/tpu/distributed.ipynb b/docs/pallas/tpu/distributed.ipynb index ec13e7c7b5f1..25968585ee81 100644 --- a/docs/pallas/tpu/distributed.ipynb +++ b/docs/pallas/tpu/distributed.ipynb @@ -809,11 +809,11 @@ " num_scalar_prefetch=0,\n", " in_specs=[\n", " # Our input lives in VMEM\n", - " pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),\n", + " pl.BlockSpec(memory_space=pltpu.VMEM),\n", " ],\n", " out_specs=[\n", " # Our output lives in VMEM\n", - " pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),\n", + " pl.BlockSpec(memory_space=pltpu.VMEM),\n", " # Our double-buffer lives in HBM\n", " pl.BlockSpec(memory_space=pl.ANY),\n", " ],\n", @@ -1146,10 +1146,10 @@ "grid_spec = pltpu.PrefetchScalarGridSpec(\n", " num_scalar_prefetch=0,\n", " in_specs=[\n", - " pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),\n", + " pl.BlockSpec(memory_space=pltpu.VMEM),\n", " ],\n", " out_specs=[\n", - " pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),\n", + " pl.BlockSpec(memory_space=pltpu.VMEM),\n", " pl.BlockSpec(memory_space=pl.ANY),\n", " ],\n", " grid=(num_devices, 2),\n", diff --git a/docs/pallas/tpu/distributed.md b/docs/pallas/tpu/distributed.md index da6e36bf672d..2fd2131c49ff 100644 --- a/docs/pallas/tpu/distributed.md +++ b/docs/pallas/tpu/distributed.md @@ -703,11 +703,11 @@ grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ # Our input lives in VMEM - pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), ], out_specs=[ # Our output lives in VMEM - pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), # Our double-buffer lives in HBM pl.BlockSpec(memory_space=pl.ANY), ], @@ -1019,10 +1019,10 @@ out_shape = ( grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), ], out_specs=[ - pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), pl.BlockSpec(memory_space=pl.ANY), ], grid=(num_devices, 2), diff --git a/docs/pallas/tpu/pipelining.ipynb b/docs/pallas/tpu/pipelining.ipynb index a78d794140ad..12a4b852e84a 100644 --- a/docs/pallas/tpu/pipelining.ipynb +++ b/docs/pallas/tpu/pipelining.ipynb @@ -124,9 +124,9 @@ "| Pallas Enum | TPU Memory Space | Type (DRAM/SRAM) |\n", "| --- | --- | --- |\n", "| `pl.ANY` | HBM (usually) or VMEM | DRAM |\n", - "| `pltpu.MemorySpace.VMEM` | VMEM | SRAM |\n", - "| `pltpu.MemorySpace.SMEM` | SMEM | SRAM |\n", - "| `pltpu.MemorySpace.SEMAPHORE` | Semaphore | SRAM |\n", + "| `pltpu.VMEM` | VMEM | SRAM |\n", + "| `pltpu.SMEM` | SMEM | SRAM |\n", + "| `pltpu.SEMAPHORE` | Semaphore | SRAM |\n", "\n", "- `MemorySpace.VMEM` denotes vector SRAM. It is the default memory space if nothing is specified.\n", "- `MemorySpace.SMEM` denotes scalar SRAM. Only scalar loads and stores can be performed to/from SMEM.\n", @@ -166,7 +166,7 @@ "out = pl.pallas_call(hbm_vmem_kernel,\n", " in_specs=[pl.BlockSpec(memory_space=pl.ANY)],\n", " out_shape=jax.ShapeDtypeStruct((1, 128), jnp.float32),\n", - " scratch_shapes=(pltpu.MemorySpace.VMEM(shape=(1, 128), dtype=jnp.float32),)\n", + " scratch_shapes=(pltpu.VMEM(shape=(1, 128), dtype=jnp.float32),)\n", ")(x)\n", "\n", "np.testing.assert_allclose(out, x[0:1] + 1)" @@ -288,7 +288,7 @@ " in_specs=[hbm_block_spec, hbm_block_spec],\n", " out_specs=hbm_block_spec,\n", " out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),\n", - " scratch_shapes=(pltpu.MemorySpace.SMEM(slices.shape, jnp.int32),)\n", + " scratch_shapes=(pltpu.SMEM(slices.shape, jnp.int32),)\n", " )(x, slices)\n", "\n", "np.testing.assert_allclose(x, out)" diff --git a/docs/pallas/tpu/pipelining.md b/docs/pallas/tpu/pipelining.md index 38747d17915a..02c9187edd2e 100644 --- a/docs/pallas/tpu/pipelining.md +++ b/docs/pallas/tpu/pipelining.md @@ -96,9 +96,9 @@ Pallas exposes all levels of the TPU memory hierarchy to users. The following ta | Pallas Enum | TPU Memory Space | Type (DRAM/SRAM) | | --- | --- | --- | | `pl.ANY` | HBM (usually) or VMEM | DRAM | -| `pltpu.MemorySpace.VMEM` | VMEM | SRAM | -| `pltpu.MemorySpace.SMEM` | SMEM | SRAM | -| `pltpu.MemorySpace.SEMAPHORE` | Semaphore | SRAM | +| `pltpu.VMEM` | VMEM | SRAM | +| `pltpu.SMEM` | SMEM | SRAM | +| `pltpu.SEMAPHORE` | Semaphore | SRAM | - `MemorySpace.VMEM` denotes vector SRAM. It is the default memory space if nothing is specified. - `MemorySpace.SMEM` denotes scalar SRAM. Only scalar loads and stores can be performed to/from SMEM. @@ -131,7 +131,7 @@ x = jax.random.uniform(jax.random.key(0), (8, 128), jnp.float32) out = pl.pallas_call(hbm_vmem_kernel, in_specs=[pl.BlockSpec(memory_space=pl.ANY)], out_shape=jax.ShapeDtypeStruct((1, 128), jnp.float32), - scratch_shapes=(pltpu.MemorySpace.VMEM(shape=(1, 128), dtype=jnp.float32),) + scratch_shapes=(pltpu.VMEM(shape=(1, 128), dtype=jnp.float32),) )(x) np.testing.assert_allclose(out, x[0:1] + 1) @@ -234,7 +234,7 @@ out = pl.pallas_call(dynamic_block_example_kernel, in_specs=[hbm_block_spec, hbm_block_spec], out_specs=hbm_block_spec, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), - scratch_shapes=(pltpu.MemorySpace.SMEM(slices.shape, jnp.int32),) + scratch_shapes=(pltpu.SMEM(slices.shape, jnp.int32),) )(x, slices) np.testing.assert_allclose(x, out) diff --git a/tests/pallas/tpu_pallas_interpret_test.py b/tests/pallas/tpu_pallas_interpret_test.py index be1fe689c40d..00da918bca5e 100644 --- a/tests/pallas/tpu_pallas_interpret_test.py +++ b/tests/pallas/tpu_pallas_interpret_test.py @@ -854,7 +854,7 @@ def f(x): kernel, grid=(2,), out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), - in_specs=[pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM)], + in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)], scratch_shapes=[ pltpu.VMEM(x.shape, x.dtype), ], @@ -1191,7 +1191,7 @@ def kernel_call(s, num_cores_per_device, grid_point_recorder): ], ) - @parameterized.parameters(pltpu.MemorySpace.HBM, pl.ANY) + @parameterized.parameters(pltpu.HBM, pl.ANY) def test_referencing_hbm_raises(self, disallowed_memory_space): def jax_load_and_store(in_ref, o_ref): o_ref[...] = in_ref[...] @@ -1220,7 +1220,7 @@ def kernel_call(kernel, x, *, in_memory_space, out_memory_space): jax_load_and_store, jnp.zeros((8, 128), jnp.float32), in_memory_space=disallowed_memory_space, - out_memory_space=pltpu.MemorySpace.VMEM, + out_memory_space=pltpu.VMEM, ) pltpu.reset_tpu_interpret_mode_state() @@ -1234,7 +1234,7 @@ def kernel_call(kernel, x, *, in_memory_space, out_memory_space): pallas_load_and_store, jnp.zeros((8, 128), jnp.float32), in_memory_space=disallowed_memory_space, - out_memory_space=pltpu.MemorySpace.VMEM, + out_memory_space=pltpu.VMEM, ) pltpu.reset_tpu_interpret_mode_state() @@ -1247,7 +1247,7 @@ def kernel_call(kernel, x, *, in_memory_space, out_memory_space): kernel_call( jax_load_and_store, jnp.zeros((8, 128), jnp.float32), - in_memory_space=pltpu.MemorySpace.VMEM, + in_memory_space=pltpu.VMEM, out_memory_space=disallowed_memory_space, ) pltpu.reset_tpu_interpret_mode_state() @@ -1261,8 +1261,8 @@ def kernel_call(kernel, x, *, in_memory_space, out_memory_space): kernel_call( pallas_load_and_store, jnp.zeros((8, 128), jnp.float32), - in_memory_space=pltpu.MemorySpace.VMEM, - out_memory_space=pltpu.MemorySpace.HBM, + in_memory_space=pltpu.VMEM, + out_memory_space=pltpu.HBM, ) pltpu.reset_tpu_interpret_mode_state() diff --git a/tests/pallas/tpu_pallas_pipeline_test.py b/tests/pallas/tpu_pallas_pipeline_test.py index df3e56e0ab97..61d3107da53e 100644 --- a/tests/pallas/tpu_pallas_pipeline_test.py +++ b/tests/pallas/tpu_pallas_pipeline_test.py @@ -422,7 +422,7 @@ def inner_kernel(x_ref, o_ref): out_shape=jax.ShapeDtypeStruct((len(in_block_indices) * 128, 128), jnp.int32), in_specs=[ pl.BlockSpec(memory_space=pl.ANY), - pl.BlockSpec(memory_space=pltpu.MemorySpace.SMEM), + pl.BlockSpec(memory_space=pltpu.SMEM), ], out_specs=pl.BlockSpec(memory_space=pl.ANY), ) @@ -475,7 +475,7 @@ def inner_kernel(x_ref, o_ref): out_shape=jax.ShapeDtypeStruct((1024, 128), jnp.int32), in_specs=[ pl.BlockSpec(memory_space=pl.ANY), - pl.BlockSpec(memory_space=pltpu.MemorySpace.SMEM), + pl.BlockSpec(memory_space=pltpu.SMEM), ], out_specs=pl.BlockSpec(memory_space=pl.ANY), ) @@ -614,7 +614,7 @@ def _(allocations): out_shape=jax.ShapeDtypeStruct((blk_len * 2 * 128, 128), jnp.int32), in_specs=[ pl.BlockSpec(memory_space=pl.ANY), - pl.BlockSpec(memory_space=pltpu.MemorySpace.SMEM), + pl.BlockSpec(memory_space=pltpu.SMEM), ], out_specs=pl.BlockSpec(memory_space=pl.ANY), scratch_shapes = [pltpu.SMEM((1,), dtype=jnp.int32)] @@ -2433,9 +2433,9 @@ def body(x_ref, o_ref): out = pl.pallas_call( kernel, - in_specs=(pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),), + in_specs=(pl.BlockSpec(memory_space=pl.ANY),), out_shape=out_ty, - out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), )(inp) np.testing.assert_allclose(out.x0, inp.x0) diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 229938289c65..b292cd3d409a 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -4263,11 +4263,11 @@ def test_async_copy_slice(self): def kernel(o): @functools.partial(pl.run_scoped, sem=pltpu.SemaphoreType.DMA, - x=pltpu.MemorySpace.VMEM((1,), jnp.float32)) + x=pltpu.VMEM((1,), jnp.float32)) def _(sem, x): x[...] = jnp.ones_like(x) @functools.partial(pl.run_scoped, - y=pltpu.MemorySpace.VMEM((1, 1,), jnp.float32)) + y=pltpu.VMEM((1, 1,), jnp.float32)) def _(y): pltpu.async_copy(x, y.at[0], sem).wait() o[...] = y[0] From 286e3583988992d9f9c2239f9e5d427a273d755e Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 17 Dec 2025 19:46:49 -0800 Subject: [PATCH 258/315] Automated Code Change PiperOrigin-RevId: 846022819 --- examples/jax_cpp/main.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/jax_cpp/main.cc b/examples/jax_cpp/main.cc index 27681e41fdad..b911711ad53f 100644 --- a/examples/jax_cpp/main.cc +++ b/examples/jax_cpp/main.cc @@ -106,7 +106,7 @@ int main(int argc, char** argv) { // Get result. std::shared_ptr result_literal = - results[0][0]->ToLiteralSync().value(); + results[0][0]->ToLiteral().Await().value(); LOG(INFO) << "result = " << *result_literal; return 0; } From 97c557de776308c4c03e7592dfc9d95be4e5d265 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 18 Dec 2025 00:05:24 -0800 Subject: [PATCH 259/315] Update XLA dependency to use revision http://github.com/openxla/xla/commit/913ae2eaa3cb88971003592a90959685a78c9e30 PiperOrigin-RevId: 846110589 --- third_party/xla/revision.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/revision.bzl b/third_party/xla/revision.bzl index 677ff16cfdf0..02a6712d7e3d 100644 --- a/third_party/xla/revision.bzl +++ b/third_party/xla/revision.bzl @@ -21,5 +21,5 @@ # and update XLA_SHA256 with the result. # buildifier: disable=module-docstring -XLA_COMMIT = "8bc1190b0e51ec5f01143d255826143ad7975ee9" -XLA_SHA256 = "2011c87aa50750037e78ebb771d1bc76f93f68b174f4354f64eff8eaa70c413d" +XLA_COMMIT = "913ae2eaa3cb88971003592a90959685a78c9e30" +XLA_SHA256 = "b0420fdca3789e659e314cae7ee38d1f13c613c458c00376a9d44dde51740d7f" From 3349e7a14684516aad97a9f23f315e9363dfb033 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Thu, 18 Dec 2025 01:45:45 -0800 Subject: [PATCH 260/315] [Pallas] Allow inferring the backend from the provided compiler params. This is a convenience on GPU, where two backends are available. PiperOrigin-RevId: 846147150 --- jax/_src/pallas/pallas_call.py | 3 +++ tests/pallas/BUILD | 1 + tests/pallas/pallas_test.py | 46 ++++++++++++++++++++++++++++++++-- 3 files changed, 48 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 8f00b492057d..480bcb650c4d 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1406,6 +1406,9 @@ def pallas_call( "If `grid_spec` is specified, then `scratch_shapes` must " f"be `()`. It is {scratch_shapes}") del grid, in_specs, out_specs + # We can infer a backend from compiler_params if it is not specified. + if backend is None and isinstance(compiler_params, pallas_core.CompilerParams): + backend = compiler_params.BACKEND return _pallas_call( kernel, out_shape, diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 0d03cc9b5444..d04c318d64f2 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -60,6 +60,7 @@ jax_multiplatform_test( "//jax/experimental:pallas", "//jax/experimental:pallas_gpu", "//jax/experimental:pallas_gpu_ops", + "//jax/experimental:pallas_mosaic_gpu", "//jax/experimental:pallas_tpu", "//jax/experimental:pallas_tpu_ops", ] + py_deps([ diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 4400a5b7ac14..53990260eaf3 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -43,10 +43,12 @@ if sys.platform != "win32": from jax.experimental.pallas import tpu as pltpu - from jax.experimental.pallas import triton as plgpu + from jax.experimental.pallas import triton as pltriton + from jax.experimental.pallas import mosaic_gpu as plmgpu else: pltpu = None - plgpu = None + pltriton = None + plmgpu = None # TODO(sharadmv): Update signatures of pallas_call to correct inputs/outputs. @@ -95,6 +97,46 @@ def body(i, acc): class PallasCallTest(ptu.PallasTest): + def test_pallas_call_infers_backend_from_compiler_params(self): + if not jtu.test_device_matches(["gpu"]): + self.skipTest("Only works on GPU.") + if not jtu.is_cuda_compute_capability_at_least("9.0"): + self.skipTest("Only works on a GPU with capability >= sm90") + + triton_params = pltriton.CompilerParams( + num_warps=2, + num_stages=1, + ) + mosaic_gpu_params = plmgpu.CompilerParams() + + pallas_call = functools.partial( + pl.pallas_call, + grid=(1,), + out_shape=jax.ShapeDtypeStruct((128, 128), jnp.float32), + ) + def add_one(x_ref, o_ref): + x = x_ref[:] + # Use a Pallas/Mosaic GPU-specific primitive to trigger a failure when + # using a different backend. + plmgpu.print_layout("x: {}", x) + o_ref[:] = x + 1 + + add_one_mgpu = pallas_call(add_one, compiler_params=mosaic_gpu_params) + add_one_triton = pallas_call(add_one, compiler_params=triton_params) + + x = jnp.ones((128, 128), jnp.float32) + + # Running on the Mosaic GPU backend should be fine. + self.assertArraysEqual(add_one_mgpu(x), x + 1) + + # But Triton doesn't have the required primitive, so it should fail to + # lower. + with self.assertRaisesRegex( + NotImplementedError, + "Unimplemented primitive in Pallas GPU lowering: print_layout." + ): + add_one_triton(x) + def test_add_one(self): if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: self.skipTest("On TPU the test works only in interpret mode") From b9b2d4f997d3adc3e5097e27d0d136dfc91999a7 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 18 Dec 2025 03:04:33 -0800 Subject: [PATCH 261/315] [pallas:mosaic] Recover `memory_space` from the aval in `aval_to_ir_type` Prior to this change `aval_to_ir_type` sometimes produced nonsensical memrefs, e.g. `memref>`, because it failed to inspect the `memory_space` associated with the aval. PiperOrigin-RevId: 846173986 --- jax/_src/pallas/mosaic/lowering.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 70ae27890779..229b29a1a130 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -342,6 +342,8 @@ def aval_to_ir_type( if isinstance(aval, state.AbstractRef): if shape is None: shape = aval.shape + if memory_space is None: + memory_space = aval.memory_space memspace = _memory_space_to_mosaic_attribute(memory_space) shape = dynamic_shape_replacement_fn(shape) return ir.MemRefType.get(shape, From c1188e740cc019e85fa8a393c3b07eca18c8d7de Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 18 Dec 2025 04:15:19 -0800 Subject: [PATCH 262/315] Add __getitem__ to backwards compatible shims. PiperOrigin-RevId: 846196686 --- jax/_src/interpreters/batching.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 69f26c14eeac..7fff2d31a371 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -660,6 +660,10 @@ def __setitem__(self, prim, batcher): def wrapped(axis_data, vals, dims, **params): return batcher(axis_data.size, axis_data.name, None, vals, dims, **params) fancy_primitive_batchers[prim] = wrapped + + def __getitem__(self, prim): + return fancy_primitive_batchers[prim] + axis_primitive_batchers = AxisPrimitiveBatchersProxy() # backwards compat shim. TODO: delete @@ -675,6 +679,10 @@ def wrapped(axis_data, vals, dims, **params): def __delitem__(self, prim): del fancy_primitive_batchers[prim] + + def __getitem__(self, prim): + return fancy_primitive_batchers[prim] + primitive_batchers = PrimitiveBatchersProxy() From 315bb935b369a46b777be9f68a245da51dfd654f Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Thu, 18 Dec 2025 04:34:55 -0800 Subject: [PATCH 263/315] [MGPU] Doc fix and shape size fix for broadcast in WGSplatFragLayout. We enforce len(source_shape) <= len(target_shape) PiperOrigin-RevId: 846203495 --- jax/experimental/mosaic/gpu/fragmented_array.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index ac0a10199943..795a3c3062c4 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -616,10 +616,14 @@ class WGSplatFragLayout: def can_broadcast_to(self, shape) -> bool: """Check that the shape can be broadcast. - Only dimensions of size 1 can be broadcast. All other dimensions - must be the same as the argument shape. + All source dimensions must match the target's trailing dimensions by + equality or being set to 1 (i.e. we can broadcast 1-sized dimensions or + create new leading dimensions). """ - return all(dim1 == dim2 or dim1 == 1 for dim1, dim2 in zip(self.shape[::-1], shape[::-1])) + return len(self.shape) <= len(shape) and all( + dim1 == dim2 or dim1 == 1 + for dim1, dim2 in zip(self.shape[::-1], shape[::-1]) + ) def registers_element_type(self, t: ir.Type) -> ir.Type: return t From 69a6cca9a656ba11831a8ca691e471b39caa2c05 Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 18 Dec 2025 07:10:47 -0800 Subject: [PATCH 264/315] [mpmd] Fix stage_id field nanobind compatability .none() annotations needed for optional parameters to allow accepting None from Python. PiperOrigin-RevId: 846253556 --- jaxlib/sdy_mpmd.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxlib/sdy_mpmd.cc b/jaxlib/sdy_mpmd.cc index 515c29d995f9..d21729dbe350 100644 --- a/jaxlib/sdy_mpmd.cc +++ b/jaxlib/sdy_mpmd.cc @@ -121,7 +121,7 @@ NB_MODULE(_sdy_mpmd, m) { std::optional, std::optional, const std::string&>(), - nb::arg("origins"), nb::arg("stage_id"), + nb::arg("origins"), nb::arg("stage_id").none() = std::nullopt, nb::arg("call_counter").none() = std::nullopt, nb::arg("split_type").none() = std::nullopt, nb::arg("mesh_name")) .def_ro("origins", &FragmentInfo::origins) From e2815d51486489ace1699b0e0152fc4feacc8ec4 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 18 Dec 2025 07:16:36 -0800 Subject: [PATCH 265/315] [pallas:mosaic] `pltpu.emit_pipeline` now accepts block specs in HBM This makes it possible to use it to implement pipelining in the `pallas_call` lowering on SparseCore. PiperOrigin-RevId: 846255621 --- jax/_src/pallas/mosaic/pipeline.py | 23 +++++++++------------- tests/pallas/tpu_pallas_pipeline_test.py | 25 ++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 11e67f747be9..80adc8bac314 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -40,6 +40,7 @@ SMEM = tpu_core.MemorySpace.SMEM VMEM = tpu_core.MemorySpace.VMEM +HBM = tpu_core.MemorySpace.HBM ANY = pallas_core.MemorySpace.ANY REF = pallas_core.MemoryRef GridDimensionSemantics = tpu_core.GridDimensionSemantics @@ -582,13 +583,13 @@ def create( accum_ref = VMEM.from_type(ty.update(shape=block_shape)) else: accum_ref = None - if source_memory_space == VMEM: - # We don't need to do any double-buffering in the case that our pipeline - # reference is already in VMEM, we just need allocate the accumulation - # buffer and we will refer to the original reference slices directly. - if spec.memory_space not in (VMEM, None): - raise ValueError( - f"Cannot hold a non-buffered ref in {spec.memory_space=}") + buffer_memory_space = ( + VMEM if spec.memory_space is None else spec.memory_space) + if buffer_memory_space not in (SMEM, VMEM, HBM): + raise ValueError( + f"Unsupported buffer memory space: {buffer_memory_space}" + ) + if source_memory_space is buffer_memory_space: return cls( _spec=spec, _buffer_type=buffer_type, @@ -609,12 +610,6 @@ def create( swap=None, ) else: - buffer_memory_space = ( - VMEM if spec.memory_space is None else spec.memory_space) - if buffer_memory_space not in (SMEM, VMEM): - raise ValueError( - f"Unsupported buffer memory space: {buffer_memory_space}" - ) if use_lookahead and grid_rank is None: raise ValueError( "grid_rank must be specified when use_lookahead is True." @@ -1335,7 +1330,7 @@ def out_of_fetch(self, buffered_ref): # Currently this is based on the iteration, but if we want to support # lookahead this will depend on whether the lookahead reached the end. if not buffered_ref.is_buffered: - return False + return jnp.bool(False) return self.step >= (self.num_steps - buffered_ref.buffer_count + 1) def has_changed(self, buffered_ref): diff --git a/tests/pallas/tpu_pallas_pipeline_test.py b/tests/pallas/tpu_pallas_pipeline_test.py index 61d3107da53e..62149556edc5 100644 --- a/tests/pallas/tpu_pallas_pipeline_test.py +++ b/tests/pallas/tpu_pallas_pipeline_test.py @@ -148,6 +148,31 @@ def body(o_ref): )() np.testing.assert_allclose(out, jnp.full_like(out, 42)) + def test_hbm_output(self): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct((8, 512), jnp.int32), + in_specs=[pl.BlockSpec(memory_space=pltpu.HBM)], + out_specs=pl.BlockSpec(memory_space=pltpu.HBM), + ) + def kernel(x_hbm_ref, o_hbm_ref): + @functools.partial( + pltpu.emit_pipeline, + grid=(4,), + in_specs=pl.BlockSpec((8, 128), lambda i: (0, i)), + out_specs=pl.BlockSpec( + (8, 512), lambda i: (0, 0), memory_space=pltpu.HBM + ), + ) + def pipeline(x_ref, o_ref): + i = pl.program_id(0) + pltpu.sync_copy(x_ref, o_ref.at[:, pl.ds(i * 128, 128)]) + + pipeline(x_hbm_ref, o_hbm_ref) + + x = jnp.arange(8 * 512).reshape(8, 512) + np.testing.assert_allclose(kernel(x), x) + @parameterized.product( no_pipelining=[False, True], ) From 987a02584069adcb0a8d213cb15a9d29b7e72002 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 18 Dec 2025 07:32:42 -0800 Subject: [PATCH 266/315] Update JAX tests for separate input_striding and input_tiling on transposes. This updates JAX after an PJRT API change. PiperOrigin-RevId: 846260872 --- jaxlib/callback.cc | 2 +- jaxlib/gpu/py_client_gpu.cc | 2 +- jaxlib/py_client_cpu.cc | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/jaxlib/callback.cc b/jaxlib/callback.cc index 83d46ed45c07..81263d149c64 100644 --- a/jaxlib/callback.cc +++ b/jaxlib/callback.cc @@ -104,7 +104,7 @@ absl::Status CpuCallback::PrepareAndCall(void** result, void** arg_ptrs) { xla::primitive_util::ByteWidth(results_[i].type); options.dims = dims; options.permutation = results_[i].reversed_layout; - options.input_layout = xla::TransposePlan::Striding{strides}; + options.input_striding = xla::TransposePlan::Striding{strides}; absl::StatusOr> plan = transpose_cache_.GetOrCreate(options); if (!plan.ok()) { diff --git a/jaxlib/gpu/py_client_gpu.cc b/jaxlib/gpu/py_client_gpu.cc index b3c091403c0c..f6678e023c6e 100644 --- a/jaxlib/gpu/py_client_gpu.cc +++ b/jaxlib/gpu/py_client_gpu.cc @@ -217,7 +217,7 @@ xla::ffi::Error XlaFfiPythonGpuCallback(gpuStream_t stream, absl::c_reverse_copy(expected_shape.layout().minor_to_major(), reversed_layout.begin()); options.permutation = reversed_layout; - options.input_layout = xla::TransposePlan::Striding{strides}; + options.input_striding = xla::TransposePlan::Striding{strides}; auto maybe_plan = transpose_cache->cache.GetOrCreate(options); if (!maybe_plan.ok()) { return xla::ffi::Error::Internal(maybe_plan.status().ToString()); diff --git a/jaxlib/py_client_cpu.cc b/jaxlib/py_client_cpu.cc index dc493832647a..aaf7fd2faab0 100644 --- a/jaxlib/py_client_cpu.cc +++ b/jaxlib/py_client_cpu.cc @@ -173,7 +173,7 @@ ffi::Error XlaFfiPythonCpuCallback(xla::FfiLoadedHostCallbacks* callbacks, absl::c_reverse_copy(expected_shape.layout().minor_to_major(), reversed_layout.begin()); options.permutation = reversed_layout; - options.input_layout = xla::TransposePlan::Striding{strides}; + options.input_striding = xla::TransposePlan::Striding{strides}; auto maybe_plan = transpose_cache->cache.GetOrCreate(options); if (!maybe_plan.ok()) { return ffi::Error::Internal(maybe_plan.status().ToString()); From 3f125024ecd17a1d90789a69c59962de4ff6af53 Mon Sep 17 00:00:00 2001 From: Vladimir Belitskiy Date: Thu, 18 Dec 2025 15:54:41 +0000 Subject: [PATCH 267/315] Prepare for JAX release 0.8.2 --- jax/version.py | 2 +- setup.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/jax/version.py b/jax/version.py index 24fccad06e8e..adde4c95168c 100644 --- a/jax/version.py +++ b/jax/version.py @@ -152,7 +152,7 @@ def make_release_tree(self, base_dir, files): __version__ = _get_version_string() -_minimum_jaxlib_version = '0.8.1' +_minimum_jaxlib_version = '0.8.2' def _version_as_tuple(version_str): return tuple(int(i) for i in version_str.split(".") if i.isdigit()) diff --git a/setup.py b/setup.py index 671578e791c1..11db6ffa92f2 100644 --- a/setup.py +++ b/setup.py @@ -19,11 +19,11 @@ project_name = 'jax' -_current_jaxlib_version = '0.8.1' +_current_jaxlib_version = '0.8.2' # The following should be updated after each new jaxlib release. _latest_jaxlib_version_on_pypi = '0.8.1' -_libtpu_version = '0.0.30.*' +_libtpu_version = '0.0.32.*' def load_version_module(pkg_path): spec = importlib.util.spec_from_file_location( From b8cd917515f776649aee436ebd155949c68245e8 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 18 Dec 2025 07:58:22 -0800 Subject: [PATCH 268/315] Reverts c1188e740cc019e85fa8a393c3b07eca18c8d7de PiperOrigin-RevId: 846269616 --- jax/_src/interpreters/batching.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 7fff2d31a371..69f26c14eeac 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -660,10 +660,6 @@ def __setitem__(self, prim, batcher): def wrapped(axis_data, vals, dims, **params): return batcher(axis_data.size, axis_data.name, None, vals, dims, **params) fancy_primitive_batchers[prim] = wrapped - - def __getitem__(self, prim): - return fancy_primitive_batchers[prim] - axis_primitive_batchers = AxisPrimitiveBatchersProxy() # backwards compat shim. TODO: delete @@ -679,10 +675,6 @@ def wrapped(axis_data, vals, dims, **params): def __delitem__(self, prim): del fancy_primitive_batchers[prim] - - def __getitem__(self, prim): - return fancy_primitive_batchers[prim] - primitive_batchers = PrimitiveBatchersProxy() From 9af721622fa57a5740730669692c0896bde6e50e Mon Sep 17 00:00:00 2001 From: Gleb Pobudzey Date: Thu, 18 Dec 2025 08:07:49 -0800 Subject: [PATCH 269/315] =?UTF-8?q?[Pallas=20MGPU]=20Simplify=20how=20we?= =?UTF-8?q?=20keep=20track=20of=20the=20current=20output=20slices.=20Keepi?= =?UTF-8?q?ng=20track=20of=20the=20full=20store=20slices=20is=20unnecessar?= =?UTF-8?q?y=20because=20the=20slice=20size=20doesn=E2=80=99t=20change,=20?= =?UTF-8?q?we=20really=20only=20care=20about=20the=20the=20start=20of=20th?= =?UTF-8?q?e=20slices.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PiperOrigin-RevId: 846273629 --- jax/_src/pallas/mosaic_gpu/pipeline.py | 82 ++++++++------------------ 1 file changed, 23 insertions(+), 59 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index b2286e3dfc18..4c8b49cbce60 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -182,26 +182,6 @@ def _inc_grid_by_1( def _in_smem(spec: pallas_core.BlockSpec) -> bool: return spec.memory_space in (None, gpu_core.SMEM) - -# ``pl.Slice`` uses a different pytree encoding, depending on whether the -# start/size are static or dynamic. This leads to pytree structure mismatch -# in the pipeline body. So, we define a different ``Slice`` class below. - - -@dataclasses.dataclass(frozen=True) -class _Slice: - start: int | jax.Array - size: int | jax.Array - - def __eq__(self, other: _Slice) -> jax.Array: # type: ignore - return lax.bitwise_and(self.start == other.start, self.size == other.size) - - -jax.tree_util.register_dataclass( - _Slice, data_fields=["start", "size"], meta_fields=[] -) - - def _downcast_spec( spec: gpu_core.BlockSpec | pallas_core.BlockSpec, ) -> gpu_core.BlockSpec: @@ -357,7 +337,7 @@ def prologue(step, fetch_indices): # need to fetch more data anyway. def loop_body(step, carry): slot = lax.rem(step, max_concurrent_steps) - indices, fetch_index_levels, last_store_slices, prev_body_carry = carry + indices, fetch_index_levels, last_store_indices, prev_body_carry = carry if barrier_ref is not None: # Wait for the current GMEM->SMEM copy to complete, if any. @@ -381,20 +361,17 @@ def loop_body(step, carry): gpu_primitives.commit_smem() # Copy the output from SMEM to GMEM. - new_store_slices = last_store_slices[:] + new_store_indices = last_store_indices[:] for idx, bref in enumerate(out_brefs): if bref.is_index_invariant: - assert last_store_slices[idx] is None + assert last_store_indices[idx] is None continue - assert last_store_slices[idx] is not None - new_store_slices[idx] = tuple( - _Slice(s.start, s.size) if isinstance(s, pl.Slice) else s - for s in bref.compute_gmem_slice(indices) - ) + assert last_store_indices[idx] is not None + new_store_indices[idx] = bref.spec.index_map(*indices) are_same_slices = map( lambda old, new: old == new, - last_store_slices[idx], - new_store_slices[idx], + last_store_indices[idx], + new_store_indices[idx], ) slices_changed = ~functools.reduce(lax.bitwise_and, are_same_slices) is_last_step = step == num_steps - 1 @@ -436,7 +413,7 @@ def do_fetch(): return ( _inc_grid_by_1(indices, grid), next_fetch_indices_levels, - new_store_slices, + new_store_indices, next_body_carry if init_carry is not None else None, ) @@ -447,23 +424,18 @@ def do_fetch(): fetch_indices = _inc_grid_by_1(fetch_indices, grid) fetch_index_levels.append(fetch_indices) - def _init_store_slice(bd): - if bd is None or isinstance(bd, pl.Squeezed): - return jnp.array(-1, dtype=jnp.int32) - return _Slice(-1, -1) - # TODO(justinfu): Only store base pointer instead of all indices. - last_store_slices = [ + last_store_indices = [ None if bref.is_index_invariant - else tuple(map(_init_store_slice, bref.spec.block_shape)) + else (jnp.array(-1),) * len(bref.spec.block_shape) for bref in out_brefs ] last_indices, _, _, final_carry = lax.fori_loop( 0, num_steps, loop_body, - (indices, fetch_index_levels, last_store_slices, init_carry), + (indices, fetch_index_levels, last_store_indices, init_carry), ) # Outputs invariant to the sequential axis are never written from inside the @@ -848,7 +820,7 @@ def compute_block(): needs_epilogue = any(bref.is_index_invariant for bref in smem_out_brefs) def compute_loop_body(step, carry): - indices, last_store_slices, prev_body_carry = carry + indices, last_store_indices, prev_body_carry = carry slot = lax.rem(step, max_concurrent_steps) consumed_slot = lax.rem(step - delay_release, max_concurrent_steps) # Wait for the current GMEM->SMEM copies to complete. @@ -895,20 +867,17 @@ def compute_loop_body(step, carry): if copies_out_in_loop: gpu_primitives.commit_smem() - new_store_slices = last_store_slices[:] + new_store_indices = last_store_indices[:] for idx, bref in enumerate(flat_out_brefs): if bref.is_index_invariant: - assert last_store_slices[idx] is None + assert last_store_indices[idx] is None continue - assert last_store_slices[idx] is not None - new_store_slices[idx] = tuple( - _Slice(s.start, s.size) if isinstance(s, pl.Slice) else s - for s in bref.compute_gmem_slice(indices) - ) + assert last_store_indices[idx] is not None + new_store_indices[idx] = bref.spec.index_map(*indices) are_same_slices = map( lambda old, new: old == new, - last_store_slices[idx], - new_store_slices[idx], + last_store_indices[idx], + new_store_indices[idx], ) slices_changed = ~functools.reduce(lax.bitwise_and, are_same_slices) bref.copy_out(_get_slot(slot, not bref.is_index_invariant), @@ -916,19 +885,14 @@ def compute_loop_body(step, carry): predicate=slices_changed) gpu_primitives.commit_smem_to_gmem_group() next_indices = _inc_grid_by_1(indices, grid) - return (next_indices, new_store_slices, next_body_carry) + return (next_indices, new_store_indices, next_body_carry) init_indices = (jnp.asarray(0, dtype=jnp.int32),) * len(grid) - def _init_store_slice(bd): - if bd is None or isinstance(bd, pl.Squeezed): - return jnp.array(-1, dtype=jnp.int32) - return _Slice(-1, -1) - # TODO(justinfu): Only store base pointer instead of all indices. - last_store_slices = [ + last_store_indices = [ None if bref.is_index_invariant - else tuple(map(_init_store_slice, bref.spec.block_shape)) + else (jnp.array(-1),) * len(bref.spec.block_shape) for bref in flat_out_brefs ] @@ -939,7 +903,7 @@ def pipeline_callback(user_init_carry): if last_indices is not None: raise ValueError( "Cannot call pipeline more than once in `compute_context`") - init_loop_carry = (init_indices, last_store_slices, user_init_carry) + init_loop_carry = (init_indices, last_store_indices, user_init_carry) last_indices, _, final_body_carry = lax.fori_loop(0, num_steps, compute_loop_body, @@ -952,7 +916,7 @@ def pipeline_callback(user_init_carry): assert compute_context is None last_indices, _, _ = lax.fori_loop( 0, num_steps, compute_loop_body, - (init_indices, last_store_slices, None) + (init_indices, last_store_indices, None) ) # Handle index_invariant outputs after the loop. They are not From cd1d05741cb862bd5bbc7c535caa4d0aea7d21a3 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 18 Dec 2025 09:15:23 -0800 Subject: [PATCH 270/315] [pallas:mosaic] Use `pl.delay` instead of the deprecated `pltpu.delay` PiperOrigin-RevId: 846297423 --- docs/pallas/tpu/distributed.ipynb | 2 +- docs/pallas/tpu/distributed.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/pallas/tpu/distributed.ipynb b/docs/pallas/tpu/distributed.ipynb index 25968585ee81..feebb7c2f8e7 100644 --- a/docs/pallas/tpu/distributed.ipynb +++ b/docs/pallas/tpu/distributed.ipynb @@ -644,7 +644,7 @@ "\n", "The main body assumes that a value has already been copied into our local working slot, either from the previous iteration or from the prologue. A complicating factor is that our destination buffers live in HBM, but we need to load values to VMEM before we perform arithmetic. Therefore, we simultaneously copy the working slot value into our VMEM (`receive_scratch`) and pass the value on to our right neighbor's receiving slot. Once the value has been copied into our VMEM, we can accumulate it into our result (contained in `o_ref`).\n", "\n", - "A subtle race condition can occur if one device runs one loop ahead of it's right neighbor. In this case, it could copy into the receiver's `working_slot` at the same time the receiver is reading from it. In order to avoid this, each device will block on a `REGULAR` semaphore before copying into the right neighbor's `dst_ref` until it has signaled that it is done reading from its `working_slot`. This race condition is rarely triggered for a small kernel such as this example, but can it can be explicitly triggered if for example using a `pltpu.delay` instruction to artificially hang a device.\n", + "A subtle race condition can occur if one device runs one loop ahead of it's right neighbor. In this case, it could copy into the receiver's `working_slot` at the same time the receiver is reading from it. In order to avoid this, each device will block on a `REGULAR` semaphore before copying into the right neighbor's `dst_ref` until it has signaled that it is done reading from its `working_slot`. This race condition is rarely triggered for a small kernel such as this example, but can it can be explicitly triggered if for example using a `pl.delay` instruction to artificially hang a device.\n", "\n", "Note that this is not an optimal or fully general kernel, as the block sizes must entirely fit in VMEM and we could better interleave communication and accumulation. We will discuss these optimizations in later sections." ] diff --git a/docs/pallas/tpu/distributed.md b/docs/pallas/tpu/distributed.md index 2fd2131c49ff..e8bbdb3089cc 100644 --- a/docs/pallas/tpu/distributed.md +++ b/docs/pallas/tpu/distributed.md @@ -556,7 +556,7 @@ The prologue (executed when `outer_step==0`) first initiates a barrier with both The main body assumes that a value has already been copied into our local working slot, either from the previous iteration or from the prologue. A complicating factor is that our destination buffers live in HBM, but we need to load values to VMEM before we perform arithmetic. Therefore, we simultaneously copy the working slot value into our VMEM (`receive_scratch`) and pass the value on to our right neighbor's receiving slot. Once the value has been copied into our VMEM, we can accumulate it into our result (contained in `o_ref`). -A subtle race condition can occur if one device runs one loop ahead of it's right neighbor. In this case, it could copy into the receiver's `working_slot` at the same time the receiver is reading from it. In order to avoid this, each device will block on a `REGULAR` semaphore before copying into the right neighbor's `dst_ref` until it has signaled that it is done reading from its `working_slot`. This race condition is rarely triggered for a small kernel such as this example, but can it can be explicitly triggered if for example using a `pltpu.delay` instruction to artificially hang a device. +A subtle race condition can occur if one device runs one loop ahead of it's right neighbor. In this case, it could copy into the receiver's `working_slot` at the same time the receiver is reading from it. In order to avoid this, each device will block on a `REGULAR` semaphore before copying into the right neighbor's `dst_ref` until it has signaled that it is done reading from its `working_slot`. This race condition is rarely triggered for a small kernel such as this example, but can it can be explicitly triggered if for example using a `pl.delay` instruction to artificially hang a device. Note that this is not an optimal or fully general kernel, as the block sizes must entirely fit in VMEM and we could better interleave communication and accumulation. We will discuss these optimizations in later sections. From eac6699d16ff098fd6096356f9ce881e28a1414d Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 18 Dec 2025 09:39:28 -0800 Subject: [PATCH 271/315] Fix `sub` jvp rule to broadcast shardings correctly so that tangent shardings match the primal shardings PiperOrigin-RevId: 846306224 --- jax/_src/lax/lax.py | 54 ++++++++++++++++++++++++++----------------- jax/_src/lax/other.py | 5 ++-- tests/pjit_test.py | 25 ++++++++++++++++++++ 3 files changed, 61 insertions(+), 23 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 71f063135001..4aaa0c88539f 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -62,6 +62,7 @@ from jax._src.lax.utils import ( input_dtype, dtype_to_string, standard_multi_result_abstract_eval, standard_primitive) +from jax._src.core import typeof from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo @@ -2709,7 +2710,7 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape, if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and isinstance(operand, Array) and out_sharding is None): return operand - operand_aval = core.typeof(operand) + operand_aval = typeof(operand) if (operand_aval.shape == shape and list(broadcast_dimensions) == list(range(operand_aval.ndim)) and out_sharding is not None and operand_aval.sharding != out_sharding): @@ -2776,7 +2777,7 @@ def reshape(operand: ArrayLike, new_sizes: Shape, same_dims = tuple(dims) == tuple(range(np.ndim(operand))) out_sharding = canonicalize_sharding(out_sharding, 'reshape') same_sharding = (out_sharding is None or - core.typeof(operand).sharding == out_sharding) + typeof(operand).sharding == out_sharding) if (np.shape(operand) and same_shape and same_dims and same_sharding and isinstance(operand, Array)): @@ -3603,7 +3604,7 @@ def full_like(x: ArrayLike | DuckTypedArray, # TODO(yashkatariya): Maybe use `shaped_abstractify` here instead of # `typeof` because `x` can be anything that implements the # `DuckTypedArray` protocol. - val = core.pvary(val, tuple(core.typeof(x).vma)) + val = core.pvary(val, tuple(typeof(x).vma)) return val @@ -4066,7 +4067,7 @@ def _unbroadcast(aval, x): raise TypeError("transpose with implicit broadcasting of unshaped values") x_shape = np.shape(x) if (core.definitely_equal_shape(aval.shape, x_shape) and - aval.sharding == core.typeof(x).sharding): + aval.sharding == typeof(x).sharding): return x assert not aval.shape or len(x_shape) == len(aval.shape) if not aval.shape: @@ -4079,17 +4080,20 @@ def _unbroadcast(aval, x): x = reduce_sum(x, dims) if dims else x return reshape(x, aval.shape, out_sharding=aval.to_cotangent_aval().sharding) -def _maybe_broadcast(target_shape, x): +def _maybe_broadcast(target_shape, x, target_sharding): x_shape = np.shape(x) - if core.definitely_equal_shape(x_shape, target_shape): + x_sharding = typeof(x).sharding + if (core.definitely_equal_shape(x_shape, target_shape) and + x_sharding == target_sharding): return x elif not x_shape: - return broadcast_in_dim(x, target_shape, ()) + return broadcast_in_dim(x, target_shape, (), out_sharding=target_sharding) else: dims = [i for i, (a, b) in enumerate(zip(x_shape, target_shape)) if core.definitely_equal(a, b)] squeeze_shape = [x_shape[i] for i in dims] - return broadcast_in_dim(reshape(x, squeeze_shape), target_shape, dims) + return broadcast_in_dim(reshape(x, squeeze_shape), target_shape, dims, + out_sharding=target_sharding) def broadcast_hlo( aval_out: core.ShapedArray, avals: Sequence[core.ShapedArray], @@ -4509,8 +4513,9 @@ def _pow_jvp_lhs(g, ans, x, y): if dtypes.issubdtype(y_dtype, np.integer): if x.shape != y.shape: shape = broadcast_shapes(x.shape, y.shape) - x = _maybe_broadcast(shape, x) - y = _maybe_broadcast(shape, y) + sharding = broadcast_shardings(typeof(x), typeof(y)) + x = _maybe_broadcast(shape, x, sharding) + y = _maybe_broadcast(shape, y, sharding) jac = select(eq(y, _const(y, 0)), _zeros(y), mul(_replace_zero(y), pow(x, sub(y, _ones(y))))) else: @@ -4617,9 +4622,11 @@ def _add_jvp(primals, tangents): if type(xdot) is type(ydot) is ad_util.Zero: return primal_out, ad_util.Zero.from_primal_value(primal_out) if type(xdot) is ad_util.Zero: - return primal_out, _maybe_broadcast(primal_out.shape, ydot) + return (primal_out, _maybe_broadcast(primal_out.shape, ydot, + typeof(primal_out).sharding)) elif type(ydot) is ad_util.Zero: - return primal_out, _maybe_broadcast(primal_out.shape, xdot) + return (primal_out, _maybe_broadcast(primal_out.shape, xdot, + typeof(primal_out).sharding)) else: return primal_out, add(xdot, ydot) @@ -4670,9 +4677,11 @@ def _sub_jvp(primals, tangents): if type(xdot) is type(ydot) is ad_util.Zero: return primal_out, ad_util.Zero.from_primal_value(primal_out) if type(xdot) is ad_util.Zero: - return primal_out, _maybe_broadcast(primal_out.shape, neg(ydot)) + return (primal_out, _maybe_broadcast(primal_out.shape, neg(ydot), + typeof(primal_out).sharding)) elif type(ydot) is ad_util.Zero: - return primal_out, _maybe_broadcast(primal_out.shape, xdot) + return (primal_out, _maybe_broadcast(primal_out.shape, xdot, + typeof(primal_out).sharding)) else: return primal_out, sub(xdot, ydot) @@ -4745,14 +4754,17 @@ def _div_transpose_rule(cotangent, x, y): rem_p = standard_naryop([_int | _float, _int | _float], 'rem') ad.defjvp( rem_p, - lambda g, x, y: _maybe_broadcast(broadcast_shapes(np.shape(x), np.shape(y)), g), + lambda g, x, y: _maybe_broadcast( + broadcast_shapes(np.shape(x), np.shape(y)), g, + broadcast_shardings(typeof(x), typeof(y))), lambda g, x, y: mul(neg(g), mul(sign(div(x, y)), floor(abs(div(x, y)))))) mlir.register_lowering(rem_p, partial(_nary_lower_hlo, hlo.remainder)) def _minmax_complex_lowering(x, y, *, lax_cmp_pick_x): result_shape = broadcast_shapes(np.shape(x), np.shape(y)) - x = _maybe_broadcast(result_shape, x) - y = _maybe_broadcast(result_shape, y) + result_sharding = broadcast_shardings(typeof(x), typeof(y)) + x = _maybe_broadcast(result_shape, x, result_sharding) + y = _maybe_broadcast(result_shape, y, result_sharding) rx = real(x) ry = real(y) pick_x = select(eq(rx, ry), lax_cmp_pick_x(imag(x), imag(y)), @@ -7325,7 +7337,7 @@ def _select_batch_rule(axis_data, batched_args, batch_dims, **unused_kwargs): # vmapped function had a scalar which with nonscalar args assert np.ndim(which) == 1 which = broadcast_in_dim(which, cases[0].shape, [which_bdim], - out_sharding=core.typeof(cases[0]).sharding) + out_sharding=typeof(cases[0]).sharding) return select_n(which, *cases), which_bdim elif np.ndim(which) == 0 and all(bdim is not None for bdim in case_bdims): if all(case_bdims[0] == bdim for bdim in case_bdims[1:]): @@ -7347,7 +7359,7 @@ def _select_batch_rule(axis_data, batched_args, batch_dims, **unused_kwargs): # vmapped function had a scalar which with nonscalar args assert np.ndim(which) == 1 which = broadcast_in_dim(which, cases[0].shape, [0], - out_sharding=core.typeof(cases[0]).sharding) + out_sharding=typeof(cases[0]).sharding) if np.ndim(which) > np.ndim(cases[0]): assert np.ndim(cases[0]) == 0 cases = [broadcast(c, which.shape) for c in cases] @@ -7794,7 +7806,7 @@ def _reduce_logical_sharding_rule(operand, *, axes): def _reduce_or_lin(nzs, x, *, axes): nz, = nzs y = reduce_or_p.bind(x, axes=axes) - aval = core.typeof(y).to_tangent_aval() + aval = typeof(y).to_tangent_aval() return y, False, (), lambda _, t: ad_util.Zero(aval) reduce_or_p = standard_primitive( @@ -8001,7 +8013,7 @@ def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable, num_keys dims = np.delete(np.arange(prototype_arg.ndim), new_bdim) new_args.append(broadcast_in_dim( arg, prototype_arg.shape, dims, - out_sharding=core.typeof(prototype_arg).sharding)) + out_sharding=typeof(prototype_arg).sharding)) else: new_args.append(batching.moveaxis(arg, bdim, new_bdim)) new_dimension = dimension + (new_bdim <= dimension) diff --git a/jax/_src/lax/other.py b/jax/_src/lax/other.py index f67f64a40133..b3d54064f9b4 100644 --- a/jax/_src/lax/other.py +++ b/jax/_src/lax/other.py @@ -284,8 +284,9 @@ def _logaddexp_jvp(primals, tangents): x1, x2 = primals t1, t2 = tangents primal_out = logaddexp(x1, x2) - tangent_out = lax.add(lax.mul(t1, lax.exp(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))), - lax.mul(t2, lax.exp(lax.sub(_replace_inf(x2), _replace_inf(primal_out))))) + tangent_out = lax.add( + lax.mul(t1, lax.exp(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))), + lax.mul(t2, lax.exp(lax.sub(_replace_inf(x2), _replace_inf(primal_out))))) return primal_out, tangent_out diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 8c71d5132667..03655c64903e 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -9982,6 +9982,31 @@ def test_reshard_no_mesh_ctx(self): self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) self.assertArraysEqual(out, np.arange(8)) + @jtu.with_explicit_mesh((2,), 'x') + def test_sub_custom_jvp(self, mesh): + np1 = np.arange(2 * 1024, dtype=np.float32).reshape(4, 16, 16, -1) + arr1 = jax.device_put(np1, P("x", None)) + arr2 = jax.device_put(np1, P("x", None)) + + def f(logits, labels): + labels = jnp.astype(labels, logits.dtype) + log_p = jax.nn.log_sigmoid(logits) + log_not_p = jax.nn.log_sigmoid(-logits) + return -labels * log_p - (1.0 - labels) * log_not_p + + @jax.jit + def g(pl, tl): + x = pl - jnp.mean(tl) + y = tl - jnp.mean(pl) + loss_on_fake = jnp.mean(f(x, jnp.zeros_like(pl))) + loss_on_real = jnp.mean(f(y, jnp.ones_like(tl))) + disc_loss = loss_on_fake + loss_on_real + gen_loss = jnp.mean(f(x, jnp.ones_like(pl))) + gen_loss += jnp.mean(f(y, jnp.zeros_like(tl))) + return disc_loss, gen_loss + + jax.jit(jax.grad(lambda t1, t2: g(t1, t2)[0]))(arr1, arr2) # doesn't crash + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): From c617efcdbb40281a260488e0e3a00f12651da021 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 18 Dec 2025 11:22:32 -0800 Subject: [PATCH 272/315] Drop into full manual mode via shard_map in pallas batching rule when pallas_call is vmapped over an Explicit mesh axis. This is because pallas can only work in full manual mode. Also, allow mixing Manual and Explicit mesh axis in the same tuple in PartitionSpec. For example: ``` mesh = [('x', 2, Explicit), ('y', 2, Explicit)] P((x, y), None) # this is valid now. # Under a shard_map, if you go manual only over `x` i.e. @shard_map(in_specs=P('x', None), out_specs=P('x', None), axis_names={'x'}) def f(x): print(typeof(x)) # f32[4@y, 2]{V:x} return x f(arr: f32[8@(x,y), 2] ``` Co-authored-by: Matthew Johnson PiperOrigin-RevId: 846347531 --- jax/_src/named_sharding.py | 14 ------ jax/_src/pallas/pallas_call.py | 87 +++++++++++++++++++++++++--------- jax/_src/tpu_custom_call.py | 3 +- tests/array_test.py | 7 +-- tests/shard_map_test.py | 16 +++++++ 5 files changed, 85 insertions(+), 42 deletions(-) diff --git a/jax/_src/named_sharding.py b/jax/_src/named_sharding.py index a08e0b51b093..e3ffd1538322 100644 --- a/jax/_src/named_sharding.py +++ b/jax/_src/named_sharding.py @@ -515,19 +515,6 @@ def _check_unique_resources(pspec: PartitionSpec, arg_name: str, mesh=None f' for {mesh_lib.show_axes(multiple_uses)}'), mesh=mesh, pspec=pspec) -def check_pspec_mix_axis_type(mesh, pspec): - for spec in pspec: - if isinstance(spec, tuple): - if all(mesh._name_to_type[spec[0]] == mesh._name_to_type[p] - for p in spec): - continue - if any(mesh._name_to_type[p] == AxisType.Manual for p in spec): - raise ValueError( - 'Tuple subset of `PartitionSpec` cannot contain `Manual` mixed' - f' with `Auto` or `Explicit`. Got pspec {pspec} and subset' - f' {spec} with axis types:' - f' ({", ".join(str(mesh._name_to_type[p]) for p in spec)})') - def _check_mesh_resource_axis(mesh, pspec): for p in pspec: if p is PartitionSpec.UNCONSTRAINED or p is None: @@ -538,7 +525,6 @@ def _check_mesh_resource_axis(mesh, pspec): raise ValueError( f"Resource axis: {r} of {pspec} " f"is not found in mesh: {tuple(mesh.shape.keys())}.") - check_pspec_mix_axis_type(mesh, pspec) if (AxisType.Auto not in mesh.axis_types and PartitionSpec.UNCONSTRAINED in pspec): raise ValueError( diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 480bcb650c4d..bf81467e4669 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -18,6 +18,7 @@ from collections.abc import Callable, Mapping, Sequence import contextlib import enum +import math from functools import partial, reduce import types from typing import Any @@ -37,6 +38,7 @@ from jax._src.traceback_util import api_boundary from jax._src import tree_util from jax._src import typing as jax_typing +from jax._src.mesh import get_abstract_mesh from jax._src.frozen_dict import FrozenDict from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -46,6 +48,7 @@ from jax._src.pallas import hlo_interpreter from jax._src.pallas import primitives from jax._src.state import discharge as state_discharge +from jax._src.shard_map import shard_map, P, _as_manual_mesh from jax._src.state import types as state_types from jax._src.util import ( safe_map, @@ -115,13 +118,26 @@ def _pallas_call_abstract_eval( raise ValueError(f"input pinned buffers without input_output_aliases:" f"{missing}") outin_aliases = {out_idx: in_idx for in_idx, out_idx in inout_aliases.items()} + # Make sure we don't return ShapedArrayWithMemorySpace to the outside world. out_avals = [jax_core.ShapedArray(a.shape, a.dtype, a.weak_type, sharding=a.sharding) if isinstance(a, pallas_core.ShapedArrayWithMemorySpace) else avals[outin_aliases[out_idx]] if out_idx in outin_aliases else a for out_idx, a in enumerate(out_avals)] - - # Make sure we don't return ShapedArrayWithMemorySpace to the outside world. + # TODO(mattjj,yashkatariya): if we hide vmapped away mesh axes, use this: + # if not (all(a.sharding.mesh.are_all_axes_manual for a in avals) and + # all(a.sharding.mesh.are_all_axes_manual for a in out_avals) and + # get_abstract_mesh().are_all_axes_manual): + # raise ValueError("pallas_call requires all mesh axes to be Manual, " + # f"got {get_abstract_mesh().axis_types}") + + # NOTE(mattjj,yashkatariya): this doesn't catch auto-mode non-manual axes + if not (all(p is None for a in avals if isinstance(a, jax_core.ShapedArray) + for p in a.sharding.spec) and + all(p is None for a in out_avals if isinstance(a, jax_core.ShapedArray) + for p in a.sharding.spec)): + raise ValueError("pallas_call requires all mesh axes to be Manual, " + f"got {get_abstract_mesh().axis_types}") return out_avals, effs @@ -516,6 +532,7 @@ def body(batch_index: jax_typing.Array, state: list[jax_typing.Array]) -> list[j def _pallas_call_batching_rule( + axis_data, args, dims, *, @@ -544,11 +561,21 @@ def _maybe_squeeze_out_bdim( return x return jnp.squeeze(x, axis=bdim) - axis_size, = {x.shape[d] for i, (x, d) in enumerate(zip(args, dims)) - if d is not batching.not_mapped} + # this is the _global_ axis size if axis_data.explicit_mesh_axis is not None + # we want to convert it to the local axis size + axis_size = axis_data.size + ema = axis_data.explicit_mesh_axis + abs_mesh = get_abstract_mesh() + if ema: + mesh_size = math.prod(abs_mesh.shape[i] for i in ema) + axis_size, ragged = divmod(axis_size, mesh_size) + assert not ragged + if axis_size == 1: # Why are we even vmapping? args = map(_maybe_squeeze_out_bdim, args, dims) + if ema: + raise NotImplementedError() out = pallas_call_p.bind( *args, jaxpr=jaxpr, @@ -584,6 +611,8 @@ def _maybe_squeeze_out_bdim( elif any(bdim is not batching.not_mapped for bdim in dynamic_grid_dims): # TODO(amagni, sharadmv): Explore possibility of batching dynamic grid # bounds. + if ema: + raise NotImplementedError() return _batch_with_explicit_loop( args=dynamic_grid_args + args, dims=dynamic_grid_dims + dims, @@ -621,6 +650,8 @@ def _maybe_squeeze_out_bdim( else: # TODO(amagni,sharadmv,apaszke): enable efficient batching over # prefetched scalar args. + if ema: + raise NotImplementedError() return _batch_with_explicit_loop( args=scalar_args + args, dims=scalar_bdims + bdims, @@ -705,31 +736,43 @@ def _maybe_squeeze_out_bdim( batched_out_avals = [] for aval in out_avals: - sharding = aval.sharding.update(spec=tuple_insert(aval.sharding.spec, 0, None)) + manual_mesh = (_as_manual_mesh(aval.sharding.mesh, ema) if ema else + aval.sharding.mesh) + sharding = aval.sharding.update( + mesh=manual_mesh, spec=tuple_insert(aval.sharding.spec, 0, None)) shape = tuple_insert(aval.shape, 0, axis_size) batched_out_avals.append(aval.update(shape=shape, sharding=sharding)) batched_out_avals = tuple(batched_out_avals) - out = pallas_call_p.bind( - *dynamic_grid_args, - *args, - jaxpr=jaxpr, - grid_mapping=batched_grid_mapping, - mesh=mesh, - input_output_aliases=input_output_aliases, - debug=debug, - interpret=interpret, - compiler_params=compiler_params, - cost_estimate=batched_cost_estimate, - out_avals=batched_out_avals, - backend=backend, - metadata=metadata, - name=name, - ) + bind = partial( + pallas_call_p.bind, jaxpr=jaxpr, grid_mapping=batched_grid_mapping, + mesh=mesh, input_output_aliases=input_output_aliases, debug=debug, + interpret=interpret, compiler_params=compiler_params, + cost_estimate=batched_cost_estimate, out_avals=batched_out_avals, + backend=backend, metadata=metadata, name=name) + + if ema: + # TODO all batching rules should probably be in outer mesh ctx + bind = remove_explicit(ema)(shard_map( + bind, out_specs=P(ema), axis_names=set(ema))) + + out = bind(*dynamic_grid_args, *args) return out, (0,) * len(out) +batching.fancy_primitive_batchers[pallas_call_p] = _pallas_call_batching_rule +batching.skippable_batchers[pallas_call_p] = lambda _: () + -batching.primitive_batchers[pallas_call_p] = _pallas_call_batching_rule +@contextlib.contextmanager +def remove_explicit(ema): + prev = jax_core.trace_ctx.axis_env + # assert set(prev.explicit_mesh_axis_names) == set(ema) + new = jax_core.AxisEnv(prev.axis_sizes, prev.spmd_axis_names, set()) + try: + jax_core.trace_ctx.set_axis_env(new) + yield + finally: + jax_core.trace_ctx.set_axis_env(prev) def checkify_pallas_kernel_body_jaxpr( diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index 2a6773f28c80..c7c2ceba1f0e 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -337,8 +337,9 @@ def _tpu_custom_call_lowering( result_types = [mlir.aval_to_ir_type(aval) for aval in out_avals] axis_context = ctx.module_context.axis_context if isinstance(axis_context, sharding_impls.SPMDAxisContext): + manual_axes = axis_context.manual_axes | set(axis_context.mesh.manual_axes) if (axis_context.manual_axes and - axis_context.manual_axes != frozenset(axis_context.mesh.axis_names)): + manual_axes != frozenset(axis_context.mesh.axis_names)): raise NotImplementedError( "Mosaic kernels cannot be automatically partitioned. Please wrap the" " call in a shard_map." diff --git a/tests/array_test.py b/tests/array_test.py index 50447cf7a3d4..1404b9416321 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -1400,11 +1400,8 @@ def test_pspec_mix_axis_types(self): out = aval.update(sharding=NamedSharding(mesh, P(('a', 'b'), 'c', 'd'))) self.assertEqual(out.sharding.spec, P(('a', 'b'), None, None, None)) - with self.assertRaisesRegex( - ValueError, - 'Tuple subset of `PartitionSpec` cannot contain `Manual` mixed with' - ' `Auto` or `Explicit`'): - aval.update(sharding=NamedSharding(mesh, P(('a', 'd'), 'b', 'c'))) + out = aval.update(sharding=NamedSharding(mesh, P(('a', 'd'), 'b', 'c'))) + self.assertEqual(out.sharding.spec, P('a', 'b', None, None)) def test_aval_str_short(self): mesh = AbstractMesh( diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 249898952e85..a698c73afd2b 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -4777,6 +4777,22 @@ def f(x): out_g = jax.jit(jax.grad(lambda x: f(x).sum()))(arr) self.assertEqual(out_g.sharding, NamedSharding(mesh, P('x'))) + @jtu.with_explicit_mesh((2, 2), ('x', 'y')) + def test_mix_manual_explicit_partial(self, mesh): + arr = jax.device_put(np.arange(16).reshape(8, 2), P(('x', 'y'), None)) + + @jax.jit + @jax.shard_map(out_specs=P('x'), axis_names={'x'}) + def f(x): + self.assertEqual(x.shape, (4, 2)) + self.assertEqual(x.aval.sharding.spec, P('y', None)) + self.assertEqual(x.aval.vma, {'x'}) + return x * 2 + + out = f(arr) + self.assertEqual(out.sharding, NamedSharding(mesh, P(('x', 'y'), None))) + self.assertArraysEqual(out, arr * 2) + class FunSpec(NamedTuple): name: str From 2c96bd30588f564a884b7525b6adaf38c425feeb Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 18 Dec 2025 11:50:47 -0800 Subject: [PATCH 273/315] Add jax_disable_bwd_checks to disable bwd pass checks Co-authored-by: Matthew Johnson PiperOrigin-RevId: 846358307 --- jax/_src/config.py | 7 +++++++ jax/_src/custom_derivatives.py | 3 ++- jax/_src/interpreters/ad.py | 2 ++ tests/custom_api_test.py | 17 +++++++++++++++++ 4 files changed, 28 insertions(+), 1 deletion(-) diff --git a/jax/_src/config.py b/jax/_src/config.py index 85d563157939..ad4fec0f3c91 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1832,6 +1832,13 @@ def _validate_default_device(val): upgrade=True, help='Lower refs to pinned buffers in HLO.') +# TODO(mattjj, yashkatariya): remove once we land box plumbing +disable_bwd_checks = bool_state( + name='jax_disable_bwd_checks', + default=False, + upgrade=True, + help='Disables all bwd pass checks') + xla_runtime_errors = bool_state( name='jax_experimental_unsafe_xla_runtime_errors', default=False, diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index c7e1ae016d07..94528250d67d 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -956,7 +956,8 @@ def append(x, d): raise ValueError(msg) results.append(Zero(ct.aval)) else: - if (not core.typecompat(a.to_cotangent_aval(), a_ := core.get_aval(ct)) + if (not config.disable_bwd_checks.value and + not core.typecompat(a.to_cotangent_aval(), a_ := core.get_aval(ct)) and not _ref_typecompat(a.to_cotangent_aval(), a_) and not _temporary_dtype_exception(a.to_cotangent_aval(), a_)): msg = ("Custom VJP bwd rule must produce an output with the same " diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index d0f625ca5784..5d3af4d9ced8 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -457,6 +457,8 @@ def freeze(self): return self.val def ct_check(primal, ct): + if config.disable_bwd_checks.value: + return ct_aval = ct.aval if type(ct) is Zero else typeof(ct) ct_aval_expected = primal.aval.to_cotangent_aval() # type: ignore if not core.typematch(ct_aval, ct_aval_expected, only_shape_shd_check=True): diff --git a/tests/custom_api_test.py b/tests/custom_api_test.py index 143fe6ce01a1..d2be4a4762f5 100644 --- a/tests/custom_api_test.py +++ b/tests/custom_api_test.py @@ -2971,6 +2971,23 @@ def foo_bwd(_, g): r'output\[1\] the bwd rule produced an output of type float..\[3\]'): jax.grad(lambda x, y: foo(x, y * y).sum(), 1)(jnp.ones(3), jnp.ones(4)) + def test_bwd_rule_shape_mismatch_disable(self): + # TODO(mattjj): remove this test when the config option is removed + @jax.custom_vjp + def foo(x, y): + return x + + def foo_fwd(x, y): + return x, None + + def foo_bwd(_, g): + return jnp.zeros(3), jnp.zeros(3) + + foo.defvjp(foo_fwd, foo_bwd) + + with config.disable_bwd_checks(True): + jax.grad(lambda x, y: foo(x, y).sum(), 1)(jnp.ones(3), jnp.ones(4)) + def test_bwd_rule_can_produce_list_or_tuple(self): @jax.custom_vjp def f(x, y): From b4a983c9024e606b9c95ae2017bada5e22987197 Mon Sep 17 00:00:00 2001 From: Vladimir Belitskiy Date: Thu, 18 Dec 2025 20:22:06 +0000 Subject: [PATCH 274/315] Pin the NumPy version for the type/lint presubmit more strictly. --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 638004b2e178..6e619a81854e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -41,7 +41,7 @@ repos: - id: mypy files: (jax/|tests/typing_test\.py) exclude: jax/_src/basearray.py|jax/numpy/__init__.py|jax/nn/__init__.py|jaxlib/_jax/.* # Use pyi instead - additional_dependencies: [types-requests==2.31.0, numpy>=2.2.0, scipy-stubs] + additional_dependencies: [types-requests==2.31.0, numpy~=2.3.0, scipy-stubs] args: [--config=pyproject.toml] - repo: https://github.com/mwouts/jupytext From 8887722b6651b5ec8b9919832e03c90ad1d0f590 Mon Sep 17 00:00:00 2001 From: Vladimir Belitskiy Date: Thu, 18 Dec 2025 21:35:54 +0000 Subject: [PATCH 275/315] Skip PickleTest.testPickleOfKeyArray0 on python3.11. Failing with the newly built jaxlib/jax 0.8.2. --- tests/pickle_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/pickle_test.py b/tests/pickle_test.py index c1466c09058f..3f0f2c70de52 100644 --- a/tests/pickle_test.py +++ b/tests/pickle_test.py @@ -14,6 +14,7 @@ import copy import pickle +import sys import unittest from absl.testing import absltest @@ -124,6 +125,8 @@ def testPickleOfArrayWeakType(self): self.assertIsInstance(y, type(x)) self.assertEqual(x.aval, y.aval) + @unittest.skipIf(sys.version_info[:2] == (3, 11), + "cannot pickle: b/470129766") @jtu.sample_product(prng_name=['threefry2x32', 'rbg', 'unsafe_rbg']) def testPickleOfKeyArray(self, prng_name): with jax.default_prng_impl(prng_name): From 45c0a4c3156b2ae220382d49eaea8863698212e3 Mon Sep 17 00:00:00 2001 From: Vladimir Belitskiy Date: Thu, 18 Dec 2025 15:54:49 -0800 Subject: [PATCH 276/315] Suppress FutureWarning from TensorFlow imports regarding `np.object`. Due to Keras updating 3.12.0 -> 3.13.0, this warning started popping up a few tests that import TensorFlow (interoperability tests), for some reason. PiperOrigin-RevId: 846446873 --- jax/experimental/jax2tf/tests/jax2tf_test.py | 8 +++++++- .../tests/multiprocess/jax2tf_multiprocess_test.py | 8 +++++++- jax/experimental/jax2tf/tests/sharding_test.py | 8 +++++++- tests/array_interoperability_test.py | 9 ++++++++- 4 files changed, 29 insertions(+), 4 deletions(-) diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index e3c80e865b12..80e44f8938c1 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -19,6 +19,7 @@ import os import re import unittest +import warnings from absl import logging from absl.testing import absltest, parameterized @@ -40,7 +41,12 @@ import numpy as np try: - import tensorflow as tf + # TODO(b/470156950): Remove this once a proper fix is in place + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", + category=FutureWarning, + message=".*np.object.*") + import tensorflow as tf from jax.experimental import jax2tf from jax.experimental.jax2tf.tests import tf_test_util JaxToTfTestCase = tf_test_util.JaxToTfTestCase diff --git a/jax/experimental/jax2tf/tests/multiprocess/jax2tf_multiprocess_test.py b/jax/experimental/jax2tf/tests/multiprocess/jax2tf_multiprocess_test.py index 2576a277ab83..fa4861f55d92 100644 --- a/jax/experimental/jax2tf/tests/multiprocess/jax2tf_multiprocess_test.py +++ b/jax/experimental/jax2tf/tests/multiprocess/jax2tf_multiprocess_test.py @@ -22,9 +22,15 @@ from jax.experimental import multihost_utils from jax.sharding import PartitionSpec as P import unittest +import warnings try: - import tensorflow as tf + # TODO(b/470156950): Remove this once a proper fix is in place + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", + category=FutureWarning, + message=".*np.object.*") + import tensorflow as tf from jax.experimental import jax2tf from jax.experimental.jax2tf.tests import tf_test_util JaxToTfTestCase = tf_test_util.JaxToTfTestCase diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index 0167c3c45ea8..f0bc0ffa78d5 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -24,6 +24,7 @@ import re from typing import Any import unittest +import warnings from absl import app from absl.testing import absltest @@ -46,7 +47,12 @@ import numpy as np -import tensorflow as tf +# TODO(b/470156950): Remove this once a proper fix is in place +with warnings.catch_warnings(): + warnings.filterwarnings("ignore", + category=FutureWarning, + message=".*np.object.*") + import tensorflow as tf config.parse_flags_with_absl() jtu.request_cpu_devices(8) diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py index e28e0f11461e..6960ab84c765 100644 --- a/tests/array_interoperability_test.py +++ b/tests/array_interoperability_test.py @@ -13,6 +13,7 @@ # limitations under the License. import unittest +import warnings from absl.testing import absltest import numpy as np @@ -34,7 +35,13 @@ cupy = None try: - import tensorflow as tf + # TODO(b/470156950): Remove this once a proper fix is in place + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", + category=FutureWarning, + message=".*np.object.*") + import tensorflow as tf + tf_version = tuple( int(x) for x in tf.version.VERSION.split("-")[0].split(".")) except ImportError: From d1673991f8b97fdad1b23f5b64f2375694fd7a31 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 18 Dec 2025 17:07:09 -0800 Subject: [PATCH 277/315] anselm refs PiperOrigin-RevId: 846470902 --- jax/_src/ad_checkpoint.py | 1 + jax/_src/state/primitives.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 5976aa192e9c..62002cd8eb51 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -735,6 +735,7 @@ def remat_transpose(out_cts, *args, jaxpr, prevent_cse, **params): # TODO(mattjj): avoid round-tripping into UndefinedPrimals args_ = [ad.UndefinedPrimal(x.aval) if isinstance(x, ad.ValAccum) else x for x in args] + if any(isinstance(x, ad.GradAccum) for x in args_): raise NotImplementedError assert not jaxpr.constvars in_linear = [ad.is_undefined_primal(x) for x in args_] diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index d3b7e6fae71b..42dfc614298f 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -1115,9 +1115,11 @@ def _ref_lin(nzs, x, *, memory_space, kind): nz, = nzs x_ref = core.ref_p.bind(x, memory_space=memory_space, kind=kind) def mut_lin(_, x_dot): + if kind == 'anselm_ref': + return ad.Zero(AbstractRef(core.typeof(x_dot))) zero = ad_util.instantiate(x_dot) return core.ref_p.bind(zero, memory_space=memory_space, kind=kind) - return x_ref, True, None, mut_lin + return x_ref, kind != 'anselm_ref', None, mut_lin ad.primitive_jvps[core.ref_p] = _ref_jvp ad.primitive_linearizations[core.ref_p] = _ref_lin From f53e1f3ada032d815aae24295672c610322eab3f Mon Sep 17 00:00:00 2001 From: Hyeontaek Lim Date: Thu, 18 Dec 2025 17:11:46 -0800 Subject: [PATCH 278/315] [JAX] Refresh a custom layout if a buffer is copied across clients or memories This change retrieves a new custom layout if a buffer is copied across PjRt clients or memory spaces because the detail of the buffer layout often changes (e.g., tiling is added if a buffer is copied from a CPU client to a TPU client). Without this change, the newly added test would fail with an error: `AssertionError: Layou[14 chars]or=(1, 0), tiling=(), sub_byte_element_size_in_bits=0) != Layou[14 chars]or=(1, 0), tiling=((8, 128),), sub_byte_element_size_in_bits=0)` PiperOrigin-RevId: 846472304 --- tests/layout_test.py | 60 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/tests/layout_test.py b/tests/layout_test.py index 154fd53f3898..a6ef84864eea 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -16,6 +16,7 @@ from functools import partial from absl.testing import absltest +from absl.testing import parameterized import numpy as np import jax @@ -720,6 +721,65 @@ def f(x): self.assertEqual(out.format, l) self.assertEqual(out.sharding, s) + def test_valid_custom_layout_after_copy_across_clients(self): + if jax._src.lib.ifrt_version < 45: + self.skipTest('Only works for JAX_IFRT_VERSION_NUMBER >= 45') + if not jtu.test_device_matches(['tpu']): + self.skipTest('Only works for TPU') + + custom_dll = Layout(major_to_minor=(1, 0)) + + cpu_sharding = jax.sharding.SingleDeviceSharding( + jax.local_devices(backend='cpu')[0]) + cpu_format = Format(custom_dll, cpu_sharding) + cpu_array = jax.device_put(np.ones((128, 8)), cpu_format) + + mesh = jtu.create_mesh((1, 1), ('x', 'y')) + tpu_sharding = jax.sharding.NamedSharding(mesh, P()) + tpu_format = Format(custom_dll, tpu_sharding) + + copied_tpu_array = jax.device_put(cpu_array, tpu_format.sharding) + canonical_tpu_array = jax.device_put(np.ones((128, 8)), tpu_format) + self.assertEqual( + copied_tpu_array.format.layout, canonical_tpu_array.format.layout) + + @parameterized.named_parameters( + ('device_to_pinned_host', 'device', 'pinned_host'), + ('pinned_host_to_device', 'pinned_host', 'device'), + ('device_to_unpinned_host', 'device', 'unpinned_host'), + ('unpinned_host_to_device', 'unpinned_host', 'device'), + ('pinned_host_to_unpinned_host', 'pinned_host', 'unpinned_host'), + ('unpinned_host_to_pinned_host', 'unpinned_host', 'pinned_host'), + ) + def test_valid_layout_after_copy_across_memories( + self, src_memory_kind, dst_memory_kind): + if not jtu.test_device_matches(['tpu']): + self.skipTest('Only works for TPU') + custom_dll = Layout(major_to_minor=(1, 0)) + + mesh = jtu.create_mesh((1, 1), ('x', 'y')) + src_tpu_sharding = jax.sharding.NamedSharding( + mesh, P(), memory_kind=src_memory_kind) + dst_tpu_sharding = jax.sharding.NamedSharding( + mesh, P(), memory_kind=dst_memory_kind) + + # TPU unpinned_host memories do not support custom layouts. + if src_memory_kind == 'unpinned_host': + src_tpu_format = src_tpu_sharding + else: + src_tpu_format = Format(custom_dll, src_tpu_sharding) + if dst_memory_kind == 'unpinned_host': + dst_tpu_format = dst_tpu_sharding + else: + dst_tpu_format = Format(custom_dll, dst_tpu_sharding) + + tpu_array = jax.device_put(np.ones((128, 8)), src_tpu_format) + + copied_tpu_array = jax.device_put(tpu_array, dst_tpu_sharding) + canonical_tpu_array = jax.device_put(np.ones((128, 8)), dst_tpu_format) + self.assertEqual( + copied_tpu_array.format.layout, canonical_tpu_array.format.layout) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) From 662f93cf02e54814e0d455eeb23e1065ce09b65e Mon Sep 17 00:00:00 2001 From: Vladimir Belitskiy Date: Thu, 18 Dec 2025 17:20:45 -0800 Subject: [PATCH 279/315] Update the requirements' lock files post the JAX 0.8.2 release. PiperOrigin-RevId: 846474763 --- build/requirements.in | 2 +- build/requirements_lock_3_11.txt | 446 ++++++++++++++------------- build/requirements_lock_3_12.txt | 452 ++++++++++++++-------------- build/requirements_lock_3_13.txt | 452 ++++++++++++++-------------- build/requirements_lock_3_13_ft.txt | 370 +++++++++++------------ build/requirements_lock_3_14.txt | 384 +++++++++++------------ build/requirements_lock_3_14_ft.txt | 336 ++++++++++----------- 7 files changed, 1213 insertions(+), 1229 deletions(-) diff --git a/build/requirements.in b/build/requirements.in index 252bdb4be8ce..e37b6cec1659 100644 --- a/build/requirements.in +++ b/build/requirements.in @@ -18,7 +18,7 @@ wheel # the requirements files. jaxlib -jax-cuda12-plugin; sys_platform == "linux" and python_version<"3.14" +jax-cuda12-plugin; sys_platform == "linux" jax-cuda13-plugin jax-cuda12-pjrt; sys_platform == "linux" jax-cuda13-pjrt diff --git a/build/requirements_lock_3_11.txt b/build/requirements_lock_3_11.txt index 640a3462158e..6e9a776a38c5 100644 --- a/build/requirements_lock_3_11.txt +++ b/build/requirements_lock_3_11.txt @@ -243,9 +243,9 @@ execnet==2.1.2 \ --hash=sha256:63d83bfdd9a23e35b9c6a3261412324f964c2ec8dcd8d3c6916ee9373e0befcd \ --hash=sha256:67fba928dd5a544b783f6056f449e5e3931a5c378b128bc18501f7ea79e296ec # via pytest-xdist -filelock==3.20.0 \ - --hash=sha256:339b4732ffda5cd79b13f4e2711a31b0365ce445d95d243bb996273d072546a2 \ - --hash=sha256:711e943b4ec6be42e1d4e6690b48dc175c822967466bb31c0c293f34334c13f4 +filelock==3.20.1 \ + --hash=sha256:15d9e9a67306188a44baa72f569d2bfd803076269365fdea0934385da4dc361a \ + --hash=sha256:b8360948b351b80f420878d8516519a2204b07aefcdcfd24912a5d33127f188c # via -r build/test-requirements.txt flatbuffers==25.9.23 \ --hash=sha256:255538574d6cb6d0a79a17ec8bc0d30985913b87513a01cce8bcdb6b4c44d0e2 \ @@ -253,65 +253,57 @@ flatbuffers==25.9.23 \ # via # -r build/test-requirements.txt # tensorflow -fonttools==4.60.1 \ - --hash=sha256:022beaea4b73a70295b688f817ddc24ed3e3418b5036ffcd5658141184ef0d0c \ - --hash=sha256:026290e4ec76583881763fac284aca67365e0be9f13a7fb137257096114cb3bc \ - --hash=sha256:0b0835ed15dd5b40d726bb61c846a688f5b4ce2208ec68779bc81860adb5851a \ - --hash=sha256:0eae96373e4b7c9e45d099d7a523444e3554360927225c1cdae221a58a45b856 \ - --hash=sha256:122e1a8ada290423c493491d002f622b1992b1ab0b488c68e31c413390dc7eb2 \ - --hash=sha256:1410155d0e764a4615774e5c2c6fc516259fe3eca5882f034eb9bfdbee056259 \ - --hash=sha256:145daa14bf24824b677b9357c5e44fd8895c2a8f53596e1b9ea3496081dc692c \ - --hash=sha256:1525796c3ffe27bb6268ed2a1bb0dcf214d561dfaf04728abf01489eb5339dce \ - --hash=sha256:154cb6ee417e417bf5f7c42fe25858c9140c26f647c7347c06f0cc2d47eff003 \ - --hash=sha256:2299df884c11162617a66b7c316957d74a18e3758c0274762d2cc87df7bc0272 \ - --hash=sha256:2409d5fb7b55fd70f715e6d34e7a6e4f7511b8ad29a49d6df225ee76da76dd77 \ - --hash=sha256:268ecda8ca6cb5c4f044b1fb9b3b376e8cd1b361cef275082429dc4174907038 \ - --hash=sha256:282dafa55f9659e8999110bd8ed422ebe1c8aecd0dc396550b038e6c9a08b8ea \ - --hash=sha256:2ee06fc57512144d8b0445194c2da9f190f61ad51e230f14836286470c99f854 \ - --hash=sha256:3630e86c484263eaac71d117085d509cbcf7b18f677906824e4bace598fb70d2 \ - --hash=sha256:398447f3d8c0c786cbf1209711e79080a40761eb44b27cdafffb48f52bcec258 \ - --hash=sha256:4ba4bd646e86de16160f0fb72e31c3b9b7d0721c3e5b26b9fa2fc931dfdb2652 \ - --hash=sha256:5664fd1a9ea7f244487ac8f10340c4e37664675e8667d6fee420766e0fb3cf08 \ - --hash=sha256:583b7f8e3c49486e4d489ad1deacfb8d5be54a8ef34d6df824f6a171f8511d99 \ - --hash=sha256:596ecaca36367027d525b3b426d8a8208169d09edcf8c7506aceb3a38bfb55c7 \ - --hash=sha256:5c1015318e4fec75dd4943ad5f6a206d9727adf97410d58b7e32ab644a807914 \ - --hash=sha256:66929e2ea2810c6533a5184f938502cfdaea4bc3efb7130d8cc02e1c1b4108d6 \ - --hash=sha256:6ec722ee589e89a89f5b7574f5c45604030aa6ae24cb2c751e2707193b466fed \ - --hash=sha256:6f68576bb4bbf6060c7ab047b1574a1ebe5c50a17de62830079967b211059ebb \ - --hash=sha256:7473a8ed9ed09aeaa191301244a5a9dbe46fe0bf54f9d6cd21d83044c3321217 \ - --hash=sha256:7b0c6d57ab00dae9529f3faf187f2254ea0aa1e04215cf2f1a8ec277c96661bc \ - --hash=sha256:7b4c32e232a71f63a5d00259ca3d88345ce2a43295bb049d21061f338124246f \ - --hash=sha256:8177ec9676ea6e1793c8a084a90b65a9f778771998eb919d05db6d4b1c0b114c \ - --hash=sha256:839565cbf14645952d933853e8ade66a463684ed6ed6c9345d0faf1f0e868877 \ - --hash=sha256:875cb7764708b3132637f6c5fb385b16eeba0f7ac9fa45a69d35e09b47045801 \ - --hash=sha256:8a44788d9d91df72d1a5eac49b31aeb887a5f4aab761b4cffc4196c74907ea85 \ - --hash=sha256:8b4eb332f9501cb1cd3d4d099374a1e1306783ff95489a1026bde9eb02ccc34a \ - --hash=sha256:906306ac7afe2156fcf0042173d6ebbb05416af70f6b370967b47f8f00103bbb \ - --hash=sha256:992775c9fbe2cf794786fa0ffca7f09f564ba3499b8fe9f2f80bd7197db60383 \ - --hash=sha256:996a4d1834524adbb423385d5a629b868ef9d774670856c63c9a0408a3063401 \ - --hash=sha256:9a52f254ce051e196b8fe2af4634c2d2f02c981756c6464dc192f1b6050b4e28 \ - --hash=sha256:9d0ced62b59e0430b3690dbc5373df1c2aa7585e9a8ce38eff87f0fd993c5b01 \ - --hash=sha256:a140761c4ff63d0cb9256ac752f230460ee225ccef4ad8f68affc723c88e2036 \ - --hash=sha256:a184b2ea57b13680ab6d5fbde99ccef152c95c06746cb7718c583abd8f945ccc \ - --hash=sha256:a3db56f153bd4c5c2b619ab02c5db5192e222150ce5a1bc10f16164714bc39ac \ - --hash=sha256:a46b2f450bc79e06ef3b6394f0c68660529ed51692606ad7f953fc2e448bc903 \ - --hash=sha256:a884aef09d45ba1206712c7dbda5829562d3fea7726935d3289d343232ecb0d3 \ - --hash=sha256:b2cf105cee600d2de04ca3cfa1f74f1127f8455b71dbad02b9da6ec266e116d6 \ - --hash=sha256:b33a7884fabd72bdf5f910d0cf46be50dce86a0362a65cfc746a4168c67eb96c \ - --hash=sha256:b42d86938e8dda1cd9a1a87a6d82f1818eaf933348429653559a458d027446da \ - --hash=sha256:b6379e7546ba4ae4b18f8ae2b9bc5960936007a1c0e30b342f662577e8bc3299 \ - --hash=sha256:c7420a2696a44650120cdd269a5d2e56a477e2bfa9d95e86229059beb1c19e15 \ - --hash=sha256:c8651e0d4b3bdeda6602b85fdc2abbefc1b41e573ecb37b6779c4ca50753a199 \ - --hash=sha256:d066ea419f719ed87bc2c99a4a4bfd77c2e5949cb724588b9dd58f3fd90b92bf \ - --hash=sha256:e6c58beb17380f7c2ea181ea11e7db8c0ceb474c9dd45f48e71e2cb577d146a1 \ - --hash=sha256:e852d9dda9f93ad3651ae1e3bb770eac544ec93c3807888798eccddf84596537 \ - --hash=sha256:ec3681a0cb34c255d76dd9d865a55f260164adb9fa02628415cdc2d43ee2c05d \ - --hash=sha256:ee0c0b3b35b34f782afc673d503167157094a16f442ace7c6c5e0ca80b08f50c \ - --hash=sha256:eedacb5c5d22b7097482fa834bda0dafa3d914a4e829ec83cdea2a01f8c813c4 \ - --hash=sha256:ef00af0439ebfee806b25f24c8f92109157ff3fac5731dc7867957812e87b8d9 \ - --hash=sha256:f0e8817c7d1a0c2eedebf57ef9a9896f3ea23324769a9a2061a80fe8852705ed \ - --hash=sha256:f3d5be054c461d6a2268831f04091dc82753176f6ea06dc6047a5e168265a987 \ - --hash=sha256:f4b5c37a5f40e4d733d3bbaaef082149bee5a5ea3156a785ff64d949bd1353fa +fonttools==4.61.1 \ + --hash=sha256:0de30bfe7745c0d1ffa2b0b7048fb7123ad0d71107e10ee090fa0b16b9452e87 \ + --hash=sha256:10d88e55330e092940584774ee5e8a6971b01fc2f4d3466a1d6c158230880796 \ + --hash=sha256:11f35ad7805edba3aac1a3710d104592df59f4b957e30108ae0ba6c10b11dd75 \ + --hash=sha256:15acc09befd16a0fb8a8f62bc147e1a82817542d72184acca9ce6e0aeda9fa6d \ + --hash=sha256:17d2bf5d541add43822bcf0c43d7d847b160c9bb01d15d5007d84e2217aaa371 \ + --hash=sha256:2180f14c141d2f0f3da43f3a81bc8aa4684860f6b0e6f9e165a4831f24e6a23b \ + --hash=sha256:21e7c8d76f62ab13c9472ccf74515ca5b9a761d1bde3265152a6dc58700d895b \ + --hash=sha256:41a7170d042e8c0024703ed13b71893519a1a6d6e18e933e3ec7507a2c26a4b2 \ + --hash=sha256:41ed4b5ec103bd306bb68f81dc166e77409e5209443e5773cb4ed837bcc9b0d3 \ + --hash=sha256:497c31ce314219888c0e2fce5ad9178ca83fe5230b01a5006726cdf3ac9f24d9 \ + --hash=sha256:4c1b526c8d3f615a7b1867f38a9410849c8f4aef078535742198e942fba0e9bd \ + --hash=sha256:4d7092bb38c53bbc78e9255a59158b150bcdc115a1e3b3ce0b5f267dc35dd63c \ + --hash=sha256:4f5686e1fe5fce75d82d93c47a438a25bf0d1319d2843a926f741140b2b16e0c \ + --hash=sha256:58b0ee0ab5b1fc9921eccfe11d1435added19d6494dde14e323f25ad2bc30c56 \ + --hash=sha256:5ce02f38a754f207f2f06557523cd39a06438ba3aafc0639c477ac409fc64e37 \ + --hash=sha256:5fade934607a523614726119164ff621e8c30e8fa1ffffbbd358662056ba69f0 \ + --hash=sha256:5fe9fd43882620017add5eabb781ebfbc6998ee49b35bd7f8f79af1f9f99a958 \ + --hash=sha256:64102ca87e84261419c3747a0d20f396eb024bdbeb04c2bfb37e2891f5fadcb5 \ + --hash=sha256:664c5a68ec406f6b1547946683008576ef8b38275608e1cee6c061828171c118 \ + --hash=sha256:6675329885c44657f826ef01d9e4fb33b9158e9d93c537d84ad8399539bc6f69 \ + --hash=sha256:75c1a6dfac6abd407634420c93864a1e274ebc1c7531346d9254c0d8f6ca00f9 \ + --hash=sha256:75da8f28eff26defba42c52986de97b22106cb8f26515b7c22443ebc9c2d3261 \ + --hash=sha256:77efb033d8d7ff233385f30c62c7c79271c8885d5c9657d967ede124671bbdfb \ + --hash=sha256:78a7d3ab09dc47ac1a363a493e6112d8cabed7ba7caad5f54dbe2f08676d1b47 \ + --hash=sha256:7c7db70d57e5e1089a274cbb2b1fd635c9a24de809a231b154965d415d6c6d24 \ + --hash=sha256:8c56c488ab471628ff3bfa80964372fc13504ece601e0d97a78ee74126b2045c \ + --hash=sha256:91669ccac46bbc1d09e9273546181919064e8df73488ea087dcac3e2968df9ba \ + --hash=sha256:9b666a475a65f4e839d3d10473fad6d47e0a9db14a2f4a224029c5bfde58ad2c \ + --hash=sha256:9cfef3ab326780c04d6646f68d4b4742aae222e8b8ea1d627c74e38afcbc9d91 \ + --hash=sha256:a13fc8aeb24bad755eea8f7f9d409438eb94e82cf86b08fe77a03fbc8f6a96b1 \ + --hash=sha256:a75c301f96db737e1c5ed5fd7d77d9c34466de16095a266509e13da09751bd19 \ + --hash=sha256:a76d4cb80f41ba94a6691264be76435e5f72f2cb3cab0b092a6212855f71c2f6 \ + --hash=sha256:aed04cabe26f30c1647ef0e8fbb207516fd40fe9472e9439695f5c6998e60ac5 \ + --hash=sha256:b148b56f5de675ee16d45e769e69f87623a4944f7443850bf9a9376e628a89d2 \ + --hash=sha256:b501c862d4901792adaec7c25b1ecc749e2662543f68bb194c42ba18d6eec98d \ + --hash=sha256:b846a1fcf8beadeb9ea4f44ec5bdde393e2f1569e17d700bfc49cd69bde75881 \ + --hash=sha256:b931ae8f62db78861b0ff1ac017851764602288575d65b8e8ff1963fed419063 \ + --hash=sha256:c33ab3ca9d3ccd581d58e989d67554e42d8d4ded94ab3ade3508455fe70e65f7 \ + --hash=sha256:c6604b735bb12fef8e0efd5578c9fb5d3d8532d5001ea13a19cddf295673ee09 \ + --hash=sha256:d8db08051fc9e7d8bc622f2112511b8107d8f27cd89e2f64ec45e9825e8288da \ + --hash=sha256:d9203500f7c63545b4ce3799319fe4d9feb1a1b89b28d3cb5abd11b9dd64147e \ + --hash=sha256:dc492779501fa723b04d0ab1f5be046797fee17d27700476edc7ee9ae535a61e \ + --hash=sha256:e6bcdf33aec38d16508ce61fd81838f24c83c90a1d1b8c68982857038673d6b8 \ + --hash=sha256:e76ce097e3c57c4bcb67c5aa24a0ecdbd9f74ea9219997a707a4061fbe2707aa \ + --hash=sha256:eff1ac3cc66c2ac7cda1e64b4e2f3ffef474b7335f92fc3833fc632d595fcee6 \ + --hash=sha256:f3cb4a569029b9f291f88aafc927dd53683757e640081ca8c412781ea144565e \ + --hash=sha256:f79b168428351d11e10c5aeb61a74e1851ec221081299f4cf56036a95431c43a \ + --hash=sha256:fa646ecec9528bef693415c79a86e733c70a4965dd938e9a226b0fc64c9d2e6c \ + --hash=sha256:fe2efccb324948a11dd09d22136fe2ac8a97d6c1347cf0b58a911dcd529f66b7 \ + --hash=sha256:fff4f534200a04b4a36e7ae3cb74493afe807b517a09e99cb4faa89a34ed6ecd # via matplotlib fsspec==2025.10.0 \ --hash=sha256:7c7712353ae7d875407f97715f0e1ffcc21e33d5b24556cb1e090ae9409ec61d \ @@ -451,69 +443,69 @@ iniconfig==2.3.0 \ --hash=sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730 \ --hash=sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12 # via pytest -jax-cuda12-pjrt==0.8.1 ; sys_platform == "linux" \ - --hash=sha256:452b70ee10cb9ac5d7dfca55ffbcdb89b6c8bc6ba70a45af7c490d1dcea98eb7 \ - --hash=sha256:a631d0689903354afd7b3d2ec595b7da06a6230a76da00ff9548f542b21b6250 +jax-cuda12-pjrt==0.8.2 ; sys_platform == "linux" \ + --hash=sha256:717a1b196a642409ce195ddf031c20bbeadcc886f55e49a1d3f4927373aeedae \ + --hash=sha256:e3bab41ca7c48e4163db9e7efd271b3aa85f0fe45f5ed0708d6bbed93a59f977 # via # -r build/requirements.in # jax-cuda12-plugin -jax-cuda12-plugin==0.8.1 ; sys_platform == "linux" and python_version < "3.14" \ - --hash=sha256:1052c29157c99ca01d74d3073bbf4f711eb94465c0b4f5a4322d5e46233b1b2f \ - --hash=sha256:13311b72ca703a1bbad1ec516ac9ef750019a2d2c421d4c1daf8acf2720b822e \ - --hash=sha256:385001f56f852959f061ae15ad157c39cc4471c8d1d2544dfc3f805684ac2213 \ - --hash=sha256:479ca438d555024dac8dd1058371efbf8f479819a70aea213f3f3037ece99d74 \ - --hash=sha256:5a154723cb6c4e1e7969581a923dacf378f7515b0d53b5f1920e25e51cf6cecc \ - --hash=sha256:6a4b6fda687ca8361322029d58444bc0326798204806a3f90f231dc8ca5541a5 \ - --hash=sha256:7342c8810cc947de78f28c7287a30b2e201b0f51578543dd2553692b79a49942 \ - --hash=sha256:836eb0cd3af612d17bf17efc7eee175c6b9827989d5370df8ba919947fcb67cf \ - --hash=sha256:9968c15b87fd3867b6da0ce30681673a7fc4eedebaadcd24dce892e3f9fe1a52 \ - --hash=sha256:b3383bdc0b9f6260d9adc4ca0d1f68bf241158dfe69d726b267b0681382ea7a7 \ - --hash=sha256:b60bf0bbda24cec6fa71170bd69b613359f01a376d8e09fe34bf67ecc9a3164f \ - --hash=sha256:da7c0f2ef1c697f9ade51a71cfad211e2bff25407a6855dddde372c0190fc468 +jax-cuda12-plugin==0.8.2 ; sys_platform == "linux" and python_version < "3.14" \ + --hash=sha256:0b0a3304ce7e494acd8d9c593490c112a32cdb6010fe1afc584d9e41fd863167 \ + --hash=sha256:1b4828242d57f233b394d17ebaa599c503c1fb9b7c754012a06eb84dbc935fc8 \ + --hash=sha256:20165861b3d3e66ebb2c0f63a547d1d5ee17ea44ac3be7153c7908c9ca8c88f3 \ + --hash=sha256:377e4be17e22dde0343b3f3c05bf69235b3dbf11d766cca9c5a93da47971dcb7 \ + --hash=sha256:403d5e07731b5cdac3bd9fb3f448bd8480062cb2c0ab61ea2ad23fcd0a65479a \ + --hash=sha256:58c51473fc622e03138035985f741833564d70a4bd5a2178f61b62cdaa32ff94 \ + --hash=sha256:637387dc3408cd204562668502f9e95f76c6edde0a6d2e48f055162dc2aebf0d \ + --hash=sha256:70d33222484ad5c375b8f8357b7c23cacb844f6ecfc39567f8dd47fde6e87858 \ + --hash=sha256:82c6798be66bf8c773386918e4c8e5cd8119753f3bfb3ca4bbc46818283750c6 \ + --hash=sha256:a5898bac1d8ab6020b54546440256409f2c66bcbbb3a1099ca473c84843addad \ + --hash=sha256:d68a6d8b4a45ee561746bac7a6468da8203832626b0b39ad4ac43011f61f875d \ + --hash=sha256:dd4f7c34d4512ff5a36fd1b01584ef7781cad615e3f9e71880eae2f4998e5108 # via -r build/requirements.in -jax-cuda13-pjrt==0.8.1 \ - --hash=sha256:86a6926da76aebf6080922747a7a98d321f4ca27101077357fa148032bc3cd1d \ - --hash=sha256:f3b1c1c7118b4570f72740ed756cbed289a3f8fa813570a0dbf16f186bccb8c9 +jax-cuda13-pjrt==0.8.2 \ + --hash=sha256:370e9cb9a76af6a6b00e8f6c6ae8e5a0158886b65994a114cf3acf56ea1570f3 \ + --hash=sha256:c0ebc284d4bea9c0b278017b4bb62eb871755fe582cbf22ce1fb770d30fb9af6 # via # -r build/requirements.in # jax-cuda13-plugin -jax-cuda13-plugin==0.8.1 \ - --hash=sha256:07625aed1aa769c701213e84d6b2a46902019a1d2af8a09ce6dfd9575163bfc6 \ - --hash=sha256:0d503c312d2daefea62a00c74534579deeacd46e15c364074d27a8d95a100032 \ - --hash=sha256:12a7aac712a7c6dc228ef9991578e85e3bcab7c324193bdfb2b5acf059bae6d6 \ - --hash=sha256:16ee16b13393baf9672b6612566308675cebdc8d785b61fac2b93ce8c97825ff \ - --hash=sha256:4e589ed8197f1bea7e7fd20d866ccc5c2a1276d7acd02224e3a5b07983df61e2 \ - --hash=sha256:64df1f1414d899ab7a84751d6f78515365555b54fb64b3e318bd70519de99c86 \ - --hash=sha256:7a373fd3e5f11ecad01b8add1e277eb6559b4966b0745d92dc91c585579fac35 \ - --hash=sha256:92238530152890c3405addacd1fc021c87022cbf99fa66418cfa2e9f68a5c49d \ - --hash=sha256:a4c5a4a69346be6520c729675d5d80e85d610399f4840d74bdfae9c6ebedc8bc \ - --hash=sha256:af33f737ccf5426155cf5c7d175bf765ca25724b94af5109ef2df891b410f997 \ - --hash=sha256:d81222989fb30496cc6554d42deaca6ab003721f56ff669da7f80d61fea1219d \ - --hash=sha256:ef218c47b2cde8c700ad2a56d04320f9d1490439fc6db20747f56e91de7289c2 +jax-cuda13-plugin==0.8.2 \ + --hash=sha256:0ab3a13f798f973416754cb00b1696a8e6d2425c39496e2145293661bcd28446 \ + --hash=sha256:2c7532f2a6c9dcf50ba4b8ebbe4c78988b8c260695e50f9fdbd68f9417a8bb51 \ + --hash=sha256:358f4b4d0d2b92f62fe7f07d355a9c57c5684da3d4e6f483afec38b23e66f02e \ + --hash=sha256:6703552ee10b6b41b43fe4fc62b4672bec13d271eb709a29bf14b0fa1ec805dd \ + --hash=sha256:8fe2c5874cbae18a01a0fa722dfbd547b5504e7bf42aaaae6ad3e1b676e8daa2 \ + --hash=sha256:950f2c1f2e0f5e13893c4c27a4df540a96a26ebd15d7954fa85ffb8cc8b1d9a1 \ + --hash=sha256:9835076d0a917345700370cfaee5cfd7876c0b0982e99206a8b4848af0ea6e5e \ + --hash=sha256:bb059d99eb960718358f82f3eaa0829f96fc7d7e978ded821f2c40256a8aee4b \ + --hash=sha256:d5f9470db46aee8a64e9ab7817927335b77c6d3698d5113e0da2dd962c1468c4 \ + --hash=sha256:f1643b216ccb41a2d7c0e240359c729c96c5eca76e4525ffcd6313049e348b9a \ + --hash=sha256:f35f687d9b839365a49c2b0052eaeb6b9f0a9879813b4fa020f8d9caa3ddde46 \ + --hash=sha256:fa6a6faf3130d6ef05f9ee6cbc27e0a2e4db0f23720136403991eccb3a0144af # via -r build/requirements.in -jaxlib==0.8.1 \ - --hash=sha256:117f2fe2c19479e560ad85a3ef2fcc0b1d24816456f0d039f865c2acbab63b5a \ - --hash=sha256:1a4001ed3ba9ed5a812da1b16f52eebb5d473a4480c1523828c7bd3dae8d1375 \ - --hash=sha256:1bc76edec2bc74a7adb5e29329ece51a67c57cd011a06d55d07da62fbabe3389 \ - --hash=sha256:22f489fb5c8be0da7be5e4957a10936b3760a169668f8b25c5d09c51c3ef47f6 \ - --hash=sha256:24ec3f3a9c45d6de060020dc94c444d69e18099fab927ea3979ff8cedf0ed2c9 \ - --hash=sha256:4933298fcfb07a5aa2d1fed21c111d07cea50e6f180dba2cdb5463c13fb98f2f \ - --hash=sha256:63fc25c4b5d03256798796a024125e29bcf254acc3eae5dc3239d1c30b86b866 \ - --hash=sha256:7a5d381fad89622750fae29fab83c0847e2931ad8d6a34dc13b28fc4d67f75a3 \ - --hash=sha256:865add56139883405f3f15c9b0de6a64ab8f4aa549dff196b72dbc86be6ccc1f \ - --hash=sha256:88bde0f535eeea6689e0cd57d40b7660d5206ac95c7d42e09562a109b963a49f \ - --hash=sha256:8e118e1fbe714f37a94ba26777c17faab7dca4a33646a3d98cd1d99673bbd6b1 \ - --hash=sha256:90e48973f8dbded7edc8728be84c01ae00412190187fb06622abfa4edd42c0a8 \ - --hash=sha256:92c41c9b9862c08521eb90515a7c5bcc840c6d30f86230cebf94aea2d6a0af81 \ - --hash=sha256:a0349f6e8179dc897d33aeb90ec66b4a8041330fbbba8d071dc6167cd2271539 \ - --hash=sha256:af4924189fc53b69237715b56ebcbfc71bb91ca16184143dcef0d430c8173de6 \ - --hash=sha256:bd697c171ace1e2e9d6ed910a78f385b3c4095cee290b0255aa58848f2acdeab \ - --hash=sha256:bed1e94ae8c7c16bca4476d8d7f582f0d1a102a4e69c3a9bd2069a0dc42274a9 \ - --hash=sha256:c14c8c19a7eb694aa14092b6d2fffb9d2bdd8a603b63d6f26fbeaf129c204f9f \ - --hash=sha256:d245bd6a279c72ca5f796df84cdd64d7c9c8abc4b8d89adf4acf45898dab958b \ - --hash=sha256:f2f11491b077d05249d63813e811401194a41edc8e9cc60af8f4b554057cfad0 \ - --hash=sha256:fdbbf2336c08bbf8f30548e204c8c9d77f8b2a3a5b7fc7985749246feb8852b0 \ - --hash=sha256:ff32b6320d729131efaf22939825b52d75957c84c32af2b0b1bdb33cf27ba75f +jaxlib==0.8.2 \ + --hash=sha256:023de6f3f56da2af7037970996500586331fdb50b530ecbb54b9666da633bd00 \ + --hash=sha256:05b958f497e49824c432e734bb059723b7dfe69e2ad696a9f9c8ad82fff7c3f8 \ + --hash=sha256:1bfbcf6c3de221784fa4cdb6765a09d71cb4298b15626b3d0409b3dfcd8a8667 \ + --hash=sha256:28eec1a4e0639a0d8702cea3cb70dd3663053dbfa344452994ea48dc6ceadaa5 \ + --hash=sha256:2b9789bd08f8b0cc5a5c12ae896fe432d5942e32e417091b8b5a96a9a6fd5cf1 \ + --hash=sha256:3b16e50c5b730c9dd0a49e55f1acfaa722b00b1af0522a591558dcc0464252f2 \ + --hash=sha256:490bf0cb029c73c65c9431124b86cdc95082dbc1fb76fc549d24d75da33e5454 \ + --hash=sha256:4d006db96be020c8165212a1216372f8acac4ff4f8fb067743d694ef2b301ace \ + --hash=sha256:68108dff0de74adc468016be9a19f80efe48c660c0d5a122287094b44b092afc \ + --hash=sha256:7c304f3a016965b9d1f5239a8a0399a73925f5604fe914c5ca66ecf734bf6422 \ + --hash=sha256:7da8127557c786264049ae55460d1b8d04cc3cdf0403a087f2fc1e6d313ec722 \ + --hash=sha256:964626f581beab31ee6826b228fcc2ec5181b05cecf94a528dff97921c145dbc \ + --hash=sha256:a397ea7dcb37d689ce79173eeb99b2f1347637a36be9a27f20ae6848bfc58bfc \ + --hash=sha256:aa8701b6356f098e8452c3cec762fb5f706fcb8f67ffd65964f63982479aa23b \ + --hash=sha256:bb89be452b1b808d3f88fc01c415b364a260be4cc7ac120c038009f6150a32dc \ + --hash=sha256:beffb004e7eeb5c9afb24439e2b2cf45a4ee3e3e8adf45e355edf2af62acf8b8 \ + --hash=sha256:ccf77da917a20935247c990691decfcbdd06c25ef0ac94d914a04aadb22f714c \ + --hash=sha256:dffc22b5b732b9556d92c918b251c61bcc046617c4dbb51e1f7a656587fddffb \ + --hash=sha256:e6a97dfb0232eed9a2bb6e3828e4f682dbac1a7fea840bfda574cae2dbf5faf9 \ + --hash=sha256:f205e91c3a152a2a76c0bc59a6a2de03e87ec261b91e8812922777185e7b08f5 \ + --hash=sha256:f28edac8c226fc07fa3e8af6f9defede8ac2c307429e3291edce8739d39becc9 \ + --hash=sha256:f472cc72e3058e50b5f0230b236d5a1183bf6c3d5423d2a52eff07bcf34908de # via -r build/requirements.in keras==3.12.0 \ --hash=sha256:02b69e007d5df8042286c3bcc2a888539e3e487590ffb08f6be1b4354df50aa8 \ @@ -634,13 +626,13 @@ libclang==18.1.1 \ --hash=sha256:c533091d8a3bbf7460a00cb6c1a71da93bffe148f172c7d03b1c31fbf8aa2a0b \ --hash=sha256:cf4a99b05376513717ab5d82a0db832c56ccea4fd61a69dbb7bccf2dfb207dbe # via tensorflow -libtpu==0.0.30 ; sys_platform == "linux" and platform_machine == "x86_64" \ - --hash=sha256:26442f0a51d243cf7259407bba8f5d849c9024297efe97044d64b5244283ad63 \ - --hash=sha256:5fabff9a041674bb889fb59ac0b5c54b9dbcf492a8c782e083ef86a8194dbb0f \ - --hash=sha256:8be30562743a63c1c1353e7ba78f0dbfbb051e8d1e9d3bb2b5da9b720363bb0a \ - --hash=sha256:b1fc44915dad56c0ceb733311a4d4396b88dc9a1c7c01acd7617da90e7ec22f2 \ - --hash=sha256:babab04ca663da2c4e4b3ab036c4d465f2f4674c480d08239c5d4965b7ce9e1c \ - --hash=sha256:f9aa040895ec25fafebcd4e1a0e1a9524ff3bd778ca88543731e308f6e516dd1 +libtpu==0.0.32 ; sys_platform == "linux" and platform_machine == "x86_64" \ + --hash=sha256:37f6aefe6d69d24e5e268e74c1e90cd0ce0fa6a796720709445d8938605912ee \ + --hash=sha256:4a4ae6db0a90f0e33902cd345923a5d77dfc5bbab9a3e524c79bf876858104ee \ + --hash=sha256:6453995774b902e43d1caf0d82298a2e90f2f731ddd74d6bef9510d91732775f \ + --hash=sha256:7ffd9cdb89da22d32bb60ce70b88e9e76861c5d704ebf673f32987e19e9913f3 \ + --hash=sha256:98d9713a39d9581c9b10a7bdeaa575cc874a4a7f543cfb40d3a723ee27b428b5 \ + --hash=sha256:d2a0e7abe4a029b79049bc95da20e4c2a64f8ef09823f75286cecfbb4b0b6da1 # via -r build/requirements.in markdown==3.10 \ --hash=sha256:37062d4f2aa4b2b6b32aefb80faa300f82cc790cb949a35b8caede34f2b68c0e \ @@ -741,62 +733,62 @@ markupsafe==3.0.3 \ --hash=sha256:f9e130248f4462aaa8e2552d547f36ddadbeaa573879158d721bbd33dfe4743a \ --hash=sha256:fed51ac40f757d41b7c48425901843666a6677e3e8eb0abcff09e4ba6e664f50 # via werkzeug -matplotlib==3.10.7 \ - --hash=sha256:07124afcf7a6504eafcb8ce94091c5898bbdd351519a1beb5c45f7a38c67e77f \ - --hash=sha256:09d7945a70ea43bf9248f4b6582734c2fe726723204a76eca233f24cffc7ef67 \ - --hash=sha256:0d8c32b7ea6fb80b1aeff5a2ceb3fb9778e2759e899d9beff75584714afcc5ee \ - --hash=sha256:11ae579ac83cdf3fb72573bb89f70e0534de05266728740d478f0f818983c695 \ - --hash=sha256:15112bcbaef211bd663fa935ec33313b948e214454d949b723998a43357b17b0 \ - --hash=sha256:1d9d3713a237970569156cfb4de7533b7c4eacdd61789726f444f96a0d28f57f \ - --hash=sha256:1e4bbad66c177a8fdfa53972e5ef8be72a5f27e6a607cec0d8579abd0f3102b1 \ - --hash=sha256:2222c7ba2cbde7fe63032769f6eb7e83ab3227f47d997a8453377709b7fe3a5a \ - --hash=sha256:22df30ffaa89f6643206cf13877191c63a50e8f800b038bc39bee9d2d4957632 \ - --hash=sha256:31963603041634ce1a96053047b40961f7a29eb8f9a62e80cc2c0427aa1d22a2 \ - --hash=sha256:37a1fea41153dd6ee061d21ab69c9cf2cf543160b1b85d89cd3d2e2a7902ca4c \ - --hash=sha256:3886e47f64611046bc1db523a09dd0a0a6bed6081e6f90e13806dd1d1d1b5e91 \ - --hash=sha256:4645fc5d9d20ffa3a39361fcdbcec731382763b623b72627806bf251b6388866 \ - --hash=sha256:4a11c2e9e72e7de09b7b72e62f3df23317c888299c875e2b778abf1eda8c0a42 \ - --hash=sha256:4a74f79fafb2e177f240579bc83f0b60f82cc47d2f1d260f422a0627207008ca \ - --hash=sha256:4c14b6acd16cddc3569a2d515cfdd81c7a68ac5639b76548cfc1a9e48b20eb65 \ - --hash=sha256:53b492410a6cd66c7a471de6c924f6ede976e963c0f3097a3b7abfadddc67d0a \ - --hash=sha256:53cc80662dd197ece414dd5b66e07370201515a3eaf52e7c518c68c16814773b \ - --hash=sha256:5c09cf8f2793f81368f49f118b6f9f937456362bee282eac575cca7f84cda537 \ - --hash=sha256:5e38c2d581d62ee729a6e144c47a71b3f42fb4187508dbbf4fe71d5612c3433b \ - --hash=sha256:5f3f6d315dcc176ba7ca6e74c7768fb7e4cf566c49cb143f6bc257b62e634ed8 \ - --hash=sha256:6516ce375109c60ceec579e699524e9d504cd7578506f01150f7a6bc174a775e \ - --hash=sha256:667ecd5d8d37813a845053d8f5bf110b534c3c9f30e69ebd25d4701385935a6d \ - --hash=sha256:6f1851eab59ca082c95df5a500106bad73672645625e04538b3ad0f69471ffcc \ - --hash=sha256:702590829c30aada1e8cef0568ddbffa77ca747b4d6e36c6d173f66e301f89cc \ - --hash=sha256:7146d64f561498764561e9cd0ed64fcf582e570fc519e6f521e2d0cfd43365e1 \ - --hash=sha256:744991e0cc863dd669c8dc9136ca4e6e0082be2070b9d793cbd64bec872a6815 \ - --hash=sha256:786656bb13c237bbcebcd402f65f44dd61ead60ee3deb045af429d889c8dbc67 \ - --hash=sha256:7a0edb7209e21840e8361e91ea84ea676658aa93edd5f8762793dec77a4a6748 \ - --hash=sha256:7ac81eee3b7c266dd92cee1cd658407b16c57eed08c7421fa354ed68234de380 \ - --hash=sha256:90ad854c0a435da3104c01e2c6f0028d7e719b690998a2333d7218db80950722 \ - --hash=sha256:9257be2f2a03415f9105c486d304a321168e61ad450f6153d77c69504ad764bb \ - --hash=sha256:932c55d1fa7af4423422cb6a492a31cbcbdbe68fd1a9a3f545aa5e7a143b5355 \ - --hash=sha256:a06ba7e2a2ef9131c79c49e63dad355d2d878413a0376c1727c8b9335ff731c7 \ - --hash=sha256:aebed7b50aa6ac698c90f60f854b47e48cd2252b30510e7a1feddaf5a3f72cbf \ - --hash=sha256:b172db79759f5f9bc13ef1c3ef8b9ee7b37b0247f987fbbbdaa15e4f87fd46a9 \ - --hash=sha256:b3c4ea4948d93c9c29dc01c0c23eef66f2101bf75158c291b88de6525c55c3d1 \ - --hash=sha256:b498e9e4022f93de2d5a37615200ca01297ceebbb56fe4c833f46862a490f9e3 \ - --hash=sha256:b4d41379b05528091f00e1728004f9a8d7191260f3862178b88e8fd770206318 \ - --hash=sha256:b69676845a0a66f9da30e87f48be36734d6748024b525ec4710be40194282c84 \ - --hash=sha256:c17398b709a6cce3d9fdb1595c33e356d91c098cd9486cb2cc21ea2ea418e715 \ - --hash=sha256:c380371d3c23e0eadf8ebff114445b9f970aff2010198d498d4ab4c3b41eea4f \ - --hash=sha256:cb783436e47fcf82064baca52ce748af71725d0352e1d31564cbe9c95df92b9c \ - --hash=sha256:cc1c51b846aca49a5a8b44fbba6a92d583a35c64590ad9e1e950dc88940a4297 \ - --hash=sha256:d0b181e9fa8daf1d9f2d4c547527b167cb8838fc587deabca7b5c01f97199e84 \ - --hash=sha256:d2a959c640cdeecdd2ec3136e8ea0441da59bcaf58d67e9c590740addba2cb68 \ - --hash=sha256:d5f256d49fea31f40f166a5e3131235a5d2f4b7f44520b1cf0baf1ce568ccff0 \ - --hash=sha256:d883460c43e8c6b173fef244a2341f7f7c0e9725c7fe68306e8e44ed9c8fb100 \ - --hash=sha256:d8eb7194b084b12feb19142262165832fc6ee879b945491d1c3d4660748020c4 \ - --hash=sha256:d9749313deb729f08207718d29c86246beb2ea3fdba753595b55901dee5d2fd6 \ - --hash=sha256:de66744b2bb88d5cd27e80dfc2ec9f0517d0a46d204ff98fe9e5f2864eb67657 \ - --hash=sha256:e91f61a064c92c307c5a9dc8c05dc9f8a68f0a3be199d9a002a0622e13f874a1 \ - --hash=sha256:f19410b486fdd139885ace124e57f938c1e6a3210ea13dd29cab58f5d4bc12c7 \ - --hash=sha256:f79d5de970fc90cd5591f60053aecfce1fcd736e0303d9f0bf86be649fa68fb8 \ - --hash=sha256:fba2974df0bf8ce3c995fa84b79cde38326e0f7b5409e7a3a481c1141340bcf7 +matplotlib==3.10.8 \ + --hash=sha256:00270d217d6b20d14b584c521f810d60c5c78406dc289859776550df837dcda7 \ + --hash=sha256:0a33deb84c15ede243aead39f77e990469fff93ad1521163305095b77b72ce4a \ + --hash=sha256:113bb52413ea508ce954a02c10ffd0d565f9c3bc7f2eddc27dfe1731e71c7b5f \ + --hash=sha256:12d90df9183093fcd479f4172ac26b322b1248b15729cb57f42f71f24c7e37a3 \ + --hash=sha256:15d30132718972c2c074cd14638c7f4592bd98719e2308bccea40e0538bc0cb5 \ + --hash=sha256:18821ace09c763ec93aef5eeff087ee493a24051936d7b9ebcad9662f66501f9 \ + --hash=sha256:1ae029229a57cd1e8fe542485f27e7ca7b23aa9e8944ddb4985d0bc444f1eca2 \ + --hash=sha256:2299372c19d56bcd35cf05a2738308758d32b9eaed2371898d8f5bd33f084aa3 \ + --hash=sha256:238b7ce5717600615c895050239ec955d91f321c209dd110db988500558e70d6 \ + --hash=sha256:24d50994d8c5816ddc35411e50a86ab05f575e2530c02752e02538122613371f \ + --hash=sha256:25d380fe8b1dc32cf8f0b1b448470a77afb195438bafdf1d858bfb876f3edf7b \ + --hash=sha256:2c1998e92cd5999e295a731bcb2911c75f597d937341f3030cc24ef2733d78a8 \ + --hash=sha256:2cf5bd12cecf46908f286d7838b2abc6c91cda506c0445b8223a7c19a00df008 \ + --hash=sha256:32f8dce744be5569bebe789e46727946041199030db8aeb2954d26013a0eb26b \ + --hash=sha256:37b3c1cc42aa184b3f738cfa18c1c1d72fd496d85467a6cf7b807936d39aa656 \ + --hash=sha256:3a48a78d2786784cc2413e57397981fb45c79e968d99656706018d6e62e57958 \ + --hash=sha256:3ab4aabc72de4ff77b3ec33a6d78a68227bf1123465887f9905ba79184a1cc04 \ + --hash=sha256:3c624e43ed56313651bc18a47f838b60d7b8032ed348911c54906b130b20071b \ + --hash=sha256:3f2e409836d7f5ac2f1c013110a4d50b9f7edc26328c108915f9075d7d7a91b6 \ + --hash=sha256:3f5c3e4da343bba819f0234186b9004faba952cc420fbc522dc4e103c1985908 \ + --hash=sha256:41703cc95688f2516b480f7f339d8851a6035f18e100ee6a32bc0b8536a12a9c \ + --hash=sha256:495672de149445ec1b772ff2c9ede9b769e3cb4f0d0aa7fa730d7f59e2d4e1c1 \ + --hash=sha256:4cf267add95b1c88300d96ca837833d4112756045364f5c734a2276038dae27d \ + --hash=sha256:56271f3dac49a88d7fca5060f004d9d22b865f743a12a23b1e937a0be4818ee1 \ + --hash=sha256:595ba4d8fe983b88f0eec8c26a241e16d6376fe1979086232f481f8f3f67494c \ + --hash=sha256:5f62550b9a30afde8c1c3ae450e5eb547d579dd69b25c2fc7a1c67f934c1717a \ + --hash=sha256:646d95230efb9ca614a7a594d4fcacde0ac61d25e37dd51710b36477594963ce \ + --hash=sha256:64fcc24778ca0404ce0cb7b6b77ae1f4c7231cdd60e6778f999ee05cbd581b9a \ + --hash=sha256:6be43b667360fef5c754dda5d25a32e6307a03c204f3c0fc5468b78fa87b4160 \ + --hash=sha256:6da7c2ce169267d0d066adcf63758f0604aa6c3eebf67458930f9d9b79ad1db1 \ + --hash=sha256:83d282364ea9f3e52363da262ce32a09dfe241e4080dcedda3c0db059d3c1f11 \ + --hash=sha256:9153c3292705be9f9c64498a8872118540c3f4123d1a1c840172edf262c8be4a \ + --hash=sha256:99eefd13c0dc3b3c1b4d561c1169e65fe47aab7b8158754d7c084088e2329466 \ + --hash=sha256:a0a7f52498f72f13d4a25ea70f35f4cb60642b466cbb0a9be951b5bc3f45a486 \ + --hash=sha256:a2b336e2d91a3d7006864e0990c83b216fcdca64b5a6484912902cef87313d78 \ + --hash=sha256:a48f2b74020919552ea25d222d5cc6af9ca3f4eb43a93e14d068457f545c2a17 \ + --hash=sha256:ad3d9833a64cf48cc4300f2b406c3d0f4f4724a91c0bd5640678a6ba7c102077 \ + --hash=sha256:b44d07310e404ba95f8c25aa5536f154c0a8ec473303535949e52eb71d0a1565 \ + --hash=sha256:b53285e65d4fa4c86399979e956235deb900be5baa7fc1218ea67fbfaeaadd6f \ + --hash=sha256:b5a2b97dbdc7d4f353ebf343744f1d1f1cca8aa8bfddb4262fcf4306c3761d50 \ + --hash=sha256:b9a5ca4ac220a0cdd1ba6bcba3608547117d30468fefce49bb26f55c1a3d5c58 \ + --hash=sha256:bab485bcf8b1c7d2060b4fcb6fc368a9e6f4cd754c9c2fea281f4be21df394a2 \ + --hash=sha256:c108a1d6fa78a50646029cb6d49808ff0fc1330fda87fa6f6250c6b5369b6645 \ + --hash=sha256:d56a1efd5bfd61486c8bc968fa18734464556f0fb8e51690f4ac25d85cbbbbc2 \ + --hash=sha256:d9050fee89a89ed57b4fb2c1bfac9a3d0c57a0d55aed95949eedbc42070fea39 \ + --hash=sha256:dd80ecb295460a5d9d260df63c43f4afbdd832d725a531f008dad1664f458adf \ + --hash=sha256:e8ea3e2d4066083e264e75c829078f9e149fa119d27e19acd503de65e0b13149 \ + --hash=sha256:eb3823f11823deade26ce3b9f40dcb4a213da7a670013929f31d5f5ed1055b22 \ + --hash=sha256:ee40c27c795bda6a5292e9cff9890189d32f7e3a0bf04e0e3c9430c4a00c37df \ + --hash=sha256:efb30e3baaea72ce5928e32bab719ab4770099079d66726a62b11b1ef7273be4 \ + --hash=sha256:f254d118d14a7f99d616271d6c3c27922c092dac11112670b157798b89bf4933 \ + --hash=sha256:f89c151aab2e2e23cb3fe0acad1e8b82841fd265379c4cecd0f3fcb34c15e0f6 \ + --hash=sha256:f97aeb209c3d2511443f8797e3e5a569aebb040d4f8bc79aa3ee78a8fb9e3dd8 \ + --hash=sha256:f9b587c9c7274c1613a30afabf65a272114cd6cdbe67b3406f818c79d7ab2e2a \ + --hash=sha256:fb061f596dad3a0f52b60dc6a5dec4a0c300dec41e058a7efe09256188d170b7 # via -r build/test-requirements.txt mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ @@ -916,9 +908,9 @@ numpy==2.0.2 ; python_version <= "3.12" \ # tensorboard # tensorflow # tensorstore -numpy-typing-compat==20250818.2.0 \ - --hash=sha256:042da86a786b6eb164f900efdfc3ba132f4371a2e44a93109976b1d7538253ed \ - --hash=sha256:3f77ba873ec9668e9b7bd15ae083cc16c82aa732b651ed2bf5aa284cdd0dc71d +numpy-typing-compat==20251206.2.0 \ + --hash=sha256:413171c4333c4175cbad4206c94e58422d291d20426c42581865380156715493 \ + --hash=sha256:7db9d5e991af03b2ade38f43253e4eb03ab88925230931bff7f559c020676fb1 # via optype nvidia-cublas==13.1.0.3 ; sys_platform == "linux" \ --hash=sha256:2a3b94a37def342471c59fad7856caee4926809a72dd5270155d6a31b5b277be \ @@ -1298,17 +1290,17 @@ portpicker==1.6.0 \ --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa # via -r build/test-requirements.txt -protobuf==6.33.1 \ - --hash=sha256:023af8449482fa884d88b4563d85e83accab54138ae098924a985bcbb734a213 \ - --hash=sha256:0f4cf01222c0d959c2b399142deb526de420be8236f22c71356e2a544e153c53 \ - --hash=sha256:8fd7d5e0eb08cd5b87fd3df49bc193f5cfd778701f47e11d127d0afc6c39f1d1 \ - --hash=sha256:923aa6d27a92bf44394f6abf7ea0500f38769d4b07f4be41cb52bd8b1123b9ed \ - --hash=sha256:97f65757e8d09870de6fd973aeddb92f85435607235d20b2dfed93405d00c85b \ - --hash=sha256:d595a9fd694fdeb061a62fbe10eb039cc1e444df81ec9bb70c7fc59ebcb1eafa \ - --hash=sha256:df051de4fd7e5e4371334e234c62ba43763f15ab605579e04c7008c05735cd82 \ - --hash=sha256:f8adba2e44cde2d7618996b3fc02341f03f5bc3f2748be72dc7b063319276178 \ - --hash=sha256:f8d3fdbc966aaab1d05046d0240dd94d40f2a8c62856d41eaa141ff64a79de6b \ - --hash=sha256:fe34575f2bdde76ac429ec7b570235bf0c788883e70aee90068e9981806f2490 +protobuf==6.33.2 \ + --hash=sha256:1f8017c48c07ec5859106533b682260ba3d7c5567b1ca1f24297ce03384d1b4f \ + --hash=sha256:2981c58f582f44b6b13173e12bb8656711189c2a70250845f264b877f00b1913 \ + --hash=sha256:56dc370c91fbb8ac85bc13582c9e373569668a290aa2e66a590c2a0d35ddb9e4 \ + --hash=sha256:7109dcc38a680d033ffb8bf896727423528db9163be1b6a02d6a49606dcadbfe \ + --hash=sha256:7636aad9bb01768870266de5dc009de2d1b936771b38a793f73cbbf279c91c5c \ + --hash=sha256:87eb388bd2d0f78febd8f4c8779c79247b26a5befad525008e49a6955787ff3d \ + --hash=sha256:8cd7640aee0b7828b6d03ae518b5b4806fdfc1afe8de82f79c3454f8aef29872 \ + --hash=sha256:b5d3b5625192214066d99b2b605f5783483575656784de223f00a8d00754fc0e \ + --hash=sha256:d9b19771ca75935b3a4422957bc518b0cecb978b31d1dd12037b088f6bcc0e43 \ + --hash=sha256:fc2a0e8b05b180e5fc0dd1559fe8ebdae21a27e81ac77728fb6c42b12c7419b4 # via # tensorboard # tensorflow @@ -1440,9 +1432,9 @@ scipy==1.16.3 ; python_version <= "3.12" \ # via # -r build/requirements.in # jaxlib -scipy-stubs==1.16.3.0 \ - --hash=sha256:90e5d82ced2183ef3c5c0a28a77df8cc227458624364fa0ff975ad24fa89d6ad \ - --hash=sha256:d6943c085e47a1ed431309f9ca582b6a206a9db808a036132a0bf01ebc34b506 +scipy-stubs==1.16.3.3 \ + --hash=sha256:af47578875d5557567225a16ec1b9b38a48c4c4377d92396413ebd65406c44ee \ + --hash=sha256:f6316b36cd0fb272c994ae5b10c4a73c644a7e156ed8d32bcd9c35303d0e1b7e # via -r build/test-requirements.txt six==1.17.0 \ --hash=sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274 \ @@ -1486,28 +1478,32 @@ tensorflow==2.20.0 ; python_version < "3.14" \ --hash=sha256:dd71a7e7c3270239f4185915e8f2c5d39608c5e18973d6e1d101b153993841eb \ --hash=sha256:e5f169f8f5130ab255bbe854c5f0ae152e93d3d1ac44f42cb1866003b81a5357 # via -r build/nonfreethreading-requirements.txt -tensorstore==0.1.79 \ - --hash=sha256:0fd6165f3df49abc7c9de029b2b72d74bebd2ff2481a5ced003607eb61c56d3e \ - --hash=sha256:108c0e867aa2c87d4982cc6325a2de0c4f5bd63c2bea18adb193a370c40594ce \ - --hash=sha256:11a2c62694ea9c21770bc5a09938d3d15c4b9662b738ae6e1e513c26ed96251a \ - --hash=sha256:1e8e2d098829919caac6a62cf568902e34789069ceddb28497d6e36ebcb95c0b \ - --hash=sha256:29cf4336153af136ac8ac528e2ed46df19367edae7e14e37bca1a8b7c4848ef2 \ - --hash=sha256:5e152d334bf34fbabdfe8e5bc35b87d1f9947065924ff83c29e659308b36e948 \ - --hash=sha256:608f7178ec6e4e4a3c26545b0a44f44bf83438d04bf2d960cd0e7699eaa99ef6 \ - --hash=sha256:6c98c6b74c00e00eba7969292144e471d5c45d67088f0dc08e3a4c60a15ee191 \ - --hash=sha256:6f8f5a940eab434a951c2dadcc7c0516c7bef6d8b7a7144054f7a0c56152b5f5 \ - --hash=sha256:71aa9b45436d888c37b965f7b71195916d15438119b7dccb66a3b0776bfba367 \ - --hash=sha256:7af9422269c2bfcdecf9dd55309060665ab9c2d7f6c892377ed32c032400feea \ - --hash=sha256:83072ee0e551d6dca582e154b64c8b8066d276ec0759784e3149c28212a61f18 \ - --hash=sha256:847982652273fb7b2d694b789205747aaf3e50ae64738c5cb7b5eb03d86a9947 \ - --hash=sha256:8dad44a8a7f2952a5d0030a8bd868b3cfdff048bd40ab53e7226f3d8b0881c5e \ - --hash=sha256:94d8fc9df1721b0287046aca7209fd5040889cad4202e7b73a1fdb77cd9b71c6 \ - --hash=sha256:97756d2cba3c5ce21e15602c2af5a02521cc0ecda7f9fb6d18da2f3bd51827f4 \ - --hash=sha256:a071c6c255b7e412957a6aa563bc4250242c7894edad06ae6358e3d30b7d88ce \ - --hash=sha256:bbd8c1ab7d2e3c03ded3d40bb373ee9a67668e33a564484927865ce43b210386 \ - --hash=sha256:c4230b8fd29795e88e441f749d881973eca8dadf33c5262b367839fb8891f79b \ - --hash=sha256:c9f2dc3342e4686af98f6e259dc9fb377f1bf657b649c247bf6647bbe4f98090 \ - --hash=sha256:debd435042c00be68ba1fb3cf59325a7babb3f4a3cf4744c87dde346802cbbb4 +tensorstore==0.1.80 \ + --hash=sha256:04c29d979eb8b8ee48f873dc13d2701bfd49425500ffc5b848e4ec55b2548281 \ + --hash=sha256:07e4a84bacf70b78305831897068a9b5ad30326e63bbeb92c4bf7e565fcf5e9e \ + --hash=sha256:1113a6982fc0fa8dda8fcc0495715e647ac3360909a86ff13f2e04564f82d54a \ + --hash=sha256:189d924eaec394c9331e284a9c513ed583e336472a925823b5151cb26f41d091 \ + --hash=sha256:1b2b2ed0051dfab7e25295b14e6620520729e6e2ddf505f98c8d3917569614bf \ + --hash=sha256:246641a8780ee5e04e88bc95c8e31faac6471bab1180d1f5cdc9804b29a77c04 \ + --hash=sha256:4158fe76b96f62d12a37d7868150d836e089b5280b2bdd363c43c5d651f10e26 \ + --hash=sha256:46136fe42ee6dd835d957db37073058aea0b78fdfbe2975941640131b7740824 \ + --hash=sha256:4baee67fce95f29f593fbab4866119347115eaace887732aa92cfcbb9e6b0748 \ + --hash=sha256:53fd121ccd332bc4cc397f7af45889360c668b43dc3ff6bc3264df0f9886c11a \ + --hash=sha256:6b7c5dd434bba4ee08fe46bbbdb25c60dd3d47ccb4b8561a9751cf1526da52b8 \ + --hash=sha256:6c8dbbdd31cbb28eccfb23dbbd4218fe67bfc32e9cb452875a485b81031c949d \ + --hash=sha256:7451b30f99d9f31a2b9d70e6ef61815713dc782c58c6d817f91781341e4dac05 \ + --hash=sha256:8cd11027b5a8b66db8d344085a31a1666c78621dac27039c4d571bc4974804a1 \ + --hash=sha256:9c088e8c9f67c266ef4dae3703bd617f7c0cb0fd98e99c4500692e38a4328140 \ + --hash=sha256:a92505189731fcb03f1c69a84ea4460abb24204bfac1f339448a0621e7def77c \ + --hash=sha256:acb8d52fadcefafef4ef8ecca3fc99b1d0e3c5c5a888766484c3e39f050be7f5 \ + --hash=sha256:b193a7a1c4f455a61e60ed2dd67271a3daab0910ddb4bd9db51390d1b36d9996 \ + --hash=sha256:bc28a58c580253a526a4b6d239d18181ef96f1e285a502dbb03ff15eeec07a5b \ + --hash=sha256:c0529afab3800749dd245843d3bf0d061a109a8edb77fb345f476e8bccda51b8 \ + --hash=sha256:d2b353b0bd53fedd77fc5a12a1c1a91cacc3cf59e3dd785529c5a54b31d1c7b1 \ + --hash=sha256:de63843706fdfe9565a45567238c5b1e55a0b28bbde6524200b31d29043a9a16 \ + --hash=sha256:e93df6d34ff5f0f6be245f4d29b99a7c1eef8ad91b50686adf57a5eeea99cb74 \ + --hash=sha256:f65dfaf9e737a41389e29a5a2ea52ca5d14c8d6f48b402c723d800cd16d322b0 \ + --hash=sha256:f8b51d7e685bbb63f6becd7d2ac8634d5ab67ec7e53038e597182e2db2c7aa90 # via -r build/nonfreethreading-requirements.txt termcolor==3.2.0 \ --hash=sha256:610e6456feec42c4bcd28934a8c87a06c3fa28b01561d46aa09a9881b8622c58 \ @@ -1522,13 +1518,13 @@ typing-extensions==4.15.0 \ # optree # optype # tensorflow -urllib3==2.5.0 \ - --hash=sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760 \ - --hash=sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc +urllib3==2.6.2 \ + --hash=sha256:016f9c98bb7e98085cb2b4b17b87d2c702975664e4f060c6532e64d1c1a5e797 \ + --hash=sha256:ec21cddfe7724fc7cb4ba4bea7aa8e2ef36f607a4bab81aa6ce42a13dc3f03dd # via requests -werkzeug==3.1.3 \ - --hash=sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e \ - --hash=sha256:60723ce945c19328679790e3282cc758aa4a6040e4bb330f53d30fa546d44746 +werkzeug==3.1.4 \ + --hash=sha256:2ad50fb9ed09cc3af22c54698351027ace879a0b60a3b5edf5730b2f7d876905 \ + --hash=sha256:cd3cd98b1b92dc3b7b3995038826c68097dcb16f9baa63abe35f20eafeb9fe5e # via tensorboard wheel==0.46.1 \ --hash=sha256:f796f65d72750ccde090663e466d0ca37cd72b62870f7520b96d34cdc07d86d8 \ diff --git a/build/requirements_lock_3_12.txt b/build/requirements_lock_3_12.txt index 393f6f849af7..c6aa89903a08 100644 --- a/build/requirements_lock_3_12.txt +++ b/build/requirements_lock_3_12.txt @@ -243,9 +243,9 @@ execnet==2.1.2 \ --hash=sha256:63d83bfdd9a23e35b9c6a3261412324f964c2ec8dcd8d3c6916ee9373e0befcd \ --hash=sha256:67fba928dd5a544b783f6056f449e5e3931a5c378b128bc18501f7ea79e296ec # via pytest-xdist -filelock==3.20.0 \ - --hash=sha256:339b4732ffda5cd79b13f4e2711a31b0365ce445d95d243bb996273d072546a2 \ - --hash=sha256:711e943b4ec6be42e1d4e6690b48dc175c822967466bb31c0c293f34334c13f4 +filelock==3.20.1 \ + --hash=sha256:15d9e9a67306188a44baa72f569d2bfd803076269365fdea0934385da4dc361a \ + --hash=sha256:b8360948b351b80f420878d8516519a2204b07aefcdcfd24912a5d33127f188c # via -r build/test-requirements.txt flatbuffers==25.9.23 \ --hash=sha256:255538574d6cb6d0a79a17ec8bc0d30985913b87513a01cce8bcdb6b4c44d0e2 \ @@ -253,65 +253,57 @@ flatbuffers==25.9.23 \ # via # -r build/test-requirements.txt # tensorflow -fonttools==4.60.1 \ - --hash=sha256:022beaea4b73a70295b688f817ddc24ed3e3418b5036ffcd5658141184ef0d0c \ - --hash=sha256:026290e4ec76583881763fac284aca67365e0be9f13a7fb137257096114cb3bc \ - --hash=sha256:0b0835ed15dd5b40d726bb61c846a688f5b4ce2208ec68779bc81860adb5851a \ - --hash=sha256:0eae96373e4b7c9e45d099d7a523444e3554360927225c1cdae221a58a45b856 \ - --hash=sha256:122e1a8ada290423c493491d002f622b1992b1ab0b488c68e31c413390dc7eb2 \ - --hash=sha256:1410155d0e764a4615774e5c2c6fc516259fe3eca5882f034eb9bfdbee056259 \ - --hash=sha256:145daa14bf24824b677b9357c5e44fd8895c2a8f53596e1b9ea3496081dc692c \ - --hash=sha256:1525796c3ffe27bb6268ed2a1bb0dcf214d561dfaf04728abf01489eb5339dce \ - --hash=sha256:154cb6ee417e417bf5f7c42fe25858c9140c26f647c7347c06f0cc2d47eff003 \ - --hash=sha256:2299df884c11162617a66b7c316957d74a18e3758c0274762d2cc87df7bc0272 \ - --hash=sha256:2409d5fb7b55fd70f715e6d34e7a6e4f7511b8ad29a49d6df225ee76da76dd77 \ - --hash=sha256:268ecda8ca6cb5c4f044b1fb9b3b376e8cd1b361cef275082429dc4174907038 \ - --hash=sha256:282dafa55f9659e8999110bd8ed422ebe1c8aecd0dc396550b038e6c9a08b8ea \ - --hash=sha256:2ee06fc57512144d8b0445194c2da9f190f61ad51e230f14836286470c99f854 \ - --hash=sha256:3630e86c484263eaac71d117085d509cbcf7b18f677906824e4bace598fb70d2 \ - --hash=sha256:398447f3d8c0c786cbf1209711e79080a40761eb44b27cdafffb48f52bcec258 \ - --hash=sha256:4ba4bd646e86de16160f0fb72e31c3b9b7d0721c3e5b26b9fa2fc931dfdb2652 \ - --hash=sha256:5664fd1a9ea7f244487ac8f10340c4e37664675e8667d6fee420766e0fb3cf08 \ - --hash=sha256:583b7f8e3c49486e4d489ad1deacfb8d5be54a8ef34d6df824f6a171f8511d99 \ - --hash=sha256:596ecaca36367027d525b3b426d8a8208169d09edcf8c7506aceb3a38bfb55c7 \ - --hash=sha256:5c1015318e4fec75dd4943ad5f6a206d9727adf97410d58b7e32ab644a807914 \ - --hash=sha256:66929e2ea2810c6533a5184f938502cfdaea4bc3efb7130d8cc02e1c1b4108d6 \ - --hash=sha256:6ec722ee589e89a89f5b7574f5c45604030aa6ae24cb2c751e2707193b466fed \ - --hash=sha256:6f68576bb4bbf6060c7ab047b1574a1ebe5c50a17de62830079967b211059ebb \ - --hash=sha256:7473a8ed9ed09aeaa191301244a5a9dbe46fe0bf54f9d6cd21d83044c3321217 \ - --hash=sha256:7b0c6d57ab00dae9529f3faf187f2254ea0aa1e04215cf2f1a8ec277c96661bc \ - --hash=sha256:7b4c32e232a71f63a5d00259ca3d88345ce2a43295bb049d21061f338124246f \ - --hash=sha256:8177ec9676ea6e1793c8a084a90b65a9f778771998eb919d05db6d4b1c0b114c \ - --hash=sha256:839565cbf14645952d933853e8ade66a463684ed6ed6c9345d0faf1f0e868877 \ - --hash=sha256:875cb7764708b3132637f6c5fb385b16eeba0f7ac9fa45a69d35e09b47045801 \ - --hash=sha256:8a44788d9d91df72d1a5eac49b31aeb887a5f4aab761b4cffc4196c74907ea85 \ - --hash=sha256:8b4eb332f9501cb1cd3d4d099374a1e1306783ff95489a1026bde9eb02ccc34a \ - --hash=sha256:906306ac7afe2156fcf0042173d6ebbb05416af70f6b370967b47f8f00103bbb \ - --hash=sha256:992775c9fbe2cf794786fa0ffca7f09f564ba3499b8fe9f2f80bd7197db60383 \ - --hash=sha256:996a4d1834524adbb423385d5a629b868ef9d774670856c63c9a0408a3063401 \ - --hash=sha256:9a52f254ce051e196b8fe2af4634c2d2f02c981756c6464dc192f1b6050b4e28 \ - --hash=sha256:9d0ced62b59e0430b3690dbc5373df1c2aa7585e9a8ce38eff87f0fd993c5b01 \ - --hash=sha256:a140761c4ff63d0cb9256ac752f230460ee225ccef4ad8f68affc723c88e2036 \ - --hash=sha256:a184b2ea57b13680ab6d5fbde99ccef152c95c06746cb7718c583abd8f945ccc \ - --hash=sha256:a3db56f153bd4c5c2b619ab02c5db5192e222150ce5a1bc10f16164714bc39ac \ - --hash=sha256:a46b2f450bc79e06ef3b6394f0c68660529ed51692606ad7f953fc2e448bc903 \ - --hash=sha256:a884aef09d45ba1206712c7dbda5829562d3fea7726935d3289d343232ecb0d3 \ - --hash=sha256:b2cf105cee600d2de04ca3cfa1f74f1127f8455b71dbad02b9da6ec266e116d6 \ - --hash=sha256:b33a7884fabd72bdf5f910d0cf46be50dce86a0362a65cfc746a4168c67eb96c \ - --hash=sha256:b42d86938e8dda1cd9a1a87a6d82f1818eaf933348429653559a458d027446da \ - --hash=sha256:b6379e7546ba4ae4b18f8ae2b9bc5960936007a1c0e30b342f662577e8bc3299 \ - --hash=sha256:c7420a2696a44650120cdd269a5d2e56a477e2bfa9d95e86229059beb1c19e15 \ - --hash=sha256:c8651e0d4b3bdeda6602b85fdc2abbefc1b41e573ecb37b6779c4ca50753a199 \ - --hash=sha256:d066ea419f719ed87bc2c99a4a4bfd77c2e5949cb724588b9dd58f3fd90b92bf \ - --hash=sha256:e6c58beb17380f7c2ea181ea11e7db8c0ceb474c9dd45f48e71e2cb577d146a1 \ - --hash=sha256:e852d9dda9f93ad3651ae1e3bb770eac544ec93c3807888798eccddf84596537 \ - --hash=sha256:ec3681a0cb34c255d76dd9d865a55f260164adb9fa02628415cdc2d43ee2c05d \ - --hash=sha256:ee0c0b3b35b34f782afc673d503167157094a16f442ace7c6c5e0ca80b08f50c \ - --hash=sha256:eedacb5c5d22b7097482fa834bda0dafa3d914a4e829ec83cdea2a01f8c813c4 \ - --hash=sha256:ef00af0439ebfee806b25f24c8f92109157ff3fac5731dc7867957812e87b8d9 \ - --hash=sha256:f0e8817c7d1a0c2eedebf57ef9a9896f3ea23324769a9a2061a80fe8852705ed \ - --hash=sha256:f3d5be054c461d6a2268831f04091dc82753176f6ea06dc6047a5e168265a987 \ - --hash=sha256:f4b5c37a5f40e4d733d3bbaaef082149bee5a5ea3156a785ff64d949bd1353fa +fonttools==4.61.1 \ + --hash=sha256:0de30bfe7745c0d1ffa2b0b7048fb7123ad0d71107e10ee090fa0b16b9452e87 \ + --hash=sha256:10d88e55330e092940584774ee5e8a6971b01fc2f4d3466a1d6c158230880796 \ + --hash=sha256:11f35ad7805edba3aac1a3710d104592df59f4b957e30108ae0ba6c10b11dd75 \ + --hash=sha256:15acc09befd16a0fb8a8f62bc147e1a82817542d72184acca9ce6e0aeda9fa6d \ + --hash=sha256:17d2bf5d541add43822bcf0c43d7d847b160c9bb01d15d5007d84e2217aaa371 \ + --hash=sha256:2180f14c141d2f0f3da43f3a81bc8aa4684860f6b0e6f9e165a4831f24e6a23b \ + --hash=sha256:21e7c8d76f62ab13c9472ccf74515ca5b9a761d1bde3265152a6dc58700d895b \ + --hash=sha256:41a7170d042e8c0024703ed13b71893519a1a6d6e18e933e3ec7507a2c26a4b2 \ + --hash=sha256:41ed4b5ec103bd306bb68f81dc166e77409e5209443e5773cb4ed837bcc9b0d3 \ + --hash=sha256:497c31ce314219888c0e2fce5ad9178ca83fe5230b01a5006726cdf3ac9f24d9 \ + --hash=sha256:4c1b526c8d3f615a7b1867f38a9410849c8f4aef078535742198e942fba0e9bd \ + --hash=sha256:4d7092bb38c53bbc78e9255a59158b150bcdc115a1e3b3ce0b5f267dc35dd63c \ + --hash=sha256:4f5686e1fe5fce75d82d93c47a438a25bf0d1319d2843a926f741140b2b16e0c \ + --hash=sha256:58b0ee0ab5b1fc9921eccfe11d1435added19d6494dde14e323f25ad2bc30c56 \ + --hash=sha256:5ce02f38a754f207f2f06557523cd39a06438ba3aafc0639c477ac409fc64e37 \ + --hash=sha256:5fade934607a523614726119164ff621e8c30e8fa1ffffbbd358662056ba69f0 \ + --hash=sha256:5fe9fd43882620017add5eabb781ebfbc6998ee49b35bd7f8f79af1f9f99a958 \ + --hash=sha256:64102ca87e84261419c3747a0d20f396eb024bdbeb04c2bfb37e2891f5fadcb5 \ + --hash=sha256:664c5a68ec406f6b1547946683008576ef8b38275608e1cee6c061828171c118 \ + --hash=sha256:6675329885c44657f826ef01d9e4fb33b9158e9d93c537d84ad8399539bc6f69 \ + --hash=sha256:75c1a6dfac6abd407634420c93864a1e274ebc1c7531346d9254c0d8f6ca00f9 \ + --hash=sha256:75da8f28eff26defba42c52986de97b22106cb8f26515b7c22443ebc9c2d3261 \ + --hash=sha256:77efb033d8d7ff233385f30c62c7c79271c8885d5c9657d967ede124671bbdfb \ + --hash=sha256:78a7d3ab09dc47ac1a363a493e6112d8cabed7ba7caad5f54dbe2f08676d1b47 \ + --hash=sha256:7c7db70d57e5e1089a274cbb2b1fd635c9a24de809a231b154965d415d6c6d24 \ + --hash=sha256:8c56c488ab471628ff3bfa80964372fc13504ece601e0d97a78ee74126b2045c \ + --hash=sha256:91669ccac46bbc1d09e9273546181919064e8df73488ea087dcac3e2968df9ba \ + --hash=sha256:9b666a475a65f4e839d3d10473fad6d47e0a9db14a2f4a224029c5bfde58ad2c \ + --hash=sha256:9cfef3ab326780c04d6646f68d4b4742aae222e8b8ea1d627c74e38afcbc9d91 \ + --hash=sha256:a13fc8aeb24bad755eea8f7f9d409438eb94e82cf86b08fe77a03fbc8f6a96b1 \ + --hash=sha256:a75c301f96db737e1c5ed5fd7d77d9c34466de16095a266509e13da09751bd19 \ + --hash=sha256:a76d4cb80f41ba94a6691264be76435e5f72f2cb3cab0b092a6212855f71c2f6 \ + --hash=sha256:aed04cabe26f30c1647ef0e8fbb207516fd40fe9472e9439695f5c6998e60ac5 \ + --hash=sha256:b148b56f5de675ee16d45e769e69f87623a4944f7443850bf9a9376e628a89d2 \ + --hash=sha256:b501c862d4901792adaec7c25b1ecc749e2662543f68bb194c42ba18d6eec98d \ + --hash=sha256:b846a1fcf8beadeb9ea4f44ec5bdde393e2f1569e17d700bfc49cd69bde75881 \ + --hash=sha256:b931ae8f62db78861b0ff1ac017851764602288575d65b8e8ff1963fed419063 \ + --hash=sha256:c33ab3ca9d3ccd581d58e989d67554e42d8d4ded94ab3ade3508455fe70e65f7 \ + --hash=sha256:c6604b735bb12fef8e0efd5578c9fb5d3d8532d5001ea13a19cddf295673ee09 \ + --hash=sha256:d8db08051fc9e7d8bc622f2112511b8107d8f27cd89e2f64ec45e9825e8288da \ + --hash=sha256:d9203500f7c63545b4ce3799319fe4d9feb1a1b89b28d3cb5abd11b9dd64147e \ + --hash=sha256:dc492779501fa723b04d0ab1f5be046797fee17d27700476edc7ee9ae535a61e \ + --hash=sha256:e6bcdf33aec38d16508ce61fd81838f24c83c90a1d1b8c68982857038673d6b8 \ + --hash=sha256:e76ce097e3c57c4bcb67c5aa24a0ecdbd9f74ea9219997a707a4061fbe2707aa \ + --hash=sha256:eff1ac3cc66c2ac7cda1e64b4e2f3ffef474b7335f92fc3833fc632d595fcee6 \ + --hash=sha256:f3cb4a569029b9f291f88aafc927dd53683757e640081ca8c412781ea144565e \ + --hash=sha256:f79b168428351d11e10c5aeb61a74e1851ec221081299f4cf56036a95431c43a \ + --hash=sha256:fa646ecec9528bef693415c79a86e733c70a4965dd938e9a226b0fc64c9d2e6c \ + --hash=sha256:fe2efccb324948a11dd09d22136fe2ac8a97d6c1347cf0b58a911dcd529f66b7 \ + --hash=sha256:fff4f534200a04b4a36e7ae3cb74493afe807b517a09e99cb4faa89a34ed6ecd # via matplotlib fsspec==2025.10.0 \ --hash=sha256:7c7712353ae7d875407f97715f0e1ffcc21e33d5b24556cb1e090ae9409ec61d \ @@ -451,69 +443,69 @@ iniconfig==2.3.0 \ --hash=sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730 \ --hash=sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12 # via pytest -jax-cuda12-pjrt==0.8.1 ; sys_platform == "linux" \ - --hash=sha256:452b70ee10cb9ac5d7dfca55ffbcdb89b6c8bc6ba70a45af7c490d1dcea98eb7 \ - --hash=sha256:a631d0689903354afd7b3d2ec595b7da06a6230a76da00ff9548f542b21b6250 +jax-cuda12-pjrt==0.8.2 ; sys_platform == "linux" \ + --hash=sha256:717a1b196a642409ce195ddf031c20bbeadcc886f55e49a1d3f4927373aeedae \ + --hash=sha256:e3bab41ca7c48e4163db9e7efd271b3aa85f0fe45f5ed0708d6bbed93a59f977 # via # -r build/requirements.in # jax-cuda12-plugin -jax-cuda12-plugin==0.8.1 ; sys_platform == "linux" and python_version < "3.14" \ - --hash=sha256:1052c29157c99ca01d74d3073bbf4f711eb94465c0b4f5a4322d5e46233b1b2f \ - --hash=sha256:13311b72ca703a1bbad1ec516ac9ef750019a2d2c421d4c1daf8acf2720b822e \ - --hash=sha256:385001f56f852959f061ae15ad157c39cc4471c8d1d2544dfc3f805684ac2213 \ - --hash=sha256:479ca438d555024dac8dd1058371efbf8f479819a70aea213f3f3037ece99d74 \ - --hash=sha256:5a154723cb6c4e1e7969581a923dacf378f7515b0d53b5f1920e25e51cf6cecc \ - --hash=sha256:6a4b6fda687ca8361322029d58444bc0326798204806a3f90f231dc8ca5541a5 \ - --hash=sha256:7342c8810cc947de78f28c7287a30b2e201b0f51578543dd2553692b79a49942 \ - --hash=sha256:836eb0cd3af612d17bf17efc7eee175c6b9827989d5370df8ba919947fcb67cf \ - --hash=sha256:9968c15b87fd3867b6da0ce30681673a7fc4eedebaadcd24dce892e3f9fe1a52 \ - --hash=sha256:b3383bdc0b9f6260d9adc4ca0d1f68bf241158dfe69d726b267b0681382ea7a7 \ - --hash=sha256:b60bf0bbda24cec6fa71170bd69b613359f01a376d8e09fe34bf67ecc9a3164f \ - --hash=sha256:da7c0f2ef1c697f9ade51a71cfad211e2bff25407a6855dddde372c0190fc468 +jax-cuda12-plugin==0.8.2 ; sys_platform == "linux" and python_version < "3.14" \ + --hash=sha256:0b0a3304ce7e494acd8d9c593490c112a32cdb6010fe1afc584d9e41fd863167 \ + --hash=sha256:1b4828242d57f233b394d17ebaa599c503c1fb9b7c754012a06eb84dbc935fc8 \ + --hash=sha256:20165861b3d3e66ebb2c0f63a547d1d5ee17ea44ac3be7153c7908c9ca8c88f3 \ + --hash=sha256:377e4be17e22dde0343b3f3c05bf69235b3dbf11d766cca9c5a93da47971dcb7 \ + --hash=sha256:403d5e07731b5cdac3bd9fb3f448bd8480062cb2c0ab61ea2ad23fcd0a65479a \ + --hash=sha256:58c51473fc622e03138035985f741833564d70a4bd5a2178f61b62cdaa32ff94 \ + --hash=sha256:637387dc3408cd204562668502f9e95f76c6edde0a6d2e48f055162dc2aebf0d \ + --hash=sha256:70d33222484ad5c375b8f8357b7c23cacb844f6ecfc39567f8dd47fde6e87858 \ + --hash=sha256:82c6798be66bf8c773386918e4c8e5cd8119753f3bfb3ca4bbc46818283750c6 \ + --hash=sha256:a5898bac1d8ab6020b54546440256409f2c66bcbbb3a1099ca473c84843addad \ + --hash=sha256:d68a6d8b4a45ee561746bac7a6468da8203832626b0b39ad4ac43011f61f875d \ + --hash=sha256:dd4f7c34d4512ff5a36fd1b01584ef7781cad615e3f9e71880eae2f4998e5108 # via -r build/requirements.in -jax-cuda13-pjrt==0.8.1 \ - --hash=sha256:86a6926da76aebf6080922747a7a98d321f4ca27101077357fa148032bc3cd1d \ - --hash=sha256:f3b1c1c7118b4570f72740ed756cbed289a3f8fa813570a0dbf16f186bccb8c9 +jax-cuda13-pjrt==0.8.2 \ + --hash=sha256:370e9cb9a76af6a6b00e8f6c6ae8e5a0158886b65994a114cf3acf56ea1570f3 \ + --hash=sha256:c0ebc284d4bea9c0b278017b4bb62eb871755fe582cbf22ce1fb770d30fb9af6 # via # -r build/requirements.in # jax-cuda13-plugin -jax-cuda13-plugin==0.8.1 \ - --hash=sha256:07625aed1aa769c701213e84d6b2a46902019a1d2af8a09ce6dfd9575163bfc6 \ - --hash=sha256:0d503c312d2daefea62a00c74534579deeacd46e15c364074d27a8d95a100032 \ - --hash=sha256:12a7aac712a7c6dc228ef9991578e85e3bcab7c324193bdfb2b5acf059bae6d6 \ - --hash=sha256:16ee16b13393baf9672b6612566308675cebdc8d785b61fac2b93ce8c97825ff \ - --hash=sha256:4e589ed8197f1bea7e7fd20d866ccc5c2a1276d7acd02224e3a5b07983df61e2 \ - --hash=sha256:64df1f1414d899ab7a84751d6f78515365555b54fb64b3e318bd70519de99c86 \ - --hash=sha256:7a373fd3e5f11ecad01b8add1e277eb6559b4966b0745d92dc91c585579fac35 \ - --hash=sha256:92238530152890c3405addacd1fc021c87022cbf99fa66418cfa2e9f68a5c49d \ - --hash=sha256:a4c5a4a69346be6520c729675d5d80e85d610399f4840d74bdfae9c6ebedc8bc \ - --hash=sha256:af33f737ccf5426155cf5c7d175bf765ca25724b94af5109ef2df891b410f997 \ - --hash=sha256:d81222989fb30496cc6554d42deaca6ab003721f56ff669da7f80d61fea1219d \ - --hash=sha256:ef218c47b2cde8c700ad2a56d04320f9d1490439fc6db20747f56e91de7289c2 +jax-cuda13-plugin==0.8.2 \ + --hash=sha256:0ab3a13f798f973416754cb00b1696a8e6d2425c39496e2145293661bcd28446 \ + --hash=sha256:2c7532f2a6c9dcf50ba4b8ebbe4c78988b8c260695e50f9fdbd68f9417a8bb51 \ + --hash=sha256:358f4b4d0d2b92f62fe7f07d355a9c57c5684da3d4e6f483afec38b23e66f02e \ + --hash=sha256:6703552ee10b6b41b43fe4fc62b4672bec13d271eb709a29bf14b0fa1ec805dd \ + --hash=sha256:8fe2c5874cbae18a01a0fa722dfbd547b5504e7bf42aaaae6ad3e1b676e8daa2 \ + --hash=sha256:950f2c1f2e0f5e13893c4c27a4df540a96a26ebd15d7954fa85ffb8cc8b1d9a1 \ + --hash=sha256:9835076d0a917345700370cfaee5cfd7876c0b0982e99206a8b4848af0ea6e5e \ + --hash=sha256:bb059d99eb960718358f82f3eaa0829f96fc7d7e978ded821f2c40256a8aee4b \ + --hash=sha256:d5f9470db46aee8a64e9ab7817927335b77c6d3698d5113e0da2dd962c1468c4 \ + --hash=sha256:f1643b216ccb41a2d7c0e240359c729c96c5eca76e4525ffcd6313049e348b9a \ + --hash=sha256:f35f687d9b839365a49c2b0052eaeb6b9f0a9879813b4fa020f8d9caa3ddde46 \ + --hash=sha256:fa6a6faf3130d6ef05f9ee6cbc27e0a2e4db0f23720136403991eccb3a0144af # via -r build/requirements.in -jaxlib==0.8.1 \ - --hash=sha256:117f2fe2c19479e560ad85a3ef2fcc0b1d24816456f0d039f865c2acbab63b5a \ - --hash=sha256:1a4001ed3ba9ed5a812da1b16f52eebb5d473a4480c1523828c7bd3dae8d1375 \ - --hash=sha256:1bc76edec2bc74a7adb5e29329ece51a67c57cd011a06d55d07da62fbabe3389 \ - --hash=sha256:22f489fb5c8be0da7be5e4957a10936b3760a169668f8b25c5d09c51c3ef47f6 \ - --hash=sha256:24ec3f3a9c45d6de060020dc94c444d69e18099fab927ea3979ff8cedf0ed2c9 \ - --hash=sha256:4933298fcfb07a5aa2d1fed21c111d07cea50e6f180dba2cdb5463c13fb98f2f \ - --hash=sha256:63fc25c4b5d03256798796a024125e29bcf254acc3eae5dc3239d1c30b86b866 \ - --hash=sha256:7a5d381fad89622750fae29fab83c0847e2931ad8d6a34dc13b28fc4d67f75a3 \ - --hash=sha256:865add56139883405f3f15c9b0de6a64ab8f4aa549dff196b72dbc86be6ccc1f \ - --hash=sha256:88bde0f535eeea6689e0cd57d40b7660d5206ac95c7d42e09562a109b963a49f \ - --hash=sha256:8e118e1fbe714f37a94ba26777c17faab7dca4a33646a3d98cd1d99673bbd6b1 \ - --hash=sha256:90e48973f8dbded7edc8728be84c01ae00412190187fb06622abfa4edd42c0a8 \ - --hash=sha256:92c41c9b9862c08521eb90515a7c5bcc840c6d30f86230cebf94aea2d6a0af81 \ - --hash=sha256:a0349f6e8179dc897d33aeb90ec66b4a8041330fbbba8d071dc6167cd2271539 \ - --hash=sha256:af4924189fc53b69237715b56ebcbfc71bb91ca16184143dcef0d430c8173de6 \ - --hash=sha256:bd697c171ace1e2e9d6ed910a78f385b3c4095cee290b0255aa58848f2acdeab \ - --hash=sha256:bed1e94ae8c7c16bca4476d8d7f582f0d1a102a4e69c3a9bd2069a0dc42274a9 \ - --hash=sha256:c14c8c19a7eb694aa14092b6d2fffb9d2bdd8a603b63d6f26fbeaf129c204f9f \ - --hash=sha256:d245bd6a279c72ca5f796df84cdd64d7c9c8abc4b8d89adf4acf45898dab958b \ - --hash=sha256:f2f11491b077d05249d63813e811401194a41edc8e9cc60af8f4b554057cfad0 \ - --hash=sha256:fdbbf2336c08bbf8f30548e204c8c9d77f8b2a3a5b7fc7985749246feb8852b0 \ - --hash=sha256:ff32b6320d729131efaf22939825b52d75957c84c32af2b0b1bdb33cf27ba75f +jaxlib==0.8.2 \ + --hash=sha256:023de6f3f56da2af7037970996500586331fdb50b530ecbb54b9666da633bd00 \ + --hash=sha256:05b958f497e49824c432e734bb059723b7dfe69e2ad696a9f9c8ad82fff7c3f8 \ + --hash=sha256:1bfbcf6c3de221784fa4cdb6765a09d71cb4298b15626b3d0409b3dfcd8a8667 \ + --hash=sha256:28eec1a4e0639a0d8702cea3cb70dd3663053dbfa344452994ea48dc6ceadaa5 \ + --hash=sha256:2b9789bd08f8b0cc5a5c12ae896fe432d5942e32e417091b8b5a96a9a6fd5cf1 \ + --hash=sha256:3b16e50c5b730c9dd0a49e55f1acfaa722b00b1af0522a591558dcc0464252f2 \ + --hash=sha256:490bf0cb029c73c65c9431124b86cdc95082dbc1fb76fc549d24d75da33e5454 \ + --hash=sha256:4d006db96be020c8165212a1216372f8acac4ff4f8fb067743d694ef2b301ace \ + --hash=sha256:68108dff0de74adc468016be9a19f80efe48c660c0d5a122287094b44b092afc \ + --hash=sha256:7c304f3a016965b9d1f5239a8a0399a73925f5604fe914c5ca66ecf734bf6422 \ + --hash=sha256:7da8127557c786264049ae55460d1b8d04cc3cdf0403a087f2fc1e6d313ec722 \ + --hash=sha256:964626f581beab31ee6826b228fcc2ec5181b05cecf94a528dff97921c145dbc \ + --hash=sha256:a397ea7dcb37d689ce79173eeb99b2f1347637a36be9a27f20ae6848bfc58bfc \ + --hash=sha256:aa8701b6356f098e8452c3cec762fb5f706fcb8f67ffd65964f63982479aa23b \ + --hash=sha256:bb89be452b1b808d3f88fc01c415b364a260be4cc7ac120c038009f6150a32dc \ + --hash=sha256:beffb004e7eeb5c9afb24439e2b2cf45a4ee3e3e8adf45e355edf2af62acf8b8 \ + --hash=sha256:ccf77da917a20935247c990691decfcbdd06c25ef0ac94d914a04aadb22f714c \ + --hash=sha256:dffc22b5b732b9556d92c918b251c61bcc046617c4dbb51e1f7a656587fddffb \ + --hash=sha256:e6a97dfb0232eed9a2bb6e3828e4f682dbac1a7fea840bfda574cae2dbf5faf9 \ + --hash=sha256:f205e91c3a152a2a76c0bc59a6a2de03e87ec261b91e8812922777185e7b08f5 \ + --hash=sha256:f28edac8c226fc07fa3e8af6f9defede8ac2c307429e3291edce8739d39becc9 \ + --hash=sha256:f472cc72e3058e50b5f0230b236d5a1183bf6c3d5423d2a52eff07bcf34908de # via -r build/requirements.in keras==3.12.0 \ --hash=sha256:02b69e007d5df8042286c3bcc2a888539e3e487590ffb08f6be1b4354df50aa8 \ @@ -634,13 +626,13 @@ libclang==18.1.1 \ --hash=sha256:c533091d8a3bbf7460a00cb6c1a71da93bffe148f172c7d03b1c31fbf8aa2a0b \ --hash=sha256:cf4a99b05376513717ab5d82a0db832c56ccea4fd61a69dbb7bccf2dfb207dbe # via tensorflow -libtpu==0.0.30 ; sys_platform == "linux" and platform_machine == "x86_64" \ - --hash=sha256:26442f0a51d243cf7259407bba8f5d849c9024297efe97044d64b5244283ad63 \ - --hash=sha256:5fabff9a041674bb889fb59ac0b5c54b9dbcf492a8c782e083ef86a8194dbb0f \ - --hash=sha256:8be30562743a63c1c1353e7ba78f0dbfbb051e8d1e9d3bb2b5da9b720363bb0a \ - --hash=sha256:b1fc44915dad56c0ceb733311a4d4396b88dc9a1c7c01acd7617da90e7ec22f2 \ - --hash=sha256:babab04ca663da2c4e4b3ab036c4d465f2f4674c480d08239c5d4965b7ce9e1c \ - --hash=sha256:f9aa040895ec25fafebcd4e1a0e1a9524ff3bd778ca88543731e308f6e516dd1 +libtpu==0.0.32 ; sys_platform == "linux" and platform_machine == "x86_64" \ + --hash=sha256:37f6aefe6d69d24e5e268e74c1e90cd0ce0fa6a796720709445d8938605912ee \ + --hash=sha256:4a4ae6db0a90f0e33902cd345923a5d77dfc5bbab9a3e524c79bf876858104ee \ + --hash=sha256:6453995774b902e43d1caf0d82298a2e90f2f731ddd74d6bef9510d91732775f \ + --hash=sha256:7ffd9cdb89da22d32bb60ce70b88e9e76861c5d704ebf673f32987e19e9913f3 \ + --hash=sha256:98d9713a39d9581c9b10a7bdeaa575cc874a4a7f543cfb40d3a723ee27b428b5 \ + --hash=sha256:d2a0e7abe4a029b79049bc95da20e4c2a64f8ef09823f75286cecfbb4b0b6da1 # via -r build/requirements.in markdown==3.10 \ --hash=sha256:37062d4f2aa4b2b6b32aefb80faa300f82cc790cb949a35b8caede34f2b68c0e \ @@ -741,62 +733,62 @@ markupsafe==3.0.3 \ --hash=sha256:f9e130248f4462aaa8e2552d547f36ddadbeaa573879158d721bbd33dfe4743a \ --hash=sha256:fed51ac40f757d41b7c48425901843666a6677e3e8eb0abcff09e4ba6e664f50 # via werkzeug -matplotlib==3.10.7 \ - --hash=sha256:07124afcf7a6504eafcb8ce94091c5898bbdd351519a1beb5c45f7a38c67e77f \ - --hash=sha256:09d7945a70ea43bf9248f4b6582734c2fe726723204a76eca233f24cffc7ef67 \ - --hash=sha256:0d8c32b7ea6fb80b1aeff5a2ceb3fb9778e2759e899d9beff75584714afcc5ee \ - --hash=sha256:11ae579ac83cdf3fb72573bb89f70e0534de05266728740d478f0f818983c695 \ - --hash=sha256:15112bcbaef211bd663fa935ec33313b948e214454d949b723998a43357b17b0 \ - --hash=sha256:1d9d3713a237970569156cfb4de7533b7c4eacdd61789726f444f96a0d28f57f \ - --hash=sha256:1e4bbad66c177a8fdfa53972e5ef8be72a5f27e6a607cec0d8579abd0f3102b1 \ - --hash=sha256:2222c7ba2cbde7fe63032769f6eb7e83ab3227f47d997a8453377709b7fe3a5a \ - --hash=sha256:22df30ffaa89f6643206cf13877191c63a50e8f800b038bc39bee9d2d4957632 \ - --hash=sha256:31963603041634ce1a96053047b40961f7a29eb8f9a62e80cc2c0427aa1d22a2 \ - --hash=sha256:37a1fea41153dd6ee061d21ab69c9cf2cf543160b1b85d89cd3d2e2a7902ca4c \ - --hash=sha256:3886e47f64611046bc1db523a09dd0a0a6bed6081e6f90e13806dd1d1d1b5e91 \ - --hash=sha256:4645fc5d9d20ffa3a39361fcdbcec731382763b623b72627806bf251b6388866 \ - --hash=sha256:4a11c2e9e72e7de09b7b72e62f3df23317c888299c875e2b778abf1eda8c0a42 \ - --hash=sha256:4a74f79fafb2e177f240579bc83f0b60f82cc47d2f1d260f422a0627207008ca \ - --hash=sha256:4c14b6acd16cddc3569a2d515cfdd81c7a68ac5639b76548cfc1a9e48b20eb65 \ - --hash=sha256:53b492410a6cd66c7a471de6c924f6ede976e963c0f3097a3b7abfadddc67d0a \ - --hash=sha256:53cc80662dd197ece414dd5b66e07370201515a3eaf52e7c518c68c16814773b \ - --hash=sha256:5c09cf8f2793f81368f49f118b6f9f937456362bee282eac575cca7f84cda537 \ - --hash=sha256:5e38c2d581d62ee729a6e144c47a71b3f42fb4187508dbbf4fe71d5612c3433b \ - --hash=sha256:5f3f6d315dcc176ba7ca6e74c7768fb7e4cf566c49cb143f6bc257b62e634ed8 \ - --hash=sha256:6516ce375109c60ceec579e699524e9d504cd7578506f01150f7a6bc174a775e \ - --hash=sha256:667ecd5d8d37813a845053d8f5bf110b534c3c9f30e69ebd25d4701385935a6d \ - --hash=sha256:6f1851eab59ca082c95df5a500106bad73672645625e04538b3ad0f69471ffcc \ - --hash=sha256:702590829c30aada1e8cef0568ddbffa77ca747b4d6e36c6d173f66e301f89cc \ - --hash=sha256:7146d64f561498764561e9cd0ed64fcf582e570fc519e6f521e2d0cfd43365e1 \ - --hash=sha256:744991e0cc863dd669c8dc9136ca4e6e0082be2070b9d793cbd64bec872a6815 \ - --hash=sha256:786656bb13c237bbcebcd402f65f44dd61ead60ee3deb045af429d889c8dbc67 \ - --hash=sha256:7a0edb7209e21840e8361e91ea84ea676658aa93edd5f8762793dec77a4a6748 \ - --hash=sha256:7ac81eee3b7c266dd92cee1cd658407b16c57eed08c7421fa354ed68234de380 \ - --hash=sha256:90ad854c0a435da3104c01e2c6f0028d7e719b690998a2333d7218db80950722 \ - --hash=sha256:9257be2f2a03415f9105c486d304a321168e61ad450f6153d77c69504ad764bb \ - --hash=sha256:932c55d1fa7af4423422cb6a492a31cbcbdbe68fd1a9a3f545aa5e7a143b5355 \ - --hash=sha256:a06ba7e2a2ef9131c79c49e63dad355d2d878413a0376c1727c8b9335ff731c7 \ - --hash=sha256:aebed7b50aa6ac698c90f60f854b47e48cd2252b30510e7a1feddaf5a3f72cbf \ - --hash=sha256:b172db79759f5f9bc13ef1c3ef8b9ee7b37b0247f987fbbbdaa15e4f87fd46a9 \ - --hash=sha256:b3c4ea4948d93c9c29dc01c0c23eef66f2101bf75158c291b88de6525c55c3d1 \ - --hash=sha256:b498e9e4022f93de2d5a37615200ca01297ceebbb56fe4c833f46862a490f9e3 \ - --hash=sha256:b4d41379b05528091f00e1728004f9a8d7191260f3862178b88e8fd770206318 \ - --hash=sha256:b69676845a0a66f9da30e87f48be36734d6748024b525ec4710be40194282c84 \ - --hash=sha256:c17398b709a6cce3d9fdb1595c33e356d91c098cd9486cb2cc21ea2ea418e715 \ - --hash=sha256:c380371d3c23e0eadf8ebff114445b9f970aff2010198d498d4ab4c3b41eea4f \ - --hash=sha256:cb783436e47fcf82064baca52ce748af71725d0352e1d31564cbe9c95df92b9c \ - --hash=sha256:cc1c51b846aca49a5a8b44fbba6a92d583a35c64590ad9e1e950dc88940a4297 \ - --hash=sha256:d0b181e9fa8daf1d9f2d4c547527b167cb8838fc587deabca7b5c01f97199e84 \ - --hash=sha256:d2a959c640cdeecdd2ec3136e8ea0441da59bcaf58d67e9c590740addba2cb68 \ - --hash=sha256:d5f256d49fea31f40f166a5e3131235a5d2f4b7f44520b1cf0baf1ce568ccff0 \ - --hash=sha256:d883460c43e8c6b173fef244a2341f7f7c0e9725c7fe68306e8e44ed9c8fb100 \ - --hash=sha256:d8eb7194b084b12feb19142262165832fc6ee879b945491d1c3d4660748020c4 \ - --hash=sha256:d9749313deb729f08207718d29c86246beb2ea3fdba753595b55901dee5d2fd6 \ - --hash=sha256:de66744b2bb88d5cd27e80dfc2ec9f0517d0a46d204ff98fe9e5f2864eb67657 \ - --hash=sha256:e91f61a064c92c307c5a9dc8c05dc9f8a68f0a3be199d9a002a0622e13f874a1 \ - --hash=sha256:f19410b486fdd139885ace124e57f938c1e6a3210ea13dd29cab58f5d4bc12c7 \ - --hash=sha256:f79d5de970fc90cd5591f60053aecfce1fcd736e0303d9f0bf86be649fa68fb8 \ - --hash=sha256:fba2974df0bf8ce3c995fa84b79cde38326e0f7b5409e7a3a481c1141340bcf7 +matplotlib==3.10.8 \ + --hash=sha256:00270d217d6b20d14b584c521f810d60c5c78406dc289859776550df837dcda7 \ + --hash=sha256:0a33deb84c15ede243aead39f77e990469fff93ad1521163305095b77b72ce4a \ + --hash=sha256:113bb52413ea508ce954a02c10ffd0d565f9c3bc7f2eddc27dfe1731e71c7b5f \ + --hash=sha256:12d90df9183093fcd479f4172ac26b322b1248b15729cb57f42f71f24c7e37a3 \ + --hash=sha256:15d30132718972c2c074cd14638c7f4592bd98719e2308bccea40e0538bc0cb5 \ + --hash=sha256:18821ace09c763ec93aef5eeff087ee493a24051936d7b9ebcad9662f66501f9 \ + --hash=sha256:1ae029229a57cd1e8fe542485f27e7ca7b23aa9e8944ddb4985d0bc444f1eca2 \ + --hash=sha256:2299372c19d56bcd35cf05a2738308758d32b9eaed2371898d8f5bd33f084aa3 \ + --hash=sha256:238b7ce5717600615c895050239ec955d91f321c209dd110db988500558e70d6 \ + --hash=sha256:24d50994d8c5816ddc35411e50a86ab05f575e2530c02752e02538122613371f \ + --hash=sha256:25d380fe8b1dc32cf8f0b1b448470a77afb195438bafdf1d858bfb876f3edf7b \ + --hash=sha256:2c1998e92cd5999e295a731bcb2911c75f597d937341f3030cc24ef2733d78a8 \ + --hash=sha256:2cf5bd12cecf46908f286d7838b2abc6c91cda506c0445b8223a7c19a00df008 \ + --hash=sha256:32f8dce744be5569bebe789e46727946041199030db8aeb2954d26013a0eb26b \ + --hash=sha256:37b3c1cc42aa184b3f738cfa18c1c1d72fd496d85467a6cf7b807936d39aa656 \ + --hash=sha256:3a48a78d2786784cc2413e57397981fb45c79e968d99656706018d6e62e57958 \ + --hash=sha256:3ab4aabc72de4ff77b3ec33a6d78a68227bf1123465887f9905ba79184a1cc04 \ + --hash=sha256:3c624e43ed56313651bc18a47f838b60d7b8032ed348911c54906b130b20071b \ + --hash=sha256:3f2e409836d7f5ac2f1c013110a4d50b9f7edc26328c108915f9075d7d7a91b6 \ + --hash=sha256:3f5c3e4da343bba819f0234186b9004faba952cc420fbc522dc4e103c1985908 \ + --hash=sha256:41703cc95688f2516b480f7f339d8851a6035f18e100ee6a32bc0b8536a12a9c \ + --hash=sha256:495672de149445ec1b772ff2c9ede9b769e3cb4f0d0aa7fa730d7f59e2d4e1c1 \ + --hash=sha256:4cf267add95b1c88300d96ca837833d4112756045364f5c734a2276038dae27d \ + --hash=sha256:56271f3dac49a88d7fca5060f004d9d22b865f743a12a23b1e937a0be4818ee1 \ + --hash=sha256:595ba4d8fe983b88f0eec8c26a241e16d6376fe1979086232f481f8f3f67494c \ + --hash=sha256:5f62550b9a30afde8c1c3ae450e5eb547d579dd69b25c2fc7a1c67f934c1717a \ + --hash=sha256:646d95230efb9ca614a7a594d4fcacde0ac61d25e37dd51710b36477594963ce \ + --hash=sha256:64fcc24778ca0404ce0cb7b6b77ae1f4c7231cdd60e6778f999ee05cbd581b9a \ + --hash=sha256:6be43b667360fef5c754dda5d25a32e6307a03c204f3c0fc5468b78fa87b4160 \ + --hash=sha256:6da7c2ce169267d0d066adcf63758f0604aa6c3eebf67458930f9d9b79ad1db1 \ + --hash=sha256:83d282364ea9f3e52363da262ce32a09dfe241e4080dcedda3c0db059d3c1f11 \ + --hash=sha256:9153c3292705be9f9c64498a8872118540c3f4123d1a1c840172edf262c8be4a \ + --hash=sha256:99eefd13c0dc3b3c1b4d561c1169e65fe47aab7b8158754d7c084088e2329466 \ + --hash=sha256:a0a7f52498f72f13d4a25ea70f35f4cb60642b466cbb0a9be951b5bc3f45a486 \ + --hash=sha256:a2b336e2d91a3d7006864e0990c83b216fcdca64b5a6484912902cef87313d78 \ + --hash=sha256:a48f2b74020919552ea25d222d5cc6af9ca3f4eb43a93e14d068457f545c2a17 \ + --hash=sha256:ad3d9833a64cf48cc4300f2b406c3d0f4f4724a91c0bd5640678a6ba7c102077 \ + --hash=sha256:b44d07310e404ba95f8c25aa5536f154c0a8ec473303535949e52eb71d0a1565 \ + --hash=sha256:b53285e65d4fa4c86399979e956235deb900be5baa7fc1218ea67fbfaeaadd6f \ + --hash=sha256:b5a2b97dbdc7d4f353ebf343744f1d1f1cca8aa8bfddb4262fcf4306c3761d50 \ + --hash=sha256:b9a5ca4ac220a0cdd1ba6bcba3608547117d30468fefce49bb26f55c1a3d5c58 \ + --hash=sha256:bab485bcf8b1c7d2060b4fcb6fc368a9e6f4cd754c9c2fea281f4be21df394a2 \ + --hash=sha256:c108a1d6fa78a50646029cb6d49808ff0fc1330fda87fa6f6250c6b5369b6645 \ + --hash=sha256:d56a1efd5bfd61486c8bc968fa18734464556f0fb8e51690f4ac25d85cbbbbc2 \ + --hash=sha256:d9050fee89a89ed57b4fb2c1bfac9a3d0c57a0d55aed95949eedbc42070fea39 \ + --hash=sha256:dd80ecb295460a5d9d260df63c43f4afbdd832d725a531f008dad1664f458adf \ + --hash=sha256:e8ea3e2d4066083e264e75c829078f9e149fa119d27e19acd503de65e0b13149 \ + --hash=sha256:eb3823f11823deade26ce3b9f40dcb4a213da7a670013929f31d5f5ed1055b22 \ + --hash=sha256:ee40c27c795bda6a5292e9cff9890189d32f7e3a0bf04e0e3c9430c4a00c37df \ + --hash=sha256:efb30e3baaea72ce5928e32bab719ab4770099079d66726a62b11b1ef7273be4 \ + --hash=sha256:f254d118d14a7f99d616271d6c3c27922c092dac11112670b157798b89bf4933 \ + --hash=sha256:f89c151aab2e2e23cb3fe0acad1e8b82841fd265379c4cecd0f3fcb34c15e0f6 \ + --hash=sha256:f97aeb209c3d2511443f8797e3e5a569aebb040d4f8bc79aa3ee78a8fb9e3dd8 \ + --hash=sha256:f9b587c9c7274c1613a30afabf65a272114cd6cdbe67b3406f818c79d7ab2e2a \ + --hash=sha256:fb061f596dad3a0f52b60dc6a5dec4a0c300dec41e058a7efe09256188d170b7 # via -r build/test-requirements.txt mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ @@ -916,9 +908,9 @@ numpy==2.0.2 ; python_version <= "3.12" \ # tensorboard # tensorflow # tensorstore -numpy-typing-compat==20250818.2.0 \ - --hash=sha256:042da86a786b6eb164f900efdfc3ba132f4371a2e44a93109976b1d7538253ed \ - --hash=sha256:3f77ba873ec9668e9b7bd15ae083cc16c82aa732b651ed2bf5aa284cdd0dc71d +numpy-typing-compat==20251206.2.0 \ + --hash=sha256:413171c4333c4175cbad4206c94e58422d291d20426c42581865380156715493 \ + --hash=sha256:7db9d5e991af03b2ade38f43253e4eb03ab88925230931bff7f559c020676fb1 # via optype nvidia-cublas==13.1.0.3 ; sys_platform == "linux" \ --hash=sha256:2a3b94a37def342471c59fad7856caee4926809a72dd5270155d6a31b5b277be \ @@ -1178,9 +1170,9 @@ optree==0.18.0 \ --hash=sha256:fa8e3878a1857761d64f08a23b32140d29754a53f85f7c87186ced2b5b1b49cb \ --hash=sha256:ff7326f36ed70d84c3fd62fb39bc6858f699640b8ab238c3cb8dafe1e200af59 # via keras -optype[numpy]==0.14.0 \ - --hash=sha256:50d02edafd04edf2e5e27d6249760a51b2198adb9f6ffd778030b3d2806b026b \ - --hash=sha256:925cf060b7d1337647f880401f6094321e7d8e837533b8e159b9a92afa3157c6 +optype[numpy]==0.15.0 \ + --hash=sha256:457d6ca9e7da19967ec16d42bdf94e240b33b5d70a56fbbf5b427e5ea39cf41e \ + --hash=sha256:caba40ece9ea39b499fa76c036a82e0d452a432dd4dd3e8e0d30892be2e8c76c # via scipy-stubs packaging==25.0 \ --hash=sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484 \ @@ -1298,17 +1290,17 @@ portpicker==1.6.0 \ --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa # via -r build/test-requirements.txt -protobuf==6.33.1 \ - --hash=sha256:023af8449482fa884d88b4563d85e83accab54138ae098924a985bcbb734a213 \ - --hash=sha256:0f4cf01222c0d959c2b399142deb526de420be8236f22c71356e2a544e153c53 \ - --hash=sha256:8fd7d5e0eb08cd5b87fd3df49bc193f5cfd778701f47e11d127d0afc6c39f1d1 \ - --hash=sha256:923aa6d27a92bf44394f6abf7ea0500f38769d4b07f4be41cb52bd8b1123b9ed \ - --hash=sha256:97f65757e8d09870de6fd973aeddb92f85435607235d20b2dfed93405d00c85b \ - --hash=sha256:d595a9fd694fdeb061a62fbe10eb039cc1e444df81ec9bb70c7fc59ebcb1eafa \ - --hash=sha256:df051de4fd7e5e4371334e234c62ba43763f15ab605579e04c7008c05735cd82 \ - --hash=sha256:f8adba2e44cde2d7618996b3fc02341f03f5bc3f2748be72dc7b063319276178 \ - --hash=sha256:f8d3fdbc966aaab1d05046d0240dd94d40f2a8c62856d41eaa141ff64a79de6b \ - --hash=sha256:fe34575f2bdde76ac429ec7b570235bf0c788883e70aee90068e9981806f2490 +protobuf==6.33.2 \ + --hash=sha256:1f8017c48c07ec5859106533b682260ba3d7c5567b1ca1f24297ce03384d1b4f \ + --hash=sha256:2981c58f582f44b6b13173e12bb8656711189c2a70250845f264b877f00b1913 \ + --hash=sha256:56dc370c91fbb8ac85bc13582c9e373569668a290aa2e66a590c2a0d35ddb9e4 \ + --hash=sha256:7109dcc38a680d033ffb8bf896727423528db9163be1b6a02d6a49606dcadbfe \ + --hash=sha256:7636aad9bb01768870266de5dc009de2d1b936771b38a793f73cbbf279c91c5c \ + --hash=sha256:87eb388bd2d0f78febd8f4c8779c79247b26a5befad525008e49a6955787ff3d \ + --hash=sha256:8cd7640aee0b7828b6d03ae518b5b4806fdfc1afe8de82f79c3454f8aef29872 \ + --hash=sha256:b5d3b5625192214066d99b2b605f5783483575656784de223f00a8d00754fc0e \ + --hash=sha256:d9b19771ca75935b3a4422957bc518b0cecb978b31d1dd12037b088f6bcc0e43 \ + --hash=sha256:fc2a0e8b05b180e5fc0dd1559fe8ebdae21a27e81ac77728fb6c42b12c7419b4 # via # tensorboard # tensorflow @@ -1440,9 +1432,9 @@ scipy==1.16.3 ; python_version <= "3.12" \ # via # -r build/requirements.in # jaxlib -scipy-stubs==1.16.3.0 \ - --hash=sha256:90e5d82ced2183ef3c5c0a28a77df8cc227458624364fa0ff975ad24fa89d6ad \ - --hash=sha256:d6943c085e47a1ed431309f9ca582b6a206a9db808a036132a0bf01ebc34b506 +scipy-stubs==1.16.3.3 \ + --hash=sha256:af47578875d5557567225a16ec1b9b38a48c4c4377d92396413ebd65406c44ee \ + --hash=sha256:f6316b36cd0fb272c994ae5b10c4a73c644a7e156ed8d32bcd9c35303d0e1b7e # via -r build/test-requirements.txt six==1.17.0 \ --hash=sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274 \ @@ -1486,28 +1478,32 @@ tensorflow==2.20.0 ; python_version < "3.14" \ --hash=sha256:dd71a7e7c3270239f4185915e8f2c5d39608c5e18973d6e1d101b153993841eb \ --hash=sha256:e5f169f8f5130ab255bbe854c5f0ae152e93d3d1ac44f42cb1866003b81a5357 # via -r build/nonfreethreading-requirements.txt -tensorstore==0.1.79 \ - --hash=sha256:0fd6165f3df49abc7c9de029b2b72d74bebd2ff2481a5ced003607eb61c56d3e \ - --hash=sha256:108c0e867aa2c87d4982cc6325a2de0c4f5bd63c2bea18adb193a370c40594ce \ - --hash=sha256:11a2c62694ea9c21770bc5a09938d3d15c4b9662b738ae6e1e513c26ed96251a \ - --hash=sha256:1e8e2d098829919caac6a62cf568902e34789069ceddb28497d6e36ebcb95c0b \ - --hash=sha256:29cf4336153af136ac8ac528e2ed46df19367edae7e14e37bca1a8b7c4848ef2 \ - --hash=sha256:5e152d334bf34fbabdfe8e5bc35b87d1f9947065924ff83c29e659308b36e948 \ - --hash=sha256:608f7178ec6e4e4a3c26545b0a44f44bf83438d04bf2d960cd0e7699eaa99ef6 \ - --hash=sha256:6c98c6b74c00e00eba7969292144e471d5c45d67088f0dc08e3a4c60a15ee191 \ - --hash=sha256:6f8f5a940eab434a951c2dadcc7c0516c7bef6d8b7a7144054f7a0c56152b5f5 \ - --hash=sha256:71aa9b45436d888c37b965f7b71195916d15438119b7dccb66a3b0776bfba367 \ - --hash=sha256:7af9422269c2bfcdecf9dd55309060665ab9c2d7f6c892377ed32c032400feea \ - --hash=sha256:83072ee0e551d6dca582e154b64c8b8066d276ec0759784e3149c28212a61f18 \ - --hash=sha256:847982652273fb7b2d694b789205747aaf3e50ae64738c5cb7b5eb03d86a9947 \ - --hash=sha256:8dad44a8a7f2952a5d0030a8bd868b3cfdff048bd40ab53e7226f3d8b0881c5e \ - --hash=sha256:94d8fc9df1721b0287046aca7209fd5040889cad4202e7b73a1fdb77cd9b71c6 \ - --hash=sha256:97756d2cba3c5ce21e15602c2af5a02521cc0ecda7f9fb6d18da2f3bd51827f4 \ - --hash=sha256:a071c6c255b7e412957a6aa563bc4250242c7894edad06ae6358e3d30b7d88ce \ - --hash=sha256:bbd8c1ab7d2e3c03ded3d40bb373ee9a67668e33a564484927865ce43b210386 \ - --hash=sha256:c4230b8fd29795e88e441f749d881973eca8dadf33c5262b367839fb8891f79b \ - --hash=sha256:c9f2dc3342e4686af98f6e259dc9fb377f1bf657b649c247bf6647bbe4f98090 \ - --hash=sha256:debd435042c00be68ba1fb3cf59325a7babb3f4a3cf4744c87dde346802cbbb4 +tensorstore==0.1.80 \ + --hash=sha256:04c29d979eb8b8ee48f873dc13d2701bfd49425500ffc5b848e4ec55b2548281 \ + --hash=sha256:07e4a84bacf70b78305831897068a9b5ad30326e63bbeb92c4bf7e565fcf5e9e \ + --hash=sha256:1113a6982fc0fa8dda8fcc0495715e647ac3360909a86ff13f2e04564f82d54a \ + --hash=sha256:189d924eaec394c9331e284a9c513ed583e336472a925823b5151cb26f41d091 \ + --hash=sha256:1b2b2ed0051dfab7e25295b14e6620520729e6e2ddf505f98c8d3917569614bf \ + --hash=sha256:246641a8780ee5e04e88bc95c8e31faac6471bab1180d1f5cdc9804b29a77c04 \ + --hash=sha256:4158fe76b96f62d12a37d7868150d836e089b5280b2bdd363c43c5d651f10e26 \ + --hash=sha256:46136fe42ee6dd835d957db37073058aea0b78fdfbe2975941640131b7740824 \ + --hash=sha256:4baee67fce95f29f593fbab4866119347115eaace887732aa92cfcbb9e6b0748 \ + --hash=sha256:53fd121ccd332bc4cc397f7af45889360c668b43dc3ff6bc3264df0f9886c11a \ + --hash=sha256:6b7c5dd434bba4ee08fe46bbbdb25c60dd3d47ccb4b8561a9751cf1526da52b8 \ + --hash=sha256:6c8dbbdd31cbb28eccfb23dbbd4218fe67bfc32e9cb452875a485b81031c949d \ + --hash=sha256:7451b30f99d9f31a2b9d70e6ef61815713dc782c58c6d817f91781341e4dac05 \ + --hash=sha256:8cd11027b5a8b66db8d344085a31a1666c78621dac27039c4d571bc4974804a1 \ + --hash=sha256:9c088e8c9f67c266ef4dae3703bd617f7c0cb0fd98e99c4500692e38a4328140 \ + --hash=sha256:a92505189731fcb03f1c69a84ea4460abb24204bfac1f339448a0621e7def77c \ + --hash=sha256:acb8d52fadcefafef4ef8ecca3fc99b1d0e3c5c5a888766484c3e39f050be7f5 \ + --hash=sha256:b193a7a1c4f455a61e60ed2dd67271a3daab0910ddb4bd9db51390d1b36d9996 \ + --hash=sha256:bc28a58c580253a526a4b6d239d18181ef96f1e285a502dbb03ff15eeec07a5b \ + --hash=sha256:c0529afab3800749dd245843d3bf0d061a109a8edb77fb345f476e8bccda51b8 \ + --hash=sha256:d2b353b0bd53fedd77fc5a12a1c1a91cacc3cf59e3dd785529c5a54b31d1c7b1 \ + --hash=sha256:de63843706fdfe9565a45567238c5b1e55a0b28bbde6524200b31d29043a9a16 \ + --hash=sha256:e93df6d34ff5f0f6be245f4d29b99a7c1eef8ad91b50686adf57a5eeea99cb74 \ + --hash=sha256:f65dfaf9e737a41389e29a5a2ea52ca5d14c8d6f48b402c723d800cd16d322b0 \ + --hash=sha256:f8b51d7e685bbb63f6becd7d2ac8634d5ab67ec7e53038e597182e2db2c7aa90 # via -r build/nonfreethreading-requirements.txt termcolor==3.2.0 \ --hash=sha256:610e6456feec42c4bcd28934a8c87a06c3fa28b01561d46aa09a9881b8622c58 \ @@ -1522,13 +1518,13 @@ typing-extensions==4.15.0 \ # optree # optype # tensorflow -urllib3==2.5.0 \ - --hash=sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760 \ - --hash=sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc +urllib3==2.6.2 \ + --hash=sha256:016f9c98bb7e98085cb2b4b17b87d2c702975664e4f060c6532e64d1c1a5e797 \ + --hash=sha256:ec21cddfe7724fc7cb4ba4bea7aa8e2ef36f607a4bab81aa6ce42a13dc3f03dd # via requests -werkzeug==3.1.3 \ - --hash=sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e \ - --hash=sha256:60723ce945c19328679790e3282cc758aa4a6040e4bb330f53d30fa546d44746 +werkzeug==3.1.4 \ + --hash=sha256:2ad50fb9ed09cc3af22c54698351027ace879a0b60a3b5edf5730b2f7d876905 \ + --hash=sha256:cd3cd98b1b92dc3b7b3995038826c68097dcb16f9baa63abe35f20eafeb9fe5e # via tensorboard wheel==0.46.1 \ --hash=sha256:f796f65d72750ccde090663e466d0ca37cd72b62870f7520b96d34cdc07d86d8 \ diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt index 1fd63b0cf32d..c6206076d2c1 100644 --- a/build/requirements_lock_3_13.txt +++ b/build/requirements_lock_3_13.txt @@ -243,9 +243,9 @@ execnet==2.1.2 \ --hash=sha256:63d83bfdd9a23e35b9c6a3261412324f964c2ec8dcd8d3c6916ee9373e0befcd \ --hash=sha256:67fba928dd5a544b783f6056f449e5e3931a5c378b128bc18501f7ea79e296ec # via pytest-xdist -filelock==3.20.0 \ - --hash=sha256:339b4732ffda5cd79b13f4e2711a31b0365ce445d95d243bb996273d072546a2 \ - --hash=sha256:711e943b4ec6be42e1d4e6690b48dc175c822967466bb31c0c293f34334c13f4 +filelock==3.20.1 \ + --hash=sha256:15d9e9a67306188a44baa72f569d2bfd803076269365fdea0934385da4dc361a \ + --hash=sha256:b8360948b351b80f420878d8516519a2204b07aefcdcfd24912a5d33127f188c # via -r build/test-requirements.txt flatbuffers==25.9.23 \ --hash=sha256:255538574d6cb6d0a79a17ec8bc0d30985913b87513a01cce8bcdb6b4c44d0e2 \ @@ -253,65 +253,57 @@ flatbuffers==25.9.23 \ # via # -r build/test-requirements.txt # tensorflow -fonttools==4.60.1 \ - --hash=sha256:022beaea4b73a70295b688f817ddc24ed3e3418b5036ffcd5658141184ef0d0c \ - --hash=sha256:026290e4ec76583881763fac284aca67365e0be9f13a7fb137257096114cb3bc \ - --hash=sha256:0b0835ed15dd5b40d726bb61c846a688f5b4ce2208ec68779bc81860adb5851a \ - --hash=sha256:0eae96373e4b7c9e45d099d7a523444e3554360927225c1cdae221a58a45b856 \ - --hash=sha256:122e1a8ada290423c493491d002f622b1992b1ab0b488c68e31c413390dc7eb2 \ - --hash=sha256:1410155d0e764a4615774e5c2c6fc516259fe3eca5882f034eb9bfdbee056259 \ - --hash=sha256:145daa14bf24824b677b9357c5e44fd8895c2a8f53596e1b9ea3496081dc692c \ - --hash=sha256:1525796c3ffe27bb6268ed2a1bb0dcf214d561dfaf04728abf01489eb5339dce \ - --hash=sha256:154cb6ee417e417bf5f7c42fe25858c9140c26f647c7347c06f0cc2d47eff003 \ - --hash=sha256:2299df884c11162617a66b7c316957d74a18e3758c0274762d2cc87df7bc0272 \ - --hash=sha256:2409d5fb7b55fd70f715e6d34e7a6e4f7511b8ad29a49d6df225ee76da76dd77 \ - --hash=sha256:268ecda8ca6cb5c4f044b1fb9b3b376e8cd1b361cef275082429dc4174907038 \ - --hash=sha256:282dafa55f9659e8999110bd8ed422ebe1c8aecd0dc396550b038e6c9a08b8ea \ - --hash=sha256:2ee06fc57512144d8b0445194c2da9f190f61ad51e230f14836286470c99f854 \ - --hash=sha256:3630e86c484263eaac71d117085d509cbcf7b18f677906824e4bace598fb70d2 \ - --hash=sha256:398447f3d8c0c786cbf1209711e79080a40761eb44b27cdafffb48f52bcec258 \ - --hash=sha256:4ba4bd646e86de16160f0fb72e31c3b9b7d0721c3e5b26b9fa2fc931dfdb2652 \ - --hash=sha256:5664fd1a9ea7f244487ac8f10340c4e37664675e8667d6fee420766e0fb3cf08 \ - --hash=sha256:583b7f8e3c49486e4d489ad1deacfb8d5be54a8ef34d6df824f6a171f8511d99 \ - --hash=sha256:596ecaca36367027d525b3b426d8a8208169d09edcf8c7506aceb3a38bfb55c7 \ - --hash=sha256:5c1015318e4fec75dd4943ad5f6a206d9727adf97410d58b7e32ab644a807914 \ - --hash=sha256:66929e2ea2810c6533a5184f938502cfdaea4bc3efb7130d8cc02e1c1b4108d6 \ - --hash=sha256:6ec722ee589e89a89f5b7574f5c45604030aa6ae24cb2c751e2707193b466fed \ - --hash=sha256:6f68576bb4bbf6060c7ab047b1574a1ebe5c50a17de62830079967b211059ebb \ - --hash=sha256:7473a8ed9ed09aeaa191301244a5a9dbe46fe0bf54f9d6cd21d83044c3321217 \ - --hash=sha256:7b0c6d57ab00dae9529f3faf187f2254ea0aa1e04215cf2f1a8ec277c96661bc \ - --hash=sha256:7b4c32e232a71f63a5d00259ca3d88345ce2a43295bb049d21061f338124246f \ - --hash=sha256:8177ec9676ea6e1793c8a084a90b65a9f778771998eb919d05db6d4b1c0b114c \ - --hash=sha256:839565cbf14645952d933853e8ade66a463684ed6ed6c9345d0faf1f0e868877 \ - --hash=sha256:875cb7764708b3132637f6c5fb385b16eeba0f7ac9fa45a69d35e09b47045801 \ - --hash=sha256:8a44788d9d91df72d1a5eac49b31aeb887a5f4aab761b4cffc4196c74907ea85 \ - --hash=sha256:8b4eb332f9501cb1cd3d4d099374a1e1306783ff95489a1026bde9eb02ccc34a \ - --hash=sha256:906306ac7afe2156fcf0042173d6ebbb05416af70f6b370967b47f8f00103bbb \ - --hash=sha256:992775c9fbe2cf794786fa0ffca7f09f564ba3499b8fe9f2f80bd7197db60383 \ - --hash=sha256:996a4d1834524adbb423385d5a629b868ef9d774670856c63c9a0408a3063401 \ - --hash=sha256:9a52f254ce051e196b8fe2af4634c2d2f02c981756c6464dc192f1b6050b4e28 \ - --hash=sha256:9d0ced62b59e0430b3690dbc5373df1c2aa7585e9a8ce38eff87f0fd993c5b01 \ - --hash=sha256:a140761c4ff63d0cb9256ac752f230460ee225ccef4ad8f68affc723c88e2036 \ - --hash=sha256:a184b2ea57b13680ab6d5fbde99ccef152c95c06746cb7718c583abd8f945ccc \ - --hash=sha256:a3db56f153bd4c5c2b619ab02c5db5192e222150ce5a1bc10f16164714bc39ac \ - --hash=sha256:a46b2f450bc79e06ef3b6394f0c68660529ed51692606ad7f953fc2e448bc903 \ - --hash=sha256:a884aef09d45ba1206712c7dbda5829562d3fea7726935d3289d343232ecb0d3 \ - --hash=sha256:b2cf105cee600d2de04ca3cfa1f74f1127f8455b71dbad02b9da6ec266e116d6 \ - --hash=sha256:b33a7884fabd72bdf5f910d0cf46be50dce86a0362a65cfc746a4168c67eb96c \ - --hash=sha256:b42d86938e8dda1cd9a1a87a6d82f1818eaf933348429653559a458d027446da \ - --hash=sha256:b6379e7546ba4ae4b18f8ae2b9bc5960936007a1c0e30b342f662577e8bc3299 \ - --hash=sha256:c7420a2696a44650120cdd269a5d2e56a477e2bfa9d95e86229059beb1c19e15 \ - --hash=sha256:c8651e0d4b3bdeda6602b85fdc2abbefc1b41e573ecb37b6779c4ca50753a199 \ - --hash=sha256:d066ea419f719ed87bc2c99a4a4bfd77c2e5949cb724588b9dd58f3fd90b92bf \ - --hash=sha256:e6c58beb17380f7c2ea181ea11e7db8c0ceb474c9dd45f48e71e2cb577d146a1 \ - --hash=sha256:e852d9dda9f93ad3651ae1e3bb770eac544ec93c3807888798eccddf84596537 \ - --hash=sha256:ec3681a0cb34c255d76dd9d865a55f260164adb9fa02628415cdc2d43ee2c05d \ - --hash=sha256:ee0c0b3b35b34f782afc673d503167157094a16f442ace7c6c5e0ca80b08f50c \ - --hash=sha256:eedacb5c5d22b7097482fa834bda0dafa3d914a4e829ec83cdea2a01f8c813c4 \ - --hash=sha256:ef00af0439ebfee806b25f24c8f92109157ff3fac5731dc7867957812e87b8d9 \ - --hash=sha256:f0e8817c7d1a0c2eedebf57ef9a9896f3ea23324769a9a2061a80fe8852705ed \ - --hash=sha256:f3d5be054c461d6a2268831f04091dc82753176f6ea06dc6047a5e168265a987 \ - --hash=sha256:f4b5c37a5f40e4d733d3bbaaef082149bee5a5ea3156a785ff64d949bd1353fa +fonttools==4.61.1 \ + --hash=sha256:0de30bfe7745c0d1ffa2b0b7048fb7123ad0d71107e10ee090fa0b16b9452e87 \ + --hash=sha256:10d88e55330e092940584774ee5e8a6971b01fc2f4d3466a1d6c158230880796 \ + --hash=sha256:11f35ad7805edba3aac1a3710d104592df59f4b957e30108ae0ba6c10b11dd75 \ + --hash=sha256:15acc09befd16a0fb8a8f62bc147e1a82817542d72184acca9ce6e0aeda9fa6d \ + --hash=sha256:17d2bf5d541add43822bcf0c43d7d847b160c9bb01d15d5007d84e2217aaa371 \ + --hash=sha256:2180f14c141d2f0f3da43f3a81bc8aa4684860f6b0e6f9e165a4831f24e6a23b \ + --hash=sha256:21e7c8d76f62ab13c9472ccf74515ca5b9a761d1bde3265152a6dc58700d895b \ + --hash=sha256:41a7170d042e8c0024703ed13b71893519a1a6d6e18e933e3ec7507a2c26a4b2 \ + --hash=sha256:41ed4b5ec103bd306bb68f81dc166e77409e5209443e5773cb4ed837bcc9b0d3 \ + --hash=sha256:497c31ce314219888c0e2fce5ad9178ca83fe5230b01a5006726cdf3ac9f24d9 \ + --hash=sha256:4c1b526c8d3f615a7b1867f38a9410849c8f4aef078535742198e942fba0e9bd \ + --hash=sha256:4d7092bb38c53bbc78e9255a59158b150bcdc115a1e3b3ce0b5f267dc35dd63c \ + --hash=sha256:4f5686e1fe5fce75d82d93c47a438a25bf0d1319d2843a926f741140b2b16e0c \ + --hash=sha256:58b0ee0ab5b1fc9921eccfe11d1435added19d6494dde14e323f25ad2bc30c56 \ + --hash=sha256:5ce02f38a754f207f2f06557523cd39a06438ba3aafc0639c477ac409fc64e37 \ + --hash=sha256:5fade934607a523614726119164ff621e8c30e8fa1ffffbbd358662056ba69f0 \ + --hash=sha256:5fe9fd43882620017add5eabb781ebfbc6998ee49b35bd7f8f79af1f9f99a958 \ + --hash=sha256:64102ca87e84261419c3747a0d20f396eb024bdbeb04c2bfb37e2891f5fadcb5 \ + --hash=sha256:664c5a68ec406f6b1547946683008576ef8b38275608e1cee6c061828171c118 \ + --hash=sha256:6675329885c44657f826ef01d9e4fb33b9158e9d93c537d84ad8399539bc6f69 \ + --hash=sha256:75c1a6dfac6abd407634420c93864a1e274ebc1c7531346d9254c0d8f6ca00f9 \ + --hash=sha256:75da8f28eff26defba42c52986de97b22106cb8f26515b7c22443ebc9c2d3261 \ + --hash=sha256:77efb033d8d7ff233385f30c62c7c79271c8885d5c9657d967ede124671bbdfb \ + --hash=sha256:78a7d3ab09dc47ac1a363a493e6112d8cabed7ba7caad5f54dbe2f08676d1b47 \ + --hash=sha256:7c7db70d57e5e1089a274cbb2b1fd635c9a24de809a231b154965d415d6c6d24 \ + --hash=sha256:8c56c488ab471628ff3bfa80964372fc13504ece601e0d97a78ee74126b2045c \ + --hash=sha256:91669ccac46bbc1d09e9273546181919064e8df73488ea087dcac3e2968df9ba \ + --hash=sha256:9b666a475a65f4e839d3d10473fad6d47e0a9db14a2f4a224029c5bfde58ad2c \ + --hash=sha256:9cfef3ab326780c04d6646f68d4b4742aae222e8b8ea1d627c74e38afcbc9d91 \ + --hash=sha256:a13fc8aeb24bad755eea8f7f9d409438eb94e82cf86b08fe77a03fbc8f6a96b1 \ + --hash=sha256:a75c301f96db737e1c5ed5fd7d77d9c34466de16095a266509e13da09751bd19 \ + --hash=sha256:a76d4cb80f41ba94a6691264be76435e5f72f2cb3cab0b092a6212855f71c2f6 \ + --hash=sha256:aed04cabe26f30c1647ef0e8fbb207516fd40fe9472e9439695f5c6998e60ac5 \ + --hash=sha256:b148b56f5de675ee16d45e769e69f87623a4944f7443850bf9a9376e628a89d2 \ + --hash=sha256:b501c862d4901792adaec7c25b1ecc749e2662543f68bb194c42ba18d6eec98d \ + --hash=sha256:b846a1fcf8beadeb9ea4f44ec5bdde393e2f1569e17d700bfc49cd69bde75881 \ + --hash=sha256:b931ae8f62db78861b0ff1ac017851764602288575d65b8e8ff1963fed419063 \ + --hash=sha256:c33ab3ca9d3ccd581d58e989d67554e42d8d4ded94ab3ade3508455fe70e65f7 \ + --hash=sha256:c6604b735bb12fef8e0efd5578c9fb5d3d8532d5001ea13a19cddf295673ee09 \ + --hash=sha256:d8db08051fc9e7d8bc622f2112511b8107d8f27cd89e2f64ec45e9825e8288da \ + --hash=sha256:d9203500f7c63545b4ce3799319fe4d9feb1a1b89b28d3cb5abd11b9dd64147e \ + --hash=sha256:dc492779501fa723b04d0ab1f5be046797fee17d27700476edc7ee9ae535a61e \ + --hash=sha256:e6bcdf33aec38d16508ce61fd81838f24c83c90a1d1b8c68982857038673d6b8 \ + --hash=sha256:e76ce097e3c57c4bcb67c5aa24a0ecdbd9f74ea9219997a707a4061fbe2707aa \ + --hash=sha256:eff1ac3cc66c2ac7cda1e64b4e2f3ffef474b7335f92fc3833fc632d595fcee6 \ + --hash=sha256:f3cb4a569029b9f291f88aafc927dd53683757e640081ca8c412781ea144565e \ + --hash=sha256:f79b168428351d11e10c5aeb61a74e1851ec221081299f4cf56036a95431c43a \ + --hash=sha256:fa646ecec9528bef693415c79a86e733c70a4965dd938e9a226b0fc64c9d2e6c \ + --hash=sha256:fe2efccb324948a11dd09d22136fe2ac8a97d6c1347cf0b58a911dcd529f66b7 \ + --hash=sha256:fff4f534200a04b4a36e7ae3cb74493afe807b517a09e99cb4faa89a34ed6ecd # via matplotlib fsspec==2025.10.0 \ --hash=sha256:7c7712353ae7d875407f97715f0e1ffcc21e33d5b24556cb1e090ae9409ec61d \ @@ -451,69 +443,69 @@ iniconfig==2.3.0 \ --hash=sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730 \ --hash=sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12 # via pytest -jax-cuda12-pjrt==0.8.1 ; sys_platform == "linux" \ - --hash=sha256:452b70ee10cb9ac5d7dfca55ffbcdb89b6c8bc6ba70a45af7c490d1dcea98eb7 \ - --hash=sha256:a631d0689903354afd7b3d2ec595b7da06a6230a76da00ff9548f542b21b6250 +jax-cuda12-pjrt==0.8.2 ; sys_platform == "linux" \ + --hash=sha256:717a1b196a642409ce195ddf031c20bbeadcc886f55e49a1d3f4927373aeedae \ + --hash=sha256:e3bab41ca7c48e4163db9e7efd271b3aa85f0fe45f5ed0708d6bbed93a59f977 # via # -r build/requirements.in # jax-cuda12-plugin -jax-cuda12-plugin==0.8.1 ; sys_platform == "linux" and python_version < "3.14" \ - --hash=sha256:1052c29157c99ca01d74d3073bbf4f711eb94465c0b4f5a4322d5e46233b1b2f \ - --hash=sha256:13311b72ca703a1bbad1ec516ac9ef750019a2d2c421d4c1daf8acf2720b822e \ - --hash=sha256:385001f56f852959f061ae15ad157c39cc4471c8d1d2544dfc3f805684ac2213 \ - --hash=sha256:479ca438d555024dac8dd1058371efbf8f479819a70aea213f3f3037ece99d74 \ - --hash=sha256:5a154723cb6c4e1e7969581a923dacf378f7515b0d53b5f1920e25e51cf6cecc \ - --hash=sha256:6a4b6fda687ca8361322029d58444bc0326798204806a3f90f231dc8ca5541a5 \ - --hash=sha256:7342c8810cc947de78f28c7287a30b2e201b0f51578543dd2553692b79a49942 \ - --hash=sha256:836eb0cd3af612d17bf17efc7eee175c6b9827989d5370df8ba919947fcb67cf \ - --hash=sha256:9968c15b87fd3867b6da0ce30681673a7fc4eedebaadcd24dce892e3f9fe1a52 \ - --hash=sha256:b3383bdc0b9f6260d9adc4ca0d1f68bf241158dfe69d726b267b0681382ea7a7 \ - --hash=sha256:b60bf0bbda24cec6fa71170bd69b613359f01a376d8e09fe34bf67ecc9a3164f \ - --hash=sha256:da7c0f2ef1c697f9ade51a71cfad211e2bff25407a6855dddde372c0190fc468 +jax-cuda12-plugin==0.8.2 ; sys_platform == "linux" and python_version < "3.14" \ + --hash=sha256:0b0a3304ce7e494acd8d9c593490c112a32cdb6010fe1afc584d9e41fd863167 \ + --hash=sha256:1b4828242d57f233b394d17ebaa599c503c1fb9b7c754012a06eb84dbc935fc8 \ + --hash=sha256:20165861b3d3e66ebb2c0f63a547d1d5ee17ea44ac3be7153c7908c9ca8c88f3 \ + --hash=sha256:377e4be17e22dde0343b3f3c05bf69235b3dbf11d766cca9c5a93da47971dcb7 \ + --hash=sha256:403d5e07731b5cdac3bd9fb3f448bd8480062cb2c0ab61ea2ad23fcd0a65479a \ + --hash=sha256:58c51473fc622e03138035985f741833564d70a4bd5a2178f61b62cdaa32ff94 \ + --hash=sha256:637387dc3408cd204562668502f9e95f76c6edde0a6d2e48f055162dc2aebf0d \ + --hash=sha256:70d33222484ad5c375b8f8357b7c23cacb844f6ecfc39567f8dd47fde6e87858 \ + --hash=sha256:82c6798be66bf8c773386918e4c8e5cd8119753f3bfb3ca4bbc46818283750c6 \ + --hash=sha256:a5898bac1d8ab6020b54546440256409f2c66bcbbb3a1099ca473c84843addad \ + --hash=sha256:d68a6d8b4a45ee561746bac7a6468da8203832626b0b39ad4ac43011f61f875d \ + --hash=sha256:dd4f7c34d4512ff5a36fd1b01584ef7781cad615e3f9e71880eae2f4998e5108 # via -r build/requirements.in -jax-cuda13-pjrt==0.8.1 \ - --hash=sha256:86a6926da76aebf6080922747a7a98d321f4ca27101077357fa148032bc3cd1d \ - --hash=sha256:f3b1c1c7118b4570f72740ed756cbed289a3f8fa813570a0dbf16f186bccb8c9 +jax-cuda13-pjrt==0.8.2 \ + --hash=sha256:370e9cb9a76af6a6b00e8f6c6ae8e5a0158886b65994a114cf3acf56ea1570f3 \ + --hash=sha256:c0ebc284d4bea9c0b278017b4bb62eb871755fe582cbf22ce1fb770d30fb9af6 # via # -r build/requirements.in # jax-cuda13-plugin -jax-cuda13-plugin==0.8.1 \ - --hash=sha256:07625aed1aa769c701213e84d6b2a46902019a1d2af8a09ce6dfd9575163bfc6 \ - --hash=sha256:0d503c312d2daefea62a00c74534579deeacd46e15c364074d27a8d95a100032 \ - --hash=sha256:12a7aac712a7c6dc228ef9991578e85e3bcab7c324193bdfb2b5acf059bae6d6 \ - --hash=sha256:16ee16b13393baf9672b6612566308675cebdc8d785b61fac2b93ce8c97825ff \ - --hash=sha256:4e589ed8197f1bea7e7fd20d866ccc5c2a1276d7acd02224e3a5b07983df61e2 \ - --hash=sha256:64df1f1414d899ab7a84751d6f78515365555b54fb64b3e318bd70519de99c86 \ - --hash=sha256:7a373fd3e5f11ecad01b8add1e277eb6559b4966b0745d92dc91c585579fac35 \ - --hash=sha256:92238530152890c3405addacd1fc021c87022cbf99fa66418cfa2e9f68a5c49d \ - --hash=sha256:a4c5a4a69346be6520c729675d5d80e85d610399f4840d74bdfae9c6ebedc8bc \ - --hash=sha256:af33f737ccf5426155cf5c7d175bf765ca25724b94af5109ef2df891b410f997 \ - --hash=sha256:d81222989fb30496cc6554d42deaca6ab003721f56ff669da7f80d61fea1219d \ - --hash=sha256:ef218c47b2cde8c700ad2a56d04320f9d1490439fc6db20747f56e91de7289c2 +jax-cuda13-plugin==0.8.2 \ + --hash=sha256:0ab3a13f798f973416754cb00b1696a8e6d2425c39496e2145293661bcd28446 \ + --hash=sha256:2c7532f2a6c9dcf50ba4b8ebbe4c78988b8c260695e50f9fdbd68f9417a8bb51 \ + --hash=sha256:358f4b4d0d2b92f62fe7f07d355a9c57c5684da3d4e6f483afec38b23e66f02e \ + --hash=sha256:6703552ee10b6b41b43fe4fc62b4672bec13d271eb709a29bf14b0fa1ec805dd \ + --hash=sha256:8fe2c5874cbae18a01a0fa722dfbd547b5504e7bf42aaaae6ad3e1b676e8daa2 \ + --hash=sha256:950f2c1f2e0f5e13893c4c27a4df540a96a26ebd15d7954fa85ffb8cc8b1d9a1 \ + --hash=sha256:9835076d0a917345700370cfaee5cfd7876c0b0982e99206a8b4848af0ea6e5e \ + --hash=sha256:bb059d99eb960718358f82f3eaa0829f96fc7d7e978ded821f2c40256a8aee4b \ + --hash=sha256:d5f9470db46aee8a64e9ab7817927335b77c6d3698d5113e0da2dd962c1468c4 \ + --hash=sha256:f1643b216ccb41a2d7c0e240359c729c96c5eca76e4525ffcd6313049e348b9a \ + --hash=sha256:f35f687d9b839365a49c2b0052eaeb6b9f0a9879813b4fa020f8d9caa3ddde46 \ + --hash=sha256:fa6a6faf3130d6ef05f9ee6cbc27e0a2e4db0f23720136403991eccb3a0144af # via -r build/requirements.in -jaxlib==0.8.1 \ - --hash=sha256:117f2fe2c19479e560ad85a3ef2fcc0b1d24816456f0d039f865c2acbab63b5a \ - --hash=sha256:1a4001ed3ba9ed5a812da1b16f52eebb5d473a4480c1523828c7bd3dae8d1375 \ - --hash=sha256:1bc76edec2bc74a7adb5e29329ece51a67c57cd011a06d55d07da62fbabe3389 \ - --hash=sha256:22f489fb5c8be0da7be5e4957a10936b3760a169668f8b25c5d09c51c3ef47f6 \ - --hash=sha256:24ec3f3a9c45d6de060020dc94c444d69e18099fab927ea3979ff8cedf0ed2c9 \ - --hash=sha256:4933298fcfb07a5aa2d1fed21c111d07cea50e6f180dba2cdb5463c13fb98f2f \ - --hash=sha256:63fc25c4b5d03256798796a024125e29bcf254acc3eae5dc3239d1c30b86b866 \ - --hash=sha256:7a5d381fad89622750fae29fab83c0847e2931ad8d6a34dc13b28fc4d67f75a3 \ - --hash=sha256:865add56139883405f3f15c9b0de6a64ab8f4aa549dff196b72dbc86be6ccc1f \ - --hash=sha256:88bde0f535eeea6689e0cd57d40b7660d5206ac95c7d42e09562a109b963a49f \ - --hash=sha256:8e118e1fbe714f37a94ba26777c17faab7dca4a33646a3d98cd1d99673bbd6b1 \ - --hash=sha256:90e48973f8dbded7edc8728be84c01ae00412190187fb06622abfa4edd42c0a8 \ - --hash=sha256:92c41c9b9862c08521eb90515a7c5bcc840c6d30f86230cebf94aea2d6a0af81 \ - --hash=sha256:a0349f6e8179dc897d33aeb90ec66b4a8041330fbbba8d071dc6167cd2271539 \ - --hash=sha256:af4924189fc53b69237715b56ebcbfc71bb91ca16184143dcef0d430c8173de6 \ - --hash=sha256:bd697c171ace1e2e9d6ed910a78f385b3c4095cee290b0255aa58848f2acdeab \ - --hash=sha256:bed1e94ae8c7c16bca4476d8d7f582f0d1a102a4e69c3a9bd2069a0dc42274a9 \ - --hash=sha256:c14c8c19a7eb694aa14092b6d2fffb9d2bdd8a603b63d6f26fbeaf129c204f9f \ - --hash=sha256:d245bd6a279c72ca5f796df84cdd64d7c9c8abc4b8d89adf4acf45898dab958b \ - --hash=sha256:f2f11491b077d05249d63813e811401194a41edc8e9cc60af8f4b554057cfad0 \ - --hash=sha256:fdbbf2336c08bbf8f30548e204c8c9d77f8b2a3a5b7fc7985749246feb8852b0 \ - --hash=sha256:ff32b6320d729131efaf22939825b52d75957c84c32af2b0b1bdb33cf27ba75f +jaxlib==0.8.2 \ + --hash=sha256:023de6f3f56da2af7037970996500586331fdb50b530ecbb54b9666da633bd00 \ + --hash=sha256:05b958f497e49824c432e734bb059723b7dfe69e2ad696a9f9c8ad82fff7c3f8 \ + --hash=sha256:1bfbcf6c3de221784fa4cdb6765a09d71cb4298b15626b3d0409b3dfcd8a8667 \ + --hash=sha256:28eec1a4e0639a0d8702cea3cb70dd3663053dbfa344452994ea48dc6ceadaa5 \ + --hash=sha256:2b9789bd08f8b0cc5a5c12ae896fe432d5942e32e417091b8b5a96a9a6fd5cf1 \ + --hash=sha256:3b16e50c5b730c9dd0a49e55f1acfaa722b00b1af0522a591558dcc0464252f2 \ + --hash=sha256:490bf0cb029c73c65c9431124b86cdc95082dbc1fb76fc549d24d75da33e5454 \ + --hash=sha256:4d006db96be020c8165212a1216372f8acac4ff4f8fb067743d694ef2b301ace \ + --hash=sha256:68108dff0de74adc468016be9a19f80efe48c660c0d5a122287094b44b092afc \ + --hash=sha256:7c304f3a016965b9d1f5239a8a0399a73925f5604fe914c5ca66ecf734bf6422 \ + --hash=sha256:7da8127557c786264049ae55460d1b8d04cc3cdf0403a087f2fc1e6d313ec722 \ + --hash=sha256:964626f581beab31ee6826b228fcc2ec5181b05cecf94a528dff97921c145dbc \ + --hash=sha256:a397ea7dcb37d689ce79173eeb99b2f1347637a36be9a27f20ae6848bfc58bfc \ + --hash=sha256:aa8701b6356f098e8452c3cec762fb5f706fcb8f67ffd65964f63982479aa23b \ + --hash=sha256:bb89be452b1b808d3f88fc01c415b364a260be4cc7ac120c038009f6150a32dc \ + --hash=sha256:beffb004e7eeb5c9afb24439e2b2cf45a4ee3e3e8adf45e355edf2af62acf8b8 \ + --hash=sha256:ccf77da917a20935247c990691decfcbdd06c25ef0ac94d914a04aadb22f714c \ + --hash=sha256:dffc22b5b732b9556d92c918b251c61bcc046617c4dbb51e1f7a656587fddffb \ + --hash=sha256:e6a97dfb0232eed9a2bb6e3828e4f682dbac1a7fea840bfda574cae2dbf5faf9 \ + --hash=sha256:f205e91c3a152a2a76c0bc59a6a2de03e87ec261b91e8812922777185e7b08f5 \ + --hash=sha256:f28edac8c226fc07fa3e8af6f9defede8ac2c307429e3291edce8739d39becc9 \ + --hash=sha256:f472cc72e3058e50b5f0230b236d5a1183bf6c3d5423d2a52eff07bcf34908de # via -r build/requirements.in keras==3.12.0 \ --hash=sha256:02b69e007d5df8042286c3bcc2a888539e3e487590ffb08f6be1b4354df50aa8 \ @@ -634,13 +626,13 @@ libclang==18.1.1 \ --hash=sha256:c533091d8a3bbf7460a00cb6c1a71da93bffe148f172c7d03b1c31fbf8aa2a0b \ --hash=sha256:cf4a99b05376513717ab5d82a0db832c56ccea4fd61a69dbb7bccf2dfb207dbe # via tensorflow -libtpu==0.0.30 ; sys_platform == "linux" and platform_machine == "x86_64" \ - --hash=sha256:26442f0a51d243cf7259407bba8f5d849c9024297efe97044d64b5244283ad63 \ - --hash=sha256:5fabff9a041674bb889fb59ac0b5c54b9dbcf492a8c782e083ef86a8194dbb0f \ - --hash=sha256:8be30562743a63c1c1353e7ba78f0dbfbb051e8d1e9d3bb2b5da9b720363bb0a \ - --hash=sha256:b1fc44915dad56c0ceb733311a4d4396b88dc9a1c7c01acd7617da90e7ec22f2 \ - --hash=sha256:babab04ca663da2c4e4b3ab036c4d465f2f4674c480d08239c5d4965b7ce9e1c \ - --hash=sha256:f9aa040895ec25fafebcd4e1a0e1a9524ff3bd778ca88543731e308f6e516dd1 +libtpu==0.0.32 ; sys_platform == "linux" and platform_machine == "x86_64" \ + --hash=sha256:37f6aefe6d69d24e5e268e74c1e90cd0ce0fa6a796720709445d8938605912ee \ + --hash=sha256:4a4ae6db0a90f0e33902cd345923a5d77dfc5bbab9a3e524c79bf876858104ee \ + --hash=sha256:6453995774b902e43d1caf0d82298a2e90f2f731ddd74d6bef9510d91732775f \ + --hash=sha256:7ffd9cdb89da22d32bb60ce70b88e9e76861c5d704ebf673f32987e19e9913f3 \ + --hash=sha256:98d9713a39d9581c9b10a7bdeaa575cc874a4a7f543cfb40d3a723ee27b428b5 \ + --hash=sha256:d2a0e7abe4a029b79049bc95da20e4c2a64f8ef09823f75286cecfbb4b0b6da1 # via -r build/requirements.in markdown==3.10 \ --hash=sha256:37062d4f2aa4b2b6b32aefb80faa300f82cc790cb949a35b8caede34f2b68c0e \ @@ -741,62 +733,62 @@ markupsafe==3.0.3 \ --hash=sha256:f9e130248f4462aaa8e2552d547f36ddadbeaa573879158d721bbd33dfe4743a \ --hash=sha256:fed51ac40f757d41b7c48425901843666a6677e3e8eb0abcff09e4ba6e664f50 # via werkzeug -matplotlib==3.10.7 \ - --hash=sha256:07124afcf7a6504eafcb8ce94091c5898bbdd351519a1beb5c45f7a38c67e77f \ - --hash=sha256:09d7945a70ea43bf9248f4b6582734c2fe726723204a76eca233f24cffc7ef67 \ - --hash=sha256:0d8c32b7ea6fb80b1aeff5a2ceb3fb9778e2759e899d9beff75584714afcc5ee \ - --hash=sha256:11ae579ac83cdf3fb72573bb89f70e0534de05266728740d478f0f818983c695 \ - --hash=sha256:15112bcbaef211bd663fa935ec33313b948e214454d949b723998a43357b17b0 \ - --hash=sha256:1d9d3713a237970569156cfb4de7533b7c4eacdd61789726f444f96a0d28f57f \ - --hash=sha256:1e4bbad66c177a8fdfa53972e5ef8be72a5f27e6a607cec0d8579abd0f3102b1 \ - --hash=sha256:2222c7ba2cbde7fe63032769f6eb7e83ab3227f47d997a8453377709b7fe3a5a \ - --hash=sha256:22df30ffaa89f6643206cf13877191c63a50e8f800b038bc39bee9d2d4957632 \ - --hash=sha256:31963603041634ce1a96053047b40961f7a29eb8f9a62e80cc2c0427aa1d22a2 \ - --hash=sha256:37a1fea41153dd6ee061d21ab69c9cf2cf543160b1b85d89cd3d2e2a7902ca4c \ - --hash=sha256:3886e47f64611046bc1db523a09dd0a0a6bed6081e6f90e13806dd1d1d1b5e91 \ - --hash=sha256:4645fc5d9d20ffa3a39361fcdbcec731382763b623b72627806bf251b6388866 \ - --hash=sha256:4a11c2e9e72e7de09b7b72e62f3df23317c888299c875e2b778abf1eda8c0a42 \ - --hash=sha256:4a74f79fafb2e177f240579bc83f0b60f82cc47d2f1d260f422a0627207008ca \ - --hash=sha256:4c14b6acd16cddc3569a2d515cfdd81c7a68ac5639b76548cfc1a9e48b20eb65 \ - --hash=sha256:53b492410a6cd66c7a471de6c924f6ede976e963c0f3097a3b7abfadddc67d0a \ - --hash=sha256:53cc80662dd197ece414dd5b66e07370201515a3eaf52e7c518c68c16814773b \ - --hash=sha256:5c09cf8f2793f81368f49f118b6f9f937456362bee282eac575cca7f84cda537 \ - --hash=sha256:5e38c2d581d62ee729a6e144c47a71b3f42fb4187508dbbf4fe71d5612c3433b \ - --hash=sha256:5f3f6d315dcc176ba7ca6e74c7768fb7e4cf566c49cb143f6bc257b62e634ed8 \ - --hash=sha256:6516ce375109c60ceec579e699524e9d504cd7578506f01150f7a6bc174a775e \ - --hash=sha256:667ecd5d8d37813a845053d8f5bf110b534c3c9f30e69ebd25d4701385935a6d \ - --hash=sha256:6f1851eab59ca082c95df5a500106bad73672645625e04538b3ad0f69471ffcc \ - --hash=sha256:702590829c30aada1e8cef0568ddbffa77ca747b4d6e36c6d173f66e301f89cc \ - --hash=sha256:7146d64f561498764561e9cd0ed64fcf582e570fc519e6f521e2d0cfd43365e1 \ - --hash=sha256:744991e0cc863dd669c8dc9136ca4e6e0082be2070b9d793cbd64bec872a6815 \ - --hash=sha256:786656bb13c237bbcebcd402f65f44dd61ead60ee3deb045af429d889c8dbc67 \ - --hash=sha256:7a0edb7209e21840e8361e91ea84ea676658aa93edd5f8762793dec77a4a6748 \ - --hash=sha256:7ac81eee3b7c266dd92cee1cd658407b16c57eed08c7421fa354ed68234de380 \ - --hash=sha256:90ad854c0a435da3104c01e2c6f0028d7e719b690998a2333d7218db80950722 \ - --hash=sha256:9257be2f2a03415f9105c486d304a321168e61ad450f6153d77c69504ad764bb \ - --hash=sha256:932c55d1fa7af4423422cb6a492a31cbcbdbe68fd1a9a3f545aa5e7a143b5355 \ - --hash=sha256:a06ba7e2a2ef9131c79c49e63dad355d2d878413a0376c1727c8b9335ff731c7 \ - --hash=sha256:aebed7b50aa6ac698c90f60f854b47e48cd2252b30510e7a1feddaf5a3f72cbf \ - --hash=sha256:b172db79759f5f9bc13ef1c3ef8b9ee7b37b0247f987fbbbdaa15e4f87fd46a9 \ - --hash=sha256:b3c4ea4948d93c9c29dc01c0c23eef66f2101bf75158c291b88de6525c55c3d1 \ - --hash=sha256:b498e9e4022f93de2d5a37615200ca01297ceebbb56fe4c833f46862a490f9e3 \ - --hash=sha256:b4d41379b05528091f00e1728004f9a8d7191260f3862178b88e8fd770206318 \ - --hash=sha256:b69676845a0a66f9da30e87f48be36734d6748024b525ec4710be40194282c84 \ - --hash=sha256:c17398b709a6cce3d9fdb1595c33e356d91c098cd9486cb2cc21ea2ea418e715 \ - --hash=sha256:c380371d3c23e0eadf8ebff114445b9f970aff2010198d498d4ab4c3b41eea4f \ - --hash=sha256:cb783436e47fcf82064baca52ce748af71725d0352e1d31564cbe9c95df92b9c \ - --hash=sha256:cc1c51b846aca49a5a8b44fbba6a92d583a35c64590ad9e1e950dc88940a4297 \ - --hash=sha256:d0b181e9fa8daf1d9f2d4c547527b167cb8838fc587deabca7b5c01f97199e84 \ - --hash=sha256:d2a959c640cdeecdd2ec3136e8ea0441da59bcaf58d67e9c590740addba2cb68 \ - --hash=sha256:d5f256d49fea31f40f166a5e3131235a5d2f4b7f44520b1cf0baf1ce568ccff0 \ - --hash=sha256:d883460c43e8c6b173fef244a2341f7f7c0e9725c7fe68306e8e44ed9c8fb100 \ - --hash=sha256:d8eb7194b084b12feb19142262165832fc6ee879b945491d1c3d4660748020c4 \ - --hash=sha256:d9749313deb729f08207718d29c86246beb2ea3fdba753595b55901dee5d2fd6 \ - --hash=sha256:de66744b2bb88d5cd27e80dfc2ec9f0517d0a46d204ff98fe9e5f2864eb67657 \ - --hash=sha256:e91f61a064c92c307c5a9dc8c05dc9f8a68f0a3be199d9a002a0622e13f874a1 \ - --hash=sha256:f19410b486fdd139885ace124e57f938c1e6a3210ea13dd29cab58f5d4bc12c7 \ - --hash=sha256:f79d5de970fc90cd5591f60053aecfce1fcd736e0303d9f0bf86be649fa68fb8 \ - --hash=sha256:fba2974df0bf8ce3c995fa84b79cde38326e0f7b5409e7a3a481c1141340bcf7 +matplotlib==3.10.8 \ + --hash=sha256:00270d217d6b20d14b584c521f810d60c5c78406dc289859776550df837dcda7 \ + --hash=sha256:0a33deb84c15ede243aead39f77e990469fff93ad1521163305095b77b72ce4a \ + --hash=sha256:113bb52413ea508ce954a02c10ffd0d565f9c3bc7f2eddc27dfe1731e71c7b5f \ + --hash=sha256:12d90df9183093fcd479f4172ac26b322b1248b15729cb57f42f71f24c7e37a3 \ + --hash=sha256:15d30132718972c2c074cd14638c7f4592bd98719e2308bccea40e0538bc0cb5 \ + --hash=sha256:18821ace09c763ec93aef5eeff087ee493a24051936d7b9ebcad9662f66501f9 \ + --hash=sha256:1ae029229a57cd1e8fe542485f27e7ca7b23aa9e8944ddb4985d0bc444f1eca2 \ + --hash=sha256:2299372c19d56bcd35cf05a2738308758d32b9eaed2371898d8f5bd33f084aa3 \ + --hash=sha256:238b7ce5717600615c895050239ec955d91f321c209dd110db988500558e70d6 \ + --hash=sha256:24d50994d8c5816ddc35411e50a86ab05f575e2530c02752e02538122613371f \ + --hash=sha256:25d380fe8b1dc32cf8f0b1b448470a77afb195438bafdf1d858bfb876f3edf7b \ + --hash=sha256:2c1998e92cd5999e295a731bcb2911c75f597d937341f3030cc24ef2733d78a8 \ + --hash=sha256:2cf5bd12cecf46908f286d7838b2abc6c91cda506c0445b8223a7c19a00df008 \ + --hash=sha256:32f8dce744be5569bebe789e46727946041199030db8aeb2954d26013a0eb26b \ + --hash=sha256:37b3c1cc42aa184b3f738cfa18c1c1d72fd496d85467a6cf7b807936d39aa656 \ + --hash=sha256:3a48a78d2786784cc2413e57397981fb45c79e968d99656706018d6e62e57958 \ + --hash=sha256:3ab4aabc72de4ff77b3ec33a6d78a68227bf1123465887f9905ba79184a1cc04 \ + --hash=sha256:3c624e43ed56313651bc18a47f838b60d7b8032ed348911c54906b130b20071b \ + --hash=sha256:3f2e409836d7f5ac2f1c013110a4d50b9f7edc26328c108915f9075d7d7a91b6 \ + --hash=sha256:3f5c3e4da343bba819f0234186b9004faba952cc420fbc522dc4e103c1985908 \ + --hash=sha256:41703cc95688f2516b480f7f339d8851a6035f18e100ee6a32bc0b8536a12a9c \ + --hash=sha256:495672de149445ec1b772ff2c9ede9b769e3cb4f0d0aa7fa730d7f59e2d4e1c1 \ + --hash=sha256:4cf267add95b1c88300d96ca837833d4112756045364f5c734a2276038dae27d \ + --hash=sha256:56271f3dac49a88d7fca5060f004d9d22b865f743a12a23b1e937a0be4818ee1 \ + --hash=sha256:595ba4d8fe983b88f0eec8c26a241e16d6376fe1979086232f481f8f3f67494c \ + --hash=sha256:5f62550b9a30afde8c1c3ae450e5eb547d579dd69b25c2fc7a1c67f934c1717a \ + --hash=sha256:646d95230efb9ca614a7a594d4fcacde0ac61d25e37dd51710b36477594963ce \ + --hash=sha256:64fcc24778ca0404ce0cb7b6b77ae1f4c7231cdd60e6778f999ee05cbd581b9a \ + --hash=sha256:6be43b667360fef5c754dda5d25a32e6307a03c204f3c0fc5468b78fa87b4160 \ + --hash=sha256:6da7c2ce169267d0d066adcf63758f0604aa6c3eebf67458930f9d9b79ad1db1 \ + --hash=sha256:83d282364ea9f3e52363da262ce32a09dfe241e4080dcedda3c0db059d3c1f11 \ + --hash=sha256:9153c3292705be9f9c64498a8872118540c3f4123d1a1c840172edf262c8be4a \ + --hash=sha256:99eefd13c0dc3b3c1b4d561c1169e65fe47aab7b8158754d7c084088e2329466 \ + --hash=sha256:a0a7f52498f72f13d4a25ea70f35f4cb60642b466cbb0a9be951b5bc3f45a486 \ + --hash=sha256:a2b336e2d91a3d7006864e0990c83b216fcdca64b5a6484912902cef87313d78 \ + --hash=sha256:a48f2b74020919552ea25d222d5cc6af9ca3f4eb43a93e14d068457f545c2a17 \ + --hash=sha256:ad3d9833a64cf48cc4300f2b406c3d0f4f4724a91c0bd5640678a6ba7c102077 \ + --hash=sha256:b44d07310e404ba95f8c25aa5536f154c0a8ec473303535949e52eb71d0a1565 \ + --hash=sha256:b53285e65d4fa4c86399979e956235deb900be5baa7fc1218ea67fbfaeaadd6f \ + --hash=sha256:b5a2b97dbdc7d4f353ebf343744f1d1f1cca8aa8bfddb4262fcf4306c3761d50 \ + --hash=sha256:b9a5ca4ac220a0cdd1ba6bcba3608547117d30468fefce49bb26f55c1a3d5c58 \ + --hash=sha256:bab485bcf8b1c7d2060b4fcb6fc368a9e6f4cd754c9c2fea281f4be21df394a2 \ + --hash=sha256:c108a1d6fa78a50646029cb6d49808ff0fc1330fda87fa6f6250c6b5369b6645 \ + --hash=sha256:d56a1efd5bfd61486c8bc968fa18734464556f0fb8e51690f4ac25d85cbbbbc2 \ + --hash=sha256:d9050fee89a89ed57b4fb2c1bfac9a3d0c57a0d55aed95949eedbc42070fea39 \ + --hash=sha256:dd80ecb295460a5d9d260df63c43f4afbdd832d725a531f008dad1664f458adf \ + --hash=sha256:e8ea3e2d4066083e264e75c829078f9e149fa119d27e19acd503de65e0b13149 \ + --hash=sha256:eb3823f11823deade26ce3b9f40dcb4a213da7a670013929f31d5f5ed1055b22 \ + --hash=sha256:ee40c27c795bda6a5292e9cff9890189d32f7e3a0bf04e0e3c9430c4a00c37df \ + --hash=sha256:efb30e3baaea72ce5928e32bab719ab4770099079d66726a62b11b1ef7273be4 \ + --hash=sha256:f254d118d14a7f99d616271d6c3c27922c092dac11112670b157798b89bf4933 \ + --hash=sha256:f89c151aab2e2e23cb3fe0acad1e8b82841fd265379c4cecd0f3fcb34c15e0f6 \ + --hash=sha256:f97aeb209c3d2511443f8797e3e5a569aebb040d4f8bc79aa3ee78a8fb9e3dd8 \ + --hash=sha256:f9b587c9c7274c1613a30afabf65a272114cd6cdbe67b3406f818c79d7ab2e2a \ + --hash=sha256:fb061f596dad3a0f52b60dc6a5dec4a0c300dec41e058a7efe09256188d170b7 # via -r build/test-requirements.txt mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ @@ -926,9 +918,9 @@ numpy==2.1.3 ; python_version == "3.13" \ # tensorboard # tensorflow # tensorstore -numpy-typing-compat==20250818.2.1 \ - --hash=sha256:36e073e82f93a1754526f71f8fc7896fa209e0eb19a6e278a74456ab198e2bda \ - --hash=sha256:7626eda39e42b513d44285a70e1a9f07f13d3b658cc4d4d83671dc134b232de0 +numpy-typing-compat==20251206.2.1 \ + --hash=sha256:703ae61be7877ab0af562298776b89eae609a3985414d92011a39a42350b42e1 \ + --hash=sha256:8a868da29e8d076c2aaef8ea9ebb602af917c9752063cfe7c95d6cb60c7b9ea3 # via optype nvidia-cublas==13.1.0.3 ; sys_platform == "linux" \ --hash=sha256:2a3b94a37def342471c59fad7856caee4926809a72dd5270155d6a31b5b277be \ @@ -1188,9 +1180,9 @@ optree==0.18.0 \ --hash=sha256:fa8e3878a1857761d64f08a23b32140d29754a53f85f7c87186ced2b5b1b49cb \ --hash=sha256:ff7326f36ed70d84c3fd62fb39bc6858f699640b8ab238c3cb8dafe1e200af59 # via keras -optype[numpy]==0.14.0 \ - --hash=sha256:50d02edafd04edf2e5e27d6249760a51b2198adb9f6ffd778030b3d2806b026b \ - --hash=sha256:925cf060b7d1337647f880401f6094321e7d8e837533b8e159b9a92afa3157c6 +optype[numpy]==0.15.0 \ + --hash=sha256:457d6ca9e7da19967ec16d42bdf94e240b33b5d70a56fbbf5b427e5ea39cf41e \ + --hash=sha256:caba40ece9ea39b499fa76c036a82e0d452a432dd4dd3e8e0d30892be2e8c76c # via scipy-stubs packaging==25.0 \ --hash=sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484 \ @@ -1308,17 +1300,17 @@ portpicker==1.6.0 \ --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa # via -r build/test-requirements.txt -protobuf==6.33.1 \ - --hash=sha256:023af8449482fa884d88b4563d85e83accab54138ae098924a985bcbb734a213 \ - --hash=sha256:0f4cf01222c0d959c2b399142deb526de420be8236f22c71356e2a544e153c53 \ - --hash=sha256:8fd7d5e0eb08cd5b87fd3df49bc193f5cfd778701f47e11d127d0afc6c39f1d1 \ - --hash=sha256:923aa6d27a92bf44394f6abf7ea0500f38769d4b07f4be41cb52bd8b1123b9ed \ - --hash=sha256:97f65757e8d09870de6fd973aeddb92f85435607235d20b2dfed93405d00c85b \ - --hash=sha256:d595a9fd694fdeb061a62fbe10eb039cc1e444df81ec9bb70c7fc59ebcb1eafa \ - --hash=sha256:df051de4fd7e5e4371334e234c62ba43763f15ab605579e04c7008c05735cd82 \ - --hash=sha256:f8adba2e44cde2d7618996b3fc02341f03f5bc3f2748be72dc7b063319276178 \ - --hash=sha256:f8d3fdbc966aaab1d05046d0240dd94d40f2a8c62856d41eaa141ff64a79de6b \ - --hash=sha256:fe34575f2bdde76ac429ec7b570235bf0c788883e70aee90068e9981806f2490 +protobuf==6.33.2 \ + --hash=sha256:1f8017c48c07ec5859106533b682260ba3d7c5567b1ca1f24297ce03384d1b4f \ + --hash=sha256:2981c58f582f44b6b13173e12bb8656711189c2a70250845f264b877f00b1913 \ + --hash=sha256:56dc370c91fbb8ac85bc13582c9e373569668a290aa2e66a590c2a0d35ddb9e4 \ + --hash=sha256:7109dcc38a680d033ffb8bf896727423528db9163be1b6a02d6a49606dcadbfe \ + --hash=sha256:7636aad9bb01768870266de5dc009de2d1b936771b38a793f73cbbf279c91c5c \ + --hash=sha256:87eb388bd2d0f78febd8f4c8779c79247b26a5befad525008e49a6955787ff3d \ + --hash=sha256:8cd7640aee0b7828b6d03ae518b5b4806fdfc1afe8de82f79c3454f8aef29872 \ + --hash=sha256:b5d3b5625192214066d99b2b605f5783483575656784de223f00a8d00754fc0e \ + --hash=sha256:d9b19771ca75935b3a4422957bc518b0cecb978b31d1dd12037b088f6bcc0e43 \ + --hash=sha256:fc2a0e8b05b180e5fc0dd1559fe8ebdae21a27e81ac77728fb6c42b12c7419b4 # via # tensorboard # tensorflow @@ -1450,9 +1442,9 @@ scipy==1.16.3 ; python_version >= "3.13" \ # via # -r build/requirements.in # jaxlib -scipy-stubs==1.16.3.0 \ - --hash=sha256:90e5d82ced2183ef3c5c0a28a77df8cc227458624364fa0ff975ad24fa89d6ad \ - --hash=sha256:d6943c085e47a1ed431309f9ca582b6a206a9db808a036132a0bf01ebc34b506 +scipy-stubs==1.16.3.3 \ + --hash=sha256:af47578875d5557567225a16ec1b9b38a48c4c4377d92396413ebd65406c44ee \ + --hash=sha256:f6316b36cd0fb272c994ae5b10c4a73c644a7e156ed8d32bcd9c35303d0e1b7e # via -r build/test-requirements.txt six==1.17.0 \ --hash=sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274 \ @@ -1496,28 +1488,32 @@ tensorflow==2.20.0 ; python_version < "3.14" \ --hash=sha256:dd71a7e7c3270239f4185915e8f2c5d39608c5e18973d6e1d101b153993841eb \ --hash=sha256:e5f169f8f5130ab255bbe854c5f0ae152e93d3d1ac44f42cb1866003b81a5357 # via -r build/nonfreethreading-requirements.txt -tensorstore==0.1.79 \ - --hash=sha256:0fd6165f3df49abc7c9de029b2b72d74bebd2ff2481a5ced003607eb61c56d3e \ - --hash=sha256:108c0e867aa2c87d4982cc6325a2de0c4f5bd63c2bea18adb193a370c40594ce \ - --hash=sha256:11a2c62694ea9c21770bc5a09938d3d15c4b9662b738ae6e1e513c26ed96251a \ - --hash=sha256:1e8e2d098829919caac6a62cf568902e34789069ceddb28497d6e36ebcb95c0b \ - --hash=sha256:29cf4336153af136ac8ac528e2ed46df19367edae7e14e37bca1a8b7c4848ef2 \ - --hash=sha256:5e152d334bf34fbabdfe8e5bc35b87d1f9947065924ff83c29e659308b36e948 \ - --hash=sha256:608f7178ec6e4e4a3c26545b0a44f44bf83438d04bf2d960cd0e7699eaa99ef6 \ - --hash=sha256:6c98c6b74c00e00eba7969292144e471d5c45d67088f0dc08e3a4c60a15ee191 \ - --hash=sha256:6f8f5a940eab434a951c2dadcc7c0516c7bef6d8b7a7144054f7a0c56152b5f5 \ - --hash=sha256:71aa9b45436d888c37b965f7b71195916d15438119b7dccb66a3b0776bfba367 \ - --hash=sha256:7af9422269c2bfcdecf9dd55309060665ab9c2d7f6c892377ed32c032400feea \ - --hash=sha256:83072ee0e551d6dca582e154b64c8b8066d276ec0759784e3149c28212a61f18 \ - --hash=sha256:847982652273fb7b2d694b789205747aaf3e50ae64738c5cb7b5eb03d86a9947 \ - --hash=sha256:8dad44a8a7f2952a5d0030a8bd868b3cfdff048bd40ab53e7226f3d8b0881c5e \ - --hash=sha256:94d8fc9df1721b0287046aca7209fd5040889cad4202e7b73a1fdb77cd9b71c6 \ - --hash=sha256:97756d2cba3c5ce21e15602c2af5a02521cc0ecda7f9fb6d18da2f3bd51827f4 \ - --hash=sha256:a071c6c255b7e412957a6aa563bc4250242c7894edad06ae6358e3d30b7d88ce \ - --hash=sha256:bbd8c1ab7d2e3c03ded3d40bb373ee9a67668e33a564484927865ce43b210386 \ - --hash=sha256:c4230b8fd29795e88e441f749d881973eca8dadf33c5262b367839fb8891f79b \ - --hash=sha256:c9f2dc3342e4686af98f6e259dc9fb377f1bf657b649c247bf6647bbe4f98090 \ - --hash=sha256:debd435042c00be68ba1fb3cf59325a7babb3f4a3cf4744c87dde346802cbbb4 +tensorstore==0.1.80 \ + --hash=sha256:04c29d979eb8b8ee48f873dc13d2701bfd49425500ffc5b848e4ec55b2548281 \ + --hash=sha256:07e4a84bacf70b78305831897068a9b5ad30326e63bbeb92c4bf7e565fcf5e9e \ + --hash=sha256:1113a6982fc0fa8dda8fcc0495715e647ac3360909a86ff13f2e04564f82d54a \ + --hash=sha256:189d924eaec394c9331e284a9c513ed583e336472a925823b5151cb26f41d091 \ + --hash=sha256:1b2b2ed0051dfab7e25295b14e6620520729e6e2ddf505f98c8d3917569614bf \ + --hash=sha256:246641a8780ee5e04e88bc95c8e31faac6471bab1180d1f5cdc9804b29a77c04 \ + --hash=sha256:4158fe76b96f62d12a37d7868150d836e089b5280b2bdd363c43c5d651f10e26 \ + --hash=sha256:46136fe42ee6dd835d957db37073058aea0b78fdfbe2975941640131b7740824 \ + --hash=sha256:4baee67fce95f29f593fbab4866119347115eaace887732aa92cfcbb9e6b0748 \ + --hash=sha256:53fd121ccd332bc4cc397f7af45889360c668b43dc3ff6bc3264df0f9886c11a \ + --hash=sha256:6b7c5dd434bba4ee08fe46bbbdb25c60dd3d47ccb4b8561a9751cf1526da52b8 \ + --hash=sha256:6c8dbbdd31cbb28eccfb23dbbd4218fe67bfc32e9cb452875a485b81031c949d \ + --hash=sha256:7451b30f99d9f31a2b9d70e6ef61815713dc782c58c6d817f91781341e4dac05 \ + --hash=sha256:8cd11027b5a8b66db8d344085a31a1666c78621dac27039c4d571bc4974804a1 \ + --hash=sha256:9c088e8c9f67c266ef4dae3703bd617f7c0cb0fd98e99c4500692e38a4328140 \ + --hash=sha256:a92505189731fcb03f1c69a84ea4460abb24204bfac1f339448a0621e7def77c \ + --hash=sha256:acb8d52fadcefafef4ef8ecca3fc99b1d0e3c5c5a888766484c3e39f050be7f5 \ + --hash=sha256:b193a7a1c4f455a61e60ed2dd67271a3daab0910ddb4bd9db51390d1b36d9996 \ + --hash=sha256:bc28a58c580253a526a4b6d239d18181ef96f1e285a502dbb03ff15eeec07a5b \ + --hash=sha256:c0529afab3800749dd245843d3bf0d061a109a8edb77fb345f476e8bccda51b8 \ + --hash=sha256:d2b353b0bd53fedd77fc5a12a1c1a91cacc3cf59e3dd785529c5a54b31d1c7b1 \ + --hash=sha256:de63843706fdfe9565a45567238c5b1e55a0b28bbde6524200b31d29043a9a16 \ + --hash=sha256:e93df6d34ff5f0f6be245f4d29b99a7c1eef8ad91b50686adf57a5eeea99cb74 \ + --hash=sha256:f65dfaf9e737a41389e29a5a2ea52ca5d14c8d6f48b402c723d800cd16d322b0 \ + --hash=sha256:f8b51d7e685bbb63f6becd7d2ac8634d5ab67ec7e53038e597182e2db2c7aa90 # via -r build/nonfreethreading-requirements.txt termcolor==3.2.0 \ --hash=sha256:610e6456feec42c4bcd28934a8c87a06c3fa28b01561d46aa09a9881b8622c58 \ @@ -1531,13 +1527,13 @@ typing-extensions==4.15.0 \ # grpcio # optree # tensorflow -urllib3==2.5.0 \ - --hash=sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760 \ - --hash=sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc +urllib3==2.6.2 \ + --hash=sha256:016f9c98bb7e98085cb2b4b17b87d2c702975664e4f060c6532e64d1c1a5e797 \ + --hash=sha256:ec21cddfe7724fc7cb4ba4bea7aa8e2ef36f607a4bab81aa6ce42a13dc3f03dd # via requests -werkzeug==3.1.3 \ - --hash=sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e \ - --hash=sha256:60723ce945c19328679790e3282cc758aa4a6040e4bb330f53d30fa546d44746 +werkzeug==3.1.4 \ + --hash=sha256:2ad50fb9ed09cc3af22c54698351027ace879a0b60a3b5edf5730b2f7d876905 \ + --hash=sha256:cd3cd98b1b92dc3b7b3995038826c68097dcb16f9baa63abe35f20eafeb9fe5e # via tensorboard wheel==0.46.1 \ --hash=sha256:f796f65d72750ccde090663e466d0ca37cd72b62870f7520b96d34cdc07d86d8 \ diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt index 7157d2a08ebf..815ab33cc09a 100644 --- a/build/requirements_lock_3_13_ft.txt +++ b/build/requirements_lock_3_13_ft.txt @@ -116,73 +116,65 @@ execnet==2.1.2 \ --hash=sha256:63d83bfdd9a23e35b9c6a3261412324f964c2ec8dcd8d3c6916ee9373e0befcd \ --hash=sha256:67fba928dd5a544b783f6056f449e5e3931a5c378b128bc18501f7ea79e296ec # via pytest-xdist -filelock==3.20.0 \ - --hash=sha256:339b4732ffda5cd79b13f4e2711a31b0365ce445d95d243bb996273d072546a2 \ - --hash=sha256:711e943b4ec6be42e1d4e6690b48dc175c822967466bb31c0c293f34334c13f4 +filelock==3.20.1 \ + --hash=sha256:15d9e9a67306188a44baa72f569d2bfd803076269365fdea0934385da4dc361a \ + --hash=sha256:b8360948b351b80f420878d8516519a2204b07aefcdcfd24912a5d33127f188c # via -r build/test-requirements.txt flatbuffers==25.9.23 \ --hash=sha256:255538574d6cb6d0a79a17ec8bc0d30985913b87513a01cce8bcdb6b4c44d0e2 \ --hash=sha256:676f9fa62750bb50cf531b42a0a2a118ad8f7f797a511eda12881c016f093b12 # via -r build/test-requirements.txt -fonttools==4.60.1 \ - --hash=sha256:022beaea4b73a70295b688f817ddc24ed3e3418b5036ffcd5658141184ef0d0c \ - --hash=sha256:026290e4ec76583881763fac284aca67365e0be9f13a7fb137257096114cb3bc \ - --hash=sha256:0b0835ed15dd5b40d726bb61c846a688f5b4ce2208ec68779bc81860adb5851a \ - --hash=sha256:0eae96373e4b7c9e45d099d7a523444e3554360927225c1cdae221a58a45b856 \ - --hash=sha256:122e1a8ada290423c493491d002f622b1992b1ab0b488c68e31c413390dc7eb2 \ - --hash=sha256:1410155d0e764a4615774e5c2c6fc516259fe3eca5882f034eb9bfdbee056259 \ - --hash=sha256:145daa14bf24824b677b9357c5e44fd8895c2a8f53596e1b9ea3496081dc692c \ - --hash=sha256:1525796c3ffe27bb6268ed2a1bb0dcf214d561dfaf04728abf01489eb5339dce \ - --hash=sha256:154cb6ee417e417bf5f7c42fe25858c9140c26f647c7347c06f0cc2d47eff003 \ - --hash=sha256:2299df884c11162617a66b7c316957d74a18e3758c0274762d2cc87df7bc0272 \ - --hash=sha256:2409d5fb7b55fd70f715e6d34e7a6e4f7511b8ad29a49d6df225ee76da76dd77 \ - --hash=sha256:268ecda8ca6cb5c4f044b1fb9b3b376e8cd1b361cef275082429dc4174907038 \ - --hash=sha256:282dafa55f9659e8999110bd8ed422ebe1c8aecd0dc396550b038e6c9a08b8ea \ - --hash=sha256:2ee06fc57512144d8b0445194c2da9f190f61ad51e230f14836286470c99f854 \ - --hash=sha256:3630e86c484263eaac71d117085d509cbcf7b18f677906824e4bace598fb70d2 \ - --hash=sha256:398447f3d8c0c786cbf1209711e79080a40761eb44b27cdafffb48f52bcec258 \ - --hash=sha256:4ba4bd646e86de16160f0fb72e31c3b9b7d0721c3e5b26b9fa2fc931dfdb2652 \ - --hash=sha256:5664fd1a9ea7f244487ac8f10340c4e37664675e8667d6fee420766e0fb3cf08 \ - --hash=sha256:583b7f8e3c49486e4d489ad1deacfb8d5be54a8ef34d6df824f6a171f8511d99 \ - --hash=sha256:596ecaca36367027d525b3b426d8a8208169d09edcf8c7506aceb3a38bfb55c7 \ - --hash=sha256:5c1015318e4fec75dd4943ad5f6a206d9727adf97410d58b7e32ab644a807914 \ - --hash=sha256:66929e2ea2810c6533a5184f938502cfdaea4bc3efb7130d8cc02e1c1b4108d6 \ - --hash=sha256:6ec722ee589e89a89f5b7574f5c45604030aa6ae24cb2c751e2707193b466fed \ - --hash=sha256:6f68576bb4bbf6060c7ab047b1574a1ebe5c50a17de62830079967b211059ebb \ - --hash=sha256:7473a8ed9ed09aeaa191301244a5a9dbe46fe0bf54f9d6cd21d83044c3321217 \ - --hash=sha256:7b0c6d57ab00dae9529f3faf187f2254ea0aa1e04215cf2f1a8ec277c96661bc \ - --hash=sha256:7b4c32e232a71f63a5d00259ca3d88345ce2a43295bb049d21061f338124246f \ - --hash=sha256:8177ec9676ea6e1793c8a084a90b65a9f778771998eb919d05db6d4b1c0b114c \ - --hash=sha256:839565cbf14645952d933853e8ade66a463684ed6ed6c9345d0faf1f0e868877 \ - --hash=sha256:875cb7764708b3132637f6c5fb385b16eeba0f7ac9fa45a69d35e09b47045801 \ - --hash=sha256:8a44788d9d91df72d1a5eac49b31aeb887a5f4aab761b4cffc4196c74907ea85 \ - --hash=sha256:8b4eb332f9501cb1cd3d4d099374a1e1306783ff95489a1026bde9eb02ccc34a \ - --hash=sha256:906306ac7afe2156fcf0042173d6ebbb05416af70f6b370967b47f8f00103bbb \ - --hash=sha256:992775c9fbe2cf794786fa0ffca7f09f564ba3499b8fe9f2f80bd7197db60383 \ - --hash=sha256:996a4d1834524adbb423385d5a629b868ef9d774670856c63c9a0408a3063401 \ - --hash=sha256:9a52f254ce051e196b8fe2af4634c2d2f02c981756c6464dc192f1b6050b4e28 \ - --hash=sha256:9d0ced62b59e0430b3690dbc5373df1c2aa7585e9a8ce38eff87f0fd993c5b01 \ - --hash=sha256:a140761c4ff63d0cb9256ac752f230460ee225ccef4ad8f68affc723c88e2036 \ - --hash=sha256:a184b2ea57b13680ab6d5fbde99ccef152c95c06746cb7718c583abd8f945ccc \ - --hash=sha256:a3db56f153bd4c5c2b619ab02c5db5192e222150ce5a1bc10f16164714bc39ac \ - --hash=sha256:a46b2f450bc79e06ef3b6394f0c68660529ed51692606ad7f953fc2e448bc903 \ - --hash=sha256:a884aef09d45ba1206712c7dbda5829562d3fea7726935d3289d343232ecb0d3 \ - --hash=sha256:b2cf105cee600d2de04ca3cfa1f74f1127f8455b71dbad02b9da6ec266e116d6 \ - --hash=sha256:b33a7884fabd72bdf5f910d0cf46be50dce86a0362a65cfc746a4168c67eb96c \ - --hash=sha256:b42d86938e8dda1cd9a1a87a6d82f1818eaf933348429653559a458d027446da \ - --hash=sha256:b6379e7546ba4ae4b18f8ae2b9bc5960936007a1c0e30b342f662577e8bc3299 \ - --hash=sha256:c7420a2696a44650120cdd269a5d2e56a477e2bfa9d95e86229059beb1c19e15 \ - --hash=sha256:c8651e0d4b3bdeda6602b85fdc2abbefc1b41e573ecb37b6779c4ca50753a199 \ - --hash=sha256:d066ea419f719ed87bc2c99a4a4bfd77c2e5949cb724588b9dd58f3fd90b92bf \ - --hash=sha256:e6c58beb17380f7c2ea181ea11e7db8c0ceb474c9dd45f48e71e2cb577d146a1 \ - --hash=sha256:e852d9dda9f93ad3651ae1e3bb770eac544ec93c3807888798eccddf84596537 \ - --hash=sha256:ec3681a0cb34c255d76dd9d865a55f260164adb9fa02628415cdc2d43ee2c05d \ - --hash=sha256:ee0c0b3b35b34f782afc673d503167157094a16f442ace7c6c5e0ca80b08f50c \ - --hash=sha256:eedacb5c5d22b7097482fa834bda0dafa3d914a4e829ec83cdea2a01f8c813c4 \ - --hash=sha256:ef00af0439ebfee806b25f24c8f92109157ff3fac5731dc7867957812e87b8d9 \ - --hash=sha256:f0e8817c7d1a0c2eedebf57ef9a9896f3ea23324769a9a2061a80fe8852705ed \ - --hash=sha256:f3d5be054c461d6a2268831f04091dc82753176f6ea06dc6047a5e168265a987 \ - --hash=sha256:f4b5c37a5f40e4d733d3bbaaef082149bee5a5ea3156a785ff64d949bd1353fa +fonttools==4.61.1 \ + --hash=sha256:0de30bfe7745c0d1ffa2b0b7048fb7123ad0d71107e10ee090fa0b16b9452e87 \ + --hash=sha256:10d88e55330e092940584774ee5e8a6971b01fc2f4d3466a1d6c158230880796 \ + --hash=sha256:11f35ad7805edba3aac1a3710d104592df59f4b957e30108ae0ba6c10b11dd75 \ + --hash=sha256:15acc09befd16a0fb8a8f62bc147e1a82817542d72184acca9ce6e0aeda9fa6d \ + --hash=sha256:17d2bf5d541add43822bcf0c43d7d847b160c9bb01d15d5007d84e2217aaa371 \ + --hash=sha256:2180f14c141d2f0f3da43f3a81bc8aa4684860f6b0e6f9e165a4831f24e6a23b \ + --hash=sha256:21e7c8d76f62ab13c9472ccf74515ca5b9a761d1bde3265152a6dc58700d895b \ + --hash=sha256:41a7170d042e8c0024703ed13b71893519a1a6d6e18e933e3ec7507a2c26a4b2 \ + --hash=sha256:41ed4b5ec103bd306bb68f81dc166e77409e5209443e5773cb4ed837bcc9b0d3 \ + --hash=sha256:497c31ce314219888c0e2fce5ad9178ca83fe5230b01a5006726cdf3ac9f24d9 \ + --hash=sha256:4c1b526c8d3f615a7b1867f38a9410849c8f4aef078535742198e942fba0e9bd \ + --hash=sha256:4d7092bb38c53bbc78e9255a59158b150bcdc115a1e3b3ce0b5f267dc35dd63c \ + --hash=sha256:4f5686e1fe5fce75d82d93c47a438a25bf0d1319d2843a926f741140b2b16e0c \ + --hash=sha256:58b0ee0ab5b1fc9921eccfe11d1435added19d6494dde14e323f25ad2bc30c56 \ + --hash=sha256:5ce02f38a754f207f2f06557523cd39a06438ba3aafc0639c477ac409fc64e37 \ + --hash=sha256:5fade934607a523614726119164ff621e8c30e8fa1ffffbbd358662056ba69f0 \ + --hash=sha256:5fe9fd43882620017add5eabb781ebfbc6998ee49b35bd7f8f79af1f9f99a958 \ + --hash=sha256:64102ca87e84261419c3747a0d20f396eb024bdbeb04c2bfb37e2891f5fadcb5 \ + --hash=sha256:664c5a68ec406f6b1547946683008576ef8b38275608e1cee6c061828171c118 \ + --hash=sha256:6675329885c44657f826ef01d9e4fb33b9158e9d93c537d84ad8399539bc6f69 \ + --hash=sha256:75c1a6dfac6abd407634420c93864a1e274ebc1c7531346d9254c0d8f6ca00f9 \ + --hash=sha256:75da8f28eff26defba42c52986de97b22106cb8f26515b7c22443ebc9c2d3261 \ + --hash=sha256:77efb033d8d7ff233385f30c62c7c79271c8885d5c9657d967ede124671bbdfb \ + --hash=sha256:78a7d3ab09dc47ac1a363a493e6112d8cabed7ba7caad5f54dbe2f08676d1b47 \ + --hash=sha256:7c7db70d57e5e1089a274cbb2b1fd635c9a24de809a231b154965d415d6c6d24 \ + --hash=sha256:8c56c488ab471628ff3bfa80964372fc13504ece601e0d97a78ee74126b2045c \ + --hash=sha256:91669ccac46bbc1d09e9273546181919064e8df73488ea087dcac3e2968df9ba \ + --hash=sha256:9b666a475a65f4e839d3d10473fad6d47e0a9db14a2f4a224029c5bfde58ad2c \ + --hash=sha256:9cfef3ab326780c04d6646f68d4b4742aae222e8b8ea1d627c74e38afcbc9d91 \ + --hash=sha256:a13fc8aeb24bad755eea8f7f9d409438eb94e82cf86b08fe77a03fbc8f6a96b1 \ + --hash=sha256:a75c301f96db737e1c5ed5fd7d77d9c34466de16095a266509e13da09751bd19 \ + --hash=sha256:a76d4cb80f41ba94a6691264be76435e5f72f2cb3cab0b092a6212855f71c2f6 \ + --hash=sha256:aed04cabe26f30c1647ef0e8fbb207516fd40fe9472e9439695f5c6998e60ac5 \ + --hash=sha256:b148b56f5de675ee16d45e769e69f87623a4944f7443850bf9a9376e628a89d2 \ + --hash=sha256:b501c862d4901792adaec7c25b1ecc749e2662543f68bb194c42ba18d6eec98d \ + --hash=sha256:b846a1fcf8beadeb9ea4f44ec5bdde393e2f1569e17d700bfc49cd69bde75881 \ + --hash=sha256:b931ae8f62db78861b0ff1ac017851764602288575d65b8e8ff1963fed419063 \ + --hash=sha256:c33ab3ca9d3ccd581d58e989d67554e42d8d4ded94ab3ade3508455fe70e65f7 \ + --hash=sha256:c6604b735bb12fef8e0efd5578c9fb5d3d8532d5001ea13a19cddf295673ee09 \ + --hash=sha256:d8db08051fc9e7d8bc622f2112511b8107d8f27cd89e2f64ec45e9825e8288da \ + --hash=sha256:d9203500f7c63545b4ce3799319fe4d9feb1a1b89b28d3cb5abd11b9dd64147e \ + --hash=sha256:dc492779501fa723b04d0ab1f5be046797fee17d27700476edc7ee9ae535a61e \ + --hash=sha256:e6bcdf33aec38d16508ce61fd81838f24c83c90a1d1b8c68982857038673d6b8 \ + --hash=sha256:e76ce097e3c57c4bcb67c5aa24a0ecdbd9f74ea9219997a707a4061fbe2707aa \ + --hash=sha256:eff1ac3cc66c2ac7cda1e64b4e2f3ffef474b7335f92fc3833fc632d595fcee6 \ + --hash=sha256:f3cb4a569029b9f291f88aafc927dd53683757e640081ca8c412781ea144565e \ + --hash=sha256:f79b168428351d11e10c5aeb61a74e1851ec221081299f4cf56036a95431c43a \ + --hash=sha256:fa646ecec9528bef693415c79a86e733c70a4965dd938e9a226b0fc64c9d2e6c \ + --hash=sha256:fe2efccb324948a11dd09d22136fe2ac8a97d6c1347cf0b58a911dcd529f66b7 \ + --hash=sha256:fff4f534200a04b4a36e7ae3cb74493afe807b517a09e99cb4faa89a34ed6ecd # via matplotlib fsspec==2025.10.0 \ --hash=sha256:7c7712353ae7d875407f97715f0e1ffcc21e33d5b24556cb1e090ae9409ec61d \ @@ -200,69 +192,69 @@ iniconfig==2.3.0 \ --hash=sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730 \ --hash=sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12 # via pytest -jax-cuda12-pjrt==0.8.1 ; sys_platform == "linux" \ - --hash=sha256:452b70ee10cb9ac5d7dfca55ffbcdb89b6c8bc6ba70a45af7c490d1dcea98eb7 \ - --hash=sha256:a631d0689903354afd7b3d2ec595b7da06a6230a76da00ff9548f542b21b6250 +jax-cuda12-pjrt==0.8.2 ; sys_platform == "linux" \ + --hash=sha256:717a1b196a642409ce195ddf031c20bbeadcc886f55e49a1d3f4927373aeedae \ + --hash=sha256:e3bab41ca7c48e4163db9e7efd271b3aa85f0fe45f5ed0708d6bbed93a59f977 # via # -r build/requirements.in # jax-cuda12-plugin -jax-cuda12-plugin==0.8.1 ; sys_platform == "linux" and python_version < "3.14" \ - --hash=sha256:1052c29157c99ca01d74d3073bbf4f711eb94465c0b4f5a4322d5e46233b1b2f \ - --hash=sha256:13311b72ca703a1bbad1ec516ac9ef750019a2d2c421d4c1daf8acf2720b822e \ - --hash=sha256:385001f56f852959f061ae15ad157c39cc4471c8d1d2544dfc3f805684ac2213 \ - --hash=sha256:479ca438d555024dac8dd1058371efbf8f479819a70aea213f3f3037ece99d74 \ - --hash=sha256:5a154723cb6c4e1e7969581a923dacf378f7515b0d53b5f1920e25e51cf6cecc \ - --hash=sha256:6a4b6fda687ca8361322029d58444bc0326798204806a3f90f231dc8ca5541a5 \ - --hash=sha256:7342c8810cc947de78f28c7287a30b2e201b0f51578543dd2553692b79a49942 \ - --hash=sha256:836eb0cd3af612d17bf17efc7eee175c6b9827989d5370df8ba919947fcb67cf \ - --hash=sha256:9968c15b87fd3867b6da0ce30681673a7fc4eedebaadcd24dce892e3f9fe1a52 \ - --hash=sha256:b3383bdc0b9f6260d9adc4ca0d1f68bf241158dfe69d726b267b0681382ea7a7 \ - --hash=sha256:b60bf0bbda24cec6fa71170bd69b613359f01a376d8e09fe34bf67ecc9a3164f \ - --hash=sha256:da7c0f2ef1c697f9ade51a71cfad211e2bff25407a6855dddde372c0190fc468 +jax-cuda12-plugin==0.8.2 ; sys_platform == "linux" and python_version < "3.14" \ + --hash=sha256:0b0a3304ce7e494acd8d9c593490c112a32cdb6010fe1afc584d9e41fd863167 \ + --hash=sha256:1b4828242d57f233b394d17ebaa599c503c1fb9b7c754012a06eb84dbc935fc8 \ + --hash=sha256:20165861b3d3e66ebb2c0f63a547d1d5ee17ea44ac3be7153c7908c9ca8c88f3 \ + --hash=sha256:377e4be17e22dde0343b3f3c05bf69235b3dbf11d766cca9c5a93da47971dcb7 \ + --hash=sha256:403d5e07731b5cdac3bd9fb3f448bd8480062cb2c0ab61ea2ad23fcd0a65479a \ + --hash=sha256:58c51473fc622e03138035985f741833564d70a4bd5a2178f61b62cdaa32ff94 \ + --hash=sha256:637387dc3408cd204562668502f9e95f76c6edde0a6d2e48f055162dc2aebf0d \ + --hash=sha256:70d33222484ad5c375b8f8357b7c23cacb844f6ecfc39567f8dd47fde6e87858 \ + --hash=sha256:82c6798be66bf8c773386918e4c8e5cd8119753f3bfb3ca4bbc46818283750c6 \ + --hash=sha256:a5898bac1d8ab6020b54546440256409f2c66bcbbb3a1099ca473c84843addad \ + --hash=sha256:d68a6d8b4a45ee561746bac7a6468da8203832626b0b39ad4ac43011f61f875d \ + --hash=sha256:dd4f7c34d4512ff5a36fd1b01584ef7781cad615e3f9e71880eae2f4998e5108 # via -r build/requirements.in -jax-cuda13-pjrt==0.8.1 \ - --hash=sha256:86a6926da76aebf6080922747a7a98d321f4ca27101077357fa148032bc3cd1d \ - --hash=sha256:f3b1c1c7118b4570f72740ed756cbed289a3f8fa813570a0dbf16f186bccb8c9 +jax-cuda13-pjrt==0.8.2 \ + --hash=sha256:370e9cb9a76af6a6b00e8f6c6ae8e5a0158886b65994a114cf3acf56ea1570f3 \ + --hash=sha256:c0ebc284d4bea9c0b278017b4bb62eb871755fe582cbf22ce1fb770d30fb9af6 # via # -r build/requirements.in # jax-cuda13-plugin -jax-cuda13-plugin==0.8.1 \ - --hash=sha256:07625aed1aa769c701213e84d6b2a46902019a1d2af8a09ce6dfd9575163bfc6 \ - --hash=sha256:0d503c312d2daefea62a00c74534579deeacd46e15c364074d27a8d95a100032 \ - --hash=sha256:12a7aac712a7c6dc228ef9991578e85e3bcab7c324193bdfb2b5acf059bae6d6 \ - --hash=sha256:16ee16b13393baf9672b6612566308675cebdc8d785b61fac2b93ce8c97825ff \ - --hash=sha256:4e589ed8197f1bea7e7fd20d866ccc5c2a1276d7acd02224e3a5b07983df61e2 \ - --hash=sha256:64df1f1414d899ab7a84751d6f78515365555b54fb64b3e318bd70519de99c86 \ - --hash=sha256:7a373fd3e5f11ecad01b8add1e277eb6559b4966b0745d92dc91c585579fac35 \ - --hash=sha256:92238530152890c3405addacd1fc021c87022cbf99fa66418cfa2e9f68a5c49d \ - --hash=sha256:a4c5a4a69346be6520c729675d5d80e85d610399f4840d74bdfae9c6ebedc8bc \ - --hash=sha256:af33f737ccf5426155cf5c7d175bf765ca25724b94af5109ef2df891b410f997 \ - --hash=sha256:d81222989fb30496cc6554d42deaca6ab003721f56ff669da7f80d61fea1219d \ - --hash=sha256:ef218c47b2cde8c700ad2a56d04320f9d1490439fc6db20747f56e91de7289c2 +jax-cuda13-plugin==0.8.2 \ + --hash=sha256:0ab3a13f798f973416754cb00b1696a8e6d2425c39496e2145293661bcd28446 \ + --hash=sha256:2c7532f2a6c9dcf50ba4b8ebbe4c78988b8c260695e50f9fdbd68f9417a8bb51 \ + --hash=sha256:358f4b4d0d2b92f62fe7f07d355a9c57c5684da3d4e6f483afec38b23e66f02e \ + --hash=sha256:6703552ee10b6b41b43fe4fc62b4672bec13d271eb709a29bf14b0fa1ec805dd \ + --hash=sha256:8fe2c5874cbae18a01a0fa722dfbd547b5504e7bf42aaaae6ad3e1b676e8daa2 \ + --hash=sha256:950f2c1f2e0f5e13893c4c27a4df540a96a26ebd15d7954fa85ffb8cc8b1d9a1 \ + --hash=sha256:9835076d0a917345700370cfaee5cfd7876c0b0982e99206a8b4848af0ea6e5e \ + --hash=sha256:bb059d99eb960718358f82f3eaa0829f96fc7d7e978ded821f2c40256a8aee4b \ + --hash=sha256:d5f9470db46aee8a64e9ab7817927335b77c6d3698d5113e0da2dd962c1468c4 \ + --hash=sha256:f1643b216ccb41a2d7c0e240359c729c96c5eca76e4525ffcd6313049e348b9a \ + --hash=sha256:f35f687d9b839365a49c2b0052eaeb6b9f0a9879813b4fa020f8d9caa3ddde46 \ + --hash=sha256:fa6a6faf3130d6ef05f9ee6cbc27e0a2e4db0f23720136403991eccb3a0144af # via -r build/requirements.in -jaxlib==0.8.1 \ - --hash=sha256:117f2fe2c19479e560ad85a3ef2fcc0b1d24816456f0d039f865c2acbab63b5a \ - --hash=sha256:1a4001ed3ba9ed5a812da1b16f52eebb5d473a4480c1523828c7bd3dae8d1375 \ - --hash=sha256:1bc76edec2bc74a7adb5e29329ece51a67c57cd011a06d55d07da62fbabe3389 \ - --hash=sha256:22f489fb5c8be0da7be5e4957a10936b3760a169668f8b25c5d09c51c3ef47f6 \ - --hash=sha256:24ec3f3a9c45d6de060020dc94c444d69e18099fab927ea3979ff8cedf0ed2c9 \ - --hash=sha256:4933298fcfb07a5aa2d1fed21c111d07cea50e6f180dba2cdb5463c13fb98f2f \ - --hash=sha256:63fc25c4b5d03256798796a024125e29bcf254acc3eae5dc3239d1c30b86b866 \ - --hash=sha256:7a5d381fad89622750fae29fab83c0847e2931ad8d6a34dc13b28fc4d67f75a3 \ - --hash=sha256:865add56139883405f3f15c9b0de6a64ab8f4aa549dff196b72dbc86be6ccc1f \ - --hash=sha256:88bde0f535eeea6689e0cd57d40b7660d5206ac95c7d42e09562a109b963a49f \ - --hash=sha256:8e118e1fbe714f37a94ba26777c17faab7dca4a33646a3d98cd1d99673bbd6b1 \ - --hash=sha256:90e48973f8dbded7edc8728be84c01ae00412190187fb06622abfa4edd42c0a8 \ - --hash=sha256:92c41c9b9862c08521eb90515a7c5bcc840c6d30f86230cebf94aea2d6a0af81 \ - --hash=sha256:a0349f6e8179dc897d33aeb90ec66b4a8041330fbbba8d071dc6167cd2271539 \ - --hash=sha256:af4924189fc53b69237715b56ebcbfc71bb91ca16184143dcef0d430c8173de6 \ - --hash=sha256:bd697c171ace1e2e9d6ed910a78f385b3c4095cee290b0255aa58848f2acdeab \ - --hash=sha256:bed1e94ae8c7c16bca4476d8d7f582f0d1a102a4e69c3a9bd2069a0dc42274a9 \ - --hash=sha256:c14c8c19a7eb694aa14092b6d2fffb9d2bdd8a603b63d6f26fbeaf129c204f9f \ - --hash=sha256:d245bd6a279c72ca5f796df84cdd64d7c9c8abc4b8d89adf4acf45898dab958b \ - --hash=sha256:f2f11491b077d05249d63813e811401194a41edc8e9cc60af8f4b554057cfad0 \ - --hash=sha256:fdbbf2336c08bbf8f30548e204c8c9d77f8b2a3a5b7fc7985749246feb8852b0 \ - --hash=sha256:ff32b6320d729131efaf22939825b52d75957c84c32af2b0b1bdb33cf27ba75f +jaxlib==0.8.2 \ + --hash=sha256:023de6f3f56da2af7037970996500586331fdb50b530ecbb54b9666da633bd00 \ + --hash=sha256:05b958f497e49824c432e734bb059723b7dfe69e2ad696a9f9c8ad82fff7c3f8 \ + --hash=sha256:1bfbcf6c3de221784fa4cdb6765a09d71cb4298b15626b3d0409b3dfcd8a8667 \ + --hash=sha256:28eec1a4e0639a0d8702cea3cb70dd3663053dbfa344452994ea48dc6ceadaa5 \ + --hash=sha256:2b9789bd08f8b0cc5a5c12ae896fe432d5942e32e417091b8b5a96a9a6fd5cf1 \ + --hash=sha256:3b16e50c5b730c9dd0a49e55f1acfaa722b00b1af0522a591558dcc0464252f2 \ + --hash=sha256:490bf0cb029c73c65c9431124b86cdc95082dbc1fb76fc549d24d75da33e5454 \ + --hash=sha256:4d006db96be020c8165212a1216372f8acac4ff4f8fb067743d694ef2b301ace \ + --hash=sha256:68108dff0de74adc468016be9a19f80efe48c660c0d5a122287094b44b092afc \ + --hash=sha256:7c304f3a016965b9d1f5239a8a0399a73925f5604fe914c5ca66ecf734bf6422 \ + --hash=sha256:7da8127557c786264049ae55460d1b8d04cc3cdf0403a087f2fc1e6d313ec722 \ + --hash=sha256:964626f581beab31ee6826b228fcc2ec5181b05cecf94a528dff97921c145dbc \ + --hash=sha256:a397ea7dcb37d689ce79173eeb99b2f1347637a36be9a27f20ae6848bfc58bfc \ + --hash=sha256:aa8701b6356f098e8452c3cec762fb5f706fcb8f67ffd65964f63982479aa23b \ + --hash=sha256:bb89be452b1b808d3f88fc01c415b364a260be4cc7ac120c038009f6150a32dc \ + --hash=sha256:beffb004e7eeb5c9afb24439e2b2cf45a4ee3e3e8adf45e355edf2af62acf8b8 \ + --hash=sha256:ccf77da917a20935247c990691decfcbdd06c25ef0ac94d914a04aadb22f714c \ + --hash=sha256:dffc22b5b732b9556d92c918b251c61bcc046617c4dbb51e1f7a656587fddffb \ + --hash=sha256:e6a97dfb0232eed9a2bb6e3828e4f682dbac1a7fea840bfda574cae2dbf5faf9 \ + --hash=sha256:f205e91c3a152a2a76c0bc59a6a2de03e87ec261b91e8812922777185e7b08f5 \ + --hash=sha256:f28edac8c226fc07fa3e8af6f9defede8ac2c307429e3291edce8739d39becc9 \ + --hash=sha256:f472cc72e3058e50b5f0230b236d5a1183bf6c3d5423d2a52eff07bcf34908de # via -r build/requirements.in kiwisolver==1.4.9 \ --hash=sha256:0749fd8f4218ad2e851e11cc4dc05c7cbc0cbc4267bdfdb31782e65aace4ee9c \ @@ -367,74 +359,74 @@ kiwisolver==1.4.9 \ --hash=sha256:fb940820c63a9590d31d88b815e7a3aa5915cad3ce735ab45f0c730b39547de1 \ --hash=sha256:fc1795ac5cd0510207482c3d1d3ed781143383b8cfd36f5c645f3897ce066220 # via matplotlib -libtpu==0.0.30 ; sys_platform == "linux" and platform_machine == "x86_64" \ - --hash=sha256:26442f0a51d243cf7259407bba8f5d849c9024297efe97044d64b5244283ad63 \ - --hash=sha256:5fabff9a041674bb889fb59ac0b5c54b9dbcf492a8c782e083ef86a8194dbb0f \ - --hash=sha256:8be30562743a63c1c1353e7ba78f0dbfbb051e8d1e9d3bb2b5da9b720363bb0a \ - --hash=sha256:b1fc44915dad56c0ceb733311a4d4396b88dc9a1c7c01acd7617da90e7ec22f2 \ - --hash=sha256:babab04ca663da2c4e4b3ab036c4d465f2f4674c480d08239c5d4965b7ce9e1c \ - --hash=sha256:f9aa040895ec25fafebcd4e1a0e1a9524ff3bd778ca88543731e308f6e516dd1 +libtpu==0.0.32 ; sys_platform == "linux" and platform_machine == "x86_64" \ + --hash=sha256:37f6aefe6d69d24e5e268e74c1e90cd0ce0fa6a796720709445d8938605912ee \ + --hash=sha256:4a4ae6db0a90f0e33902cd345923a5d77dfc5bbab9a3e524c79bf876858104ee \ + --hash=sha256:6453995774b902e43d1caf0d82298a2e90f2f731ddd74d6bef9510d91732775f \ + --hash=sha256:7ffd9cdb89da22d32bb60ce70b88e9e76861c5d704ebf673f32987e19e9913f3 \ + --hash=sha256:98d9713a39d9581c9b10a7bdeaa575cc874a4a7f543cfb40d3a723ee27b428b5 \ + --hash=sha256:d2a0e7abe4a029b79049bc95da20e4c2a64f8ef09823f75286cecfbb4b0b6da1 # via -r build/requirements.in markdown-it-py==4.0.0 \ --hash=sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147 \ --hash=sha256:cb0a2b4aa34f932c007117b194e945bd74e0ec24133ceb5bac59009cda1cb9f3 # via rich -matplotlib==3.10.7 \ - --hash=sha256:07124afcf7a6504eafcb8ce94091c5898bbdd351519a1beb5c45f7a38c67e77f \ - --hash=sha256:09d7945a70ea43bf9248f4b6582734c2fe726723204a76eca233f24cffc7ef67 \ - --hash=sha256:0d8c32b7ea6fb80b1aeff5a2ceb3fb9778e2759e899d9beff75584714afcc5ee \ - --hash=sha256:11ae579ac83cdf3fb72573bb89f70e0534de05266728740d478f0f818983c695 \ - --hash=sha256:15112bcbaef211bd663fa935ec33313b948e214454d949b723998a43357b17b0 \ - --hash=sha256:1d9d3713a237970569156cfb4de7533b7c4eacdd61789726f444f96a0d28f57f \ - --hash=sha256:1e4bbad66c177a8fdfa53972e5ef8be72a5f27e6a607cec0d8579abd0f3102b1 \ - --hash=sha256:2222c7ba2cbde7fe63032769f6eb7e83ab3227f47d997a8453377709b7fe3a5a \ - --hash=sha256:22df30ffaa89f6643206cf13877191c63a50e8f800b038bc39bee9d2d4957632 \ - --hash=sha256:31963603041634ce1a96053047b40961f7a29eb8f9a62e80cc2c0427aa1d22a2 \ - --hash=sha256:37a1fea41153dd6ee061d21ab69c9cf2cf543160b1b85d89cd3d2e2a7902ca4c \ - --hash=sha256:3886e47f64611046bc1db523a09dd0a0a6bed6081e6f90e13806dd1d1d1b5e91 \ - --hash=sha256:4645fc5d9d20ffa3a39361fcdbcec731382763b623b72627806bf251b6388866 \ - --hash=sha256:4a11c2e9e72e7de09b7b72e62f3df23317c888299c875e2b778abf1eda8c0a42 \ - --hash=sha256:4a74f79fafb2e177f240579bc83f0b60f82cc47d2f1d260f422a0627207008ca \ - --hash=sha256:4c14b6acd16cddc3569a2d515cfdd81c7a68ac5639b76548cfc1a9e48b20eb65 \ - --hash=sha256:53b492410a6cd66c7a471de6c924f6ede976e963c0f3097a3b7abfadddc67d0a \ - --hash=sha256:53cc80662dd197ece414dd5b66e07370201515a3eaf52e7c518c68c16814773b \ - --hash=sha256:5c09cf8f2793f81368f49f118b6f9f937456362bee282eac575cca7f84cda537 \ - --hash=sha256:5e38c2d581d62ee729a6e144c47a71b3f42fb4187508dbbf4fe71d5612c3433b \ - --hash=sha256:5f3f6d315dcc176ba7ca6e74c7768fb7e4cf566c49cb143f6bc257b62e634ed8 \ - --hash=sha256:6516ce375109c60ceec579e699524e9d504cd7578506f01150f7a6bc174a775e \ - --hash=sha256:667ecd5d8d37813a845053d8f5bf110b534c3c9f30e69ebd25d4701385935a6d \ - --hash=sha256:6f1851eab59ca082c95df5a500106bad73672645625e04538b3ad0f69471ffcc \ - --hash=sha256:702590829c30aada1e8cef0568ddbffa77ca747b4d6e36c6d173f66e301f89cc \ - --hash=sha256:7146d64f561498764561e9cd0ed64fcf582e570fc519e6f521e2d0cfd43365e1 \ - --hash=sha256:744991e0cc863dd669c8dc9136ca4e6e0082be2070b9d793cbd64bec872a6815 \ - --hash=sha256:786656bb13c237bbcebcd402f65f44dd61ead60ee3deb045af429d889c8dbc67 \ - --hash=sha256:7a0edb7209e21840e8361e91ea84ea676658aa93edd5f8762793dec77a4a6748 \ - --hash=sha256:7ac81eee3b7c266dd92cee1cd658407b16c57eed08c7421fa354ed68234de380 \ - --hash=sha256:90ad854c0a435da3104c01e2c6f0028d7e719b690998a2333d7218db80950722 \ - --hash=sha256:9257be2f2a03415f9105c486d304a321168e61ad450f6153d77c69504ad764bb \ - --hash=sha256:932c55d1fa7af4423422cb6a492a31cbcbdbe68fd1a9a3f545aa5e7a143b5355 \ - --hash=sha256:a06ba7e2a2ef9131c79c49e63dad355d2d878413a0376c1727c8b9335ff731c7 \ - --hash=sha256:aebed7b50aa6ac698c90f60f854b47e48cd2252b30510e7a1feddaf5a3f72cbf \ - --hash=sha256:b172db79759f5f9bc13ef1c3ef8b9ee7b37b0247f987fbbbdaa15e4f87fd46a9 \ - --hash=sha256:b3c4ea4948d93c9c29dc01c0c23eef66f2101bf75158c291b88de6525c55c3d1 \ - --hash=sha256:b498e9e4022f93de2d5a37615200ca01297ceebbb56fe4c833f46862a490f9e3 \ - --hash=sha256:b4d41379b05528091f00e1728004f9a8d7191260f3862178b88e8fd770206318 \ - --hash=sha256:b69676845a0a66f9da30e87f48be36734d6748024b525ec4710be40194282c84 \ - --hash=sha256:c17398b709a6cce3d9fdb1595c33e356d91c098cd9486cb2cc21ea2ea418e715 \ - --hash=sha256:c380371d3c23e0eadf8ebff114445b9f970aff2010198d498d4ab4c3b41eea4f \ - --hash=sha256:cb783436e47fcf82064baca52ce748af71725d0352e1d31564cbe9c95df92b9c \ - --hash=sha256:cc1c51b846aca49a5a8b44fbba6a92d583a35c64590ad9e1e950dc88940a4297 \ - --hash=sha256:d0b181e9fa8daf1d9f2d4c547527b167cb8838fc587deabca7b5c01f97199e84 \ - --hash=sha256:d2a959c640cdeecdd2ec3136e8ea0441da59bcaf58d67e9c590740addba2cb68 \ - --hash=sha256:d5f256d49fea31f40f166a5e3131235a5d2f4b7f44520b1cf0baf1ce568ccff0 \ - --hash=sha256:d883460c43e8c6b173fef244a2341f7f7c0e9725c7fe68306e8e44ed9c8fb100 \ - --hash=sha256:d8eb7194b084b12feb19142262165832fc6ee879b945491d1c3d4660748020c4 \ - --hash=sha256:d9749313deb729f08207718d29c86246beb2ea3fdba753595b55901dee5d2fd6 \ - --hash=sha256:de66744b2bb88d5cd27e80dfc2ec9f0517d0a46d204ff98fe9e5f2864eb67657 \ - --hash=sha256:e91f61a064c92c307c5a9dc8c05dc9f8a68f0a3be199d9a002a0622e13f874a1 \ - --hash=sha256:f19410b486fdd139885ace124e57f938c1e6a3210ea13dd29cab58f5d4bc12c7 \ - --hash=sha256:f79d5de970fc90cd5591f60053aecfce1fcd736e0303d9f0bf86be649fa68fb8 \ - --hash=sha256:fba2974df0bf8ce3c995fa84b79cde38326e0f7b5409e7a3a481c1141340bcf7 +matplotlib==3.10.8 \ + --hash=sha256:00270d217d6b20d14b584c521f810d60c5c78406dc289859776550df837dcda7 \ + --hash=sha256:0a33deb84c15ede243aead39f77e990469fff93ad1521163305095b77b72ce4a \ + --hash=sha256:113bb52413ea508ce954a02c10ffd0d565f9c3bc7f2eddc27dfe1731e71c7b5f \ + --hash=sha256:12d90df9183093fcd479f4172ac26b322b1248b15729cb57f42f71f24c7e37a3 \ + --hash=sha256:15d30132718972c2c074cd14638c7f4592bd98719e2308bccea40e0538bc0cb5 \ + --hash=sha256:18821ace09c763ec93aef5eeff087ee493a24051936d7b9ebcad9662f66501f9 \ + --hash=sha256:1ae029229a57cd1e8fe542485f27e7ca7b23aa9e8944ddb4985d0bc444f1eca2 \ + --hash=sha256:2299372c19d56bcd35cf05a2738308758d32b9eaed2371898d8f5bd33f084aa3 \ + --hash=sha256:238b7ce5717600615c895050239ec955d91f321c209dd110db988500558e70d6 \ + --hash=sha256:24d50994d8c5816ddc35411e50a86ab05f575e2530c02752e02538122613371f \ + --hash=sha256:25d380fe8b1dc32cf8f0b1b448470a77afb195438bafdf1d858bfb876f3edf7b \ + --hash=sha256:2c1998e92cd5999e295a731bcb2911c75f597d937341f3030cc24ef2733d78a8 \ + --hash=sha256:2cf5bd12cecf46908f286d7838b2abc6c91cda506c0445b8223a7c19a00df008 \ + --hash=sha256:32f8dce744be5569bebe789e46727946041199030db8aeb2954d26013a0eb26b \ + --hash=sha256:37b3c1cc42aa184b3f738cfa18c1c1d72fd496d85467a6cf7b807936d39aa656 \ + --hash=sha256:3a48a78d2786784cc2413e57397981fb45c79e968d99656706018d6e62e57958 \ + --hash=sha256:3ab4aabc72de4ff77b3ec33a6d78a68227bf1123465887f9905ba79184a1cc04 \ + --hash=sha256:3c624e43ed56313651bc18a47f838b60d7b8032ed348911c54906b130b20071b \ + --hash=sha256:3f2e409836d7f5ac2f1c013110a4d50b9f7edc26328c108915f9075d7d7a91b6 \ + --hash=sha256:3f5c3e4da343bba819f0234186b9004faba952cc420fbc522dc4e103c1985908 \ + --hash=sha256:41703cc95688f2516b480f7f339d8851a6035f18e100ee6a32bc0b8536a12a9c \ + --hash=sha256:495672de149445ec1b772ff2c9ede9b769e3cb4f0d0aa7fa730d7f59e2d4e1c1 \ + --hash=sha256:4cf267add95b1c88300d96ca837833d4112756045364f5c734a2276038dae27d \ + --hash=sha256:56271f3dac49a88d7fca5060f004d9d22b865f743a12a23b1e937a0be4818ee1 \ + --hash=sha256:595ba4d8fe983b88f0eec8c26a241e16d6376fe1979086232f481f8f3f67494c \ + --hash=sha256:5f62550b9a30afde8c1c3ae450e5eb547d579dd69b25c2fc7a1c67f934c1717a \ + --hash=sha256:646d95230efb9ca614a7a594d4fcacde0ac61d25e37dd51710b36477594963ce \ + --hash=sha256:64fcc24778ca0404ce0cb7b6b77ae1f4c7231cdd60e6778f999ee05cbd581b9a \ + --hash=sha256:6be43b667360fef5c754dda5d25a32e6307a03c204f3c0fc5468b78fa87b4160 \ + --hash=sha256:6da7c2ce169267d0d066adcf63758f0604aa6c3eebf67458930f9d9b79ad1db1 \ + --hash=sha256:83d282364ea9f3e52363da262ce32a09dfe241e4080dcedda3c0db059d3c1f11 \ + --hash=sha256:9153c3292705be9f9c64498a8872118540c3f4123d1a1c840172edf262c8be4a \ + --hash=sha256:99eefd13c0dc3b3c1b4d561c1169e65fe47aab7b8158754d7c084088e2329466 \ + --hash=sha256:a0a7f52498f72f13d4a25ea70f35f4cb60642b466cbb0a9be951b5bc3f45a486 \ + --hash=sha256:a2b336e2d91a3d7006864e0990c83b216fcdca64b5a6484912902cef87313d78 \ + --hash=sha256:a48f2b74020919552ea25d222d5cc6af9ca3f4eb43a93e14d068457f545c2a17 \ + --hash=sha256:ad3d9833a64cf48cc4300f2b406c3d0f4f4724a91c0bd5640678a6ba7c102077 \ + --hash=sha256:b44d07310e404ba95f8c25aa5536f154c0a8ec473303535949e52eb71d0a1565 \ + --hash=sha256:b53285e65d4fa4c86399979e956235deb900be5baa7fc1218ea67fbfaeaadd6f \ + --hash=sha256:b5a2b97dbdc7d4f353ebf343744f1d1f1cca8aa8bfddb4262fcf4306c3761d50 \ + --hash=sha256:b9a5ca4ac220a0cdd1ba6bcba3608547117d30468fefce49bb26f55c1a3d5c58 \ + --hash=sha256:bab485bcf8b1c7d2060b4fcb6fc368a9e6f4cd754c9c2fea281f4be21df394a2 \ + --hash=sha256:c108a1d6fa78a50646029cb6d49808ff0fc1330fda87fa6f6250c6b5369b6645 \ + --hash=sha256:d56a1efd5bfd61486c8bc968fa18734464556f0fb8e51690f4ac25d85cbbbbc2 \ + --hash=sha256:d9050fee89a89ed57b4fb2c1bfac9a3d0c57a0d55aed95949eedbc42070fea39 \ + --hash=sha256:dd80ecb295460a5d9d260df63c43f4afbdd832d725a531f008dad1664f458adf \ + --hash=sha256:e8ea3e2d4066083e264e75c829078f9e149fa119d27e19acd503de65e0b13149 \ + --hash=sha256:eb3823f11823deade26ce3b9f40dcb4a213da7a670013929f31d5f5ed1055b22 \ + --hash=sha256:ee40c27c795bda6a5292e9cff9890189d32f7e3a0bf04e0e3c9430c4a00c37df \ + --hash=sha256:efb30e3baaea72ce5928e32bab719ab4770099079d66726a62b11b1ef7273be4 \ + --hash=sha256:f254d118d14a7f99d616271d6c3c27922c092dac11112670b157798b89bf4933 \ + --hash=sha256:f89c151aab2e2e23cb3fe0acad1e8b82841fd265379c4cecd0f3fcb34c15e0f6 \ + --hash=sha256:f97aeb209c3d2511443f8797e3e5a569aebb040d4f8bc79aa3ee78a8fb9e3dd8 \ + --hash=sha256:f9b587c9c7274c1613a30afabf65a272114cd6cdbe67b3406f818c79d7ab2e2a \ + --hash=sha256:fb061f596dad3a0f52b60dc6a5dec4a0c300dec41e058a7efe09256188d170b7 # via -r build/test-requirements.txt mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ @@ -552,9 +544,9 @@ numpy==2.2.6 ; python_version == "3.13" \ # numpy-typing-compat # optype # scipy -numpy-typing-compat==20250818.2.2 \ - --hash=sha256:84f50c86908bf796857180856f1acb7da3c5bf22f461558de1cd225128c028ba \ - --hash=sha256:8b6c551952fd46e887ee905e75b6e4977d97defe1c63ae1b516343e9913e1534 +numpy-typing-compat==20251206.2.2 \ + --hash=sha256:93c9442985ef73dc5a18d29d6bc0f7d47a9afe95372d0a9fc68ca4802ea7ad86 \ + --hash=sha256:9d5bf8bca75a27ee1254fea5a2a783b5c862dd9f3e726d12bd4b6143932effd2 # via optype nvidia-cublas==13.1.0.3 ; sys_platform == "linux" \ --hash=sha256:2a3b94a37def342471c59fad7856caee4926809a72dd5270155d6a31b5b277be \ @@ -710,9 +702,9 @@ opt-einsum==3.4.0 \ --hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \ --hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac # via -r build/requirements.in -optype[numpy]==0.14.0 \ - --hash=sha256:50d02edafd04edf2e5e27d6249760a51b2198adb9f6ffd778030b3d2806b026b \ - --hash=sha256:925cf060b7d1337647f880401f6094321e7d8e837533b8e159b9a92afa3157c6 +optype[numpy]==0.15.0 \ + --hash=sha256:457d6ca9e7da19967ec16d42bdf94e240b33b5d70a56fbbf5b427e5ea39cf41e \ + --hash=sha256:caba40ece9ea39b499fa76c036a82e0d452a432dd4dd3e8e0d30892be2e8c76c # via scipy-stubs packaging==25.0 \ --hash=sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484 \ @@ -948,9 +940,9 @@ scipy==1.16.3 ; python_version >= "3.13" \ # via # -r build/requirements.in # jaxlib -scipy-stubs==1.16.3.0 \ - --hash=sha256:90e5d82ced2183ef3c5c0a28a77df8cc227458624364fa0ff975ad24fa89d6ad \ - --hash=sha256:d6943c085e47a1ed431309f9ca582b6a206a9db808a036132a0bf01ebc34b506 +scipy-stubs==1.16.3.3 \ + --hash=sha256:af47578875d5557567225a16ec1b9b38a48c4c4377d92396413ebd65406c44ee \ + --hash=sha256:f6316b36cd0fb272c994ae5b10c4a73c644a7e156ed8d32bcd9c35303d0e1b7e # via -r build/test-requirements.txt six==1.17.0 \ --hash=sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274 \ diff --git a/build/requirements_lock_3_14.txt b/build/requirements_lock_3_14.txt index 3d2bf8d0cc96..24ed1fec57a6 100644 --- a/build/requirements_lock_3_14.txt +++ b/build/requirements_lock_3_14.txt @@ -116,65 +116,65 @@ execnet==2.1.2 \ --hash=sha256:63d83bfdd9a23e35b9c6a3261412324f964c2ec8dcd8d3c6916ee9373e0befcd \ --hash=sha256:67fba928dd5a544b783f6056f449e5e3931a5c378b128bc18501f7ea79e296ec # via pytest-xdist -filelock==3.20.0 \ - --hash=sha256:339b4732ffda5cd79b13f4e2711a31b0365ce445d95d243bb996273d072546a2 \ - --hash=sha256:711e943b4ec6be42e1d4e6690b48dc175c822967466bb31c0c293f34334c13f4 +filelock==3.20.1 \ + --hash=sha256:15d9e9a67306188a44baa72f569d2bfd803076269365fdea0934385da4dc361a \ + --hash=sha256:b8360948b351b80f420878d8516519a2204b07aefcdcfd24912a5d33127f188c # via -r build/test-requirements.txt flatbuffers==25.9.23 \ --hash=sha256:255538574d6cb6d0a79a17ec8bc0d30985913b87513a01cce8bcdb6b4c44d0e2 \ --hash=sha256:676f9fa62750bb50cf531b42a0a2a118ad8f7f797a511eda12881c016f093b12 # via -r build/test-requirements.txt -fonttools==4.61.0 \ - --hash=sha256:0011d640afa61053bc6590f9a3394bd222de7cfde19346588beabac374e9d8ac \ - --hash=sha256:02bdf8e04d1a70476564b8640380f04bb4ac74edc1fc71f1bacb840b3e398ee9 \ - --hash=sha256:0bdcf2e29d65c26299cc3d502f4612365e8b90a939f46cd92d037b6cb7bb544a \ - --hash=sha256:13e3e20a5463bfeb77b3557d04b30bd6a96a6bb5c15c7b2e7908903e69d437a0 \ - --hash=sha256:14a290c5c93fcab76b7f451e6a4b7721b712d90b3b5ed6908f1abcf794e90d6d \ - --hash=sha256:14fafda386377b6131d9e448af42d0926bad47e038de0e5ba1d58c25d621f028 \ - --hash=sha256:1cfa2eb9bae650e58f0e8ad53c49d19a844d6034d6b259f30f197238abc1ccee \ - --hash=sha256:276f14c560e6f98d24ef7f5f44438e55ff5a67f78fa85236b218462c9f5d0635 \ - --hash=sha256:2cb5e45a824ce14b90510024d0d39dae51bd4fbb54c42a9334ea8c8cf4d95cbe \ - --hash=sha256:2de14557d113faa5fb519f7f29c3abe4d69c17fe6a5a2595cc8cda7338029219 \ - --hash=sha256:2f0bafc8a3b3749c69cc610e5aa3da832d39c2a37a68f03d18ec9a02ecaac04a \ - --hash=sha256:328a9c227984bebaf69f3ac9062265f8f6acc7ddf2e4e344c63358579af0aa3d \ - --hash=sha256:3b2065d94e5d63aafc2591c8b6ccbdb511001d9619f1bca8ad39b745ebeb5efa \ - --hash=sha256:4238120002e68296d55e091411c09eab94e111c8ce64716d17df53fd0eb3bb3d \ - --hash=sha256:46cb3d9279f758ac0cf671dc3482da877104b65682679f01b246515db03dbb72 \ - --hash=sha256:58b4f1b78dfbfe855bb8a6801b31b8cdcca0e2847ec769ad8e0b0b692832dd3b \ - --hash=sha256:59587bbe455dbdf75354a9dbca1697a35a8903e01fab4248d6b98a17032cee52 \ - --hash=sha256:5a9b78da5d5faa17e63b2404b77feeae105c1b7e75f26020ab7a27b76e02039f \ - --hash=sha256:627216062d90ab0d98215176d8b9562c4dd5b61271d35f130bcd30f6a8aaa33a \ - --hash=sha256:63c7125d31abe3e61d7bb917329b5543c5b3448db95f24081a13aaf064360fc8 \ - --hash=sha256:6781e7a4bb010be1cd69a29927b0305c86b843395f2613bdabe115f7d6ea7f34 \ - --hash=sha256:67d841aa272be5500de7f447c40d1d8452783af33b4c3599899319f6ef9ad3c1 \ - --hash=sha256:68704a8bbe0b61976262b255e90cde593dc0fe3676542d9b4d846bad2a890a76 \ - --hash=sha256:6b493c32d2555e9944ec1b911ea649ff8f01a649ad9cba6c118d6798e932b3f0 \ - --hash=sha256:6e5ca8c62efdec7972dfdfd454415c4db49b89aeaefaaacada432f3b7eea9866 \ - --hash=sha256:70e2a0c0182ee75e493ef33061bfebf140ea57e035481d2f95aa03b66c7a0e05 \ - --hash=sha256:787ef9dfd1ea9fe49573c272412ae5f479d78e671981819538143bec65863865 \ - --hash=sha256:7b446623c9cd5f14a59493818eaa80255eec2468c27d2c01b56e05357c263195 \ - --hash=sha256:7fb5b84f48a6a733ca3d7f41aa9551908ccabe8669ffe79586560abcc00a9cfd \ - --hash=sha256:9064b0f55b947e929ac669af5311ab1f26f750214db6dd9a0c97e091e918f486 \ - --hash=sha256:96dfc9bc1f2302224e48e6ee37e656eddbab810b724b52e9d9c13a57a6abad01 \ - --hash=sha256:9821ed77bb676736b88fa87a737c97b6af06e8109667e625a4f00158540ce044 \ - --hash=sha256:a32a16951cbf113d38f1dd8551b277b6e06e0f6f776fece0f99f746d739e1be3 \ - --hash=sha256:a5c5fff72bf31b0e558ed085e4fd7ed96eb85881404ecc39ed2a779e7cf724eb \ - --hash=sha256:ad751319dc532a79bdf628b8439af167181b4210a0cd28a8935ca615d9fdd727 \ - --hash=sha256:adbb4ecee1a779469a77377bbe490565effe8fce6fb2e6f95f064de58f8bac85 \ - --hash=sha256:b2b734d8391afe3c682320840c8191de9bd24e7eb85768dd4dc06ed1b63dbb1b \ - --hash=sha256:b5ca59b7417d149cf24e4c1933c9f44b2957424fc03536f132346d5242e0ebe5 \ - --hash=sha256:b6ceac262cc62bec01b3bb59abccf41b24ef6580869e306a4e88b7e56bb4bdda \ - --hash=sha256:ba774b8cbd8754f54b8eb58124e8bd45f736b2743325ab1a5229698942b9b433 \ - --hash=sha256:c53b47834ae41e8e4829171cc44fec0fdf125545a15f6da41776b926b9645a9a \ - --hash=sha256:c84b430616ed73ce46e9cafd0bf0800e366a3e02fb7e1ad7c1e214dbe3862b1f \ - --hash=sha256:dc25a4a9c1225653e4431a9413d0381b1c62317b0f543bdcec24e1991f612f33 \ - --hash=sha256:df8cbce85cf482eb01f4551edca978c719f099c623277bda8332e5dbe7dba09d \ - --hash=sha256:e074bc07c31406f45c418e17c1722e83560f181d122c412fa9e815df0ff74810 \ - --hash=sha256:e0d87e81e4d869549585ba0beb3f033718501c1095004f5e6aef598d13ebc216 \ - --hash=sha256:e24a1565c4e57111ec7f4915f8981ecbb61adf66a55f378fdc00e206059fcfef \ - --hash=sha256:e2bfacb5351303cae9f072ccf3fc6ecb437a6f359c0606bae4b1ab6715201d87 \ - --hash=sha256:e6cd0d9051b8ddaf7385f99dd82ec2a058e2b46cf1f1961e68e1ff20fcbb61af \ - --hash=sha256:ec520a1f0c7758d7a858a00f090c1745f6cde6a7c5e76fb70ea4044a15f712e7 +fonttools==4.61.1 \ + --hash=sha256:0de30bfe7745c0d1ffa2b0b7048fb7123ad0d71107e10ee090fa0b16b9452e87 \ + --hash=sha256:10d88e55330e092940584774ee5e8a6971b01fc2f4d3466a1d6c158230880796 \ + --hash=sha256:11f35ad7805edba3aac1a3710d104592df59f4b957e30108ae0ba6c10b11dd75 \ + --hash=sha256:15acc09befd16a0fb8a8f62bc147e1a82817542d72184acca9ce6e0aeda9fa6d \ + --hash=sha256:17d2bf5d541add43822bcf0c43d7d847b160c9bb01d15d5007d84e2217aaa371 \ + --hash=sha256:2180f14c141d2f0f3da43f3a81bc8aa4684860f6b0e6f9e165a4831f24e6a23b \ + --hash=sha256:21e7c8d76f62ab13c9472ccf74515ca5b9a761d1bde3265152a6dc58700d895b \ + --hash=sha256:41a7170d042e8c0024703ed13b71893519a1a6d6e18e933e3ec7507a2c26a4b2 \ + --hash=sha256:41ed4b5ec103bd306bb68f81dc166e77409e5209443e5773cb4ed837bcc9b0d3 \ + --hash=sha256:497c31ce314219888c0e2fce5ad9178ca83fe5230b01a5006726cdf3ac9f24d9 \ + --hash=sha256:4c1b526c8d3f615a7b1867f38a9410849c8f4aef078535742198e942fba0e9bd \ + --hash=sha256:4d7092bb38c53bbc78e9255a59158b150bcdc115a1e3b3ce0b5f267dc35dd63c \ + --hash=sha256:4f5686e1fe5fce75d82d93c47a438a25bf0d1319d2843a926f741140b2b16e0c \ + --hash=sha256:58b0ee0ab5b1fc9921eccfe11d1435added19d6494dde14e323f25ad2bc30c56 \ + --hash=sha256:5ce02f38a754f207f2f06557523cd39a06438ba3aafc0639c477ac409fc64e37 \ + --hash=sha256:5fade934607a523614726119164ff621e8c30e8fa1ffffbbd358662056ba69f0 \ + --hash=sha256:5fe9fd43882620017add5eabb781ebfbc6998ee49b35bd7f8f79af1f9f99a958 \ + --hash=sha256:64102ca87e84261419c3747a0d20f396eb024bdbeb04c2bfb37e2891f5fadcb5 \ + --hash=sha256:664c5a68ec406f6b1547946683008576ef8b38275608e1cee6c061828171c118 \ + --hash=sha256:6675329885c44657f826ef01d9e4fb33b9158e9d93c537d84ad8399539bc6f69 \ + --hash=sha256:75c1a6dfac6abd407634420c93864a1e274ebc1c7531346d9254c0d8f6ca00f9 \ + --hash=sha256:75da8f28eff26defba42c52986de97b22106cb8f26515b7c22443ebc9c2d3261 \ + --hash=sha256:77efb033d8d7ff233385f30c62c7c79271c8885d5c9657d967ede124671bbdfb \ + --hash=sha256:78a7d3ab09dc47ac1a363a493e6112d8cabed7ba7caad5f54dbe2f08676d1b47 \ + --hash=sha256:7c7db70d57e5e1089a274cbb2b1fd635c9a24de809a231b154965d415d6c6d24 \ + --hash=sha256:8c56c488ab471628ff3bfa80964372fc13504ece601e0d97a78ee74126b2045c \ + --hash=sha256:91669ccac46bbc1d09e9273546181919064e8df73488ea087dcac3e2968df9ba \ + --hash=sha256:9b666a475a65f4e839d3d10473fad6d47e0a9db14a2f4a224029c5bfde58ad2c \ + --hash=sha256:9cfef3ab326780c04d6646f68d4b4742aae222e8b8ea1d627c74e38afcbc9d91 \ + --hash=sha256:a13fc8aeb24bad755eea8f7f9d409438eb94e82cf86b08fe77a03fbc8f6a96b1 \ + --hash=sha256:a75c301f96db737e1c5ed5fd7d77d9c34466de16095a266509e13da09751bd19 \ + --hash=sha256:a76d4cb80f41ba94a6691264be76435e5f72f2cb3cab0b092a6212855f71c2f6 \ + --hash=sha256:aed04cabe26f30c1647ef0e8fbb207516fd40fe9472e9439695f5c6998e60ac5 \ + --hash=sha256:b148b56f5de675ee16d45e769e69f87623a4944f7443850bf9a9376e628a89d2 \ + --hash=sha256:b501c862d4901792adaec7c25b1ecc749e2662543f68bb194c42ba18d6eec98d \ + --hash=sha256:b846a1fcf8beadeb9ea4f44ec5bdde393e2f1569e17d700bfc49cd69bde75881 \ + --hash=sha256:b931ae8f62db78861b0ff1ac017851764602288575d65b8e8ff1963fed419063 \ + --hash=sha256:c33ab3ca9d3ccd581d58e989d67554e42d8d4ded94ab3ade3508455fe70e65f7 \ + --hash=sha256:c6604b735bb12fef8e0efd5578c9fb5d3d8532d5001ea13a19cddf295673ee09 \ + --hash=sha256:d8db08051fc9e7d8bc622f2112511b8107d8f27cd89e2f64ec45e9825e8288da \ + --hash=sha256:d9203500f7c63545b4ce3799319fe4d9feb1a1b89b28d3cb5abd11b9dd64147e \ + --hash=sha256:dc492779501fa723b04d0ab1f5be046797fee17d27700476edc7ee9ae535a61e \ + --hash=sha256:e6bcdf33aec38d16508ce61fd81838f24c83c90a1d1b8c68982857038673d6b8 \ + --hash=sha256:e76ce097e3c57c4bcb67c5aa24a0ecdbd9f74ea9219997a707a4061fbe2707aa \ + --hash=sha256:eff1ac3cc66c2ac7cda1e64b4e2f3ffef474b7335f92fc3833fc632d595fcee6 \ + --hash=sha256:f3cb4a569029b9f291f88aafc927dd53683757e640081ca8c412781ea144565e \ + --hash=sha256:f79b168428351d11e10c5aeb61a74e1851ec221081299f4cf56036a95431c43a \ + --hash=sha256:fa646ecec9528bef693415c79a86e733c70a4965dd938e9a226b0fc64c9d2e6c \ + --hash=sha256:fe2efccb324948a11dd09d22136fe2ac8a97d6c1347cf0b58a911dcd529f66b7 \ + --hash=sha256:fff4f534200a04b4a36e7ae3cb74493afe807b517a09e99cb4faa89a34ed6ecd # via matplotlib fsspec==2025.10.0 \ --hash=sha256:7c7712353ae7d875407f97715f0e1ffcc21e33d5b24556cb1e090ae9409ec61d \ @@ -192,53 +192,53 @@ iniconfig==2.3.0 \ --hash=sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730 \ --hash=sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12 # via pytest -jax-cuda12-pjrt==0.8.1 ; sys_platform == "linux" \ - --hash=sha256:452b70ee10cb9ac5d7dfca55ffbcdb89b6c8bc6ba70a45af7c490d1dcea98eb7 \ - --hash=sha256:a631d0689903354afd7b3d2ec595b7da06a6230a76da00ff9548f542b21b6250 +jax-cuda12-pjrt==0.8.2 ; sys_platform == "linux" \ + --hash=sha256:717a1b196a642409ce195ddf031c20bbeadcc886f55e49a1d3f4927373aeedae \ + --hash=sha256:e3bab41ca7c48e4163db9e7efd271b3aa85f0fe45f5ed0708d6bbed93a59f977 # via -r build/requirements.in -jax-cuda13-pjrt==0.8.1 \ - --hash=sha256:86a6926da76aebf6080922747a7a98d321f4ca27101077357fa148032bc3cd1d \ - --hash=sha256:f3b1c1c7118b4570f72740ed756cbed289a3f8fa813570a0dbf16f186bccb8c9 +jax-cuda13-pjrt==0.8.2 \ + --hash=sha256:370e9cb9a76af6a6b00e8f6c6ae8e5a0158886b65994a114cf3acf56ea1570f3 \ + --hash=sha256:c0ebc284d4bea9c0b278017b4bb62eb871755fe582cbf22ce1fb770d30fb9af6 # via # -r build/requirements.in # jax-cuda13-plugin -jax-cuda13-plugin==0.8.1 \ - --hash=sha256:07625aed1aa769c701213e84d6b2a46902019a1d2af8a09ce6dfd9575163bfc6 \ - --hash=sha256:0d503c312d2daefea62a00c74534579deeacd46e15c364074d27a8d95a100032 \ - --hash=sha256:12a7aac712a7c6dc228ef9991578e85e3bcab7c324193bdfb2b5acf059bae6d6 \ - --hash=sha256:16ee16b13393baf9672b6612566308675cebdc8d785b61fac2b93ce8c97825ff \ - --hash=sha256:4e589ed8197f1bea7e7fd20d866ccc5c2a1276d7acd02224e3a5b07983df61e2 \ - --hash=sha256:64df1f1414d899ab7a84751d6f78515365555b54fb64b3e318bd70519de99c86 \ - --hash=sha256:7a373fd3e5f11ecad01b8add1e277eb6559b4966b0745d92dc91c585579fac35 \ - --hash=sha256:92238530152890c3405addacd1fc021c87022cbf99fa66418cfa2e9f68a5c49d \ - --hash=sha256:a4c5a4a69346be6520c729675d5d80e85d610399f4840d74bdfae9c6ebedc8bc \ - --hash=sha256:af33f737ccf5426155cf5c7d175bf765ca25724b94af5109ef2df891b410f997 \ - --hash=sha256:d81222989fb30496cc6554d42deaca6ab003721f56ff669da7f80d61fea1219d \ - --hash=sha256:ef218c47b2cde8c700ad2a56d04320f9d1490439fc6db20747f56e91de7289c2 +jax-cuda13-plugin==0.8.2 \ + --hash=sha256:0ab3a13f798f973416754cb00b1696a8e6d2425c39496e2145293661bcd28446 \ + --hash=sha256:2c7532f2a6c9dcf50ba4b8ebbe4c78988b8c260695e50f9fdbd68f9417a8bb51 \ + --hash=sha256:358f4b4d0d2b92f62fe7f07d355a9c57c5684da3d4e6f483afec38b23e66f02e \ + --hash=sha256:6703552ee10b6b41b43fe4fc62b4672bec13d271eb709a29bf14b0fa1ec805dd \ + --hash=sha256:8fe2c5874cbae18a01a0fa722dfbd547b5504e7bf42aaaae6ad3e1b676e8daa2 \ + --hash=sha256:950f2c1f2e0f5e13893c4c27a4df540a96a26ebd15d7954fa85ffb8cc8b1d9a1 \ + --hash=sha256:9835076d0a917345700370cfaee5cfd7876c0b0982e99206a8b4848af0ea6e5e \ + --hash=sha256:bb059d99eb960718358f82f3eaa0829f96fc7d7e978ded821f2c40256a8aee4b \ + --hash=sha256:d5f9470db46aee8a64e9ab7817927335b77c6d3698d5113e0da2dd962c1468c4 \ + --hash=sha256:f1643b216ccb41a2d7c0e240359c729c96c5eca76e4525ffcd6313049e348b9a \ + --hash=sha256:f35f687d9b839365a49c2b0052eaeb6b9f0a9879813b4fa020f8d9caa3ddde46 \ + --hash=sha256:fa6a6faf3130d6ef05f9ee6cbc27e0a2e4db0f23720136403991eccb3a0144af # via -r build/requirements.in -jaxlib==0.8.1 \ - --hash=sha256:117f2fe2c19479e560ad85a3ef2fcc0b1d24816456f0d039f865c2acbab63b5a \ - --hash=sha256:1a4001ed3ba9ed5a812da1b16f52eebb5d473a4480c1523828c7bd3dae8d1375 \ - --hash=sha256:1bc76edec2bc74a7adb5e29329ece51a67c57cd011a06d55d07da62fbabe3389 \ - --hash=sha256:22f489fb5c8be0da7be5e4957a10936b3760a169668f8b25c5d09c51c3ef47f6 \ - --hash=sha256:24ec3f3a9c45d6de060020dc94c444d69e18099fab927ea3979ff8cedf0ed2c9 \ - --hash=sha256:4933298fcfb07a5aa2d1fed21c111d07cea50e6f180dba2cdb5463c13fb98f2f \ - --hash=sha256:63fc25c4b5d03256798796a024125e29bcf254acc3eae5dc3239d1c30b86b866 \ - --hash=sha256:7a5d381fad89622750fae29fab83c0847e2931ad8d6a34dc13b28fc4d67f75a3 \ - --hash=sha256:865add56139883405f3f15c9b0de6a64ab8f4aa549dff196b72dbc86be6ccc1f \ - --hash=sha256:88bde0f535eeea6689e0cd57d40b7660d5206ac95c7d42e09562a109b963a49f \ - --hash=sha256:8e118e1fbe714f37a94ba26777c17faab7dca4a33646a3d98cd1d99673bbd6b1 \ - --hash=sha256:90e48973f8dbded7edc8728be84c01ae00412190187fb06622abfa4edd42c0a8 \ - --hash=sha256:92c41c9b9862c08521eb90515a7c5bcc840c6d30f86230cebf94aea2d6a0af81 \ - --hash=sha256:a0349f6e8179dc897d33aeb90ec66b4a8041330fbbba8d071dc6167cd2271539 \ - --hash=sha256:af4924189fc53b69237715b56ebcbfc71bb91ca16184143dcef0d430c8173de6 \ - --hash=sha256:bd697c171ace1e2e9d6ed910a78f385b3c4095cee290b0255aa58848f2acdeab \ - --hash=sha256:bed1e94ae8c7c16bca4476d8d7f582f0d1a102a4e69c3a9bd2069a0dc42274a9 \ - --hash=sha256:c14c8c19a7eb694aa14092b6d2fffb9d2bdd8a603b63d6f26fbeaf129c204f9f \ - --hash=sha256:d245bd6a279c72ca5f796df84cdd64d7c9c8abc4b8d89adf4acf45898dab958b \ - --hash=sha256:f2f11491b077d05249d63813e811401194a41edc8e9cc60af8f4b554057cfad0 \ - --hash=sha256:fdbbf2336c08bbf8f30548e204c8c9d77f8b2a3a5b7fc7985749246feb8852b0 \ - --hash=sha256:ff32b6320d729131efaf22939825b52d75957c84c32af2b0b1bdb33cf27ba75f +jaxlib==0.8.2 \ + --hash=sha256:023de6f3f56da2af7037970996500586331fdb50b530ecbb54b9666da633bd00 \ + --hash=sha256:05b958f497e49824c432e734bb059723b7dfe69e2ad696a9f9c8ad82fff7c3f8 \ + --hash=sha256:1bfbcf6c3de221784fa4cdb6765a09d71cb4298b15626b3d0409b3dfcd8a8667 \ + --hash=sha256:28eec1a4e0639a0d8702cea3cb70dd3663053dbfa344452994ea48dc6ceadaa5 \ + --hash=sha256:2b9789bd08f8b0cc5a5c12ae896fe432d5942e32e417091b8b5a96a9a6fd5cf1 \ + --hash=sha256:3b16e50c5b730c9dd0a49e55f1acfaa722b00b1af0522a591558dcc0464252f2 \ + --hash=sha256:490bf0cb029c73c65c9431124b86cdc95082dbc1fb76fc549d24d75da33e5454 \ + --hash=sha256:4d006db96be020c8165212a1216372f8acac4ff4f8fb067743d694ef2b301ace \ + --hash=sha256:68108dff0de74adc468016be9a19f80efe48c660c0d5a122287094b44b092afc \ + --hash=sha256:7c304f3a016965b9d1f5239a8a0399a73925f5604fe914c5ca66ecf734bf6422 \ + --hash=sha256:7da8127557c786264049ae55460d1b8d04cc3cdf0403a087f2fc1e6d313ec722 \ + --hash=sha256:964626f581beab31ee6826b228fcc2ec5181b05cecf94a528dff97921c145dbc \ + --hash=sha256:a397ea7dcb37d689ce79173eeb99b2f1347637a36be9a27f20ae6848bfc58bfc \ + --hash=sha256:aa8701b6356f098e8452c3cec762fb5f706fcb8f67ffd65964f63982479aa23b \ + --hash=sha256:bb89be452b1b808d3f88fc01c415b364a260be4cc7ac120c038009f6150a32dc \ + --hash=sha256:beffb004e7eeb5c9afb24439e2b2cf45a4ee3e3e8adf45e355edf2af62acf8b8 \ + --hash=sha256:ccf77da917a20935247c990691decfcbdd06c25ef0ac94d914a04aadb22f714c \ + --hash=sha256:dffc22b5b732b9556d92c918b251c61bcc046617c4dbb51e1f7a656587fddffb \ + --hash=sha256:e6a97dfb0232eed9a2bb6e3828e4f682dbac1a7fea840bfda574cae2dbf5faf9 \ + --hash=sha256:f205e91c3a152a2a76c0bc59a6a2de03e87ec261b91e8812922777185e7b08f5 \ + --hash=sha256:f28edac8c226fc07fa3e8af6f9defede8ac2c307429e3291edce8739d39becc9 \ + --hash=sha256:f472cc72e3058e50b5f0230b236d5a1183bf6c3d5423d2a52eff07bcf34908de # via -r build/requirements.in kiwisolver==1.4.9 \ --hash=sha256:0749fd8f4218ad2e851e11cc4dc05c7cbc0cbc4267bdfdb31782e65aace4ee9c \ @@ -343,74 +343,74 @@ kiwisolver==1.4.9 \ --hash=sha256:fb940820c63a9590d31d88b815e7a3aa5915cad3ce735ab45f0c730b39547de1 \ --hash=sha256:fc1795ac5cd0510207482c3d1d3ed781143383b8cfd36f5c645f3897ce066220 # via matplotlib -libtpu==0.0.30 ; sys_platform == "linux" and platform_machine == "x86_64" \ - --hash=sha256:26442f0a51d243cf7259407bba8f5d849c9024297efe97044d64b5244283ad63 \ - --hash=sha256:5fabff9a041674bb889fb59ac0b5c54b9dbcf492a8c782e083ef86a8194dbb0f \ - --hash=sha256:8be30562743a63c1c1353e7ba78f0dbfbb051e8d1e9d3bb2b5da9b720363bb0a \ - --hash=sha256:b1fc44915dad56c0ceb733311a4d4396b88dc9a1c7c01acd7617da90e7ec22f2 \ - --hash=sha256:babab04ca663da2c4e4b3ab036c4d465f2f4674c480d08239c5d4965b7ce9e1c \ - --hash=sha256:f9aa040895ec25fafebcd4e1a0e1a9524ff3bd778ca88543731e308f6e516dd1 +libtpu==0.0.32 ; sys_platform == "linux" and platform_machine == "x86_64" \ + --hash=sha256:37f6aefe6d69d24e5e268e74c1e90cd0ce0fa6a796720709445d8938605912ee \ + --hash=sha256:4a4ae6db0a90f0e33902cd345923a5d77dfc5bbab9a3e524c79bf876858104ee \ + --hash=sha256:6453995774b902e43d1caf0d82298a2e90f2f731ddd74d6bef9510d91732775f \ + --hash=sha256:7ffd9cdb89da22d32bb60ce70b88e9e76861c5d704ebf673f32987e19e9913f3 \ + --hash=sha256:98d9713a39d9581c9b10a7bdeaa575cc874a4a7f543cfb40d3a723ee27b428b5 \ + --hash=sha256:d2a0e7abe4a029b79049bc95da20e4c2a64f8ef09823f75286cecfbb4b0b6da1 # via -r build/requirements.in markdown-it-py==4.0.0 \ --hash=sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147 \ --hash=sha256:cb0a2b4aa34f932c007117b194e945bd74e0ec24133ceb5bac59009cda1cb9f3 # via rich -matplotlib==3.10.7 \ - --hash=sha256:07124afcf7a6504eafcb8ce94091c5898bbdd351519a1beb5c45f7a38c67e77f \ - --hash=sha256:09d7945a70ea43bf9248f4b6582734c2fe726723204a76eca233f24cffc7ef67 \ - --hash=sha256:0d8c32b7ea6fb80b1aeff5a2ceb3fb9778e2759e899d9beff75584714afcc5ee \ - --hash=sha256:11ae579ac83cdf3fb72573bb89f70e0534de05266728740d478f0f818983c695 \ - --hash=sha256:15112bcbaef211bd663fa935ec33313b948e214454d949b723998a43357b17b0 \ - --hash=sha256:1d9d3713a237970569156cfb4de7533b7c4eacdd61789726f444f96a0d28f57f \ - --hash=sha256:1e4bbad66c177a8fdfa53972e5ef8be72a5f27e6a607cec0d8579abd0f3102b1 \ - --hash=sha256:2222c7ba2cbde7fe63032769f6eb7e83ab3227f47d997a8453377709b7fe3a5a \ - --hash=sha256:22df30ffaa89f6643206cf13877191c63a50e8f800b038bc39bee9d2d4957632 \ - --hash=sha256:31963603041634ce1a96053047b40961f7a29eb8f9a62e80cc2c0427aa1d22a2 \ - --hash=sha256:37a1fea41153dd6ee061d21ab69c9cf2cf543160b1b85d89cd3d2e2a7902ca4c \ - --hash=sha256:3886e47f64611046bc1db523a09dd0a0a6bed6081e6f90e13806dd1d1d1b5e91 \ - --hash=sha256:4645fc5d9d20ffa3a39361fcdbcec731382763b623b72627806bf251b6388866 \ - --hash=sha256:4a11c2e9e72e7de09b7b72e62f3df23317c888299c875e2b778abf1eda8c0a42 \ - --hash=sha256:4a74f79fafb2e177f240579bc83f0b60f82cc47d2f1d260f422a0627207008ca \ - --hash=sha256:4c14b6acd16cddc3569a2d515cfdd81c7a68ac5639b76548cfc1a9e48b20eb65 \ - --hash=sha256:53b492410a6cd66c7a471de6c924f6ede976e963c0f3097a3b7abfadddc67d0a \ - --hash=sha256:53cc80662dd197ece414dd5b66e07370201515a3eaf52e7c518c68c16814773b \ - --hash=sha256:5c09cf8f2793f81368f49f118b6f9f937456362bee282eac575cca7f84cda537 \ - --hash=sha256:5e38c2d581d62ee729a6e144c47a71b3f42fb4187508dbbf4fe71d5612c3433b \ - --hash=sha256:5f3f6d315dcc176ba7ca6e74c7768fb7e4cf566c49cb143f6bc257b62e634ed8 \ - --hash=sha256:6516ce375109c60ceec579e699524e9d504cd7578506f01150f7a6bc174a775e \ - --hash=sha256:667ecd5d8d37813a845053d8f5bf110b534c3c9f30e69ebd25d4701385935a6d \ - --hash=sha256:6f1851eab59ca082c95df5a500106bad73672645625e04538b3ad0f69471ffcc \ - --hash=sha256:702590829c30aada1e8cef0568ddbffa77ca747b4d6e36c6d173f66e301f89cc \ - --hash=sha256:7146d64f561498764561e9cd0ed64fcf582e570fc519e6f521e2d0cfd43365e1 \ - --hash=sha256:744991e0cc863dd669c8dc9136ca4e6e0082be2070b9d793cbd64bec872a6815 \ - --hash=sha256:786656bb13c237bbcebcd402f65f44dd61ead60ee3deb045af429d889c8dbc67 \ - --hash=sha256:7a0edb7209e21840e8361e91ea84ea676658aa93edd5f8762793dec77a4a6748 \ - --hash=sha256:7ac81eee3b7c266dd92cee1cd658407b16c57eed08c7421fa354ed68234de380 \ - --hash=sha256:90ad854c0a435da3104c01e2c6f0028d7e719b690998a2333d7218db80950722 \ - --hash=sha256:9257be2f2a03415f9105c486d304a321168e61ad450f6153d77c69504ad764bb \ - --hash=sha256:932c55d1fa7af4423422cb6a492a31cbcbdbe68fd1a9a3f545aa5e7a143b5355 \ - --hash=sha256:a06ba7e2a2ef9131c79c49e63dad355d2d878413a0376c1727c8b9335ff731c7 \ - --hash=sha256:aebed7b50aa6ac698c90f60f854b47e48cd2252b30510e7a1feddaf5a3f72cbf \ - --hash=sha256:b172db79759f5f9bc13ef1c3ef8b9ee7b37b0247f987fbbbdaa15e4f87fd46a9 \ - --hash=sha256:b3c4ea4948d93c9c29dc01c0c23eef66f2101bf75158c291b88de6525c55c3d1 \ - --hash=sha256:b498e9e4022f93de2d5a37615200ca01297ceebbb56fe4c833f46862a490f9e3 \ - --hash=sha256:b4d41379b05528091f00e1728004f9a8d7191260f3862178b88e8fd770206318 \ - --hash=sha256:b69676845a0a66f9da30e87f48be36734d6748024b525ec4710be40194282c84 \ - --hash=sha256:c17398b709a6cce3d9fdb1595c33e356d91c098cd9486cb2cc21ea2ea418e715 \ - --hash=sha256:c380371d3c23e0eadf8ebff114445b9f970aff2010198d498d4ab4c3b41eea4f \ - --hash=sha256:cb783436e47fcf82064baca52ce748af71725d0352e1d31564cbe9c95df92b9c \ - --hash=sha256:cc1c51b846aca49a5a8b44fbba6a92d583a35c64590ad9e1e950dc88940a4297 \ - --hash=sha256:d0b181e9fa8daf1d9f2d4c547527b167cb8838fc587deabca7b5c01f97199e84 \ - --hash=sha256:d2a959c640cdeecdd2ec3136e8ea0441da59bcaf58d67e9c590740addba2cb68 \ - --hash=sha256:d5f256d49fea31f40f166a5e3131235a5d2f4b7f44520b1cf0baf1ce568ccff0 \ - --hash=sha256:d883460c43e8c6b173fef244a2341f7f7c0e9725c7fe68306e8e44ed9c8fb100 \ - --hash=sha256:d8eb7194b084b12feb19142262165832fc6ee879b945491d1c3d4660748020c4 \ - --hash=sha256:d9749313deb729f08207718d29c86246beb2ea3fdba753595b55901dee5d2fd6 \ - --hash=sha256:de66744b2bb88d5cd27e80dfc2ec9f0517d0a46d204ff98fe9e5f2864eb67657 \ - --hash=sha256:e91f61a064c92c307c5a9dc8c05dc9f8a68f0a3be199d9a002a0622e13f874a1 \ - --hash=sha256:f19410b486fdd139885ace124e57f938c1e6a3210ea13dd29cab58f5d4bc12c7 \ - --hash=sha256:f79d5de970fc90cd5591f60053aecfce1fcd736e0303d9f0bf86be649fa68fb8 \ - --hash=sha256:fba2974df0bf8ce3c995fa84b79cde38326e0f7b5409e7a3a481c1141340bcf7 +matplotlib==3.10.8 \ + --hash=sha256:00270d217d6b20d14b584c521f810d60c5c78406dc289859776550df837dcda7 \ + --hash=sha256:0a33deb84c15ede243aead39f77e990469fff93ad1521163305095b77b72ce4a \ + --hash=sha256:113bb52413ea508ce954a02c10ffd0d565f9c3bc7f2eddc27dfe1731e71c7b5f \ + --hash=sha256:12d90df9183093fcd479f4172ac26b322b1248b15729cb57f42f71f24c7e37a3 \ + --hash=sha256:15d30132718972c2c074cd14638c7f4592bd98719e2308bccea40e0538bc0cb5 \ + --hash=sha256:18821ace09c763ec93aef5eeff087ee493a24051936d7b9ebcad9662f66501f9 \ + --hash=sha256:1ae029229a57cd1e8fe542485f27e7ca7b23aa9e8944ddb4985d0bc444f1eca2 \ + --hash=sha256:2299372c19d56bcd35cf05a2738308758d32b9eaed2371898d8f5bd33f084aa3 \ + --hash=sha256:238b7ce5717600615c895050239ec955d91f321c209dd110db988500558e70d6 \ + --hash=sha256:24d50994d8c5816ddc35411e50a86ab05f575e2530c02752e02538122613371f \ + --hash=sha256:25d380fe8b1dc32cf8f0b1b448470a77afb195438bafdf1d858bfb876f3edf7b \ + --hash=sha256:2c1998e92cd5999e295a731bcb2911c75f597d937341f3030cc24ef2733d78a8 \ + --hash=sha256:2cf5bd12cecf46908f286d7838b2abc6c91cda506c0445b8223a7c19a00df008 \ + --hash=sha256:32f8dce744be5569bebe789e46727946041199030db8aeb2954d26013a0eb26b \ + --hash=sha256:37b3c1cc42aa184b3f738cfa18c1c1d72fd496d85467a6cf7b807936d39aa656 \ + --hash=sha256:3a48a78d2786784cc2413e57397981fb45c79e968d99656706018d6e62e57958 \ + --hash=sha256:3ab4aabc72de4ff77b3ec33a6d78a68227bf1123465887f9905ba79184a1cc04 \ + --hash=sha256:3c624e43ed56313651bc18a47f838b60d7b8032ed348911c54906b130b20071b \ + --hash=sha256:3f2e409836d7f5ac2f1c013110a4d50b9f7edc26328c108915f9075d7d7a91b6 \ + --hash=sha256:3f5c3e4da343bba819f0234186b9004faba952cc420fbc522dc4e103c1985908 \ + --hash=sha256:41703cc95688f2516b480f7f339d8851a6035f18e100ee6a32bc0b8536a12a9c \ + --hash=sha256:495672de149445ec1b772ff2c9ede9b769e3cb4f0d0aa7fa730d7f59e2d4e1c1 \ + --hash=sha256:4cf267add95b1c88300d96ca837833d4112756045364f5c734a2276038dae27d \ + --hash=sha256:56271f3dac49a88d7fca5060f004d9d22b865f743a12a23b1e937a0be4818ee1 \ + --hash=sha256:595ba4d8fe983b88f0eec8c26a241e16d6376fe1979086232f481f8f3f67494c \ + --hash=sha256:5f62550b9a30afde8c1c3ae450e5eb547d579dd69b25c2fc7a1c67f934c1717a \ + --hash=sha256:646d95230efb9ca614a7a594d4fcacde0ac61d25e37dd51710b36477594963ce \ + --hash=sha256:64fcc24778ca0404ce0cb7b6b77ae1f4c7231cdd60e6778f999ee05cbd581b9a \ + --hash=sha256:6be43b667360fef5c754dda5d25a32e6307a03c204f3c0fc5468b78fa87b4160 \ + --hash=sha256:6da7c2ce169267d0d066adcf63758f0604aa6c3eebf67458930f9d9b79ad1db1 \ + --hash=sha256:83d282364ea9f3e52363da262ce32a09dfe241e4080dcedda3c0db059d3c1f11 \ + --hash=sha256:9153c3292705be9f9c64498a8872118540c3f4123d1a1c840172edf262c8be4a \ + --hash=sha256:99eefd13c0dc3b3c1b4d561c1169e65fe47aab7b8158754d7c084088e2329466 \ + --hash=sha256:a0a7f52498f72f13d4a25ea70f35f4cb60642b466cbb0a9be951b5bc3f45a486 \ + --hash=sha256:a2b336e2d91a3d7006864e0990c83b216fcdca64b5a6484912902cef87313d78 \ + --hash=sha256:a48f2b74020919552ea25d222d5cc6af9ca3f4eb43a93e14d068457f545c2a17 \ + --hash=sha256:ad3d9833a64cf48cc4300f2b406c3d0f4f4724a91c0bd5640678a6ba7c102077 \ + --hash=sha256:b44d07310e404ba95f8c25aa5536f154c0a8ec473303535949e52eb71d0a1565 \ + --hash=sha256:b53285e65d4fa4c86399979e956235deb900be5baa7fc1218ea67fbfaeaadd6f \ + --hash=sha256:b5a2b97dbdc7d4f353ebf343744f1d1f1cca8aa8bfddb4262fcf4306c3761d50 \ + --hash=sha256:b9a5ca4ac220a0cdd1ba6bcba3608547117d30468fefce49bb26f55c1a3d5c58 \ + --hash=sha256:bab485bcf8b1c7d2060b4fcb6fc368a9e6f4cd754c9c2fea281f4be21df394a2 \ + --hash=sha256:c108a1d6fa78a50646029cb6d49808ff0fc1330fda87fa6f6250c6b5369b6645 \ + --hash=sha256:d56a1efd5bfd61486c8bc968fa18734464556f0fb8e51690f4ac25d85cbbbbc2 \ + --hash=sha256:d9050fee89a89ed57b4fb2c1bfac9a3d0c57a0d55aed95949eedbc42070fea39 \ + --hash=sha256:dd80ecb295460a5d9d260df63c43f4afbdd832d725a531f008dad1664f458adf \ + --hash=sha256:e8ea3e2d4066083e264e75c829078f9e149fa119d27e19acd503de65e0b13149 \ + --hash=sha256:eb3823f11823deade26ce3b9f40dcb4a213da7a670013929f31d5f5ed1055b22 \ + --hash=sha256:ee40c27c795bda6a5292e9cff9890189d32f7e3a0bf04e0e3c9430c4a00c37df \ + --hash=sha256:efb30e3baaea72ce5928e32bab719ab4770099079d66726a62b11b1ef7273be4 \ + --hash=sha256:f254d118d14a7f99d616271d6c3c27922c092dac11112670b157798b89bf4933 \ + --hash=sha256:f89c151aab2e2e23cb3fe0acad1e8b82841fd265379c4cecd0f3fcb34c15e0f6 \ + --hash=sha256:f97aeb209c3d2511443f8797e3e5a569aebb040d4f8bc79aa3ee78a8fb9e3dd8 \ + --hash=sha256:f9b587c9c7274c1613a30afabf65a272114cd6cdbe67b3406f818c79d7ab2e2a \ + --hash=sha256:fb061f596dad3a0f52b60dc6a5dec4a0c300dec41e058a7efe09256188d170b7 # via -r build/test-requirements.txt mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ @@ -549,9 +549,9 @@ numpy==2.3.5 ; python_version >= "3.14" \ # optype # scipy # tensorstore -numpy-typing-compat==20250818.2.3 \ - --hash=sha256:72e83d535b635d668ba7315e43ae80be1469a6faea6fc96d312516f39b3d8fa5 \ - --hash=sha256:930413d34dd9083c0bf418815576222f1c66ea2d68950f447fd27ea1a78b26b0 +numpy-typing-compat==20251206.2.4 \ + --hash=sha256:59882d23aaff054a2536da80564012cdce33487657be4d79c5925bb8705fcabc \ + --hash=sha256:a82e723bd20efaa4cf2886709d4264c144f1f2b609bda83d1545113b7e47a5b5 # via optype nvidia-cublas==13.1.0.3 ; sys_platform == "linux" \ --hash=sha256:2a3b94a37def342471c59fad7856caee4926809a72dd5270155d6a31b5b277be \ @@ -707,9 +707,9 @@ opt-einsum==3.4.0 \ --hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \ --hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac # via -r build/requirements.in -optype[numpy]==0.14.0 \ - --hash=sha256:50d02edafd04edf2e5e27d6249760a51b2198adb9f6ffd778030b3d2806b026b \ - --hash=sha256:925cf060b7d1337647f880401f6094321e7d8e837533b8e159b9a92afa3157c6 +optype[numpy]==0.15.0 \ + --hash=sha256:457d6ca9e7da19967ec16d42bdf94e240b33b5d70a56fbbf5b427e5ea39cf41e \ + --hash=sha256:caba40ece9ea39b499fa76c036a82e0d452a432dd4dd3e8e0d30892be2e8c76c # via scipy-stubs packaging==25.0 \ --hash=sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484 \ @@ -945,9 +945,9 @@ scipy==1.16.3 ; python_version >= "3.13" \ # via # -r build/requirements.in # jaxlib -scipy-stubs==1.16.3.0 \ - --hash=sha256:90e5d82ced2183ef3c5c0a28a77df8cc227458624364fa0ff975ad24fa89d6ad \ - --hash=sha256:d6943c085e47a1ed431309f9ca582b6a206a9db808a036132a0bf01ebc34b506 +scipy-stubs==1.16.3.3 \ + --hash=sha256:af47578875d5557567225a16ec1b9b38a48c4c4377d92396413ebd65406c44ee \ + --hash=sha256:f6316b36cd0fb272c994ae5b10c4a73c644a7e156ed8d32bcd9c35303d0e1b7e # via -r build/test-requirements.txt six==1.17.0 \ --hash=sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274 \ @@ -957,28 +957,32 @@ sortedcontainers==2.4.0 \ --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \ --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0 # via hypothesis -tensorstore==0.1.79 \ - --hash=sha256:0fd6165f3df49abc7c9de029b2b72d74bebd2ff2481a5ced003607eb61c56d3e \ - --hash=sha256:108c0e867aa2c87d4982cc6325a2de0c4f5bd63c2bea18adb193a370c40594ce \ - --hash=sha256:11a2c62694ea9c21770bc5a09938d3d15c4b9662b738ae6e1e513c26ed96251a \ - --hash=sha256:1e8e2d098829919caac6a62cf568902e34789069ceddb28497d6e36ebcb95c0b \ - --hash=sha256:29cf4336153af136ac8ac528e2ed46df19367edae7e14e37bca1a8b7c4848ef2 \ - --hash=sha256:5e152d334bf34fbabdfe8e5bc35b87d1f9947065924ff83c29e659308b36e948 \ - --hash=sha256:608f7178ec6e4e4a3c26545b0a44f44bf83438d04bf2d960cd0e7699eaa99ef6 \ - --hash=sha256:6c98c6b74c00e00eba7969292144e471d5c45d67088f0dc08e3a4c60a15ee191 \ - --hash=sha256:6f8f5a940eab434a951c2dadcc7c0516c7bef6d8b7a7144054f7a0c56152b5f5 \ - --hash=sha256:71aa9b45436d888c37b965f7b71195916d15438119b7dccb66a3b0776bfba367 \ - --hash=sha256:7af9422269c2bfcdecf9dd55309060665ab9c2d7f6c892377ed32c032400feea \ - --hash=sha256:83072ee0e551d6dca582e154b64c8b8066d276ec0759784e3149c28212a61f18 \ - --hash=sha256:847982652273fb7b2d694b789205747aaf3e50ae64738c5cb7b5eb03d86a9947 \ - --hash=sha256:8dad44a8a7f2952a5d0030a8bd868b3cfdff048bd40ab53e7226f3d8b0881c5e \ - --hash=sha256:94d8fc9df1721b0287046aca7209fd5040889cad4202e7b73a1fdb77cd9b71c6 \ - --hash=sha256:97756d2cba3c5ce21e15602c2af5a02521cc0ecda7f9fb6d18da2f3bd51827f4 \ - --hash=sha256:a071c6c255b7e412957a6aa563bc4250242c7894edad06ae6358e3d30b7d88ce \ - --hash=sha256:bbd8c1ab7d2e3c03ded3d40bb373ee9a67668e33a564484927865ce43b210386 \ - --hash=sha256:c4230b8fd29795e88e441f749d881973eca8dadf33c5262b367839fb8891f79b \ - --hash=sha256:c9f2dc3342e4686af98f6e259dc9fb377f1bf657b649c247bf6647bbe4f98090 \ - --hash=sha256:debd435042c00be68ba1fb3cf59325a7babb3f4a3cf4744c87dde346802cbbb4 +tensorstore==0.1.80 \ + --hash=sha256:04c29d979eb8b8ee48f873dc13d2701bfd49425500ffc5b848e4ec55b2548281 \ + --hash=sha256:07e4a84bacf70b78305831897068a9b5ad30326e63bbeb92c4bf7e565fcf5e9e \ + --hash=sha256:1113a6982fc0fa8dda8fcc0495715e647ac3360909a86ff13f2e04564f82d54a \ + --hash=sha256:189d924eaec394c9331e284a9c513ed583e336472a925823b5151cb26f41d091 \ + --hash=sha256:1b2b2ed0051dfab7e25295b14e6620520729e6e2ddf505f98c8d3917569614bf \ + --hash=sha256:246641a8780ee5e04e88bc95c8e31faac6471bab1180d1f5cdc9804b29a77c04 \ + --hash=sha256:4158fe76b96f62d12a37d7868150d836e089b5280b2bdd363c43c5d651f10e26 \ + --hash=sha256:46136fe42ee6dd835d957db37073058aea0b78fdfbe2975941640131b7740824 \ + --hash=sha256:4baee67fce95f29f593fbab4866119347115eaace887732aa92cfcbb9e6b0748 \ + --hash=sha256:53fd121ccd332bc4cc397f7af45889360c668b43dc3ff6bc3264df0f9886c11a \ + --hash=sha256:6b7c5dd434bba4ee08fe46bbbdb25c60dd3d47ccb4b8561a9751cf1526da52b8 \ + --hash=sha256:6c8dbbdd31cbb28eccfb23dbbd4218fe67bfc32e9cb452875a485b81031c949d \ + --hash=sha256:7451b30f99d9f31a2b9d70e6ef61815713dc782c58c6d817f91781341e4dac05 \ + --hash=sha256:8cd11027b5a8b66db8d344085a31a1666c78621dac27039c4d571bc4974804a1 \ + --hash=sha256:9c088e8c9f67c266ef4dae3703bd617f7c0cb0fd98e99c4500692e38a4328140 \ + --hash=sha256:a92505189731fcb03f1c69a84ea4460abb24204bfac1f339448a0621e7def77c \ + --hash=sha256:acb8d52fadcefafef4ef8ecca3fc99b1d0e3c5c5a888766484c3e39f050be7f5 \ + --hash=sha256:b193a7a1c4f455a61e60ed2dd67271a3daab0910ddb4bd9db51390d1b36d9996 \ + --hash=sha256:bc28a58c580253a526a4b6d239d18181ef96f1e285a502dbb03ff15eeec07a5b \ + --hash=sha256:c0529afab3800749dd245843d3bf0d061a109a8edb77fb345f476e8bccda51b8 \ + --hash=sha256:d2b353b0bd53fedd77fc5a12a1c1a91cacc3cf59e3dd785529c5a54b31d1c7b1 \ + --hash=sha256:de63843706fdfe9565a45567238c5b1e55a0b28bbde6524200b31d29043a9a16 \ + --hash=sha256:e93df6d34ff5f0f6be245f4d29b99a7c1eef8ad91b50686adf57a5eeea99cb74 \ + --hash=sha256:f65dfaf9e737a41389e29a5a2ea52ca5d14c8d6f48b402c723d800cd16d322b0 \ + --hash=sha256:f8b51d7e685bbb63f6becd7d2ac8634d5ab67ec7e53038e597182e2db2c7aa90 # via -r build/nonfreethreading-requirements.txt typing-extensions==4.15.0 \ --hash=sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466 \ diff --git a/build/requirements_lock_3_14_ft.txt b/build/requirements_lock_3_14_ft.txt index 5dd84e11ee63..a5d099e40ff0 100644 --- a/build/requirements_lock_3_14_ft.txt +++ b/build/requirements_lock_3_14_ft.txt @@ -116,65 +116,65 @@ execnet==2.1.2 \ --hash=sha256:63d83bfdd9a23e35b9c6a3261412324f964c2ec8dcd8d3c6916ee9373e0befcd \ --hash=sha256:67fba928dd5a544b783f6056f449e5e3931a5c378b128bc18501f7ea79e296ec # via pytest-xdist -filelock==3.20.0 \ - --hash=sha256:339b4732ffda5cd79b13f4e2711a31b0365ce445d95d243bb996273d072546a2 \ - --hash=sha256:711e943b4ec6be42e1d4e6690b48dc175c822967466bb31c0c293f34334c13f4 +filelock==3.20.1 \ + --hash=sha256:15d9e9a67306188a44baa72f569d2bfd803076269365fdea0934385da4dc361a \ + --hash=sha256:b8360948b351b80f420878d8516519a2204b07aefcdcfd24912a5d33127f188c # via -r build/test-requirements.txt flatbuffers==25.9.23 \ --hash=sha256:255538574d6cb6d0a79a17ec8bc0d30985913b87513a01cce8bcdb6b4c44d0e2 \ --hash=sha256:676f9fa62750bb50cf531b42a0a2a118ad8f7f797a511eda12881c016f093b12 # via -r build/test-requirements.txt -fonttools==4.61.0 \ - --hash=sha256:0011d640afa61053bc6590f9a3394bd222de7cfde19346588beabac374e9d8ac \ - --hash=sha256:02bdf8e04d1a70476564b8640380f04bb4ac74edc1fc71f1bacb840b3e398ee9 \ - --hash=sha256:0bdcf2e29d65c26299cc3d502f4612365e8b90a939f46cd92d037b6cb7bb544a \ - --hash=sha256:13e3e20a5463bfeb77b3557d04b30bd6a96a6bb5c15c7b2e7908903e69d437a0 \ - --hash=sha256:14a290c5c93fcab76b7f451e6a4b7721b712d90b3b5ed6908f1abcf794e90d6d \ - --hash=sha256:14fafda386377b6131d9e448af42d0926bad47e038de0e5ba1d58c25d621f028 \ - --hash=sha256:1cfa2eb9bae650e58f0e8ad53c49d19a844d6034d6b259f30f197238abc1ccee \ - --hash=sha256:276f14c560e6f98d24ef7f5f44438e55ff5a67f78fa85236b218462c9f5d0635 \ - --hash=sha256:2cb5e45a824ce14b90510024d0d39dae51bd4fbb54c42a9334ea8c8cf4d95cbe \ - --hash=sha256:2de14557d113faa5fb519f7f29c3abe4d69c17fe6a5a2595cc8cda7338029219 \ - --hash=sha256:2f0bafc8a3b3749c69cc610e5aa3da832d39c2a37a68f03d18ec9a02ecaac04a \ - --hash=sha256:328a9c227984bebaf69f3ac9062265f8f6acc7ddf2e4e344c63358579af0aa3d \ - --hash=sha256:3b2065d94e5d63aafc2591c8b6ccbdb511001d9619f1bca8ad39b745ebeb5efa \ - --hash=sha256:4238120002e68296d55e091411c09eab94e111c8ce64716d17df53fd0eb3bb3d \ - --hash=sha256:46cb3d9279f758ac0cf671dc3482da877104b65682679f01b246515db03dbb72 \ - --hash=sha256:58b4f1b78dfbfe855bb8a6801b31b8cdcca0e2847ec769ad8e0b0b692832dd3b \ - --hash=sha256:59587bbe455dbdf75354a9dbca1697a35a8903e01fab4248d6b98a17032cee52 \ - --hash=sha256:5a9b78da5d5faa17e63b2404b77feeae105c1b7e75f26020ab7a27b76e02039f \ - --hash=sha256:627216062d90ab0d98215176d8b9562c4dd5b61271d35f130bcd30f6a8aaa33a \ - --hash=sha256:63c7125d31abe3e61d7bb917329b5543c5b3448db95f24081a13aaf064360fc8 \ - --hash=sha256:6781e7a4bb010be1cd69a29927b0305c86b843395f2613bdabe115f7d6ea7f34 \ - --hash=sha256:67d841aa272be5500de7f447c40d1d8452783af33b4c3599899319f6ef9ad3c1 \ - --hash=sha256:68704a8bbe0b61976262b255e90cde593dc0fe3676542d9b4d846bad2a890a76 \ - --hash=sha256:6b493c32d2555e9944ec1b911ea649ff8f01a649ad9cba6c118d6798e932b3f0 \ - --hash=sha256:6e5ca8c62efdec7972dfdfd454415c4db49b89aeaefaaacada432f3b7eea9866 \ - --hash=sha256:70e2a0c0182ee75e493ef33061bfebf140ea57e035481d2f95aa03b66c7a0e05 \ - --hash=sha256:787ef9dfd1ea9fe49573c272412ae5f479d78e671981819538143bec65863865 \ - --hash=sha256:7b446623c9cd5f14a59493818eaa80255eec2468c27d2c01b56e05357c263195 \ - --hash=sha256:7fb5b84f48a6a733ca3d7f41aa9551908ccabe8669ffe79586560abcc00a9cfd \ - --hash=sha256:9064b0f55b947e929ac669af5311ab1f26f750214db6dd9a0c97e091e918f486 \ - --hash=sha256:96dfc9bc1f2302224e48e6ee37e656eddbab810b724b52e9d9c13a57a6abad01 \ - --hash=sha256:9821ed77bb676736b88fa87a737c97b6af06e8109667e625a4f00158540ce044 \ - --hash=sha256:a32a16951cbf113d38f1dd8551b277b6e06e0f6f776fece0f99f746d739e1be3 \ - --hash=sha256:a5c5fff72bf31b0e558ed085e4fd7ed96eb85881404ecc39ed2a779e7cf724eb \ - --hash=sha256:ad751319dc532a79bdf628b8439af167181b4210a0cd28a8935ca615d9fdd727 \ - --hash=sha256:adbb4ecee1a779469a77377bbe490565effe8fce6fb2e6f95f064de58f8bac85 \ - --hash=sha256:b2b734d8391afe3c682320840c8191de9bd24e7eb85768dd4dc06ed1b63dbb1b \ - --hash=sha256:b5ca59b7417d149cf24e4c1933c9f44b2957424fc03536f132346d5242e0ebe5 \ - --hash=sha256:b6ceac262cc62bec01b3bb59abccf41b24ef6580869e306a4e88b7e56bb4bdda \ - --hash=sha256:ba774b8cbd8754f54b8eb58124e8bd45f736b2743325ab1a5229698942b9b433 \ - --hash=sha256:c53b47834ae41e8e4829171cc44fec0fdf125545a15f6da41776b926b9645a9a \ - --hash=sha256:c84b430616ed73ce46e9cafd0bf0800e366a3e02fb7e1ad7c1e214dbe3862b1f \ - --hash=sha256:dc25a4a9c1225653e4431a9413d0381b1c62317b0f543bdcec24e1991f612f33 \ - --hash=sha256:df8cbce85cf482eb01f4551edca978c719f099c623277bda8332e5dbe7dba09d \ - --hash=sha256:e074bc07c31406f45c418e17c1722e83560f181d122c412fa9e815df0ff74810 \ - --hash=sha256:e0d87e81e4d869549585ba0beb3f033718501c1095004f5e6aef598d13ebc216 \ - --hash=sha256:e24a1565c4e57111ec7f4915f8981ecbb61adf66a55f378fdc00e206059fcfef \ - --hash=sha256:e2bfacb5351303cae9f072ccf3fc6ecb437a6f359c0606bae4b1ab6715201d87 \ - --hash=sha256:e6cd0d9051b8ddaf7385f99dd82ec2a058e2b46cf1f1961e68e1ff20fcbb61af \ - --hash=sha256:ec520a1f0c7758d7a858a00f090c1745f6cde6a7c5e76fb70ea4044a15f712e7 +fonttools==4.61.1 \ + --hash=sha256:0de30bfe7745c0d1ffa2b0b7048fb7123ad0d71107e10ee090fa0b16b9452e87 \ + --hash=sha256:10d88e55330e092940584774ee5e8a6971b01fc2f4d3466a1d6c158230880796 \ + --hash=sha256:11f35ad7805edba3aac1a3710d104592df59f4b957e30108ae0ba6c10b11dd75 \ + --hash=sha256:15acc09befd16a0fb8a8f62bc147e1a82817542d72184acca9ce6e0aeda9fa6d \ + --hash=sha256:17d2bf5d541add43822bcf0c43d7d847b160c9bb01d15d5007d84e2217aaa371 \ + --hash=sha256:2180f14c141d2f0f3da43f3a81bc8aa4684860f6b0e6f9e165a4831f24e6a23b \ + --hash=sha256:21e7c8d76f62ab13c9472ccf74515ca5b9a761d1bde3265152a6dc58700d895b \ + --hash=sha256:41a7170d042e8c0024703ed13b71893519a1a6d6e18e933e3ec7507a2c26a4b2 \ + --hash=sha256:41ed4b5ec103bd306bb68f81dc166e77409e5209443e5773cb4ed837bcc9b0d3 \ + --hash=sha256:497c31ce314219888c0e2fce5ad9178ca83fe5230b01a5006726cdf3ac9f24d9 \ + --hash=sha256:4c1b526c8d3f615a7b1867f38a9410849c8f4aef078535742198e942fba0e9bd \ + --hash=sha256:4d7092bb38c53bbc78e9255a59158b150bcdc115a1e3b3ce0b5f267dc35dd63c \ + --hash=sha256:4f5686e1fe5fce75d82d93c47a438a25bf0d1319d2843a926f741140b2b16e0c \ + --hash=sha256:58b0ee0ab5b1fc9921eccfe11d1435added19d6494dde14e323f25ad2bc30c56 \ + --hash=sha256:5ce02f38a754f207f2f06557523cd39a06438ba3aafc0639c477ac409fc64e37 \ + --hash=sha256:5fade934607a523614726119164ff621e8c30e8fa1ffffbbd358662056ba69f0 \ + --hash=sha256:5fe9fd43882620017add5eabb781ebfbc6998ee49b35bd7f8f79af1f9f99a958 \ + --hash=sha256:64102ca87e84261419c3747a0d20f396eb024bdbeb04c2bfb37e2891f5fadcb5 \ + --hash=sha256:664c5a68ec406f6b1547946683008576ef8b38275608e1cee6c061828171c118 \ + --hash=sha256:6675329885c44657f826ef01d9e4fb33b9158e9d93c537d84ad8399539bc6f69 \ + --hash=sha256:75c1a6dfac6abd407634420c93864a1e274ebc1c7531346d9254c0d8f6ca00f9 \ + --hash=sha256:75da8f28eff26defba42c52986de97b22106cb8f26515b7c22443ebc9c2d3261 \ + --hash=sha256:77efb033d8d7ff233385f30c62c7c79271c8885d5c9657d967ede124671bbdfb \ + --hash=sha256:78a7d3ab09dc47ac1a363a493e6112d8cabed7ba7caad5f54dbe2f08676d1b47 \ + --hash=sha256:7c7db70d57e5e1089a274cbb2b1fd635c9a24de809a231b154965d415d6c6d24 \ + --hash=sha256:8c56c488ab471628ff3bfa80964372fc13504ece601e0d97a78ee74126b2045c \ + --hash=sha256:91669ccac46bbc1d09e9273546181919064e8df73488ea087dcac3e2968df9ba \ + --hash=sha256:9b666a475a65f4e839d3d10473fad6d47e0a9db14a2f4a224029c5bfde58ad2c \ + --hash=sha256:9cfef3ab326780c04d6646f68d4b4742aae222e8b8ea1d627c74e38afcbc9d91 \ + --hash=sha256:a13fc8aeb24bad755eea8f7f9d409438eb94e82cf86b08fe77a03fbc8f6a96b1 \ + --hash=sha256:a75c301f96db737e1c5ed5fd7d77d9c34466de16095a266509e13da09751bd19 \ + --hash=sha256:a76d4cb80f41ba94a6691264be76435e5f72f2cb3cab0b092a6212855f71c2f6 \ + --hash=sha256:aed04cabe26f30c1647ef0e8fbb207516fd40fe9472e9439695f5c6998e60ac5 \ + --hash=sha256:b148b56f5de675ee16d45e769e69f87623a4944f7443850bf9a9376e628a89d2 \ + --hash=sha256:b501c862d4901792adaec7c25b1ecc749e2662543f68bb194c42ba18d6eec98d \ + --hash=sha256:b846a1fcf8beadeb9ea4f44ec5bdde393e2f1569e17d700bfc49cd69bde75881 \ + --hash=sha256:b931ae8f62db78861b0ff1ac017851764602288575d65b8e8ff1963fed419063 \ + --hash=sha256:c33ab3ca9d3ccd581d58e989d67554e42d8d4ded94ab3ade3508455fe70e65f7 \ + --hash=sha256:c6604b735bb12fef8e0efd5578c9fb5d3d8532d5001ea13a19cddf295673ee09 \ + --hash=sha256:d8db08051fc9e7d8bc622f2112511b8107d8f27cd89e2f64ec45e9825e8288da \ + --hash=sha256:d9203500f7c63545b4ce3799319fe4d9feb1a1b89b28d3cb5abd11b9dd64147e \ + --hash=sha256:dc492779501fa723b04d0ab1f5be046797fee17d27700476edc7ee9ae535a61e \ + --hash=sha256:e6bcdf33aec38d16508ce61fd81838f24c83c90a1d1b8c68982857038673d6b8 \ + --hash=sha256:e76ce097e3c57c4bcb67c5aa24a0ecdbd9f74ea9219997a707a4061fbe2707aa \ + --hash=sha256:eff1ac3cc66c2ac7cda1e64b4e2f3ffef474b7335f92fc3833fc632d595fcee6 \ + --hash=sha256:f3cb4a569029b9f291f88aafc927dd53683757e640081ca8c412781ea144565e \ + --hash=sha256:f79b168428351d11e10c5aeb61a74e1851ec221081299f4cf56036a95431c43a \ + --hash=sha256:fa646ecec9528bef693415c79a86e733c70a4965dd938e9a226b0fc64c9d2e6c \ + --hash=sha256:fe2efccb324948a11dd09d22136fe2ac8a97d6c1347cf0b58a911dcd529f66b7 \ + --hash=sha256:fff4f534200a04b4a36e7ae3cb74493afe807b517a09e99cb4faa89a34ed6ecd # via matplotlib fsspec==2025.10.0 \ --hash=sha256:7c7712353ae7d875407f97715f0e1ffcc21e33d5b24556cb1e090ae9409ec61d \ @@ -192,53 +192,53 @@ iniconfig==2.3.0 \ --hash=sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730 \ --hash=sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12 # via pytest -jax-cuda12-pjrt==0.8.1 ; sys_platform == "linux" \ - --hash=sha256:452b70ee10cb9ac5d7dfca55ffbcdb89b6c8bc6ba70a45af7c490d1dcea98eb7 \ - --hash=sha256:a631d0689903354afd7b3d2ec595b7da06a6230a76da00ff9548f542b21b6250 +jax-cuda12-pjrt==0.8.2 ; sys_platform == "linux" \ + --hash=sha256:717a1b196a642409ce195ddf031c20bbeadcc886f55e49a1d3f4927373aeedae \ + --hash=sha256:e3bab41ca7c48e4163db9e7efd271b3aa85f0fe45f5ed0708d6bbed93a59f977 # via -r build/requirements.in -jax-cuda13-pjrt==0.8.1 \ - --hash=sha256:86a6926da76aebf6080922747a7a98d321f4ca27101077357fa148032bc3cd1d \ - --hash=sha256:f3b1c1c7118b4570f72740ed756cbed289a3f8fa813570a0dbf16f186bccb8c9 +jax-cuda13-pjrt==0.8.2 \ + --hash=sha256:370e9cb9a76af6a6b00e8f6c6ae8e5a0158886b65994a114cf3acf56ea1570f3 \ + --hash=sha256:c0ebc284d4bea9c0b278017b4bb62eb871755fe582cbf22ce1fb770d30fb9af6 # via # -r build/requirements.in # jax-cuda13-plugin -jax-cuda13-plugin==0.8.1 \ - --hash=sha256:07625aed1aa769c701213e84d6b2a46902019a1d2af8a09ce6dfd9575163bfc6 \ - --hash=sha256:0d503c312d2daefea62a00c74534579deeacd46e15c364074d27a8d95a100032 \ - --hash=sha256:12a7aac712a7c6dc228ef9991578e85e3bcab7c324193bdfb2b5acf059bae6d6 \ - --hash=sha256:16ee16b13393baf9672b6612566308675cebdc8d785b61fac2b93ce8c97825ff \ - --hash=sha256:4e589ed8197f1bea7e7fd20d866ccc5c2a1276d7acd02224e3a5b07983df61e2 \ - --hash=sha256:64df1f1414d899ab7a84751d6f78515365555b54fb64b3e318bd70519de99c86 \ - --hash=sha256:7a373fd3e5f11ecad01b8add1e277eb6559b4966b0745d92dc91c585579fac35 \ - --hash=sha256:92238530152890c3405addacd1fc021c87022cbf99fa66418cfa2e9f68a5c49d \ - --hash=sha256:a4c5a4a69346be6520c729675d5d80e85d610399f4840d74bdfae9c6ebedc8bc \ - --hash=sha256:af33f737ccf5426155cf5c7d175bf765ca25724b94af5109ef2df891b410f997 \ - --hash=sha256:d81222989fb30496cc6554d42deaca6ab003721f56ff669da7f80d61fea1219d \ - --hash=sha256:ef218c47b2cde8c700ad2a56d04320f9d1490439fc6db20747f56e91de7289c2 +jax-cuda13-plugin==0.8.2 \ + --hash=sha256:0ab3a13f798f973416754cb00b1696a8e6d2425c39496e2145293661bcd28446 \ + --hash=sha256:2c7532f2a6c9dcf50ba4b8ebbe4c78988b8c260695e50f9fdbd68f9417a8bb51 \ + --hash=sha256:358f4b4d0d2b92f62fe7f07d355a9c57c5684da3d4e6f483afec38b23e66f02e \ + --hash=sha256:6703552ee10b6b41b43fe4fc62b4672bec13d271eb709a29bf14b0fa1ec805dd \ + --hash=sha256:8fe2c5874cbae18a01a0fa722dfbd547b5504e7bf42aaaae6ad3e1b676e8daa2 \ + --hash=sha256:950f2c1f2e0f5e13893c4c27a4df540a96a26ebd15d7954fa85ffb8cc8b1d9a1 \ + --hash=sha256:9835076d0a917345700370cfaee5cfd7876c0b0982e99206a8b4848af0ea6e5e \ + --hash=sha256:bb059d99eb960718358f82f3eaa0829f96fc7d7e978ded821f2c40256a8aee4b \ + --hash=sha256:d5f9470db46aee8a64e9ab7817927335b77c6d3698d5113e0da2dd962c1468c4 \ + --hash=sha256:f1643b216ccb41a2d7c0e240359c729c96c5eca76e4525ffcd6313049e348b9a \ + --hash=sha256:f35f687d9b839365a49c2b0052eaeb6b9f0a9879813b4fa020f8d9caa3ddde46 \ + --hash=sha256:fa6a6faf3130d6ef05f9ee6cbc27e0a2e4db0f23720136403991eccb3a0144af # via -r build/requirements.in -jaxlib==0.8.1 \ - --hash=sha256:117f2fe2c19479e560ad85a3ef2fcc0b1d24816456f0d039f865c2acbab63b5a \ - --hash=sha256:1a4001ed3ba9ed5a812da1b16f52eebb5d473a4480c1523828c7bd3dae8d1375 \ - --hash=sha256:1bc76edec2bc74a7adb5e29329ece51a67c57cd011a06d55d07da62fbabe3389 \ - --hash=sha256:22f489fb5c8be0da7be5e4957a10936b3760a169668f8b25c5d09c51c3ef47f6 \ - --hash=sha256:24ec3f3a9c45d6de060020dc94c444d69e18099fab927ea3979ff8cedf0ed2c9 \ - --hash=sha256:4933298fcfb07a5aa2d1fed21c111d07cea50e6f180dba2cdb5463c13fb98f2f \ - --hash=sha256:63fc25c4b5d03256798796a024125e29bcf254acc3eae5dc3239d1c30b86b866 \ - --hash=sha256:7a5d381fad89622750fae29fab83c0847e2931ad8d6a34dc13b28fc4d67f75a3 \ - --hash=sha256:865add56139883405f3f15c9b0de6a64ab8f4aa549dff196b72dbc86be6ccc1f \ - --hash=sha256:88bde0f535eeea6689e0cd57d40b7660d5206ac95c7d42e09562a109b963a49f \ - --hash=sha256:8e118e1fbe714f37a94ba26777c17faab7dca4a33646a3d98cd1d99673bbd6b1 \ - --hash=sha256:90e48973f8dbded7edc8728be84c01ae00412190187fb06622abfa4edd42c0a8 \ - --hash=sha256:92c41c9b9862c08521eb90515a7c5bcc840c6d30f86230cebf94aea2d6a0af81 \ - --hash=sha256:a0349f6e8179dc897d33aeb90ec66b4a8041330fbbba8d071dc6167cd2271539 \ - --hash=sha256:af4924189fc53b69237715b56ebcbfc71bb91ca16184143dcef0d430c8173de6 \ - --hash=sha256:bd697c171ace1e2e9d6ed910a78f385b3c4095cee290b0255aa58848f2acdeab \ - --hash=sha256:bed1e94ae8c7c16bca4476d8d7f582f0d1a102a4e69c3a9bd2069a0dc42274a9 \ - --hash=sha256:c14c8c19a7eb694aa14092b6d2fffb9d2bdd8a603b63d6f26fbeaf129c204f9f \ - --hash=sha256:d245bd6a279c72ca5f796df84cdd64d7c9c8abc4b8d89adf4acf45898dab958b \ - --hash=sha256:f2f11491b077d05249d63813e811401194a41edc8e9cc60af8f4b554057cfad0 \ - --hash=sha256:fdbbf2336c08bbf8f30548e204c8c9d77f8b2a3a5b7fc7985749246feb8852b0 \ - --hash=sha256:ff32b6320d729131efaf22939825b52d75957c84c32af2b0b1bdb33cf27ba75f +jaxlib==0.8.2 \ + --hash=sha256:023de6f3f56da2af7037970996500586331fdb50b530ecbb54b9666da633bd00 \ + --hash=sha256:05b958f497e49824c432e734bb059723b7dfe69e2ad696a9f9c8ad82fff7c3f8 \ + --hash=sha256:1bfbcf6c3de221784fa4cdb6765a09d71cb4298b15626b3d0409b3dfcd8a8667 \ + --hash=sha256:28eec1a4e0639a0d8702cea3cb70dd3663053dbfa344452994ea48dc6ceadaa5 \ + --hash=sha256:2b9789bd08f8b0cc5a5c12ae896fe432d5942e32e417091b8b5a96a9a6fd5cf1 \ + --hash=sha256:3b16e50c5b730c9dd0a49e55f1acfaa722b00b1af0522a591558dcc0464252f2 \ + --hash=sha256:490bf0cb029c73c65c9431124b86cdc95082dbc1fb76fc549d24d75da33e5454 \ + --hash=sha256:4d006db96be020c8165212a1216372f8acac4ff4f8fb067743d694ef2b301ace \ + --hash=sha256:68108dff0de74adc468016be9a19f80efe48c660c0d5a122287094b44b092afc \ + --hash=sha256:7c304f3a016965b9d1f5239a8a0399a73925f5604fe914c5ca66ecf734bf6422 \ + --hash=sha256:7da8127557c786264049ae55460d1b8d04cc3cdf0403a087f2fc1e6d313ec722 \ + --hash=sha256:964626f581beab31ee6826b228fcc2ec5181b05cecf94a528dff97921c145dbc \ + --hash=sha256:a397ea7dcb37d689ce79173eeb99b2f1347637a36be9a27f20ae6848bfc58bfc \ + --hash=sha256:aa8701b6356f098e8452c3cec762fb5f706fcb8f67ffd65964f63982479aa23b \ + --hash=sha256:bb89be452b1b808d3f88fc01c415b364a260be4cc7ac120c038009f6150a32dc \ + --hash=sha256:beffb004e7eeb5c9afb24439e2b2cf45a4ee3e3e8adf45e355edf2af62acf8b8 \ + --hash=sha256:ccf77da917a20935247c990691decfcbdd06c25ef0ac94d914a04aadb22f714c \ + --hash=sha256:dffc22b5b732b9556d92c918b251c61bcc046617c4dbb51e1f7a656587fddffb \ + --hash=sha256:e6a97dfb0232eed9a2bb6e3828e4f682dbac1a7fea840bfda574cae2dbf5faf9 \ + --hash=sha256:f205e91c3a152a2a76c0bc59a6a2de03e87ec261b91e8812922777185e7b08f5 \ + --hash=sha256:f28edac8c226fc07fa3e8af6f9defede8ac2c307429e3291edce8739d39becc9 \ + --hash=sha256:f472cc72e3058e50b5f0230b236d5a1183bf6c3d5423d2a52eff07bcf34908de # via -r build/requirements.in kiwisolver==1.4.9 \ --hash=sha256:0749fd8f4218ad2e851e11cc4dc05c7cbc0cbc4267bdfdb31782e65aace4ee9c \ @@ -343,74 +343,74 @@ kiwisolver==1.4.9 \ --hash=sha256:fb940820c63a9590d31d88b815e7a3aa5915cad3ce735ab45f0c730b39547de1 \ --hash=sha256:fc1795ac5cd0510207482c3d1d3ed781143383b8cfd36f5c645f3897ce066220 # via matplotlib -libtpu==0.0.30 ; sys_platform == "linux" and platform_machine == "x86_64" \ - --hash=sha256:26442f0a51d243cf7259407bba8f5d849c9024297efe97044d64b5244283ad63 \ - --hash=sha256:5fabff9a041674bb889fb59ac0b5c54b9dbcf492a8c782e083ef86a8194dbb0f \ - --hash=sha256:8be30562743a63c1c1353e7ba78f0dbfbb051e8d1e9d3bb2b5da9b720363bb0a \ - --hash=sha256:b1fc44915dad56c0ceb733311a4d4396b88dc9a1c7c01acd7617da90e7ec22f2 \ - --hash=sha256:babab04ca663da2c4e4b3ab036c4d465f2f4674c480d08239c5d4965b7ce9e1c \ - --hash=sha256:f9aa040895ec25fafebcd4e1a0e1a9524ff3bd778ca88543731e308f6e516dd1 +libtpu==0.0.32 ; sys_platform == "linux" and platform_machine == "x86_64" \ + --hash=sha256:37f6aefe6d69d24e5e268e74c1e90cd0ce0fa6a796720709445d8938605912ee \ + --hash=sha256:4a4ae6db0a90f0e33902cd345923a5d77dfc5bbab9a3e524c79bf876858104ee \ + --hash=sha256:6453995774b902e43d1caf0d82298a2e90f2f731ddd74d6bef9510d91732775f \ + --hash=sha256:7ffd9cdb89da22d32bb60ce70b88e9e76861c5d704ebf673f32987e19e9913f3 \ + --hash=sha256:98d9713a39d9581c9b10a7bdeaa575cc874a4a7f543cfb40d3a723ee27b428b5 \ + --hash=sha256:d2a0e7abe4a029b79049bc95da20e4c2a64f8ef09823f75286cecfbb4b0b6da1 # via -r build/requirements.in markdown-it-py==4.0.0 \ --hash=sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147 \ --hash=sha256:cb0a2b4aa34f932c007117b194e945bd74e0ec24133ceb5bac59009cda1cb9f3 # via rich -matplotlib==3.10.7 \ - --hash=sha256:07124afcf7a6504eafcb8ce94091c5898bbdd351519a1beb5c45f7a38c67e77f \ - --hash=sha256:09d7945a70ea43bf9248f4b6582734c2fe726723204a76eca233f24cffc7ef67 \ - --hash=sha256:0d8c32b7ea6fb80b1aeff5a2ceb3fb9778e2759e899d9beff75584714afcc5ee \ - --hash=sha256:11ae579ac83cdf3fb72573bb89f70e0534de05266728740d478f0f818983c695 \ - --hash=sha256:15112bcbaef211bd663fa935ec33313b948e214454d949b723998a43357b17b0 \ - --hash=sha256:1d9d3713a237970569156cfb4de7533b7c4eacdd61789726f444f96a0d28f57f \ - --hash=sha256:1e4bbad66c177a8fdfa53972e5ef8be72a5f27e6a607cec0d8579abd0f3102b1 \ - --hash=sha256:2222c7ba2cbde7fe63032769f6eb7e83ab3227f47d997a8453377709b7fe3a5a \ - --hash=sha256:22df30ffaa89f6643206cf13877191c63a50e8f800b038bc39bee9d2d4957632 \ - --hash=sha256:31963603041634ce1a96053047b40961f7a29eb8f9a62e80cc2c0427aa1d22a2 \ - --hash=sha256:37a1fea41153dd6ee061d21ab69c9cf2cf543160b1b85d89cd3d2e2a7902ca4c \ - --hash=sha256:3886e47f64611046bc1db523a09dd0a0a6bed6081e6f90e13806dd1d1d1b5e91 \ - --hash=sha256:4645fc5d9d20ffa3a39361fcdbcec731382763b623b72627806bf251b6388866 \ - --hash=sha256:4a11c2e9e72e7de09b7b72e62f3df23317c888299c875e2b778abf1eda8c0a42 \ - --hash=sha256:4a74f79fafb2e177f240579bc83f0b60f82cc47d2f1d260f422a0627207008ca \ - --hash=sha256:4c14b6acd16cddc3569a2d515cfdd81c7a68ac5639b76548cfc1a9e48b20eb65 \ - --hash=sha256:53b492410a6cd66c7a471de6c924f6ede976e963c0f3097a3b7abfadddc67d0a \ - --hash=sha256:53cc80662dd197ece414dd5b66e07370201515a3eaf52e7c518c68c16814773b \ - --hash=sha256:5c09cf8f2793f81368f49f118b6f9f937456362bee282eac575cca7f84cda537 \ - --hash=sha256:5e38c2d581d62ee729a6e144c47a71b3f42fb4187508dbbf4fe71d5612c3433b \ - --hash=sha256:5f3f6d315dcc176ba7ca6e74c7768fb7e4cf566c49cb143f6bc257b62e634ed8 \ - --hash=sha256:6516ce375109c60ceec579e699524e9d504cd7578506f01150f7a6bc174a775e \ - --hash=sha256:667ecd5d8d37813a845053d8f5bf110b534c3c9f30e69ebd25d4701385935a6d \ - --hash=sha256:6f1851eab59ca082c95df5a500106bad73672645625e04538b3ad0f69471ffcc \ - --hash=sha256:702590829c30aada1e8cef0568ddbffa77ca747b4d6e36c6d173f66e301f89cc \ - --hash=sha256:7146d64f561498764561e9cd0ed64fcf582e570fc519e6f521e2d0cfd43365e1 \ - --hash=sha256:744991e0cc863dd669c8dc9136ca4e6e0082be2070b9d793cbd64bec872a6815 \ - --hash=sha256:786656bb13c237bbcebcd402f65f44dd61ead60ee3deb045af429d889c8dbc67 \ - --hash=sha256:7a0edb7209e21840e8361e91ea84ea676658aa93edd5f8762793dec77a4a6748 \ - --hash=sha256:7ac81eee3b7c266dd92cee1cd658407b16c57eed08c7421fa354ed68234de380 \ - --hash=sha256:90ad854c0a435da3104c01e2c6f0028d7e719b690998a2333d7218db80950722 \ - --hash=sha256:9257be2f2a03415f9105c486d304a321168e61ad450f6153d77c69504ad764bb \ - --hash=sha256:932c55d1fa7af4423422cb6a492a31cbcbdbe68fd1a9a3f545aa5e7a143b5355 \ - --hash=sha256:a06ba7e2a2ef9131c79c49e63dad355d2d878413a0376c1727c8b9335ff731c7 \ - --hash=sha256:aebed7b50aa6ac698c90f60f854b47e48cd2252b30510e7a1feddaf5a3f72cbf \ - --hash=sha256:b172db79759f5f9bc13ef1c3ef8b9ee7b37b0247f987fbbbdaa15e4f87fd46a9 \ - --hash=sha256:b3c4ea4948d93c9c29dc01c0c23eef66f2101bf75158c291b88de6525c55c3d1 \ - --hash=sha256:b498e9e4022f93de2d5a37615200ca01297ceebbb56fe4c833f46862a490f9e3 \ - --hash=sha256:b4d41379b05528091f00e1728004f9a8d7191260f3862178b88e8fd770206318 \ - --hash=sha256:b69676845a0a66f9da30e87f48be36734d6748024b525ec4710be40194282c84 \ - --hash=sha256:c17398b709a6cce3d9fdb1595c33e356d91c098cd9486cb2cc21ea2ea418e715 \ - --hash=sha256:c380371d3c23e0eadf8ebff114445b9f970aff2010198d498d4ab4c3b41eea4f \ - --hash=sha256:cb783436e47fcf82064baca52ce748af71725d0352e1d31564cbe9c95df92b9c \ - --hash=sha256:cc1c51b846aca49a5a8b44fbba6a92d583a35c64590ad9e1e950dc88940a4297 \ - --hash=sha256:d0b181e9fa8daf1d9f2d4c547527b167cb8838fc587deabca7b5c01f97199e84 \ - --hash=sha256:d2a959c640cdeecdd2ec3136e8ea0441da59bcaf58d67e9c590740addba2cb68 \ - --hash=sha256:d5f256d49fea31f40f166a5e3131235a5d2f4b7f44520b1cf0baf1ce568ccff0 \ - --hash=sha256:d883460c43e8c6b173fef244a2341f7f7c0e9725c7fe68306e8e44ed9c8fb100 \ - --hash=sha256:d8eb7194b084b12feb19142262165832fc6ee879b945491d1c3d4660748020c4 \ - --hash=sha256:d9749313deb729f08207718d29c86246beb2ea3fdba753595b55901dee5d2fd6 \ - --hash=sha256:de66744b2bb88d5cd27e80dfc2ec9f0517d0a46d204ff98fe9e5f2864eb67657 \ - --hash=sha256:e91f61a064c92c307c5a9dc8c05dc9f8a68f0a3be199d9a002a0622e13f874a1 \ - --hash=sha256:f19410b486fdd139885ace124e57f938c1e6a3210ea13dd29cab58f5d4bc12c7 \ - --hash=sha256:f79d5de970fc90cd5591f60053aecfce1fcd736e0303d9f0bf86be649fa68fb8 \ - --hash=sha256:fba2974df0bf8ce3c995fa84b79cde38326e0f7b5409e7a3a481c1141340bcf7 +matplotlib==3.10.8 \ + --hash=sha256:00270d217d6b20d14b584c521f810d60c5c78406dc289859776550df837dcda7 \ + --hash=sha256:0a33deb84c15ede243aead39f77e990469fff93ad1521163305095b77b72ce4a \ + --hash=sha256:113bb52413ea508ce954a02c10ffd0d565f9c3bc7f2eddc27dfe1731e71c7b5f \ + --hash=sha256:12d90df9183093fcd479f4172ac26b322b1248b15729cb57f42f71f24c7e37a3 \ + --hash=sha256:15d30132718972c2c074cd14638c7f4592bd98719e2308bccea40e0538bc0cb5 \ + --hash=sha256:18821ace09c763ec93aef5eeff087ee493a24051936d7b9ebcad9662f66501f9 \ + --hash=sha256:1ae029229a57cd1e8fe542485f27e7ca7b23aa9e8944ddb4985d0bc444f1eca2 \ + --hash=sha256:2299372c19d56bcd35cf05a2738308758d32b9eaed2371898d8f5bd33f084aa3 \ + --hash=sha256:238b7ce5717600615c895050239ec955d91f321c209dd110db988500558e70d6 \ + --hash=sha256:24d50994d8c5816ddc35411e50a86ab05f575e2530c02752e02538122613371f \ + --hash=sha256:25d380fe8b1dc32cf8f0b1b448470a77afb195438bafdf1d858bfb876f3edf7b \ + --hash=sha256:2c1998e92cd5999e295a731bcb2911c75f597d937341f3030cc24ef2733d78a8 \ + --hash=sha256:2cf5bd12cecf46908f286d7838b2abc6c91cda506c0445b8223a7c19a00df008 \ + --hash=sha256:32f8dce744be5569bebe789e46727946041199030db8aeb2954d26013a0eb26b \ + --hash=sha256:37b3c1cc42aa184b3f738cfa18c1c1d72fd496d85467a6cf7b807936d39aa656 \ + --hash=sha256:3a48a78d2786784cc2413e57397981fb45c79e968d99656706018d6e62e57958 \ + --hash=sha256:3ab4aabc72de4ff77b3ec33a6d78a68227bf1123465887f9905ba79184a1cc04 \ + --hash=sha256:3c624e43ed56313651bc18a47f838b60d7b8032ed348911c54906b130b20071b \ + --hash=sha256:3f2e409836d7f5ac2f1c013110a4d50b9f7edc26328c108915f9075d7d7a91b6 \ + --hash=sha256:3f5c3e4da343bba819f0234186b9004faba952cc420fbc522dc4e103c1985908 \ + --hash=sha256:41703cc95688f2516b480f7f339d8851a6035f18e100ee6a32bc0b8536a12a9c \ + --hash=sha256:495672de149445ec1b772ff2c9ede9b769e3cb4f0d0aa7fa730d7f59e2d4e1c1 \ + --hash=sha256:4cf267add95b1c88300d96ca837833d4112756045364f5c734a2276038dae27d \ + --hash=sha256:56271f3dac49a88d7fca5060f004d9d22b865f743a12a23b1e937a0be4818ee1 \ + --hash=sha256:595ba4d8fe983b88f0eec8c26a241e16d6376fe1979086232f481f8f3f67494c \ + --hash=sha256:5f62550b9a30afde8c1c3ae450e5eb547d579dd69b25c2fc7a1c67f934c1717a \ + --hash=sha256:646d95230efb9ca614a7a594d4fcacde0ac61d25e37dd51710b36477594963ce \ + --hash=sha256:64fcc24778ca0404ce0cb7b6b77ae1f4c7231cdd60e6778f999ee05cbd581b9a \ + --hash=sha256:6be43b667360fef5c754dda5d25a32e6307a03c204f3c0fc5468b78fa87b4160 \ + --hash=sha256:6da7c2ce169267d0d066adcf63758f0604aa6c3eebf67458930f9d9b79ad1db1 \ + --hash=sha256:83d282364ea9f3e52363da262ce32a09dfe241e4080dcedda3c0db059d3c1f11 \ + --hash=sha256:9153c3292705be9f9c64498a8872118540c3f4123d1a1c840172edf262c8be4a \ + --hash=sha256:99eefd13c0dc3b3c1b4d561c1169e65fe47aab7b8158754d7c084088e2329466 \ + --hash=sha256:a0a7f52498f72f13d4a25ea70f35f4cb60642b466cbb0a9be951b5bc3f45a486 \ + --hash=sha256:a2b336e2d91a3d7006864e0990c83b216fcdca64b5a6484912902cef87313d78 \ + --hash=sha256:a48f2b74020919552ea25d222d5cc6af9ca3f4eb43a93e14d068457f545c2a17 \ + --hash=sha256:ad3d9833a64cf48cc4300f2b406c3d0f4f4724a91c0bd5640678a6ba7c102077 \ + --hash=sha256:b44d07310e404ba95f8c25aa5536f154c0a8ec473303535949e52eb71d0a1565 \ + --hash=sha256:b53285e65d4fa4c86399979e956235deb900be5baa7fc1218ea67fbfaeaadd6f \ + --hash=sha256:b5a2b97dbdc7d4f353ebf343744f1d1f1cca8aa8bfddb4262fcf4306c3761d50 \ + --hash=sha256:b9a5ca4ac220a0cdd1ba6bcba3608547117d30468fefce49bb26f55c1a3d5c58 \ + --hash=sha256:bab485bcf8b1c7d2060b4fcb6fc368a9e6f4cd754c9c2fea281f4be21df394a2 \ + --hash=sha256:c108a1d6fa78a50646029cb6d49808ff0fc1330fda87fa6f6250c6b5369b6645 \ + --hash=sha256:d56a1efd5bfd61486c8bc968fa18734464556f0fb8e51690f4ac25d85cbbbbc2 \ + --hash=sha256:d9050fee89a89ed57b4fb2c1bfac9a3d0c57a0d55aed95949eedbc42070fea39 \ + --hash=sha256:dd80ecb295460a5d9d260df63c43f4afbdd832d725a531f008dad1664f458adf \ + --hash=sha256:e8ea3e2d4066083e264e75c829078f9e149fa119d27e19acd503de65e0b13149 \ + --hash=sha256:eb3823f11823deade26ce3b9f40dcb4a213da7a670013929f31d5f5ed1055b22 \ + --hash=sha256:ee40c27c795bda6a5292e9cff9890189d32f7e3a0bf04e0e3c9430c4a00c37df \ + --hash=sha256:efb30e3baaea72ce5928e32bab719ab4770099079d66726a62b11b1ef7273be4 \ + --hash=sha256:f254d118d14a7f99d616271d6c3c27922c092dac11112670b157798b89bf4933 \ + --hash=sha256:f89c151aab2e2e23cb3fe0acad1e8b82841fd265379c4cecd0f3fcb34c15e0f6 \ + --hash=sha256:f97aeb209c3d2511443f8797e3e5a569aebb040d4f8bc79aa3ee78a8fb9e3dd8 \ + --hash=sha256:f9b587c9c7274c1613a30afabf65a272114cd6cdbe67b3406f818c79d7ab2e2a \ + --hash=sha256:fb061f596dad3a0f52b60dc6a5dec4a0c300dec41e058a7efe09256188d170b7 # via -r build/test-requirements.txt mdurl==0.1.2 \ --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ @@ -547,9 +547,9 @@ numpy==2.3.5 ; python_version >= "3.14" \ # numpy-typing-compat # optype # scipy -numpy-typing-compat==20250818.2.3 \ - --hash=sha256:72e83d535b635d668ba7315e43ae80be1469a6faea6fc96d312516f39b3d8fa5 \ - --hash=sha256:930413d34dd9083c0bf418815576222f1c66ea2d68950f447fd27ea1a78b26b0 +numpy-typing-compat==20251206.2.4 \ + --hash=sha256:59882d23aaff054a2536da80564012cdce33487657be4d79c5925bb8705fcabc \ + --hash=sha256:a82e723bd20efaa4cf2886709d4264c144f1f2b609bda83d1545113b7e47a5b5 # via optype nvidia-cublas==13.1.0.3 ; sys_platform == "linux" \ --hash=sha256:2a3b94a37def342471c59fad7856caee4926809a72dd5270155d6a31b5b277be \ @@ -705,9 +705,9 @@ opt-einsum==3.4.0 \ --hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \ --hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac # via -r build/requirements.in -optype[numpy]==0.14.0 \ - --hash=sha256:50d02edafd04edf2e5e27d6249760a51b2198adb9f6ffd778030b3d2806b026b \ - --hash=sha256:925cf060b7d1337647f880401f6094321e7d8e837533b8e159b9a92afa3157c6 +optype[numpy]==0.15.0 \ + --hash=sha256:457d6ca9e7da19967ec16d42bdf94e240b33b5d70a56fbbf5b427e5ea39cf41e \ + --hash=sha256:caba40ece9ea39b499fa76c036a82e0d452a432dd4dd3e8e0d30892be2e8c76c # via scipy-stubs packaging==25.0 \ --hash=sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484 \ @@ -943,9 +943,9 @@ scipy==1.16.3 ; python_version >= "3.13" \ # via # -r build/requirements.in # jaxlib -scipy-stubs==1.16.3.0 \ - --hash=sha256:90e5d82ced2183ef3c5c0a28a77df8cc227458624364fa0ff975ad24fa89d6ad \ - --hash=sha256:d6943c085e47a1ed431309f9ca582b6a206a9db808a036132a0bf01ebc34b506 +scipy-stubs==1.16.3.3 \ + --hash=sha256:af47578875d5557567225a16ec1b9b38a48c4c4377d92396413ebd65406c44ee \ + --hash=sha256:f6316b36cd0fb272c994ae5b10c4a73c644a7e156ed8d32bcd9c35303d0e1b7e # via -r build/test-requirements.txt six==1.17.0 \ --hash=sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274 \ From ea14daa79a8cf7216c416cb701e14d060cc5b201 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 18 Dec 2025 17:32:51 -0800 Subject: [PATCH 280/315] Fix multi_broadcast_in_dim which was doing replicated -> unreduced casts which was unnecessary. PiperOrigin-RevId: 846477858 --- jax/_src/interpreters/mlir.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 4be9077d4e34..b501d7bdb298 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -2675,9 +2675,15 @@ def multi_broadcast_in_dim(ctx: LoweringRuleContext, out_aval = core.ShapedArray( out_shape, op_aval.dtype, sharding=out_sharding) # type: ignore if core.definitely_equal_shape(op_aval_shape, out_shape): - out.append(op if op_aval_sharding == out_sharding else - lower_with_sharding_in_types(ctx, op, out_aval)) + if op_aval_sharding.spec.unreduced or op_aval_sharding.spec.reduced: + out.append(op) + elif op_aval_sharding == out_sharding: + out.append(op) + else: + out.append(lower_with_sharding_in_types(ctx, op, out_aval)) else: + if op_aval_sharding.spec.unreduced or op_aval_sharding.spec.reduced: + raise NotImplementedError() assert len(op_aval_shape) <= len(out_shape), (op_aval_shape, out_shape) broadcast_dimensions = list(range(len(out_shape) - len(op_aval_shape), len(out_shape))) b_out = broadcast_in_dim( From 996cbeb0aee791a7757a78bb66e02715b0884cb6 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 18 Dec 2025 20:23:00 -0800 Subject: [PATCH 281/315] anselm refs PiperOrigin-RevId: 846525618 --- jax/_src/state/primitives.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 42dfc614298f..69d26237db54 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -1116,7 +1116,8 @@ def _ref_lin(nzs, x, *, memory_space, kind): x_ref = core.ref_p.bind(x, memory_space=memory_space, kind=kind) def mut_lin(_, x_dot): if kind == 'anselm_ref': - return ad.Zero(AbstractRef(core.typeof(x_dot))) + aval = x_dot.aval if type(x_dot) is ad.Zero else core.typeof(x_dot) + return ad.Zero(AbstractRef(aval)) zero = ad_util.instantiate(x_dot) return core.ref_p.bind(zero, memory_space=memory_space, kind=kind) return x_ref, kind != 'anselm_ref', None, mut_lin From b54bdaf477063e0218f1f1716b29fc3a4379b3fd Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 18 Dec 2025 21:01:00 -0800 Subject: [PATCH 282/315] Automated Code Change PiperOrigin-RevId: 846536230 --- jaxlib/ffi.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/jaxlib/ffi.cc b/jaxlib/ffi.cc index ff1dd96958f5..e5f31079b332 100644 --- a/jaxlib/ffi.cc +++ b/jaxlib/ffi.cc @@ -21,6 +21,9 @@ limitations under the License. #include #include #include +#include +#include +#include #include #include "absl/base/casts.h" From 842804da390fe0eae7ac221bbd789098adf72871 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 19 Dec 2025 00:06:28 -0800 Subject: [PATCH 283/315] Update XLA dependency to use revision http://github.com/openxla/xla/commit/c45b8fed642a7ce99e315f979dee6fc45e08f79f PiperOrigin-RevId: 846596619 --- third_party/xla/revision.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/revision.bzl b/third_party/xla/revision.bzl index 02a6712d7e3d..5adedc2e48ce 100644 --- a/third_party/xla/revision.bzl +++ b/third_party/xla/revision.bzl @@ -21,5 +21,5 @@ # and update XLA_SHA256 with the result. # buildifier: disable=module-docstring -XLA_COMMIT = "913ae2eaa3cb88971003592a90959685a78c9e30" -XLA_SHA256 = "b0420fdca3789e659e314cae7ee38d1f13c613c458c00376a9d44dde51740d7f" +XLA_COMMIT = "c45b8fed642a7ce99e315f979dee6fc45e08f79f" +XLA_SHA256 = "75ba4f7a261fa43834791c218b8a8909d003fcf4bf28426f32fabcbf09682352" From fe0c066b0a831f6c481b5f035f20ac2b83fb2e5a Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Fri, 19 Dec 2025 00:52:18 -0800 Subject: [PATCH 284/315] [Pallas:MGPU] Add support for indexing untiled dimensions under WG semantic. Indexing is lowered to `lax.slice_p` + `lax.squeeze_p`. We add the missing rule for `lax.squeeze_p` under WG semantics. We lower `lax.squeeze_p` to `vector.shape_cast`. So `x[1, 1]` where `x` is a JAX array will be lowered to: ``` %1 = vector.extract_strided_slice %0 {offsets = [1, 1, ...], sizes = [1, 1, ...]} : vector to vector<1x1x...> %2 = vector.shape_cast %2 : vector<1x1x...> to vector<...> ``` This gets simplified to: ``` %2 = vector.extract %0[1, 1] : vector<...> from vector ``` So we need to add support for `vector.extract` to the Mosaic GPU dialect. We add the respective lowering and layout inference rules. PiperOrigin-RevId: 846611677 --- jax/_src/pallas/mosaic_gpu/lowering.py | 13 +++++++ .../mosaic/gpu/dialect_lowering.py | 21 ++++++++++++ .../mosaic/gpu/layout_inference.py | 25 ++++++++++++++ tests/mosaic/gpu_layout_inference_test.py | 25 ++++++++++++++ tests/pallas/mosaic_gpu_test.py | 34 ++++++++++++++++--- 5 files changed, 114 insertions(+), 4 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 481d038178d5..edfe33cb8134 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -2441,6 +2441,19 @@ def _squeeze_lowering_rule(ctx: LoweringRuleContext, x, dimensions): return _ensure_fa(x, x_aval.dtype).reshape(y_aval.shape) +@register_lowering_rule(lax.squeeze_p, mgpu.LoweringSemantics.Warpgroup) +def _squeeze_lowering_rule_wg(ctx: LoweringRuleContext, x, dimensions): + [x_aval] = ctx.avals_in + [y_aval] = ctx.avals_out + x = _ensure_ir_value(x, x_aval.dtype) + if y_aval.ndim == 0: # scalar + # TODO(allanrenucci): Lower to `vector.extract` once we support scalar + # results in MGPU dialect lowering. + raise NotImplementedError("Squeeze to scalar is not supported.") + res_ty = ir.VectorType.get(y_aval.shape, ir.VectorType(x.type).element_type) + return vector_dialect.shape_cast(res_ty, x) + + def _reduce_lowering_rule(op, ctx: LoweringRuleContext, x, *, axes, **kwargs): [x_aval] = ctx.avals_in match x.layout: diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 7e3cc39783da..fc3c10d79150 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -659,6 +659,27 @@ def _vector_extract_strided_slice_op_lowering_rule( return [fragmented_array_to_ir(result, out_vec_ty)] +@_register_lowering(vector.ExtractOp) +def _vector_extract_op_lowering_rule( + ctx: LoweringContext, op: vector.ExtractOp +) -> Sequence[ir.Value]: + del ctx + if not ir.VectorType.isinstance(op.result.type): + raise NotImplementedError("Scalar element extraction is not supported.") + if op.dynamic_position: + raise NotImplementedError("Only slicing with static indices allowed.") + [in_layout] = inference_utils.in_layouts(op) + [out_layout] = inference_utils.out_layouts(op) + assert in_layout == out_layout + a = _fragmented_array_from_ir(op.source, in_layout) + result_type = ir.VectorType(op.result.type) + slices = tuple(slice(i, i + 1) for i in op.static_position) + # TODO(allanrenucci): Add direct support for indexing to FragmentedArray. + result = a[slices].reshape(tuple(result_type.shape)) + assert result.layout == layouts.from_layout_attr(out_layout) + return [fragmented_array_to_ir(result, result_type)] + + @_register_lowering(vector.ReductionOp) def _vector_reduction_op_lowering_rule( ctx: LoweringContext, op: vector.ReductionOp diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index 542811ee509f..280200d5ca3d 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -1103,6 +1103,31 @@ def _extract_strided_slice_constraint_system( ) +@_add_constraint_system_derivation_rule(vector.ExtractOp) +def _vector_extract_constraint_system( + ctx: DerivationContext, op: vector.ExtractOp +) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: + del ctx + if not ir.VectorType.isinstance(op.result.type): + raise NotImplementedError("Scalar element extraction is not supported.") + if op.dynamic_position: + raise NotImplementedError("Only slicing with static indices allowed.") + operand = ValueSite(op, VariableType.OPERAND, 0) + result = ValueSite(op, VariableType.RESULT, 0) + variable = cs.Variable(operand) + constraints = [ + cs.Divides(variable, tuple(op.result.type.shape)), + # TODO(allanrenucci): Remove once vectors with splat and strided layouts + # can be sliced. + cs.NotOfType(variable, fa.WGSplatFragLayout), + cs.NotOfType(variable, fa.WGStridedFragLayout), + ] + return ( + cs.ConstraintSystem(constraints=constraints), + {variable: [operand, result]}, + ) + + @_add_constraint_system_derivation_rule(mgpu.CustomPrimitiveOp) def _custom_primitive_constraint_system( ctx: DerivationContext, diff --git a/tests/mosaic/gpu_layout_inference_test.py b/tests/mosaic/gpu_layout_inference_test.py index c52a0bdc1d46..ed7ae72c1ced 100644 --- a/tests/mosaic/gpu_layout_inference_test.py +++ b/tests/mosaic/gpu_layout_inference_test.py @@ -1979,6 +1979,31 @@ def test_infer_layout_for_vector_extract_strided_slice_fails( ): mgpu.infer_layout(self.module) + def test_infer_layout_for_vector_extract(self): + layout = layouts.to_layout_attr(fa.WGMMA_LAYOUT) + with ir.InsertionPoint(self.module.body): + i16 = ir.IntegerType.get_signless(16) + src_ty = ir.VectorType.get([2, 3, 64, 8], i16) + [src] = undefs(src_ty) + src = mgpu.dialect.layout_cast(src, layout) + op = vector.ExtractOp(src, dynamic_position=[], static_position=[1, 1]) + mgpu.infer_layout(self.module) + self.checkInLayouts(op, [layout]) + self.checkOutLayouts(op, [layout]) + + def test_infer_layout_for_vector_extract_fails_if_not_dividing_result_shape(self): + layout = layouts.to_layout_attr(fa.WGMMA_LAYOUT) + with ir.InsertionPoint(self.module.body): + i16 = ir.IntegerType.get_signless(16) + src_ty = ir.VectorType.get([64, 64], i16) + [src] = undefs(src_ty) + src = mgpu.dialect.layout_cast(src, layout) + vector.extract(src, dynamic_position=[], static_position=[0]) + with self.assertRaisesRegex( + ValueError, "Failed to infer a possible set of layouts." + ): + mgpu.infer_layout(self.module) + def test_infer_tmem_layout_for_slice_tmem_op(self): # in and out layouts can be different. in_layout = layouts.to_layout_attr(tcgen05.tmem_default_layout(packing=1)) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 5b65671f91c7..6628fd73be6b 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -108,6 +108,22 @@ def fn(_): return fn() +def _array_splat(value, shape: tuple[int, ...]): + """Same as `jnp.full(shape, value, jnp.float32)` but implemented using `inline_mgpu`. + + This is useful to prevent the result from being optimized away. + """ + @plgpu.inline_mgpu( + return_type=plgpu.ShapeDtypeStruct( + shape, jnp.float32, layout=plgpu.Layout.WG_SPLAT(shape) + ), + ) + def fn(_): + ir_value = mgpu.c(value, ir.F32Type.get()) + return mgpu.FragmentedArray.splat(ir_value, shape) + return fn() + + class PallasTestMetaclass(parameterized.TestGeneratorMetaclass): def __new__(mcs, *args, lowering_semantics=plgpu.LoweringSemantics.Lane): @@ -352,7 +368,6 @@ def kernel(out_ref): ) def test_slice_untiled_dim(self): - self.skip_if_wg_semantics() shape = (2, 3, 64, 8) @functools.partial( @@ -360,12 +375,24 @@ def test_slice_untiled_dim(self): out_shape=jax.ShapeDtypeStruct(shape[2:], jnp.float32), ) def kernel(x_ref, out_ref): - y = plgpu.load(x_ref, (), layout=plgpu.Layout.WGMMA, optimized=False)[1, 1] - out_ref[...] = y + x = plgpu.load(x_ref, (), layout=plgpu.Layout.WGMMA, optimized=False) + out_ref[...] = x[1, 1] x = jnp.arange(math.prod(shape)).reshape(shape).astype(jnp.float32) np.testing.assert_array_equal(kernel(x), x[1, 1]) + def test_squeeze_to_scalar(self): + self.skip_if_wg_semantics() # Scalar element extraction is not supported for `vector.extract`. + @functools.partial( + self.kernel, + out_shape=jax.ShapeDtypeStruct((), jnp.float32), + ) + def kernel(out_ref): + x = _array_splat(42, (1, 1, 1)) + out_ref[...] = lax.squeeze(x, dimensions=(0, 1, 2)) + + np.testing.assert_array_equal(kernel(), jnp.array(42, dtype=jnp.float32)) + def test_add_xy_indexed(self): @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.float32) @@ -2863,7 +2890,6 @@ def test_missing_primitive_lowerings_are_tracked(self): pallas_primitives.semaphore_read_p, pallas_primitives.delay_p, checkify.check_p, - lax.squeeze_p, } self.assertSetEqual(actual_missing_primitives, expected_missing_primitives) From b4447d96759686bc9e4bdfe164a132b09ab71fb8 Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Fri, 19 Dec 2025 03:35:20 -0800 Subject: [PATCH 285/315] [Mosaic GPU] Add support for indexing into splat array. We allows scalar results when lowering `vector.extract` for inputs with splat layout. We extend `FragmentedArray.__getitem__` to support arbitrary indexing for splat layouts. PiperOrigin-RevId: 846661907 --- jax/_src/pallas/mosaic_gpu/lowering.py | 11 +++++----- .../mosaic/gpu/dialect_lowering.py | 10 +++++++-- .../mosaic/gpu/fragmented_array.py | 5 ++++- .../mosaic/gpu/layout_inference.py | 10 +++++++-- tests/mosaic/gpu_layout_inference_test.py | 10 +++++++++ tests/mosaic/gpu_test.py | 22 +++++++++++++++++++ tests/pallas/mosaic_gpu_test.py | 1 - 7 files changed, 58 insertions(+), 11 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index edfe33cb8134..57834b6b7f63 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -2447,11 +2447,12 @@ def _squeeze_lowering_rule_wg(ctx: LoweringRuleContext, x, dimensions): [y_aval] = ctx.avals_out x = _ensure_ir_value(x, x_aval.dtype) if y_aval.ndim == 0: # scalar - # TODO(allanrenucci): Lower to `vector.extract` once we support scalar - # results in MGPU dialect lowering. - raise NotImplementedError("Squeeze to scalar is not supported.") - res_ty = ir.VectorType.get(y_aval.shape, ir.VectorType(x.type).element_type) - return vector_dialect.shape_cast(res_ty, x) + return vector_dialect.extract( + x, dynamic_position=[], static_position=[0] * x_aval.ndim + ) + else: + res_ty = ir.VectorType.get(y_aval.shape, ir.VectorType(x.type).element_type) + return vector_dialect.shape_cast(res_ty, x) def _reduce_lowering_rule(op, ctx: LoweringRuleContext, x, *, axes, **kwargs): diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index fc3c10d79150..1d4f46c1cf9f 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -664,11 +664,17 @@ def _vector_extract_op_lowering_rule( ctx: LoweringContext, op: vector.ExtractOp ) -> Sequence[ir.Value]: del ctx - if not ir.VectorType.isinstance(op.result.type): - raise NotImplementedError("Scalar element extraction is not supported.") if op.dynamic_position: raise NotImplementedError("Only slicing with static indices allowed.") + [in_layout] = inference_utils.in_layouts(op) + a = _fragmented_array_from_ir(op.source, in_layout) + + if not ir.VectorType.isinstance(op.result.type): # scalar result + result = a[tuple(op.static_position)] + assert isinstance(result.layout, fa.WGSplatFragLayout) + return [result.registers.item()] + [out_layout] = inference_utils.out_layouts(op) assert in_layout == out_layout a = _fragmented_array_from_ir(op.source, in_layout) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 795a3c3062c4..0093ed01caf0 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1725,9 +1725,12 @@ def bitcast( ) def __getitem__(self, idx) -> FragmentedArray: + base_idx, slice_shape, is_squeezed = utils.parse_indices(idx, self.shape) + if isinstance(self.layout, WGSplatFragLayout): + shape = tuple(d for d, s in zip(slice_shape, is_squeezed) if not s) + return self.splat(self.registers.item(), shape, is_signed=self.is_signed) if not isinstance(self.layout, TiledLayout): raise NotImplementedError("Only arrays with tiled layouts can be sliced") - base_idx, slice_shape, is_squeezed = utils.parse_indices(idx, self.shape) if any(isinstance(idx, ir.Value) for idx in base_idx): raise ValueError("Only slicing with static indices allowed") if any(is_squeezed): diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index 280200d5ca3d..cca47f1775aa 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -1108,8 +1108,14 @@ def _vector_extract_constraint_system( ctx: DerivationContext, op: vector.ExtractOp ) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: del ctx - if not ir.VectorType.isinstance(op.result.type): - raise NotImplementedError("Scalar element extraction is not supported.") + if not ir.VectorType.isinstance(op.result.type): # scalar result + operand = ValueSite(op, VariableType.OPERAND, 0) + variable = cs.Variable(operand) + layout = fa.WGSplatFragLayout(tuple(op.source.type.shape)) + # We only support indexing for splat layout. + assignments = {variable: cs.RegisterLayout(layout)} + return cs.ConstraintSystem(assignments), {variable: [operand]} + if op.dynamic_position: raise NotImplementedError("Only slicing with static indices allowed.") operand = ValueSite(op, VariableType.OPERAND, 0) diff --git a/tests/mosaic/gpu_layout_inference_test.py b/tests/mosaic/gpu_layout_inference_test.py index ed7ae72c1ced..facacb604b51 100644 --- a/tests/mosaic/gpu_layout_inference_test.py +++ b/tests/mosaic/gpu_layout_inference_test.py @@ -1991,6 +1991,16 @@ def test_infer_layout_for_vector_extract(self): self.checkInLayouts(op, [layout]) self.checkOutLayouts(op, [layout]) + def test_infer_layout_for_vector_extract_to_scalar(self): + with ir.InsertionPoint(self.module.body): + i16 = ir.IntegerType.get_signless(16) + src_ty = ir.VectorType.get([64, 8], i16) + [src] = undefs(src_ty) + op = vector.ExtractOp(src, dynamic_position=[], static_position=[1, 1]) + mgpu.infer_layout(self.module) + self.checkInLayouts(op, [mgpu.WGSplatFragLayout(tuple(src_ty.shape))]) + self.assertNotIn("out_layouts", op.attributes) + def test_infer_layout_for_vector_extract_fails_if_not_dividing_result_shape(self): layout = layouts.to_layout_attr(fa.WGMMA_LAYOUT) with ir.InsertionPoint(self.module.body): diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index c9c919daedf6..a81e6b29d043 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -4149,6 +4149,28 @@ def kernel(ctx, src, dst, scratch): )(x) np.testing.assert_array_equal(y, x) + @parameterized.parameters( + ((32, 32), (0, 5)), + ((32, 128), (3,)), + ((32, 32, 128), (slice(1, 3), 0)), + ) + def test_splat_indexing(self, shape, indices): + def _kernel(ctx, out_ref, scratch): + del ctx, scratch + splat = mgpu.FragmentedArray.splat(c(1.0, ir.F32Type.get()), shape) + splat[indices].store_untiled(out_ref) + + expected = np.ones(shape, dtype=jnp.float32)[indices] + kernel = mgpu.as_gpu_kernel( + _kernel, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(), + out_shape=expected, + smem_scratch_shape=(), + ) + np.testing.assert_array_equal(kernel(), expected) + class ProfilerTest(TestCase, jtu.JaxTestCase): diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 6628fd73be6b..7b5d844e7015 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -382,7 +382,6 @@ def kernel(x_ref, out_ref): np.testing.assert_array_equal(kernel(x), x[1, 1]) def test_squeeze_to_scalar(self): - self.skip_if_wg_semantics() # Scalar element extraction is not supported for `vector.extract`. @functools.partial( self.kernel, out_shape=jax.ShapeDtypeStruct((), jnp.float32), From 8c0c43c02728dfadec12b5aede467fde9bdb6239 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Melissa=20Weber=20Mendon=C3=A7a?= Date: Wed, 10 Dec 2025 18:23:57 -0300 Subject: [PATCH 286/315] Complete reorganization of tutorials and guides * Split Advanced AD into sections * Add debugging section to quickstart * Move Is JAX faster than NumPy question to H2 * Move discussion on jitting class methods to Sharp Bits * Fix section label * Fix ToC --- docs/advanced-autodiff.md | 1777 ----------------- docs/advanced_autodiff.md | 11 + docs/advanced_guides.rst | 4 +- docs/automatic-differentiation.md | 4 +- docs/complex-differentiation.md | 207 ++ docs/debugging.md | 8 - docs/debugging/flags.md | 24 +- docs/debugging/index.md | 3 +- docs/faq.rst | 189 +- docs/gradient-checkpointing.md | 4 +- docs/higher-order.md | 336 ++++ docs/jacobian-vector-products.md | 358 ++++ docs/jax-primitives.md | 2 +- docs/notebooks/Common_Gotchas_in_JAX.ipynb | 380 +++- docs/notebooks/Common_Gotchas_in_JAX.md | 187 +- ...tom_derivative_rules_for_Python_code.ipynb | 50 +- ...Custom_derivative_rules_for_Python_code.md | 45 +- docs/notebooks/autodiff_cookbook.ipynb | 2 +- docs/notebooks/autodiff_cookbook.md | 2 +- docs/notebooks/thinking_in_jax.ipynb | 68 + docs/notebooks/thinking_in_jax.md | 43 + 21 files changed, 1656 insertions(+), 2048 deletions(-) delete mode 100644 docs/advanced-autodiff.md create mode 100644 docs/advanced_autodiff.md create mode 100644 docs/complex-differentiation.md create mode 100644 docs/higher-order.md create mode 100644 docs/jacobian-vector-products.md diff --git a/docs/advanced-autodiff.md b/docs/advanced-autodiff.md deleted file mode 100644 index f8b5000c2b47..000000000000 --- a/docs/advanced-autodiff.md +++ /dev/null @@ -1,1777 +0,0 @@ ---- -jupytext: - formats: md:myst - text_representation: - extension: .md - format_name: myst - format_version: 0.13 - jupytext_version: 1.16.4 -kernelspec: - display_name: Python 3 - language: python - name: python3 ---- - -(advanced-autodiff)= -# Advanced automatic differentiation - - - -In this tutorial, you will learn about complex applications of automatic differentiation (autodiff) in JAX and gain a better understanding of how taking derivatives in JAX can be both easy and powerful. - -Make sure to check out the {ref}`automatic-differentiation` tutorial to go over the JAX autodiff basics, if you haven't already. - -## Setup - -```{code-cell} -import jax -import jax.numpy as jnp -from jax import grad, jit, vmap -from jax import random - -key = random.key(0) -``` - -## Taking gradients (part 2) - -### Higher-order derivatives - -JAX's autodiff makes it easy to compute higher-order derivatives, because the functions that compute derivatives are themselves differentiable. Thus, higher-order derivatives are as easy as stacking transformations. - -The single-variable case was covered in the {ref}`automatic-differentiation` tutorial, where the example showed how to use {func}`jax.grad` to compute the derivative of $f(x) = x^3 + 2x^2 - 3x + 1$. - -In the multivariable case, higher-order derivatives are more complicated. The second-order derivative of a function is represented by its [Hessian matrix](https://en.wikipedia.org/wiki/Hessian_matrix), defined according to: - -$$(\mathbf{H}f)_{i,j} = \frac{\partial^2 f}{\partial_i\partial_j}.$$ - -The Hessian of a real-valued function of several variables, $f: \mathbb R^n\to\mathbb R$, can be identified with the [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant) of its gradient. - -JAX provides two transformations for computing the Jacobian of a function, {func}`jax.jacfwd` and {func}`jax.jacrev`, corresponding to forward- and reverse-mode autodiff. They give the same answer, but one can be more efficient than the other in different circumstances – refer to the [video about autodiff](https://www.youtube.com/watch?v=wG_nF1awSSY). - -```{code-cell} -def hessian(f): - return jax.jacfwd(jax.grad(f)) -``` - -Let's double check this is correct on the dot-product $f: \mathbf{x} \mapsto \mathbf{x} ^\top \mathbf{x}$. - -if $i=j$, $\frac{\partial^2 f}{\partial_i\partial_j}(\mathbf{x}) = 2$. Otherwise, $\frac{\partial^2 f}{\partial_i\partial_j}(\mathbf{x}) = 0$. - -```{code-cell} -def f(x): - return jnp.dot(x, x) - -hessian(f)(jnp.array([1., 2., 3.])) -``` - -## Higher-order optimization - -Some meta-learning techniques, such as Model-Agnostic Meta-Learning ([MAML](https://arxiv.org/abs/1703.03400)), require differentiating through gradient updates. In other frameworks this can be quite cumbersome, but in JAX it's much easier: - -```python -def meta_loss_fn(params, data): - """Computes the loss after one step of SGD.""" - grads = jax.grad(loss_fn)(params, data) - return loss_fn(params - lr * grads, data) - -meta_grads = jax.grad(meta_loss_fn)(params, data) -``` - -(stopping-gradients)= -### Stopping gradients - -Autodiff enables automatic computation of the gradient of a function with respect to its inputs. Sometimes, however, you might want some additional control: for instance, you might want to avoid backpropagating gradients through some subset of the computational graph. - -Consider for instance the TD(0) ([temporal difference](https://en.wikipedia.org/wiki/Temporal_difference_learning)) reinforcement learning update. This is used to learn to estimate the *value* of a state in an environment from experience of interacting with the environment. Let's assume the value estimate $v_{\theta}(s_{t-1}$) in a state $s_{t-1}$ is parameterised by a linear function. - -```{code-cell} -# Value function and initial parameters -value_fn = lambda theta, state: jnp.dot(theta, state) -theta = jnp.array([0.1, -0.1, 0.]) -``` - -Consider a transition from a state $s_{t-1}$ to a state $s_t$ during which you observed the reward $r_t$ - -```{code-cell} -# An example transition. -s_tm1 = jnp.array([1., 2., -1.]) -r_t = jnp.array(1.) -s_t = jnp.array([2., 1., 0.]) -``` - -The TD(0) update to the network parameters is: - -$$ -\Delta \theta = (r_t + v_{\theta}(s_t) - v_{\theta}(s_{t-1})) \nabla v_{\theta}(s_{t-1}) -$$ - -This update is not the gradient of any loss function. - -However, it can be **written** as the gradient of the pseudo loss function - -$$ -L(\theta) = - \frac{1}{2} [r_t + v_{\theta}(s_t) - v_{\theta}(s_{t-1})]^2 -$$ - -if the dependency of the target $r_t + v_{\theta}(s_t)$ on the parameter $\theta$ is ignored. - -How can you implement this in JAX? If you write the pseudo loss naively, you get: - -```{code-cell} -def td_loss(theta, s_tm1, r_t, s_t): - v_tm1 = value_fn(theta, s_tm1) - target = r_t + value_fn(theta, s_t) - return -0.5 * ((target - v_tm1) ** 2) - -td_update = jax.grad(td_loss) -delta_theta = td_update(theta, s_tm1, r_t, s_t) - -delta_theta -``` - -But `td_update` will **not** compute a TD(0) update, because the gradient computation will include the dependency of `target` on $\theta$. - -You can use {func}`jax.lax.stop_gradient` to force JAX to ignore the dependency of the target on $\theta$: - -```{code-cell} -def td_loss(theta, s_tm1, r_t, s_t): - v_tm1 = value_fn(theta, s_tm1) - target = r_t + value_fn(theta, s_t) - return -0.5 * ((jax.lax.stop_gradient(target) - v_tm1) ** 2) - -td_update = jax.grad(td_loss) -delta_theta = td_update(theta, s_tm1, r_t, s_t) - -delta_theta -``` - -This will treat `target` as if it did **not** depend on the parameters $\theta$ and compute the correct update to the parameters. - -Now, let's also calculate $\Delta \theta$ using the original TD(0) update expression, to cross-check our work. You may wish to try and implement this yourself using {func}`jax.grad` and your knowledge so far. Here's our solution: - -```{code-cell} -s_grad = jax.grad(value_fn)(theta, s_tm1) -delta_theta_original_calculation = (r_t + value_fn(theta, s_t) - value_fn(theta, s_tm1)) * s_grad - -delta_theta_original_calculation # [1.2, 2.4, -1.2], same as `delta_theta` -``` - -`jax.lax.stop_gradient` may also be useful in other settings, for instance if you want the gradient from some loss to only affect a subset of the parameters of the neural network (because, for instance, the other parameters are trained using a different loss). - - -### Straight-through estimator using `stop_gradient` - -The straight-through estimator is a trick for defining a 'gradient' of a function that is otherwise non-differentiable. Given a non-differentiable function $f : \mathbb{R}^n \to \mathbb{R}^n$ that is used as part of a larger function that we wish to find a gradient of, we simply pretend during the backward pass that $f$ is the identity function. This can be implemented neatly using `jax.lax.stop_gradient`: - -```{code-cell} -def f(x): - return jnp.round(x) # non-differentiable - -def straight_through_f(x): - # Create an exactly-zero expression with Sterbenz lemma that has - # an exactly-one gradient. - zero = x - jax.lax.stop_gradient(x) - return zero + jax.lax.stop_gradient(f(x)) - -print("f(x): ", f(3.2)) -print("straight_through_f(x):", straight_through_f(3.2)) - -print("grad(f)(x):", jax.grad(f)(3.2)) -print("grad(straight_through_f)(x):", jax.grad(straight_through_f)(3.2)) -``` - -### Per-example gradients - -While most ML systems compute gradients and updates from batches of data, for reasons of computational efficiency and/or variance reduction, it is sometimes necessary to have access to the gradient/update associated with each specific sample in the batch. - -For instance, this is needed to prioritize data based on gradient magnitude, or to apply clipping / normalisations on a sample by sample basis. - -In many frameworks (PyTorch, TF, Theano) it is often not trivial to compute per-example gradients, because the library directly accumulates the gradient over the batch. Naive workarounds, such as computing a separate loss per example and then aggregating the resulting gradients are typically very inefficient. - -In JAX, you can define the code to compute the gradient per-sample in an easy but efficient way. - -Just combine the {func}`jax.jit`, {func}`jax.vmap` and {func}`jax.grad` transformations together: - -```{code-cell} -perex_grads = jax.jit(jax.vmap(jax.grad(td_loss), in_axes=(None, 0, 0, 0))) - -# Test it: -batched_s_tm1 = jnp.stack([s_tm1, s_tm1]) -batched_r_t = jnp.stack([r_t, r_t]) -batched_s_t = jnp.stack([s_t, s_t]) - -perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t) -``` - -Let's go through this one transformation at a time. - -First, you apply {func}`jax.grad` to `td_loss` to obtain a function that computes the gradient of the loss w.r.t. the parameters on single (unbatched) inputs: - -```{code-cell} -dtdloss_dtheta = jax.grad(td_loss) - -dtdloss_dtheta(theta, s_tm1, r_t, s_t) -``` - -This function computes one row of the array above. - -Then, you vectorise this function using {func}`jax.vmap`. This adds a batch dimension to all inputs and outputs. Now, given a batch of inputs, you produce a batch of outputs — each output in the batch corresponds to the gradient for the corresponding member of the input batch. - -```{code-cell} -almost_perex_grads = jax.vmap(dtdloss_dtheta) - -batched_theta = jnp.stack([theta, theta]) -almost_perex_grads(batched_theta, batched_s_tm1, batched_r_t, batched_s_t) -``` - -This isn't quite what we want, because we have to manually feed this function a batch of `theta`s, whereas we actually want to use a single `theta`. We fix this by adding `in_axes` to the {func}`jax.vmap`, specifying theta as `None`, and the other args as `0`. This makes the resulting function add an extra axis only to the other arguments, leaving `theta` unbatched, as we want: - -```{code-cell} -inefficient_perex_grads = jax.vmap(dtdloss_dtheta, in_axes=(None, 0, 0, 0)) - -inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t) -``` - -This does what we want, but is slower than it has to be. Now, you wrap the whole thing in a {func}`jax.jit` to get the compiled, efficient version of the same function: - -```{code-cell} -perex_grads = jax.jit(inefficient_perex_grads) - -perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t) -``` - -```{code-cell} -%timeit inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready() -%timeit perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready() -``` - -### Hessian-vector products with `jax.grad`-of-`jax.grad` - -One thing you can do with higher-order {func}`jax.grad` is build a Hessian-vector product function. (Later on you'll write an even more efficient implementation that mixes both forward- and reverse-mode, but this one will use pure reverse-mode.) - -A Hessian-vector product function can be useful in a [truncated Newton Conjugate-Gradient algorithm](https://en.wikipedia.org/wiki/Truncated_Newton_method) for minimizing smooth convex functions, or for studying the curvature of neural network training objectives (e.g. [1](https://arxiv.org/abs/1406.2572), [2](https://arxiv.org/abs/1811.07062), [3](https://arxiv.org/abs/1706.04454), [4](https://arxiv.org/abs/1802.03451)). - -For a scalar-valued function $f : \mathbb{R}^n \to \mathbb{R}$ with continuous second derivatives (so that the Hessian matrix is symmetric), the Hessian at a point $x \in \mathbb{R}^n$ is written as $\partial^2 f(x)$. A Hessian-vector product function is then able to evaluate - -$\qquad v \mapsto \partial^2 f(x) \cdot v$ - -for any $v \in \mathbb{R}^n$. - -The trick is not to instantiate the full Hessian matrix: if $n$ is large, perhaps in the millions or billions in the context of neural networks, then that might be impossible to store. - -Luckily, {func}`jax.grad` already gives us a way to write an efficient Hessian-vector product function. You just have to use the identity: - -$\qquad \partial^2 f (x) v = \partial [x \mapsto \partial f(x) \cdot v] = \partial g(x)$, - -where $g(x) = \partial f(x) \cdot v$ is a new scalar-valued function that dots the gradient of $f$ at $x$ with the vector $v$. Notice that you're only ever differentiating scalar-valued functions of vector-valued arguments, which is exactly where you know {func}`jax.grad` is efficient. - -In JAX code, you can just write this: - -```{code-cell} -def hvp(f, x, v): - return grad(lambda x: jnp.vdot(grad(f)(x), v))(x) -``` - -This example shows that you can freely use lexical closure, and JAX will never get perturbed or confused. - -You will check this implementation a few cells down, once you learn how to compute dense Hessian matrices. You'll also write an even better version that uses both forward-mode and reverse-mode. - - -### Jacobians and Hessians using `jax.jacfwd` and `jax.jacrev` - -You can compute full Jacobian matrices using the {func}`jax.jacfwd` and {func}`jax.jacrev` functions: - -```{code-cell} -from jax import jacfwd, jacrev - -# Define a sigmoid function. -def sigmoid(x): - return 0.5 * (jnp.tanh(x / 2) + 1) - -# Outputs probability of a label being true. -def predict(W, b, inputs): - return sigmoid(jnp.dot(inputs, W) + b) - -# Build a toy dataset. -inputs = jnp.array([[0.52, 1.12, 0.77], - [0.88, -1.08, 0.15], - [0.52, 0.06, -1.30], - [0.74, -2.49, 1.39]]) - -# Initialize random model coefficients -key, W_key, b_key = random.split(key, 3) -W = random.normal(W_key, (3,)) -b = random.normal(b_key, ()) - -# Isolate the function from the weight matrix to the predictions -f = lambda W: predict(W, b, inputs) - -J = jacfwd(f)(W) -print("jacfwd result, with shape", J.shape) -print(J) - -J = jacrev(f)(W) -print("jacrev result, with shape", J.shape) -print(J) -``` - -These two functions compute the same values (up to machine numerics), but differ in their implementation: {func}`jax.jacfwd` uses forward-mode automatic differentiation, which is more efficient for "tall" Jacobian matrices (more outputs than inputs), while {func}`jax.jacrev` uses reverse-mode, which is more efficient for "wide" Jacobian matrices (more inputs than outputs). For matrices that are near-square, {func}`jax.jacfwd` probably has an edge over {func}`jax.jacrev`. - -You can also use {func}`jax.jacfwd` and {func}`jax.jacrev` with container types: - -```{code-cell} -def predict_dict(params, inputs): - return predict(params['W'], params['b'], inputs) - -J_dict = jacrev(predict_dict)({'W': W, 'b': b}, inputs) -for k, v in J_dict.items(): - print("Jacobian from {} to logits is".format(k)) - print(v) -``` - -For more details on forward- and reverse-mode, as well as how to implement {func}`jax.jacfwd` and {func}`jax.jacrev` as efficiently as possible, read on! - -Using a composition of two of these functions gives us a way to compute dense Hessian matrices: - -```{code-cell} -def hessian(f): - return jacfwd(jacrev(f)) - -H = hessian(f)(W) -print("hessian, with shape", H.shape) -print(H) -``` - -This shape makes sense: if you start with a function $f : \mathbb{R}^n \to \mathbb{R}^m$, then at a point $x \in \mathbb{R}^n$ you expect to get the shapes: - -* $f(x) \in \mathbb{R}^m$, the value of $f$ at $x$, -* $\partial f(x) \in \mathbb{R}^{m \times n}$, the Jacobian matrix at $x$, -* $\partial^2 f(x) \in \mathbb{R}^{m \times n \times n}$, the Hessian at $x$, - -and so on. - -To implement `hessian`, you could have used `jacfwd(jacrev(f))` or `jacrev(jacfwd(f))` or any other composition of these two. But forward-over-reverse is typically the most efficient. That's because in the inner Jacobian computation we're often differentiating a function wide Jacobian (maybe like a loss function $f : \mathbb{R}^n \to \mathbb{R}$), while in the outer Jacobian computation we're differentiating a function with a square Jacobian (since $\nabla f : \mathbb{R}^n \to \mathbb{R}^n$), which is where forward-mode wins out. - - -## How it's made: Two foundational autodiff functions - -### Jacobian-Vector products (JVPs, a.k.a. forward-mode autodiff) - -JAX includes efficient and general implementations of both forward- and reverse-mode automatic differentiation. The familiar {func}`jax.grad` function is built on reverse-mode, but to explain the difference between the two modes, and when each can be useful, you need a bit of math background. - - -#### JVPs in math - -Mathematically, given a function $f : \mathbb{R}^n \to \mathbb{R}^m$, the Jacobian of $f$ evaluated at an input point $x \in \mathbb{R}^n$, denoted $\partial f(x)$, is often thought of as a matrix in $\mathbb{R}^m \times \mathbb{R}^n$: - -$\qquad \partial f(x) \in \mathbb{R}^{m \times n}$. - -But you can also think of $\partial f(x)$ as a linear map, which maps the tangent space of the domain of $f$ at the point $x$ (which is just another copy of $\mathbb{R}^n$) to the tangent space of the codomain of $f$ at the point $f(x)$ (a copy of $\mathbb{R}^m$): - -$\qquad \partial f(x) : \mathbb{R}^n \to \mathbb{R}^m$. - -This map is called the [pushforward map](https://en.wikipedia.org/wiki/Pushforward_(differential)) of $f$ at $x$. The Jacobian matrix is just the matrix for this linear map on a standard basis. - -If you don't commit to one specific input point $x$, then you can think of the function $\partial f$ as first taking an input point and returning the Jacobian linear map at that input point: - -$\qquad \partial f : \mathbb{R}^n \to \mathbb{R}^n \to \mathbb{R}^m$. - -In particular, you can uncurry things so that given input point $x \in \mathbb{R}^n$ and a tangent vector $v \in \mathbb{R}^n$, you get back an output tangent vector in $\mathbb{R}^m$. We call that mapping, from $(x, v)$ pairs to output tangent vectors, the *Jacobian-vector product*, and write it as: - -$\qquad (x, v) \mapsto \partial f(x) v$ - - -#### JVPs in JAX code - -Back in Python code, JAX's {func}`jax.jvp` function models this transformation. Given a Python function that evaluates $f$, JAX's {func}`jax.jvp` is a way to get a Python function for evaluating $(x, v) \mapsto (f(x), \partial f(x) v)$. - -```{code-cell} -from jax import jvp - -# Isolate the function from the weight matrix to the predictions -f = lambda W: predict(W, b, inputs) - -key, subkey = random.split(key) -v = random.normal(subkey, W.shape) - -# Push forward the vector `v` along `f` evaluated at `W` -y, u = jvp(f, (W,), (v,)) -``` - -In terms of [Haskell-like type signatures](https://wiki.haskell.org/Type_signature), you could write: - -```haskell -jvp :: (a -> b) -> a -> T a -> (b, T b) -``` - -where `T a` is used to denote the type of the tangent space for `a`. - -In other words, `jvp` takes as arguments a function of type `a -> b`, a value of type `a`, and a tangent vector value of type `T a`. It gives back a pair consisting of a value of type `b` and an output tangent vector of type `T b`. - -The `jvp`-transformed function is evaluated much like the original function, but paired up with each primal value of type `a` it pushes along tangent values of type `T a`. For each primitive numerical operation that the original function would have applied, the `jvp`-transformed function executes a "JVP rule" for that primitive that both evaluates the primitive on the primals and applies the primitive's JVP at those primal values. - -That evaluation strategy has some immediate implications about computational complexity. Since we evaluate JVPs as we go, we don't need to store anything for later, and so the memory cost is independent of the depth of the computation. In addition, the FLOP cost of the `jvp`-transformed function is about 3x the cost of just evaluating the function (one unit of work for evaluating the original function, for example `sin(x)`; one unit for linearizing, like `cos(x)`; and one unit for applying the linearized function to a vector, like `cos_x * v`). Put another way, for a fixed primal point $x$, we can evaluate $v \mapsto \partial f(x) \cdot v$ for about the same marginal cost as evaluating $f$. - -That memory complexity sounds pretty compelling! So why don't we see forward-mode very often in machine learning? - -To answer that, first think about how you could use a JVP to build a full Jacobian matrix. If we apply a JVP to a one-hot tangent vector, it reveals one column of the Jacobian matrix, corresponding to the nonzero entry we fed in. So we can build a full Jacobian one column at a time, and to get each column costs about the same as one function evaluation. That will be efficient for functions with "tall" Jacobians, but inefficient for "wide" Jacobians. - -If you're doing gradient-based optimization in machine learning, you probably want to minimize a loss function from parameters in $\mathbb{R}^n$ to a scalar loss value in $\mathbb{R}$. That means the Jacobian of this function is a very wide matrix: $\partial f(x) \in \mathbb{R}^{1 \times n}$, which we often identify with the Gradient vector $\nabla f(x) \in \mathbb{R}^n$. Building that matrix one column at a time, with each call taking a similar number of FLOPs to evaluate the original function, sure seems inefficient! In particular, for training neural networks, where $f$ is a training loss function and $n$ can be in the millions or billions, this approach just won't scale. - -To do better for functions like this, you just need to use reverse-mode. - - -### Vector-Jacobian products (VJPs, a.k.a. reverse-mode autodiff) - -Where forward-mode gives us back a function for evaluating Jacobian-vector products, which we can then use to build Jacobian matrices one column at a time, reverse-mode is a way to get back a function for evaluating vector-Jacobian products (equivalently Jacobian-transpose-vector products), which we can use to build Jacobian matrices one row at a time. - - -#### VJPs in math - -Let's again consider a function $f : \mathbb{R}^n \to \mathbb{R}^m$. -Starting from our notation for JVPs, the notation for VJPs is pretty simple: - -$\qquad (x, v) \mapsto v \partial f(x)$, - -where $v$ is an element of the cotangent space of $f$ at $x$ (isomorphic to another copy of $\mathbb{R}^m$). When being rigorous, we should think of $v$ as a linear map $v : \mathbb{R}^m \to \mathbb{R}$, and when we write $v \partial f(x)$ we mean function composition $v \circ \partial f(x)$, where the types work out because $\partial f(x) : \mathbb{R}^n \to \mathbb{R}^m$. But in the common case we can identify $v$ with a vector in $\mathbb{R}^m$ and use the two almost interchangeably, just like we might sometimes flip between "column vectors" and "row vectors" without much comment. - -With that identification, we can alternatively think of the linear part of a VJP as the transpose (or adjoint conjugate) of the linear part of a JVP: - -$\qquad (x, v) \mapsto \partial f(x)^\mathsf{T} v$. - -For a given point $x$, we can write the signature as - -$\qquad \partial f(x)^\mathsf{T} : \mathbb{R}^m \to \mathbb{R}^n$. - -The corresponding map on cotangent spaces is often called the [pullback](https://en.wikipedia.org/wiki/Pullback_(differential_geometry)) -of $f$ at $x$. The key for our purposes is that it goes from something that looks like the output of $f$ to something that looks like the input of $f$, just like we might expect from a transposed linear function. - -#### VJPs in JAX code - -Switching from math back to Python, the JAX function `vjp` can take a Python function for evaluating $f$ and give us back a Python function for evaluating the VJP $(x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))$. - -```{code-cell} -from jax import vjp - -# Isolate the function from the weight matrix to the predictions -f = lambda W: predict(W, b, inputs) - -y, vjp_fun = vjp(f, W) - -key, subkey = random.split(key) -u = random.normal(subkey, y.shape) - -# Pull back the covector `u` along `f` evaluated at `W` -v = vjp_fun(u) -``` - -In terms of [Haskell-like type signatures](https://wiki.haskell.org/Type_signature), we could write - -```haskell -vjp :: (a -> b) -> a -> (b, CT b -> CT a) -``` - -where we use `CT a` to denote the type for the cotangent space for `a`. In words, `vjp` takes as arguments a function of type `a -> b` and a point of type `a`, and gives back a pair consisting of a value of type `b` and a linear map of type `CT b -> CT a`. - -This is great because it lets us build Jacobian matrices one row at a time, and the FLOP cost for evaluating $(x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))$ is only about three times the cost of evaluating $f$. In particular, if we want the gradient of a function $f : \mathbb{R}^n \to \mathbb{R}$, we can do it in just one call. That's how {func}`jax.grad` is efficient for gradient-based optimization, even for objectives like neural network training loss functions on millions or billions of parameters. - -There's a cost, though the FLOPs are friendly, memory scales with the depth of the computation. Also, the implementation is traditionally more complex than that of forward-mode, though JAX has some tricks up its sleeve (that's a story for a future notebook!). - -For more on how reverse-mode works, check out [this tutorial video from the Deep Learning Summer School in 2017](http://videolectures.net/deeplearning2017_johnson_automatic_differentiation/). - - -### Vector-valued gradients with VJPs - -If you're interested in taking vector-valued gradients (like `tf.gradients`): - -```{code-cell} -def vgrad(f, x): - y, vjp_fn = vjp(f, x) - return vjp_fn(jnp.ones(y.shape))[0] - -print(vgrad(lambda x: 3*x**2, jnp.ones((2, 2)))) -``` - -### Hessian-vector products using both forward- and reverse-mode - -In a previous section, you implemented a Hessian-vector product function just using reverse-mode (assuming continuous second derivatives): - -```{code-cell} -def hvp(f, x, v): - return grad(lambda x: jnp.vdot(grad(f)(x), v))(x) -``` - -That's efficient, but you can do even better and save some memory by using forward-mode together with reverse-mode. - -Mathematically, given a function $f : \mathbb{R}^n \to \mathbb{R}$ to differentiate, a point $x \in \mathbb{R}^n$ at which to linearize the function, and a vector $v \in \mathbb{R}^n$, the Hessian-vector product function we want is: - -$(x, v) \mapsto \partial^2 f(x) v$ - -Consider the helper function $g : \mathbb{R}^n \to \mathbb{R}^n$ defined to be the derivative (or gradient) of $f$, namely $g(x) = \partial f(x)$. All you need is its JVP, since that will give us: - -$(x, v) \mapsto \partial g(x) v = \partial^2 f(x) v$. - -We can translate that almost directly into code: - -```{code-cell} -# forward-over-reverse -def hvp(f, primals, tangents): - return jvp(grad(f), primals, tangents)[1] -``` - -Even better, since you didn't have to call {func}`jnp.dot` directly, this `hvp` function works with arrays of any shape and with arbitrary container types (like vectors stored as nested lists/dicts/tuples), and doesn't even have a dependence on {mod}`jax.numpy`. - -Here's an example of how to use it: - -```{code-cell} -def f(X): - return jnp.sum(jnp.tanh(X)**2) - -key, subkey1, subkey2 = random.split(key, 3) -X = random.normal(subkey1, (30, 40)) -V = random.normal(subkey2, (30, 40)) - -ans1 = hvp(f, (X,), (V,)) -ans2 = jnp.tensordot(hessian(f)(X), V, 2) - -print(jnp.allclose(ans1, ans2, 1e-4, 1e-4)) -``` - -Another way you might consider writing this is using reverse-over-forward: - -```{code-cell} -# Reverse-over-forward -def hvp_revfwd(f, primals, tangents): - g = lambda primals: jvp(f, primals, tangents)[1] - return grad(g)(primals) -``` - -That's not quite as good, though, because forward-mode has less overhead than reverse-mode, and since the outer differentiation operator here has to differentiate a larger computation than the inner one, keeping forward-mode on the outside works best: - -```{code-cell} -# Reverse-over-reverse, only works for single arguments -def hvp_revrev(f, primals, tangents): - x, = primals - v, = tangents - return grad(lambda x: jnp.vdot(grad(f)(x), v))(x) - - -print("Forward over reverse") -%timeit -n10 -r3 hvp(f, (X,), (V,)) -print("Reverse over forward") -%timeit -n10 -r3 hvp_revfwd(f, (X,), (V,)) -print("Reverse over reverse") -%timeit -n10 -r3 hvp_revrev(f, (X,), (V,)) - -print("Naive full Hessian materialization") -%timeit -n10 -r3 jnp.tensordot(hessian(f)(X), V, 2) -``` - -## Composing VJPs, JVPs, and `jax.vmap` - -### Jacobian-Matrix and Matrix-Jacobian products - -Now that you have {func}`jax.jvp` and {func}`jax.vjp` transformations that give you functions to push-forward or pull-back single vectors at a time, you can use JAX's {func}`jax.vmap` [transformation](https://github.com/jax-ml/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, you can use that to write fast matrix-Jacobian and Jacobian-matrix products: - -```{code-cell} -# Isolate the function from the weight matrix to the predictions -f = lambda W: predict(W, b, inputs) - -# Pull back the covectors `m_i` along `f`, evaluated at `W`, for all `i`. -# First, use a list comprehension to loop over rows in the matrix M. -def loop_mjp(f, x, M): - y, vjp_fun = vjp(f, x) - return jnp.vstack([vjp_fun(mi) for mi in M]) - -# Now, use vmap to build a computation that does a single fast matrix-matrix -# multiply, rather than an outer loop over vector-matrix multiplies. -def vmap_mjp(f, x, M): - y, vjp_fun = vjp(f, x) - outs, = vmap(vjp_fun)(M) - return outs - -key = random.key(0) -num_covecs = 128 -U = random.normal(key, (num_covecs,) + y.shape) - -loop_vs = loop_mjp(f, W, M=U) -print('Non-vmapped Matrix-Jacobian product') -%timeit -n10 -r3 loop_mjp(f, W, M=U) - -print('\nVmapped Matrix-Jacobian product') -vmap_vs = vmap_mjp(f, W, M=U) -%timeit -n10 -r3 vmap_mjp(f, W, M=U) - -assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Matrix-Jacobian Products should be identical' -``` - -```{code-cell} -def loop_jmp(f, W, M): - # jvp immediately returns the primal and tangent values as a tuple, - # so we'll compute and select the tangents in a list comprehension - return jnp.vstack([jvp(f, (W,), (mi,))[1] for mi in M]) - -def vmap_jmp(f, W, M): - _jvp = lambda s: jvp(f, (W,), (s,))[1] - return vmap(_jvp)(M) - -num_vecs = 128 -S = random.normal(key, (num_vecs,) + W.shape) - -loop_vs = loop_jmp(f, W, M=S) -print('Non-vmapped Jacobian-Matrix product') -%timeit -n10 -r3 loop_jmp(f, W, M=S) -vmap_vs = vmap_jmp(f, W, M=S) -print('\nVmapped Jacobian-Matrix product') -%timeit -n10 -r3 vmap_jmp(f, W, M=S) - -assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Jacobian-Matrix products should be identical' -``` - -### The implementation of `jax.jacfwd` and `jax.jacrev` - -Now that we've seen fast Jacobian-matrix and matrix-Jacobian products, it's not hard to guess how to write {func}`jax.jacfwd` and {func}`jax.jacrev`. We just use the same technique to push-forward or pull-back an entire standard basis (isomorphic to an identity matrix) at once. - -```{code-cell} -from jax import jacrev as builtin_jacrev - -def our_jacrev(f): - def jacfun(x): - y, vjp_fun = vjp(f, x) - # Use vmap to do a matrix-Jacobian product. - # Here, the matrix is the Euclidean basis, so we get all - # entries in the Jacobian at once. - J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y))) - return J - return jacfun - -assert jnp.allclose(builtin_jacrev(f)(W), our_jacrev(f)(W)), 'Incorrect reverse-mode Jacobian results!' -``` - -```{code-cell} -from jax import jacfwd as builtin_jacfwd - -def our_jacfwd(f): - def jacfun(x): - _jvp = lambda s: jvp(f, (x,), (s,))[1] - Jt = vmap(_jvp, in_axes=1)(jnp.eye(len(x))) - return jnp.transpose(Jt) - return jacfun - -assert jnp.allclose(builtin_jacfwd(f)(W), our_jacfwd(f)(W)), 'Incorrect forward-mode Jacobian results!' -``` - -Interestingly, the [Autograd](https://github.com/hips/autograd) library couldn't do this. The [implementation](https://github.com/HIPS/autograd/blob/96a03f44da43cd7044c61ac945c483955deba957/autograd/differential_operators.py#L60) of reverse-mode `jacobian` in Autograd had to pull back one vector at a time with an outer-loop `map`. Pushing one vector at a time through the computation is much less efficient than batching it all together with {func}`jax.vmap`. - -Another thing that Autograd couldn't do is {func}`jax.jit`. Interestingly, no matter how much Python dynamism you use in your function to be differentiated, we could always use {func}`jax.jit` on the linear part of the computation. For example: - -```{code-cell} -def f(x): - try: - if x < 3: - return 2 * x ** 3 - else: - raise ValueError - except ValueError: - return jnp.pi * x - -y, f_vjp = vjp(f, 4.) -print(jit(f_vjp)(1.)) -``` - -## Complex numbers and differentiation - -JAX is great at complex numbers and differentiation. To support both [holomorphic and non-holomorphic differentiation](https://en.wikipedia.org/wiki/Holomorphic_function), it helps to think in terms of JVPs and VJPs. - -Consider a complex-to-complex function $f: \mathbb{C} \to \mathbb{C}$ and identify it with a corresponding function $g: \mathbb{R}^2 \to \mathbb{R}^2$, - -```{code-cell} -def f(z): - x, y = jnp.real(z), jnp.imag(z) - return u(x, y) + v(x, y) * 1j - -def g(x, y): - return (u(x, y), v(x, y)) -``` - -That is, we've decomposed $f(z) = u(x, y) + v(x, y) i$ where $z = x + y i$, and identified $\mathbb{C}$ with $\mathbb{R}^2$ to get $g$. - -Since $g$ only involves real inputs and outputs, we already know how to write a Jacobian-vector product for it, say given a tangent vector $(c, d) \in \mathbb{R}^2$, namely: - -$\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} -\begin{bmatrix} c \\ d \end{bmatrix}$. - -To get a JVP for the original function $f$ applied to a tangent vector $c + di \in \mathbb{C}$, we just use the same definition and identify the result as another complex number, - -$\partial f(x + y i)(c + d i) = -\begin{matrix} \begin{bmatrix} 1 & i \end{bmatrix} \\ ~ \end{matrix} -\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} -\begin{bmatrix} c \\ d \end{bmatrix}$. - -That's our definition of the JVP of a $\mathbb{C} \to \mathbb{C}$ function! Notice it doesn't matter whether or not $f$ is holomorphic: the JVP is unambiguous. - -Here's a check: - -```{code-cell} -def check(seed): - key = random.key(seed) - - # random coeffs for u and v - key, subkey = random.split(key) - a, b, c, d = random.uniform(subkey, (4,)) - - def fun(z): - x, y = jnp.real(z), jnp.imag(z) - return u(x, y) + v(x, y) * 1j - - def u(x, y): - return a * x + b * y - - def v(x, y): - return c * x + d * y - - # primal point - key, subkey = random.split(key) - x, y = random.uniform(subkey, (2,)) - z = x + y * 1j - - # tangent vector - key, subkey = random.split(key) - c, d = random.uniform(subkey, (2,)) - z_dot = c + d * 1j - - # check jvp - _, ans = jvp(fun, (z,), (z_dot,)) - expected = (grad(u, 0)(x, y) * c + - grad(u, 1)(x, y) * d + - grad(v, 0)(x, y) * c * 1j+ - grad(v, 1)(x, y) * d * 1j) - print(jnp.allclose(ans, expected)) -``` - -```{code-cell} -check(0) -check(1) -check(2) -``` - -What about VJPs? We do something pretty similar: for a cotangent vector $c + di \in \mathbb{C}$ we define the VJP of $f$ as - -$(c + di)^* \; \partial f(x + y i) = -\begin{matrix} \begin{bmatrix} c & -d \end{bmatrix} \\ ~ \end{matrix} -\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} -\begin{bmatrix} 1 \\ -i \end{bmatrix}$. - -What's with the negatives? They're just to take care of complex conjugation, and the fact that we're working with covectors. - -Here's a check of the VJP rules: - -```{code-cell} -def check(seed): - key = random.key(seed) - - # random coeffs for u and v - key, subkey = random.split(key) - a, b, c, d = random.uniform(subkey, (4,)) - - def fun(z): - x, y = jnp.real(z), jnp.imag(z) - return u(x, y) + v(x, y) * 1j - - def u(x, y): - return a * x + b * y - - def v(x, y): - return c * x + d * y - - # primal point - key, subkey = random.split(key) - x, y = random.uniform(subkey, (2,)) - z = x + y * 1j - - # cotangent vector - key, subkey = random.split(key) - c, d = random.uniform(subkey, (2,)) - z_bar = jnp.array(c + d * 1j) # for dtype control - - # check vjp - _, fun_vjp = vjp(fun, z) - ans, = fun_vjp(z_bar) - expected = (grad(u, 0)(x, y) * c + - grad(v, 0)(x, y) * (-d) + - grad(u, 1)(x, y) * c * (-1j) + - grad(v, 1)(x, y) * (-d) * (-1j)) - assert jnp.allclose(ans, expected, atol=1e-5, rtol=1e-5) -``` - -```{code-cell} -check(0) -check(1) -check(2) -``` - -What about convenience wrappers like {func}`jax.grad`, {func}`jax.jacfwd`, and {func}`jax.jacrev`? - -For $\mathbb{R} \to \mathbb{R}$ functions, recall we defined `grad(f)(x)` as being `vjp(f, x)[1](1.0)`, which works because applying a VJP to a `1.0` value reveals the gradient (i.e. Jacobian, or derivative). We can do the same thing for $\mathbb{C} \to \mathbb{R}$ functions: we can still use `1.0` as the cotangent vector, and we just get out a complex number result summarizing the full Jacobian: - -```{code-cell} -def f(z): - x, y = jnp.real(z), jnp.imag(z) - return x**2 + y**2 - -z = 3. + 4j -grad(f)(z) -``` - -For general $\mathbb{C} \to \mathbb{C}$ functions, the Jacobian has 4 real-valued degrees of freedom (as in the 2x2 Jacobian matrices above), so we can't hope to represent all of them within a complex number. But we can for holomorphic functions! A holomorphic function is precisely a $\mathbb{C} \to \mathbb{C}$ function with the special property that its derivative can be represented as a single complex number. (The [Cauchy-Riemann equations](https://en.wikipedia.org/wiki/Cauchy%E2%80%93Riemann_equations) ensure that the above 2x2 Jacobians have the special form of a scale-and-rotate matrix in the complex plane, i.e. the action of a single complex number under multiplication.) And we can reveal that one complex number using a single call to `vjp` with a covector of `1.0`. - -Because this only works for holomorphic functions, to use this trick we need to promise JAX that our function is holomorphic; otherwise, JAX will raise an error when {func}`jax.grad` is used for a complex-output function: - -```{code-cell} -def f(z): - return jnp.sin(z) - -z = 3. + 4j -grad(f, holomorphic=True)(z) -``` - -All the `holomorphic=True` promise does is disable the error when the output is complex-valued. We can still write `holomorphic=True` when the function isn't holomorphic, but the answer we get out won't represent the full Jacobian. Instead, it'll be the Jacobian of the function where we just discard the imaginary part of the output: - -```{code-cell} -def f(z): - return jnp.conjugate(z) - -z = 3. + 4j -grad(f, holomorphic=True)(z) # f is not actually holomorphic! -``` - -There are some useful upshots for how {func}`jax.grad` works here: - -1. We can use {func}`jax.grad` on holomorphic $\mathbb{C} \to \mathbb{C}$ functions. -2. We can use {func}`jax.grad` to optimize $f : \mathbb{C} \to \mathbb{R}$ functions, like real-valued loss functions of complex parameters `x`, by taking steps in the direction of the conjugate of `grad(f)(x)`. -3. If we have an $\mathbb{R} \to \mathbb{R}$ function that just happens to use some complex-valued operations internally (some of which must be non-holomorphic, e.g. FFTs used in convolutions) then {func}`jax.grad` still works and we get the same result that an implementation using only real values would have given. - -In any case, JVPs and VJPs are always unambiguous. And if we wanted to compute the full Jacobian matrix of a non-holomorphic $\mathbb{C} \to \mathbb{C}$ function, we can do it with JVPs or VJPs! - - -You should expect complex numbers to work everywhere in JAX. Here's differentiating through a Cholesky decomposition of a complex matrix: - -```{code-cell} -A = jnp.array([[5., 2.+3j, 5j], - [2.-3j, 7., 1.+7j], - [-5j, 1.-7j, 12.]]) - -def f(X): - L = jnp.linalg.cholesky(X) - return jnp.sum((L - jnp.sin(L))**2) - -grad(f, holomorphic=True)(A) -``` - -(advanced-autodiff-custom-derivative-rules)= -## Custom derivative rules for JAX-transformable Python functions - -There are two ways to define differentiation rules in JAX: - -1. Using {func}`jax.custom_jvp` and {func}`jax.custom_vjp` to define custom differentiation rules for Python functions that are already JAX-transformable; and -2. Defining new `core.Primitive` instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems. - -This notebook is about #1. To read instead about #2, refer to the [notebook on adding primitives](https://docs.jax.dev/en/latest/notebooks/How_JAX_primitives_work.html). - - -### TL;DR: Custom JVPs with {func}`jax.custom_jvp` - -```{code-cell} -from jax import custom_jvp - -@custom_jvp -def f(x, y): - return jnp.sin(x) * y - -@f.defjvp -def f_jvp(primals, tangents): - x, y = primals - x_dot, y_dot = tangents - primal_out = f(x, y) - tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot - return primal_out, tangent_out -``` - -```{code-cell} -print(f(2., 3.)) -y, y_dot = jvp(f, (2., 3.), (1., 0.)) -print(y) -print(y_dot) -print(grad(f)(2., 3.)) -``` - -```{code-cell} -# Equivalent alternative using the `defjvps` convenience wrapper - -@custom_jvp -def f(x, y): - return jnp.sin(x) * y - -f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y, - lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot) -``` - -```{code-cell} -print(f(2., 3.)) -y, y_dot = jvp(f, (2., 3.), (1., 0.)) -print(y) -print(y_dot) -print(grad(f)(2., 3.)) -``` - -### TL;DR: Custom VJPs with `jax.custom_vjp` - -```{code-cell} -from jax import custom_vjp - -@custom_vjp -def f(x, y): - return jnp.sin(x) * y - -def f_fwd(x, y): -# Returns primal output and residuals to be used in backward pass by `f_bwd`. - return f(x, y), (jnp.cos(x), jnp.sin(x), y) - -def f_bwd(res, g): - cos_x, sin_x, y = res # Gets residuals computed in `f_fwd` - return (cos_x * g * y, sin_x * g) - -f.defvjp(f_fwd, f_bwd) -``` - -```{code-cell} -print(grad(f)(2., 3.)) -``` - -### Example problems - -To get an idea of what problems {func}`jax.custom_jvp` and {func}`jax.custom_vjp` are meant to solve, let's go over a few examples. A more thorough introduction to the {func}`jax.custom_jvp` and {func}`jax.custom_vjp` APIs is in the next section. - - -#### Example: Numerical stability - -One application of {func}`jax.custom_jvp` is to improve the numerical stability of differentiation. - -Say we want to write a function called `log1pexp`, which computes $x \mapsto \log ( 1 + e^x )$. We can write that using `jax.numpy`: - -```{code-cell} -def log1pexp(x): - return jnp.log(1. + jnp.exp(x)) - -log1pexp(3.) -``` - -Since it's written in terms of `jax.numpy`, it's JAX-transformable: - -```{code-cell} -print(jit(log1pexp)(3.)) -print(jit(grad(log1pexp))(3.)) -print(vmap(jit(grad(log1pexp)))(jnp.arange(3.))) -``` - -But there's a numerical stability problem lurking here: - -```{code-cell} -print(grad(log1pexp)(100.)) -``` - -That doesn't seem right! After all, the derivative of $x \mapsto \log (1 + e^x)$ is $x \mapsto \frac{e^x}{1 + e^x}$, and so for large values of $x$ we'd expect the value to be about 1. - -We can get a bit more insight into what's going on by looking at the jaxpr for the gradient computation: - -```{code-cell} -from jax import make_jaxpr - -make_jaxpr(grad(log1pexp))(100.) -``` - -Stepping through how the jaxpr would be evaluated, notice that the last line would involve multiplying values that floating point math will round to 0 and $\infty$, respectively, which is never a good idea. That is, we're effectively evaluating `lambda x: (1 / (1 + jnp.exp(x))) * jnp.exp(x)` for large `x`, which effectively turns into `0. * jnp.inf`. - -Instead of generating such large and small values, hoping for a cancellation that floats can't always provide, we'd rather just express the derivative function as a more numerically stable program. In particular, we can write a program that more closely evaluates the equal mathematical expression $1 - \frac{1}{1 + e^x}$, with no cancellation in sight. - -This problem is interesting because even though our definition of `log1pexp` could already be JAX-differentiated (and transformed with {func}`jax.jit`, {func}`jax.vmap`, ...), we're not happy with the result of applying standard autodiff rules to the primitives comprising `log1pexp` and composing the result. Instead, we'd like to specify how the whole function `log1pexp` should be differentiated, as a unit, and thus arrange those exponentials better. - -This is one application of custom derivative rules for Python functions that are already JAX transformable: specifying how a composite function should be differentiated, while still using its original Python definition for other transformations (like {func}`jax.jit`, {func}`jax.vmap`, ...). - -Here's a solution using {func}`jax.custom_jvp`: - -```{code-cell} -@custom_jvp -def log1pexp(x): - return jnp.log(1. + jnp.exp(x)) - -@log1pexp.defjvp -def log1pexp_jvp(primals, tangents): - x, = primals - x_dot, = tangents - ans = log1pexp(x) - ans_dot = (1 - 1/(1 + jnp.exp(x))) * x_dot - return ans, ans_dot -``` - -```{code-cell} -print(grad(log1pexp)(100.)) -``` - -```{code-cell} -print(jit(log1pexp)(3.)) -print(jit(grad(log1pexp))(3.)) -print(vmap(jit(grad(log1pexp)))(jnp.arange(3.))) -``` - -Here's a `defjvps` convenience wrapper to express the same thing: - -```{code-cell} -@custom_jvp -def log1pexp(x): - return jnp.log(1. + jnp.exp(x)) - -log1pexp.defjvps(lambda t, ans, x: (1 - 1/(1 + jnp.exp(x))) * t) -``` - -```{code-cell} -print(grad(log1pexp)(100.)) -print(jit(log1pexp)(3.)) -print(jit(grad(log1pexp))(3.)) -print(vmap(jit(grad(log1pexp)))(jnp.arange(3.))) -``` - -#### Example: Enforcing a differentiation convention - -A related application is to enforce a differentiation convention, perhaps at a boundary. - -Consider the function $f : \mathbb{R}_+ \to \mathbb{R}_+$ with $f(x) = \frac{x}{1 + \sqrt{x}}$, where we take $\mathbb{R}_+ = [0, \infty)$. We might implement $f$ as a program like this: - -```{code-cell} -def f(x): - return x / (1 + jnp.sqrt(x)) -``` - -As a mathematical function on $\mathbb{R}$ (the full real line), $f$ is not differentiable at zero (because the limit defining the derivative doesn't exist from the left). Correspondingly, autodiff produces a `nan` value: - -```{code-cell} -print(grad(f)(0.)) -``` - -But mathematically if we think of $f$ as a function on $\mathbb{R}_+$ then it is differentiable at 0 [Rudin's Principles of Mathematical Analysis Definition 5.1, or Tao's Analysis I 3rd ed. Definition 10.1.1 and Example 10.1.6]. Alternatively, we might say as a convention we want to consider the directional derivative from the right. So there is a sensible value for the Python function `grad(f)` to return at `0.0`, namely `1.0`. By default, JAX's machinery for differentiation assumes all functions are defined over $\mathbb{R}$ and thus doesn't produce `1.0` here. - -We can use a custom JVP rule! In particular, we can define the JVP rule in terms of the derivative function $x \mapsto \frac{\sqrt{x} + 2}{2(\sqrt{x} + 1)^2}$ on $\mathbb{R}_+$, - -```{code-cell} -@custom_jvp -def f(x): - return x / (1 + jnp.sqrt(x)) - -@f.defjvp -def f_jvp(primals, tangents): - x, = primals - x_dot, = tangents - ans = f(x) - ans_dot = ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * x_dot - return ans, ans_dot -``` - -```{code-cell} -print(grad(f)(0.)) -``` - -Here's the convenience wrapper version: - -```{code-cell} -@custom_jvp -def f(x): - return x / (1 + jnp.sqrt(x)) - -f.defjvps(lambda t, ans, x: ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * t) -``` - -```{code-cell} -print(grad(f)(0.)) -``` - -#### Example: Gradient clipping - -While in some cases we want to express a mathematical differentiation computation, in other cases we may even want to take a step away from mathematics to adjust the computation autodiff performs. One canonical example is reverse-mode gradient clipping. - -For gradient clipping, we can use {func}`jnp.clip` together with a {func}`jax.custom_vjp` reverse-mode-only rule: - -```{code-cell} -from functools import partial - -@custom_vjp -def clip_gradient(lo, hi, x): - return x # identity function - -def clip_gradient_fwd(lo, hi, x): - return x, (lo, hi) # save bounds as residuals - -def clip_gradient_bwd(res, g): - lo, hi = res - return (None, None, jnp.clip(g, lo, hi)) # use None to indicate zero cotangents for lo and hi - -clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd) -``` - -```{code-cell} -import matplotlib.pyplot as plt - -t = jnp.linspace(0, 10, 1000) - -plt.plot(jnp.sin(t)) -plt.plot(vmap(grad(jnp.sin))(t)) -``` - -```{code-cell} -def clip_sin(x): - x = clip_gradient(-0.75, 0.75, x) - return jnp.sin(x) - -plt.plot(clip_sin(t)) -plt.plot(vmap(grad(clip_sin))(t)) -``` - -#### Example: Python debugging - -Another application that is motivated by development workflow rather than numerics is to set a `pdb` debugger trace in the backward pass of reverse-mode autodiff. - -When trying to track down the source of a `nan` runtime error, or just examine carefully the cotangent (gradient) values being propagated, it can be useful to insert a debugger at a point in the backward pass that corresponds to a specific point in the primal computation. You can do that with {func}`jax.custom_vjp`. - -We'll defer an example until the next section. - - - -#### Example: Implicit function differentiation of iterative implementations - -This example gets pretty deep in the mathematical weeds! - -Another application for {func}`jax.custom_vjp` is reverse-mode differentiation of functions that are JAX-transformable (by {func}`jax.jit`, {func}`jax.vmap`, ...) but not efficiently JAX-differentiable for some reason, perhaps because they involve {func}`jax.lax.while_loop`. (It's not possible to produce an XLA HLO program that efficiently computes the reverse-mode derivative of an XLA HLO While loop because that would require a program with unbounded memory use, which isn't possible to express in XLA HLO, at least without "side-effecting" interactions through infeed/outfeed.) - -For example, consider this `fixed_point` routine which computes a fixed point by iteratively applying a function in a `while_loop`: - -```{code-cell} -from jax.lax import while_loop - -def fixed_point(f, a, x_guess): - def cond_fun(carry): - x_prev, x = carry - return jnp.abs(x_prev - x) > 1e-6 - - def body_fun(carry): - _, x = carry - return x, f(a, x) - - _, x_star = while_loop(cond_fun, body_fun, (x_guess, f(a, x_guess))) - return x_star -``` - -This is an iterative procedure for numerically solving the equation $x = f(a, x)$ for $x$, by iterating $x_{t+1} = f(a, x_t)$ until $x_{t+1}$ is sufficiently close to $x_t$. The result $x^*$ depends on the parameters $a$, and so we can think of there being a function $a \mapsto x^*(a)$ that is implicitly defined by equation $x = f(a, x)$. - -We can use `fixed_point` to run iterative procedures to convergence, for example running Newton's method to calculate square roots while only executing adds, multiplies, and divides: - -```{code-cell} -def newton_sqrt(a): - update = lambda a, x: 0.5 * (x + a / x) - return fixed_point(update, a, a) -``` - -```{code-cell} -print(newton_sqrt(2.)) -``` - -We can {func}`jax.vmap` or {func}`jax.jit` the function as well: - -```{code-cell} -print(jit(vmap(newton_sqrt))(jnp.array([1., 2., 3., 4.]))) -``` - -We can't apply reverse-mode automatic differentiation because of the `while_loop`, but it turns out we wouldn't want to anyway: instead of differentiating through the implementation of `fixed_point` and all its iterations, we can exploit the mathematical structure to do something that is much more memory-efficient (and FLOP-efficient in this case, too!). We can instead use the implicit function theorem [Prop A.25 of Bertsekas's Nonlinear Programming, 2nd ed.], which guarantees (under some conditions) the existence of the mathematical objects we're about to use. In essence, we linearize the solution and solve those linear equations iteratively to compute the derivatives we want. - -Consider again the equation $x = f(a, x)$ and the function $x^*$. We want to evaluate vector-Jacobian products like $v^\mathsf{T} \mapsto v^\mathsf{T} \partial x^*(a_0)$. - -At least in an open neighborhood around the point $a_0$ at which we want to differentiate, let's assume that the equation $x^*(a) = f(a, x^*(a))$ holds for all $a$. Since the two sides are equal as functions of $a$, their derivatives must be equal as well, so let's differentiate both sides: - -$\qquad \partial x^*(a) = \partial_0 f(a, x^*(a)) + \partial_1 f(a, x^*(a)) \partial x^*(a)$. - -Setting $A = \partial_1 f(a_0, x^*(a_0))$ and $B = \partial_0 f(a_0, x^*(a_0))$, we can write the quantity we're after more simply as: - -$\qquad \partial x^*(a_0) = B + A \partial x^*(a_0)$, - -or, by rearranging, - -$\qquad \partial x^*(a_0) = (I - A)^{-1} B$. - -That means we can evaluate vector-Jacobian products, such as: - -$\qquad v^\mathsf{T} \partial x^*(a_0) = v^\mathsf{T} (I - A)^{-1} B = w^\mathsf{T} B$, - -where $w^\mathsf{T} = v^\mathsf{T} (I - A)^{-1}$, or equivalently $w^\mathsf{T} = v^\mathsf{T} + w^\mathsf{T} A$, or equivalently $w^\mathsf{T}$ is the fixed point of the map $u^\mathsf{T} \mapsto v^\mathsf{T} + u^\mathsf{T} A$. That last characterization gives us a way to write the VJP for `fixed_point` in terms of a call to `fixed_point`! Moreover, after expanding $A$ and $B$ back out, you can conclude you need only to evaluate VJPs of $f$ at $(a_0, x^*(a_0))$. - -Here's the upshot: - -```{code-cell} -@partial(custom_vjp, nondiff_argnums=(0,)) -def fixed_point(f, a, x_guess): - def cond_fun(carry): - x_prev, x = carry - return jnp.abs(x_prev - x) > 1e-6 - - def body_fun(carry): - _, x = carry - return x, f(a, x) - - _, x_star = while_loop(cond_fun, body_fun, (x_guess, f(a, x_guess))) - return x_star - -def fixed_point_fwd(f, a, x_init): - x_star = fixed_point(f, a, x_init) - return x_star, (a, x_star) - -def fixed_point_rev(f, res, x_star_bar): - a, x_star = res - _, vjp_a = vjp(lambda a: f(a, x_star), a) - a_bar, = vjp_a(fixed_point(partial(rev_iter, f), - (a, x_star, x_star_bar), - x_star_bar)) - return a_bar, jnp.zeros_like(x_star) - -def rev_iter(f, packed, u): - a, x_star, x_star_bar = packed - _, vjp_x = vjp(lambda x: f(a, x), x_star) - return x_star_bar + vjp_x(u)[0] - -fixed_point.defvjp(fixed_point_fwd, fixed_point_rev) -``` - -```{code-cell} -print(newton_sqrt(2.)) -``` - -```{code-cell} -print(grad(newton_sqrt)(2.)) -print(grad(grad(newton_sqrt))(2.)) -``` - -We can check our answers by differentiating {func}`jnp.sqrt`, which uses a totally different implementation: - -```{code-cell} -print(grad(jnp.sqrt)(2.)) -print(grad(grad(jnp.sqrt))(2.)) -``` - -A limitation to this approach is that the argument `f` can't close over any values involved in differentiation. That is, you might notice that we kept the parameter `a` explicit in the argument list of `fixed_point`. For this use case, consider using the low-level primitive `lax.custom_root`, which allows for derivatives in closed-over variables with custom root-finding functions. - - -### Basic usage of `jax.custom_jvp` and `jax.custom_vjp` APIs - -#### Use `jax.custom_jvp` to define forward-mode (and, indirectly, reverse-mode) rules - -Here's a canonical basic example of using {func}`jax.custom_jvp`, where the comments use -[Haskell-like type signatures](https://wiki.haskell.org/Type_signature): - -```{code-cell} -# f :: a -> b -@custom_jvp -def f(x): - return jnp.sin(x) - -# f_jvp :: (a, T a) -> (b, T b) -def f_jvp(primals, tangents): - x, = primals - t, = tangents - return f(x), jnp.cos(x) * t - -f.defjvp(f_jvp) -``` - -```{code-cell} -print(f(3.)) - -y, y_dot = jvp(f, (3.,), (1.,)) -print(y) -print(y_dot) -``` - -In other words, we start with a primal function `f` that takes inputs of type `a` and produces outputs of type `b`. We associate with it a JVP rule function `f_jvp` that takes a pair of inputs representing the primal inputs of type `a` and the corresponding tangent inputs of type `T a`, and produces a pair of outputs representing the primal outputs of type `b` and tangent outputs of type `T b`. The tangent outputs should be a linear function of the tangent inputs. - -You can also use `f.defjvp` as a decorator, as in - -```python -@custom_jvp -def f(x): - ... - -@f.defjvp -def f_jvp(primals, tangents): - ... -``` - -Even though we defined only a JVP rule and no VJP rule, we can use both forward- and reverse-mode differentiation on `f`. JAX will automatically transpose the linear computation on tangent values from our custom JVP rule, computing the VJP as efficiently as if we had written the rule by hand: - -```{code-cell} -print(grad(f)(3.)) -print(grad(grad(f))(3.)) -``` - -For automatic transposition to work, the JVP rule's output tangents must be linear as a function of the input tangents. Otherwise a transposition error is raised. - -Multiple arguments work like this: - -```{code-cell} -@custom_jvp -def f(x, y): - return x ** 2 * y - -@f.defjvp -def f_jvp(primals, tangents): - x, y = primals - x_dot, y_dot = tangents - primal_out = f(x, y) - tangent_out = 2 * x * y * x_dot + x ** 2 * y_dot - return primal_out, tangent_out -``` - -```{code-cell} -print(grad(f)(2., 3.)) -``` - -The `defjvps` convenience wrapper lets us define a JVP for each argument separately, and the results are computed separately then summed: - -```{code-cell} -@custom_jvp -def f(x): - return jnp.sin(x) - -f.defjvps(lambda t, ans, x: jnp.cos(x) * t) -``` - -```{code-cell} -print(grad(f)(3.)) -``` - -Here's a `defjvps` example with multiple arguments: - -```{code-cell} -@custom_jvp -def f(x, y): - return x ** 2 * y - -f.defjvps(lambda x_dot, primal_out, x, y: 2 * x * y * x_dot, - lambda y_dot, primal_out, x, y: x ** 2 * y_dot) -``` - -```{code-cell} -print(grad(f)(2., 3.)) -print(grad(f, 0)(2., 3.)) # same as above -print(grad(f, 1)(2., 3.)) -``` - -As a shorthand, with `defjvps` you can pass a `None` value to indicate that the JVP for a particular argument is zero: - -```{code-cell} -@custom_jvp -def f(x, y): - return x ** 2 * y - -f.defjvps(lambda x_dot, primal_out, x, y: 2 * x * y * x_dot, - None) -``` - -```{code-cell} -print(grad(f)(2., 3.)) -print(grad(f, 0)(2., 3.)) # same as above -print(grad(f, 1)(2., 3.)) -``` - -Calling a {func}`jax.custom_jvp` function with keyword arguments, or writing a {func}`jax.custom_jvp` function definition with default arguments, are both allowed so long as they can be unambiguously mapped to positional arguments based on the function signature retrieved by the standard library `inspect.signature` mechanism. - -When you're not performing differentiation, the function `f` is called just as if it weren't decorated by {func}`jax.custom_jvp`: - -```{code-cell} -@custom_jvp -def f(x): - print('called f!') # a harmless side-effect - return jnp.sin(x) - -@f.defjvp -def f_jvp(primals, tangents): - print('called f_jvp!') # a harmless side-effect - x, = primals - t, = tangents - return f(x), jnp.cos(x) * t -``` - -```{code-cell} -print(f(3.)) -``` - -```{code-cell} -print(vmap(f)(jnp.arange(3.))) -print(jit(f)(3.)) -``` - -The custom JVP rule is invoked during differentiation, whether forward or reverse: - -```{code-cell} -y, y_dot = jvp(f, (3.,), (1.,)) -print(y_dot) -``` - -```{code-cell} -print(grad(f)(3.)) -``` - -Notice that `f_jvp` calls `f` to compute the primal outputs. In the context of higher-order differentiation, each application of a differentiation transform will use the custom JVP rule if and only if the rule calls the original `f` to compute the primal outputs. (This represents a kind of fundamental tradeoff, where we can't make use of intermediate values from the evaluation of `f` in our rule _and also_ have the rule apply in all orders of higher-order differentiation.) - -```{code-cell} -grad(grad(f))(3.) -``` - -You can use Python control flow with {func}`jax.custom_jvp`: - -```{code-cell} -@custom_jvp -def f(x): - if x > 0: - return jnp.sin(x) - else: - return jnp.cos(x) - -@f.defjvp -def f_jvp(primals, tangents): - x, = primals - x_dot, = tangents - ans = f(x) - if x > 0: - return ans, 2 * x_dot - else: - return ans, 3 * x_dot -``` - -```{code-cell} -print(grad(f)(1.)) -print(grad(f)(-1.)) -``` - -#### Use `jax.custom_vjp` to define custom reverse-mode-only rules - -While {func}`jax.custom_jvp` suffices for controlling both forward- and, via JAX's automatic transposition, reverse-mode differentiation behavior, in some cases we may want to directly control a VJP rule, for example in the latter two example problems presented above. We can do that with {func}`jax.custom_vjp`: - -```{code-cell} -from jax import custom_vjp - -# f :: a -> b -@custom_vjp -def f(x): - return jnp.sin(x) - -# f_fwd :: a -> (b, c) -def f_fwd(x): - return f(x), jnp.cos(x) - -# f_bwd :: (c, CT b) -> CT a -def f_bwd(cos_x, y_bar): - return (cos_x * y_bar,) - -f.defvjp(f_fwd, f_bwd) -``` - -```{code-cell} -print(f(3.)) -print(grad(f)(3.)) -``` - -In other words, we again start with a primal function `f` that takes inputs of type `a` and produces outputs of type `b`. We associate with it two functions, `f_fwd` and `f_bwd`, which describe how to perform the forward- and backward-passes of reverse-mode autodiff, respectively. - -The function `f_fwd` describes the forward pass, not only the primal computation but also what values to save for use on the backward pass. Its input signature is just like that of the primal function `f`, in that it takes a primal input of type `a`. But as output it produces a pair, where the first element is the primal output `b` and the second element is any "residual" data of type `c` to be stored for use by the backward pass. (This second output is analogous to [PyTorch's save_for_backward mechanism](https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html).) - -The function `f_bwd` describes the backward pass. It takes two inputs, where the first is the residual data of type `c` produced by `f_fwd` and the second is the output cotangents of type `CT b` corresponding to the output of the primal function. It produces an output of type `CT a` representing the cotangents corresponding to the input of the primal function. In particular, the output of `f_bwd` must be a sequence (e.g. a tuple) of length equal to the number of arguments to the primal function. - -So multiple arguments work like this: - -```{code-cell} -@custom_vjp -def f(x, y): - return jnp.sin(x) * y - -def f_fwd(x, y): - return f(x, y), (jnp.cos(x), jnp.sin(x), y) - -def f_bwd(res, g): - cos_x, sin_x, y = res - return (cos_x * g * y, sin_x * g) - -f.defvjp(f_fwd, f_bwd) -``` - -```{code-cell} -print(grad(f)(2., 3.)) -``` - -Calling a {func}`jax.custom_vjp` function with keyword arguments, or writing a {func}`jax.custom_vjp` function definition with default arguments, are both allowed so long as they can be unambiguously mapped to positional arguments based on the function signature retrieved by the standard library `inspect.signature` mechanism. - -As with {func}`jax.custom_jvp`, the custom VJP rule composed of `f_fwd` and `f_bwd` is not invoked if differentiation is not applied. If the function is evaluated, or transformed with {func}`jax.jit`, {func}`jax.vmap`, or other non-differentiation transformations, then only `f` is called. - -```{code-cell} -@custom_vjp -def f(x): - print("called f!") - return jnp.sin(x) - -def f_fwd(x): - print("called f_fwd!") - return f(x), jnp.cos(x) - -def f_bwd(cos_x, y_bar): - print("called f_bwd!") - return (cos_x * y_bar,) - -f.defvjp(f_fwd, f_bwd) -``` - -```{code-cell} -print(f(3.)) -``` - -```{code-cell} -print(grad(f)(3.)) -``` - -```{code-cell} -y, f_vjp = vjp(f, 3.) -print(y) -``` - -```{code-cell} -print(f_vjp(1.)) -``` - -**Forward-mode autodiff cannot be used on the** {func}`jax.custom_vjp` **function** and will raise an error: - -```{code-cell} -:tags: [raises-exception] - -from jax import jvp - -try: - jvp(f, (3.,), (1.,)) -except TypeError as e: - print('ERROR! {}'.format(e)) -``` - -If you want to use both forward- and reverse-mode, use {func}`jax.custom_jvp` instead. - -We can use {func}`jax.custom_vjp` together with `pdb` to insert a debugger trace in the backward pass: - -```{code-cell} -import pdb - -@custom_vjp -def debug(x): - return x # acts like identity - -def debug_fwd(x): - return x, x - -def debug_bwd(x, g): - import pdb; pdb.set_trace() - return g - -debug.defvjp(debug_fwd, debug_bwd) -``` - -```{code-cell} -def foo(x): - y = x ** 2 - y = debug(y) # insert pdb in corresponding backward pass step - return jnp.sin(y) -``` - -```python -jax.grad(foo)(3.) - -> (12)debug_bwd() --> return g -(Pdb) p x -Array(9., dtype=float32) -(Pdb) p g -Array(-0.91113025, dtype=float32) -(Pdb) q -``` - - -### More features and details - -#### Working with `list` / `tuple` / `dict` containers (and other pytrees) - -You should expect standard Python containers like lists, tuples, namedtuples, and dicts to just work, along with nested versions of those. In general, any [pytrees](https://docs.jax.dev/en/latest/pytrees.html) are permissible, so long as their structures are consistent according to the type constraints. - -Here's a contrived example with {func}`jax.custom_jvp`: - -```{code-cell} -from collections import namedtuple -Point = namedtuple("Point", ["x", "y"]) - -@custom_jvp -def f(pt): - x, y = pt.x, pt.y - return {'a': x ** 2, - 'b': (jnp.sin(x), jnp.cos(y))} - -@f.defjvp -def f_jvp(primals, tangents): - pt, = primals - pt_dot, = tangents - ans = f(pt) - ans_dot = {'a': 2 * pt.x * pt_dot.x, - 'b': (jnp.cos(pt.x) * pt_dot.x, -jnp.sin(pt.y) * pt_dot.y)} - return ans, ans_dot - -def fun(pt): - dct = f(pt) - return dct['a'] + dct['b'][0] -``` - -```{code-cell} -pt = Point(1., 2.) - -print(f(pt)) -``` - -```{code-cell} -print(grad(fun)(pt)) -``` - -And an analogous contrived example with {func}`jax.custom_vjp`: - -```{code-cell} -@custom_vjp -def f(pt): - x, y = pt.x, pt.y - return {'a': x ** 2, - 'b': (jnp.sin(x), jnp.cos(y))} - -def f_fwd(pt): - return f(pt), pt - -def f_bwd(pt, g): - a_bar, (b0_bar, b1_bar) = g['a'], g['b'] - x_bar = 2 * pt.x * a_bar + jnp.cos(pt.x) * b0_bar - y_bar = -jnp.sin(pt.y) * b1_bar - return (Point(x_bar, y_bar),) - -f.defvjp(f_fwd, f_bwd) - -def fun(pt): - dct = f(pt) - return dct['a'] + dct['b'][0] -``` - -```{code-cell} -pt = Point(1., 2.) - -print(f(pt)) -``` - -```{code-cell} -print(grad(fun)(pt)) -``` - -#### Handling non-differentiable arguments - -Some use cases, like the final example problem, call for non-differentiable arguments like function-valued arguments to be passed to functions with custom differentiation rules, and for those arguments to also be passed to the rules themselves. In the case of `fixed_point`, the function argument `f` was such a non-differentiable argument. A similar situation arises with `jax.experimental.odeint`. - -##### `jax.custom_jvp` with `nondiff_argnums` - -Use the optional `nondiff_argnums` parameter to {func}`jax.custom_jvp` to indicate arguments like these. Here's an example with {func}`jax.custom_jvp`: - -```{code-cell} -from functools import partial - -@partial(custom_jvp, nondiff_argnums=(0,)) -def app(f, x): - return f(x) - -@app.defjvp -def app_jvp(f, primals, tangents): - x, = primals - x_dot, = tangents - return f(x), 2. * x_dot -``` - -```{code-cell} -print(app(lambda x: x ** 3, 3.)) -``` - -```{code-cell} -print(grad(app, 1)(lambda x: x ** 3, 3.)) -``` - -Notice the gotcha here: no matter where in the argument list these parameters appear, they're placed at the *start* of the signature of the corresponding JVP rule. Here's another example: - -```{code-cell} -@partial(custom_jvp, nondiff_argnums=(0, 2)) -def app2(f, x, g): - return f(g((x))) - -@app2.defjvp -def app2_jvp(f, g, primals, tangents): - x, = primals - x_dot, = tangents - return f(g(x)), 3. * x_dot -``` - -```{code-cell} -print(app2(lambda x: x ** 3, 3., lambda y: 5 * y)) -``` - -```{code-cell} -print(grad(app2, 1)(lambda x: x ** 3, 3., lambda y: 5 * y)) -``` - -##### `jax.custom_vjp` with `nondiff_argnums` - -A similar option exists for {func}`jax.custom_vjp`, and, similarly, the convention is that the non-differentiable arguments are passed as the first arguments to the `_bwd` rule, no matter where they appear in the signature of the original function. The signature of the `_fwd` rule remains unchanged - it is the same as the signature of the primal function. Here's an example: - -```{code-cell} -@partial(custom_vjp, nondiff_argnums=(0,)) -def app(f, x): - return f(x) - -def app_fwd(f, x): - return f(x), x - -def app_bwd(f, x, g): - return (5 * g,) - -app.defvjp(app_fwd, app_bwd) -``` - -```{code-cell} -print(app(lambda x: x ** 2, 4.)) -``` - -```{code-cell} -print(grad(app, 1)(lambda x: x ** 2, 4.)) -``` - -Refer to `fixed_point` above for another usage example. - -**You don't need to use** `nondiff_argnums` **with array-valued arguments**, such as, for example, ones with the integer dtype. Instead, `nondiff_argnums` should only be used for argument values that don't correspond to JAX types (essentially don't correspond to array types), like Python callables or strings. If JAX detects that an argument indicated by `nondiff_argnums` contains a JAX Tracer, then an error is raised. The `clip_gradient` function above is a good example of not using `nondiff_argnums` for integer-dtype array arguments. - -## Next steps - -There's a whole world of other autodiff tricks and functionality out there. Topics that weren't covered in this tutorial but can be worth pursuing include: - - - Gauss-Newton Vector Products, linearizing once - - Custom VJPs and JVPs - - Efficient derivatives at fixed-points - - Estimating the trace of a Hessian using random Hessian-vector products - - Forward-mode autodiff using only reverse-mode autodiff - - Taking derivatives with respect to custom data types - - Checkpointing (binomial checkpointing for efficient reverse-mode, not model snapshotting) - - Optimizing VJPs with Jacobian pre-accumulation diff --git a/docs/advanced_autodiff.md b/docs/advanced_autodiff.md new file mode 100644 index 000000000000..43b62b0e0d5c --- /dev/null +++ b/docs/advanced_autodiff.md @@ -0,0 +1,11 @@ +# Advanced Automatic Differentiation + +```{toctree} +:caption: Advanced automatic differentiation +:maxdepth: 1 + +higher-order +jacobian-vector-products +complex-differentiation +notebooks/Custom_derivative_rules_for_Python_code +``` diff --git a/docs/advanced_guides.rst b/docs/advanced_guides.rst index e090efa67b29..4a7624e08262 100644 --- a/docs/advanced_guides.rst +++ b/docs/advanced_guides.rst @@ -32,9 +32,8 @@ operations. :maxdepth: 1 notebooks/autodiff_cookbook - notebooks/Custom_derivative_rules_for_Python_code notebooks/autodiff_remat - advanced-autodiff + advanced_autodiff .. toctree:: :maxdepth: 1 @@ -43,7 +42,6 @@ operations. errors debugging debugging/index - debugging/flags transfer_guard .. toctree:: diff --git a/docs/automatic-differentiation.md b/docs/automatic-differentiation.md index 07af05e3d973..221dd19c5121 100644 --- a/docs/automatic-differentiation.md +++ b/docs/automatic-differentiation.md @@ -26,7 +26,7 @@ Computing gradients is a critical part of modern machine learning methods, and t - {ref}`automatic-differentiation-evaluating-using-jax-value_and_grad` - {ref}`automatic-differentiation-checking-against-numerical-differences` -Make sure to also check out the {ref}`advanced-autodiff` tutorial for more advanced topics. +Make sure to also check out the {ref}`"Advanced automatic differentiation" guides ` for more advanced topics. While understanding how automatic differentiation works "under the hood" isn't crucial for using JAX in most contexts, you are encouraged to check out this quite accessible [video](https://www.youtube.com/watch?v=wG_nF1awSSY) to get a deeper sense of what's going on. @@ -230,4 +230,4 @@ check_grads(loss, (W, b), order=2) # check up to 2nd order derivatives ## Next steps -The {ref}`advanced-autodiff` tutorial provides more advanced and detailed explanations of how the ideas covered in this document are implemented in the JAX backend. Some features, such as {ref}`advanced-autodiff-custom-derivative-rules`, depend on understanding advanced automatic differentiation, so do check out that section in the {ref}`advanced-autodiff` tutorial if you are interested. +The {ref}`"Advanced automatic differentiation" guides ` provide more advanced and detailed explanations of how the ideas covered in this document are implemented in the JAX backend. Some features, such as {ref}`advanced-autodiff-custom-derivative-rules`, depend on understanding advanced automatic differentiation, so do check out that section if you are interested. diff --git a/docs/complex-differentiation.md b/docs/complex-differentiation.md new file mode 100644 index 000000000000..cf31b90a45ef --- /dev/null +++ b/docs/complex-differentiation.md @@ -0,0 +1,207 @@ +--- +jupytext: + formats: md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.4 +kernelspec: + display_name: Python 3 + name: python3 +--- + +# Complex numbers and differentiation + +JAX is great at complex numbers and differentiation. To support both [holomorphic and non-holomorphic differentiation](https://en.wikipedia.org/wiki/Holomorphic_function), it helps to think in terms of JVPs and VJPs. + +Consider a complex-to-complex function $f: \mathbb{C} \to \mathbb{C}$ and identify it with a corresponding function $g: \mathbb{R}^2 \to \mathbb{R}^2$, + +```{code-cell} +import jax.numpy as jnp + +def f(z): + x, y = jnp.real(z), jnp.imag(z) + return u(x, y) + v(x, y) * 1j + +def g(x, y): + return (u(x, y), v(x, y)) +``` + +That is, we've decomposed $f(z) = u(x, y) + v(x, y) i$ where $z = x + y i$, and identified $\mathbb{C}$ with $\mathbb{R}^2$ to get $g$. + +Since $g$ only involves real inputs and outputs, we already know how to write a Jacobian-vector product for it, say given a tangent vector $(c, d) \in \mathbb{R}^2$, namely: + +$\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} +\begin{bmatrix} c \\ d \end{bmatrix}$. + +To get a JVP for the original function $f$ applied to a tangent vector $c + di \in \mathbb{C}$, we just use the same definition and identify the result as another complex number, + +$\partial f(x + y i)(c + d i) = +\begin{matrix} \begin{bmatrix} 1 & i \end{bmatrix} \\ ~ \end{matrix} +\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} +\begin{bmatrix} c \\ d \end{bmatrix}$. + +That's our definition of the JVP of a $\mathbb{C} \to \mathbb{C}$ function! Notice it doesn't matter whether or not $f$ is holomorphic: the JVP is unambiguous. + +Here's a check: + +```{code-cell} +from jax import random, grad, jvp + +def check(seed): + key = random.key(seed) + + # random coeffs for u and v + key, subkey = random.split(key) + a, b, c, d = random.uniform(subkey, (4,)) + + def fun(z): + x, y = jnp.real(z), jnp.imag(z) + return u(x, y) + v(x, y) * 1j + + def u(x, y): + return a * x + b * y + + def v(x, y): + return c * x + d * y + + # primal point + key, subkey = random.split(key) + x, y = random.uniform(subkey, (2,)) + z = x + y * 1j + + # tangent vector + key, subkey = random.split(key) + c, d = random.uniform(subkey, (2,)) + z_dot = c + d * 1j + + # check jvp + _, ans = jvp(fun, (z,), (z_dot,)) + expected = (grad(u, 0)(x, y) * c + + grad(u, 1)(x, y) * d + + grad(v, 0)(x, y) * c * 1j+ + grad(v, 1)(x, y) * d * 1j) + print(jnp.allclose(ans, expected)) +``` + +```{code-cell} +check(0) +check(1) +check(2) +``` + +What about VJPs? We do something pretty similar: for a cotangent vector $c + di \in \mathbb{C}$ we define the VJP of $f$ as + +$(c + di)^* \; \partial f(x + y i) = +\begin{matrix} \begin{bmatrix} c & -d \end{bmatrix} \\ ~ \end{matrix} +\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} +\begin{bmatrix} 1 \\ -i \end{bmatrix}$. + +What's with the negatives? They're just to take care of complex conjugation, and the fact that we're working with covectors. + +Here's a check of the VJP rules: + +```{code-cell} +from jax import vjp + +def check(seed): + key = random.key(seed) + + # random coeffs for u and v + key, subkey = random.split(key) + a, b, c, d = random.uniform(subkey, (4,)) + + def fun(z): + x, y = jnp.real(z), jnp.imag(z) + return u(x, y) + v(x, y) * 1j + + def u(x, y): + return a * x + b * y + + def v(x, y): + return c * x + d * y + + # primal point + key, subkey = random.split(key) + x, y = random.uniform(subkey, (2,)) + z = x + y * 1j + + # cotangent vector + key, subkey = random.split(key) + c, d = random.uniform(subkey, (2,)) + z_bar = jnp.array(c + d * 1j) # for dtype control + + # check vjp + _, fun_vjp = vjp(fun, z) + ans, = fun_vjp(z_bar) + expected = (grad(u, 0)(x, y) * c + + grad(v, 0)(x, y) * (-d) + + grad(u, 1)(x, y) * c * (-1j) + + grad(v, 1)(x, y) * (-d) * (-1j)) + assert jnp.allclose(ans, expected, atol=1e-5, rtol=1e-5) +``` + +```{code-cell} +check(0) +check(1) +check(2) +``` + +What about convenience wrappers like {func}`jax.grad`, {func}`jax.jacfwd`, and {func}`jax.jacrev`? + +For $\mathbb{R} \to \mathbb{R}$ functions, recall we defined `grad(f)(x)` as being `vjp(f, x)[1](1.0)`, which works because applying a VJP to a `1.0` value reveals the gradient (i.e. Jacobian, or derivative). We can do the same thing for $\mathbb{C} \to \mathbb{R}$ functions: we can still use `1.0` as the cotangent vector, and we just get out a complex number result summarizing the full Jacobian: + +```{code-cell} +def f(z): + x, y = jnp.real(z), jnp.imag(z) + return x**2 + y**2 + +z = 3. + 4j +grad(f)(z) +``` + +For general $\mathbb{C} \to \mathbb{C}$ functions, the Jacobian has 4 real-valued degrees of freedom (as in the 2x2 Jacobian matrices above), so we can't hope to represent all of them within a complex number. But we can for holomorphic functions! A holomorphic function is precisely a $\mathbb{C} \to \mathbb{C}$ function with the special property that its derivative can be represented as a single complex number. (The [Cauchy-Riemann equations](https://en.wikipedia.org/wiki/Cauchy%E2%80%93Riemann_equations) ensure that the above 2x2 Jacobians have the special form of a scale-and-rotate matrix in the complex plane, i.e. the action of a single complex number under multiplication.) And we can reveal that one complex number using a single call to `vjp` with a covector of `1.0`. + +Because this only works for holomorphic functions, to use this trick we need to promise JAX that our function is holomorphic; otherwise, JAX will raise an error when {func}`jax.grad` is used for a complex-output function: + +```{code-cell} +def f(z): + return jnp.sin(z) + +z = 3. + 4j +grad(f, holomorphic=True)(z) +``` + +All the `holomorphic=True` promise does is disable the error when the output is complex-valued. We can still write `holomorphic=True` when the function isn't holomorphic, but the answer we get out won't represent the full Jacobian. Instead, it'll be the Jacobian of the function where we just discard the imaginary part of the output: + +```{code-cell} +def f(z): + return jnp.conjugate(z) + +z = 3. + 4j +grad(f, holomorphic=True)(z) # f is not actually holomorphic! +``` + +There are some useful upshots for how {func}`jax.grad` works here: + +1. We can use {func}`jax.grad` on holomorphic $\mathbb{C} \to \mathbb{C}$ functions. +2. We can use {func}`jax.grad` to optimize $f : \mathbb{C} \to \mathbb{R}$ functions, like real-valued loss functions of complex parameters `x`, by taking steps in the direction of the conjugate of `grad(f)(x)`. +3. If we have an $\mathbb{R} \to \mathbb{R}$ function that just happens to use some complex-valued operations internally (some of which must be non-holomorphic, e.g. FFTs used in convolutions) then {func}`jax.grad` still works and we get the same result that an implementation using only real values would have given. + +In any case, JVPs and VJPs are always unambiguous. And if we wanted to compute the full Jacobian matrix of a non-holomorphic $\mathbb{C} \to \mathbb{C}$ function, we can do it with JVPs or VJPs! + + +You should expect complex numbers to work everywhere in JAX. Here's differentiating through a Cholesky decomposition of a complex matrix: + +```{code-cell} +A = jnp.array([[5., 2.+3j, 5j], + [2.-3j, 7., 1.+7j], + [-5j, 1.-7j, 12.]]) + +def f(X): + L = jnp.linalg.cholesky(X) + return jnp.sum((L - jnp.sin(L))**2) + +grad(f, holomorphic=True)(A) +``` diff --git a/docs/debugging.md b/docs/debugging.md index b86e32cb6522..9aa646f1ecb3 100644 --- a/docs/debugging.md +++ b/docs/debugging.md @@ -275,11 +275,3 @@ Read more in [](debugging/flags). ## Next steps Check out the {ref}`advanced-debugging` to learn more about debugging in JAX. - -```{toctree} -:hidden: - -debugging/print_breakpoint -debugging/checkify_guide -debugging/flags -``` diff --git a/docs/debugging/flags.md b/docs/debugging/flags.md index 53009500f8fe..a879fb69e16e 100644 --- a/docs/debugging/flags.md +++ b/docs/debugging/flags.md @@ -1,3 +1,17 @@ +--- +jupytext: + formats: md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.4 +kernelspec: + display_name: Python 3 + language: python + name: python3 +--- + (debugging-flags)= # JAX debugging flags @@ -9,13 +23,13 @@ JAX offers flags and context managers that enable catching errors more easily. **Summary:** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code. -`jax_debug_nans` is a JAX flag that when enabled, will cause computations to error-out immediately on production of a NaN. Switching this option on adds a NaN check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an `@jax.jit`. +`jax_debug_nans` is a JAX flag that when enabled, will cause computations to error-out immediately on production of a NaN. Switching this option on adds a NaN check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an `@jax.jit`. For code under an `@jax.jit`, the output of every `@jax.jit` function is checked and if a NaN is present it will re-run the function in de-optimized op-by-op mode, effectively removing one level of `@jax.jit` at a time. There could be tricky situations that arise, like NaNs that only occur under a `@jax.jit` but don't get produced in de-optimized mode. In that case you'll see a warning message print out but your code will continue to execute. -If the NaNs are being produced in the backward pass of a gradient evaluation, when an exception is raised several frames up in the stack trace you will be in the backward_pass function, which is essentially a simple jaxpr interpreter that walks the sequence of primitive operations in reverse. +If the NaNs are being produced in the backward pass of a gradient evaluation, when an exception is raised several frames up in the stack trace you will be in the backward_pass function, which is essentially a simple jaxpr interpreter that walks the sequence of primitive operations in reverse. ### Usage @@ -27,7 +41,7 @@ If you want to trace where NaNs are occurring in your functions or gradients, yo ### Example(s) -```python +```{code-cell} import jax import jax.numpy as jnp import traceback @@ -46,7 +60,7 @@ except FloatingPointError as e: The NaN generated was caught. By running `%debug`, we can get a post-mortem debugger. This also works with functions under `@jax.jit`, as the example below shows. -```python +```{code-cell} :tags: [raises-exception] jax.jit(f)(5.) @@ -56,7 +70,7 @@ When this code sees a NaN in the output of an `@jax.jit` function, it calls into The `jax.debug_nans` context manager can be used to activate/deactivate NaN debugging. Since we activated it above with `jax.config.update`, let's deactivate it: -```python +```{code-cell} with jax.debug_nans(False): print(jax.jit(f)(5.)) ``` diff --git a/docs/debugging/index.md b/docs/debugging/index.md index 936c701f0c00..724d29af34b6 100644 --- a/docs/debugging/index.md +++ b/docs/debugging/index.md @@ -138,9 +138,8 @@ ENTRY main.5 { :caption: Read more :maxdepth: 1 +flags print_breakpoint checkify_guide -./flags xla_metadata ``` - diff --git a/docs/faq.rst b/docs/faq.rst index 2d3c920498f6..5653ff1cbb26 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -137,195 +137,14 @@ on GitHub. How to use ``jit`` with methods? -------------------------------- -Most examples of :func:`jax.jit` concern decorating stand-alone Python functions, -but decorating a method within a class introduces some complication. For example, -consider the following simple class, where we've used a standard :func:`~jax.jit` -annotation on a method:: - - >>> import jax.numpy as jnp - >>> from jax import jit - - >>> class CustomClass: - ... def __init__(self, x: jnp.ndarray, mul: bool): - ... self.x = x - ... self.mul = mul - ... - ... @jit # <---- How to do this correctly? - ... def calc(self, y): - ... if self.mul: - ... return self.x * y - ... return y - -However, this approach will result in an error when you attempt to call this method:: - - >>> c = CustomClass(2, True) - >>> c.calc(3) # doctest: +SKIP - --------------------------------------------------------------------------- - TypeError Traceback (most recent call last) - File "", line 1, in ' of type is not a valid JAX type. - -The problem is that the first argument to the function is ``self``, which has type -``CustomClass``, and JAX does not know how to handle this type. -There are three basic strategies we might use in this case, and we'll discuss -them below. - -Strategy 1: JIT-compiled helper function -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -The most straightforward approach is to create a helper function external to the class -that can be JIT-decorated in the normal way. For example:: - - >>> from functools import partial - - >>> class CustomClass: - ... def __init__(self, x: jnp.ndarray, mul: bool): - ... self.x = x - ... self.mul = mul - ... - ... def calc(self, y): - ... return _calc(self.mul, self.x, y) - - >>> @partial(jit, static_argnums=0) - ... def _calc(mul, x, y): - ... if mul: - ... return x * y - ... return y - -The result will work as expected:: - - >>> c = CustomClass(2, True) - >>> print(c.calc(3)) - 6 - -The benefit of such an approach is that it is simple, explicit, and it avoids the need -to teach JAX how to handle objects of type ``CustomClass``. However, you may wish to -keep all the method logic in the same place. - -Strategy 2: Marking ``self`` as static -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Another common pattern is to use ``static_argnums`` to mark the ``self`` argument as static. -But this must be done with care to avoid unexpected results. -You may be tempted to simply do this:: - - >>> class CustomClass: - ... def __init__(self, x: jnp.ndarray, mul: bool): - ... self.x = x - ... self.mul = mul - ... - ... # WARNING: this example is broken, as we'll see below. Don't copy & paste! - ... @partial(jit, static_argnums=0) - ... def calc(self, y): - ... if self.mul: - ... return self.x * y - ... return y - -If you call the method, it will no longer raise an error:: - - >>> c = CustomClass(2, True) - >>> print(c.calc(3)) - 6 - -However, there is a catch: if you mutate the object after the first method call, the -subsequent method call may return an incorrect result:: - - >>> c.mul = False - >>> print(c.calc(3)) # Should print 3 - 6 - -Why is this? When you mark an object as static, it will effectively be used as a dictionary -key in JIT's internal compilation cache, meaning its hash (i.e. ``hash(obj)``) equality -(i.e. ``obj1 == obj2``) and object identity (i.e. ``obj1 is obj2``) will be assumed to have -consistent behavior. The default ``__hash__`` for a custom object is its object ID, and so -JAX has no way of knowing that a mutated object should trigger a re-compilation. - -You can partially address this by defining an appropriate ``__hash__`` and ``__eq__`` methods -for your object; for example:: - - >>> class CustomClass: - ... def __init__(self, x: jnp.ndarray, mul: bool): - ... self.x = x - ... self.mul = mul - ... - ... @partial(jit, static_argnums=0) - ... def calc(self, y): - ... if self.mul: - ... return self.x * y - ... return y - ... - ... def __hash__(self): - ... return hash((self.x, self.mul)) - ... - ... def __eq__(self, other): - ... return (isinstance(other, CustomClass) and - ... (self.x, self.mul) == (other.x, other.mul)) - -(see the :meth:`object.__hash__` documentation for more discussion of the requirements -when overriding ``__hash__``). - -This should work correctly with JIT and other transforms **so long as you never mutate -your object**. Mutations of objects used as hash keys lead to several subtle problems, -which is why for example mutable Python containers (e.g. :class:`dict`, :class:`list`) -don't define ``__hash__``, while their immutable counterparts (e.g. :class:`tuple`) do. - -If your class relies on in-place mutations (such as setting ``self.attr = ...`` within its -methods), then your object is not really "static" and marking it as such may lead to problems. -Fortunately, there's another option for this case. - -Strategy 3: Making ``CustomClass`` a PyTree -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -The most flexible approach to correctly JIT-compiling a class method is to register the -type as a custom PyTree object; see :ref:`pytrees-custom-pytree-nodes`. This lets you specify -exactly which components of the class should be treated as static and which should be -treated as dynamic. Here's how it might look:: - - >>> class CustomClass: - ... def __init__(self, x: jnp.ndarray, mul: bool): - ... self.x = x - ... self.mul = mul - ... - ... @jit - ... def calc(self, y): - ... if self.mul: - ... return self.x * y - ... return y - ... - ... def _tree_flatten(self): - ... children = (self.x,) # arrays / dynamic values - ... aux_data = {'mul': self.mul} # static values - ... return (children, aux_data) - ... - ... @classmethod - ... def _tree_unflatten(cls, aux_data, children): - ... return cls(*children, **aux_data) - - >>> from jax import tree_util - >>> tree_util.register_pytree_node(CustomClass, - ... CustomClass._tree_flatten, - ... CustomClass._tree_unflatten) - -This is certainly more involved, but it solves all the issues associated with the simpler -approaches used above:: - - >>> c = CustomClass(2, True) - >>> print(c.calc(3)) - 6 - - >>> c.mul = False # mutation is detected - >>> print(c.calc(3)) - 3 - - >>> c = CustomClass(jnp.array(2), True) # non-hashable x is supported - >>> print(c.calc(3)) - 6 - -So long as your ``tree_flatten`` and ``tree_unflatten`` functions correctly handle all -relevant attributes in the class, you should be able to use objects of this type directly -as arguments to JIT-compiled functions, without any special annotations. + +Moved to :ref:`jax-jit-class-methods`. .. _faq-jax-vs-numpy: Is JAX faster than NumPy? -~~~~~~~~~~~~~~~~~~~~~~~~~ +------------------------- + One question users frequently attempt to answer with such benchmarks is whether JAX is faster than NumPy; due to the difference in the two packages, there is not a simple answer. diff --git a/docs/gradient-checkpointing.md b/docs/gradient-checkpointing.md index 929956ad6e31..1b6463f65024 100644 --- a/docs/gradient-checkpointing.md +++ b/docs/gradient-checkpointing.md @@ -19,7 +19,7 @@ kernelspec: In this tutorial, you will learn how to control JAX automatic differentiation's saved values using {func}`jax.checkpoint` (also known as {func}`jax.remat`), which can be particularly helpful in machine learning. -If you are new to automatic differentiation (autodiff) or need to refresh your memory, JAX has {ref}`automatic-differentiation` and {ref}`advanced-autodiff` tutorials. +If you are new to automatic differentiation (autodiff) or need to refresh your memory, JAX has an {ref}`automatic-differentiation` tutorial and several {ref}`Advanced automatic differentiation guides `. **TL;DR** Use the {func}`jax.checkpoint` decorator (aliased as {func}`jax.remat`) with {func}`jax.grad` to control which intermediates are saved on the forward pass versus the recomputed intermediates on the backward pass, trading off memory and FLOPs. @@ -144,7 +144,7 @@ print_fwd_bwd(f3, W1, W2, W3, x) ### Let's think step by step -**Note:** It may help to check out the {ref}`advanced-autodiff` tutorial prior to continuing here. +**Note:** It may help to check out the {ref}`"Advanced automatic differentiation" guides ` prior to continuing here. #### `jax.checkpoint` fundamentals diff --git a/docs/higher-order.md b/docs/higher-order.md new file mode 100644 index 000000000000..e835d3af82cc --- /dev/null +++ b/docs/higher-order.md @@ -0,0 +1,336 @@ +--- +jupytext: + formats: md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.4 +kernelspec: + display_name: Python 3 + language: python + name: python3 +--- + +# Higher-order derivatives + +## Taking gradients (part 2) + +JAX's autodiff makes it easy to compute higher-order derivatives, because the functions that compute derivatives are themselves differentiable. Thus, higher-order derivatives are as easy as stacking transformations. + +The single-variable case was covered in the {ref}`automatic-differentiation` tutorial, where the example showed how to use {func}`jax.grad` to compute the derivative of $f(x) = x^3 + 2x^2 - 3x + 1$. + +In the multivariable case, higher-order derivatives are more complicated. The second-order derivative of a function is represented by its [Hessian matrix](https://en.wikipedia.org/wiki/Hessian_matrix), defined according to: + +$$(\mathbf{H}f)_{i,j} = \frac{\partial^2 f}{\partial_i\partial_j}.$$ + +The Hessian of a real-valued function of several variables, $f: \mathbb R^n\to\mathbb R$, can be identified with the [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant) of its gradient. + +JAX provides two transformations for computing the Jacobian of a function, {func}`jax.jacfwd` and {func}`jax.jacrev`, corresponding to forward- and reverse-mode autodiff. They give the same answer, but one can be more efficient than the other in different circumstances – refer to the [video about autodiff](https://www.youtube.com/watch?v=wG_nF1awSSY). + +```{code-cell} +import jax + +def hessian(f): + return jax.jacfwd(jax.grad(f)) +``` + +Let's double check this is correct on the dot-product $f: \mathbf{x} \mapsto \mathbf{x} ^\top \mathbf{x}$. + +if $i=j$, $\frac{\partial^2 f}{\partial_i\partial_j}(\mathbf{x}) = 2$. Otherwise, $\frac{\partial^2 f}{\partial_i\partial_j}(\mathbf{x}) = 0$. + +```{code-cell} +import jax.numpy as jnp + +def f(x): + return jnp.dot(x, x) + +hessian(f)(jnp.array([1., 2., 3.])) +``` + +## Higher-order derivative applications + +Some meta-learning techniques, such as Model-Agnostic Meta-Learning ([MAML](https://arxiv.org/abs/1703.03400)), require differentiating through gradient updates. In other frameworks this can be quite cumbersome, but in JAX it's much easier: + +```python +def meta_loss_fn(params, data): + """Computes the loss after one step of SGD.""" + grads = jax.grad(loss_fn)(params, data) + return loss_fn(params - lr * grads, data) + +meta_grads = jax.grad(meta_loss_fn)(params, data) +``` + +(stopping-gradients)= +### Stopping gradients + +Autodiff enables automatic computation of the gradient of a function with respect to its inputs. Sometimes, however, you might want some additional control: for instance, you might want to avoid backpropagating gradients through some subset of the computational graph. + +Consider for instance the TD(0) ([temporal difference](https://en.wikipedia.org/wiki/Temporal_difference_learning)) reinforcement learning update. This is used to learn to estimate the *value* of a state in an environment from experience of interacting with the environment. Let's assume the value estimate $v_{\theta}(s_{t-1}$) in a state $s_{t-1}$ is parameterised by a linear function. + +```{code-cell} +# Value function and initial parameters +value_fn = lambda theta, state: jnp.dot(theta, state) +theta = jnp.array([0.1, -0.1, 0.]) +``` + +Consider a transition from a state $s_{t-1}$ to a state $s_t$ during which you observed the reward $r_t$ + +```{code-cell} +# An example transition. +s_tm1 = jnp.array([1., 2., -1.]) +r_t = jnp.array(1.) +s_t = jnp.array([2., 1., 0.]) +``` + +The TD(0) update to the network parameters is: + +$$ +\Delta \theta = (r_t + v_{\theta}(s_t) - v_{\theta}(s_{t-1})) \nabla v_{\theta}(s_{t-1}) +$$ + +This update is not the gradient of any loss function. + +However, it can be **written** as the gradient of the pseudo loss function + +$$ +L(\theta) = - \frac{1}{2} [r_t + v_{\theta}(s_t) - v_{\theta}(s_{t-1})]^2 +$$ + +if the dependency of the target $r_t + v_{\theta}(s_t)$ on the parameter $\theta$ is ignored. + +How can you implement this in JAX? If you write the pseudo loss naively, you get: + +```{code-cell} +def td_loss(theta, s_tm1, r_t, s_t): + v_tm1 = value_fn(theta, s_tm1) + target = r_t + value_fn(theta, s_t) + return -0.5 * ((target - v_tm1) ** 2) + +td_update = jax.grad(td_loss) +delta_theta = td_update(theta, s_tm1, r_t, s_t) + +delta_theta +``` + +But `td_update` will **not** compute a TD(0) update, because the gradient computation will include the dependency of `target` on $\theta$. + +You can use {func}`jax.lax.stop_gradient` to force JAX to ignore the dependency of the target on $\theta$: + +```{code-cell} +def td_loss(theta, s_tm1, r_t, s_t): + v_tm1 = value_fn(theta, s_tm1) + target = r_t + value_fn(theta, s_t) + return -0.5 * ((jax.lax.stop_gradient(target) - v_tm1) ** 2) + +td_update = jax.grad(td_loss) +delta_theta = td_update(theta, s_tm1, r_t, s_t) + +delta_theta +``` + +This will treat `target` as if it did **not** depend on the parameters $\theta$ and compute the correct update to the parameters. + +Now, let's also calculate $\Delta \theta$ using the original TD(0) update expression, to cross-check our work. You may wish to try and implement this yourself using {func}`jax.grad` and your knowledge so far. Here's our solution: + +```{code-cell} +s_grad = jax.grad(value_fn)(theta, s_tm1) +delta_theta_original_calculation = (r_t + value_fn(theta, s_t) - value_fn(theta, s_tm1)) * s_grad + +delta_theta_original_calculation # [1.2, 2.4, -1.2], same as `delta_theta` +``` + +`jax.lax.stop_gradient` may also be useful in other settings, for instance if you want the gradient from some loss to only affect a subset of the parameters of the neural network (because, for instance, the other parameters are trained using a different loss). + +### Straight-through estimator using `stop_gradient` + +The straight-through estimator is a trick for defining a 'gradient' of a function that is otherwise non-differentiable. Given a non-differentiable function $f : \mathbb{R}^n \to \mathbb{R}^n$ that is used as part of a larger function that we wish to find a gradient of, we simply pretend during the backward pass that $f$ is the identity function. This can be implemented neatly using `jax.lax.stop_gradient`: + +```{code-cell} +def f(x): + return jnp.round(x) # non-differentiable + +def straight_through_f(x): + # Create an exactly-zero expression with Sterbenz lemma that has + # an exactly-one gradient. + zero = x - jax.lax.stop_gradient(x) + return zero + jax.lax.stop_gradient(f(x)) + +print("f(x): ", f(3.2)) +print("straight_through_f(x):", straight_through_f(3.2)) + +print("grad(f)(x):", jax.grad(f)(3.2)) +print("grad(straight_through_f)(x):", jax.grad(straight_through_f)(3.2)) +``` + +### Per-example gradients + +While most ML systems compute gradients and updates from batches of data, for reasons of computational efficiency and/or variance reduction, it is sometimes necessary to have access to the gradient/update associated with each specific sample in the batch. + +For instance, this is needed to prioritize data based on gradient magnitude, or to apply clipping / normalisations on a sample by sample basis. + +In many frameworks (PyTorch, TF, Theano) it is often not trivial to compute per-example gradients, because the library directly accumulates the gradient over the batch. Naive workarounds, such as computing a separate loss per example and then aggregating the resulting gradients are typically very inefficient. + +In JAX, you can define the code to compute the gradient per-sample in an easy but efficient way. + +Just combine the {func}`jax.jit`, {func}`jax.vmap` and {func}`jax.grad` transformations together: + +```{code-cell} +perex_grads = jax.jit(jax.vmap(jax.grad(td_loss), in_axes=(None, 0, 0, 0))) + +# Test it: +batched_s_tm1 = jnp.stack([s_tm1, s_tm1]) +batched_r_t = jnp.stack([r_t, r_t]) +batched_s_t = jnp.stack([s_t, s_t]) + +perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t) +``` + +Let's go through this one transformation at a time. + +First, you apply {func}`jax.grad` to `td_loss` to obtain a function that computes the gradient of the loss w.r.t. the parameters on single (unbatched) inputs: + +```{code-cell} +dtdloss_dtheta = jax.grad(td_loss) + +dtdloss_dtheta(theta, s_tm1, r_t, s_t) +``` + +This function computes one row of the array above. + +Then, you vectorise this function using {func}`jax.vmap`. This adds a batch dimension to all inputs and outputs. Now, given a batch of inputs, you produce a batch of outputs — each output in the batch corresponds to the gradient for the corresponding member of the input batch. + +```{code-cell} +almost_perex_grads = jax.vmap(dtdloss_dtheta) + +batched_theta = jnp.stack([theta, theta]) +almost_perex_grads(batched_theta, batched_s_tm1, batched_r_t, batched_s_t) +``` + +This isn't quite what we want, because we have to manually feed this function a batch of `theta`s, whereas we actually want to use a single `theta`. We fix this by adding `in_axes` to the {func}`jax.vmap`, specifying theta as `None`, and the other args as `0`. This makes the resulting function add an extra axis only to the other arguments, leaving `theta` unbatched, as we want: + +```{code-cell} +inefficient_perex_grads = jax.vmap(dtdloss_dtheta, in_axes=(None, 0, 0, 0)) + +inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t) +``` + +This does what we want, but is slower than it has to be. Now, you wrap the whole thing in a {func}`jax.jit` to get the compiled, efficient version of the same function: + +```{code-cell} +perex_grads = jax.jit(inefficient_perex_grads) + +perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t) +``` + +```{code-cell} +%timeit inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready() +%timeit perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready() +``` + +### Hessian-vector products with `jax.grad`-of-`jax.grad` + +One thing you can do with higher-order {func}`jax.grad` is build a Hessian-vector product function. (Later on you'll write an even more efficient implementation that mixes both forward- and reverse-mode, but this one will use pure reverse-mode.) + +A Hessian-vector product function can be useful in a [truncated Newton Conjugate-Gradient algorithm](https://en.wikipedia.org/wiki/Truncated_Newton_method) for minimizing smooth convex functions, or for studying the curvature of neural network training objectives (e.g. [1](https://arxiv.org/abs/1406.2572), [2](https://arxiv.org/abs/1811.07062), [3](https://arxiv.org/abs/1706.04454), [4](https://arxiv.org/abs/1802.03451)). + +For a scalar-valued function $f : \mathbb{R}^n \to \mathbb{R}$ with continuous second derivatives (so that the Hessian matrix is symmetric), the Hessian at a point $x \in \mathbb{R}^n$ is written as $\partial^2 f(x)$. A Hessian-vector product function is then able to evaluate + +$\qquad v \mapsto \partial^2 f(x) \cdot v$ + +for any $v \in \mathbb{R}^n$. + +The trick is not to instantiate the full Hessian matrix: if $n$ is large, perhaps in the millions or billions in the context of neural networks, then that might be impossible to store. + +Luckily, {func}`jax.grad` already gives us a way to write an efficient Hessian-vector product function. You just have to use the identity: + +$\qquad \partial^2 f (x) v = \partial [x \mapsto \partial f(x) \cdot v] = \partial g(x)$, + +where $g(x) = \partial f(x) \cdot v$ is a new scalar-valued function that dots the gradient of $f$ at $x$ with the vector $v$. Notice that you're only ever differentiating scalar-valued functions of vector-valued arguments, which is exactly where you know {func}`jax.grad` is efficient. + +In JAX code, you can just write this: + +```{code-cell} +def hvp(f, x, v): + return jax.grad(lambda x: jnp.vdot(jax.grad(f)(x), v))(x) +``` + +This example shows that you can freely use lexical closure, and JAX will never get perturbed or confused. + +You will check this implementation a few cells down, once you learn how to compute dense Hessian matrices. You'll also write an even better version that uses both forward-mode and reverse-mode. + +### Jacobians and Hessians using `jax.jacfwd` and `jax.jacrev` + +You can compute full Jacobian matrices using the {func}`jax.jacfwd` and {func}`jax.jacrev` functions: + +```{code-cell} +from jax import jacfwd, jacrev + +# Define a sigmoid function. +def sigmoid(x): + return 0.5 * (jnp.tanh(x / 2) + 1) + +# Outputs probability of a label being true. +def predict(W, b, inputs): + return sigmoid(jnp.dot(inputs, W) + b) + +# Build a toy dataset. +inputs = jnp.array([[0.52, 1.12, 0.77], + [0.88, -1.08, 0.15], + [0.52, 0.06, -1.30], + [0.74, -2.49, 1.39]]) + +# Initialize random model coefficients +key = jax.random.key(0) +key, W_key, b_key = jax.random.split(key, 3) +W = jax.random.normal(W_key, (3,)) +b = jax.random.normal(b_key, ()) + +# Isolate the function from the weight matrix to the predictions +f = lambda W: predict(W, b, inputs) + +J = jacfwd(f)(W) +print("jacfwd result, with shape", J.shape) +print(J) + +J = jacrev(f)(W) +print("jacrev result, with shape", J.shape) +print(J) +``` + +These two functions compute the same values (up to machine numerics), but differ in their implementation: {func}`jax.jacfwd` uses forward-mode automatic differentiation, which is more efficient for "tall" Jacobian matrices (more outputs than inputs), while {func}`jax.jacrev` uses reverse-mode, which is more efficient for "wide" Jacobian matrices (more inputs than outputs). For matrices that are near-square, {func}`jax.jacfwd` probably has an edge over {func}`jax.jacrev`. + +You can also use {func}`jax.jacfwd` and {func}`jax.jacrev` with container types: + +```{code-cell} +def predict_dict(params, inputs): + return predict(params['W'], params['b'], inputs) + +J_dict = jax.jacrev(predict_dict)({'W': W, 'b': b}, inputs) +for k, v in J_dict.items(): + print("Jacobian from {} to logits is".format(k)) + print(v) +``` + +For more details on forward- and reverse-mode, as well as how to implement {func}`jax.jacfwd` and {func}`jax.jacrev` as efficiently as possible, read on! + +Using a composition of two of these functions gives us a way to compute dense Hessian matrices: + +```{code-cell} +def hessian(f): + return jax.jacfwd(jax.jacrev(f)) + +H = hessian(f)(W) +print("hessian, with shape", H.shape) +print(H) +``` + +This shape makes sense: if you start with a function $f : \mathbb{R}^n \to \mathbb{R}^m$, then at a point $x \in \mathbb{R}^n$ you expect to get the shapes: + +* $f(x) \in \mathbb{R}^m$, the value of $f$ at $x$, +* $\partial f(x) \in \mathbb{R}^{m \times n}$, the Jacobian matrix at $x$, +* $\partial^2 f(x) \in \mathbb{R}^{m \times n \times n}$, the Hessian at $x$, + +and so on. + +To implement `hessian`, you could have used `jacfwd(jacrev(f))` or `jacrev(jacfwd(f))` or any other composition of these two. But forward-over-reverse is typically the most efficient. That's because in the inner Jacobian computation we're often differentiating a function wide Jacobian (maybe like a loss function $f : \mathbb{R}^n \to \mathbb{R}$), while in the outer Jacobian computation we're differentiating a function with a square Jacobian (since $\nabla f : \mathbb{R}^n \to \mathbb{R}^n$), which is where forward-mode wins out. diff --git a/docs/jacobian-vector-products.md b/docs/jacobian-vector-products.md new file mode 100644 index 000000000000..bbc678d02d1c --- /dev/null +++ b/docs/jacobian-vector-products.md @@ -0,0 +1,358 @@ +--- +jupytext: + formats: md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.4 +kernelspec: + display_name: Python 3 + name: python3 +--- + +(advanced-guides-jvp-vjp)= +# Forward- and reverse-mode autodiff in JAX + +## Jacobian-Vector products (JVPs, a.k.a. forward-mode autodiff) + +JAX includes efficient and general implementations of both forward- and reverse-mode automatic differentiation. The familiar {func}`jax.grad` function is built on reverse-mode, but to explain the difference between the two modes, and when each can be useful, you need a bit of math background. + +### JVPs in math + +Mathematically, given a function $f : \mathbb{R}^n \to \mathbb{R}^m$, the Jacobian of $f$ evaluated at an input point $x \in \mathbb{R}^n$, denoted $\partial f(x)$, is often thought of as a matrix in $\mathbb{R}^m \times \mathbb{R}^n$: + +$\qquad \partial f(x) \in \mathbb{R}^{m \times n}$. + +But you can also think of $\partial f(x)$ as a linear map, which maps the tangent space of the domain of $f$ at the point $x$ (which is just another copy of $\mathbb{R}^n$) to the tangent space of the codomain of $f$ at the point $f(x)$ (a copy of $\mathbb{R}^m$): + +$\qquad \partial f(x) : \mathbb{R}^n \to \mathbb{R}^m$. + +This map is called the [pushforward map](https://en.wikipedia.org/wiki/Pushforward_(differential)) of $f$ at $x$. The Jacobian matrix is just the matrix for this linear map on a standard basis. + +If you don't commit to one specific input point $x$, then you can think of the function $\partial f$ as first taking an input point and returning the Jacobian linear map at that input point: + +$\qquad \partial f : \mathbb{R}^n \to \mathbb{R}^n \to \mathbb{R}^m$. + +In particular, you can uncurry things so that given input point $x \in \mathbb{R}^n$ and a tangent vector $v \in \mathbb{R}^n$, you get back an output tangent vector in $\mathbb{R}^m$. We call that mapping, from $(x, v)$ pairs to output tangent vectors, the *Jacobian-vector product*, and write it as: + +$\qquad (x, v) \mapsto \partial f(x) v$ + +### JVPs in JAX code + +Back in Python code, JAX's {func}`jax.jvp` function models this transformation. Given a Python function that evaluates $f$, JAX's {func}`jax.jvp` is a way to get a Python function for evaluating $(x, v) \mapsto (f(x), \partial f(x) v)$. + +```{code-cell} +import jax +import jax.numpy as jnp + +key = jax.random.key(0) + +# Initialize random model coefficients +key, W_key, b_key = jax.random.split(key, 3) +W = jax.random.normal(W_key, (3,)) +b = jax.random.normal(b_key, ()) + +# Define a sigmoid function. +def sigmoid(x): + return 0.5 * (jnp.tanh(x / 2) + 1) + +# Outputs probability of a label being true. +def predict(W, b, inputs): + return sigmoid(jnp.dot(inputs, W) + b) + +# Build a toy dataset. +inputs = jnp.array([[0.52, 1.12, 0.77], + [0.88, -1.08, 0.15], + [0.52, 0.06, -1.30], + [0.74, -2.49, 1.39]]) + +# Isolate the function from the weight matrix to the predictions +f = lambda W: predict(W, b, inputs) + +key, subkey = jax.random.split(key) +v = jax.random.normal(subkey, W.shape) + +# Push forward the vector `v` along `f` evaluated at `W` +y, u = jax.jvp(f, (W,), (v,)) +``` + +In terms of [Haskell-like type signatures](https://wiki.haskell.org/Type_signature), you could write: + +```haskell +jvp :: (a -> b) -> a -> T a -> (b, T b) +``` + +where `T a` is used to denote the type of the tangent space for `a`. + +In other words, `jvp` takes as arguments a function of type `a -> b`, a value of type `a`, and a tangent vector value of type `T a`. It gives back a pair consisting of a value of type `b` and an output tangent vector of type `T b`. + +The `jvp`-transformed function is evaluated much like the original function, but paired up with each primal value of type `a` it pushes along tangent values of type `T a`. For each primitive numerical operation that the original function would have applied, the `jvp`-transformed function executes a "JVP rule" for that primitive that both evaluates the primitive on the primals and applies the primitive's JVP at those primal values. + +That evaluation strategy has some immediate implications about computational complexity. Since we evaluate JVPs as we go, we don't need to store anything for later, and so the memory cost is independent of the depth of the computation. In addition, the FLOP cost of the `jvp`-transformed function is about 3x the cost of just evaluating the function (one unit of work for evaluating the original function, for example `sin(x)`; one unit for linearizing, like `cos(x)`; and one unit for applying the linearized function to a vector, like `cos_x * v`). Put another way, for a fixed primal point $x$, we can evaluate $v \mapsto \partial f(x) \cdot v$ for about the same marginal cost as evaluating $f$. + +That memory complexity sounds pretty compelling! So why don't we see forward-mode very often in machine learning? + +To answer that, first think about how you could use a JVP to build a full Jacobian matrix. If we apply a JVP to a one-hot tangent vector, it reveals one column of the Jacobian matrix, corresponding to the nonzero entry we fed in. So we can build a full Jacobian one column at a time, and to get each column costs about the same as one function evaluation. That will be efficient for functions with "tall" Jacobians, but inefficient for "wide" Jacobians. + +If you're doing gradient-based optimization in machine learning, you probably want to minimize a loss function from parameters in $\mathbb{R}^n$ to a scalar loss value in $\mathbb{R}$. That means the Jacobian of this function is a very wide matrix: $\partial f(x) \in \mathbb{R}^{1 \times n}$, which we often identify with the Gradient vector $\nabla f(x) \in \mathbb{R}^n$. Building that matrix one column at a time, with each call taking a similar number of FLOPs to evaluate the original function, sure seems inefficient! In particular, for training neural networks, where $f$ is a training loss function and $n$ can be in the millions or billions, this approach just won't scale. + +To do better for functions like this, you just need to use reverse-mode. + +## Vector-Jacobian products (VJPs, a.k.a. reverse-mode autodiff) + +Where forward-mode gives us back a function for evaluating Jacobian-vector products, which we can then use to build Jacobian matrices one column at a time, reverse-mode is a way to get back a function for evaluating vector-Jacobian products (equivalently Jacobian-transpose-vector products), which we can use to build Jacobian matrices one row at a time. + +### VJPs in math + +Let's again consider a function $f : \mathbb{R}^n \to \mathbb{R}^m$. +Starting from our notation for JVPs, the notation for VJPs is pretty simple: + +$\qquad (x, v) \mapsto v \partial f(x)$, + +where $v$ is an element of the cotangent space of $f$ at $x$ (isomorphic to another copy of $\mathbb{R}^m$). When being rigorous, we should think of $v$ as a linear map $v : \mathbb{R}^m \to \mathbb{R}$, and when we write $v \partial f(x)$ we mean function composition $v \circ \partial f(x)$, where the types work out because $\partial f(x) : \mathbb{R}^n \to \mathbb{R}^m$. But in the common case we can identify $v$ with a vector in $\mathbb{R}^m$ and use the two almost interchangeably, just like we might sometimes flip between "column vectors" and "row vectors" without much comment. + +With that identification, we can alternatively think of the linear part of a VJP as the transpose (or adjoint conjugate) of the linear part of a JVP: + +$\qquad (x, v) \mapsto \partial f(x)^\mathsf{T} v$. + +For a given point $x$, we can write the signature as + +$\qquad \partial f(x)^\mathsf{T} : \mathbb{R}^m \to \mathbb{R}^n$. + +The corresponding map on cotangent spaces is often called the [pullback](https://en.wikipedia.org/wiki/Pullback_(differential_geometry)) +of $f$ at $x$. The key for our purposes is that it goes from something that looks like the output of $f$ to something that looks like the input of $f$, just like we might expect from a transposed linear function. + +### VJPs in JAX code + +Switching from math back to Python, the JAX function `vjp` can take a Python function for evaluating $f$ and give us back a Python function for evaluating the VJP $(x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))$. + +```{code-cell} +from jax import vjp + +# Isolate the function from the weight matrix to the predictions +f = lambda W: predict(W, b, inputs) + +y, vjp_fun = vjp(f, W) + +key, subkey = jax.random.split(key) +u = jax.random.normal(subkey, y.shape) + +# Pull back the covector `u` along `f` evaluated at `W` +v = vjp_fun(u) +``` + +In terms of [Haskell-like type signatures](https://wiki.haskell.org/Type_signature), we could write + +```haskell +vjp :: (a -> b) -> a -> (b, CT b -> CT a) +``` + +where we use `CT a` to denote the type for the cotangent space for `a`. In words, `vjp` takes as arguments a function of type `a -> b` and a point of type `a`, and gives back a pair consisting of a value of type `b` and a linear map of type `CT b -> CT a`. + +This is great because it lets us build Jacobian matrices one row at a time, and the FLOP cost for evaluating $(x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))$ is only about three times the cost of evaluating $f$. In particular, if we want the gradient of a function $f : \mathbb{R}^n \to \mathbb{R}$, we can do it in just one call. That's how {func}`jax.grad` is efficient for gradient-based optimization, even for objectives like neural network training loss functions on millions or billions of parameters. + +There's a cost, though the FLOPs are friendly, memory scales with the depth of the computation. Also, the implementation is traditionally more complex than that of forward-mode, though JAX has some tricks up its sleeve (that's a story for a future notebook!). + +For more on how reverse-mode works, check out [this tutorial video from the Deep Learning Summer School in 2017](http://videolectures.net/deeplearning2017_johnson_automatic_differentiation/). + +## Vector-valued gradients with VJPs + +If you're interested in taking vector-valued gradients (like `tf.gradients`): + +```{code-cell} +def vgrad(f, x): + y, vjp_fn = jax.vjp(f, x) + return vjp_fn(jnp.ones(y.shape))[0] + +print(vgrad(lambda x: 3*x**2, jnp.ones((2, 2)))) +``` + +## Hessian-vector products using both forward- and reverse-mode + +In a previous section, you implemented a Hessian-vector product function just using reverse-mode (assuming continuous second derivatives): + +```{code-cell} +def hvp(f, x, v): + return jax.grad(lambda x: jnp.vdot(jax.grad(f)(x), v))(x) +``` + +That's efficient, but you can do even better and save some memory by using forward-mode together with reverse-mode. + +Mathematically, given a function $f : \mathbb{R}^n \to \mathbb{R}$ to differentiate, a point $x \in \mathbb{R}^n$ at which to linearize the function, and a vector $v \in \mathbb{R}^n$, the Hessian-vector product function we want is: + +$(x, v) \mapsto \partial^2 f(x) v$ + +Consider the helper function $g : \mathbb{R}^n \to \mathbb{R}^n$ defined to be the derivative (or gradient) of $f$, namely $g(x) = \partial f(x)$. All you need is its JVP, since that will give us: + +$(x, v) \mapsto \partial g(x) v = \partial^2 f(x) v$. + +We can translate that almost directly into code: + +```{code-cell} +# forward-over-reverse +def hvp(f, primals, tangents): + return jax.jvp(jax.grad(f), primals, tangents)[1] +``` + +Even better, since you didn't have to call {func}`jnp.dot` directly, this `hvp` function works with arrays of any shape and with arbitrary container types (like vectors stored as nested lists/dicts/tuples), and doesn't even have a dependence on {mod}`jax.numpy`. + +Here's an example of how to use it: + +```{code-cell} +def f(X): + return jnp.sum(jnp.tanh(X)**2) + +key, subkey1, subkey2 = jax.random.split(key, 3) +X = jax.random.normal(subkey1, (30, 40)) +V = jax.random.normal(subkey2, (30, 40)) + +def hessian(f): + return jax.jacfwd(jax.jacrev(f)) + +ans1 = hvp(f, (X,), (V,)) +ans2 = jnp.tensordot(hessian(f)(X), V, 2) + +print(jnp.allclose(ans1, ans2, 1e-4, 1e-4)) +``` + +Another way you might consider writing this is using reverse-over-forward: + +```{code-cell} +# Reverse-over-forward +def hvp_revfwd(f, primals, tangents): + g = lambda primals: jax.jvp(f, primals, tangents)[1] + return jax.grad(g)(primals) +``` + +That's not quite as good, though, because forward-mode has less overhead than reverse-mode, and since the outer differentiation operator here has to differentiate a larger computation than the inner one, keeping forward-mode on the outside works best: + +```{code-cell} +# Reverse-over-reverse, only works for single arguments +def hvp_revrev(f, primals, tangents): + x, = primals + v, = tangents + return jax.grad(lambda x: jnp.vdot(jax.grad(f)(x), v))(x) + + +print("Forward over reverse") +%timeit -n10 -r3 hvp(f, (X,), (V,)) +print("Reverse over forward") +%timeit -n10 -r3 hvp_revfwd(f, (X,), (V,)) +print("Reverse over reverse") +%timeit -n10 -r3 hvp_revrev(f, (X,), (V,)) + +print("Naive full Hessian materialization") +%timeit -n10 -r3 jnp.tensordot(jax.hessian(f)(X), V, 2) +``` + +## Composing VJPs, JVPs, and `jax.vmap` + +## Jacobian-Matrix and Matrix-Jacobian products + +Now that you have {func}`jax.jvp` and {func}`jax.vjp` transformations that give you functions to push-forward or pull-back single vectors at a time, you can use JAX's {func}`jax.vmap` [transformation](https://github.com/jax-ml/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, you can use that to write fast matrix-Jacobian and Jacobian-matrix products: + +```{code-cell} +# Isolate the function from the weight matrix to the predictions +f = lambda W: predict(W, b, inputs) + +# Pull back the covectors `m_i` along `f`, evaluated at `W`, for all `i`. +# First, use a list comprehension to loop over rows in the matrix M. +def loop_mjp(f, x, M): + y, vjp_fun = jax.vjp(f, x) + return jnp.vstack([vjp_fun(mi) for mi in M]) + +# Now, use vmap to build a computation that does a single fast matrix-matrix +# multiply, rather than an outer loop over vector-matrix multiplies. +def vmap_mjp(f, x, M): + y, vjp_fun = jax.vjp(f, x) + outs, = jax.vmap(vjp_fun)(M) + return outs + +key = jax.random.key(0) +num_covecs = 128 +U = jax.random.normal(key, (num_covecs,) + y.shape) + +loop_vs = loop_mjp(f, W, M=U) +print('Non-vmapped Matrix-Jacobian product') +%timeit -n10 -r3 loop_mjp(f, W, M=U) + +print('\nVmapped Matrix-Jacobian product') +vmap_vs = vmap_mjp(f, W, M=U) +%timeit -n10 -r3 vmap_mjp(f, W, M=U) + +assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Matrix-Jacobian Products should be identical' +``` + +```{code-cell} +def loop_jmp(f, W, M): + # jvp immediately returns the primal and tangent values as a tuple, + # so we'll compute and select the tangents in a list comprehension + return jnp.vstack([jax.jvp(f, (W,), (mi,))[1] for mi in M]) + +def vmap_jmp(f, W, M): + _jvp = lambda s: jax.jvp(f, (W,), (s,))[1] + return jax.vmap(_jvp)(M) +num_vecs = 128 +S = jax.random.normal(key, (num_vecs,) + W.shape) + +loop_vs = loop_jmp(f, W, M=S) +print('Non-vmapped Jacobian-Matrix product') +%timeit -n10 -r3 loop_jmp(f, W, M=S) +vmap_vs = vmap_jmp(f, W, M=S) +print('\nVmapped Jacobian-Matrix product') +%timeit -n10 -r3 vmap_jmp(f, W, M=S) + +assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Jacobian-Matrix products should be identical' +``` + +## The implementation of `jax.jacfwd` and `jax.jacrev` + +Now that we've seen fast Jacobian-matrix and matrix-Jacobian products, it's not hard to guess how to write {func}`jax.jacfwd` and {func}`jax.jacrev`. We just use the same technique to push-forward or pull-back an entire standard basis (isomorphic to an identity matrix) at once. + +```{code-cell} +from jax import jacrev as builtin_jacrev + +def our_jacrev(f): + def jacfun(x): + y, vjp_fun = jax.vjp(f, x) + # Use vmap to do a matrix-Jacobian product. + # Here, the matrix is the Euclidean basis, so we get all + # entries in the Jacobian at once. + J, = jax.vmap(vjp_fun, in_axes=0)(jnp.eye(len(y))) + return J + return jacfun + +assert jnp.allclose(builtin_jacrev(f)(W), our_jacrev(f)(W)), 'Incorrect reverse-mode Jacobian results!' +``` + +```{code-cell} +from jax import jacfwd as builtin_jacfwd + +def our_jacfwd(f): + def jacfun(x): + _jvp = lambda s: jax.jvp(f, (x,), (s,))[1] + Jt = jax.vmap(_jvp, in_axes=1)(jnp.eye(len(x))) + return jnp.transpose(Jt) + return jacfun + +assert jnp.allclose(builtin_jacfwd(f)(W), our_jacfwd(f)(W)), 'Incorrect forward-mode Jacobian results!' +``` + +Interestingly, the [Autograd](https://github.com/hips/autograd) library couldn't do this. The [implementation](https://github.com/HIPS/autograd/blob/96a03f44da43cd7044c61ac945c483955deba957/autograd/differential_operators.py#L60) of reverse-mode `jacobian` in Autograd had to pull back one vector at a time with an outer-loop `map`. Pushing one vector at a time through the computation is much less efficient than batching it all together with {func}`jax.vmap`. + +Another thing that Autograd couldn't do is {func}`jax.jit`. Interestingly, no matter how much Python dynamism you use in your function to be differentiated, we could always use {func}`jax.jit` on the linear part of the computation. For example: + +```{code-cell} +def f(x): + try: + if x < 3: + return 2 * x ** 3 + else: + raise ValueError + except ValueError: + return jnp.pi * x + +y, f_vjp = jax.vjp(f, 4.) +print(jax.jit(f_vjp)(1.)) +``` diff --git a/docs/jax-primitives.md b/docs/jax-primitives.md index 43dea9bced1a..92579e06ab16 100644 --- a/docs/jax-primitives.md +++ b/docs/jax-primitives.md @@ -321,7 +321,7 @@ assert api.jit(lambda x, y: square_add_prim(x, y), ### Forward differentiation -JAX implements forward differentiation in the form of a Jacobian-Vector Product (JVP) (you can learn more about it in {ref}`advanced-autodiff`). +JAX implements forward differentiation in the form of a Jacobian-Vector Product (JVP) (you can learn more about it in {ref}`advanced-guides-jvp-vjp`). If you attempt to compute the `jvp` function, you'll get an error because you have not yet told JAX how to differentiate the `multiply_add` primitive. diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb index bbf8a29fa286..38d2a0e84383 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb +++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb @@ -564,6 +564,377 @@ "For more details on indexed array updates, see the [documentation for the `.at` property](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at)." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(jax-jit-class-methods)=\n", + "## 🔪 Using `jax.jit` with class methods\n", + "\n", + "Most examples of [`jax.jit`](https://docs.jax.dev/en/latest/_autosummary/jax.jit.html) concern decorating stand-alone Python functions, but decorating a method within a class introduces some complication. For example, consider the following simple class, where we've used a standard `jax.jit` annotation on a method:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "import jax.numpy as jnp\n", + "from jax import jit\n", + "\n", + "class CustomClass:\n", + " def __init__(self, x: jnp.ndarray, mul: bool):\n", + " self.x = x\n", + " self.mul = mul\n", + "\n", + " @jit # <---- How to do this correctly?\n", + " def calc(self, y):\n", + " if self.mul:\n", + " return self.x * y\n", + " return y" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "However, this approach will result in an error when you attempt to call this method:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "tags": [ + "raises-exception" + ] + }, + "outputs": [ + { + "ename": "TypeError", + "evalue": "Error interpreting argument to as an abstract array. The problematic value is of type and was passed to the function at path self.\nThis typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mTypeError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[8]\u001b[39m\u001b[32m, line 2\u001b[39m\n\u001b[32m 1\u001b[39m c = CustomClass(\u001b[32m2\u001b[39m, \u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[32m----> \u001b[39m\u001b[32m2\u001b[39m \u001b[43mc\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcalc\u001b[49m\u001b[43m(\u001b[49m\u001b[32;43m3\u001b[39;49m\u001b[43m)\u001b[49m\n", + " \u001b[31m[... skipping hidden 5 frame]\u001b[39m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/.local/share/mamba/envs/jax-dev/lib/python3.12/site-packages/jax/_src/pjit.py:659\u001b[39m, in \u001b[36m_infer_input_type\u001b[39m\u001b[34m(fun, dbg_fn, explicit_args)\u001b[39m\n\u001b[32m 657\u001b[39m dbg = dbg_fn()\n\u001b[32m 658\u001b[39m arg_description = \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mpath \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdbg.arg_names[i]\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mif\u001b[39;00m\u001b[38;5;250m \u001b[39mdbg.arg_names\u001b[38;5;250m \u001b[39m\u001b[38;5;129;01mis\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;129;01mnot\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01melse\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[33m'\u001b[39m\u001b[33munknown\u001b[39m\u001b[33m'\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m \u001b[38;5;66;03m# pytype: disable=name-error\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m659\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\n\u001b[32m 660\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mError interpreting argument to \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfun\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m as an abstract array.\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 661\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m The problematic value is of type \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(x)\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m and was passed to\u001b[39m\u001b[33m\"\u001b[39m \u001b[38;5;66;03m# pytype: disable=name-error\u001b[39;00m\n\u001b[32m 662\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m the function at \u001b[39m\u001b[38;5;132;01m{\u001b[39;00marg_description\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[33m\"\u001b[39m\n\u001b[32m 663\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mThis typically means that a jit-wrapped function was called with a non-array\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 664\u001b[39m \u001b[33m\"\u001b[39m\u001b[33m argument, and this argument was not marked as static using the\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 665\u001b[39m \u001b[33m\"\u001b[39m\u001b[33m static_argnums or static_argnames parameters of jax.jit.\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 666\u001b[39m ) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 667\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m config.mutable_array_checks.value:\n\u001b[32m 668\u001b[39m check_no_aliased_ref_args(dbg_fn, avals, explicit_args)\n", + "\u001b[31mTypeError\u001b[39m: Error interpreting argument to as an abstract array. The problematic value is of type and was passed to the function at path self.\nThis typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit." + ] + } + ], + "source": [ + "c = CustomClass(2, True)\n", + "c.calc(3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The problem is that the first argument to the function is `self`, which has type `CustomClass`, and JAX does not know how to handle this type. There are three basic strategies we might use in this case, and we'll discuss them below." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Strategy 1: JIT-compiled helper function\n", + "\n", + "The most straightforward approach is to create a helper function external to the class that can be JIT-decorated in the normal way. For example:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "from functools import partial\n", + "\n", + "class CustomClass:\n", + " def __init__(self, x: jnp.ndarray, mul: bool):\n", + " self.x = x\n", + " self.mul = mul\n", + "\n", + " def calc(self, y):\n", + " return _calc(self.mul, self.x, y)\n", + "\n", + "@partial(jit, static_argnums=0)\n", + "def _calc(mul, x, y):\n", + " if mul:\n", + " return x * y\n", + " return y" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The result will work as expected:" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6\n" + ] + } + ], + "source": [ + "c = CustomClass(2, True)\n", + "print(c.calc(3))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The benefit of such an approach is that it is simple, explicit, and it avoids the need to teach JAX how to handle objects of type `CustomClass`. However, you may wish to keep all the method logic in the same place." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Strategy 2: Marking `self` as static\n", + "\n", + "Another common pattern is to use `static_argnums` to mark the `self` argument as static. But this must be done with care to avoid unexpected results. You may be tempted to simply do this:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "class CustomClass:\n", + " def __init__(self, x: jnp.ndarray, mul: bool):\n", + " self.x = x\n", + " self.mul = mul\n", + "\n", + " # WARNING: this example is broken, as we'll see below. Don't copy & paste!\n", + " @partial(jit, static_argnums=0)\n", + " def calc(self, y):\n", + " if self.mul:\n", + " return self.x * y\n", + " return y" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you call the method, it will no longer raise an error:" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6\n" + ] + } + ], + "source": [ + "c = CustomClass(2, True)\n", + "print(c.calc(3))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "However, there is a catch: if you mutate the object after the first method call, the subsequent method call may return an incorrect result:" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6\n" + ] + } + ], + "source": [ + "c.mul = False\n", + "print(c.calc(3)) # Should print 3" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Why is this? When you mark an object as static, it will effectively be used as a dictionary key in JIT's internal compilation cache, meaning its hash (i.e. `hash(obj)`) equality (i.e. `obj1 == obj2`) and object identity (i.e. `obj1 is obj2`) will be assumed to have consistent behavior. The default `__hash__` for a custom object is its object ID, and so JAX has no way of knowing that a mutated object should trigger a re-compilation.\n", + "\n", + "You can partially address this by defining an appropriate `__hash__` and `__eq__` methods for your object; for example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class CustomClass:\n", + " def __init__(self, x: jnp.ndarray, mul: bool):\n", + " self.x = x\n", + " self.mul = mul\n", + "\n", + " @partial(jit, static_argnums=0)\n", + " def calc(self, y):\n", + " if self.mul:\n", + " return self.x * y\n", + " return y\n", + "\n", + " def __hash__(self):\n", + " return hash((self.x, self.mul))\n", + "\n", + " def __eq__(self, other):\n", + " return (isinstance(other, CustomClass) and\n", + " (self.x, self.mul) == (other.x, other.mul))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(see the [`object.__hash__`](https://docs.python.org/3/reference/datamodel.html#object.__hash__) documentation for more discussion of the requirements\n", + "when overriding `__hash__`).\n", + "\n", + "This should work correctly with JIT and other transforms **so long as you never mutate your object**. Mutations of objects used as hash keys lead to several subtle problems, which is why for example mutable Python containers (e.g. [`dict`](https://docs.python.org/3/library/stdtypes.html#dict), [`list`](https://docs.python.org/3/library/stdtypes.html#list)) don't define `__hash__`, while their immutable counterparts (e.g. [`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple)) do.\n", + "\n", + "If your class relies on in-place mutations (such as setting `self.attr = ...` within its methods), then your object is not really \"static\" and marking it as such may lead to problems. Fortunately, there's another option for this case." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Strategy 3: Making `CustomClass` a PyTree\n", + "\n", + "The most flexible approach to correctly JIT-compiling a class method is to register the type as a custom PyTree object; see [Custom pytree nodes](https://docs.jax.dev/en/latest/custom_pytrees.html#pytrees-custom-pytree-nodes). This lets you specify exactly which components of the class should be treated as static and which should be\n", + "treated as dynamic. Here's how it might look:" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "class CustomClass:\n", + " def __init__(self, x: jnp.ndarray, mul: bool):\n", + " self.x = x\n", + " self.mul = mul\n", + "\n", + " @jit\n", + " def calc(self, y):\n", + " if self.mul:\n", + " return self.x * y\n", + " return y\n", + "\n", + " def _tree_flatten(self):\n", + " children = (self.x,) # arrays / dynamic values\n", + " aux_data = {'mul': self.mul} # static values\n", + " return (children, aux_data)\n", + "\n", + " @classmethod\n", + " def _tree_unflatten(cls, aux_data, children):\n", + " return cls(*children, **aux_data)\n", + "\n", + "from jax import tree_util\n", + "tree_util.register_pytree_node(CustomClass,\n", + " CustomClass._tree_flatten,\n", + " CustomClass._tree_unflatten)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This is certainly more involved, but it solves all the issues associated with the simpler approaches used above:" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6\n" + ] + } + ], + "source": [ + "c = CustomClass(2, True)\n", + "print(c.calc(3))" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3\n" + ] + } + ], + "source": [ + "c.mul = False # mutation is detected\n", + "print(c.calc(3))" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6\n" + ] + } + ], + "source": [ + "c = CustomClass(jnp.array(2), True) # non-hashable x is supported\n", + "print(c.calc(3))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "So long as your `tree_flatten` and `tree_unflatten` functions correctly handle all relevant attributes in the class, you should be able to use objects of this type directly as arguments to JIT-compiled functions, without any special annotations." + ] + }, { "cell_type": "markdown", "metadata": { @@ -1231,7 +1602,7 @@ "formats": "ipynb,md:myst" }, "kernelspec": { - "display_name": "Python 3", + "display_name": "jax-dev", "language": "python", "name": "python3" }, @@ -1245,15 +1616,10 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.2 (v3.8.2:7b3ab5921f, Feb 24 2020, 17:52:18) \n[Clang 6.0 (clang-600.0.57)]" + "version": "3.12.12" }, "mystnb": { "render_error_lexer": "none" - }, - "vscode": { - "interpreter": { - "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49" - } } }, "nbformat": 4, diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md index f675b65d9f45..40a40640d8ef 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.md +++ b/docs/notebooks/Common_Gotchas_in_JAX.md @@ -7,7 +7,7 @@ jupytext: format_version: 0.13 jupytext_version: 1.16.4 kernelspec: - display_name: Python 3 + display_name: jax-dev language: python name: python3 --- @@ -285,6 +285,191 @@ print(new_jax_array) For more details on indexed array updates, see the [documentation for the `.at` property](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at). ++++ + +(jax-jit-class-methods)= +## 🔪 Using `jax.jit` with class methods + +Most examples of [`jax.jit`](https://docs.jax.dev/en/latest/_autosummary/jax.jit.html) concern decorating stand-alone Python functions, but decorating a method within a class introduces some complication. For example, consider the following simple class, where we've used a standard `jax.jit` annotation on a method: + +```{code-cell} ipython3 +import jax.numpy as jnp +from jax import jit + +class CustomClass: + def __init__(self, x: jnp.ndarray, mul: bool): + self.x = x + self.mul = mul + + @jit # <---- How to do this correctly? + def calc(self, y): + if self.mul: + return self.x * y + return y +``` + +However, this approach will result in an error when you attempt to call this method: + +```{code-cell} ipython3 +:tags: [raises-exception] + +c = CustomClass(2, True) +c.calc(3) +``` + +The problem is that the first argument to the function is `self`, which has type `CustomClass`, and JAX does not know how to handle this type. There are three basic strategies we might use in this case, and we'll discuss them below. + ++++ + +### Strategy 1: JIT-compiled helper function + +The most straightforward approach is to create a helper function external to the class that can be JIT-decorated in the normal way. For example: + +```{code-cell} ipython3 +from functools import partial + +class CustomClass: + def __init__(self, x: jnp.ndarray, mul: bool): + self.x = x + self.mul = mul + + def calc(self, y): + return _calc(self.mul, self.x, y) + +@partial(jit, static_argnums=0) +def _calc(mul, x, y): + if mul: + return x * y + return y +``` + +The result will work as expected: + +```{code-cell} ipython3 +c = CustomClass(2, True) +print(c.calc(3)) +``` + +The benefit of such an approach is that it is simple, explicit, and it avoids the need to teach JAX how to handle objects of type `CustomClass`. However, you may wish to keep all the method logic in the same place. + ++++ + +### Strategy 2: Marking `self` as static + +Another common pattern is to use `static_argnums` to mark the `self` argument as static. But this must be done with care to avoid unexpected results. You may be tempted to simply do this: + +```{code-cell} ipython3 +class CustomClass: + def __init__(self, x: jnp.ndarray, mul: bool): + self.x = x + self.mul = mul + + # WARNING: this example is broken, as we'll see below. Don't copy & paste! + @partial(jit, static_argnums=0) + def calc(self, y): + if self.mul: + return self.x * y + return y +``` + +If you call the method, it will no longer raise an error: + +```{code-cell} ipython3 +c = CustomClass(2, True) +print(c.calc(3)) +``` + +However, there is a catch: if you mutate the object after the first method call, the subsequent method call may return an incorrect result: + +```{code-cell} ipython3 +c.mul = False +print(c.calc(3)) # Should print 3 +``` + +Why is this? When you mark an object as static, it will effectively be used as a dictionary key in JIT's internal compilation cache, meaning its hash (i.e. `hash(obj)`) equality (i.e. `obj1 == obj2`) and object identity (i.e. `obj1 is obj2`) will be assumed to have consistent behavior. The default `__hash__` for a custom object is its object ID, and so JAX has no way of knowing that a mutated object should trigger a re-compilation. + +You can partially address this by defining an appropriate `__hash__` and `__eq__` methods for your object; for example: + +```{code-cell} ipython3 +class CustomClass: + def __init__(self, x: jnp.ndarray, mul: bool): + self.x = x + self.mul = mul + + @partial(jit, static_argnums=0) + def calc(self, y): + if self.mul: + return self.x * y + return y + + def __hash__(self): + return hash((self.x, self.mul)) + + def __eq__(self, other): + return (isinstance(other, CustomClass) and + (self.x, self.mul) == (other.x, other.mul)) +``` + +(see the [`object.__hash__`](https://docs.python.org/3/reference/datamodel.html#object.__hash__) documentation for more discussion of the requirements +when overriding `__hash__`). + +This should work correctly with JIT and other transforms **so long as you never mutate your object**. Mutations of objects used as hash keys lead to several subtle problems, which is why for example mutable Python containers (e.g. [`dict`](https://docs.python.org/3/library/stdtypes.html#dict), [`list`](https://docs.python.org/3/library/stdtypes.html#list)) don't define `__hash__`, while their immutable counterparts (e.g. [`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple)) do. + +If your class relies on in-place mutations (such as setting `self.attr = ...` within its methods), then your object is not really "static" and marking it as such may lead to problems. Fortunately, there's another option for this case. + ++++ + +### Strategy 3: Making `CustomClass` a PyTree + +The most flexible approach to correctly JIT-compiling a class method is to register the type as a custom PyTree object; see [Custom pytree nodes](https://docs.jax.dev/en/latest/custom_pytrees.html#pytrees-custom-pytree-nodes). This lets you specify exactly which components of the class should be treated as static and which should be +treated as dynamic. Here's how it might look: + +```{code-cell} ipython3 +class CustomClass: + def __init__(self, x: jnp.ndarray, mul: bool): + self.x = x + self.mul = mul + + @jit + def calc(self, y): + if self.mul: + return self.x * y + return y + + def _tree_flatten(self): + children = (self.x,) # arrays / dynamic values + aux_data = {'mul': self.mul} # static values + return (children, aux_data) + + @classmethod + def _tree_unflatten(cls, aux_data, children): + return cls(*children, **aux_data) + +from jax import tree_util +tree_util.register_pytree_node(CustomClass, + CustomClass._tree_flatten, + CustomClass._tree_unflatten) +``` + +This is certainly more involved, but it solves all the issues associated with the simpler approaches used above: + +```{code-cell} ipython3 +c = CustomClass(2, True) +print(c.calc(3)) +``` + +```{code-cell} ipython3 +c.mul = False # mutation is detected +print(c.calc(3)) +``` + +```{code-cell} ipython3 +c = CustomClass(jnp.array(2), True) # non-hashable x is supported +print(c.calc(3)) +``` + +So long as your `tree_flatten` and `tree_unflatten` functions correctly handle all relevant attributes in the class, you should be able to use objects of this type directly as arguments to JIT-compiled functions, without any special annotations. + +++ {"id": "oZ_jE2WAypdL"} ## 🔪 Out-of-bounds indexing diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb index e80c7ae94687..27f53cf32778 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb @@ -6,20 +6,21 @@ "id": "LqiaKasFjH82" }, "source": [ - "# Custom derivative rules\n", + "(advanced-autodiff-custom-derivative-rules)=\n", + "# Custom derivative rules for JAX-transformable Python functions\n", "\n", - "\n", + "\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)\n", "\n", "There are two ways to define differentiation rules in JAX:\n", "\n", - "1. using `jax.custom_jvp` and `jax.custom_vjp` to define custom differentiation rules for Python functions that are already JAX-transformable; and\n", + "1. using [`jax.custom_jvp`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.html) and [`jax.custom_vjp`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_vjp.html) to define custom differentiation rules for Python functions that are already JAX-transformable; and\n", "2. defining new `core.Primitive` instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems.\n", "\n", "This notebook is about #1. To read instead about #2, see the [notebook on adding primitives](https://docs.jax.dev/en/latest/notebooks/How_JAX_primitives_work.html).\n", "\n", - "For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [jax.jvp](https://docs.jax.dev/en/latest/jax.html#jax.jvp) and [jax.grad](https://docs.jax.dev/en/latest/jax.html#jax.grad), and the mathematical meaning of JVPs and VJPs." + "For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [jax.jvp](https://docs.jax.dev/en/latest/_autosummary/jax.jvp.html) and [jax.grad](https://docs.jax.dev/en/latest/_autosummary/jax.grad.html), and the mathematical meaning of JVPs and VJPs." ] }, { @@ -28,16 +29,7 @@ "id": "9Fg3NFNY-2RY" }, "source": [ - "## Summary" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ZgMNRtXyWIW8" - }, - "source": [ - "### Custom JVPs with `jax.custom_jvp`" + "### TL;DR: Custom JVPs with `jax.custom_jvp`" ] }, { @@ -144,7 +136,7 @@ "id": "N2DOGCREWXFj" }, "source": [ - "### Custom VJPs with `jax.custom_vjp`" + "### TL;DR: Custom VJPs with `jax.custom_vjp`" ] }, { @@ -209,7 +201,7 @@ "id": "AR02eyd1GQhC" }, "source": [ - "### Numerical stability\n", + "### Example: Numerical stability\n", "\n", "One application of `jax.custom_jvp` is to improve the numerical stability of differentiation." ] @@ -370,7 +362,7 @@ "\n", "Instead of generating such large and small values, hoping for a cancellation that floats can't always provide, we'd rather just express the derivative function as a more numerically stable program. In particular, we can write a program that more closely evaluates the equal mathematical expression $1 - \\frac{1}{1 + e^x}$, with no cancellation in sight.\n", "\n", - "This problem is interesting because even though our definition of `log1pexp` could already be JAX-differentiated (and transformed with `jit`, `vmap`, ...), we're not happy with the result of applying standard autodiff rules to the primitives comprising `log1pexp` and composing the result. Instead, we'd like to specify how the whole function `log1pexp` should be differentiated, as a unit, and thus arrange those exponentials better.\n", + "This problem is interesting because even though our definition of `log1pexp` could already be JAX-differentiated (and transformed with [`jit`](https://docs.jax.dev/en/latest/_autosummary/jax.jit.html), [`vmap`](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html), ...), we're not happy with the result of applying standard autodiff rules to the primitives comprising `log1pexp` and composing the result. Instead, we'd like to specify how the whole function `log1pexp` should be differentiated, as a unit, and thus arrange those exponentials better.\n", "\n", "This is one application of custom derivative rules for Python functions that are already JAX transformable: specifying how a composite function should be differentiated, while still using its original Python definition for other transformations (like `jit`, `vmap`, ...).\n", "\n", @@ -450,7 +442,7 @@ "id": "9sVUGbGkUOqO" }, "source": [ - "Here's a `defjvps` convenience wrapper to express the same thing:" + "Here's a [`defjvps`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.defjvps.html) convenience wrapper to express the same thing:" ] }, { @@ -500,7 +492,7 @@ "id": "V9tHAfrSF1N-" }, "source": [ - "### Enforcing a differentiation convention\n", + "### Example: Enforcing a differentiation convention\n", "\n", "A related application is to enforce a differentiation convention, perhaps at a boundary." ] @@ -657,11 +649,11 @@ "id": "7J2A85wbSAmF" }, "source": [ - "### Gradient clipping\n", + "### Example: Gradient clipping\n", "\n", "While in some cases we want to express a mathematical differentiation computation, in other cases we may even want to take a step away from mathematics to adjust the computation autodiff performs. One canonical example is reverse-mode gradient clipping.\n", "\n", - "For gradient clipping, we can use `jnp.clip` together with a `jax.custom_vjp` reverse-mode-only rule:" + "For gradient clipping, we can use [`jnp.clip`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.clip.html) together with a [`jax.custom_vjp`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_vjp.html) reverse-mode-only rule:" ] }, { @@ -782,7 +774,7 @@ "id": "CICQuI86WK4_" }, "source": [ - "### Python debugging\n", + "### Example: Python debugging\n", "\n", "Another application that is motivated by development workflow rather than numerics is to set a `pdb` debugger trace in the backward pass of reverse-mode autodiff." ] @@ -804,7 +796,7 @@ "id": "IC7tEcr1-Fc5" }, "source": [ - "### Implicit function differentiation of iterative implementations\n", + "### Example: Implicit function differentiation of iterative implementations\n", "\n", "This example gets pretty deep in the mathematical weeds!" ] @@ -815,7 +807,7 @@ "id": "szAt97t80hew" }, "source": [ - "Another application for `jax.custom_vjp` is reverse-mode differentiation of functions that are JAX-transformable (by `jit`, `vmap`, ...) but not efficiently JAX-differentiable for some reason, perhaps because they involve `lax.while_loop`. (It's not possible to produce an XLA HLO program that efficiently computes the reverse-mode derivative of an XLA HLO While loop because that would require a program with unbounded memory use, which isn't possible to express in XLA HLO, at least without side-effecting interactions through infeed/outfeed.)\n", + "Another application for `jax.custom_vjp` is reverse-mode differentiation of functions that are JAX-transformable (by `jit`, `vmap`, ...) but not efficiently JAX-differentiable for some reason, perhaps because they involve [`lax.while_loop`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.while_loop.html). (It's not possible to produce an XLA HLO program that efficiently computes the reverse-mode derivative of an XLA HLO While loop because that would require a program with unbounded memory use, which isn't possible to express in XLA HLO, at least without side-effecting interactions through infeed/outfeed.)\n", "\n", "For example, consider this `fixed_point` routine which computes a fixed point by iteratively applying a function in a `while_loop`:" ] @@ -1069,7 +1061,7 @@ "id": "HowvqayEuy-H" }, "source": [ - "A limitation to this approach is that the argument `f` can't close over any values involved in differentiation. That is, you might notice that we kept the parameter `a` explicit in the argument list of `fixed_point`. For this use case, consider using the low-level primitive `lax.custom_root`, which allows for deriviatives in closed-over variables with custom root-finding functions." + "A limitation to this approach is that the argument `f` can't close over any values involved in differentiation. That is, you might notice that we kept the parameter `a` explicit in the argument list of `fixed_point`. For this use case, consider using the low-level primitive `lax.custom_root`, which allows for derivatives in closed-over variables with custom root-finding functions." ] }, { @@ -1089,7 +1081,7 @@ "source": [ "### Use `jax.custom_jvp` to define forward-mode (and, indirectly, reverse-mode) rules\n", "\n", - "Here's a canonical basic example of using `jax.custom_jvp`, where the comments use\n", + "Here's a canonical basic example of using [`jax.custom_jvp`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.html), where the comments use\n", "[Haskell-like type signatures](https://wiki.haskell.org/Type_signature):" ] }, @@ -1272,7 +1264,7 @@ "id": "YPsPS3rdaGo2" }, "source": [ - "The `defjvps` convenience wrapper lets us define a JVP for each argument separately, and the results are computed separately then summed:" + "The [`defjvps`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.defjvps.html) convenience wrapper lets us define a JVP for each argument separately, and the results are computed separately then summed:" ] }, { @@ -1656,7 +1648,7 @@ "source": [ "### Use `jax.custom_vjp` to define custom reverse-mode-only rules\n", "\n", - "While `jax.custom_jvp` suffices for controlling both forward- and, via JAX's automatic transposition, reverse-mode differentiation behavior, in some cases we may want to directly control a VJP rule, for example in the latter two example problems presented above. We can do that with `jax.custom_vjp`:" + "While `jax.custom_jvp` suffices for controlling both forward- and, via JAX's automatic transposition, reverse-mode differentiation behavior, in some cases we may want to directly control a VJP rule, for example in the latter two example problems presented above. We can do that with [`jax.custom_vjp`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_vjp.html):" ] }, { @@ -2200,7 +2192,7 @@ "id": "JKTNivxbmKWO" }, "source": [ - "### Handling non-differentiable arguments" + "### Handling non-differentiable arguments" ] }, { diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.md b/docs/notebooks/Custom_derivative_rules_for_Python_code.md index 82b97e195bd9..ccdc709bd48b 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.md +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.md @@ -13,28 +13,25 @@ kernelspec: +++ {"id": "LqiaKasFjH82"} -# Custom derivative rules +(advanced-autodiff-custom-derivative-rules)= +# Custom derivative rules for JAX-transformable Python functions - + [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) There are two ways to define differentiation rules in JAX: -1. using `jax.custom_jvp` and `jax.custom_vjp` to define custom differentiation rules for Python functions that are already JAX-transformable; and +1. using [`jax.custom_jvp`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.html) and [`jax.custom_vjp`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_vjp.html) to define custom differentiation rules for Python functions that are already JAX-transformable; and 2. defining new `core.Primitive` instances along with all their transformation rules, for example to call into functions from other systems like solvers, simulators, or general numerical computing systems. This notebook is about #1. To read instead about #2, see the [notebook on adding primitives](https://docs.jax.dev/en/latest/notebooks/How_JAX_primitives_work.html). -For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [jax.jvp](https://docs.jax.dev/en/latest/jax.html#jax.jvp) and [jax.grad](https://docs.jax.dev/en/latest/jax.html#jax.grad), and the mathematical meaning of JVPs and VJPs. +For an introduction to JAX's automatic differentiation API, see [The Autodiff Cookbook](https://docs.jax.dev/en/latest/notebooks/autodiff_cookbook.html). This notebook assumes some familiarity with [jax.jvp](https://docs.jax.dev/en/latest/_autosummary/jax.jvp.html) and [jax.grad](https://docs.jax.dev/en/latest/_autosummary/jax.grad.html), and the mathematical meaning of JVPs and VJPs. +++ {"id": "9Fg3NFNY-2RY"} -## Summary - -+++ {"id": "ZgMNRtXyWIW8"} - -### Custom JVPs with `jax.custom_jvp` +### TL;DR: Custom JVPs with `jax.custom_jvp` ```{code-cell} ipython3 :id: zXic8tr--1PK @@ -94,7 +91,7 @@ print(grad(f)(2., 3.)) +++ {"id": "N2DOGCREWXFj"} -### Custom VJPs with `jax.custom_vjp` +### TL;DR: Custom VJPs with `jax.custom_vjp` ```{code-cell} ipython3 :id: 35ScHqhrBwPh @@ -131,7 +128,7 @@ To get an idea of what problems `jax.custom_jvp` and `jax.custom_vjp` are meant +++ {"id": "AR02eyd1GQhC"} -### Numerical stability +### Example: Numerical stability One application of `jax.custom_jvp` is to improve the numerical stability of differentiation. @@ -197,7 +194,7 @@ Stepping through how the jaxpr would be evaluated, we can see that the last line Instead of generating such large and small values, hoping for a cancellation that floats can't always provide, we'd rather just express the derivative function as a more numerically stable program. In particular, we can write a program that more closely evaluates the equal mathematical expression $1 - \frac{1}{1 + e^x}$, with no cancellation in sight. -This problem is interesting because even though our definition of `log1pexp` could already be JAX-differentiated (and transformed with `jit`, `vmap`, ...), we're not happy with the result of applying standard autodiff rules to the primitives comprising `log1pexp` and composing the result. Instead, we'd like to specify how the whole function `log1pexp` should be differentiated, as a unit, and thus arrange those exponentials better. +This problem is interesting because even though our definition of `log1pexp` could already be JAX-differentiated (and transformed with [`jit`](https://docs.jax.dev/en/latest/_autosummary/jax.jit.html), [`vmap`](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html), ...), we're not happy with the result of applying standard autodiff rules to the primitives comprising `log1pexp` and composing the result. Instead, we'd like to specify how the whole function `log1pexp` should be differentiated, as a unit, and thus arrange those exponentials better. This is one application of custom derivative rules for Python functions that are already JAX transformable: specifying how a composite function should be differentiated, while still using its original Python definition for other transformations (like `jit`, `vmap`, ...). @@ -239,7 +236,7 @@ print(vmap(jit(grad(log1pexp)))(jnp.arange(3.))) +++ {"id": "9sVUGbGkUOqO"} -Here's a `defjvps` convenience wrapper to express the same thing: +Here's a [`defjvps`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.defjvps.html) convenience wrapper to express the same thing: ```{code-cell} ipython3 :id: xfQTp8F7USEM @@ -263,7 +260,7 @@ print(vmap(jit(grad(log1pexp)))(jnp.arange(3.))) +++ {"id": "V9tHAfrSF1N-"} -### Enforcing a differentiation convention +### Example: Enforcing a differentiation convention A related application is to enforce a differentiation convention, perhaps at a boundary. @@ -341,11 +338,11 @@ print(grad(f)(0.)) +++ {"id": "7J2A85wbSAmF"} -### Gradient clipping +### Example: Gradient clipping While in some cases we want to express a mathematical differentiation computation, in other cases we may even want to take a step away from mathematics to adjust the computation autodiff performs. One canonical example is reverse-mode gradient clipping. -For gradient clipping, we can use `jnp.clip` together with a `jax.custom_vjp` reverse-mode-only rule: +For gradient clipping, we can use [`jnp.clip`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.clip.html) together with a [`jax.custom_vjp`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_vjp.html) reverse-mode-only rule: ```{code-cell} ipython3 :id: 8jfjSanIW_tJ @@ -394,7 +391,7 @@ plt.plot(vmap(grad(clip_sin))(t)) +++ {"id": "CICQuI86WK4_"} -### Python debugging +### Example: Python debugging Another application that is motivated by development workflow rather than numerics is to set a `pdb` debugger trace in the backward pass of reverse-mode autodiff. @@ -406,13 +403,13 @@ We'll defer an example until the next section. +++ {"id": "IC7tEcr1-Fc5"} -### Implicit function differentiation of iterative implementations +### Example: Implicit function differentiation of iterative implementations This example gets pretty deep in the mathematical weeds! +++ {"id": "szAt97t80hew"} -Another application for `jax.custom_vjp` is reverse-mode differentiation of functions that are JAX-transformable (by `jit`, `vmap`, ...) but not efficiently JAX-differentiable for some reason, perhaps because they involve `lax.while_loop`. (It's not possible to produce an XLA HLO program that efficiently computes the reverse-mode derivative of an XLA HLO While loop because that would require a program with unbounded memory use, which isn't possible to express in XLA HLO, at least without side-effecting interactions through infeed/outfeed.) +Another application for `jax.custom_vjp` is reverse-mode differentiation of functions that are JAX-transformable (by `jit`, `vmap`, ...) but not efficiently JAX-differentiable for some reason, perhaps because they involve [`lax.while_loop`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.while_loop.html). (It's not possible to produce an XLA HLO program that efficiently computes the reverse-mode derivative of an XLA HLO While loop because that would require a program with unbounded memory use, which isn't possible to express in XLA HLO, at least without side-effecting interactions through infeed/outfeed.) For example, consider this `fixed_point` routine which computes a fixed point by iteratively applying a function in a `while_loop`: @@ -559,7 +556,7 @@ print(grad(grad(jnp.sqrt))(2.)) +++ {"id": "HowvqayEuy-H"} -A limitation to this approach is that the argument `f` can't close over any values involved in differentiation. That is, you might notice that we kept the parameter `a` explicit in the argument list of `fixed_point`. For this use case, consider using the low-level primitive `lax.custom_root`, which allows for deriviatives in closed-over variables with custom root-finding functions. +A limitation to this approach is that the argument `f` can't close over any values involved in differentiation. That is, you might notice that we kept the parameter `a` explicit in the argument list of `fixed_point`. For this use case, consider using the low-level primitive `lax.custom_root`, which allows for derivatives in closed-over variables with custom root-finding functions. +++ {"id": "Dr0aNkBslfQf"} @@ -569,7 +566,7 @@ A limitation to this approach is that the argument `f` can't close over any valu ### Use `jax.custom_jvp` to define forward-mode (and, indirectly, reverse-mode) rules -Here's a canonical basic example of using `jax.custom_jvp`, where the comments use +Here's a canonical basic example of using [`jax.custom_jvp`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.html), where the comments use [Haskell-like type signatures](https://wiki.haskell.org/Type_signature): ```{code-cell} ipython3 @@ -670,7 +667,7 @@ print(grad(f)(2., 3.)) +++ {"id": "YPsPS3rdaGo2"} -The `defjvps` convenience wrapper lets us define a JVP for each argument separately, and the results are computed separately then summed: +The [`defjvps`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_jvp.defjvps.html) convenience wrapper lets us define a JVP for each argument separately, and the results are computed separately then summed: ```{code-cell} ipython3 :id: CsQIUhUkajua @@ -845,7 +842,7 @@ print(grad(f)(-1.)) ### Use `jax.custom_vjp` to define custom reverse-mode-only rules -While `jax.custom_jvp` suffices for controlling both forward- and, via JAX's automatic transposition, reverse-mode differentiation behavior, in some cases we may want to directly control a VJP rule, for example in the latter two example problems presented above. We can do that with `jax.custom_vjp`: +While `jax.custom_jvp` suffices for controlling both forward- and, via JAX's automatic transposition, reverse-mode differentiation behavior, in some cases we may want to directly control a VJP rule, for example in the latter two example problems presented above. We can do that with [`jax.custom_vjp`](https://docs.jax.dev/en/latest/_autosummary/jax.custom_vjp.html): ```{code-cell} ipython3 :id: zAZk1n3dUw76 @@ -1141,7 +1138,7 @@ print(grad(fun)(pt)) +++ {"id": "JKTNivxbmKWO"} -### Handling non-differentiable arguments +### Handling non-differentiable arguments +++ {"id": "7g9sXSp_uc36"} diff --git a/docs/notebooks/autodiff_cookbook.ipynb b/docs/notebooks/autodiff_cookbook.ipynb index 5538b70dac93..46f887f8986f 100644 --- a/docs/notebooks/autodiff_cookbook.ipynb +++ b/docs/notebooks/autodiff_cookbook.ipynb @@ -1637,7 +1637,7 @@ "source": [ "## More advanced autodiff\n", "\n", - "In this notebook, we worked through some easy, and then progressively more complicated, applications of automatic differentiation in JAX. We hope you now feel that taking derivatives in JAX is easy and powerful. \n", + "In this notebook, we worked through some easy, and then progressively more complicated, applications of automatic differentiation in JAX. We hope you now feel that taking derivatives in JAX is easy and powerful. For more details, check out the [\"Advanced automatic differentiation\" section in the JAX advanced guides](https://jax.readthedocs.io/en/latest/advanced_guides.html).\n", "\n", "There's a whole world of other autodiff tricks and functionality out there. Topics we didn't cover, but hope to in an \"Advanced Autodiff Cookbook\" include:\n", "\n", diff --git a/docs/notebooks/autodiff_cookbook.md b/docs/notebooks/autodiff_cookbook.md index db6fde8051d1..d2cb091bc0e8 100644 --- a/docs/notebooks/autodiff_cookbook.md +++ b/docs/notebooks/autodiff_cookbook.md @@ -960,7 +960,7 @@ grad(f, holomorphic=True)(A) ## More advanced autodiff -In this notebook, we worked through some easy, and then progressively more complicated, applications of automatic differentiation in JAX. We hope you now feel that taking derivatives in JAX is easy and powerful. +In this notebook, we worked through some easy, and then progressively more complicated, applications of automatic differentiation in JAX. We hope you now feel that taking derivatives in JAX is easy and powerful. For more details, check out the ["Advanced automatic differentiation" section in the JAX advanced guides](https://jax.readthedocs.io/en/latest/advanced_guides.html). There's a whole world of other autodiff tricks and functionality out there. Topics we didn't cover, but hope to in an "Advanced Autodiff Cookbook" include: diff --git a/docs/notebooks/thinking_in_jax.ipynb b/docs/notebooks/thinking_in_jax.ipynb index 3fd8913459ea..26809769c981 100644 --- a/docs/notebooks/thinking_in_jax.ipynb +++ b/docs/notebooks/thinking_in_jax.ipynb @@ -797,8 +797,76 @@ }, { "cell_type": "markdown", + "id": "b79e0c62", "metadata": {}, "source": [ + "## Debugging\n", + "\n", + "Debugging JAX code can be challenging due to its functional programming model and the fact that JAX code is often transformed via JIT compilation or vectorization. However, JAX provides several tools to help with debugging.\n", + "\n", + "### `jax.debug.print`\n", + "\n", + "For simple inspection, use [`jax.debug.print`](https://docs.jax.dev/en/latest/_autosummary/jax.debug.print.html).\n", + "\n", + "Python's built-in `print` executes at trace-time, before the runtime values exist. Because of this, `print` will only show tracer values within `jax.jit`-decorated code." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "61675ec9", + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "@jax.jit\n", + "def f(x):\n", + " print(\"print(x) ->\", x)\n", + " y = jnp.sin(x)\n", + " print(\"print(y) ->\", y)\n", + " return y\n", + "\n", + "result = f(2.)" + ] + }, + { + "cell_type": "markdown", + "id": "a34c34bb", + "metadata": {}, + "source": [ + "If you want to print the actual runtime values, you can use `jax.debug.print`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "49b5cb05", + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "def f(x):\n", + " jax.debug.print(\"jax.debug.print(x) -> {x}\", x=x)\n", + " y = jnp.sin(x)\n", + " jax.debug.print(\"jax.debug.print(y) -> {y}\", y=y)\n", + " return y\n", + "\n", + "result = f(2.)" + ] + }, + { + "cell_type": "markdown", + "id": "515495d4", + "metadata": {}, + "source": [ + "### Debugging flags\n", + "\n", + "JAX offers flags and context managers that enable catching errors more easily. For example, you can enable the `jax.debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code. You can also enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`.\n", + "\n", + "For more details, see [Introduction to debugging](https://docs.jax.dev/en/latest/debugging.html).\n", + "\n", "---\n", "\n", "This is just a taste of what JAX can do. We're really excited to see what you do with it!" diff --git a/docs/notebooks/thinking_in_jax.md b/docs/notebooks/thinking_in_jax.md index 71cbe8a58e54..77f8797c4f5f 100644 --- a/docs/notebooks/thinking_in_jax.md +++ b/docs/notebooks/thinking_in_jax.md @@ -474,6 +474,49 @@ For more on pseudo random numbers in JAX, see the [Pseudorandom numbers tutorial +++ +## Debugging + +Debugging JAX code can be challenging due to its functional programming model and the fact that JAX code is often transformed via JIT compilation or vectorization. However, JAX provides several tools to help with debugging. + +### `jax.debug.print` + +For simple inspection, use [`jax.debug.print`](https://docs.jax.dev/en/latest/_autosummary/jax.debug.print.html). + +Python's built-in `print` executes at trace-time, before the runtime values exist. Because of this, `print` will only show tracer values within `jax.jit`-decorated code. + +```{code-cell} ipython3 +import jax +import jax.numpy as jnp + +@jax.jit +def f(x): + print("print(x) ->", x) + y = jnp.sin(x) + print("print(y) ->", y) + return y + +result = f(2.) +``` + +If you want to print the actual runtime values, you can use `jax.debug.print`: + +```{code-cell} ipython3 +@jax.jit +def f(x): + jax.debug.print("jax.debug.print(x) -> {x}", x=x) + y = jnp.sin(x) + jax.debug.print("jax.debug.print(y) -> {y}", y=y) + return y + +result = f(2.) +``` + +### Debugging flags + +JAX offers flags and context managers that enable catching errors more easily. For example, you can enable the `jax.debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code. You can also enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`. + +For more details, see [Introduction to debugging](https://docs.jax.dev/en/latest/debugging.html). + --- This is just a taste of what JAX can do. We're really excited to see what you do with it! From 55c215d07468d1f5cb8ee801a53e01bbf8885d4a Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Fri, 19 Dec 2025 04:30:47 -0800 Subject: [PATCH 287/315] [MGPU] Add WGStridedFragLayout broadcast support on major dim. PiperOrigin-RevId: 846677494 --- .../mosaic/gpu/fragmented_array.py | 32 +++++++++++++++++++ tests/mosaic/gpu_test.py | 23 +++++++++++++ tests/pallas/mosaic_gpu_test.py | 26 +++++++++++++++ 3 files changed, 81 insertions(+) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 0093ed01caf0..7a4d21dbc45f 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -2454,6 +2454,38 @@ def reduce( ) def broadcast(self, shape) -> FragmentedArray: + if isinstance(self.layout, WGStridedFragLayout): + src_shape, dst_shape = self.layout.shape, shape + if len(src_shape) > len(dst_shape): + raise ValueError( + f"Shape length mismatch. Expected len({src_shape}) <= len({dst_shape})" + ) + if not all(s == 1 or s == d for s, d in zip(src_shape[::-1], dst_shape[::-1])): + raise ValueError( + "Can broadcast if all source dimensions match trailing target" + " dimensions by being equal or set to 1. Broadcasting from" + f" {src_shape} to {dst_shape}" + ) + rank_diff = len(dst_shape) - len(src_shape) + src_shape = tuple([1] * rank_diff + list(src_shape)) + + assert len(src_shape) == len(dst_shape), (src_shape, dst_shape) + len_suffix = next( + (i for i in range(len(src_shape)) if src_shape[~i] != dst_shape[~i]), + len(src_shape) + ) + if len_suffix > 0 and all(x == 1 for x in src_shape[:-len_suffix]): + return FragmentedArray( + _registers=np.tile(self.registers, np.prod(dst_shape[:-len_suffix])), + _layout=WGStridedFragLayout(shape, self.layout.vec_size), + _is_signed=self.is_signed, + ) + + raise NotImplementedError( + "Only major-most broadcast for WGStridedFragLayout is implemented." + f" Broadcasting from: {src_shape}, to: {dst_shape}." + ) + if not isinstance(self.layout, WGSplatFragLayout): raise NotImplementedError(self.layout) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index a81e6b29d043..deb8772937d0 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -3839,6 +3839,29 @@ def kernel(ctx, gmem_input, gmem_output, _): out_ref = jax.lax.broadcast_in_dim(inp, (m, n), (1,)) np.testing.assert_array_equal(result, out_ref) + @parameterized.parameters( + ((128), (4, 128)), + ((1, 128), (2, 128)), + ((1, 128), (4, 128)), + ((1, 256), (2, 256)), + ((128, ), (1, 3, 1, 2, 4, 128)), + ((1, 1, 128,), (1, 3, 1, 2, 4, 128)), + ((1, 1, 1, 1, 1, 128,), (1, 3, 1, 2, 4, 128)), + ((2, 4, 128,), (1, 3, 1, 2, 4, 128)), + ((1, 1, 1, 2, 4, 128,), (1, 3, 1, 2, 4, 128)), + ((2, 8, 8), (2, 8, 8)), + ) + def test_broadcast_major_strided(self, in_shape, out_shape): + dtype = jnp.float16 + def kernel(ctx, gmem_input, gmem_output, _): + t = mgpu.FragmentedArray.load_strided(gmem_input, vec_size=1) + t.broadcast(out_shape).store_untiled(gmem_output, optimized=False) + inp = self.prng.uniform(-1, 1, in_shape).astype(dtype) + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (inp,), jax.ShapeDtypeStruct(out_shape, dtype), inp + )(inp) + np.testing.assert_array_equal(result, jnp.broadcast_to(inp, out_shape)) + @parameterized.parameters(*mtu.RegisterLayout) def test_broadcast_splat(self, layout): out_shape = (128, 128) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 7b5d844e7015..05c2ef48516a 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -2486,6 +2486,32 @@ def kernel(x_ref, y_ref): result = jax.random.uniform(jax.random.key(0), shape=(128,), dtype=jnp.float32) np.testing.assert_array_equal(kernel(result), jnp.broadcast_to(result[None,:], (256, 128))) + @parameterized.parameters( + ((4, 128),), + ((2, 4, 128),), + ) + def test_broadcast_wg_strided_majormost_dim(self, out_shape): + self.skip_if_wg_semantics() # Lowering not implemented. + dtype = jnp.float32 + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct(out_shape, dtype) + ) + def kernel(x_ref, side_load_ref, y_ref): + x_strided = plgpu.load( + x_ref, (), layout=plgpu.Layout.WG_STRIDED((128,), vec_size=1) + ) + side_load_strided = plgpu.load( + side_load_ref, (), layout=plgpu.Layout.WG_STRIDED(out_shape, vec_size=1) + ) + for _ in range(len(out_shape) - 1): + x_strided = x_strided[None, ...] + y_ref[...] = x_strided + side_load_strided[...] + + inp = jax.random.uniform(jax.random.key(0), (128,), dtype) + side_load = jax.random.uniform(jax.random.key(1), out_shape, dtype) + np.testing.assert_array_equal(kernel(inp, side_load), + jnp.broadcast_to(inp, out_shape) + side_load) + def test_broadcast_in_dim_tcgen05_native_layout(self): @functools.partial( self.kernel, From 585016cbf83999f5fcbf36a830ef13f2b3a92ff5 Mon Sep 17 00:00:00 2001 From: Alexey Buslavyev Date: Fri, 19 Dec 2025 06:13:16 -0800 Subject: [PATCH 288/315] internal changes PiperOrigin-RevId: 846704711 --- jax/BUILD | 1 + jaxlib/jax.bzl | 3 +++ 2 files changed, 4 insertions(+) diff --git a/jax/BUILD b/jax/BUILD index 0aaed7b61d2c..6431c08656d1 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -154,6 +154,7 @@ py_library_providing_imports_info( ], # TODO(dsuo): Consider moving these files out of experimental if they're in the public API. ) + ["//jax/experimental:jax_public"], + lazy_imports = True, lib_rule = pytype_library, pytype_srcs = glob( [ diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 0fd5faf39398..abb6658cf640 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -138,11 +138,13 @@ jax2tf_deps = [] def pytype_library(name, pytype_srcs = None, **kwargs): _ = pytype_srcs # @unused + kwargs.pop("lazy_imports", None) py_library(name = name, **kwargs) def pytype_strict_library(name, pytype_srcs = [], **kwargs): data = pytype_srcs + (kwargs["data"] if "data" in kwargs else []) new_kwargs = {k: v for k, v in kwargs.items() if k != "data"} + new_kwargs.pop("lazy_imports", None) py_library(name = name, data = data, **new_kwargs) py_strict_library = py_library @@ -151,6 +153,7 @@ py_strict_test = py_test def py_library_providing_imports_info(*, name, lib_rule = py_library, pytype_srcs = [], **kwargs): data = pytype_srcs + (kwargs["data"] if "data" in kwargs else []) new_kwargs = {k: v for k, v in kwargs.items() if k != "data"} + new_kwargs.pop("lazy_imports", None) lib_rule(name = name, data = data, **new_kwargs) def py_extension(name, srcs, copts, deps, linkopts = []): From d00f43f85f6aed96a44f873968a697b69c78ce55 Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Fri, 19 Dec 2025 07:14:40 -0800 Subject: [PATCH 289/315] [Mosaic GPU] Add a sanity check ensuring that layout inference only ever assigns layouts compatible with variables. Previously, when no such check existed, we would sometimes allow assigning layouts incompatible with the shape of a variable, especially in the case of layout casts. As a result of this change, a few tests that were assigning invalid layouts had to be updated, as well as the rules that perform variable assignments, i.e. layout casts in different memory spaces, and MMA ops. To propagate the errors, we allow the derivation rules to return `Unsatisfiable` directly. Later, we may want to think about improving error reporting, since this should allow for more helpful diagnostics. PiperOrigin-RevId: 846722181 --- jax/experimental/mosaic/gpu/constraints.py | 5 +- .../mosaic/gpu/layout_inference.py | 249 +++++++++++++----- tests/mosaic/gpu_layout_inference_test.py | 48 +++- 3 files changed, 236 insertions(+), 66 deletions(-) diff --git a/jax/experimental/mosaic/gpu/constraints.py b/jax/experimental/mosaic/gpu/constraints.py index 409dd1d7982b..b6fda3d520fc 100644 --- a/jax/experimental/mosaic/gpu/constraints.py +++ b/jax/experimental/mosaic/gpu/constraints.py @@ -258,6 +258,7 @@ def reduce_expression( case _: assert_never(expr) + @dataclasses.dataclass(frozen=True) class Equals: """States that `lhs` and `rhs` are equal.""" @@ -586,8 +587,10 @@ def extract_variables(expr: Expression) -> None: return free_variables def __and__( - self, other: ConstraintSystem + self, other: ConstraintSystem | Unsatisfiable ) -> ConstraintSystem | Unsatisfiable: + if isinstance(other, Unsatisfiable): + return Unsatisfiable() for variable, assignment in self.assignments.items(): if variable in other.assignments and assignment != other.assignments[variable]: return Unsatisfiable() diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index cca47f1775aa..22c805fa0f59 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -82,6 +82,7 @@ class MemorySpace(enum.Enum): _op_name_regex = re.compile(r"^(%\d+ = )?\S+") + @dataclasses.dataclass(frozen=True) class ValueSite: """A unique identifier for a variable. @@ -114,6 +115,11 @@ def value(self) -> ir.Value: else: return self.operation.regions[self.region_index].blocks[0].arguments[self.index] + @property + def shape(self) -> tuple[int, ...]: + """Returns the shape of the underlying value.""" + return tuple(self.value.type.shape) # pytype: disable=attribute-error + @property def memory_space(self) -> MemorySpace: """Returns the memory space associated with this value.""" @@ -459,9 +465,12 @@ def producer_ref(self, operand: ValueSite) -> cs.Variable: # and each identifier in the mapping must be keyed by exactly one variable. # Lastly, the mapping must only refer to variables and # operands/results/arguments that correspond to the given operation. +ConstraintSystemDerivationRuleResult = cs.Unsatisfiable | tuple[ + cs.ConstraintSystem, ValueSitesForVariable +] ConstraintSystemDerivationRule = Callable[ [DerivationContext, ir.OpView], - tuple[cs.ConstraintSystem, ValueSitesForVariable], + ConstraintSystemDerivationRuleResult, ] _constraint_system_derivation_rules: dict[ str, ConstraintSystemDerivationRule @@ -492,7 +501,7 @@ def _is_tmem_ref(v: ir.Value) -> bool: def _pointwise_op_constraint_system( ctx: DerivationContext, op: ir.OpView, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: del ctx all_value_sites = vector_value_sites(op) variable = cs.Variable(all_value_sites[-1]) @@ -548,7 +557,7 @@ def _pointwise_op_constraint_system( def _vector_load_constraint_system( ctx: DerivationContext, op: mgpu.VectorLoadOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: # TODO(b/447079781): Investigate whether we should check for contiguous # strides here. An initial implementation of this failed the # test_gmem_to_smem_with_multiple_smem_indexers_and_transforms test, but @@ -576,7 +585,7 @@ def _vector_load_constraint_system( def _vector_store_constraint_system( ctx: DerivationContext, op: mgpu.VectorStoreOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: # TODO(b/447079781): Investigate whether we should check for contiguous # strides here. An initial implementaiton of this failed the # test_gmem_to_smem_with_multiple_smem_indexers_and_transforms test, but @@ -604,7 +613,7 @@ def _vector_store_constraint_system( def _debug_print_constraint_system( ctx: DerivationContext, op: mgpu.DebugPrintOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: del ctx value = ValueSite(op, VariableType.OPERAND, 0) return cs.ConstraintSystem(), {cs.Variable(value): [value]} @@ -614,7 +623,7 @@ def _debug_print_constraint_system( def _print_layout_constraint_system( ctx: DerivationContext, op: mgpu.PrintLayoutOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: value = ValueSite(op, VariableType.OPERAND, 0) var = cs.Variable(value) if is_vector(op.value) else ctx.producer_ref(value) return cs.ConstraintSystem(), {var: [value]} @@ -624,7 +633,7 @@ def _print_layout_constraint_system( def _broadcasted_iota_constraint_system( ctx: DerivationContext, op: mgpu.BroadcastedIotaOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: del ctx value = ValueSite(op, VariableType.RESULT, 0) var = cs.Variable(value) @@ -636,7 +645,7 @@ def _broadcasted_iota_constraint_system( def _optimization_barrier_constraint_system( ctx: DerivationContext, op: ir.OpView, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: del ctx value_sites_for_variable: ValueSitesForVariable = {} @@ -656,7 +665,7 @@ def _optimization_barrier_constraint_system( def _vector_splat_constraint_system( ctx: DerivationContext, op: ir.OpView, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: del ctx result = ValueSite(op, VariableType.RESULT, 0) variable = cs.Variable(result) @@ -671,7 +680,7 @@ def _vector_splat_constraint_system( def _constant_constraint_system( ctx: DerivationContext, constant_op: arith.ConstantOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: del ctx value = constant_op.value result = ValueSite(constant_op, VariableType.RESULT, 0) @@ -708,7 +717,7 @@ def _terminator( def _for_constraint_system( ctx: DerivationContext, op: scf.ForOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: [block] = op.region.blocks yield_op = _terminator(block, scf.YieldOp) value_sites_for_variable: ValueSitesForVariable = {} @@ -772,7 +781,7 @@ def dynamic_gcd(a: int, b: ir.Value) -> int: def _while_constraint_system( ctx: DerivationContext, op: scf.WhileOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: del ctx [before_block] = op.before.blocks [after_block] = op.after.blocks @@ -811,7 +820,7 @@ def _while_constraint_system( def _index_switch_constraint_system( ctx: DerivationContext, op: scf.IndexSwitchOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: del ctx value_sites_for_variable: ValueSitesForVariable = { cs.Variable(o): [o] for o in vector_value_sites(op) @@ -833,14 +842,19 @@ def _index_switch_constraint_system( def _layout_cast_constraint_system( ctx: DerivationContext, op: mgpu.LayoutCastOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: del ctx operand = ValueSite(op, VariableType.OPERAND, 0) result = ValueSite(op, VariableType.RESULT, 0) variable = cs.Variable(operand) - out_layout = cs.RegisterLayout(layouts_lib.from_layout_attr(op.new_layout)) + out_layout = layouts_lib.from_layout_attr(op.new_layout) + # TODO(bchetioui): think about raising a better error here. + if not is_valid_register_layout_assignment(operand.shape, out_layout): + return cs.Unsatisfiable() return ( - cs.ConstraintSystem(assignments={variable: out_layout}), + cs.ConstraintSystem( + assignments={variable: cs.RegisterLayout(out_layout)} + ), {variable: [operand, result]}, ) @@ -911,35 +925,47 @@ def _infer_wgmma_tiling( def _wgmma_constraint_system( ctx: DerivationContext, op: mgpu.WGMMAOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: assignments: dict[cs.Variable, cs.Constant] = {} value_sites_for_variable: ValueSitesForVariable = {} acc_out = ValueSite(op, VariableType.RESULT, 0) acc_in = ValueSite(op, VariableType.OPERAND, 0) acc_var = cs.Variable(acc_out) - assignments[acc_var] = cs.RegisterLayout(fa.WGMMA_LAYOUT) + acc_layout = fa.WGMMA_LAYOUT + assignments[acc_var] = cs.RegisterLayout(acc_layout) + acc_is_valid = is_valid_register_layout_assignment(acc_out.shape, acc_layout) value_sites_for_variable[acc_var] = [acc_in, acc_out] a_tiling, b_tiling = _infer_wgmma_tiling(op.a.type, op.b.type) b = ValueSite(op, VariableType.OPERAND, 2) b_var = ctx.producer_ref(b) - assignments[b_var] = cs.SMEMTiling(lc.TileTransform(b_tiling)) + b_tile_transform = lc.TileTransform(b_tiling) + b_is_valid = is_valid_smem_layout_assignment(b.shape, b_tile_transform) + assignments[b_var] = cs.SMEMTiling(b_tile_transform) value_sites_for_variable[b_var] = [b] a = ValueSite(op, VariableType.OPERAND, 1) if _is_smem_ref(op.a): a_var = ctx.producer_ref(a) - assignments[a_var] = cs.SMEMTiling(lc.TileTransform(a_tiling)) + a_tile_transform = lc.TileTransform(a_tiling) + assignments[a_var] = cs.SMEMTiling(a_tile_transform) + a_is_valid = is_valid_smem_layout_assignment(a.shape, a_tile_transform) else: assert a_tiling is None a_var = cs.Variable(a) if ir.IntegerType.get_signless(8) == ir.VectorType(op.a.type).element_type: - assignments[a_var] = cs.RegisterLayout(fa.WGMMA_LAYOUT_8BIT) + layout = fa.WGMMA_LAYOUT_8BIT else: - assignments[a_var] = cs.RegisterLayout(fa.WGMMA_LAYOUT) + layout = fa.WGMMA_LAYOUT + assignments[a_var] = cs.RegisterLayout(layout) + a_is_valid = is_valid_register_layout_assignment(a.shape, layout) + value_sites_for_variable[a_var] = [a] + # TODO(bchetioui): think about raising a better error here. + if not a_is_valid or not b_is_valid or not acc_is_valid: + return cs.Unsatisfiable() return cs.ConstraintSystem(assignments), value_sites_for_variable @@ -947,7 +973,7 @@ def _wgmma_constraint_system( def _vector_broadcast_constraint_system( ctx: DerivationContext, op: vector.BroadcastOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: del ctx # This is not expected to be necessary at the moment. We should be using # mgpu.BroadcastInDimOp instead when dealing with broadcasting vectors. @@ -965,7 +991,7 @@ def _vector_broadcast_constraint_system( def _vector_reduction_constraint_system( ctx: DerivationContext, op: vector.ReductionOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: del ctx in_variable = cs.Variable(ValueSite(op, VariableType.OPERAND, 0)) return cs.ConstraintSystem(), {in_variable: [in_variable.key]} @@ -987,7 +1013,7 @@ def _reduction_constraints( def _multi_dim_reduction_constraint_system( ctx: DerivationContext, op: vector.MultiDimReductionOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: del ctx source = ValueSite(op, VariableType.OPERAND, 0) acc = ValueSite(op, VariableType.OPERAND, 1) @@ -1013,7 +1039,7 @@ def _multi_dim_reduction_constraint_system( def _broadcast_in_dim_constraint_system( ctx: DerivationContext, op: mgpu.BroadcastInDimOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: del ctx out_variable = cs.Variable(ValueSite(op, VariableType.RESULT, 0)) source_variable = cs.Variable(ValueSite(op, VariableType.OPERAND, 0)) @@ -1037,7 +1063,7 @@ def _broadcast_in_dim_constraint_system( @_add_constraint_system_derivation_rule(vector.ShapeCastOp) def _shape_cast_constraint_system( ctx: DerivationContext, op: vector.ShapeCastOp -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: del ctx in_shape = tuple(cast(ir.ShapedType, op.source.type).shape) out_shape = tuple(cast(ir.ShapedType, op.result.type).shape) @@ -1080,7 +1106,7 @@ def _shape_cast_constraint_system( @_add_constraint_system_derivation_rule(vector.ExtractStridedSliceOp) def _extract_strided_slice_constraint_system( ctx: DerivationContext, op: vector.ExtractStridedSliceOp -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: del ctx if any(ir.IntegerAttr(s).value != 1 for s in op.strides): raise NotImplementedError("`strides` must contain only 1s.") @@ -1138,7 +1164,7 @@ def _vector_extract_constraint_system( def _custom_primitive_constraint_system( ctx: DerivationContext, op: mgpu.CustomPrimitiveOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: assignments: dict[cs.Variable, cs.Constant] = {} constraints: list[cs.Constraint] = [] in_layouts = iter(op.in_layouts) @@ -1199,11 +1225,14 @@ def _tmem_layout_from_layout_attr( def _tmem_layout_cast_constraint_system( ctx: DerivationContext, op: mgpu.TmemLayoutCastOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: operand = ValueSite(op, VariableType.OPERAND, 0) variable = ctx.producer_ref(operand) result = ValueSite(op, VariableType.RESULT, 0) - out_layout = cs.TMEMLayout(_tmem_layout_from_layout_attr(op.new_layout)) + tmem_layout = _tmem_layout_from_layout_attr(op.new_layout) + if not is_valid_tmem_layout_assignment(operand.shape, tmem_layout): + return cs.Unsatisfiable() + out_layout = cs.TMEMLayout(tmem_layout) return ( cs.ConstraintSystem(assignments={variable: out_layout}), {variable: [operand, result]}, @@ -1214,7 +1243,7 @@ def _tmem_layout_cast_constraint_system( def _tmem_alloc_constraint_system( ctx: DerivationContext, op: mgpu.TmemAllocOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: del ctx result = ValueSite(op, VariableType.RESULT, 0) result_var = cs.Variable(result) @@ -1231,7 +1260,7 @@ def _tmem_alloc_constraint_system( def _tmem_dealloc_constraint_system( ctx: DerivationContext, op: mgpu.TmemDeallocOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: operand = ValueSite(op, VariableType.OPERAND, 0) variable = ctx.producer_ref(operand) return cs.ConstraintSystem(), {variable: [operand]} @@ -1241,7 +1270,7 @@ def _tmem_dealloc_constraint_system( def _tcgen05_mma_constraint_system( ctx: DerivationContext, op: mgpu.TcGen05MMAOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: assignments: dict[cs.Variable, cs.Constant] = {} operands_for_variable: ValueSitesForVariable = {} @@ -1253,6 +1282,7 @@ def _tcgen05_mma_constraint_system( tuple(acc_type.shape), op.collective, packing=1 ) assignments[acc_variable] = cs.TMEMLayout(acc_layout) + acc_is_valid = is_valid_tmem_layout_assignment(acc.shape, acc_layout) operands_for_variable[acc_variable] = [acc] if _is_tmem_ref(op.a): @@ -1265,6 +1295,19 @@ def _tcgen05_mma_constraint_system( ) assignments[a_var] = cs.TMEMLayout(a_layout) operands_for_variable[a_var] = [a] + a_is_valid = is_valid_tmem_layout_assignment(a.shape, a_layout) + else: + assert _is_smem_ref(op.a) + a_tiling = _infer_tiling_for_mma_ref( + ir.MemRefType(op.a.type), + max_swizzle=mgpu.SwizzlingMode.k128ByteSwizzle, + ) + a = ValueSite(op, VariableType.OPERAND, 1) + a_var = ctx.producer_ref(a) + a_tile_transform = lc.TileTransform(a_tiling) + assignments[a_var] = cs.SMEMTiling(a_tile_transform) + operands_for_variable[a_var] = [a] + a_is_valid = is_valid_smem_layout_assignment(a.shape, a_tile_transform) # SMEM M = op.accumulator.type.shape[0] @@ -1284,18 +1327,14 @@ def _tcgen05_mma_constraint_system( b_tiling = _infer_tiling_for_mma_ref(ir.MemRefType(op.b.type), max_b_swizzle) b = ValueSite(op, VariableType.OPERAND, 2) b_var = ctx.producer_ref(b) - assignments[b_var] = cs.SMEMTiling(lc.TileTransform(b_tiling)) + b_tile_transform = lc.TileTransform(b_tiling) + assignments[b_var] = cs.SMEMTiling(b_tile_transform) operands_for_variable[b_var] = [b] + b_is_valid = is_valid_smem_layout_assignment(b.shape, b_tile_transform) - if _is_smem_ref(op.a): - a_tiling = _infer_tiling_for_mma_ref( - ir.MemRefType(op.a.type), - max_swizzle=mgpu.SwizzlingMode.k128ByteSwizzle, - ) - a = ValueSite(op, VariableType.OPERAND, 1) - a_var = ctx.producer_ref(a) - assignments[a_var] = cs.SMEMTiling(lc.TileTransform(a_tiling)) - operands_for_variable[a_var] = [a] + # TODO(bchetioui): think about raising a better error here. + if not a_is_valid or not b_is_valid or not acc_is_valid: + return cs.Unsatisfiable() return cs.ConstraintSystem(assignments=assignments), operands_for_variable @@ -1304,7 +1343,7 @@ def _tcgen05_mma_constraint_system( def _async_load_tmem_constraint_system( ctx: DerivationContext, op: mgpu.AsyncLoadTmemOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: source = ValueSite(op, VariableType.OPERAND, 0) source_variable = ctx.producer_ref(source) destination = ValueSite(op, VariableType.RESULT, 0) @@ -1324,7 +1363,7 @@ def _async_load_tmem_constraint_system( def _slice_tmem_constraint_system( ctx: DerivationContext, op: mgpu.SliceTmemOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: operand = ValueSite(op, VariableType.OPERAND, 0) operand_variable = ctx.producer_ref(operand) result = ValueSite(op, VariableType.RESULT, 0) @@ -1339,7 +1378,7 @@ def _slice_tmem_constraint_system( def _async_store_tmem_constraint_system( ctx: DerivationContext, op: mgpu.AsyncStoreTmemOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: source = ValueSite(op, VariableType.OPERAND, 0) source_variable = cs.Variable(source) destination = ValueSite(op, VariableType.OPERAND, 1) @@ -1359,7 +1398,7 @@ def _async_store_tmem_constraint_system( def _slice_smem_constraint_system( ctx: DerivationContext, op: mgpu.SliceSMEMOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: del ctx res = ValueSite(op, VariableType.RESULT, 0) res_var = cs.Variable(res) @@ -1370,7 +1409,7 @@ def _slice_smem_constraint_system( def _memref_subview_constraint_system( ctx: DerivationContext, op: memref.SubViewOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: source = ValueSite(op, VariableType.OPERAND, 0) dest = ValueSite(op, VariableType.RESULT, 0) source_dest_var = ctx.producer_ref(source) @@ -1415,7 +1454,7 @@ def _memref_subview_constraint_system( def _memref_cast_op_constraint_system( ctx: DerivationContext, op: memref.CastOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: source = ValueSite(op, VariableType.OPERAND, 0) var_source_dest = ctx.producer_ref(source) dest = ValueSite(op, VariableType.RESULT, 0) @@ -1426,7 +1465,7 @@ def _memref_cast_op_constraint_system( def _memref_transpose_op_constraint_system( ctx: DerivationContext, op: memref.TransposeOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: in_ty = ir.MemRefType(op.in_.type) if len(in_ty.shape) != 2: raise NotImplementedError(f"Only 2D memrefs are supported, got {in_ty}") @@ -1454,7 +1493,7 @@ def _memref_transpose_op_constraint_system( def _memref_expand_shape_op_equation_system( ctx: DerivationContext, op: memref.ExpandShapeOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: if utils.is_memref_transposed(ir.MemRefType(op.src.type)): raise NotImplementedError( "Transposed memrefs are not supported in ExpandShapeOp." @@ -1485,7 +1524,7 @@ def _memref_expand_shape_op_equation_system( def _memref_load_store_op_constraint_system( ctx: DerivationContext, op: memref.LoadOp | memref.StoreOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: del ctx ref_shape = ir.MemRefType(op.memref.type).shape @@ -1534,11 +1573,15 @@ def _extract_smem_tiling_from_custom_transform_attrs( def _with_transforms_constraint_system( ctx: DerivationContext, op: mgpu.WithTransformsOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: source = ValueSite(op, VariableType.OPERAND, 0) dest = ValueSite(op, VariableType.RESULT, 0) var = ctx.producer_ref(source) tiling = _extract_smem_tiling_from_custom_transform_attrs(op.ref.type, op.transforms) + if tiling.value is not None: + # TODO(bchetioui): think about raising a better error here. + if not is_valid_smem_layout_assignment(source.shape, tiling.value): + return cs.Unsatisfiable() assignments: dict[cs.Variable, cs.Constant] = {var: tiling} return cs.ConstraintSystem(assignments=assignments), {var: [source, dest]} @@ -1548,7 +1591,7 @@ def _with_transforms_constraint_system( def _async_load_store_constraint_system( ctx: DerivationContext, op: mgpu.AsyncLoadOp | mgpu.AsyncStoreOp, -) -> tuple[cs.ConstraintSystem, ValueSitesForVariable]: +) -> ConstraintSystemDerivationRuleResult: tiling_multiple = [] for size, index in zip(op.slice_lengths, op.indices, strict=True): if size == -1: @@ -1857,6 +1900,77 @@ def traverse_op( traverse_op(block_op, callback) +def is_valid_register_layout_assignment( + shape: tuple[int, ...], layout: fa.FragmentedLayout +) -> bool: + match layout: + case fa.WGStridedFragLayout() as strided_layout: + return strided_layout.shape == shape + case fa.WGSplatFragLayout() as splat_layout: + return splat_layout.shape == shape + case fa.TiledLayout(tiling=tiling): + try: + # `tiling.tile_shape` will raise if the shape is not tileable. + _ = tiling.tile_shape(shape) + except ValueError: + return False + return True + case _: + assert False, f"Unreachable {shape}, {layout}" + + +def is_valid_smem_layout_assignment( + shape: tuple[int, ...], tiling: lc.TileTransform +) -> bool: + try: + # `tiling.transform_shape` will raise if the shape is not tileable. + _ = tiling.transform_shape(shape) + except ValueError: + return False + return True + + +def is_valid_tmem_layout_assignment( + shape: tuple[int, ...], layout: tcgen05.TMEMLayout +) -> bool: + try: + # `layout.tiling.tile_shape` will raise if the shape is not tileable. + _ = layout.tiling.tile_shape(shape) + except ValueError: + return False + return True + + +def check_layout_assignment(v: ValueSite, layout: cs.Constant) -> None: + """Raises if the given layout can not be assigned to the given `ValueSite`.""" + match v.memory_space, layout: + case MemorySpace.REG, cs.RegisterLayout(value=reg_layout): + if not is_valid_register_layout_assignment(v.shape, reg_layout): + raise ValueError( + f"Layout {reg_layout} is not compatible with register variable " + f"{v.value}. This is a bug." + ) + case MemorySpace.TMEM, cs.TMEMLayout(value=tmem_layout): + if not is_valid_tmem_layout_assignment(v.shape, tmem_layout): + raise ValueError( + f"Layout {tmem_layout} is not compatible with TMEM variable " + f"{v.value}. This is a bug." + ) + case MemorySpace.SMEM, cs.SMEMTiling(value=tiling_or_none): + if tiling_or_none is None: + return + if not is_valid_smem_layout_assignment(v.shape, tiling_or_none): + raise ValueError( + f"Layout {tiling_or_none} is not compatible with SMEM variable " + f"{v.value}. This is a bug." + ) + case _: + raise ValueError( + f"Variable {v.value} in memory space {v.memory_space} should not be " + f"assigned a layout of type {type(layout)}. This is a bug." + ) + + def infer_layout( module: ir.Module, *, fuel: int = _DEFAULT_LAYOUT_INFERENCE_FUEL ): @@ -1895,13 +2009,21 @@ def gather_constraints(op: ir.Operation): rule = _constraint_system_derivation_rules.get(op.OPERATION_NAME, None) # pytype: disable=attribute-error if rule is None: raise NotImplementedError(f"No layout inference rule defined for {op}") - constraint_system, mapping = rule(ctx, op) - ctx.update(mapping) + rule_result = rule(ctx, op) nonlocal global_constraint_system + if isinstance(rule_result, cs.Unsatisfiable): + global_constraint_system = cs.Unsatisfiable() + return + constraint_system, mapping = rule_result global_constraint_system &= constraint_system + ctx.update(mapping) for op in module.body: traverse_op(op, gather_constraints) + # Short-circuit if we have an unsatisfiable constraint system, we won't + # construct anything useful anymore. + if isinstance(global_constraint_system, cs.Unsatisfiable): + break if isinstance(global_constraint_system, cs.Unsatisfiable): raise ValueError( @@ -1940,11 +2062,14 @@ def gather_constraints(op: ir.Operation): "user-provided layout casts are unsatisfiable." ) - layout_for_value_site = { - k: solution[v] - for v, ks in ctx.value_sites_for_variable.items() - for k in ks - } + layout_for_value_site: dict[ValueSite, cs.Constant] = {} + for variable, value_sites in ctx.value_sites_for_variable.items(): + for value_site in value_sites: + layout = solution[variable] + # Ensure that the layout assignment is valid for the value site. This + # should only ever fail if our implementation is buggy. + check_layout_assignment(value_site, layout) + layout_for_value_site[value_site] = layout # Assigns the layouts that we found to the ops. assign_layouts(layout_for_value_site) diff --git a/tests/mosaic/gpu_layout_inference_test.py b/tests/mosaic/gpu_layout_inference_test.py index facacb604b51..c0f369affab2 100644 --- a/tests/mosaic/gpu_layout_inference_test.py +++ b/tests/mosaic/gpu_layout_inference_test.py @@ -426,8 +426,8 @@ def test_infer_layout_from_yield_op_in_layouts_for_for_op( add = layout_cast(arith.addf(loop_a, loop_b), layout) transforms = ir.ArrayAttr.get([ - mgpu.dialect.TileTransformAttr.get((8, 64)), - mgpu.dialect.SwizzleTransformAttr.get(128), + mgpu.dialect.TileTransformAttr.get((8, 32)), + mgpu.dialect.SwizzleTransformAttr.get(64), ]) loop_ref = mgpu.dialect.with_transforms(loop_ref, transforms) @@ -624,7 +624,7 @@ def test_optimization_barrier_op_propagates_user_layouts(self): wgmma_layout = layouts.to_layout_attr(mgpu.WGMMA_LAYOUT) with ir.InsertionPoint(self.module.body): - ty = ir.VectorType.get((32, 4), ir.BF16Type.get()) + ty = ir.VectorType.get((64, 16), ir.BF16Type.get()) lhs, rhs = undefs(ty, ty) optimization_barrier = mgpu.dialect.OptimizationBarrierOp([lhs, rhs]) lhs, rhs = optimization_barrier.results @@ -1209,6 +1209,9 @@ def test_memref_load_store_op_transforms_are_empty(self): lhs_in_registers=(False, True), ) def test_infer_transforms_for_wgmma_op(self, swizzle, dtype, lhs_in_registers): + if swizzle == mgpu.dialect.SwizzlingMode.kNoSwizzle: + self.skipTest("kNoSwizzle is not supported by this test.") + swizzle_elems = swizzle // np.dtype(dtype).itemsize m = 64 # Note: `group_m` and `group_k` should be coprime with 2 for the test to be @@ -1286,6 +1289,9 @@ def test_infer_layouts_for_8bits_wgmma_op(self, dtype, lhs_in_registers): def test_infer_transforms_for_tcgen05_mma_op( self, swizzle_lhs, swizzle_rhs, dtype, lhs_in_tmem ): + if mgpu.dialect.SwizzlingMode.kNoSwizzle in (swizzle_lhs, swizzle_rhs): + self.skipTest("kNoSwizzle is not supported by this test.") + swizzle_elems_lhs = swizzle_lhs // np.dtype(dtype).itemsize swizzle_elems_rhs = swizzle_rhs // np.dtype(dtype).itemsize m = 128 @@ -2134,6 +2140,42 @@ def test_infer_layout_for_memref_expand_shape_op(self, input_shape, reassociatio [out_transform] = inference_utils.out_transforms(op) self.assertSequenceEqual(out_transform, transforms) + def test_layout_cast_incompatible_with_vector_shape_is_unsatisfiable(self): + with ir.InsertionPoint(self.module.body): + [vec] = undefs(ir.VectorType.get((4, 4), ir.BF16Type.get())) + mgpu.dialect.layout_cast(vec, layouts.to_layout_attr(fa.WGMMA_LAYOUT)) + with self.assertRaisesRegex( + ValueError, "Failed to infer a possible set of layouts" + ): + mgpu.infer_layout(self.module) + + def test_tmem_layout_cast_incompatible_with_ref_shape_is_unsatisfiable(self): + with ir.InsertionPoint(self.module.body): + f32 = ir.F32Type.get() + ref_ty = ir.MemRefType.get((4, 4), f32, memory_space=mgpu.utils.tmem()) + [ref] = undefs(ref_ty) + mgpu.dialect.tmem_layout_cast( + ref, layouts.to_layout_attr(mgpu.TMEM_NATIVE_LAYOUT) + ) + with self.assertRaisesRegex( + ValueError, "Failed to infer a possible set of layouts" + ): + mgpu.infer_layout(self.module) + + def test_with_transforms_incompatible_with_smem_shape_is_unsatisfiable(self): + with ir.InsertionPoint(self.module.body): + f32 = ir.F32Type.get() + ref_ty = ir.MemRefType.get((4, 4), f32, memory_space=mgpu.utils.smem()) + [ref] = undefs(ref_ty) + transforms = ir.ArrayAttr.get([ + mgpu.dialect.TileTransformAttr.get((8, 2)), + ]) + mgpu.dialect.with_transforms(ref, transforms) + with self.assertRaisesRegex( + ValueError, "Failed to infer a possible set of layouts" + ): + mgpu.infer_layout(self.module) + if __name__ == "__main__": parameterized.absltest.main(testLoader=jtu.JaxTestLoader()) From 7bc428b8a0204b384e4236181ebe5bf1cac02b5c Mon Sep 17 00:00:00 2001 From: Liam Miller-Cushon Date: Fri, 19 Dec 2025 08:49:00 -0800 Subject: [PATCH 290/315] Automated Code Change PiperOrigin-RevId: 846751675 --- jax/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/BUILD b/jax/BUILD index 6431c08656d1..6315c67e41ac 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -28,7 +28,7 @@ load( ) package( - default_applicable_licenses = [], + default_applicable_licenses = [":license"], default_visibility = [":internal"], ) From e27c8dd1fbd2d0063e6083c705997a144fa9c5c6 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 19 Dec 2025 09:53:23 -0800 Subject: [PATCH 291/315] [dep] remove several dozen finalized deprecations for v0.9.0 These names were already deprecated in v0.8.0 or prior; this change means that the custom AttributeError will be replaced with a generic AttributeError. --- jax/__init__.py | 35 -------------- jax/custom_derivatives.py | 19 -------- jax/dlpack.py | 16 ------- jax/errors.py | 11 ----- .../compilation_cache/compilation_cache.py | 27 ----------- jax/lib/xla_bridge.py | 23 --------- jax/lib/xla_client.py | 48 ------------------- jax/lib/xla_extension.py | 35 -------------- 8 files changed, 214 deletions(-) diff --git a/jax/__init__.py b/jax/__init__.py index 936b7f914377..874c5f119fdf 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -208,41 +208,6 @@ "jax.device_put_sharded is deprecated; use jax.device_put instead.", _deprecated_device_put_sharded ), - # Finalized 2025-03-25; remove after 2025-06-25 - "treedef_is_leaf": ( - "jax.treedef_is_leaf was removed in JAX v0.6.0: use jax.tree_util.treedef_is_leaf.", - None - ), - "tree_flatten": ( - "jax.tree_flatten was removed in JAX v0.6.0: use jax.tree.flatten (jax v0.4.25 or newer) " - "or jax.tree_util.tree_flatten (any JAX version).", - None - ), - "tree_leaves": ( - "jax.tree_leaves was removed in JAX v0.6.0: use jax.tree.leaves (jax v0.4.25 or newer) " - "or jax.tree_util.tree_leaves (any JAX version).", - None - ), - "tree_structure": ( - "jax.tree_structure was removed in JAX v0.6.0: use jax.tree.structure (jax v0.4.25 or newer) " - "or jax.tree_util.tree_structure (any JAX version).", - None - ), - "tree_transpose": ( - "jax.tree_transpose was removed in JAX v0.6.0: use jax.tree.transpose (jax v0.4.25 or newer) " - "or jax.tree_util.tree_transpose (any JAX version).", - None - ), - "tree_unflatten": ( - "jax.tree_unflatten was removed in JAX v0.6.0: use jax.tree.unflatten (jax v0.4.25 or newer) " - "or jax.tree_util.tree_unflatten (any JAX version).", - None - ), - "tree_map": ( - "jax.tree_map was removed in JAX v0.6.0: use jax.tree.map (jax v0.4.25 or newer) " - "or jax.tree_util.tree_map (any JAX version).", - None - ), } import typing as _typing diff --git a/jax/custom_derivatives.py b/jax/custom_derivatives.py index 2dde0d3cacbb..edefdae40c44 100644 --- a/jax/custom_derivatives.py +++ b/jax/custom_derivatives.py @@ -35,22 +35,3 @@ SymbolicZero as SymbolicZero, zero_from_primal as zero_from_primal ) - -_deprecations = { - # Finalized for v0.8.0; remove in v0.9.0 - "custom_jvp_call_jaxpr_p": ( - ("jax.custom_derivatives.custom_jvp_call_jaxpr_p was deprecated in v0.7.0" - " and removed in v0.8.0. use jax.extend.core.primitives.custom_jvp_call_p" - " instead, and please note that you must `import jax.extend` explicitly."), - None, - ), -} - -import typing -if typing.TYPE_CHECKING: - pass -else: - from jax._src.deprecations import deprecation_getattr as _deprecation_getattr - __getattr__ = _deprecation_getattr(__name__, _deprecations) - del _deprecation_getattr -del typing diff --git a/jax/dlpack.py b/jax/dlpack.py index da04ed7119d7..6fa73748ee8b 100644 --- a/jax/dlpack.py +++ b/jax/dlpack.py @@ -16,19 +16,3 @@ from_dlpack as from_dlpack, is_supported_dtype as is_supported_dtype, ) - -_deprecations = { - # Deprecated in JAX v0.7.0 - "SUPPORTED_DTYPES": ( - ( - "jax.SUPPORTED_DTYPES is deprecated in JAX v0.7.0 and will be removed" - " in JAX v0.8.0. Use jax.dlpack.is_supported_dtype() instead." - ), - None, - ), -} - - -from jax._src.deprecations import deprecation_getattr as _deprecation_getattr -__getattr__ = _deprecation_getattr(__name__, _deprecations) -del _deprecation_getattr diff --git a/jax/errors.py b/jax/errors.py index 928ab6c8a7f2..a4a6c5388db2 100644 --- a/jax/errors.py +++ b/jax/errors.py @@ -31,14 +31,3 @@ JaxRuntimeError = _jax.JaxRuntimeError JaxRuntimeError.__module__ = "jax.errors" del _jax - -_deprecations = { - "SimplifiedTraceback": ( - "jax.errors.SimplifiedTraceback is deprecated and will be removed in JAX v0.8.", - None, - ), -} - -from jax._src.deprecations import deprecation_getattr as _deprecation_getattr -__getattr__ = _deprecation_getattr(__name__, _deprecations) -del _deprecation_getattr diff --git a/jax/experimental/compilation_cache/compilation_cache.py b/jax/experimental/compilation_cache/compilation_cache.py index 8b993c1c142a..8c820c5434fe 100644 --- a/jax/experimental/compilation_cache/compilation_cache.py +++ b/jax/experimental/compilation_cache/compilation_cache.py @@ -16,30 +16,3 @@ set_cache_dir as set_cache_dir, reset_cache as reset_cache, ) - -_deprecations = { - # Finalized for v0.8.0; remove in v0.9.0 - "is_initialized": ( - ( - "compilation_cache.is_initialized was deprecated in JAX v0.4.24 and" - " removed in JAX v0.8.0." - ), - None, - ), - "initialize_cache": ( - ( - "compilation_cache.initialize_cache was deprecated in JAX v0.4.24 and" - " removed in JAX v0.8.0. use compilation_cache.set_cache_dir instead." - ), - None, - ), -} - -import typing as _typing -if _typing.TYPE_CHECKING: - pass -else: - from jax._src.deprecations import deprecation_getattr - __getattr__ = deprecation_getattr(__name__, _deprecations) - del deprecation_getattr -del _typing diff --git a/jax/lib/xla_bridge.py b/jax/lib/xla_bridge.py index 39bda685f4ec..9dc9d269ef05 100644 --- a/jax/lib/xla_bridge.py +++ b/jax/lib/xla_bridge.py @@ -22,27 +22,4 @@ ), stacklevel=4 ) - -_deprecations = { - # Finalized in JAX v0.8.0; remove these messages in v0.9.0. - "get_backend": ( - ( - "jax.lib.xla_bridge.get_backend is deprecated and will be removed" - " in JAX v0.8.0; use jax.extend.backend.get_backend, and please" - " note that you must `import jax.extend` explicitly." - ), - None, - ), - "get_compile_options": ( - ( - "jax.lib.xla_bridge.get_compile_options is deprecated in JAX v0.7.0" - " and will be removed in JAX v0.8.0. Use" - " jax.extend.backend.get_compile_options, and please note that you" - " must `import jax.extend` explicitly." - ), - None, - ), -} - -__getattr__ = _deps.deprecation_getattr(__name__, _deprecations) del _deps diff --git a/jax/lib/xla_client.py b/jax/lib/xla_client.py index ecebb1a7b9a6..7cc1fb88ab15 100644 --- a/jax/lib/xla_client.py +++ b/jax/lib/xla_client.py @@ -22,52 +22,4 @@ ), stacklevel=4 ) - -_deprecations = { - # Finalized in JAX v0.8.0; remove these messages in v0.9.0. - "Client": ( - ( - "jax.lib.xla_client.Client was deprecated in JAX v0.6.0 and will be" - " removed in JAX v0.8.0" - ), - None, - ), - "CompileOptions": ( - ( - "jax.lib.xla_client.CompileOptions was deprecated in JAX v0.6.0 and" - " will be removed in JAX v0.8.0" - ), - None, - ), - "Frame": ( - ( - "jax.lib.xla_client.Frame was deprecated in JAX v0.6.0 and will be" - " removed in JAX v0.8.0" - ), - None, - ), - "HloSharding": ( - ( - "jax.lib.xla_client.HloSharding was deprecated in JAX v0.6.0 and" - " will be removed in JAX v0.8.0" - ), - None, - ), - "OpSharding": ( - ( - "jax.lib.xla_client.OpSharding was deprecated in JAX v0.6.0 and" - " will be removed in JAX v0.8.0" - ), - None, - ), - "Traceback": ( - ( - "jax.lib.xla_client.Traceback was deprecated in JAX v0.6.0 and will" - " be removed in JAX v0.8.0" - ), - None, - ), -} - -__getattr__ = _deps.deprecation_getattr(__name__, _deprecations) del _deps diff --git a/jax/lib/xla_extension.py b/jax/lib/xla_extension.py index c02710c081ad..3c0f2fd5a3e1 100644 --- a/jax/lib/xla_extension.py +++ b/jax/lib/xla_extension.py @@ -22,39 +22,4 @@ ), stacklevel=4 ) - -_deprecations = { - # Finalized in JAX v0.8.0; remove these messages in v0.9.0. - "ifrt_proxy": ( - "jax.lib.xla_extension.ifrt_proxy is deprecated.", - None, - ), - "mlir": ("jax.lib.xla_extension.mlir is deprecated.", None), - "profiler": ( - "jax.lib.xla_extension.profiler is deprecated.", - None, - ), - "hlo_module_cost_analysis": ( - "jax.lib.xla_extension.hlo_module_cost_analysis is deprecated.", - None, - ), - "hlo_module_to_dot_graph": ( - "jax.lib.xla_extension.hlo_module_to_dot_graph is deprecated.", - None, - ), - "HloPrintOptions": ( - "jax.lib.xla_extension.HloPrintOptions is deprecated.", - None, - ), - "PjitFunction": ( - "jax.lib.xla_extension.PjitFunction is deprecated.", - None, - ), - "PmapFunction": ( - "jax.lib.xla_extension.PmapFunction is deprecated.", - None, - ), -} - -__getattr__ = _deps.deprecation_getattr(__name__, _deprecations) del _deps From a26a544174214748e08cccea078078b8eb3ef874 Mon Sep 17 00:00:00 2001 From: Vladimir Belitskiy Date: Fri, 19 Dec 2025 10:18:33 -0800 Subject: [PATCH 292/315] Remove an erroneously added license. PiperOrigin-RevId: 846781195 --- jax/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/BUILD b/jax/BUILD index 6315c67e41ac..6431c08656d1 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -28,7 +28,7 @@ load( ) package( - default_applicable_licenses = [":license"], + default_applicable_licenses = [], default_visibility = [":internal"], ) From 89401bd3143e448eec169732295470ccb9ff97af Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 19 Dec 2025 10:35:13 -0800 Subject: [PATCH 293/315] Default make_mesh to Explicit axes for the upcoming 0.9.0 release PiperOrigin-RevId: 846787187 --- jax/_src/deprecations.py | 1 - jax/_src/sharding_impls.py | 15 +-------------- tests/array_test.py | 19 ++----------------- 3 files changed, 3 insertions(+), 32 deletions(-) diff --git a/jax/_src/deprecations.py b/jax/_src/deprecations.py index b650e0fd5a2f..a72230b04e3d 100644 --- a/jax/_src/deprecations.py +++ b/jax/_src/deprecations.py @@ -134,5 +134,4 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None: register('jax-scipy-special-sph-harm') register('safer-randint-config') register('jax-pmap-no-rank-reduction') -register('jax-make-mesh-default-explicit') register('pltpu-memory-space-any') diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 2cdb0ec2fe56..48247bcce753 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -20,7 +20,6 @@ import dataclasses import functools import math -import warnings from typing import Any, NamedTuple, cast from jax._src import config @@ -33,7 +32,6 @@ from jax._src import source_info_util from jax._src import xla_bridge as xb from jax._src import mesh_utils -from jax._src import deprecations from jax._src.lib import xla_client as xc from jax._src.lib.mlir.dialects import sdy from jax._src.named_sharding import ( # noqa: F401 @@ -1195,18 +1193,7 @@ def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str], '`jax.make_mesh` does not support multi-slice topologies. Please use' ' jax.experimental.mesh_utils.create_hybrid_device_mesh') if axis_types is None: - if deprecations.is_accelerated('jax-make-mesh-default-explicit'): - axis_types = (mesh_lib.AxisType.Explicit,) * len(mesh_devices.shape) - else: - warnings.warn( - 'The default axis_types will change in JAX v0.9.0 to' - ' jax.sharding.AxisType.Explicit. To maintain the old behavior, pass' - ' `axis_types=(jax.sharding.AxisType.Auto,) * len(axis_names)`. To' - ' opt-into the new behavior, pass' - ' `axis_types=(jax.sharding.AxisType.Explicit,) * len(axis_names)', - category=DeprecationWarning, - stacklevel=2, - ) + axis_types = (mesh_lib.AxisType.Explicit,) * len(mesh_devices.shape) return mesh_lib.Mesh(mesh_devices, axis_names, axis_types=axis_types) class set_mesh: diff --git a/tests/array_test.py b/tests/array_test.py index 1404b9416321..70c2dcbf9b94 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -25,7 +25,6 @@ from jax._src import config from jax._src import core from jax._src import dispatch -from jax._src import deprecations from jax._src import op_shardings from jax._src import test_util as jtu from jax._src import xla_bridge as xb @@ -1345,13 +1344,8 @@ def test_make_mesh_axis_types(self): mesh2 = jax.sharding.AbstractMesh((2,), 'x', axis_types=Auto) self.assertEqual(mesh1, mesh2) - if deprecations.is_accelerated('jax-make-mesh-default-explicit'): - mesh = jax.make_mesh((1, 1), ('x', 'y')) - self.assertTupleEqual(mesh.axis_types, (AxisType.Explicit,) * 2) - else: - mesh = jax.make_mesh((1, 1), ('x', 'y'), - axis_types=(AxisType.Explicit,) * 2) - self.assertTupleEqual(mesh.axis_types, (AxisType.Explicit,) * 2) + mesh = jax.make_mesh((1, 1), ('x', 'y')) + self.assertTupleEqual(mesh.axis_types, (AxisType.Explicit,) * 2) mesh = jax.make_mesh((1, 1, 1), ('x', 'y', 'z'), axis_types=(Explicit, Auto, Manual)) @@ -1576,15 +1570,6 @@ def test_nested_tuple_pspec_error(self): "A tuple inside PartitionSpec cannot contain a nested tuple"): jax.P((('a', 'b'), 'c')) - def test_make_mesh_accelerate_explicit(self): - if deprecations.is_accelerated('jax-make-mesh-default-explicit'): - mesh = jax.make_mesh((1,), 'x') - self.assertTupleEqual(mesh.axis_types, (AxisType.Explicit,)) - else: - with self.assertWarnsRegex(DeprecationWarning, "The default axis_types"): - mesh = jax.make_mesh((1,), 'x') - self.assertTupleEqual(mesh.axis_types, (AxisType.Auto,)) - class RngShardingTest(jtu.JaxTestCase): # tests that the PRNGs are automatically sharded as expected From cfc997cd41ba8145d4141b3984c99f375bafdab4 Mon Sep 17 00:00:00 2001 From: Yulia Baturina Date: Fri, 19 Dec 2025 11:01:31 -0800 Subject: [PATCH 294/315] Enable using custom hermetic NCCL version. The NCCL version can be chosen via `HERMETIC_NCCL_VERSION` env var. See docs [here](https://github.com/google-ml-infra/rules_ml_toolchain/blob/main/gpu/README.md#environment-variables-controlling-the-hermetic-cudacudnnnvshmem-versions). PiperOrigin-RevId: 846797606 --- .bazelrc | 2 ++ WORKSPACE | 6 +++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/.bazelrc b/.bazelrc index 371f4b78a376..9fbc3435f74d 100644 --- a/.bazelrc +++ b/.bazelrc @@ -178,6 +178,7 @@ common:clang --copt=-Wno-error=c23-extensions common:cuda_v12 --repo_env=HERMETIC_CUDA_VERSION="12.9.1" common:cuda_v12 --repo_env=HERMETIC_CUDNN_VERSION="9.8.0" common:cuda_v12 --repo_env=HERMETIC_NVSHMEM_VERSION="3.3.9" +common:cuda_v12 --repo_env=HERMETIC_NCCL_VERSION="2.27.7" # "sm" means we emit only cubin, which is forward compatible within a GPU generation. # "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations. common:cuda_v12 --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,sm_90,sm_100,compute_120" @@ -185,6 +186,7 @@ common:cuda_v12 --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70 common:cuda_v13 --repo_env=HERMETIC_CUDA_VERSION="13.0.0" common:cuda_v13 --repo_env=HERMETIC_CUDNN_VERSION="9.12.0" common:cuda_v13 --repo_env=HERMETIC_NVSHMEM_VERSION="3.3.20" +common:cuda_v13 --repo_env=HERMETIC_NCCL_VERSION="2.27.7" common:cuda_v13 --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_75,sm_80,sm_90,sm_100,compute_120" common:cuda_common --repo_env TF_NEED_CUDA=1 diff --git a/WORKSPACE b/WORKSPACE index 6b3d0e2aa010..d032c404553d 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -17,10 +17,10 @@ xla_workspace3() # Details: https://github.com/google-ml-infra/rules_ml_toolchain tf_http_archive( name = "rules_ml_toolchain", - sha256 = "53905ede50e3eebc782266e20e9b9ac1d7166ef68b877bea593d3600dcfe03e6", - strip_prefix = "rules_ml_toolchain-a1ff84835e407b41eef5fd1a865a23748c294db6", + sha256 = "1c2c530a054e9e8b3c811ec21ed8a687fc865bec3abbc8ff65beb829b1d67ae4", + strip_prefix = "rules_ml_toolchain-6734d2a174bf29e731d3f473743d1cc1a86100c3", urls = tf_mirror_urls( - "https://github.com/google-ml-infra/rules_ml_toolchain/archive/a1ff84835e407b41eef5fd1a865a23748c294db6.tar.gz", + "https://github.com/google-ml-infra/rules_ml_toolchain/archive/6734d2a174bf29e731d3f473743d1cc1a86100c3.tar.gz", ), ) From a9dac1167db7a33909caa1e1a8778b303ae3995f Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 19 Dec 2025 11:17:23 -0800 Subject: [PATCH 295/315] [dep] remove deprecated `jax_safer_randint` configuration for JAX v0.9.0 PiperOrigin-RevId: 846803402 --- jax/_src/config.py | 21 --------------------- jax/_src/deprecations.py | 1 - jax/_src/random.py | 26 +++++++++++--------------- 3 files changed, 11 insertions(+), 37 deletions(-) diff --git a/jax/_src/config.py b/jax/_src/config.py index ad4fec0f3c91..507228cac0d7 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1249,27 +1249,6 @@ def _validate_jax_pjrt_client_create_options(new_val): include_in_trace_context=True, ) -def _safer_randint_deprecation(new_val): - if not new_val: - deprecations.warn( - 'safer-randint-config', - ( - 'The jax_safer_randint configuration is deprecated in JAX v0.7.2' - ' and will be removed in JAX v0.9.0.' - ), - stacklevel=4 - ) - -# TODO(jakevdp): remove this flag. -safer_randint = bool_state( - name='jax_safer_randint', - default=True, - help='Use a safer randint algorithm for 8-bit and 16-bit dtypes.', - include_in_jit_key=True, - upgrade=True, - validator=_safer_randint_deprecation -) - class LegacyPrngKeyState(enum.StrEnum): ALLOW = 'allow' WARN = 'warn' diff --git a/jax/_src/deprecations.py b/jax/_src/deprecations.py index a72230b04e3d..946699483f5a 100644 --- a/jax/_src/deprecations.py +++ b/jax/_src/deprecations.py @@ -132,6 +132,5 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None: register('jax-numpy-astype-complex-to-real') register('jax-numpy-clip-args') register('jax-scipy-special-sph-harm') -register('safer-randint-config') register('jax-pmap-no-rank-reduction') register('pltpu-memory-space-any') diff --git a/jax/_src/random.py b/jax/_src/random.py index bdd55ea2f013..5da171c269af 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -565,21 +565,17 @@ def randint_via_uniform(key, shape, minval, maxval, dtype): if not dtypes.issubdtype(dtype, np.integer): raise TypeError(f"randint only accepts integer dtypes, got {dtype}") - # TODO(jakevdp): migrate users to safer randint and remove the old version. - if config.safer_randint.value: - info = dtypes.iinfo(dtype) - dtype_for_sampling = dtype - if info.bits < 32: - # Sample in 32 bits to avoid biased results. - dtype_for_sampling = np.dtype('int32') - minval = jnp.asarray(minval).astype('int32').clip(int(info.min), int(info.max)) - maxval = jnp.asarray(maxval).astype('int32').clip(int(info.min), int(info.max) + 1) - - return maybe_auto_axes(_randint, out_sharding, shape=shape, dtype=dtype_for_sampling)( - key, minval, maxval).astype(dtype) - - return maybe_auto_axes(_randint, out_sharding, shape=shape, dtype=dtype)( - key, minval, maxval) + info = dtypes.iinfo(dtype) + dtype_for_sampling = dtype + if info.bits < 32: + # Sample in 32 bits to avoid biased results. + dtype_for_sampling = np.dtype('int32') + minval = jnp.asarray(minval).astype('int32').clip(int(info.min), int(info.max)) + maxval = jnp.asarray(maxval).astype('int32').clip(int(info.min), int(info.max) + 1) + + return maybe_auto_axes(_randint, out_sharding, shape=shape, dtype=dtype_for_sampling)( + key, minval, maxval).astype(dtype) + @jit(static_argnums=(3, 4)) def _randint(key, minval, maxval, shape, dtype) -> Array: From f314722eda76016d3ad59c4451198aa20d7a0fdd Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 19 Dec 2025 11:41:00 -0800 Subject: [PATCH 296/315] [dep] finalize deprecation of lax.dot positional arguments for JAX v0.9.0. PiperOrigin-RevId: 846811504 --- jax/_src/BUILD | 1 - jax/_src/deprecations.py | 1 - jax/_src/lax/lax.py | 28 +++++----------------------- tests/lax_test.py | 21 ++------------------- 4 files changed, 7 insertions(+), 44 deletions(-) diff --git a/jax/_src/BUILD b/jax/_src/BUILD index f514c931c3bd..3a2db9d307bb 100644 --- a/jax/_src/BUILD +++ b/jax/_src/BUILD @@ -437,7 +437,6 @@ py_library_providing_imports_info( ":core", ":custom_derivatives", ":custom_partitioning_sharding_rule", - ":deprecations", ":dtypes", ":effects", ":ffi", diff --git a/jax/_src/deprecations.py b/jax/_src/deprecations.py index 946699483f5a..6ad6f6dfaf44 100644 --- a/jax/_src/deprecations.py +++ b/jax/_src/deprecations.py @@ -125,7 +125,6 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None: # always registered by the time `accelerate` and `is_acelerated` are called. register('default-dtype-bits-config') register('jax-checkpoint-concrete') -register('jax-lax-dot-positional-args') register('jax-lib-module') register('jax-nn-one-hot-float-input') register('jax-numpy-arange-complex') diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 4aaa0c88539f..609367132004 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -34,7 +34,6 @@ from jax._src import array from jax._src import config from jax._src import core -from jax._src import deprecations from jax._src import dispatch from jax._src import dtypes from jax._src import effects @@ -2411,6 +2410,7 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, preferred_element_type=preferred_element_type, out_sharding=out_sharding) +# TODO(jakevdp): replace `*args`` with `*` in v0.10.0 def dot(lhs: ArrayLike, rhs: ArrayLike, *args, dimension_numbers: DotDimensionNumbers | None = None, precision: PrecisionLike = None, @@ -2465,30 +2465,12 @@ def dot(lhs: ArrayLike, rhs: ArrayLike, *args, .. _stablehlo.dot_general: https://openxla.org/stablehlo/spec#dot_general .. _DotGeneral: https://www.openxla.org/xla/operation_semantics#dotgeneral """ - # TODO(jakevdp): keyword warning added for JAX v0.7.1; finalize this for v0.9.0. if args: - deprecations.warn( - "jax-lax-dot-positional-args", - ( - "jax.lax.dot: passing precision or preferred_element_type by position" - " is deprecated; pass them by keyword instead." - ), - stacklevel=2 + raise TypeError( + f"dot() takes 2 positional arguments but {2 + len(args)} were given." + " Passing precision or preferred_element_type by position is not allowed" + " as of JAX v0.9.0; pass them by keyword instead." ) - # Prior to merging dot and dot_general, dot() had two additional positional args: - # `precision` and `preferred_element_type`. - if len(args) == 1: - if precision is not None: - raise TypeError("jax.lax.dot got multiple values for argument 'precision'") - precision, = args - elif len(args) == 2: - if precision is not None: - raise TypeError("jax.lax.dot got multiple values for argument 'precision'") - if preferred_element_type is not None: - raise TypeError("jax.lax.dot got multiple values for argument 'preferred_element_type'") - precision, preferred_element_type = args - else: - raise TypeError("Too many positional arguments passed to jax.lax.dot.") del args lhs_shape = np.shape(lhs) diff --git a/tests/lax_test.py b/tests/lax_test.py index 626982b1c88d..b7aed37f56be 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -38,7 +38,6 @@ from jax._src import array from jax._src import config from jax._src import core -from jax._src import deprecations from jax._src import dtypes from jax._src import lax_reference from jax._src import test_util as jtu @@ -1107,25 +1106,9 @@ def testDot(self, lhs_shape, rhs_shape, lhs_dtype, rhs_dtype, precision): def testDotPositionalArgumentDeprecation(self): lhs = jnp.arange(5.0) rhs = jnp.arange(5.0) - msg = "jax.lax.dot: passing precision or preferred_element_type by position" - multiple_args_msg = "jax.lax.dot got multiple values for argument" - with self.assertDeprecationWarnsOrRaises("jax-lax-dot-positional-args", msg): - lax.dot(lhs, rhs, lax.Precision.DEFAULT, jnp.float32) - - with self.assertDeprecationWarnsOrRaises("jax-lax-dot-positional-args", msg): - with self.assertRaises(TypeError): - lax.dot(lhs, rhs, lax.Precision.DEFAULT, precision=lax.Precision.DEFAULT) - - if deprecations.is_accelerated("jax-lax-dot-positional-args"): - with self.assertRaisesRegex(ValueError, msg): - lax.dot(lhs, rhs, lax.Precision.DEFAULT, jnp.float32, - preferred_element_type=jnp.float32) - else: - with self.assertWarnsRegex(DeprecationWarning, msg): - with self.assertRaisesRegex(TypeError, multiple_args_msg): - lax.dot(lhs, rhs, lax.Precision.DEFAULT, jnp.float32, - preferred_element_type=jnp.float32) + with self.assertRaisesRegex(TypeError, r"dot\(\) takes 2 positional arguments"): + lax.dot(lhs, rhs, lax.Precision.DEFAULT) @parameterized.parameters([ (algorithm, dtype) From d6944d3d89ff98ffa9e430b26288e841ce51f3c0 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 19 Dec 2025 12:25:45 -0800 Subject: [PATCH 297/315] [dep] remove references to the already-deprecated interpolation argument --- jax/_src/numpy/reductions.py | 39 ++++++++------------------------ tests/lax_numpy_reducers_test.py | 8 ------- tests/lax_numpy_test.py | 8 +++---- 3 files changed, 14 insertions(+), 41 deletions(-) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 05cc64bf1568..1d6869d2140d 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -34,7 +34,7 @@ from jax._src.lax import other as lax_other from jax._src.lax import parallel as lax_parallel from jax._src.lax import slicing as lax_slicing -from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DeprecatedArg +from jax._src.typing import Array, ArrayLike, DType, DTypeLike from jax._src.util import canonicalize_axis, canonicalize_axis_tuple, maybe_named_axis, set_module @@ -2371,12 +2371,11 @@ def cumulative_prod( # Quantiles -# TODO(jakevdp): interpolation argument deprecated 2024-05-16 @export -@api.jit(static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) +@api.jit(static_argnames=('axis', 'overwrite_input', 'keepdims', 'method')) def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", - keepdims: bool = False, *, interpolation: DeprecatedArg = DeprecatedArg()) -> Array: + keepdims: bool = False) -> Array: """Compute the quantile of the data along the specified axis. JAX implementation of :func:`numpy.quantile`. @@ -2418,18 +2417,14 @@ def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = No if overwrite_input or out is not None: raise ValueError("jax.numpy.quantile does not support overwrite_input=True " "or out != None") - # TODO(jakevdp): remove the interpolation argument in JAX v0.9.0 - if not isinstance(interpolation, DeprecatedArg): - raise TypeError("quantile() argument interpolation was removed in JAX" - " v0.8.0. Use method instead.") return _quantile(lax.asarray(a), lax.asarray(q), axis, method, keepdims, False) -# TODO(jakevdp): interpolation argument deprecated 2024-05-16 + @export -@api.jit(static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) +@api.jit(static_argnames=('axis', 'overwrite_input', 'keepdims', 'method')) def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", - keepdims: bool = False, *, interpolation: DeprecatedArg = DeprecatedArg()) -> Array: + keepdims: bool = False) -> Array: """Compute the quantile of the data along the specified axis, ignoring NaNs. JAX implementation of :func:`numpy.nanquantile`. @@ -2473,10 +2468,6 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = msg = ("jax.numpy.nanquantile does not support overwrite_input=True or " "out != None") raise ValueError(msg) - # TODO(jakevdp): remove the interpolation argument in JAX v0.9.0 - if not isinstance(interpolation, DeprecatedArg): - raise TypeError("nanquantile() argument interpolation was removed in JAX" - " v0.8.0. Use method instead.") return _quantile(lax.asarray(a), lax.asarray(q), axis, method, keepdims, True) def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, @@ -2603,13 +2594,12 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, return lax.convert_element_type(result, a.dtype) -# TODO(jakevdp): interpolation argument deprecated 2024-05-16 @export -@api.jit(static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) +@api.jit(static_argnames=('axis', 'overwrite_input', 'keepdims', 'method')) def percentile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", - keepdims: bool = False, *, interpolation: DeprecatedArg = DeprecatedArg()) -> Array: + keepdims: bool = False) -> Array: """Compute the percentile of the data along the specified axis. JAX implementation of :func:`numpy.percentile`. @@ -2649,21 +2639,16 @@ def percentile(a: ArrayLike, q: ArrayLike, """ a, q = ensure_arraylike("percentile", a, q) q, = promote_dtypes_inexact(q) - # TODO(jakevdp): remove the interpolation argument in JAX v0.9.0 - if not isinstance(interpolation, DeprecatedArg): - raise TypeError("percentile() argument interpolation was removed in JAX" - " v0.8.0. Use method instead.") return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input, method=method, keepdims=keepdims) -# TODO(jakevdp): interpolation argument deprecated 2024-05-16 @export -@api.jit(static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) +@api.jit(static_argnames=('axis', 'overwrite_input', 'keepdims', 'method')) def nanpercentile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", - keepdims: bool = False, *, interpolation: DeprecatedArg = DeprecatedArg()) -> Array: + keepdims: bool = False) -> Array: """Compute the percentile of the data along the specified axis, ignoring NaN values. JAX implementation of :func:`numpy.nanpercentile`. @@ -2706,10 +2691,6 @@ def nanpercentile(a: ArrayLike, q: ArrayLike, a, q = ensure_arraylike("nanpercentile", a, q) q, = promote_dtypes_inexact(q) q = q / 100 - # TODO(jakevdp): remove the interpolation argument in JAX v0.9.0 - if not isinstance(interpolation, DeprecatedArg): - raise TypeError("nanpercentile() argument interpolation was removed in JAX" - " v0.8.0. Use method instead.") return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input, method=method, keepdims=keepdims) diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 8a863f68d5e7..b868c2fa4694 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -861,14 +861,6 @@ def np_fun(*args): tol=tol) self._CompileAndCheck(jnp_fun, args_maker, rtol=tol) - @jtu.sample_product( - op=['quantile', 'nanquantile', 'percentile', 'nanpercentile'] - ) - def testQuantileDeprecatedArgs(self, op): - func = getattr(jnp, op) - with self.assertRaisesRegex(TypeError, rf"{op}\(\) argument interpolation"): - func(jnp.arange(4), 0.5, interpolation='linear') - @unittest.skipIf(not config.enable_x64.value, "test requires X64") @jtu.run_on_devices("cpu") # test is for CPU float64 precision def testPercentilePrecision(self): diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 677f0555878d..87383be65577 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6314,16 +6314,16 @@ def testWrappedSignaturesMatch(self): 'frompyfunc': ['kwargs'], 'fromstring': ['like'], 'load': ['mmap_mode', 'allow_pickle', 'fix_imports', 'encoding', 'max_header_size'], - 'nanpercentile': ['weights'], - 'nanquantile': ['weights'], + 'nanpercentile': ['interpolation', 'weights'], + 'nanquantile': ['interpolation', 'weights'], 'nanstd': ['correction'], 'nanvar': ['correction'], 'ones': ['order', 'like'], 'ones_like': ['subok', 'order'], 'partition': ['kind', 'order'], - 'percentile': ['weights'], + 'percentile': ['interpolation', 'weights'], 'promote_types': ['type1', 'type2'], - 'quantile': ['weights'], + 'quantile': ['interpolation', 'weights'], 'row_stack': ['casting'], 'stack': ['casting'], 'tri': ['like'], From 1085d3a00ab2284350fc95b79a6c3ed238a7481a Mon Sep 17 00:00:00 2001 From: Vladimir Belitskiy Date: Fri, 19 Dec 2025 12:38:16 -0800 Subject: [PATCH 298/315] Remove pre-0.8.2 version guards. PiperOrigin-RevId: 846830895 --- jax/_src/core.py | 2 +- tests/lax_test.py | 2 - tests/pallas/indexing_test.py | 2 - tests/pallas/tpu_ops_test.py | 16 ---- tests/pallas/tpu_pallas_test.py | 85 ---------------------- tests/pallas/tpu_side_effects_test.py | 2 - tests/pallas/tpu_sparsecore_pallas_test.py | 31 -------- tests/pjit_test.py | 19 ----- tests/profiler_test.py | 4 - 9 files changed, 1 insertion(+), 162 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index 4d3e3435e583..d86dc6f7ce08 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -895,7 +895,7 @@ def _aval_property(name): return property(lambda self: getattr(self.aval, name)) -if TYPE_CHECKING or jaxlib_extension_version < 388: +if TYPE_CHECKING: # We want Python type checkers to accept `some_tracer: jax.Array`, even though # tracers can represent non-arrays. That is, ideally we would only accept that # annotation when the Tracer instance has a ShapedArray aval, but we can't diff --git a/tests/lax_test.py b/tests/lax_test.py index b7aed37f56be..ae025485aa7e 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -5008,8 +5008,6 @@ def test_ragged_dot_use_ragged_dot_instruction(self, use_instruction): {"m": 10, "k": 9, "n": 8, "num_groups": 2}, ) def test_ragged_dot_small_m(self, m, k, n, num_groups): - if not jtu.is_cloud_tpu_at_least(2025, 10, 14): - self.skipTest("Requires libtpu built after 2025-10-14") lhs_shape = (m, k) rhs_shape = (num_groups, k, n) group_sizes_shape = (num_groups,) diff --git a/tests/pallas/indexing_test.py b/tests/pallas/indexing_test.py index 59aaa0dc4dbc..d057387e0b9b 100644 --- a/tests/pallas/indexing_test.py +++ b/tests/pallas/indexing_test.py @@ -543,8 +543,6 @@ def body(x_ref, y_ref1, y_ref2): @hp.given(hps.data()) def test_load_and_broadcast_with_stride_0(self, data): - if not jtu.is_cloud_tpu_at_least(2025, 11, 25): - self.skipTest("Requires libtpu built after 2025-11-25") if self.INTERPRET: self.skipTest("TODO: fails in interpret mode.") dtype = jnp.float32 diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index f4cc095f921f..52d92f78fb3f 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -194,8 +194,6 @@ def body(x_ref, o_ref): ) def test_sum_of_two_matmuls(self): - if not jtu.is_cloud_tpu_at_least(2025, 11, 15): - self.skipTest("Test requires libtpu from 2025/11/15 or later") if not jtu.is_device_tpu_at_least(version=5): self.skipTest("Test requires TPUv5+") @@ -349,8 +347,6 @@ def kernel(x, out): keepdims=[False, True], ) def test_reduce_index(self, axis, in_shape, reduce_func, keepdims): - if not keepdims and not jtu.is_cloud_tpu_at_least(2025, 11, 24): - self.skipTest("Requires libtpu built after 2025-11-24") dtype = jnp.float32 rank = len(in_shape) if axis >= rank: @@ -389,8 +385,6 @@ def kernel(x, out): dtype=[jnp.float32, jnp.bfloat16], ) def test_i1_relayout_bw(self, shape, msk_dtype, dtype): - if shape[0] < 8 and not jtu.is_cloud_tpu_at_least(2025, 11, 9): - self.skipTest("Requires libtpu built after 2025-11-09") msk_bitwidth = dtypes.itemsize_bits(msk_dtype) bitwidth = dtypes.itemsize_bits(dtype) if jtu.get_tpu_version() < 5 and msk_bitwidth < 32: @@ -424,8 +418,6 @@ def kernel(x_ref, mask_ref, o_ref): ) def test_i1_relayout_bw_tiling(self, msk_dtype, dtype): self.skipTest("TODO: jevinjiang - Enable once presubmits pass.") - if not jtu.is_cloud_tpu_at_least(2025, 10, 7): - self.skipTest("Requires libtpu built after 2025-10-07") shape = (256, 256) bitwidth = dtypes.itemsize_bits(dtype) msk_bitwidth = dtypes.itemsize_bits(msk_dtype) @@ -708,8 +700,6 @@ def else_0(): self.assertEqual(output, 0) def test_retiling_with_replicated_lane(self): - if not jtu.is_cloud_tpu_at_least(2025, 11, 5): - self.skipTest("Test requires libtpu from 2025/11/5 or later") shape = (32, 1) broadcast_shape = (32, 256) @@ -733,8 +723,6 @@ def kernel(x_ref, o_ref): def test_stochastic_round(self, target_dtype): if not jtu.is_device_tpu_at_least(version=5): self.skipTest("Requires TPU v5+") - if not jtu.is_cloud_tpu_at_least(2025, 10, 29): - self.skipTest("Test requires libtpu from 2025/10/29 or later") def kernel(x_ref, b_ref, o_ref): o_ref[...] = pltpu.stochastic_round( @@ -807,8 +795,6 @@ def test_pack_elementwise(self, config, shape): unpacked_dtype, packed_dtype = config if not jtu.is_device_tpu_at_least(version=5): self.skipTest("Requires TPU v5+") - if not jtu.is_cloud_tpu_at_least(2025, 11, 7): - self.skipTest("Test requires libtpu from 2025/11/7 or later") bitwidth = dtypes.itemsize_bits(packed_dtype) num_sources = 32 // bitwidth @@ -842,8 +828,6 @@ def test_unpack_elementwise(self, config, index, shape): unpacked_dtype, packed_dtype = config if not jtu.is_device_tpu_at_least(version=5): self.skipTest("Requires TPU v5+") - if not jtu.is_cloud_tpu_at_least(2025, 11, 7): - self.skipTest("Test requires libtpu from 2025/11/7 or later") bitwidth = dtypes.itemsize_bits(packed_dtype) packing_factor = 32 // bitwidth diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index b292cd3d409a..95e0030cb91d 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -1928,12 +1928,6 @@ def reduce(): reduce_value = jnp.sum(jnp.full(shape, x), dtype=dty) np.testing.assert_allclose(z, reduce_value) - if not jtu.is_cloud_tpu_at_least(2025, 10, 12): - self.skipTest( - 'New CompilerParams shape_invariant_numerics was added on Oct 12,' - ' 2025' - ) - @jax.jit def reduce_with_shape_invariant_numerics(): return self.pallas_call( @@ -2039,12 +2033,6 @@ def reduce(x): expected = reduce_func(dilated_x, axis=reduced_dims).reshape(red_shape) np.testing.assert_allclose(y, expected) - if not jtu.is_cloud_tpu_at_least(2025, 10, 12): - self.skipTest( - 'New CompilerParams shape_invariant_numerics was added on Oct 12,' - ' 2025' - ) - @jax.jit def reduce_with_shape_invariant_numerics(x): return self.pallas_call( @@ -2169,8 +2157,6 @@ def kernel(x_ref, y_ref): pl.Buffered(2), ]) def test_vmem_oom_error_message_basics(self, pmode: pl.Buffered): - if not jtu.is_cloud_tpu_at_least(2025, 11, 12): - self.skipTest('Support added on Nov 12, 2025') if jtu.is_device_tpu(version=5, variant='e') or jtu.is_device_tpu( version=6, variant='e' @@ -2232,8 +2218,6 @@ def test_vmem_oom_error_message_dynamic_grid_scalar_prefetch_and_vmem_scratch( ): if jax.device_count() > 1: self.skipTest("Test only works with a single device.") - if not jtu.is_cloud_tpu_at_least(2025, 10, 14): - self.skipTest('Support added on Oct 14, 2025') def body(s_ref, x_hbm_ref, o_hbm_ref, vmem_scratch_ref): del s_ref, vmem_scratch_ref @@ -2283,8 +2267,6 @@ def run(num_grid, s, x): def test_automatic_single_buffering(self,): if self.INTERPRET: self.skipTest('OOM tests need us to compile the kernels') - if not jtu.is_cloud_tpu_at_least(2025, 11, 12): - self.skipTest('Support added on Oct 14, 2025') def body(*_): pass # We only want to compile the kernel. @@ -2577,9 +2559,6 @@ def test_scalar_integer_addition(self, dtype): def kernel(x_ref, y_ref): y_ref[0] = x_ref[0] + x_ref[0] - if not jtu.is_cloud_tpu_at_least(2025, 9, 13): - self.skipTest('Scalar integer addition support was added on Sep 13, 2025') - x = jnp.asarray([3], dtype=dtype) if dtype in [jnp.int32, jnp.uint32]: @@ -2616,9 +2595,6 @@ def test_vector_integer_addition(self, dtype): def kernel(x_ref, y_ref): y_ref[...] = x_ref[...] + x_ref[...] - if not jtu.is_cloud_tpu_at_least(2025, 9, 15): - self.skipTest('Descriptive message was added on Sep 15, 2025') - x = jnp.full((128, 16), 7, dtype=dtype) if dtype in [jnp.int32, jnp.uint32, jnp.int16, jnp.uint16]: @@ -2650,9 +2626,6 @@ def test_max_operation(self, dtype): def kernel(x_ref, y_ref): y_ref[0] = jnp.maximum(x_ref[0], x_ref[1]) - if not jtu.is_cloud_tpu_at_least(2025, 9, 20): - self.skipTest('Support added on Sep 20, 2025') - x = jnp.asarray([242, 87], dtype=dtype) y = pl.pallas_call( @@ -2672,9 +2645,6 @@ def test_min_operation(self, dtype): def kernel(x_ref, y_ref): y_ref[0] = jnp.minimum(x_ref[0], x_ref[1]) - if not jtu.is_cloud_tpu_at_least(2025, 9, 20): - self.skipTest('Support added on Sep 20, 2025') - x = jnp.asarray([242, 87], dtype=dtype) y = pl.pallas_call( @@ -2701,9 +2671,6 @@ def test_bool_select_operation(self, dtype): def kernel(condlist, choicelist, out_ref): out_ref[...] = jnp.where(condlist[...], choicelist[...], 0) - if not jtu.is_cloud_tpu_at_least(2025, 10, 15): - self.skipTest('Support added on Oct 15, 2025') - if dtype in [jnp.int4, jnp.uint4] and not jtu.is_device_tpu_at_least(4): self.skipTest('i4 is not supported on TPU generations <= 3') @@ -2741,9 +2708,6 @@ def wrapper(*args, **kwargs): def _integer_ops_canonicalization_helper(self, kernel, result, dtype): """For integer scalar ops, only i1 and i32 are supported.""" - if not jtu.is_cloud_tpu_at_least(2025, 9, 27): - self.skipTest('Error message was changed on Sep 27, 2025') - x = jnp.arange(3, dtype=dtype) if dtype in [jnp.int32, jnp.uint32]: @@ -3502,8 +3466,6 @@ class MiscellaneousTest(ptu.PallasTPUTest): def test_casting_bool_to_i8(self): if not jtu.is_device_tpu_at_least(5): self.skipTest("Operation not supported on this TPU version.") - if not jtu.is_cloud_tpu_at_least(2025, 9, 12): - self.skipTest("Needs a newer libtpu") def greater_than(x: jax.Array, y: jax.Array): def kernel(x_ref, y_ref, out_ref): @@ -3545,8 +3507,6 @@ def kernel(x_ref, y_ref, out_ref): np.testing.assert_array_equal(out, np.stack([x, y], axis=1)) def test_lane_to_chunk_reshape_bf16(self): - if not jtu.is_cloud_tpu_at_least(2025, 7, 12): - self.skipTest('Needs a newer libTPU') if not jtu.is_device_tpu_at_least(4): self.skipTest('Operation not supported on this TPU version.') x = np.arange(256 * 1024, dtype=jnp.bfloat16).reshape(1, 256, 1024) @@ -3617,12 +3577,6 @@ def kernel(x_ref, out_ref): def test_roll_partial_with_static_shift( self, shape: tuple[int, int], shift: int, axis: int ): - if ( - not jtu.is_cloud_tpu_at_least(2025, 7, 19) - and shape[0] % 8 - and axis == 0 - ): - self.skipTest('Needs a newer libtpu for non-sublane-aligned shape') x = np.arange(math.prod(shape), dtype=jnp.float32).reshape(shape) def kernel(x_ref, out_ref): @@ -3654,8 +3608,6 @@ def kernel(x_ref, out_ref): )(x) def test_retiling1(self): - if not jtu.is_cloud_tpu_at_least(2025, 7, 2): - self.skipTest('Needs a newer libtpu') x = np.arange(1024, dtype=jnp.bfloat16).reshape(1024) def kernel(x_ref, out_ref): @@ -3684,8 +3636,6 @@ def kernel(x_ref, out_ref): np.testing.assert_array_equal(out, np.reshape(x[:, 7, :], (1, 8, 128))) def test_sublane_adding_shape_cast_f32(self): - if not jtu.is_cloud_tpu_at_least(2025, 7, 12): - self.skipTest('Needs a newer libTPU') x = np.arange(8 * 128, dtype=jnp.float32).reshape(8, 128) def kernel(x_ref, out_ref): @@ -3698,8 +3648,6 @@ def kernel(x_ref, out_ref): np.testing.assert_array_equal(out, np.reshape(x, (8, 1, 128))) def test_sublane_adding_shape_cast_bf16(self): - if not jtu.is_cloud_tpu_at_least(2025, 7, 12): - self.skipTest('Needs a newer libTPU') if not jtu.is_device_tpu_at_least(4): self.skipTest('Operation not supported on this TPU version.') x = np.arange(8 * 128, dtype=jnp.bfloat16).reshape(8, 128) @@ -3742,9 +3690,6 @@ def kernel(x_ref, out_ref): np.testing.assert_array_equal(out, np.zeros((8, 2, 128), dtype=jnp.float32)) def test_transpose(self): - if not jtu.is_cloud_tpu_at_least(2025, 9, 19): - self.skipTest('Needs a newer libTPU') - x = np.zeros((8, 2, 8, 128), dtype=jnp.float32) def kernel(x_ref, out_ref): @@ -3763,9 +3708,6 @@ def kernel(x_ref, out_ref): (5, 1, 4096, jnp.int8), ) def test_1d_tiling_major_minor_transpose(self, q, m, n, dtype): - if not jtu.is_cloud_tpu_at_least(2025, 12, 10): - self.skipTest('Needs a newer libTPU') - in_shape = (q, n) mid_shape = (q, m, n) out_shape = (m, q, n) @@ -3799,8 +3741,6 @@ def kernel(x_ref, o_ref): ) ) def test_reshape_two_minor_dims_to_R2(self, q, m, n, dtype): - if not jtu.is_cloud_tpu_at_least(2025, 7, 12): - self.skipTest('Needs a newer libTPU') if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) ): @@ -3835,8 +3775,6 @@ def kernel(x_ref, y_ref): ) ) def test_reshape_two_minor_dims_to_R3(self, q, m, n, k, dtype): - if not jtu.is_cloud_tpu_at_least(2025, 7, 12): - self.skipTest('Needs a newer libTPU') if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) ): @@ -3941,8 +3879,6 @@ def kernel(x_ref, y_ref): ) ) def test_reshape_four_minor_dims_to_R2(self, p, q, m, n, k, dtype): - if not jtu.is_cloud_tpu_at_least(2025, 7, 12): - self.skipTest('Needs a newer libTPU') if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) ): @@ -3971,8 +3907,6 @@ def kernel(x_ref, y_ref): ) ) def test_reshape_two_minor_dims_preserve_rank(self, q, m, n, k, dtype): - if not jtu.is_cloud_tpu_at_least(2025, 7, 12): - self.skipTest('Needs a newer libTPU') if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) ): @@ -4013,8 +3947,6 @@ def kernel(x_ref, y_ref): def test_reshape_fold_two_leading_dims_and_two_minor_dims_R4_to_R2( self, q, m, n, k, dtype ): - if not jtu.is_cloud_tpu_at_least(2025, 7, 12): - self.skipTest('Needs a newer libTPU') if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) ): @@ -4046,8 +3978,6 @@ def kernel(x_ref, y_ref): def test_reshape_unfold_leading_dim_and_fold_two_minor_dims_R3_to_R3( self, q, m, n, k, dtype ): - if not jtu.is_cloud_tpu_at_least(2025, 7, 12): - self.skipTest('Needs a newer libTPU') if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) ): @@ -4081,8 +4011,6 @@ def kernel(x_ref, y_ref): def test_reshape_unfold_leading_and_minor_dims_R2_to_R4( self, q, m, n, k, dtype ): - if not jtu.is_cloud_tpu_at_least(2025, 7, 12): - self.skipTest('Needs a newer libTPU') if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) ): @@ -4112,8 +4040,6 @@ def kernel(x_ref, y_ref): def test_reshape_fold_leading_dims_and_unfold_minor_dim( self, q, m, n, k, dtype ): - if not jtu.is_cloud_tpu_at_least(2025, 7, 12): - self.skipTest('Needs a newer libTPU') if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) ): @@ -4141,8 +4067,6 @@ def kernel(x_ref, y_ref): ) ) def test_reshape_fold_middle_dims(self, q, m, n, k, dtype): - if not jtu.is_cloud_tpu_at_least(2025, 7, 12): - self.skipTest('Needs a newer libTPU') if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) ): @@ -4170,8 +4094,6 @@ def kernel(x_ref, y_ref): ) ) def test_reshape_unfold_middle_dims(self, q, m, n, k, dtype): - if not jtu.is_cloud_tpu_at_least(2025, 7, 12): - self.skipTest('Needs a newer libTPU') if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) ): @@ -4188,8 +4110,6 @@ def kernel(x_ref, y_ref): @parameterized.parameters([jnp.int8, jnp.bfloat16, jnp.float32]) def test_reshape_shift_factor_from_minor_to_major(self, dtype): - if not jtu.is_cloud_tpu_at_least(2025, 7, 12): - self.skipTest('Needs a newer libTPU') if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) ): @@ -4210,9 +4130,6 @@ def kernel(x_ref, y_ref): dtype=[jnp.float32, jnp.bfloat16, jnp.float8_e4m3fn], ) def test_reshape_fold_minormost_dim(self, dtype): - if not jtu.is_cloud_tpu_at_least(2025, 10, 22): - self.skipTest('Needs a newer libTPU') - packing = 32 // (8 * np.dtype(dtype).itemsize) in_shape = (8 * packing, 128) out_shape = (1, math.prod(in_shape)) @@ -4231,8 +4148,6 @@ def kernel(x_ref, y_ref): def test_dynamic_grid_with_smem_output(self): if self.INTERPRET: self.skipTest('Fail on interpreter.') - if not jtu.is_cloud_tpu_at_least(2025, 11, 3): - self.skipTest('Needs a newer libTPU') def body(_, o_ref): o_ref[0] = lax.cond( diff --git a/tests/pallas/tpu_side_effects_test.py b/tests/pallas/tpu_side_effects_test.py index c5109c66605f..ad650921c889 100644 --- a/tests/pallas/tpu_side_effects_test.py +++ b/tests/pallas/tpu_side_effects_test.py @@ -30,8 +30,6 @@ def setUp(self): super().setUp() if not jtu.is_device_tpu(): self.skipTest("TPU required") - if not jtu.is_cloud_tpu_at_least(2025, 11, 11): - self.skipTest("Newer libtpu required") @parameterized.named_parameters( ("pure", pltpu.SideEffectType.PURE), diff --git a/tests/pallas/tpu_sparsecore_pallas_test.py b/tests/pallas/tpu_sparsecore_pallas_test.py index 5f5c1726ffba..e79805fa726f 100644 --- a/tests/pallas/tpu_sparsecore_pallas_test.py +++ b/tests/pallas/tpu_sparsecore_pallas_test.py @@ -531,8 +531,6 @@ def kernel(x_hbm_ref, indices_ref, o_ref): def test_gather_1d_with_dynamically_sized_2d_ref(self): self.skip_if_tc_tiling() - if not jtu.is_cloud_tpu_at_least(2025, 10, 22): - self.skipTest("Needs a newer libtpu") x = jnp.arange(16) indices = jax.random.permutation( @@ -752,8 +750,6 @@ def kernel(x_ref, indices_ref, o_ref): ) def test_store_scatter_2d(self): - if not jtu.is_cloud_tpu_at_least(2025, 10, 31): - self.skipTest("Needs a newer libtpu") num_steps = 4 x = jnp.arange(num_steps * 8).reshape(num_steps, 8) @@ -1075,8 +1071,6 @@ def scoped_kernel(scratch_ref): @parameterized.product(sizes=[[1, 1], [2, 2], [1, 1, 1, 1]]) def test_split_concatenate(self, sizes): - if not jtu.is_cloud_tpu_at_least(2025, 10, 26): - self.skipTest("Test requires a newer libtpu") shape = (sum(sizes), 8) x = jnp.arange(math.prod(shape)).reshape(-1, 8) @@ -1366,9 +1360,6 @@ def kernel(x_ref, o_ref): kernel(x) def test_multiple_of(self): - if not jtu.is_cloud_tpu_at_least(2025, 10, 16): - self.skipTest("Test requires a newer libtpu") - x = jnp.arange(16) @self.vector_subcore_kernel(out_shape=x) @@ -1408,9 +1399,6 @@ def _(i): np.testing.assert_array_equal(kernel(), expected) def test_barrier_via_pallas_call(self): - if not jtu.is_cloud_tpu_at_least(2025, 11, 22): - self.skipTest("Test requires a newer libtpu") - self.skip_if_tc_tiling() mesh = plsc.VectorSubcoreMesh( @@ -1605,9 +1593,6 @@ def kernel(in_ref, o_ref, scratch_ref): ("exp", jnp.exp), ("neg", lambda x: -x), ("abs", jnp.abs) ) def test_unary_ops(self, op): - if not jtu.is_cloud_tpu_at_least(2025, 11, 30): - self.skipTest("Test requires a newer libtpu") - x = jnp.arange(8, dtype=jnp.float32) @self.vector_subcore_kernel(out_shape=x) @@ -1618,9 +1603,6 @@ def kernel(x_ref, o_ref): @parameterized.product(dtype=[np.int32, np.float32]) def test_vector_gather(self, dtype): - if not jtu.is_cloud_tpu_at_least(2025, 12, 2): - self.skipTest("Test requires a newer libtpu") - vec_dim = self.sc_info.num_lanes x = np.arange(vec_dim, dtype=dtype) indices = np.random.randint(0, vec_dim, size=vec_dim, dtype=np.int32) @@ -1640,9 +1622,6 @@ def kernel(x_ref, indices_ref, out_ref): descending=[False, True], ) def test_sort_key_val(self, keys_dtype, values_dtype, use_mask, descending): - if not jtu.is_cloud_tpu_at_least(2025, 12, 2): - self.skipTest("Test requires a newer libtpu") - vec_dim = self.sc_info.num_lanes keys = np.arange(vec_dim, dtype=keys_dtype) np.random.shuffle(keys) @@ -1685,9 +1664,6 @@ def kernel(*args): @parameterized.product(dtype=[np.int32, np.float32]) def test_rev_and_sort_desc(self, dtype): - if not jtu.is_cloud_tpu_at_least(2025, 12, 2): - self.skipTest("Test requires a newer libtpu") - vec_dim = self.sc_info.num_lanes keys = np.arange(vec_dim, dtype=dtype) np.random.shuffle(keys) @@ -1707,9 +1683,6 @@ def kernel(x_ref, o1_ref, o2_ref): values_dtypes=[(), (np.int32,), (np.float32, np.int32)], ) def test_sort(self, keys_dtype, values_dtypes): - if not jtu.is_cloud_tpu_at_least(2025, 11, 30): - self.skipTest("Test requires a newer libtpu") - vec_dim = self.sc_info.num_lanes keys = np.arange(vec_dim, dtype=keys_dtype) np.random.shuffle(keys) @@ -1906,8 +1879,6 @@ class PallasSparsecoreAsyncTest(PallasSCTest): def setUp(self): super().setUp() - if not jtu.is_cloud_tpu_at_least(2025, 12, 14): - self.skipTest("Needs a newer libtpu") @parameterized.product( shape=[ @@ -1925,8 +1896,6 @@ def setUp(self): dtype=[jnp.int32, jnp.float32, jnp.bfloat16], ) def test_basic_async_kernel(self, shape, dtype): - if not jtu.is_cloud_tpu_at_least(2025, 12, 8): - self.skipTest("Need newer libtpu") x = jnp.arange(shape[0] * shape[1], dtype=dtype).reshape(shape) @jax.jit diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 03655c64903e..74e31da86864 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -62,7 +62,6 @@ from jax._src.mesh import AxisType from jax._src.interpreters import pxla from jax._src.lib import xla_client as xc -from jax._src.lib import ifrt_version from jax._src.util import curry, unzip2 config.parse_flags_with_absl() @@ -9620,8 +9619,6 @@ def body(c, _): @jtu.with_explicit_mesh((2,), ('x',)) def test_reduced_sin_fwd_mul_bwd(self, mesh): - if not jtu.is_cloud_tpu_at_least(2025, 11, 7): - self.skipTest('Requires libtpu built after 2025-11-6') np_inp1 = np.arange(8.).reshape(4, 2) np_inp2 = np.arange(16.).reshape(2, 8) @@ -9748,10 +9745,6 @@ def f(x, y): ) @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_sharded_unreduced_roundtrip(self, shape, orig_spec, un_spec, mesh): - if ifrt_version < 40: - self.skipTest('Requires ifrt_version >= 40') - if not jtu.is_cloud_tpu_at_least(2025, 12, 15): - self.skipTest('Requires libtpu built after 2025-12-15') np1 = np.arange(math.prod(shape)).reshape(shape) arr = jax.device_put(np1, orig_spec) @@ -9768,10 +9761,6 @@ def test_sharded_unreduced_roundtrip(self, shape, orig_spec, un_spec, mesh): ) @jtu.with_explicit_mesh((2,), 'x') def test_one_input_sharded_another_reduced(self, func, mesh): - if ifrt_version < 40: - self.skipTest('Requires ifrt_version >= 40') - if not jtu.is_cloud_tpu_at_least(2025, 12, 15): - self.skipTest('Requires libtpu built after 2025-12-15') np1 = np.arange(8.) arr1 = jax.device_put(np1, P('x')) arr2 = jax.device_put(np1, P(None, reduced={'x'})) @@ -9802,10 +9791,6 @@ def g(x, y): @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_reduced_reshard_unreduced_bwd(self, mesh): - if ifrt_version < 40: - self.skipTest('Requires ifrt_version >= 40') - if not jtu.is_cloud_tpu_at_least(2025, 12, 15): - self.skipTest('Requires libtpu built after 2025-12-15') np1 = np.arange(4.) arr = jax.device_put(np1, P(None, reduced={'x'})) @@ -9833,10 +9818,6 @@ def g(x): @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_reduced_reshard_unreduced_bwd_sharded(self, mesh): - if ifrt_version < 40: - self.skipTest('Requires ifrt_version >= 40') - if not jtu.is_cloud_tpu_at_least(2025, 12, 15): - self.skipTest('Requires libtpu built after 2025-12-15') np1 = np.arange(8.).reshape(4, 2) arr = jax.device_put(np1, P('x', None, reduced={'y'})) diff --git a/tests/profiler_test.py b/tests/profiler_test.py index 0db0e2dffab7..a803e010859f 100644 --- a/tests/profiler_test.py +++ b/tests/profiler_test.py @@ -32,7 +32,6 @@ import jax._src.test_util as jtu from jax._src import profiler -from jax._src.lib import ifrt_version from jax import jit @@ -510,9 +509,6 @@ def on_profile(): ) def test_advanced_configuration_getter(self): - if ifrt_version < 41: - self.skipTest("advanced_configuration getter is newly added") - options = jax.profiler.ProfileOptions() advanced_config = { "tpu_trace_mode": "TRACE_COMPUTE", From 82f02eb97638478aa660f0eab2d2dc0af78dc162 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 19 Dec 2025 12:39:55 -0800 Subject: [PATCH 299/315] [dep] Remove deprecated jax.lib submodules for JAX v0.9.0. The following modules are removed: - `jax.lib.xla_bridge` - `jax.lib.xla_client` - `jax.lib.xla_extension` All contents of these submodules were deprecated and removed as of JAX v0.8.0; the modules themselves have been raising warnings on import since this release. PiperOrigin-RevId: 846831442 --- jax/_src/deprecations.py | 1 - jax/lib/__init__.py | 8 -------- jax/lib/xla_bridge.py | 25 ------------------------- jax/lib/xla_client.py | 25 ------------------------- jax/lib/xla_extension.py | 25 ------------------------- 5 files changed, 84 deletions(-) delete mode 100644 jax/lib/xla_bridge.py delete mode 100644 jax/lib/xla_client.py delete mode 100644 jax/lib/xla_extension.py diff --git a/jax/_src/deprecations.py b/jax/_src/deprecations.py index 6ad6f6dfaf44..39894348e557 100644 --- a/jax/_src/deprecations.py +++ b/jax/_src/deprecations.py @@ -125,7 +125,6 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None: # always registered by the time `accelerate` and `is_acelerated` are called. register('default-dtype-bits-config') register('jax-checkpoint-concrete') -register('jax-lib-module') register('jax-nn-one-hot-float-input') register('jax-numpy-arange-complex') register('jax-numpy-astype-complex-to-real') diff --git a/jax/lib/__init__.py b/jax/lib/__init__.py index 989bcc944067..46b3668e0fdf 100644 --- a/jax/lib/__init__.py +++ b/jax/lib/__init__.py @@ -16,11 +16,3 @@ from jax._src.lib import ( version_str as __version__, ) - -# Dynamically load submodules because they warn on import. -# TODO(jakevdp): remove this in JAX v0.9.0. -def __getattr__(attr): - if attr in {'xla_bridge', 'xla_client', 'xla_extension'}: - import importlib - return importlib.import_module(f'jax.lib.{attr}') - raise AttributeError(f"module 'jax.lib' has no attribute {attr!r}") diff --git a/jax/lib/xla_bridge.py b/jax/lib/xla_bridge.py deleted file mode 100644 index 9dc9d269ef05..000000000000 --- a/jax/lib/xla_bridge.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2018 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from jax._src import deprecations as _deps - -_deps.warn( - 'jax-lib-module', - ( - 'jax.lib.xla_bridge module will be removed in JAX v0.9.0;' - ' all its APIs were deprecated and removed by JAX v0.8.0.' - ), - stacklevel=4 -) -del _deps diff --git a/jax/lib/xla_client.py b/jax/lib/xla_client.py deleted file mode 100644 index 7cc1fb88ab15..000000000000 --- a/jax/lib/xla_client.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2024 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from jax._src import deprecations as _deps - -_deps.warn( - 'jax-lib-module', - ( - 'jax.lib.xla_client module will be removed in JAX v0.9.0;' - ' all its APIs were deprecated and removed by JAX v0.8.0.' - ), - stacklevel=4 -) -del _deps diff --git a/jax/lib/xla_extension.py b/jax/lib/xla_extension.py deleted file mode 100644 index 3c0f2fd5a3e1..000000000000 --- a/jax/lib/xla_extension.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2024 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from jax._src import deprecations as _deps - -_deps.warn( - 'jax-lib-module', - ( - 'jax.lib.xla_extension module will be removed in JAX v0.9.0;' - ' all its APIs were deprecated and removed by JAX v0.8.0.' - ), - stacklevel=4 -) -del _deps From ec5015fcd732d2110b997d7f291843dc1bf0899f Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 19 Dec 2025 16:53:43 +0000 Subject: [PATCH 300/315] [hijax] upgrade VJPHiPrimitive.def_fwd to support symbolic zeros --- jax/_src/BUILD | 1 + jax/_src/hijax.py | 50 ++++++++++++------- jax/_src/interpreters/partial_eval.py | 4 +- jax/_src/tree_util.py | 7 +++ tests/hijax_test.py | 71 +++++++++++++++++++++++++-- 5 files changed, 108 insertions(+), 25 deletions(-) diff --git a/jax/_src/BUILD b/jax/_src/BUILD index f514c931c3bd..1858a7e684b1 100644 --- a/jax/_src/BUILD +++ b/jax/_src/BUILD @@ -1189,6 +1189,7 @@ pytype_strict_library( ":effects", ":tree_util", ":util", + ":partial_eval", ], ) diff --git a/jax/_src/hijax.py b/jax/_src/hijax.py index fcb9b9b8f9f3..4eaca75131a1 100644 --- a/jax/_src/hijax.py +++ b/jax/_src/hijax.py @@ -24,9 +24,12 @@ from jax._src import effects from jax._src.interpreters import ad from jax._src.interpreters import batching +from jax._src.interpreters import partial_eval as pe from jax._src import ad_util from jax._src.util import safe_zip, safe_map, split_list -from jax._src.tree_util import tree_flatten, tree_unflatten, tree_leaves, tree_map +from jax._src.tree_util import ( + tree_map, tree_flatten, tree_unflatten, tree_leaves, tree_leaves_checked, + broadcast_prefix) map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip @@ -361,7 +364,7 @@ def __init__(self): def expand(self, *args): raise NotImplementedError(f"subclass {type(self)} must implement `expand`") - def vjp_fwd(self, *args): + def vjp_fwd(self, nzs_in, *args): raise NotImplementedError(f"for grad support, subclass {type(self)} must " "implement `vjp_fwd`") @@ -388,8 +391,13 @@ def __call__(self, *args): return tree_unflatten(self.out_tree, ans_flat) def check(self, *arg_tys): - # subclass can optionally override this to add checking logic - return + return # subclass can optionally override this to add checking logic + + def staging(self, trace, source_info, *args): + args_flat = tree_leaves_checked(self.in_tree, args) + ans_flat = trace.default_process_primitive( + call_hi_primitive_p, args_flat, dict(prim=self), source_info) + return tree_unflatten(self.out_tree, ans_flat) def __repr__(self): return f"{self.__class__.__name__}[{self.params}]" @@ -400,11 +408,6 @@ def __hash__(self): def __eq__(self, other): return type(self) is type(other) and self.params == other.params -def tree_leaves_checked(treedef_expected, tree): - flat_vals, treedef_actual = tree_flatten(tree) - assert treedef_actual == treedef_expected - return flat_vals - call_hi_primitive_p = core.Primitive("call_hi_primitive") call_hi_primitive_p.multiple_results = True call_hi_primitive_p.is_high = lambda *args, prim: True # type: ignore @@ -412,9 +415,17 @@ def tree_leaves_checked(treedef_expected, tree): def _call_hi_primitive_abstract_eval(*_args, prim): return prim.out_avals_flat +def _call_hi_primitive_staging(trace, source_info, *args_flat, prim): + trace.frame.is_high = True + args = tree_unflatten(prim.in_tree, args_flat) + ans = prim.staging(trace, source_info, *args) + return tree_leaves_checked(prim.out_tree, ans) +pe.custom_staging_rules[call_hi_primitive_p] = _call_hi_primitive_staging + def _call_hi_primitive_to_lojax(*args_flat, prim): args = tree_unflatten(prim.in_tree, args_flat) - return tree_leaves_checked(prim.out_tree, prim.expand(*args)) + ans = prim.expand(*args) + return tree_leaves_checked(prim.out_tree, ans) call_hi_primitive_p.to_lojax = _call_hi_primitive_to_lojax def _call_hi_primitive_batcher(axis_data, args_flat, dims_flat, prim): @@ -428,20 +439,21 @@ def _call_hi_primitive_batcher(axis_data, args_flat, dims_flat, prim): def _call_hi_primitive_linearize(nz_in_flat, *args_flat, prim): args = tree_unflatten(prim.in_tree, args_flat) - ans, residuals = prim.vjp_fwd(*args) - # TODO(dougalm): does the fwd/bwd API force us to assume the nzs_out are all False - # (except in the case that all the nzs_in are True, which is handled in - # LinearizeTrace.ProcessPrimitive)? + nzs_in = tree_unflatten(prim.in_tree, nz_in_flat) + ans, residuals, *maybe_nzs_out = prim.vjp_fwd(nzs_in, *args) ans_flat = tree_leaves_checked(prim.out_tree, ans) - nzs_out = [True for _ in ans_flat] - return (ans_flat, nzs_out, residuals, partial(fake_linear_op, prim, nz_in_flat)) + nzs_out = True if maybe_nzs_out == [] else maybe_nzs_out[0] + nzs_out_flat = broadcast_prefix(nzs_out, ans) + linearized = partial(fake_linear_op, prim, nz_in_flat) + return (ans_flat, nzs_out_flat, residuals, linearized) def fake_linear_op(prim, nz_in_flat, rs, *tangents): residuals_flat, residuals_tree = tree_flatten(rs) - tangents_flat, _ = tree_flatten(tangents) # prune symbolic zeros + assert nz_in_flat == [not isinstance(t, ad_util.Zero) for t in tangents] + nz_tangents = tree_leaves(tangents) return call_hi_primitive_linearized_p.bind( - *residuals_flat, *tangents_flat, - residuals_tree=residuals_tree, nz_in_flat=tuple(nz_in_flat), prim=prim) + *residuals_flat, *nz_tangents, residuals_tree=residuals_tree, prim=prim, + nz_in_flat=tuple(nz_in_flat)) ad.primitive_linearizations[call_hi_primitive_p] = _call_hi_primitive_linearize diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index ef6b803de368..7fcd9f3645d8 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -2034,8 +2034,8 @@ def default_process_primitive(self, primitive, tracers, params, # TODO(mattjj,dougalm): clean up how we check for new-style hi primitives if primitive is call_hi_primitive_p: out_avals, effs = params['prim'].out_avals_flat, set() # TODO effs - elif (primitive.name == "custom_lin" or - primitive.is_effectful and primitive.is_effectful(params)): + elif (primitive.name in ("custom_lin", "call_hi_primitive_linearized") or + primitive.is_effectful and primitive.is_effectful(params)): out_avals, effs = primitive.abstract_eval(*aval_qdds, **params) else: try: diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 5fc4c4018b9f..ad8ac40ff836 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -93,6 +93,13 @@ def tree_leaves(tree: Any, return default_registry.flatten(tree, is_leaf)[0] +@export +def tree_leaves_checked(treedef_expected: PyTreeDef, tree: Any) -> list[Leaf]: + flat_vals, treedef_actual = tree_flatten(tree) + assert treedef_actual == treedef_expected + return flat_vals + + @export def tree_structure(tree: Any, is_leaf: None | (Callable[[Any], diff --git a/tests/hijax_test.py b/tests/hijax_test.py index bd1e199864d2..b675c3db67c3 100644 --- a/tests/hijax_test.py +++ b/tests/hijax_test.py @@ -522,7 +522,7 @@ def __init__(self, in_aval, *, power): def expand(self, x): return x ** self.power - def vjp_fwd(self, x): + def vjp_fwd(self, nzs_in, x): ans = self(x) return (ans, x) @@ -570,7 +570,7 @@ def __init__(self, in_aval, *, power): def expand(self, x): return x ** self.power - def vjp_fwd(self, x): + def vjp_fwd(self, nzs_in, x): ans = self(x) return (ans, x) @@ -671,7 +671,7 @@ def expand(self, x): qvalue = jnp.round(x / scale).astype(jnp.int8) return QArray(qvalue, scale) - def vjp_fwd(self, x): + def vjp_fwd(self, nzs_in, x): return self(x), None def vjp_bwd_retval(self, _, g): @@ -688,7 +688,7 @@ def __init__(self, quantized_aval): def expand(self, qx): return qx.qvalue * qx.scale - def vjp_fwd(self, qx): + def vjp_fwd(self, nzs_in, qx): return self(qx), None def vjp_bwd_retval(self, _, g): @@ -700,6 +700,69 @@ def f(x): x = jax.random.normal(jax.random.key(0), (3, 3), dtype='float32') g = jax.grad(f)(x) + def test_symbolic_zeros(self): + + class Mul(VJPHiPrimitive): + def __init__(self, aval): + self.in_avals = (aval, aval) + self.out_aval = aval + self.params = {} + super().__init__() + + def expand(self, x, y): + return x * y + + def vjp_fwd(self, nzs_in, x, y): + assert list(nzs_in) == list(nzs_in_) # defined below + ans = self(x, y) + return ans, (x, y) + + def vjp_bwd(self, res, g, x_acc, y_acc): + assert list(nzs_in_) == [not isinstance(x_acc, ad.NullAccum), + not isinstance(y_acc, ad.NullAccum)] + x, y = res + x_acc.accum(g * y) + y_acc.accum(x * g) + + def mul(x, y): + return Mul(typeof(x))(x, y) + + nzs_in_ = (True, False) + self.assertAllClose(jax.grad(mul)(2., 3.), 3., check_dtypes=False) + + nzs_in_ = (False, True) + self.assertAllClose(jax.grad(mul, 1)(2., 3.), 2., check_dtypes=False) + + def test_symbolic_zeros_retval(self): + + class Mul(VJPHiPrimitive): + def __init__(self, aval): + self.in_avals = (aval, aval) + self.out_aval = aval + self.params = {} + super().__init__() + + def expand(self, x, y): + return x * y + + def vjp_fwd(self, nzs_in, x, y): + assert list(nzs_in) == list(nzs_in_) # defined below + ans = self(x, y) + return ans, (x, y) + + def vjp_bwd_retval(self, res, g): + x, y = res + return (g * y, x * g) + + def mul(x, y): + return Mul(typeof(x))(x, y) + + nzs_in_ = (True, False) + self.assertAllClose(jax.grad(mul)(2., 3.), 3., check_dtypes=False) + + nzs_in_ = (False, True) + self.assertAllClose(jax.grad(mul, 1)(2., 3.), 2., check_dtypes=False) + class BoxTest(jtu.JaxTestCase): From 4f0f93eba331aa326f414b569fa01dabddad55f6 Mon Sep 17 00:00:00 2001 From: jax authors Date: Fri, 19 Dec 2025 14:02:02 -0800 Subject: [PATCH 301/315] Integrate LLVM at llvm/llvm-project@7d381f2a5634 Updates LLVM usage to match [7d381f2a5634](https://github.com/llvm/llvm-project/commit/7d381f2a5634) PiperOrigin-RevId: 846858892 --- jax/_src/interpreters/mlir.py | 17 ++++++++++------- jax/experimental/mosaic/gpu/layout_inference.py | 12 ++++++------ jax/experimental/mosaic/gpu/utils.py | 2 +- jaxlib/mosaic/gpu/gpu_module_to_assembly.cc | 5 +++-- 4 files changed, 20 insertions(+), 16 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index b501d7bdb298..7af9770793ca 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -2291,6 +2291,9 @@ def _platforms_for_eqn(ctx: LoweringRuleContext) -> tuple[str, ...]: return tuple(_platforms_for_eqn_ctx(ctx.jaxpr_eqn_ctx) or ctx.platforms or ctx.module_context.platforms) +def _get_owner(v: ir.Value): + owner = v.owner + return owner.operation if isinstance(owner, ir.OpView) else owner def lower_per_platform(ctx: LoweringRuleContext, description: str, @@ -2367,11 +2370,11 @@ def lower_per_platform(ctx: LoweringRuleContext, if len(kept_rules) == 1: output = kept_rules[0](ctx, *rule_args, **rule_kwargs) foreach( - lambda o: wrap_compute_type_in_place(ctx, o.owner), + lambda o: wrap_compute_type_in_place(ctx, _get_owner(o)), filter(_is_not_block_argument, flatten_ir_values(output)), ) foreach( - lambda o: wrap_xla_metadata_in_place(ctx, o.owner), + lambda o: wrap_xla_metadata_in_place(ctx, _get_owner(o)), flatten_ir_values(output), ) return output @@ -2412,11 +2415,11 @@ def lower_per_platform(ctx: LoweringRuleContext, raise ValueError("Output of translation rule must be iterable: " f"{description}, got output {output}") from e foreach( - lambda o: wrap_compute_type_in_place(ctx, o.owner), + lambda o: wrap_compute_type_in_place(ctx, _get_owner(o)), filter(_is_not_block_argument, out_nodes), ) foreach( - lambda o: wrap_xla_metadata_in_place(ctx, o.owner), + lambda o: wrap_xla_metadata_in_place(ctx, _get_owner(o)), out_nodes, ) if inner_ctx.tokens_out is not None: @@ -2610,11 +2613,11 @@ def wrap_compute_type_in_place(ctx: LoweringRuleContext, op: ir.Operation) -> No "_xla_stream_annotation": ir.StringAttr.get(stream), "inlineable": ir.StringAttr.get("false"), } - op.operation.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr) + op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr) else: dict_attr = {"_xla_compute_type": ir.StringAttr.get( map_compute_type(ctx.jaxpr_eqn_ctx.compute_type))} - op.operation.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr) + op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr) def wrap_xla_metadata_in_place(ctx: LoweringRuleContext, op: ir.Operation) -> None: @@ -2659,7 +2662,7 @@ def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue, out = hlo.broadcast_in_dim( aval_to_ir_type(aval_out), op, dense_int_array(broadcast_dimensions)) - wrap_compute_type_in_place(ctx, out.owner) + wrap_compute_type_in_place(ctx, _get_owner(out)) return out def multi_broadcast_in_dim(ctx: LoweringRuleContext, diff --git a/jax/experimental/mosaic/gpu/layout_inference.py b/jax/experimental/mosaic/gpu/layout_inference.py index 22c805fa0f59..736a6b3cfd9b 100644 --- a/jax/experimental/mosaic/gpu/layout_inference.py +++ b/jax/experimental/mosaic/gpu/layout_inference.py @@ -187,7 +187,7 @@ def _default_tmem_layout_for_variable( ) -> tcgen05.TMEMLayout | None: """Returns a default TMEM layout for the given variable, if one is defined.""" value = variable.key.value - parent = value.owner.opview + parent = value.owner if isinstance(parent, mgpu.TmemAllocOp): return tcgen05._infer_tmem_layout( tuple(value.type.shape), parent.collective, packing=1 @@ -768,8 +768,8 @@ def dynamic_gcd(a: int, b: ir.Value) -> int: raise ValueError("a must be strictly positive") if not ir.IntegerType.isinstance(b.type) and not ir.IndexType.isinstance(b.type): raise ValueError(f"Expected an integer dynamic value, got a {b.type}") - if isinstance(b.owner, ir.Operation) and isinstance(b.owner.opview, arith.ConstantOp): - return math.gcd(a, b.owner.opview.literal_value) + if isinstance(b.owner, arith.ConstantOp): + return math.gcd(a, b.owner.literal_value) running_gcd = 1 for factor in prime_decomposition(a): if utils.is_known_divisible(b, running_gcd * factor): @@ -1804,9 +1804,9 @@ def producer_result(operand: ValueSite) -> ValueSite: assert operand.type == VariableType.OPERAND value = operand.value producer = value.owner - if isinstance(producer, ir.Operation): + if isinstance(producer, ir.OpView): index = list(producer.results).index(value) - return ValueSite(producer.opview, VariableType.RESULT, index) + return ValueSite(producer, VariableType.RESULT, index) if isinstance(producer, ir.Block): index = list(producer.arguments).index(value) @@ -1825,7 +1825,7 @@ def consumer_operands(result: ValueSite) -> Sequence[ValueSite]: # The layout can also be chosen from the layout of the consumers of the # results. for use in result.value.uses: - consumer = use.owner.opview # pytype: disable=attribute-error + consumer = use.owner index = use.operand_number consumer_operands.append(ValueSite(consumer, VariableType.OPERAND, index)) return consumer_operands diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index c06555242740..94f23066a88a 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -1854,7 +1854,7 @@ def is_known_divisible(value, divisor, max_depth=10) -> bool: """Returns True if the value is statically known to be divisible by the divisor.""" if divisor == 1: return True - if max_depth < 0 or not isinstance(value.owner, ir.Operation): + if max_depth < 0 or not isinstance(value.owner, ir.OpView): return False new_depth = max_depth - 1 diff --git a/jaxlib/mosaic/gpu/gpu_module_to_assembly.cc b/jaxlib/mosaic/gpu/gpu_module_to_assembly.cc index 140066320f4d..dd83e06b9523 100644 --- a/jaxlib/mosaic/gpu/gpu_module_to_assembly.cc +++ b/jaxlib/mosaic/gpu/gpu_module_to_assembly.cc @@ -94,8 +94,9 @@ std::optional> ModuleToAssembly::moduleToObject( << triple << ", can't optimize with LLVM\n"; return std::nullopt; } - std::optional ptx = translateToISA(llvm_module, **machine); - if (!ptx) { + llvm::FailureOr ptx = translateModuleToISA( + llvm_module, **machine, [&]() { return getOperation().emitError(); }); + if (failed(ptx)) { getOperation().emitError() << "Failed translating the module to PTX."; return std::nullopt; } From 514b7c47d41a1c7e85c687f9fb79d7f7a0c1098e Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 19 Dec 2025 14:56:15 -0800 Subject: [PATCH 302/315] [dep] make jax_default_dtype_bits config a no-op for JAX v0.9.0. This has been deprecated since JAX v0.8.0; after this change the flag still exists, but setting it raises a warning and otherwise has no effect. It will be removed in JAX v0.10.0. PiperOrigin-RevId: 846877149 --- jax/_src/config.py | 29 +++++++++++++-------------- jax/_src/dtypes.py | 36 ++++++++-------------------------- jax/_src/numpy/scalar_types.py | 8 ++++---- jax/_src/test_util.py | 6 ++---- tests/dtypes_test.py | 10 ++++------ tests/lax_numpy_test.py | 10 ++++------ 6 files changed, 36 insertions(+), 63 deletions(-) diff --git a/jax/_src/config.py b/jax/_src/config.py index 507228cac0d7..70826d6bf82a 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -23,6 +23,7 @@ import os import sys from typing import Any, Generic, NoReturn, Optional, Protocol, Type, TypeVar, cast +import warnings from jax._src import deprecations from jax._src.lib import _jax @@ -1527,25 +1528,23 @@ class LegacyPrngKeyState(enum.StrEnum): 'what they are trying to achieve should set it.'), ) -def _default_dtype_bits_deprecation(new_val): - if new_val != '64': - deprecations.warn( - 'default-dtype-bits-config', - ( - 'The jax_default_dtype_bits configuration is deprecated in JAX v0.7.1' - ' and will be removed in JAX v0.9.0.' - ), - stacklevel=4 - ) +def _default_dtype_bits_deprecation(val): + if val != '_default': + warnings.warn( + ( + 'The jax_default_dtype_bits configuration is deprecated in JAX v0.7.1' + ' and has no effect as of JAX v0.9.0. It will be removed in JAX v0.10.0.' + ), + category=DeprecationWarning, + stacklevel=4) default_dtype_bits = enum_state( name='jax_default_dtype_bits', - enum_values=['32', '64'], - default='64', - help=('[deprecated]. This flag was an experiment in allowing users to specify the' - ' default bit width. It was never fully supported or tested. It will ' - ' have no effect after JAX v0.9.0, and be removed entirely in JAX v0.10.0.'), + enum_values=['_default', '32', '64'], + default='_default', + help=('[deprecated]. This has no effect starting with JAX v0.9.0, and' + ' will be removed in JAX v0.10.0.'), extra_validator=_default_dtype_bits_deprecation) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 7373f0ad3815..03fbfe91fd1b 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -184,20 +184,10 @@ def supports_inf(dtype: DTypeLike) -> bool: # Default types. bool_ = np.bool_ -int_: type[Any] -uint: type[Any] -float_: type[Any] -complex_: type[Any] -if config.default_dtype_bits.value == '32': - int_ = np.int32 - uint = np.uint32 - float_ = np.float32 - complex_ = np.complex64 -else: - int_ = np.int64 - uint = np.uint64 - float_ = np.float64 - complex_ = np.complex128 +int_: type[Any] = np.int64 +uint: type[Any] = np.uint64 +float_: type[Any] = np.float64 +complex_: type[Any] = np.complex128 # Default dtypes. These are intended to have the same semantics as, say, @@ -206,33 +196,23 @@ def supports_inf(dtype: DTypeLike) -> bool: def default_int_dtype() -> DType: - return ( - np.dtype(np.int64) - if config.enable_x64.value and config.default_dtype_bits.value == '64' - else np.dtype(np.int32) - ) + return np.dtype(np.int64) if config.enable_x64.value else np.dtype(np.int32) def default_uint_dtype() -> DType: - return ( - np.dtype(np.uint64) - if config.enable_x64.value and config.default_dtype_bits.value == '64' - else np.dtype(np.uint32) - ) + return np.dtype(np.uint64) if config.enable_x64.value else np.dtype(np.uint32) def default_float_dtype() -> DType: return ( - np.dtype(np.float64) - if config.enable_x64.value and config.default_dtype_bits.value == '64' - else np.dtype(np.float32) + np.dtype(np.float64) if config.enable_x64.value else np.dtype(np.float32) ) def default_complex_dtype() -> DType: return ( np.dtype(np.complex128) - if config.enable_x64.value and config.default_dtype_bits.value == '64' + if config.enable_x64.value else np.dtype(np.complex64) ) diff --git a/jax/_src/numpy/scalar_types.py b/jax/_src/numpy/scalar_types.py index 360cb96ed1ed..4ebf75b020a7 100644 --- a/jax/_src/numpy/scalar_types.py +++ b/jax/_src/numpy/scalar_types.py @@ -102,7 +102,7 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta: complex64 = csingle = _make_scalar_type(np.complex64) complex128 = cdouble = _make_scalar_type(np.complex128) -int_ = int32 if dtypes.int_ == np.int32 else int64 -uint = uint32 if dtypes.uint == np.uint32 else uint64 -float_: Any = float32 if dtypes.float_ == np.float32 else float64 -complex_ = complex64 if dtypes.complex_ == np.complex64 else complex128 +int_ = int64 +uint = uint64 +float_ = float64 +complex_ = complex128 diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 503f100efdd4..b9397babb7b7 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -131,8 +131,7 @@ def to_default_dtype(arr: ArrayLike) -> np.ndarray: """Convert a value to an array with JAX's default dtype. This is generally used for type conversions of values returned by numpy functions, - to make their dtypes take into account the state of the ``jax_enable_x64`` and - ``jax_default_dtype_bits`` flags. + to make their dtypes take into account the state of the ``jax_enable_x64`` flag. """ arr = np.asarray(arr) dtype_fn = _dtypes.default_types.get(arr.dtype.kind) @@ -143,8 +142,7 @@ def with_jax_dtype_defaults(func: Callable[..., Any], use_defaults: bool = True) This is generally used to wrap numpy functions within tests, in order to make their default output dtypes match those of corresponding JAX functions, taking - into account the state of the ``jax_enable_x64`` and ``jax_default_dtype_bits`` - flags. + into account the state of the ``jax_enable_x64`` flag. Args: use_defaults : whether to convert any given output to the default dtype. May be diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index bbb1f0ea7afc..c598f71c8291 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -504,13 +504,11 @@ def testDtypeFromNone(self): dtypes.dtype(None) def testDefaultDtypes(self): - precision = config.default_dtype_bits.value - assert precision in ['32', '64'] self.assertEqual(dtypes.bool_, np.bool_) - self.assertEqual(dtypes.int_, np.int32 if precision == '32' else np.int64) - self.assertEqual(dtypes.uint, np.uint32 if precision == '32' else np.uint64) - self.assertEqual(dtypes.float_, np.float32 if precision == '32' else np.float64) - self.assertEqual(dtypes.complex_, np.complex64 if precision == '32' else np.complex128) + self.assertEqual(dtypes.int_, np.int64) + self.assertEqual(dtypes.uint, np.uint64) + self.assertEqual(dtypes.float_, np.float64) + self.assertEqual(dtypes.complex_, np.complex128) def test_check_dtype_non_hashable(self): # regression test for issue with checking non-hashable custom dtype diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 87383be65577..9a0ae02d5fe3 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -5789,13 +5789,11 @@ def np_op(x1, x2): self._CompileAndCheck(jnp.logaddexp2, args_maker, rtol=tol, atol=tol) def testDefaultDtypes(self): - precision = config.default_dtype_bits.value - assert precision in ['32', '64'] self.assertEqual(jnp.bool_, np.bool_) - self.assertEqual(jnp.int_, np.int32 if precision == '32' else np.int64) - self.assertEqual(jnp.uint, np.uint32 if precision == '32' else np.uint64) - self.assertEqual(jnp.float_, np.float32 if precision == '32' else np.float64) - self.assertEqual(jnp.complex_, np.complex64 if precision == '32' else np.complex128) + self.assertEqual(jnp.int_, np.int64) + self.assertEqual(jnp.uint, np.uint64) + self.assertEqual(jnp.float_, np.float64) + self.assertEqual(jnp.complex_, np.complex128) def testFromBuffer(self): buf = b'\x01\x02\x03' From dd7ff373363ca1db9cd297e66f990b8e272d8a34 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Fri, 19 Dec 2025 16:40:47 -0800 Subject: [PATCH 303/315] Support weakref_lru_cache.evict. PiperOrigin-RevId: 846909748 --- jaxlib/weakref_lru_cache.cc | 44 +++++++++++++++++++++++--------- jaxlib/weakref_lru_cache.pyi | 1 + jaxlib/weakref_lru_cache_test.py | 33 ++++++++++++++++++++++++ jaxlib/xla_client.py | 2 +- 4 files changed, 67 insertions(+), 13 deletions(-) diff --git a/jaxlib/weakref_lru_cache.cc b/jaxlib/weakref_lru_cache.cc index e43d0cb55813..9160503b86a2 100644 --- a/jaxlib/weakref_lru_cache.cc +++ b/jaxlib/weakref_lru_cache.cc @@ -113,6 +113,8 @@ class WeakrefLRUCache : public std::enable_shared_from_this { nb::object Call(nb::object weakref_key, nb::args args, nb::kwargs kwargs); + void EvictWeakref(nb::object weakref_key); + std::vector GetKeys(); struct CacheInfo { @@ -208,6 +210,8 @@ class WeakrefLRUCache : public std::enable_shared_from_this { return value.cache; } + WeakrefCacheKey MakeWeakrefKey(const nb::object& weakref_key); + nb::callable cache_context_fn_; nb::callable fn_; std::shared_ptr lru_list_; @@ -226,17 +230,8 @@ class WeakrefLRUCache : public std::enable_shared_from_this { static int tp_clear(PyObject* self); }; -nb::object WeakrefLRUCache::Call(nb::object weakref_key, nb::args args, - nb::kwargs kwargs) - ABSL_NO_THREAD_SAFETY_ANALYSIS { - nb::object context = cache_context_fn_(); - - // We precompute all of the hash values needed by the various maps rather - // than computing them during the std::unordered_map insertions. At the very - // least, MSVC's std::unordered_map has undefined behavior if the hash - // function throws an exception - // (https://learn.microsoft.com/en-us/cpp/standard-library/unordered-map-class?view=msvc-170#emplace). - Key key(context, args, kwargs); +WeakrefLRUCache::WeakrefCacheKey WeakrefLRUCache::MakeWeakrefKey( + const nb::object& weakref_key) { size_t wrcache_hash = static_cast(nb::hash(weakref_key)); // No hash computations after this point. @@ -267,7 +262,31 @@ nb::object WeakrefLRUCache::Call(nb::object weakref_key, nb::args args, cache->entries_.erase(it); }); nb::weakref weakref = nb::weakref(weakref_key, weakref_gc_callback); - WeakrefCacheKey wrcache_key{weakref, wrcache_hash}; + return WeakrefCacheKey{std::move(weakref), wrcache_hash}; +} + +void WeakrefLRUCache::EvictWeakref(nb::object weakref_key) { + auto it = entries_.find(MakeWeakrefKey(weakref_key)); + if (it == entries_.end()) { + return; + } + // Create temp-var to avoid re-entrant erase. + auto tmp = std::move(it->second); + entries_.erase(it); +} + +nb::object WeakrefLRUCache::Call(nb::object weakref_key, nb::args args, + nb::kwargs kwargs) + ABSL_NO_THREAD_SAFETY_ANALYSIS { + nb::object context = cache_context_fn_(); + // We precompute all of the hash values needed by the various maps rather + // than computing them during the std::unordered_map insertions. At the very + // least, MSVC's std::unordered_map has undefined behavior if the hash + // function throws an exception + // (https://learn.microsoft.com/en-us/cpp/standard-library/unordered-map-class?view=msvc-170#emplace). + Key key(context, args, kwargs); + auto wrcache_key = MakeWeakrefKey(weakref_key); + std::shared_ptr cache_ptr = GetCache(wrcache_key); Cache& cache = *cache_ptr; ++total_queries_; @@ -420,6 +439,7 @@ NB_MODULE(weakref_lru_cache, m) { nb::is_weak_referenceable(), nb::type_slots(WeakrefLRUCache::slots_)) .def("__call__", &WeakrefLRUCache::Call, nb::lock_self()) + .def("evict_weakref", &WeakrefLRUCache::EvictWeakref, nb::lock_self()) .def("cache_keys", &WeakrefLRUCache::GetKeys, nb::lock_self()) .def("cache_info", &WeakrefLRUCache::GetCacheInfo, nb::lock_self()) .def("cache_clear", &WeakrefLRUCache::Clear, nb::lock_self()); diff --git a/jaxlib/weakref_lru_cache.pyi b/jaxlib/weakref_lru_cache.pyi index ed965d7be811..9209939afabb 100644 --- a/jaxlib/weakref_lru_cache.pyi +++ b/jaxlib/weakref_lru_cache.pyi @@ -18,6 +18,7 @@ from typing import Any class WeakrefLRUCache: def __call__(self, arg0: Any, /, *args, **kwargs) -> Any: ... + def evict_weakref(self, arg0: Any) -> None: ... def cache_keys(self) -> list[Any]: ... def cache_info(self) -> WeakrefLRUCache.WeakrefLRUCacheInfo: ... def cache_clear(self) -> None: ... diff --git a/jaxlib/weakref_lru_cache_test.py b/jaxlib/weakref_lru_cache_test.py index aae2c2fec31b..c47ec680bf6a 100644 --- a/jaxlib/weakref_lru_cache_test.py +++ b/jaxlib/weakref_lru_cache_test.py @@ -14,6 +14,7 @@ # ============================================================================== import gc +import random import threading import time import weakref @@ -280,6 +281,38 @@ def __hash__(self): for _ in range(100): cache(wrkey, ReentrantKey()) + def testEvictWeakref(self): + dtor_list = [] + + class NoisyDestructor: + + def __init__(self, v): + self.v = v + + def __del__(self): + dtor_list.append(self.v) + + cache = weakref_lru_cache.weakref_lru_cache( + lambda: None, lambda x, y: NoisyDestructor(y) + ) + + class WRKey: + pass + + N = 100 + expected_deletes = [] + plan = list(range(N)) * 2 + random.shuffle(plan) + keys = [None] * N + for i in plan: + if keys[i] is None: + keys[i] = WRKey() + cache(keys[i], i) + else: + cache.evict_weakref(keys[i]) + expected_deletes.append(i) + self.assertEqual(dtor_list, expected_deletes) + if __name__ == "__main__": absltest.main() diff --git a/jaxlib/xla_client.py b/jaxlib/xla_client.py index c26c41ccf4fb..0ba6898a4dee 100644 --- a/jaxlib/xla_client.py +++ b/jaxlib/xla_client.py @@ -47,7 +47,7 @@ # Please suffix the version number with a brief description of your change # in a comment. The goal here is to force a merge conflict if two changes # attempt to grab the same version number. -_version = 391 # ResultHandler.pre_wrap +_version = 392 # weakref_lru_cache.evict_weakref # An internal increasing version number for protecting jaxlib code against # ifrt changes. From d3ed2f4aaa6c05f591c0cfde5ebedc03824498c0 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 19 Dec 2025 18:06:03 -0800 Subject: [PATCH 304/315] Add replicated -> unreduced test coverage PiperOrigin-RevId: 846930629 --- tests/pjit_test.py | 62 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 74e31da86864..edf345bd70bb 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -62,6 +62,7 @@ from jax._src.mesh import AxisType from jax._src.interpreters import pxla from jax._src.lib import xla_client as xc +from jax._src.lib import ifrt_version from jax._src.util import curry, unzip2 config.parse_flags_with_absl() @@ -9691,6 +9692,10 @@ def test_jnp_repeat_arraylike(self, mesh): ) @jtu.with_explicit_mesh((2,), 'x') def test_both_inputs_reduced(self, func, mesh): + if ifrt_version < 46: + self.skipTest('Requires ifrt_version >= 46') + if not jtu.is_cloud_tpu_at_least(2025, 12, 22): + self.skipTest('Requires libtpu built after 2025-12-22') arr1 = jax.device_put(np.arange(8.), P(reduced={'x'})) arr2 = jax.device_put(np.arange(8.), P(reduced={'x'})) @@ -9705,6 +9710,12 @@ def f(x, y): self.assertEqual(out2.sharding, NamedSharding(mesh, P(None, unreduced={'x'}))) + arr3 = jax.device_put(np.arange(8.), P()) + arr4 = jax.device_put(np.arange(8.), P()) + ex_out1, ex_out2 = jax.jit(jax.grad(f, argnums=(0, 1)))(arr3, arr4) + self.assertArraysEqual(reshard(out1, P()), ex_out1) + self.assertArraysEqual(reshard(out2, P()), ex_out2) + @parameterized.named_parameters( ('mul', jax.lax.mul), ('add', jax.lax.add), @@ -9755,6 +9766,57 @@ def test_sharded_unreduced_roundtrip(self, shape, orig_spec, un_spec, mesh): self.assertArraysEqual(arr, arr3) self.assertEqual(arr.sharding, arr3.sharding) + @jtu.with_explicit_mesh((2,), ('x',)) + def test_scalar_to_unreduced(self, mesh): + if ifrt_version < 46: + self.skipTest('Requires ifrt_version >= 46') + if not jtu.is_cloud_tpu_at_least(2025, 12, 22): + self.skipTest('Requires libtpu built after 2025-12-22') + inp = jnp.array(1) + for s in inp.addressable_shards: + self.assertArraysEqual(s.data, inp) + + out = reshard(inp, P(unreduced={'x'})) + expected_out = [inp, jnp.array(0)] + for s, ex_out in zip(out.addressable_shards, expected_out): + self.assertArraysEqual(s.data, ex_out) + + out2 = reshard(out, P()) + for s, inp_s in zip(out2.addressable_shards, inp.addressable_shards): + self.assertArraysEqual(s.data, inp_s.data) + + @parameterized.parameters( + ((4,), P(None), P(None, unreduced={'x'})), + ((4,), P(None), P(None, unreduced={'y'})), + ((4,), P(None), P(None, unreduced={'x', 'y'})), + ((4, 2), P(None, None), P(None, None, unreduced={'x'})), + ((4, 2), P(None, None), P(None, None, unreduced={'y'})), + ((4, 2), P(None, None), P(None, None, unreduced={'x', 'y'})), + ((4, 4), P('x', None), P('x', None, unreduced={'y'})), + ((4, 4), P(None, 'y'), P(None, 'y', unreduced={'x'})), + ((4, 4), P('x', None), P(None, None, unreduced={'x', 'y'})), + ((4, 4), P('y', None), P(None, None, unreduced={'x', 'y'})), + ((4, 4), P('x', 'y'), P(None, None, unreduced={'x', 'y', 'z'})), + ((4, 4), P('x', 'z'), P(None, None, unreduced={'x', 'y', 'z'})), + ((4, 4), P(('x', 'z'), 'y'), P(None, None, unreduced={'x', 'y', 'z'})), + ((4, 4), P(('x', 'z'), 'y'), P(None, None, unreduced={'y', 'z'})), + ) + @jtu.with_explicit_mesh((2, 2, 2), ('x', 'y', 'z')) + def test_replicated_unreduced_roundtrip(self, shape, orig_spec, un_spec, mesh): + if ifrt_version < 46: + self.skipTest('Requires ifrt_version >= 46') + if not jtu.is_cloud_tpu_at_least(2025, 12, 22): + self.skipTest('Requires libtpu built after 2025-12-22') + np1 = np.arange(math.prod(shape)).reshape(shape) + arr = jax.device_put(np1, orig_spec) + + arr2 = reshard(arr, un_spec) + self.assertEqual(arr2.sharding, NamedSharding(mesh, un_spec)) + + arr3 = reshard(arr2, orig_spec) + self.assertArraysEqual(arr, arr3) + self.assertEqual(arr.sharding, arr3.sharding) + @parameterized.named_parameters( ('mul', jax.lax.mul), ('add', jax.lax.add), From 20bdca7a2610707e2c2719dffa41828f900951bc Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 19 Dec 2025 21:45:50 -0800 Subject: [PATCH 305/315] Fix a broken sharded -> unreduced test PiperOrigin-RevId: 846986071 --- tests/pjit_test.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index edf345bd70bb..16e1e5948b57 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -9799,10 +9799,13 @@ def test_scalar_to_unreduced(self, mesh): ((4, 4), P('x', 'y'), P(None, None, unreduced={'x', 'y', 'z'})), ((4, 4), P('x', 'z'), P(None, None, unreduced={'x', 'y', 'z'})), ((4, 4), P(('x', 'z'), 'y'), P(None, None, unreduced={'x', 'y', 'z'})), - ((4, 4), P(('x', 'z'), 'y'), P(None, None, unreduced={'y', 'z'})), + ((4, 4), P(('x', 'z'), 'y'), P('x', None, unreduced={'y', 'z'})), + ((4, 4), P('z', 'y'), P(None, None, unreduced={'y', 'z'})), + ((4, 4), P('z', 'y'), P(None, None, unreduced={'x', 'y', 'z'})), ) @jtu.with_explicit_mesh((2, 2, 2), ('x', 'y', 'z')) - def test_replicated_unreduced_roundtrip(self, shape, orig_spec, un_spec, mesh): + def test_replicated_sharded_unreduced_roundtrip( + self, shape, orig_spec, un_spec, mesh): if ifrt_version < 46: self.skipTest('Requires ifrt_version >= 46') if not jtu.is_cloud_tpu_at_least(2025, 12, 22): From 6469f170235d91adcff60c9495cb1181b6b2dba1 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sat, 20 Dec 2025 00:06:58 -0800 Subject: [PATCH 306/315] Update XLA dependency to use revision http://github.com/openxla/xla/commit/66dbbf501ffd74f83c6a5d8fc201c756b1198d64 PiperOrigin-RevId: 847021722 --- third_party/xla/revision.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/revision.bzl b/third_party/xla/revision.bzl index 5adedc2e48ce..2b793f2b9fd7 100644 --- a/third_party/xla/revision.bzl +++ b/third_party/xla/revision.bzl @@ -21,5 +21,5 @@ # and update XLA_SHA256 with the result. # buildifier: disable=module-docstring -XLA_COMMIT = "c45b8fed642a7ce99e315f979dee6fc45e08f79f" -XLA_SHA256 = "75ba4f7a261fa43834791c218b8a8909d003fcf4bf28426f32fabcbf09682352" +XLA_COMMIT = "66dbbf501ffd74f83c6a5d8fc201c756b1198d64" +XLA_SHA256 = "be875b43335f20c93fd1a910fa21d086fbd0a6c61b2a270ad2c837d3fdb42607" From 0ad46d99f3c27277570ae6131f700e2eeb58a453 Mon Sep 17 00:00:00 2001 From: jax authors Date: Sun, 21 Dec 2025 00:06:17 -0800 Subject: [PATCH 307/315] Update XLA dependency to use revision http://github.com/openxla/xla/commit/08872f587d442a05802cbdb052e8c9e6e87423f4 PiperOrigin-RevId: 847325574 --- third_party/xla/revision.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/revision.bzl b/third_party/xla/revision.bzl index 2b793f2b9fd7..35d5cca6bb85 100644 --- a/third_party/xla/revision.bzl +++ b/third_party/xla/revision.bzl @@ -21,5 +21,5 @@ # and update XLA_SHA256 with the result. # buildifier: disable=module-docstring -XLA_COMMIT = "66dbbf501ffd74f83c6a5d8fc201c756b1198d64" -XLA_SHA256 = "be875b43335f20c93fd1a910fa21d086fbd0a6c61b2a270ad2c837d3fdb42607" +XLA_COMMIT = "08872f587d442a05802cbdb052e8c9e6e87423f4" +XLA_SHA256 = "9d0919042ff878d270bec2f7e711e248adc2a957b65445c461ed065e63479eea" From f29645bbde291adadc41f91489a4af7e7e2c64e1 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 22 Dec 2025 00:05:57 -0800 Subject: [PATCH 308/315] Update XLA dependency to use revision http://github.com/openxla/xla/commit/7521349eccb22a780f50fa5f6f09dbaa1d09f470 PiperOrigin-RevId: 847636908 --- third_party/xla/revision.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/revision.bzl b/third_party/xla/revision.bzl index 35d5cca6bb85..6e053326816f 100644 --- a/third_party/xla/revision.bzl +++ b/third_party/xla/revision.bzl @@ -21,5 +21,5 @@ # and update XLA_SHA256 with the result. # buildifier: disable=module-docstring -XLA_COMMIT = "08872f587d442a05802cbdb052e8c9e6e87423f4" -XLA_SHA256 = "9d0919042ff878d270bec2f7e711e248adc2a957b65445c461ed065e63479eea" +XLA_COMMIT = "7521349eccb22a780f50fa5f6f09dbaa1d09f470" +XLA_SHA256 = "7857bf7a1e0bd7e3bb8a9e90f3fae4186c179f005b80a0a0c3ac4d001e2fc606" From 2da971365f9e15fc9f9583768103c2687ce7063c Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Mon, 22 Dec 2025 01:37:23 -0800 Subject: [PATCH 309/315] [Mosaic GPU] Add `nvshmemx_cumodule_finalize` support. PiperOrigin-RevId: 847665237 --- jaxlib/mosaic/gpu/BUILD | 3 ++- jaxlib/mosaic/gpu/nvshmem.h | 7 +++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index b7359495f184..5c9bff0c3e1e 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -232,10 +232,11 @@ cc_library( "-Wl,--export-dynamic-symbol='mosaic_gpu_*'", "-Wl,--export-dynamic-symbol='nvshmem_my_pe'", "-Wl,--export-dynamic-symbol='nvshmem_ptr'", - "-Wl,--export-dynamic-symbol='nvshmemx_mc_ptr'", "-Wl,--export-dynamic-symbol='nvshmemx_barrier_all_on_stream'", + "-Wl,--export-dynamic-symbol='nvshmemx_cumodule_finalize'", "-Wl,--export-dynamic-symbol='nvshmemx_cumodule_init'", "-Wl,--export-dynamic-symbol='nvshmemx_init_status'", + "-Wl,--export-dynamic-symbol='nvshmemx_mc_ptr'", ], deps = [ ":nvshmem", diff --git a/jaxlib/mosaic/gpu/nvshmem.h b/jaxlib/mosaic/gpu/nvshmem.h index dbd11aa1d373..7869b55b7a31 100644 --- a/jaxlib/mosaic/gpu/nvshmem.h +++ b/jaxlib/mosaic/gpu/nvshmem.h @@ -54,6 +54,11 @@ class NvshmemApi { return nvshmemx_cumodule_init(module); } + int cumodule_finalize(CUmodule module) { + std::lock_guard lock(mutex_); + return nvshmemx_cumodule_finalize(module); + } + void barrier_all_on_stream(cudaStream_t stream) { nvshmemx_barrier_all_on_stream(stream); } @@ -77,11 +82,13 @@ class NvshmemApi { } NVSHMEM_SET_FN(nvshmemx_barrier_all_on_stream) + NVSHMEM_SET_FN(nvshmemx_cumodule_finalize) NVSHMEM_SET_FN(nvshmemx_cumodule_init) NVSHMEM_SET_FN(nvshmemx_init_status) } int (*nvshmemx_barrier_all_on_stream)(cudaStream_t); + int (*nvshmemx_cumodule_finalize)(CUmodule); int (*nvshmemx_cumodule_init)(CUmodule); int (*nvshmemx_init_status)(); From 6d41fa0c5a9400f40a4c8fb53c32b45e460903cd Mon Sep 17 00:00:00 2001 From: Dirk Hornung Date: Mon, 22 Dec 2025 04:38:21 -0800 Subject: [PATCH 310/315] [Autotuner] Prepare the CuDNN fusion test for the new autotuner. PiperOrigin-RevId: 847713430 --- tests/cudnn_fusion_test.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/tests/cudnn_fusion_test.py b/tests/cudnn_fusion_test.py index a06c78074be7..5e04d69889fc 100644 --- a/tests/cudnn_fusion_test.py +++ b/tests/cudnn_fusion_test.py @@ -12,21 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -from absl.testing import absltest, parameterized from unittest import SkipTest -from jax._src import test_util as jtu +from absl.testing import absltest, parameterized import jax -import jax.numpy as jnp +from jax._src import test_util as jtu from jax._src.cudnn import cudnn_fusion +import jax.numpy as jnp jax.config.parse_flags_with_absl() class CudnnFusionTest(jtu.JaxTestCase): + def setUp(self): - if (not jtu.test_device_matches(["cuda"]) or - not jtu.is_cuda_compute_capability_at_least("8.0")): + if not jtu.test_device_matches( + ["cuda"] + ) or not jtu.is_cuda_compute_capability_at_least("8.0"): self.skipTest("Only works on >= sm80 GPUs") super().setUp() @@ -38,11 +40,11 @@ def test_cudnn_fusion(self, mode): batch_size = 2 if mode == "pmap" and jax.device_count() < batch_size: - raise SkipTest("pmap test requires 2 GPUs") + raise SkipTest("pmap test requires 2 GPUs") @cudnn_fusion def comp1(x, y, z): - return jnp.float32(jax.lax.batch_matmul(jnp.bfloat16(x), y)) + z + return jnp.float32(jax.lax.batch_matmul(jnp.bfloat16(x), y)) + z k = jax.random.key(0) s = batch_size, 16, 16 @@ -61,7 +63,14 @@ def comp1(x, y, z): self.assertIn('custom_call_target="__cudnn$fusion"', hlo) self.assertIn("called_computations=", hlo) - compiled = lowered.compile({"xla_gpu_cublas_fallback": False}) + compiled = lowered.compile({ + # Disable Cublas to make sure CuDNN is used. + "xla_gpu_cublas_fallback": False, + # Enable CuDNN fusions. + "xla_gpu_cudnn_gemm_fusion_level": 2, + # Disable autotuning to pick first config to ensure CuDNN is always used. + "xla_gpu_autotune_level": 0, + }) hlo_after_opt = compiled.as_text() self.assertIn("kind=kCustom", hlo_after_opt) @@ -70,5 +79,5 @@ def comp1(x, y, z): self.assertAllClose(compiled(x, y, z), fn(x, y, z)) -if __name__ == '__main__': +if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) From c25a5149812bf327a7e421b0f58f0429b718437a Mon Sep 17 00:00:00 2001 From: Subham Soni Date: Mon, 22 Dec 2025 10:03:29 -0800 Subject: [PATCH 311/315] Add programmatic profiling test with session_id support in JAX. PiperOrigin-RevId: 847801848 --- tests/BUILD | 10 +++++ tests/profiler_session_test.py | 76 ++++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+) create mode 100644 tests/profiler_session_test.py diff --git a/tests/BUILD b/tests/BUILD index 3f12aea68415..c2ac19688917 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1211,6 +1211,16 @@ jax_multiplatform_test( ]), ) +jax_py_test( + name = "profiler_session_test", + srcs = ["profiler_session_test.py"], + deps = [ + "//jax", + "//jax/_src:profiler", + "//jax/_src:test_util", + ] + py_deps("absl/testing"), +) + jax_multiplatform_test( name = "pytorch_interoperability_test", srcs = ["pytorch_interoperability_test.py"], diff --git a/tests/profiler_session_test.py b/tests/profiler_session_test.py new file mode 100644 index 000000000000..695a1552a9fc --- /dev/null +++ b/tests/profiler_session_test.py @@ -0,0 +1,76 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pathlib + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import test_util as jtu +import jax.numpy as jnp + +_TEST_SESSION_ID = 'my_custom_session_123' + + +@jtu.thread_unsafe_test_class() +class ProfilerSessionTest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + # Ensure that any running profiler is stopped before starting the test. + # This is in setUp rather than tearDown to defend against previous tests + # that may have crashed or failed to clean up properly. + try: + jax.profiler.stop_trace() + except RuntimeError: + pass + + @parameterized.named_parameters( + dict(testcase_name='without_session_id', session_id=None), + dict(testcase_name='with_empty_session_id', session_id=''), + dict(testcase_name='with_custom_session_id', session_id=_TEST_SESSION_ID), + ) + def test_programmatic_profiling(self, session_id: str | None): + tmpdir = pathlib.Path(self.create_tempdir()) + + options = jax.profiler.ProfileOptions() + if session_id is not None: + options.session_id = session_id + + with jax.profiler.trace(tmpdir, profiler_options=options): + jax.pmap(lambda x: jax.lax.psum(x + 1, 'i'), axis_name='i')( + jnp.ones(jax.local_device_count()) + ).block_until_ready() + + profile_plugin_dir = tmpdir / 'plugins' / 'profile' + self.assertTrue(profile_plugin_dir.exists(), f'Not found at {profile_plugin_dir}') + + subdirs = [x.name for x in profile_plugin_dir.iterdir() if x.is_dir()] + self.assertLen(subdirs, 1) + + if session_id is None or not session_id: + self.assertNotIn(_TEST_SESSION_ID, subdirs) + self.assertNotIn('', subdirs) + target_dir = subdirs[0] + else: + self.assertIn(session_id, subdirs) + target_dir = session_id + + session_dir = profile_plugin_dir / target_dir + pb_files = list(session_dir.glob('*.xplane.pb')) + self.assertNotEmpty(pb_files, f'No .xplane.pb files found in {session_dir}') + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader()) From 9fe0430dabc1fc23cb84ab513365a9278ebd85ec Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Mon, 22 Dec 2025 13:42:04 -0800 Subject: [PATCH 312/315] [Mosaic] Extend tpu.pack_elementwise to support non-32-bit integers. PiperOrigin-RevId: 847871002 --- jax/_src/pallas/mosaic/lowering.py | 10 ++++---- jax/_src/pallas/mosaic/primitives.py | 12 +++++++++- jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 16 ++++++++++--- jaxlib/mosaic/dialect/tpu/tpu_ops.td | 14 +++++++---- tests/pallas/tpu_ops_test.py | 35 ++++++++++++++++++---------- 5 files changed, 61 insertions(+), 26 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 229b29a1a130..045628dbd445 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -3593,7 +3593,7 @@ def _stochastic_round_lowering_rule( return tpu.stochastic_convert(out_type, x, random_bits) -def _check_elementwise_packing_dtypes(unpacked_dtype, packed_dtype): +def _check_elementwise_unpack_dtypes(unpacked_dtype, packed_dtype): if unpacked_dtype == jnp.float32 and packed_dtype == jnp.bfloat16: return if unpacked_dtype == jnp.int32 and packed_dtype in [ @@ -3611,11 +3611,9 @@ def _pack_elementwise_lowering_rule( ctx: LoweringRuleContext, *xs, packed_dtype ): in_aval = ctx.avals_in[0] - _check_elementwise_packing_dtypes(in_aval.dtype, packed_dtype) + out_aval = ctx.avals_out[0] packed_ir_type = _dtype_to_ir_type(packed_dtype) - out_type = ir.VectorType.get( - in_aval.shape, _dtype_to_ir_type(jnp.uint32) - ) + out_type = ir.VectorType.get(in_aval.shape, _dtype_to_ir_type(out_aval.dtype)) return tpu.pack_elementwise(out_type, xs, target_type=packed_ir_type) @@ -3624,7 +3622,7 @@ def _unpack_elementwise_lowering_rule( ctx: LoweringRuleContext, x, index, packed_dtype, unpacked_dtype ): in_aval = ctx.avals_in[0] - _check_elementwise_packing_dtypes(unpacked_dtype, packed_dtype) + _check_elementwise_unpack_dtypes(unpacked_dtype, packed_dtype) out_type = ir.VectorType.get( in_aval.shape, _dtype_to_ir_type(unpacked_dtype) ) diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index 1b94e5e75512..67aaab743d24 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -1079,6 +1079,7 @@ def _stochastic_round_abstract_eval(x, random_bits, *, target_dtype): ) return jax_core.ShapedArray(x.shape, target_dtype) + def _get_elementwise_packing_factor(unpacked_dtype, packed_dtype): unpacked_bitwidth = dtypes.itemsize_bits(unpacked_dtype) packed_bitwidth = dtypes.itemsize_bits(packed_dtype) @@ -1105,13 +1106,22 @@ def _pack_elementwise_abstract_eval(*xs, packed_dtype): raise ValueError("All sources must have the same shape") if not all(x.dtype == first.dtype for x in xs): raise ValueError("All sources must have the same dtype") + if not (first.dtype == jnp.float32 and packed_dtype == jnp.bfloat16) and not ( + jnp.issubdtype(first.dtype, jnp.integer) + and jnp.issubdtype(packed_dtype, jnp.integer) + ): + raise ValueError( + "Only f32 -> bf16 and int -> int are supported. Got" + f" {first.dtype} and {packed_dtype}" + ) packing_factor = _get_elementwise_packing_factor(first.dtype, packed_dtype) if len(xs) != packing_factor: raise ValueError( "The number of sources must match the packing factor " f"({packing_factor}), got {len(xs)}" ) - return jax_core.ShapedArray(first.shape, jnp.uint32) + out_dtype = jnp.dtype(f"uint{dtypes.itemsize_bits(first.dtype)}") + return jax_core.ShapedArray(first.shape, out_dtype) unpack_elementwise_p = jax_core.Primitive("unpack_elementwise") diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 855625337e64..ec56b3561591 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -2042,9 +2042,19 @@ LogicalResult PackElementwiseOp::verify() { return emitOpError("At least one source is required"); } const auto src_vty = cast(getSources().front().getType()); - if (failed(verifyElementwisePacking(*this, src_vty.getElementType(), - getTargetType()))) { - return failure(); + if (getElementTypeBitwidth(src_vty) != getElementTypeBitwidth(getType())) { + return emitOpError("All sources must have the same bitwidth as the result"); + } + if (!getType().getElementType().isSignlessInteger()) { + return emitOpError("Output type must be a signless integer type"); + } + + auto src_elem_ty = src_vty.getElementType(); + auto tgt_elem_ty = getTargetType(); + if (!(src_elem_ty.isF32() && tgt_elem_ty.isBF16()) && + !(src_elem_ty.isSignlessInteger() && tgt_elem_ty.isSignlessInteger())) { + return emitOpError( + "Only packing f32 -> bf16 and integer -> integer is supported"); } const int packing_factor = getElementTypeBitwidth(src_vty) / getTypeBitwidth(getTargetType()); diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.td b/jaxlib/mosaic/dialect/tpu/tpu_ops.td index 3329220cdd91..58f36f78d499 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.td +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.td @@ -754,14 +754,20 @@ def TPU_PackElementwiseOp : TPU_Op<"pack_elementwise", [Pure, SameTypeOperands, The number of `sources` must equal the packing factor, which is the ratio of the element bitwidth of the `sources` to the element bitwidth of the `target_type`. Elements from the `sources` are interleaved and packed into - each word of the `output`, ordered from lowest to highest bits, - corresponding to their order in the `sources`. + the `output`, ordered from lowest to highest bits, corresponding to their + order in the `sources`. The `output` is then bitcasted to the signless + integer type of the same bitwidth as the `sources`. + + Note that for integer packing, the bits in `sources` that exceed the + bitwidth of the `target_type` are just truncated. + For example, given the `sources` are int8 xxxx'1001 and yyyy'0011, + `target_type` is int4, the output will be 0011'1001. }]; let arguments = (ins - Variadic>:$sources, + Variadic>:$sources, TypeAttr:$target_type ); - let results = (outs VectorOfNonZeroRankOf<[I32]>:$output); + let results = (outs VectorOfNonZeroRankOf<[AnyInteger]>:$output); let assemblyFormat = [{ $sources attr-dict `:` type($sources) `->` type($output) }]; let hasVerifier = 1; } diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index 52d92f78fb3f..e22b391b8e62 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -758,28 +758,31 @@ def kernel(x_ref, b_ref, o_ref): def _pack_unpack_elementwise_test_data( self, shape, unpacked_dtype, packed_dtype): """Generates data for test_pack_elementwise and test_unpack_elementwise.""" - bitwidth = dtypes.itemsize_bits(packed_dtype) - num_sources = 32 // bitwidth - if unpacked_dtype == jnp.int32: + unpacked_bitwidth = dtypes.itemsize_bits(unpacked_dtype) + packed_bitwidth = dtypes.itemsize_bits(packed_dtype) + num_sources = unpacked_bitwidth // packed_bitwidth + if jnp.issubdtype(unpacked_dtype, jnp.integer): stacked_sources = jax.random.randint( jax.random.key(0), (num_sources, *shape), minval=-1000, maxval=1000, - dtype=unpacked_dtype, - ) + dtype=jnp.int32, + ).astype(unpacked_dtype) else: stacked_sources = jax.random.uniform( jax.random.key(0), (num_sources, *shape), dtype=unpacked_dtype ) stacked_results = ( stacked_sources.astype(packed_dtype) - .view(getattr(jnp, f"uint{bitwidth}")) - .astype(jnp.uint32) + .view(getattr(jnp, f"uint{packed_bitwidth}")) + .astype(getattr(jnp, f"uint{unpacked_bitwidth}")) ) - shifts = jnp.arange(num_sources, dtype=jnp.uint32) * bitwidth + shifts = jnp.arange(num_sources, dtype=jnp.uint32) * packed_bitwidth shifts = jnp.expand_dims(shifts, axis=tuple(range(1, stacked_results.ndim))) - packed_data = jnp.bitwise_or.reduce(stacked_results << shifts, axis=0) + packed_data = jnp.bitwise_or.reduce( + stacked_results.astype(jnp.uint32) << shifts, axis=0 + ).astype(getattr(jnp, f"uint{unpacked_bitwidth}")) return stacked_sources, packed_data @parameterized.product( @@ -788,6 +791,8 @@ def _pack_unpack_elementwise_test_data( (jnp.int32, jnp.int16), (jnp.int32, jnp.int8), (jnp.int32, jnp.int4), + (jnp.int16, jnp.int8), + (jnp.int8, jnp.int4), ], shape=[(8, 128), (2, 15, 300)], ) @@ -795,9 +800,15 @@ def test_pack_elementwise(self, config, shape): unpacked_dtype, packed_dtype = config if not jtu.is_device_tpu_at_least(version=5): self.skipTest("Requires TPU v5+") + if dtypes.itemsize_bits( + unpacked_dtype + ) != 32 and not jtu.is_cloud_tpu_at_least(2026, 1, 2): + self.skipTest("Test requires libtpu from 2026/01/02 or later") - bitwidth = dtypes.itemsize_bits(packed_dtype) - num_sources = 32 // bitwidth + src_bitwidth = dtypes.itemsize_bits(unpacked_dtype) + tgt_bitwidth = dtypes.itemsize_bits(packed_dtype) + num_sources = src_bitwidth // tgt_bitwidth + output_dtype = getattr(jnp, f"uint{src_bitwidth}") def kernel(xs_ref, o_ref): xs = [xs_ref[i] for i in range(num_sources)] @@ -809,7 +820,7 @@ def kernel(xs_ref, o_ref): result = self.pallas_call( kernel, - out_shape=jax.ShapeDtypeStruct(shape, jnp.uint32), + out_shape=jax.ShapeDtypeStruct(shape, output_dtype), )(stacked_sources) np.testing.assert_array_equal(result, expected) From 5173d18a1cbf073420006a7a23f174894ca3df88 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 22 Dec 2025 15:45:08 -0800 Subject: [PATCH 313/315] Adjust tile_n assert to accommodate 2 cta in tcgen05 blockscale mma. PiperOrigin-RevId: 847908521 --- jax/experimental/mosaic/gpu/tcgen05.py | 2 +- tests/mosaic/gpu_test.py | 2 +- tests/pallas/mosaic_gpu_test.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/jax/experimental/mosaic/gpu/tcgen05.py b/jax/experimental/mosaic/gpu/tcgen05.py index 79fffad7f511..c623a1b3349c 100644 --- a/jax/experimental/mosaic/gpu/tcgen05.py +++ b/jax/experimental/mosaic/gpu/tcgen05.py @@ -666,7 +666,7 @@ def _get_offset(idx: int, idx_tiling: tuple[int, ...], strides: tuple[int, ...]) scale_id, scale_id, a_transpose, b_transpose ) assert m == 128 - assert n % 128 == 0 + assert (n * num_cta) % 128 == 0 # A scales are sharded, B scales are replicated across CTAs. a_scale_addr_offset = arith.constant(i32, k_step // scale_steps * 4) b_scale_addr_offset = arith.constant(i32, k_step // scale_steps * n // 32 * num_cta) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index deb8772937d0..364a27756fbe 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1760,7 +1760,7 @@ def format_scales(scales): @parameterized.product( m=(256,), - n=(256,), + n=(128, 256), scale_jax_dtype=(jnp.float8_e8m0fnu, jnp.float8_e4m3fn), ) def test_mma_block_scaled_collective(self, m, n, scale_jax_dtype): diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 05c2ef48516a..4ba116b53aa6 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -3920,7 +3920,7 @@ def format_scales(scales): @parameterized.product( m=[256], - n=[256], + n=[128, 256], scale_jax_dtype=[jnp.float8_e8m0fnu, jnp.float8_e4m3fn], ) def test_collective_scaled_matmul(self, m, n, scale_jax_dtype): From 82ae1b1cde42a5b93e00d8c3376cde627c2d83bb Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 23 Dec 2025 00:06:27 -0800 Subject: [PATCH 314/315] Update XLA dependency to use revision http://github.com/openxla/xla/commit/d1635d1c99de225c8029d82e56c4dd03f90b013f PiperOrigin-RevId: 848049355 --- third_party/xla/revision.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/xla/revision.bzl b/third_party/xla/revision.bzl index 6e053326816f..294b54b92c4a 100644 --- a/third_party/xla/revision.bzl +++ b/third_party/xla/revision.bzl @@ -21,5 +21,5 @@ # and update XLA_SHA256 with the result. # buildifier: disable=module-docstring -XLA_COMMIT = "7521349eccb22a780f50fa5f6f09dbaa1d09f470" -XLA_SHA256 = "7857bf7a1e0bd7e3bb8a9e90f3fae4186c179f005b80a0a0c3ac4d001e2fc606" +XLA_COMMIT = "d1635d1c99de225c8029d82e56c4dd03f90b013f" +XLA_SHA256 = "a2ca41e647fe8166c0ec5597420a784c9f45fbf1edfea79ab5704e814c21d86f" From 5c2e646435979b533d9c52a49aacf76a5c6ab00d Mon Sep 17 00:00:00 2001 From: Rob Taylor Date: Wed, 24 Dec 2025 00:15:07 +0000 Subject: [PATCH 315/315] Add iree_metal to platforms with buffer donation support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The IREE Metal PJRT plugin now supports buffer donation, which allows JAX to reuse input buffers for outputs when the input is no longer needed. This optimization reduces memory allocation overhead. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- jax/_src/interpreters/mlir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 7af9770793ca..69c50c07bad1 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1034,7 +1034,7 @@ class LoweringResult(NamedTuple): shape_poly_state: ShapePolyLoweringState -_platforms_with_donation = ["cpu", "cuda", "rocm", "tpu", "neuron"] +_platforms_with_donation = ["cpu", "cuda", "rocm", "tpu", "neuron", "iree_metal"] def add_manual_axes(axis_ctx: sharding_impls.SPMDAxisContext, sharding, ndim):