Skip to content

Commit

Permalink
Add comment and remove duplicate code
Browse files Browse the repository at this point in the history
  • Loading branch information
dungba88 committed Jul 20, 2024
1 parent 62e08f5 commit d9d2205
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,13 @@ abstract class AbstractKnnVectorQuery extends Query {

private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;

/** the KNN vector field to search */
protected final String field;

/** the number of documents to find */
protected final int k;

/** the filter to executed before KNN search */
private final Query filter;

public AbstractKnnVectorQuery(String field, int k, Query filter) {
Expand All @@ -70,18 +75,7 @@ public AbstractKnnVectorQuery(String field, int k, Query filter) {
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
IndexReader reader = indexSearcher.getIndexReader();

final Weight filterWeight;
if (filter != null) {
BooleanQuery booleanQuery =
new BooleanQuery.Builder()
.add(filter, BooleanClause.Occur.FILTER)
.add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER)
.build();
Query rewritten = indexSearcher.rewrite(booleanQuery);
filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
} else {
filterWeight = null;
}
final Weight filterWeight = createFilterWeight(indexSearcher);

TimeLimitingKnnCollectorManager knnCollectorManager =
new TimeLimitingKnnCollectorManager(
Expand All @@ -102,6 +96,21 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
return createRewrittenQuery(reader, topK);
}

// Create a Weight for the filter query. The filter will also be enhanced to only match documents
// with the KNN vector field.
private Weight createFilterWeight(IndexSearcher indexSearcher) throws IOException {
if (filter == null) {
return null;
}
BooleanQuery booleanQuery =
new BooleanQuery.Builder()
.add(filter, BooleanClause.Occur.FILTER)
.add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER)
.build();
Query rewritten = indexSearcher.rewrite(booleanQuery);
return indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f);
}

private TopDocs searchLeaf(
LeafReaderContext ctx,
Weight filterWeight,
Expand All @@ -116,6 +125,7 @@ private TopDocs searchLeaf(
return results;
}

// Execute the filter if any and perform KNN search at each segment.
private TopDocs getLeafResults(
LeafReaderContext ctx,
Weight filterWeight,
Expand Down Expand Up @@ -156,7 +166,7 @@ private TopDocs getLeafResults(
}
}

private BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc)
static BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc)
throws IOException {
if (liveDocs == null && iterator instanceof BitSetIterator bitSetIterator) {
// If we already have a BitSet and no deletions, reuse the BitSet
Expand Down Expand Up @@ -188,6 +198,8 @@ protected abstract TopDocs approximateSearch(
abstract VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi)
throws IOException;

// Perform a brute-force search by computing the vector score for each accepted doc and try to
// take the top k docs.
// We allow this to be overridden so that tests can check what search strategy is used
protected TopDocs exactSearch(
LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout)
Expand Down Expand Up @@ -255,6 +267,8 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {
return TopDocs.merge(k, perLeafResults);
}

// At this point we already collected top k matching docs, thus we only wrap the cached docs with
// their scores here.
private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
int len = topK.scoreDocs.length;

Expand All @@ -272,6 +286,8 @@ private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
return new DocAndScoreQuery(docs, scores, maxScore, segmentStarts, reader.getContext().id());
}

// For each segment, find the first index in <code>docs</code> belong to that segment.
// This method essentially partitions <code>docs</code> by segments
static int[] findSegmentStarts(List<LeafReaderContext> leaves, int[] docs) {
int[] starts = new int[leaves.size() + 1];
starts[starts.length - 1] = docs.length;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.lucene.search;

import static org.apache.lucene.search.AbstractKnnVectorQuery.createBitSet;

import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
Expand Down Expand Up @@ -111,30 +113,15 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti
if (results.scoreDocs.length == 0) {
return null;
}
vectorSimilarityScorer =
VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs);
vectorSimilarityScorer = VectorSimilarityScorer.fromScoreDocs(boost, results.scoreDocs);
} else {
Scorer scorer = filterWeight.scorer(context);
if (scorer == null) {
// If the filter does not match any documents
return null;
}

BitSet acceptDocs;
if (liveDocs == null && scorer.iterator() instanceof BitSetIterator bitSetIterator) {
// If there are no deletions, and matching docs are already cached
acceptDocs = bitSetIterator.getBitSet();
} else {
// Else collect all matching docs
FilteredDocIdSetIterator filtered =
new FilteredDocIdSetIterator(scorer.iterator()) {
@Override
protected boolean match(int doc) {
return liveDocs == null || liveDocs.get(doc);
}
};
acceptDocs = BitSet.of(filtered, leafReader.maxDoc());
}
BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, leafReader.maxDoc());

int cardinality = acceptDocs.cardinality();
if (cardinality == 0) {
Expand All @@ -150,7 +137,6 @@ protected boolean match(int doc) {
// Return a lazy-loading iterator
vectorSimilarityScorer =
VectorSimilarityScorer.fromAcceptDocs(
this,
boost,
createVectorScorer(context),
new BitSetIterator(acceptDocs, cardinality),
Expand All @@ -159,8 +145,7 @@ protected boolean match(int doc) {
return null;
} else {
// Return an iterator over the collected results
vectorSimilarityScorer =
VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs);
vectorSimilarityScorer = VectorSimilarityScorer.fromScoreDocs(boost, results.scoreDocs);
}
}
return new DefaultScorerSupplier(vectorSimilarityScorer);
Expand Down Expand Up @@ -206,7 +191,7 @@ private static class VectorSimilarityScorer extends Scorer {
this.cachedScore = cachedScore;
}

static VectorSimilarityScorer fromScoreDocs(Weight weight, float boost, ScoreDoc[] scoreDocs) {
static VectorSimilarityScorer fromScoreDocs(float boost, ScoreDoc[] scoreDocs) {
// Sort in ascending order of docid
Arrays.sort(scoreDocs, Comparator.comparingInt(scoreDoc -> scoreDoc.doc));

Expand Down Expand Up @@ -256,11 +241,7 @@ public long cost() {
}

static VectorSimilarityScorer fromAcceptDocs(
Weight weight,
float boost,
VectorScorer scorer,
DocIdSetIterator acceptDocs,
float threshold) {
float boost, VectorScorer scorer, DocIdSetIterator acceptDocs, float threshold) {
if (scorer == null) {
return null;
}
Expand Down

0 comments on commit d9d2205

Please sign in to comment.