Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Add LazyTensor for large scale OT #544

Merged
merged 14 commits into from
Oct 31, 2023
230 changes: 227 additions & 3 deletions ot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,79 @@
return x_t


def reduce_lazytensor(a, fun, axis=None, nx=None, batch_size=None):
""" Reduce a LazyTensor along an axis with function fun using batches.
rflamary marked this conversation as resolved.
Show resolved Hide resolved

When axis=None, reduce the LazyTensor to a scalar as a sum of fun over
batches taken along dim.

Parameters
----------

rflamary marked this conversation as resolved.
Show resolved Hide resolved
a : LazyTensor
LazyTensor to reduce
fun : callable
rflamary marked this conversation as resolved.
Show resolved Hide resolved
Function to apply to the LazyTensor
axis : int, optional
Axis along which to reduce the LazyTensor. If None, reduce the
LazyTensor to a scalar as a sum of fun over batches taken along axis 0.
If 0 or 1 reduce the LazyTensor to a vector/matrix as a sum of fun over
batches taken along axis.
nx : Backend, optional
Backend to use for the reduction
batch_size : int, optional
Size of the batches to use for the reduction (default=100)

Returns
-------

rflamary marked this conversation as resolved.
Show resolved Hide resolved
res : array-like
Result of the reduction

"""

if nx is None:
nx = get_backend(a[0])

if batch_size is None:
batch_size = 100
rflamary marked this conversation as resolved.
Show resolved Hide resolved

if axis is None:
res = 0.0
for i in range(0, a.shape[0], batch_size):
res += fun(a[i:i + batch_size])
return res
elif axis == 0:
res = nx.zeros(a.shape[1:], type_as=a[0])
if nx.__name__ in ["jax", "tf"]:
lst = []
for j in range(0, a.shape[1], batch_size):
lst.append(fun(a[:, j:j + batch_size], 0))
return nx.concatenate(lst, axis=0)
else:
for j in range(0, a.shape[1], batch_size):
res[j:j + batch_size] = fun(a[:, j:j + batch_size], axis=0)
return res
elif axis == 1:
if len(a.shape) == 2:
shape = (a.shape[0])
else:
shape = (a.shape[0], *a.shape[2:])

Check warning on line 552 in ot/utils.py

View check run for this annotation

Codecov / codecov/patch

ot/utils.py#L552

Added line #L552 was not covered by tests
res = nx.zeros(shape, type_as=a[0])
if nx.__name__ in ["jax", "tf"]:
lst = []
for i in range(0, a.shape[0], batch_size):
lst.append(fun(a[i:i + batch_size], 1))
return nx.concatenate(lst, axis=0)
else:
for i in range(0, a.shape[0], batch_size):
res[i:i + batch_size] = fun(a[i:i + batch_size], axis=1)
return res

else:
raise (NotImplementedError("Only axis=None is implemented for now."))

Check warning on line 565 in ot/utils.py

View check run for this annotation

Codecov / codecov/patch

ot/utils.py#L565

Added line #L565 was not covered by tests


def get_parameter_pair(parameter):
r"""Extract a pair of parameters from a given parameter
Used in unbalanced OT and COOT solvers
Expand Down Expand Up @@ -761,7 +834,76 @@


class OTResult:
def __init__(self, potentials=None, value=None, value_linear=None, value_quad=None, plan=None, log=None, backend=None, sparse_plan=None, lazy_plan=None, status=None):
""" Base class for OT results.

Parameters
----------

potentials : tuple of array-like, shape (`n1`, `n2`)
Dual potentials, i.e. Lagrange multipliers for the marginal constraints.
This pair of arrays has the same shape, numerical type
and properties as the input weights "a" and "b".
value : float, array-like
Full transport cost, including possible regularization terms and
quadratic term for Gromov Wasserstein solutions.
value_linear : float, array-like
The linear part of the transport cost, i.e. the product between the
transport plan and the cost.
value_quad : float, array-like
The quadratic part of the transport cost for Gromov-Wasserstein
solutions.
plan : array-like, shape (`n1`, `n2`)
Transport plan, encoded as a dense array.
log : dict
Dictionary containing potential information about the solver.
backend : Backend
Backend used to compute the results.
sparse_plan : array-like, shape (`n1`, `n2`)
Transport plan, encoded as a sparse array.
lazy_plan : LazyTensor
Transport plan, encoded as a symbolic POT or KeOps LazyTensor.
status : int or str
Status of the solver.
batch_size : int
Batch size used to compute the results/marginals for LazyTensor.

Attributes
----------

potentials : tuple of array-like, shape (`n1`, `n2`)
Dual potentials, i.e. Lagrange multipliers for the marginal constraints.
This pair of arrays has the same shape, numerical type
and properties as the input weights "a" and "b".
potential_a : array-like, shape (`n1`,)
First dual potential, associated to the "source" measure "a".
potential_b : array-like, shape (`n2`,)
Second dual potential, associated to the "target" measure "b".
value : float, array-like
Full transport cost, including possible regularization terms and
quadratic term for Gromov Wasserstein solutions.
value_linear : float, array-like
The linear part of the transport cost, i.e. the product between the
transport plan and the cost.
value_quad : float, array-like
The quadratic part of the transport cost for Gromov-Wasserstein
solutions.
plan : array-like, shape (`n1`, `n2`)
Transport plan, encoded as a dense array.
sparse_plan : array-like, shape (`n1`, `n2`)
Transport plan, encoded as a sparse array.
lazy_plan : LazyTensor
Transport plan, encoded as a symbolic POT or KeOps LazyTensor.
marginals : tuple of array-like, shape (`n1`,), (`n2`,)
Marginals of the transport plan: should be very close to "a" and "b"
for balanced OT.
marginal_a : array-like, shape (`n1`,)
Marginal of the transport plan for the "source" measure "a".
marginal_b : array-like, shape (`n2`,)
Marginal of the transport plan for the "target" measure "b".

