From abcd3392be8572a502433fcbfd718bb76598d818 Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Thu, 21 Nov 2024 23:38:38 -0800 Subject: [PATCH] Added null checks for fieldInfo in ExactSearcher to avoid NPE while running exact search for segments with no vector field (#2278) Signed-off-by: Navneet Verma (cherry picked from commit 7523cc317d5575c74b82f315f271d63668c8f623) --- CHANGELOG.md | 3 +- .../knn/common/FieldInfoExtractor.java | 15 ++++- .../knn/index/KNNVectorDVLeafFieldData.java | 3 +- .../knn/index/query/ExactSearcher.java | 22 ++++++-- .../opensearch/knn/index/query/KNNWeight.java | 4 +- .../knn/common/FieldInfoExtractorTests.java | 13 +++++ .../knn/index/query/ExactSearcherTests.java | 55 +++++++++++++++++++ 7 files changed, 105 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dadf0d19c..a39546843 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,11 +18,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Enhancements - Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241] ### Bug Fixes -* Fix NPE in ANN search when a segment doesn't contain vector field (#2278)[https://github.com/opensearch-project/k-NN/pull/2278] +* Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282] ### Infrastructure * Updated C++ version in JNI from c++11 to c++17 [#2259](https://github.com/opensearch-project/k-NN/pull/2259) * Upgrade bytebuddy and objenesis version to match OpenSearch core and, update github ci runner for macos [#2279](https://github.com/opensearch-project/k-NN/pull/2279) ### Documentation ### Maintenance * Select index settings based on cluster version[2236](https://github.com/opensearch-project/k-NN/pull/2236) +* Added null checks for fieldInfo in ExactSearcher to avoid NPE while running exact search for segments with no vector field (#2278)[https://github.com/opensearch-project/k-NN/pull/2278] ### Refactoring diff --git a/src/main/java/org/opensearch/knn/common/FieldInfoExtractor.java b/src/main/java/org/opensearch/knn/common/FieldInfoExtractor.java index 4ded68237..16bf0fb54 100644 --- a/src/main/java/org/opensearch/knn/common/FieldInfoExtractor.java +++ b/src/main/java/org/opensearch/knn/common/FieldInfoExtractor.java @@ -8,6 +8,8 @@ import lombok.experimental.UtilityClass; import org.apache.commons.lang.StringUtils; import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.LeafReader; +import org.opensearch.common.Nullable; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; @@ -28,7 +30,7 @@ import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; /** - * A utility class to extract information from FieldInfo. + * A utility class to extract information from FieldInfo and also provides utility functions to extract fieldInfo */ @UtilityClass public class FieldInfoExtractor { @@ -104,4 +106,15 @@ public static SpaceType getSpaceType(final ModelDao modelDao, final FieldInfo fi } return modelMetadata.getSpaceType(); } + + /** + * Get the field info for the given field name, do a null check on the fieldInfo, as this function can return null, + * if the field is not found. + * @param leafReader {@link LeafReader} + * @param fieldName {@link String} + * @return {@link FieldInfo} + */ + public static @Nullable FieldInfo getFieldInfo(final LeafReader leafReader, final String fieldName) { + return leafReader.getFieldInfos().fieldInfo(fieldName); + } } diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java b/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java index 85f037c0f..7053e6151 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java @@ -12,6 +12,7 @@ import org.opensearch.index.fielddata.LeafFieldData; import org.opensearch.index.fielddata.ScriptDocValues; import org.opensearch.index.fielddata.SortedBinaryDocValues; +import org.opensearch.knn.common.FieldInfoExtractor; import java.io.IOException; @@ -40,7 +41,7 @@ public long ramBytesUsed() { @Override public ScriptDocValues getScriptValues() { try { - FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(fieldName); + FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, fieldName); if (fieldInfo == null) { return KNNVectorScriptDocValues.emptyValues(fieldName, vectorDataType); } diff --git a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java index 77e993297..6a97b4083 100644 --- a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java +++ b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java @@ -38,6 +38,7 @@ import org.opensearch.knn.indices.ModelDao; import java.io.IOException; +import java.util.Collections; import java.util.HashMap; import java.util.Locale; import java.util.Map; @@ -59,7 +60,11 @@ public class ExactSearcher { */ public Map searchLeaf(final LeafReaderContext leafReaderContext, final ExactSearcherContext exactSearcherContext) throws IOException { - KNNIterator iterator = getKNNIterator(leafReaderContext, exactSearcherContext); + final KNNIterator iterator = getKNNIterator(leafReaderContext, exactSearcherContext); + // because of any reason if we are not able to get KNNIterator, return an empty map + if (iterator == null) { + return Collections.emptyMap(); + } if (exactSearcherContext.getKnnQuery().getRadius() != null) { return doRadialSearch(leafReaderContext, exactSearcherContext, iterator); } @@ -74,8 +79,8 @@ public Map searchLeaf(final LeafReaderContext leafReaderContext, * Perform radial search by comparing scores with min score. Currently, FAISS from native engine supports radial search. * Hence, we assume that Radius from knnQuery is always distance, and we convert it to score since we do exact search uses scores * to filter out the documents that does not have given min score. - * @param leafReaderContext - * @param exactSearcherContext + * @param leafReaderContext {@link LeafReaderContext} + * @param exactSearcherContext {@link ExactSearcherContext} * @param iterator {@link KNNIterator} * @return Map of docId and score * @throws IOException exception raised by iterator during traversal @@ -87,7 +92,10 @@ private Map doRadialSearch( ) throws IOException { final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader()); final KNNQuery knnQuery = exactSearcherContext.getKnnQuery(); - final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); + final FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, knnQuery.getField()); + if (fieldInfo == null) { + return Collections.emptyMap(); + } final KNNEngine engine = FieldInfoExtractor.extractKNNEngine(fieldInfo); if (KNNEngine.FAISS != engine) { throw new IllegalArgumentException(String.format(Locale.ROOT, "Engine [%s] does not support radial search", engine)); @@ -149,7 +157,11 @@ private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSea final KNNQuery knnQuery = exactSearcherContext.getKnnQuery(); final BitSet matchedDocs = exactSearcherContext.getMatchedDocs(); final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader()); - final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); + final FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, knnQuery.getField()); + if (fieldInfo == null) { + log.debug("[KNN] Cannot get KNNIterator as Field info not found for {}:{}", knnQuery.getField(), reader.getSegmentName()); + return null; + } final SpaceType spaceType = FieldInfoExtractor.getSpaceType(modelDao, fieldInfo); boolean isNestedRequired = exactSearcherContext.isParentHits() && knnQuery.getParentsFilter() != null; diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 04c2ce587..b64472994 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -227,7 +227,7 @@ private Map doANNSearch( ) throws IOException { final SegmentReader reader = Lucene.segmentReader(context.reader()); - FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); + FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, knnQuery.getField()); if (fieldInfo == null) { log.debug("[KNN] Field info not found for {}:{}", knnQuery.getField(), reader.getSegmentName()); @@ -479,7 +479,7 @@ private boolean isFilteredExactSearchRequireAfterANNSearch(final int filterIdsCo */ private boolean isMissingNativeEngineFiles(LeafReaderContext context) { final SegmentReader reader = Lucene.segmentReader(context.reader()); - final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); + final FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, knnQuery.getField()); // if segment has no documents with at least 1 vector field, field info will be null if (fieldInfo == null) { return false; diff --git a/src/test/java/org/opensearch/knn/common/FieldInfoExtractorTests.java b/src/test/java/org/opensearch/knn/common/FieldInfoExtractorTests.java index 27aedd1d0..dd3721071 100644 --- a/src/test/java/org/opensearch/knn/common/FieldInfoExtractorTests.java +++ b/src/test/java/org/opensearch/knn/common/FieldInfoExtractorTests.java @@ -6,6 +6,8 @@ package org.opensearch.knn.common; import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.LeafReader; import org.junit.Assert; import org.mockito.MockedStatic; import org.mockito.Mockito; @@ -63,4 +65,15 @@ public void testExtractVectorDataType() { when(fieldInfo.getAttribute("model_id")).thenReturn(null); assertEquals(VectorDataType.DEFAULT, FieldInfoExtractor.extractVectorDataType(fieldInfo)); } + + public void testGetFieldInfo_whenDifferentInput_thenSuccess() { + LeafReader leafReader = Mockito.mock(LeafReader.class); + FieldInfos fieldInfos = Mockito.mock(FieldInfos.class); + FieldInfo fieldInfo = Mockito.mock(FieldInfo.class); + Mockito.when(leafReader.getFieldInfos()).thenReturn(fieldInfos); + Mockito.when(fieldInfos.fieldInfo("invalid")).thenReturn(null); + Mockito.when(fieldInfos.fieldInfo("valid")).thenReturn(fieldInfo); + Assert.assertNull(FieldInfoExtractor.getFieldInfo(leafReader, "invalid")); + Assert.assertEquals(fieldInfo, FieldInfoExtractor.getFieldInfo(leafReader, "valid")); + } } diff --git a/src/test/java/org/opensearch/knn/index/query/ExactSearcherTests.java b/src/test/java/org/opensearch/knn/index/query/ExactSearcherTests.java index 8492ca1f0..a4b853560 100644 --- a/src/test/java/org/opensearch/knn/index/query/ExactSearcherTests.java +++ b/src/test/java/org/opensearch/knn/index/query/ExactSearcherTests.java @@ -20,6 +20,7 @@ import org.mockito.Mockito; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.KNNCodecVersion; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; @@ -50,6 +51,59 @@ public class ExactSearcherTests extends KNNTestCase { private static final String SEGMENT_NAME = "0"; + @SneakyThrows + public void testExactSearch_whenSegmentHasNoVectorField_thenNoDocsReturned() { + final float[] queryVector = new float[] { 0.1f, 2.0f, 3.0f }; + final KNNQuery query = KNNQuery.builder().field(FIELD_NAME).queryVector(queryVector).k(10).indexName(INDEX_NAME).build(); + + final ExactSearcher.ExactSearcherContext.ExactSearcherContextBuilder exactSearcherContextBuilder = + ExactSearcher.ExactSearcherContext.builder().knnQuery(query); + + ExactSearcher exactSearcher = new ExactSearcher(null); + final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + final SegmentReader reader = mock(SegmentReader.class); + when(leafReaderContext.reader()).thenReturn(reader); + + final FieldInfos fieldInfos = mock(FieldInfos.class); + when(reader.getFieldInfos()).thenReturn(fieldInfos); + when(fieldInfos.fieldInfo(query.getField())).thenReturn(null); + Map docIds = exactSearcher.searchLeaf(leafReaderContext, exactSearcherContextBuilder.build()); + Mockito.verify(fieldInfos).fieldInfo(query.getField()); + Mockito.verify(reader).getFieldInfos(); + Mockito.verify(leafReaderContext).reader(); + assertEquals(0, docIds.size()); + } + + @SneakyThrows + public void testRadialSearchExactSearch_whenSegmentHasNoVectorField_thenNoDocsReturned() { + final float[] queryVector = new float[] { 0.1f, 2.0f, 3.0f }; + KNNQuery.Context context = new KNNQuery.Context(10); + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(queryVector) + .context(context) + .radius(1.0f) + .indexName(INDEX_NAME) + .build(); + + final ExactSearcher.ExactSearcherContext.ExactSearcherContextBuilder exactSearcherContextBuilder = + ExactSearcher.ExactSearcherContext.builder().knnQuery(query); + + ExactSearcher exactSearcher = new ExactSearcher(null); + final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + final SegmentReader reader = mock(SegmentReader.class); + when(leafReaderContext.reader()).thenReturn(reader); + + final FieldInfos fieldInfos = mock(FieldInfos.class); + when(reader.getFieldInfos()).thenReturn(fieldInfos); + when(fieldInfos.fieldInfo(query.getField())).thenReturn(null); + Map docIds = exactSearcher.searchLeaf(leafReaderContext, exactSearcherContextBuilder.build()); + Mockito.verify(fieldInfos).fieldInfo(query.getField()); + Mockito.verify(reader).getFieldInfos(); + Mockito.verify(leafReaderContext).reader(); + assertEquals(0, docIds.size()); + } + @SneakyThrows public void testRadialSearch_whenNoEngineFiles_thenSuccess() { try (MockedStatic valuesFactoryMockedStatic = Mockito.mockStatic(KNNVectorValuesFactory.class)) { @@ -75,6 +129,7 @@ public void testRadialSearch_whenNoEngineFiles_thenSuccess() { .queryVector(queryVector) .radius(radius) .indexName(INDEX_NAME) + .vectorDataType(VectorDataType.FLOAT) .context(context) .build();