Skip to content

Commit

Permalink
gaussian init
Browse files Browse the repository at this point in the history
  • Loading branch information
rflamary committed Nov 2, 2023
1 parent 53dde7a commit b3be5a6
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 7 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -334,3 +334,4 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil

[59] Taylor A. B. (2017). [Convex interpolation and performance estimation of first-order methods for convex optimization.](https://dial.uclouvain.be/pr/boreal/object/boreal%3A182881/datastream/PDF_01/view) PhD thesis, Catholic University of Louvain, Louvain-la-Neuve, Belgium, 2017.

[60] Thornton, James, and Marco Cuturi. [Rethinking initialization of the sinkhorn algorithm](https://arxiv.org/pdf/2206.07630.pdf). International Conference on Artificial Intelligence and Statistics. PMLR, 2023.
3 changes: 2 additions & 1 deletion ot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d,
binary_search_circle, wasserstein_circle,
semidiscrete_wasserstein2_unif_circle)
from .bregman import sinkhorn, sinkhorn2, barycenter
from .bregman import (sinkhorn, sinkhorn2, barycenter, empirical_sinkhorn, empirical_sinkhorn2, empirical_sinkhorn_divergence)
from .unbalanced import (sinkhorn_unbalanced, barycenter_unbalanced,
sinkhorn_unbalanced2)
from .da import sinkhorn_lpl1_mm
Expand All @@ -61,6 +61,7 @@
'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
'emd2_1d', 'wasserstein_1d', 'backend', 'gaussian',
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
'empirical_sinkhorn', 'empirical_sinkhorn2', 'empirical_sinkhorn_divergence',
'sinkhorn_unbalanced', 'barycenter_unbalanced',
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'sliced_wasserstein_sphere',
'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein',
Expand Down
37 changes: 31 additions & 6 deletions ot/bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ot.utils import dist, list_to_array, unif

from .backend import get_backend
from .gaussian import dual_gaussian_init


def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9,
Expand Down Expand Up @@ -541,6 +542,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
log['niter'] = ii
log['u'] = u
log['v'] = v
log['warmstart'] = (nx.log(u), nx.log(v))

if n_hists: # return only loss
res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M)
Expand Down Expand Up @@ -697,6 +699,7 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
'log_v': nx.stack(lst_v, 1), }
log['u'] = nx.exp(log['log_u'])
log['v'] = nx.exp(log['log_v'])
log['warmstart'] = (log['log_u'], log['log_v'])
return res, log
else:
return res
Expand Down Expand Up @@ -2999,15 +3002,23 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
if b is None:
b = nx.from_numpy(unif(nt), type_as=X_s)

if warmstart is None:
f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a)
elif warmstart == 'gaussian':
# init only g since f is the first updated
f = dual_gaussian_init(X_s, X_t, a[:, None], b[:, None])
g = dual_gaussian_init(X_t, X_s, b[:, None], a[:, None])

Check warning on line 3010 in ot/bregman.py

View check run for this annotation

Codecov / codecov/patch

ot/bregman.py#L3009-L3010

Added lines #L3009 - L3010 were not covered by tests
elif (isinstance(warmstart, tuple) or isinstance(warmstart, list)) and len(warmstart) == 2:
f, g = warmstart
else:
raise ValueError(

Check warning on line 3014 in ot/bregman.py

View check run for this annotation

Codecov / codecov/patch

ot/bregman.py#L3014

Added line #L3014 was not covered by tests
"warmstart must be None, 'gaussian' or a tuple of two arrays")

if isLazy:
if log:
dict_log = {"err": []}

log_a, log_b = nx.log(a), nx.log(b)
if warmstart is None:
f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a)
else:
f, g = warmstart

if isinstance(batchSize, int):
bs, bt = batchSize, batchSize
Expand Down Expand Up @@ -3075,6 +3086,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
if log:
dict_log["u"] = f
dict_log["v"] = g
dict_log["warmstart"] = (f, g)
return (f, g, dict_log)
else:
return (f, g)
Expand All @@ -3083,11 +3095,11 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
M = dist(X_s, X_t, metric=metric)
if log:
pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr,
verbose=verbose, log=True, warmstart=warmstart, **kwargs)
verbose=verbose, log=True, warmstart=(f, g), **kwargs)
return pi, log
else:
pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr,
verbose=verbose, log=False, warmstart=warmstart, **kwargs)
verbose=verbose, log=False, warmstart=(f, g), **kwargs)
return pi


Expand Down Expand Up @@ -3201,6 +3213,19 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
if b is None:
b = nx.from_numpy(unif(nt), type_as=X_s)

