From 3e7ceaadbb158a7be056daade62bd72ed04a27ef Mon Sep 17 00:00:00 2001 From: Siddharth Mishra-Sharma Date: Fri, 23 Jan 2026 23:49:16 -0500 Subject: [PATCH] Replace jnp.interp with interpax.interp1d in thermo.py Follow-up to PR #16 which replaced jnp.interp with interpax in abundances.py. This change applies the same improvement to thermo.py for consistency and better performance. Changes: - Import interpax module - Flip QED correction tables at load time (instead of at each call) for monotonically increasing x coordinates required by interpax - Replace 6 jnp.interp calls with interpax.interp1d: - rho_EM_std: 2 calls for QED corrections - p_EM_std: 1 call for QED correction - rho_plus_p_EM_std: 1 call for QED correction - G_nue_with_me: 1 call for collision factor interpolation - G_numt_with_me: 1 call for collision factor interpolation Co-Authored-By: Claude Opus 4.5 --- linx/thermo.py | 57 +++++++++++++++++++++++++------------------------- 1 file changed, 29 insertions(+), 28 deletions(-) diff --git a/linx/thermo.py b/linx/thermo.py index 85cff9d..8582a53 100644 --- a/linx/thermo.py +++ b/linx/thermo.py @@ -1,10 +1,11 @@ -import os +import os import numpy as np -import jax.numpy as jnp +import jax.numpy as jnp import jax.lax as lax from jax import grad, vmap, device_put, devices +import interpax import linx.const as const from linx.special_funcs import Li, K1, K2 @@ -633,10 +634,10 @@ def p_massive_MB(T, mu, m, g): file_dir = os.path.dirname(__file__) -# QED Corrections -P_QED_tab = np.loadtxt(file_dir+"/data/background/"+"QED_P_int.txt") -dPdT_QED_tab = np.loadtxt(file_dir+"/data/background/"+"QED_dP_intdT.txt") -d2PdT2_QED_tab = np.loadtxt(file_dir+"/data/background/"+"QED_d2P_intdT2.txt") +# QED Corrections - flip to ensure monotonically increasing T for interpax.interp1d +P_QED_tab = np.flip(np.loadtxt(file_dir+"/data/background/"+"QED_P_int.txt"), axis=0) +dPdT_QED_tab = np.flip(np.loadtxt(file_dir+"/data/background/"+"QED_dP_intdT.txt"), axis=0) +d2PdT2_QED_tab = np.flip(np.loadtxt(file_dir+"/data/background/"+"QED_d2P_intdT2.txt"), axis=0) # Effect of standard value of electron mass in scattering matrix elements f_nue_scat_tab = np.loadtxt(file_dir+"/data/background/"+"nue_scatt.txt") @@ -704,13 +705,13 @@ def rho_EM_std(T_g, mu=0, LO=True, NLO=True): """ corr_QED = ( - -jnp.interp( - T_g, jnp.flip(P_QED_tab[:,0]), - jnp.flip(LO*P_QED_tab[:,1]+NLO*P_QED_tab[:,2]) - ) - + T_g*jnp.interp( - T_g, jnp.flip(dPdT_QED_tab[:,0]), - jnp.flip(LO*dPdT_QED_tab[:,1]+NLO*dPdT_QED_tab[:,2]) + -interpax.interp1d( + T_g, P_QED_tab[:,0], + LO*P_QED_tab[:,1]+NLO*P_QED_tab[:,2] + ) + + T_g*interpax.interp1d( + T_g, dPdT_QED_tab[:,0], + LO*dPdT_QED_tab[:,1]+NLO*dPdT_QED_tab[:,2] ) ) @@ -745,9 +746,9 @@ def p_EM_std(T_g, mu=0, LO=True, NLO=True): Units of MeV^4. """ - corr_QED = jnp.interp( - T_g, jnp.flip(P_QED_tab[:,0]), - jnp.flip(LO*P_QED_tab[:,1] + NLO*P_QED_tab[:,2]) + corr_QED = interpax.interp1d( + T_g, P_QED_tab[:,0], + LO*P_QED_tab[:,1] + NLO*P_QED_tab[:,2] ) return ( @@ -781,9 +782,9 @@ def rho_plus_p_EM_std(T_g, mu=0, LO=True, NLO=True): Units of MeV^4. """ - corr_QED = T_g * jnp.interp( - T_g, jnp.flip(dPdT_QED_tab[:,0]), - jnp.flip(LO*dPdT_QED_tab[:,1] + NLO*dPdT_QED_tab[:,2]) + corr_QED = T_g * interpax.interp1d( + T_g, dPdT_QED_tab[:,0], + LO*dPdT_QED_tab[:,1] + NLO*dPdT_QED_tab[:,2] ) return ( @@ -1018,10 +1019,10 @@ def G(T_1, mu_1, T_2, mu_2): def G_nue_with_me(T_1, mu_1, T_2, mu_2): - def interp_f(f_tab): - - return jnp.interp( - T_1, f_tab[:,0], f_tab[:,1], left=f_tab[0,1], right=f_tab[-1,1] + def interp_f(f_tab): + # Tables have boundary values 0.0 (low T) and 1.0 (high T) + return interpax.interp1d( + T_1, f_tab[:,0], f_tab[:,1], extrap=(0.0, 1.0) ) f_nue_ann = lax.cond( @@ -1042,12 +1043,12 @@ def interp_f(f_tab): ) ) - def G_numt_with_me(T_1, mu_1, T_2, mu_2): - - def interp_f(f_tab): + def G_numt_with_me(T_1, mu_1, T_2, mu_2): - return jnp.interp( - T_1, f_tab[:,0], f_tab[:,1], left=f_tab[0,1], right=f_tab[-1,1] + def interp_f(f_tab): + # Tables have boundary values 0.0 (low T) and 1.0 (high T) + return interpax.interp1d( + T_1, f_tab[:,0], f_tab[:,1], extrap=(0.0, 1.0) ) f_numt_ann = lax.cond(