From 58d5622516a38f1cb1054141cb850a1d3c3cf2f4 Mon Sep 17 00:00:00 2001 From: Tejas Shah Date: Sat, 1 Jun 2024 10:25:02 -0700 Subject: [PATCH] Updating feature branch with latest changes on main (#1728) * Fix flaky test in Faiss JNI range search (#1705) Signed-off-by: Junqiu Lei * Support script score when doc value is disabled and fix misusing DISI (#1696) * Revert "Revert 'Support script score when doc value is disabled' (#1662)" This reverts commit bd2f403cb1ff439e6f1d88efce71464682472544. Signed-off-by: panguixin * fix misusing doc value Signed-off-by: panguixin * add changelog Signed-off-by: panguixin --------- Signed-off-by: panguixin * --- (#1712) updated-dependencies: - dependency-name: requests dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Update threshold value after new result is added (#1715) Signed-off-by: Heemin Kim * Use the Lucene Distance Calculation Function in Script Scoring for doing exact search (#1699) * Use the Lucene Distance Calculation Function in Script Scoring for doing exact search Signed-off-by: Ryan Bogan * Add Changelog entry Signed-off-by: Ryan Bogan * Fix failing test Signed-off-by: Ryan Bogan * fix test Signed-off-by: Ryan Bogan * Fix test bug and remove unnecessary validation Signed-off-by: Ryan Bogan * Remove cosineSimilOptimized Signed-off-by: Ryan Bogan * Revert "Remove cosineSimilOptimized" This reverts commit f872d8389683186c9ff64f6a65fd77f170f4a47d. Signed-off-by: Ryan Bogan --------- Signed-off-by: Ryan Bogan * Add validation for pq m parameter before training starts (#1713) * Add validation for pq code count before training starts Signed-off-by: Ryan Bogan * Add integration test Signed-off-by: Ryan Bogan * Add unit tests Signed-off-by: Ryan Bogan * Clean up code Signed-off-by: Ryan Bogan * Remove unnecessary lines Signed-off-by: Ryan Bogan * Add changelog entry Signed-off-by: Ryan Bogan * Change framework to add validation with data Signed-off-by: Ryan Bogan * Remove unused error message Signed-off-by: Ryan Bogan * Add unit tests Signed-off-by: Ryan Bogan * Change space type check name for readability Signed-off-by: Ryan Bogan * Add javadocs Signed-off-by: Ryan Bogan * Modify validation error wording and add json structure to tests Signed-off-by: Ryan Bogan * Change TrainingDataSpec to VectorSpaceInfo Signed-off-by: Ryan Bogan * Add unit tests Signed-off-by: Ryan Bogan --------- Signed-off-by: Ryan Bogan * Updating the BWC test config after 2.14 release (#1724) Signed-off-by: Navneet Verma --------- Signed-off-by: Junqiu Lei Signed-off-by: panguixin Signed-off-by: dependabot[bot] Signed-off-by: Heemin Kim Signed-off-by: Ryan Bogan Signed-off-by: Navneet Verma Co-authored-by: Junqiu Lei Co-authored-by: panguixin Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Heemin Kim Co-authored-by: Ryan Bogan Co-authored-by: Navneet Verma --- ...backwards_compatibility_tests_workflow.yml | 4 +- CHANGELOG.md | 4 + benchmarks/perf-tool/requirements.txt | 2 +- ...Custom-patch-to-support-multi-vector.patch | 110 +++++++++---- jni/tests/faiss_wrapper_test.cpp | 33 ++-- .../org/opensearch/knn/index/KNNMethod.java | 41 ++++- .../knn/index/KNNMethodContext.java | 11 ++ .../knn/index/KNNVectorDVLeafFieldData.java | 28 +++- .../knn/index/KNNVectorScriptDocValues.java | 117 ++++++++++++-- .../opensearch/knn/index/MethodComponent.java | 38 +++++ .../org/opensearch/knn/index/Parameter.java | 149 +++++++++++++++++- .../knn/index/util/AbstractKNNLibrary.java | 7 + .../org/opensearch/knn/index/util/Faiss.java | 9 +- .../opensearch/knn/index/util/KNNEngine.java | 6 + .../opensearch/knn/index/util/KNNLibrary.java | 11 ++ .../knn/plugin/script/KNNScoringUtil.java | 32 +--- .../transport/TrainingModelRequest.java | 7 + .../knn/training/VectorSpaceInfo.java | 26 +++ .../org/opensearch/knn/index/FaissIT.java | 136 +++++++++++++++- .../opensearch/knn/index/KNNMethodTests.java | 55 ++++++- .../index/KNNVectorScriptDocValuesTests.java | 61 +++++-- .../opensearch/knn/index/NestedSearchIT.java | 8 +- .../opensearch/knn/index/ParameterTests.java | 109 +++++++++++++ .../knn/index/VectorDataTypeTests.java | 4 +- .../plugin/script/KNNScoringSpaceTests.java | 4 +- .../plugin/script/KNNScoringUtilTests.java | 2 +- .../knn/plugin/script/KNNScriptScoringIT.java | 53 +++++-- .../knn/plugin/script/PainlessScriptIT.java | 4 +- .../LibraryInitializedSupplierTests.java | 6 + .../org/opensearch/knn/KNNRestTestCase.java | 12 +- 30 files changed, 945 insertions(+), 144 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/training/VectorSpaceInfo.java diff --git a/.github/workflows/backwards_compatibility_tests_workflow.yml b/.github/workflows/backwards_compatibility_tests_workflow.yml index cfa1bb128..503b23326 100644 --- a/.github/workflows/backwards_compatibility_tests_workflow.yml +++ b/.github/workflows/backwards_compatibility_tests_workflow.yml @@ -15,7 +15,7 @@ jobs: matrix: java: [ 11, 17 ] os: [ubuntu-latest] - bwc_version : [ "2.0.1", "2.1.0", "2.2.1", "2.3.0", "2.4.1", "2.5.0", "2.6.0", "2.7.0", "2.8.0", "2.9.0", "2.10.0", "2.11.0", "2.12.0", "2.13.0", "2.14.0-SNAPSHOT"] + bwc_version : [ "2.0.1", "2.1.0", "2.2.1", "2.3.0", "2.4.1", "2.5.0", "2.6.0", "2.7.0", "2.8.0", "2.9.0", "2.10.0", "2.11.0", "2.12.0", "2.13.0", "2.14.0", "2.15.0-SNAPSHOT"] opensearch_version : [ "3.0.0-SNAPSHOT" ] exclude: - os: windows-latest @@ -94,7 +94,7 @@ jobs: matrix: java: [ 11, 17 ] os: [ubuntu-latest] - bwc_version: [ "2.14.0-SNAPSHOT" ] + bwc_version: [ "2.15.0-SNAPSHOT" ] opensearch_version: [ "3.0.0-SNAPSHOT" ] name: k-NN Rolling-Upgrade BWC Tests diff --git a/CHANGELOG.md b/CHANGELOG.md index ee01f2ac7..0a0f62ebd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,11 +14,15 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.14...2.x) ### Features +* Use the Lucene Distance Calculation Function in Script Scoring for doing exact search [#1699](https://github.com/opensearch-project/k-NN/pull/1699) ### Enhancements * Add KnnCircuitBreakerException and modify exception message [#1688](https://github.com/opensearch-project/k-NN/pull/1688) * Add stats for radial search [#1684](https://github.com/opensearch-project/k-NN/pull/1684) +* Support script score when doc value is disabled and fix misusing DISI [#1696](https://github.com/opensearch-project/k-NN/pull/1696) +* Add validation for pq m parameter before training starts [#1713](https://github.com/opensearch-project/k-NN/pull/1713) ### Bug Fixes * Block commas in model description [#1692](https://github.com/opensearch-project/k-NN/pull/1692) +* Update threshold value after new result is added [#1715](https://github.com/opensearch-project/k-NN/pull/1715) ### Infrastructure ### Documentation ### Maintenance diff --git a/benchmarks/perf-tool/requirements.txt b/benchmarks/perf-tool/requirements.txt index 46cec00ed..10338bace 100644 --- a/benchmarks/perf-tool/requirements.txt +++ b/benchmarks/perf-tool/requirements.txt @@ -26,7 +26,7 @@ psutil==5.8.0 # via -r requirements.in pyyaml==5.4.1 # via -r requirements.in -requests==2.31.0 +requests==2.32.0 # via -r requirements.in urllib3==1.26.18 # via diff --git a/jni/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch b/jni/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch index a22e28130..227630c63 100644 --- a/jni/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch +++ b/jni/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch @@ -1,4 +1,4 @@ -From 0d1385959ddecabb2825957e48ff28ff0e8abf53 Mon Sep 17 00:00:00 2001 +From 35ef01f59b8903dfbd4d08ff874b085e851e4228 Mon Sep 17 00:00:00 2001 From: Heemin Kim Date: Tue, 30 Jan 2024 14:43:56 -0800 Subject: [PATCH] Add IDGrouper for HNSW @@ -7,18 +7,18 @@ Signed-off-by: Heemin Kim --- faiss/CMakeLists.txt | 3 + faiss/Index.h | 8 +- - faiss/IndexHNSW.cpp | 13 ++- - faiss/IndexIDMap.cpp | 29 ++++++ - faiss/IndexIDMap.h | 22 +++++ - faiss/impl/HNSW.cpp | 10 +- - faiss/impl/IDGrouper.cpp | 51 ++++++++++ - faiss/impl/IDGrouper.h | 51 ++++++++++ - faiss/impl/ResultHandler.h | 187 ++++++++++++++++++++++++++++++++++++ - faiss/utils/GroupHeap.h | 182 +++++++++++++++++++++++++++++++++++ + faiss/IndexHNSW.cpp | 13 +- + faiss/IndexIDMap.cpp | 29 +++++ + faiss/IndexIDMap.h | 22 ++++ + faiss/impl/HNSW.cpp | 6 + + faiss/impl/IDGrouper.cpp | 51 ++++++++ + faiss/impl/IDGrouper.h | 51 ++++++++ + faiss/impl/ResultHandler.h | 190 +++++++++++++++++++++++++++++ + faiss/utils/GroupHeap.h | 182 ++++++++++++++++++++++++++++ tests/CMakeLists.txt | 2 + - tests/test_group_heap.cpp | 98 +++++++++++++++++++ - tests/test_id_grouper.cpp | 189 +++++++++++++++++++++++++++++++++++++ - 13 files changed, 838 insertions(+), 7 deletions(-) + tests/test_group_heap.cpp | 98 +++++++++++++++ + tests/test_id_grouper.cpp | 241 +++++++++++++++++++++++++++++++++++++ + 13 files changed, 891 insertions(+), 5 deletions(-) create mode 100644 faiss/impl/IDGrouper.cpp create mode 100644 faiss/impl/IDGrouper.h create mode 100644 faiss/utils/GroupHeap.h @@ -54,7 +54,7 @@ index a890a46f..137e68d4 100644 utils/WorkerThread.h utils/distances.h diff --git a/faiss/Index.h b/faiss/Index.h -index 4b4b302b..3b673d1e 100644 +index 3d1bdb99..a8622858 100644 --- a/faiss/Index.h +++ b/faiss/Index.h @@ -38,9 +38,10 @@ @@ -106,7 +106,7 @@ index 9a67332d..a5e0fea0 100644 if (is_similarity_metric(this->metric_type)) { // we need to revert the negated distances diff --git a/faiss/IndexIDMap.cpp b/faiss/IndexIDMap.cpp -index e093bbda..e24365d5 100644 +index dc84052b..3f375e7b 100644 --- a/faiss/IndexIDMap.cpp +++ b/faiss/IndexIDMap.cpp @@ -102,6 +102,23 @@ struct ScopedSelChange { @@ -198,20 +198,9 @@ index 2d164123..a68887bd 100644 + } // namespace faiss diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp -index fb4de678..b6f602a0 100644 +index a9fb9daf..33b56638 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp -@@ -110,8 +110,8 @@ void HNSW::print_neighbor_stats(int level) const { - level, - nb_neighbors(level)); - size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0; --#pragma omp parallel for reduction(+: tot_neigh) reduction(+: tot_common) \ -- reduction(+: tot_reciprocal) reduction(+: n_node) -+#pragma omp parallel for reduction(+ : tot_neigh) reduction(+ : tot_common) \ -+ reduction(+ : tot_reciprocal) reduction(+ : n_node) - for (int i = 0; i < levels.size(); i++) { - if (levels[i] > level) { - n_node++; @@ -804,6 +804,12 @@ int extract_k_from_ResultHandler(ResultHandler& res) { if (auto hres = dynamic_cast(&res)) { return hres->k; @@ -340,19 +329,20 @@ index 00000000..d56113d9 + +} // namespace faiss diff --git a/faiss/impl/ResultHandler.h b/faiss/impl/ResultHandler.h -index 270de8dc..2f7f3e7f 100644 +index 270de8dc..3199634f 100644 --- a/faiss/impl/ResultHandler.h +++ b/faiss/impl/ResultHandler.h -@@ -12,6 +12,8 @@ +@@ -12,6 +12,9 @@ #pragma once #include ++#include +#include +#include #include #include -@@ -265,6 +267,191 @@ struct HeapBlockResultHandler : BlockResultHandler { +@@ -265,6 +268,193 @@ struct HeapBlockResultHandler : BlockResultHandler { } }; @@ -436,6 +426,7 @@ index 270de8dc..2f7f3e7f 100644 + idx, + group_id, + &group_id_to_index_in_heap); ++ threshold = heap_dis[0]; + return true; + } else { + size_t pos = it_pos->second; @@ -452,6 +443,7 @@ index 270de8dc..2f7f3e7f 100644 + idx, + group_id, + &group_id_to_index_in_heap); ++ threshold = heap_dis[0]; + return true; + } + } @@ -734,10 +726,10 @@ index 00000000..3b7078da +} // namespace faiss \ No newline at end of file diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt -index cc0a4f4c..96e19328 100644 +index 9017edc5..a8e9d30c 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt -@@ -26,6 +26,8 @@ set(FAISS_TEST_SRC +@@ -27,6 +27,8 @@ set(FAISS_TEST_SRC test_approx_topk.cpp test_RCQ_cropping.cpp test_distances_simd.cpp @@ -852,10 +844,10 @@ index 00000000..0e8fe7a7 +} diff --git a/tests/test_id_grouper.cpp b/tests/test_id_grouper.cpp new file mode 100644 -index 00000000..2aed5500 +index 00000000..6601795b --- /dev/null +++ b/tests/test_id_grouper.cpp -@@ -0,0 +1,189 @@ +@@ -0,0 +1,241 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * @@ -920,6 +912,58 @@ index 00000000..2aed5500 + ASSERT_EQ(bitmap.NO_MORE_DOCS, bitmap.get_group(group_ids[3] + 1)); +} + ++TEST(IdGrouper, sanity_test) { ++ int d = 1; // dimension ++ int nb = 10; // database size ++ ++ std::mt19937 rng; ++ std::uniform_real_distribution<> distrib; ++ ++ float* xb = new float[d * nb]; ++ ++ for (int i = 0; i < nb; i++) { ++ for (int j = 0; j < d; j++) ++ xb[d * i + j] = distrib(rng); ++ xb[d * i] += i / 1000.; ++ } ++ ++ uint64_t bitmap[1] = {}; ++ faiss::IDGrouperBitmap id_grouper(1, bitmap); ++ for (int i = 0; i < nb; i++) { ++ id_grouper.set_group(i); ++ } ++ ++ int k = 5; ++ int m = 8; ++ faiss::Index* index = ++ new faiss::IndexHNSWFlat(d, m, faiss::MetricType::METRIC_L2); ++ index->add(nb, xb); // add vectors to the index ++ ++ // search ++ auto pSearchParameters = new faiss::SearchParametersHNSW(); ++ ++ idx_t* expectedI = new idx_t[k]; ++ float* expectedD = new float[k]; ++ index->search(1, xb, k, expectedD, expectedI, pSearchParameters); ++ ++ idx_t* I = new idx_t[k]; ++ float* D = new float[k]; ++ pSearchParameters->grp = &id_grouper; ++ index->search(1, xb, k, D, I, pSearchParameters); ++ ++ // compare ++ for (int j = 0; j < k; j++) { ++ ASSERT_EQ(expectedI[j], I[j]); ++ ASSERT_EQ(expectedD[j], D[j]); ++ } ++ ++ delete[] expectedI; ++ delete[] expectedD; ++ delete[] I; ++ delete[] D; ++ delete[] xb; ++} ++ +TEST(IdGrouper, bitmap_with_hnsw) { + int d = 1; // dimension + int nb = 10; // database size diff --git a/jni/tests/faiss_wrapper_test.cpp b/jni/tests/faiss_wrapper_test.cpp index e9316dcc2..4cd3b319e 100644 --- a/jni/tests/faiss_wrapper_test.cpp +++ b/jni/tests/faiss_wrapper_test.cpp @@ -25,6 +25,9 @@ using ::testing::Return; float randomDataMin = -500.0; float randomDataMax = 500.0; +float rangeSearchRandomDataMin = -50; +float rangeSearchRandomDataMax = 50; +float rangeSearchRadius = 20000; TEST(FaissCreateIndexTest, BasicAssertions) { // Define the data @@ -621,13 +624,12 @@ TEST(FaissRangeSearchQueryIndexTest, BasicAssertions) { faiss::idx_t numIds = 200; int dim = 2; std::vector ids = test_util::Range(numIds); - std::vector vectors = test_util::RandomVectors(dim, numIds, randomDataMin, randomDataMax); + std::vector vectors = test_util::RandomVectors(dim, numIds, rangeSearchRandomDataMin, rangeSearchRandomDataMax); faiss::MetricType metricType = faiss::METRIC_L2; std::string method = "HNSW32,Flat"; // Define query data - float radius = 100000.0; int numQueries = 100; std::vector> queries; @@ -635,7 +637,7 @@ TEST(FaissRangeSearchQueryIndexTest, BasicAssertions) { std::vector query; query.reserve(dim); for (int j = 0; j < dim; j++) { - query.push_back(test_util::RandomFloat(randomDataMin, randomDataMax)); + query.push_back(test_util::RandomFloat(rangeSearchRandomDataMin, rangeSearchRandomDataMax)); } queries.push_back(query); } @@ -659,7 +661,7 @@ TEST(FaissRangeSearchQueryIndexTest, BasicAssertions) { knn_jni::faiss_wrapper::RangeSearch( &mockJNIUtil, jniEnv, reinterpret_cast(&createdIndexWithData), - reinterpret_cast(&query), radius, maxResultWindow, nullptr))); + reinterpret_cast(&query), rangeSearchRadius, maxResultWindow, nullptr))); // assert result size is not 0 ASSERT_NE(0, results->size()); @@ -677,13 +679,12 @@ TEST(FaissRangeSearchQueryIndexTest_WhenHitMaxWindowResult, BasicAssertions){ faiss::idx_t numIds = 200; int dim = 2; std::vector ids = test_util::Range(numIds); - std::vector vectors = test_util::RandomVectors(dim, numIds, randomDataMin, randomDataMax); + std::vector vectors = test_util::RandomVectors(dim, numIds, rangeSearchRandomDataMin, rangeSearchRandomDataMax); faiss::MetricType metricType = faiss::METRIC_L2; std::string method = "HNSW32,Flat"; // Define query data - float radius = 100000.0; int numQueries = 100; std::vector> queries; @@ -691,7 +692,7 @@ TEST(FaissRangeSearchQueryIndexTest_WhenHitMaxWindowResult, BasicAssertions){ std::vector query; query.reserve(dim); for (int j = 0; j < dim; j++) { - query.push_back(test_util::RandomFloat(randomDataMin, randomDataMax)); + query.push_back(test_util::RandomFloat(rangeSearchRandomDataMin, rangeSearchRandomDataMax)); } queries.push_back(query); } @@ -715,7 +716,7 @@ TEST(FaissRangeSearchQueryIndexTest_WhenHitMaxWindowResult, BasicAssertions){ knn_jni::faiss_wrapper::RangeSearch( &mockJNIUtil, jniEnv, reinterpret_cast(&createdIndexWithData), - reinterpret_cast(&query), radius, maxResultWindow, nullptr))); + reinterpret_cast(&query), rangeSearchRadius, maxResultWindow, nullptr))); // assert result size is not 0 ASSERT_NE(0, results->size()); @@ -734,13 +735,12 @@ TEST(FaissRangeSearchQueryIndexTestWithFilterTest, BasicAssertions) { faiss::idx_t numIds = 200; int dim = 2; std::vector ids = test_util::Range(numIds); - std::vector vectors = test_util::RandomVectors(dim, numIds, randomDataMin, randomDataMax); + std::vector vectors = test_util::RandomVectors(dim, numIds, rangeSearchRandomDataMin, rangeSearchRandomDataMax); faiss::MetricType metricType = faiss::METRIC_L2; std::string method = "HNSW32,Flat"; // Define query data - float radius = 100000.0; int numQueries = 100; std::vector> queries; @@ -748,7 +748,7 @@ TEST(FaissRangeSearchQueryIndexTestWithFilterTest, BasicAssertions) { std::vector query; query.reserve(dim); for (int j = 0; j < dim; j++) { - query.push_back(test_util::RandomFloat(randomDataMin, randomDataMax)); + query.push_back(test_util::RandomFloat(rangeSearchRandomDataMin, rangeSearchRandomDataMax)); } queries.push_back(query); } @@ -767,7 +767,7 @@ TEST(FaissRangeSearchQueryIndexTestWithFilterTest, BasicAssertions) { std::vector bitmap(num_bits,0); std::vector filterIds; - for (int64_t i = 154; i < 163; i++) { + for (int64_t i = 1; i < 50; i++) { filterIds.push_back(i); test_util::setBitSet(i, bitmap.data(), bitmap.size()); } @@ -782,7 +782,7 @@ TEST(FaissRangeSearchQueryIndexTestWithFilterTest, BasicAssertions) { knn_jni::faiss_wrapper::RangeSearchWithFilter( &mockJNIUtil, jniEnv, reinterpret_cast(&createdIndexWithData), - reinterpret_cast(&query), radius, maxResultWindow, + reinterpret_cast(&query), rangeSearchRadius, maxResultWindow, reinterpret_cast(&bitmap), 0, nullptr))); // assert result size is not 0 @@ -814,7 +814,7 @@ TEST(FaissRangeSearchQueryIndexTestWithParentFilterTest, BasicAssertions) { } ids.push_back(i); for (int j = 0; j < dim; j++) { - vectors.push_back(test_util::RandomFloat(-500.0, 500.0)); + vectors.push_back(test_util::RandomFloat(rangeSearchRandomDataMin, rangeSearchRandomDataMax)); } } @@ -822,7 +822,6 @@ TEST(FaissRangeSearchQueryIndexTestWithParentFilterTest, BasicAssertions) { std::string method = "HNSW32,Flat"; // Define query data - float radius = 100000.0; int numQueries = 1; std::vector> queries; @@ -830,7 +829,7 @@ TEST(FaissRangeSearchQueryIndexTestWithParentFilterTest, BasicAssertions) { std::vector query; query.reserve(dim); for (int j = 0; j < dim; j++) { - query.push_back(test_util::RandomFloat(-500.0, 500.0)); + query.push_back(test_util::RandomFloat(rangeSearchRandomDataMin, rangeSearchRandomDataMax)); } queries.push_back(query); } @@ -858,7 +857,7 @@ TEST(FaissRangeSearchQueryIndexTestWithParentFilterTest, BasicAssertions) { knn_jni::faiss_wrapper::RangeSearchWithFilter( &mockJNIUtil, jniEnv, reinterpret_cast(&createdIndexWithData), - reinterpret_cast(&query), radius, maxResultWindow, nullptr, 0, + reinterpret_cast(&query), rangeSearchRadius, maxResultWindow, nullptr, 0, reinterpret_cast(&parentIds)))); // assert result size is not 0 diff --git a/src/main/java/org/opensearch/knn/index/KNNMethod.java b/src/main/java/org/opensearch/knn/index/KNNMethod.java index 2d3672d87..7abd2ce39 100644 --- a/src/main/java/org/opensearch/knn/index/KNNMethod.java +++ b/src/main/java/org/opensearch/knn/index/KNNMethod.java @@ -15,6 +15,7 @@ import lombok.Getter; import org.opensearch.common.ValidationException; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.training.VectorSpaceInfo; import java.util.ArrayList; import java.util.Arrays; @@ -41,7 +42,7 @@ public class KNNMethod { * @param space to be checked * @return true if the space is supported; false otherwise */ - public boolean containsSpace(SpaceType space) { + public boolean isSpaceTypeSupported(SpaceType space) { return spaces.contains(space); } @@ -53,7 +54,7 @@ public boolean containsSpace(SpaceType space) { */ public ValidationException validate(KNNMethodContext knnMethodContext) { List errorMessages = new ArrayList<>(); - if (!containsSpace(knnMethodContext.getSpaceType())) { + if (!isSpaceTypeSupported(knnMethodContext.getSpaceType())) { errorMessages.add( String.format( "\"%s\" configuration does not support space type: " + "\"%s\".", @@ -77,6 +78,42 @@ public ValidationException validate(KNNMethodContext knnMethodContext) { return validationException; } + /** + * Validate that the configured KNNMethodContext is valid for this method, using additional data not present in the method context + * + * @param knnMethodContext to be validated + * @param vectorSpaceInfo additional data not present in the method context + * @return ValidationException produced by validation errors; null if no validations errors. + */ + public ValidationException validateWithData(KNNMethodContext knnMethodContext, VectorSpaceInfo vectorSpaceInfo) { + List errorMessages = new ArrayList<>(); + if (!isSpaceTypeSupported(knnMethodContext.getSpaceType())) { + errorMessages.add( + String.format( + "\"%s\" configuration does not support space type: " + "\"%s\".", + this.methodComponent.getName(), + knnMethodContext.getSpaceType().getValue() + ) + ); + } + + ValidationException methodValidation = methodComponent.validateWithData( + knnMethodContext.getMethodComponentContext(), + vectorSpaceInfo + ); + if (methodValidation != null) { + errorMessages.addAll(methodValidation.validationErrors()); + } + + if (errorMessages.isEmpty()) { + return null; + } + + ValidationException validationException = new ValidationException(); + validationException.addValidationErrors(errorMessages); + return validationException; + } + /** * returns whether training is required or not * diff --git a/src/main/java/org/opensearch/knn/index/KNNMethodContext.java b/src/main/java/org/opensearch/knn/index/KNNMethodContext.java index d4df713c2..ce48b06be 100644 --- a/src/main/java/org/opensearch/knn/index/KNNMethodContext.java +++ b/src/main/java/org/opensearch/knn/index/KNNMethodContext.java @@ -30,6 +30,7 @@ import java.util.stream.Collectors; import org.apache.commons.lang.builder.EqualsBuilder; import org.apache.commons.lang.builder.HashCodeBuilder; +import org.opensearch.knn.training.VectorSpaceInfo; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; @@ -86,6 +87,16 @@ public ValidationException validate() { return knnEngine.validateMethod(this); } + /** + * This method uses the knnEngine to validate that the method is compatible with the engine, using additional data not present in the method context + * + * @param vectorSpaceInfo additional data not present in the method context + * @return ValidationException produced by validation errors; null if no validations errors. + */ + public ValidationException validateWithData(VectorSpaceInfo vectorSpaceInfo) { + return knnEngine.validateMethodWithData(this, vectorSpaceInfo); + } + /** * This method returns whether training is requires or not from knnEngine * diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java b/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java index f4caa4f20..85f037c0f 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java @@ -5,9 +5,10 @@ package org.opensearch.knn.index; -import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.DocValues; +import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.LeafReader; +import org.apache.lucene.search.DocIdSetIterator; import org.opensearch.index.fielddata.LeafFieldData; import org.opensearch.index.fielddata.ScriptDocValues; import org.opensearch.index.fielddata.SortedBinaryDocValues; @@ -39,10 +40,29 @@ public long ramBytesUsed() { @Override public ScriptDocValues getScriptValues() { try { - BinaryDocValues values = DocValues.getBinary(reader, fieldName); - return new KNNVectorScriptDocValues(values, fieldName, vectorDataType); + FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(fieldName); + if (fieldInfo == null) { + return KNNVectorScriptDocValues.emptyValues(fieldName, vectorDataType); + } + + DocIdSetIterator values; + if (fieldInfo.hasVectorValues()) { + switch (fieldInfo.getVectorEncoding()) { + case FLOAT32: + values = reader.getFloatVectorValues(fieldName); + break; + case BYTE: + values = reader.getByteVectorValues(fieldName); + break; + default: + throw new IllegalStateException("Unsupported Lucene vector encoding: " + fieldInfo.getVectorEncoding()); + } + } else { + values = DocValues.getBinary(reader, fieldName); + } + return KNNVectorScriptDocValues.create(values, fieldName, vectorDataType); } catch (IOException e) { - throw new IllegalStateException("Cannot load doc values for knn vector field: " + fieldName, e); + throw new IllegalStateException("Cannot load values for knn vector field: " + fieldName, e); } } diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java index 349988c93..55ff65516 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java @@ -6,30 +6,40 @@ package org.opensearch.knn.index; import java.io.IOException; +import java.util.Objects; +import lombok.AccessLevel; import lombok.Getter; import lombok.RequiredArgsConstructor; import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.search.DocIdSetIterator; import org.opensearch.ExceptionsHelper; import org.opensearch.index.fielddata.ScriptDocValues; -import java.io.IOException; - -@RequiredArgsConstructor -public final class KNNVectorScriptDocValues extends ScriptDocValues { +@RequiredArgsConstructor(access = AccessLevel.PRIVATE) +public abstract class KNNVectorScriptDocValues extends ScriptDocValues { - private final BinaryDocValues binaryDocValues; + private final DocIdSetIterator vectorValues; private final String fieldName; @Getter private final VectorDataType vectorDataType; private boolean docExists = false; + private int lastDocID = -1; @Override public void setNextDocId(int docId) throws IOException { - if (binaryDocValues.advanceExact(docId)) { - docExists = true; - return; + if (docId < lastDocID) { + throw new IllegalArgumentException("docs were sent out-of-order: lastDocID=" + lastDocID + " vs docID=" + docId); + } + + lastDocID = docId; + + int curDocID = vectorValues.docID(); + if (lastDocID > curDocID) { + curDocID = vectorValues.advance(docId); } - docExists = false; + docExists = lastDocID == curDocID; } public float[] getValue() { @@ -44,12 +54,14 @@ public float[] getValue() { throw new IllegalStateException(errorMessage); } try { - return vectorDataType.getVectorFromBytesRef(binaryDocValues.binaryValue()); + return doGetValue(); } catch (IOException e) { throw ExceptionsHelper.convertToOpenSearchException(e); } } + protected abstract float[] doGetValue() throws IOException; + @Override public int size() { return docExists ? 1 : 0; @@ -59,4 +71,89 @@ public int size() { public float[] get(int i) { throw new UnsupportedOperationException("knn vector does not support this operation"); } + + /** + * Creates a KNNVectorScriptDocValues object based on the provided parameters. + * + * @param values The DocIdSetIterator representing the vector values. + * @param fieldName The name of the field. + * @param vectorDataType The data type of the vector. + * @return A KNNVectorScriptDocValues object based on the type of the values. + * @throws IllegalArgumentException If the type of values is unsupported. + */ + public static KNNVectorScriptDocValues create(DocIdSetIterator values, String fieldName, VectorDataType vectorDataType) { + Objects.requireNonNull(values, "values must not be null"); + if (values instanceof ByteVectorValues) { + return new KNNByteVectorScriptDocValues((ByteVectorValues) values, fieldName, vectorDataType); + } else if (values instanceof FloatVectorValues) { + return new KNNFloatVectorScriptDocValues((FloatVectorValues) values, fieldName, vectorDataType); + } else if (values instanceof BinaryDocValues) { + return new KNNNativeVectorScriptDocValues((BinaryDocValues) values, fieldName, vectorDataType); + } else { + throw new IllegalArgumentException("Unsupported values type: " + values.getClass()); + } + } + + private static final class KNNByteVectorScriptDocValues extends KNNVectorScriptDocValues { + private final ByteVectorValues values; + + KNNByteVectorScriptDocValues(ByteVectorValues values, String field, VectorDataType type) { + super(values, field, type); + this.values = values; + } + + @Override + protected float[] doGetValue() throws IOException { + byte[] bytes = values.vectorValue(); + float[] value = new float[bytes.length]; + for (int i = 0; i < bytes.length; i++) { + value[i] = (float) bytes[i]; + } + return value; + } + } + + private static final class KNNFloatVectorScriptDocValues extends KNNVectorScriptDocValues { + private final FloatVectorValues values; + + KNNFloatVectorScriptDocValues(FloatVectorValues values, String field, VectorDataType type) { + super(values, field, type); + this.values = values; + } + + @Override + protected float[] doGetValue() throws IOException { + return values.vectorValue(); + } + } + + private static final class KNNNativeVectorScriptDocValues extends KNNVectorScriptDocValues { + private final BinaryDocValues values; + + KNNNativeVectorScriptDocValues(BinaryDocValues values, String field, VectorDataType type) { + super(values, field, type); + this.values = values; + } + + @Override + protected float[] doGetValue() throws IOException { + return getVectorDataType().getVectorFromBytesRef(values.binaryValue()); + } + } + + /** + * Creates an empty KNNVectorScriptDocValues object based on the provided field name and vector data type. + * + * @param fieldName The name of the field. + * @param type The data type of the vector. + * @return An empty KNNVectorScriptDocValues object. + */ + public static KNNVectorScriptDocValues emptyValues(String fieldName, VectorDataType type) { + return new KNNVectorScriptDocValues(DocIdSetIterator.empty(), fieldName, type) { + @Override + protected float[] doGetValue() throws IOException { + throw new UnsupportedOperationException("empty values"); + } + }; + } } diff --git a/src/main/java/org/opensearch/knn/index/MethodComponent.java b/src/main/java/org/opensearch/knn/index/MethodComponent.java index f2e2d878e..256d55ee5 100644 --- a/src/main/java/org/opensearch/knn/index/MethodComponent.java +++ b/src/main/java/org/opensearch/knn/index/MethodComponent.java @@ -17,6 +17,7 @@ import org.opensearch.common.ValidationException; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.util.IndexHyperParametersUtil; +import org.opensearch.knn.training.VectorSpaceInfo; import java.util.ArrayList; import java.util.HashMap; @@ -102,6 +103,43 @@ public ValidationException validate(MethodComponentContext methodComponentContex return validationException; } + /** + * Validate that the methodComponentContext is a valid configuration for this methodComponent, using additional data not present in the method component context + * + * @param methodComponentContext to be validated + * @param vectorSpaceInfo additional data not present in the method component context + * @return ValidationException produced by validation errors; null if no validations errors. + */ + public ValidationException validateWithData(MethodComponentContext methodComponentContext, VectorSpaceInfo vectorSpaceInfo) { + Map providedParameters = methodComponentContext.getParameters(); + List errorMessages = new ArrayList<>(); + + if (providedParameters == null) { + return null; + } + + ValidationException parameterValidation; + for (Map.Entry parameter : providedParameters.entrySet()) { + if (!parameters.containsKey(parameter.getKey())) { + errorMessages.add(String.format("Invalid parameter for method \"%s\".", getName())); + continue; + } + + parameterValidation = parameters.get(parameter.getKey()).validateWithData(parameter.getValue(), vectorSpaceInfo); + if (parameterValidation != null) { + errorMessages.addAll(parameterValidation.validationErrors()); + } + } + + if (errorMessages.isEmpty()) { + return null; + } + + ValidationException validationException = new ValidationException(); + validationException.addValidationErrors(errorMessages); + return validationException; + } + /** * gets requiresTraining value * diff --git a/src/main/java/org/opensearch/knn/index/Parameter.java b/src/main/java/org/opensearch/knn/index/Parameter.java index e223909d5..a4520636e 100644 --- a/src/main/java/org/opensearch/knn/index/Parameter.java +++ b/src/main/java/org/opensearch/knn/index/Parameter.java @@ -12,8 +12,10 @@ package org.opensearch.knn.index; import org.opensearch.common.ValidationException; +import org.opensearch.knn.training.VectorSpaceInfo; import java.util.Map; +import java.util.function.BiFunction; import java.util.function.Predicate; /** @@ -26,6 +28,7 @@ public abstract class Parameter { private String name; private T defaultValue; protected Predicate validator; + protected BiFunction validatorWithData; /** * Constructor @@ -38,6 +41,14 @@ public Parameter(String name, T defaultValue, Predicate validator) { this.name = name; this.defaultValue = defaultValue; this.validator = validator; + this.validatorWithData = null; + } + + public Parameter(String name, T defaultValue, Predicate validator, BiFunction validatorWithData) { + this.name = name; + this.defaultValue = defaultValue; + this.validator = validator; + this.validatorWithData = validatorWithData; } /** @@ -66,6 +77,15 @@ public T getDefaultValue() { */ public abstract ValidationException validate(Object value); + /** + * Check if the value passed in is valid, using additional data not present in the value + * + * @param value to be checked + * @param vectorSpaceInfo additional data not present in the value + * @return ValidationException produced by validation errors; null if no validations errors. + */ + public abstract ValidationException validateWithData(Object value, VectorSpaceInfo vectorSpaceInfo); + /** * Boolean method parameter */ @@ -74,12 +94,23 @@ public BooleanParameter(String name, Boolean defaultValue, Predicate va super(name, defaultValue, validator); } + public BooleanParameter( + String name, + Boolean defaultValue, + Predicate validator, + BiFunction validatorWithData + ) { + super(name, defaultValue, validator, validatorWithData); + } + @Override public ValidationException validate(Object value) { ValidationException validationException = null; if (!(value instanceof Boolean)) { validationException = new ValidationException(); - validationException.addValidationError(String.format("value not of type Boolean for Boolean parameter [%s].", getName())); + validationException.addValidationError( + String.format("value is not an instance of Boolean for Boolean parameter [%s].", getName()) + ); return validationException; } @@ -89,6 +120,27 @@ public ValidationException validate(Object value) { } return validationException; } + + @Override + public ValidationException validateWithData(Object value, VectorSpaceInfo vectorSpaceInfo) { + ValidationException validationException = null; + if (!(value instanceof Boolean)) { + validationException = new ValidationException(); + validationException.addValidationError(String.format("value not of type Boolean for Boolean parameter [%s].", getName())); + return validationException; + } + + if (validatorWithData == null) { + return null; + } + + if (!validatorWithData.apply((Boolean) value, vectorSpaceInfo)) { + validationException = new ValidationException(); + validationException.addValidationError(String.format("parameter validation failed for Boolean parameter [%s].", getName())); + } + + return validationException; + } } /** @@ -99,6 +151,15 @@ public IntegerParameter(String name, Integer defaultValue, Predicate va super(name, defaultValue, validator); } + public IntegerParameter( + String name, + Integer defaultValue, + Predicate validator, + BiFunction validatorWithData + ) { + super(name, defaultValue, validator, validatorWithData); + } + @Override public ValidationException validate(Object value) { ValidationException validationException = null; @@ -118,6 +179,29 @@ public ValidationException validate(Object value) { } return validationException; } + + @Override + public ValidationException validateWithData(Object value, VectorSpaceInfo vectorSpaceInfo) { + ValidationException validationException = null; + if (!(value instanceof Integer)) { + validationException = new ValidationException(); + validationException.addValidationError( + String.format("value is not an instance of Integer for Integer parameter [%s].", getName()) + ); + return validationException; + } + + if (validatorWithData == null) { + return null; + } + + if (!validatorWithData.apply((Integer) value, vectorSpaceInfo)) { + validationException = new ValidationException(); + validationException.addValidationError(String.format("parameter validation failed for Integer parameter [%s].", getName())); + } + + return validationException; + } } /** @@ -136,6 +220,15 @@ public StringParameter(String name, String defaultValue, Predicate valid super(name, defaultValue, validator); } + public StringParameter( + String name, + String defaultValue, + Predicate validator, + BiFunction validatorWithData + ) { + super(name, defaultValue, validator, validatorWithData); + } + /** * Check if the value passed in is valid * @@ -161,6 +254,29 @@ public ValidationException validate(Object value) { } return validationException; } + + @Override + public ValidationException validateWithData(Object value, VectorSpaceInfo vectorSpaceInfo) { + ValidationException validationException = null; + if (!(value instanceof String)) { + validationException = new ValidationException(); + validationException.addValidationError( + String.format("value is not an instance of String for String parameter [%s].", getName()) + ); + return validationException; + } + + if (validatorWithData == null) { + return null; + } + + if (!validatorWithData.apply((String) value, vectorSpaceInfo)) { + validationException = new ValidationException(); + validationException.addValidationError(String.format("parameter validation failed for String parameter [%s].", getName())); + } + + return validationException; + } } /** @@ -190,6 +306,12 @@ public MethodComponentContextParameter( } return methodComponents.get(methodComponentContext.getName()).validate(methodComponentContext) == null; + }, (methodComponentContext, vectorSpaceInfo) -> { + if (!methodComponents.containsKey(methodComponentContext.getName())) { + return false; + } + return methodComponents.get(methodComponentContext.getName()) + .validateWithData(methodComponentContext, vectorSpaceInfo) == null; }); this.methodComponents = methodComponents; } @@ -216,6 +338,31 @@ public ValidationException validate(Object value) { return validationException; } + @Override + public ValidationException validateWithData(Object value, VectorSpaceInfo vectorSpaceInfo) { + ValidationException validationException = null; + if (!(value instanceof MethodComponentContext)) { + validationException = new ValidationException(); + validationException.addValidationError( + String.format("value is not an instance of for MethodComponentContext parameter [%s].", getName()) + ); + return validationException; + } + + if (validatorWithData == null) { + return null; + } + + if (!validatorWithData.apply((MethodComponentContext) value, vectorSpaceInfo)) { + validationException = new ValidationException(); + validationException.addValidationError( + String.format("parameter validation failed for MethodComponentContext parameter [%s].", getName()) + ); + } + + return validationException; + } + /** * Get method component by name * diff --git a/src/main/java/org/opensearch/knn/index/util/AbstractKNNLibrary.java b/src/main/java/org/opensearch/knn/index/util/AbstractKNNLibrary.java index f97d18810..0fe311094 100644 --- a/src/main/java/org/opensearch/knn/index/util/AbstractKNNLibrary.java +++ b/src/main/java/org/opensearch/knn/index/util/AbstractKNNLibrary.java @@ -11,6 +11,7 @@ import org.opensearch.common.ValidationException; import org.opensearch.knn.index.KNNMethod; import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.training.VectorSpaceInfo; import java.util.Map; @@ -39,6 +40,12 @@ public ValidationException validateMethod(KNNMethodContext knnMethodContext) { return getMethod(methodName).validate(knnMethodContext); } + @Override + public ValidationException validateMethodWithData(KNNMethodContext knnMethodContext, VectorSpaceInfo vectorSpaceInfo) { + String methodName = knnMethodContext.getMethodComponentContext().getName(); + return getMethod(methodName).validateWithData(knnMethodContext, vectorSpaceInfo); + } + @Override public boolean isTrainingRequired(KNNMethodContext knnMethodContext) { String methodName = knnMethodContext.getMethodComponentContext().getName(); diff --git a/src/main/java/org/opensearch/knn/index/util/Faiss.java b/src/main/java/org/opensearch/knn/index/util/Faiss.java index efd8a637c..bbb58bf1e 100644 --- a/src/main/java/org/opensearch/knn/index/util/Faiss.java +++ b/src/main/java/org/opensearch/knn/index/util/Faiss.java @@ -109,9 +109,6 @@ class Faiss extends NativeLibrary { .build() ); - // TODO: To think about in future: for PQ, if dimension is not divisible by code count, PQ will fail. Right now, - // we do not have a way to base validation off of dimension. Failure will happen during training in JNI. - // Define methods supported by faiss. See issue here: https://github.com/opensearch-project/k-NN/issues/1075 private final static Map HNSW_ENCODERS = ImmutableMap.builder() .putAll( ImmutableMap.of( @@ -122,7 +119,8 @@ class Faiss extends NativeLibrary { new Parameter.IntegerParameter( ENCODER_PARAMETER_PQ_M, ENCODER_PARAMETER_PQ_CODE_COUNT_DEFAULT, - v -> v > 0 && v < ENCODER_PARAMETER_PQ_CODE_COUNT_LIMIT + v -> v > 0 && v < ENCODER_PARAMETER_PQ_CODE_COUNT_LIMIT, + (v, vectorSpaceInfo) -> vectorSpaceInfo.getDimension() % v == 0 ) ) .addParameter( @@ -161,7 +159,8 @@ class Faiss extends NativeLibrary { new Parameter.IntegerParameter( ENCODER_PARAMETER_PQ_M, ENCODER_PARAMETER_PQ_CODE_COUNT_DEFAULT, - v -> v > 0 && v < ENCODER_PARAMETER_PQ_CODE_COUNT_LIMIT + v -> v > 0 && v < ENCODER_PARAMETER_PQ_CODE_COUNT_LIMIT, + (v, vectorSpaceInfo) -> vectorSpaceInfo.getDimension() % v == 0 ) ) .addParameter( diff --git a/src/main/java/org/opensearch/knn/index/util/KNNEngine.java b/src/main/java/org/opensearch/knn/index/util/KNNEngine.java index e282c69db..556785783 100644 --- a/src/main/java/org/opensearch/knn/index/util/KNNEngine.java +++ b/src/main/java/org/opensearch/knn/index/util/KNNEngine.java @@ -10,6 +10,7 @@ import org.opensearch.knn.index.KNNMethod; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.training.VectorSpaceInfo; import java.util.List; import java.util.Map; @@ -168,6 +169,11 @@ public ValidationException validateMethod(KNNMethodContext knnMethodContext) { return knnLibrary.validateMethod(knnMethodContext); } + @Override + public ValidationException validateMethodWithData(KNNMethodContext knnMethodContext, VectorSpaceInfo vectorSpaceInfo) { + return knnLibrary.validateMethodWithData(knnMethodContext, vectorSpaceInfo); + } + @Override public boolean isTrainingRequired(KNNMethodContext knnMethodContext) { return knnLibrary.isTrainingRequired(knnMethodContext); diff --git a/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java b/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java index f837566b8..cac5af2bb 100644 --- a/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java +++ b/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java @@ -15,6 +15,7 @@ import org.opensearch.knn.index.KNNMethod; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.training.VectorSpaceInfo; import java.util.Collections; import java.util.List; @@ -97,6 +98,16 @@ public interface KNNLibrary { */ ValidationException validateMethod(KNNMethodContext knnMethodContext); + /** + * Validate the knnMethodContext for the given library, using additional data not present in the method context. A ValidationException should be thrown if the method is + * deemed invalid. + * + * @param knnMethodContext to be validated + * @param vectorSpaceInfo additional data not present in the method context + * @return ValidationException produced by validation errors; null if no validations errors. + */ + ValidationException validateMethodWithData(KNNMethodContext knnMethodContext, VectorSpaceInfo vectorSpaceInfo); + /** * Returns whether training is required or not from knnMethodContext for the given library. * diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java index 114499100..84e986faa 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java @@ -10,6 +10,7 @@ import java.util.Objects; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.apache.lucene.util.VectorUtil; import org.opensearch.knn.index.KNNVectorScriptDocValues; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; @@ -48,13 +49,7 @@ private static void requireEqualDimension(final float[] queryVector, final float * @return L2 score */ public static float l2Squared(float[] queryVector, float[] inputVector) { - requireEqualDimension(queryVector, inputVector); - float squaredDistance = 0; - for (int i = 0; i < inputVector.length; i++) { - float diff = queryVector[i] - inputVector[i]; - squaredDistance += diff * diff; - } - return squaredDistance; + return VectorUtil.squareDistance(queryVector, inputVector); } private static float[] toFloat(List inputVector, VectorDataType vectorDataType) { @@ -148,20 +143,12 @@ public static float cosineSimilarity(List queryVector, KNNVectorScriptDo */ public static float cosinesimil(float[] queryVector, float[] inputVector) { requireEqualDimension(queryVector, inputVector); - float dotProduct = 0.0f; - float normQueryVector = 0.0f; - float normInputVector = 0.0f; - for (int i = 0; i < queryVector.length; i++) { - dotProduct += queryVector[i] * inputVector[i]; - normQueryVector += queryVector[i] * queryVector[i]; - normInputVector += inputVector[i] * inputVector[i]; - } - float normalizedProduct = normQueryVector * normInputVector; - if (normalizedProduct == 0) { + try { + return VectorUtil.cosine(queryVector, inputVector); + } catch (IllegalArgumentException | AssertionError e) { logger.debug("Invalid vectors for cosine. Returning minimum score to put this result to end"); return 0.0f; } - return (float) (dotProduct / (Math.sqrt(normalizedProduct))); } /** @@ -217,7 +204,6 @@ public static float calculateHammingBit(Long queryLong, Long inputLong) { * @return L1 score */ public static float l1Norm(float[] queryVector, float[] inputVector) { - requireEqualDimension(queryVector, inputVector); float distance = 0; for (int i = 0; i < inputVector.length; i++) { float diff = queryVector[i] - inputVector[i]; @@ -255,7 +241,6 @@ public static float l1Norm(List queryVector, KNNVectorScriptDocValues do * @return L-inf score */ public static float lInfNorm(float[] queryVector, float[] inputVector) { - requireEqualDimension(queryVector, inputVector); float distance = 0; for (int i = 0; i < inputVector.length; i++) { float diff = queryVector[i] - inputVector[i]; @@ -293,12 +278,7 @@ public static float lInfNorm(List queryVector, KNNVectorScriptDocValues * @return dot product score */ public static float innerProduct(float[] queryVector, float[] inputVector) { - requireEqualDimension(queryVector, inputVector); - float distance = 0; - for (int i = 0; i < inputVector.length; i++) { - distance += queryVector[i] * inputVector[i]; - } - return distance; + return VectorUtil.dotProduct(queryVector, inputVector); } /** diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java index 9035a8e84..5f3913ac5 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java @@ -22,6 +22,7 @@ import org.opensearch.knn.index.IndexUtil; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.indices.ModelDao; +import org.opensearch.knn.training.VectorSpaceInfo; import java.io.IOException; @@ -281,6 +282,12 @@ public ActionRequestValidationException validate() { exception.addValidationErrors(validationException.validationErrors()); } + validationException = this.knnMethodContext.validateWithData(new VectorSpaceInfo(dimension)); + if (validationException != null) { + exception = new ActionRequestValidationException(); + exception.addValidationErrors(validationException.validationErrors()); + } + if (!this.knnMethodContext.isTrainingRequired()) { exception = exception == null ? new ActionRequestValidationException() : exception; exception.addValidationError("Method does not require training."); diff --git a/src/main/java/org/opensearch/knn/training/VectorSpaceInfo.java b/src/main/java/org/opensearch/knn/training/VectorSpaceInfo.java new file mode 100644 index 000000000..13843486d --- /dev/null +++ b/src/main/java/org/opensearch/knn/training/VectorSpaceInfo.java @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.training; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.Setter; + +/** + * A data spec containing relevant information for validation. + */ +@Getter +@Setter +@AllArgsConstructor +public class VectorSpaceInfo { + private int dimension; +} diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index 145fa2cff..b018740bc 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -1311,8 +1311,8 @@ public void testSharedIndexState_whenOneIndexDeleted_thenSecondIndexIsStillSearc .startObject(METHOD_ENCODER_PARAMETER) .field(NAME, ENCODER_PQ) .startObject(PARAMETERS) - .field(ENCODER_PARAMETER_PQ_M, pqCodeSize) - .field(ENCODER_PARAMETER_PQ_CODE_SIZE, pqM) + .field(ENCODER_PARAMETER_PQ_M, pqM) + .field(ENCODER_PARAMETER_PQ_CODE_SIZE, pqCodeSize) .endObject() .endObject() .endObject() @@ -1648,6 +1648,138 @@ public void testFiltering_whenUsingFaissExactSearchWithIP_thenMatchExpectedScore } } + @SneakyThrows + public void testHNSW_InvalidPQM_thenFail() { + String trainingIndexName = "training-index"; + String trainingFieldName = "training-field"; + + String modelId = "test-model"; + String modelDescription = "test model"; + + List mValues = ImmutableList.of(16, 32, 64, 128); + int invalidPQM = 3; + + // training data needs to be at least equal to the number of centroids for PQ + // which is 2^8 = 256. 8 because thats the only valid code_size for HNSWPQ + int trainingDataCount = 256; + + SpaceType spaceType = SpaceType.L2; + + Integer dimension = testData.indexData.vectors[0].length; + + /* + * Builds the below json: + * { + * "name": "hnsw", + * "engine": "faiss", + * "space_type": "l2", + * "parameters": { + * "encoder": { + * "name": "pq", + * "parameters": { + * "m": 3 + * } + * } + * } + * } + */ + + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, METHOD_HNSW) + .field(KNN_ENGINE, FAISS_NAME) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .startObject(PARAMETERS) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, ENCODER_PQ) + .startObject(PARAMETERS) + .field(ENCODER_PARAMETER_PQ_M, invalidPQM) + .endObject() + .endObject() + .endObject() + .endObject(); + Map in = xContentBuilderToMap(xContentBuilder); + + createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension); + ResponseException re = expectThrows( + ResponseException.class, + () -> ingestDataAndTrainModel(modelId, trainingIndexName, trainingFieldName, dimension, modelDescription, in, trainingDataCount) + ); + assertTrue( + re.getMessage().contains("Validation Failed: 1: parameter validation failed for MethodComponentContext parameter [encoder].;") + ); + } + + @SneakyThrows + public void testIVF_InvalidPQM_thenFail() { + String trainingIndexName = "training-index"; + String trainingFieldName = "training-field"; + + String modelId = "test-model"; + String modelDescription = "test model"; + + List mValues = ImmutableList.of(16, 32, 64, 128); + int invalidPQM = 3; + + // training data needs to be at least equal to the number of centroids for PQ + // which is 2^8 = 256. + int trainingDataCount = 256; + + int dimension = testData.indexData.vectors[0].length; + SpaceType spaceType = SpaceType.L2; + int ivfNlist = 4; + int ivfNprobes = 4; + int pqCodeSize = 8; + + /* + * Builds the below json: + * { + * "name": "ivf", + * "engine": "faiss", + * "space_type": "l2", + * "parameters": { + * "nprobes": 8, + * "nlist": 4, + * "encoder": { + * "name": "pq", + * "parameters": { + * "m": 3, + * "code_size": 8 + * } + * } + * } + * } + */ + + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, METHOD_IVF) + .field(KNN_ENGINE, FAISS_NAME) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_NPROBES, ivfNprobes) + .field(METHOD_PARAMETER_NLIST, ivfNlist) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, ENCODER_PQ) + .startObject(PARAMETERS) + .field(ENCODER_PARAMETER_PQ_M, invalidPQM) + .field(ENCODER_PARAMETER_PQ_CODE_SIZE, pqCodeSize) + .endObject() + .endObject() + .endObject() + .endObject(); + Map in = xContentBuilderToMap(xContentBuilder); + + createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension); + ResponseException re = expectThrows( + ResponseException.class, + () -> ingestDataAndTrainModel(modelId, trainingIndexName, trainingFieldName, dimension, modelDescription, in, trainingDataCount) + ); + assertTrue( + re.getMessage().contains("Validation Failed: 1: parameter validation failed for MethodComponentContext parameter [encoder].;") + ); + } + protected void setupKNNIndexForFilterQuery() throws Exception { // Create Mappings XContentBuilder builder = XContentFactory.jsonBuilder() diff --git a/src/test/java/org/opensearch/knn/index/KNNMethodTests.java b/src/test/java/org/opensearch/knn/index/KNNMethodTests.java index d4dd989f7..607ca849e 100644 --- a/src/test/java/org/opensearch/knn/index/KNNMethodTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNMethodTests.java @@ -17,6 +17,7 @@ import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.knn.training.VectorSpaceInfo; import java.io.IOException; import java.util.HashMap; @@ -44,9 +45,9 @@ public void testHasSpace() { KNNMethod knnMethod = KNNMethod.Builder.builder(MethodComponent.Builder.builder(name).build()) .addSpaces(SpaceType.L2, SpaceType.COSINESIMIL) .build(); - assertTrue(knnMethod.containsSpace(SpaceType.L2)); - assertTrue(knnMethod.containsSpace(SpaceType.COSINESIMIL)); - assertFalse(knnMethod.containsSpace(SpaceType.INNER_PRODUCT)); + assertTrue(knnMethod.isSpaceTypeSupported(SpaceType.L2)); + assertTrue(knnMethod.isSpaceTypeSupported(SpaceType.COSINESIMIL)); + assertFalse(knnMethod.isSpaceTypeSupported(SpaceType.INNER_PRODUCT)); } /** @@ -93,6 +94,52 @@ public void testValidate() throws IOException { assertNull(knnMethod.validate(knnMethodContext3)); } + /** + * Test KNNMethod validateWithData + */ + public void testValidateWithData() throws IOException { + String methodName = "test-method"; + KNNMethod knnMethod = KNNMethod.Builder.builder(MethodComponent.Builder.builder(methodName).build()) + .addSpaces(SpaceType.L2) + .build(); + + VectorSpaceInfo testVectorSpaceInfo = new VectorSpaceInfo(4); + + // Invalid space + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, methodName) + .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.INNER_PRODUCT.getValue()) + .endObject(); + Map in = xContentBuilderToMap(xContentBuilder); + KNNMethodContext knnMethodContext1 = KNNMethodContext.parse(in); + assertNotNull(knnMethod.validateWithData(knnMethodContext1, testVectorSpaceInfo)); + + // Invalid methodComponent + xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, methodName) + .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue()) + .startObject(PARAMETERS) + .field("invalid", "invalid") + .endObject() + .endObject(); + in = xContentBuilderToMap(xContentBuilder); + KNNMethodContext knnMethodContext2 = KNNMethodContext.parse(in); + + assertNotNull(knnMethod.validateWithData(knnMethodContext2, testVectorSpaceInfo)); + + // Valid everything + xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, methodName) + .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue()) + .endObject(); + in = xContentBuilderToMap(xContentBuilder); + KNNMethodContext knnMethodContext3 = KNNMethodContext.parse(in); + assertNull(knnMethod.validateWithData(knnMethodContext3, testVectorSpaceInfo)); + } + public void testGetAsMap() { SpaceType spaceType = SpaceType.DEFAULT; String methodName = "test-method"; @@ -122,6 +169,6 @@ public void testBuilder() { builder.addSpaces(SpaceType.L2); knnMethod = builder.build(); - assertTrue(knnMethod.containsSpace(SpaceType.L2)); + assertTrue(knnMethod.isSpaceTypeSupported(SpaceType.L2)); } } diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java index 3f98a9136..66e2893c0 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java @@ -5,7 +5,15 @@ package org.opensearch.knn.index; -import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.KnnByteVectorField; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.DocValues; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.search.DocIdSetIterator; import org.opensearch.knn.KNNTestCase; import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.document.BinaryDocValuesField; @@ -33,26 +41,39 @@ public class KNNVectorScriptDocValuesTests extends KNNTestCase { public void setUp() throws Exception { super.setUp(); directory = newDirectory(); - createKNNVectorDocument(directory); + Class valuesClass = randomFrom(BinaryDocValues.class, ByteVectorValues.class, FloatVectorValues.class); + createKNNVectorDocument(directory, valuesClass); reader = DirectoryReader.open(directory); - LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - scriptDocValues = new KNNVectorScriptDocValues( - leafReaderContext.reader().getBinaryDocValues(MOCK_INDEX_FIELD_NAME), - MOCK_INDEX_FIELD_NAME, - VectorDataType.FLOAT - ); + LeafReader leafReader = reader.getContext().leaves().get(0).reader(); + DocIdSetIterator vectorValues; + if (BinaryDocValues.class.equals(valuesClass)) { + vectorValues = DocValues.getBinary(leafReader, MOCK_INDEX_FIELD_NAME); + } else if (ByteVectorValues.class.equals(valuesClass)) { + vectorValues = leafReader.getByteVectorValues(MOCK_INDEX_FIELD_NAME); + } else { + vectorValues = leafReader.getFloatVectorValues(MOCK_INDEX_FIELD_NAME); + } + + scriptDocValues = KNNVectorScriptDocValues.create(vectorValues, MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT); } - private void createKNNVectorDocument(Directory directory) throws IOException { + private void createKNNVectorDocument(Directory directory, Class valuesClass) throws IOException { IndexWriterConfig conf = newIndexWriterConfig(new MockAnalyzer(random())); IndexWriter writer = new IndexWriter(directory, conf); Document knnDocument = new Document(); - knnDocument.add( - new BinaryDocValuesField( + Field field; + if (BinaryDocValues.class.equals(valuesClass)) { + field = new BinaryDocValuesField( MOCK_INDEX_FIELD_NAME, new VectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA, new FieldType()).binaryValue() - ) - ); + ); + } else if (ByteVectorValues.class.equals(valuesClass)) { + field = new KnnByteVectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_BYTE_VECTOR_DATA); + } else { + field = new KnnFloatVectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA); + } + + knnDocument.add(field); writer.addDocument(knnDocument); writer.commit(); writer.close(); @@ -84,4 +105,18 @@ public void testSize() throws IOException { public void testGet() throws IOException { expectThrows(UnsupportedOperationException.class, () -> scriptDocValues.get(0)); } + + public void testUnsupportedValues() throws IOException { + expectThrows( + IllegalArgumentException.class, + () -> KNNVectorScriptDocValues.create(DocValues.emptyNumeric(), MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT) + ); + } + + public void testEmptyValues() throws IOException { + KNNVectorScriptDocValues values = KNNVectorScriptDocValues.emptyValues(MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT); + assertEquals(0, values.size()); + scriptDocValues.setNextDocId(0); + assertEquals(0, values.size()); + } } diff --git a/src/test/java/org/opensearch/knn/index/NestedSearchIT.java b/src/test/java/org/opensearch/knn/index/NestedSearchIT.java index 20b87708f..73d00141f 100644 --- a/src/test/java/org/opensearch/knn/index/NestedSearchIT.java +++ b/src/test/java/org/opensearch/knn/index/NestedSearchIT.java @@ -75,11 +75,13 @@ public void testNestedSearchWithLucene_whenKIsTwo_thenReturnTwoResults() { refreshIndex(INDEX_NAME); forceMergeKnnIndex(INDEX_NAME); - Float[] queryVector = { 1f, 1f }; + Float[] queryVector = { 14f, 14f }; Response response = queryNestedField(INDEX_NAME, 2, queryVector); String entity = EntityUtils.toString(response.getEntity()); assertEquals(2, parseHits(entity)); assertEquals(2, parseTotalSearchHits(entity)); + assertEquals("14", parseIds(entity).get(0)); + assertEquals("13", parseIds(entity).get(1)); } @SneakyThrows @@ -97,11 +99,13 @@ public void testNestedSearchWithFaiss_whenKIsTwo_thenReturnTwoResults() { refreshIndex(INDEX_NAME); forceMergeKnnIndex(INDEX_NAME); - Float[] queryVector = { 1f, 1f }; + Float[] queryVector = { 14f, 14f }; Response response = queryNestedField(INDEX_NAME, 2, queryVector); String entity = EntityUtils.toString(response.getEntity()); assertEquals(2, parseHits(entity)); assertEquals(2, parseTotalSearchHits(entity)); + assertEquals("14", parseIds(entity).get(0)); + assertEquals("13", parseIds(entity).get(1)); } /** diff --git a/src/test/java/org/opensearch/knn/index/ParameterTests.java b/src/test/java/org/opensearch/knn/index/ParameterTests.java index 08decd592..2f3f19727 100644 --- a/src/test/java/org/opensearch/knn/index/ParameterTests.java +++ b/src/test/java/org/opensearch/knn/index/ParameterTests.java @@ -17,6 +17,7 @@ import org.opensearch.knn.index.Parameter.IntegerParameter; import org.opensearch.knn.index.Parameter.StringParameter; import org.opensearch.knn.index.Parameter.MethodComponentContextParameter; +import org.opensearch.knn.training.VectorSpaceInfo; import java.util.Map; @@ -31,6 +32,12 @@ public void testGetDefaultValue() { public ValidationException validate(Object value) { return null; } + + @Override + public ValidationException validateWithData(Object value, VectorSpaceInfo vectorSpaceInfo) { + return null; + } + }; assertEquals(defaultValue, parameter.getDefaultValue()); @@ -52,6 +59,29 @@ public void testIntegerParameter_validate() { assertNull(parameter.validate(12)); } + /** + * Test integer parameter validate + */ + public void testIntegerParameter_validateWithData() { + final IntegerParameter parameter = new IntegerParameter( + "test", + 1, + v -> v > 0, + (v, vectorSpaceInfo) -> v > vectorSpaceInfo.getDimension() + ); + + VectorSpaceInfo testVectorSpaceInfo = new VectorSpaceInfo(0); + + // Invalid type + assertNotNull(parameter.validateWithData("String", testVectorSpaceInfo)); + + // Invalid value + assertNotNull(parameter.validateWithData(-1, testVectorSpaceInfo)); + + // valid value + assertNull(parameter.validateWithData(12, testVectorSpaceInfo)); + } + public void testStringParameter_validate() { final StringParameter parameter = new StringParameter("test_parameter", "default_value", v -> "test".equals(v)); @@ -65,6 +95,36 @@ public void testStringParameter_validate() { assertNull(parameter.validate("test")); } + public void testStringParameter_validateWithData() { + final StringParameter parameter = new StringParameter( + "test_parameter", + "default_value", + v -> "test".equals(v), + (v, vectorSpaceInfo) -> { + if (vectorSpaceInfo.getDimension() > 0) { + return "test".equals(v); + } + return false; + } + ); + + VectorSpaceInfo testVectorSpaceInfo = new VectorSpaceInfo(1); + + // Invalid type + assertNotNull(parameter.validateWithData(5, testVectorSpaceInfo)); + + // null + assertNotNull(parameter.validateWithData(null, testVectorSpaceInfo)); + + // valid value + assertNull(parameter.validateWithData("test", testVectorSpaceInfo)); + + testVectorSpaceInfo.setDimension(0); + + // invalid value + assertNotNull(parameter.validateWithData("test", testVectorSpaceInfo)); + } + public void testMethodComponentContextParameter_validate() { String methodComponentName1 = "method-1"; String parameterKey1 = "parameter_key_1"; @@ -109,6 +169,55 @@ public void testMethodComponentContextParameter_validate() { assertNull(parameter.validate(methodComponentContext)); } + public void testMethodComponentContextParameter_validateWithData() { + String methodComponentName1 = "method-1"; + String parameterKey1 = "parameter_key_1"; + Integer parameterValue1 = 12; + + Map defaultParameterMap = ImmutableMap.of(parameterKey1, parameterValue1); + MethodComponentContext methodComponentContext = new MethodComponentContext(methodComponentName1, defaultParameterMap); + + Map methodComponentMap = ImmutableMap.of( + methodComponentName1, + MethodComponent.Builder.builder(parameterKey1) + .addParameter( + parameterKey1, + new IntegerParameter(parameterKey1, 1, v -> v > 0, (v, vectorSpaceInfo) -> v > vectorSpaceInfo.getDimension()) + ) + .build() + ); + + final MethodComponentContextParameter parameter = new MethodComponentContextParameter( + "test", + methodComponentContext, + methodComponentMap + ); + + VectorSpaceInfo testVectorSpaceInfo = new VectorSpaceInfo(0); + + // Invalid type + assertNotNull(parameter.validateWithData(17, testVectorSpaceInfo)); + assertNotNull(parameter.validateWithData("invalid-value", testVectorSpaceInfo)); + + // Invalid value + String invalidMethodComponentName = "invalid-method"; + MethodComponentContext invalidMethodComponentContext1 = new MethodComponentContext(invalidMethodComponentName, defaultParameterMap); + assertNotNull(parameter.validateWithData(invalidMethodComponentContext1, testVectorSpaceInfo)); + + String invalidParameterKey = "invalid-parameter"; + Map invalidParameterMap1 = ImmutableMap.of(invalidParameterKey, parameterValue1); + MethodComponentContext invalidMethodComponentContext2 = new MethodComponentContext(methodComponentName1, invalidParameterMap1); + assertNotNull(parameter.validateWithData(invalidMethodComponentContext2, testVectorSpaceInfo)); + + String invalidParameterValue = "invalid-value"; + Map invalidParameterMap2 = ImmutableMap.of(parameterKey1, invalidParameterValue); + MethodComponentContext invalidMethodComponentContext3 = new MethodComponentContext(methodComponentName1, invalidParameterMap2); + assertNotNull(parameter.validateWithData(invalidMethodComponentContext3, testVectorSpaceInfo)); + + // valid value + assertNull(parameter.validateWithData(methodComponentContext, testVectorSpaceInfo)); + } + public void testMethodComponentContextParameter_getMethodComponent() { String methodComponentName1 = "method-1"; String parameterKey1 = "parameter_key_1"; diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java index 4423c85d8..19270717d 100644 --- a/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java @@ -57,7 +57,7 @@ private KNNVectorScriptDocValues getKNNFloatVectorScriptDocValues() { createKNNFloatVectorDocument(directory); reader = DirectoryReader.open(directory); LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - return new KNNVectorScriptDocValues( + return KNNVectorScriptDocValues.create( leafReaderContext.reader().getBinaryDocValues(VectorDataTypeTests.MOCK_FLOAT_INDEX_FIELD_NAME), VectorDataTypeTests.MOCK_FLOAT_INDEX_FIELD_NAME, VectorDataType.FLOAT @@ -70,7 +70,7 @@ private KNNVectorScriptDocValues getKNNByteVectorScriptDocValues() { createKNNByteVectorDocument(directory); reader = DirectoryReader.open(directory); LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - return new KNNVectorScriptDocValues( + return KNNVectorScriptDocValues.create( leafReaderContext.reader().getBinaryDocValues(VectorDataTypeTests.MOCK_BYTE_INDEX_FIELD_NAME), VectorDataTypeTests.MOCK_BYTE_INDEX_FIELD_NAME, VectorDataType.BYTE diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java index 6b40f375c..3cfbe56f1 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java @@ -47,7 +47,7 @@ public void testL2() { public void testCosineSimilarity() { float[] arrayFloat = new float[] { 1.0f, 2.0f, 3.0f }; - List arrayListQueryObject = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0)); + List arrayListQueryObject = new ArrayList<>(Arrays.asList(2.0, 4.0, 6.0)); float[] arrayFloat2 = new float[] { 2.0f, 4.0f, 6.0f }; KNNMethodContext knnMethodContext = KNNMethodContext.getDefault(); @@ -59,7 +59,7 @@ public void testCosineSimilarity() { ); KNNScoringSpace.CosineSimilarity cosineSimilarity = new KNNScoringSpace.CosineSimilarity(arrayListQueryObject, fieldType); - assertEquals(3F, cosineSimilarity.scoringMethod.apply(arrayFloat2, arrayFloat), 0.1F); + assertEquals(2F, cosineSimilarity.scoringMethod.apply(arrayFloat2, arrayFloat), 0.1F); // invalid zero vector final List queryZeroVector = List.of(0.0f, 0.0f, 0.0f); diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java index 8c43a4acf..22110accd 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java @@ -280,7 +280,7 @@ public KNNVectorScriptDocValues getScriptDocValues(String fieldName) throws IOEx if (scriptDocValues == null) { reader = DirectoryReader.open(directory); LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - scriptDocValues = new KNNVectorScriptDocValues( + scriptDocValues = KNNVectorScriptDocValues.create( leafReaderContext.reader().getBinaryDocValues(fieldName), fieldName, VectorDataType.FLOAT diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java index 901511a68..5a83891d9 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java @@ -7,6 +7,7 @@ import java.util.function.BiFunction; import java.util.function.Function; +import org.opensearch.ExceptionsHelper; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.knn.KNNResult; import org.opensearch.knn.common.KNNConstants; @@ -193,7 +194,7 @@ public void testUnequalDimensions() throws Exception { } @SuppressWarnings("unchecked") - public void testKNNScoreforNonVectorDocument() throws Exception { + public void testKNNScoreForNonVectorDocument() throws Exception { /* * Create knn index and populate data */ @@ -599,7 +600,7 @@ public void testKNNScriptScoreOnModelBasedIndex() throws Exception { if (spaceType != SpaceType.HAMMING_BIT) { final float[] queryVector = randomVector(dimensions); final BiFunction scoreFunction = getScoreFunction(spaceType, queryVector); - createIndexAndAssertScriptScore(testMapping, spaceType, scoreFunction, dimensions, queryVector); + createIndexAndAssertScriptScore(testMapping, spaceType, scoreFunction, dimensions, queryVector, true); } } } @@ -612,7 +613,16 @@ private List createMappers(int dimensions) throws Exception { dimensions, KNNConstants.METHOD_HNSW, KNNEngine.LUCENE.getName(), - SpaceType.DEFAULT.getValue() + SpaceType.DEFAULT.getValue(), + true + ), + createKnnIndexMapping( + FIELD_NAME, + dimensions, + KNNConstants.METHOD_HNSW, + KNNEngine.LUCENE.getName(), + SpaceType.DEFAULT.getValue(), + false ) ); } @@ -625,12 +635,22 @@ private float[] randomVector(int dimensions) { return vector; } - private Map createDataset(Function scoreFunction, int dimensions, int numDocs) { - final Map dataset = new HashMap<>(numDocs); - for (int i = 0; i < numDocs; i++) { + private Map createDataset( + Function scoreFunction, + int dimensions, + int numDocsWithField, + boolean dense + ) { + final Map dataset = new HashMap<>(dense ? numDocsWithField : numDocsWithField * 3); + int id = 0; + for (int i = 0; i < numDocsWithField; i++) { + final int dummyDocs = dense ? 0 : randomIntBetween(2, 5); + for (int j = 0; j < dummyDocs; j++) { + dataset.put(Integer.toString(id++), null); + } final float[] vector = randomVector(dimensions); final float score = scoreFunction.apply(vector); - dataset.put(Integer.toString(i), new KNNResult(Integer.toString(i), vector, score)); + dataset.put(Integer.toString(id), new KNNResult(Integer.toString(id++), vector, score)); } return dataset; } @@ -669,7 +689,8 @@ private void testKNNScriptScore(SpaceType spaceType) throws Exception { final float[] queryVector = randomVector(dims); final BiFunction scoreFunction = getScoreFunction(spaceType, queryVector); for (String mapper : createMappers(dims)) { - createIndexAndAssertScriptScore(mapper, spaceType, scoreFunction, dims, queryVector); + createIndexAndAssertScriptScore(mapper, spaceType, scoreFunction, dims, queryVector, true); + createIndexAndAssertScriptScore(mapper, spaceType, scoreFunction, dims, queryVector, false); } } @@ -678,16 +699,20 @@ private void createIndexAndAssertScriptScore( SpaceType spaceType, BiFunction scoreFunction, int dimensions, - float[] queryVector + float[] queryVector, + boolean dense ) throws Exception { /* * Create knn index and populate data */ createKnnIndex(INDEX_NAME, mapper); - Map dataset = createDataset(v -> scoreFunction.apply(queryVector, v), dimensions, randomIntBetween(4, 10)); - for (Map.Entry entry : dataset.entrySet()) { - addKnnDoc(INDEX_NAME, entry.getKey(), FIELD_NAME, entry.getValue().getVector()); - } + final int numDocsWithField = randomIntBetween(4, 10); + Map dataset = createDataset(v -> scoreFunction.apply(queryVector, v), dimensions, numDocsWithField, dense); + final float[] dummyVector = new float[1]; + dataset.forEach((k, v) -> { + final float[] vector = (v != null) ? v.getVector() : dummyVector; + ExceptionsHelper.catchAsRuntimeException(() -> addKnnDoc(INDEX_NAME, k, (v != null) ? FIELD_NAME : "dummy", vector)); + }); /** * Construct Search Request @@ -703,7 +728,7 @@ private void createIndexAndAssertScriptScore( params.put("field", FIELD_NAME); params.put("query_value", queryVector); params.put("space_type", spaceType.getValue()); - Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); + Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params, numDocsWithField); Response response = client().performRequest(request); assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); diff --git a/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptIT.java b/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptIT.java index 0315c47c5..5325d1205 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptIT.java +++ b/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptIT.java @@ -563,7 +563,9 @@ public void testL2ScriptingWithLuceneBackedIndex() throws Exception { new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()) ); properties.add( - new MappingProperty(FIELD_NAME, KNNVectorFieldMapper.CONTENT_TYPE).dimension("2").knnMethodContext(knnMethodContext) + new MappingProperty(FIELD_NAME, KNNVectorFieldMapper.CONTENT_TYPE).dimension("2") + .knnMethodContext(knnMethodContext) + .docValues(randomBoolean()) ); String source = String.format("1/(1 + l2Squared([1.0f, 1.0f], doc['%s']))", FIELD_NAME); diff --git a/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java b/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java index cff4d5805..46240e830 100644 --- a/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java +++ b/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java @@ -16,6 +16,7 @@ import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.util.KNNLibrary; +import org.opensearch.knn.training.VectorSpaceInfo; import org.opensearch.test.OpenSearchTestCase; import java.util.Map; @@ -78,6 +79,11 @@ public ValidationException validateMethod(KNNMethodContext knnMethodContext) { return null; } + @Override + public ValidationException validateMethodWithData(KNNMethodContext knnMethodContext, VectorSpaceInfo vectorSpaceInfo) { + return null; + } + @Override public boolean isTrainingRequired(KNNMethodContext knnMethodContext) { return false; diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 5010ff6ee..fa6a13f2f 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -39,6 +39,7 @@ import org.opensearch.index.query.functionscore.ScriptScoreQueryBuilder; import org.opensearch.core.rest.RestStatus; import org.opensearch.script.Script; +import org.opensearch.search.SearchService; import org.opensearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder; import javax.management.MBeanServerInvocationHandler; @@ -954,9 +955,16 @@ protected Request constructScriptScoreContextSearchRequest( } protected Request constructKNNScriptQueryRequest(String indexName, QueryBuilder qb, Map params) throws Exception { + return constructKNNScriptQueryRequest(indexName, qb, params, SearchService.DEFAULT_SIZE); + } + + protected Request constructKNNScriptQueryRequest(String indexName, QueryBuilder qb, Map params, int size) + throws Exception { Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, KNNScoringScriptEngine.NAME, KNNScoringScriptEngine.SCRIPT_SOURCE, params); ScriptScoreQueryBuilder sc = new ScriptScoreQueryBuilder(qb, script); - XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("query"); + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + builder.field("size", size); + builder.startObject("query"); builder.startObject("script_score"); builder.field("query"); sc.query().toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -1422,7 +1430,7 @@ public void assertTrainingFails(String modelId, int attempts, int delayInMillis) assertNotEquals(ModelState.CREATED, modelState); } - fail("Training did not succeed after " + attempts + " attempts with a delay of " + delayInMillis + " ms."); + fail("Training did not fail after " + attempts + " attempts with a delay of " + delayInMillis + " ms."); } protected boolean systemIndexExists(final String indexName) throws IOException {