Skip to content

Issue152#154

Draft
jsv1206 wants to merge 9 commits intomainfrom
issue152
Draft

Issue152#154
jsv1206 wants to merge 9 commits intomainfrom
issue152

Conversation

@jsv1206
Copy link

@jsv1206 jsv1206 commented Sep 10, 2025

This PR includes adding Hierarchy model functionalities i.e. reading in the data and model, plot the parity, CDFranks, and Rank plots.

The PR also includes the tests for these plots and they run without any error.

@jsv1206
Copy link
Author

jsv1206 commented Sep 10, 2025

Here are the plots that are currently produced for the Hierarchy Model:

Parity plots:
3a131e033a_global_parity
3a131e033a_local_parity

Rank plots:
3bf062ccc3_global_ranks
3bf062ccc3_local_ranks

CDFRank plots:
9d8c9faf47_global_cdf_ranks
9d8c9faf47_local_cdf_ranks

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to not directly include this whole bit of source code? It feels a little dishonest bc it is Rachel's code and not our's. Does she have a repo with it somewhere we can pull from?


self.prior_dist = self.load_prior(prior, prior_kwargs)
self.n_dims = self.thetas.shape[1]
# Uncomment this for NPE
Copy link
Contributor

@voetberg voetberg Sep 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the best way to do this is to check the type of self.thetas and only set n_dims f it's an array or something

context array
"""
try:
return self.data["context"][0] # has to be made more general because in pendulum time is fixed. So we can do this
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flagging this as a TODO

y_global = y_global[:, 0, :]
return y_local, y_global

except KeyError:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we check for the local and global parameters before this? I am afraid this key error will cause confusion. Maybe in the init after the load we run verify they're in the keys and otherwise through a valueerror

return y_local, y_global

except KeyError:
raise NotImplementedError("Data does not have a `thetas` field.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think a NotImplementedError is the right error here - I'd make it a ValueError or something like that

Returns:
torch.Tensor: Samples from the local thetas
"""
deep_set = self.model.deep_set
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm worried about this line. Not every model will have a deep_set option. Is there a way we can this up with a different model?

If we're importing the model class anyways, can we try to assume the model itself has a sample method we can work with?


from deepdiagnostics.models.model import Model

# import sys
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reminder to remove this


def _data_setup(self, **kwargs) -> DataDisplay:
gs = kwargs.pop("global_samples", self.global_samples)
gs = bool(gs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just use the self.global_samples option? This option will never be run without the init

gs = bool(gs)

# support attribute or callable access
x = self.data.simulator_outcome() if callable(self.data.simulator_outcome) else self.data.simulator_outcome
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Little confused by this. Why would the simulatulator outcome be a callable?


true_samples[index] = self.data.thetas[sample, :]

# print shape of arrays
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reminder to remove this

else:
parity_plot = subplots[theta_dimension]
subplots[0].set_ylabel("Parity")
print('theta_dimension', theta_dimension)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reminder to remove printing here


"""

def __init__(self, model, data, global_samples: bool = True, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

include_residual and global samples should probably be kwargs for the plot method? The docstrings should the args for plot but not for init

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants