Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Fix shard level rescoring disabled setting flag #2354

Merged
merged 1 commit into from
Dec 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282]
* Allow validation for non knn index only after 2.17.0 (#2315)[https://github.com/opensearch-project/k-NN/pull/2315]
* Release query vector memory after execution (#2346)[https://github.com/opensearch-project/k-NN/pull/2346]
* Fix shard level rescoring disabled setting flag (#2352)[https://github.com/opensearch-project/k-NN/pull/2352]
### Infrastructure
* Updated C++ version in JNI from c++11 to c++17 [#2259](https://github.com/opensearch-project/k-NN/pull/2259)
* Upgrade bytebuddy and objenesis version to match OpenSearch core and, update github ci runner for macos [#2279](https://github.com/opensearch-project/k-NN/pull/2279)
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/opensearch/knn/index/KNNSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ public static Integer getFilteredExactSearchThreshold(final String indexName) {
.getAsInt(ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE);
}

public static boolean isShardLevelRescoringEnabledForDiskBasedVector(String indexName) {
public static boolean isShardLevelRescoringDisabledForDiskBasedVector(String indexName) {
return KNNSettings.state().clusterService.state()
.getMetadata()
.index(indexName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo
if (rescoreContext == null) {
perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, finalK);
} else {
boolean isShardLevelRescoringEnabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(knnQuery.getIndexName());
boolean isShardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(knnQuery.getIndexName());
int dimension = knnQuery.getQueryVector().length;
int firstPassK = rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension);
int firstPassK = rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension);
perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, firstPassK);
if (isShardLevelRescoringEnabled == true) {
if (isShardLevelRescoringDisabled == false) {
ResultUtil.reduceToTopK(perLeafResults, firstPassK);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,17 @@ public static RescoreContext getDefault() {
* based on the vector dimension if shard-level rescoring is disabled.
*
* @param finalK The final number of results to return for the entire shard.
* @param isShardLevelRescoringEnabled A boolean flag indicating whether shard-level rescoring is enabled.
* If true, the dimension-based oversampling logic is bypassed.
* @param isShardLevelRescoringDisabled A boolean flag indicating whether shard-level rescoring is disabled.
* If false, the dimension-based oversampling logic is bypassed.
* @param dimension The dimension of the vector. This is used to determine the oversampling factor when
* shard-level rescoring is disabled.
* @return The number of results to return for the first pass of rescoring, adjusted by the oversample factor.
*/
public int getFirstPassK(int finalK, boolean isShardLevelRescoringEnabled, int dimension) {
public int getFirstPassK(int finalK, boolean isShardLevelRescoringDisabled, int dimension) {
// Only apply default dimension-based oversampling logic when:
// 1. Shard-level rescoring is disabled
// 2. The oversample factor was not provided by the user
if (!isShardLevelRescoringEnabled && !userProvided) {
if (isShardLevelRescoringDisabled && !userProvided) {
// Apply new dimension-based oversampling logic when shard-level rescoring is disabled
if (dimension >= DIMENSION_THRESHOLD_1000) {
oversampleFactor = OVERSAMPLE_FACTOR_1000; // No oversampling for dimensions >= 1000
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,15 @@ public void testGetEfSearch_whenEFSearchValueSetByUser_thenReturnValue() {
}

@SneakyThrows
public void testShardLevelRescoringEnabled_whenNoValuesProvidedByUser_thenDefaultSettingsUsed() {
public void testShardLevelRescoringDisabled_whenNoValuesProvidedByUser_thenDefaultSettingsUsed() {
Node mockNode = createMockNode(Collections.emptyMap());
mockNode.start();
ClusterService clusterService = mockNode.injector().getInstance(ClusterService.class);
mockNode.client().admin().cluster().state(new ClusterStateRequest()).actionGet();
mockNode.client().admin().indices().create(new CreateIndexRequest(INDEX_NAME)).actionGet();
KNNSettings.state().setClusterService(clusterService);

boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(INDEX_NAME);
boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(INDEX_NAME);
mockNode.close();
assertFalse(shardLevelRescoringDisabled);
}
Expand All @@ -192,7 +192,7 @@ public void testShardLevelRescoringDisabled_whenValueProvidedByUser_thenSettingA

mockNode.client().admin().indices().updateSettings(new UpdateSettingsRequest(rescoringDisabledSetting, INDEX_NAME)).actionGet();

boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(INDEX_NAME);
boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(INDEX_NAME);
mockNode.close();
assertEquals(userDefinedRescoringDisabled, shardLevelRescoringDisabled);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ public void testRescoreWhenShardLevelRescoringEnabled() {
) {

// When shard-level re-scoring is enabled
mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(any())).thenReturn(true);
mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(any())).thenReturn(false);

// Mock ResultUtil to return valid TopDocs
mockedResultUtil.when(() -> ResultUtil.resultMapToTopDocs(any(), anyInt()))
Expand Down Expand Up @@ -265,7 +265,7 @@ public void testRescore() {
) {

// When shard-level re-scoring is enabled
mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(any())).thenReturn(true);
mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(any())).thenReturn(false);

mockedResultUtil.when(() -> ResultUtil.reduceToTopK(any(), anyInt())).thenAnswer(InvocationOnMock::callRealMethod);
mockedResultUtil.when(() -> ResultUtil.resultMapToDocIds(any(), anyInt())).thenAnswer(InvocationOnMock::callRealMethod);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,23 @@ public void testGetFirstPassK() {
float oversample = 2.6f;
RescoreContext rescoreContext = RescoreContext.builder().oversampleFactor(oversample).userProvided(true).build();
int finalK = 100;
boolean isShardLevelRescoringEnabled = true;
boolean isShardLevelRescoringDisabled = false;
int dimension = 500;

// Case 1: Test with standard oversample factor when shard-level rescoring is enabled
assertEquals(260, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension));
assertEquals(260, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension));

// Case 2: Test with a very small finalK that should result in a value less than MIN_FIRST_PASS_RESULTS
finalK = 1;
assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension));
assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension));

// Case 3: Test with finalK = 0, should return MIN_FIRST_PASS_RESULTS
finalK = 0;
assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension));
assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension));

// Case 4: Test with finalK = MAX_FIRST_PASS_RESULTS, should cap at MAX_FIRST_PASS_RESULTS
finalK = MAX_FIRST_PASS_RESULTS;
assertEquals(MAX_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension));
assertEquals(MAX_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension));
}

