diff --git a/nanopq/convert_faiss.py b/nanopq/convert_faiss.py index dda127f..8c7fd7f 100644 --- a/nanopq/convert_faiss.py +++ b/nanopq/convert_faiss.py @@ -9,6 +9,7 @@ import numpy as np +from .opq import OPQ from .pq import PQ @@ -48,29 +49,51 @@ def nanopq_to_faiss(pq_nanopq): def faiss_to_nanopq(pq_faiss): - """Convert a `faiss.IndexPQ `_ instance to :class:`nanopq.PQ`. + """Convert a `faiss.IndexPQ `_ + or a `faiss.IndexPreTransform `_ instance to :class:`nanopq.OPQ`. To use this function, `faiss module needs to be installed `_. Args: - pq_faiss (faiss.IndexPQ): An input PQ instance. + pq_faiss (Union[faiss.IndexPQ, faiss.IndexPreTransform]): An input PQ or OPQ instance. Returns: tuple: - * nanopq.PQ: A converted PQ instance, with the same codewords to the input. + * Union[nanopq.PQ, nanopq.OPQ]: A converted PQ or OPQ instance, with the same codewords to the input. * np.ndarray: Stored PQ codes in the input IndexPQ, with the shape=(N, M). This will be empty if codes are not stored """ - assert isinstance(pq_faiss, faiss.IndexPQ), "Error. pq_faiss must be IndexPQ" + assert isinstance( + pq_faiss, (faiss.IndexPQ, faiss.IndexPreTransform) + ), "Error. pq_faiss must be IndexPQ or IndexPreTransform" assert pq_faiss.is_trained, "Error. pq_faiss must have been trained" - pq_nanopq = PQ(M=pq_faiss.pq.M, Ks=int(2 ** pq_faiss.pq.nbits)) - pq_nanopq.Ds = int(pq_faiss.pq.d / pq_faiss.pq.M) - - # Extract codewords from pq_IndexPQ.ProductQuantizer, reshape them to M*Ks*Ds - codewords = faiss.vector_to_array(pq_faiss.pq.centroids).reshape( - pq_nanopq.M, pq_nanopq.Ks, pq_nanopq.Ds - ) - - pq_nanopq.codewords = codewords + if isinstance(pq_faiss, faiss.IndexPreTransform): + opq_matrix: faiss.LinearTransform = faiss.downcast_VectorTransform( + pq_faiss.chain.at(0) + ) + pq_faiss: faiss.IndexPQ = faiss.downcast_index(pq_faiss.index) + pq_nanopq = OPQ(M=pq_faiss.pq.M, Ks=int(2**pq_faiss.pq.nbits)) + pq_nanopq.pq.Ds = int(pq_faiss.pq.d / pq_faiss.pq.M) + + # Extract codewords from pq_IndexPQ.ProductQuantizer, reshape them to M*Ks*Ds + codewords = faiss.vector_to_array(pq_faiss.pq.centroids).reshape( + pq_nanopq.M, pq_nanopq.Ks, pq_nanopq.Ds + ) + + pq_nanopq.pq.codewords = codewords + pq_nanopq.R = ( + faiss.vector_to_array(opq_matrix.A) + .reshape(opq_matrix.d_out, opq_matrix.d_in) + .transpose(1, 0) + ) + else: + pq_nanopq = PQ(M=pq_faiss.pq.M, Ks=int(2**pq_faiss.pq.nbits)) + pq_nanopq.Ds = int(pq_faiss.pq.d / pq_faiss.pq.M) + + # Extract codewords from pq_IndexPQ.ProductQuantizer, reshape them to M*Ks*Ds + codewords = faiss.vector_to_array(pq_faiss.pq.centroids).reshape( + pq_nanopq.M, pq_nanopq.Ks, pq_nanopq.Ds + ) + pq_nanopq.codewords = codewords return pq_nanopq, faiss.vector_to_array(pq_faiss.codes).reshape(-1, pq_faiss.pq.M) diff --git a/tests/test_convert_faiss.py b/tests/test_convert_faiss.py index dd23d9d..2e2ba18 100644 --- a/tests/test_convert_faiss.py +++ b/tests/test_convert_faiss.py @@ -56,7 +56,7 @@ def test_nanopq_to_faiss(self): self.assertTrue(np.array_equal(ids1, ids2)) - def test_faiss_to_nanopq(self): + def test_faiss_to_nanopq_pq(self): D, M, Ks = 32, 4, 256 Nt, Nb, Nq = 2000, 10000, 100 nbits = int(np.log2(Ks)) @@ -70,6 +70,41 @@ def test_faiss_to_nanopq(self): pq_faiss.add(x=Xb) pq_nanopq, Cb_faiss = nanopq.faiss_to_nanopq(pq_faiss=pq_faiss) + self.assertIsInstance(pq_nanopq, nanopq.PQ) + self.assertEqual(pq_nanopq.codewords.shape, (M, Ks, int(D / M))) + + # Encoded results should be same + Cb_nanopq = pq_nanopq.encode(vecs=Xb) + self.assertTrue(np.array_equal(Cb_nanopq, Cb_faiss)) + + # Search result should be same + topk = 100 + _, ids1 = pq_faiss.search(x=Xq, k=topk) + ids2 = np.array( + [ + np.argsort(pq_nanopq.dtable(query=xq).adist(codes=Cb_nanopq))[:topk] + for xq in Xq + ] + ) + self.assertTrue(np.array_equal(ids1, ids2)) + + def test_faiss_to_nanopq_opq(self): + D, M, Ks = 32, 4, 256 + Nt, Nb, Nq = 2000, 10000, 100 + nbits = int(np.log2(Ks)) + assert nbits == 8 + Xt = np.random.rand(Nt, D).astype(np.float32) + Xb = np.random.rand(Nb, D).astype(np.float32) + Xq = np.random.rand(Nq, D).astype(np.float32) + + pq_faiss = faiss.IndexPQ(D, M, nbits) + opq_matrix = faiss.OPQMatrix(D, M=M) + pq_faiss = faiss.IndexPreTransform(opq_matrix, pq_faiss) + pq_faiss.train(x=Xt) + pq_faiss.add(x=Xb) + + pq_nanopq, Cb_faiss = nanopq.faiss_to_nanopq(pq_faiss=pq_faiss) + self.assertIsInstance(pq_nanopq, nanopq.OPQ) self.assertEqual(pq_nanopq.codewords.shape, (M, Ks, int(D / M))) # Encoded results should be same