From 9f26f6276713dc8fc497c35ac0f8d5b916b2e325 Mon Sep 17 00:00:00 2001 From: Zachary Garrett Date: Wed, 16 Apr 2025 14:27:00 -0700 Subject: [PATCH] Update error descriptions for when result sharding is unlikely to be as expected. PiperOrigin-RevId: 748424909 --- drjax/_src/impls.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/drjax/_src/impls.py b/drjax/_src/impls.py index fb827d9..b33069a 100644 --- a/drjax/_src/impls.py +++ b/drjax/_src/impls.py @@ -146,7 +146,12 @@ def broadcast_to_placement( def single_arg_broadcast(x): replicated_tensor = jnp.tile(x, reps=[n_elements] + [1] * len(x.shape)) if mesh is None: - # No sharding expected, don't worry about it. + logging.warning( + 'No mesh found; defaulting to fully replicated broadcast and' + ' *NOT* adding sharding constraints over the requested placement' + ' axis %s.', + placement, + ) return replicated_tensor else: @@ -275,7 +280,7 @@ def _shard_slice(s): 'No mesh containing axis name %s found; defaulting to standard vmap.' ' Mesh contains names: %s', placement, - mesh.axis_names if mesh is not None else '', + mesh.axis_names if mesh is not None else 'None', ) # Users should be free to use whatever mesh their model needs without # _necessarily_ registering a mesh-dimension for every placement with