diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index d43a97c43a..3d41c52d3b 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.query; +import com.google.common.annotations.VisibleForTesting; import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.LeafReaderContext; @@ -99,6 +100,12 @@ public static void initialize(ModelDao modelDao) { KNNWeight.DEFAULT_EXACT_SEARCHER = new ExactSearcher(modelDao); } + @VisibleForTesting + static void initialize(ModelDao modelDao, ExactSearcher exactSearcher) { + KNNWeight.modelDao = modelDao; + KNNWeight.DEFAULT_EXACT_SEARCHER = exactSearcher; + } + @Override public Explanation explain(LeafReaderContext context, int doc) { return Explanation.match(1.0f, "No Explanation"); diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 7edd4a414f..a5a00d2482 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -89,6 +89,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.knn.KNNRestTestCase.INDEX_NAME; import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; @@ -327,9 +328,8 @@ public void testQueryScoreForFaissWithNonExistingModel() throws IOException { } @SneakyThrows - public void testScorer_whenNoVectorFieldsInDocument_thenEmptyScorerIsReturned() { + public void testShardWithoutFiles() { final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null); - KNNWeight.initialize(null); final KNNWeight knnWeight = new KNNWeight(query, 0.0f); final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); @@ -360,9 +360,10 @@ public void testScorer_whenNoVectorFieldsInDocument_thenEmptyScorerIsReturned() final Path path = mock(Path.class); when(directory.getDirectory()).thenReturn(path); final FieldInfos fieldInfos = mock(FieldInfos.class); + final FieldInfo fieldInfo = mock(FieldInfo.class); when(reader.getFieldInfos()).thenReturn(fieldInfos); - // When no knn fields are available , field info for vector field will be null - when(fieldInfos.fieldInfo(FIELD_NAME)).thenReturn(null); + when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); + final Scorer knnScorer = knnWeight.scorer(leafReaderContext); assertEquals(KNNScorer.emptyScorer(knnWeight), knnScorer); } @@ -762,6 +763,83 @@ public void validateANNWithFilterQuery_whenExactSearch_thenSuccess(final boolean } } + @SneakyThrows + public void testRadialSearch_whenNoEngineFiles_thenPerformExactSearch() { + ExactSearcher mockedExactSearcher = mock(ExactSearcher.class); + final float[] queryVector = new float[] { 0.1f, 2.0f, 3.0f }; + final SpaceType spaceType = randomFrom(SpaceType.L2, SpaceType.INNER_PRODUCT); + KNNWeight.initialize(null, mockedExactSearcher); + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(queryVector) + .indexName(INDEX_NAME) + .methodParameters(HNSW_METHOD_PARAMETERS) + .build(); + final KNNWeight knnWeight = new KNNWeight(query, 1.0f); + + final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + final SegmentReader reader = mock(SegmentReader.class); + when(leafReaderContext.reader()).thenReturn(reader); + + final FSDirectory directory = mock(FSDirectory.class); + when(reader.directory()).thenReturn(directory); + final SegmentInfo segmentInfo = new SegmentInfo( + directory, + Version.LATEST, + Version.LATEST, + SEGMENT_NAME, + 100, + false, + false, + KNNCodecVersion.current().getDefaultCodecDelegate(), + Map.of(), + new byte[StringHelper.ID_LENGTH], + Map.of(), + Sort.RELEVANCE + ); + segmentInfo.setFiles(Set.of()); + final SegmentCommitInfo segmentCommitInfo = new SegmentCommitInfo(segmentInfo, 0, 0, 0, 0, 0, new byte[StringHelper.ID_LENGTH]); + when(reader.getSegmentInfo()).thenReturn(segmentCommitInfo); + + final Path path = mock(Path.class); + when(directory.getDirectory()).thenReturn(path); + final FieldInfos fieldInfos = mock(FieldInfos.class); + final FieldInfo fieldInfo = mock(FieldInfo.class); + when(reader.getFieldInfos()).thenReturn(fieldInfos); + when(fieldInfos.fieldInfo(FIELD_NAME)).thenReturn(fieldInfo); + when(fieldInfo.attributes()).thenReturn( + Map.of( + SPACE_TYPE, + spaceType.getValue(), + KNN_ENGINE, + KNNEngine.FAISS.getName(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + ) + ); + final ExactSearcher.ExactSearcherContext exactSearchContext = ExactSearcher.ExactSearcherContext.builder() + .isParentHits(true) + // setting to true, so that if quantization details are present we want to do search on the quantized + // vectors as this flow is used in first pass of search. + .useQuantizedVectorsForSearch(true) + .knnQuery(query) + .build(); + when(mockedExactSearcher.searchLeaf(leafReaderContext, exactSearchContext)).thenReturn(DOC_ID_TO_SCORES); + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + assertNotNull(knnScorer); + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + final List actualDocIds = new ArrayList<>(); + for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + actualDocIds.add(docId); + assertEquals(DOC_ID_TO_SCORES.get(docId), knnScorer.score(), 0.00000001f); + } + assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + // verify JNI Service is not called + jniServiceMockedStatic.verifyNoInteractions(); + verify(mockedExactSearcher).searchLeaf(leafReaderContext, exactSearchContext); + } + @SneakyThrows public void testANNWithFilterQuery_whenExactSearchAndThresholdComputations_thenSuccess() { ModelDao modelDao = mock(ModelDao.class);