Skip to content

Commit

Permalink
batchable BW distance
Browse files Browse the repository at this point in the history
  • Loading branch information
clbonet committed Nov 19, 2024
1 parent 447a1a6 commit d4045f1
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 38 deletions.
57 changes: 28 additions & 29 deletions ot/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False):
empirical distributions source :math:`\mu_s` and target :math:`\mu_t`,
discussed in remark 2.31 :ref:`[1] <references-bures-wasserstein-distance>`.
The Bures Wasserstein distance between source and target distribution :math:`\mathcal{W}`
The Bures Wasserstein distance between source and target distribution :math:`\mathcal{W}_2`
.. math::
\mathcal{W}(\mu_s, \mu_t)_2^2= \left\lVert \mathbf{m}_s - \mathbf{m}_t \right\rVert^2 + \mathcal{B}(\Sigma_s, \Sigma_t)^{2}
Expand All @@ -219,21 +219,23 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False):
Parameters
----------
ms : array-like (d,)
ms : array-like (d,) or (n,d)
mean of the source distribution
mt : array-like (d,)
mt : array-like (d,) or (m,d)
mean of the target distribution
Cs : array-like (d,d)
Cs : array-like (d,d) or (n,d,d)
covariance of the source distribution
Ct : array-like (d,d)
Ct : array-like (d,d) or (m,d,d)
covariance of the target distribution
log : bool, optional
record log if True
Returns
-------
W : float
W : float if ms and md of shape (d,), array-like (n,) if ms of shape (n,d),
mt of shape (d,), array-like (m,) if ms of shape (d,) and mt of shape (m,d),
array-like (n,m) if ms of shape (n,d) and mt of shape (m,d)
Bures Wasserstein distance
log : dict
log dictionary return only if log==True in parameters
Expand All @@ -250,30 +252,27 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False):
nx = get_backend(ms, mt, Cs, Ct)

Cs12 = nx.sqrtm(Cs)
B = nx.trace(Cs + Ct - 2 * nx.sqrtm(dots(Cs12, Ct, Cs12)))
W = nx.sqrt(nx.maximum(nx.norm(ms - mt) ** 2 + B, 0))

if log:
log = {}
log["Cs12"] = Cs12
return W, log
if len(ms.shape) == 1 and len(mt.shape) == 1:
# Return float
squared_dist_m = nx.norm(ms - mt) ** 2
B = nx.trace(Cs + Ct - 2 * nx.sqrtm(dots(Cs12, Ct, Cs12)))
elif len(ms.shape) == 1:
# Return shape (m,)
M = nx.einsum("ij, mjk, kl -> mil", Cs12, Ct, Cs12)
B = nx.trace(Cs[None] + Ct - 2 * nx.sqrtm(M))
squared_dist_m = nx.norm(ms[None] - mt, axis=-1) ** 2
elif len(mt.shape) == 1:
# Return shape (n,)
M = nx.einsum("nij, jk, nkl -> nil", Cs12, Ct, Cs12)
B = nx.trace(Cs + Ct[None] - 2 * nx.sqrtm(M))
squared_dist_m = nx.norm(ms - mt[None], axis=-1) ** 2
else:
return W


def bures_wasserstein_distance_batch(ms, mt, Cs, Ct, log=False):
"""
TODO
Maybe try to merge it with bures_wasserstein_distance
"""
ms, mt, Cs, Ct = list_to_array(ms, mt, Cs, Ct)
nx = get_backend(ms, mt, Cs, Ct)

Cs12 = nx.sqrtm(Cs)
M = nx.einsum("nij, mjk, nkl -> nmil", Cs12, Ct, Cs12)
B = nx.trace(Cs[:, None] + Ct[None] - 2 * nx.sqrtm(M))
# Return shape (n,m)
M = nx.einsum("nij, mjk, nkl -> nmil", Cs12, Ct, Cs12)
B = nx.trace(Cs[:, None] + Ct[None] - 2 * nx.sqrtm(M))
squared_dist_m = nx.norm(ms[:, None] - mt[None], axis=-1) ** 2

squared_dist_m = nx.norm(ms[:, None] - mt[None], axis=-1) ** 2
W = nx.sqrt(nx.maximum(squared_dist_m + B, 0))

if log:
Expand Down Expand Up @@ -361,12 +360,12 @@ def empirical_bures_wasserstein_distance(
Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt)

