diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 050c36881..5f96a7971 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -6,13 +6,18 @@ package org.opensearch.knn.index.query; import lombok.extern.log4j.Log4j2; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.DocValues; import org.apache.lucene.search.FilteredDocIdSetIterator; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.Bits; +import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.FixedBitSet; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.codec.util.KNNVectorSerializer; +import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.memory.NativeMemoryAllocation; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; @@ -35,10 +40,12 @@ import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.plugin.stats.KNNCounter; +import java.io.ByteArrayInputStream; import java.io.IOException; import java.nio.file.Path; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ExecutionException; @@ -97,6 +104,79 @@ public Scorer scorer(LeafReaderContext context) throws IOException { if (filterWeight != null && filterIdsArray.length == 0) { return KNNScorer.emptyScorer(this); } + final Map docIdsToScoreMap = new HashMap<>(); + + /* + * The idea for this optimization is to get K results, we need to atleast look at K vectors in the HNSW graph + * . Hence, if filtered results are less than K and filter query is present we should shift to exact search. + * This improves the recall. + */ + if (filterWeight != null && filterIdsArray.length <= knnQuery.getK()) { + docIdsToScoreMap.putAll(doExactSearch(context, filterIdsArray)); + } else { + final Map annResults = doANNSearch(context, filterIdsArray); + if (annResults == null) { + return null; + } + docIdsToScoreMap.putAll(annResults); + } + return convertSearchResponseToScorer(docIdsToScoreMap); + } + + private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx, final Weight filterWeight) throws IOException { + final Bits liveDocs = ctx.reader().getLiveDocs(); + final int maxDoc = ctx.reader().maxDoc(); + + final Scorer scorer = filterWeight.scorer(ctx); + if (scorer == null) { + return new FixedBitSet(0); + } + + final BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, maxDoc); + // TODO: Based on this cost shift to exact search, because even in ANN search you have to calculate the + // distance for K vectors. This can avoid calls to native layer and save some latency. + final int cost = acceptDocs.cardinality(); + log.debug("Number of docs valid for filter is = Cost for filtered k-nn is : {}", cost); + return acceptDocs; + } + + private BitSet createBitSet(final DocIdSetIterator filteredDocIdsIterator, final Bits liveDocs, int maxDoc) throws IOException { + if (liveDocs == null && filteredDocIdsIterator instanceof BitSetIterator) { + // If we already have a BitSet and no deletions, reuse the BitSet + return ((BitSetIterator) filteredDocIdsIterator).getBitSet(); + } + // Create a new BitSet from matching and live docs + FilteredDocIdSetIterator filterIterator = new FilteredDocIdSetIterator(filteredDocIdsIterator) { + @Override + protected boolean match(int doc) { + return liveDocs == null || liveDocs.get(doc); + } + }; + return BitSet.of(filterIterator, maxDoc); + } + + private int[] getFilterIdsArray(final LeafReaderContext context) throws IOException { + if (filterWeight == null) { + return new int[0]; + } + final BitSet filteredDocsBitSet = getFilteredDocsBitSet(context, this.filterWeight); + final int[] filteredIds = new int[filteredDocsBitSet.cardinality()]; + int filteredIdsIndex = 0; + int docId = 0; + while (true) { + docId = filteredDocsBitSet.nextSetBit(docId); + if (docId == DocIdSetIterator.NO_MORE_DOCS || docId + 1 == DocIdSetIterator.NO_MORE_DOCS) { + break; + } + log.debug("Docs in filtered docs id set is : {}", docId); + filteredIds[filteredIdsIndex] = docId; + filteredIdsIndex++; + docId++; + } + return filteredIds; + } + + private Map doANNSearch(final LeafReaderContext context, final int[] filterIdsArray) throws IOException { SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(context.reader()); String directory = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory().toString(); @@ -200,70 +280,45 @@ public Scorer scorer(LeafReaderContext context) throws IOException { return null; } - Map scores = Arrays.stream(results) + return Arrays.stream(results) .collect(Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), spaceType))); - int maxDoc = Collections.max(scores.keySet()) + 1; - DocIdSetBuilder docIdSetBuilder = new DocIdSetBuilder(maxDoc); - - // The docIdSetIterator will contain the docids of the returned results. So, before adding results to - // the builder, we can grow to results.length - DocIdSetBuilder.BulkAdder setAdder = docIdSetBuilder.grow(results.length); - Arrays.stream(results).forEach(result -> setAdder.add(result.getId())); - DocIdSetIterator docIdSetIter = docIdSetBuilder.build().iterator(); - return new KNNScorer(this, docIdSetIter, scores, boost); } - private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx, final Weight filterWeight) throws IOException { - final Bits liveDocs = ctx.reader().getLiveDocs(); - final int maxDoc = ctx.reader().maxDoc(); - - final Scorer scorer = filterWeight.scorer(ctx); - if (scorer == null) { - return new FixedBitSet(0); - } - - final BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, maxDoc); - // TODO: Based on this cost shift to exact search, because even in ANN search you have to calculate the - // distance for K vectors. This can avoid calls to native layer and save some latency. - final int cost = acceptDocs.cardinality(); - log.debug("Number of docs valid for filter is = Cost for filtered k-nn is : {}", cost); - return acceptDocs; - } - - private BitSet createBitSet(final DocIdSetIterator filteredDocIdsIterator, final Bits liveDocs, int maxDoc) throws IOException { - if (liveDocs == null && filteredDocIdsIterator instanceof BitSetIterator) { - // If we already have a BitSet and no deletions, reuse the BitSet - return ((BitSetIterator) filteredDocIdsIterator).getBitSet(); - } - // Create a new BitSet from matching and live docs - FilteredDocIdSetIterator filterIterator = new FilteredDocIdSetIterator(filteredDocIdsIterator) { - @Override - protected boolean match(int doc) { - return liveDocs == null || liveDocs.get(doc); + private Map doExactSearch(final LeafReaderContext leafReaderContext, final int[] filterIdsArray) { + final SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(leafReaderContext.reader()); + final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); + float[] queryVector = this.knnQuery.getQueryVector(); + try { + final BinaryDocValues values = DocValues.getBinary(leafReaderContext.reader(), fieldInfo.name); + final SpaceType spaceType = SpaceType.getSpace(fieldInfo.getAttribute(SPACE_TYPE)); + + final Map docToScore = new HashMap<>(); + for (int j : filterIdsArray) { + int docId = values.advance(j); + BytesRef value = values.binaryValue(); + ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length); + final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream); + final float[] vector = vectorSerializer.byteToFloatArray(byteStream); + // making min score as high score as this is closest to the vector + float score = spaceType.getVectorSimilarityFunction().compare(queryVector, vector); + docToScore.put(docId, score); } - }; - return BitSet.of(filterIterator, maxDoc); + return docToScore; + } catch (Exception e) { + log.error("Error while getting the doc values to do the k-NN Search for query : {}", this.knnQuery); + } + return Collections.emptyMap(); } - private int[] getFilterIdsArray(final LeafReaderContext context) throws IOException { - if (filterWeight == null) { - return new int[0]; - } - final BitSet filteredDocsBitSet = getFilteredDocsBitSet(context, this.filterWeight); - final int[] filteredIds = new int[filteredDocsBitSet.cardinality()]; - int filteredIdsIndex = 0; - int docId = 0; - while (true) { - docId = filteredDocsBitSet.nextSetBit(docId); - if (docId == DocIdSetIterator.NO_MORE_DOCS || docId + 1 == DocIdSetIterator.NO_MORE_DOCS) { - break; - } - log.debug("Docs in filtered docs id set is : {}", docId); - filteredIds[filteredIdsIndex] = docId; - filteredIdsIndex++; - docId++; - } - return filteredIds; + private Scorer convertSearchResponseToScorer(final Map docsToScore) throws IOException { + final int maxDoc = Collections.max(docsToScore.keySet()) + 1; + final DocIdSetBuilder docIdSetBuilder = new DocIdSetBuilder(maxDoc); + // The docIdSetIterator will contain the docids of the returned results. So, before adding results to + // the builder, we can grow to docsToScore.size() + final DocIdSetBuilder.BulkAdder setAdder = docIdSetBuilder.grow(docsToScore.size()); + docsToScore.keySet().forEach(setAdder::add); + final DocIdSetIterator docIdSetIter = docIdSetBuilder.build().iterator(); + return new KNNScorer(this, docIdSetIter, docsToScore, boost); } @Override