diff --git a/CHANGELOG.md b/CHANGELOG.md index d11cb8869..90be873d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features * Support distance type radius search for Lucene engine [#1498](https://github.com/opensearch-project/k-NN/pull/1498) * Support distance type radius search for Faiss engine [#1546](https://github.com/opensearch-project/k-NN/pull/1546) +* Support score type threshold in radial search [#1589](https://github.com/opensearch-project/k-NN/pull/1589) ### Enhancements * Make the HitQueue size more appropriate for exact search [#1549](https://github.com/opensearch-project/k-NN/pull/1549) * Support script score when doc value is disabled [#1573](https://github.com/opensearch-project/k-NN/pull/1573) diff --git a/src/main/java/org/opensearch/knn/index/SpaceType.java b/src/main/java/org/opensearch/knn/index/SpaceType.java index 50d8d352c..240bfbe91 100644 --- a/src/main/java/org/opensearch/knn/index/SpaceType.java +++ b/src/main/java/org/opensearch/knn/index/SpaceType.java @@ -36,6 +36,14 @@ public float scoreTranslation(float rawScore) { public VectorSimilarityFunction getVectorSimilarityFunction() { return VectorSimilarityFunction.EUCLIDEAN; } + + @Override + public float scoreToDistanceTranslation(float score) { + if (score == 0) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "score cannot be 0 when space type is [%s]", getValue())); + } + return 1 / score - 1; + } }, COSINESIMIL("cosinesimil") { @Override @@ -170,4 +178,14 @@ public static SpaceType getSpace(String spaceTypeName) { } throw new IllegalArgumentException("Unable to find space: " + spaceTypeName); } + + /** + * Translate a score to a distance for this space type + * + * @param score score to translate + * @return translated distance + */ + public float scoreToDistanceTranslation(float score) { + throw new UnsupportedOperationException(String.format("Space [%s] does not have a score to distance translation", getValue())); + } } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 24487e2d2..78ddb532d 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -52,6 +52,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { public static final ParseField FILTER_FIELD = new ParseField("filter"); public static final ParseField IGNORE_UNMAPPED_FIELD = new ParseField("ignore_unmapped"); public static final ParseField DISTANCE_FIELD = new ParseField("distance"); + public static final ParseField SCORE_FIELD = new ParseField("score"); public static final int K_MAX = 10000; /** * The name for the knn query @@ -64,6 +65,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { private final float[] vector; private int k = 0; private Float distance = null; + private Float score = null; private QueryBuilder filter; private boolean ignoreUnmapped = false; @@ -92,13 +94,14 @@ public KNNQueryBuilder(String fieldName, float[] vector) { * * @param k K nearest neighbours for the given vector */ - public KNNQueryBuilder k(int k) { + public KNNQueryBuilder k(Integer k) { + if (k == null) { + throw new IllegalArgumentException("[" + NAME + "] requires k to be set"); + } + validateSingleQueryType(k, distance, score); if (k <= 0 || k > K_MAX) { throw new IllegalArgumentException("[" + NAME + "] requires 0 < k <= " + K_MAX); } - if (distance != null) { - throw new IllegalArgumentException("[" + NAME + "] requires either k or distance must be set"); - } this.k = k; return this; } @@ -112,13 +115,28 @@ public KNNQueryBuilder distance(Float distance) { if (distance == null) { throw new IllegalArgumentException("[" + NAME + "] requires distance to be set"); } - if (k != 0) { - throw new IllegalArgumentException("[" + NAME + "] requires either k or distance must be set"); - } + validateSingleQueryType(k, distance, score); this.distance = distance; return this; } + /** + * Builder method for score + * + * @param score the score threshold for the nearest neighbours + */ + public KNNQueryBuilder score(Float score) { + if (score == null) { + throw new IllegalArgumentException("[" + NAME + "] requires score to be set"); + } + validateSingleQueryType(k, distance, score); + if (score <= 0) { + throw new IllegalArgumentException("[" + NAME + "] requires score greater than 0"); + } + this.score = score; + return this; + } + /** * Builder method for filter * @@ -163,6 +181,7 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder fil this.filter = filter; this.ignoreUnmapped = false; this.distance = null; + this.score = null; } public static void initialize(ModelDao modelDao) { @@ -200,6 +219,9 @@ public KNNQueryBuilder(StreamInput in) throws IOException { if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { distance = in.readOptionalFloat(); } + if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { + score = in.readOptionalFloat(); + } } catch (IOException ex) { throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex); } @@ -211,6 +233,7 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep float boost = AbstractQueryBuilder.DEFAULT_BOOST; Integer k = null; Float distance = null; + Float score = null; QueryBuilder filter = null; String queryName = null; String currentFieldName = null; @@ -241,6 +264,8 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep queryName = parser.text(); } else if (DISTANCE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { distance = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false); + } else if (SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + score = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false); } else { throw new ParsingException( parser.getTokenLocation(), @@ -270,9 +295,7 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep } } - if ((k != null && distance != null) || (k == null && distance == null)) { - throw new IllegalArgumentException("[" + NAME + "] requires either k or distance must be set"); - } + validateSingleQueryType(k, distance, score); KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector)).filter(filter) .ignoreUnmapped(ignoreUnmapped) @@ -281,8 +304,10 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep if (k != null) { knnQueryBuilder.k(k); - } else { + } else if (distance != null) { knnQueryBuilder.distance(distance); + } else if (score != null) { + knnQueryBuilder.score(score); } return knnQueryBuilder; @@ -300,6 +325,9 @@ protected void doWriteTo(StreamOutput out) throws IOException { if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { out.writeOptionalFloat(distance); } + if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { + out.writeOptionalFloat(score); + } } /** @@ -324,6 +352,10 @@ public float getDistance() { return this.distance; } + public float getScore() { + return this.score; + } + public QueryBuilder getFilter() { return this.filter; } @@ -358,6 +390,9 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio if (ignoreUnmapped) { builder.field(IGNORE_UNMAPPED_FIELD.getPreferredName(), ignoreUnmapped); } + if (score != null) { + builder.field(SCORE_FIELD.getPreferredName(), score); + } printBoostAndQueryName(builder); builder.endObject(); builder.endObject(); @@ -397,8 +432,8 @@ protected Query doToQuery(QueryShardContext context) { spaceType = knnMethodContext.getSpaceType(); } - // Currently, k-NN supports distance type radius search. - // We need transform distance radius to right type of engine required radius. + // Currently, k-NN supports distance and score types radial search + // We need transform distance/score to right type of engine required radius. Float radius = null; if (this.distance != null) { if (this.distance < 0 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) { @@ -407,6 +442,13 @@ protected Query doToQuery(QueryShardContext context) { radius = knnEngine.distanceToRadialThreshold(this.distance, spaceType); } + if (this.score != null) { + if (this.score > 1 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) { + throw new IllegalArgumentException("[" + NAME + "] requires score to be in the range (0, 1] for space type: " + spaceType); + } + radius = knnEngine.scoreToRadialThreshold(this.score, spaceType); + } + if (fieldDimension != vector.length) { throw new IllegalArgumentException( String.format("Query vector has invalid dimension: %d. Dimension should be: %d", vector.length, fieldDimension) @@ -464,7 +506,7 @@ protected Query doToQuery(QueryShardContext context) { .build(); return RNNQueryFactory.create(createQueryRequest); } - throw new IllegalArgumentException("[" + NAME + "] requires either k or distance must be set"); + throw new IllegalArgumentException("[" + NAME + "] requires either k or distance or score to be set"); } private ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) { @@ -499,4 +541,24 @@ protected int doHashCode() { public String getWriteableName() { return NAME; } + + private static void validateSingleQueryType(Integer k, Float distance, Float score) { + int countSetFields = 0; + + if (k != null && k != 0) { + countSetFields++; + } + if (distance != null) { + countSetFields++; + } + if (score != null) { + countSetFields++; + } + + if (countSetFields != 1) { + throw new IllegalArgumentException( + "[" + NAME + "] requires only one query type to be set, it can be either k, distance, or score" + ); + } + } } diff --git a/src/main/java/org/opensearch/knn/index/util/Faiss.java b/src/main/java/org/opensearch/knn/index/util/Faiss.java index 8b58b7e74..efd8a637c 100644 --- a/src/main/java/org/opensearch/knn/index/util/Faiss.java +++ b/src/main/java/org/opensearch/knn/index/util/Faiss.java @@ -56,6 +56,7 @@ * Implements NativeLibrary for the faiss native library */ class Faiss extends NativeLibrary { + Map> scoreTransform; // TODO: Current version is not really current version. Instead, it encodes information in the file name // about the compatibility version the file is created with. In the future, we should refactor this so that it @@ -68,6 +69,12 @@ class Faiss extends NativeLibrary { rawScore -> SpaceType.INNER_PRODUCT.scoreTranslation(-1 * rawScore) ); + // Map that overrides radial search score threshold to faiss required distance, check more details in knn documentation: + // https://opensearch.org/docs/latest/search-plugins/knn/approximate-knn/#spaces + private final static Map> SCORE_TO_DISTANCE_TRANSFORMATIONS = ImmutableMap.< + SpaceType, + Function>builder().put(SpaceType.INNER_PRODUCT, score -> score > 1 ? 1 - score : 1 / score - 1).build(); + // Define encoders supported by faiss private final static MethodComponentContext ENCODER_DEFAULT = new MethodComponentContext( KNNConstants.ENCODER_FLAT, @@ -301,7 +308,13 @@ class Faiss extends NativeLibrary { ).addSpaces(SpaceType.L2, SpaceType.INNER_PRODUCT).build() ); - final static Faiss INSTANCE = new Faiss(METHODS, SCORE_TRANSLATIONS, CURRENT_VERSION, KNNConstants.FAISS_EXTENSION); + final static Faiss INSTANCE = new Faiss( + METHODS, + SCORE_TRANSLATIONS, + CURRENT_VERSION, + KNNConstants.FAISS_EXTENSION, + SCORE_TO_DISTANCE_TRANSFORMATIONS + ); /** * Constructor for Faiss @@ -315,9 +328,11 @@ private Faiss( Map methods, Map> scoreTranslation, String currentVersion, - String extension + String extension, + Map> scoreTransform ) { super(methods, scoreTranslation, currentVersion, extension); + this.scoreTransform = scoreTransform; } @Override @@ -326,6 +341,15 @@ public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { return distance; } + @Override + public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { + // Faiss engine uses distance as is and need transformation + if (this.scoreTransform.containsKey(spaceType)) { + return this.scoreTransform.get(spaceType).apply(score); + } + return spaceType.scoreToDistanceTranslation(score); + } + /** * MethodAsMap builder is used to create the map that will be passed to the jni to create the faiss index. * Faiss's index factory takes an "index description" that it uses to build the index. In this description, diff --git a/src/main/java/org/opensearch/knn/index/util/KNNEngine.java b/src/main/java/org/opensearch/knn/index/util/KNNEngine.java index 6551f5a39..e282c69db 100644 --- a/src/main/java/org/opensearch/knn/index/util/KNNEngine.java +++ b/src/main/java/org/opensearch/knn/index/util/KNNEngine.java @@ -158,6 +158,11 @@ public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { return knnLibrary.distanceToRadialThreshold(distance, spaceType); } + @Override + public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { + return knnLibrary.scoreToRadialThreshold(score, spaceType); + } + @Override public ValidationException validateMethod(KNNMethodContext knnMethodContext) { return knnLibrary.validateMethod(knnMethodContext); diff --git a/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java b/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java index 2c1f454c0..f837566b8 100644 --- a/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java +++ b/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java @@ -78,6 +78,16 @@ public interface KNNLibrary { */ Float distanceToRadialThreshold(Float distance, SpaceType spaceType); + /** + * Translate the score threshold input from end user to the engine's threshold. + * + * @param score score threshold input from end user + * @param spaceType spaceType used to compute the threshold + * + * @return transformed score for the library + */ + Float scoreToRadialThreshold(Float score, SpaceType spaceType); + /** * Validate the knnMethodContext for the given library. A ValidationException should be thrown if the method is * deemed invalid. diff --git a/src/main/java/org/opensearch/knn/index/util/Lucene.java b/src/main/java/org/opensearch/knn/index/util/Lucene.java index 54a752408..630d7a2c2 100644 --- a/src/main/java/org/opensearch/knn/index/util/Lucene.java +++ b/src/main/java/org/opensearch/knn/index/util/Lucene.java @@ -48,11 +48,13 @@ public class Lucene extends JVMLibrary { ).addSpaces(SpaceType.L2, SpaceType.COSINESIMIL, SpaceType.INNER_PRODUCT).build() ); + // Map that overrides the default distance translations for Lucene, check more details in knn documentation: + // https://opensearch.org/docs/latest/search-plugins/knn/approximate-knn/#spaces private final static Map> DISTANCE_TRANSLATIONS = ImmutableMap.< SpaceType, Function>builder() .put(SpaceType.COSINESIMIL, distance -> (2 - distance) / 2) - .put(SpaceType.L2, distance -> 1 / (1 + distance)) + .put(SpaceType.INNER_PRODUCT, distance -> distance <= 0 ? 1 / (1 - distance) : distance + 1) .build(); final static Lucene INSTANCE = new Lucene(METHODS, Version.LATEST.toString(), DISTANCE_TRANSLATIONS); @@ -90,7 +92,16 @@ public float score(float rawScore, SpaceType spaceType) { @Override public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { // Lucene requires score threshold to be parameterized when calling the radius search. - return this.distanceTransform.get(spaceType).apply(distance); + if (this.distanceTransform.containsKey(spaceType)) { + return this.distanceTransform.get(spaceType).apply(distance); + } + return spaceType.scoreTranslation(distance); + } + + @Override + public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { + // Lucene engine uses distance as is and does not need transformation + return score; } @Override diff --git a/src/main/java/org/opensearch/knn/index/util/Nmslib.java b/src/main/java/org/opensearch/knn/index/util/Nmslib.java index 8993cac8e..64af43520 100644 --- a/src/main/java/org/opensearch/knn/index/util/Nmslib.java +++ b/src/main/java/org/opensearch/knn/index/util/Nmslib.java @@ -73,4 +73,8 @@ private Nmslib( public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { return distance; } + + public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { + return score; + } } diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index 653750719..16eb1a4c3 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -278,7 +278,7 @@ public void testEndToEnd_whenMethodIsHNSWFlatAndHasDeletedDocs_thenSucceed() { } @SneakyThrows - public void testEndToEnd_whenDoRadiusSearch_whenMethodIsHNSWFlat_thenSucceed() { + public void testEndToEnd_whenDoRadiusSearch_whenDistanceThreshold_whenMethodIsHNSWFlat_thenSucceed() { KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW); SpaceType spaceType = SpaceType.L2; @@ -329,17 +329,135 @@ public void testEndToEnd_whenDoRadiusSearch_whenMethodIsHNSWFlat_thenSucceed() { refreshAllNonSystemIndices(); assertEquals(testData.indexData.docs.length, getDocCount(INDEX_NAME)); - float radius = 300000000000f; - for (float[] queryVector : testData.queries) { - validateRadiusSearchResults(INDEX_NAME, FIELD_NAME, queryVector, radius, spaceType); + float distance = 300000000000f; + validateRadiusSearchResults(INDEX_NAME, FIELD_NAME, testData.queries, distance, null, spaceType); + + // Delete index + deleteKNNIndex(INDEX_NAME); + } + + @SneakyThrows + public void testEndToEnd_whenDoRadiusSearch_whenScoreThreshold_whenMethodIsHNSWFlat_thenSucceed() { + KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW); + SpaceType spaceType = SpaceType.L2; + + List mValues = ImmutableList.of(16, 32, 64, 128); + List efConstructionValues = ImmutableList.of(16, 32, 64, 128); + List efSearchValues = ImmutableList.of(16, 32, 64, 128); + + Integer dimension = testData.indexData.vectors[0].length; + + // Create an index + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(KNNConstants.PARAMETERS) + .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + Map mappingMap = xContentBuilderToMap(builder); + String mapping = builder.toString(); + + createKnnIndex(INDEX_NAME, mapping); + assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(INDEX_NAME))); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + INDEX_NAME, + Integer.toString(testData.indexData.docs[i]), + FIELD_NAME, + Floats.asList(testData.indexData.vectors[i]).toArray() + ); } + // Assert we have the right number of documents + refreshAllNonSystemIndices(); + assertEquals(testData.indexData.docs.length, getDocCount(INDEX_NAME)); + + float score = 0.00001f; + + validateRadiusSearchResults(INDEX_NAME, FIELD_NAME, testData.queries, null, score, spaceType); + // Delete index deleteKNNIndex(INDEX_NAME); } @SneakyThrows - public void testEndToEnd_whenDoRadiusSearch_whenMethodIsHNSWPQ_thenSucceed() { + public void testEndToEnd_whenDoRadiusSearch_whenMoreThanOneScoreThreshold_whenMethodIsHNSWFlat_thenSucceed() { + KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW); + SpaceType spaceType = SpaceType.INNER_PRODUCT; + + List mValues = ImmutableList.of(16, 32, 64, 128); + List efConstructionValues = ImmutableList.of(16, 32, 64, 128); + List efSearchValues = ImmutableList.of(16, 32, 64, 128); + + Integer dimension = testData.indexData.vectors[0].length; + + // Create an index + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(KNNConstants.PARAMETERS) + .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + Map mappingMap = xContentBuilderToMap(builder); + String mapping = builder.toString(); + + createKnnIndex(INDEX_NAME, mapping); + assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(INDEX_NAME))); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + INDEX_NAME, + Integer.toString(testData.indexData.docs[i]), + FIELD_NAME, + Floats.asList(testData.indexData.vectors[i]).toArray() + ); + } + + // Assert we have the right number of documents + refreshAllNonSystemIndices(); + assertEquals(testData.indexData.docs.length, getDocCount(INDEX_NAME)); + + float score = 5f; + + validateRadiusSearchResults(INDEX_NAME, FIELD_NAME, testData.queries, null, score, spaceType); + + // Delete index + deleteKNNIndex(INDEX_NAME); + } + + @SneakyThrows + public void testEndToEnd_whenDoRadiusSearch__whenDistanceThreshold_whenMethodIsHNSWPQ_thenSucceed() { String indexName = "test-index"; String fieldName = "test-field"; String trainingIndexName = "training-index"; @@ -415,11 +533,9 @@ public void testEndToEnd_whenDoRadiusSearch_whenMethodIsHNSWPQ_thenSucceed() { refreshAllNonSystemIndices(); assertEquals(testData.indexData.docs.length, getDocCount(indexName)); - float radius = 300000000000f; + float distance = 300000000000f; - for (float[] queryVector : testData.queries) { - validateRadiusSearchResults(indexName, fieldName, queryVector, radius, spaceType); - } + validateRadiusSearchResults(indexName, fieldName, testData.queries, distance, null, spaceType); // Delete index deleteKNNIndex(indexName); @@ -1575,30 +1691,40 @@ private void validateGraphEviction() throws Exception { private void validateRadiusSearchResults( String indexName, String fieldName, - float[] queryVector, - float radius, + float[][] queryVectors, + Float distanceThreshold, + Float scoreThreshold, final SpaceType spaceType ) throws IOException, ParseException { - XContentBuilder queryBuilder = XContentFactory.jsonBuilder().startObject().startObject("query"); - queryBuilder.startObject("knn"); - queryBuilder.startObject(fieldName); - queryBuilder.field("vector", queryVector); - queryBuilder.field("distance", radius); - queryBuilder.endObject(); - queryBuilder.endObject(); - queryBuilder.endObject().endObject(); - final String responseBody = EntityUtils.toString(searchKNNIndex(indexName, queryBuilder, 10).getEntity()); - - List knnResults = parseSearchResponse(responseBody, fieldName); - - for (KNNResult knnResult : knnResults) { - float[] vector = Floats.toArray(Arrays.stream(knnResult.getVector()).collect(Collectors.toList())); - if (spaceType == SpaceType.L2) { - assertTrue(KNNScoringUtil.l2Squared(queryVector, vector) <= radius); - } else if (spaceType == SpaceType.INNER_PRODUCT) { - assertTrue(KNNScoringUtil.innerProduct(queryVector, vector) >= radius); + for (float[] queryVector : queryVectors) { + XContentBuilder queryBuilder = XContentFactory.jsonBuilder().startObject().startObject("query"); + queryBuilder.startObject("knn"); + queryBuilder.startObject(fieldName); + queryBuilder.field("vector", queryVector); + if (distanceThreshold != null) { + queryBuilder.field("distance", distanceThreshold); + } else if (scoreThreshold != null) { + queryBuilder.field("score", scoreThreshold); } else { - throw new IllegalArgumentException("Invalid space type"); + throw new IllegalArgumentException("Invalid threshold"); + } + queryBuilder.endObject(); + queryBuilder.endObject(); + queryBuilder.endObject().endObject(); + final String responseBody = EntityUtils.toString(searchKNNIndex(indexName, queryBuilder, 10).getEntity()); + + List knnResults = parseSearchResponse(responseBody, fieldName); + + for (KNNResult knnResult : knnResults) { + float[] vector = knnResult.getVector(); + float distance = TestUtils.computeDistFromSpaceType(spaceType, vector, queryVector); + if (spaceType == SpaceType.L2) { + assertTrue(KNNScoringUtil.l2Squared(queryVector, vector) <= distance); + } else if (spaceType == SpaceType.INNER_PRODUCT) { + assertTrue(KNNScoringUtil.innerProduct(queryVector, vector) >= distance); + } else { + throw new IllegalArgumentException("Invalid space type"); + } } } } diff --git a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java index 8f55730fd..f721606e1 100644 --- a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java +++ b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java @@ -51,6 +51,14 @@ public class LuceneEngineIT extends KNNRestTestCase { private static final Float[][] TEST_INDEX_VECTORS = { { 1.0f, 1.0f, 1.0f }, { 2.0f, 2.0f, 2.0f }, { 3.0f, 3.0f, 3.0f } }; private static final Float[][] TEST_COSINESIMIL_INDEX_VECTORS = { { 6.0f, 7.0f, 3.0f }, { 3.0f, 2.0f, 5.0f }, { 4.0f, 5.0f, 7.0f } }; + private static final Float[][] TEST_INNER_PRODUCT_INDEX_VECTORS = { + { 1.0f, 1.0f, 1.0f }, + { 2.0f, 2.0f, 2.0f }, + { 3.0f, 3.0f, 3.0f }, + { -1.0f, -1.0f, -1.0f }, + { -2.0f, -2.0f, -2.0f }, + { -3.0f, -3.0f, -3.0f } }; + private static final float[][] TEST_QUERY_VECTORS = { { 1.0f, 1.0f, 1.0f }, { 2.0f, 2.0f, 2.0f }, { 3.0f, 3.0f, 3.0f } }; private static final Map> VECTOR_SIMILARITY_TO_SCORE = ImmutableMap.of( @@ -59,7 +67,9 @@ public class LuceneEngineIT extends KNNRestTestCase { VectorSimilarityFunction.DOT_PRODUCT, (similarity) -> (1 + similarity) / 2, VectorSimilarityFunction.COSINE, - (similarity) -> (1 + similarity) / 2 + (similarity) -> (1 + similarity) / 2, + VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT, + (similarity) -> similarity <= 0 ? 1 / (1 - similarity) : similarity + 1 ); private static final String DIMENSION_FIELD_NAME = "dimension"; private static final String KNN_VECTOR_TYPE = "knn_vector"; @@ -318,55 +328,115 @@ public void testIndexReopening() throws Exception { assertArrayEquals(knnResultsBeforeIndexClosure.toArray(), knnResultsAfterIndexClosure.toArray()); } - public void testRadiusSearch_usingL2Metrics_usingFloatType() throws Exception { + public void testRadiusSearch_usingDistanceThreshold_usingL2Metrics_usingFloatType() throws Exception { createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2, VectorDataType.FLOAT); for (int j = 0; j < TEST_INDEX_VECTORS.length; j++) { addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_INDEX_VECTORS[j]); } - final float radius = 3.5f; + final float distance = 3.5f; final int[] expectedResults = { 2, 3, 2 }; - validateRadiusSearchResults(TEST_QUERY_VECTORS, radius, SpaceType.L2, expectedResults, null, null); + validateRadiusSearchResults(TEST_QUERY_VECTORS, distance, null, SpaceType.L2, expectedResults, null, null); } - public void testRadiusSearch_usingCosineMetrics_usingFloatType() throws Exception { + public void testRadiusSearch_usingScoreThreshold_usingL2Metrics_usingFloatType() throws Exception { + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2, VectorDataType.FLOAT); + for (int j = 0; j < TEST_INDEX_VECTORS.length; j++) { + addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_INDEX_VECTORS[j]); + } + + final float score = 0.23f; + final int[] expectedResults = { 2, 3, 2 }; + + validateRadiusSearchResults(TEST_QUERY_VECTORS, null, score, SpaceType.L2, expectedResults, null, null); + } + + public void testRadiusSearch_usingDistanceThreshold_usingCosineMetrics_usingFloatType() throws Exception { createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.COSINESIMIL, VectorDataType.FLOAT); for (int j = 0; j < TEST_COSINESIMIL_INDEX_VECTORS.length; j++) { addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_COSINESIMIL_INDEX_VECTORS[j]); } - final float radius = 0.03f; + final float distance = 0.03f; final int[] expectedResults = { 1, 1, 1 }; - validateRadiusSearchResults(TEST_QUERY_VECTORS, radius, SpaceType.COSINESIMIL, expectedResults, null, null); + validateRadiusSearchResults(TEST_QUERY_VECTORS, distance, null, SpaceType.COSINESIMIL, expectedResults, null, null); } - public void testRadiusSearch_usingL2Metrics_usingByteType() throws Exception { + public void testRadiusSearch_usingScoreThreshold_usingCosineMetrics_usingFloatType() throws Exception { + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.COSINESIMIL, VectorDataType.FLOAT); + for (int j = 0; j < TEST_COSINESIMIL_INDEX_VECTORS.length; j++) { + addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_COSINESIMIL_INDEX_VECTORS[j]); + } + + final float score = 0.97f; + final int[] expectedResults = { 1, 1, 1 }; + + validateRadiusSearchResults(TEST_QUERY_VECTORS, null, score, SpaceType.COSINESIMIL, expectedResults, null, null); + } + + public void testRadiusSearch_usingScoreThreshold_usingInnerProductMetrics_usingFloatType() throws Exception { + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.INNER_PRODUCT, VectorDataType.FLOAT); + for (int j = 0; j < TEST_INNER_PRODUCT_INDEX_VECTORS.length; j++) { + addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_INNER_PRODUCT_INDEX_VECTORS[j]); + } + + final float score = 2f; + final int[] expectedResults = { 1, 1, 1 }; + + validateRadiusSearchResults(TEST_QUERY_VECTORS, null, score, SpaceType.INNER_PRODUCT, expectedResults, null, null); + } + + public void testRadiusSearch_usingDistanceThreshold_usingL2Metrics_usingByteType() throws Exception { createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2, VectorDataType.BYTE); for (int j = 0; j < TEST_INDEX_VECTORS.length; j++) { addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_INDEX_VECTORS[j]); } - final float radius = 3.5f; + final float distance = 3.5f; final int[] expectedResults = { 2, 2, 2 }; - validateRadiusSearchResults(TEST_QUERY_VECTORS, radius, SpaceType.L2, expectedResults, null, null); + validateRadiusSearchResults(TEST_QUERY_VECTORS, distance, null, SpaceType.L2, expectedResults, null, null); } - public void testRadiusSearch_usingCosineMetrics_usingByteType() throws Exception { + public void testRadiusSearch_usingScoreThreshold_usingL2Metrics_usingByteType() throws Exception { + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2, VectorDataType.BYTE); + for (int j = 0; j < TEST_INDEX_VECTORS.length; j++) { + addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_INDEX_VECTORS[j]); + } + + final float score = 0.23f; + final int[] expectedResults = { 2, 2, 2 }; + + validateRadiusSearchResults(TEST_QUERY_VECTORS, null, score, SpaceType.L2, expectedResults, null, null); + } + + public void testRadiusSearch_usingDistanceThreshold_usingCosineMetrics_usingByteType() throws Exception { + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.COSINESIMIL, VectorDataType.BYTE); + for (int j = 0; j < TEST_COSINESIMIL_INDEX_VECTORS.length; j++) { + addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_COSINESIMIL_INDEX_VECTORS[j]); + } + + final float distance = 0.05f; + final int[] expectedResults = { 2, 2, 2 }; + + validateRadiusSearchResults(TEST_QUERY_VECTORS, distance, null, SpaceType.COSINESIMIL, expectedResults, null, null); + } + + public void testRadiusSearch_usingScoreThreshold_usingCosineMetrics_usingByteType() throws Exception { createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.COSINESIMIL, VectorDataType.BYTE); for (int j = 0; j < TEST_COSINESIMIL_INDEX_VECTORS.length; j++) { addKnnDoc(INDEX_NAME, Integer.toString(j + 1), FIELD_NAME, TEST_COSINESIMIL_INDEX_VECTORS[j]); } - final float radius = 0.05f; + final float score = 0.97f; final int[] expectedResults = { 2, 2, 2 }; - validateRadiusSearchResults(TEST_QUERY_VECTORS, radius, SpaceType.COSINESIMIL, expectedResults, null, null); + validateRadiusSearchResults(TEST_QUERY_VECTORS, null, score, SpaceType.COSINESIMIL, expectedResults, null, null); } - public void testRadiusSearch_withFilter_usingL2Metrics_usingFloatType() throws Exception { + public void testRadiusSearch_usingDistanceThreshold_withFilter_usingL2Metrics_usingFloatType() throws Exception { createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.L2, VectorDataType.FLOAT); addKnnDocWithAttributes(DOC_ID, new float[] { 6.0f, 7.9f, 3.1f }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); addKnnDocWithAttributes(DOC_ID_2, new float[] { 3.2f, 2.1f, 4.8f }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); @@ -374,10 +444,24 @@ public void testRadiusSearch_withFilter_usingL2Metrics_usingFloatType() throws E refreshIndex(INDEX_NAME); - final float radius = 45.0f; + final float distance = 45.0f; + final int[] expectedResults = { 1, 1, 1 }; + + validateRadiusSearchResults(TEST_QUERY_VECTORS, distance, null, SpaceType.L2, expectedResults, COLOR_FIELD_NAME, "red"); + } + + public void testRadiusSearch_usingScoreThreshold_withFilter_usingCosineMetrics_usingFloatType() throws Exception { + createKnnIndexMappingWithLuceneEngine(DIMENSION, SpaceType.COSINESIMIL, VectorDataType.FLOAT); + addKnnDocWithAttributes(DOC_ID, new float[] { 6.0f, 7.9f, 3.1f }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); + addKnnDocWithAttributes(DOC_ID_2, new float[] { 3.2f, 2.1f, 4.8f }, ImmutableMap.of(COLOR_FIELD_NAME, "red")); + addKnnDocWithAttributes(DOC_ID_3, new float[] { 4.1f, 5.0f, 7.1f }, ImmutableMap.of(COLOR_FIELD_NAME, "green")); + + refreshIndex(INDEX_NAME); + + final float score = 0.02f; final int[] expectedResults = { 1, 1, 1 }; - validateRadiusSearchResults(TEST_QUERY_VECTORS, radius, SpaceType.L2, expectedResults, COLOR_FIELD_NAME, "red"); + validateRadiusSearchResults(TEST_QUERY_VECTORS, null, score, SpaceType.COSINESIMIL, expectedResults, COLOR_FIELD_NAME, "red"); } private void createKnnIndexMappingWithLuceneEngine(int dimension, SpaceType spaceType, VectorDataType vectorDataType) throws Exception { @@ -532,7 +616,8 @@ public void test_whenUsingIP_thenSuccess() { private void validateRadiusSearchResults( final float[][] searchVectors, - final float radius, + final Float distanceThreshold, + final Float scoreThreshold, final SpaceType spaceType, final int[] expectedResults, @Nullable final String filterField, @@ -543,7 +628,13 @@ private void validateRadiusSearchResults( builder.startObject("knn"); builder.startObject(FIELD_NAME); builder.field("vector", searchVectors[i]); - builder.field("distance", radius); + if (distanceThreshold != null) { + builder.field("distance", distanceThreshold); + } else if (scoreThreshold != null) { + builder.field("score", scoreThreshold); + } else { + throw new IllegalArgumentException("Either distance or score must be provided"); + } if (filterField != null && filterValue != null) { builder.startObject("filter"); builder.startObject("term"); @@ -562,13 +653,17 @@ private void validateRadiusSearchResults( List actualScores = parseSearchResponseScore(responseBody, FIELD_NAME); for (KNNResult result : radiusResults) { - float[] primitiveArray = Floats.toArray(Arrays.stream(result.getVector()).collect(Collectors.toList())); - float distance = TestUtils.computeDistFromSpaceType(spaceType, primitiveArray, searchVectors[i]); + float[] vector = result.getVector(); + float distance = TestUtils.computeDistFromSpaceType(spaceType, vector, searchVectors[i]); float rawScore = VECTOR_SIMILARITY_TO_SCORE.get(spaceType.getVectorSimilarityFunction()).apply(distance); if (spaceType == SpaceType.COSINESIMIL) { distance = 1 - distance; } - assertTrue(distance <= radius); + if (distanceThreshold != null) { + assertTrue(distance <= distanceThreshold); + } else { + assertTrue(rawScore >= scoreThreshold); + } assertEquals(KNNEngine.LUCENE.score(rawScore, spaceType), actualScores.get(radiusResults.indexOf(result)), 0.0001); } } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 3b0f03b38..8998dce69 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -61,6 +61,7 @@ public class KNNQueryBuilderTests extends KNNTestCase { private static final String FIELD_NAME = "myvector"; private static final int K = 1; private static final Float DISTANCE = 1.0f; + private static final Float SCORE = 0.5f; private static final TermQueryBuilder TERM_QUERY = QueryBuilders.termQuery("field", "value"); private static final float[] QUERY_VECTOR = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; @@ -91,6 +92,24 @@ public void testInvalidDistance() { expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).distance(null)); } + public void testInvalidScore() { + float[] queryVector = { 1.0f, 1.0f }; + /** + * null score + */ + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).score(null)); + + /** + * negative score + */ + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).score(-1.0f)); + + /** + * score = 0 + */ + expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).score(0.0f)); + } + public void testEmptyVector() { /** * null query vector @@ -133,7 +152,7 @@ public void testFromXContent() throws Exception { assertEquals(knnQueryBuilder, actualBuilder); } - public void testFromXContent_whenDoRadiusSearch_thenSucceed() throws Exception { + public void testFromXContent_whenDoRadiusSearch_whenDistanceThreshold_thenSucceed() throws Exception { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).distance(DISTANCE); XContentBuilder builder = XContentFactory.jsonBuilder(); @@ -149,6 +168,22 @@ public void testFromXContent_whenDoRadiusSearch_thenSucceed() throws Exception { assertEquals(knnQueryBuilder, actualBuilder); } + public void testFromXContent_whenDoRadiusSearch_whenScoreThreshold_thenSucceed() throws Exception { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).score(DISTANCE); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(knnQueryBuilder.fieldName()); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); + builder.field(KNNQueryBuilder.DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getScore()); + builder.endObject(); + builder.endObject(); + XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + KNNQueryBuilder actualBuilder = KNNQueryBuilder.fromXContent(contentParser); + assertEquals(knnQueryBuilder, actualBuilder); + } + public void testFromXContent_withFilter() throws Exception { final ClusterService clusterService = mockClusterService(Version.CURRENT); @@ -171,7 +206,7 @@ public void testFromXContent_withFilter() throws Exception { assertEquals(knnQueryBuilder, actualBuilder); } - public void testFromXContent_wenDoRadiusSearch_whenFilter_thenSucceed() throws Exception { + public void testFromXContent_wenDoRadiusSearch_whenDistanceThreshold_whenFilter_thenSucceed() throws Exception { final ClusterService clusterService = mockClusterService(Version.CURRENT); final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); @@ -193,6 +228,28 @@ public void testFromXContent_wenDoRadiusSearch_whenFilter_thenSucceed() throws E assertEquals(knnQueryBuilder, actualBuilder); } + public void testFromXContent_wenDoRadiusSearch_whenScoreThreshold_whenFilter_thenSucceed() throws Exception { + final ClusterService clusterService = mockClusterService(Version.CURRENT); + + final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); + knnClusterUtil.initialize(clusterService); + + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).score(SCORE).filter(TERM_QUERY); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(knnQueryBuilder.fieldName()); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); + builder.field(KNNQueryBuilder.DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getScore()); + builder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), knnQueryBuilder.getFilter()); + builder.endObject(); + builder.endObject(); + XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + KNNQueryBuilder actualBuilder = KNNQueryBuilder.fromXContent(contentParser); + assertEquals(knnQueryBuilder, actualBuilder); + } + public void testFromXContent_InvalidQueryVectorType() throws Exception { final ClusterService clusterService = mockClusterService(Version.CURRENT); @@ -323,7 +380,7 @@ public void testDoToQuery_Normal() throws Exception { assertEquals(knnQueryBuilder.vector(), query.getQueryVector()); } - public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenSucceed() { + public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenDistanceThreshold_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).distance(DISTANCE); Index dummyIndex = new Index("dummy", "dummy"); @@ -340,7 +397,24 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenSucceed() { assertTrue(query.toString().contains("resultSimilarity=" + KNNEngine.LUCENE.distanceToRadialThreshold(DISTANCE, SpaceType.L2))); } - public void testDoToQuery_whenPassNegativeDistance_whenSupportedSpaceType_thenSucceed() { + public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenScoreThreshold_thenSucceed() { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).score(SCORE); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + FloatVectorSimilarityQuery query = (FloatVectorSimilarityQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); + assertTrue(query.toString().contains("resultSimilarity=" + 0.5f)); + } + + public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenSupportedSpaceType_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; float negativeDistance = -1.0f; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).distance(negativeDistance); @@ -364,7 +438,7 @@ public void testDoToQuery_whenPassNegativeDistance_whenSupportedSpaceType_thenSu assertEquals(negativeDistance, query.getRadius(), 0); } - public void testDoToQuery_whenPassNegativeDistance_whenUnSupportedSpaceType_thenException() { + public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenUnSupportedSpaceType_thenException() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; float negativeDistance = -1.0f; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).distance(negativeDistance); @@ -386,6 +460,52 @@ public void testDoToQuery_whenPassNegativeDistance_whenUnSupportedSpaceType_then expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); } + public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenSupportedSpaceType_thenSucceed() { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + float score = 5f; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).score(score); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn( + new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, methodComponentContext) + ); + IndexSettings indexSettings = mock(IndexSettings.class); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + when(indexSettings.getMaxResultWindow()).thenReturn(1000); + + KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); + + assertEquals(1 - score, query.getRadius(), 0); + } + + public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenUnsupportedSpaceType_thenException() { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + float score = 5f; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).score(score); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn( + new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext) + ); + IndexSettings indexSettings = mock(IndexSettings.class); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + when(indexSettings.getMaxResultWindow()).thenReturn(1000); + + expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); + } + public void testDoToQuery_KnnQueryWithFilter() throws Exception { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY); @@ -405,7 +525,7 @@ public void testDoToQuery_KnnQueryWithFilter() throws Exception { assertTrue(query.getClass().isAssignableFrom(KnnFloatVectorQuery.class)); } - public void testDoToQuery_whenDoRadiusSearch_whenFilter_thenSucceed() { + public void testDoToQuery_whenDoRadiusSearch_whenDistanceThreshold_whenFilter_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).distance(DISTANCE).filter(TERM_QUERY); Index dummyIndex = new Index("dummy", "dummy"); @@ -423,6 +543,24 @@ public void testDoToQuery_whenDoRadiusSearch_whenFilter_thenSucceed() { assertTrue(query.getClass().isAssignableFrom(FloatVectorSimilarityQuery.class)); } + public void testDoToQuery_whenDoRadiusSearch_whenScoreThreshold_whenFilter_thenSucceed() { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).score(SCORE).filter(TERM_QUERY); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + Query query = knnQueryBuilder.doToQuery(mockQueryShardContext); + assertNotNull(query); + assertTrue(query.getClass().isAssignableFrom(FloatVectorSimilarityQuery.class)); + } + public void testDoToQuery_WhenknnQueryWithFilterAndFaissEngine_thenSuccess() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY); @@ -488,6 +626,70 @@ public void testDoToQuery_FromModel() { assertEquals(knnQueryBuilder.vector(), query.getQueryVector()); } + public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold_thenSucceed() { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).distance(DISTANCE); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + + when(mockKNNVectorField.getDimension()).thenReturn(-K); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn(null); + String modelId = "test-model-id"; + when(mockKNNVectorField.getModelId()).thenReturn(modelId); + + ModelMetadata modelMetadata = mock(ModelMetadata.class); + when(modelMetadata.getDimension()).thenReturn(4); + when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS); + when(modelMetadata.getSpaceType()).thenReturn(SpaceType.L2); + ModelDao modelDao = mock(ModelDao.class); + when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); + KNNQueryBuilder.initialize(modelDao); + IndexSettings indexSettings = mock(IndexSettings.class); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + when(indexSettings.getMaxResultWindow()).thenReturn(1000); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + + KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); + assertEquals(knnQueryBuilder.getDistance(), query.getRadius(), 0); + assertEquals(knnQueryBuilder.fieldName(), query.getField()); + assertEquals(knnQueryBuilder.vector(), query.getQueryVector()); + } + + public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenScoreThreshold_thenSucceed() { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).score(SCORE); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + + when(mockKNNVectorField.getDimension()).thenReturn(-K); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn(null); + String modelId = "test-model-id"; + when(mockKNNVectorField.getModelId()).thenReturn(modelId); + + ModelMetadata modelMetadata = mock(ModelMetadata.class); + when(modelMetadata.getDimension()).thenReturn(4); + when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS); + when(modelMetadata.getSpaceType()).thenReturn(SpaceType.L2); + ModelDao modelDao = mock(ModelDao.class); + when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); + KNNQueryBuilder.initialize(modelDao); + IndexSettings indexSettings = mock(IndexSettings.class); + when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); + when(indexSettings.getMaxResultWindow()).thenReturn(1000); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); + + assertEquals(1 / knnQueryBuilder.getScore() - 1, query.getRadius(), 0); + assertEquals(knnQueryBuilder.fieldName(), query.getField()); + assertEquals(knnQueryBuilder.vector(), query.getQueryVector()); + } + public void testDoToQuery_InvalidDimensions() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); @@ -556,18 +758,28 @@ public void testDoToQuery_InvalidZeroByteVector() { } public void testSerialization() throws Exception { - assertSerialization(Version.CURRENT, Optional.empty(), K, null); - assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), K, null); - assertSerialization(Version.V_2_3_0, Optional.empty(), K, null); - - // For radius search - assertSerialization(Version.CURRENT, Optional.empty(), null, DISTANCE); - assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, DISTANCE); + // For k-NN search + assertSerialization(Version.CURRENT, Optional.empty(), K, null, null); + assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), K, null, null); + assertSerialization(Version.V_2_3_0, Optional.empty(), K, null, null); + + // For distance threshold search + assertSerialization(Version.CURRENT, Optional.empty(), null, DISTANCE, null); + assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, DISTANCE, null); + + // For score threshold search + assertSerialization(Version.CURRENT, Optional.empty(), null, null, SCORE); + assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, null, SCORE); } - private void assertSerialization(final Version version, final Optional queryBuilderOptional, Integer k, Float distance) - throws Exception { - final KNNQueryBuilder knnQueryBuilder = getKnnQueryBuilder(queryBuilderOptional, k, distance); + private void assertSerialization( + final Version version, + final Optional queryBuilderOptional, + Integer k, + Float distance, + Float score + ) throws Exception { + final KNNQueryBuilder knnQueryBuilder = getKnnQueryBuilder(queryBuilderOptional, k, distance, score); final ClusterService clusterService = mockClusterService(version); @@ -588,8 +800,10 @@ private void assertSerialization(final Version version, final Optional queryBuilderOptional, Integer k, Float distance) { + private static KNNQueryBuilder getKnnQueryBuilder(Optional queryBuilderOptional, Integer k, Float distance, Float score) { final KNNQueryBuilder knnQueryBuilder; if (k != null) { knnQueryBuilder = queryBuilderOptional.isPresent() @@ -611,6 +825,10 @@ private static KNNQueryBuilder getKnnQueryBuilder(Optional queryBu knnQueryBuilder = queryBuilderOptional.isPresent() ? new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).distance(distance).filter(queryBuilderOptional.get()) : new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).distance(distance); + } else if (score != null) { + knnQueryBuilder = queryBuilderOptional.isPresent() + ? new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).score(score).filter(queryBuilderOptional.get()) + : new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).score(score); } else { throw new IllegalArgumentException("Either k or distance must be provided"); } diff --git a/src/test/java/org/opensearch/knn/index/util/AbstractKNNLibraryTests.java b/src/test/java/org/opensearch/knn/index/util/AbstractKNNLibraryTests.java index 747bfc751..9e6bd67ea 100644 --- a/src/test/java/org/opensearch/knn/index/util/AbstractKNNLibraryTests.java +++ b/src/test/java/org/opensearch/knn/index/util/AbstractKNNLibraryTests.java @@ -132,6 +132,10 @@ public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { return 0f; } + public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { + return 0f; + } + @Override public int estimateOverheadInKB(KNNMethodContext knnMethodContext, int dimension) { return 0; diff --git a/src/test/java/org/opensearch/knn/index/util/NativeLibraryTests.java b/src/test/java/org/opensearch/knn/index/util/NativeLibraryTests.java index 5824dfbce..3c3afbee6 100644 --- a/src/test/java/org/opensearch/knn/index/util/NativeLibraryTests.java +++ b/src/test/java/org/opensearch/knn/index/util/NativeLibraryTests.java @@ -69,5 +69,10 @@ public TestNativeLibrary( public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { return 0.0f; } + + @Override + public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { + return 0.0f; + } } } diff --git a/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java b/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java index 6cebded1b..cff4d5805 100644 --- a/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java +++ b/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java @@ -68,6 +68,11 @@ public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) { return 0.0f; } + @Override + public Float scoreToRadialThreshold(Float score, SpaceType spaceType) { + return 0.0f; + } + @Override public ValidationException validateMethod(KNNMethodContext knnMethodContext) { return null;