From d4045f1cac80fdd8f2fbf4ca443dbebbc3ee3311 Mon Sep 17 00:00:00 2001 From: Clement Date: Tue, 19 Nov 2024 20:17:06 +0100 Subject: [PATCH] batchable BW distance --- ot/gaussian.py | 57 +++++++++++++++++++++---------------------- test/test_gaussian.py | 29 +++++++++++++++------- 2 files changed, 48 insertions(+), 38 deletions(-) diff --git a/ot/gaussian.py b/ot/gaussian.py index be8abd543..583e3c1f3 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -207,7 +207,7 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False): empirical distributions source :math:`\mu_s` and target :math:`\mu_t`, discussed in remark 2.31 :ref:`[1] `. - The Bures Wasserstein distance between source and target distribution :math:`\mathcal{W}` + The Bures Wasserstein distance between source and target distribution :math:`\mathcal{W}_2` .. math:: \mathcal{W}(\mu_s, \mu_t)_2^2= \left\lVert \mathbf{m}_s - \mathbf{m}_t \right\rVert^2 + \mathcal{B}(\Sigma_s, \Sigma_t)^{2} @@ -219,13 +219,13 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False): Parameters ---------- - ms : array-like (d,) + ms : array-like (d,) or (n,d) mean of the source distribution - mt : array-like (d,) + mt : array-like (d,) or (m,d) mean of the target distribution - Cs : array-like (d,d) + Cs : array-like (d,d) or (n,d,d) covariance of the source distribution - Ct : array-like (d,d) + Ct : array-like (d,d) or (m,d,d) covariance of the target distribution log : bool, optional record log if True @@ -233,7 +233,9 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False): Returns ------- - W : float + W : float if ms and md of shape (d,), array-like (n,) if ms of shape (n,d), + mt of shape (d,), array-like (m,) if ms of shape (d,) and mt of shape (m,d), + array-like (n,m) if ms of shape (n,d) and mt of shape (m,d) Bures Wasserstein distance log : dict log dictionary return only if log==True in parameters @@ -250,30 +252,27 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False): nx = get_backend(ms, mt, Cs, Ct) Cs12 = nx.sqrtm(Cs) - B = nx.trace(Cs + Ct - 2 * nx.sqrtm(dots(Cs12, Ct, Cs12))) - W = nx.sqrt(nx.maximum(nx.norm(ms - mt) ** 2 + B, 0)) - if log: - log = {} - log["Cs12"] = Cs12 - return W, log + if len(ms.shape) == 1 and len(mt.shape) == 1: + # Return float + squared_dist_m = nx.norm(ms - mt) ** 2 + B = nx.trace(Cs + Ct - 2 * nx.sqrtm(dots(Cs12, Ct, Cs12))) + elif len(ms.shape) == 1: + # Return shape (m,) + M = nx.einsum("ij, mjk, kl -> mil", Cs12, Ct, Cs12) + B = nx.trace(Cs[None] + Ct - 2 * nx.sqrtm(M)) + squared_dist_m = nx.norm(ms[None] - mt, axis=-1) ** 2 + elif len(mt.shape) == 1: + # Return shape (n,) + M = nx.einsum("nij, jk, nkl -> nil", Cs12, Ct, Cs12) + B = nx.trace(Cs + Ct[None] - 2 * nx.sqrtm(M)) + squared_dist_m = nx.norm(ms - mt[None], axis=-1) ** 2 else: - return W - - -def bures_wasserstein_distance_batch(ms, mt, Cs, Ct, log=False): - """ - TODO - Maybe try to merge it with bures_wasserstein_distance - """ - ms, mt, Cs, Ct = list_to_array(ms, mt, Cs, Ct) - nx = get_backend(ms, mt, Cs, Ct) - - Cs12 = nx.sqrtm(Cs) - M = nx.einsum("nij, mjk, nkl -> nmil", Cs12, Ct, Cs12) - B = nx.trace(Cs[:, None] + Ct[None] - 2 * nx.sqrtm(M)) + # Return shape (n,m) + M = nx.einsum("nij, mjk, nkl -> nmil", Cs12, Ct, Cs12) + B = nx.trace(Cs[:, None] + Ct[None] - 2 * nx.sqrtm(M)) + squared_dist_m = nx.norm(ms[:, None] - mt[None], axis=-1) ** 2 - squared_dist_m = nx.norm(ms[:, None] - mt[None], axis=-1) ** 2 W = nx.sqrt(nx.maximum(squared_dist_m + B, 0)) if log: @@ -361,12 +360,12 @@ def empirical_bures_wasserstein_distance( Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt) if log: - W, log = bures_wasserstein_distance(mxs, mxt, Cs, Ct, log=log) + W, log = bures_wasserstein_distance(mxs[0], mxt[0], Cs, Ct, log=log) log["Cs"] = Cs log["Ct"] = Ct return W, log else: - W = bures_wasserstein_distance(mxs, mxt, Cs, Ct) + W = bures_wasserstein_distance(mxs[0], mxt[0], Cs, Ct) return W diff --git a/test/test_gaussian.py b/test/test_gaussian.py index e353d748a..95843374e 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -122,22 +122,35 @@ def test_bures_wasserstein_distance_batch(nx): Wb = ot.gaussian.bures_wasserstein_distance(m[0, 0], m[1, 0], C[0], C[1], log=False) - Wb2 = ot.gaussian.bures_wasserstein_distance_batch( + Wb2 = ot.gaussian.bures_wasserstein_distance( m[0, 0][None], m[1, 0][None], C[0][None], C[1][None] ) np.testing.assert_allclose(nx.to_numpy(Wb), nx.to_numpy(Wb2[0, 0]), atol=1e-5) + np.testing.assert_equal(Wb2.shape, (1, 1)) - Wb2 = ot.gaussian.bures_wasserstein_distance_batch( - m[:, 0], m[1, 0][None], C, C[1][None] - ) + Wb2 = ot.gaussian.bures_wasserstein_distance(m[:, 0], m[1, 0][None], C, C[1][None]) np.testing.assert_allclose(nx.to_numpy(Wb), nx.to_numpy(Wb2[0, 0]), atol=1e-5) np.testing.assert_allclose(0, nx.to_numpy(Wb2[1, 0]), atol=1e-5) + np.testing.assert_equal(Wb2.shape, (2, 1)) + + Wb2 = ot.gaussian.bures_wasserstein_distance( + m[0, 0][None], m[1, 0], C[0][None], C[1] + ) + np.testing.assert_allclose(nx.to_numpy(Wb), nx.to_numpy(Wb2[0]), atol=1e-5) + np.testing.assert_equal(Wb2.shape, (1,)) - Wb2 = ot.gaussian.bures_wasserstein_distance_batch(m[:, 0], m[:, 0], C, C) + Wb2 = ot.gaussian.bures_wasserstein_distance( + m[0, 0], m[1, 0][None], C[0], C[1][None] + ) + np.testing.assert_allclose(nx.to_numpy(Wb), nx.to_numpy(Wb2[0]), atol=1e-5) + np.testing.assert_equal(Wb2.shape, (1,)) + + Wb2 = ot.gaussian.bures_wasserstein_distance(m[:, 0], m[:, 0], C, C) np.testing.assert_allclose(nx.to_numpy(Wb), nx.to_numpy(Wb2[1, 0]), atol=1e-5) np.testing.assert_allclose(nx.to_numpy(Wb), nx.to_numpy(Wb2[0, 1]), atol=1e-5) np.testing.assert_allclose(0, nx.to_numpy(Wb2[0, 0]), atol=1e-5) np.testing.assert_allclose(0, nx.to_numpy(Wb2[1, 1]), atol=1e-5) + np.testing.assert_equal(Wb2.shape, (2, 2)) @pytest.mark.parametrize("bias", [True, False]) @@ -278,9 +291,7 @@ def test_stochastic_gd_bures_wasserstein_barycenter(nx): m, C, method="fixed_point", log=False ) - loss = nx.mean( - ot.gaussian.bures_wasserstein_distance_batch(mb[None], m, Cb[None], C) - ) + loss = nx.mean(ot.gaussian.bures_wasserstein_distance(mb[None], m, Cb[None], C)) n_samples = [1, 5] for n in n_samples: @@ -289,7 +300,7 @@ def test_stochastic_gd_bures_wasserstein_barycenter(nx): ) loss2 = nx.mean( - ot.gaussian.bures_wasserstein_distance_batch(mb2[None], m, Cb2[None], C) + ot.gaussian.bures_wasserstein_distance(mb2[None], m, Cb2[None], C) ) np.testing.assert_allclose(mb, mb2, atol=1e-5)