forked from opensearch-project/k-NN
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix Faiss efficient filter exact search using byte vector datatype
Signed-off-by: Naveen Tatikonda <[email protected]>
- Loading branch information
1 parent
23b95e7
commit faeffe5
Showing
9 changed files
with
299 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
75 changes: 75 additions & 0 deletions
75
src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNBinaryIterator.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.index.query.filtered; | ||
|
||
import org.apache.lucene.search.DocIdSetIterator; | ||
import org.apache.lucene.util.BitSet; | ||
import org.apache.lucene.util.BitSetIterator; | ||
import org.opensearch.knn.index.SpaceType; | ||
import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; | ||
|
||
import java.io.IOException; | ||
|
||
/** | ||
* Inspired by DiversifyingChildrenFloatKnnVectorQuery in lucene | ||
* https://github.com/apache/lucene/blob/7b8aece125aabff2823626d5b939abf4747f63a7/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java#L162 | ||
* | ||
* The class is used in KNNWeight to score filtered KNN field by iterating filterIdsArray. | ||
*/ | ||
public class FilteredIdsKNNBinaryIterator implements KNNIterator { | ||
// Array of doc ids to iterate | ||
protected final BitSet filterIdsBitSet; | ||
protected final BitSetIterator bitSetIterator; | ||
protected final byte[] queryVector; | ||
protected final KNNBinaryVectorValues binaryVectorValues; | ||
protected final SpaceType spaceType; | ||
protected float currentScore = Float.NEGATIVE_INFINITY; | ||
protected int docId; | ||
|
||
public FilteredIdsKNNBinaryIterator( | ||
final BitSet filterIdsBitSet, | ||
final byte[] queryVector, | ||
final KNNBinaryVectorValues binaryVectorValues, | ||
final SpaceType spaceType | ||
) { | ||
this.filterIdsBitSet = filterIdsBitSet; | ||
this.bitSetIterator = new BitSetIterator(filterIdsBitSet, filterIdsBitSet.length()); | ||
this.queryVector = queryVector; | ||
this.binaryVectorValues = binaryVectorValues; | ||
this.spaceType = spaceType; | ||
this.docId = bitSetIterator.nextDoc(); | ||
} | ||
|
||
/** | ||
* Advance to the next doc and update score value with score of the next doc. | ||
* DocIdSetIterator.NO_MORE_DOCS is returned when there is no more docs | ||
* | ||
* @return next doc id | ||
*/ | ||
@Override | ||
public int nextDoc() throws IOException { | ||
|
||
if (docId == DocIdSetIterator.NO_MORE_DOCS) { | ||
return DocIdSetIterator.NO_MORE_DOCS; | ||
} | ||
int doc = binaryVectorValues.advance(docId); | ||
currentScore = computeScore(); | ||
docId = bitSetIterator.nextDoc(); | ||
return doc; | ||
} | ||
|
||
@Override | ||
public float score() { | ||
return currentScore; | ||
} | ||
|
||
protected float computeScore() throws IOException { | ||
final byte[] vector = binaryVectorValues.getVector(); | ||
// Calculates a similarity score between the two vectors with a specified function. Higher similarity | ||
// scores correspond to closer vectors. | ||
return spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
61 changes: 61 additions & 0 deletions
61
...main/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNBinaryIterator.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.index.query.filtered; | ||
|
||
import org.apache.lucene.search.DocIdSetIterator; | ||
import org.apache.lucene.util.BitSet; | ||
import org.opensearch.knn.index.SpaceType; | ||
import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; | ||
|
||
import java.io.IOException; | ||
|
||
/** | ||
* This iterator iterates filterIdsArray to score. However, it dedupe docs per each parent doc | ||
* of which ID is set in parentBitSet and only return best child doc with the highest score. | ||
*/ | ||
public class NestedFilteredIdsKNNBinaryIterator extends FilteredIdsKNNBinaryIterator { | ||
private final BitSet parentBitSet; | ||
|
||
public NestedFilteredIdsKNNBinaryIterator( | ||
final BitSet filterIdsArray, | ||
final byte[] queryVector, | ||
final KNNBinaryVectorValues binaryVectorValues, | ||
final SpaceType spaceType, | ||
final BitSet parentBitSet | ||
) { | ||
super(filterIdsArray, queryVector, binaryVectorValues, spaceType); | ||
this.parentBitSet = parentBitSet; | ||
} | ||
|
||
/** | ||
* Advance to the next best child doc per parent and update score with the best score among child docs from the parent. | ||
* DocIdSetIterator.NO_MORE_DOCS is returned when there is no more docs | ||
* | ||
* @return next best child doc id | ||
*/ | ||
@Override | ||
public int nextDoc() throws IOException { | ||
if (docId == DocIdSetIterator.NO_MORE_DOCS) { | ||
return DocIdSetIterator.NO_MORE_DOCS; | ||
} | ||
|
||
currentScore = Float.NEGATIVE_INFINITY; | ||
int currentParent = parentBitSet.nextSetBit(docId); | ||
int bestChild = -1; | ||
|
||
while (docId != DocIdSetIterator.NO_MORE_DOCS && docId < currentParent) { | ||
binaryVectorValues.advance(docId); | ||
float score = computeScore(); | ||
if (score > currentScore) { | ||
bestChild = docId; | ||
currentScore = score; | ||
} | ||
docId = bitSetIterator.nextDoc(); | ||
} | ||
|
||
return bestChild; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
50 changes: 50 additions & 0 deletions
50
src/test/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNBinaryIteratorTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.index.query.filtered; | ||
|
||
import junit.framework.TestCase; | ||
import lombok.SneakyThrows; | ||
import org.apache.lucene.search.DocIdSetIterator; | ||
import org.apache.lucene.util.FixedBitSet; | ||
import org.opensearch.knn.index.SpaceType; | ||
import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; | ||
|
||
import java.util.Arrays; | ||
import java.util.List; | ||
import java.util.stream.Collectors; | ||
|
||
import static org.mockito.Mockito.mock; | ||
import static org.mockito.Mockito.when; | ||
|
||
public class FilteredIdsKNNBinaryIteratorTests extends TestCase { | ||
@SneakyThrows | ||
public void testNextDoc_whenCalled_IterateAllDocs() { | ||
final SpaceType spaceType = SpaceType.HAMMING; | ||
final byte[] queryVector = { 1, 2, 3 }; | ||
final int[] filterIds = { 1, 2, 3 }; | ||
final List<byte[]> dataVectors = Arrays.asList(new byte[] { 11, 12, 13 }, new byte[] { 14, 15, 16 }, new byte[] { 17, 18, 19 }); | ||
final List<Float> expectedScores = dataVectors.stream() | ||
.map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) | ||
.collect(Collectors.toList()); | ||
|
||
KNNBinaryVectorValues values = mock(KNNBinaryVectorValues.class); | ||
when(values.getVector()).thenReturn(dataVectors.get(0), dataVectors.get(1), dataVectors.get(2)); | ||
|
||
FixedBitSet filterBitSet = new FixedBitSet(4); | ||
for (int id : filterIds) { | ||
when(values.advance(id)).thenReturn(id); | ||
filterBitSet.set(id); | ||
} | ||
|
||
// Execute and verify | ||
FilteredIdsKNNBinaryIterator iterator = new FilteredIdsKNNBinaryIterator(filterBitSet, queryVector, values, spaceType); | ||
for (int i = 0; i < filterIds.length; i++) { | ||
assertEquals(filterIds[i], iterator.nextDoc()); | ||
assertEquals(expectedScores.get(i), (Float) iterator.score()); | ||
} | ||
assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc()); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.