diff --git a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java index 8e5849abb6..b58924601c 100644 --- a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java +++ b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java @@ -21,6 +21,7 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.query.iterators.BinaryVectorIdsKNNIterator; +import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.query.iterators.ByteVectorIdsKNNIterator; import org.opensearch.knn.index.query.iterators.NestedBinaryVectorIdsKNNIterator; import org.opensearch.knn.index.query.iterators.VectorIdsKNNIterator; @@ -36,6 +37,7 @@ import java.io.IOException; import java.util.HashMap; +import java.util.Locale; import java.util.Map; @Log4j2 @@ -55,6 +57,9 @@ public class ExactSearcher { public Map searchLeaf(final LeafReaderContext leafReaderContext, final ExactSearcherContext exactSearcherContext) throws IOException { KNNIterator iterator = getKNNIterator(leafReaderContext, exactSearcherContext); + if (exactSearcherContext.getKnnQuery().getRadius() != null) { + return doRadialSearch(leafReaderContext, exactSearcherContext, iterator); + } if (exactSearcherContext.getMatchedDocs() != null && exactSearcherContext.getMatchedDocs().cardinality() <= exactSearcherContext.getK()) { return scoreAllDocs(iterator); @@ -62,6 +67,33 @@ public Map searchLeaf(final LeafReaderContext leafReaderContext, return searchTopK(iterator, exactSearcherContext.getK()); } + /** + * Perform radial search by comparing scores with min score. Currently, FAISS from native engine supports radial search. + * Hence, we assume that Radius from knnQuery is always distance, and we convert it to score since we do exact search uses scores + * to filter out the documents that does not have given min score. + * @param leafReaderContext + * @param exactSearcherContext + * @param iterator + * @return Map of docId and score + * @throws IOException + */ + private Map doRadialSearch( + LeafReaderContext leafReaderContext, + ExactSearcherContext exactSearcherContext, + KNNIterator iterator + ) throws IOException { + final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader()); + final KNNQuery knnQuery = exactSearcherContext.getKnnQuery(); + final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); + final KNNEngine engine = FieldInfoExtractor.extractKNNEngine(fieldInfo); + if (KNNEngine.FAISS != engine) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "Engine [%s] does not support radial search", engine)); + } + final SpaceType spaceType = FieldInfoExtractor.getSpaceType(modelDao, fieldInfo); + final float minScore = spaceType.scoreTranslation(knnQuery.getRadius()); + return filterDocsByMinScore(iterator, minScore, knnQuery.getContext().getMaxResultWindow()); + } + private Map scoreAllDocs(KNNIterator iterator) throws IOException { final Map docToScore = new HashMap<>(); int docId; @@ -102,6 +134,42 @@ private Map searchTopK(KNNIterator iterator, int k) throws IOExc return docToScore; } + private Map filterDocsByMinScore(KNNIterator iterator, float minScore, int maxResultWindow) throws IOException { + // Creating min heap and init with MAX DocID and Score as -INF. + final HitQueue queue = new HitQueue(maxResultWindow, true); + ScoreDoc topDoc = queue.top(); + final Map docToScore = new HashMap<>(); + int docId; + while ((docId = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { + final float currentScore = iterator.score(); + // Consider docs which has at least minScore + if (currentScore < minScore) { + continue; + } + if (currentScore > topDoc.score) { + topDoc.score = currentScore; + topDoc.doc = docId; + // As the HitQueue is min heap, updating top will bring the doc with -INF score or worst score we + // have seen till now on top. + topDoc = queue.updateTop(); + } + } + + // If scores are negative we will remove them. + // This is done, because there can be negative values in the Heap as we initialize the heap with Score as -INF. + // If filterIds < maxResultWindow, then some values in heap can have a negative score. + while (queue.size() > 0 && queue.top().score < 0) { + queue.pop(); + } + + while (queue.size() > 0) { + final ScoreDoc doc = queue.pop(); + docToScore.put(doc.doc, doc.score); + } + + return docToScore; + } + private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSearcherContext exactSearcherContext) throws IOException { final KNNQuery knnQuery = exactSearcherContext.getKnnQuery(); final BitSet matchedDocs = exactSearcherContext.getMatchedDocs(); diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index d5cd809341..a81e68e222 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -204,8 +204,8 @@ private int[] bitSetToIntArray(final BitSet bitSet) { private Map doExactSearch(final LeafReaderContext context, final BitSet acceptedDocs, int k) throws IOException { final ExactSearcherContextBuilder exactSearcherContextBuilder = ExactSearcher.ExactSearcherContext.builder() - .k(k) .isParentHits(true) + .k(k) // setting to true, so that if quantization details are present we want to do search on the quantized // vectors as this flow is used in first pass of search. .useQuantizedVectorsForSearch(true) diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index 2df1d8a608..07293a0ed0 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -1708,6 +1708,85 @@ public void testIVF_whenBinaryFormat_whenIVF_thenSuccess() { validateGraphEviction(); } + @SneakyThrows + public void testEndToEnd_whenDoRadiusSearch_whenNoGraphFileIsCreated_whenDistanceThreshold_thenSucceed() { + SpaceType spaceType = SpaceType.L2; + + List mValues = ImmutableList.of(16, 32, 64, 128); + List efConstructionValues = ImmutableList.of(16, 32, 64, 128); + List efSearchValues = ImmutableList.of(16, 32, 64, 128); + + Integer dimension = testData.indexData.vectors[0].length; + final Settings knnIndexSettings = buildKNNIndexSettings(-1); + + // Create an index + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + String mapping = builder.toString(); + createKnnIndex(INDEX_NAME, knnIndexSettings, mapping); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + INDEX_NAME, + Integer.toString(testData.indexData.docs[i]), + FIELD_NAME, + Floats.asList(testData.indexData.vectors[i]).toArray() + ); + } + + // Assert we have the right number of documents + refreshAllNonSystemIndices(); + assertEquals(testData.indexData.docs.length, getDocCount(INDEX_NAME)); + + float distance = 300000000000f; + final List> resultsFromDistance = validateRadiusSearchResults( + INDEX_NAME, + FIELD_NAME, + testData.queries, + distance, + null, + spaceType, + null, + null + ); + assertFalse(resultsFromDistance.isEmpty()); + resultsFromDistance.forEach(result -> { assertFalse(result.isEmpty()); }); + float score = spaceType.scoreTranslation(distance); + final List> resultsFromScore = validateRadiusSearchResults( + INDEX_NAME, + FIELD_NAME, + testData.queries, + null, + score, + spaceType, + null, + null + ); + assertFalse(resultsFromScore.isEmpty()); + resultsFromScore.forEach(result -> { assertFalse(result.isEmpty()); }); + + // Delete index + deleteKNNIndex(INDEX_NAME); + } + @SneakyThrows public void testQueryWithFilter_whenNonExistingFieldUsedInFilter_thenSuccessful() { XContentBuilder builder = XContentFactory.jsonBuilder() diff --git a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java index bf8168d376..c6e5c8fd4a 100644 --- a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java +++ b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java @@ -814,6 +814,92 @@ public void testKNNIndex_whenBuildVectorDataStructureIsLessThanDocCount_thenBuil deleteKNNIndex(indexName); } + /* + For this testcase, we will create index with setting build_vector_data_structure_threshold as -1, then index few documents, perform knn search, + then, confirm hits because of exact search though there are no graph. In next step, update setting to 0, force merge segment to 1, perform knn search and confirm expected + hits are returned. + */ + public void testKNNIndex_whenBuildVectorGraphThresholdIsProvidedEndToEnd_thenBuildGraphBasedOnSettingUsingRadialSearch() + throws Exception { + final String indexName = "test-index-1"; + final String fieldName1 = "test-field-1"; + final String fieldName2 = "test-field-2"; + + final Integer dimension = testData.indexData.vectors[0].length; + final Settings knnIndexSettings = buildKNNIndexSettings(-1); + + // Create an index + final XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName1) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, KNNConstants.METHOD_HNSW) + .field(KNNConstants.KNN_ENGINE, KNNEngine.NMSLIB.getName()) + .startObject(KNNConstants.PARAMETERS) + .endObject() + .endObject() + .endObject() + .startObject(fieldName2) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, KNNConstants.METHOD_HNSW) + .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(KNNConstants.PARAMETERS) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + createKnnIndex(indexName, knnIndexSettings, builder.toString()); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + indexName, + Integer.toString(testData.indexData.docs[i]), + ImmutableList.of(fieldName1, fieldName2), + ImmutableList.of( + Floats.asList(testData.indexData.vectors[i]).toArray(), + Floats.asList(testData.indexData.vectors[i]).toArray() + ) + ); + } + + refreshAllIndices(); + // Assert we have the right number of documents in the index + assertEquals(testData.indexData.docs.length, getDocCount(indexName)); + + final List nmslibNeighbors = getResults(indexName, fieldName1, testData.queries[0], 1); + assertEquals("unexpected neighbors are returned", nmslibNeighbors.size(), nmslibNeighbors.size()); + + final List faissNeighbors = getResults(indexName, fieldName2, testData.queries[0], 1); + assertEquals("unexpected neighbors are returned", faissNeighbors.size(), faissNeighbors.size()); + + // update build vector data structure setting + updateIndexSettings(indexName, Settings.builder().put(KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, 0)); + forceMergeKnnIndex(indexName, 1); + + final int k = 10; + for (int i = 0; i < testData.queries.length; i++) { + // Search nmslib field + final Response response = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName1, testData.queries[i], k), k); + final String responseBody = EntityUtils.toString(response.getEntity()); + final List nmslibValidNeighbors = parseSearchResponse(responseBody, fieldName1); + assertEquals(k, nmslibValidNeighbors.size()); + // Search faiss field + final List faissValidNeighbors = getResults(indexName, fieldName2, testData.queries[i], k); + assertEquals(k, faissValidNeighbors.size()); + } + + // Delete index + deleteKNNIndex(indexName); + } + private List getResults(final String indexName, final String fieldName, final float[] vector, final int k) throws IOException, ParseException { final Response searchResponseField = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, vector, k), k); diff --git a/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java b/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java index eed2772b41..cca40c55dc 100644 --- a/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java +++ b/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java @@ -155,17 +155,6 @@ public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsProvidedEndToEnd_ } } - @SneakyThrows - public void testFaissHnswBinary_whenRadialSearch_thenThrowException() { - // Create Index - createKnnHnswBinaryIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 16); - - // Query - float[] queryVector = { (byte) 0b10001111, (byte) 0b10000000 }; - Exception e = expectThrows(Exception.class, () -> runRnnQuery(INDEX_NAME, FIELD_NAME, queryVector, 1, 4)); - assertTrue(e.getMessage(), e.getMessage().contains("Binary data type does not support radial search")); - } - private float getRecall(final Set truth, final Set result) { // Count the number of relevant documents retrieved result.retainAll(truth); @@ -178,23 +167,6 @@ private float getRecall(final Set truth, final Set result) { return (float) relevantRetrieved / totalRelevant; } - private List runRnnQuery( - final String indexName, - final String fieldName, - final float[] queryVector, - final float minScore, - final int size - ) throws Exception { - String query = KNNJsonQueryBuilder.builder() - .fieldName(fieldName) - .vector(ArrayUtils.toObject(queryVector)) - .minScore(minScore) - .build() - .getQueryString(); - Response response = searchKNNIndex(indexName, query, size); - return parseSearchResponse(EntityUtils.toString(response.getEntity()), fieldName); - } - private List runKnnQuery(final String indexName, final String fieldName, final float[] queryVector, final int k) throws Exception { String query = KNNJsonQueryBuilder.builder()