From 6a72eeed1ce48ab37afdc4c398de9e5d97b151fb Mon Sep 17 00:00:00 2001 From: Siddharth Mishra-Sharma Date: Thu, 22 Jan 2026 17:41:13 -0500 Subject: [PATCH] Use interpax.interp1d to support non-monotonic temperature evolution Replace jnp.interp with interpax.interp1d in abundances.py, nuclear.py, and weak_rates.py to support reheating scenarios where photon temperature may not be monotonically decreasing. jnp.interp requires monotonically increasing x-coordinates, which breaks when T_g_vec is non-monotonic (e.g., during reheating where temperature can increase before decreasing again). The fix sorts arrays by the x-coordinate before interpolation. Changes: - Add interpax dependency to requirements.txt - abundances.py: Sort by log(T_g) for a_start, t_start, t_end interpolations and by log(rho_tot) for P_tot interpolations in get_t/get_a - weak_rates.py: Sort by Tg for neutrino-to-photon temperature ratio lookup - nuclear.py: Sort by T_interval for weak rate interpolations Note: thermo.py and reactions.py are unchanged as they use pre-tabulated data that is guaranteed to be monotonic. Closes #15 Co-Authored-By: Claude Opus 4.5 --- linx/abundances.py | 88 +++++++++++++++++++++++++++++++--------------- linx/nuclear.py | 27 +++++++++----- linx/weak_rates.py | 21 +++++++---- requirements.txt | 1 + 4 files changed, 95 insertions(+), 42 deletions(-) diff --git a/linx/abundances.py b/linx/abundances.py index 1e70b47..466910e 100644 --- a/linx/abundances.py +++ b/linx/abundances.py @@ -5,6 +5,7 @@ import jax.numpy as jnp import equinox as eqx from diffrax import diffeqsolve, ODETerm, Tsit5, Kvaerno3, PIDController, SaveAt +import interpax import linx.nuclear as nucl import linx.const as const @@ -192,32 +193,44 @@ def __call__( # These are in MeV T_g_vec = thermo.T_g(rho_g_vec) - T_nu_vec = thermo.T_nu(rho_nu_vec) - - a_start = jnp.exp( - jnp.interp( - jnp.log(T_start), - jnp.flip(jnp.log(T_g_vec)), - jnp.flip(jnp.log(a_vec)), - left=jnp.log(a_vec[-1]), right=jnp.log(a_vec[0]) + T_nu_vec = thermo.T_nu(rho_nu_vec) + + # Sort by log(T_g) to handle non-monotonic temperature evolution + # (e.g., in reheating scenarios). interpax.interp1d requires + # monotonically increasing x coordinates. + log_T_g_vec = jnp.log(T_g_vec) + sort_idx = jnp.argsort(log_T_g_vec) + log_T_g_sorted = log_T_g_vec[sort_idx] + log_a_sorted = jnp.log(a_vec)[sort_idx] + log_t_sorted = jnp.log(t_vec)[sort_idx] + + a_start = jnp.exp( + interpax.interp1d( + jnp.log(T_start), + log_T_g_sorted, + log_a_sorted, + method='linear', + extrap=True ) ) t_start = jnp.exp( - jnp.interp( - jnp.log(T_start), - jnp.flip(jnp.log(T_g_vec)), - jnp.flip(jnp.log(t_vec)), - left=jnp.log(t_vec[-1]), right=jnp.log(t_vec[0]) + interpax.interp1d( + jnp.log(T_start), + log_T_g_sorted, + log_t_sorted, + method='linear', + extrap=True ) ) t_end = jnp.exp( - jnp.interp( - jnp.log(T_end), - jnp.flip(jnp.log(T_g_vec)), - jnp.flip(jnp.log(t_vec)), - left=jnp.log(t_vec[-1]), right=jnp.log(t_vec[0]) + interpax.interp1d( + jnp.log(T_end), + log_T_g_sorted, + log_t_sorted, + method='linear', + extrap=True ) ) @@ -308,13 +321,22 @@ def get_t(self, rho_g_vec, rho_nu_vec, rho_NP_vec, P_NP_vec): P_tot_vec = p_EM_std_v(T_g_vec) + 3 * (rho_nu_vec/3) + P_NP_vec - def P_tot(rho_tot): + # Sort by log(rho_tot) to handle non-monotonic energy density evolution + # (e.g., in reheating scenarios) + log_rho_tot_vec = jnp.log(rho_tot_vec) + sort_idx_rho = jnp.argsort(log_rho_tot_vec) + log_rho_sorted = log_rho_tot_vec[sort_idx_rho] + log_P_sorted = jnp.log(P_tot_vec)[sort_idx_rho] + + def P_tot(rho_tot): return jnp.exp( - jnp.interp( - jnp.log(rho_tot), - jnp.flip(jnp.log(rho_tot_vec)), - jnp.flip(jnp.log(P_tot_vec)) + interpax.interp1d( + jnp.log(rho_tot), + log_rho_sorted, + log_P_sorted, + method='linear', + extrap=True ) ) @@ -374,14 +396,24 @@ def get_a(self, rho_g_vec, rho_nu_vec, rho_NP_vec, P_NP_vec): P_tot_vec = p_EM_std_v(T_g_vec) + 3 * (rho_nu_vec/3) + P_NP_vec - def P_tot(rho_tot): + # Sort by log(rho_tot) to handle non-monotonic energy density evolution + # (e.g., in reheating scenarios) + log_rho_tot_vec = jnp.log(rho_tot_vec) + sort_idx_rho = jnp.argsort(log_rho_tot_vec) + log_rho_sorted = log_rho_tot_vec[sort_idx_rho] + log_P_sorted = jnp.log(P_tot_vec)[sort_idx_rho] + + def P_tot(rho_tot): return jnp.exp( - jnp.interp( - jnp.log(rho_tot), jnp.flip(jnp.log(rho_tot_vec)), - jnp.flip(jnp.log(P_tot_vec)) + interpax.interp1d( + jnp.log(rho_tot), + log_rho_sorted, + log_P_sorted, + method='linear', + extrap=True ) - ) + ) def dlna_prime(rho_tot, t, args): diff --git a/linx/nuclear.py b/linx/nuclear.py index 9c6c111..f1b71a1 100644 --- a/linx/nuclear.py +++ b/linx/nuclear.py @@ -1,7 +1,8 @@ import jax.numpy as jnp import equinox as eqx +import interpax -import linx.const as const +import linx.const as const from linx.reactions import Reaction class NuclearRates(eqx.Module): @@ -179,15 +180,25 @@ def __call__( nuclear_rates_q = jnp.array([0. for _ in self.reactions]) dYdt_vec = jnp.zeros(len(Y)) - - _nTOp_frwrd = jnp.interp( - T_t, jnp.flip(T_interval), jnp.flip(nTOp_frwrd_vec), - left=nTOp_frwrd_vec[-1],right=nTOp_frwrd_vec[0] + + # Sort by T_interval to handle non-monotonic temperature evolution + # (e.g., in reheating scenarios). Note: T_interval is typically + # monotonically decreasing from logspace, but we sort to be safe. + sort_idx = jnp.argsort(T_interval) + T_sorted = T_interval[sort_idx] + nTOp_frwrd_sorted = nTOp_frwrd_vec[sort_idx] + nTOp_bkwrd_sorted = nTOp_bkwrd_vec[sort_idx] + + _nTOp_frwrd = interpax.interp1d( + T_t, T_sorted, nTOp_frwrd_sorted, + method='linear', + extrap=True ) / (const.tau_n * tau_n_fac) - _nTOp_bkwrd = jnp.interp( - T_t, jnp.flip(T_interval), jnp.flip(nTOp_bkwrd_vec), - left=nTOp_bkwrd_vec[-1],right=nTOp_bkwrd_vec[0] + _nTOp_bkwrd = interpax.interp1d( + T_t, T_sorted, nTOp_bkwrd_sorted, + method='linear', + extrap=True ) / (const.tau_n * tau_n_fac) # These functions take temperature in K. diff --git a/linx/weak_rates.py b/linx/weak_rates.py index 9d487f4..71940c7 100644 --- a/linx/weak_rates.py +++ b/linx/weak_rates.py @@ -6,12 +6,13 @@ import jax.numpy as jnp # As of jax v0.4.24, jax.numpy.trapz has been deprecated, and replaced # by scipy.integrate.trapezoid. -try: - import jax.numpy.trapz as trapz -except ImportError: +try: + import jax.numpy.trapz as trapz +except ImportError: from jax.scipy.integrate import trapezoid as trapz import equinox as eqx +import interpax import linx.const as const from linx.special_funcs import gamma @@ -201,11 +202,19 @@ def nTOp_rates(self, Tg, T_vec_ref): Tg_vec_ref, Tnu_vec_ref = T_vec_ref Tnu_of_Tg_ref = Tnu_vec_ref / Tg_vec_ref - + + # Sort by Tg to handle non-monotonic temperature evolution + # (e.g., in reheating scenarios) + sort_idx = jnp.argsort(Tg_vec_ref) + Tg_sorted = Tg_vec_ref[sort_idx] + Tnu_of_Tg_sorted = Tnu_of_Tg_ref[sort_idx] + x = me / Tg xnu = me / ( - Tg * jnp.interp( - Tg, jnp.flip(Tg_vec_ref), jnp.flip(Tnu_of_Tg_ref), left=Tnu_of_Tg_ref[-1], right=Tnu_of_Tg_ref[0] + Tg * interpax.interp1d( + Tg, Tg_sorted, Tnu_of_Tg_sorted, + method='linear', + extrap=True ) ) diff --git a/requirements.txt b/requirements.txt index c1d681f..e9cdbf1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,7 @@ diffrax>=0.6.2 dynesty emcee equinox +interpax jax>=0.4.38 jaxlib>=0.4.38 jaxopt