Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 29 additions & 28 deletions linx/thermo.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os
import os

import numpy as np

import jax.numpy as jnp
import jax.numpy as jnp
import jax.lax as lax
from jax import grad, vmap, device_put, devices
import interpax

import linx.const as const
from linx.special_funcs import Li, K1, K2
Expand Down Expand Up @@ -633,10 +634,10 @@ def p_massive_MB(T, mu, m, g):

file_dir = os.path.dirname(__file__)

# QED Corrections
P_QED_tab = np.loadtxt(file_dir+"/data/background/"+"QED_P_int.txt")
dPdT_QED_tab = np.loadtxt(file_dir+"/data/background/"+"QED_dP_intdT.txt")
d2PdT2_QED_tab = np.loadtxt(file_dir+"/data/background/"+"QED_d2P_intdT2.txt")
# QED Corrections - flip to ensure monotonically increasing T for interpax.interp1d
P_QED_tab = np.flip(np.loadtxt(file_dir+"/data/background/"+"QED_P_int.txt"), axis=0)
dPdT_QED_tab = np.flip(np.loadtxt(file_dir+"/data/background/"+"QED_dP_intdT.txt"), axis=0)
d2PdT2_QED_tab = np.flip(np.loadtxt(file_dir+"/data/background/"+"QED_d2P_intdT2.txt"), axis=0)

# Effect of standard value of electron mass in scattering matrix elements
f_nue_scat_tab = np.loadtxt(file_dir+"/data/background/"+"nue_scatt.txt")
Expand Down Expand Up @@ -704,13 +705,13 @@ def rho_EM_std(T_g, mu=0, LO=True, NLO=True):
"""

corr_QED = (
-jnp.interp(
T_g, jnp.flip(P_QED_tab[:,0]),
jnp.flip(LO*P_QED_tab[:,1]+NLO*P_QED_tab[:,2])
)
+ T_g*jnp.interp(
T_g, jnp.flip(dPdT_QED_tab[:,0]),
jnp.flip(LO*dPdT_QED_tab[:,1]+NLO*dPdT_QED_tab[:,2])
-interpax.interp1d(
T_g, P_QED_tab[:,0],
LO*P_QED_tab[:,1]+NLO*P_QED_tab[:,2]
)
+ T_g*interpax.interp1d(
T_g, dPdT_QED_tab[:,0],
LO*dPdT_QED_tab[:,1]+NLO*dPdT_QED_tab[:,2]
)
)

Expand Down Expand Up @@ -745,9 +746,9 @@ def p_EM_std(T_g, mu=0, LO=True, NLO=True):
Units of MeV^4.
"""

corr_QED = jnp.interp(
T_g, jnp.flip(P_QED_tab[:,0]),
jnp.flip(LO*P_QED_tab[:,1] + NLO*P_QED_tab[:,2])
corr_QED = interpax.interp1d(
T_g, P_QED_tab[:,0],
LO*P_QED_tab[:,1] + NLO*P_QED_tab[:,2]
)

return (
Expand Down Expand Up @@ -781,9 +782,9 @@ def rho_plus_p_EM_std(T_g, mu=0, LO=True, NLO=True):
Units of MeV^4.
"""

corr_QED = T_g * jnp.interp(
T_g, jnp.flip(dPdT_QED_tab[:,0]),
jnp.flip(LO*dPdT_QED_tab[:,1] + NLO*dPdT_QED_tab[:,2])
corr_QED = T_g * interpax.interp1d(
T_g, dPdT_QED_tab[:,0],
LO*dPdT_QED_tab[:,1] + NLO*dPdT_QED_tab[:,2]
)

return (
Expand Down Expand Up @@ -1018,10 +1019,10 @@ def G(T_1, mu_1, T_2, mu_2):

def G_nue_with_me(T_1, mu_1, T_2, mu_2):

def interp_f(f_tab):

return jnp.interp(
T_1, f_tab[:,0], f_tab[:,1], left=f_tab[0,1], right=f_tab[-1,1]
def interp_f(f_tab):
# Tables have boundary values 0.0 (low T) and 1.0 (high T)
return interpax.interp1d(
T_1, f_tab[:,0], f_tab[:,1], extrap=(0.0, 1.0)
)

f_nue_ann = lax.cond(
Expand All @@ -1042,12 +1043,12 @@ def interp_f(f_tab):
)
)

def G_numt_with_me(T_1, mu_1, T_2, mu_2):

def interp_f(f_tab):
def G_numt_with_me(T_1, mu_1, T_2, mu_2):

return jnp.interp(
T_1, f_tab[:,0], f_tab[:,1], left=f_tab[0,1], right=f_tab[-1,1]
def interp_f(f_tab):
# Tables have boundary values 0.0 (low T) and 1.0 (high T)
return interpax.interp1d(
T_1, f_tab[:,0], f_tab[:,1], extrap=(0.0, 1.0)
)

f_numt_ann = lax.cond(
Expand Down