Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion example_notebooks/Schramm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1191,7 +1191,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
"version": "3.12.11"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion example_notebooks/background_evolution.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.12.11"
}
},
"nbformat": 4,
Expand Down
518 changes: 518 additions & 0 deletions example_notebooks/scratch.ipynb

Large diffs are not rendered by default.

71 changes: 71 additions & 0 deletions linx/P_QED.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import os

import numpy as np

import jax.numpy as jnp
import jax.lax as lax
from jax import grad, vmap

import linx.const as const
import equinox as eqx


# high temperature behavior is not correct, probably because bounds need to
# change with T.
# Low temperature behavior is perfect.


def explicit_P0(T, me): # not needed, computed in thermo
# 4.5 of https://arxiv.org/pdf/1911.04504

prefac = T/jnp.pi**2

p = jnp.linspace(0,50*T,num=3000) # this integral peaks at p close to T, integrating to 50*T is fine as long as resolution is very good
Ee = jnp.sqrt(p**2 + me**2)
integrand = p**2 * jnp.log( (1 + jnp.exp(-Ee/T))**2 / (1 - jnp.exp(-p/T)) )

res = jnp.trapezoid(jnp.nan_to_num(integrand,nan=0),p) # p = 0 gives a nan, should be 0

return prefac * res


def explicit_P2(T, me):
# first compute 4.7 of https://arxiv.org/pdf/1911.04504, but ignoring the
# last term because it's itty bitty according to them

e = jnp.sqrt(const.aFS * 4 * jnp.pi)
prefac1 = - e**2 * T**2 / (12 * jnp.pi**2)
prefac2 = - e**2/(8 * jnp.pi**4)

p = jnp.linspace(0,50*T,num=2000) # this integral peaks at p close to T, integrating to 50*T is fine as long as resolution is good
Ep = jnp.sqrt(p**2 + me**2)
integrand = p**2/Ep * 2/(jnp.exp(Ep/T) + 1)

res = jnp.trapezoid(integrand,p)

return prefac1 * res + prefac2 * res**2

def explicit_P3(T, me):
# compute 4.24 of https://arxiv.org/pdf/1911.04504

e = jnp.sqrt(const.aFS * 4 * jnp.pi)
prefac = e**3 * T/(12 * jnp.pi**4)

p = jnp.linspace(0,50*T,num=2000) # this integral also peaks at p close to T, integrating to 50 is fine as long as resolution is good
Ep = jnp.sqrt(p**2 + me**2)
integrand = (p**2 + Ep**2)/Ep * 2/(jnp.exp(Ep/T) + 1)

res = jnp.trapezoid(integrand,p)

return prefac * res**(3./2)

def P_QED(T,me): # not needed, sums in thermo
return explicit_P0(T, me) + explicit_P2(T, me) + explicit_P3(T, me)


dPdTQED_2 = grad(explicit_P2,argnums=0)
dPdTQED_3 = grad(explicit_P3,argnums=0)

# we don't need these computed explicitly, actually
d2PdT2QED_2 = grad(dPdTQED_2,argnums=0)
d2PdT2QED_3 = grad(dPdTQED_3,argnums=0)
23 changes: 23 additions & 0 deletions linx/Untitled.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "a0074e73-6ca5-47ee-a7bc-87a25e8fb3a5",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "",
"name": ""
},
"language_info": {
"name": ""
}
},
"nbformat": 4,
"nbformat_minor": 5
}
36 changes: 24 additions & 12 deletions linx/abundances.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import linx.thermo as thermo
from linx.thermo import rho_EM_std_v, p_EM_std_v, nB
from linx.special_funcs import zeta_3
from linx.tau_n_vary_me import tau_n_fac_vary_me

class AbundanceModel(eqx.Module):
"""
Expand All @@ -39,8 +40,6 @@ class AbundanceModel(eqx.Module):
Spin of each species.
species_binding_energy : list
Binding energy of each species.
species_mass : list
Mass of each species.
throw : bool
Whether to raise exceptions on solver failure.
"""
Expand All @@ -53,7 +52,6 @@ class AbundanceModel(eqx.Module):
species_excess_mass : dict
species_spin : list
species_binding_energy : list
species_mass : list
throw : bool

def __init__(self, nuclear_net, weak_rates=wr.WeakRates(), throw=True):
Expand Down Expand Up @@ -115,16 +113,17 @@ def __init__(self, nuclear_net, weak_rates=wr.WeakRates(), throw=True):
)

