Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added exact search for cases when filteredIds < k to improve the recall for filtered k-NN Search #928

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 113 additions & 58 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@
package org.opensearch.knn.index.query;

import lombok.extern.log4j.Log4j2;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.search.FilteredDocIdSetIterator;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.codec.util.KNNVectorSerializer;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.index.memory.NativeMemoryAllocation;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
Expand All @@ -35,10 +40,12 @@
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.plugin.stats.KNNCounter;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
Expand Down Expand Up @@ -97,6 +104,79 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
if (filterWeight != null && filterIdsArray.length == 0) {
return KNNScorer.emptyScorer(this);
}
final Map<Integer, Float> docIdsToScoreMap = new HashMap<>();

/*
* The idea for this optimization is to get K results, we need to atleast look at K vectors in the HNSW graph
* . Hence, if filtered results are less than K and filter query is present we should shift to exact search.
* This improves the recall.
*/
if (filterWeight != null && filterIdsArray.length <= knnQuery.getK()) {
navneet1v marked this conversation as resolved.
Show resolved Hide resolved
docIdsToScoreMap.putAll(doExactSearch(context, filterIdsArray));
} else {
final Map<Integer, Float> annResults = doANNSearch(context, filterIdsArray);
if (annResults == null) {
return null;
}
docIdsToScoreMap.putAll(annResults);
}
return convertSearchResponseToScorer(docIdsToScoreMap);
}

private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx, final Weight filterWeight) throws IOException {
final Bits liveDocs = ctx.reader().getLiveDocs();
final int maxDoc = ctx.reader().maxDoc();

final Scorer scorer = filterWeight.scorer(ctx);
if (scorer == null) {
return new FixedBitSet(0);
}

final BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, maxDoc);
// TODO: Based on this cost shift to exact search, because even in ANN search you have to calculate the
// distance for K vectors. This can avoid calls to native layer and save some latency.
final int cost = acceptDocs.cardinality();
log.debug("Number of docs valid for filter is = Cost for filtered k-nn is : {}", cost);
return acceptDocs;
}

private BitSet createBitSet(final DocIdSetIterator filteredDocIdsIterator, final Bits liveDocs, int maxDoc) throws IOException {
if (liveDocs == null && filteredDocIdsIterator instanceof BitSetIterator) {
// If we already have a BitSet and no deletions, reuse the BitSet
return ((BitSetIterator) filteredDocIdsIterator).getBitSet();
}
// Create a new BitSet from matching and live docs
FilteredDocIdSetIterator filterIterator = new FilteredDocIdSetIterator(filteredDocIdsIterator) {
@Override
protected boolean match(int doc) {
return liveDocs == null || liveDocs.get(doc);
}
};
return BitSet.of(filterIterator, maxDoc);
}

private int[] getFilterIdsArray(final LeafReaderContext context) throws IOException {
if (filterWeight == null) {
return new int[0];
}
final BitSet filteredDocsBitSet = getFilteredDocsBitSet(context, this.filterWeight);
final int[] filteredIds = new int[filteredDocsBitSet.cardinality()];
int filteredIdsIndex = 0;
int docId = 0;
while (true) {
docId = filteredDocsBitSet.nextSetBit(docId);
if (docId == DocIdSetIterator.NO_MORE_DOCS || docId + 1 == DocIdSetIterator.NO_MORE_DOCS) {
break;
}
log.debug("Docs in filtered docs id set is : {}", docId);
filteredIds[filteredIdsIndex] = docId;
filteredIdsIndex++;
docId++;
}
return filteredIds;
}

private Map<Integer, Float> doANNSearch(final LeafReaderContext context, final int[] filterIdsArray) throws IOException {
SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(context.reader());
String directory = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory().toString();

Expand Down Expand Up @@ -200,70 +280,45 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
return null;
}

Map<Integer, Float> scores = Arrays.stream(results)
return Arrays.stream(results)
.collect(Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), spaceType)));
int maxDoc = Collections.max(scores.keySet()) + 1;
DocIdSetBuilder docIdSetBuilder = new DocIdSetBuilder(maxDoc);

