From a9a0fc673d80aa939c7e1907ecb9f913cde769eb Mon Sep 17 00:00:00 2001
From: Peter Steinbach
Date: Tue, 9 Nov 2021 17:26:32 +0100
Subject: [PATCH 1/6] unit test for two moons, basis for testing other tasks
---
tests/tasks/test_two_moons.py | 193 ++++++++++++++++++++++++++++++++++
1 file changed, 193 insertions(+)
create mode 100644 tests/tasks/test_two_moons.py
diff --git a/tests/tasks/test_two_moons.py b/tests/tasks/test_two_moons.py
new file mode 100644
index 00000000..cf6f094f
--- /dev/null
+++ b/tests/tasks/test_two_moons.py
@@ -0,0 +1,193 @@
+import pytest
+import torch
+
+import sbibm
+from sbibm.tasks.two_moons.task import TwoMoons
+
+torch.manual_seed(47)
+
+
+def test_task_constructs():
+
+ t = TwoMoons()
+
+ assert t
+
+
+def test_obtain_task():
+
+ task = sbibm.get_task("two_moons")
+
+ assert task is not None
+
+
+def test_obtain_prior():
+
+ task = sbibm.get_task("two_moons") # See sbibm.get_available_tasks() for all tasks
+ prior = task.get_prior()
+
+ assert prior is not None
+
+
+def test_obtain_simulator():
+
+ task = sbibm.get_task("two_moons")
+
+ simulator = task.get_simulator()
+
+ assert simulator is not None
+
+
+def test_observe_once():
+
+ task = sbibm.get_task("two_moons")
+
+ x_o = task.get_observation(num_observation=1)
+
+ assert x_o is not None
+ assert hasattr(x_o, "shape")
+
+
+def test_obtain_prior_samples():
+
+ task = sbibm.get_task("two_moons")
+ prior = task.get_prior()
+ nsamples = 10
+
+ thetas = prior(num_samples=nsamples)
+
+ assert thetas.shape == (nsamples, 2)
+
+
+def test_simulate_from_thetas():
+
+ task = sbibm.get_task("two_moons")
+ prior = task.get_prior()
+ sim = task.get_simulator()
+ nsamples = 10
+
+ thetas = prior(num_samples=nsamples)
+ xs = sim(thetas)
+
+ assert xs.shape == (nsamples, 2)
+
+
+def test_reference_posterior_exists():
+
+ task = sbibm.get_task("two_moons")
+
+ reference_samples = task.get_reference_posterior_samples(num_observation=1)
+
+ assert hasattr(reference_samples, "shape")
+ assert len(reference_samples.shape) == 2
+ assert reference_samples.shape == (10_000, 2)
+
+
+# @pytest.fixture
+# def vanilla_samples():
+
+# task = sbibm.get_task("two_moons")
+# prior = task.get_prior()
+# sim = task.get_simulator()
+# nsamples = 1_000
+
+# thetas = prior(num_samples=nsamples)
+# xs = sim(thetas)
+
+# return task, thetas, xs
+
+
+def test_quick_demo_rej_abc():
+
+ from sbibm.algorithms import rej_abc # See help(rej_abc) for keywords
+
+ task = sbibm.get_task("two_moons")
+ posterior_samples, _, _ = rej_abc(
+ task=task, num_samples=50, num_observation=1, num_simulations=500
+ )
+
+ assert posterior_samples != None
+ assert posterior_samples.shape[0] == 50
+
+
+def test_quick_demo_c2st():
+
+ from sbibm.algorithms import rej_abc # See help(rej_abc) for keywords
+
+ task = sbibm.get_task("two_moons")
+ posterior_samples, _, _ = rej_abc(
+ task=task, num_samples=50, num_observation=1, num_simulations=500
+ )
+
+ from sbibm.metrics import c2st
+
+ reference_samples = task.get_reference_posterior_samples(num_observation=1)
+ c2st_accuracy = c2st(reference_samples, posterior_samples)
+
+ assert c2st_accuracy > 0.0
+ assert c2st_accuracy < 1.0
+
+
+################################################
+## demonstrate on how to run a minimal benchmark
+## see https://github.com/sbi-benchmark/results/blob/main/benchmarking_sbi/run.py
+
+
+def test_benchmark_metrics_selfobserved():
+
+ from sbibm.algorithms.sbi.snpe import run
+ from sbibm.metrics.ppc import median_distance
+
+ task = sbibm.get_task("two_moons")
+
+ nobs = 1
+ theta_o = task.get_prior()(num_samples=nobs)
+ sim = task.get_simulator()
+ x_o = sim(theta_o)
+
+ outputs, nsim, logprob_truep = run(
+ task,
+ observation=x_o,
+ num_samples=16,
+ num_simulations=64,
+ neural_net="mdn",
+ num_rounds=1, # let's do NPE not SNPE (to avoid MCMC)
+ )
+
+ assert outputs.shape
+ assert outputs.shape[0] > 0
+ assert logprob_truep == None
+
+ predictive_samples = sim(outputs)
+ value = median_distance(predictive_samples, x_o)
+
+ assert value > 0
+ assert value > 0.5
+
+
+def test_benchmark_metrics():
+
+ from sbibm.algorithms.sbi.snpe import run
+ from sbibm.metrics.ppc import median_distance
+
+ task = sbibm.get_task("two_moons")
+ sim = task.get_simulator()
+
+ outputs, nsim, logprob_truep = run(
+ task,
+ num_observation=7,
+ num_samples=64,
+ num_simulations=100,
+ neural_net="mdn",
+ num_rounds=1, # let's do NPE not SNPE (to avoid MCMC)
+ )
+
+ assert outputs.shape
+ assert outputs.shape[0] > 0
+ assert logprob_truep == None
+
+ predictive_samples = sim(outputs)
+ x_o = task.get_observation(7)
+ value = median_distance(predictive_samples, x_o)
+
+ assert value > 0
From b7b78f7d7234f17fd0f80d2ac50118702f795da7 Mon Sep 17 00:00:00 2001
From: Peter Steinbach
Date: Thu, 11 Nov 2021 17:57:06 +0100
Subject: [PATCH 2/6] unit tests for 'all' tests
- 3 types of test suites added
+ test_task_interface.py for mere API tests
+ test_task_rej_abc_demo.py for testing the API demonstrated on the
landing page (README.md)
+ test_task_benchmark.py to see/document if the benchmarks work
- added some "noref" sentinels for tasks which do not have a reference posterior
- using sets to better work with list of tasks to run tests for
- as of now, the tests exclude julia based tests
---
tests/tasks/test_task_benchmarks.py | 75 ++++++++++
tests/tasks/test_task_interface.py | 103 ++++++++++++++
tests/tasks/test_task_rej_abc_demo.py | 40 ++++++
tests/tasks/test_two_moons.py | 193 --------------------------
4 files changed, 218 insertions(+), 193 deletions(-)
create mode 100644 tests/tasks/test_task_benchmarks.py
create mode 100644 tests/tasks/test_task_interface.py
create mode 100644 tests/tasks/test_task_rej_abc_demo.py
delete mode 100644 tests/tasks/test_two_moons.py
diff --git a/tests/tasks/test_task_benchmarks.py b/tests/tasks/test_task_benchmarks.py
new file mode 100644
index 00000000..7814e293
--- /dev/null
+++ b/tests/tasks/test_task_benchmarks.py
@@ -0,0 +1,75 @@
+import re
+
+import pytest
+import torch
+
+from sbibm import get_available_tasks, get_task
+from sbibm.algorithms.sbi.snpe import run
+from sbibm.metrics.ppc import median_distance
+
+# maybe use the pyro facilities
+torch.manual_seed(47)
+
+# ################################################
+# ## demonstrate on how to run a minimal benchmark
+# ## see https://github.com/sbi-benchmark/results/blob/main/benchmarking_sbi/run.py
+
+
+@pytest.mark.parametrize(
+ "task_name",
+ [tn for tn in get_available_tasks() if not re.search("noref|lotka|sir", tn)],
+)
+def test_benchmark_metrics_selfobserved(task_name):
+
+ task = get_task(task_name)
+
+ nobs = 1
+ theta_o = task.get_prior()(num_samples=nobs)
+ sim = task.get_simulator()
+ x_o = sim(theta_o)
+
+ outputs, nsim, logprob_truep = run(
+ task,
+ observation=x_o,
+ num_samples=16,
+ num_simulations=64,
+ neural_net="mdn",
+ num_rounds=1, # let's do NPE not SNPE (to avoid MCMC)
+ )
+
+ assert outputs.shape
+ assert outputs.shape[0] > 0
+ assert logprob_truep == None
+
+ predictive_samples = sim(outputs)
+ value = median_distance(predictive_samples, x_o)
+
+ assert value > 0
+
+
+# def test_benchmark_metrics():
+
+# from sbibm.algorithms.sbi.snpe import run
+# from sbibm.metrics.ppc import median_distance
+
+# task = get_task("two_moons")
+# sim = task.get_simulator()
+
+# outputs, nsim, logprob_truep = run(
+# task,
+# num_observation=7,
+# num_samples=64,
+# num_simulations=100,
+# neural_net="mdn",
+# num_rounds=1, # let's do NPE not SNPE (to avoid MCMC)
+# )
+
+# assert outputs.shape
+# assert outputs.shape[0] > 0
+# assert logprob_truep == None
+
+# predictive_samples = sim(outputs)
+# x_o = task.get_observation(7)
+# value = median_distance(predictive_samples, x_o)
+
+# assert value > 0
diff --git a/tests/tasks/test_task_interface.py b/tests/tasks/test_task_interface.py
new file mode 100644
index 00000000..43cc2e30
--- /dev/null
+++ b/tests/tasks/test_task_interface.py
@@ -0,0 +1,103 @@
+import re
+
+import pytest
+import torch
+
+from sbibm import get_available_tasks, get_task
+
+# maybe use the pyro facilities
+torch.manual_seed(47)
+
+all_tasks = set(get_available_tasks())
+julia_tasks = set([tn for tn in get_available_tasks() if re.search("lotka|sir", tn)])
+noref_tasks = set([tn for tn in get_available_tasks() if "noref" in tn])
+
+
+@pytest.mark.parametrize("task_name", [tn for tn in (all_tasks - julia_tasks)])
+def test_task_can_be_obtained(task_name):
+
+ task = get_task(task_name)
+
+ assert task is not None
+
+
+@pytest.mark.parametrize("task_name", [tn for tn in (all_tasks - julia_tasks)])
+def test_obtain_prior_from_task(task_name):
+
+ task = get_task(task_name)
+ prior = task.get_prior()
+
+ assert prior is not None
+
+
+@pytest.mark.parametrize("task_name", [tn for tn in (all_tasks - julia_tasks)])
+def test_obtain_simulator_from_task(task_name):
+
+ task = get_task(task_name)
+
+ simulator = task.get_simulator()
+
+ assert simulator is not None
+
+
+@pytest.mark.parametrize("task_name", [tn for tn in (all_tasks - julia_tasks)])
+def test_retrieve_observation_from_task(task_name):
+
+ task = get_task(task_name)
+
+ x_o = task.get_observation(num_observation=1)
+
+ assert x_o is not None
+ assert hasattr(x_o, "shape")
+ assert len(x_o.shape) > 1
+
+
+@pytest.mark.parametrize("task_name", [tn for tn in (all_tasks - julia_tasks)])
+def test_obtain_prior_samples_from_task(task_name):
+
+ task = get_task(task_name)
+ prior = task.get_prior()
+ nsamples = 10
+
+ thetas = prior(num_samples=nsamples)
+
+ assert thetas.shape[0] == nsamples
+
+
+@pytest.mark.parametrize("task_name", [tn for tn in (all_tasks - julia_tasks)])
+def test_simulate_from_thetas(task_name):
+
+ task = get_task(task_name)
+ prior = task.get_prior()
+ sim = task.get_simulator()
+ nsamples = 10
+
+ thetas = prior(num_samples=nsamples)
+ xs = sim(thetas)
+
+ assert xs.shape[0] == nsamples
+
+
+@pytest.mark.parametrize(
+ "task_name", [tn for tn in (all_tasks - julia_tasks - noref_tasks)]
+)
+def test_reference_posterior_exists(task_name):
+
+ task = get_task(task_name)
+
+ reference_samples = task.get_reference_posterior_samples(num_observation=1)
+
+ assert hasattr(reference_samples, "shape")
+ assert len(reference_samples.shape) == 2
+ assert reference_samples.shape[0] > 0
+
+
+@pytest.mark.parametrize("task_name", [tn for tn in noref_tasks])
+def test_reference_posterior_not_called(task_name):
+
+ task = get_task(task_name)
+
+ with pytest.raises(NotImplementedError):
+ reference_samples = task.get_reference_posterior_samples(num_observation=1)
+
+ assert task is not None
diff --git a/tests/tasks/test_task_rej_abc_demo.py b/tests/tasks/test_task_rej_abc_demo.py
new file mode 100644
index 00000000..7c5c7bc0
--- /dev/null
+++ b/tests/tasks/test_task_rej_abc_demo.py
@@ -0,0 +1,40 @@
+import re
+
+import pytest
+import torch
+
+from sbibm import get_available_tasks, get_task
+from sbibm.algorithms import rej_abc
+from sbibm.metrics import c2st
+
+# maybe use the pyro facilities
+torch.manual_seed(47)
+
+task_list = [tn for tn in get_available_tasks() if not re.search("noref|lotka|sir", tn)]
+
+
+@pytest.mark.parametrize("task_name", task_list)
+def test_quick_demo_rej_abc(task_name):
+
+ task = get_task(task_name)
+ posterior_samples, _, _ = rej_abc(
+ task=task, num_samples=12, num_observation=1, num_simulations=500
+ )
+
+ assert posterior_samples != None
+ assert posterior_samples.shape[0] == 12
+
+
+@pytest.mark.parametrize("task_name", task_list)
+def test_quick_demo_c2st(task_name):
+
+ task = get_task(task_name)
+ posterior_samples, _, _ = rej_abc(
+ task=task, num_samples=50, num_observation=1, num_simulations=500
+ )
+
+ reference_samples = task.get_reference_posterior_samples(num_observation=1)
+ c2st_accuracy = c2st(reference_samples, posterior_samples)
+
+ assert c2st_accuracy > 0.0
+ assert c2st_accuracy < 1.0
diff --git a/tests/tasks/test_two_moons.py b/tests/tasks/test_two_moons.py
deleted file mode 100644
index cf6f094f..00000000
--- a/tests/tasks/test_two_moons.py
+++ /dev/null
@@ -1,193 +0,0 @@
-import pytest
-import torch
-
-import sbibm
-from sbibm.tasks.two_moons.task import TwoMoons
-
-torch.manual_seed(47)
-
-
-def test_task_constructs():
-
- t = TwoMoons()
-
- assert t
-
-
-def test_obtain_task():
-
- task = sbibm.get_task("two_moons")
-
- assert task is not None
-
-
-def test_obtain_prior():
-
- task = sbibm.get_task("two_moons") # See sbibm.get_available_tasks() for all tasks
- prior = task.get_prior()
-
- assert prior is not None
-
-
-def test_obtain_simulator():
-
- task = sbibm.get_task("two_moons")
-
- simulator = task.get_simulator()
-
- assert simulator is not None
-
-
-def test_observe_once():
-
- task = sbibm.get_task("two_moons")
-
- x_o = task.get_observation(num_observation=1)
-
- assert x_o is not None
- assert hasattr(x_o, "shape")
-
-
-def test_obtain_prior_samples():
-
- task = sbibm.get_task("two_moons")
- prior = task.get_prior()
- nsamples = 10
-
- thetas = prior(num_samples=nsamples)
-
- assert thetas.shape == (nsamples, 2)
-
-
-def test_simulate_from_thetas():
-
- task = sbibm.get_task("two_moons")
- prior = task.get_prior()
- sim = task.get_simulator()
- nsamples = 10
-
- thetas = prior(num_samples=nsamples)
- xs = sim(thetas)
-
- assert xs.shape == (nsamples, 2)
-
-
-def test_reference_posterior_exists():
-
- task = sbibm.get_task("two_moons")
-
- reference_samples = task.get_reference_posterior_samples(num_observation=1)
-
- assert hasattr(reference_samples, "shape")
- assert len(reference_samples.shape) == 2
- assert reference_samples.shape == (10_000, 2)
-
-
-# @pytest.fixture
-# def vanilla_samples():
-
-# task = sbibm.get_task("two_moons")
-# prior = task.get_prior()
-# sim = task.get_simulator()
-# nsamples = 1_000
-
-# thetas = prior(num_samples=nsamples)
-# xs = sim(thetas)
-
-# return task, thetas, xs
-
-
-def test_quick_demo_rej_abc():
-
- from sbibm.algorithms import rej_abc # See help(rej_abc) for keywords
-
- task = sbibm.get_task("two_moons")
- posterior_samples, _, _ = rej_abc(
- task=task, num_samples=50, num_observation=1, num_simulations=500
- )
-
- assert posterior_samples != None
- assert posterior_samples.shape[0] == 50
-
-
-def test_quick_demo_c2st():
-
- from sbibm.algorithms import rej_abc # See help(rej_abc) for keywords
-
- task = sbibm.get_task("two_moons")
- posterior_samples, _, _ = rej_abc(
- task=task, num_samples=50, num_observation=1, num_simulations=500
- )
-
- from sbibm.metrics import c2st
-
- reference_samples = task.get_reference_posterior_samples(num_observation=1)
- c2st_accuracy = c2st(reference_samples, posterior_samples)
-
- assert c2st_accuracy > 0.0
- assert c2st_accuracy < 1.0
-
-
-################################################
-## demonstrate on how to run a minimal benchmark
-## see https://github.com/sbi-benchmark/results/blob/main/benchmarking_sbi/run.py
-
-
-def test_benchmark_metrics_selfobserved():
-
- from sbibm.algorithms.sbi.snpe import run
- from sbibm.metrics.ppc import median_distance
-
- task = sbibm.get_task("two_moons")
-
- nobs = 1
- theta_o = task.get_prior()(num_samples=nobs)
- sim = task.get_simulator()
- x_o = sim(theta_o)
-
- outputs, nsim, logprob_truep = run(
- task,
- observation=x_o,
- num_samples=16,
- num_simulations=64,
- neural_net="mdn",
- num_rounds=1, # let's do NPE not SNPE (to avoid MCMC)
- )
-
- assert outputs.shape
- assert outputs.shape[0] > 0
- assert logprob_truep == None
-
- predictive_samples = sim(outputs)
- value = median_distance(predictive_samples, x_o)
-
- assert value > 0
- assert value > 0.5
-
-
-def test_benchmark_metrics():
-
- from sbibm.algorithms.sbi.snpe import run
- from sbibm.metrics.ppc import median_distance
-
- task = sbibm.get_task("two_moons")
- sim = task.get_simulator()
-
- outputs, nsim, logprob_truep = run(
- task,
- num_observation=7,
- num_samples=64,
- num_simulations=100,
- neural_net="mdn",
- num_rounds=1, # let's do NPE not SNPE (to avoid MCMC)
- )
-
- assert outputs.shape
- assert outputs.shape[0] > 0
- assert logprob_truep == None
-
- predictive_samples = sim(outputs)
- x_o = task.get_observation(7)
- value = median_distance(predictive_samples, x_o)
-
- assert value > 0
From c224755fe4cf62e366b09bb9eedbbfd37f17b965 Mon Sep 17 00:00:00 2001
From: Peter Steinbach
Date: Thu, 11 Nov 2021 18:04:14 +0100
Subject: [PATCH 3/6] placeholder test for task-only code
---
tests/tasks/two_moons/test_task.py | 16 ++++++++++++++++
1 file changed, 16 insertions(+)
create mode 100644 tests/tasks/two_moons/test_task.py
diff --git a/tests/tasks/two_moons/test_task.py b/tests/tasks/two_moons/test_task.py
new file mode 100644
index 00000000..7b2a4a33
--- /dev/null
+++ b/tests/tasks/two_moons/test_task.py
@@ -0,0 +1,16 @@
+import pytest
+import torch
+
+from sbibm.tasks.two_moons.task import TwoMoons
+
+torch.manual_seed(47)
+
+## a test suite that can be used for task internal code
+
+
+def test_task_constructs():
+ """this test demonstrates how to test internal task code"""
+
+ t = TwoMoons()
+
+ assert t
From 26d92f55012f28fbe086807c771ed1a46df56141 Mon Sep 17 00:00:00 2001
From: Peter Steinbach
Date: Fri, 12 Nov 2021 10:50:55 +0100
Subject: [PATCH 4/6] removed superfluous code
- include noref tasks as we don't use the reference posterior
- add TODO for later
---
tests/tasks/test_task_benchmarks.py | 32 ++---------------------------
1 file changed, 2 insertions(+), 30 deletions(-)
diff --git a/tests/tasks/test_task_benchmarks.py b/tests/tasks/test_task_benchmarks.py
index 7814e293..c259b7c6 100644
--- a/tests/tasks/test_task_benchmarks.py
+++ b/tests/tasks/test_task_benchmarks.py
@@ -17,13 +17,13 @@
@pytest.mark.parametrize(
"task_name",
- [tn for tn in get_available_tasks() if not re.search("noref|lotka|sir", tn)],
+ [tn for tn in get_available_tasks() if not re.search("lotka|sir", tn)],
)
def test_benchmark_metrics_selfobserved(task_name):
task = get_task(task_name)
- nobs = 1
+ nobs = 1 # maybe randomly dice this?
theta_o = task.get_prior()(num_samples=nobs)
sim = task.get_simulator()
x_o = sim(theta_o)
@@ -45,31 +45,3 @@ def test_benchmark_metrics_selfobserved(task_name):
value = median_distance(predictive_samples, x_o)
assert value > 0
-
-
-# def test_benchmark_metrics():
-
-# from sbibm.algorithms.sbi.snpe import run
-# from sbibm.metrics.ppc import median_distance
-
-# task = get_task("two_moons")
-# sim = task.get_simulator()
-
-# outputs, nsim, logprob_truep = run(
-# task,
-# num_observation=7,
-# num_samples=64,
-# num_simulations=100,
-# neural_net="mdn",
-# num_rounds=1, # let's do NPE not SNPE (to avoid MCMC)
-# )
-
-# assert outputs.shape
-# assert outputs.shape[0] > 0
-# assert logprob_truep == None
-
-# predictive_samples = sim(outputs)
-# x_o = task.get_observation(7)
-# value = median_distance(predictive_samples, x_o)
-
-# assert value > 0
From 9e9181825ff38d2cd3ab0234632a387eeeff1761 Mon Sep 17 00:00:00 2001
From: Peter Steinbach
Date: Fri, 12 Nov 2021 10:57:59 +0100
Subject: [PATCH 5/6] using pyro set_rng_seed utility to fix seed in tests
---
tests/tasks/test_task_benchmarks.py | 4 ++--
tests/tasks/test_task_interface.py | 4 ++--
tests/tasks/test_task_rej_abc_demo.py | 4 ++--
3 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/tests/tasks/test_task_benchmarks.py b/tests/tasks/test_task_benchmarks.py
index c259b7c6..f6960a84 100644
--- a/tests/tasks/test_task_benchmarks.py
+++ b/tests/tasks/test_task_benchmarks.py
@@ -1,5 +1,6 @@
import re
+import pyro
import pytest
import torch
@@ -7,8 +8,7 @@
from sbibm.algorithms.sbi.snpe import run
from sbibm.metrics.ppc import median_distance
-# maybe use the pyro facilities
-torch.manual_seed(47)
+pyro.util.set_rng_seed(47)
# ################################################
# ## demonstrate on how to run a minimal benchmark
diff --git a/tests/tasks/test_task_interface.py b/tests/tasks/test_task_interface.py
index 43cc2e30..b5540e7d 100644
--- a/tests/tasks/test_task_interface.py
+++ b/tests/tasks/test_task_interface.py
@@ -1,12 +1,12 @@
import re
+import pyro
import pytest
import torch
from sbibm import get_available_tasks, get_task
-# maybe use the pyro facilities
-torch.manual_seed(47)
+pyro.util.set_rng_seed(47)
all_tasks = set(get_available_tasks())
julia_tasks = set([tn for tn in get_available_tasks() if re.search("lotka|sir", tn)])
diff --git a/tests/tasks/test_task_rej_abc_demo.py b/tests/tasks/test_task_rej_abc_demo.py
index 7c5c7bc0..6199ee54 100644
--- a/tests/tasks/test_task_rej_abc_demo.py
+++ b/tests/tasks/test_task_rej_abc_demo.py
@@ -1,5 +1,6 @@
import re
+import pyro
import pytest
import torch
@@ -7,8 +8,7 @@
from sbibm.algorithms import rej_abc
from sbibm.metrics import c2st
-# maybe use the pyro facilities
-torch.manual_seed(47)
+pyro.util.set_rng_seed(47)
task_list = [tn for tn in get_available_tasks() if not re.search("noref|lotka|sir", tn)]
From abc4a852b00641546cb6591cc72e34ee4da81b61 Mon Sep 17 00:00:00 2001
From: Peter Steinbach
Date: Fri, 12 Nov 2021 17:04:23 +0100
Subject: [PATCH 6/6] using pyro utils to set seed
---
tests/tasks/two_moons/test_task.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/tests/tasks/two_moons/test_task.py b/tests/tasks/two_moons/test_task.py
index 7b2a4a33..14e0a456 100644
--- a/tests/tasks/two_moons/test_task.py
+++ b/tests/tasks/two_moons/test_task.py
@@ -1,9 +1,10 @@
+import pyro
import pytest
import torch
from sbibm.tasks.two_moons.task import TwoMoons
-torch.manual_seed(47)
+pyro.util.set_rng_seed(47)
## a test suite that can be used for task internal code