From d9d2205617dc7f0c4d6078de8f3c1aab74a80646 Mon Sep 17 00:00:00 2001 From: dungba88 Date: Sat, 20 Jul 2024 18:50:42 +0900 Subject: [PATCH] Add comment and remove duplicate code --- .../lucene/search/AbstractKnnVectorQuery.java | 42 +++++++++++++------ .../search/AbstractVectorSimilarityQuery.java | 33 ++++----------- 2 files changed, 36 insertions(+), 39 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index c0ce4eea3c6b..f8335bae18f7 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -53,8 +53,13 @@ abstract class AbstractKnnVectorQuery extends Query { private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS; + /** the KNN vector field to search */ protected final String field; + + /** the number of documents to find */ protected final int k; + + /** the filter to executed before KNN search */ private final Query filter; public AbstractKnnVectorQuery(String field, int k, Query filter) { @@ -70,18 +75,7 @@ public AbstractKnnVectorQuery(String field, int k, Query filter) { public Query rewrite(IndexSearcher indexSearcher) throws IOException { IndexReader reader = indexSearcher.getIndexReader(); - final Weight filterWeight; - if (filter != null) { - BooleanQuery booleanQuery = - new BooleanQuery.Builder() - .add(filter, BooleanClause.Occur.FILTER) - .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER) - .build(); - Query rewritten = indexSearcher.rewrite(booleanQuery); - filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f); - } else { - filterWeight = null; - } + final Weight filterWeight = createFilterWeight(indexSearcher); TimeLimitingKnnCollectorManager knnCollectorManager = new TimeLimitingKnnCollectorManager( @@ -102,6 +96,21 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { return createRewrittenQuery(reader, topK); } + // Create a Weight for the filter query. The filter will also be enhanced to only match documents + // with the KNN vector field. + private Weight createFilterWeight(IndexSearcher indexSearcher) throws IOException { + if (filter == null) { + return null; + } + BooleanQuery booleanQuery = + new BooleanQuery.Builder() + .add(filter, BooleanClause.Occur.FILTER) + .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER) + .build(); + Query rewritten = indexSearcher.rewrite(booleanQuery); + return indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1f); + } + private TopDocs searchLeaf( LeafReaderContext ctx, Weight filterWeight, @@ -116,6 +125,7 @@ private TopDocs searchLeaf( return results; } + // Execute the filter if any and perform KNN search at each segment. private TopDocs getLeafResults( LeafReaderContext ctx, Weight filterWeight, @@ -156,7 +166,7 @@ private TopDocs getLeafResults( } } - private BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc) + static BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc) throws IOException { if (liveDocs == null && iterator instanceof BitSetIterator bitSetIterator) { // If we already have a BitSet and no deletions, reuse the BitSet @@ -188,6 +198,8 @@ protected abstract TopDocs approximateSearch( abstract VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException; + // Perform a brute-force search by computing the vector score for each accepted doc and try to + // take the top k docs. // We allow this to be overridden so that tests can check what search strategy is used protected TopDocs exactSearch( LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) @@ -255,6 +267,8 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) { return TopDocs.merge(k, perLeafResults); } + // At this point we already collected top k matching docs, thus we only wrap the cached docs with + // their scores here. private Query createRewrittenQuery(IndexReader reader, TopDocs topK) { int len = topK.scoreDocs.length; @@ -272,6 +286,8 @@ private Query createRewrittenQuery(IndexReader reader, TopDocs topK) { return new DocAndScoreQuery(docs, scores, maxScore, segmentStarts, reader.getContext().id()); } + // For each segment, find the first index in docs belong to that segment. + // This method essentially partitions docs by segments static int[] findSegmentStarts(List leaves, int[] docs) { int[] starts = new int[leaves.size() + 1]; starts[starts.length - 1] = docs.length; diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java index 77a5ff6f24f0..e89197c535ab 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java @@ -16,6 +16,8 @@ */ package org.apache.lucene.search; +import static org.apache.lucene.search.AbstractKnnVectorQuery.createBitSet; + import java.io.IOException; import java.util.Arrays; import java.util.Comparator; @@ -111,8 +113,7 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti if (results.scoreDocs.length == 0) { return null; } - vectorSimilarityScorer = - VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs); + vectorSimilarityScorer = VectorSimilarityScorer.fromScoreDocs(boost, results.scoreDocs); } else { Scorer scorer = filterWeight.scorer(context); if (scorer == null) { @@ -120,21 +121,7 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti return null; } - BitSet acceptDocs; - if (liveDocs == null && scorer.iterator() instanceof BitSetIterator bitSetIterator) { - // If there are no deletions, and matching docs are already cached - acceptDocs = bitSetIterator.getBitSet(); - } else { - // Else collect all matching docs - FilteredDocIdSetIterator filtered = - new FilteredDocIdSetIterator(scorer.iterator()) { - @Override - protected boolean match(int doc) { - return liveDocs == null || liveDocs.get(doc); - } - }; - acceptDocs = BitSet.of(filtered, leafReader.maxDoc()); - } + BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, leafReader.maxDoc()); int cardinality = acceptDocs.cardinality(); if (cardinality == 0) { @@ -150,7 +137,6 @@ protected boolean match(int doc) { // Return a lazy-loading iterator vectorSimilarityScorer = VectorSimilarityScorer.fromAcceptDocs( - this, boost, createVectorScorer(context), new BitSetIterator(acceptDocs, cardinality), @@ -159,8 +145,7 @@ protected boolean match(int doc) { return null; } else { // Return an iterator over the collected results - vectorSimilarityScorer = - VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs); + vectorSimilarityScorer = VectorSimilarityScorer.fromScoreDocs(boost, results.scoreDocs); } } return new DefaultScorerSupplier(vectorSimilarityScorer); @@ -206,7 +191,7 @@ private static class VectorSimilarityScorer extends Scorer { this.cachedScore = cachedScore; } - static VectorSimilarityScorer fromScoreDocs(Weight weight, float boost, ScoreDoc[] scoreDocs) { + static VectorSimilarityScorer fromScoreDocs(float boost, ScoreDoc[] scoreDocs) { // Sort in ascending order of docid Arrays.sort(scoreDocs, Comparator.comparingInt(scoreDoc -> scoreDoc.doc)); @@ -256,11 +241,7 @@ public long cost() { } static VectorSimilarityScorer fromAcceptDocs( - Weight weight, - float boost, - VectorScorer scorer, - DocIdSetIterator acceptDocs, - float threshold) { + float boost, VectorScorer scorer, DocIdSetIterator acceptDocs, float threshold) { if (scorer == null) { return null; }