diff --git a/linx/abundances.py b/linx/abundances.py index bc0fdb7..1fb62b7 100644 --- a/linx/abundances.py +++ b/linx/abundances.py @@ -12,7 +12,7 @@ from linx.const import ma, me, mn, mp import linx.weak_rates as wr import linx.thermo as thermo -from linx.thermo import rho_EM_std_v, p_EM_std_v, nB +from linx.thermo import rho_EM_std_v, p_EM_std_v, nB, ThermoResult from linx.special_funcs import zeta_3 class AbundanceModel(eqx.Module): @@ -114,101 +114,146 @@ def __init__(self, nuclear_net, weak_rates=wr.WeakRates()): @eqx.filter_jit def __call__( - self, rho_g_vec, rho_nu_vec, rho_NP_vec, P_NP_vec, - a_vec=None, t_vec=None, - eta_fac=jnp.asarray(1.), tau_n_fac = jnp.asarray(1.), - nuclear_rates_q=None, - Y_i=None, T_start=None, T_end=None, sampling_nTOp=150, + self, rho_g_vec_or_thermo, rho_nu_vec=None, rho_NP_vec=None, P_NP_vec=None, + a_vec=None, t_vec=None, + eta_fac=jnp.asarray(1.), tau_n_fac=jnp.asarray(1.), + nuclear_rates_q=None, + Y_i=None, T_start=None, T_end=None, sampling_nTOp=150, rtol=1e-6, atol=1e-9, solver=Kvaerno3(), max_steps=4096, save_history=False ): """ - Calculate BBN abundance. + Calculate BBN abundance. Parameters ---------- - rho_g_vec : array - Energy density of photons in MeV^4. - rho_nu_vec : array - Energy density of a single neutrino species in MeV^4 - (all neutrinos assumed to have the same temperature). - rho_NP_vec : array - Energy density of all new physics particles in MeV^4. - P_NP_vec : array - Pressure of all new physics particles in MeV^4. + rho_g_vec_or_thermo : array or ThermoResult + Either energy density of photons in MeV^4, or a ThermoResult object + from BackgroundModel. If ThermoResult, other energy density/pressure + arguments and T_start/T_end are extracted automatically. + rho_nu_vec : array, optional + Energy density of a single neutrino species in MeV^4 + (all neutrinos assumed to have the same temperature). + Required if rho_g_vec_or_thermo is an array. + rho_NP_vec : array, optional + Energy density of all new physics particles in MeV^4. + Required if rho_g_vec_or_thermo is an array. + P_NP_vec : array, optional + Pressure of all new physics particles in MeV^4. + Required if rho_g_vec_or_thermo is an array. a_vec : array, optional - Scale factor. If `None`, will be computed in function. + Scale factor. If `None`, will be computed in function. t_vec : array, optional - Time in seconds. If `None`, will be computed in function. + Time in seconds. If `None`, will be computed in function. eta_fac : float, optional - Rescaling factor for baryon-to-photon ratio, 1 for fiducial value - in `const.eta0` (or `const.Omegabh2`). + Rescaling factor for baryon-to-photon ratio, 1 for fiducial value + in `const.eta0` (or `const.Omegabh2`). tau_n_fac : float, optional - Rescaling factor for neutron decay lifetime, 1 for fiducial value - in `const.eta0` (or `const.Omegabh2`). + Rescaling factor for neutron decay lifetime, 1 for fiducial value + in `const.eta0` (or `const.Omegabh2`). nuclear_rates_q : array, optional - q ~ N(0,1) specifies the nuclear rate in its log-normal - distribution. If not specified, will be taken to be `q = 0`. + q ~ N(0,1) specifies the nuclear rate in its log-normal + distribution. If not specified, will be taken to be `q = 0`. Y_i : tuple of float, optional - Initial abundances :math:`n_i/n_b` for species. Length must be equal to - `self.nuclear_net.max_i_species`. Must specify `T_start` and `T_end` if not `None`. - T_start : float - Temperature in MeV to start integration. Must specify `Y_i` and `T_end` if not `None`, otherwise `const.T_start` used. - T_end : float - Temperature in MeV to end integration. + Initial abundances :math:`n_i/n_b` for species. Length must be equal to + `self.nuclear_net.max_i_species`. Must specify `T_start` and `T_end` if not `None`. + T_start : float, optional + Temperature in MeV to start integration. If not specified and + ThermoResult is provided, uses T_start from ThermoResult. + Otherwise uses `const.T_start`. + T_end : float, optional + Temperature in MeV to end integration. If not specified and + ThermoResult is provided, uses T_end from ThermoResult. + Otherwise uses `const.T_end`. sampling_nTOp : int - Number of points to subdivide (`T_end`, `T_start`) for neutron-proton interconversion rate interpolation table. + Number of points to subdivide (`T_end`, `T_start`) for neutron-proton interconversion rate interpolation table. rtol : float, optional - Relative tolerance of the abundance solver. Default is `1e-4`. + Relative tolerance of the abundance solver. Default is `1e-6`. atol : float, optional - Absolute tolerance of the abundance solver. Default is `1e-9`. + Absolute tolerance of the abundance solver. Default is `1e-9`. max_steps : int, optional - Maximum number of steps taken by the solver. Default is `4096`. - Increasing this slows down the code, while decreasing this could - mean that the solver cannot complete the solution. - solver : Diffrax ODE solver - The Diffrax ODE solver to use. A stiff solver is recommended. - Default is the 3rd order Kvaerno solver. + Maximum number of steps taken by the solver. Default is `4096`. + Increasing this slows down the code, while decreasing this could + mean that the solver cannot complete the solution. + solver : Diffrax ODE solver + The Diffrax ODE solver to use. A stiff solver is recommended. + Default is the 3rd order Kvaerno solver. save_history : bool - If `True`, full solution is returned with temperature and time + If `True`, full solution is returned with temperature and time abscissa. Returns ------- tuple of array or array - If `save_history` is set to `True`, a tuple containing an array of EM temperatures, an array of times, and a Diffrax `Solution` instance, which can be called as a function of time. Otherwise, returns yields of all species considered in `self.nuclear_net`. + If `save_history` is set to `True`, a tuple containing an array of EM temperatures, an array of times, and a Diffrax `Solution` instance, which can be called as a function of time. Otherwise, returns yields of all species considered in `self.nuclear_net`. + Examples + -------- + >>> # New streamlined usage with ThermoResult + >>> thermo_result = BackgroundModel()(0.) + >>> abundances = AbundanceModel(nuclear_net)(thermo_result) + + >>> # Backward compatible usage with arrays + >>> t, a, rho_g, rho_nu, rho_extra, P_extra, Neff = thermo_model(0.)[:7] + >>> abundances = abundance_model(rho_g, rho_nu, rho_extra, P_extra, t_vec=t, a_vec=a) """ print('Compiling abundance model...') - if Y_i is not None: - if T_start is None: + # Handle ThermoResult input for streamlined API + if isinstance(rho_g_vec_or_thermo, ThermoResult): + thermo_input = rho_g_vec_or_thermo + rho_g_vec = thermo_input.rho_g_vec + if rho_nu_vec is None: + rho_nu_vec = thermo_input.rho_nu_vec + if rho_NP_vec is None: + rho_NP_vec = thermo_input.rho_extra_vec + if P_NP_vec is None: + P_NP_vec = thermo_input.P_extra_vec + if a_vec is None: + a_vec = thermo_input.a_vec + if t_vec is None: + t_vec = thermo_input.t_vec + # Auto-detect temperature range from ThermoResult if not specified + if T_start is None: + T_start = thermo_input.T_start + if T_end is None: + T_end = thermo_input.T_end + else: + rho_g_vec = rho_g_vec_or_thermo + # Validate required arguments for array input + if rho_nu_vec is None or rho_NP_vec is None or P_NP_vec is None: + raise TypeError( + "When passing arrays directly, rho_nu_vec, rho_NP_vec, and P_NP_vec are required" + ) + + if Y_i is not None: + if T_start is None: raise TypeError('Specifying Y_i requires specifying a T_start') - if T_start is not None: - if Y_i is None: + if T_start is not None: + if Y_i is None: raise TypeError('Specifying T_start requires specifying Y_i') - if nuclear_rates_q is None: + if nuclear_rates_q is None: nuclear_rates_q = jnp.array( [0. for _ in self.nuclear_net.reactions] ) - if t_vec is None: + if t_vec is None: t_vec = self.get_t(rho_g_vec, rho_nu_vec, rho_NP_vec, P_NP_vec) - if a_vec is None: + if a_vec is None: a_vec = self.get_a(rho_g_vec, rho_nu_vec, rho_NP_vec, P_NP_vec) - if T_start is None: + if T_start is None: - T_start = const.T_start + T_start = const.T_start - if T_end is None: + if T_end is None: - T_end = const.T_end + T_end = const.T_end # These are in MeV T_g_vec = thermo.T_g(rho_g_vec) diff --git a/linx/background.py b/linx/background.py index e3081ab..3b6210a 100644 --- a/linx/background.py +++ b/linx/background.py @@ -7,6 +7,7 @@ from diffrax import diffeqsolve, ODETerm, Tsit5, PIDController, SaveAt, Event import linx.thermo as thermo +from linx.thermo import ThermoResult import linx.const as const rho_massless_BE_v = vmap( @@ -185,9 +186,16 @@ def T_EM_check(t, y, args, **kwargs): Neff_vec = thermo.N_eff(rho_tot_vec, rho_g_vec) - return ( - t_vec, a_vec, rho_g_vec, rho_nu_vec, - rho_extra_vec, P_extra_vec, Neff_vec + return ThermoResult( + t_vec=t_vec, + a_vec=a_vec, + rho_g_vec=rho_g_vec, + rho_nu_vec=rho_nu_vec, + rho_extra_vec=rho_extra_vec, + P_extra_vec=P_extra_vec, + Neff_vec=Neff_vec, + T_start=T_start, + T_end=T_end ) @eqx.filter_jit diff --git a/linx/thermo.py b/linx/thermo.py index 85cff9d..45d6fcd 100644 --- a/linx/thermo.py +++ b/linx/thermo.py @@ -1,14 +1,63 @@ -import os +import os +from typing import NamedTuple import numpy as np -import jax.numpy as jnp +import jax.numpy as jnp import jax.lax as lax from jax import grad, vmap, device_put, devices -import linx.const as const +import linx.const as const from linx.special_funcs import Li, K1, K2 + +class ThermoResult(NamedTuple): + """Result from BackgroundModel containing thermodynamic evolution. + + This object encapsulates all outputs from BackgroundModel and can be + passed directly to AbundanceModel. It supports tuple unpacking for + backward compatibility. + + Attributes + ---------- + t_vec : jnp.ndarray + Times in seconds at which thermodynamics are saved. + a_vec : jnp.ndarray + Scale factor at each point in time. + rho_g_vec : jnp.ndarray + Energy density of photons in MeV^4 at each point in time. + rho_nu_vec : jnp.ndarray + Energy density of one species of neutrinos in MeV^4. + rho_extra_vec : jnp.ndarray + Energy density in MeV^4 of extra species at each point in time. + P_extra_vec : jnp.ndarray + Pressure in MeV^4 of extra species at each point in time. + Neff_vec : jnp.ndarray + Effective number of neutrino species at each point in time. + T_start : float + Starting temperature in MeV used for integration. + T_end : float + Ending temperature in MeV used for integration. + + Examples + -------- + >>> # New usage pattern + >>> thermo_result = BackgroundModel()(0.) + >>> abundances = AbundanceModel(nuclear_net)(thermo_result) + + >>> # Backward compatible tuple unpacking + >>> t, a, rho_g, rho_nu, rho_extra, P_extra, Neff = thermo_result[:7] + """ + t_vec: jnp.ndarray + a_vec: jnp.ndarray + rho_g_vec: jnp.ndarray + rho_nu_vec: jnp.ndarray + rho_extra_vec: jnp.ndarray + P_extra_vec: jnp.ndarray + Neff_vec: jnp.ndarray + T_start: float + T_end: float + ########################################### # Cosmology # ############################################ diff --git a/pytest/test_abundances.py b/pytest/test_abundances.py index c006dcc..b610b72 100644 --- a/pytest/test_abundances.py +++ b/pytest/test_abundances.py @@ -27,8 +27,9 @@ @pytest.fixture def sample_inputs(): - t_vec_ref, a_vec_ref, rho_g_vec, rho_nu_vec, rho_NP_vec, P_NP_vec, Neff_vec = thermo_model_DNeff(0.) - + # Use [:7] slice for backward compatibility with ThermoResult (which has 9 fields) + t_vec_ref, a_vec_ref, rho_g_vec, rho_nu_vec, rho_NP_vec, P_NP_vec, Neff_vec = thermo_model_DNeff(0.)[:7] + return t_vec_ref, a_vec_ref, rho_g_vec, rho_nu_vec, rho_NP_vec, P_NP_vec, Neff_vec @pytest.fixture