Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] srFGW barycenters #659

Merged
merged 11 commits into from
Jul 19, 2024
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ POT provides the following generic OT solvers (links to examples):
* [Wasserstein distance on the circle](https://pythonot.github.io/auto_examples/plot_compute_wasserstein_circle.html) [44, 45]
* [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46]
* [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38].
* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) (exact and regularized [48]).
* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) with corresponding barycenter solvers (exact and regularized [48]).
* [Quantized (Fused) Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_quantized_gromov_wasserstein.html) [68].
* [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://pythonot.github.io/auto_examples/others/plot_demd_gradient_minimize.html) [50].
* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays.
Expand Down
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#### New features
- Add feature `mass=True` for `nx.kl_div` (PR #654)
- Add feature `semirelaxed_fgw_barycenters` and generic FGW-related barycenter updates `update_barycenter_structure` and `update_barycenter_feature` (PR #659)

#### Closed issues

Expand Down
14 changes: 10 additions & 4 deletions ot/gromov/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# All submodules and packages
from ._utils import (init_matrix, tensor_product, gwloss, gwggrad,
update_square_loss, update_kl_loss, update_feature_matrix,
init_matrix_semirelaxed)
init_matrix_semirelaxed,
update_barycenter_structure, update_barycenter_feature,
)

from ._gw import (gromov_wasserstein, gromov_wasserstein2,
fused_gromov_wasserstein, fused_gromov_wasserstein2,
Expand Down Expand Up @@ -40,14 +42,16 @@
entropic_semirelaxed_gromov_wasserstein,
entropic_semirelaxed_gromov_wasserstein2,
entropic_semirelaxed_fused_gromov_wasserstein,
entropic_semirelaxed_fused_gromov_wasserstein2)
entropic_semirelaxed_fused_gromov_wasserstein2,
semirelaxed_fgw_barycenters)

from ._dictionary import (gromov_wasserstein_dictionary_learning,
gromov_wasserstein_linear_unmixing,
fused_gromov_wasserstein_dictionary_learning,
fused_gromov_wasserstein_linear_unmixing)

from ._lowrank import (_flat_product_operator, lowrank_gromov_wasserstein_samples)
from ._lowrank import (_flat_product_operator,
lowrank_gromov_wasserstein_samples)


from ._quantized import (quantized_fused_gromov_wasserstein_partitioned,
Expand All @@ -62,6 +66,7 @@

__all__ = ['init_matrix', 'tensor_product', 'gwloss', 'gwggrad', 'update_square_loss',
'update_kl_loss', 'update_feature_matrix', 'init_matrix_semirelaxed',
'update_barycenter_structure', 'update_barycenter_feature',
'gromov_wasserstein', 'gromov_wasserstein2', 'fused_gromov_wasserstein',
'fused_gromov_wasserstein2', 'solve_gromov_linesearch', 'gromov_barycenters',
'fgw_barycenters', 'entropic_gromov_wasserstein', 'entropic_gromov_wasserstein2',
Expand All @@ -80,4 +85,5 @@
'quantized_fused_gromov_wasserstein_partitioned', 'get_graph_partition',
'get_graph_representants', 'format_partitioned_graph',
'quantized_fused_gromov_wasserstein', 'get_partition_and_representants_samples',
'format_partitioned_samples', 'quantized_fused_gromov_wasserstein_samples']
'format_partitioned_samples', 'quantized_fused_gromov_wasserstein_samples',
'semirelaxed_fgw_barycenters']
250 changes: 248 additions & 2 deletions ot/gromov/_semirelaxed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,16 @@
import numpy as np


from ..utils import list_to_array, unif
from ..utils import (
list_to_array, unif, dist, UndefinedParameter, check_random_state
)
from ..optim import semirelaxed_cg, solve_1d_linesearch_quad
from ..backend import get_backend

from ._utils import init_matrix_semirelaxed, gwloss, gwggrad
from ._utils import (
init_matrix_semirelaxed, gwloss, gwggrad,
update_barycenter_structure, update_barycenter_feature,
)


def semirelaxed_gromov_wasserstein(C1, C2, p=None, loss_fun='square_loss', symmetric=None, log=False, G0=None,
Expand Down Expand Up @@ -1100,3 +1105,244 @@ def entropic_semirelaxed_fused_gromov_wasserstein2(
return log_srfgw['srfgw_dist'], log_srfgw
else:
return log_srfgw['srfgw_dist']


def semirelaxed_fgw_barycenters(
N, Ys, Cs, ps=None, lambdas=None, alpha=0.5, fixed_structure=False,
fixed_features=False, p=None, loss_fun='square_loss',
symmetric=True, max_iter=100, tol=1e-9, stop_criterion='barycenter',
warmstartT=False, verbose=False, log=False, init_C=None, init_X=None,
random_state=None, **kwargs):
r"""
Returns the Semi-relaxed Fused Gromov-Wasserstein barycenters of `S` measurable networks
with node features :math:`(\mathbf{C}_s, \mathbf{Y}_s, \mathbf{p}_s)_{1 \leq s \leq S}`
(see eq (44) in :ref:`[48]`, estimated using the semi-relaxed FGW transports from Conditional Gradient solvers.

The function solves the following optimization problem:

.. math::

\mathbf{C}^*, \mathbf{Y}^* = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}, \mathbf{Y}\in \mathbb{Y}^{N \times d}}
\quad \sum_s \lambda_s \mathrm{srFGW}_{\alpha}(\mathbf{C}_s, \mathbf{Y}_s, \mathbf{p}_s, \mathbf{C}, \mathbf{Y})

Where :

- :math:`\mathbf{Y}_s`: input feature matrix
- :math:`\mathbf{C}_s`: input metric cost matrix
- :math:`\mathbf{p}_s`: input distribution

Parameters
----------
N : int
Desired number of samples of the target barycenter
Ys: list of array-like, each element has shape (ns,d)
Features of all samples
Cs : list of array-like, each element has shape (ns,ns)
Structure matrices of all samples
ps : list of array-like, each element has shape (ns,), optional
Masses of all samples.
If let to its default value None, uniform distributions are taken.
lambdas : list of float, optional
List of the `S` spaces' weights.
If let to its default value None, uniform weights are taken.
alpha : float, optional
Alpha parameter for the srFGW divergence in :math:`]0, 1[`.
fixed_structure : bool, optional
Whether to fix the structure of the barycenter during the updates.
fixed_features : bool, optional
Whether to fix the feature of the barycenter during the updates
loss_fun : str, optional
Loss function used for the solver either 'square_loss' or 'kl_loss'
symmetric : bool, optional
Either structures are to be assumed symmetric or not. Default value is True.
Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric).
max_iter : int, optional
Max number of iterations
tol : float, optional
Stop threshold on relative error (>0)
stop_criterion : str, optional. Default is 'barycenter'.
Stop criterion taking values in ['barycenter', 'loss']. If set to 'barycenter'
uses absolute norm variations of estimated barycenters. Else if set to 'loss'
uses the relative variations of the loss.
warmstartT: bool, optional
Either to perform warmstart of transport plans in the successive
fused gromov-wasserstein transport problems.
verbose : bool, optional
Print information along iterations.
log : bool, optional
Record log if True.
init_C : array-like, shape (N,N), optional
Initialization for the barycenters' structure matrix. If not set
a random init is used.
init_X : array-like, shape (N,d), optional
Initialization for the barycenters' features. If not set a
random init is used.
random_state : int or RandomState instance, optional
Fix the seed for reproducibility

Returns
-------
X : array-like, shape (`N`, `d`)
Barycenters' features
C : array-like, shape (`N`, `N`)
Barycenters' structure matrix
log : dict
Only returned when log=True. It contains the keys:

- :math:`\mathbf{T}_s`: list of (`N`, `ns`) transport matrices from which target masses can be deduced.
- :math:`(\mathbf{M}_s)_s`: all distance matrices between the feature of the barycenter and the other features :math:`(dist(\mathbf{X}, \mathbf{Y}_s))_s` shape (`N`, `ns`)
- values used in convergence evaluation.

References
----------
.. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
International Conference on Learning Representations (ICLR), 2022.
"""
if stop_criterion not in ['barycenter', 'loss']:
raise ValueError(f"Unknown `stop_criterion='{stop_criterion}'`. Use one of: {'barycenter', 'loss'}.")

arr = [*Cs, *Ys]
if ps is not None:
arr += [*ps]
else:
ps = [unif(C.shape[0], type_as=C) for C in Cs]

nx = get_backend(*arr)

S = len(Cs)
if lambdas is None:
lambdas = [1. / S] * S

d = Ys[0].shape[1] # dimension on the node features

if fixed_structure:
if init_C is None:
raise UndefinedParameter('If C is fixed it must be initialized')
else:
C = init_C
else:
if init_C is None:
generator = check_random_state(random_state)
cedricvincentcuaz marked this conversation as resolved.
Show resolved Hide resolved
xalea = generator.randn(N, 2)
C = dist(xalea, xalea)
C = nx.from_numpy(C, type_as=ps[0])
else:
C = init_C

if fixed_features:
if init_X is None:
raise UndefinedParameter('If X is fixed it must be initialized')
cedricvincentcuaz marked this conversation as resolved.
Show resolved Hide resolved
else:
X = init_X
else:
if init_X is None:
X = nx.zeros((N, d), type_as=ps[0])

else:
X = init_X

Ms = [dist(Ys[s], X) for s in range(len(Ys))]

if warmstartT:
T = [None] * S

cpt = 0

if stop_criterion == 'barycenter':
inner_log = False
err_feature = 1e15
err_structure = 1e15
err_rel_loss = 0.

else:
inner_log = True
err_feature = 0.
err_structure = 0.
curr_loss = 1e15
err_rel_loss = 1e15

if log:
log_ = {}
if stop_criterion == 'barycenter':
log_['err_feature'] = []
log_['err_structure'] = []
else:
log_['loss'] = []
log_['err_rel_loss'] = []

while ((err_feature > tol or err_structure > tol or err_rel_loss > tol) and cpt < max_iter):
cedricvincentcuaz marked this conversation as resolved.
Show resolved Hide resolved
print('-- cpt :', cpt)
if stop_criterion == 'barycenter':
Cprev = C
Xprev = X
else:
prev_loss = curr_loss

# get transport plans
if warmstartT:
res = [semirelaxed_fused_gromov_wasserstein(
Ms[s], Cs[s], C, ps[s], loss_fun, symmetric, alpha, T[s],
inner_log, max_iter, tol_rel=1e-5, tol_abs=0., **kwargs)
for s in range(S)]
else:
res = [semirelaxed_fused_gromov_wasserstein(
Ms[s], Cs[s], C, ps[s], loss_fun, symmetric, alpha, None,
inner_log, max_iter, tol_rel=1e-5, tol_abs=0., **kwargs)
for s in range(S)]
if stop_criterion == 'barycenter':
T = res
else:
T = [output[0] for output in res]
curr_loss = np.sum([output[1]['srfgw_dist'] for output in res])

# update barycenters
p = nx.concatenate(
[nx.sum(T[s], 0)[None, :] for s in range(S)], axis=0)

if not fixed_features:
X = update_barycenter_feature(T, Ys, lambdas, p, nx=nx)
Ms = [dist(Ys[s], X) for s in range(len(Ys))]

if not fixed_structure:
C = update_barycenter_structure(T, Cs, lambdas, p, loss_fun, nx=nx)

# update convergence criterion
if stop_criterion == 'barycenter':
err_feature, err_structure = 0., 0.
if not fixed_features:
err_feature = nx.norm(X - Xprev)
if not fixed_structure:
err_structure = nx.norm(C - Cprev)
if log:
log_['err_feature'].append(err_feature)
log_['err_structure'].append(err_structure)

if verbose:
if cpt % 200 == 0:
print('{:5s}|{:12s}'.format(
'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err_structure))
print('{:5d}|{:8e}|'.format(cpt, err_feature))
else:
err_rel_loss = abs(curr_loss - prev_loss) / prev_loss if prev_loss != 0. else np.nan
if log:
log_['loss'].append(curr_loss)
log_['err_rel_loss'].append(err_rel_loss)

if verbose:
if cpt % 200 == 0:
print('{:5s}|{:12s}'.format(
'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err_rel_loss))

cpt += 1

if log:
log_['T'] = T
log_['p'] = p
log_['Ms'] = Ms

return X, C, log_
else:
return X, C
Loading
Loading