From b3119409b5879e5c2c118b2b33950b3d0147e1c0 Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Thu, 11 Dec 2025 10:22:13 +0100 Subject: [PATCH 1/3] added test-case --- test/core/prior/conditional_test.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py index 20c0cda93..b998dfe50 100644 --- a/test/core/prior/conditional_test.py +++ b/test/core/prior/conditional_test.py @@ -218,6 +218,32 @@ def condition_func_3(reference_parameters, var_1, var_2): ).items(): self.conditional_priors_manually_set_items[key] = value + names = ["mvgvar_a", "mvgvar_b"] + mu = [[0.79, -0.83]] + cov = [[[0.03, 0.0], [0.0, 0.04]]] + mvg = bilby.core.prior.MultivariateGaussianDist(names, mus=mu, covs=cov) + + def condition_func_4(reference_parameters, mvgvar_a): + return dict(minimum=reference_parameters["minimum"], maximum=mvgvar_a) + + prior_4 = bilby.core.prior.ConditionalUniform( + condition_func=condition_func_4, minimum=self.minimum, maximum=self.maximum + ) + + self.conditional_priors_with_joint_prior = ( + bilby.core.prior.ConditionalPriorDict( + dict( + var_4=prior_4, + var_3=self.prior_3, + var_2=self.prior_2, + var_0=self.prior_0, + var_1=self.prior_1, + mvgvar_a=bilby.core.prior.MultivariateGaussian(mvg, "mvgvar_a"), + mvgvar_b=bilby.core.prior.MultivariateGaussian(mvg, "mvgvar_b"), + ) + ) + ) + def tearDown(self): del self.minimum del self.maximum @@ -227,6 +253,7 @@ def tearDown(self): del self.prior_3 del self.conditional_priors del self.conditional_priors_manually_set_items + del self.conditional_priors_with_joint_prior del self.test_sample def test_conditions_resolved_upon_instantiation(self): @@ -292,6 +319,8 @@ def test_sample_subset_all_keys(self): def test_sample_illegal_subset(self): with self.assertRaises(bilby.core.prior.IllegalConditionsException): self.conditional_priors.sample_subset(keys=["var_1"]) + with self.assertRaises(bilby.core.prior.IllegalConditionsException): + self.conditional_priors_with_joint_prior.sample_subset(keys=["mvgvar_a"]) def test_sample_multiple(self): def condition_func(reference_params, a): From 7e3bf3af94daeb9d5004706d54ad1ecaaf07e3de Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Thu, 11 Dec 2025 10:23:21 +0100 Subject: [PATCH 2/3] Check if JointPrior is complete --- bilby/core/prior/dict.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 6d244610b..cfda4bcdf 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -709,12 +709,17 @@ def _resolve_conditions(self): 4. We set the `self._resolved` flag to True if all conditional priors were added in the right order """ - self._unconditional_keys = [ - key for key in self.keys() if not hasattr(self[key], "condition_func") - ] - conditional_keys_unsorted = [ - key for key in self.keys() if hasattr(self[key], "condition_func") - ] + conditional_keys_unsorted = [] + self._unconditional_keys = [] + joint_dists = {} + for key in self.keys(): + if not hasattr(self[key], "condition_func"): + self._unconditional_keys.append(key) + else: + conditional_keys_unsorted.append(key) + if isinstance(self[key], JointPrior): + joint_dists[self[key].dist.distname] = self[key].dist.names + self._conditional_keys = [] for _ in range(len(self)): for key in conditional_keys_unsorted[:]: @@ -726,6 +731,12 @@ def _resolve_conditions(self): if len(conditional_keys_unsorted) != 0: self._resolved = False + # ensure that all joint dist names are resolved + for names in joint_dists.values(): + for name in names: + if name not in self.sorted_keys: + self._resolved = False + def _check_conditions_resolved(self, key, sampled_keys): """Checks if all required variables have already been sampled so we can sample this key""" conditions_resolved = True From 1ef8ea326ccb1e847b71cc2cf47e1b1b8551f87c Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Thu, 11 Dec 2025 22:21:57 +0100 Subject: [PATCH 3/3] make more pythonic --- bilby/core/prior/dict.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index cfda4bcdf..364ee5483 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -733,9 +733,9 @@ def _resolve_conditions(self): # ensure that all joint dist names are resolved for names in joint_dists.values(): - for name in names: - if name not in self.sorted_keys: - self._resolved = False + if not set(names).issubset(self.sorted_keys): + self._resolved = False + break def _check_conditions_resolved(self, key, sampled_keys): """Checks if all required variables have already been sampled so we can sample this key"""