From ea65786a5847f3272abdb4bb754800bfca077a38 Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Sun, 17 Nov 2024 14:30:13 -0800 Subject: [PATCH] Fix NPE in ANN search when a segment doesn't contain vector field Signed-off-by: Navneet Verma --- CHANGELOG.md | 1 + build.gradle | 6 +- .../knn/index/query/ExactSearcher.java | 167 ++++++++++-------- .../knn/index/query/ResultUtil.java | 19 +- .../nativelib/NativeEngineKnnVectorQuery.java | 16 +- .../knn/index/query/ResultUtilTests.java | 10 +- 6 files changed, 119 insertions(+), 100 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3c57523c3..5695d48ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Enhancements - Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241] ### Bug Fixes +* Fix NPE in ANN search when a segment doesn't contain vector field (#2278)[https://github.com/opensearch-project/k-NN/pull/2278] ### Infrastructure * Updated C++ version in JNI from c++11 to c++17 [#2259](https://github.com/opensearch-project/k-NN/pull/2259) ### Documentation diff --git a/build.gradle b/build.gradle index 132fd7f43..2ad973676 100644 --- a/build.gradle +++ b/build.gradle @@ -295,9 +295,9 @@ dependencies { api group: 'com.google.guava', name: 'guava', version:'32.1.3-jre' api group: 'commons-lang', name: 'commons-lang', version: '2.6' testFixturesImplementation "org.opensearch.test:framework:${opensearch_version}" - testImplementation group: 'net.bytebuddy', name: 'byte-buddy', version: '1.15.4' - testImplementation group: 'org.objenesis', name: 'objenesis', version: '3.2' - testImplementation group: 'net.bytebuddy', name: 'byte-buddy-agent', version: '1.15.4' + testImplementation group: 'net.bytebuddy', name: 'byte-buddy', version: '1.15.10' + testImplementation group: 'org.objenesis', name: 'objenesis', version: '3.3' + testImplementation group: 'net.bytebuddy', name: 'byte-buddy-agent', version: '1.15.10' testFixturesImplementation "org.opensearch:common-utils:${version}" implementation 'com.github.oshi:oshi-core:6.4.13' api "net.java.dev.jna:jna:5.13.0" diff --git a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java index 77e993297..d44b48a5d 100644 --- a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java +++ b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java @@ -21,7 +21,6 @@ import org.opensearch.common.lucene.Lucene; import org.opensearch.knn.common.FieldInfoExtractor; import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.query.iterators.BinaryVectorIdsKNNIterator; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.query.iterators.ByteVectorIdsKNNIterator; @@ -38,9 +37,11 @@ import org.opensearch.knn.indices.ModelDao; import java.io.IOException; +import java.util.Collections; import java.util.HashMap; import java.util.Locale; import java.util.Map; +import java.util.Optional; import java.util.function.Predicate; @Log4j2 @@ -59,23 +60,27 @@ public class ExactSearcher { */ public Map searchLeaf(final LeafReaderContext leafReaderContext, final ExactSearcherContext exactSearcherContext) throws IOException { - KNNIterator iterator = getKNNIterator(leafReaderContext, exactSearcherContext); + final Optional iterator = getKNNIterator(leafReaderContext, exactSearcherContext); + // because of any reason if we are not able to get KNNIterator, return an empty map + if (iterator.isEmpty()) { + return Collections.emptyMap(); + } if (exactSearcherContext.getKnnQuery().getRadius() != null) { - return doRadialSearch(leafReaderContext, exactSearcherContext, iterator); + return doRadialSearch(leafReaderContext, exactSearcherContext, iterator.get()); } if (exactSearcherContext.getMatchedDocs() != null && exactSearcherContext.getMatchedDocs().cardinality() <= exactSearcherContext.getK()) { - return scoreAllDocs(iterator); + return scoreAllDocs(iterator.get()); } - return searchTopCandidates(iterator, exactSearcherContext.getK(), Predicates.alwaysTrue()); + return searchTopCandidates(iterator.get(), exactSearcherContext.getK(), Predicates.alwaysTrue()); } /** * Perform radial search by comparing scores with min score. Currently, FAISS from native engine supports radial search. * Hence, we assume that Radius from knnQuery is always distance, and we convert it to score since we do exact search uses scores * to filter out the documents that does not have given min score. - * @param leafReaderContext - * @param exactSearcherContext + * @param leafReaderContext {@link LeafReaderContext} + * @param exactSearcherContext {@link ExactSearcherContext} * @param iterator {@link KNNIterator} * @return Map of docId and score * @throws IOException exception raised by iterator during traversal @@ -145,79 +150,99 @@ private Map filterDocsByMinScore(ExactSearcherContext context, K return searchTopCandidates(iterator, maxResultWindow, scoreGreaterThanOrEqualToMinScore); } - private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSearcherContext exactSearcherContext) throws IOException { + private Optional getKNNIterator(LeafReaderContext leafReaderContext, ExactSearcherContext exactSearcherContext) + throws IOException { final KNNQuery knnQuery = exactSearcherContext.getKnnQuery(); final BitSet matchedDocs = exactSearcherContext.getMatchedDocs(); final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader()); final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); + if (fieldInfo == null) { + log.debug("[KNN] Cannot get KNNIterator as Field info not found for {}:{}", knnQuery.getField(), reader.getSegmentName()); + return Optional.empty(); + } final SpaceType spaceType = FieldInfoExtractor.getSpaceType(modelDao, fieldInfo); boolean isNestedRequired = exactSearcherContext.isParentHits() && knnQuery.getParentsFilter() != null; - - if (VectorDataType.BINARY == knnQuery.getVectorDataType()) { - final KNNVectorValues vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader); - if (isNestedRequired) { - return new NestedBinaryVectorIdsKNNIterator( - matchedDocs, - knnQuery.getByteQueryVector(), - (KNNBinaryVectorValues) vectorValues, - spaceType, - knnQuery.getParentsFilter().getBitSet(leafReaderContext) - ); - } - return new BinaryVectorIdsKNNIterator( - matchedDocs, - knnQuery.getByteQueryVector(), - (KNNBinaryVectorValues) vectorValues, - spaceType - ); - } - - if (VectorDataType.BYTE == knnQuery.getVectorDataType()) { - final KNNVectorValues vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader); - if (isNestedRequired) { - return new NestedByteVectorIdsKNNIterator( - matchedDocs, - knnQuery.getQueryVector(), - (KNNByteVectorValues) vectorValues, - spaceType, - knnQuery.getParentsFilter().getBitSet(leafReaderContext) + final KNNIterator knnIterator; + KNNVectorValues vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader); + switch (knnQuery.getVectorDataType()) { + case BINARY: + if (isNestedRequired) { + knnIterator = new NestedBinaryVectorIdsKNNIterator( + matchedDocs, + knnQuery.getByteQueryVector(), + (KNNBinaryVectorValues) vectorValues, + spaceType, + knnQuery.getParentsFilter().getBitSet(leafReaderContext) + ); + } else { + knnIterator = new BinaryVectorIdsKNNIterator( + matchedDocs, + knnQuery.getByteQueryVector(), + (KNNBinaryVectorValues) vectorValues, + spaceType + ); + } + return Optional.of(knnIterator); + case BYTE: + if (isNestedRequired) { + knnIterator = new NestedByteVectorIdsKNNIterator( + matchedDocs, + knnQuery.getQueryVector(), + (KNNByteVectorValues) vectorValues, + spaceType, + knnQuery.getParentsFilter().getBitSet(leafReaderContext) + ); + } else { + knnIterator = new ByteVectorIdsKNNIterator( + matchedDocs, + knnQuery.getQueryVector(), + (KNNByteVectorValues) vectorValues, + spaceType + ); + } + return Optional.of(knnIterator); + case FLOAT: + final byte[] quantizedQueryVector; + final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo; + if (exactSearcherContext.isUseQuantizedVectorsForSearch()) { + // Build Segment Level Quantization info. + segmentLevelQuantizationInfo = SegmentLevelQuantizationInfo.build(reader, fieldInfo, knnQuery.getField()); + // Quantize the Query Vector Once. + quantizedQueryVector = SegmentLevelQuantizationUtil.quantizeVector( + knnQuery.getQueryVector(), + segmentLevelQuantizationInfo + ); + } else { + segmentLevelQuantizationInfo = null; + quantizedQueryVector = null; + } + if (isNestedRequired) { + knnIterator = new NestedVectorIdsKNNIterator( + matchedDocs, + knnQuery.getQueryVector(), + (KNNFloatVectorValues) vectorValues, + spaceType, + knnQuery.getParentsFilter().getBitSet(leafReaderContext), + quantizedQueryVector, + segmentLevelQuantizationInfo + ); + } else { + knnIterator = new VectorIdsKNNIterator( + matchedDocs, + knnQuery.getQueryVector(), + (KNNFloatVectorValues) vectorValues, + spaceType, + quantizedQueryVector, + segmentLevelQuantizationInfo + ); + } + return Optional.of(knnIterator); + default: + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Vector data type [%s] is not supported", knnQuery.getVectorDataType()) ); - } - return new ByteVectorIdsKNNIterator(matchedDocs, knnQuery.getQueryVector(), (KNNByteVectorValues) vectorValues, spaceType); - } - final byte[] quantizedQueryVector; - final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo; - if (exactSearcherContext.isUseQuantizedVectorsForSearch()) { - // Build Segment Level Quantization info. - segmentLevelQuantizationInfo = SegmentLevelQuantizationInfo.build(reader, fieldInfo, knnQuery.getField()); - // Quantize the Query Vector Once. - quantizedQueryVector = SegmentLevelQuantizationUtil.quantizeVector(knnQuery.getQueryVector(), segmentLevelQuantizationInfo); - } else { - segmentLevelQuantizationInfo = null; - quantizedQueryVector = null; - } - - final KNNVectorValues vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader); - if (isNestedRequired) { - return new NestedVectorIdsKNNIterator( - matchedDocs, - knnQuery.getQueryVector(), - (KNNFloatVectorValues) vectorValues, - spaceType, - knnQuery.getParentsFilter().getBitSet(leafReaderContext), - quantizedQueryVector, - segmentLevelQuantizationInfo - ); } - return new VectorIdsKNNIterator( - matchedDocs, - knnQuery.getQueryVector(), - (KNNFloatVectorValues) vectorValues, - spaceType, - quantizedQueryVector, - segmentLevelQuantizationInfo - ); } /** diff --git a/src/main/java/org/opensearch/knn/index/query/ResultUtil.java b/src/main/java/org/opensearch/knn/index/query/ResultUtil.java index f62c09cb0..ff373e7f3 100644 --- a/src/main/java/org/opensearch/knn/index/query/ResultUtil.java +++ b/src/main/java/org/opensearch/knn/index/query/ResultUtil.java @@ -13,11 +13,7 @@ import org.apache.lucene.util.DocIdSetBuilder; import java.io.IOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.PriorityQueue; +import java.util.*; /** * Utility class used for processing results @@ -58,19 +54,20 @@ public static void reduceToTopK(List> perLeafResults, int k) } /** - * Convert map to bit set + * Convert map to bit set, if resultMap is empty or null then returns an Optional. Returning an optional here to + * ensure that the caller is aware that BitSet may not be present * * @param resultMap Map of results - * @return BitSet of results + * @return Optional BitSet of results * @throws IOException If an error occurs during the search. */ - public static BitSet resultMapToMatchBitSet(Map resultMap) throws IOException { - if (resultMap.isEmpty()) { - return BitSet.of(DocIdSetIterator.empty(), 0); + public static Optional resultMapToMatchBitSet(Map resultMap) throws IOException { + if (resultMap == null || resultMap.isEmpty()) { + return Optional.empty(); } final int maxDoc = Collections.max(resultMap.keySet()) + 1; - return BitSet.of(resultMapToDocIds(resultMap, maxDoc), maxDoc); + return Optional.of(BitSet.of(resultMapToDocIds(resultMap, maxDoc), maxDoc)); } /** diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java index a34a0f1ee..74a8ecaa9 100644 --- a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java @@ -28,12 +28,7 @@ import org.opensearch.knn.index.query.rescore.RescoreContext; import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Comparator; -import java.util.List; -import java.util.Map; -import java.util.Objects; +import java.util.*; import java.util.concurrent.Callable; /** @@ -112,9 +107,14 @@ private List> doRescore( LeafReaderContext leafReaderContext = leafReaderContexts.get(i); int finalI = i; rescoreTasks.add(() -> { - BitSet convertedBitSet = ResultUtil.resultMapToMatchBitSet(perLeafResults.get(finalI)); + final Optional convertedBitSet = ResultUtil.resultMapToMatchBitSet(perLeafResults.get(finalI)); + // if there is no docIds to re-score from a segment we should return early to ensure that we are not + // wasting any computation + if (convertedBitSet.isEmpty()) { + return Collections.emptyMap(); + } final ExactSearcher.ExactSearcherContext exactSearcherContext = ExactSearcher.ExactSearcherContext.builder() - .matchedDocs(convertedBitSet) + .matchedDocs(convertedBitSet.get()) // setting to false because in re-scoring we want to do exact search on full precision vectors .useQuantizedVectorsForSearch(false) .k(k) diff --git a/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java b/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java index 70cb86e02..5d869aaff 100644 --- a/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/query/ResultUtilTests.java @@ -12,11 +12,7 @@ import org.opensearch.knn.KNNTestCase; import java.io.IOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.stream.Collectors; public class ResultUtilTests extends KNNTestCase { @@ -44,8 +40,8 @@ public void testReduceToTopK() { public void testResultMapToMatchBitSet() throws IOException { int firstPassK = 35; Map perLeafResults = getRandomResults(firstPassK); - BitSet resultBitset = ResultUtil.resultMapToMatchBitSet(perLeafResults); - assertResultMapToMatchBitSet(perLeafResults, resultBitset); + Optional resultBitset = ResultUtil.resultMapToMatchBitSet(perLeafResults); + assertResultMapToMatchBitSet(perLeafResults, resultBitset.get()); } public void testResultMapToDocIds() throws IOException {