# in MeV
self.species_mass = (
self.species_A * ma + self.species_excess_mass - self.species_Z * me
)
# requires recompilation for each me--moved to YNSE method
# self.species_mass = (
# self.species_A * ma + self.species_excess_mass - self.species_Z * me
# )

@eqx.filter_jit
def __call__(
self, rho_g_vec, rho_nu_vec, rho_NP_vec, P_NP_vec,
a_vec=None, t_vec=None,
eta_fac=jnp.asarray(1.), tau_n_fac = jnp.asarray(1.),
nuclear_rates_q=None,
nuclear_rates_q=None, me = const.me,
Y_i=None, T_start=None, T_end=None, sampling_nTOp=150,
rtol=1e-6, atol=1e-9, solver=Kvaerno3(),
max_steps=4096,
Expand Down Expand Up @@ -153,10 +152,12 @@ def __call__(
in `const.eta0` (or `const.Omegabh2`).
tau_n_fac : float, optional
Rescaling factor for neutron decay lifetime, 1 for fiducial value
in `const.eta0` (or `const.Omegabh2`).
in `const.tau_n`.
nuclear_rates_q : array, optional
q ~ N(0,1) specifies the nuclear rate in its log-normal
distribution. If not specified, will be taken to be `q = 0`.
me : float, optional
Electron mass in MeV. Defaults to `const.me`.
Y_i : tuple of float, optional
Initial abundances :math:`n_i/n_b` for species. Length must be equal to
`self.nuclear_net.max_i_species`. Must specify `T_start` and `T_end` if not `None`.
Expand Down Expand Up @@ -217,6 +218,11 @@ def __call__(

T_end = const.T_end

# check if the user has varied me, and adjust the neutron lifetime if so
diff = jnp.abs(me - const.me)/const.me
tau_n_fac = jnp.where(diff > 1e-5, tau_n_fac_vary_me(me), 1.) * tau_n_fac


# These are in MeV
T_g_vec = thermo.T_g(rho_g_vec)
T_nu_vec = thermo.T_nu(rho_nu_vec)
Expand Down Expand Up @@ -266,7 +272,7 @@ def __call__(

T_interval_nTOp, nTOp_frwrd, nTOp_bkwrd = self.weak_rates(
jnp.array([T_g_vec, T_nu_vec]),
T_start=T_start, T_end=T_end, sampling_nTOp=sampling_nTOp
T_start=T_start, T_end=T_end, sampling_nTOp=sampling_nTOp, me=me
)

##################################
Expand All @@ -285,7 +291,7 @@ def __call__(
n_CMB_start = thermo.n_massless_BE(T_start, 0., 2.)
eta_T_start = nB(a_start, eta_fac=eta_fac) / n_CMB_start

Y_YNSE = self.YNSE(Yn_i, Yp_i, const.T_start, eta_T_start)
Y_YNSE = self.YNSE(Yn_i, Yp_i, const.T_start, eta_T_start, me)

Y_others_i = Y_YNSE[2:self.nuclear_net.max_i_species]

Expand Down Expand Up @@ -527,7 +533,7 @@ def Y_prime(self, t, Y, args):

return dY

def YNSE(self, Yn, Yp, T, eta):
def YNSE(self, Yn, Yp, T, eta, me=const.me):
"""
Nuclear statistical equilibrium yields for all species.

Expand All @@ -541,15 +547,21 @@ def YNSE(self, Yn, Yp, T, eta):
The temperature of the baryons in MeV.
eta : float
The baryon-to-photon ratio.
me: float, optional
Electron mass in MeV. Defaults to const.me

Returns
-------
array
Yields for all species considered in LINX (13 of them).
"""

species_mass = (
self.species_A * ma + self.species_excess_mass - self.species_Z * me
)

A32Overmn = (
self.species_mass / (
species_mass / (
mn**(self.species_A - self.species_Z)
* mp**self.species_Z
)
Expand Down
40 changes: 25 additions & 15 deletions linx/background.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ class BackgroundModel(eqx.Module):
collision_me : bool
LO : bool
NLO : bool
max_steps : int
throw : bool

def __init__(self, decoupled=False, use_FD=True, collision_me=True, LO=True, NLO=True, throw=True):
def __init__(self, decoupled=False, use_FD=True, collision_me=True, LO=True, NLO = True, throw=True, max_steps=512):

"""
Initialize the BackgroundModel with thermodynamic options.

Expand All @@ -70,25 +72,28 @@ def __init__(self, decoupled=False, use_FD=True, collision_me=True, LO=True, NLO
self.collision_me = collision_me
self.LO = LO
self.NLO = NLO
self.max_steps = max_steps
self.throw = throw

@eqx.filter_jit
def __call__(
self, Delt_Neff_init, T_start=const.T_start,
T_end=const.T_end, rtol=1e-8, atol=1e-10,
solver=Tsit5(), max_steps=512
T_end=const.T_end, me=const.me, rtol=1e-8, atol=1e-10,
solver=Tsit5(),
):
""" Calculate thermodynamics given an initial :math:`\\Delta N_\\mathrm{eff}`.

Parameters
----------
Delt_Neff_init : float
Initial :math:`\\Delta N_\\mathrm{eff}`. Can be positive or negative.
T_EM_init : float
T_EM_init : float, optional
Initial EM (and neutrino) temperature. Default is `const.T_start`.
T_EM_end : float
T_EM_end : float, optional
Final EM temperature to terminate integration at. Default is
`const.T_end`.
me : float, optional
Electron mass in MeV. Defaults to `const.me`.
rtol : float, optional
Relative tolerance of the abundance solver. Default is `1e-8`.
atol : float, optional
Expand Down Expand Up @@ -133,18 +138,23 @@ def __call__(
) * Delt_Neff_init

Y0 = (lna_init, T_EM_init, T_nu_init)

# use parametric form to estimate correct start
# time given T_start, assuming T ~ t^(-1/2) and
# initial g_* is 10.75
t0 = (1.5/T_start * 10.75**(-1./4))**2

def T_EM_check(t, y, args, **kwargs):
return y[1] < T_end

sol = diffeqsolve(
ODETerm(self.dY), solver, args=(lna_init, rho_extra_init),
t0=0., t1=jnp.inf, dt0=None, y0=Y0,
ODETerm(self.dY), solver, args=(lna_init, rho_extra_init, me),
t0 = t0, t1=jnp.inf, dt0=None, y0=Y0,
saveat=SaveAt(steps=True), event=Event(T_EM_check),
stepsize_controller=PIDController(
rtol=rtol, atol=atol
),
max_steps=max_steps,
),
max_steps=self.max_steps,
throw=self.throw
)

Expand All @@ -159,7 +169,7 @@ def T_EM_check(t, y, args, **kwargs):
last_step_ind = jnp.max(
jnp.argwhere(
sol.ys[1] < T_end,
size=512
size=self.max_steps
)[:,0]
)

Expand Down Expand Up @@ -218,11 +228,11 @@ def dY(self, t, Y, args):
"""

lna, T_g, T_nu = Y
lna_init, rho_extra_init = args
lna_init, rho_extra_init, me = args

rho_EM = thermo.rho_EM_std(T_g, LO=self.LO, NLO=self.NLO)
rho_plus_p_EM = thermo.rho_plus_p_EM_std(T_g, LO=self.LO, NLO=self.NLO)
drho_EM_dT_g = thermo.drho_EM_dT_g_std(T_g, LO=self.LO, NLO=self.NLO)
rho_EM = thermo.rho_EM_std(T_g, me=me, LO=self.LO, NLO=self.NLO)
rho_plus_p_EM = thermo.rho_plus_p_EM_std(T_g, me=me, LO=self.LO, NLO=self.NLO)
drho_EM_dT_g = thermo.drho_EM_dT_g_std(T_g, me=me, LO=self.LO, NLO=self.NLO)

rho_nu = 3*thermo.rho_nue_std(T_nu)
rho_plus_p_nu = (4/3) * rho_nu
Expand All @@ -233,7 +243,7 @@ def dY(self, t, Y, args):
H = thermo.Hubble(rho_EM + rho_nu + rho_extra)

C_rho_nue, C_rho_numu, _, _ = thermo.collision_terms_std(
T_g, T_nu, T_nu, decoupled=self.decoupled, use_FD=self.use_FD, collision_me=self.collision_me
T_g, T_nu, T_nu, me=me, decoupled=self.decoupled, use_FD=self.use_FD, collision_me=self.collision_me
)

drho_EM_dt = -3 * H * rho_plus_p_EM - C_rho_nue - 2*C_rho_numu
Expand Down
Loading