diff --git a/docs/source/plots.rst b/docs/source/plots.rst index 1069c2f..5669cf5 100644 --- a/docs/source/plots.rst +++ b/docs/source/plots.rst @@ -30,4 +30,7 @@ Plots .. autoclass:: deepdiagnostics.plots.Parity :members: plot +.. autoclass:: deepdiagnostics.plots.CDFParity + :members: plot + .. bibliography:: \ No newline at end of file diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 31f98c9..c432cde 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -20,14 +20,76 @@ Installation git clone https://github.com/deepskies/DeepDiagnostics/ pip install poetry - poetry shell - poetry install + poetry install + poetry run diagnose --help +Pre-requisites +------------- + +DeepDiagnostics does not train models or generate data, they must be provided. +Possible model formats are listed in :ref:`models` and data formats in :ref:`data`. +If you are using a simulator, it must be registered by using `deepdiagnostics.utils.register.register_simulator`. +More information can be found in :ref:`custom_simulations`. + +Output directories are automatically created, and if a run ID is not specified, one is generated. +Only if a run ID is specified will previous runs be overwritten. + Configuration ----- +------------- + +Description of the configuration file, including defaults, can be found in :ref:`configuration`. +Below is a minimal example. + +...code-block:: yaml + + common: + out_dir: "./deepdiagnostics_results/" + random_seed: 42 + data: + data_engine: "H5Data" + data_path: "./data/my_data.h5" + simulator: "MySimulator" + simulator_kwargs: # Any augments used to initialize the simulator + foo: bar + model: + model_engine: "SBIModel" + model_path: "./models/my_model.pkl" + plots_common: # Used across all plots + parameter_labels: # Can either be plain strings or rendered LaTeX strings + - "My favorite parameter" + - "My least favorite parameter" + - "My most mid parameter" + parameter_colors: # Any color recognized by matplotlib + - "#264a95" + - "#ed9561" + - "#89b7bb" + line_style_cycle: # Any line type recognized by matplotlib + - solid + - dashed + - dotted + figure_size: # Approximate size, it can be scaled when adding additional subfigures + - 6 # x length + - 6 # y length + metrics_common: # Used across all metrics (and plots if the plots have a calculation step) + samples_per_inference: 1000 + number_simulations: 100 + percentiles: + - 68 + - 95 + plots: + CoverageFraction: # Arguments supplied to {plottype}.plot() + include_coverage_std: True + include_ideal_range: True + reference_line_label: "Ideal Coverage" + TARP: + coverage_sigma: 4 + title: "TARP of My Model" + metrics: + AllSBC + Ranks: + num_bins: 3 -Description of the configuration file, including defaults, can be found in :ref:`configuration`. Pipeline --------- diff --git a/resources/saveddata/data_test.h5 b/resources/saveddata/data_test.h5 new file mode 100644 index 0000000..6b5fd74 Binary files /dev/null and b/resources/saveddata/data_test.h5 differ diff --git a/src/deepdiagnostics/data/data.py b/src/deepdiagnostics/data/data.py index 656baa9..ad9722b 100644 --- a/src/deepdiagnostics/data/data.py +++ b/src/deepdiagnostics/data/data.py @@ -46,99 +46,32 @@ def __init__( msg = f"Could not load the lookup table simulator - {e}. You cannot use generative diagnostics." print(msg) + self.context = self._context() + self.thetas = self._thetas() + self.prior_dist = self.load_prior(prior, prior_kwargs) - self.n_dims = self.get_theta_true().shape[1] + self.n_dims = self.thetas.shape[1] self.simulator_dimensions = simulation_dimensions if simulation_dimensions is not None else get_item("data", "simulator_dimensions", raise_exception=False) - def get_simulator_output_shape(self) -> tuple[Sequence[int]]: - """ - Run a single sample of the simulator to verify the out-shape. - - Returns: - tuple[Sequence[int]]: Output shape of a single sample of the simulator. - """ - context_shape = self.true_context().shape - sim_out = self.simulator(theta=self.get_theta_true()[0:1, :], n_samples=context_shape[-1]) - return sim_out.shape + self.simulator_outcome = self._simulator_outcome() def _load(self, path: str): raise NotImplementedError - def true_context(self): - """ - True data x values, if supplied by the data method. - """ - # From Data + def _simulator_outcome(self): + raise NotImplementedError + + def _context(self): + raise NotImplementedError + + def _thetas(self): raise NotImplementedError - def true_simulator_outcome(self) -> np.ndarray: - """ - Run the simulator on all true theta and true x values. - - Returns: - np.ndarray: array of (n samples, simulator shape) showing output of the simulator on all true samples in data. - """ - return self.simulator(self.get_theta_true(), self.true_context()) - - def sample_prior(self, n_samples: int) -> np.ndarray: - """ - Draw samples from the simulator - - Args: - n_samples (int): Number of samples to draw - - Returns: - np.ndarray: - """ - return self.prior_dist(size=(n_samples, self.n_dims)) - - def simulator_outcome(self, theta:np.ndarray, condition_context:np.ndarray=None, n_samples:int=None): - """_summary_ - - Args: - theta (np.ndarray): Theta value of shape (n_samples, theta_dimensions) - condition_context (np.ndarray, optional): If x values for theta are known, use them. Defaults to None. - n_samples (int, optional): If x values are not known for theta, draw them randomly. Defaults to None. - - Raises: - ValueError: If either n samples or content samples is supplied. - - Returns: - np.ndarray: Simulator output of shape (n samples, simulator_dimensions) - """ - if condition_context is None: - if n_samples is None: - raise ValueError( - "Samples required if condition context is not specified" - ) - return self.simulator(theta, n_samples) - else: - return self.simulator.simulate(theta, condition_context) - - def simulated_context(self, n_samples:int) -> np.ndarray: - """ - Call the simulator's `generate_context` method. - - Args: - n_samples (int): Number of samples to draw. - - Returns: - np.ndarray: context (x values), as defined by the simulator. - """ - return self.simulator.generate_context(n_samples) - - def get_theta_true(self) -> Union[Any, float, int, np.ndarray]: - """ - Look for the true theta given by data. If supplied in the method, use that, other look in the configuration file. - If neither are supplied, return None. + def save(self, data, path: str): + raise NotImplementedError - Returns: - Any: Theta value selected by the search. - """ - if hasattr(self, "theta_true"): - return self.theta_true - else: - return get_item("data", "theta_true", raise_exception=True) + def read_prior(self): + raise NotImplementedError def get_sigma_true(self) -> Union[Any, float, int, np.ndarray]: """ @@ -153,12 +86,6 @@ def get_sigma_true(self) -> Union[Any, float, int, np.ndarray]: else: return get_item("data", "sigma_true", raise_exception=True) - def save(self, data, path: str): - raise NotImplementedError - - def read_prior(self): - raise NotImplementedError - def load_prior(self, prior:str, prior_kwargs:dict[str, any]) -> callable: """ Load the prior. @@ -201,3 +128,21 @@ def load_prior(self, prior:str, prior_kwargs:dict[str, any]) -> callable: except KeyError as e: raise RuntimeError(f"Data missing a prior specification - {e}") + + def sample_prior(self, n_samples: int) -> np.ndarray: + """ + Sample from the prior. + + Args: + n_samples (int): Number of samples to draw. + + Returns: + np.ndarray: Samples drawn from the prior. + """ + + if self.prior_dist is None: + prior = self.read_prior() + sample = self.rng.randint(0, len(prior), size=n_samples) + return prior[sample] + else: + return self.prior_dist(size=(n_samples, self.n_dims)) \ No newline at end of file diff --git a/src/deepdiagnostics/data/h5_data.py b/src/deepdiagnostics/data/h5_data.py index c5e50ff..924418c 100644 --- a/src/deepdiagnostics/data/h5_data.py +++ b/src/deepdiagnostics/data/h5_data.py @@ -11,6 +11,12 @@ class H5Data(Data): """ Load data that has been saved in a h5 format. + If you cast your problem to be y = mx + b, these are the fields required and what they represent: + + simulator_outcome - y + thetas - parameters of the model - m, b + context - xs + .. attribute:: Data Parameters :xs: [REQUIRED] The context, the x values. The data that was used to train a model on what conditions produce what posterior. @@ -30,6 +36,7 @@ def __init__(self, ): super().__init__(path, simulator, simulator_kwargs, prior, prior_kwargs, simulation_dimensions) + def _load(self, path): assert path.split(".")[-1] == "h5", "File extension must be h5" loaded_data = {} @@ -49,17 +56,29 @@ def save(self, data: dict[str, Any], path: str): # Todo typing for data dict for key, value in data_arrays.items(): file.create_dataset(key, data=value) - def true_context(self): - """ - Try to get the `xs` field of the loaded data. - - Raises: - NotImplementedError: The data does not have a `xs` field. - """ + def _simulator_outcome(self): try: - return self.data["xs"] - except KeyError: - raise NotImplementedError("Cannot find `xs` in data. Please supply it.") + return self.data["simulator_outcome"] + except KeyError: + try: + sim_outcome = np.array((self.simulator_dimensions, len(self.thetas))) + for index, theta in enumerate(self.thetas): + sim_out = self.simulator(theta=theta.unsqueeze(0), n_samples=1) + sim_outcome[:, index] = sim_out + return sim_outcome + + except Exception as e: + e = f"Data does not have a `simulator_output` field and could not generate it from a simulator: {e}" + raise ValueError(e) + + def _context(self): + try: + context = self.data["context"] + if context.ndim == 1: + context = context.unsqueeze(1) + return context + except KeyError: + raise NotImplementedError("Data does not have a `context` field.") def prior(self): """ @@ -68,12 +87,14 @@ def prior(self): Raises: NotImplementedError: The data does not have a `prior` field. """ - try: + if 'prior' in self.data: return self.data['prior'] - except KeyError: - raise NotImplementedError("Data does not have a `prior` field.") - - def get_theta_true(self): + elif self.prior_dist is not None: + return self.prior_dist + else: + raise ValueError("Data does not have a `prior` field.") + + def _thetas(self): """ Get stored theta used to train the model. Returns: diff --git a/src/deepdiagnostics/data/lookup_table_simulator.py b/src/deepdiagnostics/data/lookup_table_simulator.py index 17c3342..1c6b2a1 100644 --- a/src/deepdiagnostics/data/lookup_table_simulator.py +++ b/src/deepdiagnostics/data/lookup_table_simulator.py @@ -11,11 +11,11 @@ class LookupTableSimulator(Simulator): Does not need to be registered, it is automatically available as the default simulator - Assumes your has the following fields accessible as data["xs"], data["thetas"], data["ys"] + Assumes your has the following fields accessible as data["context"], data["thetas"], data["simulator_outcome"], where xs is the context, thetas are the parameters, and ys are the outcomes """ - def __init__(self, data: torch.tensor, random_state: np.random.Generator, outside_range_limit: float = 2.0) -> None: + def __init__(self, data: torch.tensor, random_state: np.random.Generator, outside_range_limit: float = 2.0, hash_precision: int = 10) -> None: """ Parameters @@ -28,7 +28,9 @@ def __init__(self, data: torch.tensor, random_state: np.random.Generator, outsid When values of theta and x are passed that are not in the dataset, if they are greater than the max by this threshold, (where value > outside_range_limit*value_max), a value error is raised instead of taking the nearest neighbor; by default 1.5 - + hash_precision : int, optional + Number of decimal places to round to when creating hash keys, by default 10 + Raises ------ ValueError @@ -36,41 +38,48 @@ def __init__(self, data: torch.tensor, random_state: np.random.Generator, outsid """ super().__init__() # Normalizing for finding nearest neighbors - self.threshold = outside_range_limit + for key in ["simulator_outcome", "thetas", "context"]: + if key not in data.keys(): + msg = f"Data must have a field `{key}` - found {data.keys()}" + raise ValueError(msg) + + self.precision = hash_precision + self.max_theta = torch.max(data["thetas"], axis=0).values self.min_theta = torch.min(data["thetas"], axis=0).values - self.max_x = torch.max(data["xs"], axis=0).values - self.min_x = torch.min(data["xs"], axis=0).values + self.max_x = torch.max(data["context"], axis=0).values + self.min_x = torch.min(data["context"], axis=0).values + + self.threshold = outside_range_limit * torch.max(self.max_theta) self.table = self._build_table(data) self.rng = random_state - - for key in ["xs", "thetas", "ys"]: - if key not in data.keys(): - msg = f"Data must have a field `{key}` - found {data.keys()}" - raise ValueError(msg) def _build_table(self, data): "Takes all the theta, context and outcome data and builds a lookup table" table = { - self._build_hash(theta, context): { - "y": outcome, + self._build_hash(theta, simulator_outcome): { + "y": simulator_outcome, "loc": self._calc_hash_distance(theta, context), "theta": theta, "x": context, } - for theta, context, outcome in zip(data["thetas"], data["xs"], data["ys"]) + for theta, simulator_outcome, context in zip(data["thetas"], data["simulator_outcome"], data['context']) } return table def _build_hash(self, theta, context): "Take a theta and context, and build a hashable key for the lookup table" - return hash(tuple(torch.concat([theta, context], dim=-1))) + hashval = torch.concat([theta, context], dim=-1).flatten() + rounded_values = [round(val.item(), self.precision) for val in hashval] + return hash(tuple(rounded_values)) def _calc_hash_distance(self, theta: Union[torch.Tensor, float], context: Union[torch.Tensor, float]) -> float: "Create a distance (as the norm) metric between pairs of theta and context" theta = (theta - self.min_theta) / (self.max_theta - self.min_theta) context = (context - self.min_x) / (self.max_x - self.min_x) + theta = theta.unsqueeze(0) if theta.dim() == 0 else theta + context = context.unsqueeze(0) if context.dim() == 0 else context return torch.linalg.norm(torch.concat([theta, context], dim=-1)) def generate_context(self, n_samples): @@ -100,11 +109,13 @@ def simulate(self, theta: Union[torch.Tensor, float], context_samples: Union[tor context_samples = torch.Tensor([context_samples]) for t, x in zip(theta, context_samples): + t = t.unsqueeze(0) if t.dim() == 0 else t + x = x.unsqueeze(0) if x.dim() == 0 else x key = self._build_hash(t, x) try: results.append(self.table[key]["y"]) except KeyError: - print(f"Could not match theta {t} and x {x} to a result - taking the nearest neighbor") + print(f"Could not match theta {t} and context {x} to a result - taking the nearest neighbor") space_hit = self._calc_hash_distance(t, x) nearest_key = min(self.table.keys(), key=lambda k: abs(self.table[k]["loc"] - space_hit)) diff --git a/src/deepdiagnostics/data/simulator.py b/src/deepdiagnostics/data/simulator.py index 6b77d77..adb51f5 100644 --- a/src/deepdiagnostics/data/simulator.py +++ b/src/deepdiagnostics/data/simulator.py @@ -50,7 +50,7 @@ def simulate(self, theta: np.ndarray, context_samples: np.ndarray) -> np.ndarray Specify a simulation S such that y_{theta} = S(context_samples|theta) - Example: + Example: .. code-block:: python # Generate from a random distribution diff --git a/src/deepdiagnostics/metrics/__init__.py b/src/deepdiagnostics/metrics/__init__.py index a713120..7777472 100644 --- a/src/deepdiagnostics/metrics/__init__.py +++ b/src/deepdiagnostics/metrics/__init__.py @@ -1,6 +1,7 @@ from deepdiagnostics.metrics.all_sbc import AllSBC from deepdiagnostics.metrics.coverage_fraction import CoverageFraction from deepdiagnostics.metrics.local_two_sample import LocalTwoSampleTest as LC2ST +from deepdiagnostics.metrics.under_cdf_parity import CDFParityAreaUnderCurve as CDFParityAreaUnderCurve def void(*args, **kwargs): def void2(*args, **kwargs): diff --git a/src/deepdiagnostics/metrics/all_sbc.py b/src/deepdiagnostics/metrics/all_sbc.py index 8de9d54..35a1df6 100644 --- a/src/deepdiagnostics/metrics/all_sbc.py +++ b/src/deepdiagnostics/metrics/all_sbc.py @@ -39,8 +39,8 @@ def __init__( number_simulations) def _collect_data_params(self): - self.thetas = tensor(self.data.get_theta_true()) - self.context = tensor(self.data.true_context()) + self.thetas = tensor(self.data.thetas) + self.simulator_outcome = tensor(self.data.simulator_outcome) def calculate(self) -> dict[str, Sequence]: """ @@ -51,7 +51,7 @@ def calculate(self) -> dict[str, Sequence]: """ ranks, dap_samples = run_sbc( self.thetas, - self.context, + self.simulator_outcome, self.model.posterior, num_posterior_samples=self.samples_per_inference, ) diff --git a/src/deepdiagnostics/metrics/coverage_fraction.py b/src/deepdiagnostics/metrics/coverage_fraction.py index a114b6e..98ecf74 100644 --- a/src/deepdiagnostics/metrics/coverage_fraction.py +++ b/src/deepdiagnostics/metrics/coverage_fraction.py @@ -37,12 +37,8 @@ def __init__( self._collect_data_params() def _collect_data_params(self): - self.thetas = self.data.get_theta_true() - self.context = self.data.true_context() - - def _run_model_inference(self, samples_per_inference, y_inference): - samples = self.model.sample_posterior(samples_per_inference, y_inference) - return samples.numpy() + self.thetas = self.data.thetas + self.simulator_outcome = self.data.simulator_outcome def calculate(self) -> tuple[Sequence, Sequence]: """ @@ -55,6 +51,7 @@ def calculate(self) -> tuple[Sequence, Sequence]: all_samples = np.empty( (self.number_simulations, self.samples_per_inference, np.shape(self.thetas)[1]) ) + iterator = range(self.number_simulations) if self.use_progress_bar: iterator = tqdm( @@ -62,12 +59,13 @@ def calculate(self) -> tuple[Sequence, Sequence]: desc="Sampling from the posterior for each observation", unit=" observation", ) + n_theta_samples = self.thetas.shape[0] count_array = np.zeros((self.number_simulations, len(self.percentiles), self.thetas.shape[1])) for sample_index in iterator: - context_sample = self.context[self.data.rng.integers(0, len(self.context))] - samples = self._run_model_inference(self.samples_per_inference, context_sample) + context_sample = self.simulator_outcome[self.data.rng.integers(0, len(self.simulator_outcome))] + samples = self.model.sample_posterior(self.samples_per_inference, context_sample).numpy() all_samples[sample_index] = samples @@ -81,7 +79,6 @@ def calculate(self) -> tuple[Sequence, Sequence]: # the units are in parameter space confidence_lower = np.percentile(samples, percentile_lower, axis=0) confidence_upper = np.percentile(samples, percentile_upper, axis=0) - # this is asking if the true parameter value # is contained between the diff --git a/src/deepdiagnostics/metrics/local_two_sample.py b/src/deepdiagnostics/metrics/local_two_sample.py index 4487909..ec39320 100644 --- a/src/deepdiagnostics/metrics/local_two_sample.py +++ b/src/deepdiagnostics/metrics/local_two_sample.py @@ -60,13 +60,13 @@ def _collect_data_params(self): # P is the prior and x_P is generated via the simulator from the parameters P. self.p = self.data.sample_prior(self.number_simulations) self.q = np.zeros_like(self.p) - context_size = self.data.true_context().shape[-1] + context_size = self.data.simulator_outcome.shape[-1] remove_first_dim = False if self.data.simulator_dimensions == 1: self.outcome_given_p = np.zeros((self.number_simulations, context_size)) elif self.data.simulator_dimensions == 2: - sim_out_shape = self.data.get_simulator_output_shape() + sim_out_shape = self.data.simulator_outcome[0].shape if len(sim_out_shape) != 2: # TODO Debug log with a warning sim_out_shape = (sim_out_shape[1], sim_out_shape[2]) @@ -169,7 +169,7 @@ def _cross_eval_score( self.evaluation_data = np.zeros((n_cross_folds, len(next(cv_splits)[1]), self.evaluation_context.shape[-1])) elif self.data.simulator_dimensions == 2: - sim_out_shape = self.data.get_simulator_output_shape() + sim_out_shape = self.data.simulator_outcome[0].shape if len(sim_out_shape) != 2: # TODO Debug log with a warning sim_out_shape = (sim_out_shape[1], sim_out_shape[2]) diff --git a/src/deepdiagnostics/metrics/under_cdf_parity.py b/src/deepdiagnostics/metrics/under_cdf_parity.py new file mode 100644 index 0000000..ff3250b --- /dev/null +++ b/src/deepdiagnostics/metrics/under_cdf_parity.py @@ -0,0 +1,102 @@ +from typing import Union, TYPE_CHECKING, Any, Optional, Sequence + +import json +import os + +import numpy as np +import matplotlib.pyplot as plt +from matplotlib import patches as mpatches +from scipy.stats import ecdf, binom + +from deepdiagnostics.data import data +from deepdiagnostics.models import model +from deepdiagnostics.utils.config import get_item + +import numpy as np +from tqdm import tqdm +from typing import Any, Sequence + +from torch import tensor + +from deepdiagnostics.metrics.metric import Metric + + +class CDFParityAreaUnderCurve(Metric): + def __init__( + self, + model, + data, + run_id, + out_dir= None, + save = True, + use_progress_bar = None, + samples_per_inference = None, + percentiles = None, + number_simulations = None, + ) -> None: + super().__init__(model, data, run_id, out_dir, + save, + use_progress_bar, + samples_per_inference, + percentiles, + number_simulations) + self._collect_data_params() + + def _collect_data_params(self): + self.n_dims = self.data.n_dims + theta_true = self.data.thetas#self.data.get_theta_true()#self.data.thetas + self.posterior_samples = np.zeros( + (self.number_simulations, self.samples_per_inference, self.n_dims) + ) + thetas = np.zeros((self.number_simulations, self.samples_per_inference, self.n_dims)) + + for n in range(self.number_simulations): + sample_index = self.data.rng.integers(0, len(theta_true)) + + theta = theta_true[sample_index, :] + x = self.data.context[sample_index, :] + self.posterior_samples[n] = self.model.sample_posterior( + self.samples_per_inference, x + ) + + thetas[n] = np.array([theta for _ in range(self.samples_per_inference)]) + + thetas = thetas.reshape( + (self.number_simulations * self.samples_per_inference, self.n_dims) + ) + + """ + Compute the ECDF for post posteriors samples against the true parameter values. + Uses [scipy.stats.ecdf](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.ecdf.html) to compute the distributions from both given data and samples from the posterior + + ..code-block:: python + + from deepdiagnostics.plots import CDFParityPlot + + CDFParityPlot(model, data, save=False, show=True)() + + """ + def calculate(self) -> dict[str, float]: + #one dimensional + #ecdf_sample = ecdf(self.posterior_samples[:, 0].ravel()) + #print("ecdf_sample ", ecdf_sample) + results = {} + # Loop through each parameter dimension + for d in range(self.n_dims): + # Flatten posterior samples for this dimension + ecdf_sample = ecdf(self.posterior_samples[:, :, d].ravel()) + # Calculate the area under the ECDF curve + # Compute the ECDF + x = ecdf_sample.cdf.quantiles + y = ecdf_sample.cdf.probabilities + #print("x",x," y ", y) + area_under_ecdf = np.trapezoid(y, x) + #print(f"Area under the ECDF: {area_under_ecdf:.4f}") + #auc="Area_Under_Curve" + results[f"Area_Under_Curve_dim{d}"] = area_under_ecdf + self.output = results#{auc: area_under_ecdf} ##Need to run calculate + return self.output + def __call__(self, **kwds: Any) -> Any: + self._collect_data_params() + self.calculate() + self._finish() diff --git a/src/deepdiagnostics/models/model.py b/src/deepdiagnostics/models/model.py index 46e2c81..02d0a33 100644 --- a/src/deepdiagnostics/models/model.py +++ b/src/deepdiagnostics/models/model.py @@ -1,18 +1,19 @@ class Model: """ - Load a pre-trained model for analysis. - + Load a pre-trained model for analysis. + Args: - model_path (str): relative path to a model. + model_path (str): relative path to a model. """ + def __init__(self, model_path: str) -> None: self.model = self._load(model_path) def _load(self, path: str) -> None: - return NotImplementedError + raise NotImplementedError def sample_posterior(self): - return NotImplementedError + raise NotImplementedError def sample_simulation(self, data): raise NotImplementedError diff --git a/src/deepdiagnostics/models/sbi_model.py b/src/deepdiagnostics/models/sbi_model.py index dedba51..6b12fa9 100644 --- a/src/deepdiagnostics/models/sbi_model.py +++ b/src/deepdiagnostics/models/sbi_model.py @@ -1,31 +1,65 @@ import os import pickle +from sbi.inference.posteriors.base_posterior import NeuralPosterior + from deepdiagnostics.models.model import Model class SBIModel(Model): """ - Load a trained model that was generated with Mackelab SBI :cite:p:`centero2020sbi`. - `Read more about saving and loading requirements here `_. + Load a trained model that was generated with Mackelab SBI :cite:p:`centero2020sbi`. + `Read more about saving and loading requirements here `_. Args: - model_path (str): relative path to a model - must be a .pkl file. + model_path (str): Relative path to a model - must be a .pkl file. """ + def __init__(self, model_path): super().__init__(model_path) def _load(self, path: str) -> None: - assert os.path.exists(path), f"Cannot find model file at location {path}" - assert path.split(".")[-1] == "pkl", "File extension must be 'pkl'" + if not os.path.exists(path): + raise ValueError(f"Cannot find model file at location {path}") + if path.split(".")[-1] != "pkl": + raise ValueError("File extension must be 'pkl'") with open(path, "rb") as file: posterior = pickle.load(file) self.posterior = posterior + @staticmethod + def save_posterior( + neural_posterior: NeuralPosterior, path: str, allow_overwrite: bool = False + ) -> None: + """ + Save an SBI posterior to a pickle file. + + Args: + neural_posterior (NeuralPosterior): A neural posterior object. + Must be an instance of the base class 'NeuralPosterior' + from the sbi package. + path (str): Relative path to a model - must be a .pkl file. + allow_overwrite (bool, optional): Controls whether an attempt to + overwrite succeeds or results in an error. Defaults to False. + """ + if not isinstance(NeuralPosterior): + raise ValueError( + f"'neural_posterior' must be an instance of the base class 'NeuralPosterior' from the 'sbi' package." + ) + if os.path.exists(path) and (not allow_overwrite): + raise ValueError( + f"The path {path} already exists. To overwrite, use 'save_posterior(..., allow_overwrite=True)'" + ) + if path.split(".")[-1] != "pkl": + raise ValueError("File extension must be 'pkl'") + + with open(path, "wb") as file: + pickle.dump(neural_posterior, file) + def sample_posterior(self, n_samples: int, x_true): """ - Sample the posterior + Sample the posterior Args: n_samples (int): Number of samples to draw @@ -40,14 +74,14 @@ def sample_posterior(self, n_samples: int, x_true): def predict_posterior(self, data, context_samples): """ - Sample the posterior and then + Sample the posterior and then Args: data (deepdiagnostics.data.Data): Data module with the loaded simulation - context_samples (np.ndarray): X values to test the posterior over. + context_samples (np.ndarray): X values to test the posterior over. Returns: - np.ndarray: Simulator output + np.ndarray: Simulator output """ posterior_samples = self.sample_posterior(context_samples) posterior_predictive_samples = data.simulator( diff --git a/src/deepdiagnostics/plots/__init__.py b/src/deepdiagnostics/plots/__init__.py index e83274c..f221674 100644 --- a/src/deepdiagnostics/plots/__init__.py +++ b/src/deepdiagnostics/plots/__init__.py @@ -6,6 +6,7 @@ from deepdiagnostics.plots.predictive_posterior_check import PPC from deepdiagnostics.plots.parity import Parity from deepdiagnostics.plots.predictive_prior_check import PriorPC +from deepdiagnostics.plots.cdf_parity import CDFParityPlot def void(*args, **kwargs): @@ -22,5 +23,6 @@ def void2(*args, **kwargs): "LC2ST": LC2ST, PPC.__name__: PPC, "Parity": Parity, - PriorPC.__name__: PriorPC + PriorPC.__name__: PriorPC, + CDFParityPlot.__name__: CDFParityPlot } diff --git a/src/deepdiagnostics/plots/cdf_parity.py b/src/deepdiagnostics/plots/cdf_parity.py new file mode 100644 index 0000000..6d36a2f --- /dev/null +++ b/src/deepdiagnostics/plots/cdf_parity.py @@ -0,0 +1,408 @@ +from typing import Union, TYPE_CHECKING + +import numpy as np +import matplotlib.pyplot as plt +from matplotlib import patches as mpatches +from scipy.stats import ecdf, binom + +from deepdiagnostics.plots.plot import Display +from deepdiagnostics.utils.config import get_item +from deepdiagnostics.utils.utils import DataDisplay + +if TYPE_CHECKING: + from matplotlib.figure import Figure as fig + from matplotlib.axes import Axes as ax + +class CDFParityPlot(Display): + def __init__( + self, + model, + data, + run_id, + save, + show, + out_dir=None, + percentiles = None, + use_progress_bar= None, + samples_per_inference = None, + number_simulations= None, + parameter_names = None, + parameter_colors = None, + colorway =None): + """ + Compute the ECDF for post posteriors samples against the true parameter values. + Uses [scipy.stats.ecdf](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.ecdf.html) to compute the distributions from both given data and samples from the posterior + + ..code-block:: python + + from deepdiagnostics.plots import CDFParityPlot + + CDFParityPlot(model, data, save=False, show=True)() + + """ + super().__init__(model, data, run_id, save, show, out_dir, percentiles, use_progress_bar, samples_per_inference, number_simulations, parameter_names, parameter_colors, colorway) + self.line_cycle = tuple(get_item("plots_common", "line_style_cycle", raise_exception=False)) + self.labels_dict = {} + self.theory_alpha = 0.2 + + def plot_name(self): + return "cdf_parity.png" + + def _calculate_theory_cdf(self, probability:float, num_bins:int=100) -> tuple[np.array, np.array]: + """ + Calculate the theoretical limits for the CDF of `distribution` with the percentile `probability`. + Assumes the distribution is a binomial distribution + """ + + n_dims = self.data.n_dims + bounds = np.zeros((num_bins, 2, n_dims)) + cdf = np.zeros((num_bins, n_dims)) + + for dim in range(n_dims): + + # Construct uniform histogram. + uni_bins = binom(self.samples_per_inference, p=1 / num_bins).ppf(0.5) * np.ones(num_bins) + uni_bins_cdf = uni_bins.cumsum() / uni_bins.sum() + # Decrease value one in last entry by epsilon to find valid + # confidence intervals. + uni_bins_cdf[-1] -= 1e-9 + lower = [binom(self.samples_per_inference, p=p).ppf(1-probability) for p in uni_bins_cdf] + upper = [binom(self.samples_per_inference, p=p).ppf(probability) for p in uni_bins_cdf] + + bounds[:, 0, dim] = lower/np.max(lower) + bounds[:, 1, dim] = upper/np.max(upper) + + cdf[:, dim] = uni_bins_cdf + + return bounds, cdf + + def _data_setup(self, num_bins:int=100, **kwargs) -> DataDisplay: + if all([p >= 1 for p in self.percentiles]): + percentiles = [p/100 for p in self.percentiles] + else: + percentiles = self.percentiles or [0.95] + + n_dims = self.data.n_dims + theta_true = self.data.thetas + posterior_samples = np.zeros( + (self.number_simulations, self.samples_per_inference, n_dims) + ) + thetas = np.zeros((self.number_simulations, self.samples_per_inference, n_dims)) + + + for n in range(self.number_simulations): + sample_index = self.data.rng.integers(0, len(theta_true)) + + theta = theta_true[sample_index, :] + x = self.data.context[sample_index, :] + posterior_samples[n] = self.model.sample_posterior( + self.samples_per_inference, x + ) + + thetas[n] = np.array([theta for _ in range(self.samples_per_inference)]) + + thetas = thetas.reshape( + (self.number_simulations * self.samples_per_inference, n_dims) + ) + + calculated_ecdf = {} + theory_cdf = {} + # sample_quartiles based off the first dimension + # Not always perfect, but it ensures that the quantiles are consistent across all dimensions - required for the residuals + ecdf_sample = ecdf(posterior_samples[:, 0].ravel()) + + all_bands = {} + for interval in percentiles: + bands, cdf = self._calculate_theory_cdf(interval, num_bins) + all_bands[interval] = bands + + for dim, name in zip(range(n_dims), self.parameter_names): + parameter_quantiles = np.linspace( + np.min(thetas[:, dim]), + np.max(thetas[:, dim]), + num=num_bins + ) + ecdf_sample = ecdf(posterior_samples[:, dim].ravel()) + sample_probs_common = ecdf_sample.cdf.evaluate(parameter_quantiles) + for interval in percentiles: + all_bands[f"low_theory_probability_{interval}_{name}"] = all_bands[interval][:, 0, dim] + all_bands[f"high_theory_probability_{interval}_{name}"] = all_bands[interval][:, 1, dim] + + theory_cdf[f"theory_probability_{name}"] = cdf[:, dim] + calculated_ecdf[f"quantiles_{name}"] = parameter_quantiles + calculated_ecdf[f"sample_probability_{name}"] = sample_probs_common + + display_data = DataDisplay({ + **calculated_ecdf, + **all_bands, + **theory_cdf, # CDF Isn't calculated differently for percentiles, it's fine to use the last one + "percentiles": np.array(percentiles), + }) + return display_data + + def _plot_base_plot(self, data_display, ax, parameter_name, sample_label, line_style, color, theory_color, theory_line_style): + "Just plot the CDF of the posterior ECDF" + ax.plot( + data_display[f"quantiles_{parameter_name}"], + data_display[f"theory_probability_{parameter_name}"], + color=theory_color, + linestyle=theory_line_style + ) + + ax.plot( + data_display[f"quantiles_{parameter_name}"], + data_display[f"sample_probability_{parameter_name}"], + ls=line_style, + color=color, + ) + + def _plot_theory_intervals(self, data_display, ax, parameter_name, theory_label, color, interval): + lower, upper = ( + data_display[f"low_theory_probability_{interval}_{parameter_name}"], + data_display[f"high_theory_probability_{interval}_{parameter_name}"] + ) + + ax.fill_between( + data_display[f"quantiles_{parameter_name}"], + lower, + upper, + alpha=self.theory_alpha, + color=color + ) + + def _plot_theory_intervals_residual(self, data_display, ax, parameter_name, theory_label, color, interval): + if data_display[f"low_theory_probability_{interval}_{parameter_name}"] is None: + bound_low, bound_high = self._compute_intervals( + data_display[f"theory_probability_{parameter_name}"], + interval, + self.parameter_names.index(parameter_name) + ) + low = bound_low - data_display[f"theory_probability_{parameter_name}"] + high = bound_high - data_display[f"theory_probability_{parameter_name}"] + else: + # Use the precomputed values for the fill-between + low = data_display[f"low_theory_probability_{interval}_{parameter_name}"] - data_display[f"theory_probability_{parameter_name}"] + high = data_display[f"high_theory_probability_{interval}_{parameter_name}"] - data_display[f"theory_probability_{parameter_name}"] + + ax.fill_between( + data_display[f"quantiles_{parameter_name}"], + low, + high, + alpha=self.theory_alpha, + color=color + ) + + def _compute_intervals(self, cdf: np.ndarray, probability: float, dimension:int) -> tuple[np.ndarray, np.ndarray]: + "Use the Dvoretzky-Kiefer-Wolfowitz confidence bands as an approximation for plotting purposes." + + bound = np.sqrt(np.log(2.0 / (1 - probability)) / (2.0 * float(cdf.shape[0]))) + lower = cdf[:, dimension] - bound + upper = cdf[:, dimension] + bound + return lower, upper + + + def plot( + self, + data_display: Union[DataDisplay, str], + include_residuals: bool = False, + include_theory_intervals: bool = True, + display_parameters_separate: bool = False, + x_label: str = "Quantiles", + y_label: str = "CDF", + title: str = "CDF Parity Plot", + samples_label = "Posterior Samples", + theory_label = "Theory", + theory_color = "gray", + theory_line_style = "--", + normalize_view: bool = True, + theory_alpha: float = 0.2, + **kwargs + ) -> tuple["fig", "ax"]: + """ + Compute the ECDF for post posteriors samples against the true parameter values. + Uses [scipy.stats.ecdf](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.ecdf.html) to compute the distributions for sampled posterior samples. + + To show the all distributions on one plot - set `display_parameters_separate` to `False` and verify `normalize_view` is set to `True`, this will ensure the x & y axes is normalized to [0, 1] for all parameters. + + Args: + data_display (DataDisplay or str): The data to plot. If a string, it is assumed to be the path to an HDF5 file. + include_residuals (bool): Whether to include the residuals between the theory and sample distributions. + include_theory_intervals (bool): Whether to include the theory intervals (percentiles given in the 'percentiles' field of the config) in the plot + display_parameters_separate (bool): Whether to display each parameter in a separate subplot. + x_label (str): Label for the x-axis. + y_label (str): Label for the y-axis. + title (str): Title of the plot. + samples_label (str): Label for the samples in the plot. + theory_label (str): Label for the theory in the plot. + theory_color (str): Color for the center theory line. + theory_line_style (str): Line style for center theory line. + normalize_view (bool): Whether to normalize the x axis of the plot to [0, 1] for all parameters. + theory_alpha (float): Alpha (transparency) value for the fill between the theory intervals. Between 0 and 1. + """ + + self.theory_alpha = theory_alpha + color_cycler = iter(plt.cycler("color", self.parameter_colors)) + line_style_cycler = iter(plt.cycler("line_style", self.line_cycle)) + + if not isinstance(data_display, DataDisplay): + data_display = DataDisplay().from_h5(data_display, self.plot_name) + + # Used if theory intervals are included + theory_color_cycle = self._get_hex_sigma_colors(len(data_display["percentiles"])) + + + if include_residuals: + row_len = self.figure_size[0] * .8*len(self.parameter_names) if display_parameters_separate else self.figure_size[0] + figsize = (row_len, 1.5*self.figure_size[1]) + fig, ax = plt.subplots( + 2, len(self.parameter_names) if display_parameters_separate else 1, + figsize=figsize, + height_ratios=[3, 1], + sharex='col', + sharey='row' + ) + plt.subplots_adjust(hspace=0.01) + + else: + row_len = self.figure_size[0] * .8*len(self.parameter_names) if display_parameters_separate else self.figure_size[0] + fig, ax = plt.subplots( + 1, len(self.parameter_names) if display_parameters_separate else 1, + figsize=(row_len, self.figure_size[1]), + sharey='row') + + if normalize_view: + for parameter_name in self.parameter_names: + data_display[f"quantiles_{parameter_name}"] = np.linspace(0, 1, num=len(data_display[f"quantiles_{parameter_name}"])) + + if include_theory_intervals: # Each plot needs to iterate over the percentiles in the main plot and the residuals + if display_parameters_separate: + for index, parameter_name in enumerate(self.parameter_names): + plot_ax = ax[index] if not include_residuals else ax[0][index] + plot_ax.plot(data_display[f"quantiles_{parameter_name}"], data_display[f"theory_probability_{parameter_name}"], color=theory_color, linestyle=theory_line_style) + + plot_ax.set_title(f"{samples_label} {parameter_name}") + color = next(color_cycler)["color"] + line_style = next(line_style_cycler)["line_style"] + self._plot_base_plot(data_display, plot_ax, parameter_name, samples_label, line_style, color, theory_color, theory_line_style) + + for interval_index, interval in enumerate(data_display["percentiles"]): + self._plot_theory_intervals( + data_display, plot_ax, parameter_name, theory_label, + theory_color_cycle[interval_index], interval + ) + + if include_residuals: + self._plot_theory_intervals_residual( + data_display, ax[1, index], parameter_name, theory_label, + theory_color_cycle[interval_index], interval + ) + + if include_residuals: + # Plot the residuals between the theory and sample + residual = data_display[f"sample_probability_{parameter_name}"] - data_display[f"theory_probability_{parameter_name}"] + + ax[1, index].plot( + data_display[f"quantiles_{parameter_name}"], + residual, + linestyle=line_style, + color=color, + ) + ax[1, index].axhline(0, color=theory_color, linestyle=theory_line_style) + + + else: # The plot_ax is the same for all parameters + plot_ax = ax if not include_residuals else ax[0] + for parameter_name in self.parameter_names: + color = next(color_cycler)["color"] + line_style = next(line_style_cycler)["line_style"] + self._plot_base_plot(data_display, plot_ax, parameter_name, samples_label, line_style, color, theory_color, theory_line_style) + for interval_index, interval in enumerate(data_display["percentiles"]): # iterate over the percentiles + if self.parameter_names.index(parameter_name) == 0: # Only plot for the first theory interval when not displaying parameters separately + self._plot_theory_intervals( + data_display, plot_ax, parameter_name, theory_label, + theory_color_cycle[interval_index], interval + ) + + if include_residuals: + self._plot_theory_intervals_residual( + data_display, ax[1], parameter_name, theory_label, + theory_color_cycle[interval_index], interval + ) + + if include_residuals: + # Plot the residuals between the theory and sample + residual = data_display[f"sample_probability_{parameter_name}"] - data_display[f"theory_probability_{parameter_name}"] + + ax[1].plot( + data_display[f"quantiles_{parameter_name}"], + residual, + linestyle=line_style, + color=color, + ) + ax[1].axhline(0, color=theory_color, linestyle=theory_line_style) + + + else: # Do not include the theory intervals - no fill-betweens here! + if display_parameters_separate: + for index, parameter_name in enumerate(self.parameter_names): + # Each parameter gets it's own subplot + plot_ax = ax[index] if not include_residuals else ax[0][index] + plot_ax.set_title(f"{samples_label} {parameter_name}") + color = next(color_cycler)["color"] + line_style = next(line_style_cycler)["line_style"] + self._plot_base_plot(data_display, plot_ax, parameter_name, samples_label, line_style, color, theory_color, theory_line_style) + + if include_residuals: + residual = data_display[f"sample_probability_{parameter_name}"] - data_display[f"theory_probability_{parameter_name}"] + + ax[1, index].plot( + data_display[f"quantiles_{parameter_name}"], + residual, + linestyle=line_style, + color=color, + ) + ax[1, index].axhline(0, color=theory_color, linestyle=theory_line_style) + + else: # Everything goes on the one column + plot_ax = ax if not include_residuals else ax[0] + for index, parameter_name in enumerate(self.parameter_names): + color = next(color_cycler)["color"] + line_style = next(line_style_cycler)["line_style"] + + self._plot_base_plot(data_display, plot_ax, parameter_name, samples_label, line_style, color, theory_color, theory_line_style) + + if include_residuals: + residual = data_display[f"sample_probability_{parameter_name}"] - data_display[f"theory_probability_{parameter_name}"] + ax[1].plot( + data_display[f"quantiles_{parameter_name}"], + residual, + linestyle=line_style, + color=color, + ) + ax[1].axhline(0, color=theory_color, linestyle=theory_line_style) + + + handles = [ + plt.Line2D([0], [0], color=theory_color, linestyle=theory_line_style, label=theory_label) + ] + if include_theory_intervals: + handles += [ + mpatches.Rectangle((0, 0), 0, 0, facecolor=theory_color_cycle[i], alpha=self.theory_alpha, edgecolor='none', label=f"CDF {int(data_display['percentiles'][i]*100)}% CI {theory_label}") + for i in range(len(data_display["percentiles"])) + ] + + if not display_parameters_separate: + # reset the color and line style cyclers for the handles + color_cycler = iter(plt.cycler("color", self.parameter_colors)) + line_style_cycler = iter(plt.cycler("line_style", self.line_cycle)) + handles += [ + plt.Line2D([0], [0], color=color['color'], linestyle=line_style['line_style'], label=f"{samples_label} {parameter_name}") + for parameter_name, color, line_style in zip(self.parameter_names, color_cycler, line_style_cycler) + ] + + fig.legend(handles=handles) + fig.suptitle(title) + fig.supxlabel(x_label) + fig.supylabel(y_label) + + return fig, ax diff --git a/src/deepdiagnostics/plots/cdf_ranks.py b/src/deepdiagnostics/plots/cdf_ranks.py index 19c2da6..99d1bb2 100644 --- a/src/deepdiagnostics/plots/cdf_ranks.py +++ b/src/deepdiagnostics/plots/cdf_ranks.py @@ -46,8 +46,8 @@ def plot_name(self): return "cdf_ranks.png" def _data_setup(self) -> DataDisplay: - thetas = tensor(self.data.get_theta_true()) - context = tensor(self.data.true_context()) + thetas = tensor(self.data.thetas) + context = tensor(self.data.simulator_outcome) ranks, _ = run_sbc( thetas, context, self.model.posterior, num_posterior_samples=self.samples_per_inference diff --git a/src/deepdiagnostics/plots/parity.py b/src/deepdiagnostics/plots/parity.py index db4c7ba..67d7a32 100644 --- a/src/deepdiagnostics/plots/parity.py +++ b/src/deepdiagnostics/plots/parity.py @@ -48,7 +48,7 @@ def plot_name(self): def _data_setup(self, n_samples: int = 80, **kwargs) -> DataDisplay: - context_shape = self.data.true_context().shape + context_shape = self.data.simulator_outcome.shape posterior_sample_mean = np.zeros((n_samples, self.data.n_dims)) posterior_sample_std = np.zeros_like(posterior_sample_mean) true_samples = np.zeros_like(posterior_sample_mean) @@ -56,11 +56,11 @@ def _data_setup(self, n_samples: int = 80, **kwargs) -> DataDisplay: random_context_indices = self.data.rng.integers(0, context_shape[0], n_samples) for index, sample in enumerate(random_context_indices): - posterior_sample = self.model.sample_posterior(self.samples_per_inference, self.data.true_context()[sample, :]).numpy() + posterior_sample = self.model.sample_posterior(self.samples_per_inference, self.data.simulator_outcome[sample, :]).numpy() posterior_sample_mean[index] = np.mean(posterior_sample, axis=0) posterior_sample_std[index] = np.std(posterior_sample, axis=0) - true_samples[index] = self.data.get_theta_true()[sample, :] + true_samples[index] = self.data.thetas[sample, :] return DataDisplay( n_dims=self.data.n_dims, diff --git a/src/deepdiagnostics/plots/plot.py b/src/deepdiagnostics/plots/plot.py index 71b7c82..71fe989 100644 --- a/src/deepdiagnostics/plots/plot.py +++ b/src/deepdiagnostics/plots/plot.py @@ -4,6 +4,8 @@ from typing import Optional, Sequence, TYPE_CHECKING, Union import matplotlib.pyplot as plt from matplotlib import rcParams +import matplotlib.colors as plt_colors +import numpy as np from deepdiagnostics.utils.config import get_item from deepdiagnostics.utils.utils import DataDisplay @@ -89,6 +91,17 @@ def _data_setup(self, **kwargs) -> Optional[DataDisplay]: "Return all the data required for plotting" raise NotImplementedError + def _get_hex_sigma_colors(self, n_colors): + + cmap = plt.get_cmap(self.colorway) + hex_colors = [] + arr = np.linspace(0, 1, n_colors) + for hit in arr: + hex_colors.append(plt_colors.rgb2hex(cmap(hit))) + + return hex_colors + + @abstractmethod def plot(self, data_display: Union[dict, "data_display"], **kwrgs) -> tuple["figure", "axes"]: """ diff --git a/src/deepdiagnostics/plots/predictive_posterior_check.py b/src/deepdiagnostics/plots/predictive_posterior_check.py index 342e287..4df4614 100644 --- a/src/deepdiagnostics/plots/predictive_posterior_check.py +++ b/src/deepdiagnostics/plots/predictive_posterior_check.py @@ -51,69 +51,60 @@ def plot_name(self): return "predictive_posterior_check.png" def _get_posterior_2d(self, n_simulator_draws): - context_shape = self.data.true_context().shape - sim_out_shape = self.data.get_simulator_output_shape() + sim_out_shape = self.data.simulator( + theta=self.data.thetas[0].unsqueeze(0), + n_samples=1 + )[0].shape + remove_first_dim = False - if len(sim_out_shape) != 2: + if len(sim_out_shape) > 2: # TODO Debug log with a warning sim_out_shape = (sim_out_shape[1], sim_out_shape[2]) remove_first_dim = True posterior_predictive_samples = np.zeros((n_simulator_draws, *sim_out_shape)) - posterior_true_samples = np.zeros_like(posterior_predictive_samples) - random_context_indices = self.data.rng.integers(0, context_shape[0], n_simulator_draws) - for index, sample in enumerate(random_context_indices): - context_sample = self.data.true_context()[sample, :] - posterior_sample = self.model.sample_posterior(1, context_sample) + random_context_indices = self.data.rng.integers(0, self.data.simulator_outcome.shape[0], n_simulator_draws) + simulator_true = self.data.simulator_outcome[random_context_indices, :].numpy() + for index, sample in enumerate(simulator_true): + posterior_sample = self.model.sample_posterior(1, sample) # get the posterior samples for that context - sim_out_posterior = self.data.simulator.simulate( - theta=posterior_sample, context_samples = context_sample - ) - sim_out_true = self.data.simulator.simulate( - theta=self.data.get_theta_true()[sample, :], context_samples=context_sample + sim_out_posterior = self.data.simulator(n_samples=1, + theta=posterior_sample ) if remove_first_dim: sim_out_posterior = sim_out_posterior[0] - sim_out_true = sim_out_true[0] posterior_predictive_samples[index] = sim_out_posterior - posterior_true_samples[index] = sim_out_true - return posterior_predictive_samples, posterior_true_samples + return posterior_predictive_samples, simulator_true def _get_posterior_1d(self, n_simulator_draws): - context_shape = self.data.true_context().shape - posterior_predictive_samples = np.zeros((n_simulator_draws, self.samples_per_inference, context_shape[-1])) - posterior_true_samples = np.zeros_like(posterior_predictive_samples) - context = np.zeros((n_simulator_draws, context_shape[-1])) + simulator_outcome_shape = self.data.simulator_dimensions - random_context_indices = self.data.rng.integers(0, context_shape[0], n_simulator_draws) - for index, sample in enumerate(random_context_indices): - context_sample = self.data.true_context()[sample, :] - context[index] = context_sample + posterior_predictive_samples = np.zeros((n_simulator_draws, self.samples_per_inference, simulator_outcome_shape)) - posterior_sample = self.model.sample_posterior(self.samples_per_inference, context_sample) + # Sample one random simulator output for each draw + random_context_indices = self.data.rng.integers(0, self.data.simulator_outcome.shape[0], n_simulator_draws) + simulator_samples = self.data.simulator_outcome[random_context_indices, :].numpy() + posterior_predictive_samples = np.zeros((n_simulator_draws, self.samples_per_inference, *simulator_samples[0].shape)) - # get the posterior samples for that context - posterior_predictive_samples[index] = self.data.simulator.simulate( - theta=posterior_sample, context_samples = context_sample - ) - posterior_true_samples[index] = self.data.simulator.simulate( - theta=self.data.get_theta_true()[sample, :], context_samples=context_sample + for index, sample in enumerate(simulator_samples): + posterior_sample = self.model.sample_posterior(self.samples_per_inference, sample) + posterior_predictive_samples[index] = self.data.simulator(n_samples=len(sample), + theta=posterior_sample ) - return posterior_predictive_samples, posterior_true_samples, context + + return posterior_predictive_samples, simulator_samples def _data_setup(self, n_unique_plots: Optional[int] = 3, **kwargs) -> DataDisplay: + true_sigma = None if self.data.simulator_dimensions == 1: - n_dims = 1 - - posterior_predictive_samples, posterior_true_samples, context = self._get_posterior_1d(n_unique_plots) + posterior_predictive_samples, posterior_true_samples = self._get_posterior_1d(n_unique_plots) true_sigma = self.data.get_sigma_true() elif self.data.simulator_dimensions == 2: - n_dims = 2 posterior_predictive_samples, posterior_true_samples = self._get_posterior_2d(n_unique_plots) else: @@ -122,10 +113,9 @@ def _data_setup(self, n_unique_plots: Optional[int] = 3, **kwargs) -> DataDispla return DataDisplay( n_unique_plots=n_unique_plots, posterior_predictive_samples=posterior_predictive_samples, - n_dims=n_dims, + n_dims=self.data.simulator_dimensions, posterior_true_samples=posterior_true_samples, - context=context if n_dims == 1 else None, - true_sigma=true_sigma if n_dims == 1 else None + true_sigma=true_sigma ) @@ -167,9 +157,6 @@ def plot( for plot_index in range(data_display.n_unique_plots): if data_display.n_dims == 1: - if data_display.context is None: - raise ValueError("Display Data is malformed. Missing `context` for 1D simulation. Please rerun data setup stage.") - dimension_y_simulation = data_display.posterior_predictive_samples[plot_index] y_simulation_mean = np.mean(dimension_y_simulation, axis=0).ravel() @@ -177,7 +164,7 @@ def plot( for sigma, color in zip(range(n_coverage_sigma), self.colors): subplots[0, plot_index].fill_between( - data_display.context[plot_index].ravel(), + range(len(y_simulation_mean)), y_simulation_mean - sigma * y_simulation_std, y_simulation_mean + sigma * y_simulation_std, color=color, @@ -186,22 +173,22 @@ def plot( ) subplots[0, plot_index].plot( - data_display.context[plot_index], + range(len(y_simulation_mean)), y_simulation_mean - data_display.true_sigma, color="black", linestyle="dashdot", label="True Input Error" ) subplots[0, plot_index].plot( - data_display.context[plot_index], + range(len(y_simulation_mean)), y_simulation_mean + data_display.true_sigma, color="black", linestyle="dashdot", ) - true_y = np.mean(data_display.posterior_true_samples[plot_index, :, :], axis=0).ravel() + true_y = data_display.posterior_true_samples[plot_index, :].ravel() subplots[1, plot_index].scatter( - data_display.context[plot_index], + range(len(true_y)), true_y, marker=theta_true_marker, label='Theta True' diff --git a/src/deepdiagnostics/plots/predictive_prior_check.py b/src/deepdiagnostics/plots/predictive_prior_check.py index 3dabc78..a124cf6 100644 --- a/src/deepdiagnostics/plots/predictive_prior_check.py +++ b/src/deepdiagnostics/plots/predictive_prior_check.py @@ -54,7 +54,7 @@ def plot_name(self): return "predictive_prior_check.png" def _data_setup(self, n_rows: int = 3, n_columns: int = 3, **kwargs) -> DataDisplay: - context_shape = self.data.true_context().shape + sim_out_shape = self.data.simulator_outcome[0].shape remove_first_dim = False if self.data.simulator_dimensions == 1: @@ -65,45 +65,37 @@ def _data_setup(self, n_rows: int = 3, n_columns: int = 3, **kwargs) -> DataDisp if plot_image: - sim_out_shape = self.data.get_simulator_output_shape() if len(sim_out_shape) != 2: # TODO Debug log with a warning sim_out_shape = (sim_out_shape[1], sim_out_shape[2]) remove_first_dim = True - prior_predictive_samples = np.zeros((n_rows, n_columns, *sim_out_shape)) - - else: - prior_predictive_samples = np.zeros((n_rows, n_columns, context_shape[-1])) + prior_predictive_samples = np.zeros((n_rows, n_columns, *sim_out_shape)) prior_true_sample = np.zeros((n_rows, n_columns, self.data.n_dims)) - context = np.zeros((n_rows, n_columns, context_shape[-1])) - random_context_indices = self.data.rng.integers(0, context_shape[0], (n_rows, n_columns)) + + print(prior_predictive_samples.shape) + print(prior_true_sample.shape) for row_index in range(n_rows): for column_index in range(n_columns): - sample = random_context_indices[row_index, column_index] - context_sample = self.data.true_context()[sample, :] - prior_sample = self.data.sample_prior(1)[0] # get the posterior samples for that context - simulation_sample = self.data.simulator.simulate( - theta=prior_sample, context_samples = context_sample + simulation_sample = self.data.simulator( + theta=prior_sample, n_samples=np.prod(sim_out_shape) ) if remove_first_dim: simulation_sample = simulation_sample[0] prior_predictive_samples[row_index, column_index] = simulation_sample prior_true_sample[row_index, column_index] = prior_sample - context[row_index, column_index] = context_sample return DataDisplay( plot_image=plot_image, n_rows=n_rows, n_columns=n_columns, prior_predictive_samples=prior_predictive_samples, - prior_true_sample=prior_true_sample, - context=context) + prior_true_sample=prior_true_sample) def plot( self, @@ -192,7 +184,7 @@ def plot( else: subplots[plot_row_index, plot_column_index].plot( - data_display.context[column_index, row_index], + range(len(data_display.prior_predictive_samples[column_index, row_index])), data_display.prior_predictive_samples[column_index, row_index] ) diff --git a/src/deepdiagnostics/plots/ranks.py b/src/deepdiagnostics/plots/ranks.py index b592aa7..8ff86af 100644 --- a/src/deepdiagnostics/plots/ranks.py +++ b/src/deepdiagnostics/plots/ranks.py @@ -46,8 +46,8 @@ def plot_name(self): return "ranks.png" def _data_setup(self, **kwargs) -> DataDisplay: - thetas = tensor(self.data.get_theta_true()) - context = tensor(self.data.true_context()) + thetas = tensor(self.data.thetas) + context = tensor(self.data.simulator_outcome) ranks, _ = run_sbc( thetas, context, self.model.posterior, num_posterior_samples=self.samples_per_inference ) diff --git a/src/deepdiagnostics/plots/tarp.py b/src/deepdiagnostics/plots/tarp.py index 94abfbc..8e6c3ef 100644 --- a/src/deepdiagnostics/plots/tarp.py +++ b/src/deepdiagnostics/plots/tarp.py @@ -4,7 +4,6 @@ import tarp import matplotlib.pyplot as plt -import matplotlib.colors as plt_colors from matplotlib.axes import Axes as ax from matplotlib.figure import Figure as fig @@ -52,7 +51,7 @@ def plot_name(self): return "tarp.png" def _data_setup(self, **kwargs) -> DataDisplay: - self.theta_true = self.data.get_theta_true() + self.theta_true = self.data.thetas.numpy() n_dims = self.theta_true.shape[1] posterior_samples = np.zeros( (self.number_simulations, self.samples_per_inference, n_dims) @@ -62,7 +61,7 @@ def _data_setup(self, **kwargs) -> DataDisplay: sample_index = self.data.rng.integers(0, len(self.theta_true)) theta = self.theta_true[sample_index, :] - x = self.data.true_context()[sample_index, :] + x = self.data.simulator_outcome[sample_index, :] posterior_samples[n] = self.model.sample_posterior( self.samples_per_inference, x ) @@ -78,16 +77,6 @@ def plot_settings(self): "plots_common", "line_style_cycle", raise_exception=False ) - def _get_hex_sigma_colors(self, n_colors): - - cmap = plt.get_cmap(self.colorway) - hex_colors = [] - arr = np.linspace(0, 1, n_colors) - for hit in arr: - hex_colors.append(plt_colors.rgb2hex(cmap(hit))) - - return hex_colors - def plot( self, data_display: Union[DataDisplay, dict] = None, diff --git a/src/deepdiagnostics/plots/under_cdf_parity.py b/src/deepdiagnostics/plots/under_cdf_parity.py new file mode 100644 index 0000000..9fa82fe --- /dev/null +++ b/src/deepdiagnostics/plots/under_cdf_parity.py @@ -0,0 +1,10 @@ +from typing import Union, TYPE_CHECKING + +import numpy as np +import matplotlib.pyplot as plt +from matplotlib import patches as mpatches +from scipy.stats import ecdf + +from deepdiagnostics.plots.plot import Display +from deepdiagnostics.utils.config import get_item +from deepdiagnostics.utils.utils import DataDisplay diff --git a/src/deepdiagnostics/utils/defaults.py b/src/deepdiagnostics/utils/defaults.py index 2412a55..d82f619 100644 --- a/src/deepdiagnostics/utils/defaults.py +++ b/src/deepdiagnostics/utils/defaults.py @@ -33,7 +33,8 @@ "LC2ST": {}, "Parity":{}, "PPC": {}, - "PriorPC":{} + "PriorPC":{}, + "CDFParityPlot": {} }, "metrics_common": { "use_progress_bar": False, diff --git a/tests/conftest.py b/tests/conftest.py index 70874d6..ee4af93 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -57,7 +57,7 @@ def simulate(self, theta, context_samples: np.ndarray): for sample_index, t in enumerate(theta): mock_data = np.random.normal( - loc=t[0], scale=abs(t[1]), size=(len(context_samples), 2) + loc=t[0], scale=abs(t[1]), size=(len(context_samples), len(context_samples)) ) generated_stars.append( np.column_stack((context_samples, mock_data)) @@ -84,7 +84,7 @@ def model_path(): @pytest.fixture def data_path(): - return "resources/saveddata/data_validation.h5" + return "resources/saveddata/data_test.h5" @pytest.fixture def result_output(): diff --git a/tests/test_client.py b/tests/test_client.py index 78a955d..e4f8980 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -81,6 +81,5 @@ def test_missing_simulator(model_path, data_path): process = subprocess.run(command, capture_output=True) exit_code = process.returncode stdout = process.stdout.decode("utf-8") - assert exit_code == 0 - plot_name = "PPC" - assert f"Cannot run {plot_name} - simulator missing." in stdout + assert exit_code == 0, process.stderr.decode("utf-8") + assert "Warning: Simulator not loaded. Using a lookup table simulator." in stdout diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 96ef9cf..428307d 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -5,7 +5,8 @@ from deepdiagnostics.metrics import ( CoverageFraction, AllSBC, - LC2ST + LC2ST, + CDFParityAreaUnderCurve ) @pytest.fixture @@ -44,3 +45,11 @@ def test_lc2st(metric_config, mock_model, mock_data, mock_run_id): assert lc2st.output is not None assert os.path.exists(f"{lc2st.out_dir}/{mock_run_id}_diagnostic_metrics.json") +def test_CDFParity_Area_Under_Curve(metric_config, mock_model, mock_data, mock_run_id): + Config(metric_config) + areaunderecdf=CDFParityAreaUnderCurve( mock_model, mock_data, mock_run_id, save=True) + areaunderecdf() + assert areaunderecdf.output is not None + #print(areaunderecdf.output) #used for testing to check output + #assert False #testing + assert os.path.exists(f"{areaunderecdf.out_dir}/{mock_run_id}_diagnostic_metrics.json") diff --git a/tests/test_plots.py b/tests/test_plots.py index dfcda61..1d043cd 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -20,7 +20,7 @@ def plot_config(config_factory): metrics_settings = { "use_progress_bar": False, "samples_per_inference": 10, - "percentiles": [95, 75, 50], + "percentiles": [68, 95], } config = config_factory(metrics_settings=metrics_settings) return config @@ -71,6 +71,10 @@ def test_ppc(plot_config, mock_model, mock_data, mock_2d_data, result_output, mo plot(**get_item("plots", "PPC", raise_exception=False)) assert os.path.exists(f"{plot.out_dir}/{mock_run_id}_{plot.plot_name}") + +@pytest.mark.xfail(reason="2D dataset needs to be generated - the simulator will not run automatically.") +def test_ppc_2d(plot_config, mock_model, mock_2d_data, result_output, mock_run_id): + Config(plot_config) plot = PPC( mock_model, mock_2d_data, mock_run_id, save=True, show=False, @@ -79,11 +83,16 @@ def test_ppc(plot_config, mock_model, mock_data, mock_2d_data, result_output, mo plot(**get_item("plots", "PPC", raise_exception=False)) -def test_prior_pc(plot_config, mock_model, mock_2d_data, mock_data, mock_run_id, result_output): +def test_prior_pc(plot_config, mock_model, mock_data, mock_run_id): Config(plot_config) plot = PriorPC(mock_model, mock_data, mock_run_id, save=True, show=False) plot(**get_item("plots", "PriorPC", raise_exception=False)) assert os.path.exists(f"{plot.out_dir}/{mock_run_id}_{plot.plot_name}") + + +@pytest.mark.xfail(reason="2D dataset needs to be generated - the simulator will not run automatically.") +def plot_prior_pc_2d(plot_config, mock_model, mock_2d_data, result_output, mock_run_id): + Config(plot_config) plot = PriorPC( mock_model, mock_2d_data, mock_run_id, save=True, show=False, @@ -135,3 +144,9 @@ def test_rerun_plot(plot_type, plot_config, mock_model, mock_data, mock_run_id): assert os.path.exists(f"{plot.out_dir}rerun_plot.png") +def test_cdf_parity(plot_config, mock_model, mock_data, mock_run_id): + from deepdiagnostics.plots import CDFParityPlot + Config(plot_config) + plot = CDFParityPlot(mock_model, mock_data, mock_run_id, save=True, show=False) + plot(**get_item("plots", "CDFParityPlot", raise_exception=False)) + assert os.path.exists(f"{plot.out_dir}/{mock_run_id}_{plot.plot_name}") \ No newline at end of file diff --git a/tests/test_utils.py b/tests/test_utils.py index 711112b..a6e2286 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -8,9 +8,9 @@ class TestLookupTableSimulator: def test_lookup_table_simulator(): # fake data data = { - "xs": torch.Tensor([[0.1], [0.2], [0.3], [0.4]]), # context + "context": torch.Tensor([[0.1], [0.2], [0.3], [0.4]]), # context "thetas": torch.Tensor([[1.0], [2.0], [3.0], [4.0]]), # parameters - "ys": torch.Tensor([[10.0], [20.0], [30.0], [40.0]]), # outcomes + "simulator_outcome": torch.Tensor([[10.0], [20.0], [30.0], [40.0]]), # outcomes } rng = np.random.default_rng(42) sim = LookupTableSimulator(data, rng) @@ -20,11 +20,11 @@ def test_lookup_table_simulator(): # Test generate_context contexts = sim.generate_context(2) assert contexts.shape == (2, 1) - assert all(context in data["xs"].tolist() for context in contexts) # Only getting contexts from data + assert all(context in data["context"].tolist() for context in contexts) # Only getting contexts from data # Test exact match outcome - theta = torch.Tensor([[2.0]]) - context = torch.Tensor([[0.2]]) + theta = torch.Tensor([2.0]) + context = torch.Tensor([0.2]) outcome = sim.simulate(theta, context) assert outcome.shape == (1, 1) assert outcome[0] == 20.0 @@ -48,9 +48,9 @@ def test_lookup_table_simulator_multidim_params(): rng = np.random.default_rng(42) data = { - "xs": torch.tensor(rng.random((10, 2))), # context + "context": torch.tensor(rng.random((10, 2))), # context "thetas": torch.tensor(rng.random((10, 3))), # parameters - "ys": torch.tensor(rng.random((10, 1))), # outcomes + "simulator_outcome": torch.tensor(rng.random((10, 1))), # outcomes } rng = np.random.default_rng(42) @@ -58,11 +58,11 @@ def test_lookup_table_simulator_multidim_params(): # Test exact match outcome theta = data["thetas"][:2, :] - context = data["xs"][:2, :] + context = data["context"][:2, :] outcome = sim.simulate(theta, context) assert outcome.shape == (2, 1) - assert outcome[0] == data["ys"][0] - assert outcome[1] == data["ys"][1] + assert outcome[0] == data["simulator_outcome"][0] + assert outcome[1] == data["simulator_outcome"][1] # Test nearest neighbor outcome theta = rng.random((1, 3))