Skip to content

Commit

Permalink
Fix Faiss efficient filter exact search using byte vector datatype
Browse files Browse the repository at this point in the history
Signed-off-by: Naveen Tatikonda <[email protected]>
  • Loading branch information
naveentatikonda committed Sep 28, 2024
1 parent 23b95e7 commit faeffe5
Show file tree
Hide file tree
Showing 9 changed files with 299 additions and 29 deletions.
21 changes: 19 additions & 2 deletions src/main/java/org/opensearch/knn/index/query/ExactSearcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.query.filtered.FilteredIdsKNNBinaryIterator;
import org.opensearch.knn.index.query.filtered.FilteredIdsKNNByteIterator;
import org.opensearch.knn.index.query.filtered.FilteredIdsKNNIterator;
import org.opensearch.knn.index.query.filtered.KNNIterator;
import org.opensearch.knn.index.query.filtered.NestedFilteredIdsKNNBinaryIterator;
import org.opensearch.knn.index.query.filtered.NestedFilteredIdsKNNByteIterator;
import org.opensearch.knn.index.query.filtered.NestedFilteredIdsKNNIterator;
import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNByteVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory;
Expand Down Expand Up @@ -110,7 +113,7 @@ private KNNIterator getMatchedKNNIterator(LeafReaderContext leafReaderContext, E

if (VectorDataType.BINARY == knnQuery.getVectorDataType() && isNestedRequired) {
final KNNVectorValues<byte[]> vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader);
return new NestedFilteredIdsKNNByteIterator(
return new NestedFilteredIdsKNNBinaryIterator(
matchedDocs,
knnQuery.getByteQueryVector(),
(KNNBinaryVectorValues) vectorValues,
Expand All @@ -121,13 +124,27 @@ private KNNIterator getMatchedKNNIterator(LeafReaderContext leafReaderContext, E

if (VectorDataType.BINARY == knnQuery.getVectorDataType()) {
final KNNVectorValues<byte[]> vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader);
return new FilteredIdsKNNByteIterator(
return new FilteredIdsKNNBinaryIterator(
matchedDocs,
knnQuery.getByteQueryVector(),
(KNNBinaryVectorValues) vectorValues,
spaceType
);
}

if (VectorDataType.BYTE == knnQuery.getVectorDataType()) {
final KNNVectorValues<byte[]> vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader);
if (isNestedRequired) {
return new NestedFilteredIdsKNNByteIterator(
matchedDocs,
knnQuery.getQueryVector(),
(KNNByteVectorValues) vectorValues,
spaceType,
knnQuery.getParentsFilter().getBitSet(leafReaderContext)
);
}
return new FilteredIdsKNNByteIterator(matchedDocs, knnQuery.getQueryVector(), (KNNByteVectorValues) vectorValues, spaceType);
}
final byte[] quantizedQueryVector;
final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo;
if (exactSearcherContext.isUseQuantizedVectorsForSearch()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query.filtered;

import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues;

import java.io.IOException;

/**
* Inspired by DiversifyingChildrenFloatKnnVectorQuery in lucene
* https://github.com/apache/lucene/blob/7b8aece125aabff2823626d5b939abf4747f63a7/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java#L162
*
* The class is used in KNNWeight to score filtered KNN field by iterating filterIdsArray.
*/
public class FilteredIdsKNNBinaryIterator implements KNNIterator {
// Array of doc ids to iterate
protected final BitSet filterIdsBitSet;
protected final BitSetIterator bitSetIterator;
protected final byte[] queryVector;
protected final KNNBinaryVectorValues binaryVectorValues;
protected final SpaceType spaceType;
protected float currentScore = Float.NEGATIVE_INFINITY;
protected int docId;

public FilteredIdsKNNBinaryIterator(
final BitSet filterIdsBitSet,
final byte[] queryVector,
final KNNBinaryVectorValues binaryVectorValues,
final SpaceType spaceType
) {
this.filterIdsBitSet = filterIdsBitSet;
this.bitSetIterator = new BitSetIterator(filterIdsBitSet, filterIdsBitSet.length());
this.queryVector = queryVector;
this.binaryVectorValues = binaryVectorValues;
this.spaceType = spaceType;
this.docId = bitSetIterator.nextDoc();
}

/**
* Advance to the next doc and update score value with score of the next doc.
* DocIdSetIterator.NO_MORE_DOCS is returned when there is no more docs
*
* @return next doc id
*/
@Override
public int nextDoc() throws IOException {

if (docId == DocIdSetIterator.NO_MORE_DOCS) {
return DocIdSetIterator.NO_MORE_DOCS;
}
int doc = binaryVectorValues.advance(docId);
currentScore = computeScore();
docId = bitSetIterator.nextDoc();
return doc;
}

@Override
public float score() {
return currentScore;
}

protected float computeScore() throws IOException {
final byte[] vector = binaryVectorValues.getVector();
// Calculates a similarity score between the two vectors with a specified function. Higher similarity
// scores correspond to closer vectors.
return spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNByteVectorValues;

import java.io.IOException;

Expand All @@ -23,22 +23,22 @@ public class FilteredIdsKNNByteIterator implements KNNIterator {
// Array of doc ids to iterate
protected final BitSet filterIdsBitSet;
protected final BitSetIterator bitSetIterator;
protected final byte[] queryVector;
protected final KNNBinaryVectorValues binaryVectorValues;
protected final float[] queryVector;
protected final KNNByteVectorValues byteVectorValues;
protected final SpaceType spaceType;
protected float currentScore = Float.NEGATIVE_INFINITY;
protected int docId;

public FilteredIdsKNNByteIterator(
final BitSet filterIdsBitSet,
final byte[] queryVector,
final KNNBinaryVectorValues binaryVectorValues,
final float[] queryVector,
final KNNByteVectorValues byteVectorValues,
final SpaceType spaceType
) {
this.filterIdsBitSet = filterIdsBitSet;
this.bitSetIterator = new BitSetIterator(filterIdsBitSet, filterIdsBitSet.length());
this.queryVector = queryVector;
this.binaryVectorValues = binaryVectorValues;
this.byteVectorValues = byteVectorValues;
this.spaceType = spaceType;
this.docId = bitSetIterator.nextDoc();
}
Expand All @@ -55,7 +55,7 @@ public int nextDoc() throws IOException {
if (docId == DocIdSetIterator.NO_MORE_DOCS) {
return DocIdSetIterator.NO_MORE_DOCS;
}
int doc = binaryVectorValues.advance(docId);
int doc = byteVectorValues.advance(docId);
currentScore = computeScore();
docId = bitSetIterator.nextDoc();
return doc;
Expand All @@ -67,9 +67,13 @@ public float score() {
}

protected float computeScore() throws IOException {
final byte[] vector = binaryVectorValues.getVector();
final byte[] vector = byteVectorValues.getVector();
// Calculates a similarity score between the two vectors with a specified function. Higher similarity
// scores correspond to closer vectors.
return spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector);
final byte[] byteQueryvector = new byte[queryVector.length];
for (int i = 0; i < queryVector.length; i++) {
byteQueryvector[i] = (byte) queryVector[i];
}
return spaceType.getKnnVectorSimilarityFunction().compare(byteQueryvector, vector);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query.filtered;

import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BitSet;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues;

import java.io.IOException;

/**
* This iterator iterates filterIdsArray to score. However, it dedupe docs per each parent doc
* of which ID is set in parentBitSet and only return best child doc with the highest score.
*/
public class NestedFilteredIdsKNNBinaryIterator extends FilteredIdsKNNBinaryIterator {
private final BitSet parentBitSet;

public NestedFilteredIdsKNNBinaryIterator(
final BitSet filterIdsArray,
final byte[] queryVector,
final KNNBinaryVectorValues binaryVectorValues,
final SpaceType spaceType,
final BitSet parentBitSet
) {
super(filterIdsArray, queryVector, binaryVectorValues, spaceType);
this.parentBitSet = parentBitSet;
}

/**
* Advance to the next best child doc per parent and update score with the best score among child docs from the parent.
* DocIdSetIterator.NO_MORE_DOCS is returned when there is no more docs
*
* @return next best child doc id
*/
@Override
public int nextDoc() throws IOException {
if (docId == DocIdSetIterator.NO_MORE_DOCS) {
return DocIdSetIterator.NO_MORE_DOCS;
}

currentScore = Float.NEGATIVE_INFINITY;
int currentParent = parentBitSet.nextSetBit(docId);
int bestChild = -1;

while (docId != DocIdSetIterator.NO_MORE_DOCS && docId < currentParent) {
binaryVectorValues.advance(docId);
float score = computeScore();
if (score > currentScore) {
bestChild = docId;
currentScore = score;
}
docId = bitSetIterator.nextDoc();
}

return bestChild;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BitSet;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNByteVectorValues;

import java.io.IOException;

Expand All @@ -21,12 +21,12 @@ public class NestedFilteredIdsKNNByteIterator extends FilteredIdsKNNByteIterator

public NestedFilteredIdsKNNByteIterator(
final BitSet filterIdsArray,
final byte[] queryVector,
final KNNBinaryVectorValues binaryVectorValues,
final float[] queryVector,
final KNNByteVectorValues byteVectorValues,
final SpaceType spaceType,
final BitSet parentBitSet
) {
super(filterIdsArray, queryVector, binaryVectorValues, spaceType);
super(filterIdsArray, queryVector, byteVectorValues, spaceType);
this.parentBitSet = parentBitSet;
}

Expand All @@ -47,7 +47,7 @@ public int nextDoc() throws IOException {
int bestChild = -1;

while (docId != DocIdSetIterator.NO_MORE_DOCS && docId < currentParent) {
binaryVectorValues.advance(docId);
byteVectorValues.advance(docId);
float score = computeScore();
if (score > currentScore) {
bestChild = docId;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query.filtered;

import junit.framework.TestCase;
import lombok.SneakyThrows;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.FixedBitSet;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class FilteredIdsKNNBinaryIteratorTests extends TestCase {
@SneakyThrows
public void testNextDoc_whenCalled_IterateAllDocs() {
final SpaceType spaceType = SpaceType.HAMMING;
final byte[] queryVector = { 1, 2, 3 };
final int[] filterIds = { 1, 2, 3 };
final List<byte[]> dataVectors = Arrays.asList(new byte[] { 11, 12, 13 }, new byte[] { 14, 15, 16 }, new byte[] { 17, 18, 19 });
final List<Float> expectedScores = dataVectors.stream()
.map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector))
.collect(Collectors.toList());

KNNBinaryVectorValues values = mock(KNNBinaryVectorValues.class);
when(values.getVector()).thenReturn(dataVectors.get(0), dataVectors.get(1), dataVectors.get(2));

FixedBitSet filterBitSet = new FixedBitSet(4);
for (int id : filterIds) {
when(values.advance(id)).thenReturn(id);
filterBitSet.set(id);
}

// Execute and verify
FilteredIdsKNNBinaryIterator iterator = new FilteredIdsKNNBinaryIterator(filterBitSet, queryVector, values, spaceType);
for (int i = 0; i < filterIds.length; i++) {
assertEquals(filterIds[i], iterator.nextDoc());
assertEquals(expectedScores.get(i), (Float) iterator.score());
}
assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

package org.opensearch.knn.index.query.filtered;

import junit.framework.TestCase;
import lombok.SneakyThrows;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.FixedBitSet;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNByteVectorValues;

import java.util.Arrays;
import java.util.List;
Expand All @@ -19,18 +19,19 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class FilteredIdsKNNByteIteratorTests extends TestCase {
public class FilteredIdsKNNByteIteratorTests extends KNNTestCase {
@SneakyThrows
public void testNextDoc_whenCalled_IterateAllDocs() {
final SpaceType spaceType = SpaceType.HAMMING;
final byte[] queryVector = { 1, 2, 3 };
final SpaceType spaceType = SpaceType.L2;
final byte[] byteQueryVector = { 1, 2, 3 };
final float[] queryVector = { 1f, 2f, 3f };
final int[] filterIds = { 1, 2, 3 };
final List<byte[]> dataVectors = Arrays.asList(new byte[] { 11, 12, 13 }, new byte[] { 14, 15, 16 }, new byte[] { 17, 18, 19 });
final List<Float> expectedScores = dataVectors.stream()
.map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector))
.map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(byteQueryVector, vector))
.collect(Collectors.toList());

KNNBinaryVectorValues values = mock(KNNBinaryVectorValues.class);
KNNByteVectorValues values = mock(KNNByteVectorValues.class);
when(values.getVector()).thenReturn(dataVectors.get(0), dataVectors.get(1), dataVectors.get(2));

FixedBitSet filterBitSet = new FixedBitSet(4);
Expand Down
Loading

0 comments on commit faeffe5

Please sign in to comment.