Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
988 changes: 110 additions & 878 deletions notebooks/example.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/deepdiagnostics/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def main():

for metrics_name, metrics_args in metrics.items():
try:
Metrics[metrics_name](model, data, save=True)(**metrics_args)
Metrics[metrics_name](model, data, run_id, save=True)(**metrics_args)
except SimulatorMissingError:
print(f"Cannot run {metrics_name} - simulator missing.")

Expand Down
4 changes: 2 additions & 2 deletions src/deepdiagnostics/data/data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Sequence, Union
from typing import Any, Optional, Union
import numpy as np

from deepdiagnostics.utils.config import get_item
Expand Down Expand Up @@ -43,7 +43,7 @@ def __init__(
try:
self.simulator = LookupTableSimulator(self.data, self.rng)
except ValueError as e:
msg = f"Could not load the lookup table simulator - {e}. You cannot use generative diagnostics."
msg = f"Could not load the lookup table simulator - {e}. You cannot use online diagnostics."
print(msg)

self.context = self._context()
Expand Down
30 changes: 21 additions & 9 deletions src/deepdiagnostics/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,27 @@
from deepdiagnostics.metrics.all_sbc import AllSBC
from deepdiagnostics.metrics.coverage_fraction import CoverageFraction
from deepdiagnostics.metrics.local_two_sample import LocalTwoSampleTest as LC2ST
import importlib
import inspect
from pathlib import Path
from deepdiagnostics.metrics.metric import Metric

# Void is included as a placeholder for empty metrics
def void(*args, **kwargs):
def void2(*args, **kwargs):
return None
return void2

Metrics = {
"": void,
CoverageFraction.__name__: CoverageFraction,
AllSBC.__name__: AllSBC,
"LC2ST": LC2ST
}
Metrics = {"": void}
__all__ = []
for file in Path(__file__).parent.glob("*.py"):
if file.name.startswith("__") or file.name == "metric.py":
continue
module = importlib.import_module(f"deepdiagnostics.metrics.{file.stem}")
for _, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Metric) and obj != Metric:
Metrics[obj.__name__] = obj
globals()[obj.__name__] = obj
__all__.append(obj.__name__)

if 'LocalTwoSampleTest' in Metrics:
Metrics['LC2ST'] = Metrics['LocalTwoSampleTest']
globals()['lc2st'] = Metrics['LocalTwoSampleTest']
__all__.append('lc2st')
42 changes: 21 additions & 21 deletions src/deepdiagnostics/plots/__init__.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
from deepdiagnostics.plots.cdf_ranks import CDFRanks
from deepdiagnostics.plots.coverage_fraction import CoverageFraction
from deepdiagnostics.plots.ranks import Ranks
from deepdiagnostics.plots.tarp import TARP
from deepdiagnostics.plots.local_two_sample import LocalTwoSampleTest as LC2ST
from deepdiagnostics.plots.predictive_posterior_check import PPC
from deepdiagnostics.plots.parity import Parity
from deepdiagnostics.plots.predictive_prior_check import PriorPC
from deepdiagnostics.plots.cdf_parity import CDFParityPlot
import importlib
import inspect
from pathlib import Path

from deepdiagnostics.plots.plot import Display

# Void is included as a placeholder for empty metrics
def void(*args, **kwargs):
def void2(*args, **kwargs):
return None
return void2

Plots = {
"": void,
CDFRanks.__name__: CDFRanks,
CoverageFraction.__name__: CoverageFraction,
Ranks.__name__: Ranks,
TARP.__name__: TARP,
"LC2ST": LC2ST,
PPC.__name__: PPC,
"Parity": Parity,
PriorPC.__name__: PriorPC,
CDFParityPlot.__name__: CDFParityPlot
}
Plots = {"": void}
__all__ = []
for file in Path(__file__).parent.glob("*.py"):
if file.name.startswith("__") or file.name == "plot.py":
continue
module = importlib.import_module(f"deepdiagnostics.plots.{file.stem}")
for name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Display) and obj != Display:
Plots[obj.__name__] = obj
globals()[obj.__name__] = obj
__all__.append(obj.__name__)

if 'LocalTwoSampleTest' in Plots:
Plots['LC2ST'] = Plots['LocalTwoSampleTest']
globals()['lc2st'] = Plots['LocalTwoSampleTest']
__all__.append('lc2st')
9 changes: 6 additions & 3 deletions src/deepdiagnostics/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def get_section(section, raise_exception=True):

class Config:
ENV_VAR_PATH = "DeepDiagnostics_Config"

_printed_warning = False
def __init__(self, config_path: Optional[str] = None) -> None:
if config_path is not None:
# Add it to the env vars in case we need to get it later.
Expand All @@ -31,7 +31,9 @@ def __init__(self, config_path: Optional[str] = None) -> None:
self._validate_config()

except KeyError:
print("Warning: Cannot load config from environment. Hint: Have you set the config path by passing a str path to Config?")
if not Config._printed_warning:
Config._printed_warning = True
print("Warning: Cannot load config from environment. Hint: Have you set the config path by passing a str path to Config?", flush=True)
self.config = Defaults

def _validate_config(self):
Expand All @@ -40,7 +42,8 @@ def _validate_config(self):
pass

def _read_config(self, path):
assert os.path.exists(path), f"Config path at {path} does not exist."
if not os.path.exists(path):
raise FileNotFoundError(f"Config path at {path} does not exist.")
with open(path, "r") as f:
config = yaml.safe_load(f)
return config
Expand Down
4 changes: 2 additions & 2 deletions src/deepdiagnostics/utils/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@
},
"metrics_common": {
"use_progress_bar": False,
"samples_per_inference": 1000,
"samples_per_inference": 100,
"percentiles": [75, 85, 95],
"number_simulations": 50,
"number_simulations": 10,
},
"metrics": {
"AllSBC": {},
Expand Down
Loading