From 8a6d9f8009a2bb82035f43bcf286736dbd0d2ca6 Mon Sep 17 00:00:00 2001 From: Keith Rush Date: Tue, 27 May 2025 12:04:18 -0700 Subject: [PATCH] Introduce preference for jax abstract mesh in drjax. This will enable callers to use the jax.sharding.use_mesh(...) pattern, which does not alter the thread resouerces drjax was previously relying on. PiperOrigin-RevId: 763892547 --- drjax/_src/impls.py | 6 ++++-- drjax/_src/impls_sharding_test.py | 23 +++++++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/drjax/_src/impls.py b/drjax/_src/impls.py index 59562d3..f4462ae 100644 --- a/drjax/_src/impls.py +++ b/drjax/_src/impls.py @@ -47,9 +47,11 @@ def call_jaxpr(fn, arg): # TODO(b/366437841): Remove use of pxla.thread_resources.env.physical_mesh, # which is a JAX internal API. -def _global_mesh() -> jax.sharding.Mesh | None: +def _global_mesh() -> jax.sharding.Mesh | jax.sharding.AbstractMesh | None: """Returns the JAX global mesh if installed, or `None` otherwise.""" - jax_global_mesh = pxla.thread_resources.env.physical_mesh + jax_global_mesh = jax.sharding.get_abstract_mesh() + if jax_global_mesh is None or jax_global_mesh.empty: + jax_global_mesh = pxla.thread_resources.env.physical_mesh return None if jax_global_mesh.empty else jax_global_mesh diff --git a/drjax/_src/impls_sharding_test.py b/drjax/_src/impls_sharding_test.py index 0f5f6c8..36b2b62 100644 --- a/drjax/_src/impls_sharding_test.py +++ b/drjax/_src/impls_sharding_test.py @@ -99,6 +99,29 @@ def test_broadcast_with_1x1_fully_replicates(self, mesh_as_context): # replicated. self.assertTrue(sharding.is_fully_replicated) + def test_broadcast_clients_with_jax_use_mesh(self): + global_mesh = create_global_mesh( + [_CLIENTS_AXIS_SIZE, _DATA_AXIS_SIZE], [_CLIENTS_AXIS, _DATA_AXIS] + ) + arg = jnp.zeros(shape=[_DATA_SIZE]) + with jax.sharding.use_mesh(global_mesh): + result = self._comp_factory.broadcast_to_placement( + arg, + _CLIENTS_AXIS, + ) + self.assertEqual(result.shape, (_NUM_CLIENTS, _DATA_SIZE)) + sharding = result.sharding + # If this sharding were fully replicated, we would be *replicating* the data + # on each chip, rather than putting half of the clients' broadcasted arrays + # on one set of client chips and half on the other. + self.assertFalse(sharding.is_fully_replicated) + # Each shard should host half the clients, but the arg's original dimension + # should be replicated. + self.assertEqual( + sharding.shard_shape(result.shape), + (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE), + ) + @parameterized.parameters(True, False) def test_broadcast_clients_shards_along_clients(self, mesh_as_context): global_mesh = create_global_mesh(