"""

def __init__(self, potentials=None, value=None, value_linear=None, value_quad=None, plan=None, log=None, backend=None, sparse_plan=None, lazy_plan=None, status=None, batch_size=100):

self._potentials = potentials
self._value = value
Expand All @@ -773,6 +915,7 @@
self._lazy_plan = lazy_plan
self._backend = backend if backend is not None else NumpyBackend()
self._status = status
self._batch_size = batch_size

# I assume that other solvers may return directly
# some primal objects?
Expand All @@ -793,7 +936,8 @@
s += 'value_linear={},'.format(self._value_linear)
if self._plan is not None:
s += 'plan={}(shape={}),'.format(self._plan.__class__.__name__, self._plan.shape)

if self._lazy_plan is not None:
s += 'lazy_plan={}(shape={}),'.format(self._lazy_plan.__class__.__name__, self._lazy_plan.shape)

Check warning on line 940 in ot/utils.py

View check run for this annotation

Codecov / codecov/patch

ot/utils.py#L940

Added line #L940 was not covered by tests
if s[-1] != '(':
s = s[:-1] + ')'
else:
Expand Down Expand Up @@ -853,7 +997,10 @@
@property
def lazy_plan(self):
"""Transport plan, encoded as a symbolic KeOps LazyTensor."""
raise NotImplementedError()
if self._lazy_plan is not None:
return self._lazy_plan

Check warning on line 1001 in ot/utils.py

View check run for this annotation

Codecov / codecov/patch

ot/utils.py#L1001

Added line #L1001 was not covered by tests
else:
raise NotImplementedError()

# Loss values --------------------------------

Expand Down Expand Up @@ -897,6 +1044,11 @@
"""First marginal of the transport plan, with the same shape as "a"."""
if self._plan is not None:
return self._backend.sum(self._plan, 1)
elif self._lazy_plan is not None:
lp = self._lazy_plan
bs = self._batch_size
nx = self._backend
return reduce_lazytensor(lp, nx.sum, axis=1, nx=nx, batch_size=bs)
else:
raise NotImplementedError()

Expand All @@ -905,6 +1057,11 @@
"""Second marginal of the transport plan, with the same shape as "b"."""
if self._plan is not None:
return self._backend.sum(self._plan, 0)
elif self._lazy_plan is not None:
lp = self._lazy_plan
bs = self._batch_size
nx = self._backend
return reduce_lazytensor(lp, nx.sum, axis=0, nx=nx, batch_size=bs)
else:
raise NotImplementedError()

Expand Down Expand Up @@ -968,3 +1125,70 @@
url = {http://jmlr.org/papers/v22/20-451.html}
}
"""


class LazyTensor(object):
""" A lazy tensor is a tensor that is not stored in memory. Instead, it is
defined by a function that computes its values on the fly from slices.

Parameters
----------

shape : tuple
shape of the tensor
getitem : callable
function that computes the values of the indices/slices and tensors
as arguments

kwargs : dict
named arguments for the function, those names will be used as attributed
of the LazyTensor object

Examples
--------
>>> import numpy as np
>>> v = np.arange(5)
>>> def getitem(i,j, v):
... return v[i,None]+v[None,j]
>>> T = LazyTensor((5,5),getitem, v=v)
>>> T[1,2]
array([3])
>>> T[1,:]
array([[1, 2, 3, 4, 5]])
>>> T[:]
array([[0, 1, 2, 3, 4],
[1, 2, 3, 4, 5],
[2, 3, 4, 5, 6],
[3, 4, 5, 6, 7],
[4, 5, 6, 7, 8]])

"""

def __init__(self, shape, getitem, **kwargs):

self._getitem = getitem
self.shape = shape
self.ndim = len(shape)
self.kwargs = kwargs

# set attributes for named arguments/arrays
for key, value in kwargs.items():
setattr(self, key, value)

def __getitem__(self, key):
k = []
if isinstance(key, int) or isinstance(key, slice):
k.append(key)
for i in range(self.ndim - 1):
k.append(slice(None))
elif isinstance(key, tuple):
k = list(key)
for i in range(self.ndim - len(key)):
k.append(slice(None))

Check warning on line 1187 in ot/utils.py

View check run for this annotation

Codecov / codecov/patch

ot/utils.py#L1187

Added line #L1187 was not covered by tests
else:
raise NotImplementedError("Only integer, slice, and tuple indexing is supported")

Check warning on line 1189 in ot/utils.py

View check run for this annotation

Codecov / codecov/patch

ot/utils.py#L1189

Added line #L1189 was not covered by tests

return self._getitem(*k, **self.kwargs)

def __repr__(self):
return "LazyTensor(shape={},attributes=({}))".format(self.shape, ','.join(self.kwargs.keys()))
Loading
Loading