diff --git a/drjax/_src/api_test.py b/drjax/_src/api_test.py index 36eca96..0685034 100644 --- a/drjax/_src/api_test.py +++ b/drjax/_src/api_test.py @@ -61,18 +61,17 @@ def test_temp_sens_example(self, placement_name): def one_if_over(threshold, value): return jax.lax.cond(value > threshold, lambda: 1.0, lambda: 0.0) - @drjax_program(placements={placement_name: 100}) + placement_dim = 100 + + @drjax_program(placements={placement_name: placement_dim}) def temp_sens_example(threshold, values): threshold_at_clients = api.broadcast(threshold) values_over = api.map_fn(one_if_over, (threshold_at_clients, values)) return api.reduce_mean(values_over) - key = jax.random.PRNGKey(2) - random_measurements = jax.random.uniform(key, shape=[100]) + measurements = jnp.arange(placement_dim) - self.assertEqual( - temp_sens_example(jnp.array(0.5), random_measurements), 0.53 - ) + self.assertEqual(temp_sens_example(24, measurements), 0.75) def test_temp_sens_example_multiple_placement_values(self, placement_name): def one_if_over(threshold, value): @@ -91,31 +90,24 @@ def temp_sens_example_10_clients(threshold, values): values_over = api.map_fn(one_if_over, (threshold_at_clients, values)) return api.reduce_mean(values_over) - key = jax.random.PRNGKey(2) - random_measurements_100 = jax.random.uniform(key, shape=[100]) - random_measurements_10 = jax.random.uniform(key, shape=[10]) + measurements_100 = jnp.arange(100) + measurements_10 = jnp.arange(10) + self.assertEqual(temp_sens_example_100_clients(24, measurements_100), 0.75) self.assertEqual( - temp_sens_example_100_clients(jnp.array(0.5), random_measurements_100), - 0.53, - ) - self.assertEqual( - temp_sens_example_10_clients(jnp.array(0.5), random_measurements_10), - 0.4, + temp_sens_example_10_clients(3, measurements_10), + 0.6, ) # We should be able to recover the original result flipping back to the # original function. - self.assertEqual( - temp_sens_example_100_clients(jnp.array(0.5), random_measurements_100), - 0.53, - ) + self.assertEqual(temp_sens_example_100_clients(24, measurements_100), 0.75) def test_multiple_placements_raises(self, placement_name): with self.assertRaises(ValueError): @drjax_program(placements={placement_name: 1, placement_name + "x": 1}) - def test(values): + def _(values): return api.reduce_mean(values) def test_raises_outside_program_context(self, placement_name):