diff --git a/drjax/_src/impls_test.py b/drjax/_src/impls_test.py index cad507d..ff91f6a 100644 --- a/drjax/_src/impls_test.py +++ b/drjax/_src/impls_test.py @@ -48,13 +48,10 @@ def temp_sens_example(m, t): ) return comp_factory.mean_from_placement(total_over) - key = jax.random.PRNGKey(2) - random_measurements = jax.random.uniform( - key, shape=[self._placements['clients']] - ) + measurements = jnp.arange(self._placements['clients']) self.assertEqual( - temp_sens_example(random_measurements, jnp.array(0.5)), 0.53 + temp_sens_example(measurements, jnp.median(measurements)), 0.5 ) def test_runs_fake_training(self):