From 8f5fc77610a2dbe1299f7a2b3bd014b7b38aa63c Mon Sep 17 00:00:00 2001 From: epaillas Date: Mon, 9 Feb 2026 12:11:49 -0800 Subject: [PATCH 1/6] allow model predictions in transformed space --- sunbird/emulators/models/fcn.py | 50 ++++++++++++++++++++++++++------- sunbird/inference/base.py | 5 ++++ sunbird/inference/pocomc.py | 4 +-- 3 files changed, 46 insertions(+), 13 deletions(-) diff --git a/sunbird/emulators/models/fcn.py b/sunbird/emulators/models/fcn.py index 3e7d02c..87605c2 100644 --- a/sunbird/emulators/models/fcn.py +++ b/sunbird/emulators/models/fcn.py @@ -40,8 +40,8 @@ def __init__( std_output: Optional[torch.Tensor] = None, standarize_input: bool = True, standarize_output: bool = True, - transform_input: Optional[callable] = None, - transform_output: Optional[callable] = None, + input_transform: Optional[callable] = None, + output_transform: Optional[callable] = None, coordinates: Optional[dict] = None, compression_matrix: Optional[torch.Tensor] = None, *args, @@ -66,8 +66,8 @@ def __init__( self.register_parameter('std_input', std_input, n_input) self.register_parameter('mean_output', mean_output, n_output) self.register_parameter('std_output', std_output, n_output) - self.transform_input = transform_input - self.transform_output = transform_output + self.input_transform = input_transform + self.output_transform = output_transform self.loss = loss self.data_dim = self.n_output if self.loss == "learned_gaussian": @@ -91,6 +91,24 @@ def __init__( ], ) self.compression_matrix = compression_matrix + + def __setattr__(self, name, value): + """Override to provide backward compatibility for renamed attributes""" + # Map old attribute names to new ones for backward compatibility + if name == 'transform_input': + name = 'input_transform' + elif name == 'transform_output': + name = 'output_transform' + super().__setattr__(name, value) + + def __getattr__(self, name): + """Override to provide backward compatibility for renamed attributes""" + # Map old attribute names to new ones for backward compatibility + if name == 'transform_input': + return self.input_transform + elif name == 'transform_output': + return self.output_transform + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") @staticmethod def add_model_specific_args(parent_parser): @@ -157,7 +175,7 @@ def flax_attributes(self,): 'act_fn': self.act_fn_str, 'n_output': self.n_output, 'predict_errors': True if self.loss == "learned_gaussian" else False, - 'transform_output': self.transform_output, + 'output_transform': self.output_transform, 'coordinates': self.coordinates, 'compression_matrix': None, } @@ -277,17 +295,29 @@ def forward(self, x: Tensor) -> Tensor: y_var = torch.zeros_like(y_pred) return y_pred, y_var - def get_prediction(self, x: Tensor, filters: Optional[dict] = None) -> Tensor: + def get_prediction(self, x: Tensor, filters: Optional[dict] = None, skip_output_inverse_transform: bool = False) -> Tensor: + """Get prediction from the model. + + Args: + x (Tensor): Input tensor + filters (dict, optional): Filters to apply. Defaults to None. + skip_output_inverse_transform (bool, optional): If True, skip the output inverse transformation, + keeping predictions in the transformed space. Useful when performing inference in transformed + space (requires transforming observations and covariance to match). Defaults to False. + + Returns: + Tensor: Model prediction + """ x = torch.Tensor(x) - if self.transform_input: - x = self.transform_input.transform(x) + if self.input_transform is not None: + x = self.input_transform.transform(x) y, _ = self.forward(x) if self.standarize_output: std_output = self.std_output.to(x.device) mean_output = self.mean_output.to(x.device) y = y * std_output + mean_output - if self.transform_output: - y = self.transform_output.inverse_transform(y) + if self.output_transform is not None and not skip_output_inverse_transform: + y = self.output_transform.inverse_transform(y) if self.compression_matrix is not None: y = y @ self.compression_matrix return y diff --git a/sunbird/inference/base.py b/sunbird/inference/base.py index e7844a4..9747966 100644 --- a/sunbird/inference/base.py +++ b/sunbird/inference/base.py @@ -19,6 +19,7 @@ def __init__(self, coordinates: list = [], ellipsoid: bool = False, markers: dict = {}, + sample_in_transformed_space: bool = False, **kwargs, ): self.logger = logging.getLogger(self.__class__.__name__) @@ -33,11 +34,15 @@ def __init__(self, self.precision_matrix = precision_matrix self.ellipsoid = ellipsoid self.markers = markers + self.sample_in_transformed_space = sample_in_transformed_space if self.ellipsoid: self.abacus_ellipsoid = AbacusSummitEllipsoid() self.ndim = len(self.priors.keys()) - len(self.fixed_parameters.keys()) self.logger.info(f'Free parameters: {[key for key in priors.keys() if key not in fixed_parameters.keys()]}') self.logger.info(f'Fixed parameters: {[key for key in priors.keys() if key in fixed_parameters.keys()]}') + if self.sample_in_transformed_space: + self.logger.warning('Sampling in transformed space (skip_output_inverse_transform=True). ' + 'Ensure observations and covariance matrix are also transformed to match!') def save_chain(self, save_fn, metadata=None): """Save the chain to a file diff --git a/sunbird/inference/pocomc.py b/sunbird/inference/pocomc.py index e7e21b7..ec29ea0 100644 --- a/sunbird/inference/pocomc.py +++ b/sunbird/inference/pocomc.py @@ -55,9 +55,7 @@ def get_model_prediction(self, theta): Returns: np.array: model prediction """ - # pred = self.theory_model.get_prediction(x=theta) - pred = self.theory_model(x=theta) - # detach if using torch + pred = self.theory_model(x=theta, skip_output_inverse_transform=self.sample_in_transformed_space) if isinstance(pred, torch.Tensor): pred = pred.detach().numpy() return pred From b8134be4a94fa098b5e3927635260e8cb4c415c9 Mon Sep 17 00:00:00 2001 From: epaillas Date: Mon, 9 Feb 2026 14:21:38 -0800 Subject: [PATCH 2/6] allow transformation of data and covariance for inference --- sunbird/emulators/models/fcn.py | 1 + sunbird/inference/base.py | 34 ++++++++++++++++++++++++++++----- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/sunbird/emulators/models/fcn.py b/sunbird/emulators/models/fcn.py index 87605c2..219a2c0 100644 --- a/sunbird/emulators/models/fcn.py +++ b/sunbird/emulators/models/fcn.py @@ -176,6 +176,7 @@ def flax_attributes(self,): 'n_output': self.n_output, 'predict_errors': True if self.loss == "learned_gaussian" else False, 'output_transform': self.output_transform, + # 'transform_output': self.output_transform, 'coordinates': self.coordinates, 'compression_matrix': None, } diff --git a/sunbird/inference/base.py b/sunbird/inference/base.py index 9747966..27a0e34 100644 --- a/sunbird/inference/base.py +++ b/sunbird/inference/base.py @@ -20,6 +20,7 @@ def __init__(self, ellipsoid: bool = False, markers: dict = {}, sample_in_transformed_space: bool = False, + observable = None, **kwargs, ): self.logger = logging.getLogger(self.__class__.__name__) @@ -27,22 +28,45 @@ def __init__(self, if fixed_parameters is None: fixed_parameters = {} self.fixed_parameters = fixed_parameters - self.observation = observation self.priors = priors self.ranges = ranges self.labels = labels - self.precision_matrix = precision_matrix self.ellipsoid = ellipsoid self.markers = markers self.sample_in_transformed_space = sample_in_transformed_space + self.observable = observable + + # Handle transformation of observations and covariance + if sample_in_transformed_space: + if observable is None: + self.logger.warning('Sampling in transformed space (skip_output_inverse_transform=True). ' + 'Ensure observations and covariance matrix are also transformed to match!') + self.observation = observation + self.precision_matrix = precision_matrix + else: + # Validate that the observable has an output transform + if not hasattr(observable, 'output_transform') or observable.output_transform is None: + raise ValueError('Cannot sample in transformed space: observable does not have an output_transform. ' + 'Either set sample_in_transformed_space=False or use an observable with output_transform.') + + # Auto-transform the observation and covariance + self.logger.info('Auto-transforming observation and covariance matrix to transformed space.') + self.observation = observable.get_transformed_y() + + # Get transformed covariance and invert it to get precision matrix + import numpy as np + transformed_cov = observable.get_transformed_covariance_matrix() + self.precision_matrix = np.linalg.inv(transformed_cov) + self.logger.info('Successfully transformed observation and precision matrix.') + else: + self.observation = observation + self.precision_matrix = precision_matrix + if self.ellipsoid: self.abacus_ellipsoid = AbacusSummitEllipsoid() self.ndim = len(self.priors.keys()) - len(self.fixed_parameters.keys()) self.logger.info(f'Free parameters: {[key for key in priors.keys() if key not in fixed_parameters.keys()]}') self.logger.info(f'Fixed parameters: {[key for key in priors.keys() if key in fixed_parameters.keys()]}') - if self.sample_in_transformed_space: - self.logger.warning('Sampling in transformed space (skip_output_inverse_transform=True). ' - 'Ensure observations and covariance matrix are also transformed to match!') def save_chain(self, save_fn, metadata=None): """Save the chain to a file From 21689b57cb38796498c3b06efff3deac2dfb2f53 Mon Sep 17 00:00:00 2001 From: epaillas Date: Mon, 9 Feb 2026 14:40:37 -0800 Subject: [PATCH 3/6] add jacobian to data transforms --- sunbird/data/transforms_array.py | 105 +++++++++++++++++++++++++++ sunbird/emulators/models/fcn_flax.py | 6 +- 2 files changed, 108 insertions(+), 3 deletions(-) diff --git a/sunbird/data/transforms_array.py b/sunbird/data/transforms_array.py index 1b09fed..eaf91db 100644 --- a/sunbird/data/transforms_array.py +++ b/sunbird/data/transforms_array.py @@ -12,6 +12,27 @@ def transform(self, x): @abstractmethod def inverse_transform(self, x): pass + + @abstractmethod + def get_jacobian_diagonal(self, y): + """ + Get the diagonal of the Jacobian matrix df/dy for transforming covariance matrices. + + For an element-wise transformation f(y), the transformed covariance is: + Cov_transformed = diag(J) @ Cov @ diag(J) + where J = df/dy is the Jacobian diagonal. + + Parameters + ---------- + y : array_like + Data vector in the original (untransformed) space. + + Returns + ------- + array_like + Diagonal of the Jacobian matrix, same shape as y. + """ + pass class LogTransform(BaseTransform): @@ -23,6 +44,27 @@ def transform(self, x): def inverse_transform(self, x): return 10**x + + def get_jacobian_diagonal(self, y): + """ + Get Jacobian diagonal for log10 transform: d(log10(y))/dy = 1/(y * ln(10)) + + Parameters + ---------- + y : array_like + Data vector in the original (untransformed) space. + + Returns + ------- + array_like + Jacobian diagonal: 1/(y * ln(10)) + """ + if type(y) == torch.Tensor: + return 1.0 / (y * torch.log(torch.tensor(10.0))) + elif type(y) == np.ndarray: + return 1.0 / (y * np.log(10.0)) + else: + return 1.0 / (y * jnp.log(10.0)) class ArcsinhTransform(BaseTransform): @@ -41,6 +83,27 @@ def inverse_transform(self, x): return np.sinh(x) else: return jnp.sinh(x) + + def get_jacobian_diagonal(self, y): + """ + Get Jacobian diagonal for arcsinh transform: d(arcsinh(y))/dy = 1/sqrt(1 + y^2) + + Parameters + ---------- + y : array_like + Data vector in the original (untransformed) space. + + Returns + ------- + array_like + Jacobian diagonal: 1/sqrt(1 + y^2) + """ + if type(y) == torch.Tensor: + return 1.0 / torch.sqrt(1.0 + y**2) + elif type(y) == np.ndarray: + return 1.0 / np.sqrt(1.0 + y**2) + else: + return 1.0 / jnp.sqrt(1.0 + y**2) class WeiLiuOutputTransForm(BaseTransform): """Class to reconcile output the Minkowski functionals model @@ -57,6 +120,27 @@ def transform(self, x): def inverse_transform(self, x): return x * self.std + self.mean + + def get_jacobian_diagonal(self, y): + """ + Get Jacobian diagonal for affine transform: d(y * std + mean)/dy = std + + Parameters + ---------- + y : array_like + Data vector in the original (untransformed) space. + + Returns + ------- + array_like + Jacobian diagonal: std (broadcast to match y shape) + """ + if type(y) == torch.Tensor: + return torch.ones_like(y) * self.std + elif type(y) == np.ndarray: + return np.ones_like(y) * self.std.numpy() + else: + return jnp.ones_like(y) * self.std.numpy() class WeiLiuInputTransform(BaseTransform): """Class to reconcile input of the Minkowski functionals model @@ -73,4 +157,25 @@ def transform(self, x): def inverse_transform(self, x): return x + + def get_jacobian_diagonal(self, y): + """ + Get Jacobian diagonal for standardization: d((y - mean) / std)/dy = 1/std + + Parameters + ---------- + y : array_like + Data vector in the original (untransformed) space. + + Returns + ------- + array_like + Jacobian diagonal: 1/std (broadcast to match y shape) + """ + if type(y) == torch.Tensor: + return torch.ones_like(y) / self.std + elif type(y) == np.ndarray: + return np.ones_like(y) / self.std.numpy() + else: + return jnp.ones_like(y) / self.std.numpy() \ No newline at end of file diff --git a/sunbird/emulators/models/fcn_flax.py b/sunbird/emulators/models/fcn_flax.py index ad0fbb6..ba3a185 100644 --- a/sunbird/emulators/models/fcn_flax.py +++ b/sunbird/emulators/models/fcn_flax.py @@ -19,7 +19,7 @@ class FlaxFCN(nn.Module): act_fn: str n_output: int predict_errors: False - transform_output: None + output_transform: None coordinates: None compression_matrix: None @@ -97,8 +97,8 @@ def __call__(self, x: jnp.array, filters=None) -> jnp.array: else: y_var = jnp.zeros_like(y_pred) y_pred = y_pred * std_output + mean_output - if self.transform_output is not None: - y_pred = self.transform_output.inverse_transform(y_pred) + if self.output_transform is not None: + y_pred = self.output_transform.inverse_transform(y_pred) if filters is not None: y_pred = y_pred[~filters.reshape(-1)] if self.compression_matrix is not None: From dbead36038681fcb7425fd8efb2d7187dc07d6a1 Mon Sep 17 00:00:00 2001 From: epaillas Date: Mon, 9 Feb 2026 14:42:45 -0800 Subject: [PATCH 4/6] name change to output_transform --- sunbird/emulators/models/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sunbird/emulators/models/transformer.py b/sunbird/emulators/models/transformer.py index a583c9f..7d8353d 100644 --- a/sunbird/emulators/models/transformer.py +++ b/sunbird/emulators/models/transformer.py @@ -108,7 +108,7 @@ def flax_attributes(self,): 'act_fn': self.act_fn_str, 'n_output': self.n_output, 'predict_errors': True if self.loss == "learned_gaussian" else False, - 'transform_output': self.transform_output, + 'output_transform': self.output_transform, 'coordinates': self.coordinates, } From 7653befb289d9fd709bd5a872ea1dd35bb98a766 Mon Sep 17 00:00:00 2001 From: epaillas Date: Mon, 9 Feb 2026 14:43:34 -0800 Subject: [PATCH 5/6] theta_start to theta_MC_100 --- sunbird/cosmology/growth_rate.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/sunbird/cosmology/growth_rate.py b/sunbird/cosmology/growth_rate.py index 63ad44c..0956cc0 100644 --- a/sunbird/cosmology/growth_rate.py +++ b/sunbird/cosmology/growth_rate.py @@ -21,11 +21,11 @@ class Growth: def __init__( self, - theta_star: float = AbacusSummit(0).theta_star, + theta_MC_100: float = AbacusSummit(0)['theta_MC_100'], emulate=False, emulator_data_dir=DEFAULT_PATH / "data/hemu/", ): - self.theta_star = theta_star + self.theta_MC_100 = theta_MC_100 self.emulate = emulate self.emulator_data_dir = emulator_data_dir if self.emulate: @@ -63,10 +63,13 @@ def generate_emulator_training_data( h_values, sample_parameters = [], [] for i, sample in enumerate(samples_matrix): try: - cosmology = self.get_cosmology_fixed_theta_star( + # print every 1000th sample + if i % 1000 == 0: + print(i) + cosmology = self.get_cosmology_fixed_theta_MC_100( DESI(engine="class"), dict( - theta_star=self.theta_star, + theta_MC_100=self.theta_MC_100, omega_b=sample[0], omega_cdm=sample[1], sigma8=sample[2], @@ -278,29 +281,30 @@ def get_emulated_h(self, omega_b, omega_cdm, sigma8, N_ur, n_s, w0_fld, wa_fld): x = jnp.vstack([omega_b, omega_cdm, sigma8, N_ur, n_s, w0_fld, wa_fld]).T return self.model.apply(self.params, x) - def get_cosmology_fixed_theta_star( + def get_cosmology_fixed_theta_MC_100( self, fiducial, params, h_limits=[0.4, 1.0], xtol=1.0e-6, ): - theta = params.pop("theta_star", None) + theta = params.pop("theta_MC_100", None) fiducial = fiducial.clone(base="input", **params) if theta is not None: if "h" in params: - raise ValueError("Cannot provide both theta_star and h") + raise ValueError("Cannot provide both theta_MC_100 and h") def f(h): cosmo = fiducial.clone(base="input", h=h) - return 100.0 * (theta - cosmo.get_thermodynamics().theta_star) + # return 100.0 * (theta - cosmo.get_thermodynamics().theta_MC_100) + return 100.0 * (theta - cosmo['theta_MC_100']) rtol = xtol try: h = optimize.bisect(f, *h_limits, xtol=xtol, rtol=rtol, disp=True) except ValueError as exc: raise ValueError( - "Could not find proper h value in the interval that matches theta_star = {:.4f} with [f({:.3f}), f({:.3f})] = [{:.4f}, {:.4f}]".format( + "Could not find proper h value in the interval that matches theta_MC_100 = {:.4f} with [f({:.3f}), f({:.3f})] = [{:.4f}, {:.4f}]".format( theta, *h_limits, *list(map(f, h_limits)) ) ) from exc @@ -339,10 +343,10 @@ def get_growth( z=z, ) else: - cosmology = self.get_cosmology_fixed_theta_star( + cosmology = self.get_cosmology_fixed_theta_MC_100( DESI(engine="class"), dict( - theta_star=self.theta_star, + theta_MC_100=self.theta_MC_100, omega_b=omega_b, omega_cdm=omega_cdm, sigma8=sigma8, @@ -397,10 +401,10 @@ def get_fsigma8( ) return growth_rate * sigma8_z else: - cosmology = self.get_cosmology_fixed_theta_star( + cosmology = self.get_cosmology_fixed_theta_MC_100( DESI(engine="class"), dict( - theta_star=self.theta_star, + theta_MC_100=self.theta_MC_100, omega_b=omega_b, omega_cdm=omega_cdm, sigma8=sigma8, @@ -418,6 +422,6 @@ def get_fsigma8( t0 = time.time() growth = Growth() - # growth.generate_emulator_training_data() + growth.generate_emulator_training_data(n_samples=100_000) growth.train_emulator() print(f"It took {time.time() - t0} seconds") From 90765c6c2e97d557f3b14b384fdcbd9a4d938c1a Mon Sep 17 00:00:00 2001 From: epaillas Date: Tue, 10 Feb 2026 13:40:04 -0800 Subject: [PATCH 6/6] sample in transformed space for combined statistics --- sunbird/emulators/models/fcn.py | 37 +++-------- sunbird/emulators/models/fcn_flax.py | 6 +- sunbird/emulators/models/transformer.py | 2 +- sunbird/inference/base.py | 86 ++++++++++++++++--------- sunbird/inference/pocomc.py | 76 +++++++++++----------- 5 files changed, 109 insertions(+), 98 deletions(-) diff --git a/sunbird/emulators/models/fcn.py b/sunbird/emulators/models/fcn.py index b4ea5d8..055a6a9 100644 --- a/sunbird/emulators/models/fcn.py +++ b/sunbird/emulators/models/fcn.py @@ -38,8 +38,8 @@ def __init__( std_output: Optional[torch.Tensor] = None, standarize_input: bool = True, standarize_output: bool = True, - input_transform: Optional[callable] = None, - output_transform: Optional[callable] = None, + transform_input: Optional[callable] = None, + transform_output: Optional[callable] = None, coordinates: Optional[dict] = None, compression_matrix: Optional[torch.Tensor] = None, *args, @@ -64,8 +64,8 @@ def __init__( self.register_parameter('std_input', std_input, n_input) self.register_parameter('mean_output', mean_output, n_output) self.register_parameter('std_output', std_output, n_output) - self.input_transform = input_transform - self.output_transform = output_transform + self.transform_input = transform_input + self.transform_output = transform_output self.loss = loss self.data_dim = self.n_output if self.loss == "learned_gaussian": @@ -90,24 +90,6 @@ def __init__( ) self.compression_matrix = compression_matrix - def __setattr__(self, name, value): - """Override to provide backward compatibility for renamed attributes""" - # Map old attribute names to new ones for backward compatibility - if name == 'transform_input': - name = 'input_transform' - elif name == 'transform_output': - name = 'output_transform' - super().__setattr__(name, value) - - def __getattr__(self, name): - """Override to provide backward compatibility for renamed attributes""" - # Map old attribute names to new ones for backward compatibility - if name == 'transform_input': - return self.input_transform - elif name == 'transform_output': - return self.output_transform - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") - @staticmethod def add_model_specific_args(parent_parser): """Model arguments that could vary @@ -173,8 +155,7 @@ def flax_attributes(self,): 'act_fn': self.act_fn_str, 'n_output': self.n_output, 'predict_errors': True if self.loss == "learned_gaussian" else False, - 'output_transform': self.output_transform, - # 'transform_output': self.output_transform, + 'transform_output': self.transform_output, 'coordinates': self.coordinates, 'compression_matrix': None, } @@ -308,15 +289,15 @@ def get_prediction(self, x: Tensor, filters: Optional[dict] = None, skip_output_ Tensor: Model prediction """ x = torch.Tensor(x) - if self.input_transform is not None: - x = self.input_transform.transform(x) + if self.transform_input is not None: + x = self.transform_input.transform(x) y, _ = self.forward(x) if self.standarize_output: std_output = self.std_output.to(x.device) mean_output = self.mean_output.to(x.device) y = y * std_output + mean_output - if self.output_transform is not None and not skip_output_inverse_transform: - y = self.output_transform.inverse_transform(y) + if self.transform_output is not None and not skip_output_inverse_transform: + y = self.transform_output.inverse_transform(y) if self.compression_matrix is not None: y = y @ self.compression_matrix return y diff --git a/sunbird/emulators/models/fcn_flax.py b/sunbird/emulators/models/fcn_flax.py index 763f9ea..8b83a7e 100644 --- a/sunbird/emulators/models/fcn_flax.py +++ b/sunbird/emulators/models/fcn_flax.py @@ -17,7 +17,7 @@ class FlaxFCN(nn.Module): act_fn: str n_output: int predict_errors: False - output_transform: None + transform_output: None coordinates: None compression_matrix: None @@ -95,8 +95,8 @@ def __call__(self, x: jnp.array, filters=None) -> jnp.array: else: y_var = jnp.zeros_like(y_pred) y_pred = y_pred * std_output + mean_output - if self.output_transform is not None: - y_pred = self.output_transform.inverse_transform(y_pred) + if self.transform_output is not None: + y_pred = self.transform_output.inverse_transform(y_pred) if filters is not None: y_pred = y_pred[~filters.reshape(-1)] if self.compression_matrix is not None: diff --git a/sunbird/emulators/models/transformer.py b/sunbird/emulators/models/transformer.py index 1177b4e..543e6eb 100644 --- a/sunbird/emulators/models/transformer.py +++ b/sunbird/emulators/models/transformer.py @@ -106,7 +106,7 @@ def flax_attributes(self,): 'act_fn': self.act_fn_str, 'n_output': self.n_output, 'predict_errors': True if self.loss == "learned_gaussian" else False, - 'output_transform': self.output_transform, + 'transform_output': self.transform_output, 'coordinates': self.coordinates, } diff --git a/sunbird/inference/base.py b/sunbird/inference/base.py index 6c49758..3271453 100644 --- a/sunbird/inference/base.py +++ b/sunbird/inference/base.py @@ -1,3 +1,5 @@ +"""Base classes and utilities for inference samplers.""" + import logging import numpy as np from tabulate import tabulate @@ -5,7 +7,14 @@ from sunbird.inference.priors import AbacusSummitEllipsoid class BaseSampler: - def __init__(self, + """Base class for inference samplers. + + Handles parameter bookkeeping, optional transformed-space sampling, and + convenience utilities for saving chains and summary tables. + """ + + def __init__( + self, observation, precision_matrix, theory_model, @@ -13,15 +22,26 @@ def __init__(self, ranges: Optional[Dict[str, tuple]] = {}, labels: Dict[str, str] = {}, fixed_parameters: Dict[str, float] = {}, - slice_filters: Dict = {}, - select_filters: Dict = {}, - coordinates: list = [], ellipsoid: bool = False, markers: dict = {}, sample_in_transformed_space: bool = False, - observable = None, **kwargs, ): + """Initialize the sampler base. + + Args: + observation: Observed data vector. + precision_matrix: Inverse covariance matrix. + theory_model: Callable model that maps parameters to predictions. + priors: Mapping of parameter names to prior objects. + ranges: Optional plotting or reporting ranges by parameter. + labels: Optional labels by parameter. + fixed_parameters: Mapping of parameter names to fixed values. + ellipsoid: Whether to include the AbacusSummit ellipsoid prior. + markers: Optional marker styling for plots. + sample_in_transformed_space: If True, use transformed outputs. + **kwargs: Extra arguments for subclasses. + """ self.logger = logging.getLogger(self.__class__.__name__) self.theory_model = theory_model if fixed_parameters is None: @@ -33,42 +53,45 @@ def __init__(self, self.ellipsoid = ellipsoid self.markers = markers self.sample_in_transformed_space = sample_in_transformed_space - self.observable = observable # Handle transformation of observations and covariance if sample_in_transformed_space: - if observable is None: - self.logger.warning('Sampling in transformed space (skip_output_inverse_transform=True). ' - 'Ensure observations and covariance matrix are also transformed to match!') - self.observation = observation - self.precision_matrix = precision_matrix - else: - # Validate that the observable has an output transform - if not hasattr(observable, 'output_transform') or observable.output_transform is None: - raise ValueError('Cannot sample in transformed space: observable does not have an output_transform. ' - 'Either set sample_in_transformed_space=False or use an observable with output_transform.') - - # Auto-transform the observation and covariance - self.logger.info('Auto-transforming observation and covariance matrix to transformed space.') - self.observation = observable.get_transformed_y() + # Validate that the observable has an output transform + if not hasattr(theory_model.__self__.model, 'transform_output'): + raise ValueError('Cannot sample in transformed space: observable does not have a transform_output. ' + 'Either set sample_in_transformed_space=False or use an observable with transform_output.') + + # Check if transform_output is valid (not None or empty list) + transform = theory_model.__self__.model.transform_output + if transform is None: + raise ValueError('Cannot sample in transformed space: transform_output is None. ' + 'Either set sample_in_transformed_space=False or use an observable with transform_output.') + + # For combined observables, transform_output is a list + if isinstance(transform, list): + if all(t is None for t in transform): + raise ValueError('Cannot sample in transformed space: all transforms in combined observable are None. ' + 'Either set sample_in_transformed_space=False or use observables with transform_output.') + + self.logger.warning('Sampling in transformed space (skip_output_inverse_transform=True). ' + 'Ensure observations and covariance matrix are also transformed to match!') - # Get transformed covariance and invert it to get precision matrix - import numpy as np - transformed_cov = observable.get_transformed_covariance_matrix() - self.precision_matrix = np.linalg.inv(transformed_cov) - self.logger.info('Successfully transformed observation and precision matrix.') - else: - self.observation = observation - self.precision_matrix = precision_matrix + self.observation = observation + self.precision_matrix = precision_matrix if self.ellipsoid: self.abacus_ellipsoid = AbacusSummitEllipsoid() + self.ndim = len(self.priors.keys()) - len(self.fixed_parameters.keys()) self.logger.info(f'Free parameters: {[key for key in priors.keys() if key not in fixed_parameters.keys()]}') self.logger.info(f'Fixed parameters: {[key for key in priors.keys() if key in fixed_parameters.keys()]}') def save_chain(self, save_fn, metadata=None): - """Save the chain to a file + """Save a chain dictionary to a NumPy file. + + Args: + save_fn: Output filename for the NumPy archive. + metadata: Optional extra metadata to include. """ data = self.get_chain(flat=True) names = [param for param in self.priors.keys() if param not in self.fixed_parameters] @@ -95,6 +118,11 @@ def save_chain(self, save_fn, metadata=None): np.save(save_fn, cout) def save_table(self, save_fn): + """Write a summary table with MAP/mean/std values. + + Args: + save_fn: Output filename for the text table. + """ chain = self.get_chain(flat=True) maxp = chain['samples'][chain['log_posterior'].argmax()] mean = chain['samples'].mean(axis=0) diff --git a/sunbird/inference/pocomc.py b/sunbird/inference/pocomc.py index 4ba8b58..e75862a 100644 --- a/sunbird/inference/pocomc.py +++ b/sunbird/inference/pocomc.py @@ -1,3 +1,5 @@ +"""Samplers based on the `pocomc` inference engine.""" + import torch import pocomc import numpy as np @@ -5,17 +7,20 @@ class PocoMCSampler(BaseSampler): + """PoCoMC sampler wrapper with optional ellipsoid prior support.""" + def __init__(self, **kwargs): + """Initialize the PoCoMC sampler wrapper.""" super().__init__(**kwargs) def fill_params(self, theta): - """Fill the parameter vector to include fixed parameters + """Fill a parameter vector to include fixed parameters. Args: - theta (np.array): input parameters + theta: Free parameter vector. Returns: - np.array: filled parameters + Filled parameter vector with fixed values inserted. """ params = np.ones(len(self.priors.keys())) itheta = 0 @@ -28,13 +33,13 @@ def fill_params(self, theta): return params def fill_params_batch(self, thetas): - """Fill the batch of parameter vectors to include fixed parameters + """Fill a batch of parameter vectors to include fixed parameters. Args: - thetas (np.array): input parameters + thetas: Batch of free parameter vectors. Returns: - np.array: filled parameters + Filled parameter array with fixed values inserted. """ params = np.ones((len(thetas), len(self.priors.keys()))) for i, theta in enumerate(thetas): @@ -42,13 +47,13 @@ def fill_params_batch(self, thetas): return params def get_model_prediction(self, theta): - """Get model prediction + """Return the model prediction for the given parameters. Args: - theta (np.array): input parameters + theta: Parameter vector or batch of vectors. Returns: - np.array: model prediction + Model prediction as a NumPy array. """ pred = self.theory_model(x=theta, skip_output_inverse_transform=self.sample_in_transformed_space) if isinstance(pred, torch.Tensor): @@ -56,13 +61,13 @@ def get_model_prediction(self, theta): return pred def log_likelihood(self, theta): - """Log likelihood function + """Compute the log likelihood for a parameter vector or batch. Args: - theta (np.array): input parameters + theta: Free parameter vector or batch of vectors. Returns: - float: log likelihood + Log likelihood value(s). """ batch = len(theta.shape) > 1 params = self.fill_params_batch(theta) if batch else self.fill_params(theta) @@ -78,14 +83,24 @@ def log_likelihood(self, theta): logl += self.abacus_ellipsoid.log_likelihood(params[:8]) return logl - def __call__(self, vectorize=True, random_state=0, precondition=True, n_total=4096, progress=True, **kwargs): - """Run the sampler + def __call__( + self, + vectorize=True, + random_state=0, + precondition=True, + n_total=4096, + progress=True, + **kwargs, + ): + """Run the PoCoMC sampler. Args: - vectorize (bool, optional): Vectorize the log likelihood call. Defaults to False. - random_state (int, optional): Random seed. Defaults to 0. - precondition (bool, optional): If False, use standard MCMC without normalizing flow. Defaults to True. - kwargs: Additional arguments for the sampler + vectorize: Vectorize the log likelihood call. + random_state: Random seed for the sampler. + precondition: If False, disable normalizing flow preconditioning. + n_total: Total number of samples to draw. + progress: Whether to display progress output. + **kwargs: Additional arguments for `pocomc.Sampler`. """ prior = pocomc.Prior([value for key, value in self.priors.items() if key not in self.fixed_parameters.keys()]) @@ -101,27 +116,20 @@ def __call__(self, vectorize=True, random_state=0, precondition=True, n_total=40 self.sampler.run(progress=progress, n_total=n_total) def get_chain(self, **kwargs): - """Get the chain from the sampler - - Returns: - np.array: chain - """ + """Return the posterior samples and derived quantities.""" samples, weights, loglike, logprior = self.sampler.posterior() logz, logz_err = self.sampler.evidence() logposterior = loglike + logprior - logz return {'samples': samples, 'weights': weights, 'log_likelihood': loglike, 'log_prior': logprior, 'log_posterior': logposterior} - def evidence(self,): - """Get the evidence from the sampler - - Returns: - tuple: logz, logz_err - """ + def evidence(self): + """Return the evidence estimate and its error.""" return self.sampler.evidence() class PocoMCPriorSampler(PocoMCSampler): + """PoCoMC sampler that returns a flat likelihood over the prior.""" def __init__( self, observation=None, @@ -129,17 +137,11 @@ def __init__( theory_model=None, **kwargs, ): + """Initialize a prior-only sampler.""" super().__init__(observation, precision_matrix, theory_model, **kwargs) def log_likelihood(self, theta): - """Log likelihood function - - Args: - theta (np.array): input parameters - - Returns: - float: log likelihood - """ + """Return a flat log likelihood with optional ellipsoid term.""" batch = len(theta.shape) > 1 params = self.fill_params_batch(theta) if batch else self.fill_params(theta) if batch: