diff --git a/tests/tasks/test_task_benchmarks.py b/tests/tasks/test_task_benchmarks.py new file mode 100644 index 00000000..f6960a84 --- /dev/null +++ b/tests/tasks/test_task_benchmarks.py @@ -0,0 +1,47 @@ +import re + +import pyro +import pytest +import torch + +from sbibm import get_available_tasks, get_task +from sbibm.algorithms.sbi.snpe import run +from sbibm.metrics.ppc import median_distance + +pyro.util.set_rng_seed(47) + +# ################################################ +# ## demonstrate on how to run a minimal benchmark +# ## see https://github.com/sbi-benchmark/results/blob/main/benchmarking_sbi/run.py + + +@pytest.mark.parametrize( + "task_name", + [tn for tn in get_available_tasks() if not re.search("lotka|sir", tn)], +) +def test_benchmark_metrics_selfobserved(task_name): + + task = get_task(task_name) + + nobs = 1 # maybe randomly dice this? + theta_o = task.get_prior()(num_samples=nobs) + sim = task.get_simulator() + x_o = sim(theta_o) + + outputs, nsim, logprob_truep = run( + task, + observation=x_o, + num_samples=16, + num_simulations=64, + neural_net="mdn", + num_rounds=1, # let's do NPE not SNPE (to avoid MCMC) + ) + + assert outputs.shape + assert outputs.shape[0] > 0 + assert logprob_truep == None + + predictive_samples = sim(outputs) + value = median_distance(predictive_samples, x_o) + + assert value > 0 diff --git a/tests/tasks/test_task_interface.py b/tests/tasks/test_task_interface.py new file mode 100644 index 00000000..b5540e7d --- /dev/null +++ b/tests/tasks/test_task_interface.py @@ -0,0 +1,103 @@ +import re + +import pyro +import pytest +import torch + +from sbibm import get_available_tasks, get_task + +pyro.util.set_rng_seed(47) + +all_tasks = set(get_available_tasks()) +julia_tasks = set([tn for tn in get_available_tasks() if re.search("lotka|sir", tn)]) +noref_tasks = set([tn for tn in get_available_tasks() if "noref" in tn]) + + +@pytest.mark.parametrize("task_name", [tn for tn in (all_tasks - julia_tasks)]) +def test_task_can_be_obtained(task_name): + + task = get_task(task_name) + + assert task is not None + + +@pytest.mark.parametrize("task_name", [tn for tn in (all_tasks - julia_tasks)]) +def test_obtain_prior_from_task(task_name): + + task = get_task(task_name) + prior = task.get_prior() + + assert prior is not None + + +@pytest.mark.parametrize("task_name", [tn for tn in (all_tasks - julia_tasks)]) +def test_obtain_simulator_from_task(task_name): + + task = get_task(task_name) + + simulator = task.get_simulator() + + assert simulator is not None + + +@pytest.mark.parametrize("task_name", [tn for tn in (all_tasks - julia_tasks)]) +def test_retrieve_observation_from_task(task_name): + + task = get_task(task_name) + + x_o = task.get_observation(num_observation=1) + + assert x_o is not None + assert hasattr(x_o, "shape") + assert len(x_o.shape) > 1 + + +@pytest.mark.parametrize("task_name", [tn for tn in (all_tasks - julia_tasks)]) +def test_obtain_prior_samples_from_task(task_name): + + task = get_task(task_name) + prior = task.get_prior() + nsamples = 10 + + thetas = prior(num_samples=nsamples) + + assert thetas.shape[0] == nsamples + + +@pytest.mark.parametrize("task_name", [tn for tn in (all_tasks - julia_tasks)]) +def test_simulate_from_thetas(task_name): + + task = get_task(task_name) + prior = task.get_prior() + sim = task.get_simulator() + nsamples = 10 + + thetas = prior(num_samples=nsamples) + xs = sim(thetas) + + assert xs.shape[0] == nsamples + + +@pytest.mark.parametrize( + "task_name", [tn for tn in (all_tasks - julia_tasks - noref_tasks)] +) +def test_reference_posterior_exists(task_name): + + task = get_task(task_name) + + reference_samples = task.get_reference_posterior_samples(num_observation=1) + + assert hasattr(reference_samples, "shape") + assert len(reference_samples.shape) == 2 + assert reference_samples.shape[0] > 0 + + +@pytest.mark.parametrize("task_name", [tn for tn in noref_tasks]) +def test_reference_posterior_not_called(task_name): + + task = get_task(task_name) + + with pytest.raises(NotImplementedError): + reference_samples = task.get_reference_posterior_samples(num_observation=1) + + assert task is not None diff --git a/tests/tasks/test_task_rej_abc_demo.py b/tests/tasks/test_task_rej_abc_demo.py new file mode 100644 index 00000000..6199ee54 --- /dev/null +++ b/tests/tasks/test_task_rej_abc_demo.py @@ -0,0 +1,40 @@ +import re + +import pyro +import pytest +import torch + +from sbibm import get_available_tasks, get_task +from sbibm.algorithms import rej_abc +from sbibm.metrics import c2st + +pyro.util.set_rng_seed(47) + +task_list = [tn for tn in get_available_tasks() if not re.search("noref|lotka|sir", tn)] + + +@pytest.mark.parametrize("task_name", task_list) +def test_quick_demo_rej_abc(task_name): + + task = get_task(task_name) + posterior_samples, _, _ = rej_abc( + task=task, num_samples=12, num_observation=1, num_simulations=500 + ) + + assert posterior_samples != None + assert posterior_samples.shape[0] == 12 + + +@pytest.mark.parametrize("task_name", task_list) +def test_quick_demo_c2st(task_name): + + task = get_task(task_name) + posterior_samples, _, _ = rej_abc( + task=task, num_samples=50, num_observation=1, num_simulations=500 + ) + + reference_samples = task.get_reference_posterior_samples(num_observation=1) + c2st_accuracy = c2st(reference_samples, posterior_samples) + + assert c2st_accuracy > 0.0 + assert c2st_accuracy < 1.0 diff --git a/tests/tasks/two_moons/test_task.py b/tests/tasks/two_moons/test_task.py new file mode 100644 index 00000000..14e0a456 --- /dev/null +++ b/tests/tasks/two_moons/test_task.py @@ -0,0 +1,17 @@ +import pyro +import pytest +import torch + +from sbibm.tasks.two_moons.task import TwoMoons + +pyro.util.set_rng_seed(47) + +## a test suite that can be used for task internal code + + +def test_task_constructs(): + """this test demonstrates how to test internal task code""" + + t = TwoMoons() + + assert t