diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a7e1d6acd..a03d64f37f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Support script score when doc value is disabled and fix misusing DISI [#1696](https://github.com/opensearch-project/k-NN/pull/1696) ### Bug Fixes * Block commas in model description [#1692](https://github.com/opensearch-project/k-NN/pull/1692) +* Fix a bug in multi-vector with faiss engine [#1715](https://github.com/opensearch-project/k-NN/pull/1715) ### Infrastructure ### Documentation ### Maintenance diff --git a/jni/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch b/jni/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch index a22e281305..227630c638 100644 --- a/jni/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch +++ b/jni/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch @@ -1,4 +1,4 @@ -From 0d1385959ddecabb2825957e48ff28ff0e8abf53 Mon Sep 17 00:00:00 2001 +From 35ef01f59b8903dfbd4d08ff874b085e851e4228 Mon Sep 17 00:00:00 2001 From: Heemin Kim Date: Tue, 30 Jan 2024 14:43:56 -0800 Subject: [PATCH] Add IDGrouper for HNSW @@ -7,18 +7,18 @@ Signed-off-by: Heemin Kim --- faiss/CMakeLists.txt | 3 + faiss/Index.h | 8 +- - faiss/IndexHNSW.cpp | 13 ++- - faiss/IndexIDMap.cpp | 29 ++++++ - faiss/IndexIDMap.h | 22 +++++ - faiss/impl/HNSW.cpp | 10 +- - faiss/impl/IDGrouper.cpp | 51 ++++++++++ - faiss/impl/IDGrouper.h | 51 ++++++++++ - faiss/impl/ResultHandler.h | 187 ++++++++++++++++++++++++++++++++++++ - faiss/utils/GroupHeap.h | 182 +++++++++++++++++++++++++++++++++++ + faiss/IndexHNSW.cpp | 13 +- + faiss/IndexIDMap.cpp | 29 +++++ + faiss/IndexIDMap.h | 22 ++++ + faiss/impl/HNSW.cpp | 6 + + faiss/impl/IDGrouper.cpp | 51 ++++++++ + faiss/impl/IDGrouper.h | 51 ++++++++ + faiss/impl/ResultHandler.h | 190 +++++++++++++++++++++++++++++ + faiss/utils/GroupHeap.h | 182 ++++++++++++++++++++++++++++ tests/CMakeLists.txt | 2 + - tests/test_group_heap.cpp | 98 +++++++++++++++++++ - tests/test_id_grouper.cpp | 189 +++++++++++++++++++++++++++++++++++++ - 13 files changed, 838 insertions(+), 7 deletions(-) + tests/test_group_heap.cpp | 98 +++++++++++++++ + tests/test_id_grouper.cpp | 241 +++++++++++++++++++++++++++++++++++++ + 13 files changed, 891 insertions(+), 5 deletions(-) create mode 100644 faiss/impl/IDGrouper.cpp create mode 100644 faiss/impl/IDGrouper.h create mode 100644 faiss/utils/GroupHeap.h @@ -54,7 +54,7 @@ index a890a46f..137e68d4 100644 utils/WorkerThread.h utils/distances.h diff --git a/faiss/Index.h b/faiss/Index.h -index 4b4b302b..3b673d1e 100644 +index 3d1bdb99..a8622858 100644 --- a/faiss/Index.h +++ b/faiss/Index.h @@ -38,9 +38,10 @@ @@ -106,7 +106,7 @@ index 9a67332d..a5e0fea0 100644 if (is_similarity_metric(this->metric_type)) { // we need to revert the negated distances diff --git a/faiss/IndexIDMap.cpp b/faiss/IndexIDMap.cpp -index e093bbda..e24365d5 100644 +index dc84052b..3f375e7b 100644 --- a/faiss/IndexIDMap.cpp +++ b/faiss/IndexIDMap.cpp @@ -102,6 +102,23 @@ struct ScopedSelChange { @@ -198,20 +198,9 @@ index 2d164123..a68887bd 100644 + } // namespace faiss diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp -index fb4de678..b6f602a0 100644 +index a9fb9daf..33b56638 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp -@@ -110,8 +110,8 @@ void HNSW::print_neighbor_stats(int level) const { - level, - nb_neighbors(level)); - size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0; --#pragma omp parallel for reduction(+: tot_neigh) reduction(+: tot_common) \ -- reduction(+: tot_reciprocal) reduction(+: n_node) -+#pragma omp parallel for reduction(+ : tot_neigh) reduction(+ : tot_common) \ -+ reduction(+ : tot_reciprocal) reduction(+ : n_node) - for (int i = 0; i < levels.size(); i++) { - if (levels[i] > level) { - n_node++; @@ -804,6 +804,12 @@ int extract_k_from_ResultHandler(ResultHandler& res) { if (auto hres = dynamic_cast(&res)) { return hres->k; @@ -340,19 +329,20 @@ index 00000000..d56113d9 + +} // namespace faiss diff --git a/faiss/impl/ResultHandler.h b/faiss/impl/ResultHandler.h -index 270de8dc..2f7f3e7f 100644 +index 270de8dc..3199634f 100644 --- a/faiss/impl/ResultHandler.h +++ b/faiss/impl/ResultHandler.h -@@ -12,6 +12,8 @@ +@@ -12,6 +12,9 @@ #pragma once #include ++#include +#include +#include #include #include -@@ -265,6 +267,191 @@ struct HeapBlockResultHandler : BlockResultHandler { +@@ -265,6 +268,193 @@ struct HeapBlockResultHandler : BlockResultHandler { } }; @@ -436,6 +426,7 @@ index 270de8dc..2f7f3e7f 100644 + idx, + group_id, + &group_id_to_index_in_heap); ++ threshold = heap_dis[0]; + return true; + } else { + size_t pos = it_pos->second; @@ -452,6 +443,7 @@ index 270de8dc..2f7f3e7f 100644 + idx, + group_id, + &group_id_to_index_in_heap); ++ threshold = heap_dis[0]; + return true; + } + } @@ -734,10 +726,10 @@ index 00000000..3b7078da +} // namespace faiss \ No newline at end of file diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt -index cc0a4f4c..96e19328 100644 +index 9017edc5..a8e9d30c 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt -@@ -26,6 +26,8 @@ set(FAISS_TEST_SRC +@@ -27,6 +27,8 @@ set(FAISS_TEST_SRC test_approx_topk.cpp test_RCQ_cropping.cpp test_distances_simd.cpp @@ -852,10 +844,10 @@ index 00000000..0e8fe7a7 +} diff --git a/tests/test_id_grouper.cpp b/tests/test_id_grouper.cpp new file mode 100644 -index 00000000..2aed5500 +index 00000000..6601795b --- /dev/null +++ b/tests/test_id_grouper.cpp -@@ -0,0 +1,189 @@ +@@ -0,0 +1,241 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * @@ -920,6 +912,58 @@ index 00000000..2aed5500 + ASSERT_EQ(bitmap.NO_MORE_DOCS, bitmap.get_group(group_ids[3] + 1)); +} + ++TEST(IdGrouper, sanity_test) { ++ int d = 1; // dimension ++ int nb = 10; // database size ++ ++ std::mt19937 rng; ++ std::uniform_real_distribution<> distrib; ++ ++ float* xb = new float[d * nb]; ++ ++ for (int i = 0; i < nb; i++) { ++ for (int j = 0; j < d; j++) ++ xb[d * i + j] = distrib(rng); ++ xb[d * i] += i / 1000.; ++ } ++ ++ uint64_t bitmap[1] = {}; ++ faiss::IDGrouperBitmap id_grouper(1, bitmap); ++ for (int i = 0; i < nb; i++) { ++ id_grouper.set_group(i); ++ } ++ ++ int k = 5; ++ int m = 8; ++ faiss::Index* index = ++ new faiss::IndexHNSWFlat(d, m, faiss::MetricType::METRIC_L2); ++ index->add(nb, xb); // add vectors to the index ++ ++ // search ++ auto pSearchParameters = new faiss::SearchParametersHNSW(); ++ ++ idx_t* expectedI = new idx_t[k]; ++ float* expectedD = new float[k]; ++ index->search(1, xb, k, expectedD, expectedI, pSearchParameters); ++ ++ idx_t* I = new idx_t[k]; ++ float* D = new float[k]; ++ pSearchParameters->grp = &id_grouper; ++ index->search(1, xb, k, D, I, pSearchParameters); ++ ++ // compare ++ for (int j = 0; j < k; j++) { ++ ASSERT_EQ(expectedI[j], I[j]); ++ ASSERT_EQ(expectedD[j], D[j]); ++ } ++ ++ delete[] expectedI; ++ delete[] expectedD; ++ delete[] I; ++ delete[] D; ++ delete[] xb; ++} ++ +TEST(IdGrouper, bitmap_with_hnsw) { + int d = 1; // dimension + int nb = 10; // database size diff --git a/src/test/java/org/opensearch/knn/index/NestedSearchIT.java b/src/test/java/org/opensearch/knn/index/NestedSearchIT.java index 20b87708f1..73d00141f8 100644 --- a/src/test/java/org/opensearch/knn/index/NestedSearchIT.java +++ b/src/test/java/org/opensearch/knn/index/NestedSearchIT.java @@ -75,11 +75,13 @@ public void testNestedSearchWithLucene_whenKIsTwo_thenReturnTwoResults() { refreshIndex(INDEX_NAME); forceMergeKnnIndex(INDEX_NAME); - Float[] queryVector = { 1f, 1f }; + Float[] queryVector = { 14f, 14f }; Response response = queryNestedField(INDEX_NAME, 2, queryVector); String entity = EntityUtils.toString(response.getEntity()); assertEquals(2, parseHits(entity)); assertEquals(2, parseTotalSearchHits(entity)); + assertEquals("14", parseIds(entity).get(0)); + assertEquals("13", parseIds(entity).get(1)); } @SneakyThrows @@ -97,11 +99,13 @@ public void testNestedSearchWithFaiss_whenKIsTwo_thenReturnTwoResults() { refreshIndex(INDEX_NAME); forceMergeKnnIndex(INDEX_NAME); - Float[] queryVector = { 1f, 1f }; + Float[] queryVector = { 14f, 14f }; Response response = queryNestedField(INDEX_NAME, 2, queryVector); String entity = EntityUtils.toString(response.getEntity()); assertEquals(2, parseHits(entity)); assertEquals(2, parseTotalSearchHits(entity)); + assertEquals("14", parseIds(entity).get(0)); + assertEquals("13", parseIds(entity).get(1)); } /**