From 067dd46c05c77123c5e154b5a20897b0fdd2643f Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 26 Feb 2026 19:16:17 +0000 Subject: [PATCH 001/100] [hijax] optimize code, much better Co-authored-by: --- jax/_src/hijax.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/jax/_src/hijax.py b/jax/_src/hijax.py index c1189c7d21ee..01aa2ed04c23 100644 --- a/jax/_src/hijax.py +++ b/jax/_src/hijax.py @@ -506,13 +506,8 @@ def vjp_bwd_retval(self, res_, g): return tree_map(partial(unmap_zero, self.axis_data), self.in_dims, out, is_leaf=lambda x: x is None) # type: ignore def batch_dim_rule(self, axis_data, in_dims): - - def fix_dim(dim, prev_dim): - if dim is None: - return None - return dim if prev_dim is None else (dim - (prev_dim < dim)) - - in_dims_ = tree_map(fix_dim, in_dims, self.in_dims, is_leaf=lambda x: x is None) + fix = lambda d, d_: d if (d is None or d_ is None) else d - (d_ < d) # type: ignore + in_dims_ = tree_map(fix, in_dims, self.in_dims, is_leaf=lambda x: x is None) # type: ignore out_dim = self.prim.batch_dim_rule(axis_data, in_dims_) # type: ignore return tree_map(lambda d, d_: d + (d_ < d), out_dim, self.out_dim) # type: ignore From dd118091ac2756a8e5a2d78f97a2f27c79f22f8f Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 2 Mar 2026 09:31:52 +0000 Subject: [PATCH 002/100] Even more Pyrefly fixes --- .../mosaic/interpret/interpret_pallas_call.py | 8 ++--- jax/_src/pallas/mosaic/lowering.py | 16 ++++----- .../pallas/mosaic/pallas_call_registration.py | 2 ++ jax/_src/pallas/mosaic/sc_lowering.py | 2 +- jax/_src/pallas/mosaic/tpu_info.py | 10 ++++-- jax/_src/pallas/mosaic_gpu/core.py | 4 +-- jax/_src/pallas/mosaic_gpu/lowering.py | 33 ++++++++++--------- jax/_src/pallas/mosaic_gpu/primitives.py | 18 +++++----- jax/_src/pallas/mosaic_gpu/torch.py | 4 +-- jax/_src/pallas/primitives.py | 2 +- jax/_src/pallas/triton/lowering.py | 32 +++++++++--------- .../pallas/triton/pallas_call_registration.py | 2 +- jax/experimental/mosaic/gpu/utils.py | 8 ++++- 13 files changed, 79 insertions(+), 62 deletions(-) diff --git a/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py b/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py index f0ceaa306ee2..381774a2effd 100644 --- a/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py +++ b/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py @@ -1845,15 +1845,15 @@ def interpret_pallas_call( is_input = i < grid_mapping.num_inputs is_output = (output_idx >= 0) and (output_idx < grid_mapping.num_outputs) aval = var.aval - memory_space = _forward_any_to_hbm(aval.memory_space) + memory_space = _forward_any_to_hbm(aval.memory_space) # pyrefly: ignore[missing-attribute] if memory_space is _SEMAPHORE: kernel_buffer_ids.append( callback.io_callback( _allocate_semaphores, - jax.ShapeDtypeStruct(aval.shape, jnp.int16), + jax.ShapeDtypeStruct(aval.shape, jnp.int16), # pyrefly: ignore[missing-attribute] device_id, None, # local_core_id - aval.shape, + aval.shape, # pyrefly: ignore[missing-attribute] ordered=True, ) ) @@ -1877,7 +1877,7 @@ def interpret_pallas_call( None, # local_core_id, TPU_MEMORY_SPACE_IDXS[memory_space], interpret_params.get_uninitialized_array( - var.aval.shape, var.aval.dtype + var.aval.shape, var.aval.dtype # pyrefly: ignore[missing-attribute] ), ordered=True, ) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 9342d143f2ef..caba6bb3f6be 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -153,8 +153,8 @@ def _maybe_physicalize_block_shape(aval, block_shape): class LoweringDynamicShapeEnv: def __init__(self): - self.dim_expr_to_placeholder: dict[shape_poly._DimExpr, int] = {} - self.placeholder_to_dim_expr: dict[int, shape_poly._DimExpr] = {} + self.dim_expr_to_placeholder: dict[shape_poly._DimExpr, Any] = {} + self.placeholder_to_dim_expr: dict[Any, shape_poly._DimExpr] = {} def to_placeholder(self, dim_expr: Any) -> ir.Value: if jax_core.is_constant_dim(dim_expr): @@ -820,7 +820,7 @@ def lower_jaxpr_into_module( def dynamic_shape_replacement_fn( shape: jax_core.Shape, - ) -> tuple[int, ...]: + ) -> tuple[Any, ...]: assert _mosaic_lowering_dynamic_shape_env is not None return tuple( _mosaic_lowering_dynamic_shape_env.to_placeholder(dim_expr) @@ -1006,7 +1006,7 @@ def dynamic_shape_replacement_fn( args_dimvars = shape_poly.all_dim_vars(invars) # This is dimexpr var -> placeholder value for when we jit the dim expr - env: dict[str, int] = {} + env: dict[str, Any] = {} for aval in args_dimvars: env[aval] = _mosaic_lowering_dynamic_shape_env.to_placeholder(aval) @@ -1432,7 +1432,7 @@ def _maybe_cast_to_index(cast_to_index, x): def _index_to_start_size_stride( - idx: indexing.Slice | int | ir.Value, cast_to_index: bool + idx: Any, cast_to_index: bool ) -> tuple[ir.Value, int | ir.Value, int, bool]: assert not isinstance(idx, slice) if isinstance(idx, indexing.Slice): @@ -1452,7 +1452,7 @@ def _index_to_start_size_stride( size = 1 stride = 1 squeeze = True - return start, size, stride, squeeze + return start, size, stride, squeeze # pyrefly: ignore[bad-return] def _indexer_to_start_size_stride( @@ -1833,7 +1833,7 @@ def _prng_key_load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree def _maybe_cast_load_to_bool( ctx: LoweringRuleContext, out_aval, val: ir.Value -) -> tuple[ir.Value, jnp.dtype]: +) -> ir.Value: """Casts a memref load value to bool if the requested value is a bool. Mosaic does not support boolean-type memrefs, since booleans @@ -1845,7 +1845,7 @@ def _maybe_cast_load_to_bool( val: The input value. Returns: - The loaded value, and the JAX dtype of the input value. + The casted value. """ if out_aval.dtype != jnp.bool_: return val diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 9bc84874940f..521bcb473563 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -93,6 +93,8 @@ def _get_memory_space_from_aval( raise ValueError(f"Invalid kernel type for semaphore: {kernel_type}") case tpu_core.MemorySpace.HOST: return tpu_custom_call.MemorySpace.HOST + case _: + pass return None diff --git a/jax/_src/pallas/mosaic/sc_lowering.py b/jax/_src/pallas/mosaic/sc_lowering.py index be97c0f74f0e..f6eb06683cb3 100644 --- a/jax/_src/pallas/mosaic/sc_lowering.py +++ b/jax/_src/pallas/mosaic/sc_lowering.py @@ -1024,7 +1024,7 @@ def _extract_indirect_offsets_from_indexer( def _extract_indirect_offsets( - transforms: Sequence[ir.Value], expected_shape: tuple[int, ...] + transforms: Sequence[state.Transform], expected_shape: tuple[int, ...] ) -> tuple[ir.Value | None, Sequence[state.Transform]]: for i, indexer in enumerate(transforms): if not isinstance(indexer, indexing.NDIndexer): diff --git a/jax/_src/pallas/mosaic/tpu_info.py b/jax/_src/pallas/mosaic/tpu_info.py index 99f7d58b5e67..636ff9a16edb 100644 --- a/jax/_src/pallas/mosaic/tpu_info.py +++ b/jax/_src/pallas/mosaic/tpu_info.py @@ -72,8 +72,8 @@ def __str__(self) -> str: return self.value @property - def num_physical_tensor_cores_per_chip(self) -> int: - match self: + def _num_physical_tensor_cores_per_chip(self) -> int: # pyrefly: ignore[bad-return] # pyrefly#2080 + match self: # pyrefly: ignore[non-exhaustive-match] # pyrefly#2080 case ( ChipVersion.TPU_V2 | ChipVersion.TPU_V3 @@ -86,6 +86,11 @@ def num_physical_tensor_cores_per_chip(self) -> int: case ChipVersion.TPU_V4I | ChipVersion.TPU_V5E | ChipVersion.TPU_V6E: return 1 + @property + def num_physical_tensor_cores_per_chip(self) -> int: + # TODO(slebedev): Remove this wrapper once pyrefly#2080 is fixed. + return cast(int, self._num_physical_tensor_cores_per_chip) # type: ignore[redundant-cast] + @property def supports_megacore(self) -> bool: match self: @@ -280,7 +285,6 @@ def _get_tpu_info_impl(chip_version: ChipVersion, num_cores: int) -> TpuInfo: MXU_COLUMN_SIZE_GEN_LT_6 = 128 MXU_COLUMN_SIZE_GEN_GE_6 = 256 tensor_cores_per_chip = chip_version.num_physical_tensor_cores_per_chip - match chip_version: case ChipVersion.TPU_V2: return TpuInfo( diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 4e72d395a3ce..4be24c0ce8f5 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -1447,7 +1447,7 @@ def __post_init__(self): def to_mgpu(self, *args, **kwargs) -> mgpu.FragmentedLayout: if args or kwargs: raise ValueError(f"Can't instantiate {self} with arguments.") - return self.layout_cls.to_mgpu(*self.args, **self.kwargs) + return self.layout_cls.to_mgpu(*self.args, **self.kwargs) # pyrefly: ignore[bad-return] @dataclasses.dataclass(frozen=True) @@ -1492,7 +1492,7 @@ class Layout(SomeLayout, enum.Enum): def __call__(self, *args, **kwargs) -> ParameterizedLayout: return ParameterizedLayout(self, args, kwargs) - def to_mgpu(self, *args, **kwargs) -> mgpu.FragmentedLayout: # pyrefly: ignore[bad-override] + def to_mgpu(self, *args, **kwargs) -> mgpu.FragmentedLayout: # pyrefly: ignore[bad-override, bad-return] def check_no_args(): if args or kwargs: raise ValueError(f"Can't instantiate {self} with arguments.") diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 8a0cd88b3ff4..deadd6317e03 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -441,7 +441,7 @@ class ModuleContext: outer_traceback: xc.Traceback | None = None @property - def single_lane_predicate(self) -> ir.Value: + def single_lane_predicate(self) -> ir.Value | None: """Returns a predicate that is True for a single lane within the current thread semantics. """ @@ -914,7 +914,7 @@ def lower_jaxpr_to_module( jaxpr, ) - def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): + def body(launch_ctx: mgpu.LaunchContext, *buffers: Any): *buffers_gmem, ( runtime_smem, runtime_barriers, @@ -1441,13 +1441,13 @@ def _extract_aliased_ref( raise NotImplementedError("Only byte-aligned bitcasts are supported.") assert offset % gpu_core.SMEM_ALIGNMENT == 0 ref_bytes = ref_bits // 8 - ref = mgpu.memref_slice(ref, slice(offset, offset + ref_bytes)) + ref = mgpu.memref_slice(ref, slice(offset, offset + ref_bytes)) # pyrefly: ignore[bad-argument-type] ref = _handle_dtype_bitcast( ref, ir.MemRefType(ref.type).element_type, mgpu_utils.dtype_to_ir_type(dtype), ) - ref = mgpu.memref_reshape(ref, transformed_shape) + ref = mgpu.memref_reshape(ref, transformed_shape) # pyrefly: ignore[bad-assignment] return ( ref, ref_aval, @@ -1576,7 +1576,7 @@ def _handle_transforms( ref, ref_aval, transform_avals, transforms = _extract_aliased_ref( ref, ref_aval, transform_avals, transforms ) - transformed_ref = ref + transformed_ref: Any = ref new_transforms = [] new_transforms_avals = [] peer_device_id = None @@ -1671,18 +1671,18 @@ def _handle_transforms( " primitive." ) transformed_ref = ctx.launch_ctx.to_remote( - transformed_ref, _ensure_ir_value(peer_device_id, jnp.int32) + transformed_ref, _ensure_ir_value(peer_device_id, jnp.int32) # pyrefly: ignore[bad-argument-type] ) if is_multicast: - transformed_ref = ctx.launch_ctx.to_remote_multicast(transformed_ref) + transformed_ref = ctx.launch_ctx.to_remote_multicast(transformed_ref) # pyrefly: ignore[bad-argument-type] assert isinstance(ref_aval, state_types.AbstractRef) - return transformed_ref, ref_aval, new_transforms + return transformed_ref, ref_aval, new_transforms # pyrefly: ignore[bad-return] def _ndindexer_indices( indexer: indexing.NDIndexer, allow_arrays: bool = False -) -> tuple[gpu_core.Index | mgpu.FragmentedArray | ir.Value, ...]: - indices = [] +) -> tuple[Any, ...]: + indices: list[Any] = [] for idx in indexer.indices: if (isinstance(idx, mgpu.FragmentedArray) and idx.shape) or ( isinstance(idx, ir.Value) and isinstance(idx.type, ir.VectorType) # pytype: disable=attribute-error @@ -1699,7 +1699,7 @@ def _ndindexer_indices( raise NotImplementedError("Dynamic slice size not supported.") indices.append( mgpu.DynamicSlice( - _as_index(idx.start) if idx.is_dynamic_start else idx.start, + _as_index(idx.start) if idx.is_dynamic_start else idx.start, # pyrefly: ignore[bad-argument-type] int(idx.size), ) ) @@ -2825,11 +2825,11 @@ def _reduce_lowering_rule_wg( ) reduction = vector_dialect.ReductionOp(out_type, kind, x) else: - acc = vector_dialect.broadcast( + acc_vec = vector_dialect.broadcast( ir.VectorType.get(out_aval.shape, out_type), _ensure_ir_value(acc, out_aval.dtype), ) - reduction = vector_dialect.MultiDimReductionOp(kind, x, acc, axes) + reduction = vector_dialect.MultiDimReductionOp(kind, x, acc_vec, axes) def i32_attr(value: int) -> ir.IntegerAttr: return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), value) reduction.attributes["offset"] = i32_attr(ctx.module_ctx.smem_used_bytes) @@ -3304,7 +3304,7 @@ def _run_state_lowering_rule( ) assert not new_consts outs = lower_jaxpr_to_mosaic_gpu( - ctx.module_ctx, ctx.launch_ctx, discharged_jaxpr, new_input_vals, () + ctx.module_ctx, ctx.launch_ctx, discharged_jaxpr, new_input_vals, () # pyrefly: ignore[bad-argument-type] ) # Await the accumulators and extract their final values. nvvm_dialect.wgmma_wait_group_sync_aligned(0) @@ -3577,7 +3577,7 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches, index_aval, *_arg_avals = ctx.avals_in def _yielded_values(outs, avals): - ret = [] + ret: list[Any] = [] for out, aval in zip(outs, avals): if isinstance(out, (mgpu.WGMMAAccumulator, mgpu.FragmentedArray)): ret.append(out) @@ -3818,7 +3818,7 @@ def _ensure_ir_value(x: Any, dtype: jnp.dtype) -> ir.Value: return _ir_constant(x, mgpu_utils.dtype_to_ir_type(dtype)) -def _ensure_ir_value_device_id(device_id: Any) -> ir.Value: +def _ensure_ir_value_device_id(device_id: Any) -> Any: ensure_i32 = functools.partial(_ensure_ir_value, dtype=jnp.int32) if isinstance(device_id, tuple): return tuple(map(ensure_i32, device_id)) @@ -3997,6 +3997,7 @@ def _semaphore_signal_lowering_rule( raise NotImplementedError( f"Only JAX mesh axes can be used in device_id, but found {other_axes}" ) + assert device_id is not None sem = ctx.launch_ctx.to_remote(sem, device_id) sem_ptr = mgpu.utils.memref_ptr(sem) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 047bcd301510..a76f30769e33 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -98,7 +98,7 @@ def _print_layout_lowering( if transforms_leaves: assert isinstance(ctx.avals_in[0], state_types.AbstractRef) transform_avals = transforms_tree.unflatten(ctx.avals_in[1:]) - x, _, remaining_transforms = lowering._handle_transforms( + x, _, remaining_transforms = lowering._handle_transforms( # pyrefly: ignore[bad-specialization] ctx, ctx.avals_in[0], x, transform_avals, transforms_tree.unflatten(transforms_leaves), ) @@ -107,7 +107,7 @@ def _print_layout_lowering( f"Unsupported transforms {remaining_transforms}." ) if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: - print(fmt.format(mgpu.dialect_lowering.pprint_layout(x))) + print(fmt.format(mgpu.dialect_lowering.pprint_layout(x))) # pyrefly: ignore[bad-argument-type] else: assert isinstance(x, ir.Value) mgpu.dialect.print_layout(fmt, x) @@ -638,7 +638,7 @@ def _copy_gmem_to_smem_lowering( collective=collective, partitioned=partitioned_axis, **copy_params, - **predicate_kwarg, + **predicate_kwarg, # pyrefly: ignore[bad-argument-type] ) return () i32 = ir.IntegerType.get_signless(32) @@ -795,7 +795,7 @@ def _async_prefetch_lowering( collective=collective, partitioned=partitioned_axis, **copy_params, - **predicate_kwarg, + **predicate_kwarg, # type: ignore[arg-type] ) return () @@ -861,7 +861,7 @@ def async_prefetch( def _extract_barrier_slice_base(transforms) -> ir.Value | None: if not transforms: return None - base_index = None + base_index: ir.Value | None = None while transforms: match transforms: case [indexing.NDIndexer(indices=[idx]) as indexer, *transforms]: @@ -1521,6 +1521,7 @@ def _wgmma_accumulator_store_abstract_eval(acc, val): # the discharge rule re-binds the primitive and acc becomes a ShapedArray. if isinstance(acc, gpu_core.WGMMAAbstractAccumulatorRef): inner = acc.inner_aval + assert isinstance(inner, jax_core.ShapedArray) elif isinstance(acc, jax_core.ShapedArray): inner = acc else: @@ -1533,7 +1534,7 @@ def _wgmma_accumulator_store_abstract_eval(acc, val): raise ValueError( f"Accumulator dtype {inner.dtype} does not match value dtype {val.dtype}" ) - effects = {gpu_core._wgmma_pipeline_effect} + effects: set[jax_core.Effect] = {gpu_core._wgmma_pipeline_effect} if isinstance(acc, gpu_core.WGMMAAbstractAccumulatorRef): effects.add(state.WriteEffect(0)) return inner, effects @@ -1991,7 +1992,7 @@ def _tcgen05_mma_lowering( ) assert isinstance(a_sparse_metadata_ref_aval, state_types.AbstractRef) a_sparse_metadata_ref, _, a_sparse_metadata_transforms = ( - lowering._handle_transforms( + lowering._handle_transforms( # pyrefly: ignore[bad-specialization] ctx, a_sparse_metadata_ref_aval, a_sparse_metadata_ref, a_sparse_metadata_transform_avals, a_sparse_metadata_transforms) ) @@ -2022,6 +2023,7 @@ def _tcgen05_mma_lowering( collective=collective, ) if arrive: + assert barrier_ref is not None tcgen05.commit_arrive(barrier_ref, collective=collective, ctx=ctx.launch_ctx) @@ -2856,7 +2858,7 @@ def _populate_custom_primitive_op_block( are returned. """ with ir.InsertionPoint(block): - fn_inputs = [] + fn_inputs: list[ir.Value | mgpu.FragmentedArray] = [] in_layouts_it = iter(in_layouts) in_transforms_it = iter(in_transforms) avals_in = ctx.avals_in[:pytree_args.num_leaves] diff --git a/jax/_src/pallas/mosaic_gpu/torch.py b/jax/_src/pallas/mosaic_gpu/torch.py index 2f5445575d41..768824ab5e48 100644 --- a/jax/_src/pallas/mosaic_gpu/torch.py +++ b/jax/_src/pallas/mosaic_gpu/torch.py @@ -189,8 +189,8 @@ def prepare_args(*user_args, device): for thunk in to_evaluate: thunk(env, device) return tuple(env[name] for name in mgpu_arg_names) - output_input_aliases = [None] * len(mgpu_call.results) - for alias in mgpu_call.output_operand_aliases: + output_input_aliases: list[int | None] = [None] * len(mgpu_call.results) + for alias in mgpu_call.output_operand_aliases or []: alias = hlo.OutputOperandAlias(alias) if alias.operand_tuple_indices: raise NotImplementedError("Tupled operand indices not supported") diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 7ea91d8b0eb9..56e437e56d64 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -1488,7 +1488,7 @@ def device_id_to_logical( ), ), non_mesh_axes elif device_id_type is DeviceIdType.LOGICAL: - return device_id, non_mesh_axes + return device_id, non_mesh_axes # pyrefly: ignore[bad-return] raise NotImplementedError(f"Unsupported device id type: {device_id_type}") diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 45276e24d311..18d085e3f7d3 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -183,7 +183,7 @@ def _bcast( x_aval: jax_core.ShapedArray, y_aval: jax_core.ShapedArray, out_aval: jax_core.ShapedArray, -) -> ir.Value: +) -> tuple[ir.Value, ir.Value]: if isinstance( x, (np.ndarray, np.number, int, float, literals.TypedNdArray) ): @@ -250,7 +250,7 @@ def _process_grid_to_3d_grid(grid_mapping: GridMapping): prog_id_dims = launch_grid[num_collapse:] if len(collapse_dims) == 0: - prog_ids = [None] * len(prog_id_dims) + prog_ids: list[ir.Value | None] = [None] * len(prog_id_dims) for i in range(len(prog_id_dims)): prog_ids[launch_grid_to_pallas_grid[i]] = _program_id(i, prog_id_dims) @@ -261,7 +261,7 @@ def _process_grid_to_3d_grid(grid_mapping: GridMapping): assert new_grid[0] < 2**31 - 1, \ "Cannot fix pallas kernel launch grid within CUDA limits" - out_indices = [None] * len(grid_mapping.grid) + out_indices: list[ir.Value | None] = [None] * len(grid_mapping.grid) grid0 = _program_id(0, new_grid) for i, s in enumerate(collapse_dims): @@ -416,7 +416,7 @@ def write_env(var: jax_core.Var, val): rule_ctx = LoweringRuleContext(ctx, avals_in, avals_out, eqn_block_infos) # pyrefly: ignore[bad-argument-type] # pyrefly#2385 try: with source_info_util.user_context(eqn.source_info.traceback), loc: - outvals = rule(rule_ctx, *invals, **eqn.params) + outvals: Any = rule(rule_ctx, *invals, **eqn.params) except LoweringError: raise # We only add the extra info to the innermost exception. except Exception as e: @@ -635,7 +635,7 @@ def matches(self, avals: Sequence[jax_core.ShapedArray]) -> bool: for aval, arg_type in zip(avals, self.arg_types) ) - def lower(self, ctx: LoweringRuleContext, *args: Sequence[ir.Value]): + def lower(self, ctx: LoweringRuleContext, *args: ir.Value): [out_aval] = ctx.avals_out bcast_args = [] for aval, arg, arg_type in zip(ctx.avals_in, args, self.arg_types): @@ -670,7 +670,7 @@ def matches(self, avals: Sequence[jax_core.ShapedArray]) -> bool: for aval, arg_class in zip(avals, self.arg_classes) ) - def lower(self, ctx: LoweringRuleContext, *args: Sequence[ir.Value]): + def lower(self, ctx: LoweringRuleContext, *args: ir.Value): [out_aval] = ctx.avals_out bcast_args = [] for aval, arg in zip(ctx.avals_in, args): @@ -1381,7 +1381,7 @@ def debug_print_lowering_rule( def _set_attr(v: ir.Value, name: str, attr: ir.Attribute) -> None: if not isinstance(v, ir.BlockArgument): - v.owner.attributes[name] = attr + v.owner.attributes[name] = attr # pyrefly: ignore[missing-attribute] return arg = ir.BlockArgument(v) @@ -1433,7 +1433,7 @@ def _integer_pow_rule(ctx: LoweringRuleContext, x, *, y: int): if is_reciprocal: y = -y - acc = None + acc: ir.Value | None = None while y > 0: y, mod = divmod(y, 2) if mod: @@ -1554,7 +1554,7 @@ def _make_range(start: int, end: int) -> ir.Value: ) -def _full(t: ir.Type, v: Any) -> ir.Type: +def _full(t: ir.Type, v: Any) -> ir.Value: element_type = _element_type(t) if isinstance(element_type, ir.IntegerType): result = arith_dialect.constant(element_type, int(v)) @@ -1921,9 +1921,11 @@ def _compute_offsets_from_indices( if isinstance(index, primitives.Slice): if index.is_dynamic_start or (index.stride != 1): - start = index.start if not index.is_dynamic_start: - start = _ir_constant(start, offset_eltype) + start = _ir_constant(index.start, offset_eltype) + else: + assert isinstance(index.start, ir.Value) + start = index.start start = _ir_cast(start, offset_eltype, signed=False) iota = _ir_cast( @@ -2273,7 +2275,7 @@ def _masked_swap_lowering_rule( other = _bcast_to(value, shape) old_value = _load(ptr, mask=mask, other=other) - _store(ptr, value, mask=mask, eviction_policy=eviction_policy) + _store(ptr, value, mask=mask, eviction_policy=eviction_policy) # pyrefly: ignore[bad-argument-type] return old_value @@ -2605,11 +2607,11 @@ def _lower_jaxpr_to_for_loop( if step != 1: raise NotImplementedError if bound_type is None or bound_type.width == 32: - step = _i32_constant(step) + step_val = _i32_constant(step) else: - step = _i64_constant(step) + step_val = _i64_constant(step) - for_op = scf_dialect.ForOp(lower_bound, upper_bound, step, args) + for_op = scf_dialect.ForOp(lower_bound, upper_bound, step_val, args) with ir.InsertionPoint.at_block_begin(for_op.body): loop_index = for_op.induction_variable for_body_args = [for_op.body.arguments[i + 1] for i, _ in enumerate(args)] diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index 2e671c350ac3..b1e772130e4e 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -191,7 +191,7 @@ def pallas_call_lowering( # TODO(b/392558289): Migrate to ``jax.ffi``. return mlir.custom_call( call_target_name="triton_kernel_call", - result_types=mlir.flatten_ir_values( + result_types=mlir.flatten_ir_types( map(mlir.aval_to_ir_type, ctx.avals_out) ), operands=in_nodes, diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index cf4e32e2cde1..4c395dce29da 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -21,7 +21,7 @@ import functools import math import typing -from typing import Any, Literal +from typing import Any, Literal, overload import jax from jax import numpy as jnp @@ -671,6 +671,12 @@ def fold_until(shape, off, target) -> tuple[int, int]: return ref +@overload +def memref_reshape(ref: ir.Value, shape: tuple[int, ...]) -> ir.Value: ... + +@overload +def memref_reshape(ref: MultimemRef, shape: tuple[int, ...]) -> MultimemRef: ... # type: ignore[overload-cannot-match] + def memref_reshape( ref: ir.Value | MultimemRef, shape: tuple[int, ...] ) -> ir.Value | MultimemRef: From 3db21b7249d3ed5a56705d0e16a8f1773c2746b4 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 2 Mar 2026 10:36:07 -0800 Subject: [PATCH 003/100] [Mosaic GPU] Add support for s8 and u8 warp-level MMAs PiperOrigin-RevId: 877469331 --- jax/experimental/mosaic/gpu/mma.py | 49 +++++++++++++++++++----------- tests/mosaic/gpu_test.py | 39 ++++++++++++++++-------- 2 files changed, 58 insertions(+), 30 deletions(-) diff --git a/jax/experimental/mosaic/gpu/mma.py b/jax/experimental/mosaic/gpu/mma.py index 31caf1702991..6c68782e9a51 100644 --- a/jax/experimental/mosaic/gpu/mma.py +++ b/jax/experimental/mosaic/gpu/mma.py @@ -54,11 +54,15 @@ def __init__(self, element_type: ir.Type): ) -def _ptx_dtype_str(dtype: ir.Type) -> str: +def _ptx_dtype_str(dtype: ir.Type, *, is_signed: bool | None = None) -> str: if isinstance(dtype, ir.Float8E4M3FNType): return "e4m3" elif isinstance(dtype, ir.Float8E5M2Type): return "e5m2" + elif isinstance(dtype, ir.IntegerType): + if is_signed is None: + raise ValueError("is_signed must be specified for integer types") + return "s8" if is_signed else "u8" return str(dtype) @@ -66,13 +70,16 @@ def _mma_single_tile( acc: fa.FragmentedArray, a: fa.FragmentedArray, b: fa.FragmentedArray ) -> fa.FragmentedArray: """Performs `acc + a @ b.T` using warp level MMA instructions.""" + i32 = ir.IntegerType.get_signless(32) k_tile = 32 // utils.bytewidth(a.mlir_dtype) assert a.shape == (64, k_tile) assert b.shape == (8, k_tile) assert acc.shape == (64, 8) assert a.mlir_dtype == b.mlir_dtype - assert acc.mlir_dtype == ir.F32Type.get() + is_integer = isinstance(a.mlir_dtype, ir.IntegerType) + assert acc.mlir_dtype == i32 if is_integer else ir.F32Type.get() + assert acc.is_signed in {None, True} assert ( isinstance(acc.layout, fa.TiledLayout) and isinstance(a.layout, fa.TiledLayout) @@ -89,7 +96,6 @@ def _mma_single_tile( for reg in acc.registers.flatten() for pos in range(acc.layout.vector_length) ] - i32 = ir.IntegerType.get_signless(32) a_regs = [utils.bitcast(r, i32) for r in a.registers.flatten()] b_regs = [utils.bitcast(r, i32) for r in b.registers.flatten()] @@ -98,9 +104,11 @@ def _mma_single_tile( assert len(acc_regs) == 4 assert len(b_regs) == 2 - a_ptx_dtype = _ptx_dtype_str(a.mlir_dtype) - b_ptx_dtype = _ptx_dtype_str(b.mlir_dtype) - instr = f"mma.sync.aligned.m16n8k{k_tile}.row.col.f32.{a_ptx_dtype}.{b_ptx_dtype}.f32" + a_ptx_dtype = _ptx_dtype_str(a.mlir_dtype, is_signed=a.is_signed) + b_ptx_dtype = _ptx_dtype_str(b.mlir_dtype, is_signed=b.is_signed) + acc_ptx_dtype = "s32" if is_integer else "f32" + acc_constraint = "r" if is_integer else "f" + instr = f"mma.sync.aligned.m16n8k{k_tile}.row.col.{acc_ptx_dtype}.{a_ptx_dtype}.{b_ptx_dtype}.{acc_ptx_dtype}" counter = itertools.count() n_regs_str = lambda n: ( "{" + ",".join([f"${next(counter)}" for _ in range(n)]) + "}" @@ -112,10 +120,10 @@ def _mma_single_tile( ptx = f"{instr} {out_regs_str}, {a_regs_str}, {b_regs_str}, {c_regs_str};" # See: https://llvm.org/docs/LangRef.html#inline-assembler-expressions constraints = ( - f"{','.join(['=f']*num_acc_regs)}," # Output accumulator regs - f"{','.join(['r']*num_a_regs)}," # Input A regs + f"{','.join([f'={acc_constraint}']*num_acc_regs)}," + f"{','.join(['r']*num_a_regs)}," f"{','.join(['r']*num_b_regs)}," - f"{','.join(['f']*num_acc_regs)}" # Input accumulator regs + f"{','.join([acc_constraint]*num_acc_regs)}" ) in_operands = [*a_regs, *b_regs, *acc_regs] @@ -141,7 +149,7 @@ def _mma_single_tile( vec_regs.append(vec) out_regs = np.asarray(vec_regs, dtype=object).reshape(acc.registers.shape) return fa.FragmentedArray( - _registers=out_regs, _layout=acc.layout, _is_signed=None + _registers=out_regs, _layout=acc.layout, _is_signed=acc.is_signed ) @@ -185,19 +193,26 @@ def mma( # sharded across warps. bf16 = ir.BF16Type.get() f16 = ir.F16Type.get() + i8 = ir.IntegerType.get_signless(8) + i32 = ir.IntegerType.get_signless(32) f8e4m3fn = ir.Float8E4M3FNType.get() f8e5m2 = ir.Float8E5M2Type.get() - if a.mlir_dtype != b.mlir_dtype: + if (element_type := a.mlir_dtype) != b.mlir_dtype: raise ValueError(f"Dtype mismatch: {a.mlir_dtype} != {b.mlir_dtype}") - if a.mlir_dtype not in (bf16, f16, f8e4m3fn, f8e5m2): + if element_type not in (bf16, f16, f8e4m3fn, f8e5m2, i8): raise NotImplementedError( - "Only bf16, f16, float8_e4m3fn and float8_e5m2 supported for the" + "Only bf16, f16, float8_e4m3fn, float8_e5m2 and i8 supported for the" " operands." ) - if acc.mlir_dtype != ir.F32Type.get(): - raise NotImplementedError("Only f32 accumulator supported.") - - layouts = MMALayouts(a.mlir_dtype) + if element_type == i8: + if acc.mlir_dtype != i32: + raise NotImplementedError("Only s32 accumulator supported for i8 operands.") + if not acc.is_signed: + raise ValueError("Only signed accumulator supported for i8 operands.") + elif acc.mlir_dtype != ir.F32Type.get(): + raise NotImplementedError("Only f32 accumulator supported for floating operands.") + + layouts = MMALayouts(element_type) if layouts.lhs != a.layout: raise ValueError("Expected MMALayouts.lhs layout for A") if layouts.rhs != b.layout: diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 6bebcc9fe1ac..bcbd3d028ccc 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -4683,17 +4683,20 @@ def kernel(ctx, dst, _): np.testing.assert_array_equal(result, (iota > 10).astype(jnp.uint8)) @parameterized.product( - dtype=(jnp.bfloat16, jnp.float16, jnp.float8_e4m3fn, jnp.float8_e5m2), + dtype=(jnp.bfloat16, jnp.float16, jnp.float8_e4m3fn, jnp.float8_e5m2, + jnp.int8, jnp.uint8), ) def test_warp_mma(self, dtype): - dtype = jnp.dtype(dtype) m, n, k = 128, 128, 128 + dtype = jnp.dtype(dtype) + is_integer = jnp.issubdtype(dtype, jnp.integer) + acc_dtype = jnp.int32 if is_integer else jnp.float32 k_tile = 32 if dtype.itemsize == 1 else 16 def kernel(ctx: mgpu.LaunchContext, acc, a, b, out, scratch): (acc_smem, a_smem, b_smem), barrier = scratch layouts = mgpu.MMALayouts(utils.dtype_to_ir_type(dtype)) - def load(x, x_smem, layout, swizzle=32): + def load(x, x_smem, layout, dtype, swizzle=32): ctx.async_copy( src_ref=x, dst_ref=x_smem, @@ -4702,11 +4705,13 @@ def load(x, x_smem, layout, swizzle=32): barrier=barrier, ) barrier.wait() - return fa.FragmentedArray.load_tiled(x_smem, swizzle=swizzle, layout=layout) + return fa.FragmentedArray.load_tiled( + x_smem, swizzle=swizzle, layout=layout, is_signed=utils.is_signed(dtype) + ) - b_fa = load(b, b_smem, layouts.rhs) - a_fa = load(a, a_smem, layouts.lhs) - acc_fa = load(acc, acc_smem, layouts.acc) + b_fa = load(b, b_smem, layouts.rhs, dtype) + a_fa = load(a, a_smem, layouts.lhs, dtype) + acc_fa = load(acc, acc_smem, layouts.acc, acc_dtype) result_fa: mgpu.FragmentedArray = mgpu.mma(acc_fa, a_fa, b_fa) result_fa.store_tiled(acc_smem, swizzle=32) mgpu.commit_shared() @@ -4718,11 +4723,16 @@ def load(x, x_smem, layout, swizzle=32): ) ctx.await_async_copy(0) - a = self.prng.uniform(-1, 1, (m, k)).astype(dtype) - b = self.prng.uniform(-1, 1, (n, k)).astype(dtype) - acc = self.prng.uniform(-1, 1, (m, n)).astype(jnp.float32) + if is_integer: + a = self.prng.integers(-32, 32, (m, k)).astype(dtype) + b = self.prng.integers(-32, 32, (n, k)).astype(dtype) + acc = self.prng.integers(-100, 100, (m, n)).astype(acc_dtype) + else: + a = self.prng.uniform(-1, 1, (m, k)).astype(dtype) + b = self.prng.uniform(-1, 1, (n, k)).astype(dtype) + acc = self.prng.uniform(-1, 1, (m, n)).astype(acc_dtype) - expected = acc + a.astype(jnp.float32) @ b.astype(jnp.float32).T + expected = acc + a.astype(acc_dtype) @ b.astype(acc_dtype).T result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), @@ -4731,14 +4741,17 @@ def load(x, x_smem, layout, swizzle=32): out_shape=expected, smem_scratch_shape=( mgpu.Union([ - jax.ShapeDtypeStruct(mgpu.tile_shape((m, n), (8, 8)), dtype=jnp.float32), + jax.ShapeDtypeStruct(mgpu.tile_shape((m, n), (8, 8)), dtype=acc_dtype), jax.ShapeDtypeStruct(mgpu.tile_shape((m, k), (8, k_tile)), dtype=dtype), jax.ShapeDtypeStruct(mgpu.tile_shape((n, k), (8, k_tile)), dtype=dtype), ]), mgpu.Barrier(1) ), )(acc, a, b) - np.testing.assert_allclose(result, expected, atol=1e-5) + if is_integer: + np.testing.assert_array_equal(result, expected) + else: + np.testing.assert_allclose(result, expected, atol=1e-5) @parameterized.parameters( (jnp.uint8, jnp.uint16, 255), From c097ebedba41eb7d7b73773d08b29317a84be004 Mon Sep 17 00:00:00 2001 From: Yue Sheng Date: Mon, 2 Mar 2026 10:48:19 -0800 Subject: [PATCH 004/100] Add TPU v7 to `test_vmem_oom_error_message_basics`. PiperOrigin-RevId: 877475185 --- tests/pallas/tpu_pallas_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index c0b44f7c7d37..a74e17bcc1d3 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -2269,6 +2269,8 @@ def test_vmem_oom_error_message_basics(self, pmode: pl.Buffered): version=6, variant='e' ): block_shape = (4096 // pmode.buffer_count, 8192) + elif jtu.is_device_tpu(version=7, variant='x'): + block_shape = (2048 // pmode.buffer_count, 8192) elif jtu.is_device_tpu(version=5, variant='p'): block_shape = (1024, 8192) else: From 1bd107b633c5fd9d1a9e9ddc27a8d735dee5f5f8 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 2 Mar 2026 18:53:11 +0000 Subject: [PATCH 005/100] Bump actions/upload-artifact from 6.0.0 to 7.0.0 Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 6.0.0 to 7.0.0. - [Release notes](https://github.com/actions/upload-artifact/releases) - [Commits](https://github.com/actions/upload-artifact/compare/b7c566a772e6b6bfb58ed0dc250532a479d7789f...bbbca2ddaa5d8feaa63e36b76fdaad77386f024f) --- updated-dependencies: - dependency-name: actions/upload-artifact dependency-version: 7.0.0 dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- .github/workflows/rocm-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml index 1065e9eb2aea..7eaf4a58b2f2 100644 --- a/.github/workflows/rocm-ci.yml +++ b/.github/workflows/rocm-ci.yml @@ -46,7 +46,7 @@ jobs: dist_docker \ --image-tag $TEST_IMAGE - name: Archive jax wheels - uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 + uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 with: name: rocm_jax_r${{ env.ROCM_VERSION }}_py${{ env.PYTHON_VERSION }}_id${{ github.run_id }} path: ${{ env.WORKSPACE_DIR }}/dist/*.whl From a4dd1ddceb87c310102c70583d8bba942dbec87a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 2 Mar 2026 18:53:40 +0000 Subject: [PATCH 006/100] Bump actions/checkout from 6.0.0 to 6.0.2 Bumps [actions/checkout](https://github.com/actions/checkout) from 6.0.0 to 6.0.2. - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/checkout/compare/v6...de0fac2e4500dabe0009e67214ff5f5447ce83dd) --- updated-dependencies: - dependency-name: actions/checkout dependency-version: 6.0.2 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- .github/workflows/bazel_rocm.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/bazel_rocm.yml b/.github/workflows/bazel_rocm.yml index 862d42203912..1e2674497a25 100644 --- a/.github/workflows/bazel_rocm.yml +++ b/.github/workflows/bazel_rocm.yml @@ -117,7 +117,7 @@ jobs: name: "linux x86, jaxlib=${{ inputs.jaxlib-version }}, ROCM=${{ inputs.rocm-version }}, Python=${{ inputs.python }}, x64=${{ inputs.enable-x64 }}, build_jax=${{ inputs.build_jax }}, build_jaxlib=${{ inputs.build_jaxlib }}" # End Presubmit Naming Check github-rocm-presubmits steps: - - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false - name: ROCm Info From c0c98433ec3a28222819d79e2bf49530ce6db23f Mon Sep 17 00:00:00 2001 From: Yue Sheng Date: Mon, 2 Mar 2026 11:13:53 -0800 Subject: [PATCH 007/100] [NFC] Remove and modify some out-of-date test skips. PiperOrigin-RevId: 877489290 --- tests/pallas/tpu_ops_test.py | 31 ++++++++++++------------------ tests/pallas/tpu_pallas_test.py | 34 ++------------------------------- 2 files changed, 14 insertions(+), 51 deletions(-) diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index e04ae582ad7b..0ef112cdfcd7 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -135,8 +135,6 @@ def kernel(x_ref, y_ref): self.assertAllClose(y, x + 1) def test_interleave_vectors(self): - if not jtu.is_device_tpu_at_least(version=4): - self.skipTest("Expect TPUv4+") def kernel(x_ref, y_ref, out_ref): x = pltpu.bitcast(x_ref[...].astype(jnp.float32), jnp.int32) @@ -229,8 +227,10 @@ def pallas_fn(a, b, c, d): @parameterized.product(from_dtype=_JAX_INT_DTYPES, to_dtype=_JAX_INT_DTYPES) def test_integer_cast(self, from_dtype, to_dtype): - if not jtu.is_device_tpu_at_least(4): - self.skipTest("Expect TPUv4+") + if ( + jnp.iinfo(from_dtype).bits < 8 or jnp.iinfo(to_dtype).bits < 8 + ) and not jtu.is_device_tpu_at_least(4): + self.skipTest("sub-byte types casting requires TPUv4+") # Generate both low and high values to better cover the entire range # of the source dtype. min_val = from_dtype(jnp.iinfo(from_dtype).min) @@ -386,14 +386,11 @@ def kernel(x, out): ) def test_i1_relayout_bw(self, shape, msk_dtype, dtype): msk_bitwidth = dtypes.itemsize_bits(msk_dtype) - bitwidth = dtypes.itemsize_bits(dtype) - if jtu.get_tpu_version() < 5 and msk_bitwidth < 32: + if jtu.get_tpu_version() < 5 and msk_bitwidth < 16: self.skipTest( "Not implemented: cast vector to mask with bitwidth ==" f" {msk_bitwidth}" ) - if jtu.get_tpu_version() < 5 and bitwidth < 32: - self.skipTest(f"Not implemented: comparison with bitwidth == {bitwidth}") @functools.partial( pl.pallas_call, @@ -417,18 +414,14 @@ def kernel(x_ref, mask_ref, o_ref): dtype=[jnp.float32, jnp.bfloat16, jnp.int8], ) def test_i1_relayout_bw_tiling(self, msk_dtype, dtype): - self.skipTest("TODO: jevinjiang - Enable once presubmits pass.") shape = (256, 256) - bitwidth = dtypes.itemsize_bits(dtype) msk_bitwidth = dtypes.itemsize_bits(msk_dtype) msk_packing = 32 // msk_bitwidth - if jtu.get_tpu_version() < 5 and msk_bitwidth < 32: + if jtu.get_tpu_version() < 5 and msk_bitwidth < 16: self.skipTest( "Not implemented: cast vector to mask with bitwidth ==" f" {msk_bitwidth}" ) - if jtu.get_tpu_version() < 5 and bitwidth < 32: - self.skipTest(f"Not implemented: comparison with bitwidth == {bitwidth}") # Creating large tiling for masks by passing i32 vector first and # then bitcast to msk_dtype so the tiling is also bitcasted from @@ -515,8 +508,6 @@ def kernel(x, indices, out): @parameterized.product(dtype=[jnp.float32, jnp.bfloat16]) def test_float_div(self, dtype): - if not jtu.is_device_tpu_at_least(version=4): - self.skipTest("Requires TPUv4+") kwargs = {} if jtu.is_device_tpu_at_least(version=6): kwargs.update(dict(rtol=1e-2)) @@ -537,7 +528,7 @@ def kernel(x, y, out): ) def test_concat_mask(self, dtype): bitwidth = dtypes.itemsize_bits(dtype) - if jtu.get_tpu_version() < 5 and bitwidth < 32: + if jtu.get_tpu_version() < 5 and bitwidth < 16: self.skipTest( f"Not implemented: cast vector to mask with bitwidth == {bitwidth}" ) @@ -788,9 +779,9 @@ def _pack_unpack_elementwise_test_data( ) def test_pack_elementwise(self, config, shape): unpacked_dtype, packed_dtype = config - if not jtu.is_device_tpu_at_least(version=5): - self.skipTest("Requires TPU v5+") if packed_dtype == jnp.int2: + if not jtu.is_device_tpu_at_least(version=5): + self.skipTest("Requires TPU v5+") if not jtu.is_cloud_tpu_at_least(2026, 3, 1): raise self.skipTest( "int2 is only supported for tpu at least 03/01/2026" @@ -832,7 +823,9 @@ def kernel(xs_ref, o_ref): ) def test_unpack_elementwise(self, config, index, shape): unpacked_dtype, packed_dtype = config - if not jtu.is_device_tpu_at_least(version=5): + if packed_dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least( + version=5 + ): self.skipTest("Requires TPU v5+") bitwidth = dtypes.itemsize_bits(packed_dtype) diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index a74e17bcc1d3..53a3ac140898 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -1247,8 +1247,6 @@ def kernel(x_hbm_ref, y_hbm_ref, sem_val_ref, dma_sem): np.testing.assert_array_equal(sem_val, 0) def test_set_dma_priority(self): - if jtu.get_tpu_version() < 5: - self.skipTest('Target does not support DMA prefetch between HBM and VMEM') def kernel(x1, x2, y1, y2, scratch1, scratch2, sem1, sem2): copy1 = pltpu.async_copy(x1, scratch1, sem1, priority=1) copy2 = pltpu.async_copy(x2, scratch2, sem2, priority=0) @@ -1541,8 +1539,8 @@ def body(x_ref, y_ref, sem): np.testing.assert_allclose(y, x) def test_dma_with_regular_semaphore(self): - if not jtu.is_device_tpu_at_least(6): - self.skipTest('Regular semaphores in DMAs require TPU v6+') + if not jtu.is_device_tpu_at_least(5): + self.skipTest('Regular semaphores in DMAs require TPU v5+') if not jtu.is_cloud_tpu_at_least(2026, 3, 2): self.skipTest("Test requires a newer libtpu") @@ -1775,9 +1773,6 @@ def _(): np.testing.assert_array_equal(y, x + 3) def test_hoisted_smem_space(self): - # TODO(sharadmv,apaszke): enable SMEM scratch spaces - # TODO(sharadmv,apaszke): add support for ()-shaped SMEM refs - self.skipTest('Currently doesn\'t work') def kernel(y_ref, scratch_ref): scratch_ref[0, 0] = pl.program_id(0) y_ref[...] = jnp.broadcast_to(scratch_ref[0, 0], y_ref.shape) @@ -2102,12 +2097,6 @@ def kernel(x, out): def test_replicated_broadcast_reduction( self, m, replicated, reduced_dims, dty, reduce_func ): - # TODO(b/395579834): Remove this skip later. - if ( - dty == jnp.int32 - and 1 in reduced_dims - ): - self.skipTest('Requires libtpu built after 2025-09-01') if not jtu.is_device_tpu_at_least(4) and len(replicated) == 2: self.skipTest( 'Brodcast in both sublanes and lanes not supported on this hardware' @@ -3417,8 +3406,6 @@ def kernel(mask_ref, true_ref, false_ref, o_ref): np.testing.assert_array_equal(result, expected) def test_bool_dma_not_implemented(self): - if not jtu.is_device_tpu_at_least(4): - self.skipTest('DMAs not supported on TPU generations <= 3') if self.INTERPRET: self.skipTest('Test only applies to non-interpret mode.') num_devices = jax.local_device_count() @@ -3664,8 +3651,6 @@ def kernel(x_ref, y_ref, out_ref): np.testing.assert_array_equal(out, np.stack([x, y], axis=1)) def test_lane_to_chunk_reshape_bf16(self): - if not jtu.is_device_tpu_at_least(4): - self.skipTest('Operation not supported on this TPU version.') x = np.arange(256 * 1024, dtype=jnp.bfloat16).reshape(1, 256, 1024) def kernel(x_ref, out_ref): @@ -3805,8 +3790,6 @@ def kernel(x_ref, out_ref): np.testing.assert_array_equal(out, np.reshape(x, (8, 1, 128))) def test_sublane_adding_shape_cast_bf16(self): - if not jtu.is_device_tpu_at_least(4): - self.skipTest('Operation not supported on this TPU version.') x = np.arange(8 * 128, dtype=jnp.bfloat16).reshape(8, 128) def kernel(x_ref, out_ref): @@ -3898,10 +3881,6 @@ def kernel(x_ref, o_ref): ) ) def test_reshape_two_minor_dims_to_R2(self, q, m, n, dtype): - if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( - dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) - ): - self.skipTest('Operation not supported on this TPU version.') def kernel(x_ref, y_ref): y_ref[...] = x_ref[...].reshape( x_ref.shape[0], x_ref.shape[1] * x_ref.shape[2] @@ -3931,10 +3910,6 @@ def kernel(x_ref, y_ref): ) ) def test_reshape_two_minor_dims_to_R3(self, q, m, n, k, dtype): - if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( - dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) - ): - self.skipTest('Operation not supported on this TPU version.') def kernel(x_ref, y_ref): y_ref[...] = x_ref[...].reshape( x_ref.shape[0], x_ref.shape[1], x_ref.shape[2] * x_ref.shape[3] @@ -4059,10 +4034,6 @@ def kernel(x_ref, y_ref): ) ) def test_reshape_two_minor_dims_preserve_rank(self, q, m, n, k, dtype): - if (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) or ( - dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5) - ): - self.skipTest('Operation not supported on this TPU version.') def kernel(x_ref, y_ref): y_ref[...] = ( x_ref[...] @@ -4074,7 +4045,6 @@ def kernel(x_ref, y_ref): ) ) - q, m, n, k = 10, 1, 4, 256 x = np.arange(q * m * n * k, dtype=dtype).reshape(q, m, n, k) out = self.pallas_call( kernel, From 6bbd83d949ed8543dfd6db5ce58c543b2ef926a5 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Mon, 2 Mar 2026 11:44:58 -0800 Subject: [PATCH 008/100] Skip lax_numpy_reducers_test reducer tests with where. Requires a new nightly libtpu (latest is libtpu-0.0.37.dev20260224). PiperOrigin-RevId: 877504588 --- tests/lax_numpy_reducers_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index b868c2fa4694..e388f71e08f9 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -425,6 +425,8 @@ def np_fun(x): )) def testReducerWhere(self, name, rng_factory, shape, dtype, axis, keepdims, initial, inexact, whereshape, tol): + if not jtu.is_cloud_tpu_at_least(2026, 2, 28): + self.skipTest("Requires a newer libTPU") np_op = getattr(np, name) jnp_op = getattr(jnp, name) if (shape in [()] + scalar_shapes and @@ -488,6 +490,8 @@ def testReducerWhereNonBooleanErrorNoInitial(self, rec): )) def testReducerWhereNoInitial(self, name, rng_factory, shape, dtype, axis, keepdims, inexact, whereshape, tol): + if not jtu.is_cloud_tpu_at_least(2026, 2, 28): + self.skipTest("b/412684823: Requires a newer libTPU") np_op = getattr(np, name) jnp_op = getattr(jnp, name) rng = rng_factory(self.rng()) From 2260303e3e7ffb8044f3460c0ddff96fb5f04194 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Mon, 2 Mar 2026 11:45:00 -0800 Subject: [PATCH 009/100] Skip installing collecting profile requirements under 3.13-nogil. xprof pulls in cffi, which is not supported under 3.13-nogil. PiperOrigin-RevId: 877504615 --- .github/workflows/pytest_tpu.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/pytest_tpu.yml b/.github/workflows/pytest_tpu.yml index b6cf6c9d5037..22cd877d063e 100644 --- a/.github/workflows/pytest_tpu.yml +++ b/.github/workflows/pytest_tpu.yml @@ -93,7 +93,11 @@ jobs: gcs_download_uri: ${{ inputs.gcs_download_uri }} - name: Install Python dependencies run: | - $JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt -r build/collect-profile-requirements.txt + $JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt + # xprof depends on cffi, which doesn't support Python 3.13 free-threaded. + if [[ "${JAXCI_HERMETIC_PYTHON_VERSION}" != "3.13-nogil" ]]; then + $JAXCI_PYTHON -m uv pip install -r build/collect-profile-requirements.txt + fi - name: Set up libtpu wheels run: | if [[ "${INPUTS_LIBTPU_VERSION_TYPE}" == "nightly" ]]; then @@ -124,10 +128,6 @@ jobs: - name: Run Pytest TPU tests timeout-minutes: ${{ github.event_name == 'pull_request' && 30 || 210 }} run: | - if [[ ${INPUTS_PYTHON} == "3.13-nogil" ]]; then - echo "Uninstalling xprof as it is not compatible with python 3.13t." - $JAXCI_PYTHON -m uv pip uninstall xprof - fi ./ci/run_pytest_tpu.sh env: INPUTS_PYTHON: ${{ inputs.python }} From c9de85a78e656137fb845dd2e1b041b9777fd01d Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 2 Mar 2026 11:50:15 -0800 Subject: [PATCH 010/100] [pyrefly] fix typing errors in jax/_src/array.py --- jax/_src/array.py | 39 ++++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index 1ce52849a9de..f2a0c2888928 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -334,7 +334,7 @@ def __format__(self, format_spec): else: return repr(self) - def __getitem__(self, idx): + def __getitem__(self, idx): # pyrefly: ignore[bad-param-name-override] from jax._src.lax import lax # pytype: disable=import-error from jax._src.numpy import indexing # pytype: disable=import-error self._check_if_deleted() @@ -360,7 +360,7 @@ def __getitem__(self, idx): dims = tuple(i for i, x in enumerate(cidx) if isinstance(x, int)) # Squeeze on committed arrays to avoid data movement to shard 0. out = lax.squeeze(out, dimensions=dims) - + assert isinstance(out, ArrayImpl) return ArrayImpl( out.aval, sharding, [out], committed=False, _skip_checks=True) @@ -372,7 +372,7 @@ def __iter__(self): else: assert self.is_fully_replicated or self.is_fully_addressable if dispatch.is_single_device_sharding(self.sharding) or self.is_fully_replicated: - return (sl for chunk in self._chunk_iter(100) for sl in chunk._unstack()) + return (sl for chunk in self._chunk_iter(100) for sl in chunk._unstack()) # pyrefly: ignore[missing-attribute] elif isinstance(self.sharding, PmapSharding): return (self[i] for i in range(self.shape[0])) else: @@ -433,10 +433,10 @@ def is_fully_addressable(self) -> bool: """ return self.sharding.is_fully_addressable - def __array__(self, dtype=None, context=None, copy=None): + def __array__(self, dtype=None, context=None, copy=None): # pyrefly: ignore[bad-override] # copy argument is supported by np.asarray starting in numpy 2.0 kwds = {} if copy is None else {'copy': copy} - return np.asarray(self._value, dtype=dtype, **kwds) + return np.asarray(self._value, dtype=dtype, **kwds) # pyrefly: ignore[no-matching-overload] def __dlpack__(self, *, stream: int | Any | None = None, max_version: tuple[int, int] | None = None, @@ -464,10 +464,10 @@ def __dlpack_device__(self) -> tuple[enum.Enum, int]: from jax._src.dlpack import DLDeviceType # pytype: disable=import-error # pylint: disable=g-import-not-at-top - if self.platform() == "cpu": + if self.platform() == "cpu": # pyrefly: ignore[missing-attribute] return DLDeviceType.kDLCPU, 0 - elif self.platform() == "gpu": + elif self.platform() == "gpu": # pyrefly: ignore[missing-attribute] platform_version = _get_device(self).client.platform_version if "cuda" in platform_version: dl_device_type = DLDeviceType.kDLCUDA @@ -486,7 +486,7 @@ def __dlpack_device__(self) -> tuple[enum.Enum, int]: else: raise BufferError( "__dlpack__ device only supported for CPU and GPU, got platform: " - f"{self.platform()}" + f"{self.platform()}" # pyrefly: ignore[missing-attribute] ) def __reduce__(self): @@ -533,7 +533,7 @@ def device_buffers(self): def addressable_data(self, index: int) -> ArrayImpl: self._check_if_deleted() if self.is_fully_replicated: - return self._fully_replicated_shard() + return self._fully_replicated_shard() # pyrefly: ignore[missing-attribute] return self._arrays[index] @functools.cached_property @@ -550,7 +550,7 @@ def format(self): if self.is_deleted(): return Format(None, self.sharding) try: - return Format(Layout.from_pjrt_layout(self._pjrt_layout), + return Format(Layout.from_pjrt_layout(self._pjrt_layout), # pyrefly: ignore[missing-attribute] self.sharding) except _jax.JaxRuntimeError as e: msg, *_ = e.args @@ -586,8 +586,8 @@ def delete(self): return for buf in self._arrays: buf.delete() - self._arrays = None - self._npy_value = None + self._arrays = None # pyrefly: ignore[bad-assignment] + self._npy_value = None # pyrefly: ignore[bad-assignment] @use_cpp_method() def is_deleted(self): @@ -760,11 +760,12 @@ def make_array_from_callback( raise TypeError( "`Layout.AUTO` cannot be used in place of a device-local" f" layout when calling `jax.make_array_from_callback`. Got {sharding}") - sharding = sharding.sharding if isinstance(sharding, Format) else sharding - if not isinstance(sharding, Sharding): + processed_sharding = sharding.sharding if isinstance(sharding, Format) else sharding + if not isinstance(processed_sharding, Sharding): raise TypeError( - f"sharding should be an instance of `jax.sharding`. Got {sharding} of" - f" type {type(sharding)}") + f"sharding should be an instance of `jax.sharding`. Got {processed_sharding} of" + f" type {type(processed_sharding)}") + sharding = processed_sharding def get_data( index: Index | None, @@ -1102,11 +1103,11 @@ def make_array_from_single_device_arrays( if dtypes.issubdtype(aval.dtype, dtypes.extended): return aval.dtype._rules.make_sharded_array(aval, sharding, arrays, committed=True) - arrays = list(arrays) if isinstance(arrays, tuple) else arrays + arrays = list(arrays) if isinstance(arrays, tuple) else arrays # pyrefly: ignore[no-matching-overload] # pyrefly#2607 # TODO(phawkins): ideally the cast() could be checked. try: return ArrayImpl(aval, sharding, cast(Sequence[ArrayImpl], arrays), - committed=True) + committed=True) except TypeError: if not isinstance(arrays, list): raise TypeError("jax.make_array_from_single_device_arrays `arrays` " @@ -1155,7 +1156,7 @@ def as_slice_indices(arr: Any, idx: Index) -> tuple[ removed_dims: list[int] = [] tuple_idx = idx if isinstance(idx, tuple) else (idx,) - for dim, sub_idx in enumerate(tuple_idx): + for dim, sub_idx in enumerate(tuple_idx): # pyrefly: ignore[bad-argument-type] if isinstance(sub_idx, int): start_indices[dim] = sub_idx limit_indices[dim] = sub_idx + 1 From 13792508444e18d170258b8c00c17ead6d14ead7 Mon Sep 17 00:00:00 2001 From: Levon Ter-Grigoryan Date: Mon, 2 Mar 2026 11:55:28 -0800 Subject: [PATCH 011/100] [Mosaic:GPU] Fix the barrier value creation. PiperOrigin-RevId: 877509481 --- jaxlib/mosaic/gpu/custom_call.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 48bdd462a3da..0012488f07aa 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -1213,7 +1213,7 @@ absl::Status MosaicGpuInitialize( &barrier_signal_buffer_address, barrier_signal_buffer_address.size())); } - if (device_state.metadata_handle.address().is_null()) { + if (device_state.barrier_signal_value_buffer_handle.address().is_null()) { device_state.barrier_signal_value_buffer_handle = se::DeviceAddressHandle{ collective_params->executor, collective_params->executor->Allocate( From 5ef2c50381057ff0216f989eb62ff96f3415a7c6 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 2 Mar 2026 12:34:38 -0800 Subject: [PATCH 012/100] [pyrefly] fix errors in jax/_src/shard_map.py --- jax/_src/core.py | 6 +++--- jax/_src/custom_derivatives.py | 4 ++-- jax/_src/custom_transpose.py | 2 +- jax/_src/shard_map.py | 38 ++++++++++++++++++++-------------- 4 files changed, 28 insertions(+), 22 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index c96b6c862853..6c9160c9e234 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -653,7 +653,7 @@ def _true_bind(self, *args, **params): finally: trace_ctx.set_trace(prev_trace) - def bind_with_trace(self, trace, args, params): + def bind_with_trace(self, trace, args, params, /): # TODO(mattjj,dougalm): remove this block? try: in_type = map(typeof, args) except: pass # try lojax error message @@ -2997,7 +2997,7 @@ class CallPrimitive(Primitive): def bind(self, *args, **params): return self._true_bind(*args, **params) - def bind_with_trace(self, trace, fun_and_args, params): + def bind_with_trace(self, trace, fun_and_args, params, /): fun = fun_and_args[0] args = fun_and_args[1:] return trace.process_call(self, fun, args, params) @@ -3040,7 +3040,7 @@ class MapPrimitive(Primitive): def bind(self, *args, **params): return self._true_bind(*args, **params) - def bind_with_trace(self, trace, fun_and_args, params): + def bind_with_trace(self, trace, fun_and_args, params, /): fun: lu.WrappedFun = fun_and_args[0] args = fun_and_args[1:] assert len(params['in_axes']) == len(args) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 33b4da6dfe95..0ceaf3ae1c8c 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -389,7 +389,7 @@ class CustomJVPCallPrimitive(core.Primitive): def bind(self, *args, **params): return self._true_bind(*args, **params) - def bind_with_trace(self, trace, args, params): + def bind_with_trace(self, trace, args, params, /): fun, jvp, tracers = args[0], args[1], args[2:] return trace.process_custom_jvp_call(self, fun, jvp, tracers, **params) @@ -997,7 +997,7 @@ class CustomVJPCallPrimitive(core.Primitive): def bind(self, *args, **params): return self._true_bind(*args, **params) - def bind_with_trace(self, trace, args, params): + def bind_with_trace(self, trace, args, params, /): fun, fwd, bwd, tracers = args[0], args[1], args[2], args[3:] return trace.process_custom_vjp_call(self, fun, fwd, bwd, tracers, **params) diff --git a/jax/_src/custom_transpose.py b/jax/_src/custom_transpose.py index 1762522a8deb..2d2c0f4285f2 100644 --- a/jax/_src/custom_transpose.py +++ b/jax/_src/custom_transpose.py @@ -170,7 +170,7 @@ class CustomTransposePrimitive(core.Primitive): def bind(self, *args, **params): return self._true_bind(*args, **params) - def bind_with_trace(self, trace, call_args, params): + def bind_with_trace(self, trace, call_args, params, /): call, tracers = call_args[0], call_args[1:] return trace.process_custom_transpose(self, call, tracers, **params) diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index 618791b9ba30..49a320805aa5 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -403,7 +403,8 @@ def _shmap_checks(mesh, axis_names, in_specs, out_specs, _smap): def _manual_spec(manual_axes, spec: P, mesh) -> P: - out = [] # type: ignore + out: list[str | tuple[str, ...] | None] = [] # type: ignore + s: str | None | tuple[str, ...] for s in spec: if s is None: out.append(s) @@ -413,7 +414,7 @@ def _manual_spec(manual_axes, spec: P, mesh) -> P: temp.pop() if None in temp: raise ValueError(f"Invalid spec: {spec}") - out.append(None if len(temp) == 0 else tuple(temp)) + out.append(None if len(temp) == 0 else tuple(temp)) # type: ignore[arg-type] else: out.append(s if s in manual_axes else None) _check_unreduced(SpecErrorType.input, mesh, manual_axes, spec) @@ -539,6 +540,7 @@ def _spec_rank_error( ba = _try_infer_args(f, tree) else: prefix, base = 'out', f'{fun_name}(*args)' + ba = None msgs = [] for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails): extra = "" @@ -722,7 +724,7 @@ class ShardMapPrimitive(core.Primitive): def bind(self, *args, **params): return self._true_bind(*args, **params) - def bind_with_trace(self, trace, fun_and_args, params): + def bind_with_trace(self, trace, fun_and_args, params, /): fun: lu.WrappedFun fun, *args = fun_and_args return trace.process_shard_map(shard_map_p, fun, args, **params) @@ -772,6 +774,8 @@ def _shard_map_staging( args = [lo_val for x in args for lo_val in typeof(x).lower_val(x)] out_specs_thunk = (lambda t: lambda: [x for s in t() for x in s.to_lo()])(out_specs_thunk) f, hi_avals_out = _lojax_traceable(f, hi_avals_in, unk_names=True) + else: + hi_avals_out = None to_jaxpr_tracer = partial(trace.to_jaxpr_tracer, source_info=source_info) in_tracers = map(to_jaxpr_tracer, args) # pyrefly: ignore[bad-assignment] # pyrefly#2385 inner_mesh = _as_manual_mesh(mesh, manual_axes) @@ -799,6 +803,7 @@ def _shard_map_staging( out = trace.emit_eqn([*const_tracers, *in_tracers], out_avals, prim, params, effs, source_info) if trace.requires_low: + assert hi_avals_out is not None out = pe.raise_lo_outs(hi_avals_out(), out) return out pe.DynamicJaxprTrace.process_shard_map = _shard_map_staging @@ -1162,7 +1167,7 @@ def _shard_map_impl(trace, prim, fun, args, *, mesh, in_specs, out_specs_thunk, in_vma = map(_spec_to_vma, in_specs) outs, out_vma = _run_shmap(fun, mesh, manual_axes, args, in_vma, check_vma) out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs] - _check_names(out_specs_thunk(), out_avals) # pytype: disable=wrong-arg-types + _check_names(out_specs_thunk(), out_avals) # type: ignore[arg-type] if check_vma: _check_vmas(mesh, out_specs_thunk(), out_avals) src_pspecs = tuple(_vma_to_spec(mesh, r) for r in out_vma) @@ -1198,7 +1203,7 @@ def _unmatch2(mesh, prev_manual, spec, x): src = P(order_wrt_mesh(mesh, prev_manual), *spec) newly_manual = _spec_to_vma(spec) dst = P(order_wrt_mesh(mesh, prev_manual | newly_manual)) - return shard_map(lambda x: x, in_specs=src, out_specs=dst, + return shard_map(lambda x: x, in_specs=src, out_specs=dst, # pyrefly: ignore[no-matching-overload] axis_names=prev_manual | newly_manual)(x) def _match_spec2(mesh, prev_manual, spec, x) -> JaxType: @@ -1210,7 +1215,7 @@ def _match2(mesh, prev_manual, spec, x): newly_manual = _spec_to_vma(spec) src = P(order_wrt_mesh(mesh, prev_manual | newly_manual)) dst = P(order_wrt_mesh(mesh, prev_manual), *spec) - return shard_map(lambda x: x, in_specs=src, out_specs=dst, + return shard_map(lambda x: x, in_specs=src, out_specs=dst, # pyrefly: ignore[no-matching-overload] axis_names=prev_manual | newly_manual)(x) @@ -1406,28 +1411,28 @@ def __init__(self, trace, vma, val): @property def aval(self): aval = core.get_aval(self.val) - vma = self.vma if self._trace.check else self._trace.manual_axes - size = prod(self._trace.mesh.shape[n] for n in vma) + vma = self.vma if self._trace.check else self._trace.manual_axes # pyrefly: ignore[missing-attribute] + size = prod(self._trace.mesh.shape[n] for n in vma) # pyrefly: ignore[missing-attribute] out = core.mapped_aval(size, 0, aval) new_sharding = NamedSharding( - _as_manual_mesh(self._trace.amesh, self._trace.manual_axes), - out.sharding.spec) # pytype: disable=attribute-error + _as_manual_mesh(self._trace.amesh, self._trace.manual_axes), # pyrefly: ignore[missing-attribute] + out.sharding.spec) # type: ignore[missing-attribute] vma = self.vma if config._check_vma.value else frozenset() return out.update(sharding=new_sharding, vma=vma) def to_concrete_value(self): - if self._trace.check and self.vma == frozenset(): - with core.eval_context(), use_abstract_mesh(self._trace.amesh): + if self._trace.check and self.vma == frozenset(): # pyrefly: ignore[missing-attribute] + with core.eval_context(), use_abstract_mesh(self._trace.amesh): # pyrefly: ignore[missing-attribute] return core.to_concrete_value(self.val[0]) else: return None def __str__(self) -> str: - pb_names = set(self._trace.mesh.axis_names) - self.vma + pb_names = set(self._trace.mesh.axis_names) - self.vma # pyrefly: ignore[missing-attribute] self = pvary(self, tuple(pb_names)) - with core.eval_context(), use_abstract_mesh(self._trace.amesh): + with core.eval_context(), use_abstract_mesh(self._trace.amesh): # pyrefly: ignore[missing-attribute] blocks = list(self.val) - mesh = self._trace.mesh + mesh = self._trace.mesh # pyrefly: ignore[missing-attribute] axis_names = f"({', '.join(map(str, mesh.axis_names))},)" return '\n'.join( f"On {device} at mesh coordinates {axis_names} = {idx}:\n{block}\n" @@ -1832,6 +1837,7 @@ def new_out_specs_thunk(): raise e2 from None else: api_util._raise_no_nan_in_deoptimized(e) + raise # will never get here. except _RepError as e: fails, = e.args msg = _inout_vma_error( @@ -1910,7 +1916,7 @@ def _add_reshapes(which: Sequence[bool], jaxpr_known: core.Jaxpr, jaxpr_staged: core.Jaxpr) -> tuple[core.Jaxpr, core.Jaxpr]: # add singleton axes to residuals which are from jaxpr_known and are scalars - which_ = [w and not v.aval.shape # pytype: disable=attribute-error + which_ = [w and not v.aval.shape # type: ignore[missing-attribute] for w, v in zip(which, jaxpr_staged.invars[:len(which)])] if not any(which_): return jaxpr_known, jaxpr_staged assert not jaxpr_known.constvars and not jaxpr_staged.constvars From b80f8e8207840d7f4bb27c59d7e4de36b00b178c Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 2 Mar 2026 12:40:56 -0800 Subject: [PATCH 013/100] [pyrefly] fix pyrefly errors in jax._src.core --- jax/_src/basearray.pyi | 3 +- jax/_src/core.py | 62 ++++++++++++++++++++++++++++-------------- 2 files changed, 43 insertions(+), 22 deletions(-) diff --git a/jax/_src/basearray.pyi b/jax/_src/basearray.pyi index 1d6cb167a026..0ad9fb5a11fe 100644 --- a/jax/_src/basearray.pyi +++ b/jax/_src/basearray.pyi @@ -42,7 +42,8 @@ PrecisionLike = Any class Array: - aval: Any + @property + def aval(self) -> Any: ... @property def dtype(self) -> np.dtype: ... diff --git a/jax/_src/core.py b/jax/_src/core.py index c96b6c862853..adf6b9358f41 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -196,7 +196,7 @@ def pretty_print(self, *, source_info=False, print_shapes=True, self, source_info=source_info, print_shapes=print_shapes, custom_pp_eqn_rules=custom_pp_eqn_rules, name_stack=name_stack, print_effects=print_effects) - return doc.format(**kwargs) + return doc.format(**kwargs) # pyrefly: ignore[missing-attribute] def _repr_pretty_(self, p, cycle): return p.text(self.pretty_print(use_color=True)) @@ -958,12 +958,16 @@ def full_lower(self): raise NotImplementedError("must override: ", type(self)) def __iter__(self): + if not hasattr(self.aval, "_iter"): + raise TypeError(f"Value of type {type(self)} is not iterable.") return iter(self.aval._iter(self)) def __reversed__(self): return iter(self[::-1]) def __len__(self): + if not hasattr(self.aval, "_len"): + raise TypeError(f"Value of type {type(self)} has no length.") return self.aval._len(self) def to_concrete_value(self): @@ -1003,10 +1007,12 @@ def addressable_shards(self): @property def at(self): + if not hasattr(self.aval, "at"): + raise TypeError(f"Value of type {type(self)} does not support at().") return self.aval.at.fget(self) @property - def aval(self): + def aval(self) -> AbstractValue: raise NotImplementedError("must override") def get_referent(self) -> Any: @@ -1015,34 +1021,48 @@ def get_referent(self) -> Any: def __bool__(self): if is_concrete(self): return bool(self.to_concrete_value()) # pytype: disable=wrong-arg-types check_bool_conversion(self) + if not hasattr(self.aval, "_bool"): + raise TypeError(f"Value of type {type(self)} is not convertible to boolean.") return self.aval._bool(self) def __int__(self): if is_concrete(self): return int(self.to_concrete_value()) # pytype: disable=wrong-arg-types check_scalar_conversion(self) + if not hasattr(self.aval, "_int"): + raise TypeError(f"Value of type {type(self)} is not convertible to integer.") return self.aval._int(self) def __float__(self): check_scalar_conversion(self) + if not hasattr(self.aval, "_float"): + raise TypeError(f"Value of type {type(self)} is not convertible to float.") return self.aval._float(self) def __complex__(self): check_scalar_conversion(self) + if not hasattr(self.aval, "_complex"): + raise TypeError(f"Value of type {type(self)} is not convertible to complex.") return self.aval._complex(self) def __hex__(self): if is_concrete(self): return hex(self.to_concrete_value()) # pytype: disable=wrong-arg-types check_integer_conversion(self) + if not hasattr(self.aval, "_hex"): + raise TypeError(f"Value of type {type(self)} is not convertible to hex.") return self.aval._hex(self) def __oct__(self): if is_concrete(self): return oct(self.to_concrete_value()) # pytype: disable=wrong-arg-types check_integer_conversion(self) + if not hasattr(self.aval, "_oct"): + raise TypeError(f"Value of type {type(self)} is not convertible to oct.") return self.aval._oct(self) def __index__(self): if is_concrete(self): return operator.index(self.to_concrete_value()) # pytype: disable=wrong-arg-types check_integer_conversion(self) + if not hasattr(self.aval, "_index"): + raise TypeError(f"Value of type {type(self)} is not convertible to integer index.") return self.aval._index(self) # raises a useful error on attempts to pickle a Tracer. @@ -1052,15 +1072,28 @@ def __reduce__(self): "indicate an attempt to serialize/pickle a traced value.")) # raises the better error message from ShapedArray - def __setitem__(self, idx, val): return self.aval._setitem(self, idx, val) + def __setitem__(self, key, value): + if not hasattr(self.aval, "_setitem"): + raise TypeError(f"Value of type {type(self)} is not indexable.") + return self.aval._setitem(self, key, value) # NumPy also only looks up special methods on classes. - def __array_module__(self, types): return self.aval._array_module(self, types) + def __array_module__(self, types): + if not hasattr(self.aval, "_array_module"): + raise TypeError(f"Value of type {type(self)} is not compatible with the Array API.") + return self.aval._array_module(self, types) def __getattr__(self, name): # if the aval property raises an AttributeError, gets caught here assert not config.enable_checks.value or name != "aval" + # These must raise AttributeError in the base class for backward compatibility. + # TODO(jakevdp): can we change this and make them raise NotImplementedError instead? + if name in ["block_until_ready", "copy_to_host_async"]: + raise AttributeError( + f"The '{name}' method is not available on {self._error_repr()}." + f"{self._origin_msg()}") + if name == 'sharding': raise AttributeError( f"The 'sharding' attribute is not available on {self._error_repr()}. " @@ -1100,7 +1133,7 @@ def _pretty_print(self, verbose: bool = False) -> pp.Doc: return base def __repr__(self): - return self._pretty_print(verbose=False).format() + return self._pretty_print(verbose=False).format() # pyrefly: ignore[missing-attribute] def _contents(self): try: @@ -1117,20 +1150,6 @@ def addressable_data(self, index): f"The addressable_data() method was called on {self._error_repr()}." f"{self._origin_msg()}") - @property - def block_until_ready(self): - # Raise AttributeError for backward compatibility with hasattr() and getattr() checks. - raise AttributeError( - f"The 'block_until_ready' method is not available on {self._error_repr()}." - f"{self._origin_msg()}") - - @property - def copy_to_host_async(self): - # Raise AttributeError for backward compatibility with hasattr() and getattr() checks. - raise AttributeError( - f"The 'copy_to_host_async' method is not available on {self._error_repr()}." - f"{self._origin_msg()}") - def delete(self): raise ConcretizationTypeError(self, f"The delete() method was called on {self._error_repr()}." @@ -1751,7 +1770,7 @@ def valid_jaxtype(x) -> bool: return True -def mem_kind_to_space(mem_kind: str) -> MemorySpace: +def mem_kind_to_space(mem_kind: str | None) -> MemorySpace: if mem_kind == 'pinned_host': return MemorySpace.Host return MemorySpace.Device @@ -2558,7 +2577,7 @@ def unsafe_buffer_pointer(self): return self._refs._buf.unsafe_buffer_pointer() def at(self): raise NotImplementedError() # TODO(mattjj) class ArrayRefImpl: - _aval: ShapedArray + _aval: AbstractValue _buf: Array # mutable field def __init__(self, aval, buf): @@ -4035,6 +4054,7 @@ def __eq__(self, other): def get_opaque_trace_state(convention=None): del convention + assert trace_ctx.trace is not None return OpaqueTraceState(trace_ctx.trace._weakref) def nonempty_axis_env() -> bool: From 926bf9730afbe378de0cab339110e1f19698e7c8 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 2 Mar 2026 12:53:53 -0800 Subject: [PATCH 014/100] [typing] fix pyi type signature for argsort --- jax/numpy/__init__.pyi | 1 + 1 file changed, 1 insertion(+) diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 2c0ce335441b..e7fb28046d9e 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -165,6 +165,7 @@ def argsort( descending: builtins.bool = ..., kind: str | None = ..., order: None = ..., + dtype: DTypeLike | None = ..., ) -> Array: ... def argwhere( a: ArrayLike, From 153f2142e2314d29e2e1ae2646ea3931f54b592a Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 2 Mar 2026 13:44:19 -0800 Subject: [PATCH 015/100] Update XLA dependency to use revision http://github.com/openxla/xla/commit/ddf15ca00ef5693e02e2d870c8d720b7d8d060f6 PiperOrigin-RevId: 877558591 --- MODULE.bazel | 6 +++--- third_party/xla/revision.bzl | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/MODULE.bazel b/MODULE.bazel index 7714e6b019b9..acd8df3d6e6b 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -27,9 +27,9 @@ archive_override( bazel_dep(name = "xla") archive_override( module_name = "xla", - integrity = "sha256-4C4sa4TtEsEo8EqOsc3uUV4ceXDHr0EnZ3EvhBvtHBs=", - strip_prefix = "xla-ba5bbceae1ff6c0f03f0234ba6beadbcdae74635", - urls = ["https://github.com/openxla/xla/archive/ba5bbceae1ff6c0f03f0234ba6beadbcdae74635.tar.gz"], + integrity = "sha256-MrdCTZVTSkHyzA8scC4dYJ/A5p2hmB+xM73ucuFfcuo=", + strip_prefix = "xla-ddf15ca00ef5693e02e2d870c8d720b7d8d060f6", + urls = ["https://github.com/openxla/xla/archive/ddf15ca00ef5693e02e2d870c8d720b7d8d060f6.tar.gz"], ) # TODO: upstream, otherwise we have to duplicate the patches in jax diff --git a/third_party/xla/revision.bzl b/third_party/xla/revision.bzl index 9d2fd2cfd50b..b4e88e8d0bf3 100644 --- a/third_party/xla/revision.bzl +++ b/third_party/xla/revision.bzl @@ -21,5 +21,5 @@ # and update XLA_SHA256 with the result. # buildifier: disable=module-docstring -XLA_COMMIT = "ba5bbceae1ff6c0f03f0234ba6beadbcdae74635" -XLA_SHA256 = "e02e2c6b84ed12c128f04a8eb1cdee515e1c7970c7af412767712f841bed1c1b" +XLA_COMMIT = "ddf15ca00ef5693e02e2d870c8d720b7d8d060f6" +XLA_SHA256 = "32b7424d95534a41f2cc0f2c702e1d609fc0e69da1981fb133bdee72e15f72ea" From 208b29ece09674497d786f5a67e5c0f657925ad1 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 2 Mar 2026 13:44:36 -0800 Subject: [PATCH 016/100] Reverts 0b7d6672a7e69616e84048752b6d5e94ffeac8ce PiperOrigin-RevId: 877558703 --- jaxlib/mosaic/dialect/tpu/tpu_ops.cc | 9 +++++++++ jaxlib/mosaic/dialect/tpu/tpu_ops.td | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 4f5a0b2605b7..3f92118d5e5f 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -1655,6 +1655,15 @@ LogicalResult EnqueueIndirectDMAOp::canonicalize(EnqueueIndirectDMAOp op, return propagateTiledLayoutToConsumer(op, rewriter); } +// TODO(b/395630795): Remove after 2025-08-10. +LogicalResult WaitDMAOp::verify() { + auto sem_type = getMemRefType(getSemaphore()); + if (sem_type.getRank() != 0) { + return emitOpError("DMA wait semaphore must be rank 0"); + } + return success(); +} + void WaitDMA2Op::build(OpBuilder &builder, OperationState &state, Value semaphore, Value src, Value dst) { build(builder, state, semaphore, src, dst, /*device_id=*/nullptr, diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.td b/jaxlib/mosaic/dialect/tpu/tpu_ops.td index 87ba90fa6d50..653870c227dc 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.td +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.td @@ -1378,6 +1378,15 @@ def TPU_WaitDMA2Op : TPU_Op<"wait_dma2", [AttrSizedOperandSegments]> { let hasCanonicalizeMethod = 1; } +// TODO(b/395630795): Remove after 2025-08-10. +def TPU_WaitDMAOp : TPU_Op<"wait_dma"> { + let arguments = (ins + MemRefOf<[TPU_DMASemaphoreType]>:$semaphore, + AnyMemRef:$ref + ); + let hasVerifier = 1; +} + // Like tpu.wait_dma2, but for indirect DMAs. // // The number of bytes to wait for is determined based on the size of the From c318de0a4dc9f029a9231a8446776240c41e6c83 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Mon, 2 Mar 2026 15:55:18 -0800 Subject: [PATCH 017/100] [Pallas/SC] Fix bug where axis name for SC mesh is always "core" PiperOrigin-RevId: 877616830 --- jax/_src/pallas/mosaic/sc_core.py | 8 ++-- tests/pallas/tpu_sparsecore_pallas_test.py | 46 ++++++++++++++++++---- 2 files changed, 44 insertions(+), 10 deletions(-) diff --git a/jax/_src/pallas/mosaic/sc_core.py b/jax/_src/pallas/mosaic/sc_core.py index 375b9f24088e..32eea186b1b8 100644 --- a/jax/_src/pallas/mosaic/sc_core.py +++ b/jax/_src/pallas/mosaic/sc_core.py @@ -169,7 +169,7 @@ def default_memory_space(self) -> tpu_core.MemorySpace: @property def shape(self): - return collections.OrderedDict(core=self.num_cores) + return collections.OrderedDict({self.axis_name: self.num_cores}) @property def dimension_semantics(self) -> Sequence[str]: @@ -272,8 +272,10 @@ def default_memory_space(self) -> tpu_core.MemorySpace: @property def shape(self): - return collections.OrderedDict( - core=self.num_cores, subcore=self.num_subcores) + return collections.OrderedDict({ + self.core_axis_name: self.num_cores, + self.subcore_axis_name: self.num_subcores, + }) @property def dimension_semantics(self) -> Sequence[str]: diff --git a/tests/pallas/tpu_sparsecore_pallas_test.py b/tests/pallas/tpu_sparsecore_pallas_test.py index 593a45dd261f..34330d126932 100644 --- a/tests/pallas/tpu_sparsecore_pallas_test.py +++ b/tests/pallas/tpu_sparsecore_pallas_test.py @@ -39,6 +39,39 @@ jax.config.parse_flags_with_absl() +class PallasSCMeshTest(jtu.JaxTestCase): + + def setUp(self): + if not jtu.is_device_tpu(5, "p") and not jtu.is_device_tpu_at_least(6): + self.skipTest("SparseCore only supported on TPU v5p+") + super().setUp() + + def test_scalar_subcore_mesh(self): + sc_info = plsc.get_sparse_core_info() + mesh = sc_core.ScalarSubcoreMesh(axis_name="x", num_cores=sc_info.num_cores) + self.assertEqual( + mesh.shape, collections.OrderedDict({"x": sc_info.num_cores}) + ) + self.assertEqual(mesh.dimension_semantics, ["core_parallel"]) + self.assertEqual(mesh.default_memory_space, pltpu.MemorySpace.HBM) + + def test_vector_subcore_mesh(self): + sc_info = plsc.get_sparse_core_info() + num_cores = sc_info.num_cores + num_subcores = sc_info.num_subcores + mesh = sc_core.VectorSubcoreMesh( + core_axis_name="x", num_cores=num_cores, subcore_axis_name="y" + ) + self.assertEqual( + mesh.shape, + collections.OrderedDict({"x": num_cores, "y": num_subcores}), + ) + self.assertEqual( + mesh.dimension_semantics, ["core_parallel", "subcore_parallel"] + ) + self.assertEqual(mesh.default_memory_space, pltpu.MemorySpace.HBM) + + class PallasSCTest(jtu.JaxTestCase): USE_TC_TILING = False @@ -145,7 +178,7 @@ def test_scalar_subcore(self): @self.kernel( out_shape=int32s, mesh=plsc.ScalarSubcoreMesh( - axis_name="core", num_cores=self.sc_info.num_cores + axis_name="x", num_cores=self.sc_info.num_cores ), ) def kernel(int32s_hbm_ref, int16s_hbm_ref, int8s_hbm_ref, o_hbm_ref): @@ -155,7 +188,7 @@ def kernel(int32s_hbm_ref, int16s_hbm_ref, int8s_hbm_ref, o_hbm_ref): sem=pltpu.SemaphoreType.DMA, ) def _(tmp_ref, sem): - @pl.when(lax.axis_index("core") == 0) + @pl.when(lax.axis_index("x") == 0) def _(): pltpu.async_copy(int32s_hbm_ref, tmp_ref, sem).wait() pltpu.async_copy(tmp_ref, o_hbm_ref, sem).wait() @@ -176,7 +209,6 @@ def _(): ) with jtu.capture_stderr() as get_output: jax.block_until_ready(compiled_kernel(int32s, int16s, int8s)) - print(get_output()) self.assertIn("s32 array, data: s32", get_output()) self.assertIn( "{ " + ", ".join(map(str, range(nl, 2 * nl))) + " }", get_output() @@ -1879,12 +1911,12 @@ def test_copy(self): @self.kernel( out_shape=x, mesh=plsc.ScalarSubcoreMesh( - axis_name="core", num_cores=self.sc_info.num_cores + axis_name="x", num_cores=self.sc_info.num_cores ), ) def kernel(x_ref, o_ref): lax.cond( - lax.axis_index("core") == lax.axis_size("core") - 1, + lax.axis_index("x") == lax.axis_size("x") - 1, lambda: pltpu.sync_copy(x_ref, o_ref), lambda: None, ) @@ -1900,13 +1932,13 @@ def test_sliced_copy(self): @self.kernel( out_shape=x, mesh=plsc.ScalarSubcoreMesh( - axis_name="core", num_cores=self.sc_info.num_cores + axis_name="x", num_cores=self.sc_info.num_cores ), ) def kernel(x_ref, o_ref): @functools.partial(pl.run_scoped, sems=pltpu.SemaphoreType.DMA(4)) def _(sems): - core_id = lax.axis_index("core") + core_id = lax.axis_index("x") pltpu.async_copy( x_ref.at[core_id], o_ref.at[core_id], sems.at[core_id] ).wait() From 9e0dbb48290199756c1e9b1b5db49ff54295a411 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 2 Mar 2026 16:05:23 -0800 Subject: [PATCH 018/100] Use `itertools.chain` directly when the iterable is a literal PiperOrigin-RevId: 877621017 --- jax/_src/export/_export.py | 5 +++-- jax/_src/interpreters/pxla.py | 5 ++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index fafc571da0a9..e42ad942c42c 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -813,8 +813,9 @@ def _export_lowered( cur_mesh = None if config.use_shardy_partitioner.value: - for sharding in itertools.chain.from_iterable([ - all_in_shardings, lowering.compile_args["out_shardings"]]): + for sharding in itertools.chain( + all_in_shardings, lowering.compile_args["out_shardings"] + ): if isinstance(sharding, sharding_impls.NamedSharding): cur_mesh = sharding.mesh break diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 8d86327e2a05..d03922bf8d52 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2236,8 +2236,7 @@ def lower_sharding_computation( number of out_avals might not be known at that time and lower_sharding_computation calculates the number of out_avals so it can apply the singleton UNSPECIFIED to all out_avals.""" - auto_spmd_lowering = check_if_any_auto( - it.chain.from_iterable([in_shardings, out_shardings])) + auto_spmd_lowering = check_if_any_auto(it.chain(in_shardings, out_shardings)) all_args_info = AllArgsInfo(closed_jaxpr.in_avals, closed_jaxpr.jaxpr._debug_info) @@ -3054,7 +3053,7 @@ def from_hlo(name: str, mesh = None if auto_spmd_lowering: - for i in it.chain.from_iterable([in_shardings, out_shardings]): + for i in it.chain(in_shardings, out_shardings): if isinstance(i, AUTO): mesh = i.mesh break From c8890f664ac5ccd730ba15e4a6f3135d41d76ff4 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 2 Mar 2026 16:13:36 -0800 Subject: [PATCH 019/100] [pyrefly] fix errors in jax._src.interpreters.pxla --- jax/_src/interpreters/pxla.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index d03922bf8d52..9dfb0c70e970 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -575,8 +575,8 @@ def _match_annot(axis_name: core.AxisName, axis_size: int, val: Any, if src == dst: outval = val elif type(src) == type(dst) == int: - outval = batching.moveaxis(val, src, dst) - shard_axis_out = _moveaxis(np.ndim(val), shard_axis_src, src, dst) + outval = batching.moveaxis(val, src, dst) # pyrefly: ignore[bad-argument-type] # pyrefly#2530 + shard_axis_out = _moveaxis(np.ndim(val), shard_axis_src, src, dst) # pyrefly: ignore[bad-argument-type] # pyrefly#2530 elif src is None and dst is not None: outval = batching.broadcast(val, axis_size, dst, None) shard_axis_out = {n: d + (dst <= d) for n, d in shard_axis_out.items()} @@ -1414,7 +1414,7 @@ def _pmap_partial_eval_custom_params_updater( def _pmap_partial_eval_custom_res_maker(params_known, aval): return core.unmapped_aval(params_known['axis_size'], 0, aval) -def _pmap_dce_rule(used_outputs, eqn): +def _pmap_dce_rule(used_outputs, eqn: core.JaxprEqn): # just like pe.dce_jaxpr_call_rule, except handles in_axes / out_axes if not any(used_outputs) and not pe.has_effects(eqn): return [False] * len(eqn.invars), None @@ -1872,7 +1872,7 @@ def __init__(self, shardings: tuple[GSPMDSharding | UnspecifiedValue, ...], gspmd_shardings = [ s if (isinstance(s, (UnspecifiedValue, AUTO)) or (isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh))) - else to_gspmd_sharding(s, a.ndim) # pytype: disable=attribute-error + else to_gspmd_sharding(s, a.ndim) # type: ignore[missing-attribute] for s, a in zip(shardings, avals)] self._gspmd_shardings = gspmd_shardings self.shardings = shardings @@ -2179,7 +2179,7 @@ def _concretize_abstract_out_shardings(shardings, avals, device_assignment, if device_assignment is None: return shardings - out = [] + out: list[UnspecifiedValue | JSharding] = [] for s, a, mem_kind in zip(shardings, avals, out_mem_kinds): if isinstance(s, UnspecifiedValue) and isinstance(a, core.ShapedArray): if a.sharding.mesh.empty: @@ -2582,6 +2582,7 @@ def get_out_shardings_from_executable( return [sharding_impls.GSPMDSharding.get_replicated(device_list, memory_kind=mk) for mk in omk] + out_op_shardings: Sequence[xc.OpSharding] _, out_op_shardings = get_op_sharding_from_executable(xla_executable) if not out_op_shardings: return None @@ -2670,7 +2671,7 @@ def _gspmd_to_single_device_sharding( def _get_out_sharding_from_orig_sharding( out_shardings, out_avals, orig_in_s, orig_aval): - out = [] + out: list[JSharding] = [] orig_handler = _orig_out_sharding_handlers[type(orig_in_s)] for o, out_aval in safe_zip(out_shardings, out_avals): if (isinstance(o, sharding_impls.GSPMDSharding) and @@ -2912,7 +2913,7 @@ def _maybe_get_and_check_out_shardings( dtypes.issubdtype(aval.dtype, dtypes.extended)): xla_s = sharding_impls.logical_sharding(aval.shape, aval.dtype, xla_s) try: - new_out_shardings.append(_gspmd_to_named_sharding(xla_s, aval, orig)) # pytype: disable=wrong-arg-types + new_out_shardings.append(_gspmd_to_named_sharding(xla_s, aval, orig)) # type: ignore[arg-type] except: new_out_shardings.append(xla_s) else: @@ -3053,7 +3054,7 @@ def from_hlo(name: str, mesh = None if auto_spmd_lowering: - for i in it.chain(in_shardings, out_shardings): + for i in it.chain(in_shardings, out_shardings): # pyrefly: ignore[bad-argument-type] if isinstance(i, AUTO): mesh = i.mesh break @@ -3231,12 +3232,12 @@ def xla_extension_executable(self): return self.xla_executable def call(self, *args): - args_after_dce = [a for i, a in enumerate(args) if i in self._kept_var_idx] + args_after_dce = tuple(a for i, a in enumerate(args) if i in self._kept_var_idx) if (self._all_args_info is not None and self._all_args_info.debug_info.arg_names is not None): - arg_names_after_dce = [ + arg_names_after_dce = tuple( n for i, n in enumerate(self._all_args_info.debug_info.arg_names) - if i in self._kept_var_idx] + if i in self._kept_var_idx) else: arg_names_after_dce = ("",) * len(args_after_dce) @@ -3280,6 +3281,7 @@ def aot_cache_miss(*args, **kwargs): use_fastpath = (all(isinstance(x, xc.ArrayImpl) for x in out_flat) and not self._mut) else: + out_tree_dispatch = None use_fastpath = False if use_fastpath: From ac89644bf676bc5e3ded2e2dd48245e4f418730e Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 2 Mar 2026 15:40:56 -0800 Subject: [PATCH 020/100] Make main process_*() arguments position-only. This matches how they are currently used, and allows child classes flexibility in choice of argument names. --- jax/_src/core.py | 22 ++++++++++----------- jax/_src/interpreters/ad.py | 16 +++++++-------- jax/_src/interpreters/batching.py | 12 ++++++------ jax/_src/interpreters/partial_eval.py | 28 +++++++++++++-------------- jax/_src/interpreters/pxla.py | 10 +++++----- jax/_src/interpreters/remat.py | 2 +- jax/_src/shard_map.py | 10 +++++----- jax/experimental/jet.py | 8 ++++---- jax/experimental/sparse/transform.py | 6 +++--- 9 files changed, 57 insertions(+), 57 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index adf6b9358f41..e3505eca5291 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -785,7 +785,7 @@ def __init__(self): self._weakref = weakref.ref(self) self.requires_low = True - def process_primitive(self, primitive, tracers, params): + def process_primitive(self, primitive, tracers, params, /): raise NotImplementedError("must override") def invalidate(self): @@ -797,29 +797,29 @@ def is_valid(self): def __repr__(self): return f'{self.__class__.__name__}' - def process_call(self, call_primitive, f, tracers, params): + def process_call(self, call_primitive, f, tracers, params, /): msg = (f"{type(self)} must override process_call to handle call-like " "primitives") raise NotImplementedError(msg) - def process_map(self, map_primitive, f, tracers, params): + def process_map(self, map_primitive, f, tracers, params, /): msg = (f"{type(self)} must override process_map to handle map-like " "primitives") raise NotImplementedError(msg) - def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *, + def process_custom_jvp_call(self, primitive, fun, jvp, tracers, /, *, symbolic_zeros): msg = (f"{type(self)} must override process_custom_jvp_call " "to handle custom_jvp primitives") raise NotImplementedError(msg) def process_custom_transpose(self, prim: Primitive, - call: lu.WrappedFun, tracers, **params): + call: lu.WrappedFun, tracers, /, **params): msg = (f"{type(self)} must override process_custom_transpose " "to handle custom_transpose_call primitives") raise NotImplementedError(msg) - def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, + def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, /, *, out_trees, symbolic_zeros): msg = (f"{type(self)} must override process_custom_vjp_call " "to handle custom_vjp primitives") @@ -1215,7 +1215,7 @@ def check_eval_args(args): class EvalTrace(Trace): - def process_primitive(self, primitive, args, params): + def process_primitive(self, primitive, args, params, /): if config.debug_key_reuse.value: # Import here to avoid circular imports from jax.experimental.key_reuse._core import call_impl_with_key_reuse_checks # pytype: disable=import-error @@ -1226,7 +1226,7 @@ def process_primitive(self, primitive, args, params): check_eval_args(args) return primitive.impl(*args, **params) - def process_call(self, primitive, f, tracers, params): + def process_call(self, primitive, f, tracers, params, /): if config.debug_key_reuse.value: # Import here to avoid circular imports from jax.experimental.key_reuse._core import call_impl_with_key_reuse_checks # pytype: disable=import-error @@ -1235,15 +1235,15 @@ def process_call(self, primitive, f, tracers, params): return primitive.impl(f, *tracers, **params) process_map = process_call - def process_custom_transpose(self, primitive, call, tracers, **_): + def process_custom_transpose(self, primitive, call, tracers, /, **_): del primitive, _ return call.call_wrapped(*tracers) - def process_custom_jvp_call(self, primitive, fun, jvp, tracers, **_): + def process_custom_jvp_call(self, primitive, fun, jvp, tracers, /, **_): del primitive, jvp, _ # Unused. return fun.call_wrapped(*tracers) - def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, **_): # pytype: disable=signature-mismatch + def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, /, **_): del primitive, fwd, bwd, _ # Unused. return fun.call_wrapped(*tracers) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 784c8ecb3ff3..47591d696a84 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -556,7 +556,7 @@ def to_primal_tangent_pair(self, val): tangent_zero = p2tz(val) return (val, tangent_zero) - def process_primitive(self, primitive, tracers, params): + def process_primitive(self, primitive, tracers, params, /): primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers)) if (all(type(t) is Zero for t in tangents_in) and primitive is not core.ref_p and primitive is not core.empty_ref_p and @@ -579,7 +579,7 @@ def cur_qdd(self, x): with core.set_current_trace(self.parent_trace): return core.cur_qdd(p) - def process_call(self, call_primitive, f, tracers, params): + def process_call(self, call_primitive, f, tracers, params, /): assert call_primitive.multiple_results primals, tangents = unzip2(map(self.to_primal_tangent_pair, tracers)) which_nz = [ type(t) is not Zero for t in tangents] @@ -615,7 +615,7 @@ def new_out_axes_thunk(): def process_map(self, map_primitive, f, tracers, params): return self.process_call(map_primitive, f, tracers, params) - def process_custom_jvp_call(self, primitive, fun, jvp, tracers, symbolic_zeros): + def process_custom_jvp_call(self, primitive, fun, jvp, tracers, /, *, symbolic_zeros): primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers)) if all(type(t) is Zero for t in tangents_in): return primitive.bind_with_trace(self.parent_trace, (fun, jvp, *primals_in), @@ -631,7 +631,7 @@ def process_custom_jvp_call(self, primitive, fun, jvp, tracers, symbolic_zeros): tangents_out = map(replace_rule_output_symbolic_zeros, tangents_out) return map(partial(maybe_jvp_tracer, self), primals_out, tangents_out) - def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees, + def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, /, *, out_trees, symbolic_zeros): primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers)) if all(type(t) is Zero for t in tangents_in): @@ -791,7 +791,7 @@ def to_primal_tangent_pair(self, val): tangent_zero = p2tz(val) return (val, tangent_zero) - def process_primitive(self, primitive, tracers, params): + def process_primitive(self, primitive, tracers, params, /): primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers)) tangent_nzs = [type(t) is not Zero for t in tangents_in] if (all(type(t) is Zero for t in tangents_in) and @@ -818,7 +818,7 @@ def cur_qdd(self, x): return core.cur_qdd(p) def process_custom_jvp_call(self, primitive, fun: lu.WrappedFun, - jvp: lu.WrappedFun, tracers, *, + jvp: lu.WrappedFun, tracers, /, *, symbolic_zeros: bool): primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers)) if all(type(t) is Zero for t in tangents_in): @@ -845,7 +845,7 @@ def _f_jvp(primals, tangents): for x, nz, t in zip(primals_out, tangent_nzs_out, tangents_out)] def process_custom_vjp_call(self, primitive, fun, fwd, - bwd: lu.WrappedFun, tracers, + bwd: lu.WrappedFun, tracers, /, *, out_trees: Callable[[], tuple[PyTreeDef, PyTreeDef, list[int | None]]], symbolic_zeros: bool): primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers)) @@ -875,7 +875,7 @@ def process_custom_vjp_call(self, primitive, fun, fwd, tangent_nzs_out = [type(t) is not Zero for t in tangents_out] return map(partial(maybe_linearize_tracer, self), primals_out, tangent_nzs_out, tangents_out) - def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params): + def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params, /): assert call_primitive.multiple_results primals, tangents = unzip2(map(self.to_primal_tangent_pair, tracers)) nzs_in = tuple(type(t) is not Zero for t in tangents) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 8337678c24fc..46271a03564d 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -249,7 +249,7 @@ def cur_qdd(self, x): with core.set_current_trace(self.parent_trace): return core.cur_qdd(val) - def process_primitive(self, p, tracers, params): # pyrefly: ignore[bad-param-name-override] + def process_primitive(self, p, tracers, params, /): vals_in, dims_in = unzip2(map(self.to_batch_info, tracers)) args_not_mapped = all(bdim is not_mapped for bdim in dims_in) if p in fancy_primitive_batchers: @@ -271,7 +271,7 @@ def process_primitive(self, p, tracers, params): # pyrefly: ignore[bad-param-na else: raise NotImplementedError(f"Batching rule for '{p}' not implemented") - def process_call(self, call_primitive, f, tracers, params): + def process_call(self, call_primitive, f, tracers, params, /): assert call_primitive.multiple_results params = dict(params, name=params.get('name', f.__name__)) vals, dims = unzip2(map(self.to_batch_info, tracers)) @@ -282,7 +282,7 @@ def process_call(self, call_primitive, f, tracers, params): src = source_info_util.current() return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out())] - def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): + def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params, /): vals, dims = unzip2(map(self.to_batch_info, tracers)) # The logic for the dimension math below is as follows: # ╔═════════════╦════════════════════════════════════════╦═══════════╗ @@ -320,7 +320,7 @@ def new_out_axes_thunk(): src = source_info_util.current() return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out_)] - def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): # pyrefly: ignore[bad-param-name-override] + def process_custom_jvp_call(self, prim, fun, jvp, tracers, /, *, symbolic_zeros): in_vals, in_dims = unzip2(map(self.to_batch_info, tracers)) fun, out_dims1 = batch_subtrace(fun, self.tag, self.axis_data, in_dims) jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.tag, self.axis_data, in_dims) @@ -330,8 +330,8 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): src = source_info_util.current() return [BatchTracer(self, v, d, src) for v, d in zip(out_vals, out_dims)] - def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees, # pyrefly: ignore[bad-override] - symbolic_zeros): # pytype: disable=signature-mismatch + def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, /, *, out_trees, + symbolic_zeros): in_vals, in_dims = unzip2(map(self.to_batch_info, tracers)) fwd_in_dims = [d for in_dim in in_dims for d in [in_dim, not_mapped]] diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index cedf24b94a8b..bbc8916e5180 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -183,7 +183,7 @@ def cur_qdd(self, x): with core.set_current_trace(self.parent_trace): return core.cur_qdd(const) - def process_primitive(self, primitive, tracers, params): + def process_primitive(self, primitive, tracers, params, /): with core.set_current_trace(self.parent_trace): if primitive in custom_partial_eval_rules: tracers = map(self.to_jaxpr_tracer, tracers) @@ -222,7 +222,7 @@ def default_process_primitive(self, primitive, tracers, params): out_tracer.recipe = eqn return out_tracer - def process_call(self, primitive, f: lu.WrappedFun, tracers, params): # pyrefly: ignore[bad-param-name-override] + def process_call(self, primitive, f: lu.WrappedFun, tracers, params, /): tracers = map(self.to_jaxpr_tracer, tracers) rule = call_partial_eval_rules.get(primitive) if rule: @@ -278,7 +278,7 @@ def process_call(self, primitive, f: lu.WrappedFun, tracers, params): # pyrefly for t in out_tracers: t.recipe = eqn return merge_lists(out_knowns, out_tracers, out_consts) - def process_map(self, primitive, f: lu.WrappedFun, tracers, params): # pyrefly: ignore[bad-param-name-override] + def process_map(self, primitive, f: lu.WrappedFun, tracers, params, /): tracers = map(self.to_jaxpr_tracer, tracers) update_params = call_param_updaters.get(primitive) or (lambda p, _, __: p) in_knowns, in_avals, in_consts = partition_pvals([t.pval for t in tracers]) @@ -350,7 +350,7 @@ def const_out_axes_thunk(): def _current_truncated_name_stack(self): return source_info_util.current_name_stack()[len(self.name_stack):] - def process_custom_jvp_call(self, prim, fun, jvp, tracers, symbolic_zeros): # pyrefly: ignore[bad-override] + def process_custom_jvp_call(self, prim, fun, jvp, tracers, /, *, symbolic_zeros): tracers = map(self.to_jaxpr_tracer, tracers) if all(t.is_known() for t in tracers): with core.set_current_trace(self.parent_trace): @@ -362,7 +362,7 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers, symbolic_zeros): # p with core.set_current_trace(self): return fun.call_wrapped(*tracers) - def process_custom_transpose(self, prim, call, tracers, **params): + def process_custom_transpose(self, prim, call, tracers, /, **params): tracers = map(self.to_jaxpr_tracer, tracers) res_ts, lin_ts = split_list(tracers, [params['res_tree'].num_leaves]) # pyrefly: ignore[bad-argument-type] # pyrefly#2385 assert all(t.is_known() for t in res_ts) @@ -381,7 +381,7 @@ def process_custom_transpose(self, prim, call, tracers, **params): for t in out_tracers: t.recipe = eqn return out_tracers - def process_custom_vjp_call(self, prim, f, fwd, bwd, tracers, out_trees, symbolic_zeros): # pyrefly: ignore[bad-param-name-override] + def process_custom_vjp_call(self, prim, f, fwd, bwd, tracers, /, *, out_trees, symbolic_zeros): tracers = map(self.to_jaxpr_tracer, tracers) if all(t.is_known() for t in tracers): vals = [t.pval[1] for t in tracers] @@ -2041,7 +2041,7 @@ def cur_qdd(self, x): source_info = source_info_util.current() return self.to_jaxpr_tracer(x, source_info=source_info).mutable_qdd.cur_val - def process_primitive(self, primitive, tracers, params): + def process_primitive(self, primitive, tracers, params, /): self.frame.is_high |= primitive.is_high(*map(typeof, tracers), **params) if config.eager_constant_folding.value and not any(isinstance(x, Tracer) for x in tracers): return primitive.bind_with_trace(core.eval_trace, tracers, params) @@ -2101,8 +2101,8 @@ def default_process_primitive(self, primitive, tracers, params, self.frame.add_eqn(eqn) # pyrefly: ignore[bad-argument-type] return out_tracers if primitive.multiple_results else out_tracers.pop() - def process_call(self, call_primitive, f: lu.WrappedFun, in_tracers, # pyrefly: ignore[bad-param-name-override] - params): + def process_call(self, call_primitive, f: lu.WrappedFun, in_tracers, + params, /): source_info = source_info_util.current() to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) in_type = (tuple(get_aval(t) for t in in_tracers) if f.in_type is None @@ -2129,7 +2129,7 @@ def process_call(self, call_primitive, f: lu.WrappedFun, in_tracers, # pyrefly: [*const_tracers, *in_tracers], out_avals, call_primitive, new_params, new_params['call_jaxpr'].effects, source_info=source_info) - def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): + def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params, /): source_info = source_info_util.current() to_jaxpr_tracer = partial(self.to_jaxpr_tracer, source_info=source_info) tracers = map(to_jaxpr_tracer, tracers) @@ -2163,8 +2163,8 @@ def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): [*const_tracers, *tracers], out_avals, map_primitive, new_params, effs, source_info=source_info) return out_tracers - def process_custom_jvp_call(self, prim, fun: lu.WrappedFun, # pyrefly: ignore[bad-override] - jvp: lu.WrappedFun, tracers, + def process_custom_jvp_call(self, prim, fun: lu.WrappedFun, + jvp: lu.WrappedFun, tracers, /, *, symbolic_zeros: bool): if config.eager_constant_folding.value and not any(isinstance(x, Tracer) for x in tracers): return prim.bind_with_trace(core.eval_trace, (fun, jvp, *tracers), @@ -2198,9 +2198,9 @@ def jvp_jaxpr_thunk(*in_zeros): fun_jaxpr.effects, source_info=source_info) - def process_custom_vjp_call(self, prim: core.Primitive, # pyrefly: ignore[bad-param-name-override] + def process_custom_vjp_call(self, prim: core.Primitive, fun: lu.WrappedFun, - fwd: lu.WrappedFun, bwd: lu.WrappedFun, tracers, + fwd: lu.WrappedFun, bwd: lu.WrappedFun, tracers, /, *, out_trees: Callable[[], tuple[PyTreeDef, PyTreeDef, list[int | None]]], symbolic_zeros: bool): if config.eager_constant_folding.value and not any(isinstance(x, Tracer) for x in tracers): diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index d03922bf8d52..25db10f2add3 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -476,7 +476,7 @@ def to_map_tracer(self, val): else: return MapTracer(self, val, {}) - def process_primitive(self, primitive, tracers, params): + def process_primitive(self, primitive, tracers, params, /): from jax._src.lax import parallel # pytype: disable=import-error if primitive is parallel.axis_index_p: return self.process_axis_index(**params) # pytype: disable=missing-parameter @@ -500,10 +500,10 @@ def process_primitive(self, primitive, tracers, params): return [MapTracer(self, val, out_shard_axes) for val in outvals] return MapTracer(self, outvals, out_shard_axes) - def process_call(self, call_primitive, fun, tracers, params): + def process_call(self, call_primitive, fun, tracers, params, /): raise NotImplementedError - def process_map(self, map_primitive, fun, tracers, params): + def process_map(self, map_primitive, fun, tracers, params, /): if params['devices'] is not None: raise ValueError("Nested pmap with explicit devices argument.") if not config.disable_jit.value: @@ -528,7 +528,7 @@ def process_map(self, map_primitive, fun, tracers, params): for v, s, dst in zip(out, outaxes, out_axes_thunk())) return map(partial(MapTracer, self), out, outaxes) - def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): + def process_custom_jvp_call(self, prim, fun, jvp, tracers, /, *, symbolic_zeros): if symbolic_zeros: msg = ("custom_jvp with symbolic_zeros=True not supported with eager pmap. " "Please open an issue at https://github.com/jax-ml/jax/issues !") @@ -537,7 +537,7 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): with core.set_current_trace(self): return fun.call_wrapped(*tracers) - def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, + def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, /, *, out_trees, symbolic_zeros): if symbolic_zeros: msg = ("custom_vjp with symbolic_zeros=True not supported with eager pmap. " diff --git a/jax/_src/interpreters/remat.py b/jax/_src/interpreters/remat.py index 72100287d9ec..2310d6333caa 100644 --- a/jax/_src/interpreters/remat.py +++ b/jax/_src/interpreters/remat.py @@ -83,7 +83,7 @@ def to_val_tracer_pair(self, x): else: raise NotImplementedError # TODO(mattjj) - def process_primitive(self, prim, tracers, params): + def process_primitive(self, prim, tracers, params, /): in_vals, in_vals2 = unzip2(map(self.to_val_tracer_pair, tracers)) if prim in rules: with core.set_current_trace(self.parent_trace): diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index 618791b9ba30..359b1f82089a 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -1299,7 +1299,7 @@ def to_val_vma_pair(self, val): P(), val) return val_, frozenset() - def process_primitive(self, prim, tracers, params): + def process_primitive(self, prim, tracers, params, /): in_vals, in_vma = unzip2(map(self.to_val_vma_pair, tracers)) if self.check: out_avals, _ = prim.abstract_eval(*(typeof(t) for t in tracers), **params) @@ -1356,20 +1356,20 @@ def process_shard_map(self, prim, fun, args, mesh, in_specs, out_vmas = [v - _spec_to_vma(spec) for v, spec in zip(out_vmas_, out_specs)] return map(partial(ShardMapTracer, self), out_vmas, out_vals) - def process_call(self, call_primitive, fun, tracers, params): + def process_call(self, call_primitive, fun, tracers, params, /): raise NotImplementedError( f"Eager evaluation of `{call_primitive}` inside a `shard_map` isn't " "yet supported. Put a `jax.jit` around the `shard_map`-decorated " "function, and open a feature request at " "https://github.com/jax-ml/jax/issues !") - def process_map(self, map_primitive, fun, tracers, params): + def process_map(self, map_primitive, fun, tracers, params, /): raise NotImplementedError( "Eager evaluation of `pmap` inside a `shard_map` isn't yet supported." "Put a `jax.jit` around the `shard_map`-decorated function, and open " "a feature request at https://github.com/jax-ml/jax/issues !") - def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): + def process_custom_jvp_call(self, prim, fun, jvp, tracers, /, *, symbolic_zeros): # Since ShardMapTrace is only used as a base main, we can drop the jvp. del prim, jvp, symbolic_zeros in_vals, in_vma = unzip2(map(self.to_val_vma_pair, tracers)) @@ -1377,7 +1377,7 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): in_vma, self.check) return map(partial(ShardMapTracer, self), out_vma, out_vals) - def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, + def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, /, *, out_trees, symbolic_zeros): if symbolic_zeros: msg = ("custom_vjp symbolic_zeros support with shard_map is not " diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index 4352f72e655d..135f6115ac5b 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -234,7 +234,7 @@ def to_primal_terms_pair(self, val): else: return val, zero_series - def process_primitive(self, primitive, tracers, params): + def process_primitive(self, primitive, tracers, params, /): order = self.order # pytype: disable=attribute-error primals_in, series_in = unzip2(map(self.to_primal_terms_pair, tracers)) @@ -259,7 +259,7 @@ def process_primitive(self, primitive, tracers, params): else: return [JetTracer(self, p, ts) for p, ts in zip(primal_out, terms_out)] - def process_call(self, call_primitive, f, tracers, params): + def process_call(self, call_primitive, f, tracers, params, /): primals_in, series_in = unzip2(map(self.to_primal_terms_pair, tracers)) primals_and_series, in_tree_def = tree_flatten((primals_in, series_in)) f_jet, out_tree_def = traceable(jet_subtrace(f, self.main), in_tree_def) @@ -270,13 +270,13 @@ def process_call(self, call_primitive, f, tracers, params): primals_out, series_out = tree_unflatten(out_tree_def(), result) return [JetTracer(self, p, ts) for p, ts in zip(primals_out, series_out)] - def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *, + def process_custom_jvp_call(self, primitive, fun, jvp, tracers, /, *, symbolic_zeros): # TODO(mattjj): don't just ignore custom jvp rules? del primitive, jvp # Unused. return fun.call_wrapped(*tracers) - def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees): + def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, /, *, out_trees): del primitive, fwd, bwd, out_trees # Unused. return fun.call_wrapped(*tracers) diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index 6a17fb759040..8f4bf4053b96 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -314,7 +314,7 @@ def to_sparse_tracer(self, val): spvalue, = arrays_to_spvalues(self.spenv, [val]) return SparseTracer(self, spvalue=spvalue) - def process_primitive(self, primitive, tracers, params): + def process_primitive(self, primitive, tracers, params, /): tracers = [self.to_sparse_tracer(t) for t in tracers] spvalues = [t._spvalue for t in tracers] if any(spvalue.is_sparse() for spvalue in spvalues): @@ -328,7 +328,7 @@ def process_primitive(self, primitive, tracers, params): out_tracers = tuple(SparseTracer(self, spvalue=spvalue) for spvalue in out_spvalues) return out_tracers if primitive.multiple_results else out_tracers[0] - def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params): + def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params, /): assert False spvalues = tuple(t._spvalue for t in tracers) in_bufs = self.spenv._buffers @@ -339,7 +339,7 @@ def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params): _bufs_out = call_primitive.bind(fun, *in_bufs, **params) return [SparseTracer(self, spvalue=spvalue) for spvalue in out_spvalues()] - def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *, symbolic_zeros): + def process_custom_jvp_call(self, primitive, fun, jvp, tracers, /, *, symbolic_zeros): # TODO(jakevdp): handle the jvp here del primitive, jvp, symbolic_zeros with core.set_current_trace(self): From c2aba31be440a2475e81994e1d193b31bc7aaf02 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Mon, 2 Mar 2026 16:29:17 -0800 Subject: [PATCH 021/100] [mosaic:gpu] Guard nvvm.elect_sync in blackwell examples. Recent LLVM change didn't appear to update the examples. PiperOrigin-RevId: 877630264 --- jax/experimental/mosaic/gpu/examples/matmul_blackwell.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py b/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py index 954028ae0d7f..5754427b241c 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py +++ b/jax/experimental/mosaic/gpu/examples/matmul_blackwell.py @@ -19,6 +19,7 @@ import jax from jax._src.interpreters import mlir +from jax._src.lib import jaxlib_extension_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith from jax._src.lib.mlir.dialects import gpu @@ -90,7 +91,10 @@ def kernel(ctx, a, b, d, smem): (ab_full_barriers, ab_empty_barriers) = barriers warp_idx = mgpu.warp_idx(sync=True) - is_warp_leader = nvvm.elect_sync() + if jaxlib_extension_version >= 412: + is_warp_leader = nvvm.elect_sync() + else: + is_warp_leader = nvvm.elect_sync(i1) is_leader_of = lambda i: arith.andi( arith.cmpi(arith.CmpIPredicate.eq, warp_idx, c(i, i32)), is_warp_leader ) From c5bf3aea3148c25e753f20b512d2efb5a741ffe9 Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Mon, 2 Mar 2026 19:07:17 -0800 Subject: [PATCH 022/100] [pallas] Fix: Pallas dot_general on TPU gives wrong results for unsigned ints. Pallas Mosaic TPU dot_general mlir doesn't seem to distinguish between signed and unsigned (integer) types which lets the user do a dot with uint4 (or any other unsigned type) and get wrong results. Adding an explicit check should catch that. This is hopefully going to give a sensible error in cases like this: https://github.com/jax-ml/jax/issues/35492 PiperOrigin-RevId: 877682297 --- jax/_src/pallas/mosaic/lowering.py | 10 +++++++++- tests/pallas/tpu_ops_test.py | 16 ++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index caba6bb3f6be..e869b46579c7 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -341,7 +341,7 @@ def _dtype_to_ir_type(dtype: DTypeLike, dtype = BOOL_MEMREF_TYPE # TODO(justinfu): Remove after mosaic supports unsigned types. # This conversion makes mosaic interpret all unsigned types as signed types. - type = mlir.dtype_to_ir_type(jnp.dtype(dtype)) + type = mlir.dtype_to_ir_type(jnp.dtype(dtype)) if isinstance(type, ir.IntegerType): return ir.IntegerType.get_signless(type.width) else: @@ -2229,6 +2229,14 @@ def _dot_general_lowering_rule( preferred_element_type, **_, ): + for aval in ctx.avals_in: + if jnp.issubdtype(aval.dtype, jnp.unsignedinteger): + raise NotImplementedError( + f"Unsigned integer dtype {aval.dtype} is not supported for" + " dot_general (matmul) on the Pallas Mosaic TPU backend because" + " dot_general interprets all integer inputs as signed. Consider" + " casting to a signed type before the dot operation." + ) (lhs_dims, rhs_dims), _ = dimension_numbers (aval_out,) = ctx.avals_out out_type = ctx.aval_to_ir_type(aval_out) diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index 0ef112cdfcd7..b5e1672fd73e 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -912,6 +912,22 @@ def kernel(x_ref, o_ref): jax.nn.sigmoid(x), ) + @parameterized.parameters(jnp.uint4, jnp.uint8, jnp.uint16, jnp.uint32) + def test_unsigned_dtype_dot_raises(self, dtype): + k = 256 + packing = 32 // jnp.iinfo(dtype).bits + lhs = jnp.zeros((8 * packing, k), dtype=dtype) + rhs = jnp.zeros((k, 128), dtype=dtype) + + def kernel(lhs_ref, rhs_ref, o_ref): + o_ref[...] = pl.dot(lhs_ref[...], rhs_ref[...]) + + out_shape = jax.ShapeDtypeStruct((8 * packing, 128), dtype) + with self.assertRaisesRegex( + NotImplementedError, "Unsigned integer dtype.*dot_general.*matmul" + ): + self.pallas_call(kernel, out_shape=out_shape)(lhs, rhs) + if __name__ == "__main__": absltest.main() From 69a609b19ec48e0ab52c7989a5b9c07be8a34f9a Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Mon, 2 Mar 2026 20:35:16 -0800 Subject: [PATCH 023/100] [Pallas/SC] Enable closing over scalars in SCS kernels We basically transform the kernel into one where we sync copy the values from HBM to SMEM. PiperOrigin-RevId: 877710503 --- jax/_src/pallas/mosaic/core.py | 203 ++++++++++++++++++-------- jax/_src/pallas/mosaic/sc_core.py | 19 ++- tests/pallas/tpu_pallas_state_test.py | 28 +++- 3 files changed, 178 insertions(+), 72 deletions(-) diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 603d1a8f6e15..711a2f56c74d 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -13,6 +13,7 @@ # limitations under the License. """Contains TPU-specific Pallas abstractions.""" + from __future__ import annotations import collections @@ -54,6 +55,7 @@ class GridDimensionSemantics(enum.Enum): SUBCORE_PARALLEL = "subcore_parallel" ARBITRARY = "arbitrary" + PARALLEL = GridDimensionSemantics.PARALLEL CORE_PARALLEL = GridDimensionSemantics.CORE_PARALLEL SUBCORE_PARALLEL = GridDimensionSemantics.SUBCORE_PARALLEL @@ -212,12 +214,15 @@ def __getattr__(self, name): return super().__getattr__(name) # type: ignore -class dma_semaphore(pallas_core.semaphore_dtype): pass +class dma_semaphore(pallas_core.semaphore_dtype): + pass + class DMASemaphore(pallas_core.AbstractSemaphoreTy): type = dma_semaphore name = "dma_sem" + class SemaphoreType(enum.Enum): REGULAR = "regular" DMA = "dma" @@ -231,8 +236,9 @@ def __call__(self, shape: tuple[int, ...]): dtype = pallas_core.BarrierSemaphore() else: dtype = pallas_core.Semaphore() - return pallas_core.MemoryRef(jax_core.ShapedArray(shape, dtype), - MemorySpace.SEMAPHORE) + return pallas_core.MemoryRef( + jax_core.ShapedArray(shape, dtype), MemorySpace.SEMAPHORE + ) def get_array_aval(self) -> pallas_core.ShapedArrayWithMemorySpace: return self(()).get_array_aval() @@ -240,6 +246,7 @@ def get_array_aval(self) -> pallas_core.ShapedArrayWithMemorySpace: def get_ref_aval(self) -> state.AbstractRef: return self(()).get_ref_aval() + @dataclasses.dataclass(frozen=True) class AbstractSemaphore(jax_core.AbstractValue): sem_type: SemaphoreType @@ -255,15 +262,16 @@ def __init__( grid: pallas_core.Grid = (), in_specs: pallas_core.BlockSpecTree = no_block_spec, out_specs: pallas_core.BlockSpecTree = no_block_spec, - scratch_shapes: pallas_core.ScratchShapeTree = () + scratch_shapes: pallas_core.ScratchShapeTree = (), ): super().__init__(grid, in_specs, out_specs, scratch_shapes) self.num_scalar_prefetch = num_scalar_prefetch self.scratch_shapes = tuple(scratch_shapes) def _make_scalar_ref_aval(self, aval): - return state.AbstractRef(jax_core.ShapedArray(aval.shape, aval.dtype), - MemorySpace.SMEM) + return state.AbstractRef( + jax_core.ShapedArray(aval.shape, aval.dtype), MemorySpace.SMEM + ) @dataclasses.dataclass(frozen=True) @@ -274,6 +282,7 @@ class TensorCore: @dataclasses.dataclass(frozen=True) class TensorCoreMesh: """A mesh of TensorCores.""" + devices: np.ndarray axis_names: Sequence[str] @@ -315,7 +324,7 @@ def create_tensorcore_mesh( num_cores: int | None = None, ) -> TensorCoreMesh: if devices is not None and num_cores is not None: - raise ValueError('cannot specify both devices and num_cores') + raise ValueError("cannot specify both devices and num_cores") if num_cores is None: if devices is None: abstract_device = jax.sharding.get_abstract_mesh().abstract_device @@ -330,6 +339,120 @@ def create_tensorcore_mesh( ) +def pass_scalars_as_refs( + jaxpr: jax_core.Jaxpr, + args: Sequence[Any], + in_avals: Sequence[jax_core.AbstractValue], + out_avals: Sequence[jax_core.AbstractValue], + mesh, + copy_to_smem: bool = False, +) -> tuple[ + jax_core.Jaxpr, + tuple[Any, ...], + tuple[jax_core.AbstractValue, ...], + tuple[jax_core.AbstractValue, ...], + tuple[bool, ...], +]: + """Rewrites a jaxpr to pass scalars as refs instead of values.""" + def allowed_aval(aval): + if isinstance(aval, state.AbstractRef): + return True + if isinstance(aval, jax_core.ShapedArray): + # Only scalars are allowed. + return not aval.shape + return False + + assert all(allowed_aval(v.aval) for v in jaxpr.constvars + jaxpr.invars) + + is_scalar_const = [ + isinstance(v.aval, jax_core.ShapedArray) and not v.aval.shape + for v in jaxpr.constvars + ] + if not any(is_scalar_const): + return ( + jaxpr, + tuple(in_avals), + tuple(out_avals), + tuple(args), + tuple(is_scalar_const), + ) + non_scalar_const_avals, scalar_const_avals = util.partition_list( + is_scalar_const, + [v.aval for v in jaxpr.constvars], + ) + non_scalar_consts, scalar_consts = util.partition_list( + is_scalar_const, args + ) + if copy_to_smem: + smem_alloc = [ + state.AbstractRef( + jax_core.ShapedArray((1,), aval.dtype), + memory_space=MemorySpace.SMEM, + ) + for aval in scalar_const_avals + ] + else: + smem_alloc = [] + + # Rewrite body jaxpr to take in scalar values as Refs. + def new_body(*args): + scalar_const_refs, non_scalar_const_refs, args = util.split_list( + args, [len(scalar_consts), len(non_scalar_consts)] + ) + if copy_to_smem: + smem, args = util.split_list(args, [len(smem_alloc)]) + assert len(smem) == len(scalar_const_refs) + from jax._src.pallas.mosaic.helpers import sync_copy + + sync_copy(scalar_const_refs, smem) + else: + smem = scalar_const_refs + scalar_const_values = [s[0] for s in smem] + new_consts = util.merge_lists( + is_scalar_const, non_scalar_const_refs, scalar_const_values + ) + return jax_core.eval_jaxpr(jaxpr, new_consts, *args) + + # TODO(sharadmv): Remove this once Mosaic support passing scalars as values. + scalar_const_trace_avals = [ + state.AbstractRef( + jax_core.ShapedArray((1,), aval.dtype), + memory_space=MemorySpace.HBM if copy_to_smem else MemorySpace.SMEM, + ) + for aval in scalar_const_avals + ] + new_trace_avals = [ + *scalar_const_trace_avals, + *non_scalar_const_avals, + *smem_alloc, + *[v.aval for v in jaxpr.invars], + ] + with ( + pallas_core.tracing_grid_env( + tuple(mesh.shape.values()), mapped_dims=() + ), + jax_core.extend_axis_env_nd(mesh.shape.items()), + ): + new_jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( + lu.wrap_init( + new_body, debug_info=jaxpr.debug_info.with_unknown_names() + ), + new_trace_avals, + ) + jaxpr = new_jaxpr.replace( + constvars=new_jaxpr.invars[: len(jaxpr.constvars)], + invars=new_jaxpr.invars[len(jaxpr.constvars) :], + ) + args = [ + *[a[None] for a in scalar_consts], + *non_scalar_consts, + ] + in_avals, out_avals, _ = util.split_list( + new_trace_avals, [len(in_avals), len(out_avals)] + ) + return jaxpr, tuple(in_avals), tuple(out_avals), tuple(args), tuple(is_scalar_const) + + def _tensorcore_mesh_discharge_rule( in_avals, out_avals, @@ -345,17 +468,13 @@ def _tensorcore_mesh_discharge_rule( ): assert isinstance(mesh, TensorCoreMesh) if compiler_params and not isinstance(compiler_params, CompilerParams): - raise ValueError( - "compiler_params must be a pltpu.CompilerParams" - ) + raise ValueError("compiler_params must be a pltpu.CompilerParams") if not compiler_params: compiler_params = CompilerParams() if len(mesh.shape) > 1: raise NotImplementedError("Mesh must be 1D") if compiler_params.dimension_semantics is not None: - raise ValueError( - "dimension_semantics must be None for TensorCoreMesh" - ) + raise ValueError("dimension_semantics must be None for TensorCoreMesh") num_cores = len(mesh.devices) if num_cores > 1: # Since each core will have its own VMEM, we currently disallow VMEM inputs @@ -369,54 +488,10 @@ def _tensorcore_mesh_discharge_rule( "TensorCoreMesh does not support VMEM inputs/outputs when there are" " >1 cores. Use HBM or ANY instead." ) - def allowed_aval(aval): - if isinstance(aval, state.AbstractRef): - return True - if isinstance(aval, jax_core.ShapedArray): - # Only scalars are allowed. - return not aval.shape - return False - assert all(allowed_aval(v.aval) for v in jaxpr.constvars + jaxpr.invars) - - is_scalar_const = [ - isinstance(v.aval, jax_core.ShapedArray) and not v.aval.shape - for v in jaxpr.constvars - ] - if any(is_scalar_const): - # Rewrite body jaxpr to take in scalar values as Refs. - def new_body(*args): - args = [ - a[0] if is_scalar else a - for a, is_scalar in zip(args, is_scalar_const) - ] - return jax_core.eval_jaxpr(jaxpr, args) - # TODO(sharadmv): Remove this once Mosaic support passing scalars as values. - new_trace_avals = [ - state.AbstractRef( # pylint: disable=g-long-ternary - jax_core.ShapedArray((1,), v.aval.dtype), - memory_space=MemorySpace.SMEM, - ) - if is_scalar - else v.aval - for v, is_scalar in zip(jaxpr.constvars, is_scalar_const) - ] - with ( - pallas_core.tracing_grid_env(tuple(mesh.shape.values()), mapped_dims=()), - jax_core.extend_axis_env_nd(mesh.shape.items()), - ): - new_jaxpr, _, _ = pe.trace_to_jaxpr_dynamic( - lu.wrap_init( - new_body, debug_info=jaxpr.debug_info.with_unknown_names() - ), - new_trace_avals, - ) - jaxpr = new_jaxpr.replace(invars=[], constvars=new_jaxpr.invars) - args = tuple( - a[None] if is_scalar else a - for a, is_scalar in zip(args, is_scalar_const) - ) - in_avals, out_avals = util.split_list(new_trace_avals, [len(in_avals)]) - return pallas_core.default_mesh_discharge_rule( + jaxpr, in_avals, out_avals, args, is_scalar_const = pass_scalars_as_refs( + jaxpr, args, in_avals, out_avals, mesh + ) + refs_out, out = pallas_core.default_mesh_discharge_rule( in_avals, out_avals, *args, @@ -429,6 +504,11 @@ def new_body(*args): name=name, metadata=metadata, ) + refs_out = [ + a if not is_scalar else None + for is_scalar, a in zip(is_scalar_const, refs_out) + ] + return refs_out, out pallas_core._core_map_mesh_rules[TensorCoreMesh] = ( @@ -452,6 +532,7 @@ def get_device_kind() -> str: return abstract_device.device_kind return jex_backend.get_default_device().device_kind + def get_num_device_cores() -> int: if abstract_device := jax.sharding.get_abstract_mesh().abstract_device: return abstract_device.num_cores diff --git a/jax/_src/pallas/mosaic/sc_core.py b/jax/_src/pallas/mosaic/sc_core.py index 32eea186b1b8..228ca1ac55b6 100644 --- a/jax/_src/pallas/mosaic/sc_core.py +++ b/jax/_src/pallas/mosaic/sc_core.py @@ -206,13 +206,13 @@ def _scalar_subcore_mesh_discharge_rule( compiler_params = tpu_core.CompilerParams() if compiler_params.dimension_semantics is not None: raise ValueError("ScalarSubcoreMesh does not support dimension_semantics=") - sa_avals = [a for a in in_avals if isinstance(a, jax_core.ShapedArray)] - if sa_avals: - raise NotImplementedError( - f"Cannot close over values in core_map: {sa_avals}" - ) - - return pallas_core.default_mesh_discharge_rule( + jaxpr, in_avals, out_avals, args, is_scalar_const = tpu_core.pass_scalars_as_refs( + jaxpr, args, in_avals, out_avals, mesh, + # TODO(sharadmv): Delete this once we can pass into SMEM directly on + # SparseCore. + copy_to_smem=True, + ) + refs_out, out = pallas_core.default_mesh_discharge_rule( in_avals, out_avals, *args, @@ -225,6 +225,11 @@ def _scalar_subcore_mesh_discharge_rule( name=name, metadata=metadata, ) + refs_out = [ + a if not is_scalar else None + for is_scalar, a in zip(is_scalar_const, refs_out) + ] + return refs_out, out pallas_core._core_map_mesh_rules[ScalarSubcoreMesh] = ( diff --git a/tests/pallas/tpu_pallas_state_test.py b/tests/pallas/tpu_pallas_state_test.py index f05e00723ebd..163c0a8c6eff 100644 --- a/tests/pallas/tpu_pallas_state_test.py +++ b/tests/pallas/tpu_pallas_state_test.py @@ -20,6 +20,7 @@ from jax._src.state.primitives import pin, unpin from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu +from jax.experimental.pallas import tpu_sc as plsc import jax.numpy as jnp import numpy as np @@ -348,11 +349,30 @@ def kernel(x_ref, out_ref, tmp_ref): ): f(x) - def test_capture_scalar(self): + @parameterized.product( + core_type=[*pltpu.CoreType], use_tc_tiling_on_sc=[True, False] + ) + def test_capture_scalar(self, core_type, use_tc_tiling_on_sc): + match core_type: + case pltpu.CoreType.TC: + mesh = pltpu.create_tensorcore_mesh("x", num_cores=1) + use_tc_tiling_on_sc = None + case pltpu.CoreType.SC_SCALAR_SUBCORE: + if pltpu.get_tpu_info().sparse_core is None: + self.skipTest("Sparsecore not supported on this device.") + mesh = plsc.ScalarSubcoreMesh(axis_name="x", num_cores=1) + case pltpu.CoreType.SC_VECTOR_SUBCORE: + self.skipTest("Copies to SMEM on TEC not supported.") @jax.jit def f(x, i): - @pl.kernel(out_shape=jax.ShapeDtypeStruct((1, *x.shape[1:]), jnp.int32), - mesh=pltpu.create_tensorcore_mesh("x", num_cores=1)) + + @pl.kernel( + out_shape=jax.ShapeDtypeStruct((1, *x.shape[1:]), jnp.int32), + mesh=mesh, + compiler_params=pltpu.CompilerParams( + use_tc_tiling_on_sc=use_tc_tiling_on_sc, + ), + ) def kernel(x_ref, out_ref): idx = jax.lax.axis_index("x") # this is always 0 pltpu.sync_copy(x_ref.at[i], out_ref.at[idx]) @@ -366,7 +386,7 @@ def kernel(x_ref, out_ref): @jax.jit def g(x, i): @pl.kernel(out_shape=jax.ShapeDtypeStruct((2, *x.shape[1:]), jnp.int32), - mesh=pltpu.create_tensorcore_mesh("x", num_cores=1)) + mesh=mesh) def kernel(x_ref, out_ref): pltpu.sync_copy(x_ref.at[pl.ds(i, 2)], out_ref) return kernel(x) From 5bb5f777e029eb1c8b361f66c9f176cd858735c9 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Mon, 2 Mar 2026 22:22:23 -0800 Subject: [PATCH 024/100] [mgpu] Fix race condition in `GetOrCreateKernel`. The mutex is released during kernel compilation, which allows two compilations to happen concurrently. The first to complete will add a result into the map, then return the pointer. The second to complete overwrote the entry, causing the original entry to be freed, invalidating the first pointer. This has been changed so the second result is discarded, and the second thread will get a pointer to the first result. PiperOrigin-RevId: 877748235 --- jaxlib/mosaic/gpu/custom_call.cc | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 0012488f07aa..7d8236be3bd2 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -822,17 +822,18 @@ absl::StatusOr GetOrCreateKernel( factory) { auto& cache = GetKernelCache(); { - absl::MutexLock lock(&cache.mutex); + absl::MutexLock lock(cache.mutex); auto it = cache.kernels.find(kernel_hash); if (it != cache.kernels.end()) { return it->second.get(); } } - // Release the lock while compiling the kernel. + // Release the lock while compiling the kernel. It is possible that multiple + // threads compile the same kernel concurrently. In that case, we will discard + // all but the first result. TF_ASSIGN_OR_RETURN(auto kernel, factory()); - absl::MutexLock lock(&cache.mutex); - auto [iter, inserted] = - cache.kernels.insert_or_assign(kernel_hash, std::move(kernel)); + absl::MutexLock lock(cache.mutex); + auto [iter, _] = cache.kernels.try_emplace(kernel_hash, std::move(kernel)); return iter->second.get(); } From ea4cf33a7be8fa8dec2aac9f778d84c0d24fe2ec Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 3 Mar 2026 00:06:15 -0800 Subject: [PATCH 025/100] Update XLA dependency to use revision http://github.com/openxla/xla/commit/86c54b8eaaff7b52acad66472f6e38bcd867ff93 PiperOrigin-RevId: 877782959 --- MODULE.bazel | 6 +++--- third_party/xla/revision.bzl | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/MODULE.bazel b/MODULE.bazel index acd8df3d6e6b..4e6d43b53e7a 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -27,9 +27,9 @@ archive_override( bazel_dep(name = "xla") archive_override( module_name = "xla", - integrity = "sha256-MrdCTZVTSkHyzA8scC4dYJ/A5p2hmB+xM73ucuFfcuo=", - strip_prefix = "xla-ddf15ca00ef5693e02e2d870c8d720b7d8d060f6", - urls = ["https://github.com/openxla/xla/archive/ddf15ca00ef5693e02e2d870c8d720b7d8d060f6.tar.gz"], + integrity = "sha256-yL46qTzbd281Ygr+OLBXbNP4t9K1qrh5voGhpZMPizI=", + strip_prefix = "xla-86c54b8eaaff7b52acad66472f6e38bcd867ff93", + urls = ["https://github.com/openxla/xla/archive/86c54b8eaaff7b52acad66472f6e38bcd867ff93.tar.gz"], ) # TODO: upstream, otherwise we have to duplicate the patches in jax diff --git a/third_party/xla/revision.bzl b/third_party/xla/revision.bzl index b4e88e8d0bf3..d91a43ebc818 100644 --- a/third_party/xla/revision.bzl +++ b/third_party/xla/revision.bzl @@ -21,5 +21,5 @@ # and update XLA_SHA256 with the result. # buildifier: disable=module-docstring -XLA_COMMIT = "ddf15ca00ef5693e02e2d870c8d720b7d8d060f6" -XLA_SHA256 = "32b7424d95534a41f2cc0f2c702e1d609fc0e69da1981fb133bdee72e15f72ea" +XLA_COMMIT = "86c54b8eaaff7b52acad66472f6e38bcd867ff93" +XLA_SHA256 = "c8be3aa93cdb776f35620afe38b0576cd3f8b7d2b5aab879be81a1a5930f8b32" From dd5aeab0e4495de69d22aa11927536c795ba75f8 Mon Sep 17 00:00:00 2001 From: Aliia Khasanova Date: Tue, 3 Mar 2026 02:04:01 -0800 Subject: [PATCH 026/100] [mGPU] Fix race condition in CachedInit. Multiple threads can miss the cache and concurrently load the same CUDA module. The losing thread's module handle will be overwritten in the cache and never unloaded, leading to a memory leak PiperOrigin-RevId: 877824701 --- jaxlib/mosaic/gpu/custom_call.cc | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 7d8236be3bd2..cb06161afdf3 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -865,16 +865,13 @@ absl::StatusOr CachedInit(const CompiledKernel* absl_nonnull kernel) { CUDA_RETURN_IF_ERROR(cuCtxGetCurrent(&ctx)); CacheKey key(kernel, reinterpret_cast(ctx)); - { - absl::MutexLock lock(&cache->mutex); - auto it = cache->contexts.find(key); - if (it != cache->contexts.end()) { - VLOG(5) << "Found Mosaic GPU kernel in cache"; - return it->second; - } + absl::MutexLock lock(cache->mutex); + auto it = cache->contexts.find(key); + if (it != cache->contexts.end()) { + VLOG(5) << "Found Mosaic GPU kernel in cache"; + return it->second; } TF_ASSIGN_OR_RETURN(void* context, InitKernel(*kernel)); - absl::MutexLock lock(&cache->mutex); cache->contexts.insert_or_assign(key, context); return context; } From c1b231fcb54f517d4d9b2190610a36f78891c284 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 3 Mar 2026 02:54:02 -0800 Subject: [PATCH 027/100] [pallas:triton] `debug_barrier` now has an effect to prevent it from being DCEd I just noticed it doesn't have one while doing an unrelated change. PiperOrigin-RevId: 877842934 --- jax/_src/pallas/triton/BUILD | 1 + jax/_src/pallas/triton/primitives.py | 19 ++++++++++++++++--- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/jax/_src/pallas/triton/BUILD b/jax/_src/pallas/triton/BUILD index 94c898bdb9f9..ed0d5487a300 100644 --- a/jax/_src/pallas/triton/BUILD +++ b/jax/_src/pallas/triton/BUILD @@ -45,6 +45,7 @@ pytype_strict_library( "//jax/_src:ad_util", "//jax/_src:api_util", "//jax/_src:core", + "//jax/_src:effects", "//jax/_src:lax", "//jax/_src:mlir", "//jax/_src:partial_eval", diff --git a/jax/_src/pallas/triton/primitives.py b/jax/_src/pallas/triton/primitives.py index 8e763c3d8e6a..6381567f9f41 100644 --- a/jax/_src/pallas/triton/primitives.py +++ b/jax/_src/pallas/triton/primitives.py @@ -21,9 +21,11 @@ import jax from jax._src import core as jax_core +from jax._src import effects from jax._src import state from jax._src.lib.mlir.dialects import gpu as gpu_dialect from jax._src.lib.triton import dialect as tt_dialect +from jax._src.pallas import core as pallas_core from jax._src.pallas import primitives as pallas_primitives from jax._src.pallas.triton import lowering from jax.interpreters import mlir @@ -134,12 +136,23 @@ def debug_barrier() -> None: return debug_barrier_p.bind() +class BarrierEffect(jax_core.Effect): + pass + +barrier_effect = BarrierEffect() + +pallas_core.kernel_local_effects.add_type(BarrierEffect) +effects.control_flow_allowed_effects.add_type(BarrierEffect) + + debug_barrier_p = jax_core.Primitive("debug_barrier_p") debug_barrier_p.multiple_results = True -@debug_barrier_p.def_abstract_eval -def _debug_barrier_abstract_eval() -> Sequence[jax_core.ShapedArray]: - return () + +@debug_barrier_p.def_effectful_abstract_eval +def _debug_barrier_abstract_eval(): + return (), {barrier_effect} + @lowering.register_lowering(debug_barrier_p) def _debug_barrier_lowering(ctx: lowering.LoweringRuleContext): From ccdbf06231a11ffd74b3622be8817e4d2491bcde Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 3 Mar 2026 02:59:05 -0800 Subject: [PATCH 028/100] [Mosaic GPU] Ensure that WGStridedFragLayout picks a vector length that divides element count Previously, shapes such as (5, 256) would be problematic, since max_vec_size = 10. If bytewidth = 1, we'd pick vec_size = 8, but 8 * 128 does not divide 1280. PiperOrigin-RevId: 877844813 --- jax/experimental/mosaic/gpu/fragmented_array.py | 14 +++++++++----- tests/mosaic/gpu_test.py | 2 +- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 33180e0cf98d..70428c585c5c 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -479,12 +479,16 @@ def from_shaped_type(cls, shaped_ty: ir.Type) -> WGStridedFragLayout | None: return None bw = bitwidth // 8 assert 8 % bw == 0 and 8 // bw != 0, bw - if math.prod(shaped_ty.shape) % WARPGROUP_SIZE != 0: + size = math.prod(shaped_ty.shape) + if size % WARPGROUP_SIZE != 0: return None - max_vec_size = np.prod(shaped_ty.shape) // WARPGROUP_SIZE - return cls( - shape=tuple(shaped_ty.shape), vec_size=min(8 // bw, max_vec_size) - ) + max_vec_size = size // WARPGROUP_SIZE + vec_size = min(8 // bw, max_vec_size) + while vec_size > 0 and size % (vec_size * WARPGROUP_SIZE) != 0: + vec_size //= 2 + if vec_size == 0: + return None + return cls(shape=tuple(shaped_ty.shape), vec_size=vec_size) def registers_element_type(self, t: ir.Type) -> ir.Type: return ir.VectorType.get((self.vec_size,), t) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index bcbd3d028ccc..1c5ffba1d11d 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -3681,7 +3681,7 @@ def test_cp_async_tiled(self, swizzle, shape, dtype): self._test_cp_async(shape, dtype, swizzle=swizzle, tiling=tiling) @parameterized.product( - shape=((64, 128), (128, 40), (32, 384)), + shape=((64, 128), (128, 40), (5, 256)), dtype=(jnp.float32, jnp.float16), ) def test_cp_async_untiled(self, shape, dtype): From 2bbaf4cbbe6af686ed5a1e4630d9a763e7cc833b Mon Sep 17 00:00:00 2001 From: Yue Sheng Date: Tue, 3 Mar 2026 03:16:44 -0800 Subject: [PATCH 029/100] Fix a test. PiperOrigin-RevId: 877851034 --- tests/pallas/tpu_pallas_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 53a3ac140898..f289f2512bc4 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -1539,8 +1539,8 @@ def body(x_ref, y_ref, sem): np.testing.assert_allclose(y, x) def test_dma_with_regular_semaphore(self): - if not jtu.is_device_tpu_at_least(5): - self.skipTest('Regular semaphores in DMAs require TPU v5+') + if not jtu.is_device_tpu_at_least(6): + self.skipTest('Regular semaphores in DMAs require TPU v6+') if not jtu.is_cloud_tpu_at_least(2026, 3, 2): self.skipTest("Test requires a newer libtpu") From 7a3fb9db0bdc892afdecb061aabda12b9ef9370e Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 3 Mar 2026 03:23:40 -0800 Subject: [PATCH 030/100] [pallas:sc] Fixed how `scheduler.grid_env` is used when `emit_pipeline` is called with explicit indices I'm pretty sure I had this fix in my original version, but it somehow disappeared. PiperOrigin-RevId: 877853107 --- jax/_src/pallas/mosaic/pipeline.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index a989175a1f4d..5b21cc14dd49 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -1240,6 +1240,7 @@ def __init__( last_cycle: jax.Array | bool, init_accumulators=None, trace_scopes=True, + _explicit_indices: bool = False, ): """Initializes scheduler. @@ -1254,6 +1255,7 @@ def __init__( init_accumulators: do we zero-initialize accumulator state for this invocation of the pipeline. trace_scopes: whether to use named_scope to trace blocks in the pipeline. + _explicit_indices: whether the pipeline uses explicit indices. """ self.step = step self.grid = grid @@ -1263,6 +1265,7 @@ def __init__( self.last_cycle = last_cycle self.init_accumulators = init_accumulators self.trace_scopes = trace_scopes + self._explicit_indices = _explicit_indices # Total number of linear steps. self.num_steps = _grid_size(grid) @@ -1313,6 +1316,8 @@ def _named_scope(self, name): yield def grid_env(self): + if self._explicit_indices: + return contextlib.nullcontext() return pallas_core.grid_env( list(map(pallas_core.GridAxis, self.indices, self.grid))) # pyrefly: ignore[no-matching-overload] # pyrefly#2385 @@ -2059,16 +2064,13 @@ def make_scheduler(step, indices): last_cycle=last_cycle, init_accumulators=init_accumulators, trace_scopes=trace_scopes, + _explicit_indices=_explicit_indices, ) def loop_body(step, carry): unaliased_brefs, indices = carry scheduler = make_scheduler(step, indices) - grid_env_ctx = ( - contextlib.nullcontext() if _explicit_indices - else scheduler.grid_env() - ) - with grid_env_ctx: + with scheduler.grid_env(): # prepare any local VMEM aliases brefs = map_brefs(scheduler.alias_local_refs, unaliased_brefs, refs) # loop input handling phase @@ -2128,11 +2130,7 @@ def loop_body(step, carry): def _loop_body(step, carry): brefs, indices = carry scheduler = make_scheduler(step, indices) - grid_env_ctx = ( - contextlib.nullcontext() if _explicit_indices - else scheduler.grid_env() - ) - with grid_env_ctx: + with scheduler.grid_env(): # prepare any local VMEM aliases brefs = map_brefs(scheduler.alias_local_refs, brefs, refs) # loop input handling phase From 5a3accc36d5481c289daa5ec08d148ace60e48a0 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 3 Mar 2026 03:37:27 -0800 Subject: [PATCH 031/100] [pallas:triton] Moved Triton specific primitives from Pallas core into `triton` We movied the public APIs a while back, so this change only moves the implementation, which was left in `pallas/primitives.py` until now. PiperOrigin-RevId: 877857453 --- .../mosaic/interpret/interpret_pallas_call.py | 6 - jax/_src/pallas/primitives.py | 244 ------------ jax/_src/pallas/triton/BUILD | 1 + jax/_src/pallas/triton/lowering.py | 78 ---- jax/_src/pallas/triton/primitives.py | 357 +++++++++++++++++- jax/experimental/pallas/triton.py | 18 +- 6 files changed, 366 insertions(+), 338 deletions(-) diff --git a/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py b/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py index 381774a2effd..87cbd762ea19 100644 --- a/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py +++ b/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py @@ -1477,12 +1477,6 @@ def f(*args, jaxpr): ) out = [] - elif prim is primitives.atomic_rmw_p: - raise NotImplementedError('atomic_rmw_p') - - elif prim is primitives.atomic_cas_p: - raise NotImplementedError('atomic_cas_p') - else: if interpret_params.skip_floating_point_ops and all( interpret_utils.is_float(ovar.aval.dtype) for ovar in eqn.outvars diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 56e437e56d64..266906b72816 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -116,250 +116,6 @@ def _num_programs_bind_with_trace(trace, _, params): def _num_programs_abstract_eval(**_): return jax_core.ShapedArray((), jnp.int32) -class AtomicOpType(enum.Enum): - XCHG = "xchg" - ADD = "add" - MAX = "max" - MIN = "min" - AND = "and" - OR = "or" - XOR = "xor" - -atomic_rmw_p = jax_core.Primitive("atomic_rmw") - - -def _atomic_rmw_discharge_rule( - in_avals, out_avals, *args_flat, args_tree, atomic_type: AtomicOpType -): - del out_avals # Unused. - ref, transforms, val, mask = args_tree.unflatten(args_flat) - *prev_transforms, idx = transforms - ref = state_discharge.transform_array(ref, prev_transforms) - - if mask is not None: - raise NotImplementedError - - if atomic_type == AtomicOpType.ADD: - monoid = lambda x, y: x + y - elif atomic_type == AtomicOpType.MAX: - monoid = jnp.maximum - elif atomic_type == AtomicOpType.MIN: - monoid = jnp.minimum - else: - raise NotImplementedError(atomic_type) - - if all((isinstance(s, Slice) or not s.shape) for s in idx.indices): - indices = idx.indices - scalar_dims = [not isinstance(s, Slice) and not s.shape for s in indices] - slice_starts = [s.start if isinstance(s, Slice) else s for s in indices] - slice_sizes = tuple(s.size if isinstance(s, Slice) else 1 for s in indices) - out_ones = lax.dynamic_slice(ref, slice_starts, slice_sizes=slice_sizes) - val_indexer = tuple(None if scalar else slice(None) for scalar in scalar_dims) - val = val[val_indexer] - val = monoid(val, out_ones) - x_new = lax.dynamic_update_slice(ref, val, start_indices=slice_starts) - out_indexer = tuple(0 if scalar else slice(None) for scalar in scalar_dims) - out = out_ones[out_indexer] - elif all(not isinstance(s, Slice) for s in idx.indices): - out = ref[idx.indices] - x_new = ref.at[idx.indices].set(monoid(out, val)) - else: - raise NotImplementedError - return (x_new,) + (None,) * (len(in_avals) - 1), out - - -state_discharge.register_discharge_rule(atomic_rmw_p)(_atomic_rmw_discharge_rule) - - -@atomic_rmw_p.def_effectful_abstract_eval -def _atomic_abstract_eval(*avals_flat, args_tree, atomic_type: AtomicOpType): - ref, _, _, _ = args_tree.unflatten(avals_flat) - if ref.dtype == jnp.dtype("float16") and atomic_type != AtomicOpType.ADD: - raise ValueError(f"`atomic_{atomic_type.value}` does not support f16.") - if ref.dtype in { - jnp.dtype("bool"), - jnp.dtype("int8"), - jnp.dtype("int16"), - jnp.bfloat16, - }: - raise ValueError( - f"`atomic_{atomic_type.value}` does not support {ref.dtype}." - ) - return _swap_abstract_eval(*avals_flat, args_tree=args_tree) - - -def _atomic_rmw(x_ref_or_view, idx, val, *, mask: Any | None = None, - atomic_type: AtomicOpType): - x_ref, transforms = sp.get_ref_and_transforms( - x_ref_or_view, idx, "atomic_rmw" - ) - args_flat, args_tree = tree_util.tree_flatten((x_ref, transforms, val, mask)) - return atomic_rmw_p.bind( - *args_flat, args_tree=args_tree, atomic_type=atomic_type - ) - -def atomic_xchg(x_ref_or_view, idx, val, *, mask: Any | None = None): - """Atomically exchanges the given value with the value at the given index. - - Args: - x_ref_or_view: The ref to operate on. - idx: The indexer to use. - mask: TO BE DOCUMENTED. - - Returns: - The value at the given index prior to the aupdate. - """ - return _atomic_rmw( - x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.XCHG - ) - - -def atomic_add(x_ref_or_view, idx, val, *, mask: Any | None = None): - """Atomically computes ``x_ref_or_view[idx] += val``. - - Args: - x_ref_or_view: The ref to operate on. - idx: The indexer to use. - mask: TO BE DOCUMENTED. - - Returns: - The value at the given index prior to the atomic operation. - """ - return _atomic_rmw( - x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.ADD - ) - - -def atomic_max(x_ref_or_view, idx, val, *, mask: Any | None = None): - """Atomically computes ``x_ref_or_view[idx] = max(x_ref_or_view[idx], val)``. - - Args: - x_ref_or_view: The ref to operate on. - idx: The indexer to use. - mask: TO BE DOCUMENTED. - - Returns: - The value at the given index prior to the atomic operation. - """ - return _atomic_rmw( - x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.MAX - ) - - -def atomic_min(x_ref_or_view, idx, val, *, mask: Any | None = None): - """Atomically computes ``x_ref_or_view[idx] = min(x_ref_or_view[idx], val)``. - - Args: - x_ref_or_view: The ref to operate on. - idx: The indexer to use. - mask: TO BE DOCUMENTED. - - Returns: - The value at the given index prior to the atomic operation. - """ - return _atomic_rmw( - x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.MIN - ) - - -def atomic_and(x_ref_or_view, idx, val, *, mask: Any | None = None): - """Atomically computes ``x_ref_or_view[idx] &= val``. - - Args: - x_ref_or_view: The ref to operate on. - idx: The indexer to use. - mask: TO BE DOCUMENTED. - - Returns: - The value at the given index prior to the atomic operation. - """ - return _atomic_rmw( - x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.AND - ) - - -def atomic_or(x_ref_or_view, idx, val, *, mask: Any | None = None): - """Atomically computes ``x_ref_or_view[idx] |= val``. - - Args: - x_ref_or_view: The ref to operate on. - idx: The indexer to use. - mask: TO BE DOCUMENTED. - - Returns: - The value at the given index prior to the atomic operation. - """ - return _atomic_rmw( - x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.OR - ) - - -def atomic_xor(x_ref_or_view, idx, val, *, mask: Any | None = None): - """Atomically computes ``x_ref_or_view[idx] ^= val``. - - Args: - x_ref_or_view: The ref to operate on. - idx: The indexer to use. - mask: TO BE DOCUMENTED. - - Returns: - The value at the given index prior to the atomic operation. - """ - return _atomic_rmw( - x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.XOR - ) - -atomic_cas_p = jax_core.Primitive("atomic_cas") - -@atomic_cas_p.def_effectful_abstract_eval -def _atomic_cas_abstract_eval(ref_aval, cmp_aval, val_aval): - if cmp_aval.dtype != val_aval.dtype or cmp_aval.shape != val_aval.shape: - raise ValueError("cmp and val must have identical dtypes and shapes") - if ref_aval.shape: - raise ValueError("ref must be scalar.") - if cmp_aval.shape: - raise ValueError("cmp must be scalar.") - if val_aval.shape: - raise ValueError("val must be scalar.") - return jax_core.ShapedArray(val_aval.shape, val_aval.dtype), {state.WriteEffect(0)} - - -def atomic_cas(ref, cmp, val): - """Performs an atomic compare-and-swap of the value in the ref with the - given value. - - Args: - ref: The ref to operate on. - cmp: The expected value to compare against. - val: The value to swap in. - - Returns: - The value at the given index prior to the atomic operation. - """ - return atomic_cas_p.bind(ref, cmp, val) - -@state_discharge.register_discharge_rule(atomic_cas_p) -def _atomic_cas_discharge_rule(in_avals, out_avals, ref, cmp, val): - del in_avals, out_avals - new_val = jnp.where(ref == cmp, val, ref) - return (new_val, None, None), ref - -max_contiguous_p = jax_core.Primitive("max_contiguous") - -max_contiguous_p.def_impl(lambda x, **_: x) -mlir.register_lowering(max_contiguous_p, lambda _, x, **__: [x]) - -def max_contiguous(x, values): - """A compiler hint that asserts the ``values`` first values of ``x`` are contiguous. - """ - if not isinstance(values, (list, tuple)): - values = (values,) - return max_contiguous_p.bind(x, values=tuple(values)) - -@max_contiguous_p.def_abstract_eval -def _max_contiguous_abstract_eval(aval, **_): - return aval - multiple_of_p = jax_core.Primitive("multiple_of") multiple_of_p.def_impl(lambda x, **_: x) diff --git a/jax/_src/pallas/triton/BUILD b/jax/_src/pallas/triton/BUILD index ed0d5487a300..f5b2a4d8bd52 100644 --- a/jax/_src/pallas/triton/BUILD +++ b/jax/_src/pallas/triton/BUILD @@ -50,6 +50,7 @@ pytype_strict_library( "//jax/_src:mlir", "//jax/_src:partial_eval", "//jax/_src:source_info_util", + "//jax/_src:tree_util", "//jax/_src:util", "//jax/_src/lib", "//jax/_src/pallas", diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 18d085e3f7d3..8332505e6261 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -503,72 +503,6 @@ def _atomic_rmw( ) -@register_lowering(primitives.atomic_rmw_p) -def _atomic_lowering_rule( - ctx: LoweringRuleContext, - *args_flat, - args_tree, - atomic_type: primitives.AtomicOpType, -): - block_info, *_ = ctx.block_infos - assert block_info is not None - ptr, indexers, val, mask = args_tree.unflatten(args_flat) - *_, value_aval, mask_aval = args_tree.unflatten(ctx.avals_in) - indexers = list(indexers) - if not indexers or not isinstance(indexers[-1], indexing.NDIndexer): - ref_aval = state.transform_type(indexers, ctx.avals_in[0]) - assert isinstance(ref_aval, state.AbstractRef) - indexers.append(NDIndexer.make_trivial_indexer(ref_aval.shape)) - if len(indexers) != 1: - raise NotImplementedError("Only single indexer is supported.") - idx = indexers[0] - ptr = _compute_pointers_from_indices(ptr, block_info, idx) - val = _ensure_ir_value(val, value_aval) - if mask is not None: - mask = _ensure_ir_value(mask, mask_aval) - if atomic_type == primitives.AtomicOpType.XCHG: - op = tt_dialect.RMWOp.XCHG - elif atomic_type == primitives.AtomicOpType.ADD: - if isinstance(val.type, ir.IntegerType): - op = tt_dialect.RMWOp.ADD - else: - op = tt_dialect.RMWOp.FADD - elif atomic_type == primitives.AtomicOpType.MIN: - op = tt_dialect.RMWOp.MIN - elif atomic_type == primitives.AtomicOpType.MAX: - op = tt_dialect.RMWOp.MAX - elif atomic_type == primitives.AtomicOpType.AND: - op = tt_dialect.RMWOp.AND - elif atomic_type == primitives.AtomicOpType.OR: - op = tt_dialect.RMWOp.OR - elif atomic_type == primitives.AtomicOpType.XOR: - op = tt_dialect.RMWOp.XOR - else: - raise NotImplementedError(f"unsupported atomic operation: {atomic_type}") - return _atomic_rmw(op, ptr, val, mask=mask) - - -@register_lowering(primitives.atomic_cas_p) -def _atomic_cas_lowering_rule(ctx: LoweringRuleContext, ptr, cmp, val): - _, cmp_aval, val_aval = ctx.avals_in - if isinstance(ptr.type, ir.RankedTensorType): - ptr_type = ir.RankedTensorType(ptr.type) - element_type = tt_dialect.PointerType(ptr_type.element_type) - result_type = ir.RankedTensorType.get( - ptr_type.shape, element_type.pointee_type, ptr_type.encoding - ) - else: - result_type = tt_dialect.PointerType(ptr.type).pointee_type - return tt_dialect.atomic_cas( - result_type, - ptr, - _ensure_ir_value(cmp, cmp_aval), - _ensure_ir_value(val, val_aval), - sem=tt_dialect.MemSemantic.ACQUIRE_RELEASE, - scope=tt_dialect.MemSyncScope.GPU, - ) - - def _associative_scan_lowering(body, ctx: LoweringRuleContext, args, axes): flat_args = tree_util.tree_leaves(args) (axis,) = axes @@ -1406,18 +1340,6 @@ def _multiple_of_rule(ctx: LoweringRuleContext, x, values: Sequence[int]): return x -@register_lowering(primitives.max_contiguous_p) -def _max_contiguous_rule(ctx: LoweringRuleContext, x, values: Sequence[int]): - [x_aval] = ctx.avals_in - assert len(x_aval.shape) == len(values) - _set_attr( - x, - "tt.contiguity", - ir.DenseIntElementsAttr.get(np.asarray(values, dtype=np.int32)), - ) - return x - - @register_lowering(sp.broadcast_to_p) def _broadcast_to_rule(ctx: LoweringRuleContext, x, shape: Sequence[int]): (x_aval,) = ctx.avals_in diff --git a/jax/_src/pallas/triton/primitives.py b/jax/_src/pallas/triton/primitives.py index 6381567f9f41..4ec52d55ff52 100644 --- a/jax/_src/pallas/triton/primitives.py +++ b/jax/_src/pallas/triton/primitives.py @@ -17,23 +17,33 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TypeAlias +import enum +from typing import Any, TypeAlias import jax from jax._src import core as jax_core from jax._src import effects +from jax._src import lax from jax._src import state +from jax._src import tree_util +from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import gpu as gpu_dialect from jax._src.lib.triton import dialect as tt_dialect from jax._src.pallas import core as pallas_core from jax._src.pallas import primitives as pallas_primitives from jax._src.pallas.triton import lowering +from jax._src.state import discharge as state_discharge +from jax._src.state import indexing +from jax._src.state import primitives as state_primitives from jax.interpreters import mlir import jax.numpy as jnp +import numpy as np Ref: TypeAlias = state.AbstractRef | state.TransformedRef +Slice = indexing.Slice + def approx_tanh(x: jax.Array) -> jax.Array: r"""Elementwise approximate hyperbolic tangent: :math:`\mathrm{tanh}(x)`. @@ -214,3 +224,348 @@ def store( mask=mask, eviction_policy=eviction_policy, ) + + +class AtomicOpType(enum.Enum): + XCHG = "xchg" + ADD = "add" + MAX = "max" + MIN = "min" + AND = "and" + OR = "or" + XOR = "xor" + + +atomic_rmw_p = jax_core.Primitive("atomic_rmw") + + +def _atomic_rmw_discharge_rule( + in_avals, out_avals, *args_flat, args_tree, atomic_type: AtomicOpType +): + del out_avals # Unused. + ref, transforms, val, mask = args_tree.unflatten(args_flat) + *prev_transforms, idx = transforms + ref = state_discharge.transform_array(ref, prev_transforms) + + if mask is not None: + raise NotImplementedError + + if atomic_type == AtomicOpType.ADD: + monoid = lambda x, y: x + y + elif atomic_type == AtomicOpType.MAX: + monoid = jnp.maximum + elif atomic_type == AtomicOpType.MIN: + monoid = jnp.minimum + else: + raise NotImplementedError(atomic_type) + + if all((isinstance(s, Slice) or not s.shape) for s in idx.indices): + indices = idx.indices + scalar_dims = [not isinstance(s, Slice) and not s.shape for s in indices] + slice_starts = [s.start if isinstance(s, Slice) else s for s in indices] + slice_sizes = tuple(s.size if isinstance(s, Slice) else 1 for s in indices) + out_ones = lax.dynamic_slice(ref, slice_starts, slice_sizes=slice_sizes) + val_indexer = tuple( + None if scalar else slice(None) for scalar in scalar_dims + ) + val = val[val_indexer] + val = monoid(val, out_ones) + x_new = lax.dynamic_update_slice(ref, val, start_indices=slice_starts) + out_indexer = tuple(0 if scalar else slice(None) for scalar in scalar_dims) + out = out_ones[out_indexer] + elif all(not isinstance(s, Slice) for s in idx.indices): + out = ref[idx.indices] + x_new = ref.at[idx.indices].set(monoid(out, val)) + else: + raise NotImplementedError + return (x_new,) + (None,) * (len(in_avals) - 1), out + + +state_discharge.register_discharge_rule(atomic_rmw_p)( + _atomic_rmw_discharge_rule +) + + +@atomic_rmw_p.def_effectful_abstract_eval +def _atomic_abstract_eval(*avals_flat, args_tree, atomic_type: AtomicOpType): + ref, _, _, _ = args_tree.unflatten(avals_flat) + if ref.dtype == jnp.dtype("float16") and atomic_type != AtomicOpType.ADD: + raise ValueError(f"`atomic_{atomic_type.value}` does not support f16.") + if ref.dtype in { + jnp.dtype("bool"), + jnp.dtype("int8"), + jnp.dtype("int16"), + jnp.bfloat16, + }: + raise ValueError( + f"`atomic_{atomic_type.value}` does not support {ref.dtype}." + ) + return pallas_primitives._swap_abstract_eval(*avals_flat, args_tree=args_tree) + + +def _atomic_rmw( + x_ref_or_view, + idx, + val, + *, + mask: Any | None = None, + atomic_type: AtomicOpType, +): + x_ref, transforms = state_primitives.get_ref_and_transforms( + x_ref_or_view, idx, "atomic_rmw" + ) + args_flat, args_tree = tree_util.tree_flatten((x_ref, transforms, val, mask)) + return atomic_rmw_p.bind( + *args_flat, args_tree=args_tree, atomic_type=atomic_type + ) + + +@lowering.register_lowering(atomic_rmw_p) +def _atomic_lowering_rule( + ctx: lowering.LoweringRuleContext, + *args_flat, + args_tree, + atomic_type: AtomicOpType, +): + block_info, *_ = ctx.block_infos + assert block_info is not None + ptr, indexers, val, mask = args_tree.unflatten(args_flat) + *_, value_aval, mask_aval = args_tree.unflatten(ctx.avals_in) + indexers = list(indexers) + if not indexers or not isinstance(indexers[-1], indexing.NDIndexer): + ref_aval = state.transform_type(indexers, ctx.avals_in[0]) + assert isinstance(ref_aval, state.AbstractRef) + indexers.append(indexing.NDIndexer.make_trivial_indexer(ref_aval.shape)) + if len(indexers) != 1: + raise NotImplementedError("Only single indexer is supported.") + idx = indexers[0] + ptr = lowering._compute_pointers_from_indices(ptr, block_info, idx) + val = lowering._ensure_ir_value(val, value_aval) + if mask is not None: + mask = lowering._ensure_ir_value(mask, mask_aval) + if atomic_type == AtomicOpType.XCHG: + op = tt_dialect.RMWOp.XCHG + elif atomic_type == AtomicOpType.ADD: + if isinstance(val.type, ir.IntegerType): + op = tt_dialect.RMWOp.ADD + else: + op = tt_dialect.RMWOp.FADD + elif atomic_type == AtomicOpType.MIN: + op = tt_dialect.RMWOp.MIN + elif atomic_type == AtomicOpType.MAX: + op = tt_dialect.RMWOp.MAX + elif atomic_type == AtomicOpType.AND: + op = tt_dialect.RMWOp.AND + elif atomic_type == AtomicOpType.OR: + op = tt_dialect.RMWOp.OR + elif atomic_type == AtomicOpType.XOR: + op = tt_dialect.RMWOp.XOR + else: + raise NotImplementedError(f"unsupported atomic operation: {atomic_type}") + return lowering._atomic_rmw(op, ptr, val, mask=mask) + + +def atomic_xchg(x_ref_or_view, idx, val, *, mask: Any | None = None): + """Atomically exchanges the given value with the value at the given index. + + Args: + x_ref_or_view: The ref to operate on. + idx: The indexer to use. + mask: TO BE DOCUMENTED. + + Returns: + The value at the given index prior to the aupdate. + """ + return _atomic_rmw( + x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.XCHG + ) + + +def atomic_add(x_ref_or_view, idx, val, *, mask: Any | None = None): + """Atomically computes ``x_ref_or_view[idx] += val``. + + Args: + x_ref_or_view: The ref to operate on. + idx: The indexer to use. + mask: TO BE DOCUMENTED. + + Returns: + The value at the given index prior to the atomic operation. + """ + return _atomic_rmw( + x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.ADD + ) + + +def atomic_max(x_ref_or_view, idx, val, *, mask: Any | None = None): + """Atomically computes ``x_ref_or_view[idx] = max(x_ref_or_view[idx], val)``. + + Args: + x_ref_or_view: The ref to operate on. + idx: The indexer to use. + mask: TO BE DOCUMENTED. + + Returns: + The value at the given index prior to the atomic operation. + """ + return _atomic_rmw( + x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.MAX + ) + + +def atomic_min(x_ref_or_view, idx, val, *, mask: Any | None = None): + """Atomically computes ``x_ref_or_view[idx] = min(x_ref_or_view[idx], val)``. + + Args: + x_ref_or_view: The ref to operate on. + idx: The indexer to use. + mask: TO BE DOCUMENTED. + + Returns: + The value at the given index prior to the atomic operation. + """ + return _atomic_rmw( + x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.MIN + ) + + +def atomic_and(x_ref_or_view, idx, val, *, mask: Any | None = None): + """Atomically computes ``x_ref_or_view[idx] &= val``. + + Args: + x_ref_or_view: The ref to operate on. + idx: The indexer to use. + mask: TO BE DOCUMENTED. + + Returns: + The value at the given index prior to the atomic operation. + """ + return _atomic_rmw( + x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.AND + ) + + +def atomic_or(x_ref_or_view, idx, val, *, mask: Any | None = None): + """Atomically computes ``x_ref_or_view[idx] |= val``. + + Args: + x_ref_or_view: The ref to operate on. + idx: The indexer to use. + mask: TO BE DOCUMENTED. + + Returns: + The value at the given index prior to the atomic operation. + """ + return _atomic_rmw( + x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.OR + ) + + +def atomic_xor(x_ref_or_view, idx, val, *, mask: Any | None = None): + """Atomically computes ``x_ref_or_view[idx] ^= val``. + + Args: + x_ref_or_view: The ref to operate on. + idx: The indexer to use. + mask: TO BE DOCUMENTED. + + Returns: + The value at the given index prior to the atomic operation. + """ + return _atomic_rmw( + x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.XOR + ) + + +atomic_cas_p = jax_core.Primitive("atomic_cas") + + +@atomic_cas_p.def_effectful_abstract_eval +def _atomic_cas_abstract_eval(ref_aval, cmp_aval, val_aval): + if cmp_aval.dtype != val_aval.dtype or cmp_aval.shape != val_aval.shape: + raise ValueError("cmp and val must have identical dtypes and shapes") + if ref_aval.shape: + raise ValueError("ref must be scalar.") + if cmp_aval.shape: + raise ValueError("cmp must be scalar.") + if val_aval.shape: + raise ValueError("val must be scalar.") + return jax_core.ShapedArray(val_aval.shape, val_aval.dtype), { + state.WriteEffect(0) + } + + +def atomic_cas(ref, cmp, val): + """Performs an atomic compare-and-swap of the value in the ref with the + + given value. + + Args: + ref: The ref to operate on. + cmp: The expected value to compare against. + val: The value to swap in. + + Returns: + The value at the given index prior to the atomic operation. + """ + return atomic_cas_p.bind(ref, cmp, val) + + +@state_discharge.register_discharge_rule(atomic_cas_p) +def _atomic_cas_discharge_rule(in_avals, out_avals, ref, cmp, val): + del in_avals, out_avals + new_val = jnp.where(ref == cmp, val, ref) + return (new_val, None, None), ref + + +@lowering.register_lowering(atomic_cas_p) +def _atomic_cas_lowering_rule(ctx: lowering.LoweringRuleContext, ptr, cmp, val): + _, cmp_aval, val_aval = ctx.avals_in + if isinstance(ptr.type, ir.RankedTensorType): + ptr_type = ir.RankedTensorType(ptr.type) + element_type = tt_dialect.PointerType(ptr_type.element_type) + result_type = ir.RankedTensorType.get( + ptr_type.shape, element_type.pointee_type, ptr_type.encoding + ) + else: + result_type = tt_dialect.PointerType(ptr.type).pointee_type + return tt_dialect.atomic_cas( + result_type, + ptr, + lowering._ensure_ir_value(cmp, cmp_aval), + lowering._ensure_ir_value(val, val_aval), + sem=tt_dialect.MemSemantic.ACQUIRE_RELEASE, + scope=tt_dialect.MemSyncScope.GPU, + ) + + +max_contiguous_p = jax_core.Primitive("max_contiguous") + +max_contiguous_p.def_impl(lambda x, **_: x) +mlir.register_lowering(max_contiguous_p, lambda _, x, **__: [x]) + + +def max_contiguous(x, values): + """A compiler hint that asserts the ``values`` first values of ``x`` are contiguous.""" + if not isinstance(values, (list, tuple)): + values = (values,) + return max_contiguous_p.bind(x, values=tuple(values)) + + +@max_contiguous_p.def_abstract_eval +def _max_contiguous_abstract_eval(aval, **_): + return aval + + +@lowering.register_lowering(max_contiguous_p) +def _max_contiguous_rule( + ctx: lowering.LoweringRuleContext, x, values: Sequence[int] +): + [x_aval] = ctx.avals_in + assert len(x_aval.shape) == len(values) + lowering._set_attr( + x, + "tt.contiguity", + ir.DenseIntElementsAttr.get(np.asarray(values, dtype=np.int32)), + ) + return x diff --git a/jax/experimental/pallas/triton.py b/jax/experimental/pallas/triton.py index 3878a5a8af0c..ea3e139e8cea 100644 --- a/jax/experimental/pallas/triton.py +++ b/jax/experimental/pallas/triton.py @@ -14,18 +14,18 @@ """Triton-specific Pallas APIs.""" -from jax._src.pallas.primitives import atomic_add as atomic_add -from jax._src.pallas.primitives import atomic_and as atomic_and -from jax._src.pallas.primitives import atomic_cas as atomic_cas -from jax._src.pallas.primitives import atomic_max as atomic_max -from jax._src.pallas.primitives import atomic_min as atomic_min -from jax._src.pallas.primitives import atomic_or as atomic_or -from jax._src.pallas.primitives import atomic_xchg as atomic_xchg -from jax._src.pallas.primitives import atomic_xor as atomic_xor -from jax._src.pallas.primitives import max_contiguous as max_contiguous from jax._src.pallas.triton.core import CompilerParams as CompilerParams from jax._src.pallas.triton.primitives import approx_tanh as approx_tanh +from jax._src.pallas.triton.primitives import atomic_add as atomic_add +from jax._src.pallas.triton.primitives import atomic_and as atomic_and +from jax._src.pallas.triton.primitives import atomic_cas as atomic_cas +from jax._src.pallas.triton.primitives import atomic_max as atomic_max +from jax._src.pallas.triton.primitives import atomic_min as atomic_min +from jax._src.pallas.triton.primitives import atomic_or as atomic_or +from jax._src.pallas.triton.primitives import atomic_xchg as atomic_xchg +from jax._src.pallas.triton.primitives import atomic_xor as atomic_xor from jax._src.pallas.triton.primitives import debug_barrier as debug_barrier from jax._src.pallas.triton.primitives import elementwise_inline_asm as elementwise_inline_asm from jax._src.pallas.triton.primitives import load as load +from jax._src.pallas.triton.primitives import max_contiguous as max_contiguous from jax._src.pallas.triton.primitives import store as store From 8f56da1b2c8bc5789659aa7b864845bf78e64c12 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 3 Mar 2026 03:38:41 -0800 Subject: [PATCH 032/100] [pallas:sc] Use `pltpu.CoreType` instead of the derepcated `pltpu.KernelType` PiperOrigin-RevId: 877857771 --- docs/pallas/tpu/sparsecore.ipynb | 2 +- docs/pallas/tpu/sparsecore.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/pallas/tpu/sparsecore.ipynb b/docs/pallas/tpu/sparsecore.ipynb index e13de9d6ae24..40aaf3aaf368 100644 --- a/docs/pallas/tpu/sparsecore.ipynb +++ b/docs/pallas/tpu/sparsecore.ipynb @@ -562,7 +562,7 @@ " ),\n", " out_specs=pl.BlockSpec((gather_window_size, value_dim), lambda i: (i, 0)),\n", " compiler_params=pltpu.CompilerParams(\n", - " kernel_type=pltpu.KernelType.SC_VECTOR_SUBCORE,\n", + " kernel_type=pltpu.CoreType.SC_VECTOR_SUBCORE,\n", " dimension_semantics=(pltpu.PARALLEL,),\n", " ),\n", " )\n", diff --git a/docs/pallas/tpu/sparsecore.md b/docs/pallas/tpu/sparsecore.md index c2a9135aff60..56b92bb5cdc7 100644 --- a/docs/pallas/tpu/sparsecore.md +++ b/docs/pallas/tpu/sparsecore.md @@ -320,7 +320,7 @@ def gather_add_one(x, indices): ), out_specs=pl.BlockSpec((gather_window_size, value_dim), lambda i: (i, 0)), compiler_params=pltpu.CompilerParams( - kernel_type=pltpu.KernelType.SC_VECTOR_SUBCORE, + kernel_type=pltpu.CoreType.SC_VECTOR_SUBCORE, dimension_semantics=(pltpu.PARALLEL,), ), ) From 5a07b03c0ac344752a5f1e6e71f41e9afcfd1d79 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 3 Mar 2026 05:24:00 -0800 Subject: [PATCH 033/100] [jaxlib] Force the use of `typing_extensions.CapsuleType` in the _jax extension `types.CapsuleType` is not available in Python 3.11. PiperOrigin-RevId: 877892414 --- jaxlib/BUILD | 7 +++++ jaxlib/_jax/ffi.pyi | 1 + jaxlib/backport_stubs.py | 64 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 72 insertions(+) create mode 100644 jaxlib/backport_stubs.py diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 461a5b5de32f..43dd76c1c3d0 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -16,6 +16,7 @@ load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") +load("@rules_python//python:py_binary.bzl", "py_binary") load( "//jaxlib:jax.bzl", "cc_proto_library", @@ -333,6 +334,7 @@ nanobind_pywrap_extension( "//jaxlib/mlir:ir", ], enable_stub_generation = True, + postprocess_stubgen = "//jaxlib:backport_stubs", pytype_deps = py_deps(["numpy"]), pytype_srcs = glob(["_jax/*.pyi"]), stub_replacement_patterns = { @@ -445,6 +447,11 @@ nanobind_pywrap_extension( }), ) +py_binary( + name = "backport_stubs", + srcs = ["backport_stubs.py"], +) + cc_library( name = "pprof_profile_builder", srcs = ["pprof_profile_builder.cc"], diff --git a/jaxlib/_jax/ffi.pyi b/jaxlib/_jax/ffi.pyi index 2c8d9fedfdf8..3ad2e14ee005 100644 --- a/jaxlib/_jax/ffi.pyi +++ b/jaxlib/_jax/ffi.pyi @@ -15,6 +15,7 @@ """Python bindings for the XLA FFI.""" import enum + import numpy import typing_extensions diff --git a/jaxlib/backport_stubs.py b/jaxlib/backport_stubs.py new file mode 100644 index 000000000000..5009ecd8232e --- /dev/null +++ b/jaxlib/backport_stubs.py @@ -0,0 +1,64 @@ +# Copyright 2026 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Backports generated .pyi stubs to support older Python versions.""" + +import argparse +import re + +_REPLACEMENTS = [ + (re.compile(r"\btypes\.CapsuleType\b"), "typing_extensions.CapsuleType"), +] + +_TYPES_USAGE = re.compile(r"\btypes\.") +_IMPORT_TYPES = re.compile(r"^import types\n", re.MULTILINE) +_TYPING_EXTENSIONS_USAGE = re.compile(r"\btyping_extensions\.") +_IMPORT_TYPING_EXTENSIONS = re.compile( + r"^import typing_extensions\b", re.MULTILINE +) +_DOCSTRING_END = re.compile(r'^""".*?"""\n', re.MULTILINE | re.DOTALL) + + +def backport(content: str) -> str: + for pattern, replacement in _REPLACEMENTS: + content = pattern.sub(replacement, content) + + if not _TYPES_USAGE.search(content): + content = _IMPORT_TYPES.sub("", content) + + if _TYPING_EXTENSIONS_USAGE.search( + content + ) and not _IMPORT_TYPING_EXTENSIONS.search(content): + content = _DOCSTRING_END.sub( + r"\g<0>\nimport typing_extensions\n", content, count=1 + ) + + return content + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("files", nargs="+", metavar="FILE") + args = parser.parse_args() + + for path in args.files: + with open(path) as f: + content = f.read() + content = backport(content) + with open(path, "w") as f: + f.write(content) + + +if __name__ == "__main__": + main() From ba162664ea99ad6ff403fe19b1800f7bf62633f1 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 3 Mar 2026 06:03:20 -0800 Subject: [PATCH 034/100] [jaxlib] Fixed the type of `Traceback.__add__` PiperOrigin-RevId: 877905599 --- jaxlib/_jax/__init__.pyi | 4 ++-- jaxlib/traceback.cc | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/jaxlib/_jax/__init__.pyi b/jaxlib/_jax/__init__.pyi index 3101409ed24a..712c27ee74c3 100644 --- a/jaxlib/_jax/__init__.pyi +++ b/jaxlib/_jax/__init__.pyi @@ -17,7 +17,7 @@ import enum import inspect import traceback import types -from typing import Annotated, Any, TypeAlias, overload +from typing import Annotated, Any, Self, TypeAlias, overload import numpy from numpy.typing import NDArray @@ -1376,7 +1376,7 @@ class Traceback: @property def frames(self) -> list[Frame]: ... - def __add__(self, other: object) -> object: + def __add__(self, other: Self) -> Self: """Concatenates two tracebacks.""" def raw_frames(self) -> tuple[list[types.CodeType], list[int]]: ... diff --git a/jaxlib/traceback.cc b/jaxlib/traceback.cc index 2d506f12c802..ce97f7c55218 100644 --- a/jaxlib/traceback.cc +++ b/jaxlib/traceback.cc @@ -426,7 +426,8 @@ void Traceback::Register(nb::module_& m) { } return traceback; }, - nb::is_method(), nb::arg("other"), "Concatenates two tracebacks."); + nb::is_method(), nb::arg("other"), "Concatenates two tracebacks.", + nb::sig("def __add__(self, other: typing.Self) -> typing.Self")); type.attr("raw_frames") = nb::cpp_function( [](const Traceback& tb) -> nb::tuple { From b41f5838dee774dbd911158ed5ee712334a1c4b4 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 3 Mar 2026 14:47:56 +0000 Subject: [PATCH 035/100] Fixed all remaining Pyrefly errors in jax/_src --- jax/_src/ad_checkpoint.py | 8 ++- jax/_src/api.py | 5 +- jax/_src/checkify.py | 21 +++--- jax/_src/compilation_cache.py | 6 +- jax/_src/compiler.py | 2 +- jax/_src/compute_on.py | 6 +- jax/_src/custom_transpose.py | 5 ++ jax/_src/debugger/colab_debugger.py | 2 + jax/_src/earray.py | 8 ++- jax/_src/ffi.py | 4 +- jax/_src/internal_test_util/test_harnesses.py | 2 +- jax/_src/interpreters/mlir.py | 64 +++++++++++++------ jax/_src/interpreters/remat.py | 6 +- jax/_src/lax/ann.py | 21 ++++-- jax/_src/lax/control_flow/conditionals.py | 2 +- jax/_src/lax/control_flow/loops.py | 2 +- jax/_src/lax/lax.py | 25 +++++--- jax/_src/lax/parallel.py | 2 +- jax/_src/lax_reference.py | 9 +-- jax/_src/mesh.py | 5 +- jax/_src/named_sharding.py | 16 ++++- jax/_src/numpy/array_constructors.py | 2 +- jax/_src/numpy/einsum.py | 8 ++- jax/_src/pallas/mosaic/core.py | 4 +- jax/_src/pallas/mosaic/lowering.py | 2 +- jax/_src/pallas/mosaic/sc_lowering.py | 5 +- jax/_src/pjit.py | 2 +- jax/_src/pmap.py | 2 +- jax/_src/prng.py | 14 ++-- jax/_src/shard_map.py | 19 +++--- jax/_src/sharding_impls.py | 2 +- jax/_src/tpu/linalg/eigh.py | 2 +- jax/_src/tpu_custom_call.py | 14 ++-- 33 files changed, 191 insertions(+), 106 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 12069572e453..c49dfed4b5c0 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -1037,21 +1037,23 @@ def lin(self, nzs_in, *primals): primals_out, f_lin = api.linearize(self.traced, *primals) return primals_out, primals - def linearized(self, primals, *tangents): + def linearized(self, primals, *tangents): # pyrefly: ignore[bad-param-name-override] _, f_lin = api.linearize(self.traced, *primals) return f_lin(*tangents) class CheckpointName(VJPHiPrimitive): + name: str + def __init__(self, name, aval): self.in_avals = aval, self.out_aval = aval self.params = dict(name=name) super().__init__() - def expand(self, x): + def expand(self, x): # pyrefly: ignore[bad-override] return x - def remat(self, policy, x): + def remat(self, policy, x): # pyrefly: ignore[bad-override] saveable = self.name in policy rem = partial(primal_left_tangent_right, x) if saveable else lambda x: x return x, rem diff --git a/jax/_src/api.py b/jax/_src/api.py index 3cff0f5f0ce8..a32a480c659b 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -1845,8 +1845,8 @@ def cache_miss(*args, **kwargs): in_handler=in_handler, out_handler=out_handler, out_pytree_def=out_pytree_def, - input_devices=in_handler.local_devices, - input_indices=in_handler.input_indices, + input_devices=in_handler.local_devices, # pyrefly: ignore[bad-argument-type] + input_indices=in_handler.input_indices, # pyrefly: ignore[bad-argument-type] input_array_shardings=in_handler.in_shardings, out_avals=out_handler.out_avals, out_array_shardings=out_array_shardings, @@ -2087,6 +2087,7 @@ def linearize(fun: Callable, *primals, has_aux: bool = False (in_tree, out_tree), out_pvals), consts) if has_aux: [aux] = maybe_aux + assert aux_tree is not None return out_primal_py, lifted_jvp, tree_unflatten(aux_tree, aux) else: [] = maybe_aux diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 92b879150922..eb964d23d94c 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -87,12 +87,12 @@ def __init__(self, traceback_info): def __init_subclass__(cls): jtu.register_pytree_node_class(cls) - def tree_flatten(self): + def tree_flatten(self, /): return ([], self.traceback_info) @classmethod - def tree_unflatten(cls, metadata, payload): - del payload + def tree_unflatten(cls, metadata, payload, /): + del payload # Unused. return cls(metadata) def get_effect_type(self) -> ErrorEffect: @@ -134,7 +134,8 @@ def tree_flatten(self): return ([], (self.traceback_info, self.prim)) @classmethod - def tree_unflatten(cls, metadata, _): + def tree_unflatten(cls, metadata, payload): + del payload return cls(*metadata) def get_effect_type(self): @@ -156,7 +157,7 @@ def tree_flatten(self): @classmethod def tree_unflatten(cls, metadata, payload): - return cls(*metadata, payload[0]) + return cls(*metadata, payload=payload[0]) def __str__(self): return (f'out-of-bounds indexing for array of ' @@ -227,11 +228,11 @@ def get(self) -> str | None: def get_exception(self) -> JaxException | None: """Returns Python exception if error happened, None if no error happened.""" - if any(map(np.shape, self._pred.values())): + if any(np.shape(v) for v in self._pred.values()): return self._get_batched_exception() else: - min_code = None - cur_effect = None + min_code: Int | None = None + cur_effect: ErrorEffect | None = None for error_effect, code in self._code.items(): if self._pred[error_effect]: if min_code is None or code < min_code: @@ -255,8 +256,8 @@ def _get_batched_exception(self) -> BatchedError | None: shape = np.shape(list(self._pred.values())[0]) error_mapping = {} for idx in np.ndindex(*shape): - min_code = None - cur_effect = None + min_code: Int | None = None + cur_effect: ErrorEffect | None = None for error_effect, code in self._code.items(): if self._pred[error_effect][idx]: # type: ignore if min_code is None or code[idx] < min_code: # type: ignore[index] diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py index 4ff52bb4f974..3bd4c6f748e4 100644 --- a/jax/_src/compilation_cache.py +++ b/jax/_src/compilation_cache.py @@ -123,9 +123,13 @@ def __init__(self, base_cache: CacheInterface): self._verified_keys: set[str] = set() @property - def _path(self): + def _path(self): # pyrefly: ignore[bad-override] return self._base_cache._path + @_path.setter + def _path(self, value): + self._base_cache._path = value + def get(self, key: str) -> bytes | None: if key not in self._verified_keys: # Force a recompile the first time we see a key. diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 1fbdc189b27b..144970f58787 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -334,7 +334,7 @@ def backend_compile_and_load( # TODO(dsuo): Simplify this logic once we delete _jax.CompileOnlyPyClient. if isinstance(backend, _jax.CompileOnlyPyClient): if host_callbacks: - return backend.compile( + return backend.compile( # type: ignore module, executable_devices=executable_devices, # type: ignore compile_options=options, diff --git a/jax/_src/compute_on.py b/jax/_src/compute_on.py index 74d7edb3ea1f..2ad8a9a713a1 100644 --- a/jax/_src/compute_on.py +++ b/jax/_src/compute_on.py @@ -140,8 +140,10 @@ def _compute_on_lowering(ctx, *args, jaxpr, compute_type, out_memory_spaces): tokens, out_nodes = split_list(out_nodes, [len(effects)]) tokens_out = ctx.tokens_in.update_tokens(mlir.TokenSet(zip(effects, tokens))) ctx.set_tokens_out(tokens_out) - return [mlir.wrap_with_memory_kind(on, core.mem_space_to_kind(oms), out_aval) - for on, out_aval, oms in zip(out_nodes, ctx.avals_out, out_memory_spaces)] + return [ + mlir.wrap_with_memory_kind(on, core.mem_space_to_kind(oms), out_aval) # pyrefly: ignore[bad-argument-type] + for on, out_aval, oms in zip(out_nodes, ctx.avals_out, out_memory_spaces) + ] mlir.register_lowering(compute_on_p, _compute_on_lowering) diff --git a/jax/_src/custom_transpose.py b/jax/_src/custom_transpose.py index 2d2c0f4285f2..b8ea0d19a360 100644 --- a/jax/_src/custom_transpose.py +++ b/jax/_src/custom_transpose.py @@ -83,6 +83,11 @@ def def_transpose(self, transpose: Callable): @traceback_util.api_boundary def __call__(self, out_types, res_arg, lin_arg): + if self.transpose is None: + raise ValueError( + "Missing a transpose function. Use @def_transpose to define one." + ) + _, res_tree = tree_flatten(res_arg) _, lin_tree = tree_flatten(lin_arg) args_flat, in_tree = tree_flatten((res_arg, lin_arg)) diff --git a/jax/_src/debugger/colab_debugger.py b/jax/_src/debugger/colab_debugger.py index 57d785a87613..df576057fede 100644 --- a/jax/_src/debugger/colab_debugger.py +++ b/jax/_src/debugger/colab_debugger.py @@ -30,6 +30,8 @@ from google.colab import output try: import pygments + import pygments.lexers + import pygments.formatters IS_PYGMENTS_ENABLED = True except ImportError: IS_PYGMENTS_ENABLED = False diff --git a/jax/_src/earray.py b/jax/_src/earray.py index b45d371057e4..ca8b284ba5d6 100644 --- a/jax/_src/earray.py +++ b/jax/_src/earray.py @@ -29,14 +29,18 @@ # EArray is an Array that can contain extended dtypes. class EArray(basearray.Array): - __slots__ = ['aval', '_data'] + __slots__ = ['_aval', '_data'] __hash__ = None # type: ignore[assignment] __array_priority__ = 100 def __init__(self, aval, data): - self.aval = aval + self._aval = aval self._data = data + @property + def aval(self): + return self._aval + def block_until_ready(self): _ = self._data.block_until_ready() return self diff --git a/jax/_src/ffi.py b/jax/_src/ffi.py index cca0ae121477..739f56e2b19a 100644 --- a/jax/_src/ffi.py +++ b/jax/_src/ffi.py @@ -177,7 +177,7 @@ def include_dir() -> str: def _aval_shape(aval: core.AbstractValue) -> Shape: - return () if aval is core.abstract_token else core.physical_aval(aval).shape # pytype: disable=attribute-error + return () if aval is core.abstract_token else core.physical_aval(aval).shape # pytype: disable=attribute-error # pyrefly: ignore[missing-attribute] def _convert_layout_for_lowering( @@ -676,7 +676,7 @@ def ffi_call_lowering( operand_output_aliases=dict(input_output_aliases), api_version=custom_call_api_version, backend_config=legacy_backend_config) - return rule(ctx, *operands, **_unwrap_kwargs_hashable(attributes)) + return rule(ctx, *operands, **_unwrap_kwargs_hashable(attributes)) # pyrefly: ignore[bad-return] def ffi_batching_rule( diff --git a/jax/_src/internal_test_util/test_harnesses.py b/jax/_src/internal_test_util/test_harnesses.py index 6e26e336b87c..a313fcea4283 100644 --- a/jax/_src/internal_test_util/test_harnesses.py +++ b/jax/_src/internal_test_util/test_harnesses.py @@ -2483,7 +2483,7 @@ def _make_reduce_harness(name, *, dtype=np.float32): # The dtype of first operand def reducer(*args): init_val = np.array(init_value, dtype=dtype) - init_values = [init_val] + init_values: list[np.ndarray] = [init_val] if nr_operands == 2: init_values.append(np.array(0, dtype=np.int32)) return lax.reduce(args[0:nr_operands], tuple(init_values), diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index ee305e09365d..8f9540590bea 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -86,7 +86,7 @@ def _is_not_block_argument(x: IrValues) -> bool: return not isinstance(x, ir.BlockArgument) -def dense_int_elements(xs) -> ir.DenseIntElementsAttr: +def dense_int_elements(xs) -> ir.DenseElementsAttr: return ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64)) dense_int_array = ir.DenseI64ArrayAttr.get @@ -94,8 +94,7 @@ def dense_int_elements(xs) -> ir.DenseIntElementsAttr: def i32_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), i) def i64_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), i) -def shape_tensor(sizes: Sequence[int | ir.RankedTensorType] - ) -> ir.RankedTensorType: +def shape_tensor(sizes: Sequence[int | ir.RankedTensorType]) -> IrValues: int1d = aval_to_ir_type(core.ShapedArray((1,), np.int32)) i32_type = aval_to_ir_type(core.ShapedArray((), np.int32)) def lower_dim(d): @@ -107,7 +106,7 @@ def lower_dim(d): return hlo.reshape(int1d, d) ds = map(lower_dim, sizes) if not ds: - return type_cast(ir.RankedTensorType, ir_constant(np.array([], np.int32))) + return ir_constant(np.array([], np.int32)) elif len(ds) == 1: # pyrefly: ignore[bad-argument-type] # pyrefly#2385 return ds[0] # pyrefly: ignore[bad-index] # pyrefly#2385 else: @@ -259,6 +258,7 @@ def ir_constant( A representation of the constant as an IR value or sequence of IR values. """ if const_lowering is not None: + # pyrefly: ignore[no-matching-overload] if np.shape(val) and (c_val := const_lowering.get((id(val), aval))) is not None: return c_val for t in type(val).__mro__: @@ -1334,7 +1334,7 @@ def lower_jaxpr_to_module( raise ValueError( "Cannot lower jaxpr with verifier errors. " + dump_module_message(ctx.module, "verification")) - except ir.MLIRError as e: + except ir.MLIRError as e: # pyrefly: ignore[missing-attribute] msg_lines = ["Cannot lower jaxpr with verifier errors:"] def emit_diagnostic_info(d): msg_lines.append(f"\t{d.message}") @@ -1365,7 +1365,7 @@ def _set_up_aliases(input_output_aliases, avals_in, avals_out, donated_args, arg_memory_kinds, result_memory_kinds, in_layouts, out_layouts, result_shardings): if input_output_aliases is None: - input_output_aliases = [None] * len(avals_in) + input_output_aliases: list[int | None] = [None] * len(avals_in) else: input_output_aliases = list(input_output_aliases) # To match-up in-avals to out-avals we only care about the number of @@ -1438,7 +1438,7 @@ def _set_up_aliases(input_output_aliases, avals_in, avals_out, results_not_matched = collections.defaultdict(collections.deque) for i, (aval, rm) in enumerate(zip(avals_out, result_memory_kinds)): if i not in aliased_output_ids and aval is not core.abstract_token: - results_not_matched[(aval.size, rm)].append(i) + results_not_matched[(aval.size, rm)].append(i) # pyrefly: ignore[missing-attribute] # For each donated argument that hasn't been aliased or donated to XLA, try to # find an output array with matching size ignoring shapes. If a matching @@ -1451,7 +1451,11 @@ def _set_up_aliases(input_output_aliases, avals_in, avals_out, # then try to find an output array with matching size. if (out_donated_args[input_idx] and avals_in[input_idx] is not core.abstract_token): # pyrefly: ignore[bad-index] # pyrefly#2385 - key = (avals_in[input_idx].size, arg_memory_kinds[input_idx]) # pyrefly: ignore[bad-index] # pyrefly#2385 + key = ( + # pyrefly: ignore[missing-attribute] + avals_in[input_idx].size, # pyrefly: ignore[bad-index] # pyrefly#2385 + arg_memory_kinds[input_idx], + ) if results_not_matched.get(key, ()): # XLA donate the argument because there's a matching output array. results_not_matched[key].popleft() @@ -2067,6 +2071,7 @@ def write(v: core.Var, node: IrValues): eqn_name_stack = name_stack + eqn.source_info.name_stack if jaxlib_extension_version >= 409: + assert outer_traceback is not None traceback = (eqn.source_info.traceback or xc.Traceback()) + outer_traceback else: traceback = eqn.source_info.traceback @@ -2103,6 +2108,7 @@ def write(v: core.Var, node: IrValues): assert len(out_nodes) == len(eqn.outvars), (out_nodes, eqn) if ordered_effects: + assert tokens_out is not None tokens = tokens.update_tokens(tokens_out) foreach(write, eqn.outvars, out_nodes) @@ -2110,6 +2116,16 @@ def write(v: core.Var, node: IrValues): return tuple(read(v) for v in jaxpr.outvars), tokens +class CachedLoweringRule(Protocol): + def __call__( + self, + ctx: LoweringRuleContext, + *args: ir.Value | Sequence[ir.Value], + **kwargs: Any, + ) -> tuple[Sequence[ir.Value | Sequence[ir.Value]], bool]: + ... + + def _cached_lowering( ctx: ModuleContext, eqn: core.JaxprEqn, @@ -2145,9 +2161,9 @@ def _cached_lowering( avals_out = map(lambda v: v.aval, eqn.outvars) cache_entry = _emit_lowering_rule_as_fun( partial(_uncached_lowering, eqn.primitive, eqn.ctx, eqn.effects), - ctx, eqn.ctx, eqn.primitive, ordered_effects, avals_in, avals_out, + ctx, eqn.ctx, eqn.primitive, ordered_effects, avals_in, avals_out, # pyrefly: ignore[bad-argument-type] # pyrefly#2385 **params, - ) # pyrefly: ignore[bad-argument-type] # pyrefly#2385 + ) ctx.lowering_cache[cache_key] = cache_entry tokens_in_args = tuple(tokens_in.get(eff) for eff in ordered_effects) @@ -2172,7 +2188,7 @@ def _cached_lowering( def _emit_lowering_rule_as_fun( - lowering_rule: LoweringRule, + lowering_rule: CachedLoweringRule, ctx: ModuleContext, eqn_ctx: core.JaxprEqnContext, primitive: core.Primitive, @@ -2220,15 +2236,19 @@ def _emit_lowering_rule_as_fun( traceback=None, avals_in=avals_in, avals_out=avals_out, tokens_in=TokenSet(zip(ordered_effects, token_args)), - tokens_out=None, jaxpr_eqn_ctx=eqn_ctx, dim_var_values=dim_var_values, + tokens_out=None, jaxpr_eqn_ctx=eqn_ctx, + dim_var_values=flatten_ir_values(dim_var_values), const_lowering=const_lowering) with source_info_to_location( ctx, primitive, source_info_util.new_name_stack(), None ): outs, inline = lowering_rule(sub_ctx, *unflattened_args, **params) if sub_ctx.tokens_out: - outs = [*[sub_ctx.tokens_out.get(eff) for eff in ordered_effects], *outs] - outs = flatten_ir_values(outs) + outs = [ + *(sub_ctx.tokens_out.get(eff) for eff in ordered_effects), + *outs # pyrefly: ignore[not-iterable] + ] + outs = flatten_ir_values(outs) # pyrefly: ignore[bad-argument-type] func_dialect.return_(outs) return LoweringCacheValue(func_op, output_types, const_args, const_arg_avals, inline) @@ -2394,16 +2414,18 @@ def lower_per_platform(ctx: LoweringRuleContext, assert kept_rules # If there is a single rule left just apply the rule, without conditionals. if len(kept_rules) == 1: - output = kept_rules[0](ctx, *rule_args, **rule_kwargs) + output = type_cast( + Sequence[IrValues], kept_rules[0](ctx, *rule_args, **rule_kwargs) + ) + flat_output = flatten_ir_values(output) foreach( lambda o: wrap_compute_type_in_place(ctx, _get_owner(o)), - filter(_is_not_block_argument, flatten_ir_values(output)), + filter(_is_not_block_argument, flat_output), ) foreach( - lambda o: wrap_xla_metadata_in_place(ctx, _get_owner(o)), - flatten_ir_values(output), + lambda o: wrap_xla_metadata_in_place(ctx, _get_owner(o)), flat_output ) - return output + return flat_output assert len(platforms) > 1 and len(kept_rules) >= 2, (platforms, kept_rules) assert len(ctx.dim_var_values) >= 1, "Must have a platform_index variable" @@ -2434,7 +2456,9 @@ def lower_per_platform(ctx: LoweringRuleContext, inner_ctx = ctx.replace(platforms=platforms_for_this_rule) branch = case_op.regions[i].blocks.append() with ir.InsertionPoint(branch): - output = rule(inner_ctx, *rule_args, **rule_kwargs) + output = type_cast( + Sequence[IrValues], rule(inner_ctx, *rule_args, **rule_kwargs) + ) try: out_nodes = flatten_ir_values(output) except TypeError as e: diff --git a/jax/_src/interpreters/remat.py b/jax/_src/interpreters/remat.py index 2310d6333caa..dc7528123940 100644 --- a/jax/_src/interpreters/remat.py +++ b/jax/_src/interpreters/remat.py @@ -59,8 +59,10 @@ def f_rem(rs, *args): return out_ft.unflatten(), Partial(f_rem, map(reduce_precision, rs)) class RematTracer(core.Tracer): + _trace: RematTrace # pyrefly: ignore[bad-override] + def __init__(self, trace, x, jaxpr_tracer): - self._trace = trace # type: ignore + self._trace = trace self.val = x self.tracer = jaxpr_tracer @@ -139,4 +141,4 @@ def new_arg(a): [*out_primals, *rem_consts], dbg.with_unknown_names(), src) fwd_trace.invalidate() fwd_jaxpr = core.ClosedJaxpr(fwd_jaxpr_, fwd_consts) - return fwd_jaxpr, rem_jaxpr, len(rem_consts) + return fwd_jaxpr, rem_jaxpr, len(rem_consts) # pyrefly: ignore[bad-argument-type] # pyrefly#2385 diff --git a/jax/_src/lax/ann.py b/jax/_src/lax/ann.py index 117bed2eae2c..9c4c522c7b3b 100644 --- a/jax/_src/lax/ann.py +++ b/jax/_src/lax/ann.py @@ -299,7 +299,9 @@ def _approx_top_k_lowering(ctx, operand, *, k, init_arg = hlo.constant(ir.DenseElementsAttr.get(np.int32(-1))) init_val_array = _get_init_val_literal(ctx.avals_in[0].dtype, is_max_k) - init_val = mlir.ir_constant(init_val_array.reshape(())) + init_vals = mlir.flatten_ir_values( + [mlir.ir_constant(init_val_array.reshape(())) + ]) backend_config = { "reduction_dim" : mlir.i64_attr(reduction_dimension), @@ -313,16 +315,19 @@ def _approx_top_k_lowering(ctx, operand, *, k, if all(core.is_constant_shape(aval_out.shape) for aval_out in ctx.avals_out): result_shapes = None else: - result_shapes = [ + result_shapes = mlir.flatten_ir_values( mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, aval_out.shape)) - for aval_out in ctx.avals_out] + for aval_out in ctx.avals_out + ) if core.is_constant_dim(k): backend_config["top_k"] = mlir.i64_attr(k) out = mlir.custom_call( "ApproxTopK", - result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], - operands=[operand, iota, init_val, init_arg], + result_types=mlir.flatten_ir_types( + mlir.aval_to_ir_type(aval) for aval in ctx.avals_out + ), + operands=[operand, iota, *init_vals, init_arg], called_computations=[comparator.name.value], backend_config=backend_config, result_shapes=result_shapes) @@ -330,8 +335,10 @@ def _approx_top_k_lowering(ctx, operand, *, k, k_value, = mlir.eval_dynamic_shape_as_vals(ctx, (k,)) out = mlir.custom_call( "stablehlo.dynamic_approx_top_k", - result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], - operands=[operand, iota, init_val, init_arg, k_value], + result_types=mlir.flatten_ir_types( + mlir.aval_to_ir_type(aval) for aval in ctx.avals_out + ), + operands=[operand, iota, *init_vals, init_arg, k_value], called_computations=[comparator.name.value], backend_config=backend_config, result_shapes=result_shapes) diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 6907210bbe36..15f29899f771 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -1200,7 +1200,7 @@ def _platform_index_lowering(ctx: mlir.LoweringRuleContext, def lower_constant(ctx: mlir.LoweringRuleContext, *, i: int) -> Sequence[ir.Value]: v = mlir.ir_constant(np.int32(i)) - return [v] + return mlir.flatten_ir_values([v]) platform_rules: dict[str, mlir.LoweringRule] = {} default_rule = None diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 1a0339d6115c..aa53b862714a 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -2048,7 +2048,7 @@ def fun(*args): hlo.return_([*mlir.flatten_ir_values(out_tokens), *mlir.flatten_ir_values(x), *mlir.flatten_ir_values(y), - *mlir.flatten_ir_values(new_z)]) + *mlir.flatten_ir_values(new_z)]) # pyrefly: ignore[bad-argument-type] outputs = mlir.unflatten_ir_values_like_types(while_op.results, loop_carry_types) tokens, _, _, z = util.split_list(outputs, [num_tokens, cond_nconsts, body_nconsts]) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 0cc0a8a4cdec..a4c62eb38133 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1636,7 +1636,8 @@ def _convert_element_type( if new_dtype == old_dtype: if sharding is None: return operand - if isinstance(operand, core.Tracer) and operand.aval.sharding == sharding: + if (isinstance(operand, core.Tracer) and + operand.aval.sharding == sharding): # pyrefly: ignore[missing-attribute] return operand if sharding is not None or weak_type: raise NotImplementedError @@ -3584,7 +3585,7 @@ def full_like(x: ArrayLike | DuckTypedArray, return dtype._rules.full(fill_shape, fill_value, dtype) # type: ignore[union-attr] if sharding is None and shape is None and isinstance(x, core.Tracer): - sharding = x.aval.sharding + sharding = x.aval.sharding # pyrefly: ignore[missing-attribute] else: # If `x` has a sharding but no `_committed` attribute # (in case of ShapeDtypeStruct), default it to True. @@ -8228,9 +8229,12 @@ def _top_k_lower(ctx, operand, k, axis): out_values_aval, out_indices_aval, = ctx.avals_out results = mlir.custom_call( "stablehlo.dynamic_top_k", - result_types=[mlir.aval_to_ir_type(out_values_aval), - mlir.aval_to_ir_type(out_indices_aval)], - operands=[operand, k_value]).results + result_types=mlir.flatten_ir_types([ + mlir.aval_to_ir_type(out_values_aval), + mlir.aval_to_ir_type(out_indices_aval) + ]), + operands=[operand, k_value], + ).results # Move last dimension back into place if perm is not None: @@ -8419,10 +8423,13 @@ def _rng_bit_generator_lowering( mlir.eval_dynamic_shape(ctx, out_vals_aval.shape)) out_key, out_vals = mlir.custom_call( "stablehlo.dynamic_rng_bit_generator", - result_types=[key.type, - mlir.aval_to_ir_type(core.ShapedArray(shape, rbg_dtype))], - operands=[key, output_shape], - extra_attributes=dict(rng_algorithm=algorithm_attr)).results + result_types=mlir.flatten_ir_types([ + key.type, + mlir.aval_to_ir_type(core.ShapedArray(shape, rbg_dtype)) + ]), + operands=mlir.flatten_ir_values([key, output_shape]), + extra_attributes=dict(rng_algorithm=algorithm_attr), + ).results else: out_key, out_vals = hlo.RngBitGeneratorOp( key.type, diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index c9d2cfe7fb94..89e9a1e27f70 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -1477,7 +1477,7 @@ def _ragged_all_to_all_lowering( if not all(split_count == len(g) for g in replica_groups): raise ValueError('Replica groups must be equally sized') - ragged_all_to_all_attrs = { + ragged_all_to_all_attrs: dict[str, ir.Attribute] = { "replica_groups": _replica_groups_hlo(replica_groups) } is_spmd = isinstance( diff --git a/jax/_src/lax_reference.py b/jax/_src/lax_reference.py index ab035fecc31a..92b6d5e5d39f 100644 --- a/jax/_src/lax_reference.py +++ b/jax/_src/lax_reference.py @@ -241,8 +241,8 @@ def conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation, def dot_general(lhs, rhs, dimension_numbers): (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers new_id = itertools.count() - lhs_axis_ids = [next(new_id) for _ in lhs.shape] - rhs_axis_ids = [next(new_id) for _ in rhs.shape] + lhs_axis_ids: list[int | None] = [next(new_id) for _ in lhs.shape] + rhs_axis_ids: list[int | None] = [next(new_id) for _ in rhs.shape] lhs_out_axis_ids = lhs_axis_ids[:] rhs_out_axis_ids = rhs_axis_ids[:] @@ -267,8 +267,9 @@ def dot_general(lhs, rhs, dimension_numbers): batch_ids + lhs_out_axis_ids + rhs_out_axis_ids) assert lhs.dtype == rhs.dtype dtype = np.float32 if lhs.dtype == dtypes.bfloat16 else None - out = np.einsum(lhs, lhs_axis_ids, rhs, rhs_axis_ids, out_axis_ids, - dtype=dtype) + out = np.einsum( # pyrefly: ignore[no-matching-overload] + lhs, lhs_axis_ids, rhs, rhs_axis_ids, out_axis_ids, dtype=dtype + ) return out.astype(dtypes.bfloat16) if lhs.dtype == dtypes.bfloat16 else out def ragged_dot( diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index f0210d5d9b82..8ebf84d625f7 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -251,6 +251,7 @@ class Mesh(BaseMesh, contextlib.ContextDecorator): devices: np.ndarray axis_names: tuple[MeshAxisName, ...] + _size: int def __new__(cls, devices: np.ndarray | Sequence[xc.Device], axis_names: str | Sequence[MeshAxisName], @@ -354,7 +355,7 @@ def shape(self): for name, size in safe_zip(self.axis_names, self.devices.shape)) @functools.cached_property - def shape_tuple(self): + def shape_tuple(self): # pyrefly: ignore[bad-override] return tuple( (name, size) for name, size in safe_zip(self.axis_names, self.devices.shape)) @@ -538,7 +539,7 @@ def shape(self): return collections.OrderedDict(self.shape_tuple) @functools.cached_property - def shape_tuple(self): + def shape_tuple(self): # pyrefly: ignore[bad-override] return tuple( (name, size) for name, size in safe_zip(self.axis_names, self.axis_sizes)) diff --git a/jax/_src/named_sharding.py b/jax/_src/named_sharding.py index 5581e6fd5f2e..f3bf25d39abb 100644 --- a/jax/_src/named_sharding.py +++ b/jax/_src/named_sharding.py @@ -18,7 +18,7 @@ import collections import dataclasses import functools -from typing import Any, Union +from typing import Any, Union, overload from jax._src.util import use_cpp_class, cache, use_cpp_method from jax._src.lib import xla_client as xc @@ -286,6 +286,18 @@ def flatten_spec(spec): return out +@overload +def get_array_mapping(axis_resources: PartitionSpec) -> ArrayMapping: + ... + +@overload +def get_array_mapping(axis_resources: AUTO) -> AUTO: + ... + +@overload +def get_array_mapping(axis_resources: UnspecifiedValue) -> UnspecifiedValue: + ... + def get_array_mapping( axis_resources: PartitionSpec | AUTO | UnspecifiedValue ) -> ArrayMappingOrAutoOrUnspecified: @@ -468,7 +480,7 @@ def array_mapping_to_axis_resources(array_mapping: ArrayMapping): reverse_map[index].append(axis) if index > max_index: max_index = index - partitions = [] + partitions: list[MeshAxisName | None] = [] for i in range(max_index + 1): axis = reverse_map[i] if axis: diff --git a/jax/_src/numpy/array_constructors.py b/jax/_src/numpy/array_constructors.py index 43c6f647249f..e92054a91f89 100644 --- a/jax/_src/numpy/array_constructors.py +++ b/jax/_src/numpy/array_constructors.py @@ -191,7 +191,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, weak_type = dtype is None and dtypes.is_weakly_typed(object) if device is None and out_sharding is None and isinstance(object, core.Tracer): - sharding = object.aval.sharding + sharding = object.aval.sharding # pyrefly: ignore[missing-attribute] sharding = None if sharding.mesh.empty else sharding else: sharding = util.choose_device_or_out_sharding(device, out_sharding, "jnp.array") diff --git a/jax/_src/numpy/einsum.py b/jax/_src/numpy/einsum.py index b8f72081df62..49c1bdfe4ef1 100644 --- a/jax/_src/numpy/einsum.py +++ b/jax/_src/numpy/einsum.py @@ -299,7 +299,7 @@ def einsum( for d in np.shape(op) if not core.is_constant_dim(d) } if not non_constant_dim_types: - contract_path = opt_einsum.contract_path + contract_path: Any = opt_einsum.contract_path else: ty = next(iter(non_constant_dim_types)) contract_path = _poly_einsum_handlers.get(ty, _default_poly_einsum_handler) @@ -419,8 +419,10 @@ def einsum_path( .. _opt_einsum: https://github.com/dgasmith/opt_einsum """ if isinstance(optimize, bool): - optimize = 'optimal' if optimize else Unoptimized() - return opt_einsum.contract_path(subscripts, *operands, optimize=optimize) + optimize2: Any = 'optimal' if optimize else Unoptimized() + else: + optimize2 = optimize + return opt_einsum.contract_path(subscripts, *operands, optimize=optimize2) def _removechars(s, chars): return s.translate(str.maketrans(dict.fromkeys(chars))) diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 711a2f56c74d..e0ccf6c65355 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -386,7 +386,7 @@ def allowed_aval(aval): if copy_to_smem: smem_alloc = [ state.AbstractRef( - jax_core.ShapedArray((1,), aval.dtype), + jax_core.ShapedArray((1,), aval.dtype), # pyrefly: ignore[missing-attribute] memory_space=MemorySpace.SMEM, ) for aval in scalar_const_avals @@ -416,7 +416,7 @@ def new_body(*args): # TODO(sharadmv): Remove this once Mosaic support passing scalars as values. scalar_const_trace_avals = [ state.AbstractRef( - jax_core.ShapedArray((1,), aval.dtype), + jax_core.ShapedArray((1,), aval.dtype), # pyrefly: ignore[missing-attribute] memory_space=MemorySpace.HBM if copy_to_smem else MemorySpace.SMEM, ) for aval in scalar_const_avals diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index e869b46579c7..278f0a7162e4 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -464,7 +464,7 @@ def _canonicalize_dimension_semantic( dimension_semantic: tpu_core.DimensionSemantics, ) -> tpu_core.LiteralDimensionSemantics: if isinstance(dimension_semantic, tpu_core.GridDimensionSemantics): - return dimension_semantic.value + return cast(tpu_core.LiteralDimensionSemantics, dimension_semantic.value) return dimension_semantic diff --git a/jax/_src/pallas/mosaic/sc_lowering.py b/jax/_src/pallas/mosaic/sc_lowering.py index f6eb06683cb3..5810c4ef4705 100644 --- a/jax/_src/pallas/mosaic/sc_lowering.py +++ b/jax/_src/pallas/mosaic/sc_lowering.py @@ -257,7 +257,7 @@ def body_fn(indices, *refs): ) pipeline.emit_pipeline( body_fn, - grid=sequential_grid, + grid=sequential_grid, # pyrefly: ignore[bad-argument-type] in_specs=map(make_block_spec, in_block_mappings), out_specs=map(make_block_spec, out_block_mappings), tiling=tiling, @@ -345,7 +345,6 @@ def body_fn(indices, *refs): kernel_type=kernel_type, mesh=mesh, ) - return module def lower_jaxpr_into_module( @@ -359,7 +358,7 @@ def lower_jaxpr_into_module( kernel_type: tpu_core.CoreType, mesh: mesh_lib.Mesh | None = None, dynamic_shape_replacement_enabled: bool = False, -) -> ir.Module: +): """Lowers a Jaxpr to a Mosaic SparseCore module.""" if dynamic_shape_replacement_enabled: raise NotImplementedError( diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 1d71255780b4..0335333aa13e 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -2275,7 +2275,7 @@ def _reshard_transpose_fancy(ct, x, *, dst_sharding, concrete_mesh): assert isinstance(x, ad.GradAccum) if type(ct) is ad.Zero: return - out_sharding = x.aval.to_cotangent_aval().sharding + out_sharding = x.aval.to_cotangent_aval().sharding # pyrefly: ignore[missing-attribute] with mesh_lib.use_abstract_mesh(out_sharding.mesh): x_bar = reshard_p.bind(ct, dst_sharding=out_sharding, concrete_mesh=concrete_mesh) diff --git a/jax/_src/pmap.py b/jax/_src/pmap.py index d44fa4619efd..5f3401ab9d29 100644 --- a/jax/_src/pmap.py +++ b/jax/_src/pmap.py @@ -119,7 +119,7 @@ def lower(*args, **kwargs): no_kwargs=lowered._no_kwargs, # pylint: disable=protected-access ) - wrapped.lower = lower + wrapped.lower = lower # pyrefly: ignore[missing-attribute] return wrapped diff --git a/jax/_src/prng.py b/jax/_src/prng.py index cfa4f664cc5e..3bccd8099089 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -177,7 +177,7 @@ def __init__(self, impl, key_data: Any): self._base_array = key_data def _replace_with(self, value: PRNGKeyArray): - self._base_array._replace_with(value._base_array) + self._base_array._replace_with(value._base_array) # pyrefly: ignore[missing-attribute] def block_until_ready(self): _ = self._base_array.block_until_ready() @@ -217,7 +217,7 @@ def itemsize(self): _device = property(op.attrgetter('_base_array._device')) _committed = property(op.attrgetter('_base_array._committed')) - device = property(op.attrgetter('_base_array.device')) + device = property(op.attrgetter('_base_array.device')) # pyrefly: ignore[bad-override] devices = property(op.attrgetter('_base_array.devices')) # type: ignore[assignment] is_fully_addressable = property(op.attrgetter('_base_array.is_fully_addressable')) # type: ignore[assignment] is_fully_replicated = property(op.attrgetter('_base_array.is_fully_replicated')) # type: ignore[assignment] @@ -319,7 +319,7 @@ def __getstate__(self): def at(self) -> _IndexUpdateHelper: assert False # type: ignore[override] @property def T(self) -> PRNGKeyArray: assert False - def __getitem__(self, _) -> PRNGKeyArray: assert False + def __getitem__(self, key) -> PRNGKeyArray: assert False def flatten(self, *_, **__) -> PRNGKeyArray: assert False def ravel(self, *_, **__) -> PRNGKeyArray: assert False def reshape(self, *_, **__) -> PRNGKeyArray: assert False @@ -504,8 +504,8 @@ def key_array_shard_arg_handler(xs: Sequence[PRNGKeyArray], shardings, layouts, pxla.shard_arg_handlers[PRNGKeyArray] = key_array_shard_arg_handler -def key_array_constant_handler(x, aval): - arr = x._base_array +def key_array_constant_handler(val, aval): + arr = val._base_array return mlir.get_constant_handler(type(arr))(arr, aval) mlir.register_constant_handler(PRNGKeyArray, key_array_constant_handler) @@ -1114,13 +1114,13 @@ def threefry_2x32(keypair, count): odd_size = flat_count.shape[0] % 2 if core.is_constant_dim(odd_size): if odd_size: - x = list(jnp.split(jnp.concatenate([flat_count, np.uint32([0])]), 2)) + x = list(jnp.split(jnp.concatenate([flat_count, jnp.uint32([0])]), 2)) else: x = list(jnp.split(flat_count, 2)) else: # With symbolic shapes we cannot always tell statically if odd_size is true # or false, so we rewrite this without a conditional. - flat_count_padded = jnp.concatenate([flat_count, np.uint32([0])]) + flat_count_padded = jnp.concatenate([flat_count, jnp.uint32([0])]) flat_count_padded_half_size = flat_count_padded.shape[0] // 2 x = [ lax_slicing.dynamic_slice(flat_count_padded, (0,), diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index 90b619447a1d..b930e91c2360 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -17,6 +17,7 @@ import enum from functools import partial import inspect +import itertools from math import prod import operator as op from typing import Any, TypeVar, Union, cast, overload @@ -931,7 +932,7 @@ def _shardy_shard_map_sharding( def _get_token_sharding( ctx: mlir.LoweringRuleContext, mesh - ) -> ir.Attribute: + ) -> sharding_impls.SdyArray: ns = _make_scoped_manual_sharding(ctx, mesh, P()) return ns._to_sdy_sharding(0) @@ -1024,19 +1025,21 @@ def _shard_map_lowering_shardy( config._check_vma(check_vma)): dim_var_values, token_arg_values, const_arg_values, in_args = util.split_list( # type: ignore block.arguments, [num_dim_vars, num_tokens, num_const_args]) - block_const_lowering = { - (id(c), aval): ca - for c, aval, ca in zip(const_args, const_avals, const_arg_values) - } out_nodes_, tokens_out = mlir.jaxpr_subcomp( sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(zip(ctx.tokens_in.effects(), token_arg_values)), (), *in_args, dim_var_values=dim_var_values, - const_lowering=block_const_lowering, + const_lowering={ + (id(c), aval): ca + for c, aval, ca in zip(const_args, const_avals, const_arg_values) + }, outer_traceback=_jax.Traceback()) - sdy.ReturnOp([ir.Value(x) for x in (*[v for _, v in tokens_out.items()], - *out_nodes_)]) + sdy.ReturnOp( + mlir.flatten_ir_values( + itertools.chain((v for _, v in tokens_out.items()), out_nodes_) + ) + ) num_tokens = len(tokens_out.effects()) tokens_out = tokens_out.update_tokens(mlir.TokenSet(zip( ctx.tokens_in.effects(), manual_computation_op.results[:num_tokens]))) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 01b68bd953fe..71ed893a825d 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -480,7 +480,7 @@ def prepare_axis_resources(axis_resources, arg_name, axis_resources, is_leaf=lambda x: x is None) what = f"{arg_name} leaf specifications" - new_entries = [] + new_entries: list[Any] = [] for entry in entries: if isinstance(entry, (UnspecifiedValue, AUTO)) or entry is None: new_entries.append(entry) diff --git a/jax/_src/tpu/linalg/eigh.py b/jax/_src/tpu/linalg/eigh.py index 9cad146c4f87..e80c8d5c20f5 100644 --- a/jax/_src/tpu/linalg/eigh.py +++ b/jax/_src/tpu/linalg/eigh.py @@ -666,7 +666,7 @@ def _eigh_tpu_lowering( v_aval, w_aval = ctx.avals_out eigvecs_type = mlir.aval_to_ir_type(v_aval) eigvals_type = mlir.aval_to_ir_type(w_aval) - result_types = [eigvecs_type, eigvals_type] + result_types = mlir.flatten_ir_types([eigvecs_type, eigvals_type]) backend_config = f"{int(lower)},{int(sort_eigenvalues)},100,1e-6" diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index 788b41ab2877..10177b9b6a7c 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -197,7 +197,8 @@ def downgrade_lowered_module_asm( self.lowered_module_asm_version is None or self.lowered_module_asm_version > version ) - with mlir.make_ir_context() as ctx, ir.Location.unknown(): + ctx = mlir.make_ir_context() + with ctx, ir.Location.unknown(): ctx.allow_unregistered_dialects = True module = ir.Module.parse(self.lowered_module_asm) pipeline = PassManager.parse( @@ -397,9 +398,10 @@ def _tpu_custom_call_lowering( if all(core.is_constant_shape(aval_out.shape) for aval_out in ctx.avals_out): result_shapes = None else: - result_shapes = [ + result_shapes = mlir.flatten_ir_values( mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, aval_out.shape)) - for aval_out in ctx.avals_out] + for aval_out in ctx.avals_out + ) extra_attributes: dict[str, ir.Attribute] | None = None # Add kernel_name and kernel_metadata as attributes to the custom call op. # This is because we do not want to pollute the backend_config with this @@ -408,7 +410,11 @@ def _tpu_custom_call_lowering( extra_attributes = dict(kernel_name=ir.StringAttr.get(kernel_name)) # If the IR version we originally generated the ASM string with is not the # same as the one we should have used, we need to downgrade the ASM string. - if (ir_version := get_ir_version(ctx)) != config.lowered_module_asm_version: + ir_version = get_ir_version(ctx) + if ( + ir_version is not None and + ir_version != config.lowered_module_asm_version + ): config = config.downgrade_lowered_module_asm(ir_version) call = mlir.custom_call( "tpu_custom_call", From 108a11ac1de1811cb2b55da270bbda4ef196fe86 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 3 Mar 2026 06:58:23 -0800 Subject: [PATCH 036/100] Run build_cleaner on jaxlib. PiperOrigin-RevId: 877926467 --- jaxlib/BUILD | 1 - jaxlib/jax.bzl | 3 +++ jaxlib/mosaic/BUILD | 16 ---------------- jaxlib/mosaic/dialect/gpu/BUILD | 9 +-------- jaxlib/mosaic/gpu/BUILD | 2 +- jaxlib/rocm/BUILD | 7 +++++++ 6 files changed, 12 insertions(+), 26 deletions(-) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 43dd76c1c3d0..115a6dc3ef7a 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -550,7 +550,6 @@ cc_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@nanobind", diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index c937cafc699e..74753d7e9cfe 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -928,3 +928,6 @@ def compare_srcs_and_test_deps_test(name, srcs, tests, ignored_init_py_files, ro tags = tags, testonly = True, ) + +def jax_bzl_library(name, deps = [], **kwargs): # buildifier: disable=unused-variable + pass diff --git a/jaxlib/mosaic/BUILD b/jaxlib/mosaic/BUILD index 5b11754a0020..bc0899148af6 100644 --- a/jaxlib/mosaic/BUILD +++ b/jaxlib/mosaic/BUILD @@ -92,18 +92,11 @@ cc_library( ":pass_boilerplate", ":serde", ":tpu_dialect", - ":tpu_inc_gen", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:ControlFlowDialect", "@llvm-project//mlir:DataLayoutInterfaces", - "@llvm-project//mlir:Dialect", - "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", - "@llvm-project//mlir:MathDialect", - "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Pass", - "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Support", "@llvm-project//mlir:VectorDialect", ], @@ -116,25 +109,16 @@ cc_library( # compatible with libtpu deps = [ ":pass_boilerplate", - ":serde", ":tpu_dialect", - ":tpu_inc_gen", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", - "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:CommonFolders", - "@llvm-project//mlir:ControlFlowDialect", - "@llvm-project//mlir:DataLayoutInterfaces", - "@llvm-project//mlir:Dialect", "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:MathDialect", - "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Pass", - "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", diff --git a/jaxlib/mosaic/dialect/gpu/BUILD b/jaxlib/mosaic/dialect/gpu/BUILD index 6035a29846b1..af9d7594ba64 100644 --- a/jaxlib/mosaic/dialect/gpu/BUILD +++ b/jaxlib/mosaic/dialect/gpu/BUILD @@ -189,8 +189,6 @@ cc_library( hdrs = DIALECT_CAPI_HEADERS, deps = [ ":mosaic_gpu", - ":mosaic_gpu_inc_gen", - "@llvm-project//llvm:Support", "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:FuncExtensions", "@llvm-project//mlir:IR", @@ -203,10 +201,7 @@ cc_library( cc_library( name = "gpu_dialect_capi_headers", hdrs = DIALECT_CAPI_HEADERS, - deps = [ - ":mosaic_gpu_inc_gen", - "@llvm-project//mlir:CAPIIRHeaders", - ], + deps = ["@llvm-project//mlir:CAPIIRHeaders"], ) # Alwayslink target, used when exporting the C API from a shared library. @@ -216,8 +211,6 @@ cc_library( hdrs = DIALECT_CAPI_HEADERS, deps = [ ":mosaic_gpu", - ":mosaic_gpu_inc_gen", - "@llvm-project//llvm:Support", "@llvm-project//mlir:CAPIIRObjects", "@llvm-project//mlir:FuncExtensions", "@llvm-project//mlir:IR", diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index e7fee8ef3526..24e7eba294e4 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -235,7 +235,7 @@ cc_library( srcs = CAPI_SOURCES, hdrs = CAPI_HEADERS, deps = [ - ":passes", + ":serde", "@llvm-project//mlir:CAPIIRObjects", ], alwayslink = True, diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index d85ab86e02bf..7d523afc3f84 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -19,6 +19,7 @@ load("@rules_python//python:defs.bzl", "py_library") load( "//jaxlib:jax.bzl", "if_rocm_is_configured", + "jax_bzl_library", "rocm_library", ) load("//jaxlib/rocm:rocm_rpath.bzl", "rocm_nanobind_extension") @@ -550,3 +551,9 @@ py_library( ":rocm_plugin_extension", ]), ) + +jax_bzl_library( + name = "rocm_rpath_bzl", + srcs = ["rocm_rpath.bzl"], + visibility = ["//visibility:private"], +) From 0b31d9c36dab7acae102f38c785a07e3709fcdeb Mon Sep 17 00:00:00 2001 From: Levon Ter-Grigoryan Date: Tue, 3 Mar 2026 07:01:38 -0800 Subject: [PATCH 037/100] [Mosaic:GPU] Enable cuda-graphs with collective metadata. PiperOrigin-RevId: 877927668 --- tests/pallas/gpu_pallas_distributed_test.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/pallas/gpu_pallas_distributed_test.py b/tests/pallas/gpu_pallas_distributed_test.py index a653a38b08b2..119bbfc0cecd 100644 --- a/tests/pallas/gpu_pallas_distributed_test.py +++ b/tests/pallas/gpu_pallas_distributed_test.py @@ -1009,15 +1009,15 @@ def test_all_gather_different_axes(self, axis): # allocator is used, setUp will skip the test. os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.01' os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'default' - # TODO(b/483671897) re-enable once command buffers supported with collectives. - additional_xla_flags = "--xla_gpu_enable_command_buffer=''" - if "XLA_FLAGS" in os.environ: - os.environ["XLA_FLAGS"] = ( - f"{os.environ['XLA_FLAGS']} {additional_xla_flags}" - ) - else: - os.environ["XLA_FLAGS"] = additional_xla_flags if is_nvshmem_used(): + # TODO(b/483671897) re-enable once command buffers supported with collectives. + additional_xla_flags = "--xla_gpu_enable_command_buffer=''" + if "XLA_FLAGS" in os.environ: + os.environ["XLA_FLAGS"] = ( + f"{os.environ['XLA_FLAGS']} {additional_xla_flags}" + ) + else: + os.environ["XLA_FLAGS"] = additional_xla_flags jt_multiprocess.main() else: config.config_with_absl() From 5c3035c0c63d23e467ae8314df9349bb9ab060a6 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 3 Mar 2026 08:34:21 -0800 Subject: [PATCH 038/100] Reverts c5bf3aea3148c25e753f20b512d2efb5a741ffe9 PiperOrigin-RevId: 877963040 --- jax/_src/pallas/mosaic/lowering.py | 10 +--------- tests/pallas/tpu_ops_test.py | 16 ---------------- 2 files changed, 1 insertion(+), 25 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index e869b46579c7..caba6bb3f6be 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -341,7 +341,7 @@ def _dtype_to_ir_type(dtype: DTypeLike, dtype = BOOL_MEMREF_TYPE # TODO(justinfu): Remove after mosaic supports unsigned types. # This conversion makes mosaic interpret all unsigned types as signed types. - type = mlir.dtype_to_ir_type(jnp.dtype(dtype)) + type = mlir.dtype_to_ir_type(jnp.dtype(dtype)) if isinstance(type, ir.IntegerType): return ir.IntegerType.get_signless(type.width) else: @@ -2229,14 +2229,6 @@ def _dot_general_lowering_rule( preferred_element_type, **_, ): - for aval in ctx.avals_in: - if jnp.issubdtype(aval.dtype, jnp.unsignedinteger): - raise NotImplementedError( - f"Unsigned integer dtype {aval.dtype} is not supported for" - " dot_general (matmul) on the Pallas Mosaic TPU backend because" - " dot_general interprets all integer inputs as signed. Consider" - " casting to a signed type before the dot operation." - ) (lhs_dims, rhs_dims), _ = dimension_numbers (aval_out,) = ctx.avals_out out_type = ctx.aval_to_ir_type(aval_out) diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index b5e1672fd73e..0ef112cdfcd7 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -912,22 +912,6 @@ def kernel(x_ref, o_ref): jax.nn.sigmoid(x), ) - @parameterized.parameters(jnp.uint4, jnp.uint8, jnp.uint16, jnp.uint32) - def test_unsigned_dtype_dot_raises(self, dtype): - k = 256 - packing = 32 // jnp.iinfo(dtype).bits - lhs = jnp.zeros((8 * packing, k), dtype=dtype) - rhs = jnp.zeros((k, 128), dtype=dtype) - - def kernel(lhs_ref, rhs_ref, o_ref): - o_ref[...] = pl.dot(lhs_ref[...], rhs_ref[...]) - - out_shape = jax.ShapeDtypeStruct((8 * packing, 128), dtype) - with self.assertRaisesRegex( - NotImplementedError, "Unsigned integer dtype.*dot_general.*matmul" - ): - self.pallas_call(kernel, out_shape=out_shape)(lhs, rhs) - if __name__ == "__main__": absltest.main() From 1d8fba44f84c89e3bf2da6cbdfe1f9e3b59c5a26 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Tue, 3 Mar 2026 08:37:09 -0800 Subject: [PATCH 039/100] [pallas:mgpu] Use two barriers for try-cancel barriers in `dynamic_scheduling_loop`. The recent race condition fix prevented any threads from running ahead (as all threads waited at the `cancel_user_barrier` until all threads had completed the previous iteration). PiperOrigin-RevId: 877964436 --- jax/_src/pallas/mosaic_gpu/helpers.py | 74 +++++++++++++++------------ 1 file changed, 40 insertions(+), 34 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/helpers.py b/jax/_src/pallas/mosaic_gpu/helpers.py index 3f8b4e1d5db0..bb0331a6386f 100644 --- a/jax/_src/pallas/mosaic_gpu/helpers.py +++ b/jax/_src/pallas/mosaic_gpu/helpers.py @@ -21,12 +21,13 @@ from typing import TypeVar, overload import jax -from jax import numpy as jnp from jax import lax +from jax import numpy as jnp from jax._src import dtypes -from jax._src.pallas.mosaic_gpu import primitives as gpu_primitives -from jax._src.pallas.mosaic_gpu import core as gpu_core +from jax._src.pallas import helpers as pallas_helpers from jax._src.pallas import primitives as pallas_primitives +from jax._src.pallas.mosaic_gpu import core as gpu_core +from jax._src.pallas.mosaic_gpu import primitives as gpu_primitives import numpy as np _T = TypeVar("_T") @@ -335,56 +336,61 @@ def body(loop_info): body function should expect a ``carry`` keyword argument and return the next carry value. """ - if thread_axis is not None: - num_threads = lax.axis_size(thread_axis) - else: - num_threads = 1 + num_slots = 2 + num_threads = 1 if thread_axis is None else lax.axis_size(thread_axis) user_carry = init_carry def decorator(body): - grid_idx = tuple(lax.axis_index(axis_name) for axis_name in grid_names) - success = True def _scoped(try_cancel_buffer, try_cancel_barrier, cancel_used_barrier): - gpu_primitives.barrier_arrive(cancel_used_barrier) def try_cancel_cond(carry): _, success, _, _ = carry return success + def try_cancel_body(carry): grid_idx, _, wave_step, user_carry = carry - slot = lax.rem(wave_step, jnp.int32(2)) - gpu_primitives.barrier_wait(cancel_used_barrier) - gpu_primitives.try_cluster_cancel( - try_cancel_buffer.at[slot], try_cancel_barrier - ) loop_info = NDLoopInfo( - index=grid_idx, - local_index=wave_step, - num_local_steps=None, + index=grid_idx, local_index=wave_step, num_local_steps=None ) + slot = lax.rem(wave_step, jnp.int32(num_slots)) + + @pallas_helpers.when(wave_step >= num_slots) + def wait_until_slot_available(): + gpu_primitives.barrier_wait(cancel_used_barrier.at[slot]) + + gpu_primitives.try_cluster_cancel( + try_cancel_buffer.at[slot], try_cancel_barrier.at[slot] + ) + if user_carry is None: body(loop_info) else: user_carry = body(loop_info, carry=user_carry) - gpu_primitives.barrier_wait(try_cancel_barrier) + + gpu_primitives.barrier_wait(try_cancel_barrier.at[slot]) grid_idx, success = gpu_primitives.query_cluster_cancel( - try_cancel_buffer.at[slot], - grid_names=grid_names) - gpu_primitives.barrier_arrive(cancel_used_barrier) + try_cancel_buffer.at[slot], grid_names=grid_names + ) + gpu_primitives.barrier_arrive(cancel_used_barrier.at[slot]) return (grid_idx, success, wave_step + jnp.int32(1), user_carry) - init_carry = (grid_idx, success, jnp.int32(0), user_carry) - final_carry = lax.while_loop( - try_cancel_cond, - try_cancel_body, - init_carry, - ) - gpu_primitives.barrier_wait(cancel_used_barrier) - if user_carry is not None: - return final_carry[-1] + + grid_idx = tuple(map(lax.axis_index, grid_names)) + init_carry = (grid_idx, True, jnp.int32(0), user_carry) + final_carry = lax.while_loop(try_cancel_cond, try_cancel_body, init_carry) + _, _, num_steps, final_user_carry = final_carry + num_barriers_to_reset = lax.min(num_steps, jnp.int32(num_slots)) + + @pallas_helpers.loop(jnp.int32(0), num_barriers_to_reset) + def reset_cancel_barrier(slot): + gpu_primitives.barrier_wait(cancel_used_barrier.at[slot]) + + return None if user_carry is None else final_user_carry + + barrier = gpu_core.Barrier(num_arrivals=num_threads, num_barriers=num_slots) return pallas_primitives.run_scoped( _scoped, - try_cancel_buffer=gpu_core.TryClusterCancelResult(2), - try_cancel_barrier=gpu_core.Barrier(num_arrivals=num_threads), - cancel_used_barrier=gpu_core.Barrier(num_arrivals=num_threads), + try_cancel_buffer=gpu_core.TryClusterCancelResult(num_slots), + try_cancel_barrier=barrier, + cancel_used_barrier=barrier, collective_axes=thread_axis, ) return decorator From f45c878fe28db59fce13975e7d0a2c3765c18f67 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 26 Feb 2026 19:49:14 +0000 Subject: [PATCH 040/100] [hijax] enable VJPHiPrimitive to define transpose rules Co-authored-by: Yash Katariya --- jax/_src/ad_util.py | 3 ++- jax/_src/core.py | 4 ++++ jax/_src/hijax.py | 12 ++++++++++++ tests/hijax_test.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 60 insertions(+), 1 deletion(-) diff --git a/jax/_src/ad_util.py b/jax/_src/ad_util.py index 4e3dd2341818..e71d63e2f54b 100644 --- a/jax/_src/ad_util.py +++ b/jax/_src/ad_util.py @@ -32,8 +32,9 @@ map = safe_map def add_jaxvals(x: ArrayLike, y: ArrayLike) -> Array: + from jax._src.hijax import HiType # type: ignore ty = typeof(x) - if hasattr(ty, 'vspace_add'): # TODO(mattjj,dougalm): revise away hasattr + if isinstance(ty, HiType): return ty.vspace_add(x, y) x, y = core.standard_insert_pvary(x, y) return add_jaxvals_p.bind(x, y) diff --git a/jax/_src/core.py b/jax/_src/core.py index 9ecfdbe3b9f1..9b0a16363373 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1752,6 +1752,10 @@ def shard(self, mesh, manual_axes, check_vma, spec): def unshard(self, mesh, check_vma, spec): return unshard_aval(mesh, check_vma, spec, self) + def vspace_add(self, x, y): + from jax._src.ad_util import add_jaxvals # type: ignore + return add_jaxvals(x, y) + InputType = tuple[AbstractValue, ...] OutputType = tuple[AbstractValue, ...] diff --git a/jax/_src/hijax.py b/jax/_src/hijax.py index 6253e1fc486b..2519c3df5ce6 100644 --- a/jax/_src/hijax.py +++ b/jax/_src/hijax.py @@ -414,6 +414,11 @@ def linearized(self, residuals, *tangents): raise NotImplementedError(f"for linearize support, subclass {type(self)} " "must implement `lin` and `linearized`") + # optional transpose rule, for primitives that are linear in some inputs + def transpose(self, out_ct, *maybe_accums): + raise NotImplementedError(f"for transpose support, subclass {type(self)} " + "must implement `transpose`") + # vmap interface def batch(self, axis_data, args, dims): out_dim = self.batch_dim_rule(axis_data, dims) @@ -614,6 +619,13 @@ def _call_hi_primitive_jvp(primals, tangents, *, _prim): return out_primals_flat, out_tangents_flat ad.primitive_jvps[call_hi_primitive_p] = _call_hi_primitive_jvp +def _call_hi_primitive_transpose(cts_flat, *primals_flat, _prim): + cts = tree_unflatten(_prim.out_tree, cts_flat) + primals = tree_unflatten(_prim.in_tree, primals_flat) + none = _prim.transpose(cts, *primals) + assert none is None +ad.fancy_transposes[call_hi_primitive_p] = _call_hi_primitive_transpose + def _call_hi_primitive_dce(used_outs_flat, eqn): _prim = eqn.params['_prim'] used_out = tree_unflatten(_prim.out_tree, used_outs_flat) diff --git a/tests/hijax_test.py b/tests/hijax_test.py index eff1ec5b0d0b..8b11734c8920 100644 --- a/tests/hijax_test.py +++ b/tests/hijax_test.py @@ -222,6 +222,13 @@ def unshard(self, mesh, check_vma, spec): return TupTy(tuple(ty.unshard(mesh, check_vma, s) for ty, s in zip(self.tys, spec.val))) + def vspace_add(self, x_tup, y_tup): + n = len(self.tys) + x_elts = [get_tuple_element(x_tup, i) for i in range(n)] + y_elts = [get_tuple_element(y_tup, i) for i in range(n)] + return make_tup(*(ty.vspace_add(x, y) for ty, x, y + in zip(self.tys, x_elts, y_elts))) + register_hitype(HiTup, lambda t: TupTy(tuple(map(typeof, t.elts)))) @dataclass(frozen=True) @@ -246,6 +253,16 @@ def __init__(self, in_avals): def expand(self, *elts): return HiTup(elts) + def jvp(self, primals, tangents): + tangents = map(ad.instantiate_zeros, tangents) + return make_tup(*primals), make_tup(*tangents) + + def transpose(self, ct, *maybe_accums): + cts = [get_tuple_element(ct, i) for i in range(len(self.out_aval.tys))] + for ct_, accum in zip(cts, maybe_accums): + if isinstance(accum, ad.GradAccum): + accum.accum(ct_) + def batch(self, _axis_data, args, in_dims): return make_tup(*args), TupSpec(in_dims) @@ -259,6 +276,16 @@ def __init__(self, in_aval, idx): def expand(self, tup): return tup.elts[self.idx] + def jvp(self, primals, tangents): + (tup,), (tup_dot,) = primals, tangents + return get_tuple_element(tup, self.idx), get_tuple_element(tup_dot, self.idx) + + def transpose(self, g, tup_accum): + tup_ty, = self.in_avals + elts = map(ad.zeros_like_aval, tup_ty.tys) + elts[self.idx] = g + tup_accum.accum(make_tup(*elts)) + def vjp_fwd(self, tup): return get_tuple_element(tup, self.idx), None @@ -1174,6 +1201,21 @@ def f(x): self.assertEqual(f(2.0), 8.0) self.assertEqual(jax.linearize(f, 2.0)[1](1.0), 12.0) + @jtu.with_explicit_mesh((2, 2), ('i', 'j')) + def test_oh_hi_mat(self, mesh): + x = jnp.ones(4) + y = jnp.ones(2) + + @jax.remat + def f(x, y): + tup = make_tup(x, y) + x_ = get_tuple_element(tup, 0) + y_ = get_tuple_element(tup, 1) + return jnp.sum(x_ + jnp.concatenate((y_, y_))) + + f(x, y) + jax.jit(jax.grad(f))(x, y) + class BoxTest(jtu.JaxTestCase): From 158c20abcc8915f26712557c7f6d05be98005bf4 Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Tue, 3 Mar 2026 09:29:22 -0800 Subject: [PATCH 041/100] Fix a pallas dot test on TPU using unsupported unsigned int dot. PiperOrigin-RevId: 877987980 --- tests/pallas/pallas_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 1a232993cff6..bbdbb53c8560 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -719,6 +719,8 @@ def dot_kernel(x_ref, y_ref, o_ref): def test_integer_dot(self, dtype): if jtu.test_device_matches(["tpu"]) and not jtu.is_device_tpu_at_least(5): self.skipTest("`int8` dot is only supported on v5 TPUs and newer.") + if jnp.issubdtype(dtype, jnp.unsignedinteger): + self.skipTest("Not currently supported.") @functools.partial( self.pallas_call, @@ -728,8 +730,8 @@ def dot_kernel(x_ref, y_ref, o_ref): o_ref[()] = pl.dot(x_ref[()], y_ref[()]) key0, key1 = random.split(random.key(0)) - # FIXME(cjfj): TPU fails with `uint8` values >= 128. - kwargs = dict(minval=jnp.iinfo(dtype).min, maxval=128, dtype=dtype) + kwargs = dict(minval=jnp.iinfo(dtype).min, maxval=jnp.iinfo(dtype).max + 1, + dtype=dtype) # TODO(cjfj): Investigate why this fails on GPU with `k == 16`. x = random.randint(key0, (32, 128), **kwargs) y = random.randint(key1, (128, 64), **kwargs) From b0efa207cdbbbd3c2fbed4f5c50c95ce779f4d23 Mon Sep 17 00:00:00 2001 From: Ionel Gog Date: Tue, 3 Mar 2026 10:02:37 -0800 Subject: [PATCH 042/100] Rely on IFRT types implementing AbslStringify rather than calling the deprecated DebugString method PiperOrigin-RevId: 878002429 --- jaxlib/py_array.cc | 4 ++-- jaxlib/py_executable.cc | 2 +- jaxlib/py_socket_transfer.cc | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/jaxlib/py_array.cc b/jaxlib/py_array.cc index 5582b377b32d..cb712ba36471 100644 --- a/jaxlib/py_array.cc +++ b/jaxlib/py_array.cc @@ -1962,7 +1962,7 @@ absl::Status PyHostValue::CopyStringArrayToHostAsync( auto transfer_guard_formatter = [ifrt_array] { return absl::StrCat( "shape=(", absl::StrJoin(ifrt_array->shape().dims(), ","), - "), dtype=", ifrt_array->dtype().DebugString(), ", device=", + "), dtype=", ifrt_array->dtype(), ", device=", ifrt_array->sharding().devices()->devices().front()->DebugString()); }; TF_RETURN_IF_ERROR( @@ -2009,7 +2009,7 @@ absl::Status PyHostValue::CopyToHostAsync( auto transfer_guard_formatter = [ifrt_array] { return absl::StrCat( "shape=(", absl::StrJoin(ifrt_array->shape().dims(), ","), - "), dtype=", ifrt_array->dtype().DebugString(), ", device=", + "), dtype=", ifrt_array->dtype(), ", device=", ifrt_array->sharding().devices()->devices().front()->DebugString()); }; TF_RETURN_IF_ERROR( diff --git a/jaxlib/py_executable.cc b/jaxlib/py_executable.cc index 660bd9e8a2c3..fe3879f49145 100644 --- a/jaxlib/py_executable.cc +++ b/jaxlib/py_executable.cc @@ -549,7 +549,7 @@ int32_t PyLoadedExecutable::GetNextLaunchId() { << " with launch ID: " << launch_id << " key: " << launch_id_key_; VLOG(2) << "Executable devices for launch ID " << launch_id << ": " << (ifrt_loaded_executable_->devices().has_value() - ? (*ifrt_loaded_executable_->devices())->DebugString() + ? absl::StrCat(**ifrt_loaded_executable_->devices()) : ""); return launch_id; } diff --git a/jaxlib/py_socket_transfer.cc b/jaxlib/py_socket_transfer.cc index 7349db2a0b6a..7e1eac87c4c2 100644 --- a/jaxlib/py_socket_transfer.cc +++ b/jaxlib/py_socket_transfer.cc @@ -83,8 +83,8 @@ absl::StatusOr MemorySpaceFromSharding( const xla::ifrt::Sharding& sharding) { if (sharding.devices()->devices().size() != 1) { return xla::InvalidArgument( - "Can only convert SingleDeviceSharding to MemorySpace not %s", - sharding.DebugString()); + "Can only convert SingleDeviceSharding to MemorySpace not %v", + sharding); } auto* device = sharding.devices()->devices()[0]; if (sharding.memory_kind().memory_kind().has_value()) { From f08bb9996e93c56b71a53c6a1dfc887fdb58b253 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 3 Mar 2026 10:02:55 -0800 Subject: [PATCH 043/100] Add support for multiple weak keys to WeakrefLRUCache. We have no direct use for this feature, but it is a stepping stone on the path to implementing multi_weakref_lru_cache in C++. * Add a num_weak_keys optional argument to the WeakrefLRUCache constructor stating the number of weak positional arguments to use, defaulting to 1. * Change WeakKey to be variadic. * Add a reverse index data structure that maps weak ref objects to the CacheEntrys that reference them. Use this to clean up stale entries when a weak reference becomes dead. This also has a slight benefit in that we now don't have to construct a fresh callback object each time we construct a weakref: we can just use one that uses the map. PiperOrigin-RevId: 878002619 --- jaxlib/BUILD | 3 + jaxlib/weakref_lru_cache.cc | 477 ++++++++++++++++++++++--------- jaxlib/weakref_lru_cache.pyi | 2 +- jaxlib/weakref_lru_cache_test.py | 69 ++++- 4 files changed, 417 insertions(+), 134 deletions(-) diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 115a6dc3ef7a..3bf1d559e5a9 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -281,10 +281,13 @@ nanobind_pywrap_extension( srcs = ["weakref_lru_cache.cc"], pytype_srcs = ["weakref_lru_cache.pyi"], deps = [ + ":nb_class_ptr", ":reentrant_hash_map", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/hash", "@com_google_absl//absl/log:check", diff --git a/jaxlib/weakref_lru_cache.cc b/jaxlib/weakref_lru_cache.cc index f9fc0b00e0c5..a76e12704982 100644 --- a/jaxlib/weakref_lru_cache.cc +++ b/jaxlib/weakref_lru_cache.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include #include @@ -27,6 +26,8 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/cleanup/cleanup.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/hash/hash.h" #include "absl/log/check.h" @@ -38,6 +39,7 @@ limitations under the License. #include "nanobind/stl/shared_ptr.h" // IWYU pragma: keep #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "jaxlib/nb_class_ptr.h" #include "jaxlib/reentrant_hash_map.h" namespace nb = nanobind; @@ -91,16 +93,16 @@ inline size_t StrongPythonHash(nb::handle h) { // PointerWeakKey is a WeakKey that compares by pointer identity, used to find // a specific weakref object as part of a heterogeneous lookup. struct PointerWeakKey { - nb::weakref ref; + absl::InlinedVector refs; size_t cached_hash; }; // WeakKey is the key to the first level of the table. struct WeakKey { - WeakKey(nb::weakref ref, size_t cached_hash) - : ref(std::move(ref)), cached_hash(cached_hash) {} + WeakKey(absl::InlinedVector refs, size_t cached_hash) + : refs(std::move(refs)), cached_hash(cached_hash) {} - nb::weakref ref; + absl::InlinedVector refs; // The contract of ReentrantHashMap does not allow hash functions to release // locks, and hence we cannot call back into Python during our hash function. @@ -111,9 +113,27 @@ struct WeakKey { // It is important that we take the keys by value not by reference because // equal() may release locks, and per the contract of our hash map this may // invalidate references. - bool operator()(WeakKey a, WeakKey b) const { return a.ref.equal(b.ref); } + bool operator()(WeakKey a, WeakKey b) const { + if (a.refs.size() != b.refs.size()) { + return false; + } + for (size_t i = 0; i < a.refs.size(); ++i) { + if (!a.refs[i].equal(b.refs[i])) { + return false; + } + } + return true; + } bool operator()(WeakKey a, PointerWeakKey b) const { - return a.ref.ptr() == b.ref.ptr(); + if (a.refs.size() != b.refs.size()) { + return false; + } + for (size_t i = 0; i < a.refs.size(); ++i) { + if (a.refs[i].ptr() != b.refs[i].ptr()) { + return false; + } + } + return true; } }; @@ -347,28 +367,24 @@ struct CacheEntry { } // namespace -class WeakrefLRUCache : public std::enable_shared_from_this { +class WeakrefLRUCache { public: - WeakrefLRUCache(nb::callable cache_context_fn, nb::callable fn, - int64_t maxsize, std::optional explain) - : cache_context_fn_(cache_context_fn), - fn_(fn), - explain_(explain), - lru_maxsize_(maxsize) { - lru_head_.next = nullptr; - lru_head_.prev = &lru_head_; - } + // num_weak_keys is the number of leading arguments to a call to the cache + // that should be treated as weak arguments. + static nb_class_ptr Create( + nb::callable cache_context_fn, nb::callable fn, int64_t maxsize, + std::optional explain, int num_weak_keys); ~WeakrefLRUCache() { Clear(); } + // Entry point used by the Python vectorcall protocol. static PyObject* VectorCall(PyObject* self_obj, PyObject* const* args, Py_ssize_t nargsf, PyObject* kwnames); - PyObject* Call(PyObject* self_obj, absl::Span args, - Py_ssize_t nargsf, PyObject* kwnames); - - void EvictWeakref(const WeakKey& search_key); + // Evicts a particular weak key from the cache. + void EvictWeakKey(const WeakKey& search_key); + // Returns a list of the keys in the cache. std::vector GetKeys(); struct CacheInfo { @@ -381,95 +397,123 @@ class WeakrefLRUCache : public std::enable_shared_from_this { void Clear(); + int num_weak_keys() const { return num_weak_keys_; } + WeakKey MakeWeakrefKey(absl::Span weakref_args); + static PyType_Slot slots_[]; + // Do not call directly. Use Create() instead. + WeakrefLRUCache(nb::callable cache_context_fn, nb::callable fn, + int64_t maxsize, std::optional explain, + int num_weak_keys); + private: friend struct CacheEntry; - using Cache = ReentrantHashMap, - StrongKey::CachedHash, StrongKey::SafeEqual>; - - using WeakrefCacheValue = std::shared_ptr; - - WeakKey MakeWeakrefKey(const nb::object& weakref_key); + // Python callable, called each time the cache is invoked, whose return value + // is used to augment the strong key with implicit context. nb::callable cache_context_fn_; + + // Function called on cache miss. nb::callable fn_; + std::optional explain_; + + const int num_weak_keys_; + + using Cache = ReentrantHashMap, + StrongKey::CachedHash, StrongKey::SafeEqual>; + using WeakrefCacheValue = std::shared_ptr; + + // Map WeakKeys to a Cache, which contains a map of the strong key/value + // pairs. ReentrantHashMap entries_; + + // Maps weakref objects to the strong CacheEntry objects that reference them. + // Used to evict entries when a weak reference becomes dead. + absl::flat_hash_map> + reverse_index_; + + // LRU list used for eviction + int64_t lru_maxsize_; // Maximum size of the cache in entries. + int64_t lru_size_{0}; // Current size of the cache in entries. + LRUNode lru_head_; // Root of the LRU list. + + // Cache statistics. int64_t misses_ = 0; int64_t total_queries_ = 0; - int64_t lru_maxsize_; - int64_t lru_size_{0}; - LRUNode lru_head_; + // Callback invoked when a weak reference is cleared. Constructed by Create(). + nb::object weakref_callback_; - void MoveToFront(CacheEntry* node) { - node->Unlink(); - PushFront(node); - } + // Helper used by VectorCall. + PyObject* Call(PyObject* self_obj, absl::Span args, + Py_ssize_t nargsf, PyObject* kwnames); - void PushFront(CacheEntry* node) { - CHECK(!node->IsLinked()); - node->lru_node.next = lru_head_.next; - node->lru_node.prev = &lru_head_; - if (lru_head_.next) { - lru_head_.next->lru_node.prev = &node->lru_node; - } else { - lru_head_.prev = &node->lru_node; - } - lru_head_.next = node; - ++lru_size_; - } + // Evict all references to `dying_weakref_ptr` from the cache. + void EvictWeakref(PyObject* dying_weakref_ptr); - void EvictLeastRecentlyUsed() { - CacheEntry* tail = lru_head_.prev->prev->next; - if (tail->IsLinked()) { - tail->Unlink(); - } + nb::object WeakrefKeyToPython(absl::Span weakref_args) const; + nb::object WeakrefKeyToPython( + absl::Span weakref_args) const; - // Use heterogeneous lookups so we compare objects by pointer identity. - // This avoids calling Python __eq__ methods which might release the lock. - PointerWeakKey ptr_wr_key{tail->wr_key.ref, tail->wr_key.cached_hash}; - auto cache_it = entries_.find(ptr_wr_key); - if (cache_it == entries_.end()) { - return; - } - std::shared_ptr cache_ptr = cache_it->second; - - PointerStrongKey ptr_strong_key{ - tail->key.context(), absl::MakeConstSpan(tail->key.kwnames()), - absl::MakeConstSpan(tail->key.args_span()), tail->key.cached_hash()}; - auto inner_it = cache_ptr->find(ptr_strong_key); - if (inner_it == cache_ptr->end()) { - return; - } + void RemoveEntryFromReverseIndex(CacheEntry* entry); - // To prevent Python object destructors from running (and potentially - // dropping the lock) *during* the erase operation, we grab an extra - // reference to the keys and values here. They will be destroyed at the - // end of this function block, after the map operations are complete. - WeakKey wr_key_copy = tail->wr_key; - StrongKey strong_key_copy = tail->key; - std::shared_ptr value_copy = inner_it->second; - - // Now erase from the map. Because we hold references, no Python - // destruction happens here. - cache_ptr->erase(inner_it); - if (cache_ptr->empty()) { - entries_.erase(cache_it); - } - } + // Moves 'node' to the front of the LRU list. Assumes `node` is already + // linked in the LRU list. + void MoveToFront(CacheEntry* node); + + // Adds 'node' to the front of the LRU list. Assumes `node` is not already + // linked in the LRU list. + void PushFront(CacheEntry* node); + + // Removes the least recently used entry from the cache. + void EvictLeastRecentlyUsed(); static int tp_traverse(PyObject* self, visitproc visit, void* arg); static int tp_clear(PyObject* self); }; +/*static*/ nb_class_ptr WeakrefLRUCache::Create( + nb::callable cache_context_fn, nb::callable fn, int64_t maxsize, + std::optional explain, int num_weak_keys) { + auto self = make_nb_class(cache_context_fn, fn, maxsize, + explain, num_weak_keys); + + // weak_callback captures a weak reference to self otherwise we would have a + // reference count cycle. + self->weakref_callback_ = nb::cpp_function( + [this_weak = nb::weakref(self)](nb::handle dying_weakref) { + nb::object py_cache = this_weak(); + if (py_cache.is_none()) { + return; + } + nb::ft_object_guard lock(py_cache); + nb::cast(py_cache)->EvictWeakref(dying_weakref.ptr()); + }); + return self; +} + +WeakrefLRUCache::WeakrefLRUCache(nb::callable cache_context_fn, nb::callable fn, + int64_t maxsize, + std::optional explain, + int num_weak_keys) + : cache_context_fn_(cache_context_fn), + fn_(fn), + explain_(explain), + num_weak_keys_(num_weak_keys), + lru_maxsize_(maxsize) { + lru_head_.next = nullptr; + lru_head_.prev = &lru_head_; +} + CacheEntry::~CacheEntry() { if (IsLinked()) { Unlink(); } + parent->RemoveEntryFromReverseIndex(this); } void CacheEntry::Unlink() { @@ -485,49 +529,177 @@ void CacheEntry::Unlink() { parent->lru_size_--; } -WeakKey WeakrefLRUCache::MakeWeakrefKey(const nb::object& weakref_key) { - size_t wrcache_hash = StrongPythonHash(weakref_key); - - auto weakref_gc_callback = nb::cpp_function( - [this_weak = weak_from_this(), wrcache_hash](nb::handle weakref) { - // We are careful to use a weak reference to the cache object here to - // avoid the following reference cycle: the cache holds weakref objects - // as its keys, and weakrefs, despite having "weak" in their name, - // hold a strong reference to their callbacks. This would be a strong - // reference cycle. - auto cache = this_weak.lock(); - if (cache == nullptr) { - return; - } - auto py_cache = nb::find(cache); - // This should never happen as python cache should always be found - CHECK(py_cache.ptr() != nullptr); - nb::ft_object_guard lock(py_cache); +void WeakrefLRUCache::RemoveEntryFromReverseIndex(CacheEntry* entry) { + for (const nb::weakref& wref : entry->wr_key.refs) { + PyObject* weakref_ptr = wref.ptr(); + auto rev_it = reverse_index_.find(weakref_ptr); + if (rev_it != reverse_index_.end()) { + rev_it->second.erase(entry); + if (rev_it->second.empty()) { + reverse_index_.erase(rev_it); + } + } + } +} - // The object the reference referred to is now in the process of being - // destroyed, so we cannot refer to its contents. Python weakref - // objects compare based on identity if the object they refer to is - // gone, so the hash lookup will work fine. - WeakKey search_key(nb::borrow(weakref), wrcache_hash); - cache->EvictWeakref(search_key); - }); - return WeakKey(nb::weakref(weakref_key, weakref_gc_callback), wrcache_hash); +void WeakrefLRUCache::MoveToFront(CacheEntry* node) { + if (node->IsLinked()) { + node->Unlink(); + } + PushFront(node); +} + +void WeakrefLRUCache::PushFront(CacheEntry* node) { + CHECK(!node->IsLinked()); + node->lru_node.next = lru_head_.next; + node->lru_node.prev = &lru_head_; + if (lru_head_.next) { + lru_head_.next->lru_node.prev = &node->lru_node; + } else { + lru_head_.prev = &node->lru_node; + } + lru_head_.next = node; + ++lru_size_; +} + +void WeakrefLRUCache::EvictLeastRecentlyUsed() { + CacheEntry* tail = lru_head_.prev->prev->next; + if (tail->IsLinked()) { + tail->Unlink(); + } + + // Use heterogeneous lookups so we compare objects by pointer identity. + // This avoids calling Python __eq__ methods which might release the lock. + PointerWeakKey ptr_wr_key{tail->wr_key.refs, tail->wr_key.cached_hash}; + auto cache_it = entries_.find(ptr_wr_key); + if (cache_it == entries_.end()) { + return; + } + std::shared_ptr cache_ptr = cache_it->second; + + PointerStrongKey ptr_strong_key{ + tail->key.context(), absl::MakeConstSpan(tail->key.kwnames()), + absl::MakeConstSpan(tail->key.args_span()), tail->key.cached_hash()}; + auto inner_it = cache_ptr->find(ptr_strong_key); + if (inner_it == cache_ptr->end()) { + return; + } + + // To prevent Python object destructors from running (and potentially + // dropping the lock) *during* the erase operation, we grab an extra + // reference to the keys and values here. They will be destroyed at the + // end of this function block, after the map operations are complete. + WeakKey wr_key_copy = tail->wr_key; + StrongKey strong_key_copy = tail->key; + std::shared_ptr value_copy = inner_it->second; + + // Now erase from the map. Because we hold references, no Python + // destruction happens here. + cache_ptr->erase(inner_it); + if (cache_ptr->empty()) { + entries_.erase(cache_it); + } +} + +WeakKey WeakrefLRUCache::MakeWeakrefKey( + absl::Span weakref_args) { + size_t combined_hash = 0; + absl::InlinedVector refs; + refs.reserve(num_weak_keys_); + for (int i = 0; i < num_weak_keys_; ++i) { + nb::object obj = nb::borrow(weakref_args[i]); + size_t h = StrongPythonHash(obj); + combined_hash = absl::HashOf(std::make_pair(combined_hash, h)); + refs.push_back(nb::weakref(obj, weakref_callback_)); + } + return WeakKey(std::move(refs), combined_hash); } -void WeakrefLRUCache::EvictWeakref(const WeakKey& search_key) { +nb::object WeakrefLRUCache::WeakrefKeyToPython( + absl::Span weakref_args) const { + if (num_weak_keys_ == 1) { + return nb::borrow(weakref_args[0]); + } + nb::tuple keys = nb::steal(PyTuple_New(num_weak_keys_)); + for (int i = 0; i < num_weak_keys_; ++i) { + PyTuple_SET_ITEM(keys.ptr(), i, weakref_args[i]); + Py_INCREF(weakref_args[i]); + } + return keys; +} + +nb::object WeakrefLRUCache::WeakrefKeyToPython( + absl::Span weakref_args) const { + if (num_weak_keys_ == 1) { + return nb::cast(weakref_args[0]); + } + nb::tuple keys = nb::steal(PyTuple_New(num_weak_keys_)); + for (int i = 0; i < num_weak_keys_; ++i) { + nb::object obj = nb::cast(weakref_args[i]); + PyTuple_SET_ITEM(keys.ptr(), i, obj.inc_ref().ptr()); + } + return keys; +} + +void WeakrefLRUCache::EvictWeakKey(const WeakKey& search_key) { auto it = entries_.find(search_key); if (it != entries_.end()) { - for (auto& inner_kv : *(it->second)) { - if (inner_kv.second->IsLinked()) { - inner_kv.second->Unlink(); + auto& [wr_key, cache_ptr] = *it; + std::vector> deferred_deletes; + deferred_deletes.reserve(cache_ptr->size()); + for (auto& [strong_key, entry_ptr] : *cache_ptr) { + if (entry_ptr->IsLinked()) { + entry_ptr->Unlink(); } + + deferred_deletes.push_back(std::move(entry_ptr)); } + cache_ptr->clear(); // Create temp-var to avoid re-entrant erase. auto tmp = std::move(*it); entries_.erase(it); } } +void WeakrefLRUCache::EvictWeakref(PyObject* dying_weakref_ptr) { + auto rev_it = reverse_index_.find(dying_weakref_ptr); + if (rev_it == reverse_index_.end()) return; + + // We need to move the set because Unlink and modifying reverse_index_ + // will change the collections. + absl::flat_hash_set entries_to_evict = std::move(rev_it->second); + reverse_index_.erase(rev_it); + + std::vector> deferred_deletes; + deferred_deletes.reserve(entries_to_evict.size()); + + for (CacheEntry* entry : entries_to_evict) { + if (entry->IsLinked()) { + entry->Unlink(); + } + + PointerWeakKey ptr_wr_key{entry->wr_key.refs, entry->wr_key.cached_hash}; + auto cache_it = entries_.find(ptr_wr_key); + if (cache_it != entries_.end()) { + auto& [wr_key, cache_ptr] = *cache_it; + PointerStrongKey ptr_strong_key{ + entry->key.context(), absl::MakeConstSpan(entry->key.kwnames()), + absl::MakeConstSpan(entry->key.args_span()), + entry->key.cached_hash()}; + auto inner_it = cache_ptr->find(ptr_strong_key); + if (inner_it != cache_ptr->end()) { + auto& [strong_key, entry_ptr] = *inner_it; + deferred_deletes.push_back(std::move(entry_ptr)); + cache_ptr->erase(inner_it); + } + if (cache_ptr->empty()) { + auto tmp = std::move(*cache_it); + entries_.erase(cache_it); + } + } + } +} + PyObject* WeakrefLRUCache::VectorCall(PyObject* self_obj, PyObject* const* args, Py_ssize_t nargsf, PyObject* kwnames) { WeakrefLRUCache* self = nb::inst_ptr(self_obj); @@ -554,16 +726,18 @@ PyObject* WeakrefLRUCache::Call(PyObject* self_obj, absl::Span args, Py_ssize_t nargsf, PyObject* kwnames) { Py_ssize_t nargs_positional = PyVectorcall_NARGS(nargsf); - if (nargs_positional < 1) { - PyErr_SetString(PyExc_TypeError, "Missing weakref_key argument"); + if (nargs_positional < num_weak_keys_) { + PyErr_SetString(PyExc_TypeError, + absl::StrCat("Missing weakref_key argument(s). Expected ", + num_weak_keys_) + .c_str()); return nullptr; } nb::object context = cache_context_fn_(); - nb::object weakref_key = nb::borrow(args[0]); - WeakKey wrcache_key = MakeWeakrefKey(weakref_key); - StrongKey key(context, args.subspan(1), + WeakKey wrcache_key = MakeWeakrefKey(args.subspan(0, num_weak_keys_)); + StrongKey key(context, args.subspan(num_weak_keys_), kwnames ? nb::borrow(kwnames) : nb::tuple()); bool inserted = false; @@ -582,10 +756,12 @@ PyObject* WeakrefLRUCache::Call(PyObject* self_obj, if (weak_inserted) { it_weak->second = std::make_shared(); + } else { + wrcache_key = it_weak->first; } - // We need to make sure the Cache remains alive as long as this code block. - // Also, we must drop it safely under the lock because its destruction - // destroys CacheEntries which call Unlink() on the LRU list. + // We need to make sure the Cache remains alive as long as this code + // block. Also, we must drop it safely under the lock because its + // destruction destroys CacheEntries which call Unlink() on the LRU list. cache_ptr = it_weak->second; Cache& cache = *cache_ptr; @@ -599,6 +775,10 @@ PyObject* WeakrefLRUCache::Call(PyObject* self_obj, it_strong->second = entry; PushFront(entry.get()); + for (const nb::weakref& wref : wrcache_key.refs) { + reverse_index_[wref.ptr()].insert(entry.get()); + } + if (lru_maxsize_ > 0 && lru_size_ > lru_maxsize_) { // Note: EvictLeastRecentlyUsed may release the lock and may throw // exceptions. @@ -667,9 +847,10 @@ PyObject* WeakrefLRUCache::Call(PyObject* self_obj, entry->has_result = true; } else { if (entry->thread_id == std::this_thread::get_id()) { - auto error_string = - absl::StrCat("Recursively calling ", - nb::cast(nb::repr(weakref_key))); + nb::object repr_obj = + WeakrefKeyToPython(args.subspan(0, num_weak_keys_)); + auto error_string = absl::StrCat( + "Recursively calling ", nb::cast(nb::repr(repr_obj))); PyErr_SetString(PyExc_RecursionError, error_string.c_str()); return nullptr; } @@ -697,8 +878,11 @@ std::vector WeakrefLRUCache::GetKeys() { if (!value->completed.HasBeenNotified()) { continue; } + + nb::object wr_key_obj = WeakrefKeyToPython(wr_key.refs); + nb::tuple result = - nb::make_tuple(*wr_key.ref, key.context(), key.args(), key.kwargs()); + nb::make_tuple(wr_key_obj, key.context(), key.args(), key.kwargs()); results.push_back(std::move(result)); } } @@ -721,6 +905,9 @@ void WeakrefLRUCache::Clear() { for (auto& kv : entries_) { for (auto& inner_kv : *(kv.second)) { + if (inner_kv.second->IsLinked()) { + inner_kv.second->Unlink(); + } if (inner_kv.second->IsLinked()) { inner_kv.second->Unlink(); } @@ -730,6 +917,7 @@ void WeakrefLRUCache::Clear() { } entries_.clear(); deferred_deletes.clear(); + reverse_index_.clear(); total_queries_ = misses_ = 0; } @@ -746,12 +934,15 @@ void WeakrefLRUCache::Clear() { if (cache->explain_) { Py_VISIT(cache->explain_->ptr()); } + Py_VISIT(cache->weakref_callback_.ptr()); for (const auto& kv : cache->entries_) { const WeakKey& wr_key = kv.first; const WeakrefCacheValue& wr_value = kv.second; - Py_VISIT(wr_key.ref.ptr()); + for (const nb::weakref& wref : wr_key.refs) { + Py_VISIT(wref.ptr()); + } for (const auto& inner_kv : *wr_value) { const StrongKey& key = inner_kv.first; @@ -775,6 +966,7 @@ void WeakrefLRUCache::Clear() { cache->cache_context_fn_.reset(); cache->fn_.reset(); cache->explain_ = std::nullopt; + cache->weakref_callback_.reset(); return 0; } @@ -796,8 +988,28 @@ NB_MODULE(weakref_lru_cache, m) { .def( "evict_weakref", [](WeakrefLRUCache& cache, nb::object weakref_key) { - cache.EvictWeakref(WeakKey(nb::weakref(weakref_key), - StrongPythonHash(weakref_key))); + if (cache.num_weak_keys() == 1) { + PyObject* ptr = weakref_key.ptr(); + cache.EvictWeakKey(cache.MakeWeakrefKey({&ptr, 1})); + } else { + if (!nb::isinstance(weakref_key)) { + PyErr_SetString(PyExc_TypeError, + "evict_weakref expects a tuple of weak " + "keys for multi-weakref cache"); + return; + } + nb::tuple t = nb::cast(weakref_key); + if (t.size() != cache.num_weak_keys()) { + PyErr_SetString(PyExc_ValueError, + "Incorrect number of weak keys"); + return; + } + absl::InlinedVector ptrs; + for (auto item : t) { + ptrs.push_back(item.ptr()); + } + cache.EvictWeakKey(cache.MakeWeakrefKey(ptrs)); + } }, nb::lock_self()) .def("cache_keys", &WeakrefLRUCache::GetKeys, nb::lock_self()) @@ -822,14 +1034,15 @@ NB_MODULE(weakref_lru_cache, m) { m.def( "weakref_lru_cache", [](nb::callable cache_context_fn, nb::callable fn, - std::optional maxsize, std::optional explain) { - return std::make_shared( - cache_context_fn, fn, - maxsize.value_or(std::numeric_limits::max()), explain); + std::optional maxsize, std::optional explain, + int num_weakrefs) { + return WeakrefLRUCache::Create( + cache_context_fn, fn, maxsize.value_or(-1), explain, num_weakrefs); }, nb::arg("cache_context_fn"), nb::arg("fn"), nb::arg("maxsize").none() = 2048, - nb::arg("explain") = std::optional()); + nb::arg("explain") = std::optional(), + nb::arg("num_weakrefs") = 1); } } // namespace jax diff --git a/jaxlib/weakref_lru_cache.pyi b/jaxlib/weakref_lru_cache.pyi index fcd30a8d0cdc..f6be9f2d38ba 100644 --- a/jaxlib/weakref_lru_cache.pyi +++ b/jaxlib/weakref_lru_cache.pyi @@ -36,5 +36,5 @@ class WeakrefLRUCache: def weakref_lru_cache( cache_context_fn: Callable, fn: Callable, maxsize: int | None = 2048, - explain: Callable | None = None + explain: Callable | None = None, num_weakrefs: int = 1 ) -> WeakrefLRUCache: ... diff --git a/jaxlib/weakref_lru_cache_test.py b/jaxlib/weakref_lru_cache_test.py index 42d09c74055f..3c3d2b5a71ff 100644 --- a/jaxlib/weakref_lru_cache_test.py +++ b/jaxlib/weakref_lru_cache_test.py @@ -99,7 +99,7 @@ def __hash__(self): def WorkerAddToCache(): barrier.wait() - for i in range(10000): + for i in range(1000): cache(WRKey(i), i) def WorkerCleanCache(): @@ -447,6 +447,73 @@ class WRKey: self.assertEqual(info.misses, 4) self.assertEqual(info.hits, 1) + def testMultiWeakref(self): + class WRKey: + pass + + cache = weakref_lru_cache.weakref_lru_cache( + lambda: None, lambda x, y, z: z, 2048, num_weakrefs=2 + ) + + k1 = WRKey() + k2 = WRKey() + + cache(k1, k2, 1) + + info = cache.cache_info() + self.assertEqual(info.misses, 1) + self.assertEqual(info.hits, 0) + + cache(k1, k2, 1) + info = cache.cache_info() + self.assertEqual(info.misses, 1) + self.assertEqual(info.hits, 1) + + # Delete k1, the entry should be evicted + del k1 + + k1 = WRKey() + cache(k1, k2, 1) + + info = cache.cache_info() + self.assertEqual(info.misses, 2) + self.assertEqual(info.hits, 1) + + def testMemoryLeakWithMultipleStrongKeys(self): + class WRKey: + pass + + cache = weakref_lru_cache.weakref_lru_cache( + lambda: None, lambda x, y: y, 2048, num_weakrefs=1 + ) + + wk = WRKey() + + class ObjectWithDestructor: + + def __init__(self, val, l): + self.val = val + self.l = l + + def __del__(self): + self.l.append(self.val) + + deleted_strong_keys = [] + + # Insert multiple strong keys with the same weak key. + for i in range(10): + cache(wk, ObjectWithDestructor(i, deleted_strong_keys)) + + info = cache.cache_info() + self.assertEqual(info.misses, 10) + self.assertEqual(len(deleted_strong_keys), 0) + + # Delete the weak key. All cache entries associated with it should be + # dropped. + del wk + self.assertEqual(len(deleted_strong_keys), 10) + self.assertEqual(set(deleted_strong_keys), set(range(10))) + if __name__ == "__main__": absltest.main() From 3ec221c209a519c6abb82e6af69c51bf352d9735 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 3 Mar 2026 11:15:42 -0800 Subject: [PATCH 044/100] PR #35567: Bumped Pyrefly to 0.55 Imported from GitHub PR https://github.com/jax-ml/jax/pull/35567 I also removed unnecessary suppressions which were necessary due to the now fixed bugs in Pyrefly. Copybara import of the project: -- f46bd088c21b4abe37f271a3a331d859769e5e82 by Sergei Lebedev : Bumped Pyrefly to 0.55 I also removed unnecessary suppressions which were necessary due to the now fixed bugs in Pyrefly. Merging this change closes #35567 COPYBARA_INTEGRATE_REVIEW=https://github.com/jax-ml/jax/pull/35567 from superbobry:pyrefly f46bd088c21b4abe37f271a3a331d859769e5e82 PiperOrigin-RevId: 878039302 --- .pre-commit-config.yaml | 2 +- jax/_src/interpreters/ad.py | 10 +++++----- jax/_src/interpreters/partial_eval.py | 6 +++--- jax/_src/pallas/mosaic_gpu/lowering.py | 8 ++++---- jax/_src/pallas/pallas_call.py | 12 ++++++------ jax/_src/pallas/utils.py | 2 -- jax/_src/pjit.py | 2 +- jax/_src/state/primitives.py | 2 +- 8 files changed, 21 insertions(+), 23 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 78e9ce552d4d..e41792758b37 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -54,7 +54,7 @@ repos: # This is a manual-only pre-commit hook to run pyrefly type checks. To run it: # $ pre-commit run --hook-stage manual pyrefly-check --all-files - repo: https://github.com/facebook/pyrefly-pre-commit - rev: 0ed71f5d10c035e02f24a220058b39070d165142 # frozen: v0.54.0 + rev: 30778c6e83a71508a62b7297f8b22660ce4496fc # frozen: v0.55.0 hooks: - id: pyrefly-check name: Pyrefly (type checking) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 47591d696a84..5f11f1e85dbd 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -154,12 +154,12 @@ def linearize_jaxpr( ) -> tuple[core.ClosedJaxpr, int, Sequence[bool], Sequence[int | None], core.ClosedJaxpr]: if type(allow_fwds) is bool: allow_fwds = (allow_fwds,) * (len(jaxpr.consts) + len(jaxpr.jaxpr.invars)) - assert len(allow_fwds) == (len(jaxpr.consts) + len(jaxpr.jaxpr.invars)) # pyrefly: ignore[bad-argument-type] # pyrefly#2530 + assert len(allow_fwds) == (len(jaxpr.consts) + len(jaxpr.jaxpr.invars)) if type(instantiate) is bool: instantiate = (instantiate,) * len(jaxpr.jaxpr.outvars) - assert len(instantiate) == len(jaxpr.jaxpr.outvars) # pyrefly: ignore[bad-argument-type] # pyrefly#2530 - return _linearize_jaxpr(jaxpr, tuple(nonzeros), tuple(instantiate), # pyrefly: ignore[bad-argument-type] # pyrefly#2530 - tuple(allow_fwds), is_vjp) # pyrefly: ignore[bad-argument-type] # pyrefly#2530 + assert len(instantiate) == len(jaxpr.jaxpr.outvars) + return _linearize_jaxpr(jaxpr, tuple(nonzeros), tuple(instantiate), + tuple(allow_fwds), is_vjp) @weakref_lru_cache @source_info_util.reset_name_stack() @@ -1331,7 +1331,7 @@ def jvp_jaxpr(jaxpr: core.ClosedJaxpr, nonzeros: Sequence[bool], ) -> tuple[core.ClosedJaxpr, list[bool]]: if type(instantiate) is bool: instantiate = (instantiate,) * len(jaxpr.out_avals) - return _jvp_jaxpr(jaxpr, tuple(nonzeros), tuple(instantiate)) # pyrefly: ignore[bad-argument-type] # pyrefly#2530 + return _jvp_jaxpr(jaxpr, tuple(nonzeros), tuple(instantiate)) @weakref_lru_cache def _jvp_jaxpr(jaxpr: core.ClosedJaxpr, diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index bbc8916e5180..48d3ac4ebe0f 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1012,9 +1012,9 @@ def partial_eval_jaxpr_stateful( saveable = everything_saveable jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref = \ _partial_eval_jaxpr_custom_cached( - # pyrefly: ignore[bad-argument-type] # pyrefly#2530 + jaxpr, tuple(in_unknowns), tuple(in_inst), tuple(ensure_out_unknowns), - # pyrefly: ignore[bad-argument-type] # pyrefly#2530 + tuple(ensure_out_inst), saveable) return jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref @@ -1410,7 +1410,7 @@ def dce_jaxpr(jaxpr: Jaxpr, used_outputs: Sequence[bool], """ if type(instantiate) is bool: instantiate = (instantiate,) * len(jaxpr.invars) - # pyrefly: ignore[bad-argument-type] # pyrefly#2530 + return _dce_jaxpr(jaxpr, tuple(used_outputs), tuple(instantiate)) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index deadd6317e03..dd97ea5752b2 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -2060,8 +2060,7 @@ def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases): [out_aval] = ctx.avals_out if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: pred = _ensure_fa(pred, pred_aval.dtype) - # pyrefly: ignore[bad-argument-count] # pyrefly#2487 - cases = _bcast(*cases, *cases_avals, out_aval) + cases = _bcast(*cases, *cases_avals, out_aval=out_aval) # ``select`` expects the first case to be the true branch, but ``select_n`` # orders the cases in reverse. return pred.select(*reversed(cases)) @@ -3119,8 +3118,9 @@ def _run_scoped_lowering_rule( dtype = mlir.dtype_to_ir_type(aval.dtype) if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane: input_refs.append( - # pyrefly: ignore[bad-argument-count] # pyrefly#2487 - mgpu.WGMMAAccumulator.zero(*aval.shape, dtype, is_signed=is_signed) + mgpu.WGMMAAccumulator.zero( + *aval.shape, dtype=dtype, is_signed=is_signed + ) ) else: if isinstance(dtype, ir.IntegerType): diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 9d6dcc97a241..55af90fb031b 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -364,7 +364,7 @@ def _block_map_function(new_idx, *args): unflat_indices = (unflat_indices,) unflat_indices = list(unflat_indices) if dim is not batching.not_mapped: - unflat_indices.insert(dim, new_idx) # pyrefly: ignore[bad-argument-type] # pyrefly#2499 + unflat_indices.insert(dim, new_idx) return tuple(unflat_indices) idx_avals = [pallas_core.index_map_grid_aval, *block_mapping.index_map_jaxpr.in_avals] @@ -382,11 +382,11 @@ def _block_map_function(new_idx, *args): new_block_shape = shape new_array_aval = block_mapping.array_aval else: - # pyrefly: ignore[bad-argument-type] # pyrefly#2499 + new_block_shape = tuple_insert(shape, dim, pallas_core.squeezed) array_shape = block_mapping.array_aval.shape - # pyrefly: ignore[bad-argument-type] # pyrefly#2499 + array_shape = tuple_insert(array_shape, dim, axis_size) new_array_aval = jax_core.ShapedArray( @@ -425,7 +425,7 @@ def _broadcast_input_output_aliases( args_[input_index], axis_size, 0, None) elif dim != 0: # TODO(cjfj): Change output batching axis instead? - # pyrefly: ignore[bad-argument-type] # pyrefly#2499 + args_[input_index] = jnp.moveaxis(args[input_index], dim, 0) return tuple(args_), tuple(dims_) @@ -462,7 +462,7 @@ def _batch_with_explicit_loop( raise NotImplementedError("vmapping pallas_call with no arguments.") (axis_size,) = { - arg.shape[dim] # pyrefly: ignore[bad-index] # pyrefly#2499 + arg.shape[dim] for arg, dim in zip(args, dims) if dim is not batching.not_mapped } @@ -498,7 +498,7 @@ def body(batch_index: jax_typing.Array, state: list[jax_typing.Array]) -> list[j operand=arg, start_index=batch_index, slice_size=1, - axis=dim, # pyrefly: ignore[bad-argument-type] # pyrefly#2499 + axis=dim, ), axis=dim, ) diff --git a/jax/_src/pallas/utils.py b/jax/_src/pallas/utils.py index e34011696f92..57879d449159 100644 --- a/jax/_src/pallas/utils.py +++ b/jax/_src/pallas/utils.py @@ -397,7 +397,6 @@ def nextafter_lowering_helper(x, y): x_magnitude_larger_than_y = x_abs > y_abs result_has_smaller_magnitude = x_magnitude_larger_than_y | signs_disagree minus_one = jnp.full_like(x_as_int, np_int(-1).view(np_uint)) - # pyrefly: ignore[no-matching-overload] # pyrefly#2498 magnitude_adjustment = jnp.where(result_has_smaller_magnitude, minus_one, one) result = x_as_int + magnitude_adjustment @@ -412,7 +411,6 @@ def nextafter_lowering_helper(x, y): result = jnp.where(x_and_y_are_equal, result_for_equal, result) # Handle isnan(x) || isnan(y). - # pyrefly: ignore[no-matching-overload] # pyrefly#2498 result = jnp.where(nan_input, result_for_nan, result) # Cast back to the original type. diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 0335333aa13e..02f52e973f75 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1429,7 +1429,7 @@ def _pjit_batcher(axis_data, vals_in, if axis_in is not None else i for axis_in, i, aval in zip(dims_in, in_shardings, new_jaxpr.in_avals)) out_shardings = tuple( - # pyrefly: ignore[bad-argument-type] # pyrefly#2499 + _pjit_batcher_for_sharding(o, axis_out, axis_data.spmd_name, ctx_mesh, aval.ndim) if axis_out is not None else o diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 091b399b6ee7..ad48cf00ef0b 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -785,7 +785,7 @@ def _batch_indexer( else: batch_idx = indexing.Slice(0, axis_size) # type: ignore new_integer_indexer_shape = () - # pyrefly: ignore[bad-argument-type] # pyrefly#2499 + new_indices.insert(ref_dim, batch_idx) return indexing.NDIndexer( tuple(new_indices), ref_shape, new_integer_indexer_shape, validate=True From 031c1e919d98d386c49bde38800683c022ba5c99 Mon Sep 17 00:00:00 2001 From: Maxim Ermilov Date: Tue, 3 Mar 2026 12:23:35 -0800 Subject: [PATCH 045/100] Use driver linking as default provider. Disable nvjitlink when parallel compilation enabled due to memory leak PiperOrigin-RevId: 878071715 --- jaxlib/mosaic/gpu/custom_call.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index cb06161afdf3..aa6d051bd527 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -381,7 +381,7 @@ GetAssemblyToBinaryCompilationProvider() { nvjitlink_mode = se::cuda::CompilationProviderOptions::NvJitLinkMode::kAuto; constexpr bool enable_llvm_module_compilation_parallelism = false; - constexpr bool enable_driver_compilation = false; + constexpr bool enable_driver_compilation = true; bool enable_libnvptxcompiler = se::IsLibNvPtxCompilerSupported(); se::cuda::CompilationProviderOptions opts( From b275aec2f74de74f1911745ba7a199bc918a29dc Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 3 Mar 2026 19:55:40 +0000 Subject: [PATCH 046/100] Use Pyrefly instead of mypy A few caveats * We do not type check anything under jax/experimental at the moment. This will be addressed separately in follow up PRs. * We need a nightly jaxlib for accurate type stubs for MLIR dialects. --- .github/workflows/pyrefly.yml | 33 ----------- .pre-commit-config.yaml | 33 +++++------ pyproject.toml | 104 +++++++++++++++------------------- pyrefly.toml | 46 --------------- 4 files changed, 62 insertions(+), 154 deletions(-) delete mode 100644 .github/workflows/pyrefly.yml delete mode 100644 pyrefly.toml diff --git a/.github/workflows/pyrefly.yml b/.github/workflows/pyrefly.yml deleted file mode 100644 index 2287751ea066..000000000000 --- a/.github/workflows/pyrefly.yml +++ /dev/null @@ -1,33 +0,0 @@ -name: Pyrefly type check (non-blocking) - -on: - push: - branches: - - main - pull_request: - branches: - - main - -permissions: {} - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} - # Don't cancel in-progress jobs for main branches. - cancel-in-progress: ${{ github.ref != 'main' }} - -jobs: - pyrefly: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - persist-credentials: false - - name: Set up Python 3.12 - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 - with: - python-version: 3.12 - - run: python -m pip install pre-commit - - name: Run pyrefly check - run: pre-commit run pyrefly-check --hook-stage=manual --show-diff-on-failure --color=always --all-files - # This is expected to fail; we set continue-on-error so the workflow will be marked green. - continue-on-error: true diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e41792758b37..5e4d9259562a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -35,14 +35,22 @@ repos: hooks: - id: ruff -- repo: https://github.com/pre-commit/mirrors-mypy - rev: a66e98df7b4aeeb3724184b332785976d062b92e # frozen: v1.19.1 +- repo: https://github.com/facebook/pyrefly-pre-commit + rev: 30778c6e83a71508a62b7297f8b22660ce4496fc # frozen: v0.55.0 hooks: - - id: mypy - files: (jax/|tests/typing_test\.py) - exclude: jax/_src/basearray.py|jax/numpy/__init__.py|jax/nn/__init__.py|jaxlib/_jax/.* # Use pyi instead - additional_dependencies: [types-requests==2.31.0, numpy~=2.3.0, scipy-stubs] - args: [--config=pyproject.toml] + - id: pyrefly-check + name: Pyrefly (type checking) + pass_filenames: false + additional_dependencies: + - absl-py==2.4.0 + - types-requests~=2.32.0 + - numpy~=2.4.0 + - ml_dtypes~=0.5.0 + - scipy-stubs + - --pre + - --extra-index-url + - https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ + - jaxlib - repo: https://github.com/mwouts/jupytext rev: 8ed836db64ad5d304f2315e6bfd9049c9142e190 # frozen: v1.16.4 @@ -51,17 +59,6 @@ repos: files: docs/ args: [--sync] -# This is a manual-only pre-commit hook to run pyrefly type checks. To run it: -# $ pre-commit run --hook-stage manual pyrefly-check --all-files -- repo: https://github.com/facebook/pyrefly-pre-commit - rev: 30778c6e83a71508a62b7297f8b22660ce4496fc # frozen: v0.55.0 - hooks: - - id: pyrefly-check - name: Pyrefly (type checking) - pass_filenames: false - additional_dependencies: [absl-py==2.4.0, types-requests~=2.32.0, numpy~=2.4.0, ml_dtypes~=0.5.0, scipy-stubs] - stages: [manual] - - repo: local hooks: - id: check-copyright diff --git a/pyproject.toml b/pyproject.toml index 1341e55ae827..6b6593520c69 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,63 +2,6 @@ requires = ["setuptools", "wheel"] build-backend = "setuptools.build_meta" -[tool.mypy] -show_error_codes = true -disable_error_code = "attr-defined, name-defined, annotation-unchecked" -no_implicit_optional = true -warn_redundant_casts = true -allow_redefinition = true - -[[tool.mypy.overrides]] -module = [ - "IPython.*", - "absl.*", - "compression.*", - "etils.*", - "filelock.*", - "flatbuffers.*", - "flax.*", - "google.colab.*", - "hypothesis.*", - "jax.experimental.jax2tf.tests.back_compat_testdata", - "jax.experimental.jax2tf.tests.flax_models", - "jax_cuda12_plugin.*", - "jax_cuda13_plugin.*", - "jaxlib.cpu_feature_guard", - "jaxlib.cuda.*", - "jaxlib.mlir.*", - "jaxlib.mosaic.dialect.gpu.*", - "jaxlib.mosaic.python._tpu_gen", - "jaxlib.triton.*", - "jaxlib.utils", - "jaxlib.version", - "jaxlib._jax.utils", - "jaxlib._pretty_printer", - "jraph.*", - "libtpu.*", - "matplotlib.*", - "mlir.*", - "ml_dtypes.*", - "nvidia.*", - "numpy.*", - "opt_einsum.*", - "optax.*", - "portpicker.*", - "pygments.*", - "pytest.*", - "rich.*", - "setuptools.*", - "xprof.convert.*", - "tensorflow.*", - "tensorflow.io.*", - "tensorflowjs.*", - "tensorstore.*", - "web_pdb.*", - "zstandard.*", - "kubernetes.*" -] -ignore_missing_imports = true - [tool.pytest.ini_options] markers = [ "multiaccelerator: indicates that a test can make use of and possibly requires multiple accelerators", @@ -97,6 +40,53 @@ doctest_optionflags = [ ] addopts = "--doctest-glob='*.rst' --ignore='examples/ffi' --import-mode=importlib" +[tool.pyrefly] +# TODO(slebedev): Use 3.11 here. +python-version = "3.12" +project-includes = [ + "jax/**/*.py*", + # "tests/**/*.py", +] +project-excludes = [ + "jax/example_libraries/**/*.py", + # TODO(slebedev): Opt in jax/experimental and remove this exclude. + "jax/experimental/**/*.py", +] +ignore-missing-imports = [ + "absl.*", + "flatbuffers.*", + "flax.*", + "google.colab.*", + "hypothesis.*", + "IPython.*", + "jax_cuda12_plugin.*", + "jax_cuda13_plugin.*", + "jaxlib.*", + "jraph.*", + "kubernetes.*", + "libtpu.*", + "matplotlib.*", + "mpi4py.*", + "nvidia.*", + "optax.*", + "opt_einsum.*", + "portpicker.*", + "pygments.*", + "pytest.*", + "rich.*", + "setuptools.*", + "tensorflow.*", + "tensorflowjs.*", + "tensorflow_serving.*", + "tensorstore.*", + "web_pdb.*", + "xprof.*", + "zstandard.*", +] +permissive-ignores = true +# TODO(slebedev): Change this to "check-and-infer-return-type". +untyped-def-behavior = "check-and-infer-return-any" + [tool.ruff] preview = true exclude = [ diff --git a/pyrefly.toml b/pyrefly.toml deleted file mode 100644 index e8d2b1b432dc..000000000000 --- a/pyrefly.toml +++ /dev/null @@ -1,46 +0,0 @@ -python-version = "3.12" - -project-includes = [ - "jax/**/*.py*", - # "tests/**/*.py", -] - -project-excludes = [ - "jax/example_libraries/**/*.py", - "jax/experimental/**/*.py", -] - -ignore-missing-imports = [ - "absl.*", - "flatbuffers.*", - "flax.*", - "google.colab.*", - "hypothesis.*", - "IPython.*", - "jax_cuda12_plugin.*", - "jax_cuda13_plugin.*", - "jaxlib.*", - "jraph.*", - "kubernetes.*", - "libtpu.*", - "matplotlib.*", - "mpi4py.*", - "nvidia.*", - "optax.*", - "opt_einsum.*", - "portpicker.*", - "pygments.*", - "pytest.*", - "rich.*", - "setuptools.*", - "tensorflow.*", - "tensorflowjs.*", - "tensorflow_serving.*", - "tensorstore.*", - "web_pdb.*", - "xprof.*", - "zstandard.*", -] - -permissive-ignores = true -untyped-def-behavior = "check-and-infer-return-any" From d777b52f4d6d83217d45b74eeeb1375a857d61bd Mon Sep 17 00:00:00 2001 From: Jincheng Chen Date: Tue, 3 Mar 2026 12:57:56 -0800 Subject: [PATCH 047/100] Integrate TorchTPU SDPA with optimized flash attention kernel. PiperOrigin-RevId: 878086730 --- jax/experimental/pallas/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index 72a81cc2bbc3..674a73b53314 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -32,6 +32,7 @@ from jax._src.pallas.core import MemoryRef as MemoryRef from jax._src.pallas.core import MemorySpace as MemorySpace from jax._src.pallas.core import no_block_spec as no_block_spec +from jax._src.pallas.core import pallas_export_experimental as pallas_export_experimental from jax._src.pallas.core import semaphore as semaphore from jax._src.pallas.core import Squeezed as Squeezed from jax._src.pallas.core import squeezed as squeezed From 8a73c61bf5dde0c8ccbc3523eae03e13e3e199e1 Mon Sep 17 00:00:00 2001 From: Yue Sheng Date: Tue, 3 Mar 2026 13:13:10 -0800 Subject: [PATCH 048/100] [Mosaic TPU] Support reshape which unfolds the minormost dim into two dims when the minormost dim is not divisible by 128, also speed up the reshape for already supported cases where the minormost dim is a multiple of 128. PiperOrigin-RevId: 878092970 --- tests/pallas/tpu_pallas_test.py | 89 ++++++++++++++++++++++++++++++++- 1 file changed, 88 insertions(+), 1 deletion(-) diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index f289f2512bc4..357c16f78629 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -3866,6 +3866,94 @@ def kernel(x_ref, o_ref): result, np.transpose(x.reshape(mid_shape), axes=(1, 0, 2)) ) + # (q, m*n) -> (q, m, n) + @parameterized.parameters( + (q, m, n, dtype) + for (q, m, n), dtype in itertools.product( + [ + # n % 128 == 0 + (32, 16, 512), + (20, 19, 512), + (5, 3, 256), + (9, 15, 256), + (3, 2, 256), + (4, 2, 1024), + (8, 4, 1024), + # n % 128 != 0 + (32, 16, 500), + (20, 19, 500), + (5, 3, 200), + (9, 15, 200), + (3, 2, 200), + (5, 1, 300), + ], + [jnp.float32, jnp.uint32, jnp.bfloat16, jnp.int8], + ) + ) + def test_reshape_unfold_minor_dim_to_R3(self, q, m, n, dtype): + if not jtu.is_cloud_tpu_at_least(2026, 3, 8): + self.skipTest('Test requires a newer libTPU.') + if n % 128 != 0 and ( + (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) + or (dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5)) + ): + self.skipTest('Operation not supported on this TPU version.') + + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...].reshape(y_ref.shape) + + x = np.arange(q * m * n, dtype=dtype).reshape(q, m * n) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((q, m, n), dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape([q, m, n])) + + # (q, m, n*k) -> (q, m, n, k) + @parameterized.parameters( + (q, m, n, k, dtype) + for (q, m, n, k), dtype in itertools.product( + [ + # k % 128 == 0 + (3, 8, 17, 512), + (1, 8, 9, 256), + (1, 8, 3, 256), + (10, 1, 4, 256), + (1, 2, 2, 256), + (1, 9, 3, 256), + (3, 4, 2, 1024), + (5, 8, 4, 1024), + # k % 128 != 0 + (3, 8, 17, 500), + (1, 8, 9, 200), + (1, 8, 3, 200), + (10, 1, 4, 200), + (1, 2, 2, 200), + (1, 9, 3, 200), + (4, 7, 1, 300), + ], + [jnp.float32, jnp.uint32, jnp.bfloat16, jnp.int8], + ) + ) + def test_reshape_unfold_minor_dim_to_R4(self, q, m, n, k, dtype): + if not jtu.is_cloud_tpu_at_least(2026, 3, 8): + self.skipTest('Test requires a newer libTPU.') + if k % 128 != 0 and ( + (dtype == jnp.bfloat16 and not jtu.is_device_tpu_at_least(4)) + or (dtype == jnp.int8 and not jtu.is_device_tpu_at_least(5)) + ): + self.skipTest('Operation not supported on this TPU version.') + + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[...].reshape(y_ref.shape) + + x = np.arange(q * m * n * k, dtype=dtype).reshape(q, m, n * k) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((q, m, n, k), dtype), + )(x) + np.testing.assert_array_equal(out, x.reshape([q, m, n, k])) + # (q, m, n) -> (q, m * n) where n % 128 == 0 @parameterized.parameters( (q, m, n, dtype) @@ -3891,7 +3979,6 @@ def kernel(x_ref, y_ref): kernel, out_shape=jax.ShapeDtypeStruct((q, m * n), dtype), )(x) - jax.numpy.set_printoptions(threshold=jax.numpy.inf) np.testing.assert_array_equal(out, x.reshape([q, m * n])) # (q, m, n, k) -> (q, m, n * k) where k % 128 == 0 From 844bdee4870ca422a14ac645f512415fb43166af Mon Sep 17 00:00:00 2001 From: Robert Dyro Date: Tue, 3 Mar 2026 13:22:47 -0800 Subject: [PATCH 049/100] Re-apply pallas dot_general check for disallowed unsigned integer inputs. Reverts 5c3035c0c63d23e467ae8314df9349bb9ab060a6 PiperOrigin-RevId: 878097392 --- jax/_src/pallas/mosaic/lowering.py | 10 +++++++++- tests/pallas/tpu_ops_test.py | 16 ++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 99b96cf96d1d..278f0a7162e4 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -341,7 +341,7 @@ def _dtype_to_ir_type(dtype: DTypeLike, dtype = BOOL_MEMREF_TYPE # TODO(justinfu): Remove after mosaic supports unsigned types. # This conversion makes mosaic interpret all unsigned types as signed types. - type = mlir.dtype_to_ir_type(jnp.dtype(dtype)) + type = mlir.dtype_to_ir_type(jnp.dtype(dtype)) if isinstance(type, ir.IntegerType): return ir.IntegerType.get_signless(type.width) else: @@ -2229,6 +2229,14 @@ def _dot_general_lowering_rule( preferred_element_type, **_, ): + for aval in ctx.avals_in: + if jnp.issubdtype(aval.dtype, jnp.unsignedinteger): + raise NotImplementedError( + f"Unsigned integer dtype {aval.dtype} is not supported for" + " dot_general (matmul) on the Pallas Mosaic TPU backend because" + " dot_general interprets all integer inputs as signed. Consider" + " casting to a signed type before the dot operation." + ) (lhs_dims, rhs_dims), _ = dimension_numbers (aval_out,) = ctx.avals_out out_type = ctx.aval_to_ir_type(aval_out) diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py index 0ef112cdfcd7..b5e1672fd73e 100644 --- a/tests/pallas/tpu_ops_test.py +++ b/tests/pallas/tpu_ops_test.py @@ -912,6 +912,22 @@ def kernel(x_ref, o_ref): jax.nn.sigmoid(x), ) + @parameterized.parameters(jnp.uint4, jnp.uint8, jnp.uint16, jnp.uint32) + def test_unsigned_dtype_dot_raises(self, dtype): + k = 256 + packing = 32 // jnp.iinfo(dtype).bits + lhs = jnp.zeros((8 * packing, k), dtype=dtype) + rhs = jnp.zeros((k, 128), dtype=dtype) + + def kernel(lhs_ref, rhs_ref, o_ref): + o_ref[...] = pl.dot(lhs_ref[...], rhs_ref[...]) + + out_shape = jax.ShapeDtypeStruct((8 * packing, 128), dtype) + with self.assertRaisesRegex( + NotImplementedError, "Unsigned integer dtype.*dot_general.*matmul" + ): + self.pallas_call(kernel, out_shape=out_shape)(lhs, rhs) + if __name__ == "__main__": absltest.main() From a1328dfbc589bb9b62d6c0dd28ca9ac2c39d9e5f Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 3 Mar 2026 13:32:01 -0800 Subject: [PATCH 050/100] [doc] update developer docs for pyrefly --- docs/contributing.md | 2 +- docs/developer.md | 12 ++---------- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/docs/contributing.md b/docs/contributing.md index 40334bb9599a..4408b388432b 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -186,7 +186,7 @@ possible. The `git rebase -i` command might be useful to this end. ### Linting and type-checking -JAX uses [mypy](https://mypy.readthedocs.io/) and +JAX uses [Pyrefly](https://pyrefly.org/) and [ruff](https://docs.astral.sh/ruff/) to statically test code quality; the easiest way to run these checks locally is via the [pre-commit](https://pre-commit.com/) framework: diff --git a/docs/developer.md b/docs/developer.md index 8244a299187a..75063ad8232c 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -678,20 +678,12 @@ JAX_TRACEBACK_FILTERING=off XLA_FLAGS=--xla_force_host_platform_device_count=8 p ## Type checking -We use `mypy` to check the type hints. To run `mypy` with the same configuration as the +We use `pyrefly` to check the type hints. To run `pyrefly` with the same configuration as the github CI checks, you can use the [pre-commit](https://pre-commit.com/) framework: ``` pip install pre-commit -pre-commit run mypy --all-files -``` - -Because `mypy` can be somewhat slow when checking all files, it may be convenient to -only check files you have modified. To do this, first stage the changes (i.e. `git add` -the changed files) and then run this before committing the changes: - -``` -pre-commit run mypy +pre-commit run pyrefly-check --all-files ``` ## Linting From 0711c0804d0139170df3394197746a32bb0f3fd3 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 3 Mar 2026 13:52:04 -0800 Subject: [PATCH 051/100] [pyrefly] test with opt-einsum --- .pre-commit-config.yaml | 1 + jax/_src/numpy/einsum.py | 1 + 2 files changed, 2 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5e4d9259562a..aee5bc1c8429 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -46,6 +46,7 @@ repos: - types-requests~=2.32.0 - numpy~=2.4.0 - ml_dtypes~=0.5.0 + - opt-einsum~=3.4.0 - scipy-stubs - --pre - --extra-index-url diff --git a/jax/_src/numpy/einsum.py b/jax/_src/numpy/einsum.py index 49c1bdfe4ef1..3917a605b8fd 100644 --- a/jax/_src/numpy/einsum.py +++ b/jax/_src/numpy/einsum.py @@ -422,6 +422,7 @@ def einsum_path( optimize2: Any = 'optimal' if optimize else Unoptimized() else: optimize2 = optimize + # pyrefly: ignore[no-matching-overload] return opt_einsum.contract_path(subscripts, *operands, optimize=optimize2) def _removechars(s, chars): From c0b2687098239f19c53a2d55e2f8ec42eafa3e15 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 3 Mar 2026 15:39:54 -0800 Subject: [PATCH 052/100] Migrate to SafeStatic instead of SafeStaticInit. Change SafeStatic so its destructor is trivial. PiperOrigin-RevId: 878156890 --- jaxlib/pmap_lib.cc | 3 ++- jaxlib/py_array.cc | 3 ++- jaxlib/py_values.cc | 10 ++++++---- jaxlib/sharding.cc | 3 ++- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/jaxlib/pmap_lib.cc b/jaxlib/pmap_lib.cc index 10fae91f6f03..ec4435686ca4 100644 --- a/jaxlib/pmap_lib.cc +++ b/jaxlib/pmap_lib.cc @@ -296,7 +296,8 @@ class PmapFunction { size_t nargs, PyObject* kwnames); nb::object PythonSignature() { - const nb::module_& inspect = xla::SafeStaticInit([]() { + static xla::SafeStatic inspect_init; + const nb::module_& inspect = inspect_init.Get([]() { return std::make_unique(nb::module_::import_("inspect")); }); return inspect.attr("signature")(fun_); diff --git a/jaxlib/py_array.cc b/jaxlib/py_array.cc index cb712ba36471..6b7943812d38 100644 --- a/jaxlib/py_array.cc +++ b/jaxlib/py_array.cc @@ -427,7 +427,8 @@ nb::object MakeShapedArrayCached(const ShapedArrayCacheKey& key) { static auto* lru_list = new CacheT::LRUList(4096); static auto* cache = new CacheT(lru_list); - const nb::object& shaped_array = xla::SafeStaticInit([]() { + static xla::SafeStatic shaped_array_init; + const nb::object& shaped_array = shaped_array_init.Get([]() { nb::object jax_core; try { jax_core = nb::module_::import_("jax.core"); diff --git a/jaxlib/py_values.cc b/jaxlib/py_values.cc index de666b266272..a507ed4add6c 100644 --- a/jaxlib/py_values.cc +++ b/jaxlib/py_values.cc @@ -818,8 +818,8 @@ absl::StatusOr MakeShardFn(nb::handle arg, ifrt::Client* client, (*p)[dtypes.np_intc.ptr()] = HandleNumpyScalar; return p; }; - const PyObjectDeviceHandlerMap& handlers = - xla::SafeStaticInit(init_fn); + static xla::SafeStatic handlers_init; + const PyObjectDeviceHandlerMap& handlers = handlers_init.Get(init_fn); if (arg.type().ptr() == PyArray::type().ptr()) { auto array = nb::borrow(arg); @@ -866,9 +866,11 @@ using ToPyArgSignatureHandler = absl::StatusOr PyArgSignatureOfValue(nb::handle arg, bool jax_enable_x64) { + static xla::SafeStatic< + absl::flat_hash_map> + handlers_init; const absl::flat_hash_map& handlers = - xla::SafeStaticInit< - absl::flat_hash_map>([] { + handlers_init.Get([] { auto p = std::make_unique< absl::flat_hash_map>(); diff --git a/jaxlib/sharding.cc b/jaxlib/sharding.cc index 30a6f99e6ff0..63b12d5fe284 100644 --- a/jaxlib/sharding.cc +++ b/jaxlib/sharding.cc @@ -173,7 +173,8 @@ NamedSharding::NamedSharding(nb::object mesh, nb_class_ptr spec, nb::module_ si = nb::module_::import_("jax._src.named_sharding"); return std::make_unique(si.attr("check_pspec")); }; - nb::object& check_pspec = xla::SafeStaticInit(init_fn); + static xla::SafeStatic check_pspec_init; + nb::object& check_pspec = check_pspec_init.Get(init_fn); check_pspec(mesh_, spec_); } From 718ac012f61418ee9f08bd66708f365eb7bed889 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 3 Mar 2026 16:50:52 -0800 Subject: [PATCH 053/100] Couple of changes in this PR: * Add a test for grad(shmap) where we have array-like HiVal, HiPrimitive, HiType, HipSpec. * Add `to_cotangent_spec` to `HipSpec`. * Add `nospec` to `HiType` (part of shard_map interface) so that we can get a `HipSpec` from the underlying vmas. There's more subtlety here like under check_vma you will get all_manual_names passed in so you can give us a HipSpec wrt to that. Co-authored-by: Matthew Johnson PiperOrigin-RevId: 878185697 --- jax/_src/ad_util.py | 5 +- jax/_src/core.py | 7 ++ jax/_src/hijax.py | 8 +- jax/_src/interpreters/ad.py | 6 +- jax/_src/interpreters/partial_eval.py | 19 ++--- jax/_src/partition_spec.py | 3 +- jax/_src/shard_map.py | 24 +++--- jax/_src/util.py | 4 +- tests/hijax_test.py | 105 +++++++++++++++++++++++++- 9 files changed, 143 insertions(+), 38 deletions(-) diff --git a/jax/_src/ad_util.py b/jax/_src/ad_util.py index e71d63e2f54b..75c7e1c015cf 100644 --- a/jax/_src/ad_util.py +++ b/jax/_src/ad_util.py @@ -53,8 +53,9 @@ def add_abstract(x, y): return x def zeros_like_aval(aval: core.AbstractValue) -> Array: - if hasattr(aval, 'vspace_zero'): # TODO(mattjj,dougalm): revise away hasattr - return aval.vspace_zero() + from jax._src.hijax import HiType # type: ignore + if isinstance(aval, HiType): + return aval.vspace_zero() # type: ignore return aval_zeros_likers[type(aval)](aval) aval_zeros_likers: dict[type, Callable[[Any], Array]] = {} diff --git a/jax/_src/core.py b/jax/_src/core.py index 9b0a16363373..8c43f9aae105 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2348,6 +2348,10 @@ def update_vma(self, vma): def update_weak_type(self, weak_type): return self.update(weak_type=weak_type) + def nospec(self, mesh, check_vma, all_names) -> P: + # TODO(mattjj, yashkatariya): should use newly all_names in check_vma path? + return P(order_wrt_mesh(mesh, self.vma)) if check_vma else P(all_names) + _bool = concretization_function_error(bool) _int = concretization_function_error(int, True) _float = concretization_function_error(float, True) @@ -2415,6 +2419,9 @@ def primal_dtype_to_tangent_dtype(primal_dtype): return primal_dtype def primal_spec_to_cotangent_spec(spec): + from jax._src.hijax import HipSpec # type: ignore + if isinstance(spec, HipSpec): + return spec.to_cotangent_spec() return P(*spec, unreduced=spec.reduced, reduced=spec.unreduced) def primal_sharding_to_cotangent_sharding(sharding): diff --git a/jax/_src/hijax.py b/jax/_src/hijax.py index 2519c3df5ce6..b6242de89d46 100644 --- a/jax/_src/hijax.py +++ b/jax/_src/hijax.py @@ -83,6 +83,7 @@ def jvp(self, primals, tangents, **params): def transpose(self, *args, **params): assert False, "must override" +AxisName = Any class HiType(core.AbstractValue): is_high = True @@ -129,6 +130,10 @@ def shard(self, mesh, manual_axes: frozenset, check_vma: bool, spec: HipSpec assert False, "must override" def unshard(self, mesh, check_vma: bool, spec: HipSpec) -> HiType: assert False, "must override" + def nospec(self, mesh, check_vma: bool, all_names: tuple[AxisName, ...] + ) -> HipSpec: + assert False, "must override" + class MutableHiType(core.AbstractValue): is_high = True @@ -857,4 +862,5 @@ class Static: class MappingSpec: pass class HipSpec: - def to_lo(self): assert False, "must override" + def to_lo(self) -> HipSpec: assert False, "must override" + def to_cotangent_spec(self) -> HipSpec: assert False, "must override" diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 5f11f1e85dbd..5e9479387042 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -103,8 +103,10 @@ def linearize_subtrace(_f: Callable, _store: lu.Store, _is_vjp: bool, del linearize_trace, ans, tracers nzs_out = tuple(type(t) is not Zero for t in out_tangents) out_tangents = tuple(t for t, nz in zip(out_tangents, nzs_out) if nz) - out_tangents = map(partial(tangent_trace.to_jaxpr_tracer, source_info=source_info), out_tangents) # type: ignore[assignment] - jaxpr, consts = tangent_trace.to_jaxpr(out_tangents, debug_info.with_unknown_names(), source_info) # pyrefly: ignore[bad-argument-type] # pyrefly#2385 + out_tangents = map(partial(tangent_trace.to_jaxpr_tracer, source_info=source_info), # type: ignore + out_tangents) + jaxpr, consts = tangent_trace.to_jaxpr( + out_tangents, debug_info.with_unknown_names(), source_info) # type: ignore which_env = [(isinstance(c, pe.DynamicJaxprTracer) and getattr(c._trace, 'tag', None) is _tag) for c in consts] jaxpr = pe.move_envvars(jaxpr, tuple(which_env)) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 48d3ac4ebe0f..51e517ae4d2f 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -2456,13 +2456,9 @@ def trace_to_jaxpr( # TODO(dougalm): remove in favor of `trace_to_jaxpr` @profiler.annotate_function def trace_to_jaxpr_dynamic( - fun: lu.WrappedFun, - in_avals: Sequence[AbstractValue | core.AvalQDD], - *, - keep_inputs: list[bool] | None = None, - lower: bool = False, - auto_dce: bool = False, -) -> tuple[Jaxpr, list[AbstractValue], list[Any]]: + fun: lu.WrappedFun, in_avals: Sequence[AbstractValue | core.AvalQDD], + *, keep_inputs: list[bool] | None = None, lower: bool = False, + auto_dce: bool = False) -> tuple[Jaxpr, list[AbstractValue], list[Any]]: config.enable_checks.value and fun.debug_info.assert_arg_names(len(in_avals)) keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs parent_trace = core.trace_ctx.trace @@ -2472,15 +2468,11 @@ def trace_to_jaxpr_dynamic( # equations should be rooted at the enclosing jaxpr and not contain any # context from the callsite. Otherwise metadata from one caller would bleed # into metadata from a different caller if we, e.g., inline. - with ( - core.ensure_no_leaks(trace), - source_info_util.reset_name_stack(), - TracebackScope(), - ): + with (core.ensure_no_leaks(trace), source_info_util.reset_name_stack(), + TracebackScope()): source_info = source_info_util.current() in_tracers = map(partial(trace.new_arg, source_info=source_info), in_avals) in_tracers = [t for t, keep in zip(in_tracers, keep_inputs) if keep] - with core.set_current_trace(trace): ans = fun.call_wrapped(*in_tracers) _check_returned_jaxtypes(fun.debug_info, ans) @@ -2489,7 +2481,6 @@ def trace_to_jaxpr_dynamic( jaxpr, consts = trace.frame.to_jaxpr(trace, out_tracers, fun.debug_info, # pyrefly: ignore[bad-argument-type] # pyrefly#2385 source_info) del trace, fun, in_tracers, out_tracers, ans - config.enable_checks.value and core.check_jaxpr(jaxpr) return jaxpr, [v.aval for v in jaxpr.outvars], consts diff --git a/jax/_src/partition_spec.py b/jax/_src/partition_spec.py index 11b003304ffc..747bde390b0e 100644 --- a/jax/_src/partition_spec.py +++ b/jax/_src/partition_spec.py @@ -154,7 +154,8 @@ def update(self, **kwargs): unreduced=kwargs.pop("unreduced", self.unreduced), reduced=kwargs.pop("reduced", self.reduced)) - def to_lo(self): return [self] + def to_lo(self): + return [self] def _normalized_spec_for_aval(self, ndim: int) -> P: out = [None if p is _UNCONSTRAINED_PARTITION else p diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index b930e91c2360..bd6950783a0c 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -773,7 +773,8 @@ def _shard_map_staging( hi_avals_in = [typeof(x) for x in args] in_specs = [lo_spec for hi_spec in in_specs for lo_spec in hi_spec.to_lo()] args = [lo_val for x in args for lo_val in typeof(x).lower_val(x)] - out_specs_thunk = (lambda t: lambda: [x for s in t() for x in s.to_lo()])(out_specs_thunk) + out_specs_thunk = (lambda t: lambda: [x for s in t() for x in s.to_lo()] + )(out_specs_thunk) f, hi_avals_out = _lojax_traceable(f, hi_avals_in, unk_names=True) else: hi_avals_out = None @@ -781,10 +782,12 @@ def _shard_map_staging( in_tracers = map(to_jaxpr_tracer, args) # pyrefly: ignore[bad-assignment] # pyrefly#2385 inner_mesh = _as_manual_mesh(mesh, manual_axes) in_avals = [t.aval for t in in_tracers] - in_avals_ = map(partial(shard_aval, mesh, manual_axes, check_vma), in_specs, in_avals) + in_avals_ = map(partial(shard_aval, mesh, manual_axes, check_vma), + in_specs, in_avals) with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(inner_mesh), config._check_vma(check_vma)): - jaxpr, out_avals_, consts = pe.trace_to_jaxpr_dynamic(f, in_avals_, lower=trace.requires_low) + jaxpr, out_avals_, consts = pe.trace_to_jaxpr_dynamic( + f, in_avals_, lower=trace.requires_low) _check_names(out_specs_thunk(), out_avals_) if check_vma: @@ -1666,10 +1669,7 @@ def fwd_out_specs_thunk(): res_avals = [r for r, f1, f2 in zip(res_avals, in_fwd, out_fwd) if f1 is None and f2 is None] out_specs = out_specs_thunk() - if check_vma: - res_specs = [P(order_wrt_mesh(mesh, a.vma)) for a in res_avals] - else: - res_specs = [P(all_names)] * len(res_avals) + res_specs = [a.nospec(mesh, check_vma, all_names) for a in res_avals] return (*res_specs, *out_specs) fwd_params = dict( @@ -1700,15 +1700,11 @@ def fwd_out_specs_thunk(): elif f2 is not None: res_specs.append(out_specs[f2]) else: - if check_vma: - res_vma = next(res_avals_iter).vma - res_specs.append(P(order_wrt_mesh(mesh, res_vma))) - else: - res_specs.append(P(all_names)) + raval = next(res_avals_iter) + res_specs.append(raval.nospec(mesh, check_vma, all_names)) new_in_specs = (*res_specs, *(P(),) * len(env), *(ax for ax, nz in zip(in_specs, nzs_in) if nz)) - tangent_out_specs = tuple(ax for ax, nz in zip(out_specs_thunk(), nzs_out) - if nz) + tangent_out_specs = tuple(ax for ax, nz in zip(out_specs, nzs_out) if nz) @as_hashable_function(closure=tangent_out_specs) def tangent_out_specs_thunk(): return tangent_out_specs diff --git a/jax/_src/util.py b/jax/_src/util.py index dabdd3b4debc..16fd056f4ce7 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -184,9 +184,7 @@ def partition_list(bs: Sequence[bool], l: Sequence[T]) -> tuple[list[T], list[T] lists[b].append(x) return lists -def merge_lists(bs: Sequence[bool], - l0: Sequence[T1], - l1: Sequence[T2] +def merge_lists(bs: Sequence[bool], l0: Sequence[T1], l1: Sequence[T2] ) -> list[T1 | T2]: """Merge the elements of two lists based on a mask.""" assert sum(bs) == len(l1) and len(bs) - sum(bs) == len(l0) diff --git a/tests/hijax_test.py b/tests/hijax_test.py index 8b11734c8920..3b2d22afa439 100644 --- a/tests/hijax_test.py +++ b/tests/hijax_test.py @@ -1202,7 +1202,7 @@ def f(x): self.assertEqual(jax.linearize(f, 2.0)[1](1.0), 12.0) @jtu.with_explicit_mesh((2, 2), ('i', 'j')) - def test_oh_hi_mat(self, mesh): + def test_grad_remat_hitype(self, mesh): x = jnp.ones(4) y = jnp.ones(2) @@ -1216,6 +1216,109 @@ def f(x, y): f(x, y) jax.jit(jax.grad(f))(x, y) + @jtu.with_explicit_mesh((2,), 'x') + def test_shmap_grad_hitype(self, mesh): + class Mul(VJPHiPrimitive): + def __init__(self, aval): + self.in_avals = (aval, aval) + self.out_aval = aval + self.params = {} + super().__init__() + + def expand(self, x, y): + return MulH(x.val * y.val) + + def vjp_fwd(self, nzs_in, x, y): + return my_mul(x, y), (x, y) + + def vjp_bwd_retval(self, res, g): + x, y = res + return (my_mul(g, y), my_mul(g, x)) + + @dataclass + class MulH: + val: Any + + @dataclass(frozen=True) + class MulTy(HiType): + ty: Ty + + def __repr__(self): + return f"MulTy({self.ty})" + + def __hash__(self): + return hash((self.ty,)) + + def __eq__(self, other): + if not isinstance(other, MulTy): + return False + return self.ty == other.ty + + def lo_ty(self): + return [self.ty] + + def lower_val(self, hi_val: MulH): + return [hi_val.val] + + def raise_val(self, lo_val): + return MulH(lo_val) + + def to_tangent_aval(self) -> HiType: + return MulTy(self.ty.to_tangent_aval()) + + def vspace_zero(self): + return MulHZero(self)() + + def to_cotangent_aval(self) -> HiType: + return MulTy(self.ty.to_cotangent_aval()) + + def shard(self, mesh, manual_axes, check_vma, spec): + return MulTy(self.ty.shard(mesh, manual_axes, check_vma, spec.val)) + + def unshard(self, mesh, check_vma, spec): + return MulTy(self.ty.unshard(mesh, check_vma, spec.val)) + + register_hitype(MulH, lambda m: MulTy(jax.typeof(m.val))) + + class MulHZero(VJPHiPrimitive): + def __init__(self, mul_ty): + self.in_avals = () + self.out_aval = mul_ty + self.params = {} + super().__init__() + + def expand(self): + return MulH(ad.zeros_like_aval(self.out_aval.ty)) + + @dataclass(frozen=True) + class MulSpec(HipSpec): + val: Any + + def to_lo(self): + return [self.val] + + def to_cotangent_spec(self): + return MulSpec(self.val) + + def __repr__(self): + return f"MulSpec({self.val})" + + def my_mul(x, y): + return Mul(jax.typeof(x))(x, y) + + arr1 = jax.device_put(jnp.arange(8, dtype=jnp.float32), jax.P('x')) + arr2 = jax.device_put(jnp.arange(8, dtype=jnp.float32), jax.P('x')) + + @jax.jit + @jax.shard_map(in_specs=(MulSpec(jax.P('x')), MulSpec(jax.P('x'))), + out_specs=MulSpec(jax.P('x'))) + def f(x, y): + return my_mul(x, y) + + _, f_vjp = jax.vjp(f, MulH(arr1), MulH(arr2)) + x = jax.device_put(jnp.ones((8,), dtype=jnp.float32), jax.P('x')) + f_vjp(MulH(x)) # doesn't crash + class BoxTest(jtu.JaxTestCase): From 165e54c7b6af8dae0eebe9dc34be0b2f0e58dbef Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 3 Mar 2026 18:21:33 -0800 Subject: [PATCH 054/100] Fix shard_map_partial_eval and partial_eval_custom to use `nospec` instead of creating raw Pspecs PiperOrigin-RevId: 878214120 --- jax/_src/shard_map.py | 36 ++++++++++++------------------------ 1 file changed, 12 insertions(+), 24 deletions(-) diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index bd6950783a0c..31e78f2219df 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -1598,10 +1598,7 @@ def _shard_map_partial_eval(trace: pe.JaxprTrace, shard_map_p, def known_out_specs(): _, _, out_knowns, res_avals, _, _ = aux() _, out_known_specs = pe.partition_list(out_knowns, out_specs_thunk()) - if check_vma: - res_specs = [P(order_wrt_mesh(mesh, a.vma)) for a in res_avals] - else: - res_specs = [P(all_names)] * len(res_avals) + res_specs = [a.nospec(mesh, check_vma, all_names) for a in res_avals] return (*out_known_specs, *res_specs) known_params = dict(mesh=mesh, in_specs=(*known_in_specs,), @@ -1626,11 +1623,8 @@ def known_out_specs(): elif f2 is not None: res_specs.append(known_out_specs_[f2]) else: - if check_vma: - res_vma = next(res_avals_iter).vma - res_specs.append(P(order_wrt_mesh(mesh, res_vma))) - else: - res_specs.append(P(all_names)) + raval = next(res_avals_iter) + res_specs.append(raval.nospec(mesh, check_vma, all_names)) unk_in_specs = (*res_specs,) + (P(),) * len(env) + (*unk_in_specs,) # type: ignore[assignment] const_tracers = map(trace.new_instantiated_const, res) env_tracers = map(trace.to_jaxpr_tracer, env) @@ -1878,20 +1872,14 @@ def _partial_eval_jaxpr_custom_rule( out_binders_known, _ = partition_list(unks_out, eqn.outvars) _, ins_staged = partition_list(inst_in, eqn.invars) _, out_binders_staged = partition_list(inst_out, eqn.outvars) - newvar = core.gensym() - residuals, staged_in_res_specs = [], [] - for var, w in zip(jaxpr_staged.invars[:num_res], which): - if w: - rn = (P(order_wrt_mesh(mesh, var.aval.vma)) # type: ignore - if check_vma else P(_all_newly_manual_mesh_names(mesh, manual_axes))) - residuals.append(newvar(unshard_aval(mesh, check_vma, rn, var.aval))) - staged_in_res_specs.append(rn) - if check_vma: - out_res_specs_known = [P(order_wrt_mesh(mesh, var.aval.vma)) # type: ignore - for var, w in zip(res_vars, which) if w] - else: - out_res_specs_known = [ - P(_all_newly_manual_mesh_names(mesh, manual_axes))] * sum(which) + nv = core.gensym() + all_names = _all_newly_manual_mesh_names(mesh, manual_axes) + lns = lambda a: a.nospec(mesh, check_vma, all_names) # type: ignore + residuals, staged_in_res_specs = unzip2( + [(nv(unshard_aval(mesh, check_vma, (rn := lns(var.aval)), var.aval)), rn) + for var, w in zip(jaxpr_staged.invars[:num_res], which) if w]) + out_res_specs_known = [var.aval.nospec(mesh, check_vma, all_names) # type: ignore + for var, w in zip(res_vars, which) if w] params_known, params_staged = _pe_custom_params( unks_in, inst_in, map(op.not_, unks_out), inst_out, in_fwd, out_fwd, out_res_specs_known, staged_in_res_specs, @@ -1907,7 +1895,7 @@ def _partial_eval_jaxpr_custom_rule( new_inst = [x for x, inst in zip(eqn.invars, inst_in) if type(x) is core.Var and not inst] new_inst += [out_binders_known[f] for f in {i for i in out_fwd if i is not None}] - return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals + return eqn_known, eqn_staged, unks_out, inst_out, new_inst + list(residuals) pe.partial_eval_jaxpr_custom_rules[shard_map_p] = \ _partial_eval_jaxpr_custom_rule From a5ed9d2506d49ad78e63aa5342ea0637046a59af Mon Sep 17 00:00:00 2001 From: George Necula Date: Tue, 3 Mar 2026 13:50:41 +0200 Subject: [PATCH 055/100] [export] Add a test for shape polymorphism with invalid constraints --- jax/_src/export/shape_poly.py | 3 ++- tests/export_test.py | 6 ++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/jax/_src/export/shape_poly.py b/jax/_src/export/shape_poly.py index e811e7d38469..575bfc752591 100644 --- a/jax/_src/export/shape_poly.py +++ b/jax/_src/export/shape_poly.py @@ -1402,7 +1402,8 @@ def symbolic_shape(shape_spec: str | None, scope: optionally, you can specify that the parsed symbolic expressions be created in the given scope. If this is missing, then a new `SymbolicScope` is created with the given `constraints`. - You cannot specify both a `scope` and `constraints`. + You cannot specify both a `scope` and `constraints` (cannot add new + constraints to a `scope`). See [the documentation](https://docs.jax.dev/en/latest/export/shape_poly.html#user-specified-symbolic-constraints) for usage. like: when `shape_spec` contains placeholders ("_", "..."), use this diff --git a/tests/export_test.py b/tests/export_test.py index c787121c9502..022c276c6eaa 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -1051,6 +1051,12 @@ def f_jax(x): # x: f32[a + 2*b, a, a + b + c] jax.ShapeDtypeStruct(export.symbolic_shape(poly_spec), x.dtype)) exp.call(x) + def test_constraint_invalid_var_in_constraints(self): + m, = jax.export.symbolic_shape("m", constraints=["m >= m1"]) + with self.assertRaisesRegex(Exception, "Encountered dimension variable 'm1'"): + get_exported(jax.jit(lambda x: x))( + jax.ShapeDtypeStruct((m,), np.float32)) + def test_poly_booleans(self): # For booleans we use a special case ConvertOp to cast to and from # dynamic shapes arguments. From e05dc3c4b8f2193ac892230e8cd66e2674563955 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 4 Mar 2026 00:06:44 -0800 Subject: [PATCH 056/100] Update XLA dependency to use revision http://github.com/openxla/xla/commit/b0037055f03bf51d364cba09e94278963aec5bcf PiperOrigin-RevId: 878326348 --- MODULE.bazel | 6 +++--- third_party/xla/revision.bzl | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/MODULE.bazel b/MODULE.bazel index 4e6d43b53e7a..7605e8783428 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -27,9 +27,9 @@ archive_override( bazel_dep(name = "xla") archive_override( module_name = "xla", - integrity = "sha256-yL46qTzbd281Ygr+OLBXbNP4t9K1qrh5voGhpZMPizI=", - strip_prefix = "xla-86c54b8eaaff7b52acad66472f6e38bcd867ff93", - urls = ["https://github.com/openxla/xla/archive/86c54b8eaaff7b52acad66472f6e38bcd867ff93.tar.gz"], + integrity = "sha256-Tw4m9BWT5Dgsnr6SEQHpDOfMSn95O6gBvYJswJJdGCI=", + strip_prefix = "xla-b0037055f03bf51d364cba09e94278963aec5bcf", + urls = ["https://github.com/openxla/xla/archive/b0037055f03bf51d364cba09e94278963aec5bcf.tar.gz"], ) # TODO: upstream, otherwise we have to duplicate the patches in jax diff --git a/third_party/xla/revision.bzl b/third_party/xla/revision.bzl index d91a43ebc818..ad26f4aa31dd 100644 --- a/third_party/xla/revision.bzl +++ b/third_party/xla/revision.bzl @@ -21,5 +21,5 @@ # and update XLA_SHA256 with the result. # buildifier: disable=module-docstring -XLA_COMMIT = "86c54b8eaaff7b52acad66472f6e38bcd867ff93" -XLA_SHA256 = "c8be3aa93cdb776f35620afe38b0576cd3f8b7d2b5aab879be81a1a5930f8b32" +XLA_COMMIT = "b0037055f03bf51d364cba09e94278963aec5bcf" +XLA_SHA256 = "4f0e26f41593e4382c9ebe921101e90ce7cc4a7f793ba801bd826cc0925d1822" From c307523febdf8e74fcf4cdae179f0206ebf0c5d0 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 4 Mar 2026 08:08:30 +0000 Subject: [PATCH 057/100] Removed unnecessary Pyrefly suppressions --- jax/_src/array.py | 4 ++-- jax/_src/config.py | 2 +- jax/_src/cudnn/scaled_matmul_stablehlo.py | 2 +- jax/_src/export/_export.py | 8 ++++---- jax/_src/hijax.py | 18 +++++++++--------- jax/_src/interpreters/mlir.py | 2 +- jax/_src/lax/convolution.py | 2 +- jax/_src/lax/fft.py | 1 - jax/_src/lax/linalg.py | 3 +-- jax/_src/pallas/mosaic/lowering.py | 2 +- jax/_src/pallas/mosaic/tpu_info.py | 4 ++-- jax/_src/pallas/mosaic_gpu/lowering.py | 6 +++--- jax/_src/pallas/mosaic_gpu/pipeline.py | 2 -- jax/_src/scipy/stats/bernoulli.py | 4 ++-- jax/_src/shard_map.py | 10 +++++----- jax/_src/stages.py | 2 +- jax/_src/state/discharge.py | 2 +- jax/_src/test_util.py | 1 - jax/_src/tpu_custom_call.py | 2 +- jax/version.py | 2 +- 20 files changed, 37 insertions(+), 42 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index f2a0c2888928..ea18a848eb7c 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -486,7 +486,7 @@ def __dlpack_device__(self) -> tuple[enum.Enum, int]: else: raise BufferError( "__dlpack__ device only supported for CPU and GPU, got platform: " - f"{self.platform()}" # pyrefly: ignore[missing-attribute] + f"{self.platform()}" ) def __reduce__(self): @@ -587,7 +587,7 @@ def delete(self): for buf in self._arrays: buf.delete() self._arrays = None # pyrefly: ignore[bad-assignment] - self._npy_value = None # pyrefly: ignore[bad-assignment] + self._npy_value = None @use_cpp_method() def is_deleted(self): diff --git a/jax/_src/config.py b/jax/_src/config.py index f67e3f975b1f..27056bfcb51b 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -906,7 +906,7 @@ def __bool__(self) -> NoReturn: raise TypeError( "bool() not supported for instances of type '{0}' " "(did you mean to use '{0}.value' instead?)".format( - type(self).__name__)) # pyrefly: ignore[missing-attribute] # pyrefly#2444 + type(self).__name__)) def _set(self, value: _T) -> None: self.value = value diff --git a/jax/_src/cudnn/scaled_matmul_stablehlo.py b/jax/_src/cudnn/scaled_matmul_stablehlo.py index c7dcdbd2ac6b..01e8a681e899 100644 --- a/jax/_src/cudnn/scaled_matmul_stablehlo.py +++ b/jax/_src/cudnn/scaled_matmul_stablehlo.py @@ -684,7 +684,7 @@ def scaled_dot_bwd(dimension_numbers, preferred_element_type, configs, res, g): "configs": [configs[2], configs[0]] } grad_lhs = scaled_dot_general_transpose_lhs(*args, **lhs_kw_args) - grad_rhs = scaled_dot_general_transpose_rhs(*args, **rhs_kw_args) # pyrefly: ignore[bad-argument-type] + grad_rhs = scaled_dot_general_transpose_rhs(*args, **rhs_kw_args) # We apply a Straight-Through Estimator (STE) with zero-out behavior: if # inputs are clipped during quantization in fprop, their corresponding gradients diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index e42ad942c42c..98a4acec90e9 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -903,8 +903,8 @@ def _module_to_bytecode(module: ir.Module) -> bytes: # Note that this does not verify any JAX custom calls, which are only # guaranteed 3w of forward compatibility, and only prevents use of new # StableHLO features from failing on older hardware. - target_version = hlo.get_version_from_compatibility_requirement( # pyrefly: ignore[missing-attribute] - hlo.StablehloCompatibilityRequirement.WEEK_4) # pyrefly: ignore[missing-attribute] + target_version = hlo.get_version_from_compatibility_requirement( + hlo.StablehloCompatibilityRequirement.WEEK_4) module_serialized = xla_client._xla.mlir.serialize_portable_artifact( # type: ignore mlir_str, target_version, xb.get_backend().serialize_with_sdy) return module_serialized @@ -1550,8 +1550,8 @@ def _call_exported_impl(*args, exported: Exported): def get_mesh_from_symbol(symtab: ir.SymbolTable) -> mesh_lib.AbstractMesh: if "mesh" not in symtab: return mesh_lib.empty_abstract_mesh - mesh_attr = sdy.MeshAttr(symtab["mesh"].mesh) # pyrefly: ignore[missing-attribute] - axes = [sdy.MeshAxisAttr(a) for a in mesh_attr.axes] # pyrefly: ignore[missing-attribute] + mesh_attr = sdy.MeshAttr(symtab["mesh"].mesh) + axes = [sdy.MeshAxisAttr(a) for a in mesh_attr.axes] if not axes: return mesh_lib.empty_abstract_mesh axes_sizes = tuple(a.size for a in axes) diff --git a/jax/_src/hijax.py b/jax/_src/hijax.py index b6242de89d46..3420a32640f4 100644 --- a/jax/_src/hijax.py +++ b/jax/_src/hijax.py @@ -302,18 +302,18 @@ class BoxEffect(effects.Effect): ... class NewBox(HiPrimitive): def is_high(self, *, treedef) -> bool: return True # type: ignore - def abstract_eval(self, *, treedef): # pyrefly: ignore[bad-override] + def abstract_eval(self, *, treedef): leaves, treedef = tree_flatten(None) qdd = BoxTypeState(tuple(leaves), treedef) return core.AvalQDD(BoxTy(), qdd), {box_effect} - def to_lojax(_, *, treedef): # pyrefly: ignore[bad-override] + def to_lojax(_, *, treedef): return Box._new(None) def jvp(_, primals, tangents, *, treedef): # pyrefly: ignore[bad-override] assert False # TODO - def transpose(_, *args, treedef): # pyrefly: ignore[bad-override] + def transpose(_, *args, treedef): assert False # TODO new_box_p = NewBox('new_box') @@ -322,11 +322,11 @@ class BoxSet(HiPrimitive): def is_high(self, *leaf_avals, treedef) -> bool: return True # type: ignore - def abstract_eval(self, box_ty, *leaf_avals, treedef): # pyrefly: ignore[bad-override] + def abstract_eval(self, box_ty, *leaf_avals, treedef): box_ty.mutable_qdd.update(BoxTypeState(leaf_avals, treedef)) return [], {box_effect} # TODO better typechecking... - def to_lojax(_, box, *leaves, treedef): # pyrefly: ignore[bad-override] + def to_lojax(_, box, *leaves, treedef): box._val = tree_unflatten(treedef, leaves) return [] @@ -340,7 +340,7 @@ def jvp(_, primals, tangents, *, treedef): # pyrefly: ignore[bad-override] box_set_p.bind(box_dot, *val_dots, treedef=treedef) return [], [] - def transpose(_, *args, treedef): # pyrefly: ignore[bad-override] + def transpose(_, *args, treedef): assert False # TODO box_set_p = BoxSet('box_set') @@ -348,10 +348,10 @@ def transpose(_, *args, treedef): # pyrefly: ignore[bad-override] class BoxGet(HiPrimitive): multiple_results = True - def abstract_eval(self, box_ty, *, avals): # pyrefly: ignore[bad-override] + def abstract_eval(self, box_ty, *, avals): return avals, {box_effect} - def to_lojax(_, box, *, avals): # pyrefly: ignore[bad-override] + def to_lojax(_, box, *, avals): return tree_leaves(box._val) def jvp(_, primals, tangents, *, avals): # pyrefly: ignore[bad-override] @@ -361,7 +361,7 @@ def jvp(_, primals, tangents, *, avals): # pyrefly: ignore[bad-override] box_get_p.bind(box_dot, avals=tuple(a.to_tangent_aval() for a in avals)) ) - def transpose(_, *args): # pyrefly: ignore[bad-override] + def transpose(_, *args): assert False # TODO box_get_p = BoxGet('box_get') diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 8f9540590bea..8d8227d907ea 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -2246,7 +2246,7 @@ def _emit_lowering_rule_as_fun( if sub_ctx.tokens_out: outs = [ *(sub_ctx.tokens_out.get(eff) for eff in ordered_effects), - *outs # pyrefly: ignore[not-iterable] + *outs ] outs = flatten_ir_values(outs) # pyrefly: ignore[bad-argument-type] func_dialect.return_(outs) diff --git a/jax/_src/lax/convolution.py b/jax/_src/lax/convolution.py index 350b90f7adc7..9fc72b19aaec 100644 --- a/jax/_src/lax/convolution.py +++ b/jax/_src/lax/convolution.py @@ -791,7 +791,7 @@ def _conv_general_dilated_lower( return complex_conv(ctx, lhs, rhs) lhs_spec, rhs_spec, out_spec = dimension_numbers - dnums = hlo.ConvDimensionNumbers.get( # pyrefly: ignore[missing-attribute] + dnums = hlo.ConvDimensionNumbers.get( input_batch_dimension=lhs_spec[0], input_feature_dimension=lhs_spec[1], input_spatial_dimensions=list(lhs_spec[2:]), diff --git a/jax/_src/lax/fft.py b/jax/_src/lax/fft.py index 56364a0fb8c6..dd169e5b6452 100644 --- a/jax/_src/lax/fft.py +++ b/jax/_src/lax/fft.py @@ -130,7 +130,6 @@ def _fft_lowering(ctx, x, *, fft_type, fft_lengths): # TODO: https://github.com/openxla/stablehlo/issues/1366 raise NotImplementedError("Shape polymorphism for FFT with non-constant fft_length is not implemented for TPU and GPU") return [ - # pyrefly: ignore[missing-attribute] hlo.FftOp(x, hlo.FftTypeAttr.get(fft_type.name), mlir.dense_int_array(fft_lengths)).result ] diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 16f5a46fdedc..479380e7dbb7 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -2423,7 +2423,7 @@ def _triangular_solve_lowering( out = hlo.triangular_solve(a, b, ir.BoolAttr.get(left_side), ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal), - hlo.TransposeAttr.get(transpose)) # pyrefly: ignore[missing-attribute] + hlo.TransposeAttr.get(transpose)) return [mlir.lower_with_sharding_in_types(ctx, out, out_aval)] @@ -2460,7 +2460,6 @@ def _triangular_solve_cpu_lower( return [hlo.triangular_solve(a, b, ir.BoolAttr.get(left_side), ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal), - # pyrefly: ignore[missing-attribute] hlo.TransposeAttr.get(transpose))] triangular_solve_p = linalg_primitive( diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 278f0a7162e4..37b2aa54c348 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -651,7 +651,7 @@ def _get_semantics(s: str | None) -> str: # pyrefly: ignore[bad-argument-type] # pyrefly#2385 map( ir.Attribute.parse, - map(_get_semantics, self._dimension_semantics), # pyrefly: ignore[no-matching-overload] # pyrefly#2385 + map(_get_semantics, self._dimension_semantics), ) ) diff --git a/jax/_src/pallas/mosaic/tpu_info.py b/jax/_src/pallas/mosaic/tpu_info.py index 636ff9a16edb..2ea709939ab6 100644 --- a/jax/_src/pallas/mosaic/tpu_info.py +++ b/jax/_src/pallas/mosaic/tpu_info.py @@ -514,7 +514,7 @@ class Tiling(enum.Enum): @property def shape(self) -> tuple[int, ...]: # TODO(slebedev): Use ``get_tpu_info()`` instead of hardcoding the values. - match self: # pyrefly: ignore[non-exhaustive-match] # pyrefly#2080 + match self: case Tiling.COMPACT: return (8, 128) case Tiling.SPARSE_CORE: @@ -570,7 +570,7 @@ def infer_tiling( ) leading_dims, final_dims = shape[:-tiling_rank], shape[-tiling_rank:] - match tiling: # pyrefly: ignore[non-exhaustive-match] # pyrefly#2080 + match tiling: case Tiling.COMPACT: second_minor, _ = final_dims factor = _get_tiling_factor(second_minor, tiling.shape[0], packing) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index dd97ea5752b2..e9db30e220c9 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -902,7 +902,7 @@ def lower_jaxpr_to_module( # We reverse the order because Pallas prefers row-major iteration while the # CUDA runtime prefers column-major iteration. parallel_grid = parallel_grid[::-1] - cluster = cluster[::-1] # pyrefly: ignore[bad-assignment] + cluster = cluster[::-1] squashed_dims = squashed_dims[::-1] axis_names = axis_names.reverse() @@ -1447,7 +1447,7 @@ def _extract_aliased_ref( ir.MemRefType(ref.type).element_type, mgpu_utils.dtype_to_ir_type(dtype), ) - ref = mgpu.memref_reshape(ref, transformed_shape) # pyrefly: ignore[bad-assignment] + ref = mgpu.memref_reshape(ref, transformed_shape) return ( ref, ref_aval, @@ -1676,7 +1676,7 @@ def _handle_transforms( if is_multicast: transformed_ref = ctx.launch_ctx.to_remote_multicast(transformed_ref) # pyrefly: ignore[bad-argument-type] assert isinstance(ref_aval, state_types.AbstractRef) - return transformed_ref, ref_aval, new_transforms # pyrefly: ignore[bad-return] + return transformed_ref, ref_aval, new_transforms def _ndindexer_indices( diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index 2b502b6ff415..906f46f639e7 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -124,7 +124,6 @@ def copy_in(self, slot, grid_indices, barrier_ref, barrier_slot=None): gpu_primitives.copy_gmem_to_smem( # pyrefly: ignore[bad-index] self.gmem_ref.at[gmem_slices], # pytype: disable=unsupported-operands - # pyrefly: ignore[bad-index] self.smem_ref.at[slot], # pytype: disable=unsupported-operands barrier_ref.at[barrier_slot if barrier_slot is not None else slot], collective_axes=getattr(self.spec, "collective_axes", ()), @@ -136,7 +135,6 @@ def copy_out(self, slot, grid_indices, predicate=None): assert self.smem_ref is not None gmem_slices = self.compute_gmem_slice(grid_indices) gpu_primitives.copy_smem_to_gmem( - # pyrefly: ignore[bad-index] self.smem_ref.at[slot], # pytype: disable=unsupported-operands # pyrefly: ignore[bad-index] self.gmem_ref.at[gmem_slices], # pytype: disable=unsupported-operands diff --git a/jax/_src/scipy/stats/bernoulli.py b/jax/_src/scipy/stats/bernoulli.py index ef73ea554a41..b10803baa081 100644 --- a/jax/_src/scipy/stats/bernoulli.py +++ b/jax/_src/scipy/stats/bernoulli.py @@ -126,7 +126,7 @@ def cdf(k: ArrayLike, p: ArrayLike) -> Array: lax.ge(k, one) ] vals = [jnp.nan, zero, one - p, one] - return jnp.select(conds, vals) # pyrefly: ignore[bad-argument-type] + return jnp.select(conds, vals) def ppf(q: ArrayLike, p: ArrayLike) -> Array: @@ -152,7 +152,7 @@ def ppf(q: ArrayLike, p: ArrayLike) -> Array: """ q, p = promote_args_inexact('bernoulli.ppf', q, p) zero, one = _lax_const(q, 0), _lax_const(q, 1) - return jnp.where( # pyrefly: ignore[no-matching-overload] + return jnp.where( jnp.isnan(q) | jnp.isnan(p) | (p < zero) | (p > one) | (q < zero) | (q > one), jnp.nan, jnp.where(lax.le(q, one - p), zero, one) diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index 31e78f2219df..992d9122ba9e 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -779,7 +779,7 @@ def _shard_map_staging( else: hi_avals_out = None to_jaxpr_tracer = partial(trace.to_jaxpr_tracer, source_info=source_info) - in_tracers = map(to_jaxpr_tracer, args) # pyrefly: ignore[bad-assignment] # pyrefly#2385 + in_tracers = map(to_jaxpr_tracer, args) inner_mesh = _as_manual_mesh(mesh, manual_axes) in_avals = [t.aval for t in in_tracers] in_avals_ = map(partial(shard_aval, mesh, manual_axes, check_vma), @@ -1209,7 +1209,7 @@ def _unmatch2(mesh, prev_manual, spec, x): src = P(order_wrt_mesh(mesh, prev_manual), *spec) newly_manual = _spec_to_vma(spec) dst = P(order_wrt_mesh(mesh, prev_manual | newly_manual)) - return shard_map(lambda x: x, in_specs=src, out_specs=dst, # pyrefly: ignore[no-matching-overload] + return shard_map(lambda x: x, in_specs=src, out_specs=dst, axis_names=prev_manual | newly_manual)(x) def _match_spec2(mesh, prev_manual, spec, x) -> JaxType: @@ -1221,7 +1221,7 @@ def _match2(mesh, prev_manual, spec, x): newly_manual = _spec_to_vma(spec) src = P(order_wrt_mesh(mesh, prev_manual | newly_manual)) dst = P(order_wrt_mesh(mesh, prev_manual), *spec) - return shard_map(lambda x: x, in_specs=src, out_specs=dst, # pyrefly: ignore[no-matching-overload] + return shard_map(lambda x: x, in_specs=src, out_specs=dst, axis_names=prev_manual | newly_manual)(x) @@ -1436,9 +1436,9 @@ def to_concrete_value(self): def __str__(self) -> str: pb_names = set(self._trace.mesh.axis_names) - self.vma # pyrefly: ignore[missing-attribute] self = pvary(self, tuple(pb_names)) - with core.eval_context(), use_abstract_mesh(self._trace.amesh): # pyrefly: ignore[missing-attribute] + with core.eval_context(), use_abstract_mesh(self._trace.amesh): blocks = list(self.val) - mesh = self._trace.mesh # pyrefly: ignore[missing-attribute] + mesh = self._trace.mesh axis_names = f"({', '.join(map(str, mesh.axis_names))},)" return '\n'.join( f"On {device} at mesh coordinates {axis_names} = {idx}:\n{block}\n" diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 5c0003de3c47..a4bc2bb220c3 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -826,7 +826,7 @@ def call(*args, **kwargs): _in_hi_tree, final_qdds = params.in_types # TODO(jakevdp): remove pyrefly ignore when https://github.com/facebook/pyrefly/issues/2382 is fixed. args_flat = [a.read_loval(core.cur_qdd(x), x) if (a := typeof(x)).has_qdd - else a.lower_val(x) for x in hi_args_flat] # pyrefly: ignore[unbound-name] + else a.lower_val(x) for x in hi_args_flat] args_flat, in_tree = \ tree_util.tree_flatten(tree_util.tree_unflatten(in_hi_tree, args_flat)) else: diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index a0649cfc64e8..bd7273e02390 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -693,7 +693,7 @@ def _run_state_to_lojax(*args, jaxpr, is_initialized, **params): out_mut, lo_outs = split_list(all_outs, [pe.num_himuts_out(jaxpr)]) pe.apply_himut(jaxpr, args, out_mut) return pe.raise_lo_outs(arg_avals, lo_outs) -run_state_p.to_lojax = _run_state_to_lojax # pyrefly: ignore[bad-assignment] +run_state_p.to_lojax = _run_state_to_lojax def _default_initialization(x): diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index f89636439c56..11b5d42b1020 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -450,7 +450,6 @@ def stablehlo_version_at_least(required_version: str): plugin_version = xla_bridge.backend_stablehlo_version() if plugin_version is None: return True - # pyrefly: ignore[missing-attribute] return hlo.get_smaller_version( ".".join(map(str, plugin_version)), required_version ) == plugin_version diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index 10177b9b6a7c..fa4ee83f6464 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -451,7 +451,7 @@ def _lower_mosaic_module_to_asm( *, ir_version: int | None = None, ) -> tuple[bytes, tuple[bool, bool]]: - has_communication, has_custom_barrier = tpu.private_has_communication( # pyrefly: ignore[missing-attribute] + has_communication, has_custom_barrier = tpu.private_has_communication( module.operation ) # We'll mutate the module, so clone it diff --git a/jax/version.py b/jax/version.py index 8e0591731fac..ccacec4dd1b3 100644 --- a/jax/version.py +++ b/jax/version.py @@ -129,7 +129,7 @@ class _build_py(build_py_orig): def run(self): if _release_version is None: this_file_in_build_dir = os.path.join( - self.build_lib, # pyrefly: ignore[missing-attribute] + self.build_lib, pkg_source_path, os.path.basename(__file__)) # super().run() only copies files from source -> build if they are From 1184dbaab8522f49cefcd730db9d55c8e8ee7376 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 4 Mar 2026 00:38:40 -0800 Subject: [PATCH 058/100] Fix TensorStore implementation in Jax to use typed initialization. PiperOrigin-RevId: 878337514 --- jax/experimental/array_serialization/tensorstore_impl.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jax/experimental/array_serialization/tensorstore_impl.py b/jax/experimental/array_serialization/tensorstore_impl.py index b2aa69ad1955..151a205e630b 100644 --- a/jax/experimental/array_serialization/tensorstore_impl.py +++ b/jax/experimental/array_serialization/tensorstore_impl.py @@ -39,9 +39,9 @@ 'cache_pool#remote': {'total_bytes_limit': 10_000_000_000}, 'data_copy_concurrency': {'limit': 128} }) -_TS_CHUNK_LAYOUT = ts.ChunkLayout({ - "chunk": {"elements": 100_000_000}, # 100M (800MB for float64) file size -}) +_TS_CHUNK_LAYOUT = ts.ChunkLayout( + chunk=ts.ChunkLayout.Grid(elements=100_000_000), # 100M (800MB for float64) file size +) _DEFAULT_BASE_DRIVER = 'file' _PROCESS_DIR_FORMAT = "process_{}" From 41f6727c78f8bc0538bf955fea0dbb3c58ddde7e Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 4 Mar 2026 09:08:44 +0000 Subject: [PATCH 059/100] Removed a few `type: ignore`s which were only necessary for mypy --- jax/_src/ad_util.py | 6 ++-- jax/_src/api.py | 10 +++---- jax/_src/array.py | 6 ++-- jax/_src/basearray.pyi | 14 ++++----- jax/_src/callback.py | 8 ++--- jax/_src/checkify.py | 2 +- jax/_src/clusters/cloud_tpu_cluster.py | 4 +-- jax/_src/cudnn/fused_attention_stablehlo.py | 6 ++-- jax/_src/custom_derivatives.py | 6 ++-- jax/_src/dispatch.py | 2 +- jax/_src/dtypes.py | 6 ++-- jax/_src/export/_export.py | 24 +++++++-------- jax/_src/export/serialization.py | 12 ++++---- jax/_src/export/shape_poly.py | 8 ++--- jax/_src/ffi.py | 2 +- jax/_src/internal_test_util/test_harnesses.py | 3 -- jax/_src/interpreters/batching.py | 4 +-- jax/_src/interpreters/mlir.py | 29 +++++++++---------- jax/_src/interpreters/partial_eval.py | 22 ++++++++------ jax/_src/interpreters/remat.py | 2 +- jax/_src/lax/control_flow/loops.py | 8 ++--- jax/_src/lax/parallel.py | 4 +-- jax/_src/lax/slicing.py | 2 +- jax/_src/lax/windowed_reductions.py | 4 +-- jax/_src/layout.py | 2 +- jax/_src/lru_cache.py | 2 +- jax/_src/named_sharding.py | 8 ++--- jax/_src/numpy/array_methods.py | 1 - jax/_src/numpy/einsum.py | 4 +-- jax/_src/numpy/lax_numpy.py | 8 ++--- jax/_src/numpy/linalg.py | 2 +- jax/_src/numpy/polynomial.py | 2 +- jax/_src/numpy/util.py | 4 +-- jax/_src/ops/scatter.py | 2 +- jax/_src/pallas/core.py | 12 ++++---- jax/_src/pallas/fuser/block_spec.py | 1 - jax/_src/pallas/fuser/custom_fusion_lib.py | 6 ++-- jax/_src/pallas/fuser/fusible_dtype.py | 1 - jax/_src/pallas/helpers.py | 4 +-- jax/_src/pallas/hlo_interpreter.py | 2 +- .../mosaic/interpret/interpret_pallas_call.py | 6 ++-- jax/_src/pallas/mosaic/lowering.py | 1 - .../pallas/mosaic/pallas_call_registration.py | 4 +-- jax/_src/pallas/mosaic/pipeline.py | 6 ++-- jax/_src/pallas/mosaic/sc_core.py | 16 ++++++---- jax/_src/pallas/mosaic/sc_lowering.py | 4 +-- jax/_src/pallas/mosaic/tpu_info.py | 2 +- jax/_src/pallas/mosaic_gpu/core.py | 2 +- jax/_src/pallas/mosaic_gpu/lowering.py | 3 +- .../mosaic_gpu/pallas_call_registration.py | 2 +- jax/_src/pallas/mosaic_gpu/pipeline.py | 6 ++-- jax/_src/pallas/pipelining/schedule_api.py | 1 - jax/_src/pallas/pipelining/schedulers.py | 1 - jax/_src/pallas/triton/lowering.py | 1 - .../pallas/triton/pallas_call_registration.py | 2 +- jax/_src/pjit.py | 18 ++++++------ jax/_src/prng.py | 2 +- jax/_src/random.py | 2 +- jax/_src/shard_map.py | 12 ++++---- jax/_src/sharding.py | 6 ++-- jax/_src/sharding_specs.py | 2 +- jax/_src/state/discharge.py | 2 +- jax/_src/state/indexing.py | 2 +- jax/_src/state/primitives.py | 6 ++-- jax/_src/test_util.py | 4 +-- jax/_src/tpu_custom_call.py | 1 - jax/_src/traceback_util.py | 4 +-- jax/_src/tree_util.py | 6 ++-- jax/_src/util.py | 6 ++-- jax/experimental/key_reuse/_core.py | 2 +- jax/experimental/pallas/ops/gpu/attention.py | 8 ++--- 71 files changed, 195 insertions(+), 200 deletions(-) diff --git a/jax/_src/ad_util.py b/jax/_src/ad_util.py index 75c7e1c015cf..5b32af53b782 100644 --- a/jax/_src/ad_util.py +++ b/jax/_src/ad_util.py @@ -32,7 +32,7 @@ map = safe_map def add_jaxvals(x: ArrayLike, y: ArrayLike) -> Array: - from jax._src.hijax import HiType # type: ignore + from jax._src.hijax import HiType ty = typeof(x) if isinstance(ty, HiType): return ty.vspace_add(x, y) @@ -53,9 +53,9 @@ def add_abstract(x, y): return x def zeros_like_aval(aval: core.AbstractValue) -> Array: - from jax._src.hijax import HiType # type: ignore + from jax._src.hijax import HiType if isinstance(aval, HiType): - return aval.vspace_zero() # type: ignore + return aval.vspace_zero() return aval_zeros_likers[type(aval)](aval) aval_zeros_likers: dict[type, Callable[[Any], Array]] = {} diff --git a/jax/_src/api.py b/jax/_src/api.py index a32a480c659b..11520bcd87f7 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -675,7 +675,7 @@ def fwd(*args, **kwargs): f = lu.wrap_init(fun, params=kwargs, debug_info=dbg) f_partial, dyn_args = argnums_partial( f, argnums, args, require_static_args_hashable=False) - return _vjp(f_partial, *dyn_args, has_aux=has_aux) # type: ignore + return _vjp(f_partial, *dyn_args, has_aux=has_aux) def bwd(f_vjp, outgrad): g = f_vjp(outgrad) g = g[0] if isinstance(argnums, int) else g @@ -1160,7 +1160,7 @@ def vmap(fun: F, # rather than raising an error. https://github.com/jax-ml/jax/issues/2367 in_axes = tuple(in_axes) - from jax._src import hijax # type: ignore + from jax._src import hijax if not (in_axes is None or type(in_axes) in {int, tuple, *batching.spec_types} or isinstance(in_axes, hijax.MappingSpec)): raise TypeError("vmap in_axes must be an int, None, or a tuple of entries corresponding " @@ -2220,8 +2220,8 @@ def _vjp(fun, *primals, has_aux=False): out_known = [pval.is_known() for pval in out_pvals] id_map = {id(x): i for i, x in enumerate(primals_flat)} used, opaque_residuals = set(), [] - spec = [used.add(id(r)) or RSpec(id_map[id(r)], True) if id(r) in id_map else # type: ignore - RSpec(opaque_residuals.append(r) or (len(opaque_residuals) - 1), False) # type: ignore + spec = [used.add(id(r)) or RSpec(id_map[id(r)], True) if id(r) in id_map else + RSpec(opaque_residuals.append(r) or (len(opaque_residuals) - 1), False) for r in residuals] args_res = tuptree_map(lambda x: x if id(x) in used else NotNeeded(), in_tree, primals_flat) @@ -2352,7 +2352,7 @@ class VJP: out_tree: PyTreeDef args_res: list[Any] opaque_residuals: list[Any] - jaxpr = property(lambda self: self.fun.args[2]) # type: ignore + jaxpr = property(lambda self: self.fun.args[2]) def __call__(self, out_ct, *extra_args): if extra_args: diff --git a/jax/_src/array.py b/jax/_src/array.py index ea18a848eb7c..03c374069b71 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -790,7 +790,7 @@ def get_data( return r if sharding.is_fully_replicated: - devices = list(sharding._internal_device_list.addressable_device_list) # type: ignore + devices = list(sharding._internal_device_list.addressable_device_list) # Only compute data once. per_device_values = [get_data((slice(None),) * len(shape))] * len(devices) else: @@ -831,7 +831,7 @@ def get_data( ) if dll is not None: - devices = [Format(dll, SingleDeviceSharding(d)) for d in devices] # type: ignore + devices = [Format(dll, SingleDeviceSharding(d)) for d in devices] # pxla.batched_device_put doesn't support Layout... Take the slow route arrays = api.device_put(per_device_values, devices) return ArrayImpl(aval, sharding, arrays, committed=True) @@ -1334,5 +1334,5 @@ def _token_global_result_handler(global_aval, out_sharding, committed): core.get_token_aval(), out_sharding, committed) def wrapper(array): return core.Token(array) - return array_handler.wrap(wrapper) # type: ignore + return array_handler.wrap(wrapper) pxla.global_result_handlers[core.AbstractToken] = _token_global_result_handler diff --git a/jax/_src/basearray.pyi b/jax/_src/basearray.pyi index 0ad9fb5a11fe..3ceb7ac6b2b1 100644 --- a/jax/_src/basearray.pyi +++ b/jax/_src/basearray.pyi @@ -83,8 +83,8 @@ class Array: # these return bool for object, so ignore override errors. def __lt__(self, other: ArrayLike) -> Array: ... def __le__(self, other: ArrayLike) -> Array: ... - def __eq__(self, other: ArrayLike) -> Array: ... # type: ignore[override] - def __ne__(self, other: ArrayLike) -> Array: ... # type: ignore[override] + def __eq__(self, other: ArrayLike) -> Array: ... # pyrefly: ignore[bad-override] + def __ne__(self, other: ArrayLike) -> Array: ... # pyrefly: ignore[bad-override] def __gt__(self, other: ArrayLike) -> Array: ... def __ge__(self, other: ArrayLike) -> Array: ... @@ -112,15 +112,15 @@ class Array: def __xor__(self, other: ArrayLike) -> Array: ... def __or__(self, other: ArrayLike) -> Array: ... - def __radd__(self, other: ArrayLike) -> Array: ... # type: ignore[misc] - def __rsub__(self, other: ArrayLike) -> Array: ... # type: ignore[misc] - def __rmul__(self, other: ArrayLike) -> Array: ... # type: ignore[misc] + def __radd__(self, other: ArrayLike) -> Array: ... + def __rsub__(self, other: ArrayLike) -> Array: ... + def __rmul__(self, other: ArrayLike) -> Array: ... def __rmatmul__(self, other: ArrayLike) -> Array: ... - def __rtruediv__(self, other: ArrayLike) -> Array: ... # type: ignore[misc] + def __rtruediv__(self, other: ArrayLike) -> Array: ... def __rfloordiv__(self, other: ArrayLike) -> Array: ... def __rmod__(self, other: ArrayLike) -> Array: ... def __rdivmod__(self, other: ArrayLike) -> Array: ... - def __rpow__(self, other: ArrayLike) -> Array: ... # type: ignore[misc] + def __rpow__(self, other: ArrayLike) -> Array: ... def __rlshift__(self, other: ArrayLike) -> Array: ... def __rrshift__(self, other: ArrayLike) -> Array: ... def __rand__(self, other: ArrayLike) -> Array: ... diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 1980ab5066be..71c499db8e0c 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -895,7 +895,7 @@ def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function ifrt_callback = _wrapped_callback ctx.module_context.add_host_callback(ifrt_callback) index = np.uint64(len(ctx.module_context.host_callbacks) - 1) - result = ffi.build_ffi_lowering_function( # type: ignore + result = ffi.build_ffi_lowering_function( call_target_name, has_side_effect=has_side_effect, )(ctx, *operands, index=np.uint64(index)) @@ -903,9 +903,9 @@ def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function if sharding is not None: mlir.set_sharding(result, sharding) - results = result.results # type: ignore + results = result.results if token: - token, *results = results # type: ignore + token, *results = results - return results, token, ifrt_callback # type: ignore + return results, token, ifrt_callback diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index eb964d23d94c..41f6713c72f2 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -984,7 +984,7 @@ def shard_map_error_check( in_avals[i] = sharder(mesh, manual_axes, check_vma, new_in_specs[i], v) with (jshmap._extend_axis_env(mesh, manual_axes), - mesh_lib.use_abstract_mesh(jshmap._as_manual_mesh(mesh, manual_axes)), # type: ignore[arg-type] + mesh_lib.use_abstract_mesh(jshmap._as_manual_mesh(mesh, manual_axes)), config._check_vma(check_vma)): # jaxpr to checked_jaxpr checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr( diff --git a/jax/_src/clusters/cloud_tpu_cluster.py b/jax/_src/clusters/cloud_tpu_cluster.py index 65ce1151533d..82e6a45686c1 100644 --- a/jax/_src/clusters/cloud_tpu_cluster.py +++ b/jax/_src/clusters/cloud_tpu_cluster.py @@ -154,7 +154,7 @@ def _get_num_slices() -> int: num_slices = get_tpu_env_value('MEGASCALE_NUM_SLICES') if not num_slices: return 1 - return int(num_slices) # type: ignore + return int(num_slices) @staticmethod @@ -162,7 +162,7 @@ def _get_slice_id() -> int: slice_id = get_tpu_env_value('MEGASCALE_SLICE_ID') if not slice_id: return 0 - return int(slice_id) # type: ignore + return int(slice_id) @staticmethod def _get_process_id_in_slice() -> int: diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index f4c2e648921b..cd75c07d2724 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -1817,7 +1817,7 @@ def combine_bias_and_mask(bias, mask, dtype): large_negative_number = get_large_negative_number(dtype) mask = jnp.where(mask, jnp.asarray(0, dtype), large_negative_number) # reshape mask to have 4D shape - mask = mask.reshape((1,) * (4 - len(mask.shape)) + mask.shape) # type: ignore[union-attr] + mask = mask.reshape((1,) * (4 - len(mask.shape)) + mask.shape) # combine bias and mask if bias is None: @@ -1905,7 +1905,7 @@ def paged_attention( page_table_k, page_table_v, layout) has_bias = bias is not None has_dbias = has_bias and \ - should_export_dbias(bias.shape, query.shape, layout) # type: ignore[union-attr] + should_export_dbias(bias.shape, query.shape, layout) variadic_args = (has_bias, has_dbias) _not_used = jnp.zeros(0, dtype=query.dtype) @@ -2042,7 +2042,7 @@ def dot_product_attention( None, None, layout) has_bias = bias is not None has_dbias = has_bias and \ - should_export_dbias(bias.shape, query.shape, layout) # type: ignore[union-attr] + should_export_dbias(bias.shape, query.shape, layout) variadic_args = (has_bias, has_dbias) _not_used = jnp.zeros(0, dtype=query.dtype) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 0ceaf3ae1c8c..7ed1bc41a8bb 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -393,7 +393,7 @@ def bind_with_trace(self, trace, args, params, /): fun, jvp, tracers = args[0], args[1], args[2:] return trace.process_custom_jvp_call(self, fun, jvp, tracers, **params) - def impl(self, fun, _, *args): # type: ignore[bad-override] + def impl(self, fun, _, *args): raise NotImplementedError def get_bind_params(self, params): @@ -556,7 +556,7 @@ def f_bwd(res, g): def __new__(cls, fun, nondiff_argnums=(), nondiff_argnames=()): if config.custom_vjp3.value: - from jax._src.hijax import custom_vjp3 # type: ignore + from jax._src.hijax import custom_vjp3 return custom_vjp3(fun, nondiff_argnums, nondiff_argnames) else: return super().__new__(cls) @@ -1001,7 +1001,7 @@ def bind_with_trace(self, trace, args, params, /): fun, fwd, bwd, tracers = args[0], args[1], args[2], args[3:] return trace.process_custom_vjp_call(self, fun, fwd, bwd, tracers, **params) - def impl(self, fun, fwd, bwd, *args): # type: ignore[bad-override] + def impl(self, fun, fwd, bwd, *args): raise NotImplementedError def get_bind_params(self, params): diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 82b1ff67b57d..b842586b291a 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -343,7 +343,7 @@ def _different_device_order_reshard( new_mesh, inp_sharding.spec, memory_kind=target_sharding.memory_kind, _logical_device_ids=(None if permute_order is None else tuple(permute_order.tolist()))) - new_x = xc.reorder_shards(x, new_s, ArrayCopySemantics.REUSE_INPUT) # type: ignore + new_x = xc.reorder_shards(x, new_s, ArrayCopySemantics.REUSE_INPUT) return api.jit(_identity_fn, out_shardings=target_sharding, donate_argnums=donate_argnums)(new_x) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 1fcb9943b334..231189e1936c 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -506,8 +506,8 @@ def issubdtype(a: DTypeLike | ExtendedDType | None, # unhashable (e.g. custom objects with a dtype attribute). The following check is # fast and covers the majority of calls to this function within JAX library code. return _issubdtype_cached( - a if isinstance(a, _types_for_issubdtype) else np.dtype(a), # type: ignore[arg-type] - b if isinstance(b, _types_for_issubdtype) else np.dtype(b), # type: ignore[arg-type] + a if isinstance(a, _types_for_issubdtype) else np.dtype(a), + b if isinstance(b, _types_for_issubdtype) else np.dtype(b), ) @@ -1084,7 +1084,7 @@ def result_type(*args: Any, return_weak_type_flag: bool = False) -> DType | tupl if weak_type: dtype = default_types['f' if dtype in _custom_float_dtypes else dtype.kind]() # TODO(jakevdp): fix return type annotation and remove this ignore. - return (dtype, weak_type) if return_weak_type_flag else dtype # type: ignore[return-value] + return (dtype, weak_type) if return_weak_type_flag else dtype def check_and_canonicalize_user_dtype(dtype, fun_name=None) -> DType: """Checks validity of a user-provided dtype, and returns its canonical form. diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 98a4acec90e9..c21e8c300ff9 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -854,7 +854,7 @@ def _get_exported_vjp(exp_primal: Exported) -> Exported: apply_jit=True, flat_primal_fun=True, mesh=cur_mesh) # type: ignore[arg-type] - return export(fun_vjp_jax, # type: ignore[arg-type] + return export(fun_vjp_jax, platforms=exp_primal.platforms, disabled_checks=exp_primal.disabled_safety_checks)(*vjp_in_avals) @@ -905,7 +905,7 @@ def _module_to_bytecode(module: ir.Module) -> bytes: # StableHLO features from failing on older hardware. target_version = hlo.get_version_from_compatibility_requirement( hlo.StablehloCompatibilityRequirement.WEEK_4) - module_serialized = xla_client._xla.mlir.serialize_portable_artifact( # type: ignore + module_serialized = xla_client._xla.mlir.serialize_portable_artifact( mlir_str, target_version, xb.get_backend().serialize_with_sdy) return module_serialized @@ -950,8 +950,8 @@ def _wrap_main_func( def is_token(typ, attrs): return (typ == mlir.token_type()) - orig_input_types = orig_main.type.inputs # type: ignore - arg_attrs = list(ir.ArrayAttr(orig_main.arg_attrs)) # type: ignore + orig_input_types = orig_main.type.inputs + arg_attrs = list(ir.ArrayAttr(orig_main.arg_attrs)) # The order of args: platform_index_arg, dim args, token args, array args. nr_platform_index_args = 1 if has_platform_index_argument else 0 nr_dim_args = len(dim_vars) @@ -973,8 +973,8 @@ def is_token(typ, attrs): orig_input_types, [nr_platform_index_args, nr_dim_args, nr_token_args]) # The order of results: tokens, array results - orig_output_types = orig_main.type.results # type: ignore - result_attrs = list(ir.ArrayAttr(orig_main.result_attrs)) # type: ignore + orig_output_types = orig_main.type.results + result_attrs = list(ir.ArrayAttr(orig_main.result_attrs)) token_result_idxs = [i for i, (typ, attrs) in enumerate(zip(orig_output_types, result_attrs)) if is_token(typ, attrs)] @@ -1375,7 +1375,7 @@ def flattened_primal_fun_jax(*args_flat): if apply_jit: if has_named_shardings or mesh: vjp_in_shardings = tuple( - _get_named_sharding(has_named_shardings, named_sharding, # type: ignore + _get_named_sharding(has_named_shardings, named_sharding, hlo_sharding, aval, mesh) # type: ignore[arg-type] for named_sharding, hlo_sharding, aval in zip( itertools.chain(in_named_shardings, out_named_shardings), @@ -1517,7 +1517,7 @@ def pp_arg_dim(dim_idx: int | None) -> str: # it would be ambiguous whether we should continue tracing with a result # of type `f32[c]` or `f32[d]`. shape_constraints.check_statically(synthetic_eval) - exported_dim_values = [synthetic_eval.evaluate(solution[var]) # type: ignore[arg-type] + exported_dim_values = [synthetic_eval.evaluate(solution[var]) for var in exported_dim_vars] def make_aval(out_aval_idx: int): @@ -1626,7 +1626,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, ctx, x, x_aval, _get_named_sharding(exported._has_named_shardings, named_sharding, None, x_aval, None), - use_shardy=True) # type: ignore[arg-type] + use_shardy=True) for x, named_sharding, x_aval in zip( args, exported._in_named_shardings, exported.in_avals)) elif mesh: @@ -1635,7 +1635,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, wrap_with_sharding( ctx, x, x_aval, _get_named_sharding(False, None, hlo_sharding, x_aval, mesh), - use_shardy=True) # type: ignore[arg-type] + use_shardy=True) for x, hlo_sharding, x_aval in zip( args, exported.in_shardings_hlo, exported.in_avals)) else: @@ -1738,7 +1738,7 @@ def convert_shape(x: ir.Value, x_aval: core.AbstractValue, wrap_with_sharding( ctx, x, x_aval, _get_named_sharding(True, x_sharding, None, x_aval, None), - use_shardy=True) # type: ignore[arg-type] + use_shardy=True) for x, x_aval, x_sharding in \ zip(results, ctx.avals_out, exported._out_named_shardings)) elif mesh: @@ -1746,7 +1746,7 @@ def convert_shape(x: ir.Value, x_aval: core.AbstractValue, wrap_with_sharding( ctx, x, x_aval, _get_named_sharding(False, None, x_sharding, x_aval, mesh), - use_shardy=True) # type: ignore[arg-type] + use_shardy=True) for x, x_aval, x_sharding in \ zip(results, ctx.avals_out, exported.out_shardings_hlo)) else: diff --git a/jax/_src/export/serialization.py b/jax/_src/export/serialization.py index 579552924be7..33502e88f9a6 100644 --- a/jax/_src/export/serialization.py +++ b/jax/_src/export/serialization.py @@ -98,11 +98,11 @@ def _serialize_exported( # _has_named_shardings. in_shardings = _serialize_array( builder, partial(_serialize_sharding, has_named_sharding=exp._has_named_shardings), - zip(exp._in_named_shardings, exp.in_shardings_hlo) # type: ignore + zip(exp._in_named_shardings, exp.in_shardings_hlo) ) out_shardings = _serialize_array( builder, partial(_serialize_sharding, has_named_sharding=exp._has_named_shardings), - zip(exp._out_named_shardings, exp.out_shardings_hlo) # type: ignore + zip(exp._out_named_shardings, exp.out_shardings_hlo) ) ordered_effects = _serialize_array( builder, _serialize_effect, exp.ordered_effects @@ -218,9 +218,9 @@ def _deserialize_exported(exp: ser_flatbuf.Exported) -> _export.Exported: out_avals = _deserialize_tuple(exp.OutAvalsLength, exp.OutAvals, partial(_deserialize_aval, scope=scope, sharding=None)) in_shardings_hlo = cast(tuple[_export.HloSharding | None, ...], in_shardings) - in_shardings = (None,) * len(in_shardings) # type: ignore + in_shardings = (None,) * len(in_shardings) out_shardings_hlo = cast(tuple[_export.HloSharding | None, ...], out_shardings) - out_shardings = (None,) * len(out_shardings) # type: ignore + out_shardings = (None,) * len(out_shardings) platforms = _deserialize_tuple( exp.PlatformsLength, exp.Platforms, @@ -502,10 +502,10 @@ def _serialize_partition_spec(builder: flatbuffers.Builder, spec: partition_spec.PartitionSpec) -> int: partitions = _serialize_array(builder, _serialize_partition_spec_one_axis, spec._partitions) # pyrefly: ignore[bad-argument-type] - reduced = _serialize_array(builder, # type: ignore + reduced = _serialize_array(builder, lambda builder, ps: builder.CreateString(ps), spec.reduced) - unreduced = _serialize_array(builder, # type: ignore + unreduced = _serialize_array(builder, lambda builder, ps: builder.CreateString(ps), spec.unreduced) diff --git a/jax/_src/export/shape_poly.py b/jax/_src/export/shape_poly.py index 575bfc752591..c611eac99bc0 100644 --- a/jax/_src/export/shape_poly.py +++ b/jax/_src/export/shape_poly.py @@ -911,7 +911,7 @@ def _divmod(self, divisor: DimSize) -> tuple[DimSize, DimSize]: return quotient, remainder except InconclusiveDimensionOperation: return (_DimExpr._from_operation(_DimFactor.FLOORDIV, self, divisor, - scope=self.scope), # type: ignore + scope=self.scope), _DimExpr._from_operation(_DimFactor.MOD, self, divisor, scope=self.scope)) @@ -1633,7 +1633,7 @@ def expr(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]: t: Any t, tok = self.term(tok) t_sign = - t if next_t_negated else t - acc = acc + t_sign if acc is not None else t_sign # type: ignore[operator] + acc = acc + t_sign if acc is not None else t_sign if tok.exact_type in self.FOLLOW_EXPR: return acc, tok next_t_negated = (tok.exact_type == tokenize.MINUS) @@ -1655,7 +1655,7 @@ def term(self, tok: tokenize.TokenInfo) -> tuple[DimSize, tokenize.TokenInfo]: power, tok = self.integer(tok) f = f ** power - acc = acc * f if acc is not None else f # type: ignore[operator] + acc = acc * f if acc is not None else f if tok.exact_type in self.FOLLOW_TERM: return acc, tok # type: ignore[bad-return-type,unused-ignore] tok = self.consume_token(tok, tokenize.STAR) @@ -2030,7 +2030,7 @@ def compute_dim_vars_from_arg_shapes( } synthetic_eval = ShapeEvaluator(synthetic_env) shape_constraints.shape_assertions(synthetic_eval) - return tuple(synthetic_eval.evaluate(solution[var]) for var in dim_vars) # type: ignore[arg-type] + return tuple(synthetic_eval.evaluate(solution[var]) for var in dim_vars) def _solve_dim_equations( eqns: list[_DimEquation], diff --git a/jax/_src/ffi.py b/jax/_src/ffi.py index 739f56e2b19a..eec29be1a4be 100644 --- a/jax/_src/ffi.py +++ b/jax/_src/ffi.py @@ -326,7 +326,7 @@ def _lowering( **lowering_args, )(ctx, *operands, **params) - return result.results # type: ignore + return result.results return _lowering diff --git a/jax/_src/internal_test_util/test_harnesses.py b/jax/_src/internal_test_util/test_harnesses.py index a313fcea4283..5686763ad960 100644 --- a/jax/_src/internal_test_util/test_harnesses.py +++ b/jax/_src/internal_test_util/test_harnesses.py @@ -64,9 +64,6 @@ from jax._src.numpy import linalg as jnp_linalg from jax._src import random as jax_random -# mypy generates a lot of false positive due to re-assigned variables. -# mypy: disable-error-code="assignment, no-redef" - # The code in this file relies on the values of some flags that are defined by # jtu. Note that the following can not always be moved to a test file since # then the test file has to import jtu first (to define the flags) which is not diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 46271a03564d..70f9bced572c 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -52,7 +52,7 @@ MakeIotaHandler = Callable[[AxisSize], Array] def to_elt(trace: BatchTrace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt: - from jax._src import hijax # type: ignore + from jax._src import hijax handler = to_elt_handlers.get(type(x)) if handler: return handler(partial(to_elt, trace, get_idx), get_idx, x, spec) @@ -152,7 +152,7 @@ def _short_repr(self): @property def aval(self): - from jax._src import hijax # type: ignore + from jax._src import hijax aval = core.get_aval(self.val) if self._trace.axis_data.spmd_name is not None: if config._check_vma.value: diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 8d8227d907ea..2ab2752d1108 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -66,7 +66,6 @@ import numpy as np -# mypy: ignore-errors map, unsafe_map = util.safe_map, map zip, unsafe_zip = util.safe_zip, zip @@ -191,10 +190,10 @@ def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type: return ir_type_factory() def _array_ir_types(aval: core.ShapedArray) -> ir.Type: - aval = core.physical_aval(aval) # type: ignore + aval = core.physical_aval(aval) if not core.is_constant_shape(aval.shape): - return _dynamic_array_ir_types(aval) # type: ignore - return ir.RankedTensorType.get(aval.shape, dtype_to_ir_type(aval.dtype)) # type: ignore + return _dynamic_array_ir_types(aval) + return ir.RankedTensorType.get(aval.shape, dtype_to_ir_type(aval.dtype)) def _dynamic_array_ir_types(aval: core.ShapedArray) -> ir.Type: dyn_size = ir.ShapedType.get_dynamic_size() @@ -315,7 +314,7 @@ def _ndarray_constant_handler(val: np.ndarray | np.generic, for ax in range(val.ndim))] out = hlo.broadcast_in_dim( ir.RankedTensorType.get( - val.shape, dtype_to_ir_type(collapsed_val.dtype)), # type: ignore + val.shape, dtype_to_ir_type(collapsed_val.dtype)), _numpy_array_constant(collapsed_val), dense_int_array(other_axes)) # type: ignore return out @@ -330,7 +329,7 @@ def _ndarray_constant_handler(val: np.ndarray | np.generic, np.float16, np.float32, np.float64, np.complex64, np.complex128, np.bool_, np.longlong, dtypes.bfloat16]: - register_constant_handler(_scalar_type, _ndarray_constant_handler) # type: ignore + register_constant_handler(_scalar_type, _ndarray_constant_handler) def _python_scalar_handler(val, aval: core.AbstractValue | None): assert isinstance(aval, core.ShapedArray), aval @@ -398,7 +397,7 @@ def _numpy_array_attribute_handler(val: np.ndarray | np.generic) -> ir.Attribute np.float16, np.float32, np.float64, np.complex64, np.complex128, np.bool_, np.longlong, dtypes.bfloat16]: - register_attribute_handler(_scalar_type, _numpy_array_attribute_handler) # type: ignore + register_attribute_handler(_scalar_type, _numpy_array_attribute_handler) def _dtype_attribute_handler(dtype: np.dtype | np.generic) -> ir.Attribute: return ir.TypeAttr.get(dtype_to_ir_type(dtype)) @@ -980,7 +979,7 @@ def sharded_aval(aval: core.AbstractValue, return aval if not isinstance(aval, core.ShapedArray): raise NotImplementedError - return aval.update(sharding.shard_shape(aval.shape), sharding=None) # type: ignore + return aval.update(sharding.shard_shape(aval.shape), sharding=None) def eval_dynamic_shape(ctx: LoweringRuleContext, @@ -1092,7 +1091,7 @@ def _to_physical_op_sharding( axis_ctx.manual_axes): sharding = add_manual_axes(axis_ctx, sharding, aval.ndim) if config.use_shardy_partitioner.value: - return sharding._to_sdy_sharding(aval.ndim) # type: ignore + return sharding._to_sdy_sharding(aval.ndim) return sharding._to_xla_hlo_sharding(aval.ndim).to_proto() # type: ignore @@ -1602,7 +1601,7 @@ def lower_jaxpr_to_fun( token_types = [token_type() for _ in effects] token_avals = [core.abstract_token] * num_tokens # Order of arguments: dim vars, tokens, const_args, array inputs - input_avals = dim_var_avals + token_avals + list(in_avals) # type: ignore + input_avals = dim_var_avals + token_avals + list(in_avals) input_types = [*dim_var_types, *token_types, *input_types] output_avals = [core.abstract_token] * num_tokens + jaxpr.out_avals output_types = [*token_types, *output_types] @@ -2508,7 +2507,7 @@ def f_lowered(ctx: LoweringRuleContext, *args, **params): wrapped_fun, ctx.avals_in, lower=True) if any(isinstance(e, core.InternalMutableArrayEffect) for e in jaxpr.effects): - from jax._src.interpreters import pxla # type: ignore + from jax._src.interpreters import pxla closed_jaxpr = core.ClosedJaxpr(jaxpr, consts_for_constvars) closed_jaxpr = pxla._discharge_internal_refs(closed_jaxpr) jaxpr, consts_for_constvars = closed_jaxpr.jaxpr, closed_jaxpr.consts @@ -2999,15 +2998,15 @@ def get_sharding_attr( sharding: xc.OpSharding | SdyArray | SdyArrayList ) -> ir.Attribute: if isinstance(sharding, (SdyArray, SdyArrayList)): - return sharding.build() # type: ignore + return sharding.build() else: # If there are very large numbers of devices, use the proto representation. # The MHLO to HLO conversion supports both, and the proto representation is # more compact. - if len(sharding.tile_assignment_devices) > 100: # type: ignore - return ir.StringAttr.get(sharding.SerializeToString()) # type: ignore + if len(sharding.tile_assignment_devices) > 100: + return ir.StringAttr.get(sharding.SerializeToString()) else: - return ir.StringAttr.get(repr(xc.HloSharding.from_proto(sharding))) # type: ignore[arg-type] + return ir.StringAttr.get(repr(xc.HloSharding.from_proto(sharding))) def wrap_with_layout_op(ctx: LoweringRuleContext, diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 51e517ae4d2f..35d37d0af3fb 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -264,7 +264,7 @@ def process_call(self, primitive, f: lu.WrappedFun, tracers, params, /): num_new_args = len(res_tracers) + len(env_tracers) # pyrefly: ignore[bad-argument-type] # pyrefly#2385 new_jaxpr = convert_constvars_jaxpr(jaxpr) if isinstance(primitive, core.ClosedCallPrimitive): - new_jaxpr = close_jaxpr(new_jaxpr) # type: ignore + new_jaxpr = close_jaxpr(new_jaxpr) staged_params = dict(params, call_jaxpr=new_jaxpr) staged_params = update_params(staged_params, map(op.not_, in_knowns), num_new_args) @@ -794,7 +794,7 @@ def sort_key(t): env_vars, env_vals = unzip2(env.items()) invars = [*env_vars, *map(get_atom, in_tracers)] const_vars, const_vals = unzip2(consts.items()) - outvars = map(get_atom, out_tracers) # type: ignore[arg-type] + outvars = map(get_atom, out_tracers) jaxpr_effects = make_jaxpr_effects(const_vars, invars, outvars, eqns) is_high |= any(x.aval.is_high for x in it.chain(const_vars, invars, outvars)) # pyrefly: ignore[bad-argument-type] # pyrefly#2385 jaxpr = Jaxpr(const_vars, invars, # type: ignore[arg-type] @@ -1079,7 +1079,7 @@ def has_effects(effects) -> bool: foreach(partial(write, False, False), eqn.outvars) elif isinstance(policy, Offloadable): # TODO(slebedev): This is a legit error which requires a BUILD fix. - from jax._src.dispatch import device_put_p, ArrayCopySemantics # type: ignore + from jax._src.dispatch import device_put_p, ArrayCopySemantics resvars = [Var(v.aval.update(memory_space=core.mem_kind_to_space(policy.dst))) for v in eqn.outvars] offload_eqn = core.JaxprEqn( @@ -1604,7 +1604,7 @@ def move_binders_to_back(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool] class DynamicJaxprTracer(core.Tracer): - __slots__ = ['aval', 'val', 'mutable_qdd', 'parent', '_debug_info'] + __slots__ = ['_aval', 'val', 'mutable_qdd', 'parent', '_debug_info'] _trace: DynamicJaxprTrace # pyrefly: ignore[bad-override] @@ -1623,7 +1623,7 @@ def __init__(self, trace: DynamicJaxprTrace, self._trace = trace self._line_info = line_info self._debug_info = self._trace.frame.debug_info # for UnexpectedTracerError - self.aval = aval # type: ignore[misc] + self._aval = aval self.val = val self.mutable_qdd = core.MutableQuasiDynamicData(qdd) self.parent = parent @@ -1634,6 +1634,10 @@ def _short_repr(self): def cur_qdd(self): return self.mutable_qdd.cur_val + @property + def aval(self): + return self._aval + @property def aval_mutable_qdd(self): aval = self.aval @@ -1724,7 +1728,7 @@ def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects: f"`JaxprInputEffect` {eff} is invalid." f"\n Equation: {eqn}\n" "\n Jaxpr: " - f"{core.Jaxpr(constvars, invars, outvars, eqns, set(), dbg)}") # type: ignore + f"{core.Jaxpr(constvars, invars, outvars, eqns, set(), dbg)}") eqn_invar = eqn.invars[eff.input_index] if type(eqn_invar) is core.Literal or eqn_invar in mut_arrays: continue @@ -1740,7 +1744,7 @@ def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects: f"\n Equation: {eqn}\n" f"\n Effects: {eqn.effects}\n" "\n Jaxpr: " - f"{core.Jaxpr(constvars, invars, outvars, eqns, set(), dbg)}") # type: ignore + f"{core.Jaxpr(constvars, invars, outvars, eqns, set(), dbg)}") eff = eff.replace(input_index=input_index) jaxpr_effects.add(eff) return jaxpr_effects @@ -2118,7 +2122,7 @@ def process_call(self, call_primitive, f: lu.WrappedFun, in_tracers, new_jaxpr = convert_constvars_jaxpr(jaxpr) if isinstance(call_primitive, core.ClosedCallPrimitive): - new_jaxpr = close_jaxpr(new_jaxpr) # type: ignore + new_jaxpr = close_jaxpr(new_jaxpr) new_params = dict(params, call_jaxpr=new_jaxpr) update_params = call_param_updaters.get(call_primitive) if update_params: @@ -2238,7 +2242,7 @@ def out_trees_(): fun_jaxpr.effects, source_info=source_info) - def process_custom_transpose(self, prim: core.Primitive, # type: ignore[override] + def process_custom_transpose(self, prim: core.Primitive, # pyrefly: ignore[bad-override] call: lu.WrappedFun, tracers, *, transpose: lu.WrappedFun, out_types, diff --git a/jax/_src/interpreters/remat.py b/jax/_src/interpreters/remat.py index 71970f205099..7abe8c7cfe4a 100644 --- a/jax/_src/interpreters/remat.py +++ b/jax/_src/interpreters/remat.py @@ -125,7 +125,7 @@ def _remat_jaxpr(jaxpr, policy): src = source_info_util.current() def new_arg(a): - return RematTracer(trace, fwd_trace.new_arg(a, src), rem_trace.new_arg(a, src)) # type: ignore # noqa: F821 + return RematTracer(trace, fwd_trace.new_arg(a, src), rem_trace.new_arg(a, src)) # noqa: F821 tracers = map(new_arg, jaxpr.in_aval_qdds) with core.set_current_trace(trace, check_leaks=True): diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index aa53b862714a..4b2364928fb8 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -214,7 +214,7 @@ def scan(f, init, xs, length=None): args_avals = args.map(core.get_aval) init_avals, xs_avals = args_avals.unpack() - from jax._src.hijax import HiType # type: ignore + from jax._src.hijax import HiType if any(isinstance(a, HiType) for a in xs_avals): if length is None: raise ValueError("must provide `length` to `scan`") @@ -2274,7 +2274,7 @@ def _while_to_lojax(*hi_args, cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts out_mut, lo_outs = split_list(all_outs, [pe.num_himuts_out(body_jaxpr)]) pe.apply_himut(body_jaxpr, [*hi_bconsts, *hi_carry], out_mut) return pe.raise_lo_outs(body_jaxpr.out_avals, lo_outs) -while_p.to_lojax = _while_to_lojax # type: ignore +while_p.to_lojax = _while_to_lojax def _insert_binders(jaxpr, n_after, vals): avals = _map(typeof, vals) @@ -2463,9 +2463,9 @@ def fori_loop(lower, upper, body_fun, init_val): "are statically known.") if lower_dtype != dtype: - lower = lax.convert_element_type(lower, dtype) # type: ignore + lower = lax.convert_element_type(lower, dtype) if upper_dtype != dtype: - upper = lax.convert_element_type(upper, dtype) # type: ignore + upper = lax.convert_element_type(upper, dtype) while_body_fun = _fori_body_fun(body_fun, body_fun_dbg) _, _, result = while_loop(_fori_cond_fun, while_body_fun, (lower, upper, init_val)) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 89e9a1e27f70..b97a22b630af 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -50,8 +50,8 @@ unzip2) import numpy as np -unsafe_map, map = map, safe_map # type: ignore -unsafe_zip, zip = zip, safe_zip # type: ignore +unsafe_map, map = map, safe_map +unsafe_zip, zip = zip, safe_zip ### parallel traceables diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 6686b0f379ee..1b30a197bf6f 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -173,7 +173,7 @@ def dynamic_slice( """ start_indices = _dynamic_slice_indices( operand, start_indices, allow_negative_indices) - sizes = core.canonicalize_shape(slice_sizes) # type: ignore + sizes = core.canonicalize_shape(slice_sizes) operand, *start_indices = core.standard_insert_pvary(operand, *start_indices) return dynamic_slice_p.bind(operand, *start_indices, slice_sizes=tuple(sizes)) diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 931684ef8523..43a714f46a2d 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -487,7 +487,7 @@ def reducer_body(reducer: ir.Block) -> Sequence[ir.Value]: if jaxpr.effects: raise NotImplementedError('Cannot lower effectful `reduce_window`.') out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, jaxpr, ctx.name_stack, - mlir.TokenSet(), consts, *reducer.arguments, # type: ignore[misc] + mlir.TokenSet(), consts, *reducer.arguments, dim_var_values=ctx.dim_var_values, const_lowering=ctx.const_lowering, outer_traceback=ctx.traceback) return mlir.flatten_ir_values(out_nodes) @@ -997,7 +997,7 @@ def snd(t, t_aval): def reducer_body(reducer: ir.Block) -> Sequence[ir.Value]: x: ir.Value y: ir.Value - x, y = reducer.arguments # type: ignore + x, y = reducer.arguments assert select_prim is lax.ge_p or select_prim is lax.le_p cmp_op = "GE" if select_prim is lax.ge_p else "LE" out = hlo.SelectOp(mlir.compare_hlo(fst(x), fst(y), cmp_op), x, y) diff --git a/jax/_src/layout.py b/jax/_src/layout.py index 2d3a23941b76..8e9bf8eb4f42 100644 --- a/jax/_src/layout.py +++ b/jax/_src/layout.py @@ -49,7 +49,7 @@ def __init__(self, major_to_minor: tuple[int, ...], def from_pjrt_layout(pjrt_layout: xc.PjRtLayout): xla_layout = pjrt_layout._xla_layout() return Layout(xla_layout.minor_to_major()[::-1], # pytype: disable=wrong-arg-types - xla_layout.tiling(), # type: ignore[arg-type] + xla_layout.tiling(), # pyrefly: ignore[bad-argument-type] xla_layout.element_size_in_bits()) def __repr__(self): diff --git a/jax/_src/lru_cache.py b/jax/_src/lru_cache.py index 7a09dc92a17b..c08ffb10fcf9 100644 --- a/jax/_src/lru_cache.py +++ b/jax/_src/lru_cache.py @@ -22,7 +22,7 @@ filelock: Any | None = None try: - import filelock # type: ignore[no-redef] + import filelock # pyrefly: ignore[missing-import] except ImportError: pass diff --git a/jax/_src/named_sharding.py b/jax/_src/named_sharding.py index 20ff7c9ad91e..d1215c561596 100644 --- a/jax/_src/named_sharding.py +++ b/jax/_src/named_sharding.py @@ -204,7 +204,7 @@ def is_fully_addressable(self) -> bool: raise ValueError('is_fully_addressable is not implemented for ' '`jax.sharding.AbstractMesh`.') # return False if addressable_device_list is empty. - return self._internal_device_list.is_fully_addressable # type: ignore + return self._internal_device_list.is_fully_addressable @property def _is_concrete(self) -> bool: @@ -228,7 +228,7 @@ def is_fully_replicated(self) -> bool: array_mapping = get_array_mapping(self.spec) mesh_shape = self.mesh.shape num_partitions = 1 - for name in array_mapping: # type: ignore + for name in array_mapping: num_partitions *= mesh_shape[name] return num_partitions == 1 @@ -416,7 +416,7 @@ def named_sharding_to_xla_hlo_sharding( replicated_mesh_axes = [] for i, (axis_name, axis_val) in enumerate(mesh_shape.items()): - if axis_name not in array_mapping: # type: ignore + if axis_name not in array_mapping: replicated_mesh_axes.append((i, axis_val)) if len(replicated_mesh_axes) == len(mesh_shape) and not special_axes: @@ -424,7 +424,7 @@ def named_sharding_to_xla_hlo_sharding( mesh_permutation = [] new_mesh_shape = [1] * num_dimensions - for name, pos in sorted(array_mapping.items(), key=lambda x: x[1]): # type: ignore + for name, pos in sorted(array_mapping.items(), key=lambda x: x[1]): new_mesh_shape[pos] *= mesh_shape[name] mesh_permutation.append(mesh_axis_pos[name]) diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 8f10f236bb7d..2ebf7ae65c9c 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -13,7 +13,6 @@ # limitations under the License. # pytype: skip-file -# mypy: disable-error-code=has-type """Define methods which are dynamically added to JAX's Arrays and Tracers. This is done dynamically in order to avoid circular imports. diff --git a/jax/_src/numpy/einsum.py b/jax/_src/numpy/einsum.py index 3917a605b8fd..c3a4c244e72f 100644 --- a/jax/_src/numpy/einsum.py +++ b/jax/_src/numpy/einsum.py @@ -339,7 +339,7 @@ def einsum( # Enable other modules to override einsum_contact_path. # Indexed by the type of the non constant dimension -_poly_einsum_handlers = {} # type: ignore +_poly_einsum_handlers = {} def _default_poly_einsum_handler(*operands, **kwargs): dummy = collections.namedtuple('dummy', ['shape', 'dtype']) @@ -572,7 +572,7 @@ def filter_singleton_dims(operand, names, other_shape, other_names): out_sharding = (_get_inverse_sharding(out_sharding, names, result_names) if out_sharding is not None and names != result_names else out_sharding) - dot_out_sharding = ({} if out_sharding is None else # type: ignore + dot_out_sharding = ({} if out_sharding is None else {'out_sharding': out_sharding}) operand = _dot_general(lhs, rhs, dimension_numbers, precision, preferred_element_type=preferred_element_type, diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 1b06de17eb58..bf7e55d5ab98 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3404,7 +3404,7 @@ def clip( if min is not None: arr = ufuncs.maximum(min, arr) if max is not None: - arr = ufuncs.minimum(max, arr) # type: ignore + arr = ufuncs.minimum(max, arr) return asarray(arr) @@ -6422,13 +6422,13 @@ def _auto_repeat(fun, a, repeats, axis, total_repeat_length, out_sharding): return auto_axes(partial(fun, repeats, axis=axis, total_repeat_length=total_repeat_length), out_sharding=out_sharding, - axes=out_sharding.mesh.explicit_axes # type: ignore + axes=out_sharding.mesh.explicit_axes )(a) else: return auto_axes( partial(fun, axis=axis, total_repeat_length=total_repeat_length), out_sharding=out_sharding, - axes=out_sharding.mesh.explicit_axes # type: ignore + axes=out_sharding.mesh.explicit_axes )(repeats, a) def _repeat(repeats, arr, *, axis: int, @@ -6454,7 +6454,7 @@ def _repeat(repeats, arr, *, axis: int, axis = _canonicalize_axis(axis, len(input_shape)) aux_axis = axis + 1 aux_shape: list[DimSize] = list(input_shape) - aux_shape.insert(aux_axis, operator.index(repeats) if core.is_constant_dim(repeats) else repeats) # type: ignore + aux_shape.insert(aux_axis, operator.index(repeats) if core.is_constant_dim(repeats) else repeats) arr = lax.broadcast_in_dim( arr, aux_shape, [i for i in range(len(aux_shape)) if i != aux_axis]) result_shape: list[DimSize] = list(input_shape) diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index eb522aa56125..3bd966e2ca23 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -403,7 +403,7 @@ def matrix_power(a: ArrayLike, n: int) -> Array: z: Array | None = None result: Array | None = None while n > 0: - z = arr if z is None else (z @ z) # type: ignore[operator] + z = arr if z is None else (z @ z) n, bit = divmod(n, 2) if bit: result = z if result is None else (result @ z) diff --git a/jax/_src/numpy/polynomial.py b/jax/_src/numpy/polynomial.py index da054665b3b0..ca9857b62099 100644 --- a/jax/_src/numpy/polynomial.py +++ b/jax/_src/numpy/polynomial.py @@ -446,7 +446,7 @@ def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array: del p, x shape = lax.broadcast_shapes(p_arr.shape[1:], x_arr.shape) y = lax.full_like(x_arr, 0, shape=shape, dtype=x_arr.dtype) - y, _ = control_flow.scan(lambda y, p: (y * x_arr + p, None), y, p_arr, unroll=unroll) # type: ignore[misc] + y, _ = control_flow.scan(lambda y, p: (y * x_arr + p, None), y, p_arr, unroll=unroll) return y diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 53a78cf6c8ac..4ce903e3f0e3 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -91,7 +91,7 @@ def promote_dtypes_inexact(*args: ArrayLike) -> list[Array]: Promotes arguments to an inexact type.""" to_dtype, weak_type = dtypes.lattice_result_type(*args) - to_dtype_inexact = dtypes.to_inexact_dtype(to_dtype) # type: ignore[arg-type] + to_dtype_inexact = dtypes.to_inexact_dtype(to_dtype) return [lax._convert_element_type(x, to_dtype_inexact, weak_type) for x in args] @@ -408,7 +408,7 @@ def shape(a: ArrayLike | SupportsShape) -> tuple[int, ...]: if hasattr(a, "__jax_array__"): a = a.__jax_array__() # NumPy dispatches to a.shape if available. - return np.shape(a) # type: ignore[arg-type] + return np.shape(a) @export diff --git a/jax/_src/ops/scatter.py b/jax/_src/ops/scatter.py index f120d386b5fd..ecd885c4df00 100644 --- a/jax/_src/ops/scatter.py +++ b/jax/_src/ops/scatter.py @@ -83,7 +83,7 @@ def _scatter_update(x: ArrayLike, idx: Index | tuple[Index, ...], normalize_indices=normalize_indices) if out_sharding is not None: return auto_axes(internal_scatter, out_sharding=out_sharding, - axes=out_sharding.mesh.explicit_axes # type: ignore + axes=out_sharding.mesh.explicit_axes )(x, y, dynamic_idx) return internal_scatter(x, y, tuple(dynamic_idx)) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 8cd166bdbed7..4bdd1fb979ba 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -1171,9 +1171,9 @@ def get_grid_mapping( debug: bool = False, ) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]: if dynamic_shapes_export_enabled(): - dim_check : Any = jax_core.is_dim + dim_check: Any = jax_core.is_dim else: - dim_check : Any = jax_core.is_constant_dim # type: ignore[no-redef] + dim_check: Any = jax_core.is_constant_dim assert all(i is None or dim_check(i) for i in grid_spec.grid) grid_mapping_grid = tuple( dynamic_grid_dim if ( @@ -1235,7 +1235,7 @@ def get_grid_mapping( _convert_block_spec_to_block_mapping, index_map_avals=index_map_avals, index_map_tree=index_map_tree, - grid=grid_mapping_grid, # type: ignore[arg-type] + grid=grid_mapping_grid, vmapped_dims=(), debug=debug, ), @@ -1258,7 +1258,7 @@ def get_grid_mapping( _convert_block_spec_to_block_mapping, index_map_avals=index_map_avals, index_map_tree=index_map_tree, - grid=grid_mapping_grid, # type: ignore[arg-type] + grid=grid_mapping_grid, vmapped_dims=(), debug=debug, ), @@ -1295,9 +1295,9 @@ def get_grid_mapping( def unzip_dynamic_grid_bounds( grid_spec: GridSpec) -> tuple[GridSpec, tuple[Any, ...]]: if dynamic_shapes_export_enabled(): - new_grid : Any = grid_spec.grid + new_grid: Any = grid_spec.grid else: - new_grid : Any = tuple(d if isinstance(d, int) else None for d in grid_spec.grid) # type: ignore[no-redef] + new_grid: Any = tuple(d if isinstance(d, int) else None for d in grid_spec.grid) dynamic_bounds = tuple(d for d in grid_spec.grid if not isinstance(d, int)) # We can't use dataclasses.replace, because our fields are incompatible # with __init__'s signature. diff --git a/jax/_src/pallas/fuser/block_spec.py b/jax/_src/pallas/fuser/block_spec.py index d19b3cf1e951..7241c5e80046 100644 --- a/jax/_src/pallas/fuser/block_spec.py +++ b/jax/_src/pallas/fuser/block_spec.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# mypy: ignore-errors # pyrefly: ignore-errors # TODO(sharadmv): Enable type checking. diff --git a/jax/_src/pallas/fuser/custom_fusion_lib.py b/jax/_src/pallas/fuser/custom_fusion_lib.py index 9b705d86260d..8e1cc3b094c6 100644 --- a/jax/_src/pallas/fuser/custom_fusion_lib.py +++ b/jax/_src/pallas/fuser/custom_fusion_lib.py @@ -226,7 +226,7 @@ def _custom_fusion_mosaic_lowering_rule( lowering_context, pallas_jaxpr, *pallas_consts, *args) -@block_spec_lib.register_pull_block_spec_rule(custom_fusion_p) # type: ignore[arg-type] +@block_spec_lib.register_pull_block_spec_rule(custom_fusion_p) def _custom_fusion_pull_block_spec_rule( ctx : block_spec_lib.PullRuleContext, out_block_transforms : tuple[block_spec_lib.BlockIndexTransform, ...], @@ -238,7 +238,7 @@ def _custom_fusion_pull_block_spec_rule( return pull_block_spec_rule(out_block_transforms) -@block_spec_lib.register_push_block_spec_rule(custom_fusion_p) # type: ignore[arg-type] +@block_spec_lib.register_push_block_spec_rule(custom_fusion_p) def _custom_fusion_push_block_spec_rule( ctx : block_spec_lib.PushRuleContext, *block_specs : pallas_core.BlockSpec, @@ -250,7 +250,7 @@ def _custom_fusion_push_block_spec_rule( return push_block_spec_rule(block_specs) -@block_spec_lib.register_usage_rule(custom_fusion_p) # type: ignore[arg-type] +@block_spec_lib.register_usage_rule(custom_fusion_p) def _custom_fusion_usage_rule( ctx : block_spec_lib.UsageRuleContext, used_out: Sequence[set[block_spec_lib.Usage]], diff --git a/jax/_src/pallas/fuser/fusible_dtype.py b/jax/_src/pallas/fuser/fusible_dtype.py index 8a1684bb0466..f4d2ac6a22ca 100644 --- a/jax/_src/pallas/fuser/fusible_dtype.py +++ b/jax/_src/pallas/fuser/fusible_dtype.py @@ -43,7 +43,6 @@ from jax._src.util import foreach # TODO(sharadmv): Enable type checking. -# mypy: ignore-errors map, unsafe_map = util.safe_map, map zip, unsafe_zip = util.safe_zip, zip diff --git a/jax/_src/pallas/helpers.py b/jax/_src/pallas/helpers.py index 496e43e6fc0d..3b6af7f7214d 100644 --- a/jax/_src/pallas/helpers.py +++ b/jax/_src/pallas/helpers.py @@ -256,9 +256,9 @@ def kernel(in_ref, out_ref): name=name, metadata=metadata) if isinstance(body, api.NotSpecified): - return lambda fun: _make_kernel(fun, **kwds) # type: ignore[arg-type] + return lambda fun: _make_kernel(fun, **kwds) else: - return _make_kernel(body, **kwds) # type: ignore[arg-type] + return _make_kernel(body, **kwds) def with_scoped( diff --git a/jax/_src/pallas/hlo_interpreter.py b/jax/_src/pallas/hlo_interpreter.py index 643879bc9a44..7180468c69f0 100644 --- a/jax/_src/pallas/hlo_interpreter.py +++ b/jax/_src/pallas/hlo_interpreter.py @@ -97,7 +97,7 @@ def _dynamic_slice( output = slicing.dynamic_slice(value, start_idx, slice_sizes=block_shape) squeeze_dims = tuple(np.arange(len(is_squeeze))[np.array(is_squeeze, dtype=np.bool_)]) - return lax.squeeze(output, squeeze_dims) # type: ignore[arg-type] + return lax.squeeze(output, squeeze_dims) def _dynamic_update_slice(start_idx, block_shape, value, update, is_squeeze): diff --git a/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py b/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py index 87cbd762ea19..298f0bdfc2a8 100644 --- a/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py +++ b/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py @@ -140,7 +140,7 @@ def force_tpu_interpret_mode(params: InterpretParams = InterpretParams()): config.pallas_tpu_interpret_mode_context_manager.set_local(prev) def set_tpu_interpret_mode(params: InterpretParams = InterpretParams()): - config.pallas_tpu_interpret_mode_context_manager.set_global(params) # type: ignore[arg-type] + config.pallas_tpu_interpret_mode_context_manager.set_global(params) # TODO(jburnim): Do we want to support multiple instances of SharedMemory? @@ -1708,7 +1708,7 @@ def interpret_pallas_call( mosaic_params = mosaic_core.CompilerParams() else: assert isinstance(compiler_params, mosaic_core.CompilerParams) - mosaic_params = compiler_params # type: ignore[assignment] + mosaic_params = compiler_params del compiler_params args = [remove_memory_space_p.bind(a) for a in args] @@ -2150,7 +2150,7 @@ def _store_to_output_buffer(index, output_var, transform): assert len(next_start_indices[num_inputs + j].shape) == 1 transform = indexing.NDIndexer( indices=tuple( - indexing.ds(st, sz) if not iid else st # type: ignore[misc] + indexing.ds(st, sz) if not iid else st # pyrefly: ignore[bad-argument-type] for st, sz, iid in zip( cur_start_indices[num_inputs + j], block_shapes[num_inputs + j], diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 37b2aa54c348..144a9342b16c 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -82,7 +82,6 @@ import numpy as np # TODO(sharadmv): enable type checking -# mypy: ignore-errors NDIndexer = indexing.NDIndexer TPUMemorySpace = tpu_core.MemorySpace diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 521bcb473563..d24a943a23d8 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -365,7 +365,7 @@ def pallas_call_tpu_lowering_rule( mosaic_params = tpu_core.CompilerParams() else: assert isinstance(compiler_params, tpu_core.CompilerParams) - mosaic_params = compiler_params # type: ignore[assignment] + mosaic_params = compiler_params del mesh jax_mesh = None @@ -454,7 +454,7 @@ def mpmd_map_tpu_lowering_rule( mosaic_params = tpu_core.CompilerParams() else: assert isinstance(compiler_params, tpu_core.CompilerParams) - mosaic_params = compiler_params # type: ignore[assignment] + mosaic_params = compiler_params # TODO(slebedev): Check kernel type and raise if it is set. if mosaic_params.dimension_semantics is not None: diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 5b21cc14dd49..0c37e35b63e6 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -1840,7 +1840,7 @@ def _partition_grid( # TODO(sharadmv): take the product of many nondivisible dimensions to # potentially divide it more evenly largest_parallel_dimension = max(grid[i] for i in parallel_dimensions - if isinstance(grid[i], int)) # type: ignore + if isinstance(grid[i], int)) partition_dimension, *_ = ( i for i, d in enumerate(grid) @@ -1894,9 +1894,9 @@ def sync_copy(src: REF | BufferedRef, dst: REF | BufferedRef, indices): ) if copy_in: tpu_helpers.sync_copy(hbm_ref.at[hbm_slice], - bref.current_ref.at[bref_slice]) # type: ignore[union-attr] + bref.current_ref.at[bref_slice]) else: - tpu_helpers.sync_copy(bref.current_ref.at[bref_slice], # type: ignore[union-attr] + tpu_helpers.sync_copy(bref.current_ref.at[bref_slice], hbm_ref.at[hbm_slice]) diff --git a/jax/_src/pallas/mosaic/sc_core.py b/jax/_src/pallas/mosaic/sc_core.py index 228ca1ac55b6..0fc1ecd8bd0d 100644 --- a/jax/_src/pallas/mosaic/sc_core.py +++ b/jax/_src/pallas/mosaic/sc_core.py @@ -59,34 +59,38 @@ def __init__( def get_ref_aval(self) -> state.TransformedRef | state.AbstractRef: # TODO(sharadmv): Clean this up. ShapedArrayWithMemorySpace fails when we # try to apply JAX ops to it. - return AbstractRef(self.inner_aval, self.memory_space, self.tiling) + return AbstractRef(self.inner_aval, self.memory_space, tiling=self.tiling) class AbstractRef(state.AbstractRef): """An AbstractRef for SparseCore.""" - tiling: Tiling | None = None + tiling: Tiling | None def __init__( self, aval: jax_core.AbstractValue, memory_space: tpu_core.MemorySpace, - tiling: Tiling | None, + *, + kind: Any | None = None, + tiling: Tiling | None = None, ): - super().__init__(aval, memory_space) + super().__init__(aval, memory_space, kind) self.tiling = tiling - def update( # type: ignore[override] + def update( self, inner_aval: Any | None = None, memory_space: Any | None = None, + kind: Any | None = None, tiling: Tiling | None = None, ) -> AbstractRef: return AbstractRef( inner_aval if inner_aval is not None else self.inner_aval, memory_space if memory_space is not None else self.memory_space, - tiling if tiling is not None else self.tiling, + kind=kind if kind is not None else self.kind, + tiling=tiling if tiling is not None else self.tiling, ) diff --git a/jax/_src/pallas/mosaic/sc_lowering.py b/jax/_src/pallas/mosaic/sc_lowering.py index 5810c4ef4705..cf3a02fa510a 100644 --- a/jax/_src/pallas/mosaic/sc_lowering.py +++ b/jax/_src/pallas/mosaic/sc_lowering.py @@ -178,7 +178,7 @@ def lower_pipelined_jaxpr_into_module( if dimension_semantics is None: dimension_semantics = ("arbitrary",) * len(grid) # type: ignore dimension_semantics: Sequence[tpu_core.LiteralDimensionSemantics] = tuple( # pyrefly: ignore[redefinition] # pytype: disable=annotation-type-mismatch - map(tc_lowering._canonicalize_dimension_semantic, dimension_semantics) # type: ignore[arg-type] + map(tc_lowering._canonicalize_dimension_semantic, dimension_semantics) ) is_semaphore = [] @@ -1117,7 +1117,7 @@ def _jaxpr_call_lowering_rule( program_ids[axis] = user_grid_indices[axis] new_lowering_ctx = dataclasses.replace( ctx.lowering_context, - block_shapes=tuple(ref_block_shapes), # type: ignore + block_shapes=tuple(ref_block_shapes), user_grid_indices=program_ids, ) return tc_lowering.jaxpr_subcomp(new_lowering_ctx, jaxpr, *args) diff --git a/jax/_src/pallas/mosaic/tpu_info.py b/jax/_src/pallas/mosaic/tpu_info.py index 2ea709939ab6..e8b279bef563 100644 --- a/jax/_src/pallas/mosaic/tpu_info.py +++ b/jax/_src/pallas/mosaic/tpu_info.py @@ -89,7 +89,7 @@ def _num_physical_tensor_cores_per_chip(self) -> int: # pyrefly: ignore[bad-ret @property def num_physical_tensor_cores_per_chip(self) -> int: # TODO(slebedev): Remove this wrapper once pyrefly#2080 is fixed. - return cast(int, self._num_physical_tensor_cores_per_chip) # type: ignore[redundant-cast] + return cast(int, self._num_physical_tensor_cores_per_chip) @property def supports_megacore(self) -> bool: diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index 4be24c0ce8f5..6b646a3df51c 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -1078,7 +1078,7 @@ def to_block_mapping( ) block_inner_aval = bm.block_aval.inner_aval for t in self.transforms: - block_inner_aval = t.transform_type(block_inner_aval) # type: ignore[arg-type] + block_inner_aval = t.transform_type(block_inner_aval) return bm.replace( transformed_block_aval=bm.block_aval.update( inner_aval=block_inner_aval diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index e9db30e220c9..7b452c61dd19 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -76,7 +76,6 @@ # TODO(slebedev): Enable type checking. -# mypy: ignore-errors map, unsafe_map = util.safe_map, map zip, unsafe_zip = util.safe_zip, zip @@ -4148,7 +4147,7 @@ def _jaxpr_call_lowering_rule( ref_aval = treedef.unflatten(ref_aval) if isinstance(ref, tuple): ref, transforms = ref - ref_aval, transform_avals = ref_aval # type: ignore + ref_aval, transform_avals = ref_aval # We ignore other transforms here, because they are already embedded # in the jaxpr. assert isinstance(ref_aval, state_types.AbstractRef) diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index a9623d922565..e6a877801c0a 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -74,7 +74,7 @@ def pallas_call_lowering( gpu_params = gpu_core.CompilerParams() else: assert isinstance(compiler_params, gpu_core.CompilerParams) - gpu_params = compiler_params # type: ignore[assignment] + gpu_params = compiler_params jax_mesh = None axis_context = ctx.module_context.axis_context diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index 906f46f639e7..58078f46fbbf 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -273,7 +273,7 @@ def pipeline(*gmem_refs: state.AbstractRef): in_smem_refs, out_smem_refs = util.split_list( [ gpu_core.SMEM( - (max_concurrent_steps, *_get_block_shape(spec)), # type: ignore + (max_concurrent_steps, *_get_block_shape(spec)), ref.dtype, transforms=tuple( gpu_core.batch_transform(t, 1) @@ -695,7 +695,7 @@ def _get_scoped_allocs(*gmem_refs: AbstractRefPytree): slots = max_concurrent_steps if has_seq_dim else 1 smem_allocs.append( gpu_core.SMEM( - (slots, *_get_block_shape(spec)), # type: ignore + (slots, *_get_block_shape(spec)), gmem_ref.dtype, transforms=getattr(spec, "transforms", ()), ) @@ -710,7 +710,7 @@ def _get_scoped_allocs(*gmem_refs: AbstractRefPytree): consumed_barrier_type: Any if collective_axes: consumed_barrier_type = functools.partial( - gpu_core.ClusterBarrier, collective_axes=collective_axes # type: ignore + gpu_core.ClusterBarrier, collective_axes=collective_axes ) else: consumed_barrier_type = gpu_core.Barrier diff --git a/jax/_src/pallas/pipelining/schedule_api.py b/jax/_src/pallas/pipelining/schedule_api.py index cebd0e3b4e60..e30abc9726c1 100644 --- a/jax/_src/pallas/pipelining/schedule_api.py +++ b/jax/_src/pallas/pipelining/schedule_api.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# mypy: ignore-errors # pyrefly: ignore-errors # pylint: disable=missing-function-docstring # pylint: disable=g-doc-args diff --git a/jax/_src/pallas/pipelining/schedulers.py b/jax/_src/pallas/pipelining/schedulers.py index f156c8bbd45d..92a554fa779f 100644 --- a/jax/_src/pallas/pipelining/schedulers.py +++ b/jax/_src/pallas/pipelining/schedulers.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# mypy: ignore-errors # pyrefly: ignore-errors # pytype: disable=invalid-annotation # pytype: disable=wrong-arg-types diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 8332505e6261..456e61906863 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -56,7 +56,6 @@ import numpy as np # TODO(sharadmv): Enable type checking. -# mypy: ignore-errors # pytype: skip-file _T = TypeVar("_T") diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index b1e772130e4e..ef5fe0017e7a 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -78,7 +78,7 @@ def pallas_call_lowering( triton_params = triton_core.CompilerParams() else: assert isinstance(compiler_params, triton_core.CompilerParams) - triton_params = compiler_params # type: ignore[assignment] + triton_params = compiler_params num_warps = 4 if triton_params.num_warps is None else triton_params.num_warps num_stages = triton_params.num_stages diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 02f52e973f75..7fd944df9491 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -502,7 +502,7 @@ def _trace_for_jit( assert None not in out_shardings_leaves in_type = avals_ft.map2( - lambda a, x: core.AvalQDD(a, cur_qdd(x)) if a.has_qdd else a, # type: ignore + lambda a, x: core.AvalQDD(a, cur_qdd(x)) if a.has_qdd else a, args_ft) assert avals_ft is not None @@ -754,7 +754,7 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves, "pjit arguments", allow_uneven_sharding=False) check_aval_layout_compatibility( in_layouts_flat, in_avals, - dbg.safe_arg_names(len(in_avals)), "jit arguments") # type: ignore[arg-type] + dbg.safe_arg_names(len(in_avals)), "jit arguments") return in_shardings_flat, in_layouts_flat @util.cache(max_size=4096, trace_context_in_key=False) @@ -1090,7 +1090,7 @@ class MetaTy: committed: bool is_np_array: bool - replace = replace # type: ignore + replace = replace @property def shape(self): @@ -1370,9 +1370,9 @@ def _pjit_lowering(ctx: mlir.LoweringRuleContext, *args, name: str, const_args, const_arg_avals = util.unzip2(const_args_and_avals) in_avals = (*const_arg_avals, *jaxpr.in_avals) ca_shardings = const_args_shardings(const_args) - in_shardings = ca_shardings + in_shardings # type: ignore + in_shardings = ca_shardings + in_shardings ca_layouts = const_args_layouts(const_args, const_arg_avals, ca_shardings) - in_layouts = ca_layouts + in_layouts # type: ignore + in_layouts = ca_layouts + in_layouts func = _pjit_cached_lower_jaxpr_to_fun( ctx, name, jaxpr, len(const_args), in_avals, @@ -1473,7 +1473,7 @@ def _pjit_batcher_for_sharding( s.mesh, pxla.batch_spec(s.spec, dim, PartitionSpec.UNCONSTRAINED)) new_op = hlo_s.to_proto().clone() tad = list(new_op.tile_assignment_dimensions) - tad.insert(dim, 1) # type: ignore + tad.insert(dim, 1) new_op.tile_assignment_dimensions = tad new_gs = GSPMDSharding(s._internal_device_list, new_op) return pxla._get_out_sharding_from_orig_sharding([new_gs], [None], s, None)[0] @@ -1773,7 +1773,7 @@ def keep_where(l, should_keep): source_info_util.current()) for t in unknown_tracers_out: t.recipe = eqn if effects.partial_eval_kept_effects.filter_in(unknown_jaxpr.effects): - trace.effect_handles.append(pe.EffectHandle(unknown_tracers_in, eqn)) # type: ignore + trace.effect_handles.append(pe.EffectHandle(unknown_tracers_in, eqn)) return merge_lists(unknown_outs, known_out_vals, unknown_tracers_out) pe.custom_partial_eval_rules[jit_p] = _pjit_partial_eval @@ -1836,7 +1836,7 @@ def _pjit_transpose_fancy( compiler_options_kvs): primals_ctrefs, specs = ad.project_accums(args) in_flat, in_tree = tree_flatten((primals_ctrefs, cts_in)) - in_avals = [core.AvalQDD(a, cur_qdd(x)) if (a := typeof(x)).has_qdd # type: ignore + in_avals = [core.AvalQDD(a, cur_qdd(x)) if (a := typeof(x)).has_qdd else a for x in in_flat] trans_jaxpr, out_tree = _transpose_jaxpr_fancy(jaxpr, in_tree, (*in_avals,), specs) @@ -2449,7 +2449,7 @@ def _layout_constraint_impl(x, *, layout): raise ValueError( 'with_layout_constraint in eager mode can only be applied to' f' jax.Arrays. Got {type(x)}') - if x.format.layout == layout: # type: ignore + if x.format.layout == layout: return x return api.jit(_identity_fn, out_shardings=Format(layout, x.sharding))(x) layout_constraint_p.def_impl(_layout_constraint_impl) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 3bccd8099089..dde6bdde16e9 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -316,7 +316,7 @@ def __getstate__(self): # Overwritten immediately below @property - def at(self) -> _IndexUpdateHelper: assert False # type: ignore[override] + def at(self) -> _IndexUpdateHelper: assert False # pyrefly: ignore[bad-override] @property def T(self) -> PRNGKeyArray: assert False def __getitem__(self, key) -> PRNGKeyArray: assert False diff --git a/jax/_src/random.py b/jax/_src/random.py index 5da171c269af..4f3669413ef7 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -692,7 +692,7 @@ def permutation(key: ArrayLike, def _permutation(key, x, axis, independent): if independent or np.ndim(x) == 1: return _shuffle(key, x, axis) - ind = _shuffle(key, jnp.arange(x.shape[axis]), 0) # type: ignore[union-attr] + ind = _shuffle(key, jnp.arange(x.shape[axis]), 0) return jnp.take(x, ind, axis, unique_indices=True) diff --git a/jax/_src/shard_map.py b/jax/_src/shard_map.py index 992d9122ba9e..2b3115b4a27e 100644 --- a/jax/_src/shard_map.py +++ b/jax/_src/shard_map.py @@ -404,7 +404,7 @@ def _shmap_checks(mesh, axis_names, in_specs, out_specs, _smap): def _manual_spec(manual_axes, spec: P, mesh) -> P: - out: list[str | tuple[str, ...] | None] = [] # type: ignore + out: list[str | tuple[str, ...] | None] = [] s: str | None | tuple[str, ...] for s in spec: if s is None: @@ -427,7 +427,7 @@ def _manual_spec(manual_axes, spec: P, mesh) -> P: SpecErrorType = enum.Enum('SpecErrorType', ['input', 'out']) def _check_unreduced(error_type, mesh, manual_axes, specs): - from jax._src.hijax import HipSpec # type: ignore + from jax._src.hijax import HipSpec prefix = 'in' if error_type == SpecErrorType.input else 'out' full_manual = frozenset(mesh.axis_names) == manual_axes specs_flat, _ = tree_flatten(specs) @@ -455,7 +455,7 @@ def _check_unreduced(error_type, mesh, manual_axes, specs): def _check_specs(error_type: SpecErrorType, specs: Any, manual_axes) -> None: - from jax._src.hijax import HipSpec # type: ignore + from jax._src.hijax import HipSpec if error_type == SpecErrorType.input and specs is None: raise TypeError( "shard_map in_specs argument must be a pytree of " @@ -794,7 +794,7 @@ def _shard_map_staging( _check_vmas(mesh, out_specs_thunk(), [v.aval for v in jaxpr.outvars]) out_avals = [unshard_aval(mesh, check_vma, spec, aval) for spec, aval in zip(out_specs_thunk(), out_avals_)] - in_specs_staged = (P(),) * len(consts) + tuple(in_specs) # type: ignore + in_specs_staged = (P(),) * len(consts) + tuple(in_specs) with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(inner_mesh), config._check_vma(check_vma)): jaxpr = pe.convert_constvars_jaxpr(jaxpr) @@ -1252,7 +1252,7 @@ class _SpecError(Exception): pass def _check_vmas(mesh, specs, avals): - fail = [a.vma if isinstance(sp, P) and not _valid_repeats(mesh, a.vma, sp) # type: ignore + fail = [a.vma if isinstance(sp, P) and not _valid_repeats(mesh, a.vma, sp) else no_fail for sp, a in zip(specs, avals)] if any(f is not no_fail for f in fail): raise _RepError(fail) @@ -1874,7 +1874,7 @@ def _partial_eval_jaxpr_custom_rule( _, out_binders_staged = partition_list(inst_out, eqn.outvars) nv = core.gensym() all_names = _all_newly_manual_mesh_names(mesh, manual_axes) - lns = lambda a: a.nospec(mesh, check_vma, all_names) # type: ignore + lns = lambda a: a.nospec(mesh, check_vma, all_names) residuals, staged_in_res_specs = unzip2( [(nv(unshard_aval(mesh, check_vma, (rn := lns(var.aval)), var.aval)), rn) for var, w in zip(jaxpr_staged.invars[:num_res], which) if w]) diff --git a/jax/_src/sharding.py b/jax/_src/sharding.py index deb8a68773fb..4326e3f54c2f 100644 --- a/jax/_src/sharding.py +++ b/jax/_src/sharding.py @@ -39,7 +39,7 @@ def _addressable_devices_indices_map( if sharding.is_fully_addressable: return global_map return {d: global_map[d] - for d in sharding._internal_device_list.addressable_device_list} # type: ignore + for d in sharding._internal_device_list.addressable_device_list} @cache(max_size=4096, trace_context_in_key=False) def common_devices_indices_map( @@ -184,7 +184,7 @@ def has_addressable_devices(self) -> bool: def _addressable_device_assignment(self) -> XLADeviceAssignment: if self.is_fully_addressable: return self._device_assignment - return tuple(self._internal_device_list.addressable_device_list) # type: ignore + return tuple(self._internal_device_list.addressable_device_list) def shard_shape(self, global_shape: Shape) -> Shape: """Returns the shape of the data on each device. @@ -203,7 +203,7 @@ def is_equivalent_to(self: Sharding, other: Sharding, ndim: int) -> bool: try: return (are_hlo_shardings_equal(self._to_xla_hlo_sharding(ndim), other._to_xla_hlo_sharding(ndim)) - and self._internal_device_list == other._internal_device_list and # type: ignore + and self._internal_device_list == other._internal_device_list and self.memory_kind == other.memory_kind) # NotImplementedError is raised by PmapSharding because it can't lower # to OpSharding. So if `other` is a PmapSharding, default to a strict diff --git a/jax/_src/sharding_specs.py b/jax/_src/sharding_specs.py index e498923fa4b6..1c2d7bddd488 100644 --- a/jax/_src/sharding_specs.py +++ b/jax/_src/sharding_specs.py @@ -143,7 +143,7 @@ def spec_to_indices(shape: Sequence[int], int, a slice object with step=1, or a tuple thereof, to be treated as an index into the full logical array. """ - return tuple(spec.indices(shape).flat) # type: ignore + return tuple(spec.indices(shape).flat) def pmap_sharding_spec(nrep, axis_size, sharded_shape: Sequence[int], diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index bd7273e02390..837320c1afcc 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -252,7 +252,7 @@ def _eval_jaxpr_discharge_state( f"Did not ask for inval to be discharged but it was. ({invar=}," f" {new_inval=})" ) - env.write(invar, new_inval) # type: ignore[arg-type] + env.write(invar, new_inval) # pyrefly: ignore[bad-argument-type] else: # Default primitive rule, similar to `core.eval_jaxpr`. Note that here # we assume any higher-order primitives inside of the jaxpr are *not* diff --git a/jax/_src/state/indexing.py b/jax/_src/state/indexing.py index 8295b9a998d3..a497c27f439a 100644 --- a/jax/_src/state/indexing.py +++ b/jax/_src/state/indexing.py @@ -232,7 +232,7 @@ def from_indices_shape(cls, indices, shape) -> NDIndexer: # and this module. from jax._src.state import primitives as sp # pytype: disable=import-error other_indexers = [ - sp.broadcast_to(i, indexer_shape) for i in other_indexers # type: ignore[arg-type] + sp.broadcast_to(i, indexer_shape) for i in other_indexers # pyrefly: ignore[bad-argument-type] ] indices = tuple( merge_lists(is_slice_indexing, other_indexers, slice_indexers) diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index ad48cf00ef0b..146c0ef6aa68 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -82,7 +82,7 @@ def _get_to_lojax(ref, *idx, tree): idx = transforms[-1] return val_ty.ref_get_to_lojax(ref, idx) return val_ty.raise_val(*map(ref_get, val_ty.lower_val(ref._refs))) -get_p.to_lojax = _get_to_lojax # type: ignore +get_p.to_lojax = _get_to_lojax Indexer = Union[int, slice, Array, types.EllipsisType] @@ -190,7 +190,7 @@ def _swap_to_lojax(ref, val, *idx, tree): outs = [ref_swap(lo_ref, idx, lo_val) for lo_ref, lo_val in zip(lo_refs, lo_vals)] return val_ty.raise_val(*outs) -swap_p.to_lojax = _swap_to_lojax # type: ignore +swap_p.to_lojax = _swap_to_lojax @partial(traceback_util.api_boundary, repro_api_name="jax.ref.swap") @@ -767,7 +767,7 @@ def _batch_indexer( idx = lax.broadcast_in_dim(idx, new_integer_indexer_shape, bcast_dims) else: - idx = batching.moveaxis(idx, dim, 0) # type: ignore[arg-type] + idx = batching.moveaxis(idx, dim, 0) new_indices.append(idx) else: if ref_dim is not batching.not_mapped: diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 11b5d42b1020..cbaadddc238d 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -315,7 +315,7 @@ def count_subjaxpr_to_hlo_conversion(fun_name): counts = collections.Counter() thread_local_state.lower_jaxpr_to_fun_counts = counts def get(): - key, *others = {k for k in counts if fun_name in k} # type: ignore + key, *others = {k for k in counts if fun_name in k} if others: raise Exception(f"ambiguous name: {fun_name}") return counts[key] try: @@ -1540,7 +1540,7 @@ def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]: local_devices = list(xla_bridge.local_devices()) if len(local_devices) < size: raise unittest.SkipTest(f"Test requires {size} local devices") - mesh_devices = np.array(local_devices[:size]).reshape(shape) # type: ignore + mesh_devices = np.array(local_devices[:size]).reshape(shape) with mesh_lib.Mesh(mesh_devices, axis_names): yield diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index fa4ee83f6464..231aab9f0c0d 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -14,7 +14,6 @@ """JAX bindings for Mosaic.""" -# mypy: ignore-errors from __future__ import annotations import base64 diff --git a/jax/_src/traceback_util.py b/jax/_src/traceback_util.py index e26b1d0e0ec8..1306a4d1d9ab 100644 --- a/jax/_src/traceback_util.py +++ b/jax/_src/traceback_util.py @@ -240,5 +240,5 @@ def reraise_with_filtered_traceback(*args, **kwargs): repro_is_enabled = repro.is_enabled except ImportError: - repro = None # type: ignore - def repro_is_enabled(): return False # type: ignore + repro = None + def repro_is_enabled(): return False diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index b5c9be44e110..95c19d6c3fcc 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -1550,13 +1550,13 @@ def filter_statics_from_treedef(registry, treedef, statics): filtered = tuple( filter_statics_from_treedef(registry, td, s) for td, s in zip(treedef.children(), statics) if s is not True) - return treedef.from_node_data_and_children(registry, treedef.node_data(), filtered) # type: ignore + return treedef.from_node_data_and_children(registry, treedef.node_data(), filtered) elif isinstance(statics, dict): - ty, keys = treedef.node_data() # type: ignore + ty, keys = treedef.node_data() filtered_keys, filtered_subtrees = unzip2( (k, filter_statics_from_treedef(registry, td, statics[k])) for td, k in zip(treedef.children(), keys) if statics[k] is not True) - return treedef.from_node_data_and_children(registry, (ty, filtered_keys), filtered_subtrees) # type: ignore + return treedef.from_node_data_and_children(registry, (ty, filtered_keys), filtered_subtrees) else: assert False, "unreachable" diff --git a/jax/_src/util.py b/jax/_src/util.py index 16fd056f4ce7..ac3fb431d2da 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -320,9 +320,9 @@ def weakref_lru_cache( return _weakref_lru_cache(f, **kwargs) def _weakref_lru_cache(f, maxsize, trace_context_in_key, explain): - cached_f = lib_weakref_lru_cache.weakref_lru_cache( # type: ignore - config.trace_context if trace_context_in_key else _ignore, f, maxsize, # type: ignore - explain = lambda: explain if config.explain_cache_misses.value else None) # type: ignore + cached_f = lib_weakref_lru_cache.weakref_lru_cache( + config.trace_context if trace_context_in_key else _ignore, f, maxsize, + explain = lambda: explain if config.explain_cache_misses.value else None) register_cache(cached_f, str(f)) return cached_f diff --git a/jax/experimental/key_reuse/_core.py b/jax/experimental/key_reuse/_core.py index cc18ff13d183..fdf852bb836b 100644 --- a/jax/experimental/key_reuse/_core.py +++ b/jax/experimental/key_reuse/_core.py @@ -392,7 +392,7 @@ def is_consumed(var: core.Atom): if is_key(v) and np.any(consumed.get(v, False))), *(Source(i) for i, v in enumerate(jaxpr.outvars) if is_key(v) and resolve_forwards(v) not in all_inputs and not consumed.get(v, False)), - *(Forward(all_inputs.index(resolve_forwards(outvar)), idx_out) # type: ignore[arg-type] + *(Forward(all_inputs.index(resolve_forwards(outvar)), idx_out) for idx_out, outvar in enumerate(jaxpr.outvars) if is_key(outvar) and resolve_forwards(outvar) in all_inputs) ) diff --git a/jax/experimental/pallas/ops/gpu/attention.py b/jax/experimental/pallas/ops/gpu/attention.py index dea524e42ac3..f8a2ca1288f7 100644 --- a/jax/experimental/pallas/ops/gpu/attention.py +++ b/jax/experimental/pallas/ops/gpu/attention.py @@ -259,7 +259,7 @@ def mha( block_q=block_q, block_k=block_k, head_dim=head_dim, causal=causal) - in_specs = [ + in_specs: list[pl.BlockSpec | None] = [ pl.BlockSpec((None, block_q, None, head_dim_padded), lambda i, j, k: (j, i, k, 0)), pl.BlockSpec((None, kv_seq_len, None, head_dim_padded), @@ -268,7 +268,7 @@ def mha( lambda _, j, k: (j, 0, k, 0)), ] in_specs.append( - None # type: ignore[arg-type] + None if segment_ids is None else pl.BlockSpec((None, kv_seq_len), lambda _, j, k: (j, 0)) ) @@ -567,7 +567,7 @@ def _mha_backward(sm_scale: float, causal: bool, block_sizes: BlockSizes, jax.ShapeDtypeStruct(v.shape, v.dtype), ] - in_specs = [ + in_specs: list[pl.BlockSpec | None] = [ pl.BlockSpec((None, q_seq_len, None, head_dim_padded), lambda i, j, _: (i, 0, j, 0)), pl.BlockSpec((None, kv_seq_len, None, head_dim_padded), @@ -582,7 +582,7 @@ def _mha_backward(sm_scale: float, causal: bool, block_sizes: BlockSizes, pl.BlockSpec((None, None, q_seq_len), lambda i, j, _: (i, j, 0)), ] if segment_ids is None: - in_specs.insert(3, None) # type: ignore[arg-type] + in_specs.insert(3, None) else: in_specs.insert(3, pl.BlockSpec((None, kv_seq_len), lambda i, j, _: (i, 0))) From d36ae3ba7decdcf6cecc745ea9b2c878b64f46ef Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Wed, 4 Mar 2026 02:46:30 -0800 Subject: [PATCH 060/100] [Mosaic GPU][NFC] Add a `is_wg_semantics` method to `mosaic_gpu_test`. PiperOrigin-RevId: 878385486 --- tests/pallas/mosaic_gpu_test.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 812a6ad2bc69..03d37368474f 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -143,8 +143,11 @@ def setUp(self, *, artificial_shared_memory_limit=jtu._SMEM_SIZE_BOUND_FOR_TESTS super().setUp() self.enter_context(mgpu.core.artificial_shared_memory_limit(artificial_shared_memory_limit)) + def is_wg_semantics(self): + return self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Warpgroup + def skip_if_wg_semantics(self): - if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Warpgroup: + if self.is_wg_semantics(): self.skipTest("Not supported under WG semantics") def kernel(self, *args, **kwargs): @@ -175,7 +178,7 @@ def capture_stdout(self): def default_transforms( self, *, swizzle: int = 128, dtype: jnp.dtype ) -> Sequence[plgpu.Transform]: - if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Warpgroup: + if self.is_wg_semantics(): return () swizzle_elems = 8 * swizzle // dtypes.itemsize_bits(dtype) return ( @@ -598,7 +601,7 @@ def test_inline_mgpu(self, jnp_type): plgpu.SwizzleTransform(128), ) - if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Warpgroup: + if self.is_wg_semantics(): pallas_call_transforms = () else: pallas_call_transforms = transforms @@ -1863,7 +1866,7 @@ def body(acc): ) _ = jax.lax.while_loop(cond, body, strided_input) - if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Warpgroup: + if self.is_wg_semantics(): with self.assertRaisesRegex( ValueError, "Failed to infer a possible set of layouts", ): From a9a9ce4472cb53743069833987f2c9a5924e156b Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 4 Mar 2026 05:33:33 -0800 Subject: [PATCH 061/100] [Mosaic GPU] Add basic support for atomic reductions while storing PiperOrigin-RevId: 878440428 --- .../mosaic/gpu/fragmented_array.py | 58 ++++++++++++++++++- tests/mosaic/gpu_test.py | 21 +++++++ 2 files changed, 76 insertions(+), 3 deletions(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 70428c585c5c..fb2c391c25a5 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -21,7 +21,7 @@ import functools import itertools import math -from typing import Any, Protocol, TypeAlias, TypeVar, cast, overload, runtime_checkable +from typing import Any, Literal, Protocol, TypeAlias, TypeVar, cast, overload, runtime_checkable import jax import jax.experimental.mosaic.gpu as mgpu @@ -3021,7 +3021,12 @@ def _(val, idx): utils.debug_print(fmt_str, *idx, val, uniform=False) def store_untiled( - self, ref: ir.Value | utils.MultimemRef, *, swizzle: int = 16, optimized: bool = True + self, + ref: ir.Value | utils.MultimemRef, + *, + swizzle: int = 16, + optimized: bool = True, + atomic: Literal["add"] | None = None, ) -> None: if not isinstance(ref.type, ir.MemRefType): raise ValueError(ref) @@ -3029,12 +3034,18 @@ def store_untiled( case WGSplatFragLayout(): if isinstance(ref, utils.MultimemRef): raise NotImplementedError("Splat layout does not support multimem") + if atomic is not None: + raise NotImplementedError( + "Atomic stores not supported for splat layout" + ) # All values are the same so swizzle does not affect anything here. self._store_untiled_splat(ref) case WGStridedFragLayout(): if swizzle != 16: raise ValueError("Only TiledLayouts support swizzling") assert isinstance(self.layout, WGStridedFragLayout) + if atomic is not None: + raise NotImplementedError("Atomic stores not supported for warpgroup strided layouts") for get, _update, ref, idx in self.transfer_strided(ref, self.layout.vec_size): if isinstance(ref, utils.MultimemRef): ref.store(get(self.registers), idx) @@ -3043,7 +3054,7 @@ def store_untiled( case TiledLayout(): ref_shape = ir.MemRefType(ref.type).shape ref = utils.memref_reshape(ref, (*(1 for _ in ref_shape), *ref_shape)) - self.store_tiled(ref, swizzle=swizzle, optimized=optimized) + self.store_tiled(ref, swizzle=swizzle, optimized=optimized, atomic=atomic) case _: raise NotImplementedError(self.layout) @@ -3193,10 +3204,51 @@ def store_tiled( swizzle: int | None, optimized: bool = True, tiling_rank: int | None = None, + atomic: Literal["add"] | None = None, ): if not isinstance(self.layout, TiledLayout): raise NotImplementedError(self.layout) layout, shape = self.layout, self.shape + if atomic is not None: + if isinstance(ref, utils.MultimemRef): + raise NotImplementedError("Multimem refs do not support atomic stores") + if any(isinstance(d, Replicated) for d in layout.warp_dims + layout.lane_dims): + raise NotImplementedError( + "Atomic stores not supported for layouts with replicated dims" + ) + is_smem = utils.is_smem_ref(ref) + scope = "cta" if is_smem else "gpu" + space = ".shared::cta" if is_smem else "" + ptr_constraint = "r" if is_smem else "l" + stores = self.transfer_tiled( + ref, swizzle, layout, shape, optimized, ref_tiling_rank=tiling_rank + ) + i32 = ir.IntegerType.get_signless(32) + element_type = self.mlir_dtype + element_bitwidth = utils.bitwidth(element_type) + if isinstance(element_type, ir.F32Type): + ptx_type = "f32" + elif isinstance(element_type, ir.IntegerType) and element_bitwidth == 32: + ptx_type = "s32" if self.is_signed else "u32" + else: + raise NotImplementedError( + f"Unsupported element type for atomic stores: {element_type}" + ) + for get, _update, _idx, base_ptr in stores: + vreg = get(self.registers) + [vec_len] = vreg.type.shape + for i in range(vec_len): + assert element_bitwidth == 32 # Not implemented otherwise + reg = llvm.extractelement(vreg, arith.constant(i32, i)) + ptr = utils.getelementptr(base_ptr, [i], element_type) + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [ptr, reg], + f"red{space}.relaxed.{scope}.{atomic}.{ptx_type} [$0], $1;", + f"{ptr_constraint},r", + has_side_effects=True, + ) + return # Note that the loop below will "race" for layouts that replicate data. # However, in that case all of the racing writes store the same data, which # is ok in the CUDA memory model. diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 1c5ffba1d11d..a8052eed081e 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -580,6 +580,27 @@ def kernel(ctx, out, _): )() np.testing.assert_array_equal(iota, expected) + @parameterized.product(dtype=[jnp.float32, jnp.uint32, jnp.int32]) + def test_atomic_store(self, dtype): + m, n = 64, 64 + def kernel(ctx, out, smem): + del ctx + mlir_dtype = utils.dtype_to_ir_type(dtype) + mgpu.FragmentedArray.splat( + c(0, mlir_dtype), (m, n), is_signed=utils.is_signed(dtype) + ).store_untiled(smem) + gpu.barrier() + iota_tensor(m, n, dtype).store_untiled( + smem, optimized=False, atomic="add", + ) + gpu.barrier() + copy(smem, out) + x = np.arange(m * n, dtype=dtype).reshape(m, n) + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (256, 1, 1), (), x, x + )() + np.testing.assert_array_equal(result, 2 * x) + @parameterized.product( dtype=[jnp.float8_e5m2fnuz, jnp.float8_e5m2, jnp.float8_e4m3b11fnuz, jnp.float8_e4m3fn, jnp.float8_e4m3fnuz], From 1c48067683735a73612dc7d770b77373a5b687f6 Mon Sep 17 00:00:00 2001 From: Alexandros Theodoridis Date: Tue, 3 Mar 2026 11:00:28 +0000 Subject: [PATCH 062/100] Add ROCm wheel build and test pipeline to continuous CI - Add jax-rocm-plugin and jax-rocm-pjrt to allowed artifacts in build_artifacts.sh with ROCm version flag passthrough. - Create build_rocm_artifacts.yml reusable workflow that builds ROCm wheels in a ROCm container and uploads them to S3 via OIDC. - Extend wheel_tests_continuous.yml with build-rocm-artifacts, run-pytest-rocm, and run-bazel-test-rocm jobs. --- .github/workflows/build_rocm_artifacts.yml | 168 +++++++++++++++++++ .github/workflows/wheel_tests_continuous.yml | 85 +++++++++- build/rocm/rocm.bazelrc | 3 + ci/utilities/run_auditwheel.sh | 2 +- jaxlib/jax.bzl | 21 +++ jaxlib/rocm/rocm_version.bzl | 29 ++++ jaxlib/tools/BUILD.bazel | 30 +++- 7 files changed, 326 insertions(+), 12 deletions(-) create mode 100644 .github/workflows/build_rocm_artifacts.yml create mode 100644 jaxlib/rocm/rocm_version.bzl diff --git a/.github/workflows/build_rocm_artifacts.yml b/.github/workflows/build_rocm_artifacts.yml new file mode 100644 index 000000000000..f407d35c339a --- /dev/null +++ b/.github/workflows/build_rocm_artifacts.yml @@ -0,0 +1,168 @@ +# CI - Build ROCm Artifacts +# This workflow builds ROCm wheels (jax-rocm-plugin, jax-rocm-pjrt) in a ROCm container +# and uploads them to an S3 bucket. It can be triggered manually via workflow_dispatch or +# called by other workflows via workflow_call. +name: CI - Build ROCm Artifacts + +on: + workflow_dispatch: + inputs: + runner: + description: "Which runner should the workflow run on?" + type: choice + default: "linux-x86-64-1gpu-amd" + options: + - "linux-x86-64-1gpu-amd" + artifact: + description: "Which ROCm artifact to build?" + type: choice + default: "jax-rocm-plugin" + options: + - "jax-rocm-plugin" + - "jax-rocm-pjrt" + python: + description: "Which python version should the artifact be built for?" + type: choice + default: "3.12" + options: + - "3.11" + - "3.12" + - "3.13" + - "3.14" + rocm-version: + description: "Which ROCm version to build for?" + type: string + default: "7" + clone_main_xla: + description: "Should latest XLA be used?" + type: choice + default: "1" + options: + - "1" + - "0" + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: choice + default: 'no' + options: + - 'yes' + - 'no' + workflow_call: + inputs: + runner: + description: "Which runner should the workflow run on?" + type: string + default: "linux-x86-64-1gpu-amd" + artifact: + description: "Which ROCm artifact to build?" + type: string + default: "jax-rocm-plugin" + python: + description: "Which python version should the artifact be built for?" + type: string + default: "3.12" + rocm-version: + description: "Which ROCm version to build for?" + type: string + default: "7" + clone_main_xla: + description: "Should latest XLA be used?" + type: string + default: "1" + upload_artifacts_to_s3: + description: "Should the artifacts be uploaded to S3?" + default: true + type: boolean + s3_upload_uri: + description: "S3 location prefix to where the artifacts should be uploaded" + default: 's3://jax-ci-amd/rocm-wheels' + type: string + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: string + default: 'no' + outputs: + s3_upload_uri: + description: "S3 location prefix to where the artifacts were uploaded" + value: ${{ jobs.build-artifacts.outputs.s3_upload_uri }} + +permissions: + id-token: write + contents: read + actions: read + +jobs: + build-artifacts: + defaults: + run: + shell: bash + runs-on: ${{ inputs.runner }} + container: + image: "ghcr.io/rocm/jax-manylinux_2_28-rocm-${{ inputs.rocm-version }}:latest" + volumes: + - /data:/data + options: >- + --device=/dev/kfd + --device=/dev/dri + --security-opt seccomp=unconfined + --group-add video + --shm-size 64G + + env: + JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" + + name: "${{ inputs.artifact }}, py ${{ inputs.python }}, ROCm ${{ inputs.rocm-version }}" + + outputs: + s3_upload_uri: ${{ steps.store-s3-upload-uri.outputs.s3_upload_uri }} + + steps: + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + with: + persist-credentials: false + - name: Install Bazelisk + run: | + curl -fSsL -o /usr/local/bin/bazel https://github.com/bazelbuild/bazelisk/releases/latest/download/bazelisk-linux-amd64 + chmod +x /usr/local/bin/bazel + - name: ROCm Info + run: rocminfo + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Build ${{ inputs.artifact }} + timeout-minutes: 120 + run: | + bazel --bazelrc=build/rocm/rocm.bazelrc run \ + --config=rocm_release_wheel \ + --config=rocm_rbe \ + --repo_env=HERMETIC_PYTHON_VERSION="${{ inputs.python }}" \ + $DEPLOY_TARGET -- $(pwd)/dist/ + env: + DEPLOY_TARGET: ${{ inputs.artifact == 'jax-rocm-plugin' && '//jaxlib/tools:deploy_rocm_plugin_wheel' || '//jaxlib/tools:deploy_rocm_pjrt_wheel' }} + - name: Verify manylinux compliance + run: ./ci/utilities/run_auditwheel.sh + env: + JAXCI_OUTPUT_DIR: ${{ github.workspace }}/dist + - name: Configure AWS Credentials + if: ${{ inputs.upload_artifacts_to_s3 }} + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: arn:aws:iam::661452401056:role/jax-ci-amd-s3-oidc + aws-region: us-east-1 + - name: Upload artifacts to S3 + if: ${{ inputs.upload_artifacts_to_s3 }} + run: | + echo "Uploading wheels to S3..." + ls -lh $(pwd)/dist/*.whl + aws s3 cp --only-show-errors --recursive $(pwd)/dist/ "${INPUTS_S3_UPLOAD_URI}"/ + echo "Upload complete." + env: + INPUTS_S3_UPLOAD_URI: ${{ inputs.s3_upload_uri }} + - name: Store the S3 upload URI as an output + id: store-s3-upload-uri + if: ${{ inputs.upload_artifacts_to_s3 }} + run: echo "s3_upload_uri=${INPUTS_S3_UPLOAD_URI}" >> "$GITHUB_OUTPUT" + env: + INPUTS_S3_UPLOAD_URI: ${{ inputs.s3_upload_uri }} diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml index e9b1f174f062..2d7aa2486519 100644 --- a/.github/workflows/wheel_tests_continuous.yml +++ b/.github/workflows/wheel_tests_continuous.yml @@ -1,6 +1,6 @@ # CI - Wheel Tests (Continuous) # -# This workflow builds JAX artifacts and runs CPU/TPU/CUDA tests. +# This workflow builds JAX artifacts and runs CPU/TPU/CUDA/ROCm tests. # # It orchestrates the following: # 1. build-jaxlib-artifact: Calls the `build_artifacts.yml` workflow to build jaxlib and @@ -22,15 +22,26 @@ # that was built in the previous step and runs TPU tests. # 9. run-bazel-test-tpu: Calls the `bazel_test_tpu.yml` workflow which # runs Bazel TPU tests with py_import. +# 10. build-rocm-artifacts: Calls the `build_rocm_artifacts.yml` workflow to build ROCm plugin/pjrt +# wheels and uploads them to an S3 bucket. +# 11. run-pytest-rocm: Calls the `pytest_rocm.yml` workflow which downloads the jaxlib and +# ROCm artifacts and runs the ROCm tests. +# 12. run-bazel-test-rocm: Calls the `bazel_rocm.yml` workflow which runs the ROCm Bazel tests. name: CI - Wheel Tests (Continuous) permissions: - contents: read + id-token: write + contents: read + actions: read on: schedule: - cron: "0 */3 * * *" # Run once every 3 hours workflow_dispatch: # allows triggering the workflow run manually + pull_request: + paths: + - '.github/workflows/wheel_tests_continuous.yml' + - '.github/workflows/build_rocm_artifacts.yml' concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} @@ -285,4 +296,72 @@ jobs: gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} build_jaxlib: "wheel" build_jax: "wheel" - clone_main_xla: 1 \ No newline at end of file + clone_main_xla: 1 + + build-rocm-artifacts: + uses: ./.github/workflows/build_rocm_artifacts.yml + permissions: + id-token: write + contents: read + actions: read + strategy: + fail-fast: false + matrix: + runner: ["linux-x86-64-1gpu-amd"] + artifact: ["jax-rocm-plugin", "jax-rocm-pjrt"] + python: ["3.11"] + rocm-version: ["7"] + name: "Build ${{ format('{0}', 'ROCm') }} artifacts" + with: + runner: ${{ matrix.runner }} + artifact: ${{ matrix.artifact }} + python: ${{ matrix.python }} + rocm-version: ${{ matrix.rocm-version }} + clone_main_xla: 1 + upload_artifacts_to_s3: true + s3_upload_uri: 's3://jax-ci-amd/rocm-wheels/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + + run-pytest-rocm: + if: ${{ !cancelled() }} + needs: [build-jax-artifact, build-jaxlib-artifact, build-rocm-artifacts] + uses: ./.github/workflows/pytest_rocm.yml + strategy: + fail-fast: false + matrix: + runner: ["linux-x86-64-1gpu-amd", "linux-x86-64-4gpu-amd", "linux-x86-64-8gpu-amd"] + python: ["3.11"] + rocm: [ + {version: "7.2.0", tag: "rocm720"}, + ] + name: "Pytest ROCm (JAX artifacts version = ${{ format('{0}', 'head') }})" + with: + runner: ${{ matrix.runner }} + python: ${{ matrix.python }} + rocm-version: ${{ matrix.rocm.version }} + rocm-tag: ${{ matrix.rocm.tag }} + jaxlib-version: "head" + gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} + + run-bazel-test-rocm: + if: ${{ !cancelled() }} + needs: [build-jax-artifact, build-jaxlib-artifact, build-rocm-artifacts] + uses: ./.github/workflows/bazel_rocm.yml + strategy: + fail-fast: false + matrix: + runner: ["linux-x86-64-4gpu-amd", "linux-x86-64-8gpu-amd"] + python: ["3.11"] + rocm-version: ["7"] + enable-x64: [0] + name: "Bazel ROCm tests (JAX artifacts version = ${{ format('{0}', 'head') }})" + with: + runner: ${{ matrix.runner }} + python: ${{ matrix.python }} + rocm-version: ${{ matrix.rocm-version }} + enable-x64: ${{ matrix.enable-x64 }} + gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} + build_jaxlib: "false" + build_jax: "false" + jaxlib-version: "head" + run_multiaccelerator_tests: "false" + clone_main_xla: 1 diff --git a/build/rocm/rocm.bazelrc b/build/rocm/rocm.bazelrc index 9935cdce944c..50f5894baf33 100644 --- a/build/rocm/rocm.bazelrc +++ b/build/rocm/rocm.bazelrc @@ -20,6 +20,9 @@ common:rocm --copt=-Qunused-arguments # Used for @xla//build_tools/rocm:parallel_gpu_execute common:rocm --legacy_external_runfiles=true +build:rocm_release_wheel --config=rocm +build:rocm_release_wheel --@local_config_rocm//rocm:rocm_path_type=link_only + test:rocm --test_timeout=920,2400,7200,9600 test:rocm --flaky_test_attempts=3 test:rocm --test_verbose_timeout_warnings diff --git a/ci/utilities/run_auditwheel.sh b/ci/utilities/run_auditwheel.sh index 304dd1ab1792..296e2c422e02 100755 --- a/ci/utilities/run_auditwheel.sh +++ b/ci/utilities/run_auditwheel.sh @@ -18,7 +18,7 @@ # Get a list of all the wheels in the output directory. Only look for wheels # that need to be verified for manylinux compliance. -WHEELS=$(find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jaxlib*whl" -o -name "*jax*cuda*pjrt*whl" -o -name "*jax*cuda*plugin*whl" \)) +WHEELS=$(find "$JAXCI_OUTPUT_DIR/" -type f \( -name "*jaxlib*whl" -o -name "*jax*cuda*pjrt*whl" -o -name "*jax*cuda*plugin*whl" -o -name "*jax*rocm*pjrt*whl" -o -name "*jax*rocm*plugin*whl" \)) if [[ -z "$WHEELS" ]]; then echo "ERROR: No wheels found under $JAXCI_OUTPUT_DIR" diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 6b6f220ea733..be2bdbe7e2d8 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -585,6 +585,27 @@ def jax_wheel( source_files = source_files, ) +def deploy_wheel(name, wheel): + """Creates a runnable target that copies a wheel to a given directory. + + Usage: bazel run -- /output/dir + + Args: + name: the target name + wheel: the wheel target to deploy + """ + native.genrule( + name = name + "_gen", + srcs = [wheel], + outs = [name + ".sh"], + cmd = "echo '#!/bin/bash\ncp $(rootpath {wheel}) $$1' > $@".format(wheel = wheel), + ) + native.sh_binary( + name = name, + srcs = [name + ".sh"], + data = [wheel], + ) + def jax_source_package( name, source_package_binary, diff --git a/jaxlib/rocm/rocm_version.bzl b/jaxlib/rocm/rocm_version.bzl new file mode 100644 index 000000000000..662c783a4fa4 --- /dev/null +++ b/jaxlib/rocm/rocm_version.bzl @@ -0,0 +1,29 @@ +# Copyright 2026 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ROCm version constants derived from @local_config_rocm. + +The rocm_version_number() function from build_defs.bzl returns an integer +encoded as: major * 10000 + minor * 100 + patch (e.g. 70101 for 7.1.1). +""" + +load( + "@local_config_rocm//rocm:build_defs.bzl", + "rocm_version_number", +) + +_version = rocm_version_number() +ROCM_MAJOR_VERSION = str(_version // 10000) +ROCM_MINOR_VERSION = str((_version % 10000) // 100) +ROCM_PATCH_VERSION = str(_version % 100) diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index b09c34536df0..05aec252efa8 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -32,6 +32,7 @@ load( "//jaxlib:jax.bzl", "PLATFORM_TAGS_DICT", "compare_srcs_and_test_deps_test", + "deploy_wheel", "get_test_suite_list", "if_pypi_cuda_wheel_deps", "jax_wheel", @@ -39,6 +40,7 @@ load( "pytype_test", "wheel_sources", ) +load("//jaxlib/rocm:rocm_version.bzl", "ROCM_MAJOR_VERSION") licenses(["notice"]) # Apache 2 @@ -285,21 +287,21 @@ jax_wheel( name = "jax_rocm_plugin_wheel", enable_rocm = True, no_abi = False, - platform_version = "7", + platform_version = ROCM_MAJOR_VERSION, reproducible = True, source_files = [":jax_plugin_sources"], wheel_binary = ":build_gpu_kernels_wheel_tool", - wheel_name = "jax_rocm7_plugin", + wheel_name = "jax_rocm{}_plugin".format(ROCM_MAJOR_VERSION), ) jax_wheel( name = "jax_rocm_plugin_wheel_editable", editable = True, enable_rocm = True, - platform_version = "7", + platform_version = ROCM_MAJOR_VERSION, source_files = [":jax_plugin_sources"], wheel_binary = ":build_gpu_kernels_wheel_tool", - wheel_name = "jax_rocm7_plugin", + wheel_name = "jax_rocm{}_plugin".format(ROCM_MAJOR_VERSION), ) # JAX PJRT wheel targets. @@ -395,21 +397,33 @@ jax_wheel( name = "jax_rocm_pjrt_wheel", enable_rocm = True, no_abi = True, - platform_version = "7", + platform_version = ROCM_MAJOR_VERSION, reproducible = True, source_files = [":jax_pjrt_sources"], wheel_binary = ":build_gpu_plugin_wheel_tool", - wheel_name = "jax_rocm7_pjrt", + wheel_name = "jax_rocm{}_pjrt".format(ROCM_MAJOR_VERSION), ) jax_wheel( name = "jax_rocm_pjrt_wheel_editable", editable = True, enable_rocm = True, - platform_version = "7", + platform_version = ROCM_MAJOR_VERSION, source_files = [":jax_pjrt_sources"], wheel_binary = ":build_gpu_plugin_wheel_tool", - wheel_name = "jax_rocm7_pjrt", + wheel_name = "jax_rocm{}_pjrt".format(ROCM_MAJOR_VERSION), +) + +# ROCm wheel deploy targets. +# Usage: bazel run //jaxlib/tools:deploy_rocm_plugin_wheel -- /output/dir +deploy_wheel( + name = "deploy_rocm_plugin_wheel", + wheel = ":jax_rocm_plugin_wheel", +) + +deploy_wheel( + name = "deploy_rocm_pjrt_wheel", + wheel = ":jax_rocm_pjrt_wheel", ) # Py_import targets. From 4ec06d48fbc29ad3fff9ab96cda553911c9ac29e Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 4 Mar 2026 06:01:28 -0800 Subject: [PATCH 063/100] [Pallas][TPU kernel interpreter] Add optional logging to memory operations. PiperOrigin-RevId: 878449137 --- jax/_src/pallas/mosaic/interpret/BUILD | 9 +- .../mosaic/interpret/interpret_pallas_call.py | 111 +++++++++-- .../pallas/mosaic/interpret/shared_memory.py | 184 ++++++++++++++++-- jax/_src/pallas/mosaic/interpret/utils.py | 74 +++++-- .../mosaic_gpu/interpret/shared_memory.py | 9 - 5 files changed, 338 insertions(+), 49 deletions(-) diff --git a/jax/_src/pallas/mosaic/interpret/BUILD b/jax/_src/pallas/mosaic/interpret/BUILD index 2a86f2258032..9fcde71ee532 100644 --- a/jax/_src/pallas/mosaic/interpret/BUILD +++ b/jax/_src/pallas/mosaic/interpret/BUILD @@ -44,7 +44,9 @@ py_library( "//jax/_src:frozen_dict", "//jax/_src:lax", "//jax/_src:mlir", + "//jax/_src:partial_eval", "//jax/_src:source_info_util", + "//jax/_src:tree_util", "//jax/_src:typing", "//jax/_src:util", "//jax/_src/pallas", @@ -64,13 +66,17 @@ pytype_strict_library( srcs = ["shared_memory.py"], deps = [ ":race_detection_state", + ":utils", ":vector_clock", "//jax", "//jax/_src:source_info_util", "//jax/_src:typing", "//jax/_src/pallas", "//jax/_src/pallas/mosaic:core", - ] + py_deps("numpy"), + ] + py_deps([ + "absl/logging", + "numpy", + ]), ) pytype_strict_library( @@ -97,6 +103,7 @@ pytype_strict_library( deps = [ "//jax", "//jax/_src:core", + "//jax/_src:source_info_util", "//jax/_src:util", "//jax/_src/pallas", ] + py_deps("numpy"), diff --git a/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py b/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py index 298f0bdfc2a8..ea15251c79db 100644 --- a/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py +++ b/jax/_src/pallas/mosaic/interpret/interpret_pallas_call.py @@ -212,6 +212,7 @@ def _initialize_shared_memory( clean_up_barrier=threading.Barrier( num_devices, action=_clear_shared_memory ), + logging_mode=interpret_params.logging_mode, ) assert _shared_memory.num_cores == num_cores @@ -303,6 +304,7 @@ def _allocate_buffer( local_core_id: Array | None, memory_space: Array, val: Array, + source_info: source_info_util.SourceInfo | None = None, ): """Allocates a memory buffer on the device with id `device_id` and core with id `local_core_id`. @@ -316,6 +318,7 @@ def _allocate_buffer( buffer in. If the corresponding memory space is "any" (i.e. HBM), at most one buffer will be allocated and it will belong to (local) core id 0. val: Array of values to initialize the allocated buffer with. + source_info: Information about the source code location of the allocation. Returns: Integer id for the allocated buffer. @@ -356,7 +359,14 @@ def _allocate_buffer( val = val.copy() shared_memory.allocate_buffer( - key, ref_count=ref_count, value=np.array(val) + key, + ref_count=ref_count, + value=np.array(val), + logging_info=interpret_utils.LoggingInfo( + device_id=device_id, + local_core_id=lci, + source_info=source_info, + ), ) local_core_id_to_buffer_id[lci] = buffer_id @@ -375,7 +385,9 @@ def _local_core_id_or_zero_if_hbm(local_core_id: int, memory_space: str) -> int: return local_core_id -def _deallocate_buffer(device_id, local_core_id, memory_space, buffer_id): +def _deallocate_buffer( + device_id, local_core_id, memory_space, buffer_id, source_info=None +): device_id = int(device_id) local_core_id = int(local_core_id) memory_space = TPU_MEMORY_SPACE_NAMES[int(memory_space)] @@ -385,7 +397,14 @@ def _deallocate_buffer(device_id, local_core_id, memory_space, buffer_id): shared_memory = _get_shared_memory() key = (memory_space, buffer_id, device_id, local_core_id) - shared_memory.deallocate_buffer(key) + shared_memory.deallocate_buffer( + key, + logging_info=interpret_utils.LoggingInfo( + device_id=device_id, + local_core_id=local_core_id, + source_info=source_info, + ), + ) def _allocate_semaphores( @@ -532,7 +551,14 @@ def get( key = (memory_space, buffer_id, device_id, local_core_id_for_buffer) read_range = interpret_utils.to_range(transforms) ret, (shape, dtype), clock_ = shared_memory.get_buffer_content( - key, read_range, global_core_id + key, + read_range, + global_core_id, + logging_info=interpret_utils.LoggingInfo( + device_id=device_id, + local_core_id=local_core_id, + source_info=source_info, + ), ) clock = clock if clock is not None else clock_ @@ -657,7 +683,15 @@ def store( key = (memory_space, buffer_id, device_id, local_core_id_for_buffer) write_range = interpret_utils.to_range(transforms) in_bounds, (shape, _), clock_ = shared_memory.store_buffer_content( - key, write_range, val, global_core_id + key, + write_range, + val, + global_core_id, + logging_info=interpret_utils.LoggingInfo( + device_id=device_id, + local_core_id=local_core_id, + source_info=source_info, + ), ) clock = clock if clock is not None else clock_ @@ -728,7 +762,16 @@ def swap( key = (memory_space, buffer_id, device_id, local_core_id_for_buffer) read_write_range = interpret_utils.to_range(transforms) ret, (shape, _), clock = shared_memory.swap_buffer_content( - key, read_write_range, val, mask, global_core_id + key, + read_write_range, + val, + mask, + global_core_id, + logging_info=interpret_utils.LoggingInfo( + device_id=device_id, + local_core_id=local_core_id, + source_info=source_info, + ), ) if ret is None: @@ -847,7 +890,12 @@ def execute_read(self): # Signal the send semaphore. if self.src_sem is not None: self.src_sem.signal( - self.data_size, self.src_global_core_id, clock=self.clock + self.data_size, self.src_global_core_id, clock=self.clock, + logging_info=interpret_utils.LoggingInfo( + device_id=self.src_device_id, + local_core_id=self.src_local_core_id, + source_info=self.source_info, + ), ) self.state = DmaState.READ @@ -887,7 +935,12 @@ def execute_write(self): vc.inc_vector_clock(self.clock, self.virtual_device_id) self.dst_sem.signal( - self.data_size, self.dst_global_core_id, clock=self.clock + self.data_size, self.dst_global_core_id, clock=self.clock, + logging_info=interpret_utils.LoggingInfo( + device_id=self.dst_device_id, + local_core_id=self.dst_local_core_id, + source_info=self.source_info, + ), ) self.data = None @@ -993,7 +1046,7 @@ def dma_start( dma.execute_read_and_write() -def dma_wait(device_id, local_core_id, sem_id, size): +def dma_wait(device_id, local_core_id, sem_id, size, source_info=None): shared_memory = _get_shared_memory() device_id = int(device_id) @@ -1007,7 +1060,15 @@ def dma_wait(device_id, local_core_id, sem_id, size): [sem_id], global_core_id ) assert sem is not None - sem.wait(size, global_core_id, has_tasks=True) + sem.wait( + size, + global_core_id, + has_tasks=True, + logging_info=interpret_utils.LoggingInfo( + device_id=device_id, local_core_id=local_core_id, + source_info=source_info, + ), + ) def semaphore_signal( @@ -1017,6 +1078,7 @@ def semaphore_signal( inc, target_device_id, target_local_core_id, + source_info=None, ): shared_memory = _get_shared_memory() @@ -1042,10 +1104,15 @@ def semaphore_signal( inc, shared_memory.get_global_core_id(target_device_id, target_local_core_id), clock, + logging_info=interpret_utils.LoggingInfo( + device_id=device_id, + local_core_id=local_core_id, + source_info=source_info, + ), ) -def semaphore_wait(device_id, local_core_id, sem_id, value): +def semaphore_wait(device_id, local_core_id, sem_id, value, source_info=None): shared_memory = _get_shared_memory() device_id = int(device_id) @@ -1058,7 +1125,15 @@ def semaphore_wait(device_id, local_core_id, sem_id, value): [sem_id], global_core_id ) assert sem is not None - sem.wait(value, global_core_id) + sem.wait( + value, + global_core_id, + logging_info=interpret_utils.LoggingInfo( + device_id=device_id, + local_core_id=local_core_id, + source_info=source_info, + ), + ) _SEMAPHORE = mosaic_core.MemorySpace.SEMAPHORE @@ -1271,7 +1346,9 @@ def f(*args, jaxpr): memory_space = _forward_any_to_hbm(v.aval.memory_space) allocs.append( callback.io_callback( - _allocate_buffer, + functools.partial( + _allocate_buffer, source_info=eqn.source_info + ), jax.ShapeDtypeStruct((), jnp.int16), device_id, local_core_id, @@ -1297,7 +1374,9 @@ def f(*args, jaxpr): pass else: callback.io_callback( - _deallocate_buffer, + functools.partial( + _deallocate_buffer, source_info=eqn.source_info + ), None, device_id, local_core_id, @@ -1423,7 +1502,7 @@ def f(*args, jaxpr): read_shape = src_ref_aval.shape read_dtype = src_ref_aval.dtype callback.io_callback( - dma_wait, + functools.partial(dma_wait, source_info=eqn.source_info), (), device_id, local_core_id, @@ -1449,7 +1528,7 @@ def f(*args, jaxpr): target_device_id, eqn.params['device_id_type'], axis_sizes, axis_indices) callback.io_callback( - semaphore_signal, + functools.partial(semaphore_signal, source_info=eqn.source_info), (), device_id, local_core_id, diff --git a/jax/_src/pallas/mosaic/interpret/shared_memory.py b/jax/_src/pallas/mosaic/interpret/shared_memory.py index d61250fa16c2..9233f2f1a4f3 100644 --- a/jax/_src/pallas/mosaic/interpret/shared_memory.py +++ b/jax/_src/pallas/mosaic/interpret/shared_memory.py @@ -21,7 +21,9 @@ import threading from typing import Any, Callable, Literal +from absl import logging from jax._src.pallas.mosaic.interpret import vector_clock as vc +import jax._src.pallas.mosaic.interpret.utils as interpret_utils import numpy as np @@ -31,9 +33,11 @@ def __init__( self, shared_memory: SharedMemory, semaphore_id: int, + enable_logging: bool = False, ): self.shared_memory = shared_memory self.id: int = semaphore_id + self.enable_logging: bool = enable_logging # TODO(jburnim): Use one Condition variable per device. (Which will be # easier to do when we're using single integer device IDs.) @@ -66,10 +70,28 @@ def detect_races(self) -> bool: def dma_execution_mode(self) -> str: return self.shared_memory.dma_execution_mode + def _log(self, message: str): + """Logs a message to `absl.logging`. To be called while holding the lock on `self.cv`.""" + # Log every line separately to make sure `absl.logging` adds the correct + # prefix (i.e. I***