Skip to content

Commit

Permalink
[MRG] srFGW barycenters (#659)
Browse files Browse the repository at this point in the history
* init commit - integrating sr(F)GW barycenter

* correct asymmetries in srgw

* fix tests srFGW bary

* fix pep8

* complete tests for srFGW barycenters and utils

* last updates

* update barycenter update functions and remove old ones

* take review into account

* fix pep8

* ot/gromov/__init__.py

* update test
  • Loading branch information
cedricvincentcuaz authored Jul 19, 2024
1 parent d0849a4 commit 47c5925
Show file tree
Hide file tree
Showing 10 changed files with 713 additions and 227 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ POT provides the following generic OT solvers (links to examples):
* [Wasserstein distance on the circle](https://pythonot.github.io/auto_examples/plot_compute_wasserstein_circle.html) [44, 45]
* [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46]
* [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38].
* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) (exact and regularized [48]).
* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) with corresponding barycenter solvers (exact and regularized [48]).
* [Quantized (Fused) Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_quantized_gromov_wasserstein.html) [68].
* [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://pythonot.github.io/auto_examples/others/plot_demd_gradient_minimize.html) [50].
* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays.
Expand Down
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#### New features
- Add feature `mass=True` for `nx.kl_div` (PR #654)
- Add feature `semirelaxed_fgw_barycenters` and generic FGW-related barycenter updates `update_barycenter_structure` and `update_barycenter_feature` (PR #659)

#### Closed issues

Expand Down
19 changes: 12 additions & 7 deletions ot/gromov/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@

# All submodules and packages
from ._utils import (init_matrix, tensor_product, gwloss, gwggrad,
update_square_loss, update_kl_loss, update_feature_matrix,
init_matrix_semirelaxed)
init_matrix_semirelaxed,
update_barycenter_structure, update_barycenter_feature,
)

from ._gw import (gromov_wasserstein, gromov_wasserstein2,
fused_gromov_wasserstein, fused_gromov_wasserstein2,
Expand Down Expand Up @@ -40,14 +41,16 @@
entropic_semirelaxed_gromov_wasserstein,
entropic_semirelaxed_gromov_wasserstein2,
entropic_semirelaxed_fused_gromov_wasserstein,
entropic_semirelaxed_fused_gromov_wasserstein2)
entropic_semirelaxed_fused_gromov_wasserstein2,
semirelaxed_fgw_barycenters)

from ._dictionary import (gromov_wasserstein_dictionary_learning,
gromov_wasserstein_linear_unmixing,
fused_gromov_wasserstein_dictionary_learning,
fused_gromov_wasserstein_linear_unmixing)

from ._lowrank import (_flat_product_operator, lowrank_gromov_wasserstein_samples)
from ._lowrank import (_flat_product_operator,
lowrank_gromov_wasserstein_samples)


from ._quantized import (quantized_fused_gromov_wasserstein_partitioned,
Expand All @@ -60,8 +63,9 @@
quantized_fused_gromov_wasserstein_samples
)

__all__ = ['init_matrix', 'tensor_product', 'gwloss', 'gwggrad', 'update_square_loss',
'update_kl_loss', 'update_feature_matrix', 'init_matrix_semirelaxed',
__all__ = ['init_matrix', 'tensor_product', 'gwloss', 'gwggrad',
'init_matrix_semirelaxed',
'update_barycenter_structure', 'update_barycenter_feature',
'gromov_wasserstein', 'gromov_wasserstein2', 'fused_gromov_wasserstein',
'fused_gromov_wasserstein2', 'solve_gromov_linesearch', 'gromov_barycenters',
'fgw_barycenters', 'entropic_gromov_wasserstein', 'entropic_gromov_wasserstein2',
Expand All @@ -80,4 +84,5 @@
'quantized_fused_gromov_wasserstein_partitioned', 'get_graph_partition',
'get_graph_representants', 'format_partitioned_graph',
'quantized_fused_gromov_wasserstein', 'get_partition_and_representants_samples',
'format_partitioned_samples', 'quantized_fused_gromov_wasserstein_samples']
'format_partitioned_samples', 'quantized_fused_gromov_wasserstein_samples',
'semirelaxed_fgw_barycenters']
47 changes: 21 additions & 26 deletions ot/gromov/_bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ..backend import get_backend

from ._utils import init_matrix, gwloss, gwggrad
from ._utils import update_square_loss, update_kl_loss, update_feature_matrix
from ._utils import update_barycenter_structure, update_barycenter_feature


def entropic_gromov_wasserstein(
Expand Down Expand Up @@ -807,10 +807,8 @@ def entropic_gromov_barycenters(
curr_loss = np.sum([output[1]['gw_dist'] for output in res])

# update barycenters
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs, nx)
elif loss_fun == 'kl_loss':
C = update_kl_loss(p, lambdas, T, Cs, nx)
C = update_barycenter_structure(
T, Cs, lambdas, p, loss_fun, target=False, check_zeros=False, nx=nx)

# update convergence criterion
if stop_criterion == 'barycenter':
Expand Down Expand Up @@ -1651,13 +1649,14 @@ def entropic_fused_gromov_barycenters(
# Initialization of C : random euclidean distance matrix (if not provided by user)
if fixed_structure:
if init_C is None:
raise UndefinedParameter('If C is fixed it must be initialized')
raise UndefinedParameter(
'If C is fixed it must be provided in init_C')
else:
C = init_C
else:
if init_C is None:
generator = check_random_state(random_state)
xalea = generator.randn(N, 2)
rng = check_random_state(random_state)
xalea = rng.randn(N, 2)
C = dist(xalea, xalea)
C = nx.from_numpy(C, type_as=ps[0])
else:
Expand All @@ -1666,7 +1665,8 @@ def entropic_fused_gromov_barycenters(
# Initialization of Y
if fixed_features:
if init_Y is None:
raise UndefinedParameter('If Y is fixed it must be initialized')
raise UndefinedParameter(
'If Y is fixed it must be provided in init_Y')
else:
Y = init_Y
else:
Expand All @@ -1681,20 +1681,12 @@ def entropic_fused_gromov_barycenters(
if warmstartT:
T = [None] * S

cpt = 0

if stop_criterion == 'barycenter':
inner_log = False
err_feature = 1e15
err_structure = 1e15
err_rel_loss = 0.

else:
inner_log = True
err_feature = 0.
err_structure = 0.
curr_loss = 1e15
err_rel_loss = 1e15

if log:
log_ = {}
Expand All @@ -1706,7 +1698,8 @@ def entropic_fused_gromov_barycenters(
log_['loss'] = []
log_['err_rel_loss'] = []

while ((err_feature > tol or err_structure > tol or err_rel_loss > tol) and cpt < max_iter):
for cpt in range(max_iter): # break if specified errors are below tol.

if stop_criterion == 'barycenter':
Cprev = C
Yprev = Y
Expand All @@ -1732,16 +1725,14 @@ def entropic_fused_gromov_barycenters(

# update barycenters
if not fixed_features:
Ys_temp = [y.T for y in Ys]
X = update_feature_matrix(lambdas, Ys_temp, T, p, nx).T
X = update_barycenter_feature(
T, Ys, lambdas, p, target=False, check_zeros=False, nx=nx)

Ms = [dist(X, Ys[s]) for s in range(len(Ys))]

if not fixed_structure:
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs, nx)

elif loss_fun == 'kl_loss':
C = update_kl_loss(p, lambdas, T, Cs, nx)
C = update_barycenter_structure(
T, Cs, lambdas, p, loss_fun, target=False, check_zeros=False, nx=nx)

# update convergence criterion
if stop_criterion == 'barycenter':
Expand All @@ -1761,6 +1752,9 @@ def entropic_fused_gromov_barycenters(
'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err_structure))
print('{:5d}|{:8e}|'.format(cpt, err_feature))

if (err_feature <= tol) or (err_structure <= tol):
break
else:
err_rel_loss = abs(curr_loss - prev_loss) / prev_loss if prev_loss != 0. else np.nan
if log:
Expand All @@ -1773,7 +1767,8 @@ def entropic_fused_gromov_barycenters(
'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err_rel_loss))

cpt += 1
if err_rel_loss <= tol:
break

if log:
log_['T'] = T
Expand Down
61 changes: 27 additions & 34 deletions ot/gromov/_gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ..backend import get_backend, NumpyBackend

from ._utils import init_matrix, gwloss, gwggrad
from ._utils import update_square_loss, update_kl_loss, update_feature_matrix
from ._utils import update_barycenter_structure, update_barycenter_feature


def gromov_wasserstein(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None,
Expand Down Expand Up @@ -833,17 +833,14 @@ def gromov_barycenters(

# Initialization of C : random SPD matrix (if not provided by user)
if init_C is None:
generator = check_random_state(random_state)
xalea = generator.randn(N, 2)
rng = check_random_state(random_state)
xalea = rng.randn(N, 2)
C = dist(xalea, xalea)
C /= C.max()
C = nx.from_numpy(C, type_as=p)
else:
C = init_C

cpt = 0
err = 1e15 # either the error on 'barycenter' or 'loss'

if warmstartT:
T = [None] * S

Expand All @@ -859,7 +856,8 @@ def gromov_barycenters(
if stop_criterion == 'loss':
log_['loss'] = []

while (err > tol and cpt < max_iter):
for cpt in range(max_iter):

if stop_criterion == 'barycenter':
Cprev = C
else:
Expand All @@ -883,11 +881,8 @@ def gromov_barycenters(
curr_loss = np.sum([output[1]['gw_dist'] for output in res])

# update barycenters
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs, nx)

elif loss_fun == 'kl_loss':
C = update_kl_loss(p, lambdas, T, Cs, nx)
C = update_barycenter_structure(
T, Cs, lambdas, p, loss_fun, target=False, check_zeros=False, nx=nx)

# update convergence criterion
if stop_criterion == 'barycenter':
Expand All @@ -907,7 +902,8 @@ def gromov_barycenters(
'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err))

cpt += 1
if err <= tol:
break

if log:
log_['T'] = T
Expand Down Expand Up @@ -1046,21 +1042,23 @@ def fgw_barycenters(

if fixed_structure:
if init_C is None:
raise UndefinedParameter('If C is fixed it must be initialized')
raise UndefinedParameter(
'If C is fixed it must be provided in init_C')
else:
C = init_C
else:
if init_C is None:
generator = check_random_state(random_state)
xalea = generator.randn(N, 2)
rng = check_random_state(random_state)
xalea = rng.randn(N, 2)
C = dist(xalea, xalea)
C = nx.from_numpy(C, type_as=ps[0])
else:
C = init_C

if fixed_features:
if init_X is None:
raise UndefinedParameter('If X is fixed it must be initialized')
raise UndefinedParameter(
'If X is fixed it must be provided in init_X')
else:
X = init_X
else:
Expand All @@ -1075,20 +1073,12 @@ def fgw_barycenters(
if warmstartT:
T = [None] * S

cpt = 0

if stop_criterion == 'barycenter':
inner_log = False
err_feature = 1e15
err_structure = 1e15
err_rel_loss = 0.

else:
inner_log = True
err_feature = 0.
err_structure = 0.
curr_loss = 1e15
err_rel_loss = 1e15

if log:
log_ = {}
Expand All @@ -1100,7 +1090,8 @@ def fgw_barycenters(
log_['loss'] = []
log_['err_rel_loss'] = []

while ((err_feature > tol or err_structure > tol or err_rel_loss > tol) and cpt < max_iter):
for cpt in range(max_iter): # break if specified errors are below tol.

if stop_criterion == 'barycenter':
Cprev = C
Xprev = X
Expand All @@ -1126,16 +1117,14 @@ def fgw_barycenters(

# update barycenters
if not fixed_features:
Ys_temp = [y.T for y in Ys]
X = update_feature_matrix(lambdas, Ys_temp, T, p, nx).T
X = update_barycenter_feature(
T, Ys, lambdas, p, target=False, check_zeros=False, nx=nx)

Ms = [dist(X, Ys[s]) for s in range(len(Ys))]

if not fixed_structure:
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs, nx)

elif loss_fun == 'kl_loss':
C = update_kl_loss(p, lambdas, T, Cs, nx)
C = update_barycenter_structure(
T, Cs, lambdas, p, loss_fun, target=False, check_zeros=False, nx=nx)

# update convergence criterion
if stop_criterion == 'barycenter':
Expand All @@ -1155,6 +1144,9 @@ def fgw_barycenters(
'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err_structure))
print('{:5d}|{:8e}|'.format(cpt, err_feature))

if (err_feature <= tol) or (err_structure <= tol):
break
else:
err_rel_loss = abs(curr_loss - prev_loss) / prev_loss if prev_loss != 0. else np.nan
if log:
Expand All @@ -1167,7 +1159,8 @@ def fgw_barycenters(
'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err_rel_loss))

cpt += 1
if err_rel_loss <= tol:
break

if log:
log_['T'] = T
Expand Down
Loading

0 comments on commit 47c5925

Please sign in to comment.