Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 12 additions & 20 deletions drjax/_src/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,17 @@ def test_temp_sens_example(self, placement_name):
def one_if_over(threshold, value):
return jax.lax.cond(value > threshold, lambda: 1.0, lambda: 0.0)

@drjax_program(placements={placement_name: 100})
placement_dim = 100

@drjax_program(placements={placement_name: placement_dim})
def temp_sens_example(threshold, values):
threshold_at_clients = api.broadcast(threshold)
values_over = api.map_fn(one_if_over, (threshold_at_clients, values))
return api.reduce_mean(values_over)

key = jax.random.PRNGKey(2)
random_measurements = jax.random.uniform(key, shape=[100])
measurements = jnp.arange(placement_dim)

self.assertEqual(
temp_sens_example(jnp.array(0.5), random_measurements), 0.53
)
self.assertEqual(temp_sens_example(24, measurements), 0.75)

def test_temp_sens_example_multiple_placement_values(self, placement_name):
def one_if_over(threshold, value):
Expand All @@ -91,31 +90,24 @@ def temp_sens_example_10_clients(threshold, values):
values_over = api.map_fn(one_if_over, (threshold_at_clients, values))
return api.reduce_mean(values_over)

key = jax.random.PRNGKey(2)
random_measurements_100 = jax.random.uniform(key, shape=[100])
random_measurements_10 = jax.random.uniform(key, shape=[10])
measurements_100 = jnp.arange(100)
measurements_10 = jnp.arange(10)

self.assertEqual(temp_sens_example_100_clients(24, measurements_100), 0.75)
self.assertEqual(
temp_sens_example_100_clients(jnp.array(0.5), random_measurements_100),
0.53,
)
self.assertEqual(
temp_sens_example_10_clients(jnp.array(0.5), random_measurements_10),
0.4,
temp_sens_example_10_clients(3, measurements_10),
0.6,
)
# We should be able to recover the original result flipping back to the
# original function.
self.assertEqual(
temp_sens_example_100_clients(jnp.array(0.5), random_measurements_100),
0.53,
)
self.assertEqual(temp_sens_example_100_clients(24, measurements_100), 0.75)

def test_multiple_placements_raises(self, placement_name):

with self.assertRaises(ValueError):

@drjax_program(placements={placement_name: 1, placement_name + "x": 1})
def test(values):
def _(values):
return api.reduce_mean(values)

def test_raises_outside_program_context(self, placement_name):
Expand Down
Loading