Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 18 additions & 14 deletions sunbird/cosmology/growth_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
class Growth:
def __init__(
self,
theta_star: float = AbacusSummit(0).theta_star,
theta_MC_100: float = AbacusSummit(0)['theta_MC_100'],
emulate=False,
emulator_data_dir=DEFAULT_PATH / "data/hemu/",
):
self.theta_star = theta_star
self.theta_MC_100 = theta_MC_100
self.emulate = emulate
self.emulator_data_dir = emulator_data_dir
if self.emulate:
Expand Down Expand Up @@ -59,10 +59,13 @@ def generate_emulator_training_data(
h_values, sample_parameters = [], []
for i, sample in enumerate(samples_matrix):
try:
cosmology = self.get_cosmology_fixed_theta_star(
# print every 1000th sample
if i % 1000 == 0:
print(i)
cosmology = self.get_cosmology_fixed_theta_MC_100(
DESI(engine="class"),
dict(
theta_star=self.theta_star,
theta_MC_100=self.theta_MC_100,
omega_b=sample[0],
omega_cdm=sample[1],
sigma8=sample[2],
Expand Down Expand Up @@ -274,29 +277,30 @@ def get_emulated_h(self, omega_b, omega_cdm, sigma8, N_ur, n_s, w0_fld, wa_fld):
x = jnp.vstack([omega_b, omega_cdm, sigma8, N_ur, n_s, w0_fld, wa_fld]).T
return self.model.apply(self.params, x)

def get_cosmology_fixed_theta_star(
def get_cosmology_fixed_theta_MC_100(
self,
fiducial,
params,
h_limits=[0.4, 1.0],
xtol=1.0e-6,
):
theta = params.pop("theta_star", None)
theta = params.pop("theta_MC_100", None)
fiducial = fiducial.clone(base="input", **params)
if theta is not None:
if "h" in params:
raise ValueError("Cannot provide both theta_star and h")
raise ValueError("Cannot provide both theta_MC_100 and h")

def f(h):
cosmo = fiducial.clone(base="input", h=h)
return 100.0 * (theta - cosmo.get_thermodynamics().theta_star)
# return 100.0 * (theta - cosmo.get_thermodynamics().theta_MC_100)
return 100.0 * (theta - cosmo['theta_MC_100'])

rtol = xtol
try:
h = optimize.bisect(f, *h_limits, xtol=xtol, rtol=rtol, disp=True)
except ValueError as exc:
raise ValueError(
"Could not find proper h value in the interval that matches theta_star = {:.4f} with [f({:.3f}), f({:.3f})] = [{:.4f}, {:.4f}]".format(
"Could not find proper h value in the interval that matches theta_MC_100 = {:.4f} with [f({:.3f}), f({:.3f})] = [{:.4f}, {:.4f}]".format(
theta, *h_limits, *list(map(f, h_limits))
)
) from exc
Expand Down Expand Up @@ -335,10 +339,10 @@ def get_growth(
z=z,
)
else:
cosmology = self.get_cosmology_fixed_theta_star(
cosmology = self.get_cosmology_fixed_theta_MC_100(
DESI(engine="class"),
dict(
theta_star=self.theta_star,
theta_MC_100=self.theta_MC_100,
omega_b=omega_b,
omega_cdm=omega_cdm,
sigma8=sigma8,
Expand Down Expand Up @@ -393,10 +397,10 @@ def get_fsigma8(
)
return growth_rate * sigma8_z
else:
cosmology = self.get_cosmology_fixed_theta_star(
cosmology = self.get_cosmology_fixed_theta_MC_100(
DESI(engine="class"),
dict(
theta_star=self.theta_star,
theta_MC_100=self.theta_MC_100,
omega_b=omega_b,
omega_cdm=omega_cdm,
sigma8=sigma8,
Expand All @@ -414,6 +418,6 @@ def get_fsigma8(

t0 = time.time()
growth = Growth()
# growth.generate_emulator_training_data()
growth.generate_emulator_training_data(n_samples=100_000)
growth.train_emulator()
print(f"It took {time.time() - t0} seconds")
105 changes: 105 additions & 0 deletions sunbird/data/transforms_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,27 @@ def transform(self, x):
@abstractmethod
def inverse_transform(self, x):
pass

@abstractmethod
def get_jacobian_diagonal(self, y):
"""
Get the diagonal of the Jacobian matrix df/dy for transforming covariance matrices.

For an element-wise transformation f(y), the transformed covariance is:
Cov_transformed = diag(J) @ Cov @ diag(J)
where J = df/dy is the Jacobian diagonal.

Parameters
----------
y : array_like
Data vector in the original (untransformed) space.

Returns
-------
array_like
Diagonal of the Jacobian matrix, same shape as y.
"""
pass


class LogTransform(BaseTransform):
Expand All @@ -22,6 +43,27 @@ def transform(self, x):

def inverse_transform(self, x):
return 10**x

def get_jacobian_diagonal(self, y):
"""
Get Jacobian diagonal for log10 transform: d(log10(y))/dy = 1/(y * ln(10))

Parameters
----------
y : array_like
Data vector in the original (untransformed) space.

Returns
-------
array_like
Jacobian diagonal: 1/(y * ln(10))
"""
if type(y) == torch.Tensor:
return 1.0 / (y * torch.log(torch.tensor(10.0)))
elif type(y) == np.ndarray:
return 1.0 / (y * np.log(10.0))
else:
return 1.0 / (y * jnp.log(10.0))


class ArcsinhTransform(BaseTransform):
Expand All @@ -40,6 +82,27 @@ def inverse_transform(self, x):
return np.sinh(x)
else:
return jnp.sinh(x)

def get_jacobian_diagonal(self, y):
"""
Get Jacobian diagonal for arcsinh transform: d(arcsinh(y))/dy = 1/sqrt(1 + y^2)

Parameters
----------
y : array_like
Data vector in the original (untransformed) space.

Returns
-------
array_like
Jacobian diagonal: 1/sqrt(1 + y^2)
"""
if type(y) == torch.Tensor:
return 1.0 / torch.sqrt(1.0 + y**2)
elif type(y) == np.ndarray:
return 1.0 / np.sqrt(1.0 + y**2)
else:
return 1.0 / jnp.sqrt(1.0 + y**2)

class WeiLiuOutputTransForm(BaseTransform):
"""Class to reconcile output the Minkowski functionals model
Expand All @@ -56,6 +119,27 @@ def transform(self, x):

def inverse_transform(self, x):
return x * self.std + self.mean

def get_jacobian_diagonal(self, y):
"""
Get Jacobian diagonal for affine transform: d(y * std + mean)/dy = std

Parameters
----------
y : array_like
Data vector in the original (untransformed) space.

Returns
-------
array_like
Jacobian diagonal: std (broadcast to match y shape)
"""
if type(y) == torch.Tensor:
return torch.ones_like(y) * self.std
elif type(y) == np.ndarray:
return np.ones_like(y) * self.std.numpy()
else:
return jnp.ones_like(y) * self.std.numpy()

class WeiLiuInputTransform(BaseTransform):
"""Class to reconcile input of the Minkowski functionals model
Expand All @@ -72,4 +156,25 @@ def transform(self, x):

def inverse_transform(self, x):
return x

def get_jacobian_diagonal(self, y):
"""
Get Jacobian diagonal for standardization: d((y - mean) / std)/dy = 1/std

Parameters
----------
y : array_like
Data vector in the original (untransformed) space.

Returns
-------
array_like
Jacobian diagonal: 1/std (broadcast to match y shape)
"""
if type(y) == torch.Tensor:
return torch.ones_like(y) / self.std
elif type(y) == np.ndarray:
return np.ones_like(y) / self.std.numpy()
else:
return jnp.ones_like(y) / self.std.numpy()

20 changes: 16 additions & 4 deletions sunbird/emulators/models/fcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
],
)
self.compression_matrix = compression_matrix

@staticmethod
def add_model_specific_args(parent_parser):
"""Model arguments that could vary
Expand Down Expand Up @@ -275,16 +275,28 @@ def forward(self, x: Tensor) -> Tensor:
y_var = torch.zeros_like(y_pred)
return y_pred, y_var

def get_prediction(self, x: Tensor, filters: Optional[dict] = None) -> Tensor:
def get_prediction(self, x: Tensor, filters: Optional[dict] = None, skip_output_inverse_transform: bool = False) -> Tensor:
"""Get prediction from the model.

Args:
x (Tensor): Input tensor
filters (dict, optional): Filters to apply. Defaults to None.
skip_output_inverse_transform (bool, optional): If True, skip the output inverse transformation,
keeping predictions in the transformed space. Useful when performing inference in transformed
space (requires transforming observations and covariance to match). Defaults to False.

Returns:
Tensor: Model prediction
"""
x = torch.Tensor(x)
if self.transform_input:
if self.transform_input is not None:
x = self.transform_input.transform(x)
y, _ = self.forward(x)
if self.standarize_output:
std_output = self.std_output.to(x.device)
mean_output = self.mean_output.to(x.device)
y = y * std_output + mean_output
if self.transform_output:
if self.transform_output is not None and not skip_output_inverse_transform:
y = self.transform_output.inverse_transform(y)
if self.compression_matrix is not None:
y = y @ self.compression_matrix
Expand Down
71 changes: 64 additions & 7 deletions sunbird/inference/base.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,97 @@
"""Base classes and utilities for inference samplers."""

import logging
import numpy as np
from tabulate import tabulate
from typing import Dict, Optional
from sunbird.inference.priors import AbacusSummitEllipsoid

class BaseSampler:
def __init__(self,
"""Base class for inference samplers.

Handles parameter bookkeeping, optional transformed-space sampling, and
convenience utilities for saving chains and summary tables.
"""

def __init__(
self,
observation,
precision_matrix,
theory_model,
priors,
ranges: Optional[Dict[str, tuple]] = {},
labels: Dict[str, str] = {},
fixed_parameters: Dict[str, float] = {},
slice_filters: Dict = {},
select_filters: Dict = {},
coordinates: list = [],
ellipsoid: bool = False,
markers: dict = {},
sample_in_transformed_space: bool = False,
**kwargs,
):
"""Initialize the sampler base.

Args:
observation: Observed data vector.
precision_matrix: Inverse covariance matrix.
theory_model: Callable model that maps parameters to predictions.
priors: Mapping of parameter names to prior objects.
ranges: Optional plotting or reporting ranges by parameter.
labels: Optional labels by parameter.
fixed_parameters: Mapping of parameter names to fixed values.
ellipsoid: Whether to include the AbacusSummit ellipsoid prior.
markers: Optional marker styling for plots.
sample_in_transformed_space: If True, use transformed outputs.
**kwargs: Extra arguments for subclasses.
"""
self.logger = logging.getLogger(self.__class__.__name__)
self.theory_model = theory_model
if fixed_parameters is None:
fixed_parameters = {}
self.fixed_parameters = fixed_parameters
self.observation = observation
self.priors = priors
self.ranges = ranges
self.labels = labels
self.precision_matrix = precision_matrix
self.ellipsoid = ellipsoid
self.markers = markers
self.sample_in_transformed_space = sample_in_transformed_space

# Handle transformation of observations and covariance
if sample_in_transformed_space:
# Validate that the observable has an output transform
if not hasattr(theory_model.__self__.model, 'transform_output'):
raise ValueError('Cannot sample in transformed space: observable does not have a transform_output. '
'Either set sample_in_transformed_space=False or use an observable with transform_output.')

# Check if transform_output is valid (not None or empty list)
transform = theory_model.__self__.model.transform_output
if transform is None:
raise ValueError('Cannot sample in transformed space: transform_output is None. '
'Either set sample_in_transformed_space=False or use an observable with transform_output.')

# For combined observables, transform_output is a list
if isinstance(transform, list):
if all(t is None for t in transform):
raise ValueError('Cannot sample in transformed space: all transforms in combined observable are None. '
'Either set sample_in_transformed_space=False or use observables with transform_output.')

self.logger.warning('Sampling in transformed space (skip_output_inverse_transform=True). '
'Ensure observations and covariance matrix are also transformed to match!')

self.observation = observation
self.precision_matrix = precision_matrix

if self.ellipsoid:
self.abacus_ellipsoid = AbacusSummitEllipsoid()

self.ndim = len(self.priors.keys()) - len(self.fixed_parameters.keys())
self.logger.info(f'Free parameters: {[key for key in priors.keys() if key not in fixed_parameters.keys()]}')
self.logger.info(f'Fixed parameters: {[key for key in priors.keys() if key in fixed_parameters.keys()]}')

def save_chain(self, save_fn, metadata=None):
"""Save the chain to a file
"""Save a chain dictionary to a NumPy file.

Args:
save_fn: Output filename for the NumPy archive.
metadata: Optional extra metadata to include.
"""
data = self.get_chain(flat=True)
names = [param for param in self.priors.keys() if param not in self.fixed_parameters]
Expand All @@ -66,6 +118,11 @@ def save_chain(self, save_fn, metadata=None):
np.save(save_fn, cout)

def save_table(self, save_fn):
"""Write a summary table with MAP/mean/std values.

Args:
save_fn: Output filename for the text table.
"""
chain = self.get_chain(flat=True)
maxp = chain['samples'][chain['log_posterior'].argmax()]
mean = chain['samples'].mean(axis=0)
Expand Down
Loading