// The docIdSetIterator will contain the docids of the returned results. So, before adding results to
// the builder, we can grow to results.length
DocIdSetBuilder.BulkAdder setAdder = docIdSetBuilder.grow(results.length);
Arrays.stream(results).forEach(result -> setAdder.add(result.getId()));
DocIdSetIterator docIdSetIter = docIdSetBuilder.build().iterator();
return new KNNScorer(this, docIdSetIter, scores, boost);
}

private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx, final Weight filterWeight) throws IOException {
final Bits liveDocs = ctx.reader().getLiveDocs();
final int maxDoc = ctx.reader().maxDoc();

final Scorer scorer = filterWeight.scorer(ctx);
if (scorer == null) {
return new FixedBitSet(0);
}

final BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, maxDoc);
// TODO: Based on this cost shift to exact search, because even in ANN search you have to calculate the
// distance for K vectors. This can avoid calls to native layer and save some latency.
final int cost = acceptDocs.cardinality();
log.debug("Number of docs valid for filter is = Cost for filtered k-nn is : {}", cost);
return acceptDocs;
}

private BitSet createBitSet(final DocIdSetIterator filteredDocIdsIterator, final Bits liveDocs, int maxDoc) throws IOException {
if (liveDocs == null && filteredDocIdsIterator instanceof BitSetIterator) {
// If we already have a BitSet and no deletions, reuse the BitSet
return ((BitSetIterator) filteredDocIdsIterator).getBitSet();
}
// Create a new BitSet from matching and live docs
FilteredDocIdSetIterator filterIterator = new FilteredDocIdSetIterator(filteredDocIdsIterator) {
@Override
protected boolean match(int doc) {
return liveDocs == null || liveDocs.get(doc);
private Map<Integer, Float> doExactSearch(final LeafReaderContext leafReaderContext, final int[] filterIdsArray) {
final SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(leafReaderContext.reader());
final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());
float[] queryVector = this.knnQuery.getQueryVector();
try {
final BinaryDocValues values = DocValues.getBinary(leafReaderContext.reader(), fieldInfo.name);
final SpaceType spaceType = SpaceType.getSpace(fieldInfo.getAttribute(SPACE_TYPE));

final Map<Integer, Float> docToScore = new HashMap<>();
for (int j : filterIdsArray) {
int docId = values.advance(j);
BytesRef value = values.binaryValue();
ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length);
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
final float[] vector = vectorSerializer.byteToFloatArray(byteStream);
// making min score as high score as this is closest to the vector
float score = spaceType.getVectorSimilarityFunction().compare(queryVector, vector);
docToScore.put(docId, score);
}
};
return BitSet.of(filterIterator, maxDoc);
return docToScore;
} catch (Exception e) {
log.error("Error while getting the doc values to do the k-NN Search for query : {}", this.knnQuery);
}
return Collections.emptyMap();
}

private int[] getFilterIdsArray(final LeafReaderContext context) throws IOException {
if (filterWeight == null) {
return new int[0];
}
final BitSet filteredDocsBitSet = getFilteredDocsBitSet(context, this.filterWeight);
final int[] filteredIds = new int[filteredDocsBitSet.cardinality()];
int filteredIdsIndex = 0;
int docId = 0;
while (true) {
docId = filteredDocsBitSet.nextSetBit(docId);
if (docId == DocIdSetIterator.NO_MORE_DOCS || docId + 1 == DocIdSetIterator.NO_MORE_DOCS) {
break;
}
log.debug("Docs in filtered docs id set is : {}", docId);
filteredIds[filteredIdsIndex] = docId;
filteredIdsIndex++;
docId++;
}
return filteredIds;
private Scorer convertSearchResponseToScorer(final Map<Integer, Float> docsToScore) throws IOException {
final int maxDoc = Collections.max(docsToScore.keySet()) + 1;
final DocIdSetBuilder docIdSetBuilder = new DocIdSetBuilder(maxDoc);
// The docIdSetIterator will contain the docids of the returned results. So, before adding results to
// the builder, we can grow to docsToScore.size()
final DocIdSetBuilder.BulkAdder setAdder = docIdSetBuilder.grow(docsToScore.size());
docsToScore.keySet().forEach(setAdder::add);
final DocIdSetIterator docIdSetIter = docIdSetBuilder.build().iterator();
return new KNNScorer(this, docIdSetIter, docsToScore, boost);
}

@Override
Expand Down