Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 38 additions & 25 deletions abcmb/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from jax import jit, config, lax
from jax import jit, config, lax, tree_util
import jax.numpy as jnp
from jaxtyping import Array
import numpy as np
Expand Down Expand Up @@ -168,8 +168,7 @@ def __call__(self, params : dict = {}):


full_params = self.add_derived_parameters(params)
output, aux = self.run_cosmology_abbr(full_params)
return output, aux
return self.run_cosmology_abbr(full_params)

### JITTED OR JITTABLE FUNCTIONS ###

Expand Down Expand Up @@ -205,30 +204,24 @@ def run_cosmology_abbr(self, params : dict):
print('\\_____/ ')
print("")

# Compute background and linear perturbations
PT, BG = self.get_PTBG(params)
output = ()
aux = ()

if self.specs["output_Cl"]:
Cls = self.SS.get_Cl(PT, BG, params)
ells = self.SS.ells
output += Cls
aux += (ells,)

if self.specs["output_Pk"]:
Pk = self.SS.Pk_lin(self.SS.k_axis_Pk_output, 0., PT, params)
output += (Pk,)
aux += (self.SS.k_axis_Pk_output,)

aux += (params,)

if self.specs["output_perturbations"]:
aux += (PT,)

if self.specs["output_background"]:
aux += (BG,)
# Compute CMB power spectra
Cls = self.SS.get_Cl(PT, BG, params)
l = self.SS.ells

# Compute linear matter power spectrum
Pk = self.SS.Pk_lin(self.SS.k_axis_Pk_output, 0., PT, params)
k = self.SS.k_axis_Pk_output

# Package
output = Output(
Cls[0], Cls[1], Cls[2], Pk,
l, k, BG, PT, params
)

return output, aux
return output

@eqx.filter_jit
def get_PTBG(self, params : dict):
Expand Down Expand Up @@ -528,4 +521,24 @@ def add_derived_parameters(self, param_in : dict) -> dict:
if key not in expected_keys:
params[key] = jnp.array(value)

return params
return params

class Output(eqx.Module):
"""
Object containing final and intermediate results from one cosmological simulation.
Contains the power spectra (CMB & P(k)) as well as auxillary fields including
the multipoles l for the Cls, wavenumbers k for P(k), background BG, perturbations PT, and
a full list of parameters (input + derived) in the params dictionary.
"""

# Power spectra
ClTT : jnp.array
ClTE : jnp.array
ClEE : jnp.array
Pk : jnp.array

l : jnp.array
k : jnp.array
BG : background.Background
PT : perturbations.PerturbationTable
params : dict
19 changes: 5 additions & 14 deletions abcmb/model_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,15 @@ def load_specs(input_specs):
specs["input_tau_reion"] = input_specs.get("input_tau_reion", True)

### OUTPUT RELATED specs PARAMS ###
specs["output_Cl"] = input_specs.get("output_Cl", True)
specs["l_min"] = input_specs.get("l_min", 2)
specs["l_max"] = input_specs.get("l_max", 2500)
specs["lensing"] = input_specs.get("lensing", False)

specs["output_Pk"] = input_specs.get("output_Pk", True)
specs["output_k_max"] = input_specs.get("output_k_max", 0.5)

specs["output_background"] = input_specs.get("output_background", False)
specs["output_perturbations"] = input_specs.get("output_perturbations", False)
specs["k_max"] = input_specs.get("k_max", 0.5)

### BBN ###
specs["bbn_type"] = input_specs.get("bbn_type", "")
specs["linx_reaction_net"] = input_specs.get("linx_reaction_net", "key_PRIMAT_2023")

### TODO: HYREX RELATED specs PARAMS ###

### Boltzmann Hierarchy Cutoffs ###
specs["l_max_g"] = input_specs.get("l_max_g", 12)
specs["l_max_pol_g"] = input_specs.get("l_max_pol_g", 10)
Expand All @@ -52,7 +44,6 @@ def load_specs(input_specs):
specs["tau0_fid"] = input_specs.get("tau0_fid",1.418668e+04)
specs["rs_rec_fid"] = input_specs.get("rs_rec_fid", 1.446279e+02)


### Transfer integration k-grid resolution ###
specs["k_transfer_linstep"] = input_specs.get("k_transfer_linstep", 4.5e-1)
specs["k_transfer_logstep"] = input_specs.get("k_transfer_logstep", 170.)
Expand Down Expand Up @@ -167,9 +158,9 @@ def get_k_axis_perturbations(specs):
i += 1
ks[i] = k

# If the user wants P(k) and specified a k_max above the current, we should add these as well.
if specs["output_Pk"] and k < specs["output_k_max"]:
k_max = specs["output_k_max"]
# If the user specified a k_max above the current, we should add these as well.
if k < specs["k_max"]:
k_max = specs["k_max"]

while k < k_max:
step = 0.005
Expand All @@ -179,7 +170,7 @@ def get_k_axis_perturbations(specs):
ks[i] = k

ks = ks[np.where(ks>0)]
k_axis_Pk_output = ks[np.where(ks<=specs["output_k_max"])]
k_axis_Pk_output = ks[np.where(ks<=specs["k_max"])]

return jnp.array(ks), jnp.array(k_axis_Pk_output)

Expand Down
14 changes: 7 additions & 7 deletions pytests/accuracy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,12 @@ def test_accuracy_checker(h = 0.6762):

# ABCMB

data, label = model(params)
ells = label[0]
output = model(params)
ells = output.l

ABC_tt = data[0]
ABC_te = data[1]
ABC_ee = data[2]
ABC_tt = output.ClTT
ABC_te = output.ClTE
ABC_ee = output.ClEE

# Compare Cltt
err_tt = abs(cltt-ABC_tt)/cltt
Expand All @@ -126,8 +126,8 @@ def test_accuracy_checker(h = 0.6762):
print(err_ee.max())

# Compare P(k)
ABC_Pk = data[3]
ABC_k = label[1]
ABC_Pk = output.Pk
ABC_k = output.k
CLA_Pk = np.vectorize(CLASS_Model.pk)(ABC_k, 0.)
err_Pk = abs(CLA_Pk-ABC_Pk)/CLA_Pk
print(err_Pk.max())
Expand Down