From 03750f51419e9a2c6b5c533c2e6fbc837bdb646a Mon Sep 17 00:00:00 2001 From: Amir Sadoughi Date: Wed, 24 Apr 2024 14:11:02 -0700 Subject: [PATCH] Fix IndexBinary.assign Python method Summary: Fixes #3343 Reviewed By: kuarora, junjieqi Differential Revision: D56526842 fbshipit-source-id: b7c4377495db4e68283cf4ce2b7c8fae008cd404 --- faiss/python/class_wrappers.py | 34 ++++++++++++++++++++++++++++++++++ tests/test_index_binary.py | 3 +++ 2 files changed, 37 insertions(+) diff --git a/faiss/python/class_wrappers.py b/faiss/python/class_wrappers.py index 4a6808d286..4af2345009 100644 --- a/faiss/python/class_wrappers.py +++ b/faiss/python/class_wrappers.py @@ -956,10 +956,44 @@ def replacement_remove_ids(self, x): sel = IDSelectorBatch(x.size, swig_ptr(x)) return self.remove_ids_c(sel) + def replacement_assign(self, x, k, labels=None): + """Find the k nearest neighbors of the set of vectors x in the index. + This is the same as the `search` method, but discards the distances. + + Parameters + ---------- + x : array_like + Query vectors, shape (n, d) where d is appropriate for the index. + `dtype` must be uint8. + k : int + Number of nearest neighbors. + labels : array_like, optional + Labels array to store the results. + + Returns + ------- + labels: array_like + Labels of the nearest neighbors, shape (n, k). + When not enough results are found, the label is set to -1 + """ + n, d = x.shape + x = _check_dtype_uint8(x) + assert d == self.code_size + assert k > 0 + + if labels is None: + labels = np.empty((n, k), dtype=np.int64) + else: + assert labels.shape == (n, k) + + self.assign_c(n, swig_ptr(x), swig_ptr(labels), k) + return labels + replace_method(the_class, 'add', replacement_add) replace_method(the_class, 'add_with_ids', replacement_add_with_ids) replace_method(the_class, 'train', replacement_train) replace_method(the_class, 'search', replacement_search) + replace_method(the_class, 'assign', replacement_assign) replace_method(the_class, 'range_search', replacement_range_search) replace_method(the_class, 'reconstruct', replacement_reconstruct) replace_method(the_class, 'reconstruct_n', replacement_reconstruct_n) diff --git a/tests/test_index_binary.py b/tests/test_index_binary.py index b505e0ba1c..3acf622fd4 100644 --- a/tests/test_index_binary.py +++ b/tests/test_index_binary.py @@ -100,6 +100,9 @@ def test_flat(self): index.add(self.xb) D, I = index.search(self.xq, 3) + I2 = index.assign(x=self.xq, k=3, labels=None) + assert np.all(I == I2) + for i in range(nq): for j, dj in zip(I[i], D[i]): ref_dis = binary_dis(self.xq[i], self.xb[j])