Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Add LazyTensor for large scale OT #544

Merged
merged 14 commits into from
Oct 31, 2023
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
+ New LP solvers from scipy used by default for LP barycenter (PR #537)
+ Update wheels to Python 3.12 and remove old i686 arch that do not have scipy wheels (PR #543)
+ Upgraded unbalanced OT solvers for more flexibility (PR #539)
+ Add LazyTensor for modeling plans and low rank tensor in large scale OT (PR #544)

#### Closed issues
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)
Expand Down
3 changes: 2 additions & 1 deletion ot/factored.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# License: MIT License

from .backend import get_backend
from .utils import dist
from .utils import dist, get_lowrank_lazytensor
from .lp import emd
from .bregman import sinkhorn

Expand Down Expand Up @@ -139,6 +139,7 @@ def solve_ot(X1, X2, w1, w2):
'vb': logb['v'],
'costa': loga['cost'],
'costb': logb['cost'],
'lazy_plan': get_lowrank_lazytensor(Ga * r, Gb.T, nx=nx),
}
return Ga, Gb, X, log_dic

Expand Down
272 changes: 269 additions & 3 deletions ot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,121 @@
return x_t


def reduce_lazytensor(a, func, axis=None, nx=None, batch_size=100):
""" Reduce a LazyTensor along an axis with function fun using batches.
rflamary marked this conversation as resolved.
Show resolved Hide resolved

When axis=None, reduce the LazyTensor to a scalar as a sum of fun over
batches taken along dim.

.. warning::
This function works for tensor of any order but the reduction can be done
only along the first two axis (or global). Also, in order to work, it requires that the slice of size `batch_size` along the axis to reduce (or axis 0 if `axis=None`) is can be computed and fits in memory.


Parameters
----------
a : LazyTensor
LazyTensor to reduce
func : callable
Function to apply to the LazyTensor
axis : int, optional
Axis along which to reduce the LazyTensor. If None, reduce the
LazyTensor to a scalar as a sum of fun over batches taken along axis 0.
If 0 or 1 reduce the LazyTensor to a vector/matrix as a sum of fun over
batches taken along axis.
nx : Backend, optional
Backend to use for the reduction
batch_size : int, optional
Size of the batches to use for the reduction (default=100)

Returns
-------
res : array-like
Result of the reduction

"""

if nx is None:
nx = get_backend(a[0])

if axis is None:
res = 0.0
for i in range(0, a.shape[0], batch_size):
res += func(a[i:i + batch_size])
return res
elif axis == 0:
res = nx.zeros(a.shape[1:], type_as=a[0])
if nx.__name__ in ["jax", "tf"]:
lst = []
for j in range(0, a.shape[1], batch_size):
lst.append(func(a[:, j:j + batch_size], 0))
return nx.concatenate(lst, axis=0)
else:
for j in range(0, a.shape[1], batch_size):
res[j:j + batch_size] = func(a[:, j:j + batch_size], axis=0)
return res
elif axis == 1:
if len(a.shape) == 2:
shape = (a.shape[0])
else:
shape = (a.shape[0], *a.shape[2:])
res = nx.zeros(shape, type_as=a[0])
if nx.__name__ in ["jax", "tf"]:
lst = []
for i in range(0, a.shape[0], batch_size):
lst.append(func(a[i:i + batch_size], 1))
return nx.concatenate(lst, axis=0)
else:
for i in range(0, a.shape[0], batch_size):
res[i:i + batch_size] = func(a[i:i + batch_size], axis=1)
return res

else:
raise (NotImplementedError("Only axis=None, 0 or 1 is implemented for now."))


def get_lowrank_lazytensor(Q, R, d=None, nx=None):
""" Get a low rank LazyTensor T=Q@R^T or T=Q@diag(d)@R^T

Parameters
----------
Q : ndarray, shape (n, r)
First factor of the lowrank tensor
R : ndarray, shape (m, r)
Second factor of the lowrank tensor
d : ndarray, shape (r,), optional
Diagonal of the lowrank tensor
nx : Backend, optional
Backend to use for the reduction

Returns
-------
T : LazyTensor
Lowrank tensor T=Q@R^T or T=Q@diag(d)@R^T
"""

if nx is None:
nx = get_backend(Q, R, d)

shape = (Q.shape[0], R.shape[0])

if d is None:

def func(i, j, Q, R):
return nx.dot(Q[i], R[j].T)

T = LazyTensor(shape, func, Q=Q, R=R)

else:

def func(i, j, Q, R, d):
return nx.dot(Q[i] * d[None, :], R[j].T)

T = LazyTensor(shape, func, Q=Q, R=R, d=d)

return T


def get_parameter_pair(parameter):
r"""Extract a pair of parameters from a given parameter
Used in unbalanced OT and COOT solvers
Expand Down Expand Up @@ -761,7 +876,76 @@


class OTResult:
def __init__(self, potentials=None, value=None, value_linear=None, value_quad=None, plan=None, log=None, backend=None, sparse_plan=None, lazy_plan=None, status=None):
""" Base class for OT results.

Parameters
----------

potentials : tuple of array-like, shape (`n1`, `n2`)
Dual potentials, i.e. Lagrange multipliers for the marginal constraints.
This pair of arrays has the same shape, numerical type
and properties as the input weights "a" and "b".
value : float, array-like
Full transport cost, including possible regularization terms and
quadratic term for Gromov Wasserstein solutions.
value_linear : float, array-like
The linear part of the transport cost, i.e. the product between the
transport plan and the cost.
value_quad : float, array-like
The quadratic part of the transport cost for Gromov-Wasserstein
solutions.
plan : array-like, shape (`n1`, `n2`)
Transport plan, encoded as a dense array.
log : dict
Dictionary containing potential information about the solver.
backend : Backend
Backend used to compute the results.
sparse_plan : array-like, shape (`n1`, `n2`)
Transport plan, encoded as a sparse array.
lazy_plan : LazyTensor
Transport plan, encoded as a symbolic POT or KeOps LazyTensor.
status : int or str
Status of the solver.
batch_size : int
Batch size used to compute the results/marginals for LazyTensor.

Attributes
----------

potentials : tuple of array-like, shape (`n1`, `n2`)
Dual potentials, i.e. Lagrange multipliers for the marginal constraints.
This pair of arrays has the same shape, numerical type
and properties as the input weights "a" and "b".
potential_a : array-like, shape (`n1`,)
First dual potential, associated to the "source" measure "a".
potential_b : array-like, shape (`n2`,)
Second dual potential, associated to the "target" measure "b".
value : float, array-like
Full transport cost, including possible regularization terms and
quadratic term for Gromov Wasserstein solutions.
value_linear : float, array-like
The linear part of the transport cost, i.e. the product between the
transport plan and the cost.
value_quad : float, array-like
The quadratic part of the transport cost for Gromov-Wasserstein
solutions.
plan : array-like, shape (`n1`, `n2`)
Transport plan, encoded as a dense array.
sparse_plan : array-like, shape (`n1`, `n2`)
Transport plan, encoded as a sparse array.
lazy_plan : LazyTensor
Transport plan, encoded as a symbolic POT or KeOps LazyTensor.
marginals : tuple of array-like, shape (`n1`,), (`n2`,)
Marginals of the transport plan: should be very close to "a" and "b"
for balanced OT.
marginal_a : array-like, shape (`n1`,)
Marginal of the transport plan for the "source" measure "a".
marginal_b : array-like, shape (`n2`,)
Marginal of the transport plan for the "target" measure "b".

"""

def __init__(self, potentials=None, value=None, value_linear=None, value_quad=None, plan=None, log=None, backend=None, sparse_plan=None, lazy_plan=None, status=None, batch_size=100):

self._potentials = potentials
self._value = value
Expand All @@ -773,6 +957,7 @@
self._lazy_plan = lazy_plan
self._backend = backend if backend is not None else NumpyBackend()
self._status = status
self._batch_size = batch_size

# I assume that other solvers may return directly
# some primal objects?
Expand All @@ -793,7 +978,8 @@
s += 'value_linear={},'.format(self._value_linear)
if self._plan is not None:
s += 'plan={}(shape={}),'.format(self._plan.__class__.__name__, self._plan.shape)

if self._lazy_plan is not None:
s += 'lazy_plan={}(shape={}),'.format(self._lazy_plan.__class__.__name__, self._lazy_plan.shape)

Check warning on line 982 in ot/utils.py

View check run for this annotation

Codecov / codecov/patch

ot/utils.py#L982

Added line #L982 was not covered by tests
if s[-1] != '(':
s = s[:-1] + ')'
else:
Expand Down Expand Up @@ -853,7 +1039,10 @@
@property
def lazy_plan(self):
"""Transport plan, encoded as a symbolic KeOps LazyTensor."""
raise NotImplementedError()
if self._lazy_plan is not None:
return self._lazy_plan

Check warning on line 1043 in ot/utils.py

View check run for this annotation

Codecov / codecov/patch

ot/utils.py#L1043

Added line #L1043 was not covered by tests
else:
raise NotImplementedError()

# Loss values --------------------------------

Expand Down Expand Up @@ -897,6 +1086,11 @@
"""First marginal of the transport plan, with the same shape as "a"."""
if self._plan is not None:
return self._backend.sum(self._plan, 1)
elif self._lazy_plan is not None:
lp = self._lazy_plan
bs = self._batch_size
nx = self._backend
return reduce_lazytensor(lp, nx.sum, axis=1, nx=nx, batch_size=bs)
else:
raise NotImplementedError()

Expand All @@ -905,6 +1099,11 @@
"""Second marginal of the transport plan, with the same shape as "b"."""
if self._plan is not None:
return self._backend.sum(self._plan, 0)
elif self._lazy_plan is not None:
lp = self._lazy_plan
bs = self._batch_size
nx = self._backend
return reduce_lazytensor(lp, nx.sum, axis=0, nx=nx, batch_size=bs)
else:
raise NotImplementedError()

Expand Down Expand Up @@ -968,3 +1167,70 @@
url = {http://jmlr.org/papers/v22/20-451.html}
}
"""


class LazyTensor(object):
""" A lazy tensor is a tensor that is not stored in memory. Instead, it is
defined by a function that computes its values on the fly from slices.

Parameters
----------

shape : tuple
shape of the tensor
getitem : callable
function that computes the values of the indices/slices and tensors
as arguments

kwargs : dict
named arguments for the function, those names will be used as attributed
of the LazyTensor object

Examples
--------
>>> import numpy as np
>>> v = np.arange(5)
>>> def getitem(i,j, v):
... return v[i,None]+v[None,j]
>>> T = LazyTensor((5,5),getitem, v=v)
>>> T[1,2]
array([3])
>>> T[1,:]
array([[1, 2, 3, 4, 5]])
>>> T[:]
array([[0, 1, 2, 3, 4],
[1, 2, 3, 4, 5],
[2, 3, 4, 5, 6],
[3, 4, 5, 6, 7],
[4, 5, 6, 7, 8]])

"""

def __init__(self, shape, getitem, **kwargs):

self._getitem = getitem
self.shape = shape
self.ndim = len(shape)
self.kwargs = kwargs

# set attributes for named arguments/arrays
for key, value in kwargs.items():
setattr(self, key, value)

def __getitem__(self, key):
k = []
if isinstance(key, int) or isinstance(key, slice):
k.append(key)
for i in range(self.ndim - 1):
k.append(slice(None))
elif isinstance(key, tuple):
k = list(key)
for i in range(self.ndim - len(key)):
k.append(slice(None))
else:
raise NotImplementedError("Only integer, slice, and tuple indexing is supported")

return self._getitem(*k, **self.kwargs)

def __repr__(self):
return "LazyTensor(shape={},attributes=({}))".format(self.shape, ','.join(self.kwargs.keys()))
1 change: 1 addition & 0 deletions test/test_factored.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def test_factored_ot():
# check constraints
np.testing.assert_allclose(u, Ga.sum(1))
np.testing.assert_allclose(u, Gb.sum(0))
np.testing.assert_allclose(1, log['lazy_plan'][:].sum())


def test_factored_ot_backends(nx):
Expand Down
Loading
Loading