Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions tests/tasks/test_task_benchmarks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import re

import pyro
import pytest
import torch

from sbibm import get_available_tasks, get_task
from sbibm.algorithms.sbi.snpe import run
from sbibm.metrics.ppc import median_distance

pyro.util.set_rng_seed(47)

# ################################################
# ## demonstrate on how to run a minimal benchmark
# ## see https://github.com/sbi-benchmark/results/blob/main/benchmarking_sbi/run.py


@pytest.mark.parametrize(
"task_name",
[tn for tn in get_available_tasks() if not re.search("lotka|sir", tn)],
)
def test_benchmark_metrics_selfobserved(task_name):

task = get_task(task_name)

nobs = 1 # maybe randomly dice this?
theta_o = task.get_prior()(num_samples=nobs)
sim = task.get_simulator()
x_o = sim(theta_o)

outputs, nsim, logprob_truep = run(
task,
observation=x_o,
num_samples=16,
num_simulations=64,
neural_net="mdn",
num_rounds=1, # let's do NPE not SNPE (to avoid MCMC)
)

assert outputs.shape
assert outputs.shape[0] > 0
assert logprob_truep == None

predictive_samples = sim(outputs)
value = median_distance(predictive_samples, x_o)

assert value > 0
103 changes: 103 additions & 0 deletions tests/tasks/test_task_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import re

import pyro
import pytest
import torch

from sbibm import get_available_tasks, get_task

pyro.util.set_rng_seed(47)

all_tasks = set(get_available_tasks())
julia_tasks = set([tn for tn in get_available_tasks() if re.search("lotka|sir", tn)])
noref_tasks = set([tn for tn in get_available_tasks() if "noref" in tn])


@pytest.mark.parametrize("task_name", [tn for tn in (all_tasks - julia_tasks)])
def test_task_can_be_obtained(task_name):

task = get_task(task_name)

assert task is not None


@pytest.mark.parametrize("task_name", [tn for tn in (all_tasks - julia_tasks)])
def test_obtain_prior_from_task(task_name):

task = get_task(task_name)
prior = task.get_prior()

assert prior is not None


@pytest.mark.parametrize("task_name", [tn for tn in (all_tasks - julia_tasks)])
def test_obtain_simulator_from_task(task_name):

task = get_task(task_name)

simulator = task.get_simulator()

assert simulator is not None


@pytest.mark.parametrize("task_name", [tn for tn in (all_tasks - julia_tasks)])
def test_retrieve_observation_from_task(task_name):

task = get_task(task_name)

x_o = task.get_observation(num_observation=1)

assert x_o is not None
assert hasattr(x_o, "shape")
assert len(x_o.shape) > 1


@pytest.mark.parametrize("task_name", [tn for tn in (all_tasks - julia_tasks)])
def test_obtain_prior_samples_from_task(task_name):

task = get_task(task_name)
prior = task.get_prior()
nsamples = 10

thetas = prior(num_samples=nsamples)

assert thetas.shape[0] == nsamples


@pytest.mark.parametrize("task_name", [tn for tn in (all_tasks - julia_tasks)])
def test_simulate_from_thetas(task_name):

task = get_task(task_name)
prior = task.get_prior()
sim = task.get_simulator()
nsamples = 10

thetas = prior(num_samples=nsamples)
xs = sim(thetas)

assert xs.shape[0] == nsamples


@pytest.mark.parametrize(
"task_name", [tn for tn in (all_tasks - julia_tasks - noref_tasks)]
)
def test_reference_posterior_exists(task_name):

task = get_task(task_name)

reference_samples = task.get_reference_posterior_samples(num_observation=1)

assert hasattr(reference_samples, "shape")
assert len(reference_samples.shape) == 2
assert reference_samples.shape[0] > 0


@pytest.mark.parametrize("task_name", [tn for tn in noref_tasks])
def test_reference_posterior_not_called(task_name):

task = get_task(task_name)

with pytest.raises(NotImplementedError):
reference_samples = task.get_reference_posterior_samples(num_observation=1)

assert task is not None
40 changes: 40 additions & 0 deletions tests/tasks/test_task_rej_abc_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import re

import pyro
import pytest
import torch

from sbibm import get_available_tasks, get_task
from sbibm.algorithms import rej_abc
from sbibm.metrics import c2st

pyro.util.set_rng_seed(47)

task_list = [tn for tn in get_available_tasks() if not re.search("noref|lotka|sir", tn)]


@pytest.mark.parametrize("task_name", task_list)
def test_quick_demo_rej_abc(task_name):

task = get_task(task_name)
posterior_samples, _, _ = rej_abc(
task=task, num_samples=12, num_observation=1, num_simulations=500
)

assert posterior_samples != None
assert posterior_samples.shape[0] == 12


@pytest.mark.parametrize("task_name", task_list)
def test_quick_demo_c2st(task_name):

task = get_task(task_name)
posterior_samples, _, _ = rej_abc(
task=task, num_samples=50, num_observation=1, num_simulations=500
)

reference_samples = task.get_reference_posterior_samples(num_observation=1)
c2st_accuracy = c2st(reference_samples, posterior_samples)

assert c2st_accuracy > 0.0
assert c2st_accuracy < 1.0
17 changes: 17 additions & 0 deletions tests/tasks/two_moons/test_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pyro
import pytest
import torch

from sbibm.tasks.two_moons.task import TwoMoons

pyro.util.set_rng_seed(47)

## a test suite that can be used for task internal code


def test_task_constructs():
"""this test demonstrates how to test internal task code"""

t = TwoMoons()

assert t