Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 141 additions & 84 deletions linx/abundances.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,48 +12,51 @@
from linx.const import ma, me, mn, mp
import linx.weak_rates as wr
import linx.thermo as thermo
from linx.thermo import rho_EM_std_v, p_EM_std_v, nB
from linx.thermo import rho_EM_std_v, p_EM_std_v, nB, ThermoResult
from linx.special_funcs import zeta_3

class AbundanceModel(eqx.Module):
class AbundanceModel(eqx.Module):
"""
Abundance model and BBN abundance prediction.
Abundance model and BBN abundance prediction.

Attributes
----------
nuclear_net : NuclearRates
Nuclear network to be used for BBN prediction.
Nuclear network to be used for BBN prediction.
weak_rates : WeakRates
Weak rates for neutron-proton interconversion.
Weak rates for neutron-proton interconversion.
species_dict : dict
Dictionary of species considered in LINX.
Dictionary of species considered in LINX.
species_Z : list
Number of protons in each species.
Number of protons in each species.
species_N : list
Number of neutrons in each species.
Number of neutrons in each species.
species_A : list
Atomic mass number of each species.
Atomic mass number of each species.
species_excess_mass : list
Excess mass (mass - A*amu) of each species.
Excess mass (mass - A*amu) of each species.
species_spin : list
Spin of each species.
Spin of each species.
species_binding_energy : list
Binding energy of each species.
Binding energy of each species.
species_mass : list
Mass of each species.
Mass of each species.
throw : bool
Whether to raise exceptions on solver failure.
"""
nuclear_net : nucl.NuclearRates
weak_rates : wr.WeakRates
nuclear_net : nucl.NuclearRates
weak_rates : wr.WeakRates
species_dict : dict
species_Z : list
species_N : list
species_A : list
species_A : list
species_excess_mass : dict
species_spin : list
species_binding_energy : list
species_mass : list
throw : bool

def __init__(self, nuclear_net, weak_rates=wr.WeakRates()):
def __init__(self, nuclear_net, weak_rates=wr.WeakRates(), throw=True):
"""
Initialize the AbundanceModel with nuclear and weak rate networks.

Expand All @@ -65,6 +68,9 @@ def __init__(self, nuclear_net, weak_rates=wr.WeakRates()):
weak_rates : WeakRates, optional
Weak interaction rates for neutron-proton interconversion.
Defaults to standard WeakRates instance.
throw : bool, optional
If True, raise exceptions on solver failure. Default is True.
Set to False for parameter scans where some combinations may fail.

Notes
-----
Expand All @@ -76,6 +82,7 @@ def __init__(self, nuclear_net, weak_rates=wr.WeakRates()):

self.nuclear_net = nuclear_net
self.weak_rates = weak_rates
self.throw = throw

self.species_dict = {
0:'n', 1:'p', 2:'d', 3:'t', 4:'He3', 5:'a', 6:'Li7', 7:'Be7',
Expand Down Expand Up @@ -114,101 +121,146 @@ def __init__(self, nuclear_net, weak_rates=wr.WeakRates()):

@eqx.filter_jit
def __call__(
self, rho_g_vec, rho_nu_vec, rho_NP_vec, P_NP_vec,
a_vec=None, t_vec=None,
eta_fac=jnp.asarray(1.), tau_n_fac = jnp.asarray(1.),
nuclear_rates_q=None,
Y_i=None, T_start=None, T_end=None, sampling_nTOp=150,
self, rho_g_vec_or_thermo, rho_nu_vec=None, rho_NP_vec=None, P_NP_vec=None,
a_vec=None, t_vec=None,
eta_fac=jnp.asarray(1.), tau_n_fac=jnp.asarray(1.),
nuclear_rates_q=None,
Y_i=None, T_start=None, T_end=None, sampling_nTOp=150,
rtol=1e-6, atol=1e-9, solver=Kvaerno3(),
max_steps=4096,
save_history=False
):
"""
Calculate BBN abundance.
Calculate BBN abundance.

