diff --git a/drjax/_src/impls.py b/drjax/_src/impls.py index b86e24e..4a9a56e 100644 --- a/drjax/_src/impls.py +++ b/drjax/_src/impls.py @@ -63,7 +63,7 @@ def _placement_axis_in_mesh( mesh: jax.sharding.Mesh | jax.sharding.AbstractMesh | None, placement: str, ) -> bool: - """Checks if a clients axis is present in the mesh.""" + """Checks if a placement axis is present in the mesh.""" if mesh is None: return False placement_is_in_mesh = placement in mesh.axis_names @@ -146,7 +146,7 @@ def broadcast_to_placement( if _placement_axis_in_mesh(mesh, placement): pspec = P(placement, *([P.UNCONSTRAINED] * len(arg.shape))) else: - # Without a clients axis in the mesh, we simply explicitly tell the + # Without a placement axis in the mesh, we simply explicitly tell the # compiler that there are no constraints on this tensor. This will leave # the choices in the hands of the compiler. pspec = P(*([P.UNCONSTRAINED] * (len(arg.shape) + 1)))