if warmstart is None:
warmstart = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a)
elif warmstart == 'gaussian':
# init only g since f is the first updated
f = dual_gaussian_init(X_s, X_t, a[:, None], b[:, None])
g = dual_gaussian_init(X_t, X_s, b[:, None], a[:, None])
warmstart = (f, g)
elif (isinstance(warmstart, tuple) or isinstance(warmstart, list)) and len(warmstart) == 2:
warmstart = warmstart
else:
raise ValueError(

Check warning on line 3226 in ot/bregman.py

View check run for this annotation

Codecov / codecov/patch

ot/bregman.py#L3226

Added line #L3226 was not covered by tests
"warmstart must be None, 'gaussian' or a tuple of two arrays")

if isLazy:
if log:
f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric,
Expand Down
48 changes: 48 additions & 0 deletions ot/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,3 +645,51 @@ def empirical_gaussian_gromov_wasserstein_mapping(xs, xt, ws=None,
return A, b, log
else:
return A, b


def dual_gaussian_init(xs, xt, ws=None, wt=None, reg=1e-6):
r""" Return the source dual potential gaussian initialization.
This function return the dual potential gaussian initialization that can be
used to initialize the Sinkhorn algorithm. This initialization is based on
the Monge mapping between the source and target distributions seen as two
Gaussian distributions [60].
Parameters
----------
xs : array-like (ns,ds)
samples in the source domain
xt : array-like (nt,dt)
samples in the target domain
ws : array-like (ns,1), optional
weights for the source samples
wt : array-like (ns,1), optional
weights for the target samples
reg : float,optional
regularization added to the diagonals of covariances (>0)
.. [60] Thornton, James, and Marco Cuturi. "Rethinking initialization of the
sinkhorn algorithm." International Conference on Artificial Intelligence
and Statistics. PMLR, 2023.
"""

nx = get_backend(xs, xt)

if ws is None:
ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0]

if wt is None:
wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0]

# estimate mean and covariance
mu_s = nx.dot(ws.T, xs) / nx.sum(ws)
mu_t = nx.dot(wt.T, xt) / nx.sum(wt)

A, b = empirical_bures_wasserstein_mapping(xs, xt, ws=ws, wt=wt, reg=reg)

xsc = xs - mu_s

# compute the dual potential (see appendix D in [60])
f = nx.sum(xs**2 - nx.dot(xsc, A) * xsc - mu_t * xs, 1)

return f
8 changes: 8 additions & 0 deletions test/test_bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,6 +1041,9 @@ def test_empirical_sinkhorn(nx):
ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1))
loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1))

loss_emp_sinkhorn_gausss_warmstart = nx.to_numpy(
ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1, warmstart='gaussian'))

# check constraints
np.testing.assert_allclose(
sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian
Expand All @@ -1055,6 +1058,7 @@ def test_empirical_sinkhorn(nx):
np.testing.assert_allclose(
sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian
np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05)
np.testing.assert_allclose(loss_emp_sinkhorn_gausss_warmstart, loss_sinkhorn, atol=1e-05)


def test_lazy_empirical_sinkhorn(nx):
Expand Down Expand Up @@ -1095,6 +1099,9 @@ def test_lazy_empirical_sinkhorn(nx):
loss_emp_sinkhorn = nx.to_numpy(loss_emp_sinkhorn)
loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1))

loss_emp_sinkhorn_gausss_warmstart = nx.to_numpy(
ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1, warmstart='gaussian', isLazy=True))

# check constraints
np.testing.assert_allclose(
sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian
Expand All @@ -1109,6 +1116,7 @@ def test_lazy_empirical_sinkhorn(nx):
np.testing.assert_allclose(
sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian
np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05)
np.testing.assert_allclose(loss_emp_sinkhorn_gausss_warmstart, loss_sinkhorn, atol=1e-05)


def test_empirical_sinkhorn_divergence(nx):
Expand Down
19 changes: 19 additions & 0 deletions test/test_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,22 @@ def test_gaussian_gromov_wasserstein_mapping(nx, d_target):

if d_target >= 2:
np.testing.assert_allclose(Cs, Ctt)


def test_gaussian_init(nx):
ns = 50
nt = 50

Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)

a_s = np.ones((ns, 1)) / ns
a_t = np.ones((nt, 1)) / nt

Xsb, Xtb, a_sb, a_tb = nx.from_numpy(Xs, Xt, a_s, a_t)

f = ot.gaussian.dual_gaussian_init(Xsb, Xtb)

f2 = ot.gaussian.dual_gaussian_init(Xsb, Xtb, a_sb, a_tb)

np.testing.assert_allclose(nx.to_numpy(f), nx.to_numpy(f2))

0 comments on commit b3be5a6

Please sign in to comment.