From 307609bbbff4889ecdf76d615157e223023c6df7 Mon Sep 17 00:00:00 2001 From: Junqiu Lei Date: Wed, 17 Apr 2024 12:58:26 -0700 Subject: [PATCH] Add radial search feature to main branch (#1617) * Support radial search in k-NN plugin Signed-off-by: Junqiu Lei --- CHANGELOG.md | 1 + jni/include/faiss_wrapper.h | 12 + .../org_opensearch_knn_jni_FaissService.h | 8 + jni/src/faiss_wrapper.cpp | 44 ++ .../org_opensearch_knn_jni_FaissService.cpp | 14 + jni/tests/faiss_wrapper_test.cpp | 113 ++++ .../opensearch/knn/common/KNNConstants.java | 2 + .../org/opensearch/knn/index/IndexUtil.java | 2 + .../org/opensearch/knn/index/SpaceType.java | 18 + .../knn/index/query/BaseQueryFactory.java | 95 ++++ .../opensearch/knn/index/query/KNNQuery.java | 66 ++- .../knn/index/query/KNNQueryBuilder.java | 246 ++++++++- .../knn/index/query/KNNQueryFactory.java | 120 +---- .../opensearch/knn/index/query/KNNWeight.java | 32 +- .../knn/index/query/RNNQueryFactory.java | 136 +++++ .../org/opensearch/knn/index/util/Faiss.java | 34 +- .../opensearch/knn/index/util/KNNEngine.java | 11 + .../opensearch/knn/index/util/KNNLibrary.java | 20 + .../org/opensearch/knn/index/util/Lucene.java | 33 +- .../org/opensearch/knn/index/util/Nmslib.java | 9 + .../org/opensearch/knn/jni/FaissService.java | 11 + .../org/opensearch/knn/jni/JNIService.java | 23 + .../org/opensearch/knn/index/FaissIT.java | 318 +++++++++++ .../opensearch/knn/index/LuceneEngineIT.java | 205 ++++++- .../knn/index/query/KNNQueryBuilderTests.java | 505 +++++++++++++++++- .../knn/index/query/KNNWeightTests.java | 63 +++ .../knn/index/query/RNNQueryFactoryTests.java | 134 +++++ .../index/util/AbstractKNNLibraryTests.java | 9 + .../knn/index/util/LuceneTests.java | 10 +- .../knn/index/util/NativeLibraryTests.java | 10 + .../LibraryInitializedSupplierTests.java | 10 + .../org/opensearch/knn/KNNRestTestCase.java | 17 + 32 files changed, 2171 insertions(+), 160 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java create mode 100644 src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java create mode 100644 src/test/java/org/opensearch/knn/index/query/RNNQueryFactoryTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 91a37642c..3a9da1d49 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.13...2.x) ### Features +* Support radial search in k-NN plugin [#1617](https://github.com/opensearch-project/k-NN/pull/1617) ### Enhancements * Make the HitQueue size more appropriate for exact search [#1549](https://github.com/opensearch-project/k-NN/pull/1549) * Support script score when doc value is disabled [#1573](https://github.com/opensearch-project/k-NN/pull/1573) diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index 3e1adeac4..da67c0f59 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -73,6 +73,18 @@ namespace knn_jni { // Return the serialized representation jbyteArray TrainIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension, jlong trainVectorsPointerJ); + + /* + * Perform a range search against the index located in memory at indexPointerJ. + * + * @param indexPointerJ - pointer to the index + * @param queryVectorJ - the query vector + * @param radiusJ - the radius for the range search + * @param maxResultsWindowJ - the maximum number of results to return + * @return an array of RangeQueryResults + */ + jobjectArray RangeSearch(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ, jfloatArray queryVectorJ, + jfloat radiusJ, jint maxResultsWindowJ); } } diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index 32b6f22f1..3715730ab 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -122,6 +122,14 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors (JNIEnv *, jclass, jlong, jobjectArray); +/* +* Class: org_opensearch_knn_jni_FaissService +* Method: rangeSearchIndex +* Signature: (J[F[F)J +*/ +JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndex + (JNIEnv *, jclass, jlong, jfloatArray, jfloat, jint); + #ifdef __cplusplus } #endif diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index 817bdb816..983cfa8a9 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -587,3 +587,47 @@ faiss::IndexIVFPQ * extractIVFPQIndex(faiss::Index * index) { throw std::runtime_error("Unable to extract IVFPQ index. IVFPQ index not present."); } + +jobjectArray knn_jni::faiss_wrapper::RangeSearch(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ, + jfloatArray queryVectorJ, jfloat radiusJ, jint maxResultWindowJ) { + if (queryVectorJ == nullptr) { + throw std::runtime_error("Query Vector cannot be null"); + } + + auto *indexReader = reinterpret_cast(indexPointerJ); + + if (indexReader == nullptr) { + throw std::runtime_error("Invalid pointer to indexReader"); + } + + float *rawQueryVector = jniUtil->GetFloatArrayElements(env, queryVectorJ, nullptr); + + // The res will be freed by ~RangeSearchResult() in FAISS + // The second parameter is always true, as lims is allocated by FAISS + faiss::RangeSearchResult res(1, true); + indexReader->range_search(1, rawQueryVector, radiusJ, &res); + + // lims is structured to support batched queries, it has a length of nq + 1 (where nq is the number of queries), + // lims[i] - lims[i-1] gives the number of results for the i-th query. With a single query we used in k-NN, + // res.lims[0] is always 0, and res.lims[1] gives the total number of matching entries found. + int resultSize = res.lims[1]; + + // Limit the result size to maxResultWindowJ so that we don't return more than the max result window + // TODO: In the future, we should prevent this via FAISS's ResultHandler. + if (resultSize > maxResultWindowJ) { + resultSize = maxResultWindowJ; + } + + jclass resultClass = jniUtil->FindClass(env,"org/opensearch/knn/index/query/KNNQueryResult"); + jmethodID allArgs = jniUtil->FindMethod(env, "org/opensearch/knn/index/query/KNNQueryResult", ""); + + jobjectArray results = jniUtil->NewObjectArray(env, resultSize, resultClass, nullptr); + + jobject result; + for(int i = 0; i < resultSize; ++i) { + result = jniUtil->NewObject(env, resultClass, allArgs, res.labels[i], res.distances[i]); + jniUtil->SetObjectArrayElement(env, results, i, result); + } + + return results; +} diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index 3249ed872..ab2a37e84 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -190,3 +190,17 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors return (jlong) vect; } + +JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndex(JNIEnv * env, jclass cls, + jlong indexPointerJ, + jfloatArray queryVectorJ, + jfloat radiusJ, jint maxResultWindowJ) +{ + try { + return knn_jni::faiss_wrapper::RangeSearch(&jniUtil, env, indexPointerJ, queryVectorJ, radiusJ, maxResultWindowJ); + + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return nullptr; +} diff --git a/jni/tests/faiss_wrapper_test.cpp b/jni/tests/faiss_wrapper_test.cpp index 05854f7ed..07b34976f 100644 --- a/jni/tests/faiss_wrapper_test.cpp +++ b/jni/tests/faiss_wrapper_test.cpp @@ -615,3 +615,116 @@ TEST(FaissInitAndSetSharedIndexState, BasicAssertions) { ASSERT_EQ(1, ivfpqIndex->use_precomputed_table); knn_jni::faiss_wrapper::FreeSharedIndexState(sharedModelAddress); } + +TEST(FaissRangeSearchQueryIndexTest, BasicAssertions) { + // Define the index data + faiss::idx_t numIds = 200; + int dim = 2; + std::vector ids = test_util::Range(numIds); + std::vector vectors = test_util::RandomVectors(dim, numIds, randomDataMin, randomDataMax); + + faiss::MetricType metricType = faiss::METRIC_L2; + std::string method = "HNSW32,Flat"; + + // Define query data + float radius = 100000.0; + int numQueries = 2; + std::vector> queries; + + for (int i = 0; i < numQueries; i++) { + std::vector query; + query.reserve(dim); + for (int j = 0; j < dim; j++) { + query.push_back(test_util::RandomFloat(randomDataMin, randomDataMax)); + } + queries.push_back(query); + } + + // Create the index + std::unique_ptr createdIndex( + test_util::FaissCreateIndex(dim, method, metricType)); + auto createdIndexWithData = + test_util::FaissAddData(createdIndex.get(), ids, vectors); + + // Setup jni + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + + int maxResultWindow = 20000; + + for (auto query : queries) { + std::unique_ptr *>> results( + reinterpret_cast *> *>( + + knn_jni::faiss_wrapper::RangeSearch( + &mockJNIUtil, jniEnv, + reinterpret_cast(&createdIndexWithData), + reinterpret_cast(&query), radius, maxResultWindow))); + + // assert result size is not 0 + ASSERT_NE(0, results->size()); + + + // Need to free up each result + for (auto it : *results) { + delete it; + } + } +} + +TEST(FaissRangeSearchQueryIndexTest_WhenHitMaxWindowResult, BasicAssertions){ + // Define the index data + faiss::idx_t numIds = 200; + int dim = 2; + std::vector ids = test_util::Range(numIds); + std::vector vectors = test_util::RandomVectors(dim, numIds, randomDataMin, randomDataMax); + + faiss::MetricType metricType = faiss::METRIC_L2; + std::string method = "HNSW32,Flat"; + + // Define query data + float radius = 100000.0; + int numQueries = 2; + std::vector> queries; + + for (int i = 0; i < numQueries; i++) { + std::vector query; + query.reserve(dim); + for (int j = 0; j < dim; j++) { + query.push_back(test_util::RandomFloat(randomDataMin, randomDataMax)); + } + queries.push_back(query); + } + + // Create the index + std::unique_ptr createdIndex( + test_util::FaissCreateIndex(dim, method, metricType)); + auto createdIndexWithData = + test_util::FaissAddData(createdIndex.get(), ids, vectors); + + // Setup jni + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + + int maxResultWindow = 10; + + for (auto query : queries) { + std::unique_ptr *>> results( + reinterpret_cast *> *>( + + knn_jni::faiss_wrapper::RangeSearch( + &mockJNIUtil, jniEnv, + reinterpret_cast(&createdIndexWithData), + reinterpret_cast(&query), radius, maxResultWindow))); + + // assert result size is not 0 + ASSERT_NE(0, results->size()); + // assert result size is equal to maxResultWindow + ASSERT_EQ(maxResultWindow, results->size()); + + // Need to free up each result + for (auto it : *results) { + delete it; + } + } +} diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index 98622ee85..69a2bc806 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -67,6 +67,8 @@ public class KNNConstants { public static final String VECTOR_DATA_TYPE_FIELD = "data_type"; public static final VectorDataType DEFAULT_VECTOR_DATA_TYPE_FIELD = VectorDataType.FLOAT; + public static final String RADIAL_SEARCH_KEY = "radial_search"; + // Lucene specific constants public static final String LUCENE_NAME = "lucene"; diff --git a/src/main/java/org/opensearch/knn/index/IndexUtil.java b/src/main/java/org/opensearch/knn/index/IndexUtil.java index adfa611c7..c71f767a5 100644 --- a/src/main/java/org/opensearch/knn/index/IndexUtil.java +++ b/src/main/java/org/opensearch/knn/index/IndexUtil.java @@ -43,11 +43,13 @@ public class IndexUtil { private static final Version MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED = Version.V_2_11_0; private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT = Version.V_2_12_0; private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_METHOD_COMPONENT_CONTEXT = Version.V_2_13_0; + private static final Version MINIMAL_SUPPORTED_VERSION_FOR_RADIAL_SEARCH = Version.V_2_14_0; private static final Map minimalRequiredVersionMap = new HashMap() { { put("ignore_unmapped", MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED); put(MODEL_NODE_ASSIGNMENT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT); put(MODEL_METHOD_COMPONENT_CONTEXT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_METHOD_COMPONENT_CONTEXT); + put(KNNConstants.RADIAL_SEARCH_KEY, MINIMAL_SUPPORTED_VERSION_FOR_RADIAL_SEARCH); } }; diff --git a/src/main/java/org/opensearch/knn/index/SpaceType.java b/src/main/java/org/opensearch/knn/index/SpaceType.java index 50d8d352c..240bfbe91 100644 --- a/src/main/java/org/opensearch/knn/index/SpaceType.java +++ b/src/main/java/org/opensearch/knn/index/SpaceType.java @@ -36,6 +36,14 @@ public float scoreTranslation(float rawScore) { public VectorSimilarityFunction getVectorSimilarityFunction() { return VectorSimilarityFunction.EUCLIDEAN; } + + @Override + public float scoreToDistanceTranslation(float score) { + if (score == 0) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "score cannot be 0 when space type is [%s]", getValue())); + } + return 1 / score - 1; + } }, COSINESIMIL("cosinesimil") { @Override @@ -170,4 +178,14 @@ public static SpaceType getSpace(String spaceTypeName) { } throw new IllegalArgumentException("Unable to find space: " + spaceTypeName); } + + /** + * Translate a score to a distance for this space type + * + * @param score score to translate + * @return translated distance + */ + public float scoreToDistanceTranslation(float score) { + throw new UnsupportedOperationException(String.format("Space [%s] does not have a score to distance translation", getValue())); + } } diff --git a/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java new file mode 100644 index 000000000..3146cd33e --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java @@ -0,0 +1,95 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import lombok.NonNull; +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.ToChildBlockJoinQuery; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.search.NestedHelper; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.util.KNNEngine; + +import java.io.IOException; +import java.util.Optional; + +/** +* Base class for creating vector search queries. +*/ +@Log4j2 +public abstract class BaseQueryFactory { + /** + * DTO object to hold data required to create a Query instance. + */ + @AllArgsConstructor + @Builder + @Getter + public static class CreateQueryRequest { + @NonNull + private KNNEngine knnEngine; + @NonNull + private String indexName; + private String fieldName; + private float[] vector; + private byte[] byteVector; + private VectorDataType vectorDataType; + private Integer k; + private Float radius; + private QueryBuilder filter; + private QueryShardContext context; + + public Optional getFilter() { + return Optional.ofNullable(filter); + } + + public Optional getContext() { + return Optional.ofNullable(context); + } + } + + /** + * Creates a query filter. + * + * @param createQueryRequest request object that has all required fields to construct the query + * @return Lucene Query + */ + protected static Query getFilterQuery(BaseQueryFactory.CreateQueryRequest createQueryRequest) { + if (!createQueryRequest.getFilter().isPresent()) { + return null; + } + + final QueryShardContext queryShardContext = createQueryRequest.getContext() + .orElseThrow(() -> new RuntimeException("Shard context cannot be null")); + log.debug( + String.format( + "Creating query with filter for index [%s], field [%s]", + createQueryRequest.getIndexName(), + createQueryRequest.getFieldName() + ) + ); + final Query filterQuery; + try { + filterQuery = createQueryRequest.getFilter().get().toQuery(queryShardContext); + } catch (IOException e) { + throw new RuntimeException("Cannot create query with filter", e); + } + BitSetProducer parentFilter = queryShardContext.getParentFilter(); + if (parentFilter != null) { + boolean mightMatch = new NestedHelper(queryShardContext.getMapperService()).mightMatchNestedDocs(filterQuery); + if (mightMatch) { + return filterQuery; + } + return new ToChildBlockJoinQuery(filterQuery, parentFilter); + } + return filterQuery; + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java index 9c78b18a1..0862b2d93 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java @@ -7,6 +7,8 @@ import java.util.Arrays; import java.util.Objects; + +import lombok.AllArgsConstructor; import lombok.Getter; import lombok.Setter; import org.apache.lucene.search.BooleanClause; @@ -30,7 +32,7 @@ public class KNNQuery extends Query { private final String field; private final float[] queryVector; - private final int k; + private int k; private final String indexName; @Getter @@ -38,6 +40,10 @@ public class KNNQuery extends Query { private Query filterQuery; @Getter private BitSetProducer parentsFilter; + @Getter + private Float radius = null; + @Getter + private Context context; public KNNQuery( final String field, @@ -69,6 +75,54 @@ public KNNQuery( this.parentsFilter = parentsFilter; } + /** + * Constructor for KNNQuery with query vector, index name and parent filter + * + * @param field field name + * @param queryVector query vector + * @param indexName index name + * @param parentsFilter parent filter + */ + public KNNQuery(String field, float[] queryVector, String indexName, BitSetProducer parentsFilter) { + this.field = field; + this.queryVector = queryVector; + this.indexName = indexName; + this.parentsFilter = parentsFilter; + } + + /** + * Constructor for KNNQuery with radius + * + * @param radius engine radius + * @return KNNQuery + */ + public KNNQuery radius(Float radius) { + this.radius = radius; + return this; + } + + /** + * Constructor for KNNQuery with Context + * + * @param context Context for KNNQuery + * @return KNNQuery + */ + public KNNQuery kNNQueryContext(Context context) { + this.context = context; + return this; + } + + /** + * Constructor for KNNQuery with filter query + * + * @param filterQuery filter query + * @return KNNQuery + */ + public KNNQuery filterQuery(Query filterQuery) { + this.filterQuery = filterQuery; + return this; + } + public String getField() { return this.field; } @@ -144,4 +198,14 @@ private boolean equalsTo(KNNQuery other) { && Objects.equals(indexName, other.indexName) && Objects.equals(filterQuery, other.filterQuery); } + + /** + * Context for KNNQuery + */ + @Setter + @Getter + @AllArgsConstructor + public static class Context { + int maxResultWindow; + } } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 2140487c5..7d3667ac0 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -7,23 +7,16 @@ import java.io.IOException; import java.util.Arrays; + import java.util.List; import java.util.Objects; + import lombok.extern.log4j.Log4j2; import org.apache.lucene.search.MatchNoDocsQuery; -import org.apache.lucene.search.Query; -import org.opensearch.core.ParseField; -import org.opensearch.core.common.ParsingException; import org.opensearch.core.common.Strings; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.NumberFieldMapper; -import org.opensearch.index.query.AbstractQueryBuilder; import org.opensearch.index.query.QueryBuilder; -import org.opensearch.index.query.QueryShardContext; +import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; @@ -32,9 +25,20 @@ import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.plugin.stats.KNNCounter; +import org.apache.lucene.search.Query; +import org.opensearch.core.ParseField; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.query.AbstractQueryBuilder; +import org.opensearch.index.query.QueryShardContext; import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue; import static org.opensearch.knn.index.IndexUtil.isClusterOnOrAfterMinRequiredVersion; +import static org.opensearch.knn.index.util.KNNEngine.ENGINES_SUPPORTING_RADIAL_SEARCH; /** * Helper class to build the KNN query @@ -47,6 +51,8 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { public static final ParseField K_FIELD = new ParseField("k"); public static final ParseField FILTER_FIELD = new ParseField("filter"); public static final ParseField IGNORE_UNMAPPED_FIELD = new ParseField("ignore_unmapped"); + public static final ParseField MAX_DISTANCE_FIELD = new ParseField("max_distance"); + public static final ParseField MIN_SCORE_FIELD = new ParseField("min_score"); public static final int K_MAX = 10000; /** * The name for the knn query @@ -58,11 +64,91 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { private final String fieldName; private final float[] vector; private int k = 0; + private Float max_distance = null; + private Float min_score = null; private QueryBuilder filter; private boolean ignoreUnmapped = false; /** - * Constructs a new knn query + * Constructs a new query with the given field name and vector + * + * @param fieldName Name of the field + * @param vector Array of floating points + */ + public KNNQueryBuilder(String fieldName, float[] vector) { + if (Strings.isNullOrEmpty(fieldName)) { + throw new IllegalArgumentException("[" + NAME + "] requires fieldName"); + } + if (vector == null) { + throw new IllegalArgumentException("[" + NAME + "] requires query vector"); + } + if (vector.length == 0) { + throw new IllegalArgumentException("[" + NAME + "] query vector is empty"); + } + this.fieldName = fieldName; + this.vector = vector; + } + + /** + * Builder method for k + * + * @param k K nearest neighbours for the given vector + */ + public KNNQueryBuilder k(Integer k) { + if (k == null) { + throw new IllegalArgumentException("[" + NAME + "] requires k to be set"); + } + validateSingleQueryType(k, max_distance, min_score); + if (k <= 0 || k > K_MAX) { + throw new IllegalArgumentException("[" + NAME + "] requires 0 < k <= " + K_MAX); + } + this.k = k; + return this; + } + + /** + * Builder method for max_distance + * + * @param max_distance the max_distance threshold for the nearest neighbours + */ + public KNNQueryBuilder maxDistance(Float max_distance) { + if (max_distance == null) { + throw new IllegalArgumentException("[" + NAME + "] requires max_distance to be set"); + } + validateSingleQueryType(k, max_distance, min_score); + this.max_distance = max_distance; + return this; + } + + /** + * Builder method for min_score + * + * @param min_score the min_score threshold for the nearest neighbours + */ + public KNNQueryBuilder minScore(Float min_score) { + if (min_score == null) { + throw new IllegalArgumentException("[" + NAME + "] requires min_score to be set"); + } + validateSingleQueryType(k, max_distance, min_score); + if (min_score <= 0) { + throw new IllegalArgumentException("[" + NAME + "] requires min_score greater than 0"); + } + this.min_score = min_score; + return this; + } + + /** + * Builder method for filter + * + * @param filter QueryBuilder + */ + public KNNQueryBuilder filter(QueryBuilder filter) { + this.filter = filter; + return this; + } + + /** + * Constructs a new query for top k search * * @param fieldName Name of the filed * @param vector Array of floating points @@ -94,6 +180,8 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder fil this.k = k; this.filter = filter; this.ignoreUnmapped = false; + this.max_distance = null; + this.min_score = null; } public static void initialize(ModelDao modelDao) { @@ -128,6 +216,12 @@ public KNNQueryBuilder(StreamInput in) throws IOException { if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) { ignoreUnmapped = in.readOptionalBoolean(); } + if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { + max_distance = in.readOptionalFloat(); + } + if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { + min_score = in.readOptionalFloat(); + } } catch (IOException ex) { throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex); } @@ -137,7 +231,9 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep String fieldName = null; List vector = null; float boost = AbstractQueryBuilder.DEFAULT_BOOST; - int k = 0; + Integer k = null; + Float max_distance = null; + Float min_score = null; QueryBuilder filter = null; String queryName = null; String currentFieldName = null; @@ -166,6 +262,10 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep } } else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { queryName = parser.text(); + } else if (MAX_DISTANCE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + max_distance = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false); + } else if (MIN_SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + min_score = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false); } else { throw new ParsingException( parser.getTokenLocation(), @@ -195,10 +295,21 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep } } - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector), k, filter); - knnQueryBuilder.ignoreUnmapped(ignoreUnmapped); - knnQueryBuilder.queryName(queryName); - knnQueryBuilder.boost(boost); + validateSingleQueryType(k, max_distance, min_score); + + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector)).filter(filter) + .ignoreUnmapped(ignoreUnmapped) + .boost(boost) + .queryName(queryName); + + if (k != null) { + knnQueryBuilder.k(k); + } else if (max_distance != null) { + knnQueryBuilder.maxDistance(max_distance); + } else if (min_score != null) { + knnQueryBuilder.minScore(min_score); + } + return knnQueryBuilder; } @@ -211,6 +322,12 @@ protected void doWriteTo(StreamOutput out) throws IOException { if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) { out.writeOptionalBoolean(ignoreUnmapped); } + if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { + out.writeOptionalFloat(max_distance); + } + if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { + out.writeOptionalFloat(min_score); + } } /** @@ -231,6 +348,14 @@ public int getK() { return this.k; } + public float getMaxDistance() { + return this.max_distance; + } + + public float getMinScore() { + return this.min_score; + } + public QueryBuilder getFilter() { return this.filter; } @@ -259,9 +384,15 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio if (filter != null) { builder.field(FILTER_FIELD.getPreferredName(), filter); } + if (max_distance != null) { + builder.field(MAX_DISTANCE_FIELD.getPreferredName(), max_distance); + } if (ignoreUnmapped) { builder.field(IGNORE_UNMAPPED_FIELD.getPreferredName(), ignoreUnmapped); } + if (min_score != null) { + builder.field(MIN_SCORE_FIELD.getPreferredName(), min_score); + } printBoostAndQueryName(builder); builder.endObject(); builder.endObject(); @@ -298,6 +429,24 @@ protected Query doToQuery(QueryShardContext context) { } else if (knnMethodContext != null) { // If the dimension is set but the knnMethodContext is not then the field is using the legacy mapping knnEngine = knnMethodContext.getKnnEngine(); + spaceType = knnMethodContext.getSpaceType(); + } + + // Currently, k-NN supports distance and score types radial search + // We need transform distance/score to right type of engine required radius. + Float radius = null; + if (this.max_distance != null) { + if (this.max_distance < 0 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) { + throw new IllegalArgumentException("[" + NAME + "] requires distance to be non-negative for space type: " + spaceType); + } + radius = knnEngine.distanceToRadialThreshold(this.max_distance, spaceType); + } + + if (this.min_score != null) { + if (this.min_score > 1 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) { + throw new IllegalArgumentException("[" + NAME + "] requires score to be in the range (0, 1] for space type: " + spaceType); + } + radius = knnEngine.scoreToRadialThreshold(this.min_score, spaceType); } if (fieldDimension != vector.length) { @@ -325,18 +474,39 @@ protected Query doToQuery(QueryShardContext context) { } String indexName = context.index().getName(); - KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder() - .knnEngine(knnEngine) - .indexName(indexName) - .fieldName(this.fieldName) - .vector(VectorDataType.FLOAT == vectorDataType ? this.vector : null) - .byteVector(VectorDataType.BYTE == vectorDataType ? byteVector : null) - .vectorDataType(vectorDataType) - .k(this.k) - .filter(this.filter) - .context(context) - .build(); - return KNNQueryFactory.create(createQueryRequest); + + if (k != 0) { + KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder() + .knnEngine(knnEngine) + .indexName(indexName) + .fieldName(this.fieldName) + .vector(VectorDataType.FLOAT == vectorDataType ? this.vector : null) + .byteVector(VectorDataType.BYTE == vectorDataType ? byteVector : null) + .vectorDataType(vectorDataType) + .k(this.k) + .filter(this.filter) + .context(context) + .build(); + return KNNQueryFactory.create(createQueryRequest); + } + if (radius != null) { + if (!ENGINES_SUPPORTING_RADIAL_SEARCH.contains(knnEngine)) { + throw new UnsupportedOperationException(String.format("Engine [%s] does not support radial search", knnEngine)); + } + RNNQueryFactory.CreateQueryRequest createQueryRequest = RNNQueryFactory.CreateQueryRequest.builder() + .knnEngine(knnEngine) + .indexName(indexName) + .fieldName(this.fieldName) + .vector(VectorDataType.FLOAT == vectorDataType ? this.vector : null) + .byteVector(VectorDataType.BYTE == vectorDataType ? byteVector : null) + .vectorDataType(vectorDataType) + .radius(radius) + .filter(this.filter) + .context(context) + .build(); + return RNNQueryFactory.create(createQueryRequest); + } + throw new IllegalArgumentException("[" + NAME + "] requires either k or distance or score to be set"); } private ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) { @@ -371,4 +541,24 @@ protected int doHashCode() { public String getWriteableName() { return NAME; } + + private static void validateSingleQueryType(Integer k, Float distance, Float score) { + int countSetFields = 0; + + if (k != null && k != 0) { + countSetFields++; + } + if (distance != null) { + countSetFields++; + } + if (score != null) { + countSetFields++; + } + + if (countSetFields != 1) { + throw new IllegalArgumentException( + "[" + NAME + "] requires only one query type to be set, it can be either k, distance, or score" + ); + } + } } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index 2ab0e62af..ec1f53d13 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -5,11 +5,6 @@ package org.opensearch.knn.index.query; -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Getter; -import lombok.NonNull; -import lombok.Setter; import lombok.extern.log4j.Log4j2; import org.apache.lucene.search.KnnByteVectorQuery; import org.apache.lucene.search.KnnFloatVectorQuery; @@ -17,16 +12,11 @@ import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; -import org.apache.lucene.search.join.ToChildBlockJoinQuery; -import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryShardContext; -import org.opensearch.index.search.NestedHelper; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; -import java.io.IOException; import java.util.Locale; -import java.util.Optional; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES; @@ -35,7 +25,7 @@ * Creates the Lucene k-NN queries */ @Log4j2 -public class KNNQueryFactory { +public class KNNQueryFactory extends BaseQueryFactory { /** * Creates a Lucene query for a particular engine. @@ -82,7 +72,12 @@ public static Query create(CreateQueryRequest createQueryRequest) { final VectorDataType vectorDataType = createQueryRequest.getVectorDataType(); final Query filterQuery = getFilterQuery(createQueryRequest); - BitSetProducer parentFilter = createQueryRequest.context == null ? null : createQueryRequest.context.getParentFilter(); + BitSetProducer parentFilter = null; + if (createQueryRequest.getContext().isPresent()) { + QueryShardContext context = createQueryRequest.getContext().get(); + parentFilter = context.getParentFilter(); + } + if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) { if (filterQuery != null && KNNEngine.getEnginesThatSupportsFilters().contains(createQueryRequest.getKnnEngine())) { log.debug("Creating custom k-NN query with filters for index: {}, field: {} , k: {}", indexName, fieldName, k); @@ -93,19 +88,21 @@ public static Query create(CreateQueryRequest createQueryRequest) { } log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); - if (VectorDataType.BYTE == vectorDataType) { - return getKnnByteVectorQuery(fieldName, byteVector, k, filterQuery, parentFilter); - } else if (VectorDataType.FLOAT == vectorDataType) { - return getKnnFloatVectorQuery(fieldName, vector, k, filterQuery, parentFilter); - } else { - throw new IllegalArgumentException( - String.format( - Locale.ROOT, - "Invalid value provided for [%s] field. Supported values are [%s]", - VECTOR_DATA_TYPE_FIELD, - SUPPORTED_VECTOR_DATA_TYPES - ) - ); + switch (vectorDataType) { + case BYTE: + return getKnnByteVectorQuery(fieldName, byteVector, k, filterQuery, parentFilter); + case FLOAT: + return getKnnFloatVectorQuery(fieldName, vector, k, filterQuery, parentFilter); + default: + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "Invalid value provided for [%s] field. Supported values are [%s], but got: %s", + VECTOR_DATA_TYPE_FIELD, + SUPPORTED_VECTOR_DATA_TYPES, + vectorDataType + ) + ); } } @@ -144,77 +141,4 @@ private static Query getKnnFloatVectorQuery( return new DiversifyingChildrenFloatKnnVectorQuery(fieldName, floatVector, filterQuery, k, parentFilter); } } - - private static Query getFilterQuery(CreateQueryRequest createQueryRequest) { - if (createQueryRequest.getFilter().isPresent()) { - final QueryShardContext queryShardContext = createQueryRequest.getContext() - .orElseThrow(() -> new RuntimeException("Shard context cannot be null")); - log.debug( - String.format( - "Creating k-NN query with filter for index [%s], field [%s] and k [%d]", - createQueryRequest.getIndexName(), - createQueryRequest.fieldName, - createQueryRequest.k - ) - ); - final Query filterQuery; - try { - filterQuery = createQueryRequest.getFilter().get().toQuery(queryShardContext); - } catch (IOException e) { - throw new RuntimeException("Cannot create knn query with filter", e); - } - // If k-NN Field is nested field then parentFilter will not be null. This parentFilter is set by the - // Opensearch core. Ref PR: https://github.com/opensearch-project/OpenSearch/pull/10246 - if (queryShardContext.getParentFilter() != null) { - // if the filter is also a nested query clause then we should just return the same query without - // considering it to join with the parent documents. - if (new NestedHelper(queryShardContext.getMapperService()).mightMatchNestedDocs(filterQuery)) { - return filterQuery; - } - // This condition will be hit when filters are getting applied on the top level fields and k-nn - // query field is a nested field. In this case we need to wrap the filter query with - // ToChildBlockJoinQuery to ensure parent documents which will be retrieved from filters can be - // joined with the child documents containing vector field. - return new ToChildBlockJoinQuery(filterQuery, queryShardContext.getParentFilter()); - } - return filterQuery; - } - return null; - } - - /** - * DTO object to hold data required to create a Query instance. - */ - @AllArgsConstructor - @Builder - @Setter - static class CreateQueryRequest { - @Getter - @NonNull - private KNNEngine knnEngine; - @Getter - @NonNull - private String indexName; - @Getter - private String fieldName; - @Getter - private float[] vector; - @Getter - private byte[] byteVector; - @Getter - private VectorDataType vectorDataType; - @Getter - private int k; - private QueryBuilder filter; - // can be null in cases filter not passed with the knn query - private QueryShardContext context; - - public Optional getFilter() { - return Optional.ofNullable(filter); - } - - public Optional getContext() { - return Optional.ofNullable(context); - } - } } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 06bf96d63..6b323e124 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -277,16 +277,25 @@ private Map doANNSearch(final LeafReaderContext context, final B throw new RuntimeException("Index has already been closed"); } int[] parentIds = getParentIdsArray(context); - results = JNIService.queryIndex( - indexAllocation.getMemoryAddress(), - knnQuery.getQueryVector(), - knnQuery.getK(), - knnEngine, - filterIds, - filterType.getValue(), - parentIds - ); - + if (knnQuery.getK() > 0) { + results = JNIService.queryIndex( + indexAllocation.getMemoryAddress(), + knnQuery.getQueryVector(), + knnQuery.getK(), + knnEngine, + filterIds, + filterType.getValue(), + parentIds + ); + } else { + results = JNIService.radiusQueryIndex( + indexAllocation.getMemoryAddress(), + knnQuery.getQueryVector(), + knnQuery.getRadius(), + knnEngine, + knnQuery.getContext().getMaxResultWindow() + ); + } } catch (Exception e) { GRAPH_QUERY_ERRORS.increment(); throw new RuntimeException(e); @@ -406,6 +415,9 @@ private boolean canDoExactSearch(final int filterIdsCount) { filterIdsCount, KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName()) ); + if (knnQuery.getRadius() != null) { + return false; + } int filterThresholdValue = KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName()); // Refer this GitHub around more details https://github.com/opensearch-project/k-NN/issues/1049 on the logic if (filterIdsCount <= knnQuery.getK()) { diff --git a/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java new file mode 100644 index 000000000..cd32ac4f3 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java @@ -0,0 +1,136 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query; + +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES; + +import java.util.Locale; + +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.search.ByteVectorSimilarityQuery; +import org.apache.lucene.search.FloatVectorSimilarityQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.join.BitSetProducer; +import org.opensearch.index.IndexSettings; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.util.KNNEngine; + +/** + * Class to create radius nearest neighbor queries + */ +@Log4j2 +public class RNNQueryFactory extends BaseQueryFactory { + + /** + * Creates a Lucene query for a particular engine. + * + * @param knnEngine Engine to create the query for + * @param indexName Name of the OpenSearch index that is being queried + * @param fieldName Name of the field in the OpenSearch index that will be queried + * @param vector The query vector to get the nearest neighbors for + * @param radius the radius threshold for the nearest neighbors + * @return Lucene Query + */ + public static Query create( + KNNEngine knnEngine, + String indexName, + String fieldName, + float[] vector, + Float radius, + VectorDataType vectorDataType + ) { + final CreateQueryRequest createQueryRequest = CreateQueryRequest.builder() + .knnEngine(knnEngine) + .indexName(indexName) + .fieldName(fieldName) + .vector(vector) + .vectorDataType(vectorDataType) + .radius(radius) + .build(); + return create(createQueryRequest); + } + + /** + * Creates a Lucene query for a particular engine. + * @param createQueryRequest request object that has all required fields to construct the query + * @return Lucene Query + */ + public static Query create(RNNQueryFactory.CreateQueryRequest createQueryRequest) { + final String indexName = createQueryRequest.getIndexName(); + final String fieldName = createQueryRequest.getFieldName(); + final Float radius = createQueryRequest.getRadius(); + final float[] vector = createQueryRequest.getVector(); + final byte[] byteVector = createQueryRequest.getByteVector(); + final VectorDataType vectorDataType = createQueryRequest.getVectorDataType(); + final Query filterQuery = getFilterQuery(createQueryRequest); + + if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) { + BitSetProducer parentFilter = null; + QueryShardContext context = createQueryRequest.getContext().get(); + + if (createQueryRequest.getContext().isPresent()) { + parentFilter = context.getParentFilter(); + } + IndexSettings indexSettings = context.getIndexSettings(); + KNNQuery.Context knnQueryContext = new KNNQuery.Context(indexSettings.getMaxResultWindow()); + KNNQuery rnnQuery = new KNNQuery(fieldName, vector, indexName, parentFilter).radius(radius).kNNQueryContext(knnQueryContext); + if (filterQuery != null && KNNEngine.getEnginesThatSupportsFilters().contains(createQueryRequest.getKnnEngine())) { + log.debug("Creating custom radius search with filters for index: {}, field: {} , r: {}", indexName, fieldName, radius); + rnnQuery.filterQuery(filterQuery); + } + log.debug( + String.format("Creating custom radius search for index: %s \"\", field: %s \"\", r: %f", indexName, fieldName, radius) + ); + return rnnQuery; + } + + log.debug(String.format("Creating Lucene r-NN query for index: %s \"\", field: %s \"\", k: %f", indexName, fieldName, radius)); + switch (vectorDataType) { + case BYTE: + return getByteVectorSimilarityQuery(fieldName, byteVector, radius, filterQuery); + case FLOAT: + return getFloatVectorSimilarityQuery(fieldName, vector, radius, filterQuery); + default: + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "Invalid value provided for [%s] field. Supported values are [%s], but got: %s", + VECTOR_DATA_TYPE_FIELD, + SUPPORTED_VECTOR_DATA_TYPES, + vectorDataType + ) + ); + } + } + + /** + * If radius is greater than 0, we return {@link FloatVectorSimilarityQuery} which will return all documents with similarity + * greater than or equal to the resultSimilarity. If filterQuery is not null, it will be used to filter the documents. + */ + private static Query getFloatVectorSimilarityQuery( + final String fieldName, + final float[] floatVector, + final float resultSimilarity, + final Query filterQuery + ) { + return new FloatVectorSimilarityQuery(fieldName, floatVector, resultSimilarity, filterQuery); + } + + /** + * If radius is greater than 0, we return {@link ByteVectorSimilarityQuery} which will return all documents with similarity + * greater than or equal to the resultSimilarity. If filterQuery is not null, it will be used to filter the documents. + */ + private static Query getByteVectorSimilarityQuery( + final String fieldName, + final byte[] byteVector, + final float resultSimilarity, + final Query filterQuery + ) { + return new ByteVectorSimilarityQuery(fieldName, byteVector, resultSimilarity, filterQuery); + } +} diff --git a/src/main/java/org/opensearch/knn/index/util/Faiss.java b/src/main/java/org/opensearch/knn/index/util/Faiss.java index 563311c49..efd8a637c 100644 --- a/src/main/java/org/opensearch/knn/index/util/Faiss.java +++ b/src/main/java/org/opensearch/knn/index/util/Faiss.java @@ -56,6 +56,7 @@ * Implements NativeLibrary for the faiss native library */ class Faiss extends NativeLibrary { + Map> scoreTransform; // TODO: Current version is not really current version. Instead, it encodes information in the file name // about the compatibility version the file is created with. In the future, we should refactor this so that it @@ -68,6 +69,12 @@ class Faiss extends NativeLibrary { rawScore -> SpaceType.INNER_PRODUCT.scoreTranslation(-1 * rawScore) ); + // Map that overrides radial search score threshold to faiss required distance, check more details in knn documentation: + // https://opensearch.org/docs/latest/search-plugins/knn/approximate-knn/#spaces + private final static Map> SCORE_TO_DISTANCE_TRANSFORMATIONS = ImmutableMap.< + SpaceType, + Function>builder().put(SpaceType.INNER_PRODUCT, score -> score > 1 ? 1 - score : 1 / score - 1).build(); + // Define encoders supported by faiss private final static MethodComponentContext ENCODER_DEFAULT = new MethodComponentContext( KNNConstants.ENCODER_FLAT, @@ -301,7 +308,13 @@ class Faiss extends NativeLibrary { ).addSpaces(SpaceType.L2, SpaceType.INNER_PRODUCT).build() ); - final static Faiss INSTANCE = new Faiss(METHODS, SCORE_TRANSLATIONS, CURRENT_VERSION, KNNConstants.FAISS_EXTENSION); + final static Faiss INSTANCE = new Faiss( + METHODS, + SCORE_TRANSLATIONS, + CURRENT_VERSION, + KNNConstants.FAISS_EXTENSION, + SCORE_TO_DISTANCE_TRANSFORMATIONS + ); /** * Constructor for Faiss @@ -315,9 +328,26 @@ private Faiss( Map methods, Map> scoreTranslation, String currentVersion, - String extension + String extension, + Map> scoreTransform ) { super(methods, scoreTranslation, currentVersion, extension); + this.scoreTransform = scoreTransform; + } + + @Override + public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { + // Faiss engine uses distance as is and does not need transformation + return distance; + } + + @Override + public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { + // Faiss engine uses distance as is and need transformation + if (this.scoreTransform.containsKey(spaceType)) { + return this.scoreTransform.get(spaceType).apply(score); + } + return spaceType.scoreToDistanceTranslation(score); } /** diff --git a/src/main/java/org/opensearch/knn/index/util/KNNEngine.java b/src/main/java/org/opensearch/knn/index/util/KNNEngine.java index 8d03d9a9e..e282c69db 100644 --- a/src/main/java/org/opensearch/knn/index/util/KNNEngine.java +++ b/src/main/java/org/opensearch/knn/index/util/KNNEngine.java @@ -32,6 +32,7 @@ public enum KNNEngine implements KNNLibrary { private static final Set CUSTOM_SEGMENT_FILE_ENGINES = ImmutableSet.of(KNNEngine.NMSLIB, KNNEngine.FAISS); private static final Set ENGINES_SUPPORTING_FILTERS = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS); + public static final Set ENGINES_SUPPORTING_RADIAL_SEARCH = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS); private static Map MAX_DIMENSIONS_BY_ENGINE = Map.of( KNNEngine.NMSLIB, @@ -152,6 +153,16 @@ public float score(float rawScore, SpaceType spaceType) { return knnLibrary.score(rawScore, spaceType); } + @Override + public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { + return knnLibrary.distanceToRadialThreshold(distance, spaceType); + } + + @Override + public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { + return knnLibrary.scoreToRadialThreshold(score, spaceType); + } + @Override public ValidationException validateMethod(KNNMethodContext knnMethodContext) { return knnLibrary.validateMethod(knnMethodContext); diff --git a/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java b/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java index ba1d3ac84..f837566b8 100644 --- a/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java +++ b/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java @@ -68,6 +68,26 @@ public interface KNNLibrary { */ float score(float rawScore, SpaceType spaceType); + /** + * Translate the distance radius input from end user to the engine's threshold. + * + * @param distance distance radius input from end user + * @param spaceType spaceType used to compute the radius + * + * @return transformed distance for the library + */ + Float distanceToRadialThreshold(Float distance, SpaceType spaceType); + + /** + * Translate the score threshold input from end user to the engine's threshold. + * + * @param score score threshold input from end user + * @param spaceType spaceType used to compute the threshold + * + * @return transformed score for the library + */ + Float scoreToRadialThreshold(Float score, SpaceType spaceType); + /** * Validate the knnMethodContext for the given library. A ValidationException should be thrown if the method is * deemed invalid. diff --git a/src/main/java/org/opensearch/knn/index/util/Lucene.java b/src/main/java/org/opensearch/knn/index/util/Lucene.java index 63642ae2c..630d7a2c2 100644 --- a/src/main/java/org/opensearch/knn/index/util/Lucene.java +++ b/src/main/java/org/opensearch/knn/index/util/Lucene.java @@ -15,6 +15,7 @@ import java.util.List; import java.util.Map; +import java.util.function.Function; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; @@ -25,6 +26,8 @@ */ public class Lucene extends JVMLibrary { + Map> distanceTransform; + final static Map METHODS = ImmutableMap.of( METHOD_HNSW, KNNMethod.Builder.builder( @@ -45,16 +48,27 @@ public class Lucene extends JVMLibrary { ).addSpaces(SpaceType.L2, SpaceType.COSINESIMIL, SpaceType.INNER_PRODUCT).build() ); - final static Lucene INSTANCE = new Lucene(METHODS, Version.LATEST.toString()); + // Map that overrides the default distance translations for Lucene, check more details in knn documentation: + // https://opensearch.org/docs/latest/search-plugins/knn/approximate-knn/#spaces + private final static Map> DISTANCE_TRANSLATIONS = ImmutableMap.< + SpaceType, + Function>builder() + .put(SpaceType.COSINESIMIL, distance -> (2 - distance) / 2) + .put(SpaceType.INNER_PRODUCT, distance -> distance <= 0 ? 1 / (1 - distance) : distance + 1) + .build(); + + final static Lucene INSTANCE = new Lucene(METHODS, Version.LATEST.toString(), DISTANCE_TRANSLATIONS); /** * Constructor * * @param methods Map of k-NN methods that the library supports * @param version String representing version of library + * @param distanceTransform Map of space type to distance transformation function */ - Lucene(Map methods, String version) { + Lucene(Map methods, String version, Map> distanceTransform) { super(methods, version); + this.distanceTransform = distanceTransform; } @Override @@ -75,6 +89,21 @@ public float score(float rawScore, SpaceType spaceType) { return rawScore; } + @Override + public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { + // Lucene requires score threshold to be parameterized when calling the radius search. + if (this.distanceTransform.containsKey(spaceType)) { + return this.distanceTransform.get(spaceType).apply(distance); + } + return spaceType.scoreTranslation(distance); + } + + @Override + public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { + // Lucene engine uses distance as is and does not need transformation + return score; + } + @Override public List mmapFileExtensions() { return List.of("vec", "vex"); diff --git a/src/main/java/org/opensearch/knn/index/util/Nmslib.java b/src/main/java/org/opensearch/knn/index/util/Nmslib.java index 617b311f4..64af43520 100644 --- a/src/main/java/org/opensearch/knn/index/util/Nmslib.java +++ b/src/main/java/org/opensearch/knn/index/util/Nmslib.java @@ -68,4 +68,13 @@ private Nmslib( ) { super(methods, scoreTranslation, currentVersion, extension); } + + @Override + public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { + return distance; + } + + public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { + return score; + } } diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index 32516ef9d..b59ac4bcf 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -190,4 +190,15 @@ public static native KNNQueryResult[] queryIndexWithFilter( */ @Deprecated(since = "2.14.0", forRemoval = true) public static native long transferVectors(long vectorsPointer, float[][] trainingData); + + /** + * Range search index + * + * @param indexPointer pointer to index in memory + * @param queryVector vector to be used for query + * @param radius search within radius threshold + * @param indexMaxResultWindow maximum number of results to return + * @return KNNQueryResult array of neighbors within radius + */ + public static native KNNQueryResult[] rangeSearchIndex(long indexPointer, float[] queryVector, float radius, int indexMaxResultWindow); } diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index 5a5b6794a..e846f02d1 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -262,4 +262,27 @@ public static byte[] trainIndex(Map indexParameters, int dimensi public static long transferVectors(long vectorsPointer, float[][] trainingData) { return FaissService.transferVectors(vectorsPointer, trainingData); } + + /** + * Range search index for a given query vector + * + * @param indexPointer pointer to index in memory + * @param queryVector vector to be used for query + * @param radius search within radius threshold + * @param knnEngine engine to query index + * @param indexMaxResultWindow maximum number of results to return + * @return KNNQueryResult array of neighbors within radius + */ + public static KNNQueryResult[] radiusQueryIndex( + long indexPointer, + float[] queryVector, + float radius, + KNNEngine knnEngine, + int indexMaxResultWindow + ) { + if (KNNEngine.FAISS == knnEngine) { + return FaissService.rangeSearchIndex(indexPointer, queryVector, radius, indexMaxResultWindow); + } + throw new IllegalArgumentException("RadiusQueryIndex not supported for provided engine"); + } } diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index 85dd3f169..bcefeb7f4 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -277,6 +277,283 @@ public void testEndToEnd_whenMethodIsHNSWFlatAndHasDeletedDocs_thenSucceed() { fail("Graphs are not getting evicted"); } + @SneakyThrows + public void testEndToEnd_whenDoRadiusSearch_whenDistanceThreshold_whenMethodIsHNSWFlat_thenSucceed() { + KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW); + SpaceType spaceType = SpaceType.L2; + + List mValues = ImmutableList.of(16, 32, 64, 128); + List efConstructionValues = ImmutableList.of(16, 32, 64, 128); + List efSearchValues = ImmutableList.of(16, 32, 64, 128); + + Integer dimension = testData.indexData.vectors[0].length; + + // Create an index + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(KNNConstants.PARAMETERS) + .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + Map mappingMap = xContentBuilderToMap(builder); + String mapping = builder.toString(); + + createKnnIndex(INDEX_NAME, mapping); + assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(INDEX_NAME))); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + INDEX_NAME, + Integer.toString(testData.indexData.docs[i]), + FIELD_NAME, + Floats.asList(testData.indexData.vectors[i]).toArray() + ); + } + + // Assert we have the right number of documents + refreshAllNonSystemIndices(); + assertEquals(testData.indexData.docs.length, getDocCount(INDEX_NAME)); + + float distance = 300000000000f; + validateRadiusSearchResults(INDEX_NAME, FIELD_NAME, testData.queries, distance, null, spaceType); + + // Delete index + deleteKNNIndex(INDEX_NAME); + } + + @SneakyThrows + public void testEndToEnd_whenDoRadiusSearch_whenScoreThreshold_whenMethodIsHNSWFlat_thenSucceed() { + KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW); + SpaceType spaceType = SpaceType.L2; + + List mValues = ImmutableList.of(16, 32, 64, 128); + List efConstructionValues = ImmutableList.of(16, 32, 64, 128); + List efSearchValues = ImmutableList.of(16, 32, 64, 128); + + Integer dimension = testData.indexData.vectors[0].length; + + // Create an index + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(KNNConstants.PARAMETERS) + .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + Map mappingMap = xContentBuilderToMap(builder); + String mapping = builder.toString(); + + createKnnIndex(INDEX_NAME, mapping); + assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(INDEX_NAME))); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + INDEX_NAME, + Integer.toString(testData.indexData.docs[i]), + FIELD_NAME, + Floats.asList(testData.indexData.vectors[i]).toArray() + ); + } + + // Assert we have the right number of documents + refreshAllNonSystemIndices(); + assertEquals(testData.indexData.docs.length, getDocCount(INDEX_NAME)); + + float score = 0.00001f; + + validateRadiusSearchResults(INDEX_NAME, FIELD_NAME, testData.queries, null, score, spaceType); + + // Delete index + deleteKNNIndex(INDEX_NAME); + } + + @SneakyThrows + public void testEndToEnd_whenDoRadiusSearch_whenMoreThanOneScoreThreshold_whenMethodIsHNSWFlat_thenSucceed() { + KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW); + SpaceType spaceType = SpaceType.INNER_PRODUCT; + + List mValues = ImmutableList.of(16, 32, 64, 128); + List efConstructionValues = ImmutableList.of(16, 32, 64, 128); + List efSearchValues = ImmutableList.of(16, 32, 64, 128); + + Integer dimension = testData.indexData.vectors[0].length; + + // Create an index + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(KNNConstants.PARAMETERS) + .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + Map mappingMap = xContentBuilderToMap(builder); + String mapping = builder.toString(); + + createKnnIndex(INDEX_NAME, mapping); + assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(INDEX_NAME))); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + INDEX_NAME, + Integer.toString(testData.indexData.docs[i]), + FIELD_NAME, + Floats.asList(testData.indexData.vectors[i]).toArray() + ); + } + + // Assert we have the right number of documents + refreshAllNonSystemIndices(); + assertEquals(testData.indexData.docs.length, getDocCount(INDEX_NAME)); + + float score = 5f; + + validateRadiusSearchResults(INDEX_NAME, FIELD_NAME, testData.queries, null, score, spaceType); + + // Delete index + deleteKNNIndex(INDEX_NAME); + } + + @SneakyThrows + public void testEndToEnd_whenDoRadiusSearch__whenDistanceThreshold_whenMethodIsHNSWPQ_thenSucceed() { + String indexName = "test-index"; + String fieldName = "test-field"; + String trainingIndexName = "training-index"; + String trainingFieldName = "training-field"; + + String modelId = "test-model"; + String modelDescription = "test model"; + + List mValues = ImmutableList.of(16, 32, 64, 128); + List efConstructionValues = ImmutableList.of(16, 32, 64, 128); + List efSearchValues = ImmutableList.of(16, 32, 64, 128); + List pqMValues = ImmutableList.of(2, 4, 8); + + // training data needs to be at least equal to the number of centroids for PQ + // which is 2^8 = 256. 8 because that's the only valid code_size for HNSWPQ + int trainingDataCount = 256; + + SpaceType spaceType = SpaceType.L2; + + int dimension = testData.indexData.vectors[0].length; + + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, METHOD_HNSW) + .field(KNN_ENGINE, FAISS_NAME) + .startObject(PARAMETERS) + .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, ENCODER_PQ) + .startObject(PARAMETERS) + .field(ENCODER_PARAMETER_PQ_M, pqMValues.get(random().nextInt(pqMValues.size()))) + .endObject() + .endObject() + .endObject() + .endObject(); + Map in = xContentBuilderToMap(xContentBuilder); + + createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension); + ingestDataAndTrainModel(modelId, trainingIndexName, trainingFieldName, dimension, modelDescription, in, trainingDataCount); + assertTrainingSucceeds(modelId, 360, 1000); + + // Create an index + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("model_id", modelId) + .endObject() + .endObject() + .endObject(); + + Map mappingMap = xContentBuilderToMap(builder); + String mapping = builder.toString(); + + createKnnIndex(indexName, mapping); + + assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName))); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + indexName, + Integer.toString(testData.indexData.docs[i]), + fieldName, + Floats.asList(testData.indexData.vectors[i]).toArray() + ); + } + + // Assert we have the right number of documents in the index + refreshAllNonSystemIndices(); + assertEquals(testData.indexData.docs.length, getDocCount(indexName)); + + float distance = 300000000000f; + + validateRadiusSearchResults(indexName, fieldName, testData.queries, distance, null, spaceType); + + // Delete index + deleteKNNIndex(indexName); + deleteModel(modelId); + + // Search every 5 seconds 14 times to confirm graph gets evicted + int intervals = 14; + for (int i = 0; i < intervals; i++) { + if (getTotalGraphsInCache() == 0) { + return; + } + + Thread.sleep(5 * 1000); + } + + fail("Graphs are not getting evicted"); + } + @SneakyThrows public void testEndToEnd_whenMethodIsHNSWPQ_thenSucceed() { String indexName = "test-index"; @@ -1410,4 +1687,45 @@ private void validateGraphEviction() throws Exception { fail("Graphs are not getting evicted"); } + + private void validateRadiusSearchResults( + String indexName, + String fieldName, + float[][] queryVectors, + Float distanceThreshold, + Float scoreThreshold, + final SpaceType spaceType + ) throws IOException, ParseException { + for (float[] queryVector : queryVectors) { + XContentBuilder queryBuilder = XContentFactory.jsonBuilder().startObject().startObject("query"); + queryBuilder.startObject("knn"); + queryBuilder.startObject(fieldName); + queryBuilder.field("vector", queryVector); + if (distanceThreshold != null) { + queryBuilder.field("max_distance", distanceThreshold); + } else if (scoreThreshold != null) { + queryBuilder.field("min_score", scoreThreshold); + } else { + throw new IllegalArgumentException("Invalid threshold"); + } + queryBuilder.endObject(); + queryBuilder.endObject(); + queryBuilder.endObject().endObject(); + final String responseBody = EntityUtils.toString(searchKNNIndex(indexName, queryBuilder, 10).getEntity()); + + List knnResults = parseSearchResponse(responseBody, fieldName); + + for (KNNResult knnResult : knnResults) { + float[] vector = knnResult.getVector(); + float distance = TestUtils.computeDistFromSpaceType(spaceType, vector, queryVector); + if (spaceType == SpaceType.L2) { + assertTrue(KNNScoringUtil.l2Squared(queryVector, vector) <= distance); + } else if (spaceType == SpaceType.INNER_PRODUCT) { + assertTrue(KNNScoringUtil.innerProduct(queryVector, vector) >= distance); + } else { + throw new IllegalArgumentException("Invalid space type"); + } + } + } + } } diff --git a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java index b17155704..ab55741d3 100644 --- a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java +++ b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java @@ -15,6 +15,7 @@ import org.junit.After; import org.opensearch.client.Response; import org.opensearch.client.ResponseException; +import org.opensearch.common.Nullable; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.index.query.QueryBuilders; @@ -33,7 +34,6 @@ import java.util.function.Function; import java.util.stream.Collectors; -import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; @@ -50,6 +50,14 @@ public class LuceneEngineIT extends KNNRestTestCase { private static final int M = 16; private static final Float[][] TEST_INDEX_VECTORS = { { 1.0f, 1.0f, 1.0f }, { 2.0f, 2.0f, 2.0f }, { 3.0f, 3.0f, 3.0f } }; + private static final Float[][] TEST_COSINESIMIL_INDEX_VECTORS = { { 6.0f, 7.0f, 3.0f }, { 3.0f, 2.0f, 5.0f }, { 4.0f, 5.0f, 7.0f } }; + private static final Float[][] TEST_INNER_PRODUCT_INDEX_VECTORS = { + { 1.0f, 1.0f, 1.0f }, + { 2.0f, 2.0f, 2.0f }, + { 3.0f, 3.0f, 3.0f }, + { -1.0f, -1.0f, -1.0f }, + { -2.0f, -2.0f, -2.0f }, + { -3.0f, -3.0f, -3.0f } }; private static final float[][] TEST_QUERY_VECTORS = { { 1.0f, 1.0f, 1.0f }, { 2.0f, 2.0f, 2.0f }, { 3.0f, 3.0f, 3.0f } }; @@ -59,7 +67,9 @@ public class LuceneEngineIT extends KNNRestTestCase { VectorSimilarityFunction.DOT_PRODUCT, (similarity) -> (1 + similarity) / 2, VectorSimilarityFunction.COSINE, - (similarity) -> (1 + similarity) / 2 + (similarity) -> (1 + similarity) / 2, + VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT, + (similarity) -> similarity <= 0 ? 1 / (1 - similarity) : similarity + 1 ); private static final String DIMENSION_FIELD_NAME = "dimension"; private static final String KNN_VECTOR_TYPE = "knn_vector"; @@ -318,6 +328,142 @@ public void testIndexReopening() throws Exception { assertArrayEquals(knnResultsBeforeIndexClosure.toArray(), knnResultsAfterIndexClosure.toArray()); } + public void testRadiusSearch_usingDistanceThreshold_usingL2Metrics_usingFloatType() throws Exception { + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2, VectorDataType.FLOAT); + for (int j = 0; j < TEST_INDEX_VECTORS.length; j++) { + addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_INDEX_VECTORS[j]); + } + + final float distance = 3.5f; + final int[] expectedResults = { 2, 3, 2 }; + + validateRadiusSearchResults(TEST_QUERY_VECTORS, distance, null, SpaceType.L2, expectedResults, null, null); + } + + public void testRadiusSearch_usingScoreThreshold_usingL2Metrics_usingFloatType() throws Exception { + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2, VectorDataType.FLOAT); + for (int j = 0; j < TEST_INDEX_VECTORS.length; j++) { + addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_INDEX_VECTORS[j]); + } + + final float score = 0.23f; + final int[] expectedResults = { 2, 3, 2 }; + + validateRadiusSearchResults(TEST_QUERY_VECTORS, null, score, SpaceType.L2, expectedResults, null, null); + } + + public void testRadiusSearch_usingDistanceThreshold_usingCosineMetrics_usingFloatType() throws Exception { + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.COSINESIMIL, VectorDataType.FLOAT); + for (int j = 0; j < TEST_COSINESIMIL_INDEX_VECTORS.length; j++) { + addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_COSINESIMIL_INDEX_VECTORS[j]); + } + + final float distance = 0.03f; + final int[] expectedResults = { 1, 1, 1 }; + + validateRadiusSearchResults(TEST_QUERY_VECTORS, distance, null, SpaceType.COSINESIMIL, expectedResults, null, null); + } + + public void testRadiusSearch_usingScoreThreshold_usingCosineMetrics_usingFloatType() throws Exception { + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.COSINESIMIL, VectorDataType.FLOAT); + for (int j = 0; j < TEST_COSINESIMIL_INDEX_VECTORS.length; j++) { + addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_COSINESIMIL_INDEX_VECTORS[j]); + } + + final float score = 0.97f; + final int[] expectedResults = { 1, 1, 1 }; + + validateRadiusSearchResults(TEST_QUERY_VECTORS, null, score, SpaceType.COSINESIMIL, expectedResults, null, null); + } + + public void testRadiusSearch_usingScoreThreshold_usingInnerProductMetrics_usingFloatType() throws Exception { + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.INNER_PRODUCT, VectorDataType.FLOAT); + for (int j = 0; j < TEST_INNER_PRODUCT_INDEX_VECTORS.length; j++) { + addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_INNER_PRODUCT_INDEX_VECTORS[j]); + } + + final float score = 2f; + final int[] expectedResults = { 1, 1, 1 }; + + validateRadiusSearchResults(TEST_QUERY_VECTORS, null, score, SpaceType.INNER_PRODUCT, expectedResults, null, null); + } + + public void testRadiusSearch_usingDistanceThreshold_usingL2Metrics_usingByteType() throws Exception { + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2, VectorDataType.BYTE); + for (int j = 0; j < TEST_INDEX_VECTORS.length; j++) { + addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_INDEX_VECTORS[j]); + } + + final float distance = 3.5f; + final int[] expectedResults = { 2, 2, 2 }; + + validateRadiusSearchResults(TEST_QUERY_VECTORS, distance, null, SpaceType.L2, expectedResults, null, null); + } + + public void testRadiusSearch_usingScoreThreshold_usingL2Metrics_usingByteType() throws Exception { + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2, VectorDataType.BYTE); + for (int j = 0; j < TEST_INDEX_VECTORS.length; j++) { + addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_INDEX_VECTORS[j]); + } + + final float score = 0.23f; + final int[] expectedResults = { 2, 2, 2 }; + + validateRadiusSearchResults(TEST_QUERY_VECTORS, null, score, SpaceType.L2, expectedResults, null, null); + } + + public void testRadiusSearch_usingDistanceThreshold_usingCosineMetrics_usingByteType() throws Exception { + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.COSINESIMIL, VectorDataType.BYTE); + for (int j = 0; j < TEST_COSINESIMIL_INDEX_VECTORS.length; j++) { + addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_COSINESIMIL_INDEX_VECTORS[j]); + } + + final float distance = 0.05f; + final int[] expectedResults = { 2, 2, 2 }; + + validateRadiusSearchResults(TEST_QUERY_VECTORS, distance, null, SpaceType.COSINESIMIL, expectedResults, null, null); + } + + public void testRadiusSearch_usingScoreThreshold_usingCosineMetrics_usingByteType() throws Exception { + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.COSINESIMIL, VectorDataType.BYTE); + for (int j = 0; j < TEST_COSINESIMIL_INDEX_VECTORS.length; j++) { + addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_COSINESIMIL_INDEX_VECTORS[j]); + } + + final float score = 0.97f; + final int[] expectedResults = { 2, 2, 2 }; + + validateRadiusSearchResults(TEST_QUERY_VECTORS, null, score, SpaceType.COSINESIMIL, expectedResults, null, null); + } + + public void testRadiusSearch_usingDistanceThreshold_withFilter_usingL2Metrics_usingFloatType() throws Exception { + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2, VectorDataType.FLOAT); + addKnnDocWithAttributes(DOC_ID, new float[] { 6.0f, 7.9f, 3.1f }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); + addKnnDocWithAttributes(DOC_ID_2, new float[] { 3.2f, 2.1f, 4.8f }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); + addKnnDocWithAttributes(DOC_ID_3, new float[] { 4.1f, 5.0f, 7.1f }, ImmutableMap.of(COLOR_FIELD_NAME, "green")); + + refreshIndex(INDEX_NAME); + + final float distance = 45.0f; + final int[] expectedResults = { 1, 1, 1 }; + + validateRadiusSearchResults(TEST_QUERY_VECTORS, distance, null, SpaceType.L2, expectedResults, COLOR_FIELD_NAME, "red"); + } + + public void testRadiusSearch_usingScoreThreshold_withFilter_usingCosineMetrics_usingFloatType() throws Exception { + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.COSINESIMIL, VectorDataType.FLOAT); + addKnnDocWithAttributes(DOC_ID, new float[] { 6.0f, 7.9f, 3.1f }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); + addKnnDocWithAttributes(DOC_ID_2, new float[] { 3.2f, 2.1f, 4.8f }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); + addKnnDocWithAttributes(DOC_ID_3, new float[] { 4.1f, 5.0f, 7.1f }, ImmutableMap.of(COLOR_FIELD_NAME, "green")); + + refreshIndex(INDEX_NAME); + + final float score = 0.02f; + final int[] expectedResults = { 1, 1, 1 }; + + validateRadiusSearchResults(TEST_QUERY_VECTORS, null, score, SpaceType.COSINESIMIL, expectedResults, COLOR_FIELD_NAME, "red"); + } + private void createKnnIndexMappingWithLuceneEngine(int dimension, SpaceType spaceType, VectorDataType vectorDataType) throws Exception { XContentBuilder builder = XContentFactory.jsonBuilder() .startObject() @@ -467,4 +613,59 @@ public void test_whenUsingIP_thenSuccess() { assertEquals(expectedScores.get(i), knnResults.get(i), 0.0000001); } } + + private void validateRadiusSearchResults( + final float[][] searchVectors, + final Float distanceThreshold, + final Float scoreThreshold, + final SpaceType spaceType, + final int[] expectedResults, + @Nullable final String filterField, + @Nullable final String filterValue + ) throws Exception { + for (int i = 0; i < searchVectors.length; i++) { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("query"); + builder.startObject("knn"); + builder.startObject(FIELD_NAME); + builder.field("vector", searchVectors[i]); + if (distanceThreshold != null) { + builder.field("max_distance", distanceThreshold); + } else if (scoreThreshold != null) { + builder.field("min_score", scoreThreshold); + } else { + throw new IllegalArgumentException("Either distance or score must be provided"); + } + if (filterField != null && filterValue != null) { + builder.startObject("filter"); + builder.startObject("term"); + builder.field(filterField, filterValue); + builder.endObject(); + builder.endObject(); + } + builder.endObject(); + builder.endObject(); + builder.endObject().endObject(); + + final String responseBody = EntityUtils.toString(searchKNNIndex(INDEX_NAME, builder, expectedResults[i]).getEntity()); + final List radiusResults = parseSearchResponse(responseBody, FIELD_NAME); + + assertEquals(expectedResults[i], radiusResults.size()); + + List actualScores = parseSearchResponseScore(responseBody, FIELD_NAME); + for (KNNResult result : radiusResults) { + float[] vector = result.getVector(); + float distance = TestUtils.computeDistFromSpaceType(spaceType, vector, searchVectors[i]); + float rawScore = VECTOR_SIMILARITY_TO_SCORE.get(spaceType.getVectorSimilarityFunction()).apply(distance); + if (spaceType == SpaceType.COSINESIMIL) { + distance = 1 - distance; + } + if (distanceThreshold != null) { + assertTrue(distance <= distanceThreshold); + } else { + assertTrue(rawScore >= scoreThreshold); + } + assertEquals(KNNEngine.LUCENE.score(rawScore, spaceType), actualScores.get(radiusResults.indexOf(result)), 0.0001); + } + } + } } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index bcd784e23..1922e5a08 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -6,6 +6,7 @@ package org.opensearch.knn.index.query; import com.google.common.collect.ImmutableMap; +import org.apache.lucene.search.FloatVectorSimilarityQuery; import java.util.Locale; import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.MatchNoDocsQuery; @@ -18,6 +19,7 @@ import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.IndexSettings; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.TermQueryBuilder; @@ -41,8 +43,10 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Optional; +import java.util.stream.Collectors; import static org.hamcrest.Matchers.instanceOf; import static org.mockito.Mockito.anyString; @@ -50,11 +54,14 @@ import static org.mockito.Mockito.when; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.index.KNNClusterTestUtils.mockClusterService; +import static org.opensearch.knn.index.util.KNNEngine.ENGINES_SUPPORTING_RADIAL_SEARCH; public class KNNQueryBuilderTests extends KNNTestCase { private static final String FIELD_NAME = "myvector"; private static final int K = 1; + private static final Float MAX_DISTANCE = 1.0f; + private static final Float MIN_SCORE = 0.5f; private static final TermQueryBuilder TERM_QUERY = QueryBuilders.termQuery("field", "value"); private static final float[] QUERY_VECTOR = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; @@ -77,6 +84,32 @@ public void testInvalidK() { expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector, KNNQueryBuilder.K_MAX + K)); } + public void testInvalidDistance() { + float[] queryVector = { 1.0f, 1.0f }; + /** + * null distance + */ + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(null)); + } + + public void testInvalidScore() { + float[] queryVector = { 1.0f, 1.0f }; + /** + * null min_score + */ + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(null)); + + /** + * negative min_score + */ + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(-1.0f)); + + /** + * min_score = 0 + */ + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(0.0f)); + } + public void testEmptyVector() { /** * null query vector @@ -89,6 +122,18 @@ public void testEmptyVector() { */ float[] queryVector1 = {}; expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector1, K)); + + /** + * null query vector with distance + */ + float[] queryVector2 = null; + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector2).maxDistance(MAX_DISTANCE)); + + /** + * empty query vector with distance + */ + float[] queryVector3 = {}; + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector3).maxDistance(MAX_DISTANCE)); } public void testFromXContent() throws Exception { @@ -107,7 +152,39 @@ public void testFromXContent() throws Exception { assertEquals(knnQueryBuilder, actualBuilder); } - public void testFromXContent_WithFilter() throws Exception { + public void testFromXContent_whenDoRadiusSearch_whenDistanceThreshold_thenSucceed() throws Exception { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(MAX_DISTANCE); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(knnQueryBuilder.fieldName()); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); + builder.field(KNNQueryBuilder.MAX_DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getMaxDistance()); + builder.endObject(); + builder.endObject(); + XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + KNNQueryBuilder actualBuilder = KNNQueryBuilder.fromXContent(contentParser); + assertEquals(knnQueryBuilder, actualBuilder); + } + + public void testFromXContent_whenDoRadiusSearch_whenScoreThreshold_thenSucceed() throws Exception { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(MAX_DISTANCE); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(knnQueryBuilder.fieldName()); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); + builder.field(KNNQueryBuilder.MAX_DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getMinScore()); + builder.endObject(); + builder.endObject(); + XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + KNNQueryBuilder actualBuilder = KNNQueryBuilder.fromXContent(contentParser); + assertEquals(knnQueryBuilder, actualBuilder); + } + + public void testFromXContent_withFilter() throws Exception { final ClusterService clusterService = mockClusterService(Version.CURRENT); final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); @@ -129,7 +206,51 @@ public void testFromXContent_WithFilter() throws Exception { assertEquals(knnQueryBuilder, actualBuilder); } - public void testFromXContent_invalidQueryVectorType() throws Exception { + public void testFromXContent_wenDoRadiusSearch_whenDistanceThreshold_whenFilter_thenSucceed() throws Exception { + final ClusterService clusterService = mockClusterService(Version.CURRENT); + + final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); + knnClusterUtil.initialize(clusterService); + + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(MAX_DISTANCE).filter(TERM_QUERY); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(knnQueryBuilder.fieldName()); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); + builder.field(KNNQueryBuilder.MAX_DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getMaxDistance()); + builder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), knnQueryBuilder.getFilter()); + builder.endObject(); + builder.endObject(); + XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + KNNQueryBuilder actualBuilder = KNNQueryBuilder.fromXContent(contentParser); + assertEquals(knnQueryBuilder, actualBuilder); + } + + public void testFromXContent_wenDoRadiusSearch_whenScoreThreshold_whenFilter_thenSucceed() throws Exception { + final ClusterService clusterService = mockClusterService(Version.CURRENT); + + final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); + knnClusterUtil.initialize(clusterService); + + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(MIN_SCORE).filter(TERM_QUERY); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(knnQueryBuilder.fieldName()); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); + builder.field(KNNQueryBuilder.MAX_DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getMinScore()); + builder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), knnQueryBuilder.getFilter()); + builder.endObject(); + builder.endObject(); + XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + KNNQueryBuilder actualBuilder = KNNQueryBuilder.fromXContent(contentParser); + assertEquals(knnQueryBuilder, actualBuilder); + } + + public void testFromXContent_InvalidQueryVectorType() throws Exception { final ClusterService clusterService = mockClusterService(Version.CURRENT); final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); @@ -157,6 +278,34 @@ public void testFromXContent_invalidQueryVectorType() throws Exception { assertTrue(exception.getMessage().contains("[knn] field 'vector' requires to be an array of numbers")); } + public void testFromXContent_whenDoRadiusSearch_whenInputInvalidQueryVectorType_thenException() throws Exception { + final ClusterService clusterService = mockClusterService(Version.CURRENT); + + final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); + knnClusterUtil.initialize(clusterService); + + List invalidTypeQueryVector = new ArrayList<>(); + invalidTypeQueryVector.add(1.5); + invalidTypeQueryVector.add(2.5); + invalidTypeQueryVector.add("a"); + invalidTypeQueryVector.add(null); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(FIELD_NAME); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), invalidTypeQueryVector); + builder.field(KNNQueryBuilder.MAX_DISTANCE_FIELD.getPreferredName(), MAX_DISTANCE); + builder.endObject(); + builder.endObject(); + XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> KNNQueryBuilder.fromXContent(contentParser) + ); + assertTrue(exception.getMessage().contains("[knn] field 'vector' requires to be an array of numbers")); + } + public void testFromXContent_missingQueryVector() throws Exception { final ClusterService clusterService = mockClusterService(Version.CURRENT); @@ -231,6 +380,178 @@ public void testDoToQuery_Normal() throws Exception { assertEquals(knnQueryBuilder.vector(), query.getQueryVector()); } + public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenDistanceThreshold_thenSucceed() { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(MAX_DISTANCE); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + FloatVectorSimilarityQuery query = (FloatVectorSimilarityQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); + assertTrue(query.toString().contains("resultSimilarity=" + KNNEngine.LUCENE.distanceToRadialThreshold(MAX_DISTANCE, SpaceType.L2))); + } + + public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenScoreThreshold_thenSucceed() { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(MIN_SCORE); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + FloatVectorSimilarityQuery query = (FloatVectorSimilarityQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); + assertTrue(query.toString().contains("resultSimilarity=" + 0.5f)); + } + + public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenSupportedSpaceType_thenSucceed() { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + float negativeDistance = -1.0f; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(negativeDistance); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn( + new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, methodComponentContext) + ); + IndexSettings indexSettings = mock(IndexSettings.class); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + when(indexSettings.getMaxResultWindow()).thenReturn(1000); + + KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); + + assertEquals(negativeDistance, query.getRadius(), 0); + } + + public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenUnSupportedSpaceType_thenException() { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + float negativeDistance = -1.0f; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(negativeDistance); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn( + new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext) + ); + IndexSettings indexSettings = mock(IndexSettings.class); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + when(indexSettings.getMaxResultWindow()).thenReturn(1000); + + expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); + } + + public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenSupportedSpaceType_thenSucceed() { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + float score = 5f; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(score); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn( + new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, methodComponentContext) + ); + IndexSettings indexSettings = mock(IndexSettings.class); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + when(indexSettings.getMaxResultWindow()).thenReturn(1000); + + KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); + + assertEquals(1 - score, query.getRadius(), 0); + } + + public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenUnsupportedSpaceType_thenException() { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + float score = 5f; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(score); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn( + new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext) + ); + IndexSettings indexSettings = mock(IndexSettings.class); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + when(indexSettings.getMaxResultWindow()).thenReturn(1000); + + expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); + } + + public void testDoToQuery_whenPassNegativeDistance_whenSupportedSpaceType_thenSucceed() { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + float negativeDistance = -1.0f; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(negativeDistance); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn( + new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, methodComponentContext) + ); + IndexSettings indexSettings = mock(IndexSettings.class); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + when(indexSettings.getMaxResultWindow()).thenReturn(1000); + + KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); + + assertEquals(negativeDistance, query.getRadius(), 0); + } + + public void testDoToQuery_whenPassNegativeDistance_whenUnSupportedSpaceType_thenException() { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + float negativeDistance = -1.0f; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(negativeDistance); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn( + new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext) + ); + IndexSettings indexSettings = mock(IndexSettings.class); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + when(indexSettings.getMaxResultWindow()).thenReturn(1000); + + expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); + } + public void testDoToQuery_KnnQueryWithFilter() throws Exception { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY); @@ -250,6 +571,42 @@ public void testDoToQuery_KnnQueryWithFilter() throws Exception { assertTrue(query.getClass().isAssignableFrom(KnnFloatVectorQuery.class)); } + public void testDoToQuery_whenDoRadiusSearch_whenDistanceThreshold_whenFilter_thenSucceed() { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(MAX_DISTANCE).filter(TERM_QUERY); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + Query query = knnQueryBuilder.doToQuery(mockQueryShardContext); + assertNotNull(query); + assertTrue(query.getClass().isAssignableFrom(FloatVectorSimilarityQuery.class)); + } + + public void testDoToQuery_whenDoRadiusSearch_whenScoreThreshold_whenFilter_thenSucceed() { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(MIN_SCORE).filter(TERM_QUERY); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + Query query = knnQueryBuilder.doToQuery(mockQueryShardContext); + assertNotNull(query); + assertTrue(query.getClass().isAssignableFrom(FloatVectorSimilarityQuery.class)); + } + public void testDoToQuery_WhenknnQueryWithFilterAndFaissEngine_thenSuccess() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY); @@ -315,6 +672,70 @@ public void testDoToQuery_FromModel() { assertEquals(knnQueryBuilder.vector(), query.getQueryVector()); } + public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold_thenSucceed() { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(MAX_DISTANCE); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + + when(mockKNNVectorField.getDimension()).thenReturn(-K); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn(null); + String modelId = "test-model-id"; + when(mockKNNVectorField.getModelId()).thenReturn(modelId); + + ModelMetadata modelMetadata = mock(ModelMetadata.class); + when(modelMetadata.getDimension()).thenReturn(4); + when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS); + when(modelMetadata.getSpaceType()).thenReturn(SpaceType.L2); + ModelDao modelDao = mock(ModelDao.class); + when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); + KNNQueryBuilder.initialize(modelDao); + IndexSettings indexSettings = mock(IndexSettings.class); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + when(indexSettings.getMaxResultWindow()).thenReturn(1000); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + + KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); + assertEquals(knnQueryBuilder.getMaxDistance(), query.getRadius(), 0); + assertEquals(knnQueryBuilder.fieldName(), query.getField()); + assertEquals(knnQueryBuilder.vector(), query.getQueryVector()); + } + + public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenScoreThreshold_thenSucceed() { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(MIN_SCORE); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + + when(mockKNNVectorField.getDimension()).thenReturn(-K); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn(null); + String modelId = "test-model-id"; + when(mockKNNVectorField.getModelId()).thenReturn(modelId); + + ModelMetadata modelMetadata = mock(ModelMetadata.class); + when(modelMetadata.getDimension()).thenReturn(4); + when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS); + when(modelMetadata.getSpaceType()).thenReturn(SpaceType.L2); + ModelDao modelDao = mock(ModelDao.class); + when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); + KNNQueryBuilder.initialize(modelDao); + IndexSettings indexSettings = mock(IndexSettings.class); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + when(indexSettings.getMaxResultWindow()).thenReturn(1000); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); + + assertEquals(1 / knnQueryBuilder.getMinScore() - 1, query.getRadius(), 0); + assertEquals(knnQueryBuilder.fieldName(), query.getField()); + assertEquals(knnQueryBuilder.vector(), query.getQueryVector()); + } + public void testDoToQuery_InvalidDimensions() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); @@ -383,17 +804,28 @@ public void testDoToQuery_InvalidZeroByteVector() { } public void testSerialization() throws Exception { - assertSerialization(Version.CURRENT, Optional.empty()); - - assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY)); - - assertSerialization(Version.V_2_3_0, Optional.empty()); + // For k-NN search + assertSerialization(Version.CURRENT, Optional.empty(), K, null, null); + assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), K, null, null); + assertSerialization(Version.V_2_3_0, Optional.empty(), K, null, null); + + // For distance threshold search + assertSerialization(Version.CURRENT, Optional.empty(), null, MAX_DISTANCE, null); + assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, MAX_DISTANCE, null); + + // For score threshold search + assertSerialization(Version.CURRENT, Optional.empty(), null, null, MIN_SCORE); + assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, null, MIN_SCORE); } - private void assertSerialization(final Version version, final Optional queryBuilderOptional) throws Exception { - final KNNQueryBuilder knnQueryBuilder = queryBuilderOptional.isPresent() - ? new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, K, queryBuilderOptional.get()) - : new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, K); + private void assertSerialization( + final Version version, + final Optional queryBuilderOptional, + Integer k, + Float distance, + Float score + ) throws Exception { + final KNNQueryBuilder knnQueryBuilder = getKnnQueryBuilder(queryBuilderOptional, k, distance, score); final ClusterService clusterService = mockClusterService(version); @@ -412,7 +844,13 @@ private void assertSerialization(final Version version, final Optional queryBuilderOptional, Integer k, Float distance, Float score) { + final KNNQueryBuilder knnQueryBuilder; + if (k != null) { + knnQueryBuilder = queryBuilderOptional.isPresent() + ? new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, k, queryBuilderOptional.get()) + : new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, k); + } else if (distance != null) { + knnQueryBuilder = queryBuilderOptional.isPresent() + ? new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).maxDistance(distance).filter(queryBuilderOptional.get()) + : new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).maxDistance(distance); + } else if (score != null) { + knnQueryBuilder = queryBuilderOptional.isPresent() + ? new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).minScore(score).filter(queryBuilderOptional.get()) + : new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).minScore(score); + } else { + throw new IllegalArgumentException("Either k or distance must be provided"); + } + return knnQueryBuilder; + } + public void testIgnoreUnmapped() throws IOException { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); @@ -434,4 +892,27 @@ public void testIgnoreUnmapped() throws IOException { knnQueryBuilder.ignoreUnmapped(false); expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mock(QueryShardContext.class))); } + + public void testRadialSearch_whenUnsupportedEngine_thenThrowException() { + List unsupportedEngines = Arrays.stream(KNNEngine.values()) + .filter(knnEngine -> !ENGINES_SUPPORTING_RADIAL_SEARCH.contains(knnEngine)) + .collect(Collectors.toList()); + for (KNNEngine knnEngine : unsupportedEngines) { + KNNMethodContext knnMethodContext = new KNNMethodContext( + knnEngine, + SpaceType.L2, + new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()) + ); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).maxDistance(MAX_DISTANCE); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + Index dummyIndex = new Index("dummy", "dummy"); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + + expectThrows(UnsupportedOperationException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); + } + } } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index da4b7093a..581d3c0b4 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -61,6 +61,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyFloat; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; @@ -691,6 +692,68 @@ public void testANNWithParentsFilter_whenDoingANN_thenBitSetIsPassedToJNI() { assertEquals(DOC_ID_TO_SCORES.size(), docIdSetIterator.cost()); } + @SneakyThrows + public void testDoANNSearch_whenRadialIsDefined_thenCallJniRadiusQueryIndex() { + final float[] queryVector = new float[] { 0.1f, 0.3f }; + final float radius = 0.5f; + final int maxResults = 1000; + jniServiceMockedStatic.when(() -> JNIService.radiusQueryIndex(anyLong(), any(), anyFloat(), any(), anyInt())) + .thenReturn(getKNNQueryResults()); + KNNQuery.Context context = mock(KNNQuery.Context.class); + when(context.getMaxResultWindow()).thenReturn(maxResults); + + final KNNQuery query = new KNNQuery(FIELD_NAME, queryVector, INDEX_NAME, null).radius(radius).kNNQueryContext(context); + final float boost = (float) randomDoubleBetween(0, 10, true); + final KNNWeight knnWeight = new KNNWeight(query, boost); + + final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + final SegmentReader reader = mock(SegmentReader.class); + when(leafReaderContext.reader()).thenReturn(reader); + + final FSDirectory directory = mock(FSDirectory.class); + when(reader.directory()).thenReturn(directory); + final SegmentInfo segmentInfo = new SegmentInfo( + directory, + Version.LATEST, + Version.LATEST, + SEGMENT_NAME, + 100, + true, + false, + KNNCodecVersion.current().getDefaultCodecDelegate(), + Map.of(), + new byte[StringHelper.ID_LENGTH], + Map.of(), + Sort.RELEVANCE + ); + segmentInfo.setFiles(SEGMENT_FILES_FAISS); + final SegmentCommitInfo segmentCommitInfo = new SegmentCommitInfo(segmentInfo, 0, 0, 0, 0, 0, new byte[StringHelper.ID_LENGTH]); + when(reader.getSegmentInfo()).thenReturn(segmentCommitInfo); + + final Path path = mock(Path.class); + when(directory.getDirectory()).thenReturn(path); + final FieldInfos fieldInfos = mock(FieldInfos.class); + final FieldInfo fieldInfo = mock(FieldInfo.class); + when(reader.getFieldInfos()).thenReturn(fieldInfos); + when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); + when(fieldInfo.attributes()).thenReturn(Map.of(SPACE_TYPE, SpaceType.L2.getValue(), KNN_ENGINE, KNNEngine.FAISS.getName())); + + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + assertNotNull(knnScorer); + jniServiceMockedStatic.verify(() -> JNIService.radiusQueryIndex(anyLong(), any(), anyFloat(), any(), anyInt())); + + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + + final List actualDocIds = new ArrayList<>(); + final Map translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation); + for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + actualDocIds.add(docId); + assertEquals(translatedScores.get(docId) * boost, knnScorer.score(), 0.01f); + } + assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + } + private SegmentReader getMockedSegmentReader() { final SegmentReader reader = mock(SegmentReader.class); when(reader.maxDoc()).thenReturn(1); diff --git a/src/test/java/org/opensearch/knn/index/query/RNNQueryFactoryTests.java b/src/test/java/org/opensearch/knn/index/query/RNNQueryFactoryTests.java new file mode 100644 index 000000000..5492b8506 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/RNNQueryFactoryTests.java @@ -0,0 +1,134 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import org.apache.lucene.search.ByteVectorSimilarityQuery; +import org.apache.lucene.search.FloatVectorSimilarityQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.join.BitSetProducer; +import org.opensearch.index.IndexSettings; +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.util.KNNEngine; + +public class RNNQueryFactoryTests extends KNNTestCase { + private static final String FILTER_FILED_NAME = "foo"; + private static final String FILTER_FILED_VALUE = "fooval"; + private static final QueryBuilder FILTER_QUERY_BUILDER = new TermQueryBuilder(FILTER_FILED_NAME, FILTER_FILED_VALUE); + private final int testQueryDimension = 17; + private final float[] testQueryVector = new float[testQueryDimension]; + private final byte[] testByteQueryVector = new byte[testQueryDimension]; + private final String testIndexName = "test-index"; + private final String testFieldName = "test-field"; + private final Float testRadius = 0.5f; + private final int maxResultWindow = 20000; + + public void testCreate_whenLucene_withRadiusQuery_withFloatVector() { + List luceneDefaultQueryEngineList = Arrays.stream(KNNEngine.values()) + .filter(knnEngine -> !KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)) + .collect(Collectors.toList()); + for (KNNEngine knnEngine : luceneDefaultQueryEngineList) { + Query query = RNNQueryFactory.create( + knnEngine, + testIndexName, + testFieldName, + testQueryVector, + testRadius, + DEFAULT_VECTOR_DATA_TYPE_FIELD + ); + assertEquals(FloatVectorSimilarityQuery.class, query.getClass()); + } + } + + public void testCreate_whenLucene_withRadiusQuery_withByteVector() { + List luceneDefaultQueryEngineList = Arrays.stream(KNNEngine.values()) + .filter(knnEngine -> !KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)) + .collect(Collectors.toList()); + for (KNNEngine knnEngine : luceneDefaultQueryEngineList) { + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + MappedFieldType testMapper = mock(MappedFieldType.class); + when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper); + BitSetProducer parentFilter = mock(BitSetProducer.class); + when(mockQueryShardContext.getParentFilter()).thenReturn(parentFilter); + final RNNQueryFactory.CreateQueryRequest createQueryRequest = RNNQueryFactory.CreateQueryRequest.builder() + .knnEngine(knnEngine) + .indexName(testIndexName) + .fieldName(testFieldName) + .vector(testQueryVector) + .radius(testRadius) + .byteVector(testByteQueryVector) + .vectorDataType(VectorDataType.BYTE) + .context(mockQueryShardContext) + .filter(FILTER_QUERY_BUILDER) + .build(); + Query query = RNNQueryFactory.create(createQueryRequest); + assertEquals(ByteVectorSimilarityQuery.class, query.getClass()); + } + } + + public void testCreate_whenLucene_withFilter_thenSucceed() { + List luceneDefaultQueryEngineList = Arrays.stream(KNNEngine.values()) + .filter(knnEngine -> !KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)) + .collect(Collectors.toList()); + for (KNNEngine knnEngine : luceneDefaultQueryEngineList) { + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + MappedFieldType testMapper = mock(MappedFieldType.class); + when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper); + final RNNQueryFactory.CreateQueryRequest createQueryRequest = RNNQueryFactory.CreateQueryRequest.builder() + .knnEngine(knnEngine) + .indexName(testIndexName) + .fieldName(testFieldName) + .vector(testQueryVector) + .vectorDataType(DEFAULT_VECTOR_DATA_TYPE_FIELD) + .context(mockQueryShardContext) + .filter(FILTER_QUERY_BUILDER) + .radius(testRadius) + .build(); + Query query = RNNQueryFactory.create(createQueryRequest); + assertEquals(FloatVectorSimilarityQuery.class, query.getClass()); + } + } + + public void testCreate_whenFaiss_thenSucceed() { + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + MappedFieldType testMapper = mock(MappedFieldType.class); + IndexSettings indexSettings = mock(IndexSettings.class); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper); + when(mockQueryShardContext.getIndexSettings().getMaxResultWindow()).thenReturn(maxResultWindow); + final RNNQueryFactory.CreateQueryRequest createQueryRequest = RNNQueryFactory.CreateQueryRequest.builder() + .knnEngine(KNNEngine.FAISS) + .indexName(testIndexName) + .fieldName(testFieldName) + .vector(testQueryVector) + .radius(testRadius) + .vectorDataType(DEFAULT_VECTOR_DATA_TYPE_FIELD) + .context(mockQueryShardContext) + .build(); + + Query query = RNNQueryFactory.create(createQueryRequest); + + assertTrue(query instanceof KNNQuery); + assertEquals(testIndexName, ((KNNQuery) query).getIndexName()); + assertEquals(testFieldName, ((KNNQuery) query).getField()); + assertEquals(testQueryVector, ((KNNQuery) query).getQueryVector()); + assertEquals(testRadius, ((KNNQuery) query).getRadius(), 0); + assertEquals(maxResultWindow, ((KNNQuery) query).getContext().getMaxResultWindow()); + } +} diff --git a/src/test/java/org/opensearch/knn/index/util/AbstractKNNLibraryTests.java b/src/test/java/org/opensearch/knn/index/util/AbstractKNNLibraryTests.java index 916e87414..9e6bd67ea 100644 --- a/src/test/java/org/opensearch/knn/index/util/AbstractKNNLibraryTests.java +++ b/src/test/java/org/opensearch/knn/index/util/AbstractKNNLibraryTests.java @@ -127,6 +127,15 @@ public float score(float rawScore, SpaceType spaceType) { return 0; } + @Override + public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { + return 0f; + } + + public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { + return 0f; + } + @Override public int estimateOverheadInKB(KNNMethodContext knnMethodContext, int dimension) { return 0; diff --git a/src/test/java/org/opensearch/knn/index/util/LuceneTests.java b/src/test/java/org/opensearch/knn/index/util/LuceneTests.java index 6de46b52d..c9ffd13b2 100644 --- a/src/test/java/org/opensearch/knn/index/util/LuceneTests.java +++ b/src/test/java/org/opensearch/knn/index/util/LuceneTests.java @@ -99,28 +99,28 @@ public void testLucenHNSWMethod() throws IOException { } public void testGetExtension() { - Lucene luceneLibrary = new Lucene(Collections.emptyMap(), ""); + Lucene luceneLibrary = new Lucene(Collections.emptyMap(), "", Collections.emptyMap()); expectThrows(UnsupportedOperationException.class, luceneLibrary::getExtension); } public void testGetCompundExtension() { - Lucene luceneLibrary = new Lucene(Collections.emptyMap(), ""); + Lucene luceneLibrary = new Lucene(Collections.emptyMap(), "", Collections.emptyMap()); expectThrows(UnsupportedOperationException.class, luceneLibrary::getCompoundExtension); } public void testScore() { - Lucene luceneLibrary = new Lucene(Collections.emptyMap(), ""); + Lucene luceneLibrary = new Lucene(Collections.emptyMap(), "", Collections.emptyMap()); float rawScore = 10.0f; assertEquals(rawScore, luceneLibrary.score(rawScore, SpaceType.DEFAULT), 0.001); } public void testIsInitialized() { - Lucene luceneLibrary = new Lucene(Collections.emptyMap(), ""); + Lucene luceneLibrary = new Lucene(Collections.emptyMap(), "", Collections.emptyMap()); assertFalse(luceneLibrary.isInitialized()); } public void testSetInitialized() { - Lucene luceneLibrary = new Lucene(Collections.emptyMap(), ""); + Lucene luceneLibrary = new Lucene(Collections.emptyMap(), "", Collections.emptyMap()); luceneLibrary.setInitialized(true); assertTrue(luceneLibrary.isInitialized()); } diff --git a/src/test/java/org/opensearch/knn/index/util/NativeLibraryTests.java b/src/test/java/org/opensearch/knn/index/util/NativeLibraryTests.java index 00a628f1e..3c3afbee6 100644 --- a/src/test/java/org/opensearch/knn/index/util/NativeLibraryTests.java +++ b/src/test/java/org/opensearch/knn/index/util/NativeLibraryTests.java @@ -64,5 +64,15 @@ public TestNativeLibrary( ) { super(methods, scoreTranslation, currentVersion, extension); } + + @Override + public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { + return 0.0f; + } + + @Override + public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { + return 0.0f; + } } } diff --git a/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java b/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java index 64cf86381..cff4d5805 100644 --- a/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java +++ b/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java @@ -63,6 +63,16 @@ public float score(float rawScore, SpaceType spaceType) { return 0; } + @Override + public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { + return 0.0f; + } + + @Override + public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { + return 0.0f; + } + @Override public ValidationException validateMethod(KNNMethodContext knnMethodContext) { return null; diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 68255388b..396c8ea64 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -202,6 +202,23 @@ protected Response searchKNNIndex(String index, KNNQueryBuilder knnQueryBuilder, return response; } + /** + * Run KNN Search on Index with XContentBuilder query + */ + protected Response searchKNNIndex(String index, XContentBuilder xContentBuilder, int resultSize) throws IOException { + Request request = new Request("POST", "/" + index + "/_search"); + request.setJsonEntity(xContentBuilder.toString()); + + request.addParameter("size", Integer.toString(resultSize)); + request.addParameter("explain", Boolean.toString(true)); + request.addParameter("search_type", "query_then_fetch"); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + return response; + } + /** * Run exists search */