From bbf720138c2cde59102e038ce2a420536def9987 Mon Sep 17 00:00:00 2001 From: Siddharth Mishra-Sharma Date: Fri, 23 Jan 2026 16:04:37 -0500 Subject: [PATCH] Add diffrax throw option to control solver error handling Closes #17 - Add `throw` parameter to both BackgroundModel and AbundanceModel - Default is True (raises exceptions on solver failure) - Set to False for parameter scans where some combinations may fail - Passed through to all internal diffeqsolve calls Usage: # Default behavior (throw=True) thermo_model = BackgroundModel() abundance_model = AbundanceModel(nuclear_net) # For parameter scans (throw=False) thermo_model = BackgroundModel(throw=False) abundance_model = AbundanceModel(nuclear_net, throw=False) # Failures return partial results instead of raising Co-Authored-By: Claude Opus 4.5 --- linx/abundances.py | 76 +++++++++++++++++++++++++++------------------- linx/background.py | 21 +++++++++---- 2 files changed, 59 insertions(+), 38 deletions(-) diff --git a/linx/abundances.py b/linx/abundances.py index bc0fdb7..ed3ac0b 100644 --- a/linx/abundances.py +++ b/linx/abundances.py @@ -15,45 +15,48 @@ from linx.thermo import rho_EM_std_v, p_EM_std_v, nB from linx.special_funcs import zeta_3 -class AbundanceModel(eqx.Module): +class AbundanceModel(eqx.Module): """ - Abundance model and BBN abundance prediction. + Abundance model and BBN abundance prediction. Attributes ---------- nuclear_net : NuclearRates - Nuclear network to be used for BBN prediction. + Nuclear network to be used for BBN prediction. weak_rates : WeakRates - Weak rates for neutron-proton interconversion. + Weak rates for neutron-proton interconversion. species_dict : dict - Dictionary of species considered in LINX. + Dictionary of species considered in LINX. species_Z : list - Number of protons in each species. + Number of protons in each species. species_N : list - Number of neutrons in each species. + Number of neutrons in each species. species_A : list - Atomic mass number of each species. + Atomic mass number of each species. species_excess_mass : list - Excess mass (mass - A*amu) of each species. + Excess mass (mass - A*amu) of each species. species_spin : list - Spin of each species. + Spin of each species. species_binding_energy : list - Binding energy of each species. + Binding energy of each species. species_mass : list - Mass of each species. + Mass of each species. + throw : bool + Whether to raise exceptions on solver failure. """ - nuclear_net : nucl.NuclearRates - weak_rates : wr.WeakRates + nuclear_net : nucl.NuclearRates + weak_rates : wr.WeakRates species_dict : dict species_Z : list species_N : list - species_A : list + species_A : list species_excess_mass : dict species_spin : list species_binding_energy : list species_mass : list + throw : bool - def __init__(self, nuclear_net, weak_rates=wr.WeakRates()): + def __init__(self, nuclear_net, weak_rates=wr.WeakRates(), throw=True): """ Initialize the AbundanceModel with nuclear and weak rate networks. @@ -65,6 +68,9 @@ def __init__(self, nuclear_net, weak_rates=wr.WeakRates()): weak_rates : WeakRates, optional Weak interaction rates for neutron-proton interconversion. Defaults to standard WeakRates instance. + throw : bool, optional + If True, raise exceptions on solver failure. Default is True. + Set to False for parameter scans where some combinations may fail. Notes ----- @@ -76,6 +82,7 @@ def __init__(self, nuclear_net, weak_rates=wr.WeakRates()): self.nuclear_net = nuclear_net self.weak_rates = weak_rates + self.throw = throw self.species_dict = { 0:'n', 1:'p', 2:'d', 3:'t', 4:'He3', 5:'a', 6:'Li7', 7:'Be7', @@ -291,15 +298,18 @@ def __call__( saveat = SaveAt(t1=True) sol = diffeqsolve( - ODETerm(self.Y_prime), solver, - t0=t_start, t1=t_end, dt0=None, y0=Y_i, - args = ( - a_vec, t_vec, T_g_vec, T_interval_nTOp, nTOp_frwrd, + ODETerm(self.Y_prime), solver, + t0=t_start, t1=t_end, dt0=None, y0=Y_i, + args=( + a_vec, t_vec, T_g_vec, T_interval_nTOp, nTOp_frwrd, nTOp_bkwrd, eta_fac, tau_n_fac, nuclear_rates_q - ), saveat=saveat, stepsize_controller = PIDController( + ), + saveat=saveat, + stepsize_controller=PIDController( rtol=rtol, atol=atol, - ), - max_steps=max_steps + ), + max_steps=max_steps, + throw=self.throw ) if save_history: @@ -369,12 +379,13 @@ def dt_prime(rho_tot, t, args): rho_tot_fin = rho_tot_vec[-1] sol_t = diffeqsolve( - ODETerm(dt_prime), Tsit5(), - t0=rho_tot_init, t1=rho_tot_fin, - y0=1. / (2 * thermo.Hubble(rho_tot_init)), + ODETerm(dt_prime), Tsit5(), + t0=rho_tot_init, t1=rho_tot_fin, + y0=1. / (2 * thermo.Hubble(rho_tot_init)), dt0=None, max_steps=4096, - saveat=SaveAt(ts=rho_tot_vec), - stepsize_controller=PIDController(rtol=1e-8, atol=1e-10) + saveat=SaveAt(ts=rho_tot_vec), + stepsize_controller=PIDController(rtol=1e-8, atol=1e-10), + throw=self.throw ) return sol_t.ys @@ -441,13 +452,14 @@ def dlna_prime(rho_tot, t, args): rho_tot_init = rho_tot_vec[0] rho_tot_fin = rho_tot_vec[-1] - # a_0 = 1 arbitrarily, will rescale later. + # a_0 = 1 arbitrarily, will rescale later. sol_lna = diffeqsolve( - ODETerm(dlna_prime), Tsit5(), - t0=rho_tot_init, t1=rho_tot_fin, + ODETerm(dlna_prime), Tsit5(), + t0=rho_tot_init, t1=rho_tot_fin, y0=0., dt0=None, max_steps=4096, saveat=SaveAt(ts=rho_tot_vec), - stepsize_controller=PIDController(rtol=1e-8, atol=1e-10) + stepsize_controller=PIDController(rtol=1e-8, atol=1e-10), + throw=self.throw ) a_fin = const.T0CMB / T_g_vec[-1] diff --git a/linx/background.py b/linx/background.py index e3081ab..4a37529 100644 --- a/linx/background.py +++ b/linx/background.py @@ -16,7 +16,7 @@ thermo.rho_massless_FD, in_axes=(0, None, None) ) -class BackgroundModel(eqx.Module): +class BackgroundModel(eqx.Module): """Background model. Attributes @@ -31,6 +31,9 @@ class BackgroundModel(eqx.Module): Whether to use leading order QED correction. Default is `True`. NLO : bool, optional Whether to use next-to-leading order QED correction. Default is True. + throw : bool, optional + Whether to raise exceptions on solver failure. Default is `True`. + Set to `False` for parameter scans where some combinations may fail. """ decoupled : bool @@ -38,8 +41,9 @@ class BackgroundModel(eqx.Module): collision_me : bool LO : bool NLO : bool + throw : bool - def __init__(self, decoupled=False, use_FD=True, collision_me=True, LO=True, NLO = True): + def __init__(self, decoupled=False, use_FD=True, collision_me=True, LO=True, NLO=True, throw=True): """ Initialize the BackgroundModel with thermodynamic options. @@ -56,6 +60,9 @@ def __init__(self, decoupled=False, use_FD=True, collision_me=True, LO=True, NLO If True, include leading order QED corrections. Default is True. NLO : bool, optional If True, include next-to-leading order QED corrections. Default is True. + throw : bool, optional + If True, raise exceptions on solver failure. Default is True. + Set to False for parameter scans where some combinations may fail. """ self.decoupled = decoupled @@ -63,6 +70,7 @@ def __init__(self, decoupled=False, use_FD=True, collision_me=True, LO=True, NLO self.collision_me = collision_me self.LO = LO self.NLO = NLO + self.throw = throw @eqx.filter_jit def __call__( @@ -131,12 +139,13 @@ def T_EM_check(t, y, args, **kwargs): sol = diffeqsolve( ODETerm(self.dY), solver, args=(lna_init, rho_extra_init), - t0=0., t1=jnp.inf, dt0=None, y0=Y0, + t0=0., t1=jnp.inf, dt0=None, y0=Y0, saveat=SaveAt(steps=True), event=Event(T_EM_check), - stepsize_controller = PIDController( + stepsize_controller=PIDController( rtol=rtol, atol=atol - ), - max_steps=max_steps + ), + max_steps=max_steps, + throw=self.throw ) a_vec = jnp.exp(sol.ys[0])