diff --git a/drjax/_src/impls.py b/drjax/_src/impls.py index b33069a..59562d3 100644 --- a/drjax/_src/impls.py +++ b/drjax/_src/impls.py @@ -100,23 +100,23 @@ def broadcast_to_placement( placement: str, mesh: jax.sharding.Mesh | jax.sharding.AbstractMesh | None = None, ) -> PlacedArray: - """Broadcasts (replicates) to the specified placement. + """Broadcasts (tiles) to the specified placement. That is, given an `arg` of shape `[a, ... b]`, and a `placement` with `n` elements, the result of this function should be an array of shape `[n, a, ... b]`, each of whose slices on the zeroth axis is identical to `arg`. - This function shards the resulting replicated array along a mesh axis - corresponding to 'placement', if a mesh is available and contains a - 'placement' axis. If `mesh` is provided then this defines the mesh to be - used. Else JAX's global mesh is used, if one is installed. If no mesh is - available or if the available mesh does not contain a 'placement' axis then - the result is fully replicated. + This function shards the resulting array along a mesh axis corresponding to + 'placement', if a mesh is available and contains a 'placement' axis. If + `mesh` is provided then this defines the mesh to be used. Else JAX's global + mesh is used, if one is installed. If no mesh is available or if the + available mesh does not contain a 'placement' axis then the result is + unconstrained and the GSPMD partitioner may do whatever it wants. - This function must also direct the GSPMD compiler to shard the - zeroth-axis slices of this replicated array in a similar manner to the - argument. + When a mesh is available and 'placement' is in the mesh, this function also + directs the GSPMD compiler to shard the zeroth-axis slices of this tiled + array in a similar manner to the argument. Args: arg: An array to be broadcast. @@ -125,7 +125,7 @@ def broadcast_to_placement( one is installed. Returns: - A logically replicated array along the zeroth axis, as described above. + A logically tiled array along the zeroth axis, as described above. """ if mesh is None: mesh = _global_mesh() @@ -144,15 +144,15 @@ def broadcast_to_placement( pspec = P(*([P.UNCONSTRAINED] * (len(arg.shape) + 1))) def single_arg_broadcast(x): - replicated_tensor = jnp.tile(x, reps=[n_elements] + [1] * len(x.shape)) + unconstrained_tensor = jnp.tile(x, reps=[n_elements] + [1] * len(x.shape)) if mesh is None: logging.warning( - 'No mesh found; defaulting to fully replicated broadcast and' + 'No mesh found; defaulting to fully unconstrained broadcast and' ' *NOT* adding sharding constraints over the requested placement' ' axis %s.', placement, ) - return replicated_tensor + return unconstrained_tensor else: def _shard_slice_like_arg(s): @@ -160,7 +160,7 @@ def _shard_slice_like_arg(s): return s_sharded original_dims_constrained = jax.vmap(_shard_slice_like_arg, in_axes=0)( - replicated_tensor + unconstrained_tensor ) fully_constrained = _constrain_if_mesh( mesh, original_dims_constrained, pspec