Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions bilby/core/prior/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,12 +709,17 @@ def _resolve_conditions(self):
4. We set the `self._resolved` flag to True if all conditional
priors were added in the right order
"""
self._unconditional_keys = [
key for key in self.keys() if not hasattr(self[key], "condition_func")
]
conditional_keys_unsorted = [
key for key in self.keys() if hasattr(self[key], "condition_func")
]
conditional_keys_unsorted = []
self._unconditional_keys = []
joint_dists = {}
for key in self.keys():
if not hasattr(self[key], "condition_func"):
self._unconditional_keys.append(key)
else:
conditional_keys_unsorted.append(key)
if isinstance(self[key], JointPrior):
joint_dists[self[key].dist.distname] = self[key].dist.names

self._conditional_keys = []
for _ in range(len(self)):
for key in conditional_keys_unsorted[:]:
Expand All @@ -726,6 +731,12 @@ def _resolve_conditions(self):
if len(conditional_keys_unsorted) != 0:
self._resolved = False

# ensure that all joint dist names are resolved
for names in joint_dists.values():
if not set(names).issubset(self.sorted_keys):
self._resolved = False
break

def _check_conditions_resolved(self, key, sampled_keys):
"""Checks if all required variables have already been sampled so we can sample this key"""
conditions_resolved = True
Expand Down
29 changes: 29 additions & 0 deletions test/core/prior/conditional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,32 @@ def condition_func_3(reference_parameters, var_1, var_2):
).items():
self.conditional_priors_manually_set_items[key] = value

names = ["mvgvar_a", "mvgvar_b"]
mu = [[0.79, -0.83]]
cov = [[[0.03, 0.0], [0.0, 0.04]]]
mvg = bilby.core.prior.MultivariateGaussianDist(names, mus=mu, covs=cov)

def condition_func_4(reference_parameters, mvgvar_a):
return dict(minimum=reference_parameters["minimum"], maximum=mvgvar_a)

prior_4 = bilby.core.prior.ConditionalUniform(
condition_func=condition_func_4, minimum=self.minimum, maximum=self.maximum
)

self.conditional_priors_with_joint_prior = (
bilby.core.prior.ConditionalPriorDict(
dict(
var_4=prior_4,
var_3=self.prior_3,
var_2=self.prior_2,
var_0=self.prior_0,
var_1=self.prior_1,
mvgvar_a=bilby.core.prior.MultivariateGaussian(mvg, "mvgvar_a"),
mvgvar_b=bilby.core.prior.MultivariateGaussian(mvg, "mvgvar_b"),
)
)
)

def tearDown(self):
del self.minimum
del self.maximum
Expand All @@ -227,6 +253,7 @@ def tearDown(self):
del self.prior_3
del self.conditional_priors
del self.conditional_priors_manually_set_items
del self.conditional_priors_with_joint_prior
del self.test_sample

def test_conditions_resolved_upon_instantiation(self):
Expand Down Expand Up @@ -292,6 +319,8 @@ def test_sample_subset_all_keys(self):
def test_sample_illegal_subset(self):
with self.assertRaises(bilby.core.prior.IllegalConditionsException):
self.conditional_priors.sample_subset(keys=["var_1"])
with self.assertRaises(bilby.core.prior.IllegalConditionsException):
self.conditional_priors_with_joint_prior.sample_subset(keys=["mvgvar_a"])

def test_sample_multiple(self):
def condition_func(reference_params, a):
Expand Down
Loading