-
Notifications
You must be signed in to change notification settings - Fork 117
Open
Description
Describe the bug
Solving entropic-regularized OT using Sinkhorn on a grid of side-length 2 (e.g. with shape (2,) or (2,2), etc.) and then accessing the primal_cost of the result raises an AttributeError: 'PointCloud' object has no attribute 'cost_1'.
Version info
OS: macOS Sequoia 15.7.3
Python 3.14.2
jax 0.7.2 (running on CPU)
ott 0.6.0 (installed using conda install -c conda-forge ott-jax)
To Reproduce
import jax.numpy as jnp
from ott.geometry.grid import Grid
from ott.problems.linear.linear_problem import LinearProblem
from ott.solvers.linear.sinkhorn import Sinkhorn
grid = Grid(grid_size=(2,))
prob = LinearProblem(grid, a=jnp.array([1,0]), b=jnp.array([0,1]))
out = Sinkhorn()(prob)
print(out.primal_cost)
**Traceback**
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[13], line 9
7 prob = LinearProblem(grid, a=jnp.array([1,0]), b=jnp.array([0,1]))
8 out = Sinkhorn()(prob)
----> 9 print(out.primal_cost)
File ~/anaconda3/envs/ottjax/lib/python3.14/site-packages/ott/solvers/linear/sinkhorn.py:348, in SinkhornOutput.primal_cost(self)
345 @property
346 def primal_cost(self) -> jnp.ndarray:
347 """Return transport cost of current transport solution at geometry."""
--> 348 return self.transport_cost_at_geom(other_geom=self.geom)
File ~/anaconda3/envs/ottjax/lib/python3.14/site-packages/ott/solvers/linear/sinkhorn.py:415, in SinkhornOutput.transport_cost_at_geom(self, other_geom)
412 # TODO(cuturi): handle online mode for non Euclidean pointcloud geometries.
413 # TODO(michalk8): handle SqEucl point cloud is not converted to LRCGeom
414 if other_geom.can_LRC:
--> 415 geom = other_geom.to_LRCGeometry()
416 return jnp.sum(self.apply(geom.cost_1.T) * geom.cost_2.T)
417 return jnp.sum(self.matrix * other_geom.cost_matrix)
File ~/anaconda3/envs/ottjax/lib/python3.14/site-packages/ott/geometry/grid.py:393, in Grid.to_LRCGeometry(self, scale, **kwargs)
381 for dimension, geom in enumerate(self.geometries):
382 # An overall low-rank conversion of the cost matrix on a grid, to an
383 # object of :class:`~ott.geometry.low_rank.LRCGeometry`, necesitates an
(...) 390 # decomposition, the parameter `rank` is set to `0`, triggering a full
391 # singular value decomposition if needed.
392 geom = geom.to_LRCGeometry(rank=0, scale=scale, **kwargs)
--> 393 c_1, c_2 = geom.cost_1, geom.cost_2
394 l, r = self.grid_size[:dimension], self.grid_size[dimension + 1:]
395 l = int(np.prod(np.array(l)))
AttributeError: 'PointCloud' object has no attribute 'cost_1'Metadata
Metadata
Assignees
Labels
No labels