if log:
W, log = bures_wasserstein_distance(mxs, mxt, Cs, Ct, log=log)
W, log = bures_wasserstein_distance(mxs[0], mxt[0], Cs, Ct, log=log)
log["Cs"] = Cs
log["Ct"] = Ct
return W, log
else:
W = bures_wasserstein_distance(mxs, mxt, Cs, Ct)
W = bures_wasserstein_distance(mxs[0], mxt[0], Cs, Ct)
return W


Expand Down
29 changes: 20 additions & 9 deletions test/test_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,22 +122,35 @@ def test_bures_wasserstein_distance_batch(nx):

Wb = ot.gaussian.bures_wasserstein_distance(m[0, 0], m[1, 0], C[0], C[1], log=False)

Wb2 = ot.gaussian.bures_wasserstein_distance_batch(
Wb2 = ot.gaussian.bures_wasserstein_distance(
m[0, 0][None], m[1, 0][None], C[0][None], C[1][None]
)
np.testing.assert_allclose(nx.to_numpy(Wb), nx.to_numpy(Wb2[0, 0]), atol=1e-5)
np.testing.assert_equal(Wb2.shape, (1, 1))

Wb2 = ot.gaussian.bures_wasserstein_distance_batch(
m[:, 0], m[1, 0][None], C, C[1][None]
)
Wb2 = ot.gaussian.bures_wasserstein_distance(m[:, 0], m[1, 0][None], C, C[1][None])
np.testing.assert_allclose(nx.to_numpy(Wb), nx.to_numpy(Wb2[0, 0]), atol=1e-5)
np.testing.assert_allclose(0, nx.to_numpy(Wb2[1, 0]), atol=1e-5)
np.testing.assert_equal(Wb2.shape, (2, 1))

Wb2 = ot.gaussian.bures_wasserstein_distance(
m[0, 0][None], m[1, 0], C[0][None], C[1]
)
np.testing.assert_allclose(nx.to_numpy(Wb), nx.to_numpy(Wb2[0]), atol=1e-5)
np.testing.assert_equal(Wb2.shape, (1,))

Wb2 = ot.gaussian.bures_wasserstein_distance_batch(m[:, 0], m[:, 0], C, C)
Wb2 = ot.gaussian.bures_wasserstein_distance(
m[0, 0], m[1, 0][None], C[0], C[1][None]
)
np.testing.assert_allclose(nx.to_numpy(Wb), nx.to_numpy(Wb2[0]), atol=1e-5)
np.testing.assert_equal(Wb2.shape, (1,))

Wb2 = ot.gaussian.bures_wasserstein_distance(m[:, 0], m[:, 0], C, C)
np.testing.assert_allclose(nx.to_numpy(Wb), nx.to_numpy(Wb2[1, 0]), atol=1e-5)
np.testing.assert_allclose(nx.to_numpy(Wb), nx.to_numpy(Wb2[0, 1]), atol=1e-5)
np.testing.assert_allclose(0, nx.to_numpy(Wb2[0, 0]), atol=1e-5)
np.testing.assert_allclose(0, nx.to_numpy(Wb2[1, 1]), atol=1e-5)
np.testing.assert_equal(Wb2.shape, (2, 2))


@pytest.mark.parametrize("bias", [True, False])
Expand Down Expand Up @@ -278,9 +291,7 @@ def test_stochastic_gd_bures_wasserstein_barycenter(nx):
m, C, method="fixed_point", log=False
)

loss = nx.mean(
ot.gaussian.bures_wasserstein_distance_batch(mb[None], m, Cb[None], C)
)
loss = nx.mean(ot.gaussian.bures_wasserstein_distance(mb[None], m, Cb[None], C))

n_samples = [1, 5]
for n in n_samples:
Expand All @@ -289,7 +300,7 @@ def test_stochastic_gd_bures_wasserstein_barycenter(nx):
)

loss2 = nx.mean(
ot.gaussian.bures_wasserstein_distance_batch(mb2[None], m, Cb2[None], C)
ot.gaussian.bures_wasserstein_distance(mb2[None], m, Cb2[None], C)
)

np.testing.assert_allclose(mb, mb2, atol=1e-5)
Expand Down

0 comments on commit d4045f1

Please sign in to comment.