From 2f62bbaef50afaed35453f4cfdbbd40a58673637 Mon Sep 17 00:00:00 2001 From: Zachary Garrett Date: Tue, 30 Sep 2025 07:55:53 -0700 Subject: [PATCH] Extend DrJax to be compatible with Explicit Sharding. Explicit Sharding: https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html This updates `drjax.broadcast` and `drjax.map` to be compatible with meshes that have all `AxisType.Explicit` axes. Namely, this uses `jax.sharding.reshard` on the outputs of these methods, and avoids usage of `jax.lax.with_sharding_constraints` usage. DrJax still requires homogeneous axis types (all Auto or all Explicit). PiperOrigin-RevId: 813265584 --- drjax/_src/api_test.py | 204 +++++++++++++++++++++--------- drjax/_src/impls.py | 178 +++++++++++++++++--------- drjax/_src/impls_sharding_test.py | 172 ++++++++++++++++++------- 3 files changed, 391 insertions(+), 163 deletions(-) diff --git a/drjax/_src/api_test.py b/drjax/_src/api_test.py index 7d1ec75..3a9298a 100644 --- a/drjax/_src/api_test.py +++ b/drjax/_src/api_test.py @@ -13,6 +13,7 @@ # limitations under the License. import functools +import itertools from absl.testing import absltest from absl.testing import parameterized @@ -28,18 +29,39 @@ def drjax_program(*, placements): return api.drjax_program(placements=placements, self_module=api) -@parameterized.named_parameters( - ("clients_placed", "clients"), ("XY_placed", "XY") +@parameterized.product( + placement_name=["clients", "XY"], + axes_type=[ + jax.sharding.AxisType.Auto, + jax.sharding.AxisType.Explicit, + ], ) class ApiTest(absltest.TestCase): - def test_sharded_broadcast(self, placement_name): + def assertShardingEqual(self, arr, sharding): + canonical_array_sharding = jax.sharding.NamedSharding( + arr.sharding.mesh, + # Canonicalize with trailing `None`s to the rank of the input array. + # This canonicalizes across Auto and Explicit axis types, the former + # which may not include trailing `None`s. + jax.sharding.PartitionSpec(*( + axis + for axis, _ in itertools.zip_longest(arr.sharding.spec, arr.shape) + )), + ) + self.assertEqual(canonical_array_sharding, sharding) + + def test_broadcast_with_placement_in_mesh(self, placement_name, axes_type): @drjax_program(placements={placement_name: 100}) def broadcast_val(val): return api.broadcast(val) - mesh = jax.sharding.Mesh(np.array(jax.devices()), ("some_axis",)) + mesh = jax.sharding.Mesh( + np.array(jax.devices()), + axis_names=("some_axis",), + axis_types=(axes_type,), + ) arg_sharding = jax.sharding.NamedSharding( mesh, jax.sharding.PartitionSpec("some_axis") ) @@ -52,14 +74,19 @@ def broadcast_val(val): # No clients dimension in the mesh, we don't lay out the clients along that # nonexistent dimension, but rather replicate them. Notice that we don't # need to specify the sharding to DrJAX; it should be inferred by GSPMD. - expected_result_pspec = jax.sharding.PartitionSpec(None, "some_axis") - self.assertEqual( - result.sharding, jax.sharding.NamedSharding(mesh, expected_result_pspec) + expected_result_pspec = jax.sharding.PartitionSpec(None, "some_axis", None) + self.assertShardingEqual( + result, jax.sharding.NamedSharding(mesh, expected_result_pspec) ) - def test_sharded_broadcast_mesh_arg(self, placement_name): - - mesh = jax.sharding.Mesh(np.array(jax.devices()), ("some_axis",)) + def test_broadcast_mesh_arg_without_placement( + self, placement_name, axes_type + ): + mesh = jax.sharding.Mesh( + np.array(jax.devices()), + axis_names=("some_axis",), + axis_types=(axes_type,), + ) @drjax_program(placements={placement_name: 100}) def broadcast_val(val): @@ -74,15 +101,16 @@ def broadcast_val(val): # No clients dimension in the mesh, we don't lay out the clients along that # nonexistent dimension, but rather replicate them. Notice that we don't # need to specify the sharding to DrJAX; it should be inferred by GSPMD. - expected_result_pspec = jax.sharding.PartitionSpec(None, "some_axis") - self.assertEqual( - result.sharding, jax.sharding.NamedSharding(mesh, expected_result_pspec) + expected_result_pspec = jax.sharding.PartitionSpec(None, "some_axis", None) + self.assertShardingEqual( + result, jax.sharding.NamedSharding(mesh, expected_result_pspec) ) - def test_fully_sharded_broadcast_mesh_arg(self, placement_name): - + def test_fully_sharded_broadcast_mesh_arg(self, placement_name, axes_type): mesh = jax.sharding.Mesh( - np.array(jax.devices()).reshape([4, 2]), (placement_name, "some_axis") + np.array(jax.devices()).reshape([4, 2]), + axis_names=(placement_name, "some_axis"), + axis_types=(axes_type, axes_type), ) @drjax_program(placements={placement_name: 8}) @@ -98,58 +126,97 @@ def broadcast_val(val): chex.assert_trees_all_close(result, jnp.ones(shape=[8, 8, 8])) # The result should be sharded across the placement_name axis. expected_result_pspec = jax.sharding.PartitionSpec( - placement_name, "some_axis" + placement_name, "some_axis", None ) - self.assertEqual( - result.sharding, jax.sharding.NamedSharding(mesh, expected_result_pspec) + self.assertShardingEqual( + result, jax.sharding.NamedSharding(mesh, expected_result_pspec) ) - def test_temp_sens_example(self, placement_name): + def test_temperature_sensors_example(self, placement_name, axes_type): def one_if_over(threshold, value): - return jax.lax.cond(value > threshold, lambda: 1.0, lambda: 0.0) + return jax.lax.cond( + value > threshold, + lambda: jnp.ones_like(value), + lambda: jnp.zeros_like(value), + ) placement_dim = 100 + mesh = jax.sharding.Mesh( + np.array(jax.devices()).reshape([4, 2]), + axis_names=(placement_name, "some_axis"), + axis_types=(axes_type, axes_type), + ) + jax.set_mesh(mesh) @drjax_program(placements={placement_name: placement_dim}) - def temp_sens_example(threshold, values): + def temperature_sensors_example(threshold, values): threshold_at_clients = api.broadcast(threshold) values_over = api.map_fn(one_if_over, (threshold_at_clients, values)) return api.reduce_mean(values_over) - measurements = jnp.arange(placement_dim) + measurements = jax.device_put( + jnp.arange(placement_dim), + jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec(placement_name) + ), + ) + + self.assertEqual(temperature_sensors_example(24, measurements), 0.75) - self.assertEqual(temp_sens_example(24, measurements), 0.75) + def test_temperature_sensors_example_multiple_placement_values( + self, placement_name, axes_type + ): - def test_temp_sens_example_multiple_placement_values(self, placement_name): def one_if_over(threshold, value): - return jax.lax.cond(value > threshold, lambda: 1.0, lambda: 0.0) + return jax.lax.cond( + value > threshold, + lambda: jnp.ones_like(value), + lambda: jnp.zeros_like(value), + ) + + mesh = jax.sharding.Mesh( + np.array(jax.devices()).reshape([4, 2]), + axis_names=(placement_name, "some_axis"), + axis_types=(axes_type, axes_type), + ) + jax.set_mesh(mesh) @drjax_program(placements={placement_name: 100}) - def temp_sens_example_100_clients(threshold, values): + def temperature_sensors_example_100_clients(threshold, values): threshold_at_clients = api.broadcast(threshold) values_over = api.map_fn(one_if_over, (threshold_at_clients, values)) - return api.reduce_mean(values_over) - @drjax_program(placements={placement_name: 10}) - def temp_sens_example_10_clients(threshold, values): + @drjax_program(placements={placement_name: 20}) + def temperature_sensors_example_20_clients(threshold, values): threshold_at_clients = api.broadcast(threshold) values_over = api.map_fn(one_if_over, (threshold_at_clients, values)) return api.reduce_mean(values_over) - measurements_100 = jnp.arange(100) - measurements_10 = jnp.arange(10) + placement_sharding = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec(placement_name) + ) + measurements_100 = jax.device_put(jnp.arange(100), placement_sharding) + measurements_20 = jax.device_put(jnp.arange(20), placement_sharding) - self.assertEqual(temp_sens_example_100_clients(24, measurements_100), 0.75) self.assertEqual( - temp_sens_example_10_clients(3, measurements_10), - 0.6, + temperature_sensors_example_100_clients(24, measurements_100), 0.75 + ) + self.assertEqual( + temperature_sensors_example_20_clients(3, measurements_20), + 0.8, ) # We should be able to recover the original result flipping back to the # original function. - self.assertEqual(temp_sens_example_100_clients(24, measurements_100), 0.75) + self.assertEqual( + temperature_sensors_example_100_clients(24, measurements_100), 0.75 + ) + - def test_multiple_placements_raises(self, placement_name): +class ApiErrorsTest(absltest.TestCase): + + def test_multiple_placements_raises(self): + placement_name = "XY" with self.assertRaises(ValueError): @@ -157,13 +224,13 @@ def test_multiple_placements_raises(self, placement_name): def _(values): return api.reduce_mean(values) - def test_raises_outside_program_context(self, placement_name): + def test_raises_outside_program_context(self): with self.assertRaises(api.OperatorUndefinedError): api.broadcast(jnp.array(0.5)) num_clients = 10 - @drjax_program(placements={placement_name: num_clients}) + @drjax_program(placements={"xy": num_clients}) def test(values): return api.reduce_mean(values) @@ -174,11 +241,9 @@ def test(values): with self.assertRaises(api.OperatorUndefinedError): api.broadcast(jnp.array(0.5)) - def test_broadcast_raises_type_error_within_program_context( - self, placement_name - ): + def test_broadcast_raises_type_error_within_program_context(self): - @drjax_program(placements={placement_name: 1}) + @drjax_program(placements={"xy": 1}) def test(*args): return api.broadcast(*args) @@ -187,11 +252,9 @@ def test(*args): ): test(jnp.array(0.5), jnp.array(0.5)) - def test_map_fn_raises_type_error_within_program_context( - self, placement_name - ): + def test_map_fn_raises_type_error_within_program_context(self): - @drjax_program(placements={placement_name: 1}) + @drjax_program(placements={"xy": 1}) def test(*args): return api.map_fn(lambda x: x, *args) @@ -200,10 +263,9 @@ def test(*args): ): test(jnp.array(0.5), jnp.array(0.5)) - def test_reduce_sum_raises_type_error_within_program_context( - self, placement_name - ): - @drjax_program(placements={placement_name: 1}) + def test_reduce_sum_raises_type_error_within_program_context(self): + + @drjax_program(placements={"xy": 1}) def test(*args): return api.reduce_sum(*args) @@ -213,10 +275,9 @@ def test(*args): ): test(jnp.array(0.5), jnp.array(0.5)) - def test_reduce_mean_raises_type_error_within_program_context( - self, placement_name - ): - @drjax_program(placements={placement_name: 1}) + def test_reduce_mean_raises_type_error_within_program_context(self): + + @drjax_program(placements={"xy": 1}) def test(*args): return api.reduce_mean(*args) @@ -226,18 +287,47 @@ def test(*args): ): test(jnp.array(0.5), jnp.array(0.5)) - def test_map_fn_error_propagates(self, placement_name): + def test_map_fn_error_propagates(self): + test_msg = "This is a test value error." def foo(_): raise ValueError(test_msg) - @drjax_program(placements={placement_name: 1}) + @drjax_program(placements={"clients": 1}) def trigger_error(x): return api.map_fn(foo, x) with self.assertRaisesRegex(ValueError, test_msg): trigger_error(jnp.asarray([0])) + def test_apis_with_mixed_mode_mesh_axes_raise_error(self): + + mesh = jax.sharding.Mesh( + np.array(jax.devices()).reshape([4, 2]), + axis_names=("xy", "some_axis"), + axis_types=(jax.sharding.AxisType.Explicit, jax.sharding.AxisType.Auto), + ) + with jax.set_mesh(mesh): + with self.subTest("map"), self.assertRaisesRegex( + ValueError, "Mesh axis types must all be either auto or manual" + ): + + @drjax_program(placements={"xy": 1}) + def test_map(x): + return api.map_fn(lambda arr: arr, x) + + test_map(jnp.asarray([0])) + + with self.subTest("broadcast"), self.assertRaisesRegex( + ValueError, "Mesh axis types must all be either auto or manual" + ): + + @drjax_program(placements={"xy": 1}) + def test_broadcast(x): + return api.broadcast(x) + + test_broadcast(jnp.asarray([0])) + # This allows us to test sharding behavior across multiple devices. def setUpModule(): diff --git a/drjax/_src/impls.py b/drjax/_src/impls.py index 4a9a56e..d2df1ac 100644 --- a/drjax/_src/impls.py +++ b/drjax/_src/impls.py @@ -14,6 +14,7 @@ """Implementations of MapReduce primitives in JAX.""" from collections.abc import Mapping +import functools from typing import Any from absl import logging @@ -45,6 +46,18 @@ def call_jaxpr(fn, arg): return fn(arg) +def _is_all_auto_axis(mesh: jax.sharding.Mesh): + if mesh.axis_types is None: + return True + return all(a == jax.sharding.AxisType.Auto for a in mesh.axis_types) + + +def _is_all_explicit_axis(mesh: jax.sharding.Mesh): + if mesh.axis_types is None: + return False + return all(a == jax.sharding.AxisType.Explicit for a in mesh.axis_types) + + # TODO(b/366437841): Remove use of pxla.thread_resources.env.physical_mesh, # which is a JAX internal API. def _global_mesh( @@ -91,6 +104,26 @@ def _constrain_if_mesh( ) +def _constrain_alike_if_mesh( + mesh: jax.sharding.Mesh | jax.sharding.AbstractMesh | None, + x: PlacedArray, + y: UnplacedArray, + pspec: jax.sharding.PartitionSpec, +) -> PlacedArray: + """Constrains the non-leading dimensions of `x` to be sharded like `y`.""" + if mesh is None: + return x + + def _shard_slice_like_arg(s): + s_sharded, _ = shard_alike(s, y) + return s_sharded + + original_dims_constrained = jax.vmap(_shard_slice_like_arg, in_axes=0)(x) + return jax.lax.with_sharding_constraint( + original_dims_constrained, jax.sharding.NamedSharding(mesh, pspec) + ) + + class PlacedComputations: """Concrete implementations of MapReduce primitives in JAX.""" @@ -102,6 +135,7 @@ def __init__( self._placements_to_n_elements = placements_to_n_elements self._use_abstract_mesh = use_abstract_mesh + @functools.partial(jax.named_call, name='drjax_broadcast') def broadcast_to_placement( self, arg: UnplacedArray, @@ -141,16 +175,6 @@ def broadcast_to_placement( arg = jnp.array(arg) n_elements = self._placements_to_n_elements[placement] - # Note that this pspec will only result in a sharding constraint defined if - # a mesh is installed at tracing time. - if _placement_axis_in_mesh(mesh, placement): - pspec = P(placement, *([P.UNCONSTRAINED] * len(arg.shape))) - else: - # Without a placement axis in the mesh, we simply explicitly tell the - # compiler that there are no constraints on this tensor. This will leave - # the choices in the hands of the compiler. - pspec = P(*([P.UNCONSTRAINED] * (len(arg.shape) + 1))) - def single_arg_broadcast(x): unconstrained_tensor = jnp.tile(x, reps=[n_elements] + [1] * len(x.shape)) if mesh is None: @@ -162,18 +186,35 @@ def single_arg_broadcast(x): ) return unconstrained_tensor else: - - def _shard_slice_like_arg(s): - s_sharded, _ = shard_alike(s, x) - return s_sharded - - original_dims_constrained = jax.vmap(_shard_slice_like_arg, in_axes=0)( - unconstrained_tensor - ) - fully_constrained = _constrain_if_mesh( - mesh, original_dims_constrained, pspec - ) - return fully_constrained + if _is_all_auto_axis(mesh): + if _placement_axis_in_mesh(mesh, placement): + pspec = P(placement, *([P.UNCONSTRAINED] * len(arg.shape))) + else: + # Without a placement axis in the mesh, we simply explicitly tell + # the compiler that there are no constraints on this tensor. This + # will leave the choices in the hands of the compiler. + pspec = P(*([P.UNCONSTRAINED] * (len(arg.shape) + 1))) + return _constrain_alike_if_mesh(mesh, unconstrained_tensor, x, pspec) + elif _is_all_explicit_axis(mesh): + input_sharding = jax.typeof(x).sharding + if _placement_axis_in_mesh(mesh, placement): + out_sharding = jax.sharding.NamedSharding( + input_sharding.mesh, P(placement, *input_sharding.spec) + ) + else: + # With explicit axes, when a placement axis is not in the mesh, + # we must ask for replication (`None` sharding). + out_sharding = jax.sharding.NamedSharding( + input_sharding.mesh, P(None, *input_sharding.spec) + ) + return jax.sharding.reshard( + unconstrained_tensor, out_shardings=out_sharding + ) + else: + raise ValueError( + 'Mesh axis types must all be either auto or manual, but got' + f' {mesh.axis_types}.' + ) return jax.jit(single_arg_broadcast)(arg) @@ -202,6 +243,7 @@ def sum_from_placement(self, arg: PlacedArray) -> UnplacedArray: placement_idx = 0 return jnp.sum(arg, axis=[placement_idx]) + @functools.partial(jax.named_call, name='drjax_map') def map_to_placement( self, fn, @@ -248,45 +290,64 @@ def map_to_placement( if mesh is None: mesh = _global_mesh(self._use_abstract_mesh) - def _constrain_at_placement_with_slices_like(x, y): - pspec = P(placement, *([P.UNCONSTRAINED] * (len(x.shape) - 1))) - placement_constrained = _constrain_if_mesh(mesh, x, pspec) - - def _shard_slice(s): - s_sharded, _ = shard_alike(s, y[0]) - return s_sharded - - return jax.vmap(_shard_slice, in_axes=0)(placement_constrained) - - # `spmd_axis_name`` causes any internal with_sharding_constraints or - # shard_map calls inside the `vmapped` function to respect the - # sharding along this axis name. But it doesn't enrich annotations on - # input / output tensors. Since we have a very limited mapping semantic - # here, adding these annotations is always safe for us, as long as - # `placement` is in the mesh. if _placement_axis_in_mesh(mesh, placement): - arg = jax.tree_util.tree_map( - _constrain_at_placement_with_slices_like, arg, arg - ) - mapped_fn = jax.vmap( - # We must not have an `axis_name` argument here in order to work - # with any potential `shard_map` inside of `fn`. - fn, - in_axes=0, - out_axes=0, - spmd_axis_name=placement, - ) + if _is_all_auto_axis(mesh): + # `vmap(..., spmd_axis_name=)` causes any internal + # with_sharding_constraints or shard_map calls inside the `vmapped` + # function to respect the sharding along this axis name. But it doesn't + # enrich annotations on input / output tensors. Since we have a very + # limited mapping semantic here, adding these annotations is always safe + # for us, as long as `placement` is in the mesh. + def _constrain_at_placement_with_slices_like(x): + pspec = P(placement, *([P.UNCONSTRAINED] * (len(x.shape) - 1))) + return _constrain_alike_if_mesh(mesh, x, x[0], pspec) + + arg = jax.tree_util.tree_map( + _constrain_at_placement_with_slices_like, arg + ) + mapped_fn = jax.vmap( + # We must not have an `axis_name` argument here in order to work + # with any potential `shard_map` inside of `fn`. + fn, + in_axes=0, + out_axes=0, + spmd_axis_name=placement, + ) - # In some cases, vmap may prevent placement sharding from propagating. We - # ensure placement sharding on the output just in case. - result = call_jaxpr(mapped_fn, arg) - return jax.tree_util.tree_map( - _constrain_at_placement_with_slices_like, result, result - ) - else: + # In some cases, vmap may prevent placement sharding from propagating. + # We ensure placement sharding on the output just in case. + result = call_jaxpr(mapped_fn, arg) + return jax.tree_util.tree_map( + _constrain_at_placement_with_slices_like, result + ) + elif _is_all_explicit_axis(mesh): + mapped_fn = jax.vmap( + fn, + axis_name=placement, + in_axes=0, + out_axes=0, + ) + result = call_jaxpr(mapped_fn, arg) + # Ensure the result is sharded along the placement axis when using + # explicit axes. + return jax.tree_util.tree_map( + lambda arr: jax.sharding.reshard( + arr, + jax.sharding.NamedSharding( + mesh, P(placement, *arr.sharding.spec[1:]) + ), + ), + result, + ) + else: + raise ValueError( + 'Mesh axis types must all be either auto or manual, but got' + f' {mesh!r}.' + ) + else: # Placement is not in the mesh. logging.warning( - 'No mesh containing axis name %s found; defaulting to standard vmap.' - ' Mesh contains names: %s', + 'No mesh containing axis name %s found; defaulting to standard' + ' vmap. Mesh contains names: %s', placement, mesh.axis_names if mesh is not None else 'None', ) @@ -297,5 +358,6 @@ def _shard_slice(s): fn, axis_name=placement, in_axes=0, + out_axes=0, ) return call_jaxpr(mapped_fn, arg) diff --git a/drjax/_src/impls_sharding_test.py b/drjax/_src/impls_sharding_test.py index 1f78bfc..1437e81 100644 --- a/drjax/_src/impls_sharding_test.py +++ b/drjax/_src/impls_sharding_test.py @@ -25,6 +25,7 @@ import jax from jax import numpy as jnp from jax.experimental.shard_map import shard_map +from jax.sharding import AxisType from jax.sharding import PartitionSpec as PSpec import numpy as np @@ -43,14 +44,15 @@ # Inline a helper function for creating and manipulating test meshes from # JAX's internals. -def create_global_mesh(mesh_shape, axis_names): +def create_mesh(mesh_shape, axis_names, axis_types): size = math.prod(mesh_shape) if len(jax.devices()) < size: raise unittest.SkipTest(f'Test requires {size} global devices.') devices = sorted(jax.devices(), key=lambda d: d.id) mesh_devices = np.array(devices[:size]).reshape(mesh_shape) - global_mesh = jax.sharding.Mesh(mesh_devices, axis_names) - return global_mesh + return jax.sharding.Mesh( + mesh_devices, axis_names=axis_names, axis_types=axis_types + ) @contextlib.contextmanager @@ -69,11 +71,14 @@ def mesh_context(mesh: jax.sharding.Mesh, use_as_context: bool): Yields: The mesh to pass to the DrJax function under test. """ - if use_as_context: - with mesh: + with contextlib.ExitStack() as exit_stack: + if all(a == jax.sharding.AxisType.Explicit for a in mesh.axis_types): + exit_stack.enter_context(jax.sharding.set_mesh(mesh)) + if use_as_context: + exit_stack.enter_context(mesh) yield None - else: - yield mesh + else: + yield mesh class BroadcastShardingBehaviorTest(parameterized.TestCase): @@ -85,9 +90,18 @@ def setUp(self): placements_to_n_elements=self._placements, ) - @parameterized.parameters(True, False) - def test_broadcast_with_1x1_fully_replicates(self, mesh_as_context): - global_mesh = create_global_mesh([1, 1], [_CLIENTS_AXIS, _DATA_AXIS]) + @parameterized.product( + mesh_as_context=[True, False], + mesh_axes_type=[AxisType.Auto, AxisType.Explicit], + ) + def test_broadcast_with_1x1_fully_replicates( + self, mesh_as_context, mesh_axes_type + ): + global_mesh = create_mesh( + [1, 1], + axis_names=[_CLIENTS_AXIS, _DATA_AXIS], + axis_types=(mesh_axes_type, mesh_axes_type), + ) arg = jnp.zeros(shape=[_DATA_SIZE]) with mesh_context(global_mesh, mesh_as_context) as mesh: result = self._comp_factory.broadcast_to_placement( @@ -100,8 +114,10 @@ def test_broadcast_with_1x1_fully_replicates(self, mesh_as_context): self.assertTrue(sharding.is_fully_replicated) def test_broadcast_clients_with_jax_use_mesh(self): - global_mesh = create_global_mesh( - [_CLIENTS_AXIS_SIZE, _DATA_AXIS_SIZE], [_CLIENTS_AXIS, _DATA_AXIS] + global_mesh = create_mesh( + [_CLIENTS_AXIS_SIZE, _DATA_AXIS_SIZE], + [_CLIENTS_AXIS, _DATA_AXIS], + axis_types=(AxisType.Auto, AxisType.Auto), ) arg = jnp.zeros(shape=[_DATA_SIZE]) with jax.set_mesh(global_mesh): @@ -122,10 +138,17 @@ def test_broadcast_clients_with_jax_use_mesh(self): (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE), ) - @parameterized.parameters(True, False) - def test_broadcast_clients_shards_along_clients(self, mesh_as_context): - global_mesh = create_global_mesh( - [_CLIENTS_AXIS_SIZE, _DATA_AXIS_SIZE], [_CLIENTS_AXIS, _DATA_AXIS] + @parameterized.product( + mesh_as_context=[True, False], + mesh_axes_type=[AxisType.Auto, AxisType.Explicit], + ) + def test_broadcast_clients_shards_along_clients( + self, mesh_as_context, mesh_axes_type + ): + global_mesh = create_mesh( + [_CLIENTS_AXIS_SIZE, _DATA_AXIS_SIZE], + axis_names=[_CLIENTS_AXIS, _DATA_AXIS], + axis_types=(mesh_axes_type, mesh_axes_type), ) arg = jnp.zeros(shape=[_DATA_SIZE]) with mesh_context(global_mesh, mesh_as_context) as mesh: @@ -145,11 +168,16 @@ def test_broadcast_clients_shards_along_clients(self, mesh_as_context): (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE), ) - @parameterized.parameters(True, False) + @parameterized.product( + mesh_as_context=[True, False], + mesh_axes_type=[AxisType.Auto, AxisType.Explicit], + ) def test_broadcast_preserves_sharding_with_no_clients_mesh( - self, mesh_as_context + self, mesh_as_context, mesh_axes_type ): - global_mesh = create_global_mesh([_DATA_AXIS_SIZE], [_DATA_AXIS]) + global_mesh = create_mesh( + [_DATA_AXIS_SIZE], axis_names=[_DATA_AXIS], axis_types=(mesh_axes_type,) + ) arg = jnp.zeros(shape=[_DATA_SIZE]) # Replicating a situation in which the caller's mesh has no clients axis; in # this case, we should preserve the sharding of any broadcast tensors, but @@ -182,13 +210,19 @@ def test_broadcast_preserves_sharding_with_no_clients_mesh( (_NUM_CLIENTS, _DATA_SIZE // _DATA_AXIS_SIZE), ) - @parameterized.parameters(True, False) + @parameterized.product( + mesh_as_context=[True, False], + mesh_axes_type=[AxisType.Auto, AxisType.Explicit], + ) def test_broadcast_preserves_arg_sharding_with_clients_mesh( self, mesh_as_context, + mesh_axes_type, ): - global_mesh = create_global_mesh( - [_CLIENTS_AXIS_SIZE, _DATA_AXIS_SIZE], [_CLIENTS_AXIS, _DATA_AXIS] + global_mesh = create_mesh( + [_CLIENTS_AXIS_SIZE, _DATA_AXIS_SIZE], + axis_names=[_CLIENTS_AXIS, _DATA_AXIS], + axis_types=(mesh_axes_type, mesh_axes_type), ) arg = jnp.zeros(shape=[_DATA_SIZE]) arg_spec = PSpec(_DATA_AXIS) @@ -243,18 +277,23 @@ def setUp(self): self._comp_factory = impls.PlacedComputations( placements_to_n_elements=self._placements, ) - self._global_mesh = create_global_mesh( - [_CLIENTS_AXIS_SIZE, _DATA_AXIS_SIZE], [_CLIENTS_AXIS, _DATA_AXIS] - ) - @parameterized.parameters(True, False) - def test_map_respects_clients_sharding(self, mesh_as_context): + @parameterized.product( + mesh_as_context=[True, False], + mesh_axes_type=[AxisType.Auto, AxisType.Explicit], + ) + def test_map_respects_clients_sharding(self, mesh_as_context, mesh_axes_type): arg1_at_c, arg2_at_c = _place_args_at_clients( jnp.zeros(shape=[_DATA_SIZE]), jnp.ones(shape=[_DATA_SIZE]), comp_factory=self._comp_factory, ) - with mesh_context(self._global_mesh, mesh_as_context) as mesh: + mesh = create_mesh( + [_CLIENTS_AXIS_SIZE, _DATA_AXIS_SIZE], + axis_names=[_CLIENTS_AXIS, _DATA_AXIS], + axis_types=(mesh_axes_type, mesh_axes_type), + ) + with mesh_context(mesh, mesh_as_context) as mesh: result = self._comp_factory.map_to_placement( add, (arg1_at_c, arg2_at_c), _CLIENTS_AXIS, mesh ) @@ -269,13 +308,23 @@ def test_map_respects_clients_sharding(self, mesh_as_context): (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE), ) - @parameterized.parameters(True, False) - def test_map_zeros_like_respects_clients_sharding(self, mesh_as_context): + @parameterized.product( + mesh_as_context=[True, False], + mesh_axes_type=[AxisType.Auto, AxisType.Explicit], + ) + def test_map_zeros_like_respects_clients_sharding( + self, mesh_as_context, mesh_axes_type + ): arg_at_c = _place_args_at_clients( jnp.ones(shape=[_DATA_SIZE]), comp_factory=self._comp_factory, ) - with mesh_context(self._global_mesh, mesh_as_context) as mesh: + mesh = create_mesh( + [_CLIENTS_AXIS_SIZE, _DATA_AXIS_SIZE], + axis_names=[_CLIENTS_AXIS, _DATA_AXIS], + axis_types=(mesh_axes_type, mesh_axes_type), + ) + with mesh_context(mesh, mesh_as_context) as mesh: result = self._comp_factory.map_to_placement( jnp.zeros_like, arg_at_c, _CLIENTS_AXIS, mesh ) @@ -290,23 +339,33 @@ def test_map_zeros_like_respects_clients_sharding(self, mesh_as_context): (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE), ) - @parameterized.parameters(True, False) - def test_map_respects_non_clients_sharding(self, mesh_as_context): + @parameterized.product( + mesh_as_context=[True, False], + mesh_axes_type=[AxisType.Auto, AxisType.Explicit], + ) + def test_map_respects_non_clients_sharding( + self, mesh_as_context, mesh_axes_type + ): + mesh = create_mesh( + [_CLIENTS_AXIS_SIZE, _DATA_AXIS_SIZE], + axis_names=[_CLIENTS_AXIS, _DATA_AXIS], + axis_types=(mesh_axes_type, mesh_axes_type), + ) arg_spec = PSpec(_DATA_AXIS) sharded_arg1 = jax.device_put( jnp.zeros(shape=[_DATA_SIZE]), - device=jax.sharding.NamedSharding(self._global_mesh, arg_spec), + device=jax.sharding.NamedSharding(mesh, arg_spec), ) sharded_arg2 = jax.device_put( jnp.ones(shape=[_DATA_SIZE]), - device=jax.sharding.NamedSharding(self._global_mesh, arg_spec), + device=jax.sharding.NamedSharding(mesh, arg_spec), ) arg1_at_c, arg2_at_c = _place_args_at_clients( sharded_arg1, sharded_arg2, comp_factory=self._comp_factory, ) - with mesh_context(self._global_mesh, mesh_as_context) as mesh: + with mesh_context(mesh, mesh_as_context) as mesh: result = self._comp_factory.map_to_placement( add, (arg1_at_c, arg2_at_c), _CLIENTS_AXIS, mesh ) @@ -328,23 +387,30 @@ def test_map_respects_non_clients_sharding(self, mesh_as_context): (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE // _DATA_AXIS_SIZE), ) - @parameterized.parameters(True, False) + @parameterized.product( + mesh_as_context=[True, False], + mesh_axes_type=[AxisType.Auto, AxisType.Explicit], + ) def test_map_forces_clients_sharding_with_model_parallelism( - self, - mesh_as_context, + self, mesh_as_context, mesh_axes_type ): + mesh = create_mesh( + [_CLIENTS_AXIS_SIZE, _DATA_AXIS_SIZE], + axis_names=[_CLIENTS_AXIS, _DATA_AXIS], + axis_types=(mesh_axes_type, mesh_axes_type), + ) arg_spec = PSpec(_DATA_AXIS) sharded_arg1 = jax.device_put( jnp.zeros(shape=[_DATA_SIZE]), - device=jax.sharding.NamedSharding(self._global_mesh, arg_spec), + device=jax.sharding.NamedSharding(mesh, arg_spec), ) sharded_arg2 = jax.device_put( jnp.ones(shape=[_DATA_SIZE]), - device=jax.sharding.NamedSharding(self._global_mesh, arg_spec), + device=jax.sharding.NamedSharding(mesh, arg_spec), ) sharded_arg1 = jnp.tile(sharded_arg1, reps=[_NUM_CLIENTS, 1]) sharded_arg2 = jnp.tile(sharded_arg2, reps=[_NUM_CLIENTS, 1]) - with mesh_context(self._global_mesh, mesh_as_context) as mesh: + with mesh_context(mesh, mesh_as_context) as mesh: result = self._comp_factory.map_to_placement( add, (sharded_arg1, sharded_arg2), _CLIENTS_AXIS, mesh ) @@ -377,13 +443,23 @@ def test_map_forces_clients_sharding_with_model_parallelism( (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE // _DATA_AXIS_SIZE), ) - @parameterized.parameters(True, False) - def test_map_of_shard_map_fully_shards_result(self, mesh_as_context): + @parameterized.product( + mesh_as_context=[True, False], + mesh_axes_type=[AxisType.Auto, AxisType.Explicit], + ) + def test_map_of_shard_map_fully_shards_result( + self, mesh_as_context, mesh_axes_type + ): + mesh = create_mesh( + [_CLIENTS_AXIS_SIZE, _DATA_AXIS_SIZE], + axis_names=[_CLIENTS_AXIS, _DATA_AXIS], + axis_types=(mesh_axes_type, mesh_axes_type), + ) arg_spec = PSpec(_DATA_AXIS) @functools.partial( shard_map, - mesh=self._global_mesh, + mesh=mesh, in_specs=(arg_spec, arg_spec), out_specs=arg_spec, ) @@ -392,15 +468,15 @@ def shard_map_add(x, y): sharded_arg1 = jax.device_put( jnp.zeros(shape=[_DATA_SIZE]), - device=jax.sharding.NamedSharding(self._global_mesh, arg_spec), + device=jax.sharding.NamedSharding(mesh, arg_spec), ) sharded_arg2 = jax.device_put( jnp.ones(shape=[_DATA_SIZE]), - device=jax.sharding.NamedSharding(self._global_mesh, arg_spec), + device=jax.sharding.NamedSharding(mesh, arg_spec), ) sharded_arg1 = jnp.tile(sharded_arg1, reps=[_NUM_CLIENTS, 1]) sharded_arg2 = jnp.tile(sharded_arg2, reps=[_NUM_CLIENTS, 1]) - with mesh_context(self._global_mesh, mesh_as_context) as mesh: + with mesh_context(mesh, mesh_as_context) as mesh: result = self._comp_factory.map_to_placement( shard_map_add, (sharded_arg1, sharded_arg2), _CLIENTS_AXIS, mesh )