Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions drjax/_src/impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _placement_axis_in_mesh(
mesh: jax.sharding.Mesh | jax.sharding.AbstractMesh | None,
placement: str,
) -> bool:
"""Checks if a clients axis is present in the mesh."""
"""Checks if a placement axis is present in the mesh."""
if mesh is None:
return False
placement_is_in_mesh = placement in mesh.axis_names
Expand Down Expand Up @@ -146,7 +146,7 @@ def broadcast_to_placement(
if _placement_axis_in_mesh(mesh, placement):
pspec = P(placement, *([P.UNCONSTRAINED] * len(arg.shape)))
else:
# Without a clients axis in the mesh, we simply explicitly tell the
# Without a placement axis in the mesh, we simply explicitly tell the
# compiler that there are no constraints on this tensor. This will leave
# the choices in the hands of the compiler.
pspec = P(*([P.UNCONSTRAINED] * (len(arg.shape) + 1)))
Expand Down
Loading