Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 44 additions & 32 deletions linx/abundances.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,45 +15,48 @@
from linx.thermo import rho_EM_std_v, p_EM_std_v, nB
from linx.special_funcs import zeta_3

class AbundanceModel(eqx.Module):
class AbundanceModel(eqx.Module):
"""
Abundance model and BBN abundance prediction.
Abundance model and BBN abundance prediction.

Attributes
----------
nuclear_net : NuclearRates
Nuclear network to be used for BBN prediction.
Nuclear network to be used for BBN prediction.
weak_rates : WeakRates
Weak rates for neutron-proton interconversion.
Weak rates for neutron-proton interconversion.
species_dict : dict
Dictionary of species considered in LINX.
Dictionary of species considered in LINX.
species_Z : list
Number of protons in each species.
Number of protons in each species.
species_N : list
Number of neutrons in each species.
Number of neutrons in each species.
species_A : list
Atomic mass number of each species.
Atomic mass number of each species.
species_excess_mass : list
Excess mass (mass - A*amu) of each species.
Excess mass (mass - A*amu) of each species.
species_spin : list
Spin of each species.
Spin of each species.
species_binding_energy : list
Binding energy of each species.
Binding energy of each species.
species_mass : list
Mass of each species.
Mass of each species.
throw : bool
Whether to raise exceptions on solver failure.
"""
nuclear_net : nucl.NuclearRates
weak_rates : wr.WeakRates
nuclear_net : nucl.NuclearRates
weak_rates : wr.WeakRates
species_dict : dict
species_Z : list
species_N : list
species_A : list
species_A : list
species_excess_mass : dict
species_spin : list
species_binding_energy : list
species_mass : list
throw : bool

def __init__(self, nuclear_net, weak_rates=wr.WeakRates()):
def __init__(self, nuclear_net, weak_rates=wr.WeakRates(), throw=True):
"""
Initialize the AbundanceModel with nuclear and weak rate networks.

Expand All @@ -65,6 +68,9 @@ def __init__(self, nuclear_net, weak_rates=wr.WeakRates()):
weak_rates : WeakRates, optional
Weak interaction rates for neutron-proton interconversion.
Defaults to standard WeakRates instance.
throw : bool, optional
If True, raise exceptions on solver failure. Default is True.
Set to False for parameter scans where some combinations may fail.

Notes
-----
Expand All @@ -76,6 +82,7 @@ def __init__(self, nuclear_net, weak_rates=wr.WeakRates()):

self.nuclear_net = nuclear_net
self.weak_rates = weak_rates
self.throw = throw

self.species_dict = {
0:'n', 1:'p', 2:'d', 3:'t', 4:'He3', 5:'a', 6:'Li7', 7:'Be7',
Expand Down Expand Up @@ -291,15 +298,18 @@ def __call__(
saveat = SaveAt(t1=True)

sol = diffeqsolve(
ODETerm(self.Y_prime), solver,
t0=t_start, t1=t_end, dt0=None, y0=Y_i,
args = (
a_vec, t_vec, T_g_vec, T_interval_nTOp, nTOp_frwrd,
ODETerm(self.Y_prime), solver,
t0=t_start, t1=t_end, dt0=None, y0=Y_i,
args=(
a_vec, t_vec, T_g_vec, T_interval_nTOp, nTOp_frwrd,
nTOp_bkwrd, eta_fac, tau_n_fac, nuclear_rates_q
), saveat=saveat, stepsize_controller = PIDController(
),
saveat=saveat,
stepsize_controller=PIDController(
rtol=rtol, atol=atol,
),
max_steps=max_steps
),
max_steps=max_steps,
throw=self.throw
)

if save_history:
Expand Down Expand Up @@ -369,12 +379,13 @@ def dt_prime(rho_tot, t, args):
rho_tot_fin = rho_tot_vec[-1]

sol_t = diffeqsolve(
ODETerm(dt_prime), Tsit5(),
t0=rho_tot_init, t1=rho_tot_fin,
y0=1. / (2 * thermo.Hubble(rho_tot_init)),
ODETerm(dt_prime), Tsit5(),
t0=rho_tot_init, t1=rho_tot_fin,
y0=1. / (2 * thermo.Hubble(rho_tot_init)),
dt0=None, max_steps=4096,
saveat=SaveAt(ts=rho_tot_vec),
stepsize_controller=PIDController(rtol=1e-8, atol=1e-10)
saveat=SaveAt(ts=rho_tot_vec),
stepsize_controller=PIDController(rtol=1e-8, atol=1e-10),
throw=self.throw
)

return sol_t.ys
Expand Down Expand Up @@ -441,13 +452,14 @@ def dlna_prime(rho_tot, t, args):
rho_tot_init = rho_tot_vec[0]
rho_tot_fin = rho_tot_vec[-1]

# a_0 = 1 arbitrarily, will rescale later.
# a_0 = 1 arbitrarily, will rescale later.
sol_lna = diffeqsolve(
ODETerm(dlna_prime), Tsit5(),
t0=rho_tot_init, t1=rho_tot_fin,
ODETerm(dlna_prime), Tsit5(),
t0=rho_tot_init, t1=rho_tot_fin,
y0=0., dt0=None, max_steps=4096,
saveat=SaveAt(ts=rho_tot_vec),
stepsize_controller=PIDController(rtol=1e-8, atol=1e-10)
stepsize_controller=PIDController(rtol=1e-8, atol=1e-10),
throw=self.throw
)

a_fin = const.T0CMB / T_g_vec[-1]
Expand Down
21 changes: 15 additions & 6 deletions linx/background.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
thermo.rho_massless_FD, in_axes=(0, None, None)
)

class BackgroundModel(eqx.Module):
class BackgroundModel(eqx.Module):
"""Background model.

