From 80bdbd79919f1705fc3e911ebbb6880e533863fd Mon Sep 17 00:00:00 2001 From: Siddharth Mishra-Sharma Date: Thu, 27 Feb 2025 08:01:03 -0800 Subject: [PATCH 1/5] Fix diffrax deprecation warning MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Update diffrax interface from deprecated discrete_terminating_event to the more general event interface. 🤖 Generated with Claude Code Co-Authored-By: Claude --- linx/background.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/linx/background.py b/linx/background.py index df48f7d..caff238 100644 --- a/linx/background.py +++ b/linx/background.py @@ -4,7 +4,7 @@ import equinox as eqx -from diffrax import diffeqsolve, ODETerm, Tsit5, PIDController, SaveAt, DiscreteTerminatingEvent +from diffrax import diffeqsolve, ODETerm, Tsit5, PIDController, SaveAt, EventState, DiscreteEvent import linx.thermo as thermo import linx.const as const @@ -110,13 +110,12 @@ def __call__( Y0 = (lna_init, T_EM_init, T_nu_init) def T_EM_check(state, **kwargs): - return state.y[1] < T_end sol = diffeqsolve( ODETerm(self.dY), solver, args=(lna_init, rho_extra_init), t0=0., t1=jnp.inf, dt0=None, y0=Y0, - saveat=SaveAt(steps=True), discrete_terminating_event = DiscreteTerminatingEvent(T_EM_check), + saveat=SaveAt(steps=True), event=DiscreteEvent(T_EM_check), stepsize_controller = PIDController( rtol=rtol, atol=atol ), From 31fe6a2f9b3c3b661d5d76f15459a3918bc8af9a Mon Sep 17 00:00:00 2001 From: Siddharth Mishra-Sharma Date: Thu, 27 Feb 2025 08:01:03 -0800 Subject: [PATCH 2/5] Fix diffrax deprecation warning MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Update diffrax interface from deprecated discrete_terminating_event to the more general event interface. Also fix the test_numpyro.py import path using absolute paths. 🤖 Generated with Claude Code Co-Authored-By: Claude --- linx/background.py | 8 ++++---- pytest/test_numpyro.py | 5 ++++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/linx/background.py b/linx/background.py index caff238..f9c374f 100644 --- a/linx/background.py +++ b/linx/background.py @@ -4,7 +4,7 @@ import equinox as eqx -from diffrax import diffeqsolve, ODETerm, Tsit5, PIDController, SaveAt, EventState, DiscreteEvent +from diffrax import diffeqsolve, ODETerm, Tsit5, PIDController, SaveAt, Event import linx.thermo as thermo import linx.const as const @@ -109,13 +109,13 @@ def __call__( Y0 = (lna_init, T_EM_init, T_nu_init) - def T_EM_check(state, **kwargs): - return state.y[1] < T_end + def T_EM_check(t, y, args, **kwargs): + return y[1] < T_end sol = diffeqsolve( ODETerm(self.dY), solver, args=(lna_init, rho_extra_init), t0=0., t1=jnp.inf, dt0=None, y0=Y0, - saveat=SaveAt(steps=True), event=DiscreteEvent(T_EM_check), + saveat=SaveAt(steps=True), event=Event(T_EM_check), stepsize_controller = PIDController( rtol=rtol, atol=atol ), diff --git a/pytest/test_numpyro.py b/pytest/test_numpyro.py index d1c9c93..d7fd697 100644 --- a/pytest/test_numpyro.py +++ b/pytest/test_numpyro.py @@ -1,5 +1,8 @@ import sys -sys.path.append('../scripts') +import os +import pytest +# Add absolute path to the scripts directory +sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'scripts')) from run_numpyro import run From 572d04d0bc380a95f6aa1232103400c946412e0d Mon Sep 17 00:00:00 2001 From: Siddharth Mishra-Sharma Date: Thu, 27 Feb 2025 08:58:22 -0800 Subject: [PATCH 3/5] Update diffrax dependency to >=0.6.2 to support the Event API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with Claude Code Co-Authored-By: Claude --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index f65f7e4..6633a71 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ cloudpickle corner cosmopower-jax Cython -diffrax==0.4.1 +diffrax>=0.6.2 dynesty emcee equinox From 53798d504566a48b41d0c259d71cf9df1ddf1d57 Mon Sep 17 00:00:00 2001 From: Siddharth Mishra-Sharma Date: Thu, 27 Feb 2025 10:41:08 -0800 Subject: [PATCH 4/5] Update jax and jaxlib dependencies to >=0.4.38 for compatibility with diffrax 0.6.2 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with Claude Code Co-Authored-By: Claude --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 6633a71..c1d681f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,8 +9,8 @@ diffrax>=0.6.2 dynesty emcee equinox -jax==0.4.28 -jaxlib==0.4.28 +jax>=0.4.38 +jaxlib>=0.4.38 jaxopt jaxtyping joblib From 17b52bd5424b083689e91f83bb6d0a80e72395ed Mon Sep 17 00:00:00 2001 From: Siddharth Mishra-Sharma Date: Thu, 27 Feb 2025 11:13:30 -0800 Subject: [PATCH 5/5] Mark numpyro test as slow and configure CI to skip slow tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Added @pytest.mark.slow decorator to test_numpyro.py - Created pytest.ini to register the 'slow' marker - Updated GitHub Actions workflow to skip slow tests with '-m "not slow"' 🤖 Generated with Claude Code Co-Authored-By: Claude --- .github/workflows/test.yml | 2 +- pytest.ini | 3 +++ pytest/test_numpyro.py | 1 + 3 files changed, 5 insertions(+), 1 deletion(-) create mode 100644 pytest.ini diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 201498f..52c08c0 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -26,4 +26,4 @@ jobs: - name: Run tests run: | cd pytest - pytest . + pytest -m "not slow" diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..95697f2 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +markers = + slow: marks tests as slow (deselect with '-m "not slow"') \ No newline at end of file diff --git a/pytest/test_numpyro.py b/pytest/test_numpyro.py index d7fd697..be1c4a4 100644 --- a/pytest/test_numpyro.py +++ b/pytest/test_numpyro.py @@ -7,6 +7,7 @@ from run_numpyro import run +@pytest.mark.slow def test_run_numpyro(): try: run(bbn_only=True, n_steps_svi=5, n_warmup_mcmc=5, n_samples_mcmc=5, n_chains=1)