Skip to content

Unbalanced FGW doesn't converge when margins are provided #519

@selmanozleyen

Description

@selmanozleyen

Describe the bug
For application use case see tests from moscot https://github.com/theislab/moscot/actions/runs/8709537760/job/23889450330?pr=677

Unbalanced FGW is unstable especially when margins are provided. I played with epsilon and tau's but still doesn't converge. I think this happened after 41906a2

To Reproduce

import numpy as np
import jax.numpy as jnp
from ott.geometry import pointcloud
from ott.solvers.quadratic import solve


# Generating random data for x and y
x = np.random.rand(96, 2)  # 96 points in 2D
y = np.random.rand(96, 2)  # Another 96 points in 2D

# Create PointCloud instances
geom_xx = pointcloud.PointCloud(x)
geom_yy = pointcloud.PointCloud(y)
geom_xy = pointcloud.PointCloud(x, y)

# a and b are vectors of ones with lengths matching the number of points in x and y, respectively
a = jnp.ones(x.shape[0])
b = jnp.ones(y.shape[0])

# Call solve function with the specified parameters
solve(geom_xx=geom_xx, geom_yy=geom_yy, geom_xy=geom_xy, tau_a=0.9, tau_b=0.9,
      fused_penalty=1.0, epsilon=1.0, a=a, b=b)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions