diff --git a/src/seismicrna/cluster/compare.py b/src/seismicrna/cluster/compare.py index 2d9a93cd..0e5641f5 100644 --- a/src/seismicrna/cluster/compare.py +++ b/src/seismicrna/cluster/compare.py @@ -124,7 +124,7 @@ def _compare_groups(func: Callable, mus1: np.ndarray, mus2: np.ndarray): _, n2 = mus2.shape return np.array([[func(mus1[:, cluster1], mus2[:, cluster2]) for cluster2 in range(n2)] - for cluster1 in range(n1)]) + for cluster1 in range(n1)]).reshape((n1, n2)) def calc_rmsd_groups(mus1: np.ndarray, mus2: np.ndarray): @@ -157,14 +157,15 @@ def assign_clusterings(mus1: np.ndarray, mus2: np.ndarray): # the assignment problem in O(n³) time, and this naive approach runs # in O(n!) time, the latter is simpler and still sufficiently fast # when n is no more than about 6, which is almost always true. - best_assignment = list() + ns = np.arange(n) + best_assignment = ns min_cost = None - rows = list(range(n)) - for columns in permutations(rows): - cost = np.sum(costs[rows, columns]) + for cols in permutations(ns): + assignment = np.array(cols, dtype=int) + cost = np.sum(costs[ns, assignment]) if min_cost is None or cost < min_cost: - min_cost = float(cost) - best_assignment = list(zip(rows, columns)) + min_cost = cost + best_assignment = assignment return best_assignment @@ -172,14 +173,16 @@ def calc_rms_nrmsd(run1: EmClustering, run2: EmClustering): """ Compute the root-mean-square NRMSD between the clusters. """ costs = np.square(calc_nrmsd_groups(run1.p_mut, run2.p_mut)) assignment = assign_clusterings(run1.p_mut, run2.p_mut) - return float(np.sqrt(np.mean([costs[i, j] for i, j in assignment]))) + return float(np.sqrt(np.mean([costs[row, col] + for row, col in enumerate(assignment)]))) def calc_mean_pearson(run1: EmClustering, run2: EmClustering): """ Compute the mean Pearson correlation between the clusters. """ correlations = calc_pearson_groups(run1.p_mut, run2.p_mut) assignment = assign_clusterings(run1.p_mut, run2.p_mut) - return float(np.mean([correlations[i, j] for i, j in assignment])) + return float(np.mean([correlations[row, col] + for row, col in enumerate(assignment)])) ######################################################################## # # diff --git a/src/seismicrna/cluster/tests/compare_test.py b/src/seismicrna/cluster/tests/compare_test.py new file mode 100644 index 00000000..999cf909 --- /dev/null +++ b/src/seismicrna/cluster/tests/compare_test.py @@ -0,0 +1,58 @@ +import unittest as ut +import warnings +from itertools import permutations + +import numpy as np + +from seismicrna.cluster.compare import assign_clusterings +from seismicrna.core.array import calc_inverse + +rng = np.random.default_rng() + + +class TestAssignClusterings(ut.TestCase): + + def test_0_clusters(self): + for n in range(5): + x = np.empty((n, 0)) + y = np.empty((n, 0)) + self.assertTrue(np.array_equal(assign_clusterings(x, y), + np.empty(0, dtype=int))) + + def test_0_positions(self): + for n in range(5): + x = np.empty((0, n)) + y = np.empty((0, n)) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + with np.errstate(invalid="ignore"): + self.assertTrue(np.array_equal(assign_clusterings(x, y), + np.arange(n))) + + def test_1_cluster(self): + for n in range(1, 10): + x = rng.random((n, 1)) + y = rng.random((n, 1)) + self.assertTrue(np.array_equal(assign_clusterings(x, y), + np.array([0]))) + + def test_more_clusters(self): + for ncls in range(2, 5): + for npos in range(2, 8): + with self.subTest(ncls=ncls, npos=npos): + x = rng.random((npos, ncls)) + for permutation in permutations(range(ncls)): + cols = np.array(permutation) + y = x[:, cols] + self.assertTrue(np.array_equal( + assign_clusterings(x, y), + calc_inverse(cols) + )) + self.assertTrue(np.array_equal( + assign_clusterings(y, x), + cols + )) + + +if __name__ == "__main__": + ut.main()