diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 201498f..52c08c0 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -26,4 +26,4 @@ jobs: - name: Run tests run: | cd pytest - pytest . + pytest -m "not slow" diff --git a/linx/background.py b/linx/background.py index df48f7d..f9c374f 100644 --- a/linx/background.py +++ b/linx/background.py @@ -4,7 +4,7 @@ import equinox as eqx -from diffrax import diffeqsolve, ODETerm, Tsit5, PIDController, SaveAt, DiscreteTerminatingEvent +from diffrax import diffeqsolve, ODETerm, Tsit5, PIDController, SaveAt, Event import linx.thermo as thermo import linx.const as const @@ -109,14 +109,13 @@ def __call__( Y0 = (lna_init, T_EM_init, T_nu_init) - def T_EM_check(state, **kwargs): - - return state.y[1] < T_end + def T_EM_check(t, y, args, **kwargs): + return y[1] < T_end sol = diffeqsolve( ODETerm(self.dY), solver, args=(lna_init, rho_extra_init), t0=0., t1=jnp.inf, dt0=None, y0=Y0, - saveat=SaveAt(steps=True), discrete_terminating_event = DiscreteTerminatingEvent(T_EM_check), + saveat=SaveAt(steps=True), event=Event(T_EM_check), stepsize_controller = PIDController( rtol=rtol, atol=atol ), diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..95697f2 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +markers = + slow: marks tests as slow (deselect with '-m "not slow"') \ No newline at end of file diff --git a/pytest/test_numpyro.py b/pytest/test_numpyro.py index d1c9c93..be1c4a4 100644 --- a/pytest/test_numpyro.py +++ b/pytest/test_numpyro.py @@ -1,9 +1,13 @@ import sys -sys.path.append('../scripts') +import os +import pytest +# Add absolute path to the scripts directory +sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'scripts')) from run_numpyro import run +@pytest.mark.slow def test_run_numpyro(): try: run(bbn_only=True, n_steps_svi=5, n_warmup_mcmc=5, n_samples_mcmc=5, n_chains=1) diff --git a/requirements.txt b/requirements.txt index f65f7e4..c1d681f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,12 +5,12 @@ cloudpickle corner cosmopower-jax Cython -diffrax==0.4.1 +diffrax>=0.6.2 dynesty emcee equinox -jax==0.4.28 -jaxlib==0.4.28 +jax>=0.4.38 +jaxlib>=0.4.38 jaxopt jaxtyping joblib