diff --git a/abcmb/main.py b/abcmb/main.py index a3dfe7e..244327b 100644 --- a/abcmb/main.py +++ b/abcmb/main.py @@ -1,4 +1,4 @@ -from jax import jit, config, lax +from jax import jit, config, lax, tree_util import jax.numpy as jnp from jaxtyping import Array import numpy as np @@ -168,8 +168,7 @@ def __call__(self, params : dict = {}): full_params = self.add_derived_parameters(params) - output, aux = self.run_cosmology_abbr(full_params) - return output, aux + return self.run_cosmology_abbr(full_params) ### JITTED OR JITTABLE FUNCTIONS ### @@ -205,30 +204,24 @@ def run_cosmology_abbr(self, params : dict): print('\\_____/ ') print("") + # Compute background and linear perturbations PT, BG = self.get_PTBG(params) - output = () - aux = () - - if self.specs["output_Cl"]: - Cls = self.SS.get_Cl(PT, BG, params) - ells = self.SS.ells - output += Cls - aux += (ells,) - - if self.specs["output_Pk"]: - Pk = self.SS.Pk_lin(self.SS.k_axis_Pk_output, 0., PT, params) - output += (Pk,) - aux += (self.SS.k_axis_Pk_output,) - - aux += (params,) - - if self.specs["output_perturbations"]: - aux += (PT,) - if self.specs["output_background"]: - aux += (BG,) + # Compute CMB power spectra + Cls = self.SS.get_Cl(PT, BG, params) + l = self.SS.ells + + # Compute linear matter power spectrum + Pk = self.SS.Pk_lin(self.SS.k_axis_Pk_output, 0., PT, params) + k = self.SS.k_axis_Pk_output + + # Package + output = Output( + Cls[0], Cls[1], Cls[2], Pk, + l, k, BG, PT, params + ) - return output, aux + return output @eqx.filter_jit def get_PTBG(self, params : dict): @@ -528,4 +521,24 @@ def add_derived_parameters(self, param_in : dict) -> dict: if key not in expected_keys: params[key] = jnp.array(value) - return params \ No newline at end of file + return params + +class Output(eqx.Module): + """ + Object containing final and intermediate results from one cosmological simulation. + Contains the power spectra (CMB & P(k)) as well as auxillary fields including + the multipoles l for the Cls, wavenumbers k for P(k), background BG, perturbations PT, and + a full list of parameters (input + derived) in the params dictionary. + """ + + # Power spectra + ClTT : jnp.array + ClTE : jnp.array + ClEE : jnp.array + Pk : jnp.array + + l : jnp.array + k : jnp.array + BG : background.Background + PT : perturbations.PerturbationTable + params : dict diff --git a/abcmb/model_specs.py b/abcmb/model_specs.py index 3a9cfea..ef4a085 100644 --- a/abcmb/model_specs.py +++ b/abcmb/model_specs.py @@ -18,23 +18,15 @@ def load_specs(input_specs): specs["input_tau_reion"] = input_specs.get("input_tau_reion", True) ### OUTPUT RELATED specs PARAMS ### - specs["output_Cl"] = input_specs.get("output_Cl", True) specs["l_min"] = input_specs.get("l_min", 2) specs["l_max"] = input_specs.get("l_max", 2500) specs["lensing"] = input_specs.get("lensing", False) - - specs["output_Pk"] = input_specs.get("output_Pk", True) - specs["output_k_max"] = input_specs.get("output_k_max", 0.5) - - specs["output_background"] = input_specs.get("output_background", False) - specs["output_perturbations"] = input_specs.get("output_perturbations", False) + specs["k_max"] = input_specs.get("k_max", 0.5) ### BBN ### specs["bbn_type"] = input_specs.get("bbn_type", "") specs["linx_reaction_net"] = input_specs.get("linx_reaction_net", "key_PRIMAT_2023") - ### TODO: HYREX RELATED specs PARAMS ### - ### Boltzmann Hierarchy Cutoffs ### specs["l_max_g"] = input_specs.get("l_max_g", 12) specs["l_max_pol_g"] = input_specs.get("l_max_pol_g", 10) @@ -52,7 +44,6 @@ def load_specs(input_specs): specs["tau0_fid"] = input_specs.get("tau0_fid",1.418668e+04) specs["rs_rec_fid"] = input_specs.get("rs_rec_fid", 1.446279e+02) - ### Transfer integration k-grid resolution ### specs["k_transfer_linstep"] = input_specs.get("k_transfer_linstep", 4.5e-1) specs["k_transfer_logstep"] = input_specs.get("k_transfer_logstep", 170.) @@ -167,9 +158,9 @@ def get_k_axis_perturbations(specs): i += 1 ks[i] = k - # If the user wants P(k) and specified a k_max above the current, we should add these as well. - if specs["output_Pk"] and k < specs["output_k_max"]: - k_max = specs["output_k_max"] + # If the user specified a k_max above the current, we should add these as well. + if k < specs["k_max"]: + k_max = specs["k_max"] while k < k_max: step = 0.005 @@ -179,7 +170,7 @@ def get_k_axis_perturbations(specs): ks[i] = k ks = ks[np.where(ks>0)] - k_axis_Pk_output = ks[np.where(ks<=specs["output_k_max"])] + k_axis_Pk_output = ks[np.where(ks<=specs["k_max"])] return jnp.array(ks), jnp.array(k_axis_Pk_output) diff --git a/pytests/accuracy_test.py b/pytests/accuracy_test.py index a71a518..8e886a0 100644 --- a/pytests/accuracy_test.py +++ b/pytests/accuracy_test.py @@ -110,12 +110,12 @@ def test_accuracy_checker(h = 0.6762): # ABCMB - data, label = model(params) - ells = label[0] + output = model.run_cosmology(params) + ells = output.l - ABC_tt = data[0] - ABC_te = data[1] - ABC_ee = data[2] + ABC_tt = output.ClTT + ABC_te = output.ClTE + ABC_ee = output.ClEE # Compare Cltt err_tt = abs(cltt-ABC_tt)/cltt @@ -126,8 +126,8 @@ def test_accuracy_checker(h = 0.6762): print(err_ee.max()) # Compare P(k) - ABC_Pk = data[3] - ABC_k = label[1] + ABC_Pk = output.Pk + ABC_k = output.k CLA_Pk = np.vectorize(CLASS_Model.pk)(ABC_k, 0.) err_Pk = abs(CLA_Pk-ABC_Pk)/CLA_Pk print(err_Pk.max())