Skip to content

Commit

Permalink
[MRG] Add LazyTensor for large scale OT (#544)
Browse files Browse the repository at this point in the history
* add LazyTensor

* test marginals for OTResult

* test marginals for OTResult

* debug tensorflow and impelment reduce function for Lazytensors

* cleanup tensorflow stuff and debut backnd detection in reduce function

* comment agramfor

* remove trailing space

* all comments alex

* add simple low rank lazytensor creator and add it to log for actored OT

* better tests corverage

* update doc function

* pep8
  • Loading branch information
rflamary authored Oct 31, 2023
1 parent 6a29551 commit 53dde7a
Show file tree
Hide file tree
Showing 5 changed files with 441 additions and 4 deletions.
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
+ New LP solvers from scipy used by default for LP barycenter (PR #537)
+ Update wheels to Python 3.12 and remove old i686 arch that do not have scipy wheels (PR #543)
+ Upgraded unbalanced OT solvers for more flexibility (PR #539)
+ Add LazyTensor for modeling plans and low rank tensor in large scale OT (PR #544)

#### Closed issues
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)
Expand Down
3 changes: 2 additions & 1 deletion ot/factored.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# License: MIT License

from .backend import get_backend
from .utils import dist
from .utils import dist, get_lowrank_lazytensor
from .lp import emd
from .bregman import sinkhorn

Expand Down Expand Up @@ -139,6 +139,7 @@ def solve_ot(X1, X2, w1, w2):
'vb': logb['v'],
'costa': loga['cost'],
'costb': logb['cost'],
'lazy_plan': get_lowrank_lazytensor(Ga * r, Gb.T, nx=nx),
}
return Ga, Gb, X, log_dic

Expand Down
272 changes: 269 additions & 3 deletions ot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,121 @@ def get_coordinate_circle(x):
return x_t


def reduce_lazytensor(a, func, axis=None, nx=None, batch_size=100):
""" Reduce a LazyTensor along an axis with function fun using batches.
When axis=None, reduce the LazyTensor to a scalar as a sum of fun over
batches taken along dim.
.. warning::
This function works for tensor of any order but the reduction can be done
only along the first two axis (or global). Also, in order to work, it requires that the slice of size `batch_size` along the axis to reduce (or axis 0 if `axis=None`) is can be computed and fits in memory.
Parameters
----------
a : LazyTensor
LazyTensor to reduce
func : callable
Function to apply to the LazyTensor
axis : int, optional
Axis along which to reduce the LazyTensor. If None, reduce the
LazyTensor to a scalar as a sum of fun over batches taken along axis 0.
If 0 or 1 reduce the LazyTensor to a vector/matrix as a sum of fun over
batches taken along axis.
nx : Backend, optional
Backend to use for the reduction
batch_size : int, optional
Size of the batches to use for the reduction (default=100)
Returns
-------
res : array-like
Result of the reduction
"""

if nx is None:
nx = get_backend(a[0])

if axis is None:
res = 0.0
for i in range(0, a.shape[0], batch_size):
res += func(a[i:i + batch_size])
return res
elif axis == 0:
res = nx.zeros(a.shape[1:], type_as=a[0])
if nx.__name__ in ["jax", "tf"]:
lst = []
for j in range(0, a.shape[1], batch_size):
lst.append(func(a[:, j:j + batch_size], 0))
return nx.concatenate(lst, axis=0)
else:
for j in range(0, a.shape[1], batch_size):
res[j:j + batch_size] = func(a[:, j:j + batch_size], axis=0)
return res
elif axis == 1:
if len(a.shape) == 2:
shape = (a.shape[0])
else:
shape = (a.shape[0], *a.shape[2:])
res = nx.zeros(shape, type_as=a[0])
if nx.__name__ in ["jax", "tf"]:
lst = []
for i in range(0, a.shape[0], batch_size):
lst.append(func(a[i:i + batch_size], 1))
return nx.concatenate(lst, axis=0)
else:
for i in range(0, a.shape[0], batch_size):
res[i:i + batch_size] = func(a[i:i + batch_size], axis=1)
return res

else:
raise (NotImplementedError("Only axis=None, 0 or 1 is implemented for now."))


def get_lowrank_lazytensor(Q, R, d=None, nx=None):
""" Get a low rank LazyTensor T=Q@R^T or T=Q@diag(d)@R^T
Parameters
----------
Q : ndarray, shape (n, r)
First factor of the lowrank tensor
R : ndarray, shape (m, r)
Second factor of the lowrank tensor
d : ndarray, shape (r,), optional
Diagonal of the lowrank tensor
nx : Backend, optional
Backend to use for the reduction
Returns
-------
T : LazyTensor
Lowrank tensor T=Q@R^T or T=Q@diag(d)@R^T
"""

if nx is None:
nx = get_backend(Q, R, d)

shape = (Q.shape[0], R.shape[0])

if d is None:

def func(i, j, Q, R):
return nx.dot(Q[i], R[j].T)

T = LazyTensor(shape, func, Q=Q, R=R)

else:

def func(i, j, Q, R, d):
return nx.dot(Q[i] * d[None, :], R[j].T)

T = LazyTensor(shape, func, Q=Q, R=R, d=d)

return T


def get_parameter_pair(parameter):
r"""Extract a pair of parameters from a given parameter
Used in unbalanced OT and COOT solvers
Expand Down Expand Up @@ -761,7 +876,76 @@ class UndefinedParameter(Exception):


class OTResult:
def __init__(self, potentials=None, value=None, value_linear=None, value_quad=None, plan=None, log=None, backend=None, sparse_plan=None, lazy_plan=None, status=None):
""" Base class for OT results.
Parameters
----------
potentials : tuple of array-like, shape (`n1`, `n2`)
Dual potentials, i.e. Lagrange multipliers for the marginal constraints.
This pair of arrays has the same shape, numerical type
and properties as the input weights "a" and "b".
value : float, array-like
Full transport cost, including possible regularization terms and
quadratic term for Gromov Wasserstein solutions.
value_linear : float, array-like
The linear part of the transport cost, i.e. the product between the
transport plan and the cost.
value_quad : float, array-like
The quadratic part of the transport cost for Gromov-Wasserstein
solutions.
plan : array-like, shape (`n1`, `n2`)
Transport plan, encoded as a dense array.
log : dict
Dictionary containing potential information about the solver.
backend : Backend
Backend used to compute the results.
sparse_plan : array-like, shape (`n1`, `n2`)
Transport plan, encoded as a sparse array.
lazy_plan : LazyTensor
Transport plan, encoded as a symbolic POT or KeOps LazyTensor.
status : int or str
Status of the solver.
batch_size : int
Batch size used to compute the results/marginals for LazyTensor.
Attributes
----------
potentials : tuple of array-like, shape (`n1`, `n2`)
Dual potentials, i.e. Lagrange multipliers for the marginal constraints.
This pair of arrays has the same shape, numerical type
and properties as the input weights "a" and "b".
potential_a : array-like, shape (`n1`,)
First dual potential, associated to the "source" measure "a".
potential_b : array-like, shape (`n2`,)
Second dual potential, associated to the "target" measure "b".
value : float, array-like
Full transport cost, including possible regularization terms and
quadratic term for Gromov Wasserstein solutions.
value_linear : float, array-like
The linear part of the transport cost, i.e. the product between the
transport plan and the cost.
value_quad : float, array-like
The quadratic part of the transport cost for Gromov-Wasserstein
solutions.
plan : array-like, shape (`n1`, `n2`)
Transport plan, encoded as a dense array.
sparse_plan : array-like, shape (`n1`, `n2`)
Transport plan, encoded as a sparse array.
lazy_plan : LazyTensor
Transport plan, encoded as a symbolic POT or KeOps LazyTensor.
marginals : tuple of array-like, shape (`n1`,), (`n2`,)
Marginals of the transport plan: should be very close to "a" and "b"
for balanced OT.
marginal_a : array-like, shape (`n1`,)
Marginal of the transport plan for the "source" measure "a".
marginal_b : array-like, shape (`n2`,)
Marginal of the transport plan for the "target" measure "b".
"""

def __init__(self, potentials=None, value=None, value_linear=None, value_quad=None, plan=None, log=None, backend=None, sparse_plan=None, lazy_plan=None, status=None, batch_size=100):

self._potentials = potentials
self._value = value
Expand All @@ -773,6 +957,7 @@ def __init__(self, potentials=None, value=None, value_linear=None, value_quad=No
self._lazy_plan = lazy_plan
self._backend = backend if backend is not None else NumpyBackend()
self._status = status
self._batch_size = batch_size

# I assume that other solvers may return directly
# some primal objects?
Expand All @@ -793,7 +978,8 @@ def __repr__(self):
s += 'value_linear={},'.format(self._value_linear)
if self._plan is not None:
s += 'plan={}(shape={}),'.format(self._plan.__class__.__name__, self._plan.shape)

if self._lazy_plan is not None:
s += 'lazy_plan={}(shape={}),'.format(self._lazy_plan.__class__.__name__, self._lazy_plan.shape)
if s[-1] != '(':
s = s[:-1] + ')'
else:
Expand Down Expand Up @@ -853,7 +1039,10 @@ def sparse_plan(self):
@property
def lazy_plan(self):
"""Transport plan, encoded as a symbolic KeOps LazyTensor."""
raise NotImplementedError()
if self._lazy_plan is not None:
return self._lazy_plan
else:
raise NotImplementedError()

# Loss values --------------------------------

Expand Down Expand Up @@ -897,6 +1086,11 @@ def marginal_a(self):
"""First marginal of the transport plan, with the same shape as "a"."""
if self._plan is not None:
return self._backend.sum(self._plan, 1)
elif self._lazy_plan is not None:
lp = self._lazy_plan
bs = self._batch_size
nx = self._backend
return reduce_lazytensor(lp, nx.sum, axis=1, nx=nx, batch_size=bs)
else:
raise NotImplementedError()

Expand All @@ -905,6 +1099,11 @@ def marginal_b(self):
"""Second marginal of the transport plan, with the same shape as "b"."""
if self._plan is not None:
return self._backend.sum(self._plan, 0)
elif self._lazy_plan is not None:
lp = self._lazy_plan
bs = self._batch_size
nx = self._backend
return reduce_lazytensor(lp, nx.sum, axis=0, nx=nx, batch_size=bs)
else:
raise NotImplementedError()

Expand Down Expand Up @@ -968,3 +1167,70 @@ def citation(self):
url = {http://jmlr.org/papers/v22/20-451.html}
}
"""


class LazyTensor(object):
""" A lazy tensor is a tensor that is not stored in memory. Instead, it is
defined by a function that computes its values on the fly from slices.
Parameters
----------
shape : tuple
shape of the tensor
getitem : callable
function that computes the values of the indices/slices and tensors
as arguments
kwargs : dict
named arguments for the function, those names will be used as attributed
of the LazyTensor object
Examples
--------
>>> import numpy as np
>>> v = np.arange(5)
>>> def getitem(i,j, v):
... return v[i,None]+v[None,j]
>>> T = LazyTensor((5,5),getitem, v=v)
>>> T[1,2]
array([3])
>>> T[1,:]
array([[1, 2, 3, 4, 5]])
>>> T[:]
array([[0, 1, 2, 3, 4],
[1, 2, 3, 4, 5],
[2, 3, 4, 5, 6],
[3, 4, 5, 6, 7],
[4, 5, 6, 7, 8]])
"""

def __init__(self, shape, getitem, **kwargs):

self._getitem = getitem
self.shape = shape
self.ndim = len(shape)
self.kwargs = kwargs

# set attributes for named arguments/arrays
for key, value in kwargs.items():
setattr(self, key, value)

def __getitem__(self, key):
k = []
if isinstance(key, int) or isinstance(key, slice):
k.append(key)
for i in range(self.ndim - 1):
k.append(slice(None))
elif isinstance(key, tuple):
k = list(key)
for i in range(self.ndim - len(key)):
k.append(slice(None))
else:
raise NotImplementedError("Only integer, slice, and tuple indexing is supported")

return self._getitem(*k, **self.kwargs)

def __repr__(self):
return "LazyTensor(shape={},attributes=({}))".format(self.shape, ','.join(self.kwargs.keys()))
1 change: 1 addition & 0 deletions test/test_factored.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def test_factored_ot():
# check constraints
np.testing.assert_allclose(u, Ga.sum(1))
np.testing.assert_allclose(u, Gb.sum(0))
np.testing.assert_allclose(1, log['lazy_plan'][:].sum())


def test_factored_ot_backends(nx):
Expand Down
Loading

0 comments on commit 53dde7a

Please sign in to comment.