-
Notifications
You must be signed in to change notification settings - Fork 127
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix a bug in multi-vector with faiss engine
- Loading branch information
Showing
3 changed files
with
84 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
From 0d1385959ddecabb2825957e48ff28ff0e8abf53 Mon Sep 17 00:00:00 2001 | ||
From 35ef01f59b8903dfbd4d08ff874b085e851e4228 Mon Sep 17 00:00:00 2001 | ||
From: Heemin Kim <[email protected]> | ||
Date: Tue, 30 Jan 2024 14:43:56 -0800 | ||
Subject: [PATCH] Add IDGrouper for HNSW | ||
|
@@ -7,18 +7,18 @@ Signed-off-by: Heemin Kim <[email protected]> | |
--- | ||
faiss/CMakeLists.txt | 3 + | ||
faiss/Index.h | 8 +- | ||
faiss/IndexHNSW.cpp | 13 ++- | ||
faiss/IndexIDMap.cpp | 29 ++++++ | ||
faiss/IndexIDMap.h | 22 +++++ | ||
faiss/impl/HNSW.cpp | 10 +- | ||
faiss/impl/IDGrouper.cpp | 51 ++++++++++ | ||
faiss/impl/IDGrouper.h | 51 ++++++++++ | ||
faiss/impl/ResultHandler.h | 187 ++++++++++++++++++++++++++++++++++++ | ||
faiss/utils/GroupHeap.h | 182 +++++++++++++++++++++++++++++++++++ | ||
faiss/IndexHNSW.cpp | 13 +- | ||
faiss/IndexIDMap.cpp | 29 +++++ | ||
faiss/IndexIDMap.h | 22 ++++ | ||
faiss/impl/HNSW.cpp | 6 + | ||
faiss/impl/IDGrouper.cpp | 51 ++++++++ | ||
faiss/impl/IDGrouper.h | 51 ++++++++ | ||
faiss/impl/ResultHandler.h | 190 +++++++++++++++++++++++++++++ | ||
faiss/utils/GroupHeap.h | 182 ++++++++++++++++++++++++++++ | ||
tests/CMakeLists.txt | 2 + | ||
tests/test_group_heap.cpp | 98 +++++++++++++++++++ | ||
tests/test_id_grouper.cpp | 189 +++++++++++++++++++++++++++++++++++++ | ||
13 files changed, 838 insertions(+), 7 deletions(-) | ||
tests/test_group_heap.cpp | 98 +++++++++++++++ | ||
tests/test_id_grouper.cpp | 241 +++++++++++++++++++++++++++++++++++++ | ||
13 files changed, 891 insertions(+), 5 deletions(-) | ||
create mode 100644 faiss/impl/IDGrouper.cpp | ||
create mode 100644 faiss/impl/IDGrouper.h | ||
create mode 100644 faiss/utils/GroupHeap.h | ||
|
@@ -54,7 +54,7 @@ index a890a46f..137e68d4 100644 | |
utils/WorkerThread.h | ||
utils/distances.h | ||
diff --git a/faiss/Index.h b/faiss/Index.h | ||
index 4b4b302b..3b673d1e 100644 | ||
index 3d1bdb99..a8622858 100644 | ||
--- a/faiss/Index.h | ||
+++ b/faiss/Index.h | ||
@@ -38,9 +38,10 @@ | ||
|
@@ -106,7 +106,7 @@ index 9a67332d..a5e0fea0 100644 | |
if (is_similarity_metric(this->metric_type)) { | ||
// we need to revert the negated distances | ||
diff --git a/faiss/IndexIDMap.cpp b/faiss/IndexIDMap.cpp | ||
index e093bbda..e24365d5 100644 | ||
index dc84052b..3f375e7b 100644 | ||
--- a/faiss/IndexIDMap.cpp | ||
+++ b/faiss/IndexIDMap.cpp | ||
@@ -102,6 +102,23 @@ struct ScopedSelChange { | ||
|
@@ -198,20 +198,9 @@ index 2d164123..a68887bd 100644 | |
+ | ||
} // namespace faiss | ||
diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp | ||
index fb4de678..b6f602a0 100644 | ||
index a9fb9daf..33b56638 100644 | ||
--- a/faiss/impl/HNSW.cpp | ||
+++ b/faiss/impl/HNSW.cpp | ||
@@ -110,8 +110,8 @@ void HNSW::print_neighbor_stats(int level) const { | ||
level, | ||
nb_neighbors(level)); | ||
size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0; | ||
-#pragma omp parallel for reduction(+: tot_neigh) reduction(+: tot_common) \ | ||
- reduction(+: tot_reciprocal) reduction(+: n_node) | ||
+#pragma omp parallel for reduction(+ : tot_neigh) reduction(+ : tot_common) \ | ||
+ reduction(+ : tot_reciprocal) reduction(+ : n_node) | ||
for (int i = 0; i < levels.size(); i++) { | ||
if (levels[i] > level) { | ||
n_node++; | ||
@@ -804,6 +804,12 @@ int extract_k_from_ResultHandler(ResultHandler<C>& res) { | ||
if (auto hres = dynamic_cast<RH::SingleResultHandler*>(&res)) { | ||
return hres->k; | ||
|
@@ -340,19 +329,20 @@ index 00000000..d56113d9 | |
+ | ||
+} // namespace faiss | ||
diff --git a/faiss/impl/ResultHandler.h b/faiss/impl/ResultHandler.h | ||
index 270de8dc..2f7f3e7f 100644 | ||
index 270de8dc..3199634f 100644 | ||
--- a/faiss/impl/ResultHandler.h | ||
+++ b/faiss/impl/ResultHandler.h | ||
@@ -12,6 +12,8 @@ | ||
@@ -12,6 +12,9 @@ | ||
#pragma once | ||
|
||
#include <faiss/impl/AuxIndexStructures.h> | ||
+#include <faiss/impl/FaissException.h> | ||
+#include <faiss/impl/IDGrouper.h> | ||
+#include <faiss/utils/GroupHeap.h> | ||
#include <faiss/utils/Heap.h> | ||
#include <faiss/utils/partitioning.h> | ||
|
||
@@ -265,6 +267,191 @@ struct HeapBlockResultHandler : BlockResultHandler<C> { | ||
@@ -265,6 +268,193 @@ struct HeapBlockResultHandler : BlockResultHandler<C> { | ||
} | ||
}; | ||
|
||
|
@@ -436,6 +426,7 @@ index 270de8dc..2f7f3e7f 100644 | |
+ idx, | ||
+ group_id, | ||
+ &group_id_to_index_in_heap); | ||
+ threshold = heap_dis[0]; | ||
+ return true; | ||
+ } else { | ||
+ size_t pos = it_pos->second; | ||
|
@@ -452,6 +443,7 @@ index 270de8dc..2f7f3e7f 100644 | |
+ idx, | ||
+ group_id, | ||
+ &group_id_to_index_in_heap); | ||
+ threshold = heap_dis[0]; | ||
+ return true; | ||
+ } | ||
+ } | ||
|
@@ -734,10 +726,10 @@ index 00000000..3b7078da | |
+} // namespace faiss | ||
\ No newline at end of file | ||
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt | ||
index cc0a4f4c..96e19328 100644 | ||
index 9017edc5..a8e9d30c 100644 | ||
--- a/tests/CMakeLists.txt | ||
+++ b/tests/CMakeLists.txt | ||
@@ -26,6 +26,8 @@ set(FAISS_TEST_SRC | ||
@@ -27,6 +27,8 @@ set(FAISS_TEST_SRC | ||
test_approx_topk.cpp | ||
test_RCQ_cropping.cpp | ||
test_distances_simd.cpp | ||
|
@@ -852,10 +844,10 @@ index 00000000..0e8fe7a7 | |
+} | ||
diff --git a/tests/test_id_grouper.cpp b/tests/test_id_grouper.cpp | ||
new file mode 100644 | ||
index 00000000..2aed5500 | ||
index 00000000..6601795b | ||
--- /dev/null | ||
+++ b/tests/test_id_grouper.cpp | ||
@@ -0,0 +1,189 @@ | ||
@@ -0,0 +1,241 @@ | ||
+/** | ||
+ * Copyright (c) Facebook, Inc. and its affiliates. | ||
+ * | ||
|
@@ -920,6 +912,58 @@ index 00000000..2aed5500 | |
+ ASSERT_EQ(bitmap.NO_MORE_DOCS, bitmap.get_group(group_ids[3] + 1)); | ||
+} | ||
+ | ||
+TEST(IdGrouper, sanity_test) { | ||
+ int d = 1; // dimension | ||
+ int nb = 10; // database size | ||
+ | ||
+ std::mt19937 rng; | ||
+ std::uniform_real_distribution<> distrib; | ||
+ | ||
+ float* xb = new float[d * nb]; | ||
+ | ||
+ for (int i = 0; i < nb; i++) { | ||
+ for (int j = 0; j < d; j++) | ||
+ xb[d * i + j] = distrib(rng); | ||
+ xb[d * i] += i / 1000.; | ||
+ } | ||
+ | ||
+ uint64_t bitmap[1] = {}; | ||
+ faiss::IDGrouperBitmap id_grouper(1, bitmap); | ||
+ for (int i = 0; i < nb; i++) { | ||
+ id_grouper.set_group(i); | ||
+ } | ||
+ | ||
+ int k = 5; | ||
+ int m = 8; | ||
+ faiss::Index* index = | ||
+ new faiss::IndexHNSWFlat(d, m, faiss::MetricType::METRIC_L2); | ||
+ index->add(nb, xb); // add vectors to the index | ||
+ | ||
+ // search | ||
+ auto pSearchParameters = new faiss::SearchParametersHNSW(); | ||
+ | ||
+ idx_t* expectedI = new idx_t[k]; | ||
+ float* expectedD = new float[k]; | ||
+ index->search(1, xb, k, expectedD, expectedI, pSearchParameters); | ||
+ | ||
+ idx_t* I = new idx_t[k]; | ||
+ float* D = new float[k]; | ||
+ pSearchParameters->grp = &id_grouper; | ||
+ index->search(1, xb, k, D, I, pSearchParameters); | ||
+ | ||
+ // compare | ||
+ for (int j = 0; j < k; j++) { | ||
+ ASSERT_EQ(expectedI[j], I[j]); | ||
+ ASSERT_EQ(expectedD[j], D[j]); | ||
+ } | ||
+ | ||
+ delete[] expectedI; | ||
+ delete[] expectedD; | ||
+ delete[] I; | ||
+ delete[] D; | ||
+ delete[] xb; | ||
+} | ||
+ | ||
+TEST(IdGrouper, bitmap_with_hnsw) { | ||
+ int d = 1; // dimension | ||
+ int nb = 10; // database size | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters