From 4ce5633d5271145db7970298f9ab168911d46326 Mon Sep 17 00:00:00 2001 From: Siddharth Mishra-Sharma Date: Fri, 23 Jan 2026 15:57:16 -0500 Subject: [PATCH] Add ThermoResult object and diffrax throw option Implements two features: 1. ThermoResult object (closes #14): - BackgroundModel now returns a ThermoResult NamedTuple - Contains all thermo outputs plus T_start/T_end used for integration - AbundanceModel accepts ThermoResult directly with auto T_start/T_end detection - Maintains backward compatibility via tuple unpacking (use [:7] slice) 2. diffrax throw option (closes #17): - Add `throw` parameter to both BackgroundModel and AbundanceModel - Default is True (raises exceptions on solver failure) - Set to False for parameter scans where some combinations may fail - Passed through to all internal diffeqsolve calls Usage examples: # New streamlined API thermo_result = BackgroundModel()(0.) abundances = AbundanceModel(nuclear_net)(thermo_result) # Parameter scans with throw=False thermo_model = BackgroundModel(throw=False) abundance_model = AbundanceModel(nuclear_net, throw=False) # Backward compatible t, a, rho_g, rho_nu, rho_extra, P_extra, Neff = thermo_model(0.)[:7] Co-Authored-By: Claude Opus 4.5 --- linx/abundances.py | 225 ++++++++++++++++++++++++-------------- linx/background.py | 35 ++++-- linx/thermo.py | 55 +++++++++- pytest/test_abundances.py | 5 +- 4 files changed, 222 insertions(+), 98 deletions(-) diff --git a/linx/abundances.py b/linx/abundances.py index bc0fdb7..a6fb3ad 100644 --- a/linx/abundances.py +++ b/linx/abundances.py @@ -12,48 +12,51 @@ from linx.const import ma, me, mn, mp import linx.weak_rates as wr import linx.thermo as thermo -from linx.thermo import rho_EM_std_v, p_EM_std_v, nB +from linx.thermo import rho_EM_std_v, p_EM_std_v, nB, ThermoResult from linx.special_funcs import zeta_3 -class AbundanceModel(eqx.Module): +class AbundanceModel(eqx.Module): """ - Abundance model and BBN abundance prediction. + Abundance model and BBN abundance prediction. Attributes ---------- nuclear_net : NuclearRates - Nuclear network to be used for BBN prediction. + Nuclear network to be used for BBN prediction. weak_rates : WeakRates - Weak rates for neutron-proton interconversion. + Weak rates for neutron-proton interconversion. species_dict : dict - Dictionary of species considered in LINX. + Dictionary of species considered in LINX. species_Z : list - Number of protons in each species. + Number of protons in each species. species_N : list - Number of neutrons in each species. + Number of neutrons in each species. species_A : list - Atomic mass number of each species. + Atomic mass number of each species. species_excess_mass : list - Excess mass (mass - A*amu) of each species. + Excess mass (mass - A*amu) of each species. species_spin : list - Spin of each species. + Spin of each species. species_binding_energy : list - Binding energy of each species. + Binding energy of each species. species_mass : list - Mass of each species. + Mass of each species. + throw : bool + Whether to raise exceptions on solver failure. """ - nuclear_net : nucl.NuclearRates - weak_rates : wr.WeakRates + nuclear_net : nucl.NuclearRates + weak_rates : wr.WeakRates species_dict : dict species_Z : list species_N : list - species_A : list + species_A : list species_excess_mass : dict species_spin : list species_binding_energy : list species_mass : list + throw : bool - def __init__(self, nuclear_net, weak_rates=wr.WeakRates()): + def __init__(self, nuclear_net, weak_rates=wr.WeakRates(), throw=True): """ Initialize the AbundanceModel with nuclear and weak rate networks. @@ -65,6 +68,9 @@ def __init__(self, nuclear_net, weak_rates=wr.WeakRates()): weak_rates : WeakRates, optional Weak interaction rates for neutron-proton interconversion. Defaults to standard WeakRates instance. + throw : bool, optional + If True, raise exceptions on solver failure. Default is True. + Set to False for parameter scans where some combinations may fail. Notes ----- @@ -76,6 +82,7 @@ def __init__(self, nuclear_net, weak_rates=wr.WeakRates()): self.nuclear_net = nuclear_net self.weak_rates = weak_rates + self.throw = throw self.species_dict = { 0:'n', 1:'p', 2:'d', 3:'t', 4:'He3', 5:'a', 6:'Li7', 7:'Be7', @@ -114,101 +121,146 @@ def __init__(self, nuclear_net, weak_rates=wr.WeakRates()): @eqx.filter_jit def __call__( - self, rho_g_vec, rho_nu_vec, rho_NP_vec, P_NP_vec, - a_vec=None, t_vec=None, - eta_fac=jnp.asarray(1.), tau_n_fac = jnp.asarray(1.), - nuclear_rates_q=None, - Y_i=None, T_start=None, T_end=None, sampling_nTOp=150, + self, rho_g_vec_or_thermo, rho_nu_vec=None, rho_NP_vec=None, P_NP_vec=None, + a_vec=None, t_vec=None, + eta_fac=jnp.asarray(1.), tau_n_fac=jnp.asarray(1.), + nuclear_rates_q=None, + Y_i=None, T_start=None, T_end=None, sampling_nTOp=150, rtol=1e-6, atol=1e-9, solver=Kvaerno3(), max_steps=4096, save_history=False ): """ - Calculate BBN abundance. + Calculate BBN abundance. Parameters ---------- - rho_g_vec : array - Energy density of photons in MeV^4. - rho_nu_vec : array - Energy density of a single neutrino species in MeV^4 - (all neutrinos assumed to have the same temperature). - rho_NP_vec : array - Energy density of all new physics particles in MeV^4. - P_NP_vec : array - Pressure of all new physics particles in MeV^4. + rho_g_vec_or_thermo : array or ThermoResult + Either energy density of photons in MeV^4, or a ThermoResult object + from BackgroundModel. If ThermoResult, other energy density/pressure + arguments and T_start/T_end are extracted automatically. + rho_nu_vec : array, optional + Energy density of a single neutrino species in MeV^4 + (all neutrinos assumed to have the same temperature). + Required if rho_g_vec_or_thermo is an array. + rho_NP_vec : array, optional + Energy density of all new physics particles in MeV^4. + Required if rho_g_vec_or_thermo is an array. + P_NP_vec : array, optional + Pressure of all new physics particles in MeV^4. + Required if rho_g_vec_or_thermo is an array. a_vec : array, optional - Scale factor. If `None`, will be computed in function. + Scale factor. If `None`, will be computed in function. t_vec : array, optional - Time in seconds. If `None`, will be computed in function. + Time in seconds. If `None`, will be computed in function. eta_fac : float, optional - Rescaling factor for baryon-to-photon ratio, 1 for fiducial value - in `const.eta0` (or `const.Omegabh2`). + Rescaling factor for baryon-to-photon ratio, 1 for fiducial value + in `const.eta0` (or `const.Omegabh2`). tau_n_fac : float, optional - Rescaling factor for neutron decay lifetime, 1 for fiducial value - in `const.eta0` (or `const.Omegabh2`). + Rescaling factor for neutron decay lifetime, 1 for fiducial value + in `const.eta0` (or `const.Omegabh2`). nuclear_rates_q : array, optional - q ~ N(0,1) specifies the nuclear rate in its log-normal - distribution. If not specified, will be taken to be `q = 0`. + q ~ N(0,1) specifies the nuclear rate in its log-normal + distribution. If not specified, will be taken to be `q = 0`. Y_i : tuple of float, optional - Initial abundances :math:`n_i/n_b` for species. Length must be equal to - `self.nuclear_net.max_i_species`. Must specify `T_start` and `T_end` if not `None`. - T_start : float - Temperature in MeV to start integration. Must specify `Y_i` and `T_end` if not `None`, otherwise `const.T_start` used. - T_end : float - Temperature in MeV to end integration. + Initial abundances :math:`n_i/n_b` for species. Length must be equal to + `self.nuclear_net.max_i_species`. Must specify `T_start` and `T_end` if not `None`. + T_start : float, optional + Temperature in MeV to start integration. If not specified and + ThermoResult is provided, uses T_start from ThermoResult. + Otherwise uses `const.T_start`. + T_end : float, optional + Temperature in MeV to end integration. If not specified and + ThermoResult is provided, uses T_end from ThermoResult. + Otherwise uses `const.T_end`. sampling_nTOp : int - Number of points to subdivide (`T_end`, `T_start`) for neutron-proton interconversion rate interpolation table. + Number of points to subdivide (`T_end`, `T_start`) for neutron-proton interconversion rate interpolation table. rtol : float, optional - Relative tolerance of the abundance solver. Default is `1e-4`. + Relative tolerance of the abundance solver. Default is `1e-6`. atol : float, optional - Absolute tolerance of the abundance solver. Default is `1e-9`. + Absolute tolerance of the abundance solver. Default is `1e-9`. max_steps : int, optional - Maximum number of steps taken by the solver. Default is `4096`. - Increasing this slows down the code, while decreasing this could - mean that the solver cannot complete the solution. - solver : Diffrax ODE solver - The Diffrax ODE solver to use. A stiff solver is recommended. - Default is the 3rd order Kvaerno solver. + Maximum number of steps taken by the solver. Default is `4096`. + Increasing this slows down the code, while decreasing this could + mean that the solver cannot complete the solution. + solver : Diffrax ODE solver + The Diffrax ODE solver to use. A stiff solver is recommended. + Default is the 3rd order Kvaerno solver. save_history : bool - If `True`, full solution is returned with temperature and time + If `True`, full solution is returned with temperature and time abscissa. Returns ------- tuple of array or array - If `save_history` is set to `True`, a tuple containing an array of EM temperatures, an array of times, and a Diffrax `Solution` instance, which can be called as a function of time. Otherwise, returns yields of all species considered in `self.nuclear_net`. + If `save_history` is set to `True`, a tuple containing an array of EM temperatures, an array of times, and a Diffrax `Solution` instance, which can be called as a function of time. Otherwise, returns yields of all species considered in `self.nuclear_net`. + Examples + -------- + >>> # New streamlined usage with ThermoResult + >>> thermo_result = BackgroundModel()(0.) + >>> abundances = AbundanceModel(nuclear_net)(thermo_result) + + >>> # Backward compatible usage with arrays + >>> t, a, rho_g, rho_nu, rho_extra, P_extra, Neff = thermo_model(0.)[:7] + >>> abundances = abundance_model(rho_g, rho_nu, rho_extra, P_extra, t_vec=t, a_vec=a) """ print('Compiling abundance model...') - if Y_i is not None: - if T_start is None: + # Handle ThermoResult input for streamlined API + if isinstance(rho_g_vec_or_thermo, ThermoResult): + thermo_input = rho_g_vec_or_thermo + rho_g_vec = thermo_input.rho_g_vec + if rho_nu_vec is None: + rho_nu_vec = thermo_input.rho_nu_vec + if rho_NP_vec is None: + rho_NP_vec = thermo_input.rho_extra_vec + if P_NP_vec is None: + P_NP_vec = thermo_input.P_extra_vec + if a_vec is None: + a_vec = thermo_input.a_vec + if t_vec is None: + t_vec = thermo_input.t_vec + # Auto-detect temperature range from ThermoResult if not specified + if T_start is None: + T_start = thermo_input.T_start + if T_end is None: + T_end = thermo_input.T_end + else: + rho_g_vec = rho_g_vec_or_thermo + # Validate required arguments for array input + if rho_nu_vec is None or rho_NP_vec is None or P_NP_vec is None: + raise TypeError( + "When passing arrays directly, rho_nu_vec, rho_NP_vec, and P_NP_vec are required" + ) + + if Y_i is not None: + if T_start is None: raise TypeError('Specifying Y_i requires specifying a T_start') - if T_start is not None: - if Y_i is None: + if T_start is not None: + if Y_i is None: raise TypeError('Specifying T_start requires specifying Y_i') - if nuclear_rates_q is None: + if nuclear_rates_q is None: nuclear_rates_q = jnp.array( [0. for _ in self.nuclear_net.reactions] ) - if t_vec is None: + if t_vec is None: t_vec = self.get_t(rho_g_vec, rho_nu_vec, rho_NP_vec, P_NP_vec) - if a_vec is None: + if a_vec is None: a_vec = self.get_a(rho_g_vec, rho_nu_vec, rho_NP_vec, P_NP_vec) - if T_start is None: + if T_start is None: - T_start = const.T_start + T_start = const.T_start - if T_end is None: + if T_end is None: - T_end = const.T_end + T_end = const.T_end # These are in MeV T_g_vec = thermo.T_g(rho_g_vec) @@ -291,15 +343,18 @@ def __call__( saveat = SaveAt(t1=True) sol = diffeqsolve( - ODETerm(self.Y_prime), solver, - t0=t_start, t1=t_end, dt0=None, y0=Y_i, - args = ( - a_vec, t_vec, T_g_vec, T_interval_nTOp, nTOp_frwrd, + ODETerm(self.Y_prime), solver, + t0=t_start, t1=t_end, dt0=None, y0=Y_i, + args=( + a_vec, t_vec, T_g_vec, T_interval_nTOp, nTOp_frwrd, nTOp_bkwrd, eta_fac, tau_n_fac, nuclear_rates_q - ), saveat=saveat, stepsize_controller = PIDController( + ), + saveat=saveat, + stepsize_controller=PIDController( rtol=rtol, atol=atol, - ), - max_steps=max_steps + ), + max_steps=max_steps, + throw=self.throw ) if save_history: @@ -369,12 +424,13 @@ def dt_prime(rho_tot, t, args): rho_tot_fin = rho_tot_vec[-1] sol_t = diffeqsolve( - ODETerm(dt_prime), Tsit5(), - t0=rho_tot_init, t1=rho_tot_fin, - y0=1. / (2 * thermo.Hubble(rho_tot_init)), + ODETerm(dt_prime), Tsit5(), + t0=rho_tot_init, t1=rho_tot_fin, + y0=1. / (2 * thermo.Hubble(rho_tot_init)), dt0=None, max_steps=4096, - saveat=SaveAt(ts=rho_tot_vec), - stepsize_controller=PIDController(rtol=1e-8, atol=1e-10) + saveat=SaveAt(ts=rho_tot_vec), + stepsize_controller=PIDController(rtol=1e-8, atol=1e-10), + throw=self.throw ) return sol_t.ys @@ -441,13 +497,14 @@ def dlna_prime(rho_tot, t, args): rho_tot_init = rho_tot_vec[0] rho_tot_fin = rho_tot_vec[-1] - # a_0 = 1 arbitrarily, will rescale later. + # a_0 = 1 arbitrarily, will rescale later. sol_lna = diffeqsolve( - ODETerm(dlna_prime), Tsit5(), - t0=rho_tot_init, t1=rho_tot_fin, + ODETerm(dlna_prime), Tsit5(), + t0=rho_tot_init, t1=rho_tot_fin, y0=0., dt0=None, max_steps=4096, saveat=SaveAt(ts=rho_tot_vec), - stepsize_controller=PIDController(rtol=1e-8, atol=1e-10) + stepsize_controller=PIDController(rtol=1e-8, atol=1e-10), + throw=self.throw ) a_fin = const.T0CMB / T_g_vec[-1] diff --git a/linx/background.py b/linx/background.py index e3081ab..63267b0 100644 --- a/linx/background.py +++ b/linx/background.py @@ -7,6 +7,7 @@ from diffrax import diffeqsolve, ODETerm, Tsit5, PIDController, SaveAt, Event import linx.thermo as thermo +from linx.thermo import ThermoResult import linx.const as const rho_massless_BE_v = vmap( @@ -16,7 +17,7 @@ thermo.rho_massless_FD, in_axes=(0, None, None) ) -class BackgroundModel(eqx.Module): +class BackgroundModel(eqx.Module): """Background model. Attributes @@ -31,6 +32,9 @@ class BackgroundModel(eqx.Module): Whether to use leading order QED correction. Default is `True`. NLO : bool, optional Whether to use next-to-leading order QED correction. Default is True. + throw : bool, optional + Whether to raise exceptions on solver failure. Default is `True`. + Set to `False` for parameter scans where some combinations may fail. """ decoupled : bool @@ -38,8 +42,9 @@ class BackgroundModel(eqx.Module): collision_me : bool LO : bool NLO : bool + throw : bool - def __init__(self, decoupled=False, use_FD=True, collision_me=True, LO=True, NLO = True): + def __init__(self, decoupled=False, use_FD=True, collision_me=True, LO=True, NLO=True, throw=True): """ Initialize the BackgroundModel with thermodynamic options. @@ -56,6 +61,9 @@ def __init__(self, decoupled=False, use_FD=True, collision_me=True, LO=True, NLO If True, include leading order QED corrections. Default is True. NLO : bool, optional If True, include next-to-leading order QED corrections. Default is True. + throw : bool, optional + If True, raise exceptions on solver failure. Default is True. + Set to False for parameter scans where some combinations may fail. """ self.decoupled = decoupled @@ -63,6 +71,7 @@ def __init__(self, decoupled=False, use_FD=True, collision_me=True, LO=True, NLO self.collision_me = collision_me self.LO = LO self.NLO = NLO + self.throw = throw @eqx.filter_jit def __call__( @@ -131,12 +140,13 @@ def T_EM_check(t, y, args, **kwargs): sol = diffeqsolve( ODETerm(self.dY), solver, args=(lna_init, rho_extra_init), - t0=0., t1=jnp.inf, dt0=None, y0=Y0, + t0=0., t1=jnp.inf, dt0=None, y0=Y0, saveat=SaveAt(steps=True), event=Event(T_EM_check), - stepsize_controller = PIDController( + stepsize_controller=PIDController( rtol=rtol, atol=atol - ), - max_steps=max_steps + ), + max_steps=max_steps, + throw=self.throw ) a_vec = jnp.exp(sol.ys[0]) @@ -185,9 +195,16 @@ def T_EM_check(t, y, args, **kwargs): Neff_vec = thermo.N_eff(rho_tot_vec, rho_g_vec) - return ( - t_vec, a_vec, rho_g_vec, rho_nu_vec, - rho_extra_vec, P_extra_vec, Neff_vec + return ThermoResult( + t_vec=t_vec, + a_vec=a_vec, + rho_g_vec=rho_g_vec, + rho_nu_vec=rho_nu_vec, + rho_extra_vec=rho_extra_vec, + P_extra_vec=P_extra_vec, + Neff_vec=Neff_vec, + T_start=T_start, + T_end=T_end ) @eqx.filter_jit diff --git a/linx/thermo.py b/linx/thermo.py index 85cff9d..45d6fcd 100644 --- a/linx/thermo.py +++ b/linx/thermo.py @@ -1,14 +1,63 @@ -import os +import os +from typing import NamedTuple import numpy as np -import jax.numpy as jnp +import jax.numpy as jnp import jax.lax as lax from jax import grad, vmap, device_put, devices -import linx.const as const +import linx.const as const from linx.special_funcs import Li, K1, K2 + +class ThermoResult(NamedTuple): + """Result from BackgroundModel containing thermodynamic evolution. + + This object encapsulates all outputs from BackgroundModel and can be + passed directly to AbundanceModel. It supports tuple unpacking for + backward compatibility. + + Attributes + ---------- + t_vec : jnp.ndarray + Times in seconds at which thermodynamics are saved. + a_vec : jnp.ndarray + Scale factor at each point in time. + rho_g_vec : jnp.ndarray + Energy density of photons in MeV^4 at each point in time. + rho_nu_vec : jnp.ndarray + Energy density of one species of neutrinos in MeV^4. + rho_extra_vec : jnp.ndarray + Energy density in MeV^4 of extra species at each point in time. + P_extra_vec : jnp.ndarray + Pressure in MeV^4 of extra species at each point in time. + Neff_vec : jnp.ndarray + Effective number of neutrino species at each point in time. + T_start : float + Starting temperature in MeV used for integration. + T_end : float + Ending temperature in MeV used for integration. + + Examples + -------- + >>> # New usage pattern + >>> thermo_result = BackgroundModel()(0.) + >>> abundances = AbundanceModel(nuclear_net)(thermo_result) + + >>> # Backward compatible tuple unpacking + >>> t, a, rho_g, rho_nu, rho_extra, P_extra, Neff = thermo_result[:7] + """ + t_vec: jnp.ndarray + a_vec: jnp.ndarray + rho_g_vec: jnp.ndarray + rho_nu_vec: jnp.ndarray + rho_extra_vec: jnp.ndarray + P_extra_vec: jnp.ndarray + Neff_vec: jnp.ndarray + T_start: float + T_end: float + ########################################### # Cosmology # ############################################ diff --git a/pytest/test_abundances.py b/pytest/test_abundances.py index c006dcc..b610b72 100644 --- a/pytest/test_abundances.py +++ b/pytest/test_abundances.py @@ -27,8 +27,9 @@ @pytest.fixture def sample_inputs(): - t_vec_ref, a_vec_ref, rho_g_vec, rho_nu_vec, rho_NP_vec, P_NP_vec, Neff_vec = thermo_model_DNeff(0.) - + # Use [:7] slice for backward compatibility with ThermoResult (which has 9 fields) + t_vec_ref, a_vec_ref, rho_g_vec, rho_nu_vec, rho_NP_vec, P_NP_vec, Neff_vec = thermo_model_DNeff(0.)[:7] + return t_vec_ref, a_vec_ref, rho_g_vec, rho_nu_vec, rho_NP_vec, P_NP_vec, Neff_vec @pytest.fixture