Skip to content

Adding option to sample posterior in transformed (model) space#160

Open
epaillas wants to merge 4 commits intocosmodesifrom
inference
Open

Adding option to sample posterior in transformed (model) space#160
epaillas wants to merge 4 commits intocosmodesifrom
inference

Conversation

@epaillas
Copy link
Owner

@epaillas epaillas commented Feb 10, 2026

When the emulator is trained on a transformed representation of a summary statistic (e.g. log P(k)), but posterior inference is performed in the original, untransformed space, the implied noise model becomes inconsistent: the training loss encourages approximately Gaussian residuals in the transformed space, while the likelihood in linear space typically assumes additive Gaussian errors. This mismatch can lead to subtle biases (e.g. from exponentiation and Jensen’s inequality) and to an incorrect treatment of emulator uncertainty, which is intrinsically multiplicative in linear space.

While the above effects are expected to be small if the emulator uncertainty is subdominant with respect to the data uncertainty, we would like to be able to test this at the parameter inference level. This PR addresses the issue by allowing posterior sampling directly in the transformed space, ensuring that the likelihood, covariance, and emulator error model are defined consistently with the representation learned by the network.

Copy link
Collaborator

@SBouchard01 SBouchard01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While the point of this PR is very interesting, overall I feel like the transformation stuff should be done at the model level (in sunbird), not in the Observable classes. I think this will complexify the base code foe the observables.

The transformed combined function should also be defined in a child class of BaseObservable & CombinedObservable, not in the main one as they are for specific use cases.

The main interesting point is the implementation of a transform function for the data, but we should define a child class of the Observable class that could handle this by overloading the required functions (and try to put that directly in __getattr__ instead of all the call functions if that transform need to be applied at the output of the dataset elements ?). Adding a transform element in the output of __getattr__ can be interesting, but the curent implementation has a lot of moving pieces that repeat themselves


with torch.no_grad():
pred = model.get_prediction(torch.Tensor(x))
pred = model.get_prediction(torch.Tensor(x), skip_output_inverse_transform=skip_output_inverse_transform)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this be compatible with previous models ?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, no modifications of the trained models are needed at this point.

Comment on lines 248 to 259
if transform_output:
# Transform each observable's samples, then concatenate
cov_y_list = []
for observable in self.observables:
# Get samples and transform them
observable._validate_output_transform()
cov_y_obs = observable.get_covariance_y(nofilters=False)
cov_y_transformed = observable.apply_output_transform(cov_y_obs)
cov_y_list.append(cov_y_transformed)
cov_y = np.concatenate(cov_y_list, axis=-1)
else:
cov_y = self.covariance_y
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This defeats a bit the goal of the class to handle the observables separately and concatenate the outputs. Here the transform_output argument only works if all models have a transform_output property defined

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a patch to this in the latest commit to skip the transform for those statistics with no transform. See if that works better.

Comment on lines 314 to 417
def get_transformed_y(self) -> np.ndarray:
"""
Get the transformed combined observational data.

This method applies the output transform to each observable's data
and concatenates them into a single vector.

Returns
-------
np.ndarray
The concatenated transformed observational data vector.
"""
transformed_y = []
for observable in self.observables:
y_transformed = observable.get_transformed_y()
transformed_y.append(y_transformed)

return np.concatenate(transformed_y)

def get_transformed_covariance_matrix(self, **kwargs) -> np.ndarray:
"""
Get the transformed combined covariance matrix.

This method applies the Jacobian transformation to each observable's
covariance matrix and combines them into a block-diagonal matrix.

Parameters
----------
**kwargs : dict
Additional keyword arguments passed to check_covariance_matrix.

Returns
-------
np.ndarray
The block-diagonal transformed covariance matrix.
"""
covs = []
for observable in self.observables:
cov_transformed = observable.get_transformed_covariance_matrix(**kwargs)
covs.append(cov_transformed)

cov = linalg.block_diag(*covs)

# Perform sanity checks on the covariance matrix
check_covariance_matrix(cov, name="combined transformed covariance", **kwargs)

return cov

def get_transformed_emulator_error(self) -> np.ndarray:
"""
Get the transformed combined emulator error.

This method applies the output transform to each observable's emulator
error and concatenates them into a single vector.

Returns
-------
np.ndarray
The concatenated transformed emulator error vector.
"""
transformed_errors = []
for observable in self.observables:
error_transformed = observable.get_transformed_emulator_error()
transformed_errors.append(error_transformed)

return np.concatenate(transformed_errors)

