Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
6ee6718
Adding the under the ecdf parity code to git
scarletnorberg Sep 12, 2025
a76abbe
Really adding the under the curve cdf parity to the correct place
scarletnorberg Sep 12, 2025
9e0c097
Adding in the under the curve metric to the test metric code
scarletnorberg Sep 12, 2025
c04ffe1
Corrected names of simulator output, thetas
voetberg Aug 6, 2025
662126b
Minimal example config
voetberg Jul 24, 2025
32bb8ad
Include a CDF parity plot by using scipy's ecdf #34
voetberg Jul 16, 2025
0bb14c1
Remove hspace between plots
voetberg Jul 17, 2025
2428e6f
Put all parameters on the same axis
voetberg Jul 17, 2025
04d740d
Change CDF calculation
voetberg Jul 22, 2025
6811f58
Correction to pdf calculation assuming uniform theorical pdf
voetberg Aug 19, 2025
aa99abd
Additional documentation for ecdf
voetberg Aug 21, 2025
6c018e9
Minor debug after rebase
voetberg Aug 21, 2025
baf4b7a
Add save feature for models. Misc typo and format fixes (via pre-comm…
prasanthcakewalk Aug 27, 2025
27b9a78
Fix sbi posterior save feature
prasanthcakewalk Aug 28, 2025
f12bdce
Rename save_posterior method
prasanthcakewalk Aug 28, 2025
36b5bf8
Adding the under the cdf function
scarletnorberg Sep 12, 2025
8c7c7ea
Adding changes that have posterior
scarletnorberg Sep 12, 2025
0722d84
Fixed a varaible definition issue
scarletnorberg Sep 12, 2025
aabbb72
Adding in n D for area under the cdf curve
scarletnorberg Sep 12, 2025
b7e0147
Removing AMC was used in an older iteration
scarletnorberg Sep 18, 2025
8f7f47e
Changing this to include as the function in import
scarletnorberg Sep 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/plots.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,7 @@ Plots
.. autoclass:: deepdiagnostics.plots.Parity
:members: plot

.. autoclass:: deepdiagnostics.plots.CDFParity
:members: plot

.. bibliography::
70 changes: 66 additions & 4 deletions docs/source/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,76 @@ Installation

git clone https://github.com/deepskies/DeepDiagnostics/
pip install poetry
poetry shell
poetry install
poetry install
poetry run diagnose --help


Pre-requisites
-------------

DeepDiagnostics does not train models or generate data, they must be provided.
Possible model formats are listed in :ref:`models` and data formats in :ref:`data`.
If you are using a simulator, it must be registered by using `deepdiagnostics.utils.register.register_simulator`.
More information can be found in :ref:`custom_simulations`.

Output directories are automatically created, and if a run ID is not specified, one is generated.
Only if a run ID is specified will previous runs be overwritten.

Configuration
----
-------------

Description of the configuration file, including defaults, can be found in :ref:`configuration`.
Below is a minimal example.

...code-block:: yaml

common:
out_dir: "./deepdiagnostics_results/"
random_seed: 42
data:
data_engine: "H5Data"
data_path: "./data/my_data.h5"
simulator: "MySimulator"
simulator_kwargs: # Any augments used to initialize the simulator
foo: bar
model:
model_engine: "SBIModel"
model_path: "./models/my_model.pkl"
plots_common: # Used across all plots
parameter_labels: # Can either be plain strings or rendered LaTeX strings
- "My favorite parameter"
- "My least favorite parameter"
- "My most mid parameter"
parameter_colors: # Any color recognized by matplotlib
- "#264a95"
- "#ed9561"
- "#89b7bb"
line_style_cycle: # Any line type recognized by matplotlib
- solid
- dashed
- dotted
figure_size: # Approximate size, it can be scaled when adding additional subfigures
- 6 # x length
- 6 # y length
metrics_common: # Used across all metrics (and plots if the plots have a calculation step)
samples_per_inference: 1000
number_simulations: 100
percentiles:
- 68
- 95
plots:
CoverageFraction: # Arguments supplied to {plottype}.plot()
include_coverage_std: True
include_ideal_range: True
reference_line_label: "Ideal Coverage"
TARP:
coverage_sigma: 4
title: "TARP of My Model"
metrics:
AllSBC
Ranks:
num_bins: 3

Description of the configuration file, including defaults, can be found in :ref:`configuration`.

Pipeline
---------
Expand Down
Binary file added resources/saveddata/data_test.h5
Binary file not shown.
123 changes: 34 additions & 89 deletions src/deepdiagnostics/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,99 +46,32 @@ def __init__(
msg = f"Could not load the lookup table simulator - {e}. You cannot use generative diagnostics."
print(msg)

self.context = self._context()
self.thetas = self._thetas()

self.prior_dist = self.load_prior(prior, prior_kwargs)
self.n_dims = self.get_theta_true().shape[1]
self.n_dims = self.thetas.shape[1]
self.simulator_dimensions = simulation_dimensions if simulation_dimensions is not None else get_item("data", "simulator_dimensions", raise_exception=False)

def get_simulator_output_shape(self) -> tuple[Sequence[int]]:
"""
Run a single sample of the simulator to verify the out-shape.

Returns:
tuple[Sequence[int]]: Output shape of a single sample of the simulator.
"""
context_shape = self.true_context().shape
sim_out = self.simulator(theta=self.get_theta_true()[0:1, :], n_samples=context_shape[-1])
return sim_out.shape
self.simulator_outcome = self._simulator_outcome()

def _load(self, path: str):
raise NotImplementedError

def true_context(self):
"""
True data x values, if supplied by the data method.
"""
# From Data
def _simulator_outcome(self):
raise NotImplementedError

def _context(self):
raise NotImplementedError

def _thetas(self):
raise NotImplementedError

def true_simulator_outcome(self) -> np.ndarray:
"""
Run the simulator on all true theta and true x values.

Returns:
np.ndarray: array of (n samples, simulator shape) showing output of the simulator on all true samples in data.
"""
return self.simulator(self.get_theta_true(), self.true_context())

def sample_prior(self, n_samples: int) -> np.ndarray:
"""
Draw samples from the simulator

Args:
n_samples (int): Number of samples to draw

Returns:
np.ndarray:
"""
return self.prior_dist(size=(n_samples, self.n_dims))

def simulator_outcome(self, theta:np.ndarray, condition_context:np.ndarray=None, n_samples:int=None):
"""_summary_

Args:
theta (np.ndarray): Theta value of shape (n_samples, theta_dimensions)
condition_context (np.ndarray, optional): If x values for theta are known, use them. Defaults to None.
n_samples (int, optional): If x values are not known for theta, draw them randomly. Defaults to None.

Raises:
ValueError: If either n samples or content samples is supplied.

Returns:
np.ndarray: Simulator output of shape (n samples, simulator_dimensions)
"""
if condition_context is None:
if n_samples is None:
raise ValueError(
"Samples required if condition context is not specified"
)
return self.simulator(theta, n_samples)
else:
return self.simulator.simulate(theta, condition_context)

def simulated_context(self, n_samples:int) -> np.ndarray:
"""
Call the simulator's `generate_context` method.

Args:
n_samples (int): Number of samples to draw.

Returns:
np.ndarray: context (x values), as defined by the simulator.
"""
return self.simulator.generate_context(n_samples)

def get_theta_true(self) -> Union[Any, float, int, np.ndarray]:
"""
Look for the true theta given by data. If supplied in the method, use that, other look in the configuration file.
If neither are supplied, return None.
def save(self, data, path: str):
raise NotImplementedError

Returns:
Any: Theta value selected by the search.
"""
if hasattr(self, "theta_true"):
return self.theta_true
else:
return get_item("data", "theta_true", raise_exception=True)
def read_prior(self):
raise NotImplementedError

def get_sigma_true(self) -> Union[Any, float, int, np.ndarray]:
"""
Expand All @@ -153,12 +86,6 @@ def get_sigma_true(self) -> Union[Any, float, int, np.ndarray]:
else:
return get_item("data", "sigma_true", raise_exception=True)

def save(self, data, path: str):
raise NotImplementedError

def read_prior(self):
raise NotImplementedError

def load_prior(self, prior:str, prior_kwargs:dict[str, any]) -> callable:
"""
Load the prior.
Expand Down Expand Up @@ -201,3 +128,21 @@ def load_prior(self, prior:str, prior_kwargs:dict[str, any]) -> callable:

except KeyError as e:
raise RuntimeError(f"Data missing a prior specification - {e}")

def sample_prior(self, n_samples: int) -> np.ndarray:
"""
Sample from the prior.

Args:
n_samples (int): Number of samples to draw.

Returns:
np.ndarray: Samples drawn from the prior.
"""

if self.prior_dist is None:
prior = self.read_prior()
sample = self.rng.randint(0, len(prior), size=n_samples)
return prior[sample]
else:
return self.prior_dist(size=(n_samples, self.n_dims))
51 changes: 36 additions & 15 deletions src/deepdiagnostics/data/h5_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ class H5Data(Data):
"""
Load data that has been saved in a h5 format.

If you cast your problem to be y = mx + b, these are the fields required and what they represent:

simulator_outcome - y
thetas - parameters of the model - m, b
context - xs

.. attribute:: Data Parameters

:xs: [REQUIRED] The context, the x values. The data that was used to train a model on what conditions produce what posterior.
Expand All @@ -30,6 +36,7 @@ def __init__(self,
):
super().__init__(path, simulator, simulator_kwargs, prior, prior_kwargs, simulation_dimensions)


def _load(self, path):
assert path.split(".")[-1] == "h5", "File extension must be h5"
loaded_data = {}
Expand All @@ -49,17 +56,29 @@ def save(self, data: dict[str, Any], path: str): # Todo typing for data dict
for key, value in data_arrays.items():
file.create_dataset(key, data=value)

def true_context(self):
"""
Try to get the `xs` field of the loaded data.

Raises:
NotImplementedError: The data does not have a `xs` field.
"""
def _simulator_outcome(self):
try:
return self.data["xs"]
except KeyError:
raise NotImplementedError("Cannot find `xs` in data. Please supply it.")
return self.data["simulator_outcome"]
except KeyError:
try:
sim_outcome = np.array((self.simulator_dimensions, len(self.thetas)))
for index, theta in enumerate(self.thetas):
sim_out = self.simulator(theta=theta.unsqueeze(0), n_samples=1)
sim_outcome[:, index] = sim_out
return sim_outcome

except Exception as e:
e = f"Data does not have a `simulator_output` field and could not generate it from a simulator: {e}"
raise ValueError(e)

def _context(self):
try:
context = self.data["context"]
if context.ndim == 1:
context = context.unsqueeze(1)
return context
except KeyError:
raise NotImplementedError("Data does not have a `context` field.")

def prior(self):
"""
Expand All @@ -68,12 +87,14 @@ def prior(self):
Raises:
NotImplementedError: The data does not have a `prior` field.
"""
try:
if 'prior' in self.data:
return self.data['prior']
except KeyError:
raise NotImplementedError("Data does not have a `prior` field.")

def get_theta_true(self):
elif self.prior_dist is not None:
return self.prior_dist
else:
raise ValueError("Data does not have a `prior` field.")

def _thetas(self):
""" Get stored theta used to train the model.

Returns:
Expand Down
Loading
Loading