Skip to content

Conversation

@KaiChen9909
Copy link

Support relaxed projection and neural network generators.

Two simple examples are provided in mechanisms/NN.py and mechanisms/RP.py

@google-cla
Copy link

google-cla bot commented Aug 2, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Copy link
Owner

@ryan112358 ryan112358 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the hard work on this, looks like a great start! Will provide another round of feedback after my initial comments are resolved


import jax
import jax.numpy as jnp
from jax import random, jit, value_and_grad, vmap, lax
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's stick to the convention of just "import jax" and "jax.numpy as jnp, and not import these additional members

@@ -0,0 +1,593 @@
import numpy as np
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a unit test to test_estimation.py to ensure this works as expected when no noise is added?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A non-DP example is added

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

examples won't necessarily be run by pytest or github actions, can you port over a simplified version of the example to test_estimation.py?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

finished

Copy link
Owner

@ryan112358 ryan112358 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Appreciate all the work you've put into this so far!

D_start = _initialize_synthetic_dataset(key, num_generated_points=D_size, data_dimension=np.sum(domain.shape))

stat_dim = _obtain_dim(measurements = measurements)
statistics = [MarginalStatistics(domain, dim) for dim in stat_dim]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the MarginalStatistics class should not be necessary if you have a MarginalLossFn. The pattern could be:

marginals = { cl : D.project(cl) for cl in cliques }
loss = loss_fn(marginals)

And this whole computation should be differentiable by jax

from .domain import Domain
from .marginal_loss import LinearMeasurement
from .clique_vector import CliqueVector
from .estimation import mirror_descent
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't need this dependency here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean mirror_descent or all of these four dependencies? Domain and LinearMeasurement are imported for signatures of function, while others, except mirror_descent, are used in the function.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just mirror descent

@KaiChen9909
Copy link
Author

I pushed the most recent code to this branch. I went through previous comments and addressed most of them. One exception is the two-case loss function in RP, for which I left some replies.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants