From d89c85e71609c11f74a18cf60a2f063e91295c04 Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Wed, 23 Aug 2023 16:53:52 -0700 Subject: [PATCH] Improved the logic to switch to exact search for restrictive filters search. (#1059) (#1060) This change includes: * adding 2 extra advanced K-NN settings on when to do exact search for users to tune. Signed-off-by: Navneet Verma (cherry picked from commit 4edc1bf5d7d2fab39fc689ea35d0778b559f5ebc) Co-authored-by: Navneet Verma --- CHANGELOG.md | 3 +- .../org/opensearch/knn/index/KNNSettings.java | 49 +++++++++- .../opensearch/knn/index/query/KNNWeight.java | 44 ++++++++- .../knn/index/KNNSettingsTests.java | 85 +++++++++++++++++ .../knn/index/query/KNNWeightTests.java | 91 ++++++++++++++++++- 5 files changed, 263 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3dea320b4..3ddd7fca2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,10 +16,11 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements * Enabled the IVF algorithm to work with Filters of K-NN Query. [#1013](https://github.com/opensearch-project/k-NN/pull/1013) +* Improved the logic to switch to exact search for restrictive filters search for better recall. [#1059](https://github.com/opensearch-project/k-NN/pull/1059) ### Bug Fixes ### Infrastructure ### Documentation ### Maintenance * Update Guava Version to 32.0.1 [#1019](https://github.com/opensearch-project/k-NN/pull/1019) ### Refactoring -* Fix TransportAddress Refactoring Changes in Core [#1020](https://github.com/opensearch-project/k-NN/pull/1020) \ No newline at end of file +* Fix TransportAddress Refactoring Changes in Core [#1020](https://github.com/opensearch-project/k-NN/pull/1020) diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index 6c7a80c82..a1958233b 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -73,6 +73,8 @@ public class KNNSettings { public static final String MODEL_INDEX_NUMBER_OF_SHARDS = "knn.model.index.number_of_shards"; public static final String MODEL_INDEX_NUMBER_OF_REPLICAS = "knn.model.index.number_of_replicas"; public static final String MODEL_CACHE_SIZE_LIMIT = "knn.model.cache.size.limit"; + public static final String ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD = "index.knn.advanced.filtered_exact_search_threshold"; + public static final String ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_PCT = "index.knn.advanced.filtered_exact_search_threshold_pct"; /** * Default setting values @@ -87,6 +89,9 @@ public class KNNSettings { public static final Integer KNN_MAX_MODEL_CACHE_SIZE_LIMIT_PERCENTAGE = 25; // Model cache limit cannot exceed 25% of the JVM heap public static final String KNN_DEFAULT_MEMORY_CIRCUIT_BREAKER_LIMIT = "50%"; + public static final Integer ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE = 2000; + public static final Integer ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_PCT_DEFAULT_VALUE = 10; + /** * Settings Definition */ @@ -154,6 +159,22 @@ public class KNNSettings { Setting.Property.Dynamic ); + public static final Setting ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_SETTING = Setting.intSetting( + ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, + ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE, + 0, + IndexScope, + Setting.Property.Dynamic + ); + + public static final Setting ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_PCT_SETTING = Setting.intSetting( + ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_PCT, + ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_PCT_DEFAULT_VALUE, + 0, + IndexScope, + Setting.Property.Dynamic + ); + public static final Setting MODEL_CACHE_SIZE_LIMIT_SETTING = new Setting<>( MODEL_CACHE_SIZE_LIMIT, percentageAsString(KNN_DEFAULT_MODEL_CACHE_SIZE_LIMIT_PERCENTAGE), @@ -323,6 +344,14 @@ private Setting getSetting(String key) { return KNN_ALGO_PARAM_INDEX_THREAD_QTY_SETTING; } + if (ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD.equals(key)) { + return ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_SETTING; + } + + if (ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_PCT.equals(key)) { + return ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_PCT_SETTING; + } + throw new IllegalArgumentException("Cannot find setting by key [" + key + "]"); } @@ -338,7 +367,9 @@ public List> getSettings() { IS_KNN_INDEX_SETTING, MODEL_INDEX_NUMBER_OF_SHARDS_SETTING, MODEL_INDEX_NUMBER_OF_REPLICAS_SETTING, - MODEL_CACHE_SIZE_LIMIT_SETTING + MODEL_CACHE_SIZE_LIMIT_SETTING, + ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_SETTING, + ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_PCT_SETTING ); return Stream.concat(settings.stream(), dynamicCacheSettings.values().stream()).collect(Collectors.toList()); } @@ -359,6 +390,22 @@ public static double getCircuitBreakerUnsetPercentage() { return KNNSettings.state().getSettingValue(KNNSettings.KNN_CIRCUIT_BREAKER_UNSET_PERCENTAGE); } + public static int getFilteredExactSearchThreshold(final String indexName) { + return KNNSettings.state().clusterService.state() + .getMetadata() + .index(indexName) + .getSettings() + .getAsInt(ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE); + } + + public static int getFilteredExactSearchThresholdPct(final String indexName) { + return KNNSettings.state().clusterService.state() + .getMetadata() + .index(indexName) + .getSettings() + .getAsInt(ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_PCT, ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_PCT_DEFAULT_VALUE); + } + public void initialize(Client client, ClusterService clusterService) { this.client = client; this.clusterService = clusterService; diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 4bbf61f25..7352dd436 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -18,6 +18,7 @@ import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.FixedBitSet; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.codec.util.KNNVectorSerializer; import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; @@ -115,13 +116,16 @@ public Scorer scorer(LeafReaderContext context) throws IOException { * . Hence, if filtered results are less than K and filter query is present we should shift to exact search. * This improves the recall. */ - if (filterWeight != null && filterIdsArray.length <= knnQuery.getK()) { + if (filterWeight != null && canDoExactSearch(filterIdsArray.length, getTotalDocsInSegment(context))) { docIdsToScoreMap.putAll(doExactSearch(context, filterIdsArray)); } else { - final Map annResults = doANNSearch(context, filterIdsArray); + Map annResults = doANNSearch(context, filterIdsArray); if (annResults == null) { return null; } + if (canDoExactSearchAfterANNSearch(filterIdsArray.length, annResults.size())) { + annResults = doExactSearch(context, filterIdsArray); + } docIdsToScoreMap.putAll(annResults); } if (docIdsToScoreMap.isEmpty()) { @@ -170,7 +174,6 @@ private int[] getFilterIdsArray(final LeafReaderContext context) throws IOExcept if (docId == DocIdSetIterator.NO_MORE_DOCS || docId + 1 == DocIdSetIterator.NO_MORE_DOCS) { break; } - log.debug("Docs in filtered docs id set is : {}", docId); filteredIds[filteredIdsIndex] = docId; filteredIdsIndex++; docId++; @@ -369,4 +372,39 @@ private SpaceType getSpaceType(final FieldInfo fieldInfo) { String.format(Locale.ROOT, "Unable to find the Space Type from Field Info attribute for field %s", fieldInfo.getName()) ); } + + private boolean canDoExactSearch(final int filterIdsCount, final int searchableDocs) { + log.debug( + "Info for doing exact search Live Docs: {}, filterIdsLength : {}, Threshold value: {} , Threshold %age : {}", + searchableDocs, + filterIdsCount, + KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName()), + KNNSettings.getFilteredExactSearchThresholdPct(knnQuery.getIndexName()) + ); + // Refer this GitHub around more details https://github.com/opensearch-project/k-NN/issues/1049 on the logic + return filterIdsCount <= knnQuery.getK() + || (filterIdsCount <= KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName()) + && (((float) filterIdsCount / (float) searchableDocs) * 100) <= (float) KNNSettings.getFilteredExactSearchThresholdPct( + knnQuery.getIndexName() + )); + } + + /** + * This condition mainly checks during filtered search we have more than K elements in filterIds but the ANN + * doesn't yeild K nearest neighbors. + * @param filterIdsCount count of filtered Doc ids + * @param annResultCount Count of Nearest Neighbours we got after doing filtered ANN Search. + * @return boolean - true if exactSearch needs to be done after ANNSearch. + */ + private boolean canDoExactSearchAfterANNSearch(final int filterIdsCount, final int annResultCount) { + return filterWeight != null && filterIdsCount >= knnQuery.getK() && knnQuery.getK() > annResultCount; + } + + private int getTotalDocsInSegment(final LeafReaderContext context) { + // This means that there is no deleted documents, hence the live docs bitset is null + if (context.reader().getLiveDocs() == null) { + return context.reader().maxDoc(); + } + return context.reader().getLiveDocs().length(); + } } diff --git a/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java b/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java index 4adad3fe6..9432be33e 100644 --- a/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java @@ -6,6 +6,10 @@ package org.opensearch.knn.index; import lombok.SneakyThrows; +import org.junit.Assert; +import org.opensearch.action.admin.cluster.state.ClusterStateRequest; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.settings.put.UpdateSettingsRequest; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.network.NetworkModule; @@ -33,6 +37,8 @@ public class KNNSettingsTests extends KNNTestCase { + private static final String INDEX_NAME = "myindex"; + @SneakyThrows public void testGetSettingValueFromConfig() { long expectedKNNCircuitBreakerLimit = 13; @@ -70,6 +76,85 @@ public void testGetSettingValueDefault() { assertWarnings(); } + @SneakyThrows + public void testFilteredSearchAdvanceSetting_whenNoValuesProvidedByUsers_thenDefaultSettingsUsed() { + Node mockNode = createMockNode(Collections.emptyMap()); + mockNode.start(); + ClusterService clusterService = mockNode.injector().getInstance(ClusterService.class); + mockNode.client().admin().cluster().state(new ClusterStateRequest()).actionGet(); + mockNode.client().admin().indices().create(new CreateIndexRequest(INDEX_NAME)).actionGet(); + KNNSettings.state().setClusterService(clusterService); + + int filteredSearchThresholdPct = KNNSettings.getFilteredExactSearchThresholdPct(INDEX_NAME); + int filteredSearchThreshold = KNNSettings.getFilteredExactSearchThreshold(INDEX_NAME); + mockNode.close(); + assertEquals((int) KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_PCT_DEFAULT_VALUE, filteredSearchThresholdPct); + assertEquals((int) KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE, filteredSearchThreshold); + assertWarnings(); + } + + @SneakyThrows + public void testFilteredSearchAdvanceSetting_whenValuesProvidedByUsers_thenValidateSameValues() { + int userDefinedPctThreshold = 20; + int userDefinedThreshold = 1000; + int userDefinedPctThresholdMinValue = 0; + int userDefinedThresholdMinValue = 0; + Node mockNode = createMockNode(Collections.emptyMap()); + mockNode.start(); + ClusterService clusterService = mockNode.injector().getInstance(ClusterService.class); + mockNode.client().admin().cluster().state(new ClusterStateRequest()).actionGet(); + mockNode.client().admin().indices().create(new CreateIndexRequest(INDEX_NAME)).actionGet(); + KNNSettings.state().setClusterService(clusterService); + + final Settings filteredSearchAdvanceSettings = Settings.builder() + .put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, userDefinedThreshold) + .put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_PCT, userDefinedPctThreshold) + .build(); + + mockNode.client() + .admin() + .indices() + .updateSettings(new UpdateSettingsRequest(filteredSearchAdvanceSettings, INDEX_NAME)) + .actionGet(); + + int filteredSearchThresholdPct = KNNSettings.getFilteredExactSearchThresholdPct(INDEX_NAME); + int filteredSearchThreshold = KNNSettings.getFilteredExactSearchThreshold(INDEX_NAME); + + // validate if we are able to set MinValues for the setting + final Settings filteredSearchAdvanceSettingsWithMinValues = Settings.builder() + .put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, userDefinedThresholdMinValue) + .put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_PCT, userDefinedPctThresholdMinValue) + .build(); + + mockNode.client() + .admin() + .indices() + .updateSettings(new UpdateSettingsRequest(filteredSearchAdvanceSettingsWithMinValues, INDEX_NAME)) + .actionGet(); + + int filteredSearchThresholdPctMinValue = KNNSettings.getFilteredExactSearchThresholdPct(INDEX_NAME); + int filteredSearchThresholdMinValue = KNNSettings.getFilteredExactSearchThreshold(INDEX_NAME); + + // Validate if less than MinValues are set then Exception Happens + final Settings filteredSearchAdvanceSettingsWithLessThanMinValues = Settings.builder() + .put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, -1) + .put(KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_PCT, -1) + .build(); + + Assert.assertThrows(IllegalArgumentException.class, () -> mockNode.client() + .admin() + .indices() + .updateSettings(new UpdateSettingsRequest(filteredSearchAdvanceSettingsWithLessThanMinValues, INDEX_NAME)) + .actionGet()); + + mockNode.close(); + assertEquals(userDefinedPctThreshold, filteredSearchThresholdPct); + assertEquals(userDefinedThreshold, filteredSearchThreshold); + assertEquals(userDefinedPctThresholdMinValue, filteredSearchThresholdPctMinValue); + assertEquals(userDefinedThresholdMinValue, filteredSearchThresholdMinValue); + assertWarnings(); + } + private Node createMockNode(Map configSettings) throws IOException { Path configDir = createTempDir(); File configFile = configDir.resolve("opensearch.yml").toFile(); diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 20ee69744..9cc624377 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -23,9 +23,11 @@ import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.Weight; import org.apache.lucene.store.FSDirectory; +import org.apache.lucene.util.Bits; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.StringHelper; import org.apache.lucene.util.Version; +import org.junit.Before; import org.junit.BeforeClass; import org.mockito.MockedStatic; import org.opensearch.common.io.PathUtils; @@ -86,10 +88,12 @@ public class KNNWeightTests extends KNNTestCase { private static MockedStatic nativeMemoryCacheManagerMockedStatic; private static MockedStatic jniServiceMockedStatic; + private static MockedStatic knnSettingsMockedStatic; + @BeforeClass public static void setUpClass() throws Exception { final KNNSettings knnSettings = mock(KNNSettings.class); - final MockedStatic knnSettingsMockedStatic = mockStatic(KNNSettings.class); + knnSettingsMockedStatic = mockStatic(KNNSettings.class); when(knnSettings.getSettingValue(eq(KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_ENABLED))).thenReturn(true); when(knnSettings.getSettingValue(eq(KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_LIMIT))).thenReturn(CIRCUIT_BREAKER_LIMIT_100KB); when(knnSettings.getSettingValue(eq(KNNSettings.KNN_CACHE_ITEM_EXPIRY_ENABLED))).thenReturn(false); @@ -118,6 +122,12 @@ public static void setUpClass() throws Exception { pathUtilsMockedStatic.when(() -> PathUtils.get(anyString(), anyString())).thenReturn(indexPath); } + @Before + public void setupBeforeTest() { + knnSettingsMockedStatic.when(() -> KNNSettings.getFilteredExactSearchThreshold(INDEX_NAME)).thenReturn(0); + knnSettingsMockedStatic.when(() -> KNNSettings.getFilteredExactSearchThresholdPct(INDEX_NAME)).thenReturn(0); + } + @SneakyThrows public void testQueryResultScoreNmslib() { for (SpaceType space : List.of(SpaceType.L2, SpaceType.L1, SpaceType.COSINESIMIL, SpaceType.INNER_PRODUCT, SpaceType.LINF)) { @@ -327,20 +337,27 @@ public void testEmptyQueryResults() { @SneakyThrows public void testANNWithFilterQuery_whenDoingANN_thenSuccess() { + int k = 3; final int[] filterDocIds = new int[] { 0, 1, 2, 3, 4, 5 }; jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), eq(filterDocIds))) .thenReturn(getFilteredKNNQueryResults()); final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); final SegmentReader reader = mock(SegmentReader.class); - when(reader.maxDoc()).thenReturn(K + 1); + final Bits liveDocsBits = mock(Bits.class); + when(reader.maxDoc()).thenReturn(filterDocIds.length); + when(reader.getLiveDocs()).thenReturn(liveDocsBits); + for (int filterDocId : filterDocIds) { + when(liveDocsBits.get(filterDocId)).thenReturn(true); + } + when(liveDocsBits.length()).thenReturn(1000); when(leafReaderContext.reader()).thenReturn(reader); - final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY); + final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, k, INDEX_NAME, FILTER_QUERY); final Weight filterQueryWeight = mock(Weight.class); final Scorer filterScorer = mock(Scorer.class); when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer); // Just to make sure that we are not hitting the exact search condition - when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(K + 1)); + when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(filterDocIds.length + 1)); final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight); @@ -406,6 +423,10 @@ public void testANNWithFilterQuery_whenExactSearch_thenSuccess() { // scorer will return 2 documents when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(1)); when(reader.maxDoc()).thenReturn(1); + final Bits liveDocsBits = mock(Bits.class); + when(reader.getLiveDocs()).thenReturn(liveDocsBits); + when(liveDocsBits.get(filterDocId)).thenReturn(true); + when(liveDocsBits.length()).thenReturn(1000); final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight); final Map attributesMap = ImmutableMap.of(KNN_ENGINE, KNNEngine.FAISS.getName(), SPACE_TYPE, SpaceType.L2.name()); @@ -437,6 +458,68 @@ public void testANNWithFilterQuery_whenExactSearch_thenSuccess() { assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); } + /** + * This test ensure that we do the exact search when threshold settings are correct and not using filteredIds<=K + * condition to do exact search. + * FilteredIdThreshold: 10 + * FilteredIdThresholdPct: 10% + * FilteredIdsCount: 6 + * liveDocs : null, as there is no deleted documents + * MaxDoc: 100 + * K : 1 + */ + @SneakyThrows + public void testANNWithFilterQuery_whenExactSearchViaThresholdSetting_thenSuccess() { + knnSettingsMockedStatic.when(() -> KNNSettings.getFilteredExactSearchThreshold(INDEX_NAME)).thenReturn(10); + knnSettingsMockedStatic.when(() -> KNNSettings.getFilteredExactSearchThresholdPct(INDEX_NAME)).thenReturn(10); + float[] vector = new float[] { 0.1f, 0.3f }; + int k = 1; + final int[] filterDocIds = new int[] { 0, 1, 2, 3, 4, 5 }; + + final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + final SegmentReader reader = mock(SegmentReader.class); + when(leafReaderContext.reader()).thenReturn(reader); + when(reader.maxDoc()).thenReturn(100); + when(reader.getLiveDocs()).thenReturn(null); + final Weight filterQueryWeight = mock(Weight.class); + final Scorer filterScorer = mock(Scorer.class); + when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer); + + when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(filterDocIds.length)); + + + final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, k, INDEX_NAME, FILTER_QUERY); + + final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight); + final Map attributesMap = ImmutableMap.of(KNN_ENGINE, KNNEngine.FAISS.getName(), SPACE_TYPE, SpaceType.L2.name()); + final FieldInfos fieldInfos = mock(FieldInfos.class); + final FieldInfo fieldInfo = mock(FieldInfo.class); + final BinaryDocValues binaryDocValues = mock(BinaryDocValues.class); + when(reader.getFieldInfos()).thenReturn(fieldInfos); + when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); + when(fieldInfo.attributes()).thenReturn(attributesMap); + when(fieldInfo.getAttribute(SPACE_TYPE)).thenReturn(SpaceType.L2.name()); + when(fieldInfo.getName()).thenReturn(FIELD_NAME); + when(reader.getBinaryDocValues(FIELD_NAME)).thenReturn(binaryDocValues); + when(binaryDocValues.advance(0)).thenReturn(0); + BytesRef vectorByteRef = new BytesRef(new KNNVectorAsArraySerializer().floatToByteArray(vector)); + when(binaryDocValues.binaryValue()).thenReturn(vectorByteRef); + + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + assertNotNull(knnScorer); + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + assertNotNull(docIdSetIterator); + assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost()); + + final List actualDocIds = new ArrayList<>(); + for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + actualDocIds.add(docId); + assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId), knnScorer.score(), 0.01f); + } + assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + } + @SneakyThrows public void testANNWithFilterQuery_whenEmptyFilterIds_thenReturnEarly() { final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);