Skip to content

Conversation

@DriessenA
Copy link

For now we've only implemented cmonge_gap_from_samples, we will follow with a ConditionalMongeGapEstimator.

@marcocuturi marcocuturi marked this pull request as ready for review January 29, 2025 18:03
@codecov
Copy link

codecov bot commented Jan 29, 2025

Codecov Report

Attention: Patch coverage is 0% with 75 lines in your changes missing coverage. Please review.

Project coverage is 87.06%. Comparing base (e6fff13) to head (b94fef6).

Files with missing lines Patch % Lines
...eural/networks/conditional_perturbation_network.py 0.00% 45 Missing ⚠️
src/ott/neural/methods/conditional_monge_gap.py 0.00% 30 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #605      +/-   ##
==========================================
- Coverage   87.91%   87.06%   -0.86%     
==========================================
  Files          74       76       +2     
  Lines        7657     7731      +74     
  Branches      533      539       +6     
==========================================
- Hits         6732     6731       -1     
- Misses        783      858      +75     
  Partials      142      142              
Files with missing lines Coverage Δ
src/ott/neural/methods/conditional_monge_gap.py 0.00% <0.00%> (ø)
...eural/networks/conditional_perturbation_network.py 0.00% <0.00%> (ø)

... and 1 file with indirect coverage changes

@DriessenA
Copy link
Author

DriessenA commented May 15, 2025

Issue with ConcretizationTypeError

in file src/ott/neural/methods/conditional_monge_gap.py line 70-73

The conditional monge gap from samples currently uses a .unique() call to loop over all the unique conditions in a batch. Which is alright if it used without optimizing a model, just to calculate the gap. However, when optimizing and jitting the functions, jax needs this array to be consistent in size, so simply looping over the unique conditions is not trivial.

Do you have any insights how to handle this and/or circumvent this? Basically, given a batch with datapoints from multiple conditions is their a way to subset the batch per condition and calculate the monge gap per condition, while still being able to JIT the function?

Possible solutions with possible drawbacks:

  • Set a size, for how many conditions there are (and how many datapoints for each condition, otherwise same problem when subsetting the batch to the condition). The drawback being that this puts constraints on the user for correctly implemententing a dataloader that 1) Has all conditions is each batch, 2) Has the same number of datapoints per condition, and 3) that there are sufficient datapoints per condition to estimate the monge gap and the cost function. This might be difficult in imbalanced datasets and, depending on the number of conditions, might lead to huge batch sizes further contraining the use due to hardware specifications. Also, this might break compatibility between the MongeGapEstimator with the cmonge_from_samples, which are currently compatible.
  • Use single-condition batches. This means that the cmonge_from_samples is not used, since it is one condition we can simply call the monge_gap_from_samples. Our contribution would then rather lie in conditional_perturbation_network, the condition model that is optimized using the monge gap.

@marcocuturi @michalk8 what would you suggest?

@jannisborn
Copy link

@michalk8 any thoughts on this?

@michalk8
Copy link
Collaborator

michalk8 commented Sep 1, 2025

@michalk8 any thoughts on this?

Hi @jannisborn , sorry for the late reply!

After consulting with @marcocuturi , the best would be to for a slightly modified version of your 1st suggestion, while using the segment function's 1st interface:

  • segment_ids_x and segment_ids_y are the conditions
  • num_segments should be the total number of unique conditions

The trickiest part is setting the max_measure_size, which is the max allowed size for a specific condition in a batch. This can be either computed on the fly for each batch (which will trigger jitting, but for small batches with uniformly sampled conditions, thisshould be ok).
Alternatively, we can just set high enough such that no subset is (maybe just most of the time?) truncated, at the risk of wasting compute, but no re-jitting.But personally, I prefer the former option.

Please take a look at how segment_sinkhorn is implemented; it should be pretty straight-forward to define an eval_fn that computes the the Monge gap (note that padded_weight_x or padded_weight_y can be all 0s).

Hope this helps!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants