Skip to content

Conversation

@JasperMartins
Copy link
Contributor

@JasperMartins JasperMartins commented Dec 12, 2025

This PR addresses the issues in ConditionalPriorDict.rescale described in the first item in #1026

Just to reiterate, the addressed issues are:

  • Currently, if any prior in ConditionalPriorDict depends on a JointPrior, the correct ordering of the keys is not ensured for rescaling because all keys associated with the JointPrior need to be rescaled first.
  • Even if the order of keys was correct, rescaling would fail because the least_recently_sampled property of the joint priors is not updated before it is accessed by get_required_variables.

Note: The safe_flatten function introduced in #979 to ConditionalPriorDict.rescale has a bug that if any of the values was an array, the returned flattened array takes result[key].flatten() where it should use value.flatten(). However, I believe the flattening operation is redundant now anyway. It was previously only used to flatten the output of JointPriors, I believe, which is now handled separately, and the flattening is never triggered by any of the test cases. (Note that if one attempts to rescale many samples at once, which works in many cases and is explicitly supported by some priors' doc-strings, this bug would lead to a wrong output. However, flattening would also not be necessary here.)

Copy link
Collaborator

@ColmTalbot ColmTalbot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for breaking this down, I think the logic makes sense. I just had a couple of minor questions to make sure we aren't doing unnecessary conversions between lists/sets.

It's great to see the end of safe_flatten!

if isinstance(self[key], JointPrior):
# if joint prior, keep track if all names have been rescaled
distname = self[key].dist.distname
names = set(self[key].dist.names)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels like it shouldn't be needed, do we not impose that parameter names are unique?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cast to set is made because it does not matter in which order the keys occur in dist.names or joint[distname]. This could also be solved with a loop, if that would be preferred.

distname = self[key].dist.distname
names = set(self[key].dist.names)
if distname not in joint:
joint[distname] = [key]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is only used as a set, why not just construct this directly as a set?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, and there were other redundant casts. Fixed with a new commit!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: Noticed there was one more simplification possible

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants