Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] srFGW barycenters #659

Merged
merged 11 commits into from
Jul 19, 2024
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ POT provides the following generic OT solvers (links to examples):
* [Wasserstein distance on the circle](https://pythonot.github.io/auto_examples/plot_compute_wasserstein_circle.html) [44, 45]
* [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46]
* [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38].
* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) (exact and regularized [48]).
* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) with corresponding barycenter solvers (exact and regularized [48]).
* [Quantized (Fused) Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_quantized_gromov_wasserstein.html) [68].
* [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://pythonot.github.io/auto_examples/others/plot_demd_gradient_minimize.html) [50].
* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays.
Expand Down
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#### New features
- Add feature `mass=True` for `nx.kl_div` (PR #654)
- Add feature `semirelaxed_fgw_barycenters` and generic FGW-related barycenter updates `update_barycenter_structure` and `update_barycenter_feature` (PR #659)

#### Closed issues

Expand Down
19 changes: 12 additions & 7 deletions ot/gromov/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@

# All submodules and packages
from ._utils import (init_matrix, tensor_product, gwloss, gwggrad,
update_square_loss, update_kl_loss, update_feature_matrix,
init_matrix_semirelaxed)
init_matrix_semirelaxed,
update_barycenter_structure, update_barycenter_feature,
)

from ._gw import (gromov_wasserstein, gromov_wasserstein2,
fused_gromov_wasserstein, fused_gromov_wasserstein2,
Expand Down Expand Up @@ -40,14 +41,16 @@
entropic_semirelaxed_gromov_wasserstein,
entropic_semirelaxed_gromov_wasserstein2,
entropic_semirelaxed_fused_gromov_wasserstein,
entropic_semirelaxed_fused_gromov_wasserstein2)
entropic_semirelaxed_fused_gromov_wasserstein2,
semirelaxed_fgw_barycenters)

from ._dictionary import (gromov_wasserstein_dictionary_learning,
gromov_wasserstein_linear_unmixing,
fused_gromov_wasserstein_dictionary_learning,
fused_gromov_wasserstein_linear_unmixing)

from ._lowrank import (_flat_product_operator, lowrank_gromov_wasserstein_samples)
from ._lowrank import (_flat_product_operator,
lowrank_gromov_wasserstein_samples)


from ._quantized import (quantized_fused_gromov_wasserstein_partitioned,
Expand All @@ -60,8 +63,9 @@
quantized_fused_gromov_wasserstein_samples
)

__all__ = ['init_matrix', 'tensor_product', 'gwloss', 'gwggrad', 'update_square_loss',
'update_kl_loss', 'update_feature_matrix', 'init_matrix_semirelaxed',
__all__ = ['init_matrix', 'tensor_product', 'gwloss', 'gwggrad',
'init_matrix_semirelaxed',
'update_barycenter_structure', 'update_barycenter_feature',
'gromov_wasserstein', 'gromov_wasserstein2', 'fused_gromov_wasserstein',
'fused_gromov_wasserstein2', 'solve_gromov_linesearch', 'gromov_barycenters',
'fgw_barycenters', 'entropic_gromov_wasserstein', 'entropic_gromov_wasserstein2',
Expand All @@ -80,4 +84,5 @@
'quantized_fused_gromov_wasserstein_partitioned', 'get_graph_partition',
'get_graph_representants', 'format_partitioned_graph',
'quantized_fused_gromov_wasserstein', 'get_partition_and_representants_samples',
'format_partitioned_samples', 'quantized_fused_gromov_wasserstein_samples']
'format_partitioned_samples', 'quantized_fused_gromov_wasserstein_samples',
'semirelaxed_fgw_barycenters']
47 changes: 21 additions & 26 deletions ot/gromov/_bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ..backend import get_backend

from ._utils import init_matrix, gwloss, gwggrad
from ._utils import update_square_loss, update_kl_loss, update_feature_matrix
from ._utils import update_barycenter_structure, update_barycenter_feature


def entropic_gromov_wasserstein(
Expand Down Expand Up @@ -807,10 +807,8 @@ def entropic_gromov_barycenters(
curr_loss = np.sum([output[1]['gw_dist'] for output in res])

# update barycenters
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs, nx)
elif loss_fun == 'kl_loss':
C = update_kl_loss(p, lambdas, T, Cs, nx)
C = update_barycenter_structure(
T, Cs, lambdas, p, loss_fun, target=False, check_zeros=False, nx=nx)

# update convergence criterion
if stop_criterion == 'barycenter':
Expand Down Expand Up @@ -1651,13 +1649,14 @@ def entropic_fused_gromov_barycenters(
# Initialization of C : random euclidean distance matrix (if not provided by user)
if fixed_structure:
if init_C is None:
raise UndefinedParameter('If C is fixed it must be initialized')
raise UndefinedParameter(
'If C is fixed it must be provided in init_C')
else:
C = init_C
else:
if init_C is None:
generator = check_random_state(random_state)
xalea = generator.randn(N, 2)
rng = check_random_state(random_state)
xalea = rng.randn(N, 2)
C = dist(xalea, xalea)
C = nx.from_numpy(C, type_as=ps[0])
else:
Expand All @@ -1666,7 +1665,8 @@ def entropic_fused_gromov_barycenters(
# Initialization of Y
if fixed_features:
if init_Y is None:
raise UndefinedParameter('If Y is fixed it must be initialized')
raise UndefinedParameter(
'If Y is fixed it must be provided in init_Y')
else:
Y = init_Y
else:
Expand All @@ -1681,20 +1681,12 @@ def entropic_fused_gromov_barycenters(
if warmstartT:
T = [None] * S

cpt = 0

if stop_criterion == 'barycenter':
inner_log = False
err_feature = 1e15
err_structure = 1e15
err_rel_loss = 0.

else:
inner_log = True
err_feature = 0.
err_structure = 0.
curr_loss = 1e15
err_rel_loss = 1e15

if log:
log_ = {}
Expand All @@ -1706,7 +1698,8 @@ def entropic_fused_gromov_barycenters(
log_['loss'] = []
log_['err_rel_loss'] = []

while ((err_feature > tol or err_structure > tol or err_rel_loss > tol) and cpt < max_iter):
for cpt in range(max_iter): # break if specified errors are below tol.

if stop_criterion == 'barycenter':
Cprev = C
Yprev = Y
Expand All @@ -1732,16 +1725,14 @@ def entropic_fused_gromov_barycenters(

# update barycenters
if not fixed_features:
Ys_temp = [y.T for y in Ys]
X = update_feature_matrix(lambdas, Ys_temp, T, p, nx).T
X = update_barycenter_feature(
T, Ys, lambdas, p, target=False, check_zeros=False, nx=nx)

Ms = [dist(X, Ys[s]) for s in range(len(Ys))]

if not fixed_structure:
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs, nx)

elif loss_fun == 'kl_loss':
C = update_kl_loss(p, lambdas, T, Cs, nx)
C = update_barycenter_structure(
T, Cs, lambdas, p, loss_fun, target=False, check_zeros=False, nx=nx)

# update convergence criterion
if stop_criterion == 'barycenter':
Expand All @@ -1761,6 +1752,9 @@ def entropic_fused_gromov_barycenters(
'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err_structure))
print('{:5d}|{:8e}|'.format(cpt, err_feature))

if (err_feature <= tol) or (err_structure <= tol):
break
else:
err_rel_loss = abs(curr_loss - prev_loss) / prev_loss if prev_loss != 0. else np.nan
if log:
Expand All @@ -1773,7 +1767,8 @@ def entropic_fused_gromov_barycenters(
'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err_rel_loss))

cpt += 1
if err_rel_loss <= tol:
break

if log:
log_['T'] = T
Expand Down
61 changes: 27 additions & 34 deletions ot/gromov/_gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ..backend import get_backend, NumpyBackend

from ._utils import init_matrix, gwloss, gwggrad
from ._utils import update_square_loss, update_kl_loss, update_feature_matrix
from ._utils import update_barycenter_structure, update_barycenter_feature


def gromov_wasserstein(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None,
Expand Down Expand Up @@ -833,17 +833,14 @@ def gromov_barycenters(

# Initialization of C : random SPD matrix (if not provided by user)
if init_C is None:
generator = check_random_state(random_state)
xalea = generator.randn(N, 2)
rng = check_random_state(random_state)
xalea = rng.randn(N, 2)
C = dist(xalea, xalea)
C /= C.max()
C = nx.from_numpy(C, type_as=p)
else:
C = init_C

cpt = 0
err = 1e15 # either the error on 'barycenter' or 'loss'

if warmstartT:
T = [None] * S

Expand All @@ -859,7 +856,8 @@ def gromov_barycenters(
if stop_criterion == 'loss':
log_['loss'] = []

while (err > tol and cpt < max_iter):
for cpt in range(max_iter):

if stop_criterion == 'barycenter':
Cprev = C
else:
Expand All @@ -883,11 +881,8 @@ def gromov_barycenters(
curr_loss = np.sum([output[1]['gw_dist'] for output in res])

# update barycenters
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs, nx)

elif loss_fun == 'kl_loss':
C = update_kl_loss(p, lambdas, T, Cs, nx)
C = update_barycenter_structure(
T, Cs, lambdas, p, loss_fun, target=False, check_zeros=False, nx=nx)

# update convergence criterion
if stop_criterion == 'barycenter':
Expand All @@ -907,7 +902,8 @@ def gromov_barycenters(
'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err))

cpt += 1
if err <= tol:
break

if log:
log_['T'] = T
Expand Down Expand Up @@ -1046,21 +1042,23 @@ def fgw_barycenters(

if fixed_structure:
if init_C is None:
raise UndefinedParameter('If C is fixed it must be initialized')
raise UndefinedParameter(
'If C is fixed it must be provided in init_C')
else:
C = init_C
else:
if init_C is None:
generator = check_random_state(random_state)
xalea = generator.randn(N, 2)
rng = check_random_state(random_state)
xalea = rng.randn(N, 2)
C = dist(xalea, xalea)
C = nx.from_numpy(C, type_as=ps[0])
else:
C = init_C

if fixed_features:
if init_X is None:
raise UndefinedParameter('If X is fixed it must be initialized')
raise UndefinedParameter(
'If X is fixed it must be provided in init_X')
else:
X = init_X
else:
Expand All @@ -1075,20 +1073,12 @@ def fgw_barycenters(
if warmstartT:
T = [None] * S

cpt = 0

if stop_criterion == 'barycenter':
inner_log = False
err_feature = 1e15
err_structure = 1e15
err_rel_loss = 0.

else:
inner_log = True
err_feature = 0.
err_structure = 0.
curr_loss = 1e15
err_rel_loss = 1e15

if log:
log_ = {}
Expand All @@ -1100,7 +1090,8 @@ def fgw_barycenters(
log_['loss'] = []
log_['err_rel_loss'] = []

while ((err_feature > tol or err_structure > tol or err_rel_loss > tol) and cpt < max_iter):
for cpt in range(max_iter): # break if specified errors are below tol.

if stop_criterion == 'barycenter':
Cprev = C
Xprev = X
Expand All @@ -1126,16 +1117,14 @@ def fgw_barycenters(

# update barycenters
if not fixed_features:
Ys_temp = [y.T for y in Ys]
X = update_feature_matrix(lambdas, Ys_temp, T, p, nx).T
X = update_barycenter_feature(
T, Ys, lambdas, p, target=False, check_zeros=False, nx=nx)

Ms = [dist(X, Ys[s]) for s in range(len(Ys))]

if not fixed_structure:
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs, nx)

elif loss_fun == 'kl_loss':
C = update_kl_loss(p, lambdas, T, Cs, nx)
C = update_barycenter_structure(
T, Cs, lambdas, p, loss_fun, target=False, check_zeros=False, nx=nx)

# update convergence criterion
if stop_criterion == 'barycenter':
Expand All @@ -1155,6 +1144,9 @@ def fgw_barycenters(
'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err_structure))
print('{:5d}|{:8e}|'.format(cpt, err_feature))

if (err_feature <= tol) or (err_structure <= tol):
break
else:
err_rel_loss = abs(curr_loss - prev_loss) / prev_loss if prev_loss != 0. else np.nan
if log:
Expand All @@ -1167,7 +1159,8 @@ def fgw_barycenters(
'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err_rel_loss))

cpt += 1
if err_rel_loss <= tol:
break

if log:
log_['T'] = T
Expand Down
Loading
Loading