forked from opensearch-project/k-NN
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix Faiss efficient filter exact search using byte vector datatype
Signed-off-by: Naveen Tatikonda <[email protected]>
- Loading branch information
1 parent
6f6dd56
commit beb98d0
Showing
11 changed files
with
685 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
92 changes: 92 additions & 0 deletions
92
src/main/java/org/opensearch/knn/index/query/iterators/BinaryVectorIdsKNNIterator.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.index.query.iterators; | ||
|
||
import org.apache.lucene.search.DocIdSetIterator; | ||
import org.apache.lucene.util.BitSet; | ||
import org.apache.lucene.util.BitSetIterator; | ||
import org.opensearch.common.Nullable; | ||
import org.opensearch.knn.index.SpaceType; | ||
import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; | ||
|
||
import java.io.IOException; | ||
|
||
/** | ||
* Inspired by DiversifyingChildrenFloatKnnVectorQuery in lucene | ||
* https://github.com/apache/lucene/blob/7b8aece125aabff2823626d5b939abf4747f63a7/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java#L162 | ||
* | ||
* The class is used in KNNWeight to score all docs, but, it iterates over filterIdsArray if filter is provided | ||
*/ | ||
public class BinaryVectorIdsKNNIterator implements KNNIterator { | ||
protected final BitSetIterator bitSetIterator; | ||
protected final byte[] queryVector; | ||
protected final KNNBinaryVectorValues binaryVectorValues; | ||
protected final SpaceType spaceType; | ||
protected float currentScore = Float.NEGATIVE_INFINITY; | ||
protected int docId; | ||
|
||
public BinaryVectorIdsKNNIterator( | ||
@Nullable final BitSet filterIdsBitSet, | ||
final byte[] queryVector, | ||
final KNNBinaryVectorValues binaryVectorValues, | ||
final SpaceType spaceType | ||
) throws IOException { | ||
this.bitSetIterator = filterIdsBitSet == null ? null : new BitSetIterator(filterIdsBitSet, filterIdsBitSet.length()); | ||
this.queryVector = queryVector; | ||
this.binaryVectorValues = binaryVectorValues; | ||
this.spaceType = spaceType; | ||
// This cannot be moved inside nextDoc() method since it will break when we have nested field, where | ||
// nextDoc should already be referring to next knnVectorValues | ||
this.docId = getNextDocId(); | ||
} | ||
|
||
public BinaryVectorIdsKNNIterator(final byte[] queryVector, final KNNBinaryVectorValues binaryVectorValues, final SpaceType spaceType) | ||
throws IOException { | ||
this(null, queryVector, binaryVectorValues, spaceType); | ||
} | ||
|
||
/** | ||
* Advance to the next doc and update score value with score of the next doc. | ||
* DocIdSetIterator.NO_MORE_DOCS is returned when there is no more docs | ||
* | ||
* @return next doc id | ||
*/ | ||
@Override | ||
public int nextDoc() throws IOException { | ||
|
||
if (docId == DocIdSetIterator.NO_MORE_DOCS) { | ||
return DocIdSetIterator.NO_MORE_DOCS; | ||
} | ||
currentScore = computeScore(); | ||
int currentDocId = docId; | ||
docId = getNextDocId(); | ||
return currentDocId; | ||
} | ||
|
||
@Override | ||
public float score() { | ||
return currentScore; | ||
} | ||
|
||
protected float computeScore() throws IOException { | ||
final byte[] vector = binaryVectorValues.getVector(); | ||
// Calculates a similarity score between the two vectors with a specified function. Higher similarity | ||
// scores correspond to closer vectors. | ||
return spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector); | ||
} | ||
|
||
protected int getNextDocId() throws IOException { | ||
if (bitSetIterator == null) { | ||
return binaryVectorValues.nextDoc(); | ||
} | ||
int nextDocID = this.bitSetIterator.nextDoc(); | ||
// For filter case, advance vector values to corresponding doc id from filter bit set | ||
if (nextDocID != DocIdSetIterator.NO_MORE_DOCS) { | ||
binaryVectorValues.advance(nextDocID); | ||
} | ||
return nextDocID; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
77 changes: 77 additions & 0 deletions
77
src/main/java/org/opensearch/knn/index/query/iterators/NestedBinaryVectorIdsKNNIterator.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.index.query.iterators; | ||
|
||
import org.apache.lucene.search.DocIdSetIterator; | ||
import org.apache.lucene.util.BitSet; | ||
import org.opensearch.common.Nullable; | ||
import org.opensearch.knn.index.SpaceType; | ||
import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; | ||
|
||
import java.io.IOException; | ||
|
||
/** | ||
* This iterator iterates filterIdsArray to scoreif filter is provided else it iterates over all docs. | ||
* However, it dedupe docs per each parent doc | ||
* of which ID is set in parentBitSet and only return best child doc with the highest score. | ||
*/ | ||
public class NestedBinaryVectorIdsKNNIterator extends BinaryVectorIdsKNNIterator { | ||
private final BitSet parentBitSet; | ||
|
||
public NestedBinaryVectorIdsKNNIterator( | ||
@Nullable final BitSet filterIdsArray, | ||
final byte[] queryVector, | ||
final KNNBinaryVectorValues binaryVectorValues, | ||
final SpaceType spaceType, | ||
final BitSet parentBitSet | ||
) throws IOException { | ||
super(filterIdsArray, queryVector, binaryVectorValues, spaceType); | ||
this.parentBitSet = parentBitSet; | ||
} | ||
|
||
public NestedBinaryVectorIdsKNNIterator( | ||
final byte[] queryVector, | ||
final KNNBinaryVectorValues binaryVectorValues, | ||
final SpaceType spaceType, | ||
final BitSet parentBitSet | ||
) throws IOException { | ||
super(null, queryVector, binaryVectorValues, spaceType); | ||
this.parentBitSet = parentBitSet; | ||
} | ||
|
||
/** | ||
* Advance to the next best child doc per parent and update score with the best score among child docs from the parent. | ||
* DocIdSetIterator.NO_MORE_DOCS is returned when there is no more docs | ||
* | ||
* @return next best child doc id | ||
*/ | ||
@Override | ||
public int nextDoc() throws IOException { | ||
if (docId == DocIdSetIterator.NO_MORE_DOCS) { | ||
return DocIdSetIterator.NO_MORE_DOCS; | ||
} | ||
|
||
currentScore = Float.NEGATIVE_INFINITY; | ||
int currentParent = parentBitSet.nextSetBit(docId); | ||
int bestChild = -1; | ||
|
||
// In order to traverse all children for given parent, we have to use docId < parentId, because, | ||
// kNNVectorValues will not have parent id since DocId is unique per segment. For ex: let's say for doc id 1, there is one child | ||
// and for doc id 5, there are three children. In that case knnVectorValues iterator will have [0, 2, 3, 4] | ||
// and parentBitSet will have [1,5] | ||
// Hence, we have to iterate till docId from knnVectorValues is less than parentId instead of till equal to parentId | ||
while (docId != DocIdSetIterator.NO_MORE_DOCS && docId < currentParent) { | ||
float score = computeScore(); | ||
if (score > currentScore) { | ||
bestChild = docId; | ||
currentScore = score; | ||
} | ||
docId = getNextDocId(); | ||
} | ||
|
||
return bestChild; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.