From b6780d5b8559e7d3d808975162ff5c87c4c816ce Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Tue, 24 Dec 2024 17:34:07 -0600 Subject: [PATCH] Fix shard level rescoring disabled setting flag Signed-off-by: Naveen Tatikonda --- src/main/java/org/opensearch/knn/index/KNNSettings.java | 2 +- .../index/query/nativelib/NativeEngineKnnVectorQuery.java | 6 +++--- .../knn/index/query/rescore/RescoreContext.java | 8 ++++---- .../java/org/opensearch/knn/index/KNNSettingsTests.java | 6 +++--- .../query/nativelib/NativeEngineKNNVectorQueryTests.java | 4 ++-- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index b81a54124..6dc72a22b 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -577,7 +577,7 @@ public static Integer getFilteredExactSearchThreshold(final String indexName) { .getAsInt(ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE); } - public static boolean isShardLevelRescoringEnabledForDiskBasedVector(String indexName) { + public static boolean isShardLevelRescoringDisabledForDiskBasedVector(String indexName) { return KNNSettings.state().clusterService.state() .getMetadata() .index(indexName) diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java index 47ea215f3..5b4d6e7a1 100644 --- a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java @@ -63,11 +63,11 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo if (rescoreContext == null) { perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, finalK); } else { - boolean isShardLevelRescoringEnabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(knnQuery.getIndexName()); + boolean isShardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(knnQuery.getIndexName()); int dimension = knnQuery.getQueryVector().length; - int firstPassK = rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension); + int firstPassK = rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension); perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, firstPassK); - if (isShardLevelRescoringEnabled == true) { + if (isShardLevelRescoringDisabled == false) { ResultUtil.reduceToTopK(perLeafResults, firstPassK); } diff --git a/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java b/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java index 09aeb7591..03c03a87f 100644 --- a/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java +++ b/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java @@ -61,17 +61,17 @@ public static RescoreContext getDefault() { * based on the vector dimension if shard-level rescoring is disabled. * * @param finalK The final number of results to return for the entire shard. - * @param isShardLevelRescoringEnabled A boolean flag indicating whether shard-level rescoring is enabled. - * If true, the dimension-based oversampling logic is bypassed. + * @param isShardLevelRescoringDisabled A boolean flag indicating whether shard-level rescoring is disabled. + * If false, the dimension-based oversampling logic is bypassed. * @param dimension The dimension of the vector. This is used to determine the oversampling factor when * shard-level rescoring is disabled. * @return The number of results to return for the first pass of rescoring, adjusted by the oversample factor. */ - public int getFirstPassK(int finalK, boolean isShardLevelRescoringEnabled, int dimension) { + public int getFirstPassK(int finalK, boolean isShardLevelRescoringDisabled, int dimension) { // Only apply default dimension-based oversampling logic when: // 1. Shard-level rescoring is disabled // 2. The oversample factor was not provided by the user - if (!isShardLevelRescoringEnabled && !userProvided) { + if (isShardLevelRescoringDisabled && !userProvided) { // Apply new dimension-based oversampling logic when shard-level rescoring is disabled if (dimension >= DIMENSION_THRESHOLD_1000) { oversampleFactor = OVERSAMPLE_FACTOR_1000; // No oversampling for dimensions >= 1000 diff --git a/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java b/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java index c7a8e7ed8..24990dd36 100644 --- a/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java @@ -159,7 +159,7 @@ public void testGetEfSearch_whenEFSearchValueSetByUser_thenReturnValue() { } @SneakyThrows - public void testShardLevelRescoringEnabled_whenNoValuesProvidedByUser_thenDefaultSettingsUsed() { + public void testShardLevelRescoringDisabled_whenNoValuesProvidedByUser_thenDefaultSettingsUsed() { Node mockNode = createMockNode(Collections.emptyMap()); mockNode.start(); ClusterService clusterService = mockNode.injector().getInstance(ClusterService.class); @@ -167,7 +167,7 @@ public void testShardLevelRescoringEnabled_whenNoValuesProvidedByUser_thenDefaul mockNode.client().admin().indices().create(new CreateIndexRequest(INDEX_NAME)).actionGet(); KNNSettings.state().setClusterService(clusterService); - boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(INDEX_NAME); + boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(INDEX_NAME); mockNode.close(); assertFalse(shardLevelRescoringDisabled); } @@ -188,7 +188,7 @@ public void testShardLevelRescoringDisabled_whenValueProvidedByUser_thenSettingA mockNode.client().admin().indices().updateSettings(new UpdateSettingsRequest(rescoringDisabledSetting, INDEX_NAME)).actionGet(); - boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(INDEX_NAME); + boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(INDEX_NAME); mockNode.close(); assertEquals(userDefinedRescoringDisabled, shardLevelRescoringDisabled); } diff --git a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java index 789bd1054..87c4a5014 100644 --- a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java @@ -183,7 +183,7 @@ public void testRescoreWhenShardLevelRescoringEnabled() { ) { // When shard-level re-scoring is enabled - mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(any())).thenReturn(true); + mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(any())).thenReturn(false); // Mock ResultUtil to return valid TopDocs mockedResultUtil.when(() -> ResultUtil.resultMapToTopDocs(any(), anyInt())) @@ -265,7 +265,7 @@ public void testRescore() { ) { // When shard-level re-scoring is enabled - mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(any())).thenReturn(true); + mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(any())).thenReturn(false); mockedResultUtil.when(() -> ResultUtil.reduceToTopK(any(), anyInt())).thenAnswer(InvocationOnMock::callRealMethod); mockedResultUtil.when(() -> ResultUtil.resultMapToDocIds(any(), anyInt())).thenAnswer(InvocationOnMock::callRealMethod);