Skip to content

Commit

Permalink
Adding Support to Enable/Disble Share level Rescoring and Update Over…
Browse files Browse the repository at this point in the history
…sampling Factor (#2172)

Signed-off-by: VIKASH TIWARI <[email protected]>
Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
Vikasht34 authored and jmazanec15 committed Oct 1, 2024
1 parent 035e5da commit 7c8f5cf
Show file tree
Hide file tree
Showing 8 changed files with 214 additions and 48 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Optimize reduceToTopK in ResultUtil by removing pre-filling and reducing peek calls [#2146](https://github.com/opensearch-project/k-NN/pull/2146)
* Update Default Rescore Context based on Dimension [#2149](https://github.com/opensearch-project/k-NN/pull/2149)
* KNNIterators should support with and without filters [#2155](https://github.com/opensearch-project/k-NN/pull/2155)
* Adding Support to Enable/Disble Share level Rescoring and Update Oversampling Factor[#2172](https://github.com/opensearch-project/k-NN/pull/2172)
### Bug Fixes
* KNN80DocValues should only be considered for BinaryDocValues fields [#2147](https://github.com/opensearch-project/k-NN/pull/2147)
### Infrastructure
Expand Down
36 changes: 35 additions & 1 deletion src/main/java/org/opensearch/knn/index/KNNSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ public class KNNSettings {
public static final String QUANTIZATION_STATE_CACHE_SIZE_LIMIT = "knn.quantization.cache.size.limit";
public static final String QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES = "knn.quantization.cache.expiry.minutes";
public static final String KNN_FAISS_AVX512_DISABLED = "knn.faiss.avx512.disabled";
public static final String KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED = "index.knn.disk.vector.shard_level_rescoring_disabled";

/**
* Default setting values
Expand All @@ -112,11 +113,31 @@ public class KNNSettings {
public static final Integer KNN_MAX_QUANTIZATION_STATE_CACHE_SIZE_LIMIT_PERCENTAGE = 10; // Quantization state cache limit cannot exceed
// 10% of the JVM heap
public static final Integer KNN_DEFAULT_QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES = 60;
public static final boolean KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_VALUE = true;

/**
* Settings Definition
*/

/**
* This setting controls whether shard-level re-scoring for KNN disk-based vectors is turned off.
* The setting uses:
* <ul>
* <li><b>KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED:</b> The name of the setting.</li>
* <li><b>KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_VALUE:</b> The default value (true or false).</li>
* <li><b>IndexScope:</b> The setting works at the index level.</li>
* <li><b>Dynamic:</b> This setting can be changed without restarting the cluster.</li>
* </ul>
*
* @see Setting#boolSetting(String, boolean, Setting.Property...)
*/
public static final Setting<Boolean> KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING = Setting.boolSetting(
KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED,
KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_VALUE,
IndexScope,
Dynamic
);

// This setting controls how much memory should be used to transfer vectors from Java to JNI Layer. The default
// 1% of the JVM heap
public static final Setting<ByteSizeValue> KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING = Setting.memorySizeSetting(
Expand Down Expand Up @@ -454,6 +475,10 @@ private Setting<?> getSetting(String key) {
return QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING;
}

if (KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED.equals(key)) {
return KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING;
}

throw new IllegalArgumentException("Cannot find setting by key [" + key + "]");
}

Expand All @@ -475,7 +500,8 @@ public List<Setting<?>> getSettings() {
KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING,
KNN_FAISS_AVX512_DISABLED_SETTING,
QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING,
QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING
QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING,
KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING
);
return Stream.concat(settings.stream(), Stream.concat(getFeatureFlags().stream(), dynamicCacheSettings.values().stream()))
.collect(Collectors.toList());
Expand Down Expand Up @@ -528,6 +554,14 @@ public static Integer getFilteredExactSearchThreshold(final String indexName) {
.getAsInt(ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE);
}

public static boolean isShardLevelRescoringDisabledForDiskBasedVector(String indexName) {
return KNNSettings.state().clusterService.state()
.getMetadata()
.index(indexName)
.getSettings()
.getAsBoolean(KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED, true);
}

public void initialize(Client client, ClusterService clusterService) {
this.client = client;
this.clusterService = clusterService;
Expand Down
31 changes: 17 additions & 14 deletions src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java
Original file line number Diff line number Diff line change
Expand Up @@ -97,32 +97,35 @@ public static boolean isConfigured(CompressionLevel compressionLevel) {
/**
* Returns the appropriate {@link RescoreContext} based on the given {@code mode} and {@code dimension}.
*
* <p>If the {@code mode} is present in the valid {@code modesForRescore} set, the method checks the value of
* {@code dimension}:
* <p>If the {@code mode} is present in the valid {@code modesForRescore} set, the method adjusts the oversample factor based on the
* {@code dimension} value:
* <ul>
* <li>If {@code dimension} is less than or equal to 1000, it returns a {@link RescoreContext} with an
* oversample factor of 5.0f.</li>
* <li>If {@code dimension} is greater than 1000, it returns the default {@link RescoreContext} associated with
* the {@link CompressionLevel}. If no default is set, it falls back to {@link RescoreContext#getDefault()}.</li>
* <li>If {@code dimension} is greater than or equal to 1000, no oversampling is applied (oversample factor = 1.0).</li>
* <li>If {@code dimension} is greater than or equal to 768 but less than 1000, a 2x oversample factor is applied (oversample factor = 2.0).</li>
* <li>If {@code dimension} is less than 768, a 3x oversample factor is applied (oversample factor = 3.0).</li>
* </ul>
* If the {@code mode} is not valid, the method returns {@code null}.
* If the {@code mode} is not present in the {@code modesForRescore} set, the method returns {@code null}.
*
* @param mode The {@link Mode} for which to retrieve the {@link RescoreContext}.
* @param dimension The dimensional value that determines the {@link RescoreContext} behavior.
* @return A {@link RescoreContext} with an oversample factor of 5.0f if {@code dimension} is less than
* or equal to 1000, the default {@link RescoreContext} if greater, or {@code null} if the mode
* is invalid.
* @return A {@link RescoreContext} with the appropriate oversample factor based on the dimension, or {@code null} if the mode
* is not valid.
*/
public RescoreContext getDefaultRescoreContext(Mode mode, int dimension) {
if (modesForRescore.contains(mode)) {
// Adjust RescoreContext based on dimension
if (dimension <= RescoreContext.DIMENSION_THRESHOLD) {
// For dimensions <= 1000, return a RescoreContext with 5.0f oversample factor
return RescoreContext.builder().oversampleFactor(RescoreContext.OVERSAMPLE_FACTOR_BELOW_DIMENSION_THRESHOLD).build();
if (dimension >= RescoreContext.DIMENSION_THRESHOLD_1000) {
// No oversampling for dimensions >= 1000
return RescoreContext.builder().oversampleFactor(RescoreContext.OVERSAMPLE_FACTOR_1000).build();
} else if (dimension >= RescoreContext.DIMENSION_THRESHOLD_768) {
// 2x oversampling for dimensions >= 768 but < 1000
return RescoreContext.builder().oversampleFactor(RescoreContext.OVERSAMPLE_FACTOR_768).build();
} else {
return defaultRescoreContext;
// 3x oversampling for dimensions < 768
return RescoreContext.builder().oversampleFactor(RescoreContext.OVERSAMPLE_FACTOR_BELOW_768).build();
}
}
return null;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.opensearch.common.StopWatch;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.query.ExactSearcher;
import org.opensearch.knn.index.query.KNNQuery;
import org.opensearch.knn.index.query.KNNWeight;
Expand Down Expand Up @@ -54,7 +55,6 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo
final IndexReader reader = indexSearcher.getIndexReader();
final KNNWeight knnWeight = (KNNWeight) knnQuery.createWeight(indexSearcher, ScoreMode.COMPLETE, 1);
List<LeafReaderContext> leafReaderContexts = reader.leaves();

List<Map<Integer, Float>> perLeafResults;
RescoreContext rescoreContext = knnQuery.getRescoreContext();
int finalK = knnQuery.getK();
Expand All @@ -63,7 +63,9 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo
} else {
int firstPassK = rescoreContext.getFirstPassK(finalK);
perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, firstPassK);
ResultUtil.reduceToTopK(perLeafResults, firstPassK);
if (KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(knnQuery.getIndexName()) == false) {
ResultUtil.reduceToTopK(perLeafResults, firstPassK);
}

StopWatch stopWatch = new StopWatch().start();
perLeafResults = doRescore(indexSearcher, leafReaderContexts, knnWeight, perLeafResults, finalK);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ public final class RescoreContext {
public static final int DIMENSION_THRESHOLD = 1000;
public static final float OVERSAMPLE_FACTOR_BELOW_DIMENSION_THRESHOLD = 5.0f;

// Dimension thresholds for adjusting oversample factor
public static final int DIMENSION_THRESHOLD_1000 = 1000;
public static final int DIMENSION_THRESHOLD_768 = 768;

// Oversample factors based on dimension thresholds
public static final float OVERSAMPLE_FACTOR_1000 = 1.0f; // No oversampling for dimensions >= 1000
public static final float OVERSAMPLE_FACTOR_768 = 2.0f; // 2x oversampling for dimensions >= 768 and < 1000
public static final float OVERSAMPLE_FACTOR_BELOW_768 = 3.0f; // 3x oversampling for dimensions < 768

// Todo:- We will improve this in upcoming releases
public static final int MIN_FIRST_PASS_RESULTS = 100;

Expand Down
35 changes: 35 additions & 0 deletions src/test/java/org/opensearch/knn/index/KNNSettingsTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,41 @@ public void testGetEfSearch_whenEFSearchValueSetByUser_thenReturnValue() {
assertEquals(userProvidedEfSearch, efSearchValue);
}

@SneakyThrows
public void testShardLevelRescoringDisabled_whenNoValuesProvidedByUser_thenDefaultSettingsUsed() {
Node mockNode = createMockNode(Collections.emptyMap());
mockNode.start();
ClusterService clusterService = mockNode.injector().getInstance(ClusterService.class);
mockNode.client().admin().cluster().state(new ClusterStateRequest()).actionGet();
mockNode.client().admin().indices().create(new CreateIndexRequest(INDEX_NAME)).actionGet();
KNNSettings.state().setClusterService(clusterService);

boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(INDEX_NAME);
mockNode.close();
assertTrue(shardLevelRescoringDisabled);
}

@SneakyThrows
public void testShardLevelRescoringDisabled_whenValueProvidedByUser_thenSettingApplied() {
boolean userDefinedRescoringDisabled = false;
Node mockNode = createMockNode(Collections.emptyMap());
mockNode.start();
ClusterService clusterService = mockNode.injector().getInstance(ClusterService.class);
mockNode.client().admin().cluster().state(new ClusterStateRequest()).actionGet();
mockNode.client().admin().indices().create(new CreateIndexRequest(INDEX_NAME)).actionGet();
KNNSettings.state().setClusterService(clusterService);

final Settings rescoringDisabledSetting = Settings.builder()
.put(KNNSettings.KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED, userDefinedRescoringDisabled)
.build();

mockNode.client().admin().indices().updateSettings(new UpdateSettingsRequest(rescoringDisabledSetting, INDEX_NAME)).actionGet();

boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(INDEX_NAME);
mockNode.close();
assertEquals(userDefinedRescoringDisabled, shardLevelRescoringDisabled);
}

@SneakyThrows
public void testGetFaissAVX2DisabledSettingValueFromConfig_enableSetting_thenValidateAndSucceed() {
boolean expectedKNNFaissAVX2Disabled = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,65 +44,84 @@ public void testIsConfigured() {
public void testGetDefaultRescoreContext() {
// Test rescore context for ON_DISK mode
Mode mode = Mode.ON_DISK;
int belowThresholdDimension = 500; // A dimension below the threshold
int aboveThresholdDimension = 1500; // A dimension above the threshold

// x32 with dimension <= 1000 should have an oversample factor of 5.0f
// Test various dimensions based on the updated oversampling logic
int belowThresholdDimension = 500; // A dimension below 768
int between768and1000Dimension = 800; // A dimension between 768 and 1000
int above1000Dimension = 1500; // A dimension above 1000

// Compression level x32 with dimension < 768 should have an oversample factor of 3.0f
RescoreContext rescoreContext = CompressionLevel.x32.getDefaultRescoreContext(mode, belowThresholdDimension);
assertNotNull(rescoreContext);
assertEquals(5.0f, rescoreContext.getOversampleFactor(), 0.0f);
assertEquals(3.0f, rescoreContext.getOversampleFactor(), 0.0f);

// x32 with dimension > 1000 should have an oversample factor of 3.0f
rescoreContext = CompressionLevel.x32.getDefaultRescoreContext(mode, aboveThresholdDimension);
// Compression level x32 with dimension between 768 and 1000 should have an oversample factor of 2.0f
rescoreContext = CompressionLevel.x32.getDefaultRescoreContext(mode, between768and1000Dimension);
assertNotNull(rescoreContext);
assertEquals(3.0f, rescoreContext.getOversampleFactor(), 0.0f);
assertEquals(2.0f, rescoreContext.getOversampleFactor(), 0.0f);

// x16 with dimension <= 1000 should have an oversample factor of 5.0f
rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, belowThresholdDimension);
// Compression level x32 with dimension > 1000 should have no oversampling (1.0f)
rescoreContext = CompressionLevel.x32.getDefaultRescoreContext(mode, above1000Dimension);
assertNotNull(rescoreContext);
assertEquals(5.0f, rescoreContext.getOversampleFactor(), 0.0f);
assertEquals(1.0f, rescoreContext.getOversampleFactor(), 0.0f);

// x16 with dimension > 1000 should have an oversample factor of 3.0f
rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, aboveThresholdDimension);
// Compression level x16 with dimension < 768 should have an oversample factor of 3.0f
rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, belowThresholdDimension);
assertNotNull(rescoreContext);
assertEquals(3.0f, rescoreContext.getOversampleFactor(), 0.0f);

// x8 with dimension <= 1000 should have an oversample factor of 5.0f
// Compression level x16 with dimension between 768 and 1000 should have an oversample factor of 2.0f
rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, between768and1000Dimension);
assertNotNull(rescoreContext);
assertEquals(2.0f, rescoreContext.getOversampleFactor(), 0.0f);

// Compression level x16 with dimension > 1000 should have no oversampling (1.0f)
rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, above1000Dimension);
assertNotNull(rescoreContext);
assertEquals(1.0f, rescoreContext.getOversampleFactor(), 0.0f);

// Compression level x8 with dimension < 768 should have an oversample factor of 3.0f
rescoreContext = CompressionLevel.x8.getDefaultRescoreContext(mode, belowThresholdDimension);
assertNotNull(rescoreContext);
assertEquals(5.0f, rescoreContext.getOversampleFactor(), 0.0f);
assertEquals(3.0f, rescoreContext.getOversampleFactor(), 0.0f);

// x8 with dimension > 1000 should have an oversample factor of 2.0f
rescoreContext = CompressionLevel.x8.getDefaultRescoreContext(mode, aboveThresholdDimension);
// Compression level x8 with dimension between 768 and 1000 should have an oversample factor of 2.0f
rescoreContext = CompressionLevel.x8.getDefaultRescoreContext(mode, between768and1000Dimension);
assertNotNull(rescoreContext);
assertEquals(2.0f, rescoreContext.getOversampleFactor(), 0.0f);

// x4 with dimension <= 1000 should have an oversample factor of 5.0f (though it doesn't have its own RescoreContext)
// Compression level x8 with dimension > 1000 should have no oversampling (1.0f)
rescoreContext = CompressionLevel.x8.getDefaultRescoreContext(mode, above1000Dimension);
assertNotNull(rescoreContext);
assertEquals(1.0f, rescoreContext.getOversampleFactor(), 0.0f);

// Compression level x4 with dimension < 768 should return null (no RescoreContext)
rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, belowThresholdDimension);
assertNull(rescoreContext);
// x4 with dimension > 1000 should return null (no RescoreContext is configured for x4)
rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, aboveThresholdDimension);
assertNull(rescoreContext);

// Other compression levels should behave similarly with respect to dimension
// Compression level x4 with dimension > 1000 should return null (no RescoreContext)
rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, above1000Dimension);
assertNull(rescoreContext);

// Compression level x2 with dimension < 768 should return null
rescoreContext = CompressionLevel.x2.getDefaultRescoreContext(mode, belowThresholdDimension);
assertNull(rescoreContext);

// x2 with dimension > 1000 should return null
rescoreContext = CompressionLevel.x2.getDefaultRescoreContext(mode, aboveThresholdDimension);
// Compression level x2 with dimension > 1000 should return null
rescoreContext = CompressionLevel.x2.getDefaultRescoreContext(mode, above1000Dimension);
assertNull(rescoreContext);

// Compression level x1 with dimension < 768 should return null
rescoreContext = CompressionLevel.x1.getDefaultRescoreContext(mode, belowThresholdDimension);
assertNull(rescoreContext);

// x1 with dimension > 1000 should return null
rescoreContext = CompressionLevel.x1.getDefaultRescoreContext(mode, aboveThresholdDimension);
// Compression level x1 with dimension > 1000 should return null
rescoreContext = CompressionLevel.x1.getDefaultRescoreContext(mode, above1000Dimension);
assertNull(rescoreContext);

// NOT_CONFIGURED with dimension <= 1000 should return a RescoreContext with an oversample factor of 5.0f
// NOT_CONFIGURED mode should return null for any dimension
rescoreContext = CompressionLevel.NOT_CONFIGURED.getDefaultRescoreContext(mode, belowThresholdDimension);
assertNull(rescoreContext);

}

}
Loading

0 comments on commit 7c8f5cf

Please sign in to comment.