diff --git a/drjax/_src/primitives.py b/drjax/_src/primitives.py index ac7ff0f..fdb01d8 100644 --- a/drjax/_src/primitives.py +++ b/drjax/_src/primitives.py @@ -56,7 +56,6 @@ def _register_broadcast_impls( broadcast_prim_fn: BroadcastType, broadcast_array_eval: BroadcastType, sum_prim_fn: AggType, - placement_str: str, n_elements: int, ) -> None: """Registers implementations for the broadcast primitive. @@ -76,27 +75,13 @@ def _register_broadcast_impls( sum_prim_fn: A callable which binds its arguments to the summation primitive from the placement inserted by this broadcast. Similar to `broadcast_prim_fn`. - placement_str: The name of the placement which this broadcast targets. n_elements: The number of elements present at the placement which this broadcast targets. """ def broadcast_abstract_eval(xs, *, mesh): - sharding_axis = ( - placement_str - if impls._placement_axis_in_mesh(mesh, placement_str) # pylint: disable=protected-access - else None - ) - new_sharding = xs.sharding.update( - spec=jax.sharding.PartitionSpec(sharding_axis, *xs.sharding.spec) - ) - return core.ShapedArray( - shape=(n_elements,) + xs.shape, - dtype=xs.dtype, - weak_type=xs.weak_type, - sharding=new_sharding, - memory_space=xs.memory_space, - ) + del mesh + return core.ShapedArray((n_elements,) + xs.shape, xs.dtype) # Abstract eval rule. broadcast_p.def_abstract_eval(broadcast_abstract_eval) @@ -116,11 +101,11 @@ def broadcast_jvp(primals_in, tangents_in, mesh): ad.primitive_jvps[broadcast_p] = broadcast_jvp def broadcast_vjp(cotangents_out, primals_in, mesh): - del mesh # Unused. + del mesh if isinstance(cotangents_out, jax.interpreters.ad.Zero): # We are differerentiating back through a broadcast; the incoming value, # therefore, has the right shape and dtype for the Zero we generate. - return (jax.interpreters.ad.Zero.from_primal_value(primals_in.aval),) + return (jax.interpreters.ad.Zero(primals_in.aval),) # This implementation *must* use the sum_prim_fn, rather than the array # implementation of summation, to result in a reduce_sum in the Jaxpr. return (sum_prim_fn(cotangents_out),) @@ -172,21 +157,11 @@ def _register_single_arg_agg_impls( """ def agg_abstract_eval(xs): - - def aval_with_new_sharding(x): - # We slice away the first dimension in doing the reduction; its gone! - new_sharding = x.sharding.update( - spec=jax.sharding.PartitionSpec(*x.sharding.spec[1:]) - ) - return core.ShapedArray( - shape=x.shape[1:], - dtype=x.dtype, - weak_type=x.weak_type, - sharding=new_sharding, - memory_space=x.memory_space, - ) - - return jax.tree.map(aval_with_new_sharding, xs) + return jax.tree_util.tree_map( + # We slice away the first dimension in doing the reduction; its gone! + lambda x: core.ShapedArray(x.shape[1:], x.dtype), + xs, + ) # Abstract eval rule agg_p.def_abstract_eval(agg_abstract_eval) @@ -219,7 +194,7 @@ def agg_vjp(cotangents_out, primals_in): # generate. This is always correct if jax's symbolic Zero is a static # concept, depending on data flow in the program (rather than e.g. runtime # values). - return (jax.interpreters.ad.Zero.from_primal_value(primals_in),) + return (jax.interpreters.ad.Zero(primals_in.aval),) return (vjp_impl(cotangents_out),) ad.primitive_transposes[agg_p] = agg_vjp @@ -282,7 +257,6 @@ def broadcast_array_eval(x, *, mesh): broadcast_prim_fn, broadcast_array_eval, sum_prim_fn, - placement_str, n_elements, ) diff --git a/drjax/_src/primitives_test.py b/drjax/_src/primitives_test.py index 99cb407..daa1eca 100644 --- a/drjax/_src/primitives_test.py +++ b/drjax/_src/primitives_test.py @@ -12,9 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Sequence -import functools - from absl.testing import absltest from absl.testing import parameterized import chex @@ -22,8 +19,6 @@ from drjax._src import primitives import jax from jax import numpy as jnp -from jax.sharding import AxisType # pylint: disable=g-importing-member -import numpy as np def _jaxpr_has_primitive(jaxpr, prim_name: str): @@ -37,33 +32,6 @@ def _jaxpr_has_primitive(jaxpr, prim_name: str): return False -def create_mesh( - axis_type: AxisType, -) -> jax.sharding.Mesh: - return jax.sharding.Mesh( - np.asarray(jax.devices()).reshape(2, 4), - axis_names=('clients', 'data'), - axis_types=(axis_type, axis_type), - ) - - -def run_in_mesh(mesh_axes_types: Sequence[AxisType]): - - def _decorator(fn): - - @functools.wraps(fn) - def _wrapped(self, *args, **kwargs): - for mesh_axes_type in mesh_axes_types: - with self.subTest(f'{mesh_axes_type=}'): - mesh = create_mesh(mesh_axes_type) - with jax.set_mesh(mesh), mesh: - fn(self, *args, **kwargs) - - return _wrapped - - return _decorator - - class PrimitivesActingOnArraysTest(parameterized.TestCase): def setUp(self): @@ -76,7 +44,6 @@ def setUp(self): {'clients': self._n_clients}, ) - @run_in_mesh((AxisType.Auto, AxisType.Explicit)) def test_broadcast_clients_evaluation(self): fn = self._primdefs['broadcast_clients'] # Check that this function is callable. @@ -91,13 +58,11 @@ def test_broadcast_clients_evaluation(self): chex.assert_trees_all_close( jax.jacfwd(fn)(jnp.array(1.0)), jnp.ones(shape=[self._n_clients]) ) - # Also that it's reverse-diffable. chex.assert_trees_all_close( jax.jacrev(fn)(jnp.array(1.0)), jnp.ones(shape=[self._n_clients]) ) - @run_in_mesh((AxisType.Auto, AxisType.Explicit)) def test_broadcast_clients_closure_under_fad(self): fn = self._primdefs['broadcast_clients'] # Check that the forward and reverse-mode derivatives generate the expected @@ -107,7 +72,6 @@ def test_broadcast_clients_closure_under_fad(self): rev_mode_jaxpr = jax.make_jaxpr(jax.jacrev(fn))(jnp.array(1.0)) self.assertTrue(_jaxpr_has_primitive(rev_mode_jaxpr, 'sum_from_clients')) - @run_in_mesh((AxisType.Auto, AxisType.Explicit)) def test_sum_from_clients_evaluation(self): fn = self._primdefs['sum_from_clients'] clients_ones = jnp.ones(shape=[self._n_clients, 1]) @@ -128,7 +92,6 @@ def test_sum_from_clients_evaluation(self): jax.jacrev(fn)(clients_ones), jnp.ones(shape=[1, self._n_clients, 1]) ) - @run_in_mesh((AxisType.Auto, AxisType.Explicit)) def test_broadcast_and_sum_from_clients_eval(self): fn = self._primdefs['sum_from_clients'] @@ -148,7 +111,6 @@ def _broadcast_then_sum(x): jnp.array([[1.0 * self._n_clients]]), ) - @run_in_mesh((AxisType.Auto, AxisType.Explicit)) def test_sum_from_clients_closure_under_fad(self): # Check that the forward and reverse-mode derivatives generate the expected # primitives. @@ -159,7 +121,6 @@ def test_sum_from_clients_closure_under_fad(self): rev_mode_jaxpr = jax.make_jaxpr(jax.jacrev(fn))(clients_ones) self.assertTrue(_jaxpr_has_primitive(rev_mode_jaxpr, 'broadcast_clients')) - @run_in_mesh((AxisType.Auto, AxisType.Explicit)) def test_mean_from_clients_eval(self): fn = self._primdefs['mean_from_clients'] clients_ones = jnp.ones(shape=[self._n_clients, 1]) @@ -173,7 +134,6 @@ def test_mean_from_clients_eval(self): 1 / self._n_clients * jnp.ones(shape=[1, self._n_clients, 1]), ) - @run_in_mesh((AxisType.Auto, AxisType.Explicit)) def test_broadcast_then_mean_from_clients_eval(self): fn = self._primdefs['mean_from_clients'] @@ -191,7 +151,6 @@ def _broadcast_then_sum(x): jnp.array([[1.0]]), ) - @run_in_mesh((AxisType.Auto, AxisType.Explicit)) def test_mean_from_clients_closure_under_fad(self): # Check that the forward and reverse-mode derivatives generate the expected # primitives. @@ -244,10 +203,5 @@ def ignore_prim_result(x): ) -# This allows us to test sharding behavior across multiple devices. -def setUpModule(): - chex.set_n_cpu_devices(8) - - if __name__ == '__main__': absltest.main()