diff --git a/CHANGELOG.md b/CHANGELOG.md index 2aed6ce6c2..ef3728e9d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.15...2.x) ### Features * Adds dynamic query parameter ef_search [#1783](https://github.com/opensearch-project/k-NN/pull/1783) +* Adds dynamic query parameter ef_search in radial search faiss engine [#1790](https://github.com/opensearch-project/k-NN/pull/1790) ### Enhancements ### Bug Fixes ### Infrastructure diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index 1749b34af6..5ac17cfd13 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -102,6 +102,7 @@ namespace knn_jni { * @param indexPointerJ - pointer to the index * @param queryVectorJ - the query vector * @param radiusJ - the radius for the range search + * @param methodParamsJ - the method parameters * @param maxResultsWindowJ - the maximum number of results to return * @param filterIdsJ - the filter ids * @param filterIdsTypeJ - the filter ids type @@ -110,7 +111,7 @@ namespace knn_jni { * @return an array of RangeQueryResults */ jobjectArray RangeSearchWithFilter(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ, jfloatArray queryVectorJ, - jfloat radiusJ, jint maxResultWindowJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ); + jfloat radiusJ, jobject methodParamsJ, jint maxResultWindowJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ); /* * Perform a range search against the index located in memory at indexPointerJ. @@ -118,13 +119,14 @@ namespace knn_jni { * @param indexPointerJ - pointer to the index * @param queryVectorJ - the query vector * @param radiusJ - the radius for the range search + * @param methodParamsJ - the method parameters * @param maxResultsWindowJ - the maximum number of results to return * @param parentIdsJ - the parent ids * * @return an array of RangeQueryResults */ jobjectArray RangeSearch(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ, jfloatArray queryVectorJ, - jfloat radiusJ, jint maxResultWindowJ, jintArray parentIdsJ); + jfloat radiusJ, jobject methodParamsJ, jint maxResultWindowJ, jintArray parentIdsJ); } } diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index 9ead7dfe77..3d6aef45c4 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -150,18 +150,18 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors /* * Class: org_opensearch_knn_jni_FaissService * Method: rangeSearchIndexWithFilter -* Signature: (J[FJ[I)[Lorg/opensearch/knn/index/query/RangeQueryResult; +* Signature: (J[FJLjava/util/MapI[JII)[Lorg/opensearch/knn/index/query/RangeQueryResult; */ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndexWithFilter - (JNIEnv *, jclass, jlong, jfloatArray, jfloat, jint, jlongArray, jint, jintArray); + (JNIEnv *, jclass, jlong, jfloatArray, jfloat, jobject, jint, jlongArray, jint, jintArray); /* * Class: org_opensearch_knn_jni_FaissService * Method: rangeSearchIndex - * Signature: (J[FJ[I)[Lorg/opensearch/knn/index/query/RangeQueryResult; + * Signature: (J[FJLjava/util/MapII)[Lorg/opensearch/knn/index/query/RangeQueryResult; */ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndex - (JNIEnv *, jclass, jlong, jfloatArray, jfloat, jint, jintArray); + (JNIEnv *, jclass, jlong, jfloatArray, jfloat, jobject, jint, jintArray); #ifdef __cplusplus } diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index 3eda03b419..c4c6e18eb5 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -716,12 +716,12 @@ faiss::IndexIVFPQ * extractIVFPQIndex(faiss::Index * index) { } jobjectArray knn_jni::faiss_wrapper::RangeSearch(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ, - jfloatArray queryVectorJ, jfloat radiusJ, jint maxResultWindowJ, jintArray parentIdsJ) { - return knn_jni::faiss_wrapper::RangeSearchWithFilter(jniUtil, env, indexPointerJ, queryVectorJ, radiusJ, maxResultWindowJ, nullptr, 0, parentIdsJ); + jfloatArray queryVectorJ, jfloat radiusJ, jobject methodParamsJ, jint maxResultWindowJ, jintArray parentIdsJ) { + return knn_jni::faiss_wrapper::RangeSearchWithFilter(jniUtil, env, indexPointerJ, queryVectorJ, radiusJ, methodParamsJ, maxResultWindowJ, nullptr, 0, parentIdsJ); } jobjectArray knn_jni::faiss_wrapper::RangeSearchWithFilter(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ, - jfloatArray queryVectorJ, jfloat radiusJ, jint maxResultWindowJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) { + jfloatArray queryVectorJ, jfloat radiusJ, jobject methodParamsJ, jint maxResultWindowJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) { if (queryVectorJ == nullptr) { throw std::runtime_error("Query Vector cannot be null"); } @@ -734,6 +734,11 @@ jobjectArray knn_jni::faiss_wrapper::RangeSearchWithFilter(knn_jni::JNIUtilInter float *rawQueryVector = jniUtil->GetFloatArrayElements(env, queryVectorJ, nullptr); + std::unordered_map methodParams; + if (methodParamsJ != nullptr) { + methodParams = jniUtil->ConvertJavaMapToCppMap(env, methodParamsJ); + } + // The res will be freed by ~RangeSearchResult() in FAISS // The second parameter is always true, as lims is allocated by FAISS faiss::RangeSearchResult res(1, true); @@ -755,9 +760,8 @@ jobjectArray knn_jni::faiss_wrapper::RangeSearchWithFilter(knn_jni::JNIUtilInter std::vector idGrouperBitmap; auto hnswReader = dynamic_cast(indexReader->index); if(hnswReader) { - // Setting the ef_search value equal to what was provided during index creation. SearchParametersHNSW has a default - // value of ef_search = 16 which will then be used. - hnswParams.efSearch = hnswReader->hnsw.efSearch; + // Query param ef_search supersedes ef_search provided during index setting. + hnswParams.efSearch = knn_jni::commons::getIntegerMethodParameter(env, jniUtil, methodParams, EF_SEARCH, hnswReader->hnsw.efSearch); hnswParams.sel = idSelector.get(); if (parentIdsJ != nullptr) { idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap); @@ -785,12 +789,13 @@ jobjectArray knn_jni::faiss_wrapper::RangeSearchWithFilter(knn_jni::JNIUtilInter std::unique_ptr idGrouper; std::vector idGrouperBitmap; auto hnswReader = dynamic_cast(indexReader->index); - if(hnswReader!= nullptr && parentIdsJ != nullptr) { - // Setting the ef_search value equal to what was provided during index creation. SearchParametersHNSW has a default - // value of ef_search = 16 which will then be used. - hnswParams.efSearch = hnswReader->hnsw.efSearch; - idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap); - hnswParams.grp = idGrouper.get(); + if(hnswReader!= nullptr) { + // Query param ef_search supersedes ef_search provided during index setting. + hnswParams.efSearch = knn_jni::commons::getIntegerMethodParameter(env, jniUtil, methodParams, EF_SEARCH, hnswReader->hnsw.efSearch); + if (parentIdsJ != nullptr) { + idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap); + hnswParams.grp = idGrouper.get(); + } searchParameters = &hnswParams; } try { diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index d901cde30d..5f9c83ea86 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -241,11 +241,11 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndex(JNIEnv * env, jclass cls, jlong indexPointerJ, jfloatArray queryVectorJ, - jfloat radiusJ, jint maxResultWindowJ, - jintArray parentIdsJ) + jfloat radiusJ, jobject methodParamsJ, + jint maxResultWindowJ, jintArray parentIdsJ) { try { - return knn_jni::faiss_wrapper::RangeSearch(&jniUtil, env, indexPointerJ, queryVectorJ, radiusJ, maxResultWindowJ, parentIdsJ); + return knn_jni::faiss_wrapper::RangeSearch(&jniUtil, env, indexPointerJ, queryVectorJ, radiusJ, methodParamsJ, maxResultWindowJ, parentIdsJ); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } @@ -255,12 +255,11 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSea JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndexWithFilter(JNIEnv * env, jclass cls, jlong indexPointerJ, jfloatArray queryVectorJ, - jfloat radiusJ, jint maxResultWindowJ, + jfloat radiusJ, jobject methodParamsJ, jint maxResultWindowJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) { try { - return knn_jni::faiss_wrapper::RangeSearchWithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, radiusJ, - maxResultWindowJ, filterIdsJ, filterIdsTypeJ, parentIdsJ); + return knn_jni::faiss_wrapper::RangeSearchWithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, radiusJ, methodParamsJ, maxResultWindowJ, filterIdsJ, filterIdsTypeJ, parentIdsJ); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } diff --git a/jni/tests/faiss_wrapper_test.cpp b/jni/tests/faiss_wrapper_test.cpp index 030e10f752..c6663a19a4 100644 --- a/jni/tests/faiss_wrapper_test.cpp +++ b/jni/tests/faiss_wrapper_test.cpp @@ -787,6 +787,11 @@ TEST(FaissRangeSearchQueryIndexTest, BasicAssertions) { faiss::MetricType metricType = faiss::METRIC_L2; std::string method = "HNSW32,Flat"; + int efSearch = 20; + std::unordered_map methodParams; + methodParams[knn_jni::EF_SEARCH] = reinterpret_cast(&efSearch); + auto methodParamsJ = reinterpret_cast(&methodParams); + // Define query data int numQueries = 100; std::vector> queries; @@ -819,7 +824,7 @@ TEST(FaissRangeSearchQueryIndexTest, BasicAssertions) { knn_jni::faiss_wrapper::RangeSearch( &mockJNIUtil, jniEnv, reinterpret_cast(&createdIndexWithData), - reinterpret_cast(&query), rangeSearchRadius, maxResultWindow, nullptr))); + reinterpret_cast(&query), rangeSearchRadius, methodParamsJ, maxResultWindow, nullptr))); // assert result size is not 0 ASSERT_NE(0, results->size()); @@ -874,7 +879,7 @@ TEST(FaissRangeSearchQueryIndexTest_WhenHitMaxWindowResult, BasicAssertions){ knn_jni::faiss_wrapper::RangeSearch( &mockJNIUtil, jniEnv, reinterpret_cast(&createdIndexWithData), - reinterpret_cast(&query), rangeSearchRadius, maxResultWindow, nullptr))); + reinterpret_cast(&query), rangeSearchRadius, nullptr, maxResultWindow, nullptr))); // assert result size is not 0 ASSERT_NE(0, results->size()); @@ -940,7 +945,7 @@ TEST(FaissRangeSearchQueryIndexTestWithFilterTest, BasicAssertions) { knn_jni::faiss_wrapper::RangeSearchWithFilter( &mockJNIUtil, jniEnv, reinterpret_cast(&createdIndexWithData), - reinterpret_cast(&query), rangeSearchRadius, maxResultWindow, + reinterpret_cast(&query), rangeSearchRadius, nullptr, maxResultWindow, reinterpret_cast(&bitmap), 0, nullptr))); // assert result size is not 0 @@ -1015,7 +1020,7 @@ TEST(FaissRangeSearchQueryIndexTestWithParentFilterTest, BasicAssertions) { knn_jni::faiss_wrapper::RangeSearchWithFilter( &mockJNIUtil, jniEnv, reinterpret_cast(&createdIndexWithData), - reinterpret_cast(&query), rangeSearchRadius, maxResultWindow, nullptr, 0, + reinterpret_cast(&query), rangeSearchRadius, nullptr, maxResultWindow, nullptr, 0, reinterpret_cast(&parentIds)))); // assert result size is not 0 @@ -1032,4 +1037,4 @@ TEST(FaissRangeSearchQueryIndexTestWithParentFilterTest, BasicAssertions) { delete it; } } -} \ No newline at end of file +} diff --git a/jni/tests/faiss_wrapper_unit_test.cpp b/jni/tests/faiss_wrapper_unit_test.cpp index 77b38e3836..d9fdac23fb 100644 --- a/jni/tests/faiss_wrapper_unit_test.cpp +++ b/jni/tests/faiss_wrapper_unit_test.cpp @@ -22,23 +22,25 @@ #include "faiss/IndexIDMap.h" using ::testing::NiceMock; + using idx_t = faiss::idx_t; -struct FaissMockIndex : faiss::IndexHNSW { - explicit FaissMockIndex(idx_t d) : faiss::IndexHNSW(d, 32) { +struct MockIndex : faiss::IndexHNSW { + explicit MockIndex(idx_t d) : faiss::IndexHNSW(d, 32) { } }; - -struct FaissMockIdMap : faiss::IndexIDMap { - mutable idx_t nCalled; - mutable const float *xCalled; - mutable idx_t kCalled; - mutable float *distancesCalled; - mutable idx_t *labelsCalled; - mutable const faiss::SearchParametersHNSW *paramsCalled; - - explicit FaissMockIdMap(FaissMockIndex *index) : faiss::IndexIDMapTemplate(index) { +struct MockIdMap : faiss::IndexIDMap { + mutable idx_t nCalled{}; + mutable const float *xCalled{}; + mutable int kCalled{}; + mutable float radiusCalled{}; + mutable float *distancesCalled{}; + mutable idx_t *labelsCalled{}; + mutable const faiss::SearchParametersHNSW *paramsCalled{}; + mutable faiss::RangeSearchResult *resCalled{}; + + explicit MockIdMap(MockIndex *index) : faiss::IndexIDMapTemplate(index) { } void search( @@ -56,18 +58,33 @@ struct FaissMockIdMap : faiss::IndexIDMap { paramsCalled = dynamic_cast(params); } + void range_search( + idx_t n, + const float *x, + float radius, + faiss::RangeSearchResult *res, + const faiss::SearchParameters *params) const override { + nCalled = n; + xCalled = x; + radiusCalled = radius; + resCalled = res; + paramsCalled = dynamic_cast(params); + } + void resetMock() const { nCalled = 0; xCalled = nullptr; kCalled = 0; + radiusCalled = 0.0; distancesCalled = nullptr; labelsCalled = nullptr; + resCalled = nullptr; paramsCalled = nullptr; } }; struct QueryIndexHNSWTestInput { - string description; + std::string description; int k; int efSearch; int filterIdType; @@ -75,30 +92,46 @@ struct QueryIndexHNSWTestInput { bool parentIdsPresent; }; - +struct RangeSearchTestInput { + std::string description; + float radius; + int efSearch; + int filterIdType; + bool filterIdsPresent; + bool parentIdsPresent; +}; class FaissWrappeterParametrizedTestFixture : public testing::TestWithParam { public: FaissWrappeterParametrizedTestFixture() : index_(3), id_map_(&index_) { index_.hnsw.efSearch = 100; // assigning 100 to make sure default of 16 is not used anywhere - }; + } + +protected: + MockIndex index_; + MockIdMap id_map_; +}; + +class FaissWrapperParametrizedRangeSearchTestFixture : public testing::TestWithParam { +public: + FaissWrapperParametrizedRangeSearchTestFixture() : index_(3), id_map_(&index_) { + index_.hnsw.efSearch = 100; // assigning 100 to make sure default of 16 is not used anywhere + } protected: - FaissMockIndex index_; - FaissMockIdMap id_map_; + MockIndex index_; + MockIdMap id_map_; }; namespace query_index_test { std::unordered_map methodParams; - TEST_P(FaissWrappeterParametrizedTestFixture, QueryIndexHNSWTests) { - //Given + // Given JNIEnv *jniEnv = nullptr; NiceMock mockJNIUtil; - QueryIndexHNSWTestInput const &input = GetParam(); float query[] = {1.2, 2.3, 3.4}; @@ -136,7 +169,7 @@ namespace query_index_test { reinterpret_cast(&query), input.k, reinterpret_cast(&methodParams), reinterpret_cast(parentIdPtr)); - //Then + // Then int actualEfSearch = id_map_.paramsCalled->efSearch; // Asserting the captured argument EXPECT_EQ(input.k, id_map_.kCalled); @@ -164,11 +197,10 @@ namespace query_index_test { namespace query_index_with_filter_test { TEST_P(FaissWrappeterParametrizedTestFixture, QueryIndexWithFilterHNSWTests) { - //Given + // Given JNIEnv *jniEnv = nullptr; NiceMock mockJNIUtil; - QueryIndexHNSWTestInput const &input = GetParam(); float query[] = {1.2, 2.3, 3.4}; @@ -217,7 +249,7 @@ namespace query_index_with_filter_test { input.filterIdType, reinterpret_cast(parentIdPtr)); - //Then + // Then int actualEfSearch = id_map_.paramsCalled->efSearch; // Asserting the captured argument EXPECT_EQ(input.k, id_map_.kCalled); @@ -248,3 +280,93 @@ namespace query_index_with_filter_test { ) ); } + +namespace range_search_test { + + TEST_P(FaissWrapperParametrizedRangeSearchTestFixture, RangeSearchHNSWTests) { + // Given + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + + RangeSearchTestInput const &input = GetParam(); + float query[] = {1.2, 2.3, 3.4}; + float radius = input.radius; + int maxResultWindow = 100; // Set your max result window + + std::unordered_map methodParams; + int efSearch = input.efSearch; + int expectedEfSearch = 100; // default set in mock + if (efSearch != -1) { + expectedEfSearch = input.efSearch; + methodParams[knn_jni::EF_SEARCH] = reinterpret_cast(&efSearch); + } + + std::vector *parentIdPtr = nullptr; + if (input.parentIdsPresent) { + std::vector parentId; + parentId.reserve(2); + parentId.push_back(1); + parentId.push_back(2); + parentIdPtr = &parentId; + + EXPECT_CALL(mockJNIUtil, + GetJavaIntArrayLength( + jniEnv, reinterpret_cast(parentIdPtr))) + .WillOnce(testing::Return(parentId.size())); + + EXPECT_CALL(mockJNIUtil, + GetIntArrayElements( + jniEnv, reinterpret_cast(parentIdPtr), nullptr)) + .WillOnce(testing::Return(new int[2]{1, 2})); + } + + std::vector filter; + std::vector *filterptr = nullptr; + if (input.filterIdsPresent) { + filter.reserve(2); + filter.push_back(1); + filter.push_back(2); + filterptr = &filter; + } + + // When + knn_jni::faiss_wrapper::RangeSearchWithFilter( + &mockJNIUtil, jniEnv, + reinterpret_cast(&id_map_), + reinterpret_cast(&query), radius, reinterpret_cast(&methodParams), + maxResultWindow, + reinterpret_cast(filterptr), + input.filterIdType, + reinterpret_cast(parentIdPtr)); + + // Then + int actualEfSearch = id_map_.paramsCalled->efSearch; + // Asserting the captured argument + EXPECT_EQ(expectedEfSearch, actualEfSearch); + if (input.parentIdsPresent) { + faiss::IDGrouper *grouper = id_map_.paramsCalled->grp; + EXPECT_TRUE(grouper != nullptr); + } + if (input.filterIdsPresent) { + faiss::IDSelector *sel = id_map_.paramsCalled->sel; + EXPECT_TRUE(sel != nullptr); + } + id_map_.resetMock(); + } + + INSTANTIATE_TEST_CASE_P( + RangeSearchHNSWTests, + FaissWrapperParametrizedRangeSearchTestFixture, + ::testing::Values( + RangeSearchTestInput{"algoParams present, parent absent, filter absent", 10.0f, 200, 0, false, false}, + RangeSearchTestInput{"algoParams present, parent absent, filter absent, filter type 1", 10.0f, 200, 1, false, false}, + RangeSearchTestInput{"algoParams absent, parent absent, filter present", 10.0f, -1, 0, true, false}, + RangeSearchTestInput{"algoParams absent, parent absent, filter present, filter type 1", 10.0f, -1, 1, true, false}, + RangeSearchTestInput{"algoParams present, parent present, filter absent", 10.0f, 200, 0, false, true}, + RangeSearchTestInput{"algoParams present, parent present, filter absent, filter type 1", 10.0f, 150, 1, false, true}, + RangeSearchTestInput{"algoParams absent, parent present, filter present", 10.0f, -1, 0, true, true}, + RangeSearchTestInput{"algoParams absent, parent present, filter present, filter type 1", 10.0f, -1, 1, true, true} + ) + ); +} + diff --git a/src/main/java/org/opensearch/knn/index/VectorQueryType.java b/src/main/java/org/opensearch/knn/index/VectorQueryType.java index 4697a917ed..fb7bfaafd4 100644 --- a/src/main/java/org/opensearch/knn/index/VectorQueryType.java +++ b/src/main/java/org/opensearch/knn/index/VectorQueryType.java @@ -54,4 +54,8 @@ public KNNCounter getQueryWithFilterStatCounter() { public abstract KNNCounter getQueryStatCounter(); public abstract KNNCounter getQueryWithFilterStatCounter(); + + public boolean isRadialSearch() { + return this == MAX_DISTANCE || this == MIN_SCORE; + } } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java index d123cc1491..4b875d9a8b 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java @@ -7,6 +7,7 @@ import lombok.AllArgsConstructor; import lombok.Builder; +import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.Setter; import org.apache.lucene.search.BooleanClause; @@ -205,6 +206,7 @@ private boolean equalsTo(KNNQuery other) { @Setter @Getter @AllArgsConstructor + @EqualsAndHashCode public static class Context { int maxResultWindow; } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index edf79bb5f3..86d8031bd9 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -36,6 +36,7 @@ import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.query.parser.MethodParametersParser; import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.knn.index.util.QueryContext; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelUtil; @@ -214,16 +215,13 @@ private void validate() { throw new IllegalArgumentException(String.format("[%s] requires exactly one of k, distance or score to be set", NAME)); } - VectorQueryType vectorQueryType = VectorQueryType.MAX_DISTANCE; if (k != null) { - vectorQueryType = VectorQueryType.K; if (k <= 0 || k > K_MAX) { throw new IllegalArgumentException(String.format("[%s] requires k to be in the range (0, %d]", NAME, K_MAX)); } } if (minScore != null) { - vectorQueryType = VectorQueryType.MIN_SCORE; if (minScore <= 0) { throw new IllegalArgumentException(String.format("[%s] requires minScore to be greater than 0", NAME)); } @@ -237,12 +235,6 @@ private void validate() { ); } } - - // Update stats - vectorQueryType.getQueryStatCounter().increment(); - if (filter != null) { - vectorQueryType.getQueryWithFilterStatCounter().increment(); - } } } @@ -499,6 +491,8 @@ protected Query doToQuery(QueryShardContext context) { KNNEngine knnEngine = KNNEngine.DEFAULT; VectorDataType vectorDataType = knnVectorFieldType.getVectorDataType(); SpaceType spaceType = knnVectorFieldType.getSpaceType(); + VectorQueryType vectorQueryType = getVectorQueryType(k, maxDistance, minScore); + updateQueryStats(vectorQueryType); if (fieldDimension == -1) { if (spaceType != null) { @@ -521,16 +515,18 @@ protected Query doToQuery(QueryShardContext context) { final String method = methodComponentContext != null ? methodComponentContext.getName() : null; if (StringUtils.isNotBlank(method)) { final EngineSpecificMethodContext engineSpecificMethodContext = knnEngine.getMethodContext(method); + QueryContext queryContext = new QueryContext(vectorQueryType); ValidationException validationException = validateParameters( - engineSpecificMethodContext.supportedMethodParameters(), + engineSpecificMethodContext.supportedMethodParameters(queryContext), (Map) methodParameters ); if (validationException != null) { throw new IllegalArgumentException( String.format( - "Parameters not valid for [%s]:[%s] combination: [%s]", + "Parameters not valid for [%s]:[%s]:[%s] combination: [%s]", knnEngine, method, + vectorQueryType.getQueryTypeName(), validationException.getMessage() ) ); @@ -611,6 +607,7 @@ protected Query doToQuery(QueryShardContext context) { .byteVector(VectorDataType.BYTE == vectorDataType ? byteVector : null) .vectorDataType(vectorDataType) .radius(radius) + .methodParameters(this.methodParameters) .filter(this.filter) .context(context) .build(); @@ -633,6 +630,38 @@ private ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFie return modelMetadata; } + /** + * Function to get the vector query type based on the valid query parameter. + * + * @param k K nearest neighbours for the given vector, if k is set, then the query type is K + * @param maxDistance Maximum distance for the given vector, if maxDistance is set, then the query type is MAX_DISTANCE + * @param minScore Minimum score for the given vector, if minScore is set, then the query type is MIN_SCORE + */ + private VectorQueryType getVectorQueryType(int k, Float maxDistance, Float minScore) { + if (maxDistance != null) { + return VectorQueryType.MAX_DISTANCE; + } + if (minScore != null) { + return VectorQueryType.MIN_SCORE; + } + if (k != 0) { + return VectorQueryType.K; + } + throw new IllegalArgumentException(String.format("[%s] requires exactly one of k, distance or score to be set", NAME)); + } + + /** + * Function to update query stats. + * + * @param vectorQueryType The type of query to be executed + */ + private void updateQueryStats(VectorQueryType vectorQueryType) { + vectorQueryType.getQueryStatCounter().increment(); + if (filter != null) { + vectorQueryType.getQueryWithFilterStatCounter().increment(); + } + } + @Override protected boolean doEquals(KNNQueryBuilder other) { return Objects.equals(fieldName, other.fieldName) diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index a99ca5613e..fce8e8e04b 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -300,6 +300,7 @@ private Map doANNSearch(final LeafReaderContext context, final B indexAllocation.getMemoryAddress(), knnQuery.getQueryVector(), knnQuery.getRadius(), + knnQuery.getMethodParameters(), knnEngine, knnQuery.getContext().getMaxResultWindow(), filterIds, diff --git a/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java index db80848642..dd5efc93f6 100644 --- a/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java @@ -10,6 +10,7 @@ import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES; import java.util.Locale; +import java.util.Map; import lombok.extern.log4j.Log4j2; import org.apache.lucene.search.ByteVectorSimilarityQuery; @@ -69,6 +70,7 @@ public static Query create(RNNQueryFactory.CreateQueryRequest createQueryRequest final byte[] byteVector = createQueryRequest.getByteVector(); final VectorDataType vectorDataType = createQueryRequest.getVectorDataType(); final Query filterQuery = getFilterQuery(createQueryRequest); + final Map methodParameters = createQueryRequest.getMethodParameters(); if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) { BitSetProducer parentFilter = null; @@ -79,15 +81,17 @@ public static Query create(RNNQueryFactory.CreateQueryRequest createQueryRequest } IndexSettings indexSettings = context.getIndexSettings(); KNNQuery.Context knnQueryContext = new KNNQuery.Context(indexSettings.getMaxResultWindow()); - KNNQuery rnnQuery = new KNNQuery(fieldName, vector, indexName, parentFilter).radius(radius).kNNQueryContext(knnQueryContext); - if (filterQuery != null && KNNEngine.getEnginesThatSupportsFilters().contains(createQueryRequest.getKnnEngine())) { - log.debug("Creating custom radius search with filters for index: {}, field: {} , r: {}", indexName, fieldName, radius); - rnnQuery.filterQuery(filterQuery); - } - log.debug( - String.format("Creating custom radius search for index: %s \"\", field: %s \"\", r: %f", indexName, fieldName, radius) - ); - return rnnQuery; + + return KNNQuery.builder() + .field(fieldName) + .queryVector(vector) + .indexName(indexName) + .parentsFilter(parentFilter) + .radius(radius) + .methodParameters(methodParameters) + .context(knnQueryContext) + .filterQuery(filterQuery) + .build(); } log.debug(String.format("Creating Lucene r-NN query for index: %s \"\", field: %s \"\", k: %f", indexName, fieldName, radius)); diff --git a/src/main/java/org/opensearch/knn/index/util/DefaultHnswContext.java b/src/main/java/org/opensearch/knn/index/util/DefaultHnswContext.java index c16f1b05e5..c2bbb9e6f7 100644 --- a/src/main/java/org/opensearch/knn/index/util/DefaultHnswContext.java +++ b/src/main/java/org/opensearch/knn/index/util/DefaultHnswContext.java @@ -27,7 +27,7 @@ public final class DefaultHnswContext implements EngineSpecificMethodContext { .build(); @Override - public Map> supportedMethodParameters() { + public Map> supportedMethodParameters(QueryContext ctx) { return supportedMethodParameters; } } diff --git a/src/main/java/org/opensearch/knn/index/util/EngineSpecificMethodContext.java b/src/main/java/org/opensearch/knn/index/util/EngineSpecificMethodContext.java index f669704adf..edb8e830ae 100644 --- a/src/main/java/org/opensearch/knn/index/util/EngineSpecificMethodContext.java +++ b/src/main/java/org/opensearch/knn/index/util/EngineSpecificMethodContext.java @@ -25,7 +25,7 @@ */ public interface EngineSpecificMethodContext { - Map> supportedMethodParameters(); + Map> supportedMethodParameters(QueryContext ctx); - EngineSpecificMethodContext EMPTY = Collections::emptyMap; + EngineSpecificMethodContext EMPTY = ctx -> Collections.emptyMap(); } diff --git a/src/main/java/org/opensearch/knn/index/util/Lucene.java b/src/main/java/org/opensearch/knn/index/util/Lucene.java index ae6ea3a702..d98775f947 100644 --- a/src/main/java/org/opensearch/knn/index/util/Lucene.java +++ b/src/main/java/org/opensearch/knn/index/util/Lucene.java @@ -67,7 +67,7 @@ public class Lucene extends JVMLibrary { * @param distanceTransform Map of space type to distance transformation function */ Lucene(Map methods, String version, Map> distanceTransform) { - super(methods, Map.of(METHOD_HNSW, new DefaultHnswContext()), version); + super(methods, Map.of(METHOD_HNSW, new LuceneHNSWContext()), version); this.distanceTransform = distanceTransform; } diff --git a/src/main/java/org/opensearch/knn/index/util/LuceneHNSWContext.java b/src/main/java/org/opensearch/knn/index/util/LuceneHNSWContext.java new file mode 100644 index 0000000000..d9b6ba1c39 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/util/LuceneHNSWContext.java @@ -0,0 +1,36 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index.util; + +import com.google.common.collect.ImmutableMap; +import org.opensearch.knn.index.Parameter; +import org.opensearch.knn.index.query.request.MethodParameter; + +import java.util.Collections; +import java.util.Map; + +public class LuceneHNSWContext implements EngineSpecificMethodContext { + + private final Map> supportedMethodParameters = ImmutableMap.>builder() + .put(MethodParameter.EF_SEARCH.getName(), new Parameter.IntegerParameter(MethodParameter.EF_SEARCH.getName(), null, value -> true)) + .build(); + + @Override + public Map> supportedMethodParameters(QueryContext ctx) { + if (ctx.queryType.isRadialSearch()) { + // return empty map if radial search is true + return Collections.emptyMap(); + } + // Return the supported method parameters for non-radial cases + return supportedMethodParameters; + } +} diff --git a/src/main/java/org/opensearch/knn/index/util/QueryContext.java b/src/main/java/org/opensearch/knn/index/util/QueryContext.java new file mode 100644 index 0000000000..6bb4958146 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/util/QueryContext.java @@ -0,0 +1,23 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index.util; + +import lombok.AllArgsConstructor; +import org.opensearch.knn.index.VectorQueryType; + +/** + * Context class for query-specific information. + */ +@AllArgsConstructor +public class QueryContext { + VectorQueryType queryType; +} diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index 6b2990f373..f718ce6d5f 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -249,6 +249,7 @@ public static native KNNQueryResult[] queryBinaryIndexWithFilter( * @param indexPointer pointer to index in memory * @param queryVector vector to be used for query * @param radius search within radius threshold + * @param methodParameters parameters to be used for the query * @param indexMaxResultWindow maximum number of results to return * @param filteredIds list of doc ids to include in the query result * @param filterIdsType type of filter ids @@ -259,6 +260,7 @@ public static native KNNQueryResult[] rangeSearchIndexWithFilter( long indexPointer, float[] queryVector, float radius, + Map methodParameters, int indexMaxResultWindow, long[] filteredIds, int filterIdsType, @@ -271,6 +273,7 @@ public static native KNNQueryResult[] rangeSearchIndexWithFilter( * @param indexPointer pointer to index in memory * @param queryVector vector to be used for query * @param radius search within radius threshold + * @param methodParameters parameters to be used for the query * @param indexMaxResultWindow maximum number of results to return * @param parentIds list of parent doc ids when the knn field is a nested field * @return KNNQueryResult array of neighbors within radius @@ -279,6 +282,7 @@ public static native KNNQueryResult[] rangeSearchIndex( long indexPointer, float[] queryVector, float radius, + Map methodParameters, int indexMaxResultWindow, int[] parentIds ); diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index fa3a29e3af..ed6a169c10 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -328,6 +328,7 @@ public static long transferVectors(long vectorsPointer, float[][] trainingData) * @param indexPointer pointer to index in memory * @param queryVector vector to be used for query * @param radius search within radius threshold + * @param methodParameters parameters to be used when loading index * @param knnEngine engine to query index * @param indexMaxResultWindow maximum number of results to return * @param filteredIds list of doc ids to include in the query result @@ -339,6 +340,7 @@ public static KNNQueryResult[] radiusQueryIndex( long indexPointer, float[] queryVector, float radius, + @Nullable Map methodParameters, KNNEngine knnEngine, int indexMaxResultWindow, long[] filteredIds, @@ -351,13 +353,14 @@ public static KNNQueryResult[] radiusQueryIndex( indexPointer, queryVector, radius, + methodParameters, indexMaxResultWindow, filteredIds, filterIdsType, parentIds ); } - return FaissService.rangeSearchIndex(indexPointer, queryVector, radius, indexMaxResultWindow, parentIds); + return FaissService.rangeSearchIndex(indexPointer, queryVector, radius, methodParameters, indexMaxResultWindow, parentIds); } throw new IllegalArgumentException("RadiusQueryIndex not supported for provided engine"); } diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index 993f94bab5..b9116b0b15 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -59,6 +59,7 @@ import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.common.KNNConstants.METHOD_IVF; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; @@ -141,7 +142,7 @@ public void testEndToEnd_whenDoRadiusSearch_whenDistanceThreshold_whenMethodIsHN assertEquals(testData.indexData.docs.length, getDocCount(INDEX_NAME)); float distance = 300000000000f; - validateRadiusSearchResults(INDEX_NAME, FIELD_NAME, testData.queries, distance, null, spaceType, null); + validateRadiusSearchResults(INDEX_NAME, FIELD_NAME, testData.queries, distance, null, spaceType, null, null); // Delete index deleteKNNIndex(INDEX_NAME); @@ -201,7 +202,7 @@ public void testEndToEnd_whenDoRadiusSearch_whenScoreThreshold_whenMethodIsHNSWF float score = 0.00001f; - validateRadiusSearchResults(INDEX_NAME, FIELD_NAME, testData.queries, null, score, spaceType, null); + validateRadiusSearchResults(INDEX_NAME, FIELD_NAME, testData.queries, null, score, spaceType, null, null); // Delete index deleteKNNIndex(INDEX_NAME); @@ -261,7 +262,7 @@ public void testEndToEnd_whenDoRadiusSearch_whenMoreThanOneScoreThreshold_whenMe float score = 5f; - validateRadiusSearchResults(INDEX_NAME, FIELD_NAME, testData.queries, null, score, spaceType, null); + validateRadiusSearchResults(INDEX_NAME, FIELD_NAME, testData.queries, null, score, spaceType, null, null); // Delete index deleteKNNIndex(INDEX_NAME); @@ -345,8 +346,11 @@ public void testEndToEnd_whenDoRadiusSearch_whenDistanceThreshold_whenMethodIsHN assertEquals(testData.indexData.docs.length, getDocCount(indexName)); float distance = 300000000000f; + // create method parameter wih ef_search + Map methodParameters = new ImmutableMap.Builder().put(KNNConstants.METHOD_PARAMETER_EF_SEARCH, 150) + .build(); - validateRadiusSearchResults(indexName, fieldName, testData.queries, distance, null, spaceType, null); + validateRadiusSearchResults(indexName, fieldName, testData.queries, distance, null, spaceType, null, methodParameters); // Delete index deleteKNNIndex(indexName); @@ -381,7 +385,8 @@ public void testRadialQuery_withFilter_thenSuccess() { distance, null, SpaceType.L2, - termQueryBuilder + termQueryBuilder, + null ); assertEquals(1, queryResult.get(0).size()); @@ -1666,7 +1671,8 @@ private List> validateRadiusSearchResults( Float distanceThreshold, Float scoreThreshold, final SpaceType spaceType, - TermQueryBuilder filterQuery + TermQueryBuilder filterQuery, + Map methodParameters ) throws IOException, ParseException { List> queryResults = new ArrayList<>(); for (float[] queryVector : queryVectors) { @@ -1684,6 +1690,13 @@ private List> validateRadiusSearchResults( if (filterQuery != null) { queryBuilder.field("filter", filterQuery); } + if (methodParameters != null) { + queryBuilder.startObject(METHOD_PARAMETER); + for (Map.Entry entry : methodParameters.entrySet()) { + queryBuilder.field(entry.getKey(), entry.getValue()); + } + queryBuilder.endObject(); + } queryBuilder.endObject(); queryBuilder.endObject(); queryBuilder.endObject().endObject(); diff --git a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java index 874e96ba52..e05b90360c 100644 --- a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java +++ b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java @@ -36,6 +36,7 @@ import static org.opensearch.knn.common.KNNConstants.MAX_DISTANCE; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER; import static org.opensearch.knn.common.KNNConstants.MIN_SCORE; import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; @@ -339,7 +340,7 @@ public void testRadiusSearch_usingDistanceThreshold_usingL2Metrics_usingFloatTyp final float distance = 3.5f; final int[] expectedResults = { 2, 3, 2 }; - validateRadiusSearchResults(TEST_QUERY_VECTORS, distance, null, SpaceType.L2, expectedResults, null, null); + validateRadiusSearchResults(TEST_QUERY_VECTORS, distance, null, SpaceType.L2, expectedResults, null, null, null); } public void testRadiusSearch_usingScoreThreshold_usingL2Metrics_usingFloatType() throws Exception { @@ -351,7 +352,7 @@ public void testRadiusSearch_usingScoreThreshold_usingL2Metrics_usingFloatType() final float score = 0.23f; final int[] expectedResults = { 2, 3, 2 }; - validateRadiusSearchResults(TEST_QUERY_VECTORS, null, score, SpaceType.L2, expectedResults, null, null); + validateRadiusSearchResults(TEST_QUERY_VECTORS, null, score, SpaceType.L2, expectedResults, null, null, null); } public void testRadiusSearch_usingDistanceThreshold_usingCosineMetrics_usingFloatType() throws Exception { @@ -363,7 +364,7 @@ public void testRadiusSearch_usingDistanceThreshold_usingCosineMetrics_usingFloa final float distance = 0.03f; final int[] expectedResults = { 1, 1, 1 }; - validateRadiusSearchResults(TEST_QUERY_VECTORS, distance, null, SpaceType.COSINESIMIL, expectedResults, null, null); + validateRadiusSearchResults(TEST_QUERY_VECTORS, distance, null, SpaceType.COSINESIMIL, expectedResults, null, null, null); } public void testRadiusSearch_usingScoreThreshold_usingCosineMetrics_usingFloatType() throws Exception { @@ -375,7 +376,7 @@ public void testRadiusSearch_usingScoreThreshold_usingCosineMetrics_usingFloatTy final float score = 0.97f; final int[] expectedResults = { 1, 1, 1 }; - validateRadiusSearchResults(TEST_QUERY_VECTORS, null, score, SpaceType.COSINESIMIL, expectedResults, null, null); + validateRadiusSearchResults(TEST_QUERY_VECTORS, null, score, SpaceType.COSINESIMIL, expectedResults, null, null, null); } public void testRadiusSearch_usingScoreThreshold_usingInnerProductMetrics_usingFloatType() throws Exception { @@ -387,7 +388,7 @@ public void testRadiusSearch_usingScoreThreshold_usingInnerProductMetrics_usingF final float score = 2f; final int[] expectedResults = { 1, 1, 1 }; - validateRadiusSearchResults(TEST_QUERY_VECTORS, null, score, SpaceType.INNER_PRODUCT, expectedResults, null, null); + validateRadiusSearchResults(TEST_QUERY_VECTORS, null, score, SpaceType.INNER_PRODUCT, expectedResults, null, null, null); } public void testRadiusSearch_usingDistanceThreshold_usingL2Metrics_usingByteType() throws Exception { @@ -399,7 +400,7 @@ public void testRadiusSearch_usingDistanceThreshold_usingL2Metrics_usingByteType final float distance = 3.5f; final int[] expectedResults = { 2, 2, 2 }; - validateRadiusSearchResults(TEST_QUERY_VECTORS, distance, null, SpaceType.L2, expectedResults, null, null); + validateRadiusSearchResults(TEST_QUERY_VECTORS, distance, null, SpaceType.L2, expectedResults, null, null, null); } public void testRadiusSearch_usingScoreThreshold_usingL2Metrics_usingByteType() throws Exception { @@ -411,7 +412,7 @@ public void testRadiusSearch_usingScoreThreshold_usingL2Metrics_usingByteType() final float score = 0.23f; final int[] expectedResults = { 2, 2, 2 }; - validateRadiusSearchResults(TEST_QUERY_VECTORS, null, score, SpaceType.L2, expectedResults, null, null); + validateRadiusSearchResults(TEST_QUERY_VECTORS, null, score, SpaceType.L2, expectedResults, null, null, null); } public void testRadiusSearch_usingDistanceThreshold_usingCosineMetrics_usingByteType() throws Exception { @@ -423,7 +424,7 @@ public void testRadiusSearch_usingDistanceThreshold_usingCosineMetrics_usingByte final float distance = 0.05f; final int[] expectedResults = { 2, 2, 2 }; - validateRadiusSearchResults(TEST_QUERY_VECTORS, distance, null, SpaceType.COSINESIMIL, expectedResults, null, null); + validateRadiusSearchResults(TEST_QUERY_VECTORS, distance, null, SpaceType.COSINESIMIL, expectedResults, null, null, null); } public void testRadiusSearch_usingScoreThreshold_usingCosineMetrics_usingByteType() throws Exception { @@ -435,7 +436,7 @@ public void testRadiusSearch_usingScoreThreshold_usingCosineMetrics_usingByteTyp final float score = 0.97f; final int[] expectedResults = { 2, 2, 2 }; - validateRadiusSearchResults(TEST_QUERY_VECTORS, null, score, SpaceType.COSINESIMIL, expectedResults, null, null); + validateRadiusSearchResults(TEST_QUERY_VECTORS, null, score, SpaceType.COSINESIMIL, expectedResults, null, null, null); } public void testRadiusSearch_usingDistanceThreshold_withFilter_usingL2Metrics_usingFloatType() throws Exception { @@ -449,7 +450,7 @@ public void testRadiusSearch_usingDistanceThreshold_withFilter_usingL2Metrics_us final float distance = 45.0f; final int[] expectedResults = { 1, 1, 1 }; - validateRadiusSearchResults(TEST_QUERY_VECTORS, distance, null, SpaceType.L2, expectedResults, COLOR_FIELD_NAME, "red"); + validateRadiusSearchResults(TEST_QUERY_VECTORS, distance, null, SpaceType.L2, expectedResults, COLOR_FIELD_NAME, "red", null); } public void testRadiusSearch_usingScoreThreshold_withFilter_usingCosineMetrics_usingFloatType() throws Exception { @@ -463,7 +464,7 @@ public void testRadiusSearch_usingScoreThreshold_withFilter_usingCosineMetrics_u final float score = 0.02f; final int[] expectedResults = { 1, 1, 1 }; - validateRadiusSearchResults(TEST_QUERY_VECTORS, null, score, SpaceType.COSINESIMIL, expectedResults, COLOR_FIELD_NAME, "red"); + validateRadiusSearchResults(TEST_QUERY_VECTORS, null, score, SpaceType.COSINESIMIL, expectedResults, COLOR_FIELD_NAME, "red", null); } private void createKnnIndexMappingWithLuceneEngine(int dimension, SpaceType spaceType, VectorDataType vectorDataType) throws Exception { @@ -642,6 +643,25 @@ public void test_whenUsingIP_thenSuccess() { } } + @SneakyThrows + public void testRadialSearch_whenEfSearchIsSet_thenThrowException() { + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2, VectorDataType.FLOAT); + for (int j = 0; j < TEST_INDEX_VECTORS.length; j++) { + addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_INDEX_VECTORS[j]); + } + + final float score = 0.23f; + final int[] expectedResults = { 2, 3, 2 }; + + Map methodParameters = new ImmutableMap.Builder().put(KNNConstants.METHOD_PARAMETER_EF_SEARCH, 150) + .build(); + + expectThrows( + ResponseException.class, + () -> validateRadiusSearchResults(TEST_QUERY_VECTORS, null, score, SpaceType.L2, expectedResults, null, null, methodParameters) + ); + } + private void validateRadiusSearchResults( final float[][] searchVectors, final Float distanceThreshold, @@ -649,7 +669,8 @@ private void validateRadiusSearchResults( final SpaceType spaceType, final int[] expectedResults, @Nullable final String filterField, - @Nullable final String filterValue + @Nullable final String filterValue, + @Nullable final Map methodParameters ) throws Exception { for (int i = 0; i < searchVectors.length; i++) { XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("query"); @@ -670,6 +691,13 @@ private void validateRadiusSearchResults( builder.endObject(); builder.endObject(); } + if (methodParameters != null) { + builder.startObject(METHOD_PARAMETER); + for (Map.Entry entry : methodParameters.entrySet()) { + builder.field(entry.getKey(), entry.getValue()); + } + builder.endObject(); + } builder.endObject(); builder.endObject(); builder.endObject().endObject(); diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 968754c761..0492297cd8 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -161,7 +161,7 @@ public void testEmptyVector() { public void testFromXContent() throws Exception { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(queryVector).k(K).build(); XContentBuilder builder = XContentFactory.jsonBuilder(); builder.startObject(); builder.startObject(knnQueryBuilder.fieldName()); @@ -199,18 +199,22 @@ public void testFromXContent_KnnWithMethodParameters() throws Exception { assertEquals(knnQueryBuilder, actualBuilder); } - public void testFromXContent_whenDoRadiusSearch_whenDistanceThreshold_thenSucceed() throws Exception { + public void testFromXContent_whenDoRadiusSearch_whenDistanceThreshold_whenMethodParameter_thenSucceed() throws Exception { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() .fieldName(FIELD_NAME) .vector(queryVector) .maxDistance(MAX_DISTANCE) + .methodParameters(HNSW_METHOD_PARAMS) .build(); XContentBuilder builder = XContentFactory.jsonBuilder(); builder.startObject(); builder.startObject(knnQueryBuilder.fieldName()); builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); builder.field(KNNQueryBuilder.MAX_DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getMaxDistance()); + builder.startObject(org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER); + builder.field(EF_SEARCH_FIELD.getPreferredName(), EF_SEARCH); + builder.endObject(); builder.endObject(); builder.endObject(); XContentParser contentParser = createParser(builder); @@ -219,18 +223,22 @@ public void testFromXContent_whenDoRadiusSearch_whenDistanceThreshold_thenSuccee assertEquals(knnQueryBuilder, actualBuilder); } - public void testFromXContent_whenDoRadiusSearch_whenScoreThreshold_thenSucceed() throws Exception { + public void testFromXContent_whenDoRadiusSearch_whenScoreThreshold_whenMethodParameter_thenSucceed() throws Exception { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() .fieldName(FIELD_NAME) .vector(queryVector) .minScore(MAX_DISTANCE) + .methodParameters(HNSW_METHOD_PARAMS) .build(); XContentBuilder builder = XContentFactory.jsonBuilder(); builder.startObject(); builder.startObject(knnQueryBuilder.fieldName()); builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); builder.field(KNNQueryBuilder.MIN_SCORE_FIELD.getPreferredName(), knnQueryBuilder.getMinScore()); + builder.startObject(org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER); + builder.field(EF_SEARCH_FIELD.getPreferredName(), EF_SEARCH); + builder.endObject(); builder.endObject(); builder.endObject(); XContentParser contentParser = createParser(builder); @@ -246,7 +254,12 @@ public void testFromXContent_withFilter() throws Exception { knnClusterUtil.initialize(clusterService); float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY); + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .k(K) + .filter(TERM_QUERY) + .build(); XContentBuilder builder = XContentFactory.jsonBuilder(); builder.startObject(); builder.startObject(knnQueryBuilder.fieldName()); @@ -1165,4 +1178,58 @@ public void testRadialSearch_whenUnsupportedEngine_thenThrowException() { expectThrows(UnsupportedOperationException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); } } + + public void testRadialSearch_whenEfSearchIsSet_whenLuceneEngine_thenThrowException() { + KNNMethodContext knnMethodContext = new KNNMethodContext( + KNNEngine.LUCENE, + SpaceType.L2, + new MethodComponentContext(org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of()) + ); + + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(QUERY_VECTOR) + .maxDistance(MAX_DISTANCE) + .methodParameters(Map.of("ef_search", EF_SEARCH)) + .build(); + + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + Index dummyIndex = new Index("dummy", "dummy"); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + + expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); + } + + public void testRadialSearch_whenEfSearchIsSet_whenFaissEngine_thenSuccess() { + KNNMethodContext knnMethodContext = new KNNMethodContext( + KNNEngine.FAISS, + SpaceType.L2, + new MethodComponentContext(org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of()) + ); + + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(QUERY_VECTOR) + .minScore(MIN_SCORE) + .methodParameters(Map.of("ef_search", EF_SEARCH)) + .build(); + + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + Index dummyIndex = new Index("dummy", "dummy"); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + IndexSettings indexSettings = mock(IndexSettings.class); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + when(indexSettings.getMaxResultWindow()).thenReturn(1000); + + KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); + assertEquals(1 / MIN_SCORE - 1, query.getRadius(), 0); + } } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 0caf404014..dc34543680 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -63,7 +63,6 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; -import static org.mockito.ArgumentMatchers.anyFloat; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; @@ -806,12 +805,29 @@ public void testDoANNSearch_whenRadialIsDefined_thenCallJniRadiusQueryIndex() { final float radius = 0.5f; final int maxResults = 1000; jniServiceMockedStatic.when( - () -> JNIService.radiusQueryIndex(anyLong(), any(), anyFloat(), any(), anyInt(), any(), anyInt(), any()) + () -> JNIService.radiusQueryIndex( + anyLong(), + eq(queryVector), + eq(radius), + eq(HNSW_METHOD_PARAMETERS), + any(), + eq(maxResults), + any(), + anyInt(), + any() + ) ).thenReturn(getKNNQueryResults()); KNNQuery.Context context = mock(KNNQuery.Context.class); when(context.getMaxResultWindow()).thenReturn(maxResults); - final KNNQuery query = new KNNQuery(FIELD_NAME, queryVector, INDEX_NAME, null).radius(radius).kNNQueryContext(context); + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(queryVector) + .radius(radius) + .indexName(INDEX_NAME) + .context(context) + .methodParameters(HNSW_METHOD_PARAMETERS) + .build(); final float boost = (float) randomDoubleBetween(0, 10, true); final KNNWeight knnWeight = new KNNWeight(query, boost); @@ -859,7 +875,17 @@ public void testDoANNSearch_whenRadialIsDefined_thenCallJniRadiusQueryIndex() { final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); assertNotNull(knnScorer); jniServiceMockedStatic.verify( - () -> JNIService.radiusQueryIndex(anyLong(), any(), anyFloat(), any(), anyInt(), any(), anyInt(), any()) + () -> JNIService.radiusQueryIndex( + anyLong(), + eq(queryVector), + eq(radius), + eq(HNSW_METHOD_PARAMETERS), + any(), + eq(maxResults), + any(), + anyInt(), + any() + ) ); final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); diff --git a/src/test/java/org/opensearch/knn/index/query/RNNQueryFactoryTests.java b/src/test/java/org/opensearch/knn/index/query/RNNQueryFactoryTests.java index 5492b85066..af415f9c52 100644 --- a/src/test/java/org/opensearch/knn/index/query/RNNQueryFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/RNNQueryFactoryTests.java @@ -9,9 +9,11 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import org.apache.lucene.search.ByteVectorSimilarityQuery; @@ -38,6 +40,7 @@ public class RNNQueryFactoryTests extends KNNTestCase { private final String testFieldName = "test-field"; private final Float testRadius = 0.5f; private final int maxResultWindow = 20000; + private final Map methodParameters = Map.of(METHOD_PARAMETER_EF_SEARCH, 100); public void testCreate_whenLucene_withRadiusQuery_withFloatVector() { List luceneDefaultQueryEngineList = Arrays.stream(KNNEngine.values()) @@ -106,12 +109,24 @@ public void testCreate_whenLucene_withFilter_thenSucceed() { } public void testCreate_whenFaiss_thenSucceed() { + // Given QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); MappedFieldType testMapper = mock(MappedFieldType.class); IndexSettings indexSettings = mock(IndexSettings.class); when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper); when(mockQueryShardContext.getIndexSettings().getMaxResultWindow()).thenReturn(maxResultWindow); + + final KNNQuery expectedQuery = KNNQuery.builder() + .field(testFieldName) + .queryVector(testQueryVector) + .indexName(testIndexName) + .radius(testRadius) + .methodParameters(methodParameters) + .context(new KNNQuery.Context(maxResultWindow)) + .build(); + + // When final RNNQueryFactory.CreateQueryRequest createQueryRequest = RNNQueryFactory.CreateQueryRequest.builder() .knnEngine(KNNEngine.FAISS) .indexName(testIndexName) @@ -120,15 +135,12 @@ public void testCreate_whenFaiss_thenSucceed() { .radius(testRadius) .vectorDataType(DEFAULT_VECTOR_DATA_TYPE_FIELD) .context(mockQueryShardContext) + .methodParameters(methodParameters) .build(); Query query = RNNQueryFactory.create(createQueryRequest); - assertTrue(query instanceof KNNQuery); - assertEquals(testIndexName, ((KNNQuery) query).getIndexName()); - assertEquals(testFieldName, ((KNNQuery) query).getField()); - assertEquals(testQueryVector, ((KNNQuery) query).getQueryVector()); - assertEquals(testRadius, ((KNNQuery) query).getRadius(), 0); - assertEquals(maxResultWindow, ((KNNQuery) query).getContext().getMaxResultWindow()); + // Then + assertEquals(expectedQuery, query); } } diff --git a/src/test/java/org/opensearch/knn/index/util/AbstractKNNLibraryTests.java b/src/test/java/org/opensearch/knn/index/util/AbstractKNNLibraryTests.java index 0aab780423..8e5ae24f9e 100644 --- a/src/test/java/org/opensearch/knn/index/util/AbstractKNNLibraryTests.java +++ b/src/test/java/org/opensearch/knn/index/util/AbstractKNNLibraryTests.java @@ -11,12 +11,7 @@ import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.common.KNNConstants; -import org.opensearch.knn.index.KNNMethod; -import org.opensearch.knn.index.KNNMethodContext; -import org.opensearch.knn.index.MethodComponent; -import org.opensearch.knn.index.MethodComponentContext; -import org.opensearch.knn.index.Parameter; -import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.*; import java.io.IOException; import java.util.Collections; @@ -78,9 +73,13 @@ public ValidationException validate(KNNMethodContext knnMethodContext) { assertNotNull(testAbstractKNNLibrary2.validateMethod(knnMethodContext2)); } - public void testEngineSpecificMethods() throws IOException { + public void testEngineSpecificMethods() { String methodName1 = "test-method-1"; - EngineSpecificMethodContext context = () -> Map.of("myparameter", new Parameter.BooleanParameter("myparameter", false, o -> o)); + QueryContext engineSpecificMethodContext = new QueryContext(VectorQueryType.K); + EngineSpecificMethodContext context = ctx -> ImmutableMap.of( + "myparameter", + new Parameter.BooleanParameter("myparameter", null, value -> true) + ); TestAbstractKNNLibrary testAbstractKNNLibrary1 = new TestAbstractKNNLibrary( Collections.emptyMap(), @@ -89,7 +88,11 @@ public void testEngineSpecificMethods() throws IOException { ); assertNotNull(testAbstractKNNLibrary1.getMethodContext(methodName1)); - assertTrue(testAbstractKNNLibrary1.getMethodContext(methodName1).supportedMethodParameters().containsKey("myparameter")); + assertTrue( + testAbstractKNNLibrary1.getMethodContext(methodName1) + .supportedMethodParameters(engineSpecificMethodContext) + .containsKey("myparameter") + ); } public void testGetMethodAsMap() {