Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions drjax/_src/impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,23 +100,23 @@ def broadcast_to_placement(
placement: str,
mesh: jax.sharding.Mesh | jax.sharding.AbstractMesh | None = None,
) -> PlacedArray:
"""Broadcasts (replicates) to the specified placement.
"""Broadcasts (tiles) to the specified placement.

That is, given an `arg` of shape `[a, ... b]`, and a `placement` with `n`
elements, the result of this function should be an array of shape
`[n, a, ... b]`, each of whose slices on the zeroth axis is identical to
`arg`.

This function shards the resulting replicated array along a mesh axis
corresponding to 'placement', if a mesh is available and contains a
'placement' axis. If `mesh` is provided then this defines the mesh to be
used. Else JAX's global mesh is used, if one is installed. If no mesh is
available or if the available mesh does not contain a 'placement' axis then
the result is fully replicated.
This function shards the resulting array along a mesh axis corresponding to
'placement', if a mesh is available and contains a 'placement' axis. If
`mesh` is provided then this defines the mesh to be used. Else JAX's global
mesh is used, if one is installed. If no mesh is available or if the
available mesh does not contain a 'placement' axis then the result is
unconstrained and the GSPMD partitioner may do whatever it wants.

This function must also direct the GSPMD compiler to shard the
zeroth-axis slices of this replicated array in a similar manner to the
argument.
When a mesh is available and 'placement' is in the mesh, this function also
directs the GSPMD compiler to shard the zeroth-axis slices of this tiled
array in a similar manner to the argument.

Args:
arg: An array to be broadcast.
Expand All @@ -125,7 +125,7 @@ def broadcast_to_placement(
one is installed.

Returns:
A logically replicated array along the zeroth axis, as described above.
A logically tiled array along the zeroth axis, as described above.
"""
if mesh is None:
mesh = _global_mesh()
Expand All @@ -144,23 +144,23 @@ def broadcast_to_placement(
pspec = P(*([P.UNCONSTRAINED] * (len(arg.shape) + 1)))

def single_arg_broadcast(x):
replicated_tensor = jnp.tile(x, reps=[n_elements] + [1] * len(x.shape))
unconstrained_tensor = jnp.tile(x, reps=[n_elements] + [1] * len(x.shape))
if mesh is None:
logging.warning(
'No mesh found; defaulting to fully replicated broadcast and'
'No mesh found; defaulting to fully unconstrained broadcast and'
' *NOT* adding sharding constraints over the requested placement'
' axis %s.',
placement,
)
return replicated_tensor
return unconstrained_tensor
else:

def _shard_slice_like_arg(s):
s_sharded, _ = shard_alike(s, x)
return s_sharded

original_dims_constrained = jax.vmap(_shard_slice_like_arg, in_axes=0)(
replicated_tensor
unconstrained_tensor
)
fully_constrained = _constrain_if_mesh(
mesh, original_dims_constrained, pspec
Expand Down
Loading