Skip to content

SinkhornOutput.primal_cost crashes for grid of side length 2 #669

@mosco

Description

@mosco

Describe the bug
Solving entropic-regularized OT using Sinkhorn on a grid of side-length 2 (e.g. with shape (2,) or (2,2), etc.) and then accessing the primal_cost of the result raises an AttributeError: 'PointCloud' object has no attribute 'cost_1'.

Version info
OS: macOS Sequoia 15.7.3
Python 3.14.2
jax 0.7.2 (running on CPU)
ott 0.6.0 (installed using conda install -c conda-forge ott-jax)

To Reproduce

import jax.numpy as jnp
from ott.geometry.grid import Grid
from ott.problems.linear.linear_problem import LinearProblem
from ott.solvers.linear.sinkhorn import Sinkhorn
grid = Grid(grid_size=(2,))
prob = LinearProblem(grid, a=jnp.array([1,0]), b=jnp.array([0,1]))
out = Sinkhorn()(prob)
print(out.primal_cost)

**Traceback**

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[13], line 9
      7 prob = LinearProblem(grid, a=jnp.array([1,0]), b=jnp.array([0,1]))
      8 out = Sinkhorn()(prob)
----> 9 print(out.primal_cost)

File ~/anaconda3/envs/ottjax/lib/python3.14/site-packages/ott/solvers/linear/sinkhorn.py:348, in SinkhornOutput.primal_cost(self)
    345 @property
    346 def primal_cost(self) -> jnp.ndarray:
    347   """Return transport cost of current transport solution at geometry."""
--> 348   return self.transport_cost_at_geom(other_geom=self.geom)

File ~/anaconda3/envs/ottjax/lib/python3.14/site-packages/ott/solvers/linear/sinkhorn.py:415, in SinkhornOutput.transport_cost_at_geom(self, other_geom)
    412 # TODO(cuturi): handle online mode for non Euclidean pointcloud geometries.
    413 # TODO(michalk8): handle SqEucl point cloud is not converted to LRCGeom
    414 if other_geom.can_LRC:
--> 415   geom = other_geom.to_LRCGeometry()
    416   return jnp.sum(self.apply(geom.cost_1.T) * geom.cost_2.T)
    417 return jnp.sum(self.matrix * other_geom.cost_matrix)

File ~/anaconda3/envs/ottjax/lib/python3.14/site-packages/ott/geometry/grid.py:393, in Grid.to_LRCGeometry(self, scale, **kwargs)
    381 for dimension, geom in enumerate(self.geometries):
    382   # An overall low-rank conversion of the cost matrix on a grid, to an
    383   # object of :class:`~ott.geometry.low_rank.LRCGeometry`, necesitates an
   (...)    390   # decomposition, the parameter `rank` is set to `0`, triggering a full
    391   # singular value decomposition if needed.
    392   geom = geom.to_LRCGeometry(rank=0, scale=scale, **kwargs)
--> 393   c_1, c_2 = geom.cost_1, geom.cost_2
    394   l, r = self.grid_size[:dimension], self.grid_size[dimension + 1:]
    395   l = int(np.prod(np.array(l)))

AttributeError: 'PointCloud' object has no attribute 'cost_1'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions