diff --git a/docs/source/plots.rst b/docs/source/plots.rst
index 1069c2f..5669cf5 100644
--- a/docs/source/plots.rst
+++ b/docs/source/plots.rst
@@ -30,4 +30,7 @@ Plots
.. autoclass:: deepdiagnostics.plots.Parity
:members: plot
+.. autoclass:: deepdiagnostics.plots.CDFParity
+ :members: plot
+
.. bibliography::
\ No newline at end of file
diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst
index 31f98c9..c432cde 100644
--- a/docs/source/quickstart.rst
+++ b/docs/source/quickstart.rst
@@ -20,14 +20,76 @@ Installation
git clone https://github.com/deepskies/DeepDiagnostics/
pip install poetry
- poetry shell
- poetry install
+ poetry install
+ poetry run diagnose --help
+Pre-requisites
+-------------
+
+DeepDiagnostics does not train models or generate data, they must be provided.
+Possible model formats are listed in :ref:`models` and data formats in :ref:`data`.
+If you are using a simulator, it must be registered by using `deepdiagnostics.utils.register.register_simulator`.
+More information can be found in :ref:`custom_simulations`.
+
+Output directories are automatically created, and if a run ID is not specified, one is generated.
+Only if a run ID is specified will previous runs be overwritten.
+
Configuration
-----
+-------------
+
+Description of the configuration file, including defaults, can be found in :ref:`configuration`.
+Below is a minimal example.
+
+...code-block:: yaml
+
+ common:
+ out_dir: "./deepdiagnostics_results/"
+ random_seed: 42
+ data:
+ data_engine: "H5Data"
+ data_path: "./data/my_data.h5"
+ simulator: "MySimulator"
+ simulator_kwargs: # Any augments used to initialize the simulator
+ foo: bar
+ model:
+ model_engine: "SBIModel"
+ model_path: "./models/my_model.pkl"
+ plots_common: # Used across all plots
+ parameter_labels: # Can either be plain strings or rendered LaTeX strings
+ - "My favorite parameter"
+ - "My least favorite parameter"
+ - "My most mid parameter"
+ parameter_colors: # Any color recognized by matplotlib
+ - "#264a95"
+ - "#ed9561"
+ - "#89b7bb"
+ line_style_cycle: # Any line type recognized by matplotlib
+ - solid
+ - dashed
+ - dotted
+ figure_size: # Approximate size, it can be scaled when adding additional subfigures
+ - 6 # x length
+ - 6 # y length
+ metrics_common: # Used across all metrics (and plots if the plots have a calculation step)
+ samples_per_inference: 1000
+ number_simulations: 100
+ percentiles:
+ - 68
+ - 95
+ plots:
+ CoverageFraction: # Arguments supplied to {plottype}.plot()
+ include_coverage_std: True
+ include_ideal_range: True
+ reference_line_label: "Ideal Coverage"
+ TARP:
+ coverage_sigma: 4
+ title: "TARP of My Model"
+ metrics:
+ AllSBC
+ Ranks:
+ num_bins: 3
-Description of the configuration file, including defaults, can be found in :ref:`configuration`.
Pipeline
---------
diff --git a/resources/saveddata/data_test.h5 b/resources/saveddata/data_test.h5
new file mode 100644
index 0000000..6b5fd74
Binary files /dev/null and b/resources/saveddata/data_test.h5 differ
diff --git a/src/deepdiagnostics/data/data.py b/src/deepdiagnostics/data/data.py
index 656baa9..ad9722b 100644
--- a/src/deepdiagnostics/data/data.py
+++ b/src/deepdiagnostics/data/data.py
@@ -46,99 +46,32 @@ def __init__(
msg = f"Could not load the lookup table simulator - {e}. You cannot use generative diagnostics."
print(msg)
+ self.context = self._context()
+ self.thetas = self._thetas()
+
self.prior_dist = self.load_prior(prior, prior_kwargs)
- self.n_dims = self.get_theta_true().shape[1]
+ self.n_dims = self.thetas.shape[1]
self.simulator_dimensions = simulation_dimensions if simulation_dimensions is not None else get_item("data", "simulator_dimensions", raise_exception=False)
- def get_simulator_output_shape(self) -> tuple[Sequence[int]]:
- """
- Run a single sample of the simulator to verify the out-shape.
-
- Returns:
- tuple[Sequence[int]]: Output shape of a single sample of the simulator.
- """
- context_shape = self.true_context().shape
- sim_out = self.simulator(theta=self.get_theta_true()[0:1, :], n_samples=context_shape[-1])
- return sim_out.shape
+ self.simulator_outcome = self._simulator_outcome()
def _load(self, path: str):
raise NotImplementedError
- def true_context(self):
- """
- True data x values, if supplied by the data method.
- """
- # From Data
+ def _simulator_outcome(self):
+ raise NotImplementedError
+
+ def _context(self):
+ raise NotImplementedError
+
+ def _thetas(self):
raise NotImplementedError
- def true_simulator_outcome(self) -> np.ndarray:
- """
- Run the simulator on all true theta and true x values.
-
- Returns:
- np.ndarray: array of (n samples, simulator shape) showing output of the simulator on all true samples in data.
- """
- return self.simulator(self.get_theta_true(), self.true_context())
-
- def sample_prior(self, n_samples: int) -> np.ndarray:
- """
- Draw samples from the simulator
-
- Args:
- n_samples (int): Number of samples to draw
-
- Returns:
- np.ndarray:
- """
- return self.prior_dist(size=(n_samples, self.n_dims))
-
- def simulator_outcome(self, theta:np.ndarray, condition_context:np.ndarray=None, n_samples:int=None):
- """_summary_
-
- Args:
- theta (np.ndarray): Theta value of shape (n_samples, theta_dimensions)
- condition_context (np.ndarray, optional): If x values for theta are known, use them. Defaults to None.
- n_samples (int, optional): If x values are not known for theta, draw them randomly. Defaults to None.
-
- Raises:
- ValueError: If either n samples or content samples is supplied.
-
- Returns:
- np.ndarray: Simulator output of shape (n samples, simulator_dimensions)
- """
- if condition_context is None:
- if n_samples is None:
- raise ValueError(
- "Samples required if condition context is not specified"
- )
- return self.simulator(theta, n_samples)
- else:
- return self.simulator.simulate(theta, condition_context)
-
- def simulated_context(self, n_samples:int) -> np.ndarray:
- """
- Call the simulator's `generate_context` method.
-
- Args:
- n_samples (int): Number of samples to draw.
-
- Returns:
- np.ndarray: context (x values), as defined by the simulator.
- """
- return self.simulator.generate_context(n_samples)
-
- def get_theta_true(self) -> Union[Any, float, int, np.ndarray]:
- """
- Look for the true theta given by data. If supplied in the method, use that, other look in the configuration file.
- If neither are supplied, return None.
+ def save(self, data, path: str):
+ raise NotImplementedError
- Returns:
- Any: Theta value selected by the search.
- """
- if hasattr(self, "theta_true"):
- return self.theta_true
- else:
- return get_item("data", "theta_true", raise_exception=True)
+ def read_prior(self):
+ raise NotImplementedError
def get_sigma_true(self) -> Union[Any, float, int, np.ndarray]:
"""
@@ -153,12 +86,6 @@ def get_sigma_true(self) -> Union[Any, float, int, np.ndarray]:
else:
return get_item("data", "sigma_true", raise_exception=True)
- def save(self, data, path: str):
- raise NotImplementedError
-
- def read_prior(self):
- raise NotImplementedError
-
def load_prior(self, prior:str, prior_kwargs:dict[str, any]) -> callable:
"""
Load the prior.
@@ -201,3 +128,21 @@ def load_prior(self, prior:str, prior_kwargs:dict[str, any]) -> callable:
except KeyError as e:
raise RuntimeError(f"Data missing a prior specification - {e}")
+
+ def sample_prior(self, n_samples: int) -> np.ndarray:
+ """
+ Sample from the prior.
+
+ Args:
+ n_samples (int): Number of samples to draw.
+
+ Returns:
+ np.ndarray: Samples drawn from the prior.
+ """
+
+ if self.prior_dist is None:
+ prior = self.read_prior()
+ sample = self.rng.randint(0, len(prior), size=n_samples)
+ return prior[sample]
+ else:
+ return self.prior_dist(size=(n_samples, self.n_dims))
\ No newline at end of file
diff --git a/src/deepdiagnostics/data/h5_data.py b/src/deepdiagnostics/data/h5_data.py
index c5e50ff..924418c 100644
--- a/src/deepdiagnostics/data/h5_data.py
+++ b/src/deepdiagnostics/data/h5_data.py
@@ -11,6 +11,12 @@ class H5Data(Data):
"""
Load data that has been saved in a h5 format.
+ If you cast your problem to be y = mx + b, these are the fields required and what they represent:
+
+ simulator_outcome - y
+ thetas - parameters of the model - m, b
+ context - xs
+
.. attribute:: Data Parameters
:xs: [REQUIRED] The context, the x values. The data that was used to train a model on what conditions produce what posterior.
@@ -30,6 +36,7 @@ def __init__(self,
):
super().__init__(path, simulator, simulator_kwargs, prior, prior_kwargs, simulation_dimensions)
+
def _load(self, path):
assert path.split(".")[-1] == "h5", "File extension must be h5"
loaded_data = {}
@@ -49,17 +56,29 @@ def save(self, data: dict[str, Any], path: str): # Todo typing for data dict
for key, value in data_arrays.items():
file.create_dataset(key, data=value)
- def true_context(self):
- """
- Try to get the `xs` field of the loaded data.
-
- Raises:
- NotImplementedError: The data does not have a `xs` field.
- """
+ def _simulator_outcome(self):
try:
- return self.data["xs"]
- except KeyError:
- raise NotImplementedError("Cannot find `xs` in data. Please supply it.")
+ return self.data["simulator_outcome"]
+ except KeyError:
+ try:
+ sim_outcome = np.array((self.simulator_dimensions, len(self.thetas)))
+ for index, theta in enumerate(self.thetas):
+ sim_out = self.simulator(theta=theta.unsqueeze(0), n_samples=1)
+ sim_outcome[:, index] = sim_out
+ return sim_outcome
+
+ except Exception as e:
+ e = f"Data does not have a `simulator_output` field and could not generate it from a simulator: {e}"
+ raise ValueError(e)
+
+ def _context(self):
+ try:
+ context = self.data["context"]
+ if context.ndim == 1:
+ context = context.unsqueeze(1)
+ return context
+ except KeyError:
+ raise NotImplementedError("Data does not have a `context` field.")
def prior(self):
"""
@@ -68,12 +87,14 @@ def prior(self):
Raises:
NotImplementedError: The data does not have a `prior` field.
"""
- try:
+ if 'prior' in self.data:
return self.data['prior']
- except KeyError:
- raise NotImplementedError("Data does not have a `prior` field.")
-
- def get_theta_true(self):
+ elif self.prior_dist is not None:
+ return self.prior_dist
+ else:
+ raise ValueError("Data does not have a `prior` field.")
+
+ def _thetas(self):
""" Get stored theta used to train the model.
Returns:
diff --git a/src/deepdiagnostics/data/lookup_table_simulator.py b/src/deepdiagnostics/data/lookup_table_simulator.py
index 17c3342..1c6b2a1 100644
--- a/src/deepdiagnostics/data/lookup_table_simulator.py
+++ b/src/deepdiagnostics/data/lookup_table_simulator.py
@@ -11,11 +11,11 @@ class LookupTableSimulator(Simulator):
Does not need to be registered, it is automatically available as the default simulator
- Assumes your has the following fields accessible as data["xs"], data["thetas"], data["ys"]
+ Assumes your has the following fields accessible as data["context"], data["thetas"], data["simulator_outcome"],
where xs is the context, thetas are the parameters, and ys are the outcomes
"""
- def __init__(self, data: torch.tensor, random_state: np.random.Generator, outside_range_limit: float = 2.0) -> None:
+ def __init__(self, data: torch.tensor, random_state: np.random.Generator, outside_range_limit: float = 2.0, hash_precision: int = 10) -> None:
"""
Parameters
@@ -28,7 +28,9 @@ def __init__(self, data: torch.tensor, random_state: np.random.Generator, outsid
When values of theta and x are passed that are not in the dataset,
if they are greater than the max by this threshold, (where value > outside_range_limit*value_max),
a value error is raised instead of taking the nearest neighbor; by default 1.5
-
+ hash_precision : int, optional
+ Number of decimal places to round to when creating hash keys, by default 10
+
Raises
------
ValueError
@@ -36,41 +38,48 @@ def __init__(self, data: torch.tensor, random_state: np.random.Generator, outsid
"""
super().__init__()
# Normalizing for finding nearest neighbors
- self.threshold = outside_range_limit
+ for key in ["simulator_outcome", "thetas", "context"]:
+ if key not in data.keys():
+ msg = f"Data must have a field `{key}` - found {data.keys()}"
+ raise ValueError(msg)
+
+ self.precision = hash_precision
+
self.max_theta = torch.max(data["thetas"], axis=0).values
self.min_theta = torch.min(data["thetas"], axis=0).values
- self.max_x = torch.max(data["xs"], axis=0).values
- self.min_x = torch.min(data["xs"], axis=0).values
+ self.max_x = torch.max(data["context"], axis=0).values
+ self.min_x = torch.min(data["context"], axis=0).values
+
+ self.threshold = outside_range_limit * torch.max(self.max_theta)
self.table = self._build_table(data)
self.rng = random_state
-
- for key in ["xs", "thetas", "ys"]:
- if key not in data.keys():
- msg = f"Data must have a field `{key}` - found {data.keys()}"
- raise ValueError(msg)
def _build_table(self, data):
"Takes all the theta, context and outcome data and builds a lookup table"
table = {
- self._build_hash(theta, context): {
- "y": outcome,
+ self._build_hash(theta, simulator_outcome): {
+ "y": simulator_outcome,
"loc": self._calc_hash_distance(theta, context),
"theta": theta,
"x": context,
}
- for theta, context, outcome in zip(data["thetas"], data["xs"], data["ys"])
+ for theta, simulator_outcome, context in zip(data["thetas"], data["simulator_outcome"], data['context'])
}
return table
def _build_hash(self, theta, context):
"Take a theta and context, and build a hashable key for the lookup table"
- return hash(tuple(torch.concat([theta, context], dim=-1)))
+ hashval = torch.concat([theta, context], dim=-1).flatten()
+ rounded_values = [round(val.item(), self.precision) for val in hashval]
+ return hash(tuple(rounded_values))
def _calc_hash_distance(self, theta: Union[torch.Tensor, float], context: Union[torch.Tensor, float]) -> float:
"Create a distance (as the norm) metric between pairs of theta and context"
theta = (theta - self.min_theta) / (self.max_theta - self.min_theta)
context = (context - self.min_x) / (self.max_x - self.min_x)
+ theta = theta.unsqueeze(0) if theta.dim() == 0 else theta
+ context = context.unsqueeze(0) if context.dim() == 0 else context
return torch.linalg.norm(torch.concat([theta, context], dim=-1))
def generate_context(self, n_samples):
@@ -100,11 +109,13 @@ def simulate(self, theta: Union[torch.Tensor, float], context_samples: Union[tor
context_samples = torch.Tensor([context_samples])
for t, x in zip(theta, context_samples):
+ t = t.unsqueeze(0) if t.dim() == 0 else t
+ x = x.unsqueeze(0) if x.dim() == 0 else x
key = self._build_hash(t, x)
try:
results.append(self.table[key]["y"])
except KeyError:
- print(f"Could not match theta {t} and x {x} to a result - taking the nearest neighbor")
+ print(f"Could not match theta {t} and context {x} to a result - taking the nearest neighbor")
space_hit = self._calc_hash_distance(t, x)
nearest_key = min(self.table.keys(), key=lambda k: abs(self.table[k]["loc"] - space_hit))
diff --git a/src/deepdiagnostics/data/simulator.py b/src/deepdiagnostics/data/simulator.py
index 6b77d77..adb51f5 100644
--- a/src/deepdiagnostics/data/simulator.py
+++ b/src/deepdiagnostics/data/simulator.py
@@ -50,7 +50,7 @@ def simulate(self, theta: np.ndarray, context_samples: np.ndarray) -> np.ndarray
Specify a simulation S such that y_{theta} = S(context_samples|theta)
- Example:
+ Example:
.. code-block:: python
# Generate from a random distribution
diff --git a/src/deepdiagnostics/metrics/__init__.py b/src/deepdiagnostics/metrics/__init__.py
index a713120..7777472 100644
--- a/src/deepdiagnostics/metrics/__init__.py
+++ b/src/deepdiagnostics/metrics/__init__.py
@@ -1,6 +1,7 @@
from deepdiagnostics.metrics.all_sbc import AllSBC
from deepdiagnostics.metrics.coverage_fraction import CoverageFraction
from deepdiagnostics.metrics.local_two_sample import LocalTwoSampleTest as LC2ST
+from deepdiagnostics.metrics.under_cdf_parity import CDFParityAreaUnderCurve as CDFParityAreaUnderCurve
def void(*args, **kwargs):
def void2(*args, **kwargs):
diff --git a/src/deepdiagnostics/metrics/all_sbc.py b/src/deepdiagnostics/metrics/all_sbc.py
index 8de9d54..35a1df6 100644
--- a/src/deepdiagnostics/metrics/all_sbc.py
+++ b/src/deepdiagnostics/metrics/all_sbc.py
@@ -39,8 +39,8 @@ def __init__(
number_simulations)
def _collect_data_params(self):
- self.thetas = tensor(self.data.get_theta_true())
- self.context = tensor(self.data.true_context())
+ self.thetas = tensor(self.data.thetas)
+ self.simulator_outcome = tensor(self.data.simulator_outcome)
def calculate(self) -> dict[str, Sequence]:
"""
@@ -51,7 +51,7 @@ def calculate(self) -> dict[str, Sequence]:
"""
ranks, dap_samples = run_sbc(
self.thetas,
- self.context,
+ self.simulator_outcome,
self.model.posterior,
num_posterior_samples=self.samples_per_inference,
)
diff --git a/src/deepdiagnostics/metrics/coverage_fraction.py b/src/deepdiagnostics/metrics/coverage_fraction.py
index a114b6e..98ecf74 100644
--- a/src/deepdiagnostics/metrics/coverage_fraction.py
+++ b/src/deepdiagnostics/metrics/coverage_fraction.py
@@ -37,12 +37,8 @@ def __init__(
self._collect_data_params()
def _collect_data_params(self):
- self.thetas = self.data.get_theta_true()
- self.context = self.data.true_context()
-
- def _run_model_inference(self, samples_per_inference, y_inference):
- samples = self.model.sample_posterior(samples_per_inference, y_inference)
- return samples.numpy()
+ self.thetas = self.data.thetas
+ self.simulator_outcome = self.data.simulator_outcome
def calculate(self) -> tuple[Sequence, Sequence]:
"""
@@ -55,6 +51,7 @@ def calculate(self) -> tuple[Sequence, Sequence]:
all_samples = np.empty(
(self.number_simulations, self.samples_per_inference, np.shape(self.thetas)[1])
)
+
iterator = range(self.number_simulations)
if self.use_progress_bar:
iterator = tqdm(
@@ -62,12 +59,13 @@ def calculate(self) -> tuple[Sequence, Sequence]:
desc="Sampling from the posterior for each observation",
unit=" observation",
)
+
n_theta_samples = self.thetas.shape[0]
count_array = np.zeros((self.number_simulations, len(self.percentiles), self.thetas.shape[1]))
for sample_index in iterator:
- context_sample = self.context[self.data.rng.integers(0, len(self.context))]
- samples = self._run_model_inference(self.samples_per_inference, context_sample)
+ context_sample = self.simulator_outcome[self.data.rng.integers(0, len(self.simulator_outcome))]
+ samples = self.model.sample_posterior(self.samples_per_inference, context_sample).numpy()
all_samples[sample_index] = samples
@@ -81,7 +79,6 @@ def calculate(self) -> tuple[Sequence, Sequence]:
# the units are in parameter space
confidence_lower = np.percentile(samples, percentile_lower, axis=0)
confidence_upper = np.percentile(samples, percentile_upper, axis=0)
-
# this is asking if the true parameter value
# is contained between the
diff --git a/src/deepdiagnostics/metrics/local_two_sample.py b/src/deepdiagnostics/metrics/local_two_sample.py
index 4487909..ec39320 100644
--- a/src/deepdiagnostics/metrics/local_two_sample.py
+++ b/src/deepdiagnostics/metrics/local_two_sample.py
@@ -60,13 +60,13 @@ def _collect_data_params(self):
# P is the prior and x_P is generated via the simulator from the parameters P.
self.p = self.data.sample_prior(self.number_simulations)
self.q = np.zeros_like(self.p)
- context_size = self.data.true_context().shape[-1]
+ context_size = self.data.simulator_outcome.shape[-1]
remove_first_dim = False
if self.data.simulator_dimensions == 1:
self.outcome_given_p = np.zeros((self.number_simulations, context_size))
elif self.data.simulator_dimensions == 2:
- sim_out_shape = self.data.get_simulator_output_shape()
+ sim_out_shape = self.data.simulator_outcome[0].shape
if len(sim_out_shape) != 2:
# TODO Debug log with a warning
sim_out_shape = (sim_out_shape[1], sim_out_shape[2])
@@ -169,7 +169,7 @@ def _cross_eval_score(
self.evaluation_data = np.zeros((n_cross_folds, len(next(cv_splits)[1]), self.evaluation_context.shape[-1]))
elif self.data.simulator_dimensions == 2:
- sim_out_shape = self.data.get_simulator_output_shape()
+ sim_out_shape = self.data.simulator_outcome[0].shape
if len(sim_out_shape) != 2:
# TODO Debug log with a warning
sim_out_shape = (sim_out_shape[1], sim_out_shape[2])
diff --git a/src/deepdiagnostics/metrics/under_cdf_parity.py b/src/deepdiagnostics/metrics/under_cdf_parity.py
new file mode 100644
index 0000000..ff3250b
--- /dev/null
+++ b/src/deepdiagnostics/metrics/under_cdf_parity.py
@@ -0,0 +1,102 @@
+from typing import Union, TYPE_CHECKING, Any, Optional, Sequence
+
+import json
+import os
+
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib import patches as mpatches
+from scipy.stats import ecdf, binom
+
+from deepdiagnostics.data import data
+from deepdiagnostics.models import model
+from deepdiagnostics.utils.config import get_item
+
+import numpy as np
+from tqdm import tqdm
+from typing import Any, Sequence
+
+from torch import tensor
+
+from deepdiagnostics.metrics.metric import Metric
+
+
+class CDFParityAreaUnderCurve(Metric):
+ def __init__(
+ self,
+ model,
+ data,
+ run_id,
+ out_dir= None,
+ save = True,
+ use_progress_bar = None,
+ samples_per_inference = None,
+ percentiles = None,
+ number_simulations = None,
+ ) -> None:
+ super().__init__(model, data, run_id, out_dir,
+ save,
+ use_progress_bar,
+ samples_per_inference,
+ percentiles,
+ number_simulations)
+ self._collect_data_params()
+
+ def _collect_data_params(self):
+ self.n_dims = self.data.n_dims
+ theta_true = self.data.thetas#self.data.get_theta_true()#self.data.thetas
+ self.posterior_samples = np.zeros(
+ (self.number_simulations, self.samples_per_inference, self.n_dims)
+ )
+ thetas = np.zeros((self.number_simulations, self.samples_per_inference, self.n_dims))
+
+ for n in range(self.number_simulations):
+ sample_index = self.data.rng.integers(0, len(theta_true))
+
+ theta = theta_true[sample_index, :]
+ x = self.data.context[sample_index, :]
+ self.posterior_samples[n] = self.model.sample_posterior(
+ self.samples_per_inference, x
+ )
+
+ thetas[n] = np.array([theta for _ in range(self.samples_per_inference)])
+
+ thetas = thetas.reshape(
+ (self.number_simulations * self.samples_per_inference, self.n_dims)
+ )
+
+ """
+ Compute the ECDF for post posteriors samples against the true parameter values.
+ Uses [scipy.stats.ecdf](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.ecdf.html) to compute the distributions from both given data and samples from the posterior
+
+ ..code-block:: python
+
+ from deepdiagnostics.plots import CDFParityPlot
+
+ CDFParityPlot(model, data, save=False, show=True)()
+
+ """
+ def calculate(self) -> dict[str, float]:
+ #one dimensional
+ #ecdf_sample = ecdf(self.posterior_samples[:, 0].ravel())
+ #print("ecdf_sample ", ecdf_sample)
+ results = {}
+ # Loop through each parameter dimension
+ for d in range(self.n_dims):
+ # Flatten posterior samples for this dimension
+ ecdf_sample = ecdf(self.posterior_samples[:, :, d].ravel())
+ # Calculate the area under the ECDF curve
+ # Compute the ECDF
+ x = ecdf_sample.cdf.quantiles
+ y = ecdf_sample.cdf.probabilities
+ #print("x",x," y ", y)
+ area_under_ecdf = np.trapezoid(y, x)
+ #print(f"Area under the ECDF: {area_under_ecdf:.4f}")
+ #auc="Area_Under_Curve"
+ results[f"Area_Under_Curve_dim{d}"] = area_under_ecdf
+ self.output = results#{auc: area_under_ecdf} ##Need to run calculate
+ return self.output
+ def __call__(self, **kwds: Any) -> Any:
+ self._collect_data_params()
+ self.calculate()
+ self._finish()
diff --git a/src/deepdiagnostics/models/model.py b/src/deepdiagnostics/models/model.py
index 46e2c81..02d0a33 100644
--- a/src/deepdiagnostics/models/model.py
+++ b/src/deepdiagnostics/models/model.py
@@ -1,18 +1,19 @@
class Model:
"""
- Load a pre-trained model for analysis.
-
+ Load a pre-trained model for analysis.
+
Args:
- model_path (str): relative path to a model.
+ model_path (str): relative path to a model.
"""
+
def __init__(self, model_path: str) -> None:
self.model = self._load(model_path)
def _load(self, path: str) -> None:
- return NotImplementedError
+ raise NotImplementedError
def sample_posterior(self):
- return NotImplementedError
+ raise NotImplementedError
def sample_simulation(self, data):
raise NotImplementedError
diff --git a/src/deepdiagnostics/models/sbi_model.py b/src/deepdiagnostics/models/sbi_model.py
index dedba51..6b12fa9 100644
--- a/src/deepdiagnostics/models/sbi_model.py
+++ b/src/deepdiagnostics/models/sbi_model.py
@@ -1,31 +1,65 @@
import os
import pickle
+from sbi.inference.posteriors.base_posterior import NeuralPosterior
+
from deepdiagnostics.models.model import Model
class SBIModel(Model):
"""
- Load a trained model that was generated with Mackelab SBI :cite:p:`centero2020sbi`.
- `Read more about saving and loading requirements here `_.
+ Load a trained model that was generated with Mackelab SBI :cite:p:`centero2020sbi`.
+ `Read more about saving and loading requirements here `_.
Args:
- model_path (str): relative path to a model - must be a .pkl file.
+ model_path (str): Relative path to a model - must be a .pkl file.
"""
+
def __init__(self, model_path):
super().__init__(model_path)
def _load(self, path: str) -> None:
- assert os.path.exists(path), f"Cannot find model file at location {path}"
- assert path.split(".")[-1] == "pkl", "File extension must be 'pkl'"
+ if not os.path.exists(path):
+ raise ValueError(f"Cannot find model file at location {path}")
+ if path.split(".")[-1] != "pkl":
+ raise ValueError("File extension must be 'pkl'")
with open(path, "rb") as file:
posterior = pickle.load(file)
self.posterior = posterior
+ @staticmethod
+ def save_posterior(
+ neural_posterior: NeuralPosterior, path: str, allow_overwrite: bool = False
+ ) -> None:
+ """
+ Save an SBI posterior to a pickle file.
+
+ Args:
+ neural_posterior (NeuralPosterior): A neural posterior object.
+ Must be an instance of the base class 'NeuralPosterior'
+ from the sbi package.
+ path (str): Relative path to a model - must be a .pkl file.
+ allow_overwrite (bool, optional): Controls whether an attempt to
+ overwrite succeeds or results in an error. Defaults to False.
+ """
+ if not isinstance(NeuralPosterior):
+ raise ValueError(
+ f"'neural_posterior' must be an instance of the base class 'NeuralPosterior' from the 'sbi' package."
+ )
+ if os.path.exists(path) and (not allow_overwrite):
+ raise ValueError(
+ f"The path {path} already exists. To overwrite, use 'save_posterior(..., allow_overwrite=True)'"
+ )
+ if path.split(".")[-1] != "pkl":
+ raise ValueError("File extension must be 'pkl'")
+
+ with open(path, "wb") as file:
+ pickle.dump(neural_posterior, file)
+
def sample_posterior(self, n_samples: int, x_true):
"""
- Sample the posterior
+ Sample the posterior
Args:
n_samples (int): Number of samples to draw
@@ -40,14 +74,14 @@ def sample_posterior(self, n_samples: int, x_true):
def predict_posterior(self, data, context_samples):
"""
- Sample the posterior and then
+ Sample the posterior and then
Args:
data (deepdiagnostics.data.Data): Data module with the loaded simulation
- context_samples (np.ndarray): X values to test the posterior over.
+ context_samples (np.ndarray): X values to test the posterior over.
Returns:
- np.ndarray: Simulator output
+ np.ndarray: Simulator output
"""
posterior_samples = self.sample_posterior(context_samples)
posterior_predictive_samples = data.simulator(
diff --git a/src/deepdiagnostics/plots/__init__.py b/src/deepdiagnostics/plots/__init__.py
index e83274c..f221674 100644
--- a/src/deepdiagnostics/plots/__init__.py
+++ b/src/deepdiagnostics/plots/__init__.py
@@ -6,6 +6,7 @@
from deepdiagnostics.plots.predictive_posterior_check import PPC
from deepdiagnostics.plots.parity import Parity
from deepdiagnostics.plots.predictive_prior_check import PriorPC
+from deepdiagnostics.plots.cdf_parity import CDFParityPlot
def void(*args, **kwargs):
@@ -22,5 +23,6 @@ def void2(*args, **kwargs):
"LC2ST": LC2ST,
PPC.__name__: PPC,
"Parity": Parity,
- PriorPC.__name__: PriorPC
+ PriorPC.__name__: PriorPC,
+ CDFParityPlot.__name__: CDFParityPlot
}
diff --git a/src/deepdiagnostics/plots/cdf_parity.py b/src/deepdiagnostics/plots/cdf_parity.py
new file mode 100644
index 0000000..6d36a2f
--- /dev/null
+++ b/src/deepdiagnostics/plots/cdf_parity.py
@@ -0,0 +1,408 @@
+from typing import Union, TYPE_CHECKING
+
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib import patches as mpatches
+from scipy.stats import ecdf, binom
+
+from deepdiagnostics.plots.plot import Display
+from deepdiagnostics.utils.config import get_item
+from deepdiagnostics.utils.utils import DataDisplay
+
+if TYPE_CHECKING:
+ from matplotlib.figure import Figure as fig
+ from matplotlib.axes import Axes as ax
+
+class CDFParityPlot(Display):
+ def __init__(
+ self,
+ model,
+ data,
+ run_id,
+ save,
+ show,
+ out_dir=None,
+ percentiles = None,
+ use_progress_bar= None,
+ samples_per_inference = None,
+ number_simulations= None,
+ parameter_names = None,
+ parameter_colors = None,
+ colorway =None):
+ """
+ Compute the ECDF for post posteriors samples against the true parameter values.
+ Uses [scipy.stats.ecdf](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.ecdf.html) to compute the distributions from both given data and samples from the posterior
+
+ ..code-block:: python
+
+ from deepdiagnostics.plots import CDFParityPlot
+
+ CDFParityPlot(model, data, save=False, show=True)()
+
+ """
+ super().__init__(model, data, run_id, save, show, out_dir, percentiles, use_progress_bar, samples_per_inference, number_simulations, parameter_names, parameter_colors, colorway)
+ self.line_cycle = tuple(get_item("plots_common", "line_style_cycle", raise_exception=False))
+ self.labels_dict = {}
+ self.theory_alpha = 0.2
+
+ def plot_name(self):
+ return "cdf_parity.png"
+
+ def _calculate_theory_cdf(self, probability:float, num_bins:int=100) -> tuple[np.array, np.array]:
+ """
+ Calculate the theoretical limits for the CDF of `distribution` with the percentile `probability`.
+ Assumes the distribution is a binomial distribution
+ """
+
+ n_dims = self.data.n_dims
+ bounds = np.zeros((num_bins, 2, n_dims))
+ cdf = np.zeros((num_bins, n_dims))
+
+ for dim in range(n_dims):
+
+ # Construct uniform histogram.
+ uni_bins = binom(self.samples_per_inference, p=1 / num_bins).ppf(0.5) * np.ones(num_bins)
+ uni_bins_cdf = uni_bins.cumsum() / uni_bins.sum()
+ # Decrease value one in last entry by epsilon to find valid
+ # confidence intervals.
+ uni_bins_cdf[-1] -= 1e-9
+ lower = [binom(self.samples_per_inference, p=p).ppf(1-probability) for p in uni_bins_cdf]
+ upper = [binom(self.samples_per_inference, p=p).ppf(probability) for p in uni_bins_cdf]
+
+ bounds[:, 0, dim] = lower/np.max(lower)
+ bounds[:, 1, dim] = upper/np.max(upper)
+
+ cdf[:, dim] = uni_bins_cdf
+
+ return bounds, cdf
+
+ def _data_setup(self, num_bins:int=100, **kwargs) -> DataDisplay:
+ if all([p >= 1 for p in self.percentiles]):
+ percentiles = [p/100 for p in self.percentiles]
+ else:
+ percentiles = self.percentiles or [0.95]
+
+ n_dims = self.data.n_dims
+ theta_true = self.data.thetas
+ posterior_samples = np.zeros(
+ (self.number_simulations, self.samples_per_inference, n_dims)
+ )
+ thetas = np.zeros((self.number_simulations, self.samples_per_inference, n_dims))
+
+
+ for n in range(self.number_simulations):
+ sample_index = self.data.rng.integers(0, len(theta_true))
+
+ theta = theta_true[sample_index, :]
+ x = self.data.context[sample_index, :]
+ posterior_samples[n] = self.model.sample_posterior(
+ self.samples_per_inference, x
+ )
+
+ thetas[n] = np.array([theta for _ in range(self.samples_per_inference)])
+
+ thetas = thetas.reshape(
+ (self.number_simulations * self.samples_per_inference, n_dims)
+ )
+
+ calculated_ecdf = {}
+ theory_cdf = {}
+ # sample_quartiles based off the first dimension
+ # Not always perfect, but it ensures that the quantiles are consistent across all dimensions - required for the residuals
+ ecdf_sample = ecdf(posterior_samples[:, 0].ravel())
+
+ all_bands = {}
+ for interval in percentiles:
+ bands, cdf = self._calculate_theory_cdf(interval, num_bins)
+ all_bands[interval] = bands
+
+ for dim, name in zip(range(n_dims), self.parameter_names):
+ parameter_quantiles = np.linspace(
+ np.min(thetas[:, dim]),
+ np.max(thetas[:, dim]),
+ num=num_bins
+ )
+ ecdf_sample = ecdf(posterior_samples[:, dim].ravel())
+ sample_probs_common = ecdf_sample.cdf.evaluate(parameter_quantiles)
+ for interval in percentiles:
+ all_bands[f"low_theory_probability_{interval}_{name}"] = all_bands[interval][:, 0, dim]
+ all_bands[f"high_theory_probability_{interval}_{name}"] = all_bands[interval][:, 1, dim]
+
+ theory_cdf[f"theory_probability_{name}"] = cdf[:, dim]
+ calculated_ecdf[f"quantiles_{name}"] = parameter_quantiles
+ calculated_ecdf[f"sample_probability_{name}"] = sample_probs_common
+
+ display_data = DataDisplay({
+ **calculated_ecdf,
+ **all_bands,
+ **theory_cdf, # CDF Isn't calculated differently for percentiles, it's fine to use the last one
+ "percentiles": np.array(percentiles),
+ })
+ return display_data
+
+ def _plot_base_plot(self, data_display, ax, parameter_name, sample_label, line_style, color, theory_color, theory_line_style):
+ "Just plot the CDF of the posterior ECDF"
+ ax.plot(
+ data_display[f"quantiles_{parameter_name}"],
+ data_display[f"theory_probability_{parameter_name}"],
+ color=theory_color,
+ linestyle=theory_line_style
+ )
+
+ ax.plot(
+ data_display[f"quantiles_{parameter_name}"],
+ data_display[f"sample_probability_{parameter_name}"],
+ ls=line_style,
+ color=color,
+ )
+
+ def _plot_theory_intervals(self, data_display, ax, parameter_name, theory_label, color, interval):
+ lower, upper = (
+ data_display[f"low_theory_probability_{interval}_{parameter_name}"],
+ data_display[f"high_theory_probability_{interval}_{parameter_name}"]
+ )
+
+ ax.fill_between(
+ data_display[f"quantiles_{parameter_name}"],
+ lower,
+ upper,
+ alpha=self.theory_alpha,
+ color=color
+ )
+
+ def _plot_theory_intervals_residual(self, data_display, ax, parameter_name, theory_label, color, interval):
+ if data_display[f"low_theory_probability_{interval}_{parameter_name}"] is None:
+ bound_low, bound_high = self._compute_intervals(
+ data_display[f"theory_probability_{parameter_name}"],
+ interval,
+ self.parameter_names.index(parameter_name)
+ )
+ low = bound_low - data_display[f"theory_probability_{parameter_name}"]
+ high = bound_high - data_display[f"theory_probability_{parameter_name}"]
+ else:
+ # Use the precomputed values for the fill-between
+ low = data_display[f"low_theory_probability_{interval}_{parameter_name}"] - data_display[f"theory_probability_{parameter_name}"]
+ high = data_display[f"high_theory_probability_{interval}_{parameter_name}"] - data_display[f"theory_probability_{parameter_name}"]
+
+ ax.fill_between(
+ data_display[f"quantiles_{parameter_name}"],
+ low,
+ high,
+ alpha=self.theory_alpha,
+ color=color
+ )
+
+ def _compute_intervals(self, cdf: np.ndarray, probability: float, dimension:int) -> tuple[np.ndarray, np.ndarray]:
+ "Use the Dvoretzky-Kiefer-Wolfowitz confidence bands as an approximation for plotting purposes."
+
+ bound = np.sqrt(np.log(2.0 / (1 - probability)) / (2.0 * float(cdf.shape[0])))
+ lower = cdf[:, dimension] - bound
+ upper = cdf[:, dimension] + bound
+ return lower, upper
+
+
+ def plot(
+ self,
+ data_display: Union[DataDisplay, str],
+ include_residuals: bool = False,
+ include_theory_intervals: bool = True,
+ display_parameters_separate: bool = False,
+ x_label: str = "Quantiles",
+ y_label: str = "CDF",
+ title: str = "CDF Parity Plot",
+ samples_label = "Posterior Samples",
+ theory_label = "Theory",
+ theory_color = "gray",
+ theory_line_style = "--",
+ normalize_view: bool = True,
+ theory_alpha: float = 0.2,
+ **kwargs
+ ) -> tuple["fig", "ax"]:
+ """
+ Compute the ECDF for post posteriors samples against the true parameter values.
+ Uses [scipy.stats.ecdf](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.ecdf.html) to compute the distributions for sampled posterior samples.
+
+ To show the all distributions on one plot - set `display_parameters_separate` to `False` and verify `normalize_view` is set to `True`, this will ensure the x & y axes is normalized to [0, 1] for all parameters.
+
+ Args:
+ data_display (DataDisplay or str): The data to plot. If a string, it is assumed to be the path to an HDF5 file.
+ include_residuals (bool): Whether to include the residuals between the theory and sample distributions.
+ include_theory_intervals (bool): Whether to include the theory intervals (percentiles given in the 'percentiles' field of the config) in the plot
+ display_parameters_separate (bool): Whether to display each parameter in a separate subplot.
+ x_label (str): Label for the x-axis.
+ y_label (str): Label for the y-axis.
+ title (str): Title of the plot.
+ samples_label (str): Label for the samples in the plot.
+ theory_label (str): Label for the theory in the plot.
+ theory_color (str): Color for the center theory line.
+ theory_line_style (str): Line style for center theory line.
+ normalize_view (bool): Whether to normalize the x axis of the plot to [0, 1] for all parameters.
+ theory_alpha (float): Alpha (transparency) value for the fill between the theory intervals. Between 0 and 1.
+ """
+
+ self.theory_alpha = theory_alpha
+ color_cycler = iter(plt.cycler("color", self.parameter_colors))
+ line_style_cycler = iter(plt.cycler("line_style", self.line_cycle))
+
+ if not isinstance(data_display, DataDisplay):
+ data_display = DataDisplay().from_h5(data_display, self.plot_name)
+
+ # Used if theory intervals are included
+ theory_color_cycle = self._get_hex_sigma_colors(len(data_display["percentiles"]))
+
+
+ if include_residuals:
+ row_len = self.figure_size[0] * .8*len(self.parameter_names) if display_parameters_separate else self.figure_size[0]
+ figsize = (row_len, 1.5*self.figure_size[1])
+ fig, ax = plt.subplots(
+ 2, len(self.parameter_names) if display_parameters_separate else 1,
+ figsize=figsize,
+ height_ratios=[3, 1],
+ sharex='col',
+ sharey='row'
+ )
+ plt.subplots_adjust(hspace=0.01)
+
+ else:
+ row_len = self.figure_size[0] * .8*len(self.parameter_names) if display_parameters_separate else self.figure_size[0]
+ fig, ax = plt.subplots(
+ 1, len(self.parameter_names) if display_parameters_separate else 1,
+ figsize=(row_len, self.figure_size[1]),
+ sharey='row')
+
+ if normalize_view:
+ for parameter_name in self.parameter_names:
+ data_display[f"quantiles_{parameter_name}"] = np.linspace(0, 1, num=len(data_display[f"quantiles_{parameter_name}"]))
+
+ if include_theory_intervals: # Each plot needs to iterate over the percentiles in the main plot and the residuals
+ if display_parameters_separate:
+ for index, parameter_name in enumerate(self.parameter_names):
+ plot_ax = ax[index] if not include_residuals else ax[0][index]
+ plot_ax.plot(data_display[f"quantiles_{parameter_name}"], data_display[f"theory_probability_{parameter_name}"], color=theory_color, linestyle=theory_line_style)
+
+ plot_ax.set_title(f"{samples_label} {parameter_name}")
+ color = next(color_cycler)["color"]
+ line_style = next(line_style_cycler)["line_style"]
+ self._plot_base_plot(data_display, plot_ax, parameter_name, samples_label, line_style, color, theory_color, theory_line_style)
+
+ for interval_index, interval in enumerate(data_display["percentiles"]):
+ self._plot_theory_intervals(
+ data_display, plot_ax, parameter_name, theory_label,
+ theory_color_cycle[interval_index], interval
+ )
+
+ if include_residuals:
+ self._plot_theory_intervals_residual(
+ data_display, ax[1, index], parameter_name, theory_label,
+ theory_color_cycle[interval_index], interval
+ )
+
+ if include_residuals:
+ # Plot the residuals between the theory and sample
+ residual = data_display[f"sample_probability_{parameter_name}"] - data_display[f"theory_probability_{parameter_name}"]
+
+ ax[1, index].plot(
+ data_display[f"quantiles_{parameter_name}"],
+ residual,
+ linestyle=line_style,
+ color=color,
+ )
+ ax[1, index].axhline(0, color=theory_color, linestyle=theory_line_style)
+
+
+ else: # The plot_ax is the same for all parameters
+ plot_ax = ax if not include_residuals else ax[0]
+ for parameter_name in self.parameter_names:
+ color = next(color_cycler)["color"]
+ line_style = next(line_style_cycler)["line_style"]
+ self._plot_base_plot(data_display, plot_ax, parameter_name, samples_label, line_style, color, theory_color, theory_line_style)
+ for interval_index, interval in enumerate(data_display["percentiles"]): # iterate over the percentiles
+ if self.parameter_names.index(parameter_name) == 0: # Only plot for the first theory interval when not displaying parameters separately
+ self._plot_theory_intervals(
+ data_display, plot_ax, parameter_name, theory_label,
+ theory_color_cycle[interval_index], interval
+ )
+
+ if include_residuals:
+ self._plot_theory_intervals_residual(
+ data_display, ax[1], parameter_name, theory_label,
+ theory_color_cycle[interval_index], interval
+ )
+
+ if include_residuals:
+ # Plot the residuals between the theory and sample
+ residual = data_display[f"sample_probability_{parameter_name}"] - data_display[f"theory_probability_{parameter_name}"]
+
+ ax[1].plot(
+ data_display[f"quantiles_{parameter_name}"],
+ residual,
+ linestyle=line_style,
+ color=color,
+ )
+ ax[1].axhline(0, color=theory_color, linestyle=theory_line_style)
+
+
+ else: # Do not include the theory intervals - no fill-betweens here!
+ if display_parameters_separate:
+ for index, parameter_name in enumerate(self.parameter_names):
+ # Each parameter gets it's own subplot
+ plot_ax = ax[index] if not include_residuals else ax[0][index]
+ plot_ax.set_title(f"{samples_label} {parameter_name}")
+ color = next(color_cycler)["color"]
+ line_style = next(line_style_cycler)["line_style"]
+ self._plot_base_plot(data_display, plot_ax, parameter_name, samples_label, line_style, color, theory_color, theory_line_style)
+
+ if include_residuals:
+ residual = data_display[f"sample_probability_{parameter_name}"] - data_display[f"theory_probability_{parameter_name}"]
+
+ ax[1, index].plot(
+ data_display[f"quantiles_{parameter_name}"],
+ residual,
+ linestyle=line_style,
+ color=color,
+ )
+ ax[1, index].axhline(0, color=theory_color, linestyle=theory_line_style)
+
+ else: # Everything goes on the one column
+ plot_ax = ax if not include_residuals else ax[0]
+ for index, parameter_name in enumerate(self.parameter_names):
+ color = next(color_cycler)["color"]
+ line_style = next(line_style_cycler)["line_style"]
+
+ self._plot_base_plot(data_display, plot_ax, parameter_name, samples_label, line_style, color, theory_color, theory_line_style)
+
+ if include_residuals:
+ residual = data_display[f"sample_probability_{parameter_name}"] - data_display[f"theory_probability_{parameter_name}"]
+ ax[1].plot(
+ data_display[f"quantiles_{parameter_name}"],
+ residual,
+ linestyle=line_style,
+ color=color,
+ )
+ ax[1].axhline(0, color=theory_color, linestyle=theory_line_style)
+
+
+ handles = [
+ plt.Line2D([0], [0], color=theory_color, linestyle=theory_line_style, label=theory_label)
+ ]
+ if include_theory_intervals:
+ handles += [
+ mpatches.Rectangle((0, 0), 0, 0, facecolor=theory_color_cycle[i], alpha=self.theory_alpha, edgecolor='none', label=f"CDF {int(data_display['percentiles'][i]*100)}% CI {theory_label}")
+ for i in range(len(data_display["percentiles"]))
+ ]
+
+ if not display_parameters_separate:
+ # reset the color and line style cyclers for the handles
+ color_cycler = iter(plt.cycler("color", self.parameter_colors))
+ line_style_cycler = iter(plt.cycler("line_style", self.line_cycle))
+ handles += [
+ plt.Line2D([0], [0], color=color['color'], linestyle=line_style['line_style'], label=f"{samples_label} {parameter_name}")
+ for parameter_name, color, line_style in zip(self.parameter_names, color_cycler, line_style_cycler)
+ ]
+
+ fig.legend(handles=handles)
+ fig.suptitle(title)
+ fig.supxlabel(x_label)
+ fig.supylabel(y_label)
+
+ return fig, ax
diff --git a/src/deepdiagnostics/plots/cdf_ranks.py b/src/deepdiagnostics/plots/cdf_ranks.py
index 19c2da6..99d1bb2 100644
--- a/src/deepdiagnostics/plots/cdf_ranks.py
+++ b/src/deepdiagnostics/plots/cdf_ranks.py
@@ -46,8 +46,8 @@ def plot_name(self):
return "cdf_ranks.png"
def _data_setup(self) -> DataDisplay:
- thetas = tensor(self.data.get_theta_true())
- context = tensor(self.data.true_context())
+ thetas = tensor(self.data.thetas)
+ context = tensor(self.data.simulator_outcome)
ranks, _ = run_sbc(
thetas, context, self.model.posterior, num_posterior_samples=self.samples_per_inference
diff --git a/src/deepdiagnostics/plots/parity.py b/src/deepdiagnostics/plots/parity.py
index db4c7ba..67d7a32 100644
--- a/src/deepdiagnostics/plots/parity.py
+++ b/src/deepdiagnostics/plots/parity.py
@@ -48,7 +48,7 @@ def plot_name(self):
def _data_setup(self, n_samples: int = 80, **kwargs) -> DataDisplay:
- context_shape = self.data.true_context().shape
+ context_shape = self.data.simulator_outcome.shape
posterior_sample_mean = np.zeros((n_samples, self.data.n_dims))
posterior_sample_std = np.zeros_like(posterior_sample_mean)
true_samples = np.zeros_like(posterior_sample_mean)
@@ -56,11 +56,11 @@ def _data_setup(self, n_samples: int = 80, **kwargs) -> DataDisplay:
random_context_indices = self.data.rng.integers(0, context_shape[0], n_samples)
for index, sample in enumerate(random_context_indices):
- posterior_sample = self.model.sample_posterior(self.samples_per_inference, self.data.true_context()[sample, :]).numpy()
+ posterior_sample = self.model.sample_posterior(self.samples_per_inference, self.data.simulator_outcome[sample, :]).numpy()
posterior_sample_mean[index] = np.mean(posterior_sample, axis=0)
posterior_sample_std[index] = np.std(posterior_sample, axis=0)
- true_samples[index] = self.data.get_theta_true()[sample, :]
+ true_samples[index] = self.data.thetas[sample, :]
return DataDisplay(
n_dims=self.data.n_dims,
diff --git a/src/deepdiagnostics/plots/plot.py b/src/deepdiagnostics/plots/plot.py
index 71b7c82..71fe989 100644
--- a/src/deepdiagnostics/plots/plot.py
+++ b/src/deepdiagnostics/plots/plot.py
@@ -4,6 +4,8 @@
from typing import Optional, Sequence, TYPE_CHECKING, Union
import matplotlib.pyplot as plt
from matplotlib import rcParams
+import matplotlib.colors as plt_colors
+import numpy as np
from deepdiagnostics.utils.config import get_item
from deepdiagnostics.utils.utils import DataDisplay
@@ -89,6 +91,17 @@ def _data_setup(self, **kwargs) -> Optional[DataDisplay]:
"Return all the data required for plotting"
raise NotImplementedError
+ def _get_hex_sigma_colors(self, n_colors):
+
+ cmap = plt.get_cmap(self.colorway)
+ hex_colors = []
+ arr = np.linspace(0, 1, n_colors)
+ for hit in arr:
+ hex_colors.append(plt_colors.rgb2hex(cmap(hit)))
+
+ return hex_colors
+
+
@abstractmethod
def plot(self, data_display: Union[dict, "data_display"], **kwrgs) -> tuple["figure", "axes"]:
"""
diff --git a/src/deepdiagnostics/plots/predictive_posterior_check.py b/src/deepdiagnostics/plots/predictive_posterior_check.py
index 342e287..4df4614 100644
--- a/src/deepdiagnostics/plots/predictive_posterior_check.py
+++ b/src/deepdiagnostics/plots/predictive_posterior_check.py
@@ -51,69 +51,60 @@ def plot_name(self):
return "predictive_posterior_check.png"
def _get_posterior_2d(self, n_simulator_draws):
- context_shape = self.data.true_context().shape
- sim_out_shape = self.data.get_simulator_output_shape()
+ sim_out_shape = self.data.simulator(
+ theta=self.data.thetas[0].unsqueeze(0),
+ n_samples=1
+ )[0].shape
+
remove_first_dim = False
- if len(sim_out_shape) != 2:
+ if len(sim_out_shape) > 2:
# TODO Debug log with a warning
sim_out_shape = (sim_out_shape[1], sim_out_shape[2])
remove_first_dim = True
posterior_predictive_samples = np.zeros((n_simulator_draws, *sim_out_shape))
- posterior_true_samples = np.zeros_like(posterior_predictive_samples)
- random_context_indices = self.data.rng.integers(0, context_shape[0], n_simulator_draws)
- for index, sample in enumerate(random_context_indices):
- context_sample = self.data.true_context()[sample, :]
- posterior_sample = self.model.sample_posterior(1, context_sample)
+ random_context_indices = self.data.rng.integers(0, self.data.simulator_outcome.shape[0], n_simulator_draws)
+ simulator_true = self.data.simulator_outcome[random_context_indices, :].numpy()
+ for index, sample in enumerate(simulator_true):
+ posterior_sample = self.model.sample_posterior(1, sample)
# get the posterior samples for that context
- sim_out_posterior = self.data.simulator.simulate(
- theta=posterior_sample, context_samples = context_sample
- )
- sim_out_true = self.data.simulator.simulate(
- theta=self.data.get_theta_true()[sample, :], context_samples=context_sample
+ sim_out_posterior = self.data.simulator(n_samples=1,
+ theta=posterior_sample
)
if remove_first_dim:
sim_out_posterior = sim_out_posterior[0]
- sim_out_true = sim_out_true[0]
posterior_predictive_samples[index] = sim_out_posterior
- posterior_true_samples[index] = sim_out_true
- return posterior_predictive_samples, posterior_true_samples
+ return posterior_predictive_samples, simulator_true
def _get_posterior_1d(self, n_simulator_draws):
- context_shape = self.data.true_context().shape
- posterior_predictive_samples = np.zeros((n_simulator_draws, self.samples_per_inference, context_shape[-1]))
- posterior_true_samples = np.zeros_like(posterior_predictive_samples)
- context = np.zeros((n_simulator_draws, context_shape[-1]))
+ simulator_outcome_shape = self.data.simulator_dimensions
- random_context_indices = self.data.rng.integers(0, context_shape[0], n_simulator_draws)
- for index, sample in enumerate(random_context_indices):
- context_sample = self.data.true_context()[sample, :]
- context[index] = context_sample
+ posterior_predictive_samples = np.zeros((n_simulator_draws, self.samples_per_inference, simulator_outcome_shape))
- posterior_sample = self.model.sample_posterior(self.samples_per_inference, context_sample)
+ # Sample one random simulator output for each draw
+ random_context_indices = self.data.rng.integers(0, self.data.simulator_outcome.shape[0], n_simulator_draws)
+ simulator_samples = self.data.simulator_outcome[random_context_indices, :].numpy()
+ posterior_predictive_samples = np.zeros((n_simulator_draws, self.samples_per_inference, *simulator_samples[0].shape))
- # get the posterior samples for that context
- posterior_predictive_samples[index] = self.data.simulator.simulate(
- theta=posterior_sample, context_samples = context_sample
- )
- posterior_true_samples[index] = self.data.simulator.simulate(
- theta=self.data.get_theta_true()[sample, :], context_samples=context_sample
+ for index, sample in enumerate(simulator_samples):
+ posterior_sample = self.model.sample_posterior(self.samples_per_inference, sample)
+ posterior_predictive_samples[index] = self.data.simulator(n_samples=len(sample),
+ theta=posterior_sample
)
- return posterior_predictive_samples, posterior_true_samples, context
+
+ return posterior_predictive_samples, simulator_samples
def _data_setup(self, n_unique_plots: Optional[int] = 3, **kwargs) -> DataDisplay:
+ true_sigma = None
if self.data.simulator_dimensions == 1:
- n_dims = 1
-
- posterior_predictive_samples, posterior_true_samples, context = self._get_posterior_1d(n_unique_plots)
+ posterior_predictive_samples, posterior_true_samples = self._get_posterior_1d(n_unique_plots)
true_sigma = self.data.get_sigma_true()
elif self.data.simulator_dimensions == 2:
- n_dims = 2
posterior_predictive_samples, posterior_true_samples = self._get_posterior_2d(n_unique_plots)
else:
@@ -122,10 +113,9 @@ def _data_setup(self, n_unique_plots: Optional[int] = 3, **kwargs) -> DataDispla
return DataDisplay(
n_unique_plots=n_unique_plots,
posterior_predictive_samples=posterior_predictive_samples,
- n_dims=n_dims,
+ n_dims=self.data.simulator_dimensions,
posterior_true_samples=posterior_true_samples,
- context=context if n_dims == 1 else None,
- true_sigma=true_sigma if n_dims == 1 else None
+ true_sigma=true_sigma
)
@@ -167,9 +157,6 @@ def plot(
for plot_index in range(data_display.n_unique_plots):
if data_display.n_dims == 1:
- if data_display.context is None:
- raise ValueError("Display Data is malformed. Missing `context` for 1D simulation. Please rerun data setup stage.")
-
dimension_y_simulation = data_display.posterior_predictive_samples[plot_index]
y_simulation_mean = np.mean(dimension_y_simulation, axis=0).ravel()
@@ -177,7 +164,7 @@ def plot(
for sigma, color in zip(range(n_coverage_sigma), self.colors):
subplots[0, plot_index].fill_between(
- data_display.context[plot_index].ravel(),
+ range(len(y_simulation_mean)),
y_simulation_mean - sigma * y_simulation_std,
y_simulation_mean + sigma * y_simulation_std,
color=color,
@@ -186,22 +173,22 @@ def plot(
)
subplots[0, plot_index].plot(
- data_display.context[plot_index],
+ range(len(y_simulation_mean)),
y_simulation_mean - data_display.true_sigma,
color="black",
linestyle="dashdot",
label="True Input Error"
)
subplots[0, plot_index].plot(
- data_display.context[plot_index],
+ range(len(y_simulation_mean)),
y_simulation_mean + data_display.true_sigma,
color="black",
linestyle="dashdot",
)
- true_y = np.mean(data_display.posterior_true_samples[plot_index, :, :], axis=0).ravel()
+ true_y = data_display.posterior_true_samples[plot_index, :].ravel()
subplots[1, plot_index].scatter(
- data_display.context[plot_index],
+ range(len(true_y)),
true_y,
marker=theta_true_marker,
label='Theta True'
diff --git a/src/deepdiagnostics/plots/predictive_prior_check.py b/src/deepdiagnostics/plots/predictive_prior_check.py
index 3dabc78..a124cf6 100644
--- a/src/deepdiagnostics/plots/predictive_prior_check.py
+++ b/src/deepdiagnostics/plots/predictive_prior_check.py
@@ -54,7 +54,7 @@ def plot_name(self):
return "predictive_prior_check.png"
def _data_setup(self, n_rows: int = 3, n_columns: int = 3, **kwargs) -> DataDisplay:
- context_shape = self.data.true_context().shape
+ sim_out_shape = self.data.simulator_outcome[0].shape
remove_first_dim = False
if self.data.simulator_dimensions == 1:
@@ -65,45 +65,37 @@ def _data_setup(self, n_rows: int = 3, n_columns: int = 3, **kwargs) -> DataDisp
if plot_image:
- sim_out_shape = self.data.get_simulator_output_shape()
if len(sim_out_shape) != 2:
# TODO Debug log with a warning
sim_out_shape = (sim_out_shape[1], sim_out_shape[2])
remove_first_dim = True
- prior_predictive_samples = np.zeros((n_rows, n_columns, *sim_out_shape))
-
- else:
- prior_predictive_samples = np.zeros((n_rows, n_columns, context_shape[-1]))
+ prior_predictive_samples = np.zeros((n_rows, n_columns, *sim_out_shape))
prior_true_sample = np.zeros((n_rows, n_columns, self.data.n_dims))
- context = np.zeros((n_rows, n_columns, context_shape[-1]))
- random_context_indices = self.data.rng.integers(0, context_shape[0], (n_rows, n_columns))
+
+ print(prior_predictive_samples.shape)
+ print(prior_true_sample.shape)
for row_index in range(n_rows):
for column_index in range(n_columns):
- sample = random_context_indices[row_index, column_index]
- context_sample = self.data.true_context()[sample, :]
-
prior_sample = self.data.sample_prior(1)[0]
# get the posterior samples for that context
- simulation_sample = self.data.simulator.simulate(
- theta=prior_sample, context_samples = context_sample
+ simulation_sample = self.data.simulator(
+ theta=prior_sample, n_samples=np.prod(sim_out_shape)
)
if remove_first_dim:
simulation_sample = simulation_sample[0]
prior_predictive_samples[row_index, column_index] = simulation_sample
prior_true_sample[row_index, column_index] = prior_sample
- context[row_index, column_index] = context_sample
return DataDisplay(
plot_image=plot_image,
n_rows=n_rows,
n_columns=n_columns,
prior_predictive_samples=prior_predictive_samples,
- prior_true_sample=prior_true_sample,
- context=context)
+ prior_true_sample=prior_true_sample)
def plot(
self,
@@ -192,7 +184,7 @@ def plot(
else:
subplots[plot_row_index, plot_column_index].plot(
- data_display.context[column_index, row_index],
+ range(len(data_display.prior_predictive_samples[column_index, row_index])),
data_display.prior_predictive_samples[column_index, row_index]
)
diff --git a/src/deepdiagnostics/plots/ranks.py b/src/deepdiagnostics/plots/ranks.py
index b592aa7..8ff86af 100644
--- a/src/deepdiagnostics/plots/ranks.py
+++ b/src/deepdiagnostics/plots/ranks.py
@@ -46,8 +46,8 @@ def plot_name(self):
return "ranks.png"
def _data_setup(self, **kwargs) -> DataDisplay:
- thetas = tensor(self.data.get_theta_true())
- context = tensor(self.data.true_context())
+ thetas = tensor(self.data.thetas)
+ context = tensor(self.data.simulator_outcome)
ranks, _ = run_sbc(
thetas, context, self.model.posterior, num_posterior_samples=self.samples_per_inference
)
diff --git a/src/deepdiagnostics/plots/tarp.py b/src/deepdiagnostics/plots/tarp.py
index 94abfbc..8e6c3ef 100644
--- a/src/deepdiagnostics/plots/tarp.py
+++ b/src/deepdiagnostics/plots/tarp.py
@@ -4,7 +4,6 @@
import tarp
import matplotlib.pyplot as plt
-import matplotlib.colors as plt_colors
from matplotlib.axes import Axes as ax
from matplotlib.figure import Figure as fig
@@ -52,7 +51,7 @@ def plot_name(self):
return "tarp.png"
def _data_setup(self, **kwargs) -> DataDisplay:
- self.theta_true = self.data.get_theta_true()
+ self.theta_true = self.data.thetas.numpy()
n_dims = self.theta_true.shape[1]
posterior_samples = np.zeros(
(self.number_simulations, self.samples_per_inference, n_dims)
@@ -62,7 +61,7 @@ def _data_setup(self, **kwargs) -> DataDisplay:
sample_index = self.data.rng.integers(0, len(self.theta_true))
theta = self.theta_true[sample_index, :]
- x = self.data.true_context()[sample_index, :]
+ x = self.data.simulator_outcome[sample_index, :]
posterior_samples[n] = self.model.sample_posterior(
self.samples_per_inference, x
)
@@ -78,16 +77,6 @@ def plot_settings(self):
"plots_common", "line_style_cycle", raise_exception=False
)
- def _get_hex_sigma_colors(self, n_colors):
-
- cmap = plt.get_cmap(self.colorway)
- hex_colors = []
- arr = np.linspace(0, 1, n_colors)
- for hit in arr:
- hex_colors.append(plt_colors.rgb2hex(cmap(hit)))
-
- return hex_colors
-
def plot(
self,
data_display: Union[DataDisplay, dict] = None,
diff --git a/src/deepdiagnostics/plots/under_cdf_parity.py b/src/deepdiagnostics/plots/under_cdf_parity.py
new file mode 100644
index 0000000..9fa82fe
--- /dev/null
+++ b/src/deepdiagnostics/plots/under_cdf_parity.py
@@ -0,0 +1,10 @@
+from typing import Union, TYPE_CHECKING
+
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib import patches as mpatches
+from scipy.stats import ecdf
+
+from deepdiagnostics.plots.plot import Display
+from deepdiagnostics.utils.config import get_item
+from deepdiagnostics.utils.utils import DataDisplay
diff --git a/src/deepdiagnostics/utils/defaults.py b/src/deepdiagnostics/utils/defaults.py
index 2412a55..d82f619 100644
--- a/src/deepdiagnostics/utils/defaults.py
+++ b/src/deepdiagnostics/utils/defaults.py
@@ -33,7 +33,8 @@
"LC2ST": {},
"Parity":{},
"PPC": {},
- "PriorPC":{}
+ "PriorPC":{},
+ "CDFParityPlot": {}
},
"metrics_common": {
"use_progress_bar": False,
diff --git a/tests/conftest.py b/tests/conftest.py
index 70874d6..ee4af93 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -57,7 +57,7 @@ def simulate(self, theta, context_samples: np.ndarray):
for sample_index, t in enumerate(theta):
mock_data = np.random.normal(
- loc=t[0], scale=abs(t[1]), size=(len(context_samples), 2)
+ loc=t[0], scale=abs(t[1]), size=(len(context_samples), len(context_samples))
)
generated_stars.append(
np.column_stack((context_samples, mock_data))
@@ -84,7 +84,7 @@ def model_path():
@pytest.fixture
def data_path():
- return "resources/saveddata/data_validation.h5"
+ return "resources/saveddata/data_test.h5"
@pytest.fixture
def result_output():
diff --git a/tests/test_client.py b/tests/test_client.py
index 78a955d..e4f8980 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -81,6 +81,5 @@ def test_missing_simulator(model_path, data_path):
process = subprocess.run(command, capture_output=True)
exit_code = process.returncode
stdout = process.stdout.decode("utf-8")
- assert exit_code == 0
- plot_name = "PPC"
- assert f"Cannot run {plot_name} - simulator missing." in stdout
+ assert exit_code == 0, process.stderr.decode("utf-8")
+ assert "Warning: Simulator not loaded. Using a lookup table simulator." in stdout
diff --git a/tests/test_metrics.py b/tests/test_metrics.py
index 96ef9cf..428307d 100644
--- a/tests/test_metrics.py
+++ b/tests/test_metrics.py
@@ -5,7 +5,8 @@
from deepdiagnostics.metrics import (
CoverageFraction,
AllSBC,
- LC2ST
+ LC2ST,
+ CDFParityAreaUnderCurve
)
@pytest.fixture
@@ -44,3 +45,11 @@ def test_lc2st(metric_config, mock_model, mock_data, mock_run_id):
assert lc2st.output is not None
assert os.path.exists(f"{lc2st.out_dir}/{mock_run_id}_diagnostic_metrics.json")
+def test_CDFParity_Area_Under_Curve(metric_config, mock_model, mock_data, mock_run_id):
+ Config(metric_config)
+ areaunderecdf=CDFParityAreaUnderCurve( mock_model, mock_data, mock_run_id, save=True)
+ areaunderecdf()
+ assert areaunderecdf.output is not None
+ #print(areaunderecdf.output) #used for testing to check output
+ #assert False #testing
+ assert os.path.exists(f"{areaunderecdf.out_dir}/{mock_run_id}_diagnostic_metrics.json")
diff --git a/tests/test_plots.py b/tests/test_plots.py
index dfcda61..1d043cd 100644
--- a/tests/test_plots.py
+++ b/tests/test_plots.py
@@ -20,7 +20,7 @@ def plot_config(config_factory):
metrics_settings = {
"use_progress_bar": False,
"samples_per_inference": 10,
- "percentiles": [95, 75, 50],
+ "percentiles": [68, 95],
}
config = config_factory(metrics_settings=metrics_settings)
return config
@@ -71,6 +71,10 @@ def test_ppc(plot_config, mock_model, mock_data, mock_2d_data, result_output, mo
plot(**get_item("plots", "PPC", raise_exception=False))
assert os.path.exists(f"{plot.out_dir}/{mock_run_id}_{plot.plot_name}")
+
+@pytest.mark.xfail(reason="2D dataset needs to be generated - the simulator will not run automatically.")
+def test_ppc_2d(plot_config, mock_model, mock_2d_data, result_output, mock_run_id):
+ Config(plot_config)
plot = PPC(
mock_model,
mock_2d_data, mock_run_id, save=True, show=False,
@@ -79,11 +83,16 @@ def test_ppc(plot_config, mock_model, mock_data, mock_2d_data, result_output, mo
plot(**get_item("plots", "PPC", raise_exception=False))
-def test_prior_pc(plot_config, mock_model, mock_2d_data, mock_data, mock_run_id, result_output):
+def test_prior_pc(plot_config, mock_model, mock_data, mock_run_id):
Config(plot_config)
plot = PriorPC(mock_model, mock_data, mock_run_id, save=True, show=False)
plot(**get_item("plots", "PriorPC", raise_exception=False))
assert os.path.exists(f"{plot.out_dir}/{mock_run_id}_{plot.plot_name}")
+
+
+@pytest.mark.xfail(reason="2D dataset needs to be generated - the simulator will not run automatically.")
+def plot_prior_pc_2d(plot_config, mock_model, mock_2d_data, result_output, mock_run_id):
+ Config(plot_config)
plot = PriorPC(
mock_model,
mock_2d_data, mock_run_id, save=True, show=False,
@@ -135,3 +144,9 @@ def test_rerun_plot(plot_type, plot_config, mock_model, mock_data, mock_run_id):
assert os.path.exists(f"{plot.out_dir}rerun_plot.png")
+def test_cdf_parity(plot_config, mock_model, mock_data, mock_run_id):
+ from deepdiagnostics.plots import CDFParityPlot
+ Config(plot_config)
+ plot = CDFParityPlot(mock_model, mock_data, mock_run_id, save=True, show=False)
+ plot(**get_item("plots", "CDFParityPlot", raise_exception=False))
+ assert os.path.exists(f"{plot.out_dir}/{mock_run_id}_{plot.plot_name}")
\ No newline at end of file
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 711112b..a6e2286 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -8,9 +8,9 @@ class TestLookupTableSimulator:
def test_lookup_table_simulator():
# fake data
data = {
- "xs": torch.Tensor([[0.1], [0.2], [0.3], [0.4]]), # context
+ "context": torch.Tensor([[0.1], [0.2], [0.3], [0.4]]), # context
"thetas": torch.Tensor([[1.0], [2.0], [3.0], [4.0]]), # parameters
- "ys": torch.Tensor([[10.0], [20.0], [30.0], [40.0]]), # outcomes
+ "simulator_outcome": torch.Tensor([[10.0], [20.0], [30.0], [40.0]]), # outcomes
}
rng = np.random.default_rng(42)
sim = LookupTableSimulator(data, rng)
@@ -20,11 +20,11 @@ def test_lookup_table_simulator():
# Test generate_context
contexts = sim.generate_context(2)
assert contexts.shape == (2, 1)
- assert all(context in data["xs"].tolist() for context in contexts) # Only getting contexts from data
+ assert all(context in data["context"].tolist() for context in contexts) # Only getting contexts from data
# Test exact match outcome
- theta = torch.Tensor([[2.0]])
- context = torch.Tensor([[0.2]])
+ theta = torch.Tensor([2.0])
+ context = torch.Tensor([0.2])
outcome = sim.simulate(theta, context)
assert outcome.shape == (1, 1)
assert outcome[0] == 20.0
@@ -48,9 +48,9 @@ def test_lookup_table_simulator_multidim_params():
rng = np.random.default_rng(42)
data = {
- "xs": torch.tensor(rng.random((10, 2))), # context
+ "context": torch.tensor(rng.random((10, 2))), # context
"thetas": torch.tensor(rng.random((10, 3))), # parameters
- "ys": torch.tensor(rng.random((10, 1))), # outcomes
+ "simulator_outcome": torch.tensor(rng.random((10, 1))), # outcomes
}
rng = np.random.default_rng(42)
@@ -58,11 +58,11 @@ def test_lookup_table_simulator_multidim_params():
# Test exact match outcome
theta = data["thetas"][:2, :]
- context = data["xs"][:2, :]
+ context = data["context"][:2, :]
outcome = sim.simulate(theta, context)
assert outcome.shape == (2, 1)
- assert outcome[0] == data["ys"][0]
- assert outcome[1] == data["ys"][1]
+ assert outcome[0] == data["simulator_outcome"][0]
+ assert outcome[1] == data["simulator_outcome"][1]
# Test nearest neighbor outcome
theta = rng.random((1, 3))