diff --git a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java index 5b6029766c..091e6201a1 100644 --- a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java +++ b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java @@ -20,12 +20,15 @@ import org.opensearch.knn.common.FieldInfoExtractor; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.query.filtered.FilteredIdsKNNBinaryIterator; import org.opensearch.knn.index.query.filtered.FilteredIdsKNNByteIterator; import org.opensearch.knn.index.query.filtered.FilteredIdsKNNIterator; import org.opensearch.knn.index.query.filtered.KNNIterator; +import org.opensearch.knn.index.query.filtered.NestedFilteredIdsKNNBinaryIterator; import org.opensearch.knn.index.query.filtered.NestedFilteredIdsKNNByteIterator; import org.opensearch.knn.index.query.filtered.NestedFilteredIdsKNNIterator; import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNByteVectorValues; import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; @@ -110,7 +113,7 @@ private KNNIterator getMatchedKNNIterator(LeafReaderContext leafReaderContext, E if (VectorDataType.BINARY == knnQuery.getVectorDataType() && isNestedRequired) { final KNNVectorValues vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader); - return new NestedFilteredIdsKNNByteIterator( + return new NestedFilteredIdsKNNBinaryIterator( matchedDocs, knnQuery.getByteQueryVector(), (KNNBinaryVectorValues) vectorValues, @@ -121,13 +124,27 @@ private KNNIterator getMatchedKNNIterator(LeafReaderContext leafReaderContext, E if (VectorDataType.BINARY == knnQuery.getVectorDataType()) { final KNNVectorValues vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader); - return new FilteredIdsKNNByteIterator( + return new FilteredIdsKNNBinaryIterator( matchedDocs, knnQuery.getByteQueryVector(), (KNNBinaryVectorValues) vectorValues, spaceType ); } + + if (VectorDataType.BYTE == knnQuery.getVectorDataType()) { + final KNNVectorValues vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader); + if (isNestedRequired) { + return new NestedFilteredIdsKNNByteIterator( + matchedDocs, + knnQuery.getQueryVector(), + (KNNByteVectorValues) vectorValues, + spaceType, + knnQuery.getParentsFilter().getBitSet(leafReaderContext) + ); + } + return new FilteredIdsKNNByteIterator(matchedDocs, knnQuery.getQueryVector(), (KNNByteVectorValues) vectorValues, spaceType); + } final byte[] quantizedQueryVector; final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo; if (exactSearcherContext.isUseQuantizedVectorsForSearch()) { diff --git a/src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNBinaryIterator.java b/src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNBinaryIterator.java new file mode 100644 index 0000000000..1a7ae62993 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNBinaryIterator.java @@ -0,0 +1,75 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.filtered; + +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.BitSetIterator; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; + +import java.io.IOException; + +/** + * Inspired by DiversifyingChildrenFloatKnnVectorQuery in lucene + * https://github.com/apache/lucene/blob/7b8aece125aabff2823626d5b939abf4747f63a7/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java#L162 + * + * The class is used in KNNWeight to score filtered KNN field by iterating filterIdsArray. + */ +public class FilteredIdsKNNBinaryIterator implements KNNIterator { + // Array of doc ids to iterate + protected final BitSet filterIdsBitSet; + protected final BitSetIterator bitSetIterator; + protected final byte[] queryVector; + protected final KNNBinaryVectorValues binaryVectorValues; + protected final SpaceType spaceType; + protected float currentScore = Float.NEGATIVE_INFINITY; + protected int docId; + + public FilteredIdsKNNBinaryIterator( + final BitSet filterIdsBitSet, + final byte[] queryVector, + final KNNBinaryVectorValues binaryVectorValues, + final SpaceType spaceType + ) { + this.filterIdsBitSet = filterIdsBitSet; + this.bitSetIterator = new BitSetIterator(filterIdsBitSet, filterIdsBitSet.length()); + this.queryVector = queryVector; + this.binaryVectorValues = binaryVectorValues; + this.spaceType = spaceType; + this.docId = bitSetIterator.nextDoc(); + } + + /** + * Advance to the next doc and update score value with score of the next doc. + * DocIdSetIterator.NO_MORE_DOCS is returned when there is no more docs + * + * @return next doc id + */ + @Override + public int nextDoc() throws IOException { + + if (docId == DocIdSetIterator.NO_MORE_DOCS) { + return DocIdSetIterator.NO_MORE_DOCS; + } + int doc = binaryVectorValues.advance(docId); + currentScore = computeScore(); + docId = bitSetIterator.nextDoc(); + return doc; + } + + @Override + public float score() { + return currentScore; + } + + protected float computeScore() throws IOException { + final byte[] vector = binaryVectorValues.getVector(); + // Calculates a similarity score between the two vectors with a specified function. Higher similarity + // scores correspond to closer vectors. + return spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector); + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNByteIterator.java b/src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNByteIterator.java index ccfe626a0e..a1bdb48543 100644 --- a/src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNByteIterator.java +++ b/src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNByteIterator.java @@ -9,7 +9,7 @@ import org.apache.lucene.util.BitSet; import org.apache.lucene.util.BitSetIterator; import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNByteVectorValues; import java.io.IOException; @@ -23,22 +23,22 @@ public class FilteredIdsKNNByteIterator implements KNNIterator { // Array of doc ids to iterate protected final BitSet filterIdsBitSet; protected final BitSetIterator bitSetIterator; - protected final byte[] queryVector; - protected final KNNBinaryVectorValues binaryVectorValues; + protected final float[] queryVector; + protected final KNNByteVectorValues byteVectorValues; protected final SpaceType spaceType; protected float currentScore = Float.NEGATIVE_INFINITY; protected int docId; public FilteredIdsKNNByteIterator( final BitSet filterIdsBitSet, - final byte[] queryVector, - final KNNBinaryVectorValues binaryVectorValues, + final float[] queryVector, + final KNNByteVectorValues byteVectorValues, final SpaceType spaceType ) { this.filterIdsBitSet = filterIdsBitSet; this.bitSetIterator = new BitSetIterator(filterIdsBitSet, filterIdsBitSet.length()); this.queryVector = queryVector; - this.binaryVectorValues = binaryVectorValues; + this.byteVectorValues = byteVectorValues; this.spaceType = spaceType; this.docId = bitSetIterator.nextDoc(); } @@ -55,7 +55,7 @@ public int nextDoc() throws IOException { if (docId == DocIdSetIterator.NO_MORE_DOCS) { return DocIdSetIterator.NO_MORE_DOCS; } - int doc = binaryVectorValues.advance(docId); + int doc = byteVectorValues.advance(docId); currentScore = computeScore(); docId = bitSetIterator.nextDoc(); return doc; @@ -67,9 +67,13 @@ public float score() { } protected float computeScore() throws IOException { - final byte[] vector = binaryVectorValues.getVector(); + final byte[] vector = byteVectorValues.getVector(); // Calculates a similarity score between the two vectors with a specified function. Higher similarity // scores correspond to closer vectors. - return spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector); + final byte[] byteQueryvector = new byte[queryVector.length]; + for (int i = 0; i < queryVector.length; i++) { + byteQueryvector[i] = (byte) queryVector[i]; + } + return spaceType.getKnnVectorSimilarityFunction().compare(byteQueryvector, vector); } } diff --git a/src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNBinaryIterator.java b/src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNBinaryIterator.java new file mode 100644 index 0000000000..862150413b --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNBinaryIterator.java @@ -0,0 +1,61 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.filtered; + +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.BitSet; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; + +import java.io.IOException; + +/** + * This iterator iterates filterIdsArray to score. However, it dedupe docs per each parent doc + * of which ID is set in parentBitSet and only return best child doc with the highest score. + */ +public class NestedFilteredIdsKNNBinaryIterator extends FilteredIdsKNNBinaryIterator { + private final BitSet parentBitSet; + + public NestedFilteredIdsKNNBinaryIterator( + final BitSet filterIdsArray, + final byte[] queryVector, + final KNNBinaryVectorValues binaryVectorValues, + final SpaceType spaceType, + final BitSet parentBitSet + ) { + super(filterIdsArray, queryVector, binaryVectorValues, spaceType); + this.parentBitSet = parentBitSet; + } + + /** + * Advance to the next best child doc per parent and update score with the best score among child docs from the parent. + * DocIdSetIterator.NO_MORE_DOCS is returned when there is no more docs + * + * @return next best child doc id + */ + @Override + public int nextDoc() throws IOException { + if (docId == DocIdSetIterator.NO_MORE_DOCS) { + return DocIdSetIterator.NO_MORE_DOCS; + } + + currentScore = Float.NEGATIVE_INFINITY; + int currentParent = parentBitSet.nextSetBit(docId); + int bestChild = -1; + + while (docId != DocIdSetIterator.NO_MORE_DOCS && docId < currentParent) { + binaryVectorValues.advance(docId); + float score = computeScore(); + if (score > currentScore) { + bestChild = docId; + currentScore = score; + } + docId = bitSetIterator.nextDoc(); + } + + return bestChild; + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNByteIterator.java b/src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNByteIterator.java index b69a90518f..0803247a15 100644 --- a/src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNByteIterator.java +++ b/src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNByteIterator.java @@ -8,7 +8,7 @@ import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.BitSet; import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNByteVectorValues; import java.io.IOException; @@ -21,12 +21,12 @@ public class NestedFilteredIdsKNNByteIterator extends FilteredIdsKNNByteIterator public NestedFilteredIdsKNNByteIterator( final BitSet filterIdsArray, - final byte[] queryVector, - final KNNBinaryVectorValues binaryVectorValues, + final float[] queryVector, + final KNNByteVectorValues byteVectorValues, final SpaceType spaceType, final BitSet parentBitSet ) { - super(filterIdsArray, queryVector, binaryVectorValues, spaceType); + super(filterIdsArray, queryVector, byteVectorValues, spaceType); this.parentBitSet = parentBitSet; } @@ -47,7 +47,7 @@ public int nextDoc() throws IOException { int bestChild = -1; while (docId != DocIdSetIterator.NO_MORE_DOCS && docId < currentParent) { - binaryVectorValues.advance(docId); + byteVectorValues.advance(docId); float score = computeScore(); if (score > currentScore) { bestChild = docId; diff --git a/src/test/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNBinaryIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNBinaryIteratorTests.java new file mode 100644 index 0000000000..35818726d5 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNBinaryIteratorTests.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.filtered; + +import junit.framework.TestCase; +import lombok.SneakyThrows; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.FixedBitSet; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class FilteredIdsKNNBinaryIteratorTests extends TestCase { + @SneakyThrows + public void testNextDoc_whenCalled_IterateAllDocs() { + final SpaceType spaceType = SpaceType.HAMMING; + final byte[] queryVector = { 1, 2, 3 }; + final int[] filterIds = { 1, 2, 3 }; + final List dataVectors = Arrays.asList(new byte[] { 11, 12, 13 }, new byte[] { 14, 15, 16 }, new byte[] { 17, 18, 19 }); + final List expectedScores = dataVectors.stream() + .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) + .collect(Collectors.toList()); + + KNNBinaryVectorValues values = mock(KNNBinaryVectorValues.class); + when(values.getVector()).thenReturn(dataVectors.get(0), dataVectors.get(1), dataVectors.get(2)); + + FixedBitSet filterBitSet = new FixedBitSet(4); + for (int id : filterIds) { + when(values.advance(id)).thenReturn(id); + filterBitSet.set(id); + } + + // Execute and verify + FilteredIdsKNNBinaryIterator iterator = new FilteredIdsKNNBinaryIterator(filterBitSet, queryVector, values, spaceType); + for (int i = 0; i < filterIds.length; i++) { + assertEquals(filterIds[i], iterator.nextDoc()); + assertEquals(expectedScores.get(i), (Float) iterator.score()); + } + assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc()); + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNByteIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNByteIteratorTests.java index c52798c05a..45ce1133f1 100644 --- a/src/test/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNByteIteratorTests.java +++ b/src/test/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNByteIteratorTests.java @@ -5,12 +5,12 @@ package org.opensearch.knn.index.query.filtered; -import junit.framework.TestCase; import lombok.SneakyThrows; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.FixedBitSet; +import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNByteVectorValues; import java.util.Arrays; import java.util.List; @@ -19,18 +19,19 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -public class FilteredIdsKNNByteIteratorTests extends TestCase { +public class FilteredIdsKNNByteIteratorTests extends KNNTestCase { @SneakyThrows public void testNextDoc_whenCalled_IterateAllDocs() { - final SpaceType spaceType = SpaceType.HAMMING; - final byte[] queryVector = { 1, 2, 3 }; + final SpaceType spaceType = SpaceType.L2; + final byte[] byteQueryVector = { 1, 2, 3 }; + final float[] queryVector = { 1f, 2f, 3f }; final int[] filterIds = { 1, 2, 3 }; final List dataVectors = Arrays.asList(new byte[] { 11, 12, 13 }, new byte[] { 14, 15, 16 }, new byte[] { 17, 18, 19 }); final List expectedScores = dataVectors.stream() - .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) + .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(byteQueryVector, vector)) .collect(Collectors.toList()); - KNNBinaryVectorValues values = mock(KNNBinaryVectorValues.class); + KNNByteVectorValues values = mock(KNNByteVectorValues.class); when(values.getVector()).thenReturn(dataVectors.get(0), dataVectors.get(1), dataVectors.get(2)); FixedBitSet filterBitSet = new FixedBitSet(4); diff --git a/src/test/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNBinaryIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNBinaryIteratorTests.java new file mode 100644 index 0000000000..2b8a8421bb --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNBinaryIteratorTests.java @@ -0,0 +1,61 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.filtered; + +import junit.framework.TestCase; +import lombok.SneakyThrows; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.FixedBitSet; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class NestedFilteredIdsKNNBinaryIteratorTests extends TestCase { + @SneakyThrows + public void testNextDoc_whenIterate_ReturnBestChildDocsPerParent() { + final SpaceType spaceType = SpaceType.HAMMING; + final byte[] queryVector = { 1, 2, 3 }; + final int[] filterIds = { 0, 2, 3 }; + // Parent id for 0 -> 1 + // Parent id for 2, 3 -> 4 + // In bit representation, it is 10010. In long, it is 18. + final BitSet parentBitSet = new FixedBitSet(new long[] { 18 }, 5); + final List dataVectors = Arrays.asList(new byte[] { 11, 12, 13 }, new byte[] { 14, 15, 16 }, new byte[] { 17, 18, 19 }); + final List expectedScores = dataVectors.stream() + .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) + .collect(Collectors.toList()); + + KNNBinaryVectorValues values = mock(KNNBinaryVectorValues.class); + when(values.getVector()).thenReturn(dataVectors.get(0), dataVectors.get(1), dataVectors.get(2)); + + FixedBitSet filterBitSet = new FixedBitSet(4); + for (int id : filterIds) { + when(values.advance(id)).thenReturn(id); + filterBitSet.set(id); + } + + // Execute and verify + NestedFilteredIdsKNNBinaryIterator iterator = new NestedFilteredIdsKNNBinaryIterator( + filterBitSet, + queryVector, + values, + spaceType, + parentBitSet + ); + assertEquals(filterIds[0], iterator.nextDoc()); + assertEquals(expectedScores.get(0), iterator.score()); + assertEquals(filterIds[2], iterator.nextDoc()); + assertEquals(expectedScores.get(2), iterator.score()); + assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc()); + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNByteIteratorTests.java b/src/test/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNByteIteratorTests.java index 1940ffe123..987bf6f70c 100644 --- a/src/test/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNByteIteratorTests.java +++ b/src/test/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNByteIteratorTests.java @@ -11,7 +11,7 @@ import org.apache.lucene.util.BitSet; import org.apache.lucene.util.FixedBitSet; import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNByteVectorValues; import java.util.Arrays; import java.util.List; @@ -23,19 +23,20 @@ public class NestedFilteredIdsKNNByteIteratorTests extends TestCase { @SneakyThrows public void testNextDoc_whenIterate_ReturnBestChildDocsPerParent() { - final SpaceType spaceType = SpaceType.HAMMING; - final byte[] queryVector = { 1, 2, 3 }; + final SpaceType spaceType = SpaceType.L2; + final byte[] byteQueryVector = { 1, 2, 3 }; + final float[] queryVector = { 1.0f, 2.0f, 3.0f }; final int[] filterIds = { 0, 2, 3 }; // Parent id for 0 -> 1 // Parent id for 2, 3 -> 4 // In bit representation, it is 10010. In long, it is 18. final BitSet parentBitSet = new FixedBitSet(new long[] { 18 }, 5); - final List dataVectors = Arrays.asList(new byte[] { 11, 12, 13 }, new byte[] { 14, 15, 16 }, new byte[] { 17, 18, 19 }); + final List dataVectors = Arrays.asList(new byte[] { 11, 12, 13 }, new byte[] { 17, 18, 19 }, new byte[] { 14, 15, 16 }); final List expectedScores = dataVectors.stream() - .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) + .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(byteQueryVector, vector)) .collect(Collectors.toList()); - KNNBinaryVectorValues values = mock(KNNBinaryVectorValues.class); + KNNByteVectorValues values = mock(KNNByteVectorValues.class); when(values.getVector()).thenReturn(dataVectors.get(0), dataVectors.get(1), dataVectors.get(2)); FixedBitSet filterBitSet = new FixedBitSet(4);