From aa95c820c773ab709941aba168bc82728d73dc0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Wed, 27 Nov 2024 12:09:43 +0100 Subject: [PATCH] better dist and tests --- ot/utils.py | 43 +++++++++++++++++++++++++++++++------ test/test_utils.py | 53 +++++++++++++++++++++++++++++++++++++++------- 2 files changed, 81 insertions(+), 15 deletions(-) diff --git a/ot/utils.py b/ot/utils.py index a2d328484..0c07be627 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -17,25 +17,25 @@ from inspect import signature from .backend import get_backend, Backend, NumpyBackend, JaxBackend -__time_tic_toc = time.time() +__time_tic_toc = time.perf_counter() def tic(): r"""Python implementation of Matlab tic() function""" global __time_tic_toc - __time_tic_toc = time.time() + __time_tic_toc = time.perf_counter() def toc(message="Elapsed time : {} s"): r"""Python implementation of Matlab toc() function""" - t = time.time() + t = time.perf_counter() print(message.format(t - __time_tic_toc)) return t - __time_tic_toc def toq(): r"""Python implementation of Julia toc() function""" - t = time.time() + t = time.perf_counter() return t - __time_tic_toc @@ -291,11 +291,12 @@ def euclidean_distances(X, Y, squared=False): return c -def dist(x1, x2=None, metric="sqeuclidean", p=2, w=None): +def dist(x1, x2=None, metric="sqeuclidean", p=2, w=None, nx=None): r"""Compute distance between samples in :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}` .. note:: This function is backend-compatible and will work on arrays - from all compatible backends. + from all compatible backends for the following metrics: + 'sqeuclidean', 'euclidean', 'cityblock', 'minkowski', 'cosine', 'correlation'. Parameters ---------- @@ -315,7 +316,8 @@ def dist(x1, x2=None, metric="sqeuclidean", p=2, w=None): p-norm for the Minkowski and the Weighted Minkowski metrics. Default value is 2. w : array-like, rank 1 Weights for the weighted metrics. - + nx : Backend, optional + Backend to perform computations on. If omitted, the backend defaults to that of `x1`. Returns ------- @@ -324,12 +326,39 @@ def dist(x1, x2=None, metric="sqeuclidean", p=2, w=None): distance matrix computed with given metric """ + if nx is None: + nx = get_backend(x1, x2) if x2 is None: x2 = x1 if metric == "sqeuclidean": return euclidean_distances(x1, x2, squared=True) elif metric == "euclidean": return euclidean_distances(x1, x2, squared=False) + elif metric == "cityblock": + return nx.sum(nx.abs(x1[:, None, :] - x2[None, :, :]), axis=2) + elif metric == "minkowski": + if w is None: + return nx.power( + nx.sum(nx.power(nx.abs(x1[:, None, :] - x2[None, :, :]), p), axis=2), + 1 / p, + ) + return nx.power( + nx.sum( + w[None, None, :] * nx.power(nx.abs(x1[:, None, :] - x2[None, :, :]), p), + axis=2, + ), + 1 / p, + ) + elif metric == "cosine": + nx1 = nx.sqrt(nx.einsum("ij,ij->i", x1, x1)) + nx2 = nx.sqrt(nx.einsum("ij,ij->i", x2, x2)) + return 1.0 - (nx.dot(x1, nx.transpose(x2)) / nx1[:, None] / nx2[None, :]) + elif metric == "correlation": + x1 = x1 - nx.mean(x1, axis=1)[:, None] + x2 = x2 - nx.mean(x2, axis=1)[:, None] + nx1 = nx.sqrt(nx.einsum("ij,ij->i", x1, x1)) + nx2 = nx.sqrt(nx.einsum("ij,ij->i", x2, x2)) + return 1.0 - (nx.dot(x1, nx.transpose(x2)) / nx1[:, None] / nx2[None, :]) else: if not get_backend(x1, x2).__name__ == "numpy": raise NotImplementedError() diff --git a/test/test_utils.py b/test/test_utils.py index d50f29915..3607510de 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -8,6 +8,31 @@ import numpy as np import sys import pytest +import scipy + +lst_metrics = [ + "euclidean", + "sqeuclidean", + "cityblock", + "cosine", + "minkowski", + "correlation", +] + +lst_all_metrics = lst_metrics + [ + "braycurtis", + "canberra", + "chebyshev", + "dice", + "hamming", + "jaccard", + "matching", + "rogerstanimoto", + "russellrao", + "sokalmichener", + "sokalsneath", + "yule", +] def get_LazyTensor(nx): @@ -185,7 +210,7 @@ def test_dist(): assert D4[0, 1] == D4[1, 0] - # dist shoul return squared euclidean + # dist should return squared euclidean np.testing.assert_allclose(D, D2, atol=1e-14) np.testing.assert_allclose(D, D3, atol=1e-14) @@ -230,20 +255,32 @@ def test_dist(): ot.dist(x, x, metric="wminkowski") -def test_dist_backends(nx): +@pytest.mark.parametrize("metric", lst_metrics) +def test_dist_backends(nx, metric): n = 100 rng = np.random.RandomState(0) x = rng.randn(n, 2) x1 = nx.from_numpy(x) - lst_metric = ["euclidean", "sqeuclidean"] + D = ot.dist(x, x, metric=metric) + D1 = ot.dist(x1, x1, metric=metric) - for metric in lst_metric: - D = ot.dist(x, x, metric=metric) - D1 = ot.dist(x1, x1, metric=metric) + # low atol because jax forces float32 + np.testing.assert_allclose(D, nx.to_numpy(D1), atol=1e-5) - # low atol because jax forces float32 - np.testing.assert_allclose(D, nx.to_numpy(D1), atol=1e-5) + +@pytest.mark.parametrize("metric", lst_all_metrics) +def test_dist_vs_cdist(metric): + n = 10 + + rng = np.random.RandomState(0) + x = rng.randn(n, 2) + y = rng.randn(n + 1, 2) + + D = ot.dist(x, y, metric=metric) + D2 = scipy.spatial.distance.cdist(x, y, metric=metric) + + np.testing.assert_allclose(D, D2, atol=1e-15) def test_dist0():