Parameters
----------
rho_g_vec : array
Energy density of photons in MeV^4.
rho_nu_vec : array
Energy density of a single neutrino species in MeV^4
(all neutrinos assumed to have the same temperature).
rho_NP_vec : array
Energy density of all new physics particles in MeV^4.
P_NP_vec : array
Pressure of all new physics particles in MeV^4.
rho_g_vec_or_thermo : array or ThermoResult
Either energy density of photons in MeV^4, or a ThermoResult object
from BackgroundModel. If ThermoResult, other energy density/pressure
arguments and T_start/T_end are extracted automatically.
rho_nu_vec : array, optional
Energy density of a single neutrino species in MeV^4
(all neutrinos assumed to have the same temperature).
Required if rho_g_vec_or_thermo is an array.
rho_NP_vec : array, optional
Energy density of all new physics particles in MeV^4.
Required if rho_g_vec_or_thermo is an array.
P_NP_vec : array, optional
Pressure of all new physics particles in MeV^4.
Required if rho_g_vec_or_thermo is an array.
a_vec : array, optional
Scale factor. If `None`, will be computed in function.
Scale factor. If `None`, will be computed in function.
t_vec : array, optional
Time in seconds. If `None`, will be computed in function.
Time in seconds. If `None`, will be computed in function.
eta_fac : float, optional
Rescaling factor for baryon-to-photon ratio, 1 for fiducial value
in `const.eta0` (or `const.Omegabh2`).
Rescaling factor for baryon-to-photon ratio, 1 for fiducial value
in `const.eta0` (or `const.Omegabh2`).
tau_n_fac : float, optional
Rescaling factor for neutron decay lifetime, 1 for fiducial value
in `const.eta0` (or `const.Omegabh2`).
Rescaling factor for neutron decay lifetime, 1 for fiducial value
in `const.eta0` (or `const.Omegabh2`).
nuclear_rates_q : array, optional
q ~ N(0,1) specifies the nuclear rate in its log-normal
distribution. If not specified, will be taken to be `q = 0`.
q ~ N(0,1) specifies the nuclear rate in its log-normal
distribution. If not specified, will be taken to be `q = 0`.
Y_i : tuple of float, optional
Initial abundances :math:`n_i/n_b` for species. Length must be equal to
`self.nuclear_net.max_i_species`. Must specify `T_start` and `T_end` if not `None`.
T_start : float
Temperature in MeV to start integration. Must specify `Y_i` and `T_end` if not `None`, otherwise `const.T_start` used.
T_end : float
Temperature in MeV to end integration.
Initial abundances :math:`n_i/n_b` for species. Length must be equal to
`self.nuclear_net.max_i_species`. Must specify `T_start` and `T_end` if not `None`.
T_start : float, optional
Temperature in MeV to start integration. If not specified and
ThermoResult is provided, uses T_start from ThermoResult.
Otherwise uses `const.T_start`.
T_end : float, optional
Temperature in MeV to end integration. If not specified and
ThermoResult is provided, uses T_end from ThermoResult.
Otherwise uses `const.T_end`.
sampling_nTOp : int
Number of points to subdivide (`T_end`, `T_start`) for neutron-proton interconversion rate interpolation table.
Number of points to subdivide (`T_end`, `T_start`) for neutron-proton interconversion rate interpolation table.
rtol : float, optional
Relative tolerance of the abundance solver. Default is `1e-4`.
Relative tolerance of the abundance solver. Default is `1e-6`.
atol : float, optional
Absolute tolerance of the abundance solver. Default is `1e-9`.
Absolute tolerance of the abundance solver. Default is `1e-9`.
max_steps : int, optional
Maximum number of steps taken by the solver. Default is `4096`.
Increasing this slows down the code, while decreasing this could
mean that the solver cannot complete the solution.
solver : Diffrax ODE solver
The Diffrax ODE solver to use. A stiff solver is recommended.
Default is the 3rd order Kvaerno solver.
Maximum number of steps taken by the solver. Default is `4096`.
Increasing this slows down the code, while decreasing this could
mean that the solver cannot complete the solution.
solver : Diffrax ODE solver
The Diffrax ODE solver to use. A stiff solver is recommended.
Default is the 3rd order Kvaerno solver.
save_history : bool
If `True`, full solution is returned with temperature and time
If `True`, full solution is returned with temperature and time
abscissa.

Returns
-------
tuple of array or array
If `save_history` is set to `True`, a tuple containing an array of EM temperatures, an array of times, and a Diffrax `Solution` instance, which can be called as a function of time. Otherwise, returns yields of all species considered in `self.nuclear_net`.
If `save_history` is set to `True`, a tuple containing an array of EM temperatures, an array of times, and a Diffrax `Solution` instance, which can be called as a function of time. Otherwise, returns yields of all species considered in `self.nuclear_net`.

Examples
--------
>>> # New streamlined usage with ThermoResult
>>> thermo_result = BackgroundModel()(0.)
>>> abundances = AbundanceModel(nuclear_net)(thermo_result)

