Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update threshold value after new result is added #1715

Merged
merged 1 commit into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Support script score when doc value is disabled and fix misusing DISI [#1696](https://github.com/opensearch-project/k-NN/pull/1696)
### Bug Fixes
* Block commas in model description [#1692](https://github.com/opensearch-project/k-NN/pull/1692)
* Update threshold value after new result is added [#1715](https://github.com/opensearch-project/k-NN/pull/1715)
### Infrastructure
### Documentation
### Maintenance
Expand Down
110 changes: 77 additions & 33 deletions jni/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
From 0d1385959ddecabb2825957e48ff28ff0e8abf53 Mon Sep 17 00:00:00 2001
From 35ef01f59b8903dfbd4d08ff874b085e851e4228 Mon Sep 17 00:00:00 2001
From: Heemin Kim <[email protected]>
Date: Tue, 30 Jan 2024 14:43:56 -0800
Subject: [PATCH] Add IDGrouper for HNSW
Expand All @@ -7,18 +7,18 @@ Signed-off-by: Heemin Kim <[email protected]>
---
faiss/CMakeLists.txt | 3 +
faiss/Index.h | 8 +-
faiss/IndexHNSW.cpp | 13 ++-
faiss/IndexIDMap.cpp | 29 ++++++
faiss/IndexIDMap.h | 22 +++++
faiss/impl/HNSW.cpp | 10 +-
faiss/impl/IDGrouper.cpp | 51 ++++++++++
faiss/impl/IDGrouper.h | 51 ++++++++++
faiss/impl/ResultHandler.h | 187 ++++++++++++++++++++++++++++++++++++
faiss/utils/GroupHeap.h | 182 +++++++++++++++++++++++++++++++++++
faiss/IndexHNSW.cpp | 13 +-
faiss/IndexIDMap.cpp | 29 +++++
faiss/IndexIDMap.h | 22 ++++
faiss/impl/HNSW.cpp | 6 +
faiss/impl/IDGrouper.cpp | 51 ++++++++
faiss/impl/IDGrouper.h | 51 ++++++++
faiss/impl/ResultHandler.h | 190 +++++++++++++++++++++++++++++
faiss/utils/GroupHeap.h | 182 ++++++++++++++++++++++++++++
tests/CMakeLists.txt | 2 +
tests/test_group_heap.cpp | 98 +++++++++++++++++++
tests/test_id_grouper.cpp | 189 +++++++++++++++++++++++++++++++++++++
13 files changed, 838 insertions(+), 7 deletions(-)
tests/test_group_heap.cpp | 98 +++++++++++++++
tests/test_id_grouper.cpp | 241 +++++++++++++++++++++++++++++++++++++
13 files changed, 891 insertions(+), 5 deletions(-)
create mode 100644 faiss/impl/IDGrouper.cpp
create mode 100644 faiss/impl/IDGrouper.h
create mode 100644 faiss/utils/GroupHeap.h
Expand Down Expand Up @@ -54,7 +54,7 @@ index a890a46f..137e68d4 100644
utils/WorkerThread.h
utils/distances.h
diff --git a/faiss/Index.h b/faiss/Index.h
index 4b4b302b..3b673d1e 100644
index 3d1bdb99..a8622858 100644
--- a/faiss/Index.h
+++ b/faiss/Index.h
@@ -38,9 +38,10 @@
Expand Down Expand Up @@ -106,7 +106,7 @@ index 9a67332d..a5e0fea0 100644
if (is_similarity_metric(this->metric_type)) {
// we need to revert the negated distances
diff --git a/faiss/IndexIDMap.cpp b/faiss/IndexIDMap.cpp
index e093bbda..e24365d5 100644
index dc84052b..3f375e7b 100644
--- a/faiss/IndexIDMap.cpp
+++ b/faiss/IndexIDMap.cpp
@@ -102,6 +102,23 @@ struct ScopedSelChange {
Expand Down Expand Up @@ -198,20 +198,9 @@ index 2d164123..a68887bd 100644
+
} // namespace faiss
diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp
index fb4de678..b6f602a0 100644
index a9fb9daf..33b56638 100644
--- a/faiss/impl/HNSW.cpp
+++ b/faiss/impl/HNSW.cpp
@@ -110,8 +110,8 @@ void HNSW::print_neighbor_stats(int level) const {
level,
nb_neighbors(level));
size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0;
-#pragma omp parallel for reduction(+: tot_neigh) reduction(+: tot_common) \
- reduction(+: tot_reciprocal) reduction(+: n_node)
+#pragma omp parallel for reduction(+ : tot_neigh) reduction(+ : tot_common) \
+ reduction(+ : tot_reciprocal) reduction(+ : n_node)
for (int i = 0; i < levels.size(); i++) {
if (levels[i] > level) {
n_node++;
@@ -804,6 +804,12 @@ int extract_k_from_ResultHandler(ResultHandler<C>& res) {
if (auto hres = dynamic_cast<RH::SingleResultHandler*>(&res)) {
return hres->k;
Expand Down Expand Up @@ -340,19 +329,20 @@ index 00000000..d56113d9
+
+} // namespace faiss
diff --git a/faiss/impl/ResultHandler.h b/faiss/impl/ResultHandler.h
index 270de8dc..2f7f3e7f 100644
index 270de8dc..3199634f 100644
--- a/faiss/impl/ResultHandler.h
+++ b/faiss/impl/ResultHandler.h
@@ -12,6 +12,8 @@
@@ -12,6 +12,9 @@
#pragma once

#include <faiss/impl/AuxIndexStructures.h>
+#include <faiss/impl/FaissException.h>
+#include <faiss/impl/IDGrouper.h>
+#include <faiss/utils/GroupHeap.h>
#include <faiss/utils/Heap.h>
#include <faiss/utils/partitioning.h>

@@ -265,6 +267,191 @@ struct HeapBlockResultHandler : BlockResultHandler<C> {
@@ -265,6 +268,193 @@ struct HeapBlockResultHandler : BlockResultHandler<C> {
}
};

Expand Down Expand Up @@ -436,6 +426,7 @@ index 270de8dc..2f7f3e7f 100644
+ idx,
+ group_id,
+ &group_id_to_index_in_heap);
+ threshold = heap_dis[0];
+ return true;
+ } else {
+ size_t pos = it_pos->second;
Expand All @@ -452,6 +443,7 @@ index 270de8dc..2f7f3e7f 100644
+ idx,
+ group_id,
+ &group_id_to_index_in_heap);
+ threshold = heap_dis[0];
+ return true;
+ }
+ }
Expand Down Expand Up @@ -734,10 +726,10 @@ index 00000000..3b7078da
+} // namespace faiss
\ No newline at end of file
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
index cc0a4f4c..96e19328 100644
index 9017edc5..a8e9d30c 100644
--- a/tests/CMakeLists.txt
+++ b/tests/CMakeLists.txt
@@ -26,6 +26,8 @@ set(FAISS_TEST_SRC
@@ -27,6 +27,8 @@ set(FAISS_TEST_SRC
test_approx_topk.cpp
test_RCQ_cropping.cpp
test_distances_simd.cpp
Expand Down Expand Up @@ -852,10 +844,10 @@ index 00000000..0e8fe7a7
+}
diff --git a/tests/test_id_grouper.cpp b/tests/test_id_grouper.cpp
new file mode 100644
index 00000000..2aed5500
index 00000000..6601795b
--- /dev/null
+++ b/tests/test_id_grouper.cpp
@@ -0,0 +1,189 @@
@@ -0,0 +1,241 @@
+/**
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ *
Expand Down Expand Up @@ -920,6 +912,58 @@ index 00000000..2aed5500
+ ASSERT_EQ(bitmap.NO_MORE_DOCS, bitmap.get_group(group_ids[3] + 1));
+}
+
+TEST(IdGrouper, sanity_test) {
+ int d = 1; // dimension
+ int nb = 10; // database size
+
+ std::mt19937 rng;
+ std::uniform_real_distribution<> distrib;
+
+ float* xb = new float[d * nb];
+
+ for (int i = 0; i < nb; i++) {
+ for (int j = 0; j < d; j++)
+ xb[d * i + j] = distrib(rng);
+ xb[d * i] += i / 1000.;
+ }
+
+ uint64_t bitmap[1] = {};
+ faiss::IDGrouperBitmap id_grouper(1, bitmap);
+ for (int i = 0; i < nb; i++) {
+ id_grouper.set_group(i);
+ }
+
+ int k = 5;
+ int m = 8;
+ faiss::Index* index =
+ new faiss::IndexHNSWFlat(d, m, faiss::MetricType::METRIC_L2);
+ index->add(nb, xb); // add vectors to the index
+
+ // search
+ auto pSearchParameters = new faiss::SearchParametersHNSW();
+
+ idx_t* expectedI = new idx_t[k];
+ float* expectedD = new float[k];
+ index->search(1, xb, k, expectedD, expectedI, pSearchParameters);
+
+ idx_t* I = new idx_t[k];
+ float* D = new float[k];
+ pSearchParameters->grp = &id_grouper;
+ index->search(1, xb, k, D, I, pSearchParameters);
+
+ // compare
+ for (int j = 0; j < k; j++) {
+ ASSERT_EQ(expectedI[j], I[j]);
+ ASSERT_EQ(expectedD[j], D[j]);
+ }
+
+ delete[] expectedI;
+ delete[] expectedD;
+ delete[] I;
+ delete[] D;
+ delete[] xb;
+}
+
+TEST(IdGrouper, bitmap_with_hnsw) {
+ int d = 1; // dimension
+ int nb = 10; // database size
Expand Down
8 changes: 6 additions & 2 deletions src/test/java/org/opensearch/knn/index/NestedSearchIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,13 @@ public void testNestedSearchWithLucene_whenKIsTwo_thenReturnTwoResults() {
refreshIndex(INDEX_NAME);
forceMergeKnnIndex(INDEX_NAME);

Float[] queryVector = { 1f, 1f };
Float[] queryVector = { 14f, 14f };
Response response = queryNestedField(INDEX_NAME, 2, queryVector);
String entity = EntityUtils.toString(response.getEntity());
assertEquals(2, parseHits(entity));
assertEquals(2, parseTotalSearchHits(entity));
assertEquals("14", parseIds(entity).get(0));
assertEquals("13", parseIds(entity).get(1));
}

@SneakyThrows
Expand All @@ -97,11 +99,13 @@ public void testNestedSearchWithFaiss_whenKIsTwo_thenReturnTwoResults() {
refreshIndex(INDEX_NAME);
forceMergeKnnIndex(INDEX_NAME);

Float[] queryVector = { 1f, 1f };
Float[] queryVector = { 14f, 14f };
Response response = queryNestedField(INDEX_NAME, 2, queryVector);
String entity = EntityUtils.toString(response.getEntity());
assertEquals(2, parseHits(entity));
assertEquals(2, parseTotalSearchHits(entity));
assertEquals("14", parseIds(entity).get(0));
assertEquals("13", parseIds(entity).get(1));
}

/**
Expand Down
Loading