Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 60 additions & 28 deletions linx/abundances.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import jax.numpy as jnp
import equinox as eqx
from diffrax import diffeqsolve, ODETerm, Tsit5, Kvaerno3, PIDController, SaveAt
import interpax

import linx.nuclear as nucl
import linx.const as const
Expand Down Expand Up @@ -192,32 +193,44 @@ def __call__(

# These are in MeV
T_g_vec = thermo.T_g(rho_g_vec)
T_nu_vec = thermo.T_nu(rho_nu_vec)

a_start = jnp.exp(
jnp.interp(
jnp.log(T_start),
jnp.flip(jnp.log(T_g_vec)),
jnp.flip(jnp.log(a_vec)),
left=jnp.log(a_vec[-1]), right=jnp.log(a_vec[0])
T_nu_vec = thermo.T_nu(rho_nu_vec)

# Sort by log(T_g) to handle non-monotonic temperature evolution
# (e.g., in reheating scenarios). interpax.interp1d requires
# monotonically increasing x coordinates.
log_T_g_vec = jnp.log(T_g_vec)
sort_idx = jnp.argsort(log_T_g_vec)
log_T_g_sorted = log_T_g_vec[sort_idx]
log_a_sorted = jnp.log(a_vec)[sort_idx]
log_t_sorted = jnp.log(t_vec)[sort_idx]

a_start = jnp.exp(
interpax.interp1d(
jnp.log(T_start),
log_T_g_sorted,
log_a_sorted,
method='linear',
extrap=True
)
)

t_start = jnp.exp(
jnp.interp(
jnp.log(T_start),
jnp.flip(jnp.log(T_g_vec)),
jnp.flip(jnp.log(t_vec)),
left=jnp.log(t_vec[-1]), right=jnp.log(t_vec[0])
interpax.interp1d(
jnp.log(T_start),
log_T_g_sorted,
log_t_sorted,
method='linear',
extrap=True
)
)

t_end = jnp.exp(
jnp.interp(
jnp.log(T_end),
jnp.flip(jnp.log(T_g_vec)),
jnp.flip(jnp.log(t_vec)),
left=jnp.log(t_vec[-1]), right=jnp.log(t_vec[0])
interpax.interp1d(
jnp.log(T_end),
log_T_g_sorted,
log_t_sorted,
method='linear',
extrap=True
)
)

Expand Down Expand Up @@ -308,13 +321,22 @@ def get_t(self, rho_g_vec, rho_nu_vec, rho_NP_vec, P_NP_vec):

P_tot_vec = p_EM_std_v(T_g_vec) + 3 * (rho_nu_vec/3) + P_NP_vec

def P_tot(rho_tot):
# Sort by log(rho_tot) to handle non-monotonic energy density evolution
# (e.g., in reheating scenarios)
log_rho_tot_vec = jnp.log(rho_tot_vec)
sort_idx_rho = jnp.argsort(log_rho_tot_vec)
log_rho_sorted = log_rho_tot_vec[sort_idx_rho]
log_P_sorted = jnp.log(P_tot_vec)[sort_idx_rho]

def P_tot(rho_tot):

return jnp.exp(
jnp.interp(
jnp.log(rho_tot),
jnp.flip(jnp.log(rho_tot_vec)),
jnp.flip(jnp.log(P_tot_vec))
interpax.interp1d(
jnp.log(rho_tot),
log_rho_sorted,
log_P_sorted,
method='linear',
extrap=True
)
)

Expand Down Expand Up @@ -374,14 +396,24 @@ def get_a(self, rho_g_vec, rho_nu_vec, rho_NP_vec, P_NP_vec):

P_tot_vec = p_EM_std_v(T_g_vec) + 3 * (rho_nu_vec/3) + P_NP_vec

def P_tot(rho_tot):
# Sort by log(rho_tot) to handle non-monotonic energy density evolution
# (e.g., in reheating scenarios)
log_rho_tot_vec = jnp.log(rho_tot_vec)
sort_idx_rho = jnp.argsort(log_rho_tot_vec)
log_rho_sorted = log_rho_tot_vec[sort_idx_rho]
log_P_sorted = jnp.log(P_tot_vec)[sort_idx_rho]

def P_tot(rho_tot):

return jnp.exp(
jnp.interp(
jnp.log(rho_tot), jnp.flip(jnp.log(rho_tot_vec)),
jnp.flip(jnp.log(P_tot_vec))
interpax.interp1d(
jnp.log(rho_tot),
log_rho_sorted,
log_P_sorted,
method='linear',
extrap=True
)
)
)

def dlna_prime(rho_tot, t, args):

Expand Down
27 changes: 19 additions & 8 deletions linx/nuclear.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import jax.numpy as jnp
import equinox as eqx
import interpax

import linx.const as const
import linx.const as const
from linx.reactions import Reaction

class NuclearRates(eqx.Module):
Expand Down Expand Up @@ -179,15 +180,25 @@ def __call__(
nuclear_rates_q = jnp.array([0. for _ in self.reactions])

dYdt_vec = jnp.zeros(len(Y))

_nTOp_frwrd = jnp.interp(
T_t, jnp.flip(T_interval), jnp.flip(nTOp_frwrd_vec),
left=nTOp_frwrd_vec[-1],right=nTOp_frwrd_vec[0]

# Sort by T_interval to handle non-monotonic temperature evolution
# (e.g., in reheating scenarios). Note: T_interval is typically
# monotonically decreasing from logspace, but we sort to be safe.
sort_idx = jnp.argsort(T_interval)
T_sorted = T_interval[sort_idx]
nTOp_frwrd_sorted = nTOp_frwrd_vec[sort_idx]
nTOp_bkwrd_sorted = nTOp_bkwrd_vec[sort_idx]

_nTOp_frwrd = interpax.interp1d(
T_t, T_sorted, nTOp_frwrd_sorted,
method='linear',
extrap=True
) / (const.tau_n * tau_n_fac)

_nTOp_bkwrd = jnp.interp(
T_t, jnp.flip(T_interval), jnp.flip(nTOp_bkwrd_vec),
left=nTOp_bkwrd_vec[-1],right=nTOp_bkwrd_vec[0]
_nTOp_bkwrd = interpax.interp1d(
T_t, T_sorted, nTOp_bkwrd_sorted,
method='linear',
extrap=True
) / (const.tau_n * tau_n_fac)

# These functions take temperature in K.
Expand Down
21 changes: 15 additions & 6 deletions linx/weak_rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
import jax.numpy as jnp
# As of jax v0.4.24, jax.numpy.trapz has been deprecated, and replaced
# by scipy.integrate.trapezoid.
try:
import jax.numpy.trapz as trapz
except ImportError:
try:
import jax.numpy.trapz as trapz
except ImportError:
from jax.scipy.integrate import trapezoid as trapz

import equinox as eqx
import interpax

import linx.const as const
from linx.special_funcs import gamma
Expand Down Expand Up @@ -201,11 +202,19 @@ def nTOp_rates(self, Tg, T_vec_ref):

Tg_vec_ref, Tnu_vec_ref = T_vec_ref
Tnu_of_Tg_ref = Tnu_vec_ref / Tg_vec_ref


# Sort by Tg to handle non-monotonic temperature evolution
# (e.g., in reheating scenarios)
sort_idx = jnp.argsort(Tg_vec_ref)
Tg_sorted = Tg_vec_ref[sort_idx]
Tnu_of_Tg_sorted = Tnu_of_Tg_ref[sort_idx]

x = me / Tg
xnu = me / (
Tg * jnp.interp(
Tg, jnp.flip(Tg_vec_ref), jnp.flip(Tnu_of_Tg_ref), left=Tnu_of_Tg_ref[-1], right=Tnu_of_Tg_ref[0]
Tg * interpax.interp1d(
Tg, Tg_sorted, Tnu_of_Tg_sorted,
method='linear',
extrap=True
)
)

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ diffrax>=0.6.2
dynesty
emcee
equinox
interpax
jax>=0.4.38
jaxlib>=0.4.38
jaxopt
Expand Down