From 550b267bee28810688c11f33eff2124d60a9c8c3 Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Thu, 11 Dec 2025 09:06:52 +0100 Subject: [PATCH 1/6] added test case --- test/core/prior/conditional_test.py | 53 ++++++++++++++++++----------- 1 file changed, 34 insertions(+), 19 deletions(-) diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py index 20c0cda93..17e0a03dd 100644 --- a/test/core/prior/conditional_test.py +++ b/test/core/prior/conditional_test.py @@ -218,6 +218,32 @@ def condition_func_3(reference_parameters, var_1, var_2): ).items(): self.conditional_priors_manually_set_items[key] = value + names = ["mvgvar_a", "mvgvar_b"] + mu = [[0.79, -0.83]] + cov = [[[0.03, 0.0], [0.0, 0.04]]] + mvg = bilby.core.prior.MultivariateGaussianDist(names, mus=mu, covs=cov) + + def condition_func_4(reference_parameters, mvgvar_a): + return dict(minimum=reference_parameters["minimum"], maximum=mvgvar_a) + + prior_4 = bilby.core.prior.ConditionalUniform( + condition_func=condition_func_4, minimum=self.minimum, maximum=self.maximum + ) + + self.conditional_priors_with_joint_prior = ( + bilby.core.prior.ConditionalPriorDict( + dict( + var_4=prior_4, + var_3=self.prior_3, + var_2=self.prior_2, + var_0=self.prior_0, + var_1=self.prior_1, + mvgvar_a=bilby.core.prior.MultivariateGaussian(mvg, "mvgvar_a"), + mvgvar_b=bilby.core.prior.MultivariateGaussian(mvg, "mvgvar_b"), + ) + ) + ) + def tearDown(self): del self.minimum del self.maximum @@ -227,6 +253,7 @@ def tearDown(self): del self.prior_3 del self.conditional_priors del self.conditional_priors_manually_set_items + del self.conditional_priors_with_joint_prior del self.test_sample def test_conditions_resolved_upon_instantiation(self): @@ -333,35 +360,23 @@ def test_rescale_with_joint_prior(self): """ # set multivariate Gaussian distribution - names = ["mvgvar_0", "mvgvar_1"] - mu = [[0.79, -0.83]] - cov = [[[0.03, 0.], [0., 0.04]]] - mvg = bilby.core.prior.MultivariateGaussianDist(names, mus=mu, covs=cov) - priordict = bilby.core.prior.ConditionalPriorDict( - dict( - var_3=self.prior_3, - var_2=self.prior_2, - var_0=self.prior_0, - var_1=self.prior_1, - mvgvar_0=bilby.core.prior.MultivariateGaussian(mvg, "mvgvar_0"), - mvgvar_1=bilby.core.prior.MultivariateGaussian(mvg, "mvgvar_1"), - ) - ) + priordict = self.conditional_priors_with_joint_prior + names = ["mvgvar_a", "mvgvar_b"] - ref_variables = list(self.test_sample.values()) + [0.4, 0.1] - keys = list(self.test_sample.keys()) + names + ref_variables = [0.1] + list(self.test_sample.values()) + [0.4, 0.1] + keys = ["var_4"] + list(self.test_sample.keys()) + names res = priordict.rescale(keys=keys, theta=ref_variables) self.assertIsInstance(res, list) - self.assertEqual(np.shape(res), (6,)) - self.assertListEqual([isinstance(r, float) for r in res], 6 * [True]) + self.assertEqual(np.shape(res), (7,)) + self.assertListEqual([isinstance(r, float) for r in res], 7 * [True]) # check conditional values are still as expected expected = [self.test_sample["var_0"]] for ii in range(1, 4): expected.append(expected[-1] * self.test_sample[f"var_{ii}"]) - self.assertListEqual(expected, res[0:4]) + self.assertListEqual(expected, res[1:5]) def test_cdf(self): """ From d3770ac646733ff1a9920178e85d0ee6cf34af60 Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Thu, 11 Dec 2025 09:08:26 +0100 Subject: [PATCH 2/6] ensure correct ordering with joint priors --- bilby/core/prior/dict.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 3ac54622e..d2bd660a1 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -731,11 +731,14 @@ def _resolve_conditions(self): def _check_conditions_resolved(self, key, sampled_keys): """Checks if all required variables have already been sampled so we can sample this key""" - conditions_resolved = True for k in self[key].required_variables: if k not in sampled_keys: - conditions_resolved = False - return conditions_resolved + return False + elif isinstance(self[k], JointPrior): + for name in self[k].dist.names: + if name not in sampled_keys and name != key: + return False + return True def sample_subset(self, keys=iter([]), size=None): self.convert_floats_to_delta_functions() From 791786c87602570e05d378de24edc322c04c78f6 Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Thu, 11 Dec 2025 09:13:21 +0100 Subject: [PATCH 3/6] ensure correct setting of least_recently_sampled for joint priors --- bilby/core/prior/dict.py | 51 ++++++++++++++++++---------------------- 1 file changed, 23 insertions(+), 28 deletions(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index d2bd660a1..171e22197 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -877,36 +877,31 @@ def rescale(self, keys, theta): result[key] = self[key].rescale( theta[index], **self.get_required_variables(key) ) - self[key].least_recently_sampled = result[key] - if isinstance(self[key], JointPrior) and self[key].dist.distname not in joint: - joint[self[key].dist.distname] = [key] - elif isinstance(self[key], JointPrior): - joint[self[key].dist.distname].append(key) - for names in joint.values(): - # this is needed to unpack how joint prior rescaling works - # as an example of a joint prior over {a, b, c, d} we might - # get the following based on the order within the joint prior - # {a: [], b: [], c: [1, 2, 3, 4], d: []} - # -> [1, 2, 3, 4] - # -> {a: 1, b: 2, c: 3, d: 4} - values = list() - for key in names: - values = np.concatenate([values, result[key]]) - for key, value in zip(names, values): - result[key] = value - - def safe_flatten(value): - """ - this is gross but can be removed whenever we switch to returning - arrays, flatten converts 0-d arrays to 1-d so has to be special - cased - """ - if isinstance(value, (float, int, np.int64)): - return value + + # if any requested key depends on some joint prior `jp_key` + # self[jp_key].least_recently_sampled needs to be set before + # rescaling those requested keys. + # Thus we keep track of joint priors here + if isinstance(self[key], JointPrior): + # if joint prior, keep track if all names have been rescaled + distname = self[key].dist.distname + names = set(self[key].dist.names) + if distname not in joint: + joint[distname] = [key] + elif isinstance(self[key], JointPrior): + joint[distname].append(key) + # only when all names have been rescaled, we can set the values + if set(names) == set(joint[distname]): + for name, value in zip(names, result[key]): + result[name] = value + self[name].least_recently_sampled = value + joint.pop(distname) else: - return result[key].flatten() + # if not joint prior, set value immediately + self[key].least_recently_sampled = result[key] - return [safe_flatten(result[key]) for key in keys] + # finally return results in the order requested + return [result[key] for key in keys] def _update_rescale_keys(self, keys): if not keys == self._least_recently_rescaled_keys: From 5f41d973c84d581e53d424621b91ebff0be485d7 Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Fri, 12 Dec 2025 21:21:22 +0100 Subject: [PATCH 4/6] remove redundant casts to set --- bilby/core/prior/dict.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 171e22197..a2186776b 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -887,11 +887,12 @@ def rescale(self, keys, theta): distname = self[key].dist.distname names = set(self[key].dist.names) if distname not in joint: - joint[distname] = [key] + joint[distname] = {key} elif isinstance(self[key], JointPrior): - joint[distname].append(key) + joint[distname].add(key) # only when all names have been rescaled, we can set the values - if set(names) == set(joint[distname]): + # we use sets because the order does not matter here + if names == joint[distname]: for name, value in zip(names, result[key]): result[name] = value self[name].least_recently_sampled = value From fae16eab331070254bcb613b3c4da263cab6064f Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Mon, 15 Dec 2025 12:17:26 +0100 Subject: [PATCH 5/6] remove unnessary if-condition --- bilby/core/prior/dict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index a2186776b..b05eb3e76 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -888,7 +888,7 @@ def rescale(self, keys, theta): names = set(self[key].dist.names) if distname not in joint: joint[distname] = {key} - elif isinstance(self[key], JointPrior): + else: joint[distname].add(key) # only when all names have been rescaled, we can set the values # we use sets because the order does not matter here From d3fe950411d05ef23366a2ce6f40e494e5a83706 Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Wed, 8 Apr 2026 08:47:01 +0200 Subject: [PATCH 6/6] fix joint prior name order --- bilby/core/prior/dict.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index b05eb3e76..25b970852 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -885,14 +885,16 @@ def rescale(self, keys, theta): if isinstance(self[key], JointPrior): # if joint prior, keep track if all names have been rescaled distname = self[key].dist.distname - names = set(self[key].dist.names) + # maintain order of names as in the dist as this is the order + # in which they will be rescaled + names = self[key].dist.names if distname not in joint: joint[distname] = {key} else: joint[distname].add(key) # only when all names have been rescaled, we can set the values # we use sets because the order does not matter here - if names == joint[distname]: + if set(names) == joint[distname]: for name, value in zip(names, result[key]): result[name] = value self[name].least_recently_sampled = value