From 35a836276373c8d18491d69be7352c375b74cca1 Mon Sep 17 00:00:00 2001 From: Alex Kurakin Date: Fri, 30 May 2025 16:22:23 -0700 Subject: [PATCH] Introduce preference for jax abstract mesh in drjax. This will enable callers to use the jax.sharding.use_mesh(...) pattern, which does not alter the thread resources drjax was previously relying on. PiperOrigin-RevId: 765377136 --- drjax/_src/impls.py | 6 ++---- drjax/_src/impls_sharding_test.py | 23 ----------------------- 2 files changed, 2 insertions(+), 27 deletions(-) diff --git a/drjax/_src/impls.py b/drjax/_src/impls.py index f4462ae..59562d3 100644 --- a/drjax/_src/impls.py +++ b/drjax/_src/impls.py @@ -47,11 +47,9 @@ def call_jaxpr(fn, arg): # TODO(b/366437841): Remove use of pxla.thread_resources.env.physical_mesh, # which is a JAX internal API. -def _global_mesh() -> jax.sharding.Mesh | jax.sharding.AbstractMesh | None: +def _global_mesh() -> jax.sharding.Mesh | None: """Returns the JAX global mesh if installed, or `None` otherwise.""" - jax_global_mesh = jax.sharding.get_abstract_mesh() - if jax_global_mesh is None or jax_global_mesh.empty: - jax_global_mesh = pxla.thread_resources.env.physical_mesh + jax_global_mesh = pxla.thread_resources.env.physical_mesh return None if jax_global_mesh.empty else jax_global_mesh diff --git a/drjax/_src/impls_sharding_test.py b/drjax/_src/impls_sharding_test.py index 36b2b62..0f5f6c8 100644 --- a/drjax/_src/impls_sharding_test.py +++ b/drjax/_src/impls_sharding_test.py @@ -99,29 +99,6 @@ def test_broadcast_with_1x1_fully_replicates(self, mesh_as_context): # replicated. self.assertTrue(sharding.is_fully_replicated) - def test_broadcast_clients_with_jax_use_mesh(self): - global_mesh = create_global_mesh( - [_CLIENTS_AXIS_SIZE, _DATA_AXIS_SIZE], [_CLIENTS_AXIS, _DATA_AXIS] - ) - arg = jnp.zeros(shape=[_DATA_SIZE]) - with jax.sharding.use_mesh(global_mesh): - result = self._comp_factory.broadcast_to_placement( - arg, - _CLIENTS_AXIS, - ) - self.assertEqual(result.shape, (_NUM_CLIENTS, _DATA_SIZE)) - sharding = result.sharding - # If this sharding were fully replicated, we would be *replicating* the data - # on each chip, rather than putting half of the clients' broadcasted arrays - # on one set of client chips and half on the other. - self.assertFalse(sharding.is_fully_replicated) - # Each shard should host half the clients, but the arg's original dimension - # should be replicated. - self.assertEqual( - sharding.shard_shape(result.shape), - (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE), - ) - @parameterized.parameters(True, False) def test_broadcast_clients_shards_along_clients(self, mesh_as_context): global_mesh = create_global_mesh(