Attributes
Expand All @@ -31,15 +31,19 @@ class BackgroundModel(eqx.Module):
Whether to use leading order QED correction. Default is `True`.
NLO : bool, optional
Whether to use next-to-leading order QED correction. Default is True.
throw : bool, optional
Whether to raise exceptions on solver failure. Default is `True`.
Set to `False` for parameter scans where some combinations may fail.
"""

decoupled : bool
use_FD : bool
collision_me : bool
LO : bool
NLO : bool
throw : bool

def __init__(self, decoupled=False, use_FD=True, collision_me=True, LO=True, NLO = True):
def __init__(self, decoupled=False, use_FD=True, collision_me=True, LO=True, NLO=True, throw=True):
"""
Initialize the BackgroundModel with thermodynamic options.

Expand All @@ -56,13 +60,17 @@ def __init__(self, decoupled=False, use_FD=True, collision_me=True, LO=True, NLO
If True, include leading order QED corrections. Default is True.
NLO : bool, optional
If True, include next-to-leading order QED corrections. Default is True.
throw : bool, optional
If True, raise exceptions on solver failure. Default is True.
Set to False for parameter scans where some combinations may fail.
"""

self.decoupled = decoupled
self.use_FD = use_FD
self.collision_me = collision_me
self.LO = LO
self.NLO = NLO
self.throw = throw

@eqx.filter_jit
def __call__(
Expand Down Expand Up @@ -131,12 +139,13 @@ def T_EM_check(t, y, args, **kwargs):

sol = diffeqsolve(
ODETerm(self.dY), solver, args=(lna_init, rho_extra_init),
t0=0., t1=jnp.inf, dt0=None, y0=Y0,
t0=0., t1=jnp.inf, dt0=None, y0=Y0,
saveat=SaveAt(steps=True), event=Event(T_EM_check),
stepsize_controller = PIDController(
stepsize_controller=PIDController(
rtol=rtol, atol=atol
),
max_steps=max_steps
),
max_steps=max_steps,
throw=self.throw
)

a_vec = jnp.exp(sol.ys[0])
Expand Down