-
Notifications
You must be signed in to change notification settings - Fork 50
Feature/nn #48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Feature/nn #48
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
ryan112358
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the hard work on this, looks like a great start! Will provide another round of feedback after my initial comments are resolved
src/mbi/relaxed_projection.py
Outdated
|
|
||
| import jax | ||
| import jax.numpy as jnp | ||
| from jax import random, jit, value_and_grad, vmap, lax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's stick to the convention of just "import jax" and "jax.numpy as jnp, and not import these additional members
| @@ -0,0 +1,593 @@ | |||
| import numpy as np | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a unit test to test_estimation.py to ensure this works as expected when no noise is added?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A non-DP example is added
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
examples won't necessarily be run by pytest or github actions, can you port over a simplified version of the example to test_estimation.py?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
finished
ryan112358
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Appreciate all the work you've put into this so far!
src/mbi/relaxed_projection.py
Outdated
| D_start = _initialize_synthetic_dataset(key, num_generated_points=D_size, data_dimension=np.sum(domain.shape)) | ||
|
|
||
| stat_dim = _obtain_dim(measurements = measurements) | ||
| statistics = [MarginalStatistics(domain, dim) for dim in stat_dim] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the MarginalStatistics class should not be necessary if you have a MarginalLossFn. The pattern could be:
marginals = { cl : D.project(cl) for cl in cliques }
loss = loss_fn(marginals)
And this whole computation should be differentiable by jax
src/mbi/relaxed_projection.py
Outdated
| from .domain import Domain | ||
| from .marginal_loss import LinearMeasurement | ||
| from .clique_vector import CliqueVector | ||
| from .estimation import mirror_descent |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You shouldn't need this dependency here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean mirror_descent or all of these four dependencies? Domain and LinearMeasurement are imported for signatures of function, while others, except mirror_descent, are used in the function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just mirror descent
|
I pushed the most recent code to this branch. I went through previous comments and addressed most of them. One exception is the two-case loss function in RP, for which I left some replies. |
Support relaxed projection and neural network generators.
Two simple examples are provided in
mechanisms/NN.pyandmechanisms/RP.py