public void testGetFirstPassKWithDimensionBasedOversampling() {
Expand All @@ -42,44 +42,44 @@ public void testGetFirstPassKWithDimensionBasedOversampling() {
// Case 1: Test no oversampling for dimensions >= 1000 when shard-level rescoring is disabled
dimension = 1000;
RescoreContext rescoreContext = RescoreContext.builder().userProvided(false).build(); // Ensuring dimension-based logic applies
assertEquals(100, rescoreContext.getFirstPassK(finalK, false, dimension)); // No oversampling
assertEquals(100, rescoreContext.getFirstPassK(finalK, true, dimension)); // No oversampling

// Case 2: Test 2x oversampling for dimensions >= 768 but < 1000 when shard-level rescoring is disabled
dimension = 800;
rescoreContext = RescoreContext.builder().userProvided(false).build(); // Ensure previous values don't carry over
assertEquals(200, rescoreContext.getFirstPassK(finalK, false, dimension)); // 2x oversampling
assertEquals(200, rescoreContext.getFirstPassK(finalK, true, dimension)); // 2x oversampling

// Case 3: Test 3x oversampling for dimensions < 768 when shard-level rescoring is disabled
dimension = 700;
rescoreContext = RescoreContext.builder().userProvided(false).build(); // Ensure previous values don't carry over
assertEquals(300, rescoreContext.getFirstPassK(finalK, false, dimension)); // 3x oversampling
assertEquals(300, rescoreContext.getFirstPassK(finalK, true, dimension)); // 3x oversampling

// Case 4: Shard-level rescoring enabled, oversample factor should be used as provided by the user (ignore dimension)
rescoreContext = RescoreContext.builder().oversampleFactor(5.0f).userProvided(true).build(); // Provided by user
dimension = 500;
assertEquals(500, rescoreContext.getFirstPassK(finalK, true, dimension)); // User-defined oversample factor should be used
assertEquals(500, rescoreContext.getFirstPassK(finalK, false, dimension)); // User-defined oversample factor should be used

// Case 5: Test finalK where oversampling factor results in a value less than MIN_FIRST_PASS_RESULTS
finalK = 10;
dimension = 700;
rescoreContext = RescoreContext.builder().userProvided(false).build(); // Ensure dimension-based logic applies
assertEquals(100, rescoreContext.getFirstPassK(finalK, false, dimension)); // 3x oversampling results in 30
assertEquals(100, rescoreContext.getFirstPassK(finalK, true, dimension)); // 3x oversampling results in 30
}

public void testGetFirstPassKWithMinPassK() {
float oversample = 0.5f;
RescoreContext rescoreContext = RescoreContext.builder().oversampleFactor(oversample).userProvided(true).build(); // User provided
boolean isShardLevelRescoringEnabled = false;
boolean isShardLevelRescoringDisabled = true;

// Case 1: Test where finalK * oversample is smaller than MIN_FIRST_PASS_RESULTS
int finalK = 10;
int dimension = 700;
assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension));
assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension));

// Case 2: Test where finalK * oversample results in exactly MIN_FIRST_PASS_RESULTS
finalK = 100;
oversample = 1.0f; // This will result in exactly 100 (MIN_FIRST_PASS_RESULTS)
rescoreContext = RescoreContext.builder().oversampleFactor(oversample).userProvided(true).build(); // User provided
assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension));
assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension));
}
}
Loading