def get_transformed_emulator_covariance_matrix(
self,
prefactor: float = 1.0,
method: str = 'bootstrap',
diag: bool = False,
**kwargs
) -> np.ndarray:
"""
Get the transformed combined emulator covariance matrix.

This method applies the Jacobian transformation to each observable's
emulator covariance matrix and combines them into a block-diagonal matrix.

Parameters
----------
prefactor : float, optional
Prefactor to multiply the emulator error by. Default is 1.0.
method : str, optional
Method to estimate the emulator error from. Default is 'bootstrap'.
diag : bool, optional
If True, only the diagonal elements are returned. Default is False.
**kwargs : dict
Additional keyword arguments passed to check_covariance_matrix.

Returns
-------
np.ndarray
The block-diagonal transformed emulator covariance matrix.
"""
return self.get_emulator_covariance_matrix(
prefactor=prefactor,
method=method,
diag=diag,
transform_output=True,
**kwargs
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these be defined in the CombinedObservable class ?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm sorry, what do you mean? I think that's where they are atm

Comment on lines +437 to +565
def get_transformed_y(self) -> xarray.DataArray | np.ndarray:
"""
Get the data vector transformed to the model's output space.

For example, if the model was trained with log10 transform, this returns log10(y).
Useful for performing inference in the transformed space.

Returns
-------
xarray.DataArray or np.ndarray
Transformed data vector, with same shape and filters as self.y

Raises
------
ValueError
If no output transform is available on the model.
"""
self._validate_output_transform()

# Get unfiltered y to apply transform
y_unfiltered = self._dataset.y

# Apply the transform
y_transformed = self.apply_output_transform(y_unfiltered)

# Now apply filters and formatting
y_transformed = self.apply_filters(y_transformed)
y_transformed = self.flatten_output(y_transformed, self.flat_output_dims)
y_transformed = self.apply_indices_selection(y_transformed)
if self.squeeze_output:
y_transformed = y_transformed.squeeze()
if self.numpy_output:
y_transformed = y_transformed.values

return y_transformed

@temporary_class_state(numpy_output=False)
def get_transformed_covariance_matrix(self, volume_factor: float = 64, prefactor: float = 1, **kwargs) -> np.ndarray:
"""
Get the covariance matrix transformed to the model's output space.

This method transforms individual samples first, then computes the covariance
from the transformed samples. This is more accurate than using the Jacobian approximation.

Parameters
----------
volume_factor : float
Volume correction factor for the boxes. Default is 64.
prefactor : float
Prefactor to apply to the covariance matrix (e.g. Hartlap or Percival).
**kwargs
Additional arguments for the covariance matrix checker.

Returns
-------
np.ndarray
Transformed covariance matrix, shape (n_features, n_features).

Raises
------
ValueError
If no output transform is available on the model.
"""
return self.get_covariance_matrix(
volume_factor=volume_factor,
prefactor=prefactor,
transform_output=True,
**kwargs
)

def get_transformed_emulator_error(self) -> xarray.DataArray | np.ndarray:
"""
Get the emulator error transformed to the model's output space.

The emulator error is computed in the transformed space by:
1. Getting model predictions in transformed space (skip_output_inverse_transform=True)
2. Transforming the test data to the same space
3. Computing the median absolute difference

Returns
-------
xarray.DataArray or np.ndarray
Transformed emulator error, with shape (n_features,).

Raises
------
ValueError
If no output transform is available on the model.
"""
return self.get_emulator_error(transform_output=True)

@temporary_class_state(numpy_output=False)
def get_transformed_emulator_covariance_matrix(self, prefactor: float = 1, method: str = 'median', diag: bool = False, **kwargs) -> np.ndarray:
"""
Get the emulator covariance matrix transformed to the model's output space.

This method transforms individual emulator residuals first, then computes the covariance
from the transformed residuals. This is more accurate than using the Jacobian approximation.

Parameters
----------
prefactor : float
Prefactor to apply to the covariance matrix (e.g. Hartlap or Percival). Defaults to 1.
method : str
Method to compute the covariance matrix from the emulator residuals.
Options include 'median', 'mean', or 'stdev'. Defaults to 'median'.
diag : bool
If True, only the diagonal of the covariance matrix is computed. Defaults to False.
**kwargs
Additional arguments for the covariance matrix checker.

Returns
-------
np.ndarray
Transformed emulator covariance matrix, shape (n_features, n_features).

Raises
------
ValueError
If no output transform is available on the model.
"""
return self.get_emulator_covariance_matrix(
prefactor=prefactor,
method=method,
diag=diag,
transform_output=True,
**kwargs
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these be defined in the base class ?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, are you suggesting we put them somewhere else?

@epaillas
Copy link
Owner Author

Main challenge is that the sampling in the transformed space requires transformation of everything that goes into the likelihood -- data, covariance, and model.

The model is already transformed by default, so we need to apply an inverse transform at every MCMC step. The data and covariance need to be transformed appropriately, and that's what the observable class methods are doing in this branch. We could outsource those functions to keep the classes clean, or just have this separate branch for Nathan to test things. I don't have a strong preference other than I want to enable Nathan to do these tests as soon as possible to see if we should worry about it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants