Skip to content

Commit

Permalink
Fix a bug in multi-vector with faiss engine
Browse files Browse the repository at this point in the history
  • Loading branch information
heemin32 committed May 22, 2024
1 parent 6efac69 commit f21fdf4
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 35 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Support script score when doc value is disabled and fix misusing DISI [#1696](https://github.com/opensearch-project/k-NN/pull/1696)
### Bug Fixes
* Block commas in model description [#1692](https://github.com/opensearch-project/k-NN/pull/1692)
* Fix a bug in multi-vector with faiss engine [#1715](https://github.com/opensearch-project/k-NN/pull/1715)
### Infrastructure
### Documentation
### Maintenance
Expand Down
110 changes: 77 additions & 33 deletions jni/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
From 0d1385959ddecabb2825957e48ff28ff0e8abf53 Mon Sep 17 00:00:00 2001
From 35ef01f59b8903dfbd4d08ff874b085e851e4228 Mon Sep 17 00:00:00 2001
From: Heemin Kim <[email protected]>
Date: Tue, 30 Jan 2024 14:43:56 -0800
Subject: [PATCH] Add IDGrouper for HNSW
Expand All @@ -7,18 +7,18 @@ Signed-off-by: Heemin Kim <[email protected]>
---
faiss/CMakeLists.txt | 3 +
faiss/Index.h | 8 +-
faiss/IndexHNSW.cpp | 13 ++-
faiss/IndexIDMap.cpp | 29 ++++++
faiss/IndexIDMap.h | 22 +++++
faiss/impl/HNSW.cpp | 10 +-
faiss/impl/IDGrouper.cpp | 51 ++++++++++
faiss/impl/IDGrouper.h | 51 ++++++++++
faiss/impl/ResultHandler.h | 187 ++++++++++++++++++++++++++++++++++++
faiss/utils/GroupHeap.h | 182 +++++++++++++++++++++++++++++++++++
faiss/IndexHNSW.cpp | 13 +-
faiss/IndexIDMap.cpp | 29 +++++
faiss/IndexIDMap.h | 22 ++++
faiss/impl/HNSW.cpp | 6 +
faiss/impl/IDGrouper.cpp | 51 ++++++++
faiss/impl/IDGrouper.h | 51 ++++++++
faiss/impl/ResultHandler.h | 190 +++++++++++++++++++++++++++++
faiss/utils/GroupHeap.h | 182 ++++++++++++++++++++++++++++
tests/CMakeLists.txt | 2 +
tests/test_group_heap.cpp | 98 +++++++++++++++++++
tests/test_id_grouper.cpp | 189 +++++++++++++++++++++++++++++++++++++
13 files changed, 838 insertions(+), 7 deletions(-)
tests/test_group_heap.cpp | 98 +++++++++++++++
tests/test_id_grouper.cpp | 241 +++++++++++++++++++++++++++++++++++++
13 files changed, 891 insertions(+), 5 deletions(-)
create mode 100644 faiss/impl/IDGrouper.cpp
create mode 100644 faiss/impl/IDGrouper.h
create mode 100644 faiss/utils/GroupHeap.h
Expand Down Expand Up @@ -54,7 +54,7 @@ index a890a46f..137e68d4 100644
utils/WorkerThread.h
utils/distances.h
diff --git a/faiss/Index.h b/faiss/Index.h
index 4b4b302b..3b673d1e 100644
index 3d1bdb99..a8622858 100644
--- a/faiss/Index.h
+++ b/faiss/Index.h
@@ -38,9 +38,10 @@
Expand Down Expand Up @@ -106,7 +106,7 @@ index 9a67332d..a5e0fea0 100644
if (is_similarity_metric(this->metric_type)) {
// we need to revert the negated distances
diff --git a/faiss/IndexIDMap.cpp b/faiss/IndexIDMap.cpp
index e093bbda..e24365d5 100644
index dc84052b..3f375e7b 100644
--- a/faiss/IndexIDMap.cpp
+++ b/faiss/IndexIDMap.cpp
@@ -102,6 +102,23 @@ struct ScopedSelChange {
Expand Down Expand Up @@ -198,20 +198,9 @@ index 2d164123..a68887bd 100644
+
} // namespace faiss
diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp
index fb4de678..b6f602a0 100644
index a9fb9daf..33b56638 100644
--- a/faiss/impl/HNSW.cpp
+++ b/faiss/impl/HNSW.cpp
@@ -110,8 +110,8 @@ void HNSW::print_neighbor_stats(int level) const {
level,
nb_neighbors(level));
size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0;
-#pragma omp parallel for reduction(+: tot_neigh) reduction(+: tot_common) \
- reduction(+: tot_reciprocal) reduction(+: n_node)
+#pragma omp parallel for reduction(+ : tot_neigh) reduction(+ : tot_common) \
+ reduction(+ : tot_reciprocal) reduction(+ : n_node)
for (int i = 0; i < levels.size(); i++) {
if (levels[i] > level) {
n_node++;
@@ -804,6 +804,12 @@ int extract_k_from_ResultHandler(ResultHandler<C>& res) {
if (auto hres = dynamic_cast<RH::SingleResultHandler*>(&res)) {
return hres->k;
Expand Down Expand Up @@ -340,19 +329,20 @@ index 00000000..d56113d9
+
+} // namespace faiss
diff --git a/faiss/impl/ResultHandler.h b/faiss/impl/ResultHandler.h
index 270de8dc..2f7f3e7f 100644
index 270de8dc..3199634f 100644
--- a/faiss/impl/ResultHandler.h
+++ b/faiss/impl/ResultHandler.h
@@ -12,6 +12,8 @@
@@ -12,6 +12,9 @@
#pragma once

#include <faiss/impl/AuxIndexStructures.h>
+#include <faiss/impl/FaissException.h>
+#include <faiss/impl/IDGrouper.h>
+#include <faiss/utils/GroupHeap.h>
#include <faiss/utils/Heap.h>
#include <faiss/utils/partitioning.h>

@@ -265,6 +267,191 @@ struct HeapBlockResultHandler : BlockResultHandler<C> {
@@ -265,6 +268,193 @@ struct HeapBlockResultHandler : BlockResultHandler<C> {
}
};

Expand Down Expand Up @@ -436,6 +426,7 @@ index 270de8dc..2f7f3e7f 100644
+ idx,
+ group_id,
+ &group_id_to_index_in_heap);
+ threshold = heap_dis[0];
+ return true;
+ } else {
+ size_t pos = it_pos->second;
Expand All @@ -452,6 +443,7 @@ index 270de8dc..2f7f3e7f 100644
+ idx,
+ group_id,
+ &group_id_to_index_in_heap);
+ threshold = heap_dis[0];
+ return true;
+ }
+ }
Expand Down Expand Up @@ -734,10 +726,10 @@ index 00000000..3b7078da
+} // namespace faiss
\ No newline at end of file
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
index cc0a4f4c..96e19328 100644
index 9017edc5..a8e9d30c 100644
--- a/tests/CMakeLists.txt
+++ b/tests/CMakeLists.txt
@@ -26,6 +26,8 @@ set(FAISS_TEST_SRC
@@ -27,6 +27,8 @@ set(FAISS_TEST_SRC
test_approx_topk.cpp
test_RCQ_cropping.cpp
test_distances_simd.cpp
Expand Down Expand Up @@ -852,10 +844,10 @@ index 00000000..0e8fe7a7
+}
diff --git a/tests/test_id_grouper.cpp b/tests/test_id_grouper.cpp
new file mode 100644
index 00000000..2aed5500
index 00000000..6601795b
--- /dev/null
+++ b/tests/test_id_grouper.cpp
@@ -0,0 +1,189 @@
@@ -0,0 +1,241 @@
+/**
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ *
Expand Down Expand Up @@ -920,6 +912,58 @@ index 00000000..2aed5500
+ ASSERT_EQ(bitmap.NO_MORE_DOCS, bitmap.get_group(group_ids[3] + 1));
+}
+
+TEST(IdGrouper, sanity_test) {
+ int d = 1; // dimension
+ int nb = 10; // database size
+
+ std::mt19937 rng;
+ std::uniform_real_distribution<> distrib;
+
+ float* xb = new float[d * nb];
+
+ for (int i = 0; i < nb; i++) {
+ for (int j = 0; j < d; j++)
+ xb[d * i + j] = distrib(rng);
+ xb[d * i] += i / 1000.;
+ }
+
+ uint64_t bitmap[1] = {};
+ faiss::IDGrouperBitmap id_grouper(1, bitmap);
+ for (int i = 0; i < nb; i++) {
+ id_grouper.set_group(i);
+ }
+
+ int k = 5;
+ int m = 8;
+ faiss::Index* index =
+ new faiss::IndexHNSWFlat(d, m, faiss::MetricType::METRIC_L2);
+ index->add(nb, xb); // add vectors to the index
+
+ // search
+ auto pSearchParameters = new faiss::SearchParametersHNSW();
+
+ idx_t* expectedI = new idx_t[k];
+ float* expectedD = new float[k];
+ index->search(1, xb, k, expectedD, expectedI, pSearchParameters);
+
+ idx_t* I = new idx_t[k];
+ float* D = new float[k];
+ pSearchParameters->grp = &id_grouper;
+ index->search(1, xb, k, D, I, pSearchParameters);
+
+ // compare
+ for (int j = 0; j < k; j++) {
+ ASSERT_EQ(expectedI[j], I[j]);
+ ASSERT_EQ(expectedD[j], D[j]);
+ }
+
+ delete[] expectedI;
+ delete[] expectedD;
+ delete[] I;
+ delete[] D;
+ delete[] xb;
+}
+
+TEST(IdGrouper, bitmap_with_hnsw) {
+ int d = 1; // dimension
+ int nb = 10; // database size
Expand Down
8 changes: 6 additions & 2 deletions src/test/java/org/opensearch/knn/index/NestedSearchIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,13 @@ public void testNestedSearchWithLucene_whenKIsTwo_thenReturnTwoResults() {
refreshIndex(INDEX_NAME);
forceMergeKnnIndex(INDEX_NAME);

Float[] queryVector = { 1f, 1f };
Float[] queryVector = { 14f, 14f };
Response response = queryNestedField(INDEX_NAME, 2, queryVector);
String entity = EntityUtils.toString(response.getEntity());
assertEquals(2, parseHits(entity));
assertEquals(2, parseTotalSearchHits(entity));
assertEquals("14", parseIds(entity).get(0));
assertEquals("13", parseIds(entity).get(1));
}

@SneakyThrows
Expand All @@ -97,11 +99,13 @@ public void testNestedSearchWithFaiss_whenKIsTwo_thenReturnTwoResults() {
refreshIndex(INDEX_NAME);
forceMergeKnnIndex(INDEX_NAME);

Float[] queryVector = { 1f, 1f };
Float[] queryVector = { 14f, 14f };
Response response = queryNestedField(INDEX_NAME, 2, queryVector);
String entity = EntityUtils.toString(response.getEntity());
assertEquals(2, parseHits(entity));
assertEquals(2, parseTotalSearchHits(entity));
assertEquals("14", parseIds(entity).get(0));
assertEquals("13", parseIds(entity).get(1));
}

/**
Expand Down

0 comments on commit f21fdf4

Please sign in to comment.