From 36dc0bca1670c3b31ce52bd498740b357b224f61 Mon Sep 17 00:00:00 2001 From: Junqiu Lei Date: Tue, 12 Mar 2024 11:08:40 -0700 Subject: [PATCH 1/6] Support distance type radius search for Lucene engine (#1498) Signed-off-by: Junqiu Lei --- CHANGELOG.md | 1 + .../opensearch/knn/common/KNNConstants.java | 2 + .../org/opensearch/knn/index/IndexUtil.java | 2 + .../knn/index/query/BaseQueryFactory.java | 95 +++++++++ .../knn/index/query/KNNQueryBuilder.java | 161 ++++++++++++-- .../knn/index/query/KNNQueryFactory.java | 120 ++--------- .../knn/index/query/RNNQueryFactory.java | 113 ++++++++++ .../org/opensearch/knn/index/util/Faiss.java | 6 + .../opensearch/knn/index/util/KNNEngine.java | 6 + .../opensearch/knn/index/util/KNNLibrary.java | 10 + .../org/opensearch/knn/index/util/Lucene.java | 22 +- .../org/opensearch/knn/index/util/Nmslib.java | 5 + .../opensearch/knn/index/LuceneEngineIT.java | 110 +++++++++- .../knn/index/query/KNNQueryBuilderTests.java | 198 +++++++++++++++++- .../knn/index/query/RNNQueryFactoryTests.java | 105 ++++++++++ .../index/util/AbstractKNNLibraryTests.java | 5 + .../knn/index/util/LuceneTests.java | 10 +- .../knn/index/util/NativeLibraryTests.java | 5 + .../LibraryInitializedSupplierTests.java | 5 + .../org/opensearch/knn/KNNRestTestCase.java | 17 ++ 20 files changed, 862 insertions(+), 136 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java create mode 100644 src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java create mode 100644 src/test/java/org/opensearch/knn/index/query/RNNQueryFactoryTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 91a37642c..ecea33b1a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.13...2.x) ### Features +* Support distance type radius search for Lucene engine [#1498](https://github.com/opensearch-project/k-NN/pull/1498) ### Enhancements * Make the HitQueue size more appropriate for exact search [#1549](https://github.com/opensearch-project/k-NN/pull/1549) * Support script score when doc value is disabled [#1573](https://github.com/opensearch-project/k-NN/pull/1573) diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index 98622ee85..69a2bc806 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -67,6 +67,8 @@ public class KNNConstants { public static final String VECTOR_DATA_TYPE_FIELD = "data_type"; public static final VectorDataType DEFAULT_VECTOR_DATA_TYPE_FIELD = VectorDataType.FLOAT; + public static final String RADIAL_SEARCH_KEY = "radial_search"; + // Lucene specific constants public static final String LUCENE_NAME = "lucene"; diff --git a/src/main/java/org/opensearch/knn/index/IndexUtil.java b/src/main/java/org/opensearch/knn/index/IndexUtil.java index adfa611c7..cd4ca7822 100644 --- a/src/main/java/org/opensearch/knn/index/IndexUtil.java +++ b/src/main/java/org/opensearch/knn/index/IndexUtil.java @@ -43,11 +43,13 @@ public class IndexUtil { private static final Version MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED = Version.V_2_11_0; private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT = Version.V_2_12_0; private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_METHOD_COMPONENT_CONTEXT = Version.V_2_13_0; + private static final Version MINIMAL_SUPPORTED_VERSION_FOR_RADIAL_SEARCH = Version.V_2_13_0; private static final Map minimalRequiredVersionMap = new HashMap() { { put("ignore_unmapped", MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED); put(MODEL_NODE_ASSIGNMENT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT); put(MODEL_METHOD_COMPONENT_CONTEXT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_METHOD_COMPONENT_CONTEXT); + put(KNNConstants.RADIAL_SEARCH_KEY, MINIMAL_SUPPORTED_VERSION_FOR_RADIAL_SEARCH); } }; diff --git a/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java new file mode 100644 index 000000000..3146cd33e --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java @@ -0,0 +1,95 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import lombok.NonNull; +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.ToChildBlockJoinQuery; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.search.NestedHelper; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.util.KNNEngine; + +import java.io.IOException; +import java.util.Optional; + +/** +* Base class for creating vector search queries. +*/ +@Log4j2 +public abstract class BaseQueryFactory { + /** + * DTO object to hold data required to create a Query instance. + */ + @AllArgsConstructor + @Builder + @Getter + public static class CreateQueryRequest { + @NonNull + private KNNEngine knnEngine; + @NonNull + private String indexName; + private String fieldName; + private float[] vector; + private byte[] byteVector; + private VectorDataType vectorDataType; + private Integer k; + private Float radius; + private QueryBuilder filter; + private QueryShardContext context; + + public Optional getFilter() { + return Optional.ofNullable(filter); + } + + public Optional getContext() { + return Optional.ofNullable(context); + } + } + + /** + * Creates a query filter. + * + * @param createQueryRequest request object that has all required fields to construct the query + * @return Lucene Query + */ + protected static Query getFilterQuery(BaseQueryFactory.CreateQueryRequest createQueryRequest) { + if (!createQueryRequest.getFilter().isPresent()) { + return null; + } + + final QueryShardContext queryShardContext = createQueryRequest.getContext() + .orElseThrow(() -> new RuntimeException("Shard context cannot be null")); + log.debug( + String.format( + "Creating query with filter for index [%s], field [%s]", + createQueryRequest.getIndexName(), + createQueryRequest.getFieldName() + ) + ); + final Query filterQuery; + try { + filterQuery = createQueryRequest.getFilter().get().toQuery(queryShardContext); + } catch (IOException e) { + throw new RuntimeException("Cannot create query with filter", e); + } + BitSetProducer parentFilter = queryShardContext.getParentFilter(); + if (parentFilter != null) { + boolean mightMatch = new NestedHelper(queryShardContext.getMapperService()).mightMatchNestedDocs(filterQuery); + if (mightMatch) { + return filterQuery; + } + return new ToChildBlockJoinQuery(filterQuery, parentFilter); + } + return filterQuery; + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 2140487c5..2dd21ed5a 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -7,6 +7,7 @@ import java.io.IOException; import java.util.Arrays; + import java.util.List; import java.util.Objects; import lombok.extern.log4j.Log4j2; @@ -24,6 +25,7 @@ import org.opensearch.index.query.AbstractQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryShardContext; +import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; @@ -35,6 +37,7 @@ import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue; import static org.opensearch.knn.index.IndexUtil.isClusterOnOrAfterMinRequiredVersion; +import static org.opensearch.knn.index.util.KNNEngine.ENGINES_SUPPORTING_RADIAL_SEARCH; /** * Helper class to build the KNN query @@ -47,6 +50,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { public static final ParseField K_FIELD = new ParseField("k"); public static final ParseField FILTER_FIELD = new ParseField("filter"); public static final ParseField IGNORE_UNMAPPED_FIELD = new ParseField("ignore_unmapped"); + public static final ParseField DISTANCE_FIELD = new ParseField("distance"); public static final int K_MAX = 10000; /** * The name for the knn query @@ -58,11 +62,74 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { private final String fieldName; private final float[] vector; private int k = 0; + private Float distance = null; private QueryBuilder filter; private boolean ignoreUnmapped = false; /** - * Constructs a new knn query + * Constructs a new query with the given field name and vector + * + * @param fieldName Name of the field + * @param vector Array of floating points + */ + public KNNQueryBuilder(String fieldName, float[] vector) { + if (Strings.isNullOrEmpty(fieldName)) { + throw new IllegalArgumentException("[" + NAME + "] requires fieldName"); + } + if (vector == null) { + throw new IllegalArgumentException("[" + NAME + "] requires query vector"); + } + if (vector.length == 0) { + throw new IllegalArgumentException("[" + NAME + "] query vector is empty"); + } + this.fieldName = fieldName; + this.vector = vector; + } + + /** + * Builder method for k + * + * @param k K nearest neighbours for the given vector + */ + public KNNQueryBuilder k(int k) { + if (k <= 0 || k > K_MAX) { + throw new IllegalArgumentException("[" + NAME + "] requires 0 < k <= " + K_MAX); + } + if (distance != null) { + throw new IllegalArgumentException("[" + NAME + "] requires either k or distance must be set"); + } + this.k = k; + return this; + } + + /** + * Builder method for distance + * + * @param distance the distance threshold for the nearest neighbours + */ + public KNNQueryBuilder distance(Float distance) { + if (distance == null || distance < 0) { + throw new IllegalArgumentException("[" + NAME + "] requires distance >= 0"); + } + if (k != 0) { + throw new IllegalArgumentException("[" + NAME + "] requires either k or distance must be set"); + } + this.distance = distance; + return this; + } + + /** + * Builder method for filter + * + * @param filter QueryBuilder + */ + public KNNQueryBuilder filter(QueryBuilder filter) { + this.filter = filter; + return this; + } + + /** + * Constructs a new query for top k search * * @param fieldName Name of the filed * @param vector Array of floating points @@ -94,6 +161,7 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder fil this.k = k; this.filter = filter; this.ignoreUnmapped = false; + this.distance = null; } public static void initialize(ModelDao modelDao) { @@ -128,6 +196,9 @@ public KNNQueryBuilder(StreamInput in) throws IOException { if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) { ignoreUnmapped = in.readOptionalBoolean(); } + if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { + distance = in.readOptionalFloat(); + } } catch (IOException ex) { throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex); } @@ -137,7 +208,8 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep String fieldName = null; List vector = null; float boost = AbstractQueryBuilder.DEFAULT_BOOST; - int k = 0; + Integer k = null; + Float distance = null; QueryBuilder filter = null; String queryName = null; String currentFieldName = null; @@ -166,6 +238,8 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep } } else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { queryName = parser.text(); + } else if (DISTANCE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + distance = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false); } else { throw new ParsingException( parser.getTokenLocation(), @@ -195,10 +269,21 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep } } - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector), k, filter); - knnQueryBuilder.ignoreUnmapped(ignoreUnmapped); - knnQueryBuilder.queryName(queryName); - knnQueryBuilder.boost(boost); + if ((k != null && distance != null) || (k == null && distance == null)) { + throw new IllegalArgumentException("[" + NAME + "] requires either k or distance must be set"); + } + + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector)).filter(filter) + .ignoreUnmapped(ignoreUnmapped) + .boost(boost) + .queryName(queryName); + + if (k != null) { + knnQueryBuilder.k(k); + } else { + knnQueryBuilder.distance(distance); + } + return knnQueryBuilder; } @@ -211,6 +296,9 @@ protected void doWriteTo(StreamOutput out) throws IOException { if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) { out.writeOptionalBoolean(ignoreUnmapped); } + if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { + out.writeOptionalFloat(distance); + } } /** @@ -231,6 +319,10 @@ public int getK() { return this.k; } + public float getDistance() { + return this.distance; + } + public QueryBuilder getFilter() { return this.filter; } @@ -259,6 +351,9 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio if (filter != null) { builder.field(FILTER_FIELD.getPreferredName(), filter); } + if (distance != null) { + builder.field(DISTANCE_FIELD.getPreferredName(), distance); + } if (ignoreUnmapped) { builder.field(IGNORE_UNMAPPED_FIELD.getPreferredName(), ignoreUnmapped); } @@ -298,6 +393,14 @@ protected Query doToQuery(QueryShardContext context) { } else if (knnMethodContext != null) { // If the dimension is set but the knnMethodContext is not then the field is using the legacy mapping knnEngine = knnMethodContext.getKnnEngine(); + spaceType = knnMethodContext.getSpaceType(); + } + + // Currently, k-NN supports distance type radius search. + // We need transform distance radius to right type of engine required radius. + Float radius = null; + if (this.distance != null) { + radius = knnEngine.distanceToRadialThreshold(this.distance, spaceType); } if (fieldDimension != vector.length) { @@ -325,18 +428,40 @@ protected Query doToQuery(QueryShardContext context) { } String indexName = context.index().getName(); - KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder() - .knnEngine(knnEngine) - .indexName(indexName) - .fieldName(this.fieldName) - .vector(VectorDataType.FLOAT == vectorDataType ? this.vector : null) - .byteVector(VectorDataType.BYTE == vectorDataType ? byteVector : null) - .vectorDataType(vectorDataType) - .k(this.k) - .filter(this.filter) - .context(context) - .build(); - return KNNQueryFactory.create(createQueryRequest); + + if (k != 0) { + KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder() + .knnEngine(knnEngine) + .indexName(indexName) + .fieldName(this.fieldName) + .vector(VectorDataType.FLOAT == vectorDataType ? this.vector : null) + .byteVector(VectorDataType.BYTE == vectorDataType ? byteVector : null) + .vectorDataType(vectorDataType) + .k(this.k) + .filter(this.filter) + .context(context) + .build(); + return KNNQueryFactory.create(createQueryRequest); + } + if (radius != null) { + if (!ENGINES_SUPPORTING_RADIAL_SEARCH.contains(knnEngine)) { + throw new UnsupportedOperationException(String.format("Engine [%s] does not support radial search", knnEngine)); + } + RNNQueryFactory.CreateQueryRequest createQueryRequest = RNNQueryFactory.CreateQueryRequest.builder() + .knnEngine(knnEngine) + .indexName(indexName) + .fieldName(this.fieldName) + .vector(VectorDataType.FLOAT == vectorDataType ? this.vector : null) + .byteVector(VectorDataType.BYTE == vectorDataType ? byteVector : null) + .vectorDataType(vectorDataType) + .radius(radius) + .filter(this.filter) + .context(context) + .radius(radius) + .build(); + return RNNQueryFactory.create(createQueryRequest); + } + throw new IllegalArgumentException("[" + NAME + "] requires either k or distance must be set"); } private ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) { diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index 2ab0e62af..ec1f53d13 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -5,11 +5,6 @@ package org.opensearch.knn.index.query; -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Getter; -import lombok.NonNull; -import lombok.Setter; import lombok.extern.log4j.Log4j2; import org.apache.lucene.search.KnnByteVectorQuery; import org.apache.lucene.search.KnnFloatVectorQuery; @@ -17,16 +12,11 @@ import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; -import org.apache.lucene.search.join.ToChildBlockJoinQuery; -import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryShardContext; -import org.opensearch.index.search.NestedHelper; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; -import java.io.IOException; import java.util.Locale; -import java.util.Optional; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES; @@ -35,7 +25,7 @@ * Creates the Lucene k-NN queries */ @Log4j2 -public class KNNQueryFactory { +public class KNNQueryFactory extends BaseQueryFactory { /** * Creates a Lucene query for a particular engine. @@ -82,7 +72,12 @@ public static Query create(CreateQueryRequest createQueryRequest) { final VectorDataType vectorDataType = createQueryRequest.getVectorDataType(); final Query filterQuery = getFilterQuery(createQueryRequest); - BitSetProducer parentFilter = createQueryRequest.context == null ? null : createQueryRequest.context.getParentFilter(); + BitSetProducer parentFilter = null; + if (createQueryRequest.getContext().isPresent()) { + QueryShardContext context = createQueryRequest.getContext().get(); + parentFilter = context.getParentFilter(); + } + if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) { if (filterQuery != null && KNNEngine.getEnginesThatSupportsFilters().contains(createQueryRequest.getKnnEngine())) { log.debug("Creating custom k-NN query with filters for index: {}, field: {} , k: {}", indexName, fieldName, k); @@ -93,19 +88,21 @@ public static Query create(CreateQueryRequest createQueryRequest) { } log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); - if (VectorDataType.BYTE == vectorDataType) { - return getKnnByteVectorQuery(fieldName, byteVector, k, filterQuery, parentFilter); - } else if (VectorDataType.FLOAT == vectorDataType) { - return getKnnFloatVectorQuery(fieldName, vector, k, filterQuery, parentFilter); - } else { - throw new IllegalArgumentException( - String.format( - Locale.ROOT, - "Invalid value provided for [%s] field. Supported values are [%s]", - VECTOR_DATA_TYPE_FIELD, - SUPPORTED_VECTOR_DATA_TYPES - ) - ); + switch (vectorDataType) { + case BYTE: + return getKnnByteVectorQuery(fieldName, byteVector, k, filterQuery, parentFilter); + case FLOAT: + return getKnnFloatVectorQuery(fieldName, vector, k, filterQuery, parentFilter); + default: + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "Invalid value provided for [%s] field. Supported values are [%s], but got: %s", + VECTOR_DATA_TYPE_FIELD, + SUPPORTED_VECTOR_DATA_TYPES, + vectorDataType + ) + ); } } @@ -144,77 +141,4 @@ private static Query getKnnFloatVectorQuery( return new DiversifyingChildrenFloatKnnVectorQuery(fieldName, floatVector, filterQuery, k, parentFilter); } } - - private static Query getFilterQuery(CreateQueryRequest createQueryRequest) { - if (createQueryRequest.getFilter().isPresent()) { - final QueryShardContext queryShardContext = createQueryRequest.getContext() - .orElseThrow(() -> new RuntimeException("Shard context cannot be null")); - log.debug( - String.format( - "Creating k-NN query with filter for index [%s], field [%s] and k [%d]", - createQueryRequest.getIndexName(), - createQueryRequest.fieldName, - createQueryRequest.k - ) - ); - final Query filterQuery; - try { - filterQuery = createQueryRequest.getFilter().get().toQuery(queryShardContext); - } catch (IOException e) { - throw new RuntimeException("Cannot create knn query with filter", e); - } - // If k-NN Field is nested field then parentFilter will not be null. This parentFilter is set by the - // Opensearch core. Ref PR: https://github.com/opensearch-project/OpenSearch/pull/10246 - if (queryShardContext.getParentFilter() != null) { - // if the filter is also a nested query clause then we should just return the same query without - // considering it to join with the parent documents. - if (new NestedHelper(queryShardContext.getMapperService()).mightMatchNestedDocs(filterQuery)) { - return filterQuery; - } - // This condition will be hit when filters are getting applied on the top level fields and k-nn - // query field is a nested field. In this case we need to wrap the filter query with - // ToChildBlockJoinQuery to ensure parent documents which will be retrieved from filters can be - // joined with the child documents containing vector field. - return new ToChildBlockJoinQuery(filterQuery, queryShardContext.getParentFilter()); - } - return filterQuery; - } - return null; - } - - /** - * DTO object to hold data required to create a Query instance. - */ - @AllArgsConstructor - @Builder - @Setter - static class CreateQueryRequest { - @Getter - @NonNull - private KNNEngine knnEngine; - @Getter - @NonNull - private String indexName; - @Getter - private String fieldName; - @Getter - private float[] vector; - @Getter - private byte[] byteVector; - @Getter - private VectorDataType vectorDataType; - @Getter - private int k; - private QueryBuilder filter; - // can be null in cases filter not passed with the knn query - private QueryShardContext context; - - public Optional getFilter() { - return Optional.ofNullable(filter); - } - - public Optional getContext() { - return Optional.ofNullable(context); - } - } } diff --git a/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java new file mode 100644 index 000000000..69018fa73 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java @@ -0,0 +1,113 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query; + +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES; + +import java.util.Locale; + +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.search.ByteVectorSimilarityQuery; +import org.apache.lucene.search.FloatVectorSimilarityQuery; +import org.apache.lucene.search.Query; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.util.KNNEngine; + +/** + * Class to create radius nearest neighbor queries + */ +@Log4j2 +public class RNNQueryFactory extends BaseQueryFactory { + + /** + * Creates a Lucene query for a particular engine. + * + * @param knnEngine Engine to create the query for + * @param indexName Name of the OpenSearch index that is being queried + * @param fieldName Name of the field in the OpenSearch index that will be queried + * @param vector The query vector to get the nearest neighbors for + * @param radius the radius threshold for the nearest neighbors + * @return Lucene Query + */ + public static Query create( + KNNEngine knnEngine, + String indexName, + String fieldName, + float[] vector, + Float radius, + VectorDataType vectorDataType + ) { + final CreateQueryRequest createQueryRequest = CreateQueryRequest.builder() + .knnEngine(knnEngine) + .indexName(indexName) + .fieldName(fieldName) + .vector(vector) + .vectorDataType(vectorDataType) + .radius(radius) + .build(); + return create(createQueryRequest); + } + + /** + * Creates a Lucene query for a particular engine. + * @param createQueryRequest request object that has all required fields to construct the query + * @return Lucene Query + */ + public static Query create(RNNQueryFactory.CreateQueryRequest createQueryRequest) { + final String indexName = createQueryRequest.getIndexName(); + final String fieldName = createQueryRequest.getFieldName(); + final Float radius = createQueryRequest.getRadius(); + final float[] vector = createQueryRequest.getVector(); + final byte[] byteVector = createQueryRequest.getByteVector(); + final VectorDataType vectorDataType = createQueryRequest.getVectorDataType(); + final Query filterQuery = getFilterQuery(createQueryRequest); + + log.debug(String.format("Creating Lucene r-NN query for index: %s \"\", field: %s \"\", k: %f", indexName, fieldName, radius)); + switch (vectorDataType) { + case BYTE: + return getByteVectorSimilarityQuery(fieldName, byteVector, radius, filterQuery); + case FLOAT: + return getFloatVectorSimilarityQuery(fieldName, vector, radius, filterQuery); + default: + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "Invalid value provided for [%s] field. Supported values are [%s], but got: %s", + VECTOR_DATA_TYPE_FIELD, + SUPPORTED_VECTOR_DATA_TYPES, + vectorDataType + ) + ); + } + } + + /** + * If radius is greater than 0, we return {@link FloatVectorSimilarityQuery} which will return all documents with similarity + * greater than or equal to the resultSimilarity. If filterQuery is not null, it will be used to filter the documents. + */ + private static Query getFloatVectorSimilarityQuery( + final String fieldName, + final float[] floatVector, + final float resultSimilarity, + final Query filterQuery + ) { + return new FloatVectorSimilarityQuery(fieldName, floatVector, resultSimilarity, filterQuery); + } + + /** + * If radius is greater than 0, we return {@link ByteVectorSimilarityQuery} which will return all documents with similarity + * greater than or equal to the resultSimilarity. If filterQuery is not null, it will be used to filter the documents. + */ + private static Query getByteVectorSimilarityQuery( + final String fieldName, + final byte[] byteVector, + final float resultSimilarity, + final Query filterQuery + ) { + return new ByteVectorSimilarityQuery(fieldName, byteVector, resultSimilarity, filterQuery); + } +} diff --git a/src/main/java/org/opensearch/knn/index/util/Faiss.java b/src/main/java/org/opensearch/knn/index/util/Faiss.java index 563311c49..8b58b7e74 100644 --- a/src/main/java/org/opensearch/knn/index/util/Faiss.java +++ b/src/main/java/org/opensearch/knn/index/util/Faiss.java @@ -320,6 +320,12 @@ private Faiss( super(methods, scoreTranslation, currentVersion, extension); } + @Override + public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { + // Faiss engine uses distance as is and does not need transformation + return distance; + } + /** * MethodAsMap builder is used to create the map that will be passed to the jni to create the faiss index. * Faiss's index factory takes an "index description" that it uses to build the index. In this description, diff --git a/src/main/java/org/opensearch/knn/index/util/KNNEngine.java b/src/main/java/org/opensearch/knn/index/util/KNNEngine.java index 8d03d9a9e..5d334e7ef 100644 --- a/src/main/java/org/opensearch/knn/index/util/KNNEngine.java +++ b/src/main/java/org/opensearch/knn/index/util/KNNEngine.java @@ -32,6 +32,7 @@ public enum KNNEngine implements KNNLibrary { private static final Set CUSTOM_SEGMENT_FILE_ENGINES = ImmutableSet.of(KNNEngine.NMSLIB, KNNEngine.FAISS); private static final Set ENGINES_SUPPORTING_FILTERS = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS); + public static final Set ENGINES_SUPPORTING_RADIAL_SEARCH = ImmutableSet.of(KNNEngine.LUCENE); private static Map MAX_DIMENSIONS_BY_ENGINE = Map.of( KNNEngine.NMSLIB, @@ -152,6 +153,11 @@ public float score(float rawScore, SpaceType spaceType) { return knnLibrary.score(rawScore, spaceType); } + @Override + public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { + return knnLibrary.distanceToRadialThreshold(distance, spaceType); + } + @Override public ValidationException validateMethod(KNNMethodContext knnMethodContext) { return knnLibrary.validateMethod(knnMethodContext); diff --git a/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java b/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java index ba1d3ac84..2c1f454c0 100644 --- a/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java +++ b/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java @@ -68,6 +68,16 @@ public interface KNNLibrary { */ float score(float rawScore, SpaceType spaceType); + /** + * Translate the distance radius input from end user to the engine's threshold. + * + * @param distance distance radius input from end user + * @param spaceType spaceType used to compute the radius + * + * @return transformed distance for the library + */ + Float distanceToRadialThreshold(Float distance, SpaceType spaceType); + /** * Validate the knnMethodContext for the given library. A ValidationException should be thrown if the method is * deemed invalid. diff --git a/src/main/java/org/opensearch/knn/index/util/Lucene.java b/src/main/java/org/opensearch/knn/index/util/Lucene.java index 63642ae2c..54a752408 100644 --- a/src/main/java/org/opensearch/knn/index/util/Lucene.java +++ b/src/main/java/org/opensearch/knn/index/util/Lucene.java @@ -15,6 +15,7 @@ import java.util.List; import java.util.Map; +import java.util.function.Function; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; @@ -25,6 +26,8 @@ */ public class Lucene extends JVMLibrary { + Map> distanceTransform; + final static Map METHODS = ImmutableMap.of( METHOD_HNSW, KNNMethod.Builder.builder( @@ -45,16 +48,25 @@ public class Lucene extends JVMLibrary { ).addSpaces(SpaceType.L2, SpaceType.COSINESIMIL, SpaceType.INNER_PRODUCT).build() ); - final static Lucene INSTANCE = new Lucene(METHODS, Version.LATEST.toString()); + private final static Map> DISTANCE_TRANSLATIONS = ImmutableMap.< + SpaceType, + Function>builder() + .put(SpaceType.COSINESIMIL, distance -> (2 - distance) / 2) + .put(SpaceType.L2, distance -> 1 / (1 + distance)) + .build(); + + final static Lucene INSTANCE = new Lucene(METHODS, Version.LATEST.toString(), DISTANCE_TRANSLATIONS); /** * Constructor * * @param methods Map of k-NN methods that the library supports * @param version String representing version of library + * @param distanceTransform Map of space type to distance transformation function */ - Lucene(Map methods, String version) { + Lucene(Map methods, String version, Map> distanceTransform) { super(methods, version); + this.distanceTransform = distanceTransform; } @Override @@ -75,6 +87,12 @@ public float score(float rawScore, SpaceType spaceType) { return rawScore; } + @Override + public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { + // Lucene requires score threshold to be parameterized when calling the radius search. + return this.distanceTransform.get(spaceType).apply(distance); + } + @Override public List mmapFileExtensions() { return List.of("vec", "vex"); diff --git a/src/main/java/org/opensearch/knn/index/util/Nmslib.java b/src/main/java/org/opensearch/knn/index/util/Nmslib.java index 617b311f4..8993cac8e 100644 --- a/src/main/java/org/opensearch/knn/index/util/Nmslib.java +++ b/src/main/java/org/opensearch/knn/index/util/Nmslib.java @@ -68,4 +68,9 @@ private Nmslib( ) { super(methods, scoreTranslation, currentVersion, extension); } + + @Override + public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { + return distance; + } } diff --git a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java index b17155704..8f55730fd 100644 --- a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java +++ b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java @@ -15,6 +15,7 @@ import org.junit.After; import org.opensearch.client.Response; import org.opensearch.client.ResponseException; +import org.opensearch.common.Nullable; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.index.query.QueryBuilders; @@ -33,7 +34,6 @@ import java.util.function.Function; import java.util.stream.Collectors; -import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; @@ -50,7 +50,7 @@ public class LuceneEngineIT extends KNNRestTestCase { private static final int M = 16; private static final Float[][] TEST_INDEX_VECTORS = { { 1.0f, 1.0f, 1.0f }, { 2.0f, 2.0f, 2.0f }, { 3.0f, 3.0f, 3.0f } }; - + private static final Float[][] TEST_COSINESIMIL_INDEX_VECTORS = { { 6.0f, 7.0f, 3.0f }, { 3.0f, 2.0f, 5.0f }, { 4.0f, 5.0f, 7.0f } }; private static final float[][] TEST_QUERY_VECTORS = { { 1.0f, 1.0f, 1.0f }, { 2.0f, 2.0f, 2.0f }, { 3.0f, 3.0f, 3.0f } }; private static final Map> VECTOR_SIMILARITY_TO_SCORE = ImmutableMap.of( @@ -318,6 +318,68 @@ public void testIndexReopening() throws Exception { assertArrayEquals(knnResultsBeforeIndexClosure.toArray(), knnResultsAfterIndexClosure.toArray()); } + public void testRadiusSearch_usingL2Metrics_usingFloatType() throws Exception { + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2, VectorDataType.FLOAT); + for (int j = 0; j < TEST_INDEX_VECTORS.length; j++) { + addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_INDEX_VECTORS[j]); + } + + final float radius = 3.5f; + final int[] expectedResults = { 2, 3, 2 }; + + validateRadiusSearchResults(TEST_QUERY_VECTORS, radius, SpaceType.L2, expectedResults, null, null); + } + + public void testRadiusSearch_usingCosineMetrics_usingFloatType() throws Exception { + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.COSINESIMIL, VectorDataType.FLOAT); + for (int j = 0; j < TEST_COSINESIMIL_INDEX_VECTORS.length; j++) { + addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_COSINESIMIL_INDEX_VECTORS[j]); + } + + final float radius = 0.03f; + final int[] expectedResults = { 1, 1, 1 }; + + validateRadiusSearchResults(TEST_QUERY_VECTORS, radius, SpaceType.COSINESIMIL, expectedResults, null, null); + } + + public void testRadiusSearch_usingL2Metrics_usingByteType() throws Exception { + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2, VectorDataType.BYTE); + for (int j = 0; j < TEST_INDEX_VECTORS.length; j++) { + addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_INDEX_VECTORS[j]); + } + + final float radius = 3.5f; + final int[] expectedResults = { 2, 2, 2 }; + + validateRadiusSearchResults(TEST_QUERY_VECTORS, radius, SpaceType.L2, expectedResults, null, null); + } + + public void testRadiusSearch_usingCosineMetrics_usingByteType() throws Exception { + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.COSINESIMIL, VectorDataType.BYTE); + for (int j = 0; j < TEST_COSINESIMIL_INDEX_VECTORS.length; j++) { + addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_COSINESIMIL_INDEX_VECTORS[j]); + } + + final float radius = 0.05f; + final int[] expectedResults = { 2, 2, 2 }; + + validateRadiusSearchResults(TEST_QUERY_VECTORS, radius, SpaceType.COSINESIMIL, expectedResults, null, null); + } + + public void testRadiusSearch_withFilter_usingL2Metrics_usingFloatType() throws Exception { + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2, VectorDataType.FLOAT); + addKnnDocWithAttributes(DOC_ID, new float[] { 6.0f, 7.9f, 3.1f }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); + addKnnDocWithAttributes(DOC_ID_2, new float[] { 3.2f, 2.1f, 4.8f }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); + addKnnDocWithAttributes(DOC_ID_3, new float[] { 4.1f, 5.0f, 7.1f }, ImmutableMap.of(COLOR_FIELD_NAME, "green")); + + refreshIndex(INDEX_NAME); + + final float radius = 45.0f; + final int[] expectedResults = { 1, 1, 1 }; + + validateRadiusSearchResults(TEST_QUERY_VECTORS, radius, SpaceType.L2, expectedResults, COLOR_FIELD_NAME, "red"); + } + private void createKnnIndexMappingWithLuceneEngine(int dimension, SpaceType spaceType, VectorDataType vectorDataType) throws Exception { XContentBuilder builder = XContentFactory.jsonBuilder() .startObject() @@ -467,4 +529,48 @@ public void test_whenUsingIP_thenSuccess() { assertEquals(expectedScores.get(i), knnResults.get(i), 0.0000001); } } + + private void validateRadiusSearchResults( + final float[][] searchVectors, + final float radius, + final SpaceType spaceType, + final int[] expectedResults, + @Nullable final String filterField, + @Nullable final String filterValue + ) throws Exception { + for (int i = 0; i < searchVectors.length; i++) { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("query"); + builder.startObject("knn"); + builder.startObject(FIELD_NAME); + builder.field("vector", searchVectors[i]); + builder.field("distance", radius); + if (filterField != null && filterValue != null) { + builder.startObject("filter"); + builder.startObject("term"); + builder.field(filterField, filterValue); + builder.endObject(); + builder.endObject(); + } + builder.endObject(); + builder.endObject(); + builder.endObject().endObject(); + + final String responseBody = EntityUtils.toString(searchKNNIndex(INDEX_NAME, builder, expectedResults[i]).getEntity()); + final List radiusResults = parseSearchResponse(responseBody, FIELD_NAME); + + assertEquals(expectedResults[i], radiusResults.size()); + + List actualScores = parseSearchResponseScore(responseBody, FIELD_NAME); + for (KNNResult result : radiusResults) { + float[] primitiveArray = Floats.toArray(Arrays.stream(result.getVector()).collect(Collectors.toList())); + float distance = TestUtils.computeDistFromSpaceType(spaceType, primitiveArray, searchVectors[i]); + float rawScore = VECTOR_SIMILARITY_TO_SCORE.get(spaceType.getVectorSimilarityFunction()).apply(distance); + if (spaceType == SpaceType.COSINESIMIL) { + distance = 1 - distance; + } + assertTrue(distance <= radius); + assertEquals(KNNEngine.LUCENE.score(rawScore, spaceType), actualScores.get(radiusResults.indexOf(result)), 0.0001); + } + } + } } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index bcd784e23..de56aeb15 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -6,6 +6,7 @@ package org.opensearch.knn.index.query; import com.google.common.collect.ImmutableMap; +import org.apache.lucene.search.FloatVectorSimilarityQuery; import java.util.Locale; import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.MatchNoDocsQuery; @@ -41,8 +42,10 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Optional; +import java.util.stream.Collectors; import static org.hamcrest.Matchers.instanceOf; import static org.mockito.Mockito.anyString; @@ -50,11 +53,13 @@ import static org.mockito.Mockito.when; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.index.KNNClusterTestUtils.mockClusterService; +import static org.opensearch.knn.index.util.KNNEngine.ENGINES_SUPPORTING_RADIAL_SEARCH; public class KNNQueryBuilderTests extends KNNTestCase { private static final String FIELD_NAME = "myvector"; private static final int K = 1; + private static final Float DISTANCE = 1.0f; private static final TermQueryBuilder TERM_QUERY = QueryBuilders.termQuery("field", "value"); private static final float[] QUERY_VECTOR = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; @@ -77,6 +82,20 @@ public void testInvalidK() { expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector, KNNQueryBuilder.K_MAX + K)); } + public void testInvalidDistance() { + float[] queryVector = { 1.0f, 1.0f }; + + /** + * -ve distance + */ + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).distance(-1.0f)); + + /** + * null distance + */ + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).distance(null)); + } + public void testEmptyVector() { /** * null query vector @@ -89,6 +108,18 @@ public void testEmptyVector() { */ float[] queryVector1 = {}; expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector1, K)); + + /** + * null query vector with distance + */ + float[] queryVector2 = null; + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector2).distance(DISTANCE)); + + /** + * empty query vector with distance + */ + float[] queryVector3 = {}; + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector3).distance(DISTANCE)); } public void testFromXContent() throws Exception { @@ -107,7 +138,23 @@ public void testFromXContent() throws Exception { assertEquals(knnQueryBuilder, actualBuilder); } - public void testFromXContent_WithFilter() throws Exception { + public void testFromXContent_whenDoRadiusSearch_thenSucceed() throws Exception { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).distance(DISTANCE); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(knnQueryBuilder.fieldName()); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); + builder.field(KNNQueryBuilder.DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getDistance()); + builder.endObject(); + builder.endObject(); + XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + KNNQueryBuilder actualBuilder = KNNQueryBuilder.fromXContent(contentParser); + assertEquals(knnQueryBuilder, actualBuilder); + } + + public void testFromXContent_withFilter() throws Exception { final ClusterService clusterService = mockClusterService(Version.CURRENT); final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); @@ -129,7 +176,29 @@ public void testFromXContent_WithFilter() throws Exception { assertEquals(knnQueryBuilder, actualBuilder); } - public void testFromXContent_invalidQueryVectorType() throws Exception { + public void testFromXContent_wenDoRadiusSearch_whenFilter_thenSucceed() throws Exception { + final ClusterService clusterService = mockClusterService(Version.CURRENT); + + final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); + knnClusterUtil.initialize(clusterService); + + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).distance(DISTANCE).filter(TERM_QUERY); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(knnQueryBuilder.fieldName()); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); + builder.field(KNNQueryBuilder.DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getDistance()); + builder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), knnQueryBuilder.getFilter()); + builder.endObject(); + builder.endObject(); + XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + KNNQueryBuilder actualBuilder = KNNQueryBuilder.fromXContent(contentParser); + assertEquals(knnQueryBuilder, actualBuilder); + } + + public void testFromXContent_InvalidQueryVectorType() throws Exception { final ClusterService clusterService = mockClusterService(Version.CURRENT); final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); @@ -157,6 +226,34 @@ public void testFromXContent_invalidQueryVectorType() throws Exception { assertTrue(exception.getMessage().contains("[knn] field 'vector' requires to be an array of numbers")); } + public void testFromXContent_whenDoRadiusSearch_whenInputInvalidQueryVectorType_thenException() throws Exception { + final ClusterService clusterService = mockClusterService(Version.CURRENT); + + final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); + knnClusterUtil.initialize(clusterService); + + List invalidTypeQueryVector = new ArrayList<>(); + invalidTypeQueryVector.add(1.5); + invalidTypeQueryVector.add(2.5); + invalidTypeQueryVector.add("a"); + invalidTypeQueryVector.add(null); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(FIELD_NAME); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), invalidTypeQueryVector); + builder.field(KNNQueryBuilder.DISTANCE_FIELD.getPreferredName(), DISTANCE); + builder.endObject(); + builder.endObject(); + XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> KNNQueryBuilder.fromXContent(contentParser) + ); + assertTrue(exception.getMessage().contains("[knn] field 'vector' requires to be an array of numbers")); + } + public void testFromXContent_missingQueryVector() throws Exception { final ClusterService clusterService = mockClusterService(Version.CURRENT); @@ -231,6 +328,23 @@ public void testDoToQuery_Normal() throws Exception { assertEquals(knnQueryBuilder.vector(), query.getQueryVector()); } + public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenSucceed() { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).distance(DISTANCE); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + FloatVectorSimilarityQuery query = (FloatVectorSimilarityQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); + assertTrue(query.toString().contains("resultSimilarity=" + KNNEngine.LUCENE.distanceToRadialThreshold(DISTANCE, SpaceType.L2))); + } + public void testDoToQuery_KnnQueryWithFilter() throws Exception { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY); @@ -250,6 +364,24 @@ public void testDoToQuery_KnnQueryWithFilter() throws Exception { assertTrue(query.getClass().isAssignableFrom(KnnFloatVectorQuery.class)); } + public void testDoToQuery_whenDoRadiusSearch_whenFilter_thenSucceed() { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).distance(DISTANCE).filter(TERM_QUERY); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + Query query = knnQueryBuilder.doToQuery(mockQueryShardContext); + assertNotNull(query); + assertTrue(query.getClass().isAssignableFrom(FloatVectorSimilarityQuery.class)); + } + public void testDoToQuery_WhenknnQueryWithFilterAndFaissEngine_thenSuccess() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY); @@ -383,17 +515,18 @@ public void testDoToQuery_InvalidZeroByteVector() { } public void testSerialization() throws Exception { - assertSerialization(Version.CURRENT, Optional.empty()); + assertSerialization(Version.CURRENT, Optional.empty(), K, null); + assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), K, null); + assertSerialization(Version.V_2_3_0, Optional.empty(), K, null); - assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY)); - - assertSerialization(Version.V_2_3_0, Optional.empty()); + // For radius search + assertSerialization(Version.CURRENT, Optional.empty(), null, DISTANCE); + assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, DISTANCE); } - private void assertSerialization(final Version version, final Optional queryBuilderOptional) throws Exception { - final KNNQueryBuilder knnQueryBuilder = queryBuilderOptional.isPresent() - ? new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, K, queryBuilderOptional.get()) - : new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, K); + private void assertSerialization(final Version version, final Optional queryBuilderOptional, Integer k, Float distance) + throws Exception { + final KNNQueryBuilder knnQueryBuilder = getKnnQueryBuilder(queryBuilderOptional, k, distance); final ClusterService clusterService = mockClusterService(version); @@ -412,7 +545,11 @@ private void assertSerialization(final Version version, final Optional queryBuilderOptional, Integer k, Float distance) { + final KNNQueryBuilder knnQueryBuilder; + if (k != null) { + knnQueryBuilder = queryBuilderOptional.isPresent() + ? new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, k, queryBuilderOptional.get()) + : new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, k); + } else if (distance != null) { + knnQueryBuilder = queryBuilderOptional.isPresent() + ? new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).distance(distance).filter(queryBuilderOptional.get()) + : new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).distance(distance); + } else { + throw new IllegalArgumentException("Either k or distance must be provided"); + } + return knnQueryBuilder; + } + public void testIgnoreUnmapped() throws IOException { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); @@ -434,4 +587,27 @@ public void testIgnoreUnmapped() throws IOException { knnQueryBuilder.ignoreUnmapped(false); expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mock(QueryShardContext.class))); } + + public void testRadialSearch_whenUnsupportedEngine_thenThrowException() { + List unsupportedEngines = Arrays.stream(KNNEngine.values()) + .filter(knnEngine -> !ENGINES_SUPPORTING_RADIAL_SEARCH.contains(knnEngine)) + .collect(Collectors.toList()); + for (KNNEngine knnEngine : unsupportedEngines) { + KNNMethodContext knnMethodContext = new KNNMethodContext( + knnEngine, + SpaceType.L2, + new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()) + ); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).distance(DISTANCE); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + Index dummyIndex = new Index("dummy", "dummy"); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + + expectThrows(UnsupportedOperationException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); + } + } } diff --git a/src/test/java/org/opensearch/knn/index/query/RNNQueryFactoryTests.java b/src/test/java/org/opensearch/knn/index/query/RNNQueryFactoryTests.java new file mode 100644 index 000000000..42a3c2897 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/RNNQueryFactoryTests.java @@ -0,0 +1,105 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import org.apache.lucene.search.ByteVectorSimilarityQuery; +import org.apache.lucene.search.FloatVectorSimilarityQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.join.BitSetProducer; +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.util.KNNEngine; + +public class RNNQueryFactoryTests extends KNNTestCase { + private static final String FILTER_FILED_NAME = "foo"; + private static final String FILTER_FILED_VALUE = "fooval"; + private static final QueryBuilder FILTER_QUERY_BUILDER = new TermQueryBuilder(FILTER_FILED_NAME, FILTER_FILED_VALUE); + private final int testQueryDimension = 17; + private final float[] testQueryVector = new float[testQueryDimension]; + private final byte[] testByteQueryVector = new byte[testQueryDimension]; + private final String testIndexName = "test-index"; + private final String testFieldName = "test-field"; + private final Float testRadius = 0.5f; + + public void testCreate_whenLucene_withRadiusQuery_withFloatVector() { + List luceneDefaultQueryEngineList = Arrays.stream(KNNEngine.values()) + .filter(knnEngine -> !KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)) + .collect(Collectors.toList()); + for (KNNEngine knnEngine : luceneDefaultQueryEngineList) { + Query query = RNNQueryFactory.create( + knnEngine, + testIndexName, + testFieldName, + testQueryVector, + testRadius, + DEFAULT_VECTOR_DATA_TYPE_FIELD + ); + assertEquals(FloatVectorSimilarityQuery.class, query.getClass()); + } + } + + public void testCreate_whenLucene_withRadiusQuery_withByteVector() { + List luceneDefaultQueryEngineList = Arrays.stream(KNNEngine.values()) + .filter(knnEngine -> !KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)) + .collect(Collectors.toList()); + for (KNNEngine knnEngine : luceneDefaultQueryEngineList) { + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + MappedFieldType testMapper = mock(MappedFieldType.class); + when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper); + BitSetProducer parentFilter = mock(BitSetProducer.class); + when(mockQueryShardContext.getParentFilter()).thenReturn(parentFilter); + final RNNQueryFactory.CreateQueryRequest createQueryRequest = RNNQueryFactory.CreateQueryRequest.builder() + .knnEngine(knnEngine) + .indexName(testIndexName) + .fieldName(testFieldName) + .vector(testQueryVector) + .radius(testRadius) + .byteVector(testByteQueryVector) + .vectorDataType(VectorDataType.BYTE) + .context(mockQueryShardContext) + .filter(FILTER_QUERY_BUILDER) + .build(); + Query query = RNNQueryFactory.create(createQueryRequest); + assertEquals(ByteVectorSimilarityQuery.class, query.getClass()); + } + } + + public void testCreateLuceneRadiusQueryWithFilter() { + List luceneDefaultQueryEngineList = Arrays.stream(KNNEngine.values()) + .filter(knnEngine -> !KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)) + .collect(Collectors.toList()); + for (KNNEngine knnEngine : luceneDefaultQueryEngineList) { + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + MappedFieldType testMapper = mock(MappedFieldType.class); + when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper); + final RNNQueryFactory.CreateQueryRequest createQueryRequest = RNNQueryFactory.CreateQueryRequest.builder() + .knnEngine(knnEngine) + .indexName(testIndexName) + .fieldName(testFieldName) + .vector(testQueryVector) + .vectorDataType(DEFAULT_VECTOR_DATA_TYPE_FIELD) + .context(mockQueryShardContext) + .filter(FILTER_QUERY_BUILDER) + .radius(testRadius) + .build(); + Query query = RNNQueryFactory.create(createQueryRequest); + assertEquals(FloatVectorSimilarityQuery.class, query.getClass()); + } + } +} diff --git a/src/test/java/org/opensearch/knn/index/util/AbstractKNNLibraryTests.java b/src/test/java/org/opensearch/knn/index/util/AbstractKNNLibraryTests.java index 916e87414..747bfc751 100644 --- a/src/test/java/org/opensearch/knn/index/util/AbstractKNNLibraryTests.java +++ b/src/test/java/org/opensearch/knn/index/util/AbstractKNNLibraryTests.java @@ -127,6 +127,11 @@ public float score(float rawScore, SpaceType spaceType) { return 0; } + @Override + public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { + return 0f; + } + @Override public int estimateOverheadInKB(KNNMethodContext knnMethodContext, int dimension) { return 0; diff --git a/src/test/java/org/opensearch/knn/index/util/LuceneTests.java b/src/test/java/org/opensearch/knn/index/util/LuceneTests.java index 6de46b52d..c9ffd13b2 100644 --- a/src/test/java/org/opensearch/knn/index/util/LuceneTests.java +++ b/src/test/java/org/opensearch/knn/index/util/LuceneTests.java @@ -99,28 +99,28 @@ public void testLucenHNSWMethod() throws IOException { } public void testGetExtension() { - Lucene luceneLibrary = new Lucene(Collections.emptyMap(), ""); + Lucene luceneLibrary = new Lucene(Collections.emptyMap(), "", Collections.emptyMap()); expectThrows(UnsupportedOperationException.class, luceneLibrary::getExtension); } public void testGetCompundExtension() { - Lucene luceneLibrary = new Lucene(Collections.emptyMap(), ""); + Lucene luceneLibrary = new Lucene(Collections.emptyMap(), "", Collections.emptyMap()); expectThrows(UnsupportedOperationException.class, luceneLibrary::getCompoundExtension); } public void testScore() { - Lucene luceneLibrary = new Lucene(Collections.emptyMap(), ""); + Lucene luceneLibrary = new Lucene(Collections.emptyMap(), "", Collections.emptyMap()); float rawScore = 10.0f; assertEquals(rawScore, luceneLibrary.score(rawScore, SpaceType.DEFAULT), 0.001); } public void testIsInitialized() { - Lucene luceneLibrary = new Lucene(Collections.emptyMap(), ""); + Lucene luceneLibrary = new Lucene(Collections.emptyMap(), "", Collections.emptyMap()); assertFalse(luceneLibrary.isInitialized()); } public void testSetInitialized() { - Lucene luceneLibrary = new Lucene(Collections.emptyMap(), ""); + Lucene luceneLibrary = new Lucene(Collections.emptyMap(), "", Collections.emptyMap()); luceneLibrary.setInitialized(true); assertTrue(luceneLibrary.isInitialized()); } diff --git a/src/test/java/org/opensearch/knn/index/util/NativeLibraryTests.java b/src/test/java/org/opensearch/knn/index/util/NativeLibraryTests.java index 00a628f1e..5824dfbce 100644 --- a/src/test/java/org/opensearch/knn/index/util/NativeLibraryTests.java +++ b/src/test/java/org/opensearch/knn/index/util/NativeLibraryTests.java @@ -64,5 +64,10 @@ public TestNativeLibrary( ) { super(methods, scoreTranslation, currentVersion, extension); } + + @Override + public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { + return 0.0f; + } } } diff --git a/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java b/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java index 64cf86381..6cebded1b 100644 --- a/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java +++ b/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java @@ -63,6 +63,11 @@ public float score(float rawScore, SpaceType spaceType) { return 0; } + @Override + public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { + return 0.0f; + } + @Override public ValidationException validateMethod(KNNMethodContext knnMethodContext) { return null; diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 68255388b..396c8ea64 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -202,6 +202,23 @@ protected Response searchKNNIndex(String index, KNNQueryBuilder knnQueryBuilder, return response; } + /** + * Run KNN Search on Index with XContentBuilder query + */ + protected Response searchKNNIndex(String index, XContentBuilder xContentBuilder, int resultSize) throws IOException { + Request request = new Request("POST", "/" + index + "/_search"); + request.setJsonEntity(xContentBuilder.toString()); + + request.addParameter("size", Integer.toString(resultSize)); + request.addParameter("explain", Boolean.toString(true)); + request.addParameter("search_type", "query_then_fetch"); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + return response; + } + /** * Run exists search */ From 1887e1256dacb58264ed8031a0c174ad8fbc4027 Mon Sep 17 00:00:00 2001 From: Junqiu Lei Date: Fri, 29 Mar 2024 17:43:51 -0700 Subject: [PATCH 2/6] Support distance type radius search for Faiss engine (#1546) * Support distance type radius search for Faiss engine Signed-off-by: Junqiu Lei --- CHANGELOG.md | 1 + jni/include/faiss_wrapper.h | 12 ++ .../org_opensearch_knn_jni_FaissService.h | 8 + jni/src/faiss_wrapper.cpp | 44 ++++ .../org_opensearch_knn_jni_FaissService.cpp | 14 ++ jni/tests/faiss_wrapper_test.cpp | 113 +++++++++++ .../opensearch/knn/index/query/KNNQuery.java | 66 +++++- .../knn/index/query/KNNQueryBuilder.java | 29 +-- .../opensearch/knn/index/query/KNNWeight.java | 32 ++- .../knn/index/query/RNNQueryFactory.java | 23 +++ .../opensearch/knn/index/util/KNNEngine.java | 2 +- .../org/opensearch/knn/jni/FaissService.java | 11 + .../org/opensearch/knn/jni/JNIService.java | 23 +++ .../org/opensearch/knn/index/FaissIT.java | 192 ++++++++++++++++++ .../knn/index/query/KNNQueryBuilderTests.java | 53 ++++- .../knn/index/query/KNNWeightTests.java | 63 ++++++ .../knn/index/query/RNNQueryFactoryTests.java | 31 ++- 17 files changed, 685 insertions(+), 32 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ecea33b1a..daa6cb18d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.13...2.x) ### Features * Support distance type radius search for Lucene engine [#1498](https://github.com/opensearch-project/k-NN/pull/1498) +* Support distance type radius search for Faiss engine [#1546](https://github.com/opensearch-project/k-NN/pull/1546) ### Enhancements * Make the HitQueue size more appropriate for exact search [#1549](https://github.com/opensearch-project/k-NN/pull/1549) * Support script score when doc value is disabled [#1573](https://github.com/opensearch-project/k-NN/pull/1573) diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index 3e1adeac4..da67c0f59 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -73,6 +73,18 @@ namespace knn_jni { // Return the serialized representation jbyteArray TrainIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension, jlong trainVectorsPointerJ); + + /* + * Perform a range search against the index located in memory at indexPointerJ. + * + * @param indexPointerJ - pointer to the index + * @param queryVectorJ - the query vector + * @param radiusJ - the radius for the range search + * @param maxResultsWindowJ - the maximum number of results to return + * @return an array of RangeQueryResults + */ + jobjectArray RangeSearch(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ, jfloatArray queryVectorJ, + jfloat radiusJ, jint maxResultsWindowJ); } } diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index 32b6f22f1..3715730ab 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -122,6 +122,14 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors (JNIEnv *, jclass, jlong, jobjectArray); +/* +* Class: org_opensearch_knn_jni_FaissService +* Method: rangeSearchIndex +* Signature: (J[F[F)J +*/ +JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndex + (JNIEnv *, jclass, jlong, jfloatArray, jfloat, jint); + #ifdef __cplusplus } #endif diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index 817bdb816..983cfa8a9 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -587,3 +587,47 @@ faiss::IndexIVFPQ * extractIVFPQIndex(faiss::Index * index) { throw std::runtime_error("Unable to extract IVFPQ index. IVFPQ index not present."); } + +jobjectArray knn_jni::faiss_wrapper::RangeSearch(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ, + jfloatArray queryVectorJ, jfloat radiusJ, jint maxResultWindowJ) { + if (queryVectorJ == nullptr) { + throw std::runtime_error("Query Vector cannot be null"); + } + + auto *indexReader = reinterpret_cast(indexPointerJ); + + if (indexReader == nullptr) { + throw std::runtime_error("Invalid pointer to indexReader"); + } + + float *rawQueryVector = jniUtil->GetFloatArrayElements(env, queryVectorJ, nullptr); + + // The res will be freed by ~RangeSearchResult() in FAISS + // The second parameter is always true, as lims is allocated by FAISS + faiss::RangeSearchResult res(1, true); + indexReader->range_search(1, rawQueryVector, radiusJ, &res); + + // lims is structured to support batched queries, it has a length of nq + 1 (where nq is the number of queries), + // lims[i] - lims[i-1] gives the number of results for the i-th query. With a single query we used in k-NN, + // res.lims[0] is always 0, and res.lims[1] gives the total number of matching entries found. + int resultSize = res.lims[1]; + + // Limit the result size to maxResultWindowJ so that we don't return more than the max result window + // TODO: In the future, we should prevent this via FAISS's ResultHandler. + if (resultSize > maxResultWindowJ) { + resultSize = maxResultWindowJ; + } + + jclass resultClass = jniUtil->FindClass(env,"org/opensearch/knn/index/query/KNNQueryResult"); + jmethodID allArgs = jniUtil->FindMethod(env, "org/opensearch/knn/index/query/KNNQueryResult", ""); + + jobjectArray results = jniUtil->NewObjectArray(env, resultSize, resultClass, nullptr); + + jobject result; + for(int i = 0; i < resultSize; ++i) { + result = jniUtil->NewObject(env, resultClass, allArgs, res.labels[i], res.distances[i]); + jniUtil->SetObjectArrayElement(env, results, i, result); + } + + return results; +} diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index 3249ed872..ab2a37e84 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -190,3 +190,17 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors return (jlong) vect; } + +JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndex(JNIEnv * env, jclass cls, + jlong indexPointerJ, + jfloatArray queryVectorJ, + jfloat radiusJ, jint maxResultWindowJ) +{ + try { + return knn_jni::faiss_wrapper::RangeSearch(&jniUtil, env, indexPointerJ, queryVectorJ, radiusJ, maxResultWindowJ); + + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return nullptr; +} diff --git a/jni/tests/faiss_wrapper_test.cpp b/jni/tests/faiss_wrapper_test.cpp index 05854f7ed..07b34976f 100644 --- a/jni/tests/faiss_wrapper_test.cpp +++ b/jni/tests/faiss_wrapper_test.cpp @@ -615,3 +615,116 @@ TEST(FaissInitAndSetSharedIndexState, BasicAssertions) { ASSERT_EQ(1, ivfpqIndex->use_precomputed_table); knn_jni::faiss_wrapper::FreeSharedIndexState(sharedModelAddress); } + +TEST(FaissRangeSearchQueryIndexTest, BasicAssertions) { + // Define the index data + faiss::idx_t numIds = 200; + int dim = 2; + std::vector ids = test_util::Range(numIds); + std::vector vectors = test_util::RandomVectors(dim, numIds, randomDataMin, randomDataMax); + + faiss::MetricType metricType = faiss::METRIC_L2; + std::string method = "HNSW32,Flat"; + + // Define query data + float radius = 100000.0; + int numQueries = 2; + std::vector> queries; + + for (int i = 0; i < numQueries; i++) { + std::vector query; + query.reserve(dim); + for (int j = 0; j < dim; j++) { + query.push_back(test_util::RandomFloat(randomDataMin, randomDataMax)); + } + queries.push_back(query); + } + + // Create the index + std::unique_ptr createdIndex( + test_util::FaissCreateIndex(dim, method, metricType)); + auto createdIndexWithData = + test_util::FaissAddData(createdIndex.get(), ids, vectors); + + // Setup jni + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + + int maxResultWindow = 20000; + + for (auto query : queries) { + std::unique_ptr *>> results( + reinterpret_cast *> *>( + + knn_jni::faiss_wrapper::RangeSearch( + &mockJNIUtil, jniEnv, + reinterpret_cast(&createdIndexWithData), + reinterpret_cast(&query), radius, maxResultWindow))); + + // assert result size is not 0 + ASSERT_NE(0, results->size()); + + + // Need to free up each result + for (auto it : *results) { + delete it; + } + } +} + +TEST(FaissRangeSearchQueryIndexTest_WhenHitMaxWindowResult, BasicAssertions){ + // Define the index data + faiss::idx_t numIds = 200; + int dim = 2; + std::vector ids = test_util::Range(numIds); + std::vector vectors = test_util::RandomVectors(dim, numIds, randomDataMin, randomDataMax); + + faiss::MetricType metricType = faiss::METRIC_L2; + std::string method = "HNSW32,Flat"; + + // Define query data + float radius = 100000.0; + int numQueries = 2; + std::vector> queries; + + for (int i = 0; i < numQueries; i++) { + std::vector query; + query.reserve(dim); + for (int j = 0; j < dim; j++) { + query.push_back(test_util::RandomFloat(randomDataMin, randomDataMax)); + } + queries.push_back(query); + } + + // Create the index + std::unique_ptr createdIndex( + test_util::FaissCreateIndex(dim, method, metricType)); + auto createdIndexWithData = + test_util::FaissAddData(createdIndex.get(), ids, vectors); + + // Setup jni + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + + int maxResultWindow = 10; + + for (auto query : queries) { + std::unique_ptr *>> results( + reinterpret_cast *> *>( + + knn_jni::faiss_wrapper::RangeSearch( + &mockJNIUtil, jniEnv, + reinterpret_cast(&createdIndexWithData), + reinterpret_cast(&query), radius, maxResultWindow))); + + // assert result size is not 0 + ASSERT_NE(0, results->size()); + // assert result size is equal to maxResultWindow + ASSERT_EQ(maxResultWindow, results->size()); + + // Need to free up each result + for (auto it : *results) { + delete it; + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java index 9c78b18a1..0862b2d93 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java @@ -7,6 +7,8 @@ import java.util.Arrays; import java.util.Objects; + +import lombok.AllArgsConstructor; import lombok.Getter; import lombok.Setter; import org.apache.lucene.search.BooleanClause; @@ -30,7 +32,7 @@ public class KNNQuery extends Query { private final String field; private final float[] queryVector; - private final int k; + private int k; private final String indexName; @Getter @@ -38,6 +40,10 @@ public class KNNQuery extends Query { private Query filterQuery; @Getter private BitSetProducer parentsFilter; + @Getter + private Float radius = null; + @Getter + private Context context; public KNNQuery( final String field, @@ -69,6 +75,54 @@ public KNNQuery( this.parentsFilter = parentsFilter; } + /** + * Constructor for KNNQuery with query vector, index name and parent filter + * + * @param field field name + * @param queryVector query vector + * @param indexName index name + * @param parentsFilter parent filter + */ + public KNNQuery(String field, float[] queryVector, String indexName, BitSetProducer parentsFilter) { + this.field = field; + this.queryVector = queryVector; + this.indexName = indexName; + this.parentsFilter = parentsFilter; + } + + /** + * Constructor for KNNQuery with radius + * + * @param radius engine radius + * @return KNNQuery + */ + public KNNQuery radius(Float radius) { + this.radius = radius; + return this; + } + + /** + * Constructor for KNNQuery with Context + * + * @param context Context for KNNQuery + * @return KNNQuery + */ + public KNNQuery kNNQueryContext(Context context) { + this.context = context; + return this; + } + + /** + * Constructor for KNNQuery with filter query + * + * @param filterQuery filter query + * @return KNNQuery + */ + public KNNQuery filterQuery(Query filterQuery) { + this.filterQuery = filterQuery; + return this; + } + public String getField() { return this.field; } @@ -144,4 +198,14 @@ private boolean equalsTo(KNNQuery other) { && Objects.equals(indexName, other.indexName) && Objects.equals(filterQuery, other.filterQuery); } + + /** + * Context for KNNQuery + */ + @Setter + @Getter + @AllArgsConstructor + public static class Context { + int maxResultWindow; + } } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 2dd21ed5a..24487e2d2 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -10,21 +10,12 @@ import java.util.List; import java.util.Objects; + import lombok.extern.log4j.Log4j2; import org.apache.lucene.search.MatchNoDocsQuery; -import org.apache.lucene.search.Query; -import org.opensearch.core.ParseField; -import org.opensearch.core.common.ParsingException; import org.opensearch.core.common.Strings; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.NumberFieldMapper; -import org.opensearch.index.query.AbstractQueryBuilder; import org.opensearch.index.query.QueryBuilder; -import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.SpaceType; @@ -34,6 +25,16 @@ import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.plugin.stats.KNNCounter; +import org.apache.lucene.search.Query; +import org.opensearch.core.ParseField; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.query.AbstractQueryBuilder; +import org.opensearch.index.query.QueryShardContext; import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue; import static org.opensearch.knn.index.IndexUtil.isClusterOnOrAfterMinRequiredVersion; @@ -108,8 +109,8 @@ public KNNQueryBuilder k(int k) { * @param distance the distance threshold for the nearest neighbours */ public KNNQueryBuilder distance(Float distance) { - if (distance == null || distance < 0) { - throw new IllegalArgumentException("[" + NAME + "] requires distance >= 0"); + if (distance == null) { + throw new IllegalArgumentException("[" + NAME + "] requires distance to be set"); } if (k != 0) { throw new IllegalArgumentException("[" + NAME + "] requires either k or distance must be set"); @@ -400,6 +401,9 @@ protected Query doToQuery(QueryShardContext context) { // We need transform distance radius to right type of engine required radius. Float radius = null; if (this.distance != null) { + if (this.distance < 0 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) { + throw new IllegalArgumentException("[" + NAME + "] requires distance to be non-negative for space type: " + spaceType); + } radius = knnEngine.distanceToRadialThreshold(this.distance, spaceType); } @@ -457,7 +461,6 @@ protected Query doToQuery(QueryShardContext context) { .radius(radius) .filter(this.filter) .context(context) - .radius(radius) .build(); return RNNQueryFactory.create(createQueryRequest); } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 06bf96d63..6b323e124 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -277,16 +277,25 @@ private Map doANNSearch(final LeafReaderContext context, final B throw new RuntimeException("Index has already been closed"); } int[] parentIds = getParentIdsArray(context); - results = JNIService.queryIndex( - indexAllocation.getMemoryAddress(), - knnQuery.getQueryVector(), - knnQuery.getK(), - knnEngine, - filterIds, - filterType.getValue(), - parentIds - ); - + if (knnQuery.getK() > 0) { + results = JNIService.queryIndex( + indexAllocation.getMemoryAddress(), + knnQuery.getQueryVector(), + knnQuery.getK(), + knnEngine, + filterIds, + filterType.getValue(), + parentIds + ); + } else { + results = JNIService.radiusQueryIndex( + indexAllocation.getMemoryAddress(), + knnQuery.getQueryVector(), + knnQuery.getRadius(), + knnEngine, + knnQuery.getContext().getMaxResultWindow() + ); + } } catch (Exception e) { GRAPH_QUERY_ERRORS.increment(); throw new RuntimeException(e); @@ -406,6 +415,9 @@ private boolean canDoExactSearch(final int filterIdsCount) { filterIdsCount, KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName()) ); + if (knnQuery.getRadius() != null) { + return false; + } int filterThresholdValue = KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName()); // Refer this GitHub around more details https://github.com/opensearch-project/k-NN/issues/1049 on the logic if (filterIdsCount <= knnQuery.getK()) { diff --git a/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java index 69018fa73..cd32ac4f3 100644 --- a/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java @@ -14,6 +14,9 @@ import org.apache.lucene.search.ByteVectorSimilarityQuery; import org.apache.lucene.search.FloatVectorSimilarityQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.search.join.BitSetProducer; +import org.opensearch.index.IndexSettings; +import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; @@ -66,6 +69,26 @@ public static Query create(RNNQueryFactory.CreateQueryRequest createQueryRequest final VectorDataType vectorDataType = createQueryRequest.getVectorDataType(); final Query filterQuery = getFilterQuery(createQueryRequest); + if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) { + BitSetProducer parentFilter = null; + QueryShardContext context = createQueryRequest.getContext().get(); + + if (createQueryRequest.getContext().isPresent()) { + parentFilter = context.getParentFilter(); + } + IndexSettings indexSettings = context.getIndexSettings(); + KNNQuery.Context knnQueryContext = new KNNQuery.Context(indexSettings.getMaxResultWindow()); + KNNQuery rnnQuery = new KNNQuery(fieldName, vector, indexName, parentFilter).radius(radius).kNNQueryContext(knnQueryContext); + if (filterQuery != null && KNNEngine.getEnginesThatSupportsFilters().contains(createQueryRequest.getKnnEngine())) { + log.debug("Creating custom radius search with filters for index: {}, field: {} , r: {}", indexName, fieldName, radius); + rnnQuery.filterQuery(filterQuery); + } + log.debug( + String.format("Creating custom radius search for index: %s \"\", field: %s \"\", r: %f", indexName, fieldName, radius) + ); + return rnnQuery; + } + log.debug(String.format("Creating Lucene r-NN query for index: %s \"\", field: %s \"\", k: %f", indexName, fieldName, radius)); switch (vectorDataType) { case BYTE: diff --git a/src/main/java/org/opensearch/knn/index/util/KNNEngine.java b/src/main/java/org/opensearch/knn/index/util/KNNEngine.java index 5d334e7ef..6551f5a39 100644 --- a/src/main/java/org/opensearch/knn/index/util/KNNEngine.java +++ b/src/main/java/org/opensearch/knn/index/util/KNNEngine.java @@ -32,7 +32,7 @@ public enum KNNEngine implements KNNLibrary { private static final Set CUSTOM_SEGMENT_FILE_ENGINES = ImmutableSet.of(KNNEngine.NMSLIB, KNNEngine.FAISS); private static final Set ENGINES_SUPPORTING_FILTERS = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS); - public static final Set ENGINES_SUPPORTING_RADIAL_SEARCH = ImmutableSet.of(KNNEngine.LUCENE); + public static final Set ENGINES_SUPPORTING_RADIAL_SEARCH = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS); private static Map MAX_DIMENSIONS_BY_ENGINE = Map.of( KNNEngine.NMSLIB, diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index 32516ef9d..b59ac4bcf 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -190,4 +190,15 @@ public static native KNNQueryResult[] queryIndexWithFilter( */ @Deprecated(since = "2.14.0", forRemoval = true) public static native long transferVectors(long vectorsPointer, float[][] trainingData); + + /** + * Range search index + * + * @param indexPointer pointer to index in memory + * @param queryVector vector to be used for query + * @param radius search within radius threshold + * @param indexMaxResultWindow maximum number of results to return + * @return KNNQueryResult array of neighbors within radius + */ + public static native KNNQueryResult[] rangeSearchIndex(long indexPointer, float[] queryVector, float radius, int indexMaxResultWindow); } diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index 5a5b6794a..e846f02d1 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -262,4 +262,27 @@ public static byte[] trainIndex(Map indexParameters, int dimensi public static long transferVectors(long vectorsPointer, float[][] trainingData) { return FaissService.transferVectors(vectorsPointer, trainingData); } + + /** + * Range search index for a given query vector + * + * @param indexPointer pointer to index in memory + * @param queryVector vector to be used for query + * @param radius search within radius threshold + * @param knnEngine engine to query index + * @param indexMaxResultWindow maximum number of results to return + * @return KNNQueryResult array of neighbors within radius + */ + public static KNNQueryResult[] radiusQueryIndex( + long indexPointer, + float[] queryVector, + float radius, + KNNEngine knnEngine, + int indexMaxResultWindow + ) { + if (KNNEngine.FAISS == knnEngine) { + return FaissService.rangeSearchIndex(indexPointer, queryVector, radius, indexMaxResultWindow); + } + throw new IllegalArgumentException("RadiusQueryIndex not supported for provided engine"); + } } diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index 85dd3f169..653750719 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -277,6 +277,167 @@ public void testEndToEnd_whenMethodIsHNSWFlatAndHasDeletedDocs_thenSucceed() { fail("Graphs are not getting evicted"); } + @SneakyThrows + public void testEndToEnd_whenDoRadiusSearch_whenMethodIsHNSWFlat_thenSucceed() { + KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW); + SpaceType spaceType = SpaceType.L2; + + List mValues = ImmutableList.of(16, 32, 64, 128); + List efConstructionValues = ImmutableList.of(16, 32, 64, 128); + List efSearchValues = ImmutableList.of(16, 32, 64, 128); + + Integer dimension = testData.indexData.vectors[0].length; + + // Create an index + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(KNNConstants.PARAMETERS) + .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + Map mappingMap = xContentBuilderToMap(builder); + String mapping = builder.toString(); + + createKnnIndex(INDEX_NAME, mapping); + assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(INDEX_NAME))); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + INDEX_NAME, + Integer.toString(testData.indexData.docs[i]), + FIELD_NAME, + Floats.asList(testData.indexData.vectors[i]).toArray() + ); + } + + // Assert we have the right number of documents + refreshAllNonSystemIndices(); + assertEquals(testData.indexData.docs.length, getDocCount(INDEX_NAME)); + + float radius = 300000000000f; + for (float[] queryVector : testData.queries) { + validateRadiusSearchResults(INDEX_NAME, FIELD_NAME, queryVector, radius, spaceType); + } + + // Delete index + deleteKNNIndex(INDEX_NAME); + } + + @SneakyThrows + public void testEndToEnd_whenDoRadiusSearch_whenMethodIsHNSWPQ_thenSucceed() { + String indexName = "test-index"; + String fieldName = "test-field"; + String trainingIndexName = "training-index"; + String trainingFieldName = "training-field"; + + String modelId = "test-model"; + String modelDescription = "test model"; + + List mValues = ImmutableList.of(16, 32, 64, 128); + List efConstructionValues = ImmutableList.of(16, 32, 64, 128); + List efSearchValues = ImmutableList.of(16, 32, 64, 128); + List pqMValues = ImmutableList.of(2, 4, 8); + + // training data needs to be at least equal to the number of centroids for PQ + // which is 2^8 = 256. 8 because that's the only valid code_size for HNSWPQ + int trainingDataCount = 256; + + SpaceType spaceType = SpaceType.L2; + + int dimension = testData.indexData.vectors[0].length; + + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, METHOD_HNSW) + .field(KNN_ENGINE, FAISS_NAME) + .startObject(PARAMETERS) + .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, ENCODER_PQ) + .startObject(PARAMETERS) + .field(ENCODER_PARAMETER_PQ_M, pqMValues.get(random().nextInt(pqMValues.size()))) + .endObject() + .endObject() + .endObject() + .endObject(); + Map in = xContentBuilderToMap(xContentBuilder); + + createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension); + ingestDataAndTrainModel(modelId, trainingIndexName, trainingFieldName, dimension, modelDescription, in, trainingDataCount); + assertTrainingSucceeds(modelId, 360, 1000); + + // Create an index + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("model_id", modelId) + .endObject() + .endObject() + .endObject(); + + Map mappingMap = xContentBuilderToMap(builder); + String mapping = builder.toString(); + + createKnnIndex(indexName, mapping); + + assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName))); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + indexName, + Integer.toString(testData.indexData.docs[i]), + fieldName, + Floats.asList(testData.indexData.vectors[i]).toArray() + ); + } + + // Assert we have the right number of documents in the index + refreshAllNonSystemIndices(); + assertEquals(testData.indexData.docs.length, getDocCount(indexName)); + + float radius = 300000000000f; + + for (float[] queryVector : testData.queries) { + validateRadiusSearchResults(indexName, fieldName, queryVector, radius, spaceType); + } + + // Delete index + deleteKNNIndex(indexName); + deleteModel(modelId); + + // Search every 5 seconds 14 times to confirm graph gets evicted + int intervals = 14; + for (int i = 0; i < intervals; i++) { + if (getTotalGraphsInCache() == 0) { + return; + } + + Thread.sleep(5 * 1000); + } + + fail("Graphs are not getting evicted"); + } + @SneakyThrows public void testEndToEnd_whenMethodIsHNSWPQ_thenSucceed() { String indexName = "test-index"; @@ -1410,4 +1571,35 @@ private void validateGraphEviction() throws Exception { fail("Graphs are not getting evicted"); } + + private void validateRadiusSearchResults( + String indexName, + String fieldName, + float[] queryVector, + float radius, + final SpaceType spaceType + ) throws IOException, ParseException { + XContentBuilder queryBuilder = XContentFactory.jsonBuilder().startObject().startObject("query"); + queryBuilder.startObject("knn"); + queryBuilder.startObject(fieldName); + queryBuilder.field("vector", queryVector); + queryBuilder.field("distance", radius); + queryBuilder.endObject(); + queryBuilder.endObject(); + queryBuilder.endObject().endObject(); + final String responseBody = EntityUtils.toString(searchKNNIndex(indexName, queryBuilder, 10).getEntity()); + + List knnResults = parseSearchResponse(responseBody, fieldName); + + for (KNNResult knnResult : knnResults) { + float[] vector = Floats.toArray(Arrays.stream(knnResult.getVector()).collect(Collectors.toList())); + if (spaceType == SpaceType.L2) { + assertTrue(KNNScoringUtil.l2Squared(queryVector, vector) <= radius); + } else if (spaceType == SpaceType.INNER_PRODUCT) { + assertTrue(KNNScoringUtil.innerProduct(queryVector, vector) >= radius); + } else { + throw new IllegalArgumentException("Invalid space type"); + } + } + } } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index de56aeb15..3b0f03b38 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -19,6 +19,7 @@ import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.IndexSettings; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.TermQueryBuilder; @@ -84,12 +85,6 @@ public void testInvalidK() { public void testInvalidDistance() { float[] queryVector = { 1.0f, 1.0f }; - - /** - * -ve distance - */ - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).distance(-1.0f)); - /** * null distance */ @@ -345,6 +340,52 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenSucceed() { assertTrue(query.toString().contains("resultSimilarity=" + KNNEngine.LUCENE.distanceToRadialThreshold(DISTANCE, SpaceType.L2))); } + public void testDoToQuery_whenPassNegativeDistance_whenSupportedSpaceType_thenSucceed() { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + float negativeDistance = -1.0f; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).distance(negativeDistance); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn( + new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, methodComponentContext) + ); + IndexSettings indexSettings = mock(IndexSettings.class); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + when(indexSettings.getMaxResultWindow()).thenReturn(1000); + + KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); + + assertEquals(negativeDistance, query.getRadius(), 0); + } + + public void testDoToQuery_whenPassNegativeDistance_whenUnSupportedSpaceType_thenException() { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + float negativeDistance = -1.0f; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).distance(negativeDistance); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn( + new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext) + ); + IndexSettings indexSettings = mock(IndexSettings.class); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + when(indexSettings.getMaxResultWindow()).thenReturn(1000); + + expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); + } + public void testDoToQuery_KnnQueryWithFilter() throws Exception { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY); diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index da4b7093a..581d3c0b4 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -61,6 +61,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyFloat; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; @@ -691,6 +692,68 @@ public void testANNWithParentsFilter_whenDoingANN_thenBitSetIsPassedToJNI() { assertEquals(DOC_ID_TO_SCORES.size(), docIdSetIterator.cost()); } + @SneakyThrows + public void testDoANNSearch_whenRadialIsDefined_thenCallJniRadiusQueryIndex() { + final float[] queryVector = new float[] { 0.1f, 0.3f }; + final float radius = 0.5f; + final int maxResults = 1000; + jniServiceMockedStatic.when(() -> JNIService.radiusQueryIndex(anyLong(), any(), anyFloat(), any(), anyInt())) + .thenReturn(getKNNQueryResults()); + KNNQuery.Context context = mock(KNNQuery.Context.class); + when(context.getMaxResultWindow()).thenReturn(maxResults); + + final KNNQuery query = new KNNQuery(FIELD_NAME, queryVector, INDEX_NAME, null).radius(radius).kNNQueryContext(context); + final float boost = (float) randomDoubleBetween(0, 10, true); + final KNNWeight knnWeight = new KNNWeight(query, boost); + + final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + final SegmentReader reader = mock(SegmentReader.class); + when(leafReaderContext.reader()).thenReturn(reader); + + final FSDirectory directory = mock(FSDirectory.class); + when(reader.directory()).thenReturn(directory); + final SegmentInfo segmentInfo = new SegmentInfo( + directory, + Version.LATEST, + Version.LATEST, + SEGMENT_NAME, + 100, + true, + false, + KNNCodecVersion.current().getDefaultCodecDelegate(), + Map.of(), + new byte[StringHelper.ID_LENGTH], + Map.of(), + Sort.RELEVANCE + ); + segmentInfo.setFiles(SEGMENT_FILES_FAISS); + final SegmentCommitInfo segmentCommitInfo = new SegmentCommitInfo(segmentInfo, 0, 0, 0, 0, 0, new byte[StringHelper.ID_LENGTH]); + when(reader.getSegmentInfo()).thenReturn(segmentCommitInfo); + + final Path path = mock(Path.class); + when(directory.getDirectory()).thenReturn(path); + final FieldInfos fieldInfos = mock(FieldInfos.class); + final FieldInfo fieldInfo = mock(FieldInfo.class); + when(reader.getFieldInfos()).thenReturn(fieldInfos); + when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); + when(fieldInfo.attributes()).thenReturn(Map.of(SPACE_TYPE, SpaceType.L2.getValue(), KNN_ENGINE, KNNEngine.FAISS.getName())); + + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + assertNotNull(knnScorer); + jniServiceMockedStatic.verify(() -> JNIService.radiusQueryIndex(anyLong(), any(), anyFloat(), any(), anyInt())); + + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + + final List actualDocIds = new ArrayList<>(); + final Map translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation); + for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + actualDocIds.add(docId); + assertEquals(translatedScores.get(docId) * boost, knnScorer.score(), 0.01f); + } + assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + } + private SegmentReader getMockedSegmentReader() { final SegmentReader reader = mock(SegmentReader.class); when(reader.maxDoc()).thenReturn(1); diff --git a/src/test/java/org/opensearch/knn/index/query/RNNQueryFactoryTests.java b/src/test/java/org/opensearch/knn/index/query/RNNQueryFactoryTests.java index 42a3c2897..5492b8506 100644 --- a/src/test/java/org/opensearch/knn/index/query/RNNQueryFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/RNNQueryFactoryTests.java @@ -18,6 +18,7 @@ import org.apache.lucene.search.FloatVectorSimilarityQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.join.BitSetProducer; +import org.opensearch.index.IndexSettings; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryShardContext; @@ -36,6 +37,7 @@ public class RNNQueryFactoryTests extends KNNTestCase { private final String testIndexName = "test-index"; private final String testFieldName = "test-field"; private final Float testRadius = 0.5f; + private final int maxResultWindow = 20000; public void testCreate_whenLucene_withRadiusQuery_withFloatVector() { List luceneDefaultQueryEngineList = Arrays.stream(KNNEngine.values()) @@ -80,7 +82,7 @@ public void testCreate_whenLucene_withRadiusQuery_withByteVector() { } } - public void testCreateLuceneRadiusQueryWithFilter() { + public void testCreate_whenLucene_withFilter_thenSucceed() { List luceneDefaultQueryEngineList = Arrays.stream(KNNEngine.values()) .filter(knnEngine -> !KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)) .collect(Collectors.toList()); @@ -102,4 +104,31 @@ public void testCreateLuceneRadiusQueryWithFilter() { assertEquals(FloatVectorSimilarityQuery.class, query.getClass()); } } + + public void testCreate_whenFaiss_thenSucceed() { + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + MappedFieldType testMapper = mock(MappedFieldType.class); + IndexSettings indexSettings = mock(IndexSettings.class); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper); + when(mockQueryShardContext.getIndexSettings().getMaxResultWindow()).thenReturn(maxResultWindow); + final RNNQueryFactory.CreateQueryRequest createQueryRequest = RNNQueryFactory.CreateQueryRequest.builder() + .knnEngine(KNNEngine.FAISS) + .indexName(testIndexName) + .fieldName(testFieldName) + .vector(testQueryVector) + .radius(testRadius) + .vectorDataType(DEFAULT_VECTOR_DATA_TYPE_FIELD) + .context(mockQueryShardContext) + .build(); + + Query query = RNNQueryFactory.create(createQueryRequest); + + assertTrue(query instanceof KNNQuery); + assertEquals(testIndexName, ((KNNQuery) query).getIndexName()); + assertEquals(testFieldName, ((KNNQuery) query).getField()); + assertEquals(testQueryVector, ((KNNQuery) query).getQueryVector()); + assertEquals(testRadius, ((KNNQuery) query).getRadius(), 0); + assertEquals(maxResultWindow, ((KNNQuery) query).getContext().getMaxResultWindow()); + } } From 244450b3c7e6bab30a678a0551d8e7001af1fa02 Mon Sep 17 00:00:00 2001 From: Junqiu Lei Date: Fri, 5 Apr 2024 11:16:25 -0700 Subject: [PATCH 3/6] Support score type threshold in radial search (#1589) * Support score type threshold in radial search Signed-off-by: Junqiu Lei --- CHANGELOG.md | 1 + .../org/opensearch/knn/index/SpaceType.java | 18 ++ .../knn/index/query/KNNQueryBuilder.java | 90 ++++++- .../org/opensearch/knn/index/util/Faiss.java | 28 +- .../opensearch/knn/index/util/KNNEngine.java | 5 + .../opensearch/knn/index/util/KNNLibrary.java | 10 + .../org/opensearch/knn/index/util/Lucene.java | 15 +- .../org/opensearch/knn/index/util/Nmslib.java | 4 + .../org/opensearch/knn/index/FaissIT.java | 186 ++++++++++--- .../opensearch/knn/index/LuceneEngineIT.java | 137 ++++++++-- .../knn/index/query/KNNQueryBuilderTests.java | 254 ++++++++++++++++-- .../index/util/AbstractKNNLibraryTests.java | 4 + .../knn/index/util/NativeLibraryTests.java | 5 + .../LibraryInitializedSupplierTests.java | 5 + 14 files changed, 675 insertions(+), 87 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index daa6cb18d..c9e34def9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features * Support distance type radius search for Lucene engine [#1498](https://github.com/opensearch-project/k-NN/pull/1498) * Support distance type radius search for Faiss engine [#1546](https://github.com/opensearch-project/k-NN/pull/1546) +* Support score type threshold in radial search [#1589](https://github.com/opensearch-project/k-NN/pull/1589) ### Enhancements * Make the HitQueue size more appropriate for exact search [#1549](https://github.com/opensearch-project/k-NN/pull/1549) * Support script score when doc value is disabled [#1573](https://github.com/opensearch-project/k-NN/pull/1573) diff --git a/src/main/java/org/opensearch/knn/index/SpaceType.java b/src/main/java/org/opensearch/knn/index/SpaceType.java index 50d8d352c..240bfbe91 100644 --- a/src/main/java/org/opensearch/knn/index/SpaceType.java +++ b/src/main/java/org/opensearch/knn/index/SpaceType.java @@ -36,6 +36,14 @@ public float scoreTranslation(float rawScore) { public VectorSimilarityFunction getVectorSimilarityFunction() { return VectorSimilarityFunction.EUCLIDEAN; } + + @Override + public float scoreToDistanceTranslation(float score) { + if (score == 0) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "score cannot be 0 when space type is [%s]", getValue())); + } + return 1 / score - 1; + } }, COSINESIMIL("cosinesimil") { @Override @@ -170,4 +178,14 @@ public static SpaceType getSpace(String spaceTypeName) { } throw new IllegalArgumentException("Unable to find space: " + spaceTypeName); } + + /** + * Translate a score to a distance for this space type + * + * @param score score to translate + * @return translated distance + */ + public float scoreToDistanceTranslation(float score) { + throw new UnsupportedOperationException(String.format("Space [%s] does not have a score to distance translation", getValue())); + } } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 24487e2d2..78ddb532d 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -52,6 +52,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { public static final ParseField FILTER_FIELD = new ParseField("filter"); public static final ParseField IGNORE_UNMAPPED_FIELD = new ParseField("ignore_unmapped"); public static final ParseField DISTANCE_FIELD = new ParseField("distance"); + public static final ParseField SCORE_FIELD = new ParseField("score"); public static final int K_MAX = 10000; /** * The name for the knn query @@ -64,6 +65,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { private final float[] vector; private int k = 0; private Float distance = null; + private Float score = null; private QueryBuilder filter; private boolean ignoreUnmapped = false; @@ -92,13 +94,14 @@ public KNNQueryBuilder(String fieldName, float[] vector) { * * @param k K nearest neighbours for the given vector */ - public KNNQueryBuilder k(int k) { + public KNNQueryBuilder k(Integer k) { + if (k == null) { + throw new IllegalArgumentException("[" + NAME + "] requires k to be set"); + } + validateSingleQueryType(k, distance, score); if (k <= 0 || k > K_MAX) { throw new IllegalArgumentException("[" + NAME + "] requires 0 < k <= " + K_MAX); } - if (distance != null) { - throw new IllegalArgumentException("[" + NAME + "] requires either k or distance must be set"); - } this.k = k; return this; } @@ -112,13 +115,28 @@ public KNNQueryBuilder distance(Float distance) { if (distance == null) { throw new IllegalArgumentException("[" + NAME + "] requires distance to be set"); } - if (k != 0) { - throw new IllegalArgumentException("[" + NAME + "] requires either k or distance must be set"); - } + validateSingleQueryType(k, distance, score); this.distance = distance; return this; } + /** + * Builder method for score + * + * @param score the score threshold for the nearest neighbours + */ + public KNNQueryBuilder score(Float score) { + if (score == null) { + throw new IllegalArgumentException("[" + NAME + "] requires score to be set"); + } + validateSingleQueryType(k, distance, score); + if (score <= 0) { + throw new IllegalArgumentException("[" + NAME + "] requires score greater than 0"); + } + this.score = score; + return this; + } + /** * Builder method for filter * @@ -163,6 +181,7 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder fil this.filter = filter; this.ignoreUnmapped = false; this.distance = null; + this.score = null; } public static void initialize(ModelDao modelDao) { @@ -200,6 +219,9 @@ public KNNQueryBuilder(StreamInput in) throws IOException { if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { distance = in.readOptionalFloat(); } + if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { + score = in.readOptionalFloat(); + } } catch (IOException ex) { throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex); } @@ -211,6 +233,7 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep float boost = AbstractQueryBuilder.DEFAULT_BOOST; Integer k = null; Float distance = null; + Float score = null; QueryBuilder filter = null; String queryName = null; String currentFieldName = null; @@ -241,6 +264,8 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep queryName = parser.text(); } else if (DISTANCE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { distance = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false); + } else if (SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + score = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false); } else { throw new ParsingException( parser.getTokenLocation(), @@ -270,9 +295,7 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep } } - if ((k != null && distance != null) || (k == null && distance == null)) { - throw new IllegalArgumentException("[" + NAME + "] requires either k or distance must be set"); - } + validateSingleQueryType(k, distance, score); KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector)).filter(filter) .ignoreUnmapped(ignoreUnmapped) @@ -281,8 +304,10 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep if (k != null) { knnQueryBuilder.k(k); - } else { + } else if (distance != null) { knnQueryBuilder.distance(distance); + } else if (score != null) { + knnQueryBuilder.score(score); } return knnQueryBuilder; @@ -300,6 +325,9 @@ protected void doWriteTo(StreamOutput out) throws IOException { if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { out.writeOptionalFloat(distance); } + if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { + out.writeOptionalFloat(score); + } } /** @@ -324,6 +352,10 @@ public float getDistance() { return this.distance; } + public float getScore() { + return this.score; + } + public QueryBuilder getFilter() { return this.filter; } @@ -358,6 +390,9 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio if (ignoreUnmapped) { builder.field(IGNORE_UNMAPPED_FIELD.getPreferredName(), ignoreUnmapped); } + if (score != null) { + builder.field(SCORE_FIELD.getPreferredName(), score); + } printBoostAndQueryName(builder); builder.endObject(); builder.endObject(); @@ -397,8 +432,8 @@ protected Query doToQuery(QueryShardContext context) { spaceType = knnMethodContext.getSpaceType(); } - // Currently, k-NN supports distance type radius search. - // We need transform distance radius to right type of engine required radius. + // Currently, k-NN supports distance and score types radial search + // We need transform distance/score to right type of engine required radius. Float radius = null; if (this.distance != null) { if (this.distance < 0 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) { @@ -407,6 +442,13 @@ protected Query doToQuery(QueryShardContext context) { radius = knnEngine.distanceToRadialThreshold(this.distance, spaceType); } + if (this.score != null) { + if (this.score > 1 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) { + throw new IllegalArgumentException("[" + NAME + "] requires score to be in the range (0, 1] for space type: " + spaceType); + } + radius = knnEngine.scoreToRadialThreshold(this.score, spaceType); + } + if (fieldDimension != vector.length) { throw new IllegalArgumentException( String.format("Query vector has invalid dimension: %d. Dimension should be: %d", vector.length, fieldDimension) @@ -464,7 +506,7 @@ protected Query doToQuery(QueryShardContext context) { .build(); return RNNQueryFactory.create(createQueryRequest); } - throw new IllegalArgumentException("[" + NAME + "] requires either k or distance must be set"); + throw new IllegalArgumentException("[" + NAME + "] requires either k or distance or score to be set"); } private ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) { @@ -499,4 +541,24 @@ protected int doHashCode() { public String getWriteableName() { return NAME; } + + private static void validateSingleQueryType(Integer k, Float distance, Float score) { + int countSetFields = 0; + + if (k != null && k != 0) { + countSetFields++; + } + if (distance != null) { + countSetFields++; + } + if (score != null) { + countSetFields++; + } + + if (countSetFields != 1) { + throw new IllegalArgumentException( + "[" + NAME + "] requires only one query type to be set, it can be either k, distance, or score" + ); + } + } } diff --git a/src/main/java/org/opensearch/knn/index/util/Faiss.java b/src/main/java/org/opensearch/knn/index/util/Faiss.java index 8b58b7e74..efd8a637c 100644 --- a/src/main/java/org/opensearch/knn/index/util/Faiss.java +++ b/src/main/java/org/opensearch/knn/index/util/Faiss.java @@ -56,6 +56,7 @@ * Implements NativeLibrary for the faiss native library */ class Faiss extends NativeLibrary { + Map> scoreTransform; // TODO: Current version is not really current version. Instead, it encodes information in the file name // about the compatibility version the file is created with. In the future, we should refactor this so that it @@ -68,6 +69,12 @@ class Faiss extends NativeLibrary { rawScore -> SpaceType.INNER_PRODUCT.scoreTranslation(-1 * rawScore) ); + // Map that overrides radial search score threshold to faiss required distance, check more details in knn documentation: + // https://opensearch.org/docs/latest/search-plugins/knn/approximate-knn/#spaces + private final static Map> SCORE_TO_DISTANCE_TRANSFORMATIONS = ImmutableMap.< + SpaceType, + Function>builder().put(SpaceType.INNER_PRODUCT, score -> score > 1 ? 1 - score : 1 / score - 1).build(); + // Define encoders supported by faiss private final static MethodComponentContext ENCODER_DEFAULT = new MethodComponentContext( KNNConstants.ENCODER_FLAT, @@ -301,7 +308,13 @@ class Faiss extends NativeLibrary { ).addSpaces(SpaceType.L2, SpaceType.INNER_PRODUCT).build() ); - final static Faiss INSTANCE = new Faiss(METHODS, SCORE_TRANSLATIONS, CURRENT_VERSION, KNNConstants.FAISS_EXTENSION); + final static Faiss INSTANCE = new Faiss( + METHODS, + SCORE_TRANSLATIONS, + CURRENT_VERSION, + KNNConstants.FAISS_EXTENSION, + SCORE_TO_DISTANCE_TRANSFORMATIONS + ); /** * Constructor for Faiss @@ -315,9 +328,11 @@ private Faiss( Map methods, Map> scoreTranslation, String currentVersion, - String extension + String extension, + Map> scoreTransform ) { super(methods, scoreTranslation, currentVersion, extension); + this.scoreTransform = scoreTransform; } @Override @@ -326,6 +341,15 @@ public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { return distance; } + @Override + public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { + // Faiss engine uses distance as is and need transformation + if (this.scoreTransform.containsKey(spaceType)) { + return this.scoreTransform.get(spaceType).apply(score); + } + return spaceType.scoreToDistanceTranslation(score); + } + /** * MethodAsMap builder is used to create the map that will be passed to the jni to create the faiss index. * Faiss's index factory takes an "index description" that it uses to build the index. In this description, diff --git a/src/main/java/org/opensearch/knn/index/util/KNNEngine.java b/src/main/java/org/opensearch/knn/index/util/KNNEngine.java index 6551f5a39..e282c69db 100644 --- a/src/main/java/org/opensearch/knn/index/util/KNNEngine.java +++ b/src/main/java/org/opensearch/knn/index/util/KNNEngine.java @@ -158,6 +158,11 @@ public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { return knnLibrary.distanceToRadialThreshold(distance, spaceType); } + @Override + public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { + return knnLibrary.scoreToRadialThreshold(score, spaceType); + } + @Override public ValidationException validateMethod(KNNMethodContext knnMethodContext) { return knnLibrary.validateMethod(knnMethodContext); diff --git a/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java b/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java index 2c1f454c0..f837566b8 100644 --- a/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java +++ b/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java @@ -78,6 +78,16 @@ public interface KNNLibrary { */ Float distanceToRadialThreshold(Float distance, SpaceType spaceType); + /** + * Translate the score threshold input from end user to the engine's threshold. + * + * @param score score threshold input from end user + * @param spaceType spaceType used to compute the threshold + * + * @return transformed score for the library + */ + Float scoreToRadialThreshold(Float score, SpaceType spaceType); + /** * Validate the knnMethodContext for the given library. A ValidationException should be thrown if the method is * deemed invalid. diff --git a/src/main/java/org/opensearch/knn/index/util/Lucene.java b/src/main/java/org/opensearch/knn/index/util/Lucene.java index 54a752408..630d7a2c2 100644 --- a/src/main/java/org/opensearch/knn/index/util/Lucene.java +++ b/src/main/java/org/opensearch/knn/index/util/Lucene.java @@ -48,11 +48,13 @@ public class Lucene extends JVMLibrary { ).addSpaces(SpaceType.L2, SpaceType.COSINESIMIL, SpaceType.INNER_PRODUCT).build() ); + // Map that overrides the default distance translations for Lucene, check more details in knn documentation: + // https://opensearch.org/docs/latest/search-plugins/knn/approximate-knn/#spaces private final static Map> DISTANCE_TRANSLATIONS = ImmutableMap.< SpaceType, Function>builder() .put(SpaceType.COSINESIMIL, distance -> (2 - distance) / 2) - .put(SpaceType.L2, distance -> 1 / (1 + distance)) + .put(SpaceType.INNER_PRODUCT, distance -> distance <= 0 ? 1 / (1 - distance) : distance + 1) .build(); final static Lucene INSTANCE = new Lucene(METHODS, Version.LATEST.toString(), DISTANCE_TRANSLATIONS); @@ -90,7 +92,16 @@ public float score(float rawScore, SpaceType spaceType) { @Override public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { // Lucene requires score threshold to be parameterized when calling the radius search. - return this.distanceTransform.get(spaceType).apply(distance); + if (this.distanceTransform.containsKey(spaceType)) { + return this.distanceTransform.get(spaceType).apply(distance); + } + return spaceType.scoreTranslation(distance); + } + + @Override + public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { + // Lucene engine uses distance as is and does not need transformation + return score; } @Override diff --git a/src/main/java/org/opensearch/knn/index/util/Nmslib.java b/src/main/java/org/opensearch/knn/index/util/Nmslib.java index 8993cac8e..64af43520 100644 --- a/src/main/java/org/opensearch/knn/index/util/Nmslib.java +++ b/src/main/java/org/opensearch/knn/index/util/Nmslib.java @@ -73,4 +73,8 @@ private Nmslib( public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { return distance; } + + public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { + return score; + } } diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index 653750719..16eb1a4c3 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -278,7 +278,7 @@ public void testEndToEnd_whenMethodIsHNSWFlatAndHasDeletedDocs_thenSucceed() { } @SneakyThrows - public void testEndToEnd_whenDoRadiusSearch_whenMethodIsHNSWFlat_thenSucceed() { + public void testEndToEnd_whenDoRadiusSearch_whenDistanceThreshold_whenMethodIsHNSWFlat_thenSucceed() { KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW); SpaceType spaceType = SpaceType.L2; @@ -329,17 +329,135 @@ public void testEndToEnd_whenDoRadiusSearch_whenMethodIsHNSWFlat_thenSucceed() { refreshAllNonSystemIndices(); assertEquals(testData.indexData.docs.length, getDocCount(INDEX_NAME)); - float radius = 300000000000f; - for (float[] queryVector : testData.queries) { - validateRadiusSearchResults(INDEX_NAME, FIELD_NAME, queryVector, radius, spaceType); + float distance = 300000000000f; + validateRadiusSearchResults(INDEX_NAME, FIELD_NAME, testData.queries, distance, null, spaceType); + + // Delete index + deleteKNNIndex(INDEX_NAME); + } + + @SneakyThrows + public void testEndToEnd_whenDoRadiusSearch_whenScoreThreshold_whenMethodIsHNSWFlat_thenSucceed() { + KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW); + SpaceType spaceType = SpaceType.L2; + + List mValues = ImmutableList.of(16, 32, 64, 128); + List efConstructionValues = ImmutableList.of(16, 32, 64, 128); + List efSearchValues = ImmutableList.of(16, 32, 64, 128); + + Integer dimension = testData.indexData.vectors[0].length; + + // Create an index + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(KNNConstants.PARAMETERS) + .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + Map mappingMap = xContentBuilderToMap(builder); + String mapping = builder.toString(); + + createKnnIndex(INDEX_NAME, mapping); + assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(INDEX_NAME))); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + INDEX_NAME, + Integer.toString(testData.indexData.docs[i]), + FIELD_NAME, + Floats.asList(testData.indexData.vectors[i]).toArray() + ); } + // Assert we have the right number of documents + refreshAllNonSystemIndices(); + assertEquals(testData.indexData.docs.length, getDocCount(INDEX_NAME)); + + float score = 0.00001f; + + validateRadiusSearchResults(INDEX_NAME, FIELD_NAME, testData.queries, null, score, spaceType); + // Delete index deleteKNNIndex(INDEX_NAME); } @SneakyThrows - public void testEndToEnd_whenDoRadiusSearch_whenMethodIsHNSWPQ_thenSucceed() { + public void testEndToEnd_whenDoRadiusSearch_whenMoreThanOneScoreThreshold_whenMethodIsHNSWFlat_thenSucceed() { + KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW); + SpaceType spaceType = SpaceType.INNER_PRODUCT; + + List mValues = ImmutableList.of(16, 32, 64, 128); + List efConstructionValues = ImmutableList.of(16, 32, 64, 128); + List efSearchValues = ImmutableList.of(16, 32, 64, 128); + + Integer dimension = testData.indexData.vectors[0].length; + + // Create an index + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(KNNConstants.PARAMETERS) + .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + Map mappingMap = xContentBuilderToMap(builder); + String mapping = builder.toString(); + + createKnnIndex(INDEX_NAME, mapping); + assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(INDEX_NAME))); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + INDEX_NAME, + Integer.toString(testData.indexData.docs[i]), + FIELD_NAME, + Floats.asList(testData.indexData.vectors[i]).toArray() + ); + } + + // Assert we have the right number of documents + refreshAllNonSystemIndices(); + assertEquals(testData.indexData.docs.length, getDocCount(INDEX_NAME)); + + float score = 5f; + + validateRadiusSearchResults(INDEX_NAME, FIELD_NAME, testData.queries, null, score, spaceType); + + // Delete index + deleteKNNIndex(INDEX_NAME); + } + + @SneakyThrows + public void testEndToEnd_whenDoRadiusSearch__whenDistanceThreshold_whenMethodIsHNSWPQ_thenSucceed() { String indexName = "test-index"; String fieldName = "test-field"; String trainingIndexName = "training-index"; @@ -415,11 +533,9 @@ public void testEndToEnd_whenDoRadiusSearch_whenMethodIsHNSWPQ_thenSucceed() { refreshAllNonSystemIndices(); assertEquals(testData.indexData.docs.length, getDocCount(indexName)); - float radius = 300000000000f; + float distance = 300000000000f; - for (float[] queryVector : testData.queries) { - validateRadiusSearchResults(indexName, fieldName, queryVector, radius, spaceType); - } + validateRadiusSearchResults(indexName, fieldName, testData.queries, distance, null, spaceType); // Delete index deleteKNNIndex(indexName); @@ -1575,30 +1691,40 @@ private void validateGraphEviction() throws Exception { private void validateRadiusSearchResults( String indexName, String fieldName, - float[] queryVector, - float radius, + float[][] queryVectors, + Float distanceThreshold, + Float scoreThreshold, final SpaceType spaceType ) throws IOException, ParseException { - XContentBuilder queryBuilder = XContentFactory.jsonBuilder().startObject().startObject("query"); - queryBuilder.startObject("knn"); - queryBuilder.startObject(fieldName); - queryBuilder.field("vector", queryVector); - queryBuilder.field("distance", radius); - queryBuilder.endObject(); - queryBuilder.endObject(); - queryBuilder.endObject().endObject(); - final String responseBody = EntityUtils.toString(searchKNNIndex(indexName, queryBuilder, 10).getEntity()); - - List knnResults = parseSearchResponse(responseBody, fieldName); - - for (KNNResult knnResult : knnResults) { - float[] vector = Floats.toArray(Arrays.stream(knnResult.getVector()).collect(Collectors.toList())); - if (spaceType == SpaceType.L2) { - assertTrue(KNNScoringUtil.l2Squared(queryVector, vector) <= radius); - } else if (spaceType == SpaceType.INNER_PRODUCT) { - assertTrue(KNNScoringUtil.innerProduct(queryVector, vector) >= radius); + for (float[] queryVector : queryVectors) { + XContentBuilder queryBuilder = XContentFactory.jsonBuilder().startObject().startObject("query"); + queryBuilder.startObject("knn"); + queryBuilder.startObject(fieldName); + queryBuilder.field("vector", queryVector); + if (distanceThreshold != null) { + queryBuilder.field("distance", distanceThreshold); + } else if (scoreThreshold != null) { + queryBuilder.field("score", scoreThreshold); } else { - throw new IllegalArgumentException("Invalid space type"); + throw new IllegalArgumentException("Invalid threshold"); + } + queryBuilder.endObject(); + queryBuilder.endObject(); + queryBuilder.endObject().endObject(); + final String responseBody = EntityUtils.toString(searchKNNIndex(indexName, queryBuilder, 10).getEntity()); + + List knnResults = parseSearchResponse(responseBody, fieldName); + + for (KNNResult knnResult : knnResults) { + float[] vector = knnResult.getVector(); + float distance = TestUtils.computeDistFromSpaceType(spaceType, vector, queryVector); + if (spaceType == SpaceType.L2) { + assertTrue(KNNScoringUtil.l2Squared(queryVector, vector) <= distance); + } else if (spaceType == SpaceType.INNER_PRODUCT) { + assertTrue(KNNScoringUtil.innerProduct(queryVector, vector) >= distance); + } else { + throw new IllegalArgumentException("Invalid space type"); + } } } } diff --git a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java index 8f55730fd..f721606e1 100644 --- a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java +++ b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java @@ -51,6 +51,14 @@ public class LuceneEngineIT extends KNNRestTestCase { private static final Float[][] TEST_INDEX_VECTORS = { { 1.0f, 1.0f, 1.0f }, { 2.0f, 2.0f, 2.0f }, { 3.0f, 3.0f, 3.0f } }; private static final Float[][] TEST_COSINESIMIL_INDEX_VECTORS = { { 6.0f, 7.0f, 3.0f }, { 3.0f, 2.0f, 5.0f }, { 4.0f, 5.0f, 7.0f } }; + private static final Float[][] TEST_INNER_PRODUCT_INDEX_VECTORS = { + { 1.0f, 1.0f, 1.0f }, + { 2.0f, 2.0f, 2.0f }, + { 3.0f, 3.0f, 3.0f }, + { -1.0f, -1.0f, -1.0f }, + { -2.0f, -2.0f, -2.0f }, + { -3.0f, -3.0f, -3.0f } }; + private static final float[][] TEST_QUERY_VECTORS = { { 1.0f, 1.0f, 1.0f }, { 2.0f, 2.0f, 2.0f }, { 3.0f, 3.0f, 3.0f } }; private static final Map> VECTOR_SIMILARITY_TO_SCORE = ImmutableMap.of( @@ -59,7 +67,9 @@ public class LuceneEngineIT extends KNNRestTestCase { VectorSimilarityFunction.DOT_PRODUCT, (similarity) -> (1 + similarity) / 2, VectorSimilarityFunction.COSINE, - (similarity) -> (1 + similarity) / 2 + (similarity) -> (1 + similarity) / 2, + VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT, + (similarity) -> similarity <= 0 ? 1 / (1 - similarity) : similarity + 1 ); private static final String DIMENSION_FIELD_NAME = "dimension"; private static final String KNN_VECTOR_TYPE = "knn_vector"; @@ -318,55 +328,115 @@ public void testIndexReopening() throws Exception { assertArrayEquals(knnResultsBeforeIndexClosure.toArray(), knnResultsAfterIndexClosure.toArray()); } - public void testRadiusSearch_usingL2Metrics_usingFloatType() throws Exception { + public void testRadiusSearch_usingDistanceThreshold_usingL2Metrics_usingFloatType() throws Exception { createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2, VectorDataType.FLOAT); for (int j = 0; j < TEST_INDEX_VECTORS.length; j++) { addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_INDEX_VECTORS[j]); } - final float radius = 3.5f; + final float distance = 3.5f; final int[] expectedResults = { 2, 3, 2 }; - validateRadiusSearchResults(TEST_QUERY_VECTORS, radius, SpaceType.L2, expectedResults, null, null); + validateRadiusSearchResults(TEST_QUERY_VECTORS, distance, null, SpaceType.L2, expectedResults, null, null); } - public void testRadiusSearch_usingCosineMetrics_usingFloatType() throws Exception { + public void testRadiusSearch_usingScoreThreshold_usingL2Metrics_usingFloatType() throws Exception { + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2, VectorDataType.FLOAT); + for (int j = 0; j < TEST_INDEX_VECTORS.length; j++) { + addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_INDEX_VECTORS[j]); + } + + final float score = 0.23f; + final int[] expectedResults = { 2, 3, 2 }; + + validateRadiusSearchResults(TEST_QUERY_VECTORS, null, score, SpaceType.L2, expectedResults, null, null); + } + + public void testRadiusSearch_usingDistanceThreshold_usingCosineMetrics_usingFloatType() throws Exception { createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.COSINESIMIL, VectorDataType.FLOAT); for (int j = 0; j < TEST_COSINESIMIL_INDEX_VECTORS.length; j++) { addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_COSINESIMIL_INDEX_VECTORS[j]); } - final float radius = 0.03f; + final float distance = 0.03f; final int[] expectedResults = { 1, 1, 1 }; - validateRadiusSearchResults(TEST_QUERY_VECTORS, radius, SpaceType.COSINESIMIL, expectedResults, null, null); + validateRadiusSearchResults(TEST_QUERY_VECTORS, distance, null, SpaceType.COSINESIMIL, expectedResults, null, null); } - public void testRadiusSearch_usingL2Metrics_usingByteType() throws Exception { + public void testRadiusSearch_usingScoreThreshold_usingCosineMetrics_usingFloatType() throws Exception { + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.COSINESIMIL, VectorDataType.FLOAT); + for (int j = 0; j < TEST_COSINESIMIL_INDEX_VECTORS.length; j++) { + addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_COSINESIMIL_INDEX_VECTORS[j]); + } + + final float score = 0.97f; + final int[] expectedResults = { 1, 1, 1 }; + + validateRadiusSearchResults(TEST_QUERY_VECTORS, null, score, SpaceType.COSINESIMIL, expectedResults, null, null); + } + + public void testRadiusSearch_usingScoreThreshold_usingInnerProductMetrics_usingFloatType() throws Exception { + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.INNER_PRODUCT, VectorDataType.FLOAT); + for (int j = 0; j < TEST_INNER_PRODUCT_INDEX_VECTORS.length; j++) { + addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_INNER_PRODUCT_INDEX_VECTORS[j]); + } + + final float score = 2f; + final int[] expectedResults = { 1, 1, 1 }; + + validateRadiusSearchResults(TEST_QUERY_VECTORS, null, score, SpaceType.INNER_PRODUCT, expectedResults, null, null); + } + + public void testRadiusSearch_usingDistanceThreshold_usingL2Metrics_usingByteType() throws Exception { createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2, VectorDataType.BYTE); for (int j = 0; j < TEST_INDEX_VECTORS.length; j++) { addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_INDEX_VECTORS[j]); } - final float radius = 3.5f; + final float distance = 3.5f; final int[] expectedResults = { 2, 2, 2 }; - validateRadiusSearchResults(TEST_QUERY_VECTORS, radius, SpaceType.L2, expectedResults, null, null); + validateRadiusSearchResults(TEST_QUERY_VECTORS, distance, null, SpaceType.L2, expectedResults, null, null); } - public void testRadiusSearch_usingCosineMetrics_usingByteType() throws Exception { + public void testRadiusSearch_usingScoreThreshold_usingL2Metrics_usingByteType() throws Exception { + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2, VectorDataType.BYTE); + for (int j = 0; j < TEST_INDEX_VECTORS.length; j++) { + addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_INDEX_VECTORS[j]); + } + + final float score = 0.23f; + final int[] expectedResults = { 2, 2, 2 }; + + validateRadiusSearchResults(TEST_QUERY_VECTORS, null, score, SpaceType.L2, expectedResults, null, null); + } + + public void testRadiusSearch_usingDistanceThreshold_usingCosineMetrics_usingByteType() throws Exception { + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.COSINESIMIL, VectorDataType.BYTE); + for (int j = 0; j < TEST_COSINESIMIL_INDEX_VECTORS.length; j++) { + addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_COSINESIMIL_INDEX_VECTORS[j]); + } + + final float distance = 0.05f; + final int[] expectedResults = { 2, 2, 2 }; + + validateRadiusSearchResults(TEST_QUERY_VECTORS, distance, null, SpaceType.COSINESIMIL, expectedResults, null, null); + } + + public void testRadiusSearch_usingScoreThreshold_usingCosineMetrics_usingByteType() throws Exception { createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.COSINESIMIL, VectorDataType.BYTE); for (int j = 0; j < TEST_COSINESIMIL_INDEX_VECTORS.length; j++) { addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_COSINESIMIL_INDEX_VECTORS[j]); } - final float radius = 0.05f; + final float score = 0.97f; final int[] expectedResults = { 2, 2, 2 }; - validateRadiusSearchResults(TEST_QUERY_VECTORS, radius, SpaceType.COSINESIMIL, expectedResults, null, null); + validateRadiusSearchResults(TEST_QUERY_VECTORS, null, score, SpaceType.COSINESIMIL, expectedResults, null, null); } - public void testRadiusSearch_withFilter_usingL2Metrics_usingFloatType() throws Exception { + public void testRadiusSearch_usingDistanceThreshold_withFilter_usingL2Metrics_usingFloatType() throws Exception { createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2, VectorDataType.FLOAT); addKnnDocWithAttributes(DOC_ID, new float[] { 6.0f, 7.9f, 3.1f }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); addKnnDocWithAttributes(DOC_ID_2, new float[] { 3.2f, 2.1f, 4.8f }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); @@ -374,10 +444,24 @@ public void testRadiusSearch_withFilter_usingL2Metrics_usingFloatType() throws E refreshIndex(INDEX_NAME); - final float radius = 45.0f; + final float distance = 45.0f; + final int[] expectedResults = { 1, 1, 1 }; + + validateRadiusSearchResults(TEST_QUERY_VECTORS, distance, null, SpaceType.L2, expectedResults, COLOR_FIELD_NAME, "red"); + } + + public void testRadiusSearch_usingScoreThreshold_withFilter_usingCosineMetrics_usingFloatType() throws Exception { + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.COSINESIMIL, VectorDataType.FLOAT); + addKnnDocWithAttributes(DOC_ID, new float[] { 6.0f, 7.9f, 3.1f }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); + addKnnDocWithAttributes(DOC_ID_2, new float[] { 3.2f, 2.1f, 4.8f }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); + addKnnDocWithAttributes(DOC_ID_3, new float[] { 4.1f, 5.0f, 7.1f }, ImmutableMap.of(COLOR_FIELD_NAME, "green")); + + refreshIndex(INDEX_NAME); + + final float score = 0.02f; final int[] expectedResults = { 1, 1, 1 }; - validateRadiusSearchResults(TEST_QUERY_VECTORS, radius, SpaceType.L2, expectedResults, COLOR_FIELD_NAME, "red"); + validateRadiusSearchResults(TEST_QUERY_VECTORS, null, score, SpaceType.COSINESIMIL, expectedResults, COLOR_FIELD_NAME, "red"); } private void createKnnIndexMappingWithLuceneEngine(int dimension, SpaceType spaceType, VectorDataType vectorDataType) throws Exception { @@ -532,7 +616,8 @@ public void test_whenUsingIP_thenSuccess() { private void validateRadiusSearchResults( final float[][] searchVectors, - final float radius, + final Float distanceThreshold, + final Float scoreThreshold, final SpaceType spaceType, final int[] expectedResults, @Nullable final String filterField, @@ -543,7 +628,13 @@ private void validateRadiusSearchResults( builder.startObject("knn"); builder.startObject(FIELD_NAME); builder.field("vector", searchVectors[i]); - builder.field("distance", radius); + if (distanceThreshold != null) { + builder.field("distance", distanceThreshold); + } else if (scoreThreshold != null) { + builder.field("score", scoreThreshold); + } else { + throw new IllegalArgumentException("Either distance or score must be provided"); + } if (filterField != null && filterValue != null) { builder.startObject("filter"); builder.startObject("term"); @@ -562,13 +653,17 @@ private void validateRadiusSearchResults( List actualScores = parseSearchResponseScore(responseBody, FIELD_NAME); for (KNNResult result : radiusResults) { - float[] primitiveArray = Floats.toArray(Arrays.stream(result.getVector()).collect(Collectors.toList())); - float distance = TestUtils.computeDistFromSpaceType(spaceType, primitiveArray, searchVectors[i]); + float[] vector = result.getVector(); + float distance = TestUtils.computeDistFromSpaceType(spaceType, vector, searchVectors[i]); float rawScore = VECTOR_SIMILARITY_TO_SCORE.get(spaceType.getVectorSimilarityFunction()).apply(distance); if (spaceType == SpaceType.COSINESIMIL) { distance = 1 - distance; } - assertTrue(distance <= radius); + if (distanceThreshold != null) { + assertTrue(distance <= distanceThreshold); + } else { + assertTrue(rawScore >= scoreThreshold); + } assertEquals(KNNEngine.LUCENE.score(rawScore, spaceType), actualScores.get(radiusResults.indexOf(result)), 0.0001); } } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 3b0f03b38..8998dce69 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -61,6 +61,7 @@ public class KNNQueryBuilderTests extends KNNTestCase { private static final String FIELD_NAME = "myvector"; private static final int K = 1; private static final Float DISTANCE = 1.0f; + private static final Float SCORE = 0.5f; private static final TermQueryBuilder TERM_QUERY = QueryBuilders.termQuery("field", "value"); private static final float[] QUERY_VECTOR = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; @@ -91,6 +92,24 @@ public void testInvalidDistance() { expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).distance(null)); } + public void testInvalidScore() { + float[] queryVector = { 1.0f, 1.0f }; + /** + * null score + */ + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).score(null)); + + /** + * negative score + */ + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).score(-1.0f)); + + /** + * score = 0 + */ + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).score(0.0f)); + } + public void testEmptyVector() { /** * null query vector @@ -133,7 +152,7 @@ public void testFromXContent() throws Exception { assertEquals(knnQueryBuilder, actualBuilder); } - public void testFromXContent_whenDoRadiusSearch_thenSucceed() throws Exception { + public void testFromXContent_whenDoRadiusSearch_whenDistanceThreshold_thenSucceed() throws Exception { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).distance(DISTANCE); XContentBuilder builder = XContentFactory.jsonBuilder(); @@ -149,6 +168,22 @@ public void testFromXContent_whenDoRadiusSearch_thenSucceed() throws Exception { assertEquals(knnQueryBuilder, actualBuilder); } + public void testFromXContent_whenDoRadiusSearch_whenScoreThreshold_thenSucceed() throws Exception { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).score(DISTANCE); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(knnQueryBuilder.fieldName()); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); + builder.field(KNNQueryBuilder.DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getScore()); + builder.endObject(); + builder.endObject(); + XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + KNNQueryBuilder actualBuilder = KNNQueryBuilder.fromXContent(contentParser); + assertEquals(knnQueryBuilder, actualBuilder); + } + public void testFromXContent_withFilter() throws Exception { final ClusterService clusterService = mockClusterService(Version.CURRENT); @@ -171,7 +206,7 @@ public void testFromXContent_withFilter() throws Exception { assertEquals(knnQueryBuilder, actualBuilder); } - public void testFromXContent_wenDoRadiusSearch_whenFilter_thenSucceed() throws Exception { + public void testFromXContent_wenDoRadiusSearch_whenDistanceThreshold_whenFilter_thenSucceed() throws Exception { final ClusterService clusterService = mockClusterService(Version.CURRENT); final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); @@ -193,6 +228,28 @@ public void testFromXContent_wenDoRadiusSearch_whenFilter_thenSucceed() throws E assertEquals(knnQueryBuilder, actualBuilder); } + public void testFromXContent_wenDoRadiusSearch_whenScoreThreshold_whenFilter_thenSucceed() throws Exception { + final ClusterService clusterService = mockClusterService(Version.CURRENT); + + final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); + knnClusterUtil.initialize(clusterService); + + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).score(SCORE).filter(TERM_QUERY); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(knnQueryBuilder.fieldName()); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); + builder.field(KNNQueryBuilder.DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getScore()); + builder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), knnQueryBuilder.getFilter()); + builder.endObject(); + builder.endObject(); + XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + KNNQueryBuilder actualBuilder = KNNQueryBuilder.fromXContent(contentParser); + assertEquals(knnQueryBuilder, actualBuilder); + } + public void testFromXContent_InvalidQueryVectorType() throws Exception { final ClusterService clusterService = mockClusterService(Version.CURRENT); @@ -323,7 +380,7 @@ public void testDoToQuery_Normal() throws Exception { assertEquals(knnQueryBuilder.vector(), query.getQueryVector()); } - public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenSucceed() { + public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenDistanceThreshold_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).distance(DISTANCE); Index dummyIndex = new Index("dummy", "dummy"); @@ -340,7 +397,24 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenSucceed() { assertTrue(query.toString().contains("resultSimilarity=" + KNNEngine.LUCENE.distanceToRadialThreshold(DISTANCE, SpaceType.L2))); } - public void testDoToQuery_whenPassNegativeDistance_whenSupportedSpaceType_thenSucceed() { + public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenScoreThreshold_thenSucceed() { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).score(SCORE); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + FloatVectorSimilarityQuery query = (FloatVectorSimilarityQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); + assertTrue(query.toString().contains("resultSimilarity=" + 0.5f)); + } + + public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenSupportedSpaceType_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; float negativeDistance = -1.0f; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).distance(negativeDistance); @@ -364,7 +438,7 @@ public void testDoToQuery_whenPassNegativeDistance_whenSupportedSpaceType_thenSu assertEquals(negativeDistance, query.getRadius(), 0); } - public void testDoToQuery_whenPassNegativeDistance_whenUnSupportedSpaceType_thenException() { + public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenUnSupportedSpaceType_thenException() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; float negativeDistance = -1.0f; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).distance(negativeDistance); @@ -386,6 +460,52 @@ public void testDoToQuery_whenPassNegativeDistance_whenUnSupportedSpaceType_then expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); } + public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenSupportedSpaceType_thenSucceed() { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + float score = 5f; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).score(score); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn( + new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, methodComponentContext) + ); + IndexSettings indexSettings = mock(IndexSettings.class); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + when(indexSettings.getMaxResultWindow()).thenReturn(1000); + + KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); + + assertEquals(1 - score, query.getRadius(), 0); + } + + public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenUnsupportedSpaceType_thenException() { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + float score = 5f; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).score(score); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn( + new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext) + ); + IndexSettings indexSettings = mock(IndexSettings.class); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + when(indexSettings.getMaxResultWindow()).thenReturn(1000); + + expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); + } + public void testDoToQuery_KnnQueryWithFilter() throws Exception { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY); @@ -405,7 +525,7 @@ public void testDoToQuery_KnnQueryWithFilter() throws Exception { assertTrue(query.getClass().isAssignableFrom(KnnFloatVectorQuery.class)); } - public void testDoToQuery_whenDoRadiusSearch_whenFilter_thenSucceed() { + public void testDoToQuery_whenDoRadiusSearch_whenDistanceThreshold_whenFilter_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).distance(DISTANCE).filter(TERM_QUERY); Index dummyIndex = new Index("dummy", "dummy"); @@ -423,6 +543,24 @@ public void testDoToQuery_whenDoRadiusSearch_whenFilter_thenSucceed() { assertTrue(query.getClass().isAssignableFrom(FloatVectorSimilarityQuery.class)); } + public void testDoToQuery_whenDoRadiusSearch_whenScoreThreshold_whenFilter_thenSucceed() { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).score(SCORE).filter(TERM_QUERY); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + Query query = knnQueryBuilder.doToQuery(mockQueryShardContext); + assertNotNull(query); + assertTrue(query.getClass().isAssignableFrom(FloatVectorSimilarityQuery.class)); + } + public void testDoToQuery_WhenknnQueryWithFilterAndFaissEngine_thenSuccess() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY); @@ -488,6 +626,70 @@ public void testDoToQuery_FromModel() { assertEquals(knnQueryBuilder.vector(), query.getQueryVector()); } + public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold_thenSucceed() { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).distance(DISTANCE); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + + when(mockKNNVectorField.getDimension()).thenReturn(-K); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn(null); + String modelId = "test-model-id"; + when(mockKNNVectorField.getModelId()).thenReturn(modelId); + + ModelMetadata modelMetadata = mock(ModelMetadata.class); + when(modelMetadata.getDimension()).thenReturn(4); + when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS); + when(modelMetadata.getSpaceType()).thenReturn(SpaceType.L2); + ModelDao modelDao = mock(ModelDao.class); + when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); + KNNQueryBuilder.initialize(modelDao); + IndexSettings indexSettings = mock(IndexSettings.class); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + when(indexSettings.getMaxResultWindow()).thenReturn(1000); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + + KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); + assertEquals(knnQueryBuilder.getDistance(), query.getRadius(), 0); + assertEquals(knnQueryBuilder.fieldName(), query.getField()); + assertEquals(knnQueryBuilder.vector(), query.getQueryVector()); + } + + public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenScoreThreshold_thenSucceed() { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).score(SCORE); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + + when(mockKNNVectorField.getDimension()).thenReturn(-K); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn(null); + String modelId = "test-model-id"; + when(mockKNNVectorField.getModelId()).thenReturn(modelId); + + ModelMetadata modelMetadata = mock(ModelMetadata.class); + when(modelMetadata.getDimension()).thenReturn(4); + when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS); + when(modelMetadata.getSpaceType()).thenReturn(SpaceType.L2); + ModelDao modelDao = mock(ModelDao.class); + when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); + KNNQueryBuilder.initialize(modelDao); + IndexSettings indexSettings = mock(IndexSettings.class); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + when(indexSettings.getMaxResultWindow()).thenReturn(1000); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); + + assertEquals(1 / knnQueryBuilder.getScore() - 1, query.getRadius(), 0); + assertEquals(knnQueryBuilder.fieldName(), query.getField()); + assertEquals(knnQueryBuilder.vector(), query.getQueryVector()); + } + public void testDoToQuery_InvalidDimensions() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); @@ -556,18 +758,28 @@ public void testDoToQuery_InvalidZeroByteVector() { } public void testSerialization() throws Exception { - assertSerialization(Version.CURRENT, Optional.empty(), K, null); - assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), K, null); - assertSerialization(Version.V_2_3_0, Optional.empty(), K, null); - - // For radius search - assertSerialization(Version.CURRENT, Optional.empty(), null, DISTANCE); - assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, DISTANCE); + // For k-NN search + assertSerialization(Version.CURRENT, Optional.empty(), K, null, null); + assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), K, null, null); + assertSerialization(Version.V_2_3_0, Optional.empty(), K, null, null); + + // For distance threshold search + assertSerialization(Version.CURRENT, Optional.empty(), null, DISTANCE, null); + assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, DISTANCE, null); + + // For score threshold search + assertSerialization(Version.CURRENT, Optional.empty(), null, null, SCORE); + assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, null, SCORE); } - private void assertSerialization(final Version version, final Optional queryBuilderOptional, Integer k, Float distance) - throws Exception { - final KNNQueryBuilder knnQueryBuilder = getKnnQueryBuilder(queryBuilderOptional, k, distance); + private void assertSerialization( + final Version version, + final Optional queryBuilderOptional, + Integer k, + Float distance, + Float score + ) throws Exception { + final KNNQueryBuilder knnQueryBuilder = getKnnQueryBuilder(queryBuilderOptional, k, distance, score); final ClusterService clusterService = mockClusterService(version); @@ -588,8 +800,10 @@ private void assertSerialization(final Version version, final Optional queryBuilderOptional, Integer k, Float distance) { + private static KNNQueryBuilder getKnnQueryBuilder(Optional queryBuilderOptional, Integer k, Float distance, Float score) { final KNNQueryBuilder knnQueryBuilder; if (k != null) { knnQueryBuilder = queryBuilderOptional.isPresent() @@ -611,6 +825,10 @@ private static KNNQueryBuilder getKnnQueryBuilder(Optional queryBu knnQueryBuilder = queryBuilderOptional.isPresent() ? new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).distance(distance).filter(queryBuilderOptional.get()) : new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).distance(distance); + } else if (score != null) { + knnQueryBuilder = queryBuilderOptional.isPresent() + ? new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).score(score).filter(queryBuilderOptional.get()) + : new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).score(score); } else { throw new IllegalArgumentException("Either k or distance must be provided"); } diff --git a/src/test/java/org/opensearch/knn/index/util/AbstractKNNLibraryTests.java b/src/test/java/org/opensearch/knn/index/util/AbstractKNNLibraryTests.java index 747bfc751..9e6bd67ea 100644 --- a/src/test/java/org/opensearch/knn/index/util/AbstractKNNLibraryTests.java +++ b/src/test/java/org/opensearch/knn/index/util/AbstractKNNLibraryTests.java @@ -132,6 +132,10 @@ public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { return 0f; } + public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { + return 0f; + } + @Override public int estimateOverheadInKB(KNNMethodContext knnMethodContext, int dimension) { return 0; diff --git a/src/test/java/org/opensearch/knn/index/util/NativeLibraryTests.java b/src/test/java/org/opensearch/knn/index/util/NativeLibraryTests.java index 5824dfbce..3c3afbee6 100644 --- a/src/test/java/org/opensearch/knn/index/util/NativeLibraryTests.java +++ b/src/test/java/org/opensearch/knn/index/util/NativeLibraryTests.java @@ -69,5 +69,10 @@ public TestNativeLibrary( public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { return 0.0f; } + + @Override + public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { + return 0.0f; + } } } diff --git a/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java b/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java index 6cebded1b..cff4d5805 100644 --- a/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java +++ b/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java @@ -68,6 +68,11 @@ public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { return 0.0f; } + @Override + public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { + return 0.0f; + } + @Override public ValidationException validateMethod(KNNMethodContext knnMethodContext) { return null; From 4352d4b03b9e643e7fa68b1115252749e745675b Mon Sep 17 00:00:00 2001 From: Junqiu Lei Date: Fri, 12 Apr 2024 09:51:13 -0700 Subject: [PATCH 4/6] Rename radial search parameters score and distance to min_score and max_distance (#1609) * Rename radial search parameters score and distance to min_score and max_distance Signed-off-by: Junqiu Lei --- .../knn/index/query/KNNQueryBuilder.java | 108 +++++++------- .../org/opensearch/knn/index/FaissIT.java | 4 +- .../opensearch/knn/index/LuceneEngineIT.java | 4 +- .../knn/index/query/KNNQueryBuilderTests.java | 134 ++++++++++++------ 4 files changed, 148 insertions(+), 102 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 78ddb532d..7d3667ac0 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -51,8 +51,8 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { public static final ParseField K_FIELD = new ParseField("k"); public static final ParseField FILTER_FIELD = new ParseField("filter"); public static final ParseField IGNORE_UNMAPPED_FIELD = new ParseField("ignore_unmapped"); - public static final ParseField DISTANCE_FIELD = new ParseField("distance"); - public static final ParseField SCORE_FIELD = new ParseField("score"); + public static final ParseField MAX_DISTANCE_FIELD = new ParseField("max_distance"); + public static final ParseField MIN_SCORE_FIELD = new ParseField("min_score"); public static final int K_MAX = 10000; /** * The name for the knn query @@ -64,17 +64,17 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { private final String fieldName; private final float[] vector; private int k = 0; - private Float distance = null; - private Float score = null; + private Float max_distance = null; + private Float min_score = null; private QueryBuilder filter; private boolean ignoreUnmapped = false; /** - * Constructs a new query with the given field name and vector + * Constructs a new query with the given field name and vector * * @param fieldName Name of the field * @param vector Array of floating points - */ + */ public KNNQueryBuilder(String fieldName, float[] vector) { if (Strings.isNullOrEmpty(fieldName)) { throw new IllegalArgumentException("[" + NAME + "] requires fieldName"); @@ -98,7 +98,7 @@ public KNNQueryBuilder k(Integer k) { if (k == null) { throw new IllegalArgumentException("[" + NAME + "] requires k to be set"); } - validateSingleQueryType(k, distance, score); + validateSingleQueryType(k, max_distance, min_score); if (k <= 0 || k > K_MAX) { throw new IllegalArgumentException("[" + NAME + "] requires 0 < k <= " + K_MAX); } @@ -107,33 +107,33 @@ public KNNQueryBuilder k(Integer k) { } /** - * Builder method for distance + * Builder method for max_distance * - * @param distance the distance threshold for the nearest neighbours + * @param max_distance the max_distance threshold for the nearest neighbours */ - public KNNQueryBuilder distance(Float distance) { - if (distance == null) { - throw new IllegalArgumentException("[" + NAME + "] requires distance to be set"); + public KNNQueryBuilder maxDistance(Float max_distance) { + if (max_distance == null) { + throw new IllegalArgumentException("[" + NAME + "] requires max_distance to be set"); } - validateSingleQueryType(k, distance, score); - this.distance = distance; + validateSingleQueryType(k, max_distance, min_score); + this.max_distance = max_distance; return this; } /** - * Builder method for score + * Builder method for min_score * - * @param score the score threshold for the nearest neighbours + * @param min_score the min_score threshold for the nearest neighbours */ - public KNNQueryBuilder score(Float score) { - if (score == null) { - throw new IllegalArgumentException("[" + NAME + "] requires score to be set"); + public KNNQueryBuilder minScore(Float min_score) { + if (min_score == null) { + throw new IllegalArgumentException("[" + NAME + "] requires min_score to be set"); } - validateSingleQueryType(k, distance, score); - if (score <= 0) { - throw new IllegalArgumentException("[" + NAME + "] requires score greater than 0"); + validateSingleQueryType(k, max_distance, min_score); + if (min_score <= 0) { + throw new IllegalArgumentException("[" + NAME + "] requires min_score greater than 0"); } - this.score = score; + this.min_score = min_score; return this; } @@ -180,8 +180,8 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder fil this.k = k; this.filter = filter; this.ignoreUnmapped = false; - this.distance = null; - this.score = null; + this.max_distance = null; + this.min_score = null; } public static void initialize(ModelDao modelDao) { @@ -217,10 +217,10 @@ public KNNQueryBuilder(StreamInput in) throws IOException { ignoreUnmapped = in.readOptionalBoolean(); } if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { - distance = in.readOptionalFloat(); + max_distance = in.readOptionalFloat(); } if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { - score = in.readOptionalFloat(); + min_score = in.readOptionalFloat(); } } catch (IOException ex) { throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex); @@ -232,8 +232,8 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep List vector = null; float boost = AbstractQueryBuilder.DEFAULT_BOOST; Integer k = null; - Float distance = null; - Float score = null; + Float max_distance = null; + Float min_score = null; QueryBuilder filter = null; String queryName = null; String currentFieldName = null; @@ -262,10 +262,10 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep } } else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { queryName = parser.text(); - } else if (DISTANCE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - distance = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false); - } else if (SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - score = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false); + } else if (MAX_DISTANCE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + max_distance = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false); + } else if (MIN_SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + min_score = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false); } else { throw new ParsingException( parser.getTokenLocation(), @@ -295,7 +295,7 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep } } - validateSingleQueryType(k, distance, score); + validateSingleQueryType(k, max_distance, min_score); KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector)).filter(filter) .ignoreUnmapped(ignoreUnmapped) @@ -304,10 +304,10 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep if (k != null) { knnQueryBuilder.k(k); - } else if (distance != null) { - knnQueryBuilder.distance(distance); - } else if (score != null) { - knnQueryBuilder.score(score); + } else if (max_distance != null) { + knnQueryBuilder.maxDistance(max_distance); + } else if (min_score != null) { + knnQueryBuilder.minScore(min_score); } return knnQueryBuilder; @@ -323,10 +323,10 @@ protected void doWriteTo(StreamOutput out) throws IOException { out.writeOptionalBoolean(ignoreUnmapped); } if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { - out.writeOptionalFloat(distance); + out.writeOptionalFloat(max_distance); } if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { - out.writeOptionalFloat(score); + out.writeOptionalFloat(min_score); } } @@ -348,12 +348,12 @@ public int getK() { return this.k; } - public float getDistance() { - return this.distance; + public float getMaxDistance() { + return this.max_distance; } - public float getScore() { - return this.score; + public float getMinScore() { + return this.min_score; } public QueryBuilder getFilter() { @@ -384,14 +384,14 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio if (filter != null) { builder.field(FILTER_FIELD.getPreferredName(), filter); } - if (distance != null) { - builder.field(DISTANCE_FIELD.getPreferredName(), distance); + if (max_distance != null) { + builder.field(MAX_DISTANCE_FIELD.getPreferredName(), max_distance); } if (ignoreUnmapped) { builder.field(IGNORE_UNMAPPED_FIELD.getPreferredName(), ignoreUnmapped); } - if (score != null) { - builder.field(SCORE_FIELD.getPreferredName(), score); + if (min_score != null) { + builder.field(MIN_SCORE_FIELD.getPreferredName(), min_score); } printBoostAndQueryName(builder); builder.endObject(); @@ -435,18 +435,18 @@ protected Query doToQuery(QueryShardContext context) { // Currently, k-NN supports distance and score types radial search // We need transform distance/score to right type of engine required radius. Float radius = null; - if (this.distance != null) { - if (this.distance < 0 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) { + if (this.max_distance != null) { + if (this.max_distance < 0 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) { throw new IllegalArgumentException("[" + NAME + "] requires distance to be non-negative for space type: " + spaceType); } - radius = knnEngine.distanceToRadialThreshold(this.distance, spaceType); + radius = knnEngine.distanceToRadialThreshold(this.max_distance, spaceType); } - if (this.score != null) { - if (this.score > 1 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) { + if (this.min_score != null) { + if (this.min_score > 1 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) { throw new IllegalArgumentException("[" + NAME + "] requires score to be in the range (0, 1] for space type: " + spaceType); } - radius = knnEngine.scoreToRadialThreshold(this.score, spaceType); + radius = knnEngine.scoreToRadialThreshold(this.min_score, spaceType); } if (fieldDimension != vector.length) { diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index 16eb1a4c3..bcefeb7f4 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -1702,9 +1702,9 @@ private void validateRadiusSearchResults( queryBuilder.startObject(fieldName); queryBuilder.field("vector", queryVector); if (distanceThreshold != null) { - queryBuilder.field("distance", distanceThreshold); + queryBuilder.field("max_distance", distanceThreshold); } else if (scoreThreshold != null) { - queryBuilder.field("score", scoreThreshold); + queryBuilder.field("min_score", scoreThreshold); } else { throw new IllegalArgumentException("Invalid threshold"); } diff --git a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java index f721606e1..ab55741d3 100644 --- a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java +++ b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java @@ -629,9 +629,9 @@ private void validateRadiusSearchResults( builder.startObject(FIELD_NAME); builder.field("vector", searchVectors[i]); if (distanceThreshold != null) { - builder.field("distance", distanceThreshold); + builder.field("max_distance", distanceThreshold); } else if (scoreThreshold != null) { - builder.field("score", scoreThreshold); + builder.field("min_score", scoreThreshold); } else { throw new IllegalArgumentException("Either distance or score must be provided"); } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 8998dce69..1922e5a08 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -60,8 +60,8 @@ public class KNNQueryBuilderTests extends KNNTestCase { private static final String FIELD_NAME = "myvector"; private static final int K = 1; - private static final Float DISTANCE = 1.0f; - private static final Float SCORE = 0.5f; + private static final Float MAX_DISTANCE = 1.0f; + private static final Float MIN_SCORE = 0.5f; private static final TermQueryBuilder TERM_QUERY = QueryBuilders.termQuery("field", "value"); private static final float[] QUERY_VECTOR = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; @@ -89,25 +89,25 @@ public void testInvalidDistance() { /** * null distance */ - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).distance(null)); + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(null)); } public void testInvalidScore() { float[] queryVector = { 1.0f, 1.0f }; /** - * null score + * null min_score */ - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).score(null)); + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(null)); /** - * negative score + * negative min_score */ - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).score(-1.0f)); + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(-1.0f)); /** - * score = 0 + * min_score = 0 */ - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).score(0.0f)); + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(0.0f)); } public void testEmptyVector() { @@ -127,13 +127,13 @@ public void testEmptyVector() { * null query vector with distance */ float[] queryVector2 = null; - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector2).distance(DISTANCE)); + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector2).maxDistance(MAX_DISTANCE)); /** * empty query vector with distance */ float[] queryVector3 = {}; - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector3).distance(DISTANCE)); + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector3).maxDistance(MAX_DISTANCE)); } public void testFromXContent() throws Exception { @@ -154,12 +154,12 @@ public void testFromXContent() throws Exception { public void testFromXContent_whenDoRadiusSearch_whenDistanceThreshold_thenSucceed() throws Exception { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).distance(DISTANCE); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(MAX_DISTANCE); XContentBuilder builder = XContentFactory.jsonBuilder(); builder.startObject(); builder.startObject(knnQueryBuilder.fieldName()); builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); - builder.field(KNNQueryBuilder.DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getDistance()); + builder.field(KNNQueryBuilder.MAX_DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getMaxDistance()); builder.endObject(); builder.endObject(); XContentParser contentParser = createParser(builder); @@ -170,12 +170,12 @@ public void testFromXContent_whenDoRadiusSearch_whenDistanceThreshold_thenSuccee public void testFromXContent_whenDoRadiusSearch_whenScoreThreshold_thenSucceed() throws Exception { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).score(DISTANCE); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(MAX_DISTANCE); XContentBuilder builder = XContentFactory.jsonBuilder(); builder.startObject(); builder.startObject(knnQueryBuilder.fieldName()); builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); - builder.field(KNNQueryBuilder.DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getScore()); + builder.field(KNNQueryBuilder.MAX_DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getMinScore()); builder.endObject(); builder.endObject(); XContentParser contentParser = createParser(builder); @@ -213,12 +213,12 @@ public void testFromXContent_wenDoRadiusSearch_whenDistanceThreshold_whenFilter_ knnClusterUtil.initialize(clusterService); float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).distance(DISTANCE).filter(TERM_QUERY); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(MAX_DISTANCE).filter(TERM_QUERY); XContentBuilder builder = XContentFactory.jsonBuilder(); builder.startObject(); builder.startObject(knnQueryBuilder.fieldName()); builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); - builder.field(KNNQueryBuilder.DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getDistance()); + builder.field(KNNQueryBuilder.MAX_DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getMaxDistance()); builder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), knnQueryBuilder.getFilter()); builder.endObject(); builder.endObject(); @@ -235,12 +235,12 @@ public void testFromXContent_wenDoRadiusSearch_whenScoreThreshold_whenFilter_the knnClusterUtil.initialize(clusterService); float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).score(SCORE).filter(TERM_QUERY); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(MIN_SCORE).filter(TERM_QUERY); XContentBuilder builder = XContentFactory.jsonBuilder(); builder.startObject(); builder.startObject(knnQueryBuilder.fieldName()); builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); - builder.field(KNNQueryBuilder.DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getScore()); + builder.field(KNNQueryBuilder.MAX_DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getMinScore()); builder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), knnQueryBuilder.getFilter()); builder.endObject(); builder.endObject(); @@ -294,7 +294,7 @@ public void testFromXContent_whenDoRadiusSearch_whenInputInvalidQueryVectorType_ builder.startObject(); builder.startObject(FIELD_NAME); builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), invalidTypeQueryVector); - builder.field(KNNQueryBuilder.DISTANCE_FIELD.getPreferredName(), DISTANCE); + builder.field(KNNQueryBuilder.MAX_DISTANCE_FIELD.getPreferredName(), MAX_DISTANCE); builder.endObject(); builder.endObject(); XContentParser contentParser = createParser(builder); @@ -382,7 +382,7 @@ public void testDoToQuery_Normal() throws Exception { public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenDistanceThreshold_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).distance(DISTANCE); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(MAX_DISTANCE); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -394,12 +394,12 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenDistanceThreshold_th KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); FloatVectorSimilarityQuery query = (FloatVectorSimilarityQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); - assertTrue(query.toString().contains("resultSimilarity=" + KNNEngine.LUCENE.distanceToRadialThreshold(DISTANCE, SpaceType.L2))); + assertTrue(query.toString().contains("resultSimilarity=" + KNNEngine.LUCENE.distanceToRadialThreshold(MAX_DISTANCE, SpaceType.L2))); } public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenScoreThreshold_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).score(SCORE); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(MIN_SCORE); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -417,7 +417,7 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenScoreThreshold_thenS public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenSupportedSpaceType_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; float negativeDistance = -1.0f; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).distance(negativeDistance); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(negativeDistance); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -441,7 +441,7 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenSuppor public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenUnSupportedSpaceType_thenException() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; float negativeDistance = -1.0f; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).distance(negativeDistance); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(negativeDistance); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -463,7 +463,7 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenUnSupp public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenSupportedSpaceType_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; float score = 5f; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).score(score); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(score); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -487,7 +487,53 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenSuppor public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenUnsupportedSpaceType_thenException() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; float score = 5f; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).score(score); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(score); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn( + new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext) + ); + IndexSettings indexSettings = mock(IndexSettings.class); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + when(indexSettings.getMaxResultWindow()).thenReturn(1000); + + expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); + } + + public void testDoToQuery_whenPassNegativeDistance_whenSupportedSpaceType_thenSucceed() { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + float negativeDistance = -1.0f; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(negativeDistance); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn( + new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, methodComponentContext) + ); + IndexSettings indexSettings = mock(IndexSettings.class); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + when(indexSettings.getMaxResultWindow()).thenReturn(1000); + + KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); + + assertEquals(negativeDistance, query.getRadius(), 0); + } + + public void testDoToQuery_whenPassNegativeDistance_whenUnSupportedSpaceType_thenException() { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + float negativeDistance = -1.0f; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(negativeDistance); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -527,7 +573,7 @@ public void testDoToQuery_KnnQueryWithFilter() throws Exception { public void testDoToQuery_whenDoRadiusSearch_whenDistanceThreshold_whenFilter_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).distance(DISTANCE).filter(TERM_QUERY); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(MAX_DISTANCE).filter(TERM_QUERY); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -545,7 +591,7 @@ public void testDoToQuery_whenDoRadiusSearch_whenDistanceThreshold_whenFilter_th public void testDoToQuery_whenDoRadiusSearch_whenScoreThreshold_whenFilter_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).score(SCORE).filter(TERM_QUERY); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(MIN_SCORE).filter(TERM_QUERY); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -628,7 +674,7 @@ public void testDoToQuery_FromModel() { public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).distance(DISTANCE); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(MAX_DISTANCE); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -653,14 +699,14 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); - assertEquals(knnQueryBuilder.getDistance(), query.getRadius(), 0); + assertEquals(knnQueryBuilder.getMaxDistance(), query.getRadius(), 0); assertEquals(knnQueryBuilder.fieldName(), query.getField()); assertEquals(knnQueryBuilder.vector(), query.getQueryVector()); } public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenScoreThreshold_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).score(SCORE); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(MIN_SCORE); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -685,7 +731,7 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenScoreThreshold_th when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); - assertEquals(1 / knnQueryBuilder.getScore() - 1, query.getRadius(), 0); + assertEquals(1 / knnQueryBuilder.getMinScore() - 1, query.getRadius(), 0); assertEquals(knnQueryBuilder.fieldName(), query.getField()); assertEquals(knnQueryBuilder.vector(), query.getQueryVector()); } @@ -764,12 +810,12 @@ public void testSerialization() throws Exception { assertSerialization(Version.V_2_3_0, Optional.empty(), K, null, null); // For distance threshold search - assertSerialization(Version.CURRENT, Optional.empty(), null, DISTANCE, null); - assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, DISTANCE, null); + assertSerialization(Version.CURRENT, Optional.empty(), null, MAX_DISTANCE, null); + assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, MAX_DISTANCE, null); // For score threshold search - assertSerialization(Version.CURRENT, Optional.empty(), null, null, SCORE); - assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, null, SCORE); + assertSerialization(Version.CURRENT, Optional.empty(), null, null, MIN_SCORE); + assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, null, MIN_SCORE); } private void assertSerialization( @@ -801,9 +847,9 @@ private void assertSerialization( if (k != null) { assertEquals(k.intValue(), deserializedKnnQueryBuilder.getK()); } else if (distance != null) { - assertEquals(distance.floatValue(), deserializedKnnQueryBuilder.getDistance(), 0.0f); + assertEquals(distance.floatValue(), deserializedKnnQueryBuilder.getMaxDistance(), 0.0f); } else { - assertEquals(score.floatValue(), deserializedKnnQueryBuilder.getScore(), 0.0f); + assertEquals(score.floatValue(), deserializedKnnQueryBuilder.getMinScore(), 0.0f); } if (queryBuilderOptional.isPresent()) { assertNotNull(deserializedKnnQueryBuilder.getFilter()); @@ -823,12 +869,12 @@ private static KNNQueryBuilder getKnnQueryBuilder(Optional queryBu : new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, k); } else if (distance != null) { knnQueryBuilder = queryBuilderOptional.isPresent() - ? new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).distance(distance).filter(queryBuilderOptional.get()) - : new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).distance(distance); + ? new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).maxDistance(distance).filter(queryBuilderOptional.get()) + : new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).maxDistance(distance); } else if (score != null) { knnQueryBuilder = queryBuilderOptional.isPresent() - ? new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).score(score).filter(queryBuilderOptional.get()) - : new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).score(score); + ? new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).minScore(score).filter(queryBuilderOptional.get()) + : new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).minScore(score); } else { throw new IllegalArgumentException("Either k or distance must be provided"); } @@ -857,7 +903,7 @@ public void testRadialSearch_whenUnsupportedEngine_thenThrowException() { SpaceType.L2, new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()) ); - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).distance(DISTANCE); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).maxDistance(MAX_DISTANCE); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); Index dummyIndex = new Index("dummy", "dummy"); From 9ba05bd32d76ae6f3ec737a732964622ffc7c596 Mon Sep 17 00:00:00 2001 From: Junqiu Lei Date: Wed, 17 Apr 2024 08:40:24 -0700 Subject: [PATCH 5/6] Update MINIMAL_SUPPORTED_VERSION_FOR_RADIAL_SEARCH to 2.14 Signed-off-by: Junqiu Lei --- src/main/java/org/opensearch/knn/index/IndexUtil.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/knn/index/IndexUtil.java b/src/main/java/org/opensearch/knn/index/IndexUtil.java index cd4ca7822..c71f767a5 100644 --- a/src/main/java/org/opensearch/knn/index/IndexUtil.java +++ b/src/main/java/org/opensearch/knn/index/IndexUtil.java @@ -43,7 +43,7 @@ public class IndexUtil { private static final Version MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED = Version.V_2_11_0; private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT = Version.V_2_12_0; private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_METHOD_COMPONENT_CONTEXT = Version.V_2_13_0; - private static final Version MINIMAL_SUPPORTED_VERSION_FOR_RADIAL_SEARCH = Version.V_2_13_0; + private static final Version MINIMAL_SUPPORTED_VERSION_FOR_RADIAL_SEARCH = Version.V_2_14_0; private static final Map minimalRequiredVersionMap = new HashMap() { { put("ignore_unmapped", MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED); From cb1b3f8907a5584c602f0c137a156dd732befe18 Mon Sep 17 00:00:00 2001 From: Junqiu Lei Date: Wed, 17 Apr 2024 10:58:49 -0700 Subject: [PATCH 6/6] Update CHANGELOG.md Signed-off-by: Junqiu Lei --- CHANGELOG.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c9e34def9..3a9da1d49 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,9 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.13...2.x) ### Features -* Support distance type radius search for Lucene engine [#1498](https://github.com/opensearch-project/k-NN/pull/1498) -* Support distance type radius search for Faiss engine [#1546](https://github.com/opensearch-project/k-NN/pull/1546) -* Support score type threshold in radial search [#1589](https://github.com/opensearch-project/k-NN/pull/1589) +* Support radial search in k-NN plugin [#1617](https://github.com/opensearch-project/k-NN/pull/1617) ### Enhancements * Make the HitQueue size more appropriate for exact search [#1549](https://github.com/opensearch-project/k-NN/pull/1549) * Support script score when doc value is disabled [#1573](https://github.com/opensearch-project/k-NN/pull/1573)