-
Notifications
You must be signed in to change notification settings - Fork 117
Open
Description
Describe the bug
For application use case see tests from moscot https://github.com/theislab/moscot/actions/runs/8709537760/job/23889450330?pr=677
Unbalanced FGW is unstable especially when margins are provided. I played with epsilon and tau's but still doesn't converge. I think this happened after 41906a2
To Reproduce
import numpy as np
import jax.numpy as jnp
from ott.geometry import pointcloud
from ott.solvers.quadratic import solve
# Generating random data for x and y
x = np.random.rand(96, 2) # 96 points in 2D
y = np.random.rand(96, 2) # Another 96 points in 2D
# Create PointCloud instances
geom_xx = pointcloud.PointCloud(x)
geom_yy = pointcloud.PointCloud(y)
geom_xy = pointcloud.PointCloud(x, y)
# a and b are vectors of ones with lengths matching the number of points in x and y, respectively
a = jnp.ones(x.shape[0])
b = jnp.ones(y.shape[0])
# Call solve function with the specified parameters
solve(geom_xx=geom_xx, geom_yy=geom_yy, geom_xy=geom_xy, tau_a=0.9, tau_b=0.9,
fused_penalty=1.0, epsilon=1.0, a=a, b=b)Metadata
Metadata
Assignees
Labels
No labels