diff --git a/README.md b/README.md index 7f2ce3ee3..011af46d4 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,7 @@ POT provides the following generic OT solvers (links to examples): * [Wasserstein distance on the circle](https://pythonot.github.io/auto_examples/plot_compute_wasserstein_circle.html) [44, 45] * [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46] * [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38]. -* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) (exact and regularized [48]). +* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) with corresponding barycenter solvers (exact and regularized [48]). * [Quantized (Fused) Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_quantized_gromov_wasserstein.html) [68]. * [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://pythonot.github.io/auto_examples/others/plot_demd_gradient_minimize.html) [50]. * [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays. diff --git a/RELEASES.md b/RELEASES.md index e5a8ac54b..0812508be 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -4,6 +4,7 @@ #### New features - Add feature `mass=True` for `nx.kl_div` (PR #654) +- Add feature `semirelaxed_fgw_barycenters` and generic FGW-related barycenter updates `update_barycenter_structure` and `update_barycenter_feature` (PR #659) #### Closed issues diff --git a/ot/gromov/__init__.py b/ot/gromov/__init__.py index 03663dab4..675f42ccb 100644 --- a/ot/gromov/__init__.py +++ b/ot/gromov/__init__.py @@ -11,8 +11,9 @@ # All submodules and packages from ._utils import (init_matrix, tensor_product, gwloss, gwggrad, - update_square_loss, update_kl_loss, update_feature_matrix, - init_matrix_semirelaxed) + init_matrix_semirelaxed, + update_barycenter_structure, update_barycenter_feature, + ) from ._gw import (gromov_wasserstein, gromov_wasserstein2, fused_gromov_wasserstein, fused_gromov_wasserstein2, @@ -40,14 +41,16 @@ entropic_semirelaxed_gromov_wasserstein, entropic_semirelaxed_gromov_wasserstein2, entropic_semirelaxed_fused_gromov_wasserstein, - entropic_semirelaxed_fused_gromov_wasserstein2) + entropic_semirelaxed_fused_gromov_wasserstein2, + semirelaxed_fgw_barycenters) from ._dictionary import (gromov_wasserstein_dictionary_learning, gromov_wasserstein_linear_unmixing, fused_gromov_wasserstein_dictionary_learning, fused_gromov_wasserstein_linear_unmixing) -from ._lowrank import (_flat_product_operator, lowrank_gromov_wasserstein_samples) +from ._lowrank import (_flat_product_operator, + lowrank_gromov_wasserstein_samples) from ._quantized import (quantized_fused_gromov_wasserstein_partitioned, @@ -60,8 +63,9 @@ quantized_fused_gromov_wasserstein_samples ) -__all__ = ['init_matrix', 'tensor_product', 'gwloss', 'gwggrad', 'update_square_loss', - 'update_kl_loss', 'update_feature_matrix', 'init_matrix_semirelaxed', +__all__ = ['init_matrix', 'tensor_product', 'gwloss', 'gwggrad', + 'init_matrix_semirelaxed', + 'update_barycenter_structure', 'update_barycenter_feature', 'gromov_wasserstein', 'gromov_wasserstein2', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2', 'solve_gromov_linesearch', 'gromov_barycenters', 'fgw_barycenters', 'entropic_gromov_wasserstein', 'entropic_gromov_wasserstein2', @@ -80,4 +84,5 @@ 'quantized_fused_gromov_wasserstein_partitioned', 'get_graph_partition', 'get_graph_representants', 'format_partitioned_graph', 'quantized_fused_gromov_wasserstein', 'get_partition_and_representants_samples', - 'format_partitioned_samples', 'quantized_fused_gromov_wasserstein_samples'] + 'format_partitioned_samples', 'quantized_fused_gromov_wasserstein_samples', + 'semirelaxed_fgw_barycenters'] diff --git a/ot/gromov/_bregman.py b/ot/gromov/_bregman.py index 6bb7a675a..c60e786f7 100644 --- a/ot/gromov/_bregman.py +++ b/ot/gromov/_bregman.py @@ -19,7 +19,7 @@ from ..backend import get_backend from ._utils import init_matrix, gwloss, gwggrad -from ._utils import update_square_loss, update_kl_loss, update_feature_matrix +from ._utils import update_barycenter_structure, update_barycenter_feature def entropic_gromov_wasserstein( @@ -807,10 +807,8 @@ def entropic_gromov_barycenters( curr_loss = np.sum([output[1]['gw_dist'] for output in res]) # update barycenters - if loss_fun == 'square_loss': - C = update_square_loss(p, lambdas, T, Cs, nx) - elif loss_fun == 'kl_loss': - C = update_kl_loss(p, lambdas, T, Cs, nx) + C = update_barycenter_structure( + T, Cs, lambdas, p, loss_fun, target=False, check_zeros=False, nx=nx) # update convergence criterion if stop_criterion == 'barycenter': @@ -1651,13 +1649,14 @@ def entropic_fused_gromov_barycenters( # Initialization of C : random euclidean distance matrix (if not provided by user) if fixed_structure: if init_C is None: - raise UndefinedParameter('If C is fixed it must be initialized') + raise UndefinedParameter( + 'If C is fixed it must be provided in init_C') else: C = init_C else: if init_C is None: - generator = check_random_state(random_state) - xalea = generator.randn(N, 2) + rng = check_random_state(random_state) + xalea = rng.randn(N, 2) C = dist(xalea, xalea) C = nx.from_numpy(C, type_as=ps[0]) else: @@ -1666,7 +1665,8 @@ def entropic_fused_gromov_barycenters( # Initialization of Y if fixed_features: if init_Y is None: - raise UndefinedParameter('If Y is fixed it must be initialized') + raise UndefinedParameter( + 'If Y is fixed it must be provided in init_Y') else: Y = init_Y else: @@ -1681,20 +1681,12 @@ def entropic_fused_gromov_barycenters( if warmstartT: T = [None] * S - cpt = 0 - if stop_criterion == 'barycenter': inner_log = False - err_feature = 1e15 - err_structure = 1e15 - err_rel_loss = 0. else: inner_log = True - err_feature = 0. - err_structure = 0. curr_loss = 1e15 - err_rel_loss = 1e15 if log: log_ = {} @@ -1706,7 +1698,8 @@ def entropic_fused_gromov_barycenters( log_['loss'] = [] log_['err_rel_loss'] = [] - while ((err_feature > tol or err_structure > tol or err_rel_loss > tol) and cpt < max_iter): + for cpt in range(max_iter): # break if specified errors are below tol. + if stop_criterion == 'barycenter': Cprev = C Yprev = Y @@ -1732,16 +1725,14 @@ def entropic_fused_gromov_barycenters( # update barycenters if not fixed_features: - Ys_temp = [y.T for y in Ys] - X = update_feature_matrix(lambdas, Ys_temp, T, p, nx).T + X = update_barycenter_feature( + T, Ys, lambdas, p, target=False, check_zeros=False, nx=nx) + Ms = [dist(X, Ys[s]) for s in range(len(Ys))] if not fixed_structure: - if loss_fun == 'square_loss': - C = update_square_loss(p, lambdas, T, Cs, nx) - - elif loss_fun == 'kl_loss': - C = update_kl_loss(p, lambdas, T, Cs, nx) + C = update_barycenter_structure( + T, Cs, lambdas, p, loss_fun, target=False, check_zeros=False, nx=nx) # update convergence criterion if stop_criterion == 'barycenter': @@ -1761,6 +1752,9 @@ def entropic_fused_gromov_barycenters( 'It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(cpt, err_structure)) print('{:5d}|{:8e}|'.format(cpt, err_feature)) + + if (err_feature <= tol) or (err_structure <= tol): + break else: err_rel_loss = abs(curr_loss - prev_loss) / prev_loss if prev_loss != 0. else np.nan if log: @@ -1773,7 +1767,8 @@ def entropic_fused_gromov_barycenters( 'It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(cpt, err_rel_loss)) - cpt += 1 + if err_rel_loss <= tol: + break if log: log_['T'] = T diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py index 86ff566ea..1cbc98909 100644 --- a/ot/gromov/_gw.py +++ b/ot/gromov/_gw.py @@ -21,7 +21,7 @@ from ..backend import get_backend, NumpyBackend from ._utils import init_matrix, gwloss, gwggrad -from ._utils import update_square_loss, update_kl_loss, update_feature_matrix +from ._utils import update_barycenter_structure, update_barycenter_feature def gromov_wasserstein(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None, @@ -833,17 +833,14 @@ def gromov_barycenters( # Initialization of C : random SPD matrix (if not provided by user) if init_C is None: - generator = check_random_state(random_state) - xalea = generator.randn(N, 2) + rng = check_random_state(random_state) + xalea = rng.randn(N, 2) C = dist(xalea, xalea) C /= C.max() C = nx.from_numpy(C, type_as=p) else: C = init_C - cpt = 0 - err = 1e15 # either the error on 'barycenter' or 'loss' - if warmstartT: T = [None] * S @@ -859,7 +856,8 @@ def gromov_barycenters( if stop_criterion == 'loss': log_['loss'] = [] - while (err > tol and cpt < max_iter): + for cpt in range(max_iter): + if stop_criterion == 'barycenter': Cprev = C else: @@ -883,11 +881,8 @@ def gromov_barycenters( curr_loss = np.sum([output[1]['gw_dist'] for output in res]) # update barycenters - if loss_fun == 'square_loss': - C = update_square_loss(p, lambdas, T, Cs, nx) - - elif loss_fun == 'kl_loss': - C = update_kl_loss(p, lambdas, T, Cs, nx) + C = update_barycenter_structure( + T, Cs, lambdas, p, loss_fun, target=False, check_zeros=False, nx=nx) # update convergence criterion if stop_criterion == 'barycenter': @@ -907,7 +902,8 @@ def gromov_barycenters( 'It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(cpt, err)) - cpt += 1 + if err <= tol: + break if log: log_['T'] = T @@ -1046,13 +1042,14 @@ def fgw_barycenters( if fixed_structure: if init_C is None: - raise UndefinedParameter('If C is fixed it must be initialized') + raise UndefinedParameter( + 'If C is fixed it must be provided in init_C') else: C = init_C else: if init_C is None: - generator = check_random_state(random_state) - xalea = generator.randn(N, 2) + rng = check_random_state(random_state) + xalea = rng.randn(N, 2) C = dist(xalea, xalea) C = nx.from_numpy(C, type_as=ps[0]) else: @@ -1060,7 +1057,8 @@ def fgw_barycenters( if fixed_features: if init_X is None: - raise UndefinedParameter('If X is fixed it must be initialized') + raise UndefinedParameter( + 'If X is fixed it must be provided in init_X') else: X = init_X else: @@ -1075,20 +1073,12 @@ def fgw_barycenters( if warmstartT: T = [None] * S - cpt = 0 - if stop_criterion == 'barycenter': inner_log = False - err_feature = 1e15 - err_structure = 1e15 - err_rel_loss = 0. else: inner_log = True - err_feature = 0. - err_structure = 0. curr_loss = 1e15 - err_rel_loss = 1e15 if log: log_ = {} @@ -1100,7 +1090,8 @@ def fgw_barycenters( log_['loss'] = [] log_['err_rel_loss'] = [] - while ((err_feature > tol or err_structure > tol or err_rel_loss > tol) and cpt < max_iter): + for cpt in range(max_iter): # break if specified errors are below tol. + if stop_criterion == 'barycenter': Cprev = C Xprev = X @@ -1126,16 +1117,14 @@ def fgw_barycenters( # update barycenters if not fixed_features: - Ys_temp = [y.T for y in Ys] - X = update_feature_matrix(lambdas, Ys_temp, T, p, nx).T + X = update_barycenter_feature( + T, Ys, lambdas, p, target=False, check_zeros=False, nx=nx) + Ms = [dist(X, Ys[s]) for s in range(len(Ys))] if not fixed_structure: - if loss_fun == 'square_loss': - C = update_square_loss(p, lambdas, T, Cs, nx) - - elif loss_fun == 'kl_loss': - C = update_kl_loss(p, lambdas, T, Cs, nx) + C = update_barycenter_structure( + T, Cs, lambdas, p, loss_fun, target=False, check_zeros=False, nx=nx) # update convergence criterion if stop_criterion == 'barycenter': @@ -1155,6 +1144,9 @@ def fgw_barycenters( 'It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(cpt, err_structure)) print('{:5d}|{:8e}|'.format(cpt, err_feature)) + + if (err_feature <= tol) or (err_structure <= tol): + break else: err_rel_loss = abs(curr_loss - prev_loss) / prev_loss if prev_loss != 0. else np.nan if log: @@ -1167,7 +1159,8 @@ def fgw_barycenters( 'It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(cpt, err_rel_loss)) - cpt += 1 + if err_rel_loss <= tol: + break if log: log_['T'] = T diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py index 0137a8ed8..a777239d3 100644 --- a/ot/gromov/_semirelaxed.py +++ b/ot/gromov/_semirelaxed.py @@ -11,11 +11,16 @@ import numpy as np -from ..utils import list_to_array, unif +from ..utils import ( + list_to_array, unif, dist, UndefinedParameter, check_random_state +) from ..optim import semirelaxed_cg, solve_1d_linesearch_quad from ..backend import get_backend -from ._utils import init_matrix_semirelaxed, gwloss, gwggrad +from ._utils import ( + init_matrix_semirelaxed, gwloss, gwggrad, + update_barycenter_structure, update_barycenter_feature, +) def semirelaxed_gromov_wasserstein(C1, C2, p=None, loss_fun='square_loss', symmetric=None, log=False, G0=None, @@ -1100,3 +1105,242 @@ def entropic_semirelaxed_fused_gromov_wasserstein2( return log_srfgw['srfgw_dist'], log_srfgw else: return log_srfgw['srfgw_dist'] + + +def semirelaxed_fgw_barycenters( + N, Ys, Cs, ps=None, lambdas=None, alpha=0.5, fixed_structure=False, + fixed_features=False, p=None, loss_fun='square_loss', + symmetric=True, max_iter=100, tol=1e-9, stop_criterion='barycenter', + warmstartT=False, verbose=False, log=False, init_C=None, init_X=None, + random_state=None, **kwargs): + r""" + Returns the Semi-relaxed Fused Gromov-Wasserstein barycenters of `S` measurable networks + with node features :math:`(\mathbf{C}_s, \mathbf{Y}_s, \mathbf{p}_s)_{1 \leq s \leq S}` + (see eq (44) in :ref:`[48]`, estimated using the semi-relaxed FGW transports from Conditional Gradient solvers. + + The function solves the following optimization problem: + + .. math:: + + \mathbf{C}^*, \mathbf{Y}^* = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}, \mathbf{Y}\in \mathbb{Y}^{N \times d}} + \quad \sum_s \lambda_s \mathrm{srFGW}_{\alpha}(\mathbf{C}_s, \mathbf{Y}_s, \mathbf{p}_s, \mathbf{C}, \mathbf{Y}) + + Where : + + - :math:`\mathbf{Y}_s`: input feature matrix + - :math:`\mathbf{C}_s`: input metric cost matrix + - :math:`\mathbf{p}_s`: input distribution + + Parameters + ---------- + N : int + Desired number of samples of the target barycenter + Ys: list of array-like, each element has shape (ns,d) + Features of all samples + Cs : list of array-like, each element has shape (ns,ns) + Structure matrices of all samples + ps : list of array-like, each element has shape (ns,), optional + Masses of all samples. + If let to its default value None, uniform distributions are taken. + lambdas : list of float, optional + List of the `S` spaces' weights. + If let to its default value None, uniform weights are taken. + alpha : float, optional + Alpha parameter for the srFGW divergence in :math:`]0, 1[`. + fixed_structure : bool, optional + Whether to fix the structure of the barycenter during the updates. + fixed_features : bool, optional + Whether to fix the feature of the barycenter during the updates + loss_fun : str, optional + Loss function used for the solver either 'square_loss' or 'kl_loss' + symmetric : bool, optional + Either structures are to be assumed symmetric or not. Default value is True. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshold on relative error (>0) + stop_criterion : str, optional. Default is 'barycenter'. + Stop criterion taking values in ['barycenter', 'loss']. If set to 'barycenter' + uses absolute norm variations of estimated barycenters. Else if set to 'loss' + uses the relative variations of the loss. + warmstartT: bool, optional + Either to perform warmstart of transport plans in the successive + fused gromov-wasserstein transport problems. + verbose : bool, optional + Print information along iterations. + log : bool, optional + Record log if True. + init_C : array-like, shape (N,N), optional + Initialization for the barycenters' structure matrix. If not set + a random init is used. + init_X : array-like, shape (N,d), optional + Initialization for the barycenters' features. If not set a + random init is used. + random_state : int or RandomState instance, optional + Fix the seed for reproducibility + + Returns + ------- + X : array-like, shape (`N`, `d`) + Barycenters' features + C : array-like, shape (`N`, `N`) + Barycenters' structure matrix + log : dict + Only returned when log=True. It contains the keys: + + - :math:`\mathbf{T}_s`: list of (`N`, `ns`) transport matrices from which target masses can be deduced. + - :math:`(\mathbf{M}_s)_s`: all distance matrices between the feature of the barycenter and the other features :math:`(dist(\mathbf{X}, \mathbf{Y}_s))_s` shape (`N`, `ns`) + - values used in convergence evaluation. + + References + ---------- + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" + International Conference on Learning Representations (ICLR), 2022. + """ + if stop_criterion not in ['barycenter', 'loss']: + raise ValueError(f"Unknown `stop_criterion='{stop_criterion}'`. Use one of: {'barycenter', 'loss'}.") + + arr = [*Cs, *Ys] + if ps is not None: + arr += [*ps] + else: + ps = [unif(C.shape[0], type_as=C) for C in Cs] + + nx = get_backend(*arr) + + S = len(Cs) + if lambdas is None: + lambdas = [1. / S] * S + + d = Ys[0].shape[1] # dimension on the node features + + if fixed_structure: + if init_C is None: + raise UndefinedParameter( + 'If C is fixed it must be provided in init_C') + else: + C = init_C + else: + if init_C is None: + rng = check_random_state(random_state) + xalea = rng.randn(N, 2) + C = dist(xalea, xalea) + C = nx.from_numpy(C, type_as=ps[0]) + else: + C = init_C + + if fixed_features: + if init_X is None: + raise UndefinedParameter( + 'If X is fixed it must be provided in init_X') + else: + X = init_X + else: + if init_X is None: + X = nx.zeros((N, d), type_as=ps[0]) + + else: + X = init_X + + Ms = [dist(Ys[s], X) for s in range(len(Ys))] + + if warmstartT: + T = [None] * S + + if stop_criterion == 'barycenter': + inner_log = False + + else: + inner_log = True + curr_loss = 1e15 + + if log: + log_ = {} + if stop_criterion == 'barycenter': + log_['err_feature'] = [] + log_['err_structure'] = [] + else: + log_['loss'] = [] + log_['err_rel_loss'] = [] + + for cpt in range(max_iter): # break if specified errors are below tol. + + if stop_criterion == 'barycenter': + Cprev = C + Xprev = X + else: + prev_loss = curr_loss + + # get transport plans + if warmstartT: + res = [semirelaxed_fused_gromov_wasserstein( + Ms[s], Cs[s], C, ps[s], loss_fun, symmetric, alpha, T[s], + inner_log, max_iter, tol_rel=1e-5, tol_abs=0., **kwargs) + for s in range(S)] + else: + res = [semirelaxed_fused_gromov_wasserstein( + Ms[s], Cs[s], C, ps[s], loss_fun, symmetric, alpha, None, + inner_log, max_iter, tol_rel=1e-5, tol_abs=0., **kwargs) + for s in range(S)] + if stop_criterion == 'barycenter': + T = res + else: + T = [output[0] for output in res] + curr_loss = np.sum([output[1]['srfgw_dist'] for output in res]) + + # update barycenters + p = nx.concatenate( + [nx.sum(T[s], 0)[None, :] for s in range(S)], axis=0) + + if not fixed_features: + X = update_barycenter_feature(T, Ys, lambdas, p, nx=nx) + Ms = [dist(Ys[s], X) for s in range(len(Ys))] + + if not fixed_structure: + C = update_barycenter_structure(T, Cs, lambdas, p, loss_fun, nx=nx) + + # update convergence criterion + if stop_criterion == 'barycenter': + err_feature, err_structure = 0., 0. + if not fixed_features: + err_feature = nx.norm(X - Xprev) + if not fixed_structure: + err_structure = nx.norm(C - Cprev) + if log: + log_['err_feature'].append(err_feature) + log_['err_structure'].append(err_structure) + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err_structure)) + print('{:5d}|{:8e}|'.format(cpt, err_feature)) + + if (err_feature <= tol) or (err_structure <= tol): + break + else: + err_rel_loss = abs(curr_loss - prev_loss) / prev_loss if prev_loss != 0. else np.nan + if log: + log_['loss'].append(curr_loss) + log_['err_rel_loss'].append(err_rel_loss) + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err_rel_loss)) + + if err_rel_loss <= tol: + break + + if log: + log_['T'] = T + log_['p'] = p + log_['Ms'] = Ms + + return X, C, log_ + else: + return X, C diff --git a/ot/gromov/_utils.py b/ot/gromov/_utils.py index 9a8111453..d4928d062 100644 --- a/ot/gromov/_utils.py +++ b/ot/gromov/_utils.py @@ -253,160 +253,6 @@ def gwggrad(constC, hC1, hC2, T, nx=None): T, nx) # [12] Prop. 2 misses a 2 factor -def update_square_loss(p, lambdas, T, Cs, nx=None): - r""" - Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` - :math:`\mathbf{T}_s` couplings calculated at each iteration of the GW - barycenter problem in :ref:`[12]`: - - .. math:: - - \mathbf{C}^* = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s) - - Where : - - - :math:`\mathbf{C}_s`: metric cost matrix - - :math:`\mathbf{p}_s`: distribution - - Parameters - ---------- - p : array-like, shape (N,) - Masses in the targeted barycenter. - lambdas : list of float - List of the `S` spaces' weights. - T : list of S array-like of shape (N, ns) - The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration. - Cs : list of S array-like, shape(ns,ns) - Metric cost matrices. - nx : backend, optional - If let to its default value None, a backend test will be conducted. - - Returns - ---------- - C : array-like, shape (`nt`, `nt`) - Updated :math:`\mathbf{C}` matrix. - - References - ---------- - .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, - "Gromov-Wasserstein averaging of kernel and distance matrices." - International Conference on Machine Learning (ICML). 2016. - - """ - if nx is None: - nx = get_backend(p, *T, *Cs) - - # Correct order mistake in Equation 14 in [12] - tmpsum = sum([ - lambdas[s] * nx.dot( - nx.dot(T[s], Cs[s]), - T[s].T - ) for s in range(len(T)) - ]) - ppt = nx.outer(p, p) - - return tmpsum / ppt - - -def update_kl_loss(p, lambdas, T, Cs, nx=None): - r""" - Updates :math:`\mathbf{C}` according to the KL Loss kernel with the `S` - :math:`\mathbf{T}_s` couplings calculated at each iteration of the GW - barycenter problem in :ref:`[12]`: - - .. math:: - - \mathbf{C}^* = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s) - - Where : - - - :math:`\mathbf{C}_s`: metric cost matrix - - :math:`\mathbf{p}_s`: distribution - - - Parameters - ---------- - p : array-like, shape (N,) - Weights in the targeted barycenter. - lambdas : list of float - List of the `S` spaces' weights - T : list of S array-like of shape (N, ns) - The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration. - Cs : list of S array-like, shape(ns,ns) - Metric cost matrices. - nx : backend, optional - If let to its default value None, a backend test will be conducted. - - Returns - ---------- - C : array-like, shape (`ns`, `ns`) - updated :math:`\mathbf{C}` matrix - - References - ---------- - .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, - "Gromov-Wasserstein averaging of kernel and distance matrices." - International Conference on Machine Learning (ICML). 2016. - - """ - if nx is None: - nx = get_backend(p, *T, *Cs) - - # Correct order mistake in Equation 15 in [12] - tmpsum = sum([ - lambdas[s] * nx.dot( - nx.dot(T[s], nx.log(nx.maximum(Cs[s], 1e-15))), - T[s].T - ) for s in range(len(T)) - ]) - ppt = nx.outer(p, p) - - return nx.exp(tmpsum / ppt) - - -def update_feature_matrix(lambdas, Ys, Ts, p, nx=None): - r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings. - - - See "Solving the barycenter problem with Block Coordinate Descent (BCD)" - in :ref:`[24] ` calculated at each iteration - - Parameters - ---------- - p : array-like, shape (N,) - masses in the targeted barycenter - lambdas : list of float - List of the `S` spaces' weights - Ts : list of S array-like, shape (N, ns) - The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration - Ys : list of S array-like, shape (d,ns) - The features. - nx : backend, optional - If let to its default value None, a backend test will be conducted. - - Returns - ------- - X : array-like, shape (`d`, `N`) - - - .. _references-update-feature-matrix: - References - ---------- - .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas - "Optimal Transport for structured data with application on graphs" - International Conference on Machine Learning (ICML). 2019. - """ - if nx is None: - nx = get_backend(*Ys, *Ts, p) - - p = 1. / p - tmpsum = sum([ - lambdas[s] * nx.dot(Ys[s], Ts[s].T) * p[None, :] - for s in range(len(Ts)) - ]) - return tmpsum - - def init_matrix_semirelaxed(C1, C2, p, loss_fun='square_loss', nx=None): r"""Return loss matrices and tensors for semi-relaxed Gromov-Wasserstein fast computation @@ -522,3 +368,210 @@ def h2(b): hC2 = h2(C2) fC2t = f2(C2).T return constC, hC1, hC2, fC2t + + +def update_barycenter_structure( + Ts, Cs, lambdas, p=None, loss_fun='square_loss', target=True, + check_zeros=True, nx=None): + r""" + Updates :math:`\mathbf{C}` according to the inner loss L with the `S` + :math:`\mathbf{T}_s` couplings calculated at each iteration of variants of + the GW barycenter problem (e.g GW :ref:`[12]`, srGW :ref:`[48]`). + If `target=True` it solves for: + + .. math:: + + \mathbf{C}^* = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad + \sum_s \lambda_s \sum_{i,j,k,l} + L(\mathbf{C}^{(s)}_{i,k}, \mathbf{C}_{j,l}) \mathbf{T}^{(s)}_{i,j} \mathbf{T}^{(s)}_{k,l} + + Else it solves the symmetric problem: + + .. math:: + + \mathbf{C}^* = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad + \sum_s \lambda_s \sum_{i,j,k,l} + L(\mathbf{C}_{j,l}, \mathbf{C}^{(s)}_{i,k}) \mathbf{T}^{(s)}_{i,j} \mathbf{T}^{(s)}_{k,l} + + Where : + + - :math:`\mathbf{C}^{(s)}`: pairwise matrix in the s^{th} source space . + - :math:`\mathbf{C}`: pairwise matrix in the target space. + - :math:`L`: inner divergence for the GW loss + + Parameters + ---------- + Ts : list of S array-like of shape (ns, N) if `target=True` else (N, ns). + The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration. + Cs : list of S array-like, shape(ns, ns) + Metric cost matrices. + lambdas : list of float, + List of the `S` spaces' weights. + p : array-like, shape (N,) or (S,N) + Masses or list of masses in the targeted barycenter. + loss_fun : str, optional. Default is 'square_loss' + Name of loss function to use in ['square_loss', 'kl_loss']. + target: bool, optional. Default is True. + Whether the barycenter is positioned as target (True) or source (False). + check_zeros: bool, optional. Default is True. + Whether to check if marginals on the barycenter contains zeros or not. + Can be set to False to gain time if marginals are known to be positive. + nx : backend, optional + If let to its default value None, a backend test will be conducted. + + Returns + ---------- + C : array-like, shape (`nt`, `nt`) + Updated :math:`\mathbf{C}` matrix. + + References + ---------- + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" + International Conference on Learning Representations (ICLR), 2022. + + """ + + if nx is None: + arr = [*Ts, *Cs] + if p is not None: + arr += [p] + + nx = get_backend(*arr) + + S = len(Ts) + + if p is None: + p = nx.concatenate( + [nx.sum(Ts[s], int(not target))[None, :] for s in range(S)], + axis=0) + + # compute coefficients for the barycenter coming from marginals + + if len(p.shape) == 1: # shared target masses potentially with zeros + if check_zeros: + inv_p = nx.nan_to_num(1. / p, nan=1., posinf=1., neginf=1.) + else: + inv_p = 1. / p + + prod = nx.outer(inv_p, inv_p) + + else: + quotient = sum([nx.outer(p[s], p[s]) for s in range(S)]) + if check_zeros: + prod = nx.nan_to_num(1. / quotient, nan=1., posinf=1., neginf=1.) + else: + prod = 1. / quotient + + # compute coefficients for the barycenter coming from Ts and Cs + + if loss_fun == 'square_loss': + if target: + list_structures = [lambdas[s] * nx.dot( + nx.dot(Ts[s].T, Cs[s]), Ts[s]) for s in range(S)] + else: + list_structures = [lambdas[s] * nx.dot( + nx.dot(Ts[s], Cs[s]), Ts[s].T) for s in range(S)] + + return sum(list_structures) * prod + + elif loss_fun == 'kl_loss': + if target: + list_structures = [lambdas[s] * nx.dot( + nx.dot(Ts[s].T, Cs[s]), Ts[s]) + for s in range(S)] + + return sum(list_structures) * prod + else: + list_structures = [lambdas[s] * nx.dot( + nx.dot(Ts[s], nx.log(nx.maximum(Cs[s], 1e-16))), Ts[s].T) + for s in range(S)] + + return nx.exp(sum(list_structures) * prod) + + else: + raise ValueError(f"not supported loss_fun = {loss_fun}") + + +def update_barycenter_feature( + Ts, Ys, lambdas, p=None, loss_fun='square_loss', target=True, + check_zeros=True, nx=None): + r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` + couplings calculated at each iteration of variants of the FGW + barycenter problem with inner wasserstein loss `loss_fun` + (e.g FGW :ref:`[24]`, srFGW :ref:`[48]`). + If `target=True` the barycenter is considered as the target else as the source. + + Parameters + ---------- + Ts : list of S array-like of shape (ns, N) if `target=True` else (N, ns). + The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration. + Ys : list of S array-like, shape (ns, d) + Feature matrices. + lambdas : list of float + List of the `S` spaces' weights + p : array-like, shape (N,) or (S,N) + Masses or list of masses in the targeted barycenter. + loss_fun : str, optional. Default is 'square_loss' + Name of loss function to use in ['square_loss']. + target: bool, optional. Default is True. + Whether the barycenter is positioned as target (True) or source (False). + check_zeros: bool, optional. Default is True. + Whether to check if marginals on the barycenter contains zeros or not. + Can be set to False to gain time if marginals are known to be positive. + nx : backend, optional + If let to its default value None, a backend test will be conducted. + + Returns + ------- + X : array-like, shape (N, d) + + References + ---------- + .. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas + "Optimal Transport for structured data with application on graphs" + International Conference on Machine Learning (ICML). 2019. + + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" + International Conference on Learning Representations (ICLR), 2022. + """ + if nx is None: + arr = [*Ts, *Ys] + if p is not None: + arr += [p] + + nx = get_backend(*arr) + + if loss_fun != 'square_loss': + raise ValueError(f"not supported loss_fun = {loss_fun}") + + S = len(Ts) + + if target: + list_features = [lambdas[s] * nx.dot(Ts[s].T, Ys[s]) for s in range(S)] + else: + list_features = [lambdas[s] * nx.dot(Ts[s], Ys[s]) for s in range(S)] + + if p is None: + p = nx.concatenate( + [nx.sum(Ts[s], int(not target))[None, :] for s in range(S)], + axis=0) + + if len(p.shape) == 1: # shared target masses potentially with zeros + if check_zeros: + inv_p = nx.nan_to_num(1. / p, nan=1., posinf=1., neginf=1.) + else: + inv_p = 1. / p + else: + p_sum = sum(p) + if check_zeros: + inv_p = nx.nan_to_num(1. / p_sum, nan=1., posinf=1., neginf=1.) + else: + inv_p = 1. / p_sum + + return sum(list_features) * inv_p[:, None] diff --git a/test/gromov/test_gw.py b/test/gromov/test_gw.py index e76a33dcf..5b858f307 100644 --- a/test/gromov/test_gw.py +++ b/test/gromov/test_gw.py @@ -795,8 +795,10 @@ def test_fgw_barycenter(nx): random_state=12345, log=True ) # test correspondance with utils function - recovered_Cb = ot.gromov.update_square_loss(pb, lambdas, logb['Ts_iter'][-1], Csb) - recovered_Xb = ot.gromov.update_feature_matrix(lambdas, [y.T for y in Ysb], logb['Ts_iter'][-1], pb).T + recovered_Cb = ot.gromov.update_barycenter_structure( + logb['Ts_iter'][-1], Csb, lambdas, pb, target=False, check_zeros=True) + recovered_Xb = ot.gromov.update_barycenter_feature( + logb['Ts_iter'][-1], Ysb, lambdas, pb, target=False, check_zeros=True) np.testing.assert_allclose(Cb, recovered_Cb) np.testing.assert_allclose(Xb, recovered_Xb) @@ -864,7 +866,10 @@ def test_fgw_barycenter(nx): np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) # test correspondance with utils function - recovered_C = ot.gromov.update_kl_loss(p, lambdas, log['T'], [C1, C2]) + recovered_C = ot.gromov.update_barycenter_structure( + log['T'], [C1, C2], lambdas, p, loss_fun='kl_loss', + target=False, check_zeros=False) + np.testing.assert_allclose(C, recovered_C) # test edge cases for fgw barycenters: diff --git a/test/gromov/test_semirelaxed.py b/test/gromov/test_semirelaxed.py index 6f23a6b62..2e4b2f128 100644 --- a/test/gromov/test_semirelaxed.py +++ b/test/gromov/test_semirelaxed.py @@ -613,3 +613,130 @@ def test_entropic_semirelaxed_fgw_dtype_device(nx): nx.assert_same_dtype_device(C1b, Gb) nx.assert_same_dtype_device(C1b, fgw_valb) + + +def test_semirelaxed_fgw_barycenter(nx): + ns = 10 + nt = 20 + + Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) + Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) + + rng = np.random.RandomState(42) + ys = rng.randn(Xs.shape[0], 2) + yt = rng.randn(Xt.shape[0], 2) + + C1 = ot.dist(Xs) + C2 = ot.dist(Xt) + C1 /= C1.max() + C2 /= C2.max() + + p1, p2 = ot.unif(ns), ot.unif(nt) + n_samples = 3 + + ysb, ytb, C1b, C2b, p1b, p2b = nx.from_numpy(ys, yt, C1, C2, p1, p2) + + lambdas = [.5, .5] + Csb = [C1b, C2b] + Ysb = [ysb, ytb] + Xb, Cb, logb = ot.gromov.semirelaxed_fgw_barycenters( + n_samples, Ysb, Csb, None, lambdas, 0.5, fixed_structure=False, + fixed_features=False, loss_fun='square_loss', max_iter=10, tol=1e-3, + random_state=12345, log=True + ) + # test correspondance with utils function + recovered_Cb = ot.gromov.update_barycenter_structure( + logb['T'], Csb, lambdas) + recovered_Xb = ot.gromov.update_barycenter_feature( + logb['T'], Ysb, lambdas) + + np.testing.assert_allclose(Cb, recovered_Cb) + np.testing.assert_allclose(Xb, recovered_Xb) + + xalea = rng.randn(n_samples, 2) + init_C = ot.dist(xalea, xalea) + init_C /= init_C.max() + init_Cb = nx.from_numpy(init_C) + + with pytest.raises(ot.utils.UndefinedParameter): # to raise an error when `fixed_structure=True`and `init_C=None` + Xb, Cb, logb = ot.gromov.semirelaxed_fgw_barycenters( + n_samples, Ysb, Csb, ps=[p1b, p2b], lambdas=None, alpha=0.5, + fixed_structure=True, init_C=None, fixed_features=False, + loss_fun='square_loss', max_iter=10, tol=1e-3 + ) + + Xb, Cb = ot.gromov.semirelaxed_fgw_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None, + alpha=0.5, fixed_structure=True, init_C=init_Cb, fixed_features=False, + loss_fun='square_loss', max_iter=10, tol=1e-3 + ) + Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb) + np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) + np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1])) + + init_X = rng.randn(n_samples, ys.shape[1]) + init_Xb = nx.from_numpy(init_X) + + # Tests with `fixed_structure=False` and `fixed_features=True` + with pytest.raises(ot.utils.UndefinedParameter): # to raise an error when `fixed_features=True`and `init_X=None` + Xb, Cb, logb = ot.gromov.semirelaxed_fgw_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, + fixed_structure=False, fixed_features=True, init_X=None, + loss_fun='square_loss', max_iter=10, tol=1e-3, + warmstartT=True, log=True, random_state=98765, verbose=True + ) + Xb, Cb, logb = ot.gromov.semirelaxed_fgw_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, + fixed_structure=False, fixed_features=True, init_X=init_Xb, + loss_fun='square_loss', max_iter=10, tol=1e-3, + warmstartT=True, log=True, random_state=98765, verbose=True + ) + + X, C = nx.to_numpy(Xb), nx.to_numpy(Cb) + np.testing.assert_allclose(C.shape, (n_samples, n_samples)) + np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) + + # add test with 'kl_loss' + with pytest.raises(ValueError): + stop_criterion = 'unknown stop criterion' + X, C, log = ot.gromov.semirelaxed_fgw_barycenters( + n_samples, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5, + fixed_structure=False, fixed_features=False, loss_fun='kl_loss', + max_iter=10, tol=1e-3, stop_criterion=stop_criterion, init_C=C, + init_X=X, warmstartT=True, random_state=12345, log=True + ) + + for stop_criterion in ['barycenter', 'loss']: + X, C, log = ot.gromov.semirelaxed_fgw_barycenters( + n_samples, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5, + fixed_structure=False, fixed_features=False, loss_fun='kl_loss', + max_iter=10, tol=1e-3, stop_criterion=stop_criterion, init_C=C, + init_X=X, warmstartT=True, random_state=12345, log=True, verbose=True + ) + np.testing.assert_allclose(C.shape, (n_samples, n_samples)) + np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) + + # test correspondance with utils function + + recovered_C = ot.gromov.update_barycenter_structure( + log['T'], [C1, C2], lambdas, None, 'kl_loss', True) + + np.testing.assert_allclose(C, recovered_C) + + # test edge cases for semirelaxed fgw barycenters: + # unique input structure + X, C = ot.gromov.semirelaxed_fgw_barycenters( + n_samples, [ys], [C1], [p1], None, 0.5, + fixed_structure=False, fixed_features=False, loss_fun='square_loss', + max_iter=2, tol=1e-3, stop_criterion=stop_criterion, + warmstartT=True, random_state=12345, log=False, verbose=False + ) + Xb, Cb = ot.gromov.semirelaxed_fgw_barycenters( + n_samples, [ysb], [C1b], [p1b], [1.], 0.5, + fixed_structure=False, fixed_features=False, loss_fun='square_loss', + max_iter=2, tol=1e-3, stop_criterion=stop_criterion, + warmstartT=True, random_state=12345, log=False, verbose=False + ) + + np.testing.assert_allclose(C, Cb, atol=1e-06) + np.testing.assert_allclose(X, Xb, atol=1e-06) diff --git a/test/gromov/test_utils.py b/test/gromov/test_utils.py new file mode 100644 index 000000000..ad94a4042 --- /dev/null +++ b/test/gromov/test_utils.py @@ -0,0 +1,63 @@ +""" Tests for gromov._utils.py """ + +# Author: Cédric Vincent-Cuaz +# +# License: MIT License + +import numpy as np +import pytest + +import ot + + +def test_update_barycenter(nx): + ns = 5 + nt = 10 + + Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) + Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) + + rng = np.random.RandomState(42) + ys = rng.randn(Xs.shape[0], 2) + yt = rng.randn(Xt.shape[0], 2) + + C1 = ot.dist(Xs) + C2 = ot.dist(Xt) + C1 /= C1.max() + C2 /= C2.max() + + p1, p2 = ot.unif(ns), ot.unif(nt) + n_samples = 3 + + ysb, ytb, C1b, C2b, p1b, p2b = nx.from_numpy(ys, yt, C1, C2, p1, p2) + + lambdas = [.5, .5] + Csb = [C1b, C2b] + Ysb = [ysb, ytb] + + Tb = [nx.ones((m, n_samples), type_as=C1b) / (m * n_samples) for m in [ns, nt]] + pb = nx.concatenate( + [nx.sum(elem, 0)[None, :] for elem in Tb], axis=0) + + # test edge cases for the update of the barycenter with `p != None` + # and `target=False` + Cb = ot.gromov.update_barycenter_structure( + [elem.T for elem in Tb], Csb, lambdas, pb, target=False) + Xb = ot.gromov.update_barycenter_feature( + [elem.T for elem in Tb], Ysb, lambdas, pb, target=False) + + Cbt = ot.gromov.update_barycenter_structure( + Tb, Csb, lambdas, None, target=True, check_zeros=False) + Xbt = ot.gromov.update_barycenter_feature( + Tb, Ysb, lambdas, None, target=True, check_zeros=False) + + np.testing.assert_allclose(Cb, Cbt) + np.testing.assert_allclose(Xb, Xbt) + + # test not supported metrics + with pytest.raises(ValueError): + Cbt = ot.gromov.update_barycenter_structure( + Tb, Csb, lambdas, None, loss_fun='unknown', target=True) + with pytest.raises(ValueError): + Xbt = ot.gromov.update_barycenter_feature( + Tb, Ysb, lambdas, None, loss_fun='unknown', target=True)