From fb32e751d0a2fae22516b23cb9c47b9d4af950a9 Mon Sep 17 00:00:00 2001 From: Zilu Zhou Date: Thu, 12 Feb 2026 17:29:29 -0500 Subject: [PATCH 1/4] implemented simple eqx.Module output, aux fields still get gradient applied --- abcmb/main.py | 116 +++++++++++++++++++++++++++++++++---------- abcmb/model_specs.py | 19 ++----- 2 files changed, 96 insertions(+), 39 deletions(-) diff --git a/abcmb/main.py b/abcmb/main.py index bf617af..194a0a7 100644 --- a/abcmb/main.py +++ b/abcmb/main.py @@ -1,4 +1,4 @@ -from jax import jit, config, lax +from jax import jit, config, lax, tree_util import jax.numpy as jnp from jaxtyping import Array import numpy as np @@ -168,8 +168,7 @@ def run_cosmology(self, params : dict = {}): full_params = self.add_derived_parameters(params) - output, aux = self.run_cosmology_abbr(full_params) - return output, aux + return self.run_cosmology_abbr(full_params) ### JITTED OR JITTABLE FUNCTIONS ### @@ -205,30 +204,24 @@ def run_cosmology_abbr(self, params : dict): print('\\_____/ ') print("") + # Compute background and linear perturbations PT, BG = self.get_PTBG(params) - output = () - aux = () - - if self.specs["output_Cl"]: - Cls = self.SS.get_Cl(PT, BG, params) - ells = self.SS.ells - output += Cls - aux += (ells,) - - if self.specs["output_Pk"]: - Pk = self.SS.Pk_lin(self.SS.k_axis_Pk_output, 0., PT, params) - output += (Pk,) - aux += (self.SS.k_axis_Pk_output,) - - aux += (params,) - if self.specs["output_perturbations"]: - aux += (PT,) - - if self.specs["output_background"]: - aux += (BG,) + # Compute CMB power spectra + Cls = self.SS.get_Cl(PT, BG, params) + l = self.SS.ells + + # Compute linear matter power spectrum + Pk = self.SS.Pk_lin(self.SS.k_axis_Pk_output, 0., PT, params) + k = self.SS.k_axis_Pk_output + + # Package + output = Output( + Cls[0], Cls[1], Cls[2], Pk, + l, k, BG, PT, params + ) - return output, aux + return output @eqx.filter_jit def get_PTBG(self, params : dict): @@ -512,4 +505,77 @@ def add_derived_parameters(self, param_in : dict) -> dict: # Having inferred correct omega_m and omega_r, compute correct omega_Lambda params['omega_Lambda'] = params['h']**2 - params['omega_r'] - params['omega_m'] - return params \ No newline at end of file + return params + +class Output(eqx.Module): + """ + Object containing final and intermediate results from one cosmological simulation. + Contains the power spectra (CMB & P(k)) whose derivatives can be taken. + Also contains auxillary data such as l, k, background, perturbations and full params which are static + and cannot be taken gradient on. + """ + + # Power spectra + ClTT : jnp.array + ClTE : jnp.array + ClEE : jnp.array + Pk : jnp.array + + l : jnp.array + k : jnp.array + BG : background.Background + PT : perturbations.PerturbationTable + params : dict + +@tree_util.register_pytree_node_class +class Output_tree: + """ + Object containing final and intermediate results from one cosmological simulation. + Contains the power spectra (CMB & P(k)) whose derivatives can be taken. + Also contains auxillary data such as l, k, background, perturbations and full params which are static + and cannot be taken gradient on. + """ + + # Power spectra + ClTT : jnp.array + ClTE : jnp.array + ClEE : jnp.array + Pk : jnp.array + + # Auxillary data + # l : jnp.array = eqx.field(static=True) # Force static to avoid pytree registration. + # k : jnp.array = eqx.field(static=True) + # BG : background.Background = eqx.field(static=True) + # PT : perturbations.PerturbationTable = eqx.field(static=True) + # params : dict = eqx.field(static=True) + + l : jnp.array + k : jnp.array + BG : background.Background + PT : perturbations.PerturbationTable + params : dict + + def __init__(self, *, ClTT, ClTE, ClEE, Pk, l, k, BG, PT, params): + self.ClTT = ClTT + self.ClTE = ClTE + self.ClEE = ClEE + self.Pk = Pk + self.l = l + self.k = k + self.BG = BG + self.PT = PT + self.params = params + + # PyTree interface: + # children = things JAX traces/differentiates + # aux_data = static stuff carried along (not traced/differentiated) + def tree_flatten(self): + children = (self.ClTT, self.ClTE, self.ClEE, self.Pk) # <-- ONLY these get grads + aux_data = (self.l, self.k, self.BG, self.PT, self.params) # <-- static metadata + return children, aux_data + + @classmethod + def tree_unflatten(cls, aux_data, children): + ClTT, ClTE, ClEE, Pk = children + l, k, BG, PT, params = aux_data + return cls(ClTT=ClTT, ClTE=ClTE, ClEE=ClEE, Pk=Pk, l=l, k=k, BG=BG, PT=PT, params=params) \ No newline at end of file diff --git a/abcmb/model_specs.py b/abcmb/model_specs.py index 3a9cfea..ef4a085 100644 --- a/abcmb/model_specs.py +++ b/abcmb/model_specs.py @@ -18,23 +18,15 @@ def load_specs(input_specs): specs["input_tau_reion"] = input_specs.get("input_tau_reion", True) ### OUTPUT RELATED specs PARAMS ### - specs["output_Cl"] = input_specs.get("output_Cl", True) specs["l_min"] = input_specs.get("l_min", 2) specs["l_max"] = input_specs.get("l_max", 2500) specs["lensing"] = input_specs.get("lensing", False) - - specs["output_Pk"] = input_specs.get("output_Pk", True) - specs["output_k_max"] = input_specs.get("output_k_max", 0.5) - - specs["output_background"] = input_specs.get("output_background", False) - specs["output_perturbations"] = input_specs.get("output_perturbations", False) + specs["k_max"] = input_specs.get("k_max", 0.5) ### BBN ### specs["bbn_type"] = input_specs.get("bbn_type", "") specs["linx_reaction_net"] = input_specs.get("linx_reaction_net", "key_PRIMAT_2023") - ### TODO: HYREX RELATED specs PARAMS ### - ### Boltzmann Hierarchy Cutoffs ### specs["l_max_g"] = input_specs.get("l_max_g", 12) specs["l_max_pol_g"] = input_specs.get("l_max_pol_g", 10) @@ -52,7 +44,6 @@ def load_specs(input_specs): specs["tau0_fid"] = input_specs.get("tau0_fid",1.418668e+04) specs["rs_rec_fid"] = input_specs.get("rs_rec_fid", 1.446279e+02) - ### Transfer integration k-grid resolution ### specs["k_transfer_linstep"] = input_specs.get("k_transfer_linstep", 4.5e-1) specs["k_transfer_logstep"] = input_specs.get("k_transfer_logstep", 170.) @@ -167,9 +158,9 @@ def get_k_axis_perturbations(specs): i += 1 ks[i] = k - # If the user wants P(k) and specified a k_max above the current, we should add these as well. - if specs["output_Pk"] and k < specs["output_k_max"]: - k_max = specs["output_k_max"] + # If the user specified a k_max above the current, we should add these as well. + if k < specs["k_max"]: + k_max = specs["k_max"] while k < k_max: step = 0.005 @@ -179,7 +170,7 @@ def get_k_axis_perturbations(specs): ks[i] = k ks = ks[np.where(ks>0)] - k_axis_Pk_output = ks[np.where(ks<=specs["output_k_max"])] + k_axis_Pk_output = ks[np.where(ks<=specs["k_max"])] return jnp.array(ks), jnp.array(k_axis_Pk_output) From e0e0293493d02cf1108fc465c4d230ef7a952094 Mon Sep 17 00:00:00 2001 From: Zilu Zhou Date: Fri, 13 Feb 2026 14:22:29 -0500 Subject: [PATCH 2/4] add docs --- abcmb/main.py | 61 ++++----------------------------------------------- 1 file changed, 4 insertions(+), 57 deletions(-) diff --git a/abcmb/main.py b/abcmb/main.py index 194a0a7..9772976 100644 --- a/abcmb/main.py +++ b/abcmb/main.py @@ -510,9 +510,9 @@ def add_derived_parameters(self, param_in : dict) -> dict: class Output(eqx.Module): """ Object containing final and intermediate results from one cosmological simulation. - Contains the power spectra (CMB & P(k)) whose derivatives can be taken. - Also contains auxillary data such as l, k, background, perturbations and full params which are static - and cannot be taken gradient on. + Contains the power spectra (CMB & P(k)) as well as auxillary fields including + the multipoles l for the Cls, wavenumbers k for P(k), background BG, perturbations PT, and + a full list of parameters (input + derived) in the params dictionary. """ # Power spectra @@ -525,57 +525,4 @@ class Output(eqx.Module): k : jnp.array BG : background.Background PT : perturbations.PerturbationTable - params : dict - -@tree_util.register_pytree_node_class -class Output_tree: - """ - Object containing final and intermediate results from one cosmological simulation. - Contains the power spectra (CMB & P(k)) whose derivatives can be taken. - Also contains auxillary data such as l, k, background, perturbations and full params which are static - and cannot be taken gradient on. - """ - - # Power spectra - ClTT : jnp.array - ClTE : jnp.array - ClEE : jnp.array - Pk : jnp.array - - # Auxillary data - # l : jnp.array = eqx.field(static=True) # Force static to avoid pytree registration. - # k : jnp.array = eqx.field(static=True) - # BG : background.Background = eqx.field(static=True) - # PT : perturbations.PerturbationTable = eqx.field(static=True) - # params : dict = eqx.field(static=True) - - l : jnp.array - k : jnp.array - BG : background.Background - PT : perturbations.PerturbationTable - params : dict - - def __init__(self, *, ClTT, ClTE, ClEE, Pk, l, k, BG, PT, params): - self.ClTT = ClTT - self.ClTE = ClTE - self.ClEE = ClEE - self.Pk = Pk - self.l = l - self.k = k - self.BG = BG - self.PT = PT - self.params = params - - # PyTree interface: - # children = things JAX traces/differentiates - # aux_data = static stuff carried along (not traced/differentiated) - def tree_flatten(self): - children = (self.ClTT, self.ClTE, self.ClEE, self.Pk) # <-- ONLY these get grads - aux_data = (self.l, self.k, self.BG, self.PT, self.params) # <-- static metadata - return children, aux_data - - @classmethod - def tree_unflatten(cls, aux_data, children): - ClTT, ClTE, ClEE, Pk = children - l, k, BG, PT, params = aux_data - return cls(ClTT=ClTT, ClTE=ClTE, ClEE=ClEE, Pk=Pk, l=l, k=k, BG=BG, PT=PT, params=params) \ No newline at end of file + params : dict \ No newline at end of file From eff4f8ff84d861e2c716c2145dc2c8fe1b0e1ff8 Mon Sep 17 00:00:00 2001 From: Zilu Zhou Date: Fri, 13 Feb 2026 14:24:04 -0500 Subject: [PATCH 3/4] updated accuracy test --- pytests/accuracy_test.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pytests/accuracy_test.py b/pytests/accuracy_test.py index da491d4..8e886a0 100644 --- a/pytests/accuracy_test.py +++ b/pytests/accuracy_test.py @@ -110,12 +110,12 @@ def test_accuracy_checker(h = 0.6762): # ABCMB - data, label = model.run_cosmology(params) - ells = label[0] + output = model.run_cosmology(params) + ells = output.l - ABC_tt = data[0] - ABC_te = data[1] - ABC_ee = data[2] + ABC_tt = output.ClTT + ABC_te = output.ClTE + ABC_ee = output.ClEE # Compare Cltt err_tt = abs(cltt-ABC_tt)/cltt @@ -126,8 +126,8 @@ def test_accuracy_checker(h = 0.6762): print(err_ee.max()) # Compare P(k) - ABC_Pk = data[3] - ABC_k = label[1] + ABC_Pk = output.Pk + ABC_k = output.k CLA_Pk = np.vectorize(CLASS_Model.pk)(ABC_k, 0.) err_Pk = abs(CLA_Pk-ABC_Pk)/CLA_Pk print(err_Pk.max()) From 0522722e78e9b09fd5b0f907d4e2b70e7da16bc7 Mon Sep 17 00:00:00 2001 From: Zilu Zhou Date: Fri, 13 Feb 2026 14:40:52 -0500 Subject: [PATCH 4/4] new syntax for pytest --- pytests/accuracy_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytests/accuracy_test.py b/pytests/accuracy_test.py index 8e886a0..45ab664 100644 --- a/pytests/accuracy_test.py +++ b/pytests/accuracy_test.py @@ -110,7 +110,7 @@ def test_accuracy_checker(h = 0.6762): # ABCMB - output = model.run_cosmology(params) + output = model(params) ells = output.l ABC_tt = output.ClTT