>>> # Backward compatible usage with arrays
>>> t, a, rho_g, rho_nu, rho_extra, P_extra, Neff = thermo_model(0.)[:7]
>>> abundances = abundance_model(rho_g, rho_nu, rho_extra, P_extra, t_vec=t, a_vec=a)
"""

print('Compiling abundance model...')

if Y_i is not None:
if T_start is None:
# Handle ThermoResult input for streamlined API
if isinstance(rho_g_vec_or_thermo, ThermoResult):
thermo_input = rho_g_vec_or_thermo
rho_g_vec = thermo_input.rho_g_vec
if rho_nu_vec is None:
rho_nu_vec = thermo_input.rho_nu_vec
if rho_NP_vec is None:
rho_NP_vec = thermo_input.rho_extra_vec
if P_NP_vec is None:
P_NP_vec = thermo_input.P_extra_vec
if a_vec is None:
a_vec = thermo_input.a_vec
if t_vec is None:
t_vec = thermo_input.t_vec
# Auto-detect temperature range from ThermoResult if not specified
if T_start is None:
T_start = thermo_input.T_start
if T_end is None:
T_end = thermo_input.T_end
else:
rho_g_vec = rho_g_vec_or_thermo
# Validate required arguments for array input
if rho_nu_vec is None or rho_NP_vec is None or P_NP_vec is None:
raise TypeError(
"When passing arrays directly, rho_nu_vec, rho_NP_vec, and P_NP_vec are required"
)

if Y_i is not None:
if T_start is None:
raise TypeError('Specifying Y_i requires specifying a T_start')
if T_start is not None:
if Y_i is None:
if T_start is not None:
if Y_i is None:
raise TypeError('Specifying T_start requires specifying Y_i')

if nuclear_rates_q is None:
if nuclear_rates_q is None:

nuclear_rates_q = jnp.array(
[0. for _ in self.nuclear_net.reactions]
)

if t_vec is None:
if t_vec is None:
t_vec = self.get_t(rho_g_vec, rho_nu_vec, rho_NP_vec, P_NP_vec)

if a_vec is None:
if a_vec is None:
a_vec = self.get_a(rho_g_vec, rho_nu_vec, rho_NP_vec, P_NP_vec)

if T_start is None:
if T_start is None:

T_start = const.T_start
T_start = const.T_start

if T_end is None:
if T_end is None:

T_end = const.T_end
T_end = const.T_end

# These are in MeV
T_g_vec = thermo.T_g(rho_g_vec)
Expand Down Expand Up @@ -291,15 +343,18 @@ def __call__(
saveat = SaveAt(t1=True)

sol = diffeqsolve(
ODETerm(self.Y_prime), solver,
t0=t_start, t1=t_end, dt0=None, y0=Y_i,
args = (
a_vec, t_vec, T_g_vec, T_interval_nTOp, nTOp_frwrd,
ODETerm(self.Y_prime), solver,
t0=t_start, t1=t_end, dt0=None, y0=Y_i,
args=(
a_vec, t_vec, T_g_vec, T_interval_nTOp, nTOp_frwrd,
nTOp_bkwrd, eta_fac, tau_n_fac, nuclear_rates_q
), saveat=saveat, stepsize_controller = PIDController(
),
saveat=saveat,
stepsize_controller=PIDController(
rtol=rtol, atol=atol,
),
max_steps=max_steps
),
max_steps=max_steps,
throw=self.throw
)

if save_history:
Expand Down Expand Up @@ -369,12 +424,13 @@ def dt_prime(rho_tot, t, args):
rho_tot_fin = rho_tot_vec[-1]

sol_t = diffeqsolve(
ODETerm(dt_prime), Tsit5(),
t0=rho_tot_init, t1=rho_tot_fin,
y0=1. / (2 * thermo.Hubble(rho_tot_init)),
ODETerm(dt_prime), Tsit5(),
t0=rho_tot_init, t1=rho_tot_fin,
y0=1. / (2 * thermo.Hubble(rho_tot_init)),
dt0=None, max_steps=4096,
saveat=SaveAt(ts=rho_tot_vec),
stepsize_controller=PIDController(rtol=1e-8, atol=1e-10)
saveat=SaveAt(ts=rho_tot_vec),
stepsize_controller=PIDController(rtol=1e-8, atol=1e-10),
throw=self.throw
)

return sol_t.ys
Expand Down Expand Up @@ -441,13 +497,14 @@ def dlna_prime(rho_tot, t, args):
rho_tot_init = rho_tot_vec[0]
rho_tot_fin = rho_tot_vec[-1]

# a_0 = 1 arbitrarily, will rescale later.
# a_0 = 1 arbitrarily, will rescale later.
sol_lna = diffeqsolve(
ODETerm(dlna_prime), Tsit5(),
t0=rho_tot_init, t1=rho_tot_fin,
ODETerm(dlna_prime), Tsit5(),
t0=rho_tot_init, t1=rho_tot_fin,
y0=0., dt0=None, max_steps=4096,
saveat=SaveAt(ts=rho_tot_vec),
stepsize_controller=PIDController(rtol=1e-8, atol=1e-10)
stepsize_controller=PIDController(rtol=1e-8, atol=1e-10),
throw=self.throw
)

a_fin = const.T0CMB / T_g_vec[-1]
Expand Down
Loading