Skip to content

Commit

Permalink
Score Fix for Binary Quantized Vector and Seeting Default value in ca…
Browse files Browse the repository at this point in the history
…se of shard level rescoring is disabled for oversampling factor

Signed-off-by: VIKASH TIWARI <[email protected]>
  • Loading branch information
Vikasht34 committed Oct 3, 2024
1 parent 207e341 commit 5ba078c
Show file tree
Hide file tree
Showing 10 changed files with 184 additions and 122 deletions.
21 changes: 11 additions & 10 deletions src/main/java/org/opensearch/knn/index/KNNSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,11 @@ public class KNNSettings {
public static final String QUANTIZATION_STATE_CACHE_SIZE_LIMIT = "knn.quantization.cache.size.limit";
public static final String QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES = "knn.quantization.cache.expiry.minutes";
public static final String KNN_FAISS_AVX512_DISABLED = "knn.faiss.avx512.disabled";
public static final String KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED = "index.knn.disk.vector.shard_level_rescoring_disabled";
public static final String KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_ENABLED = "index.knn.disk.vector.shard_level_rescoring_enabled";

/**
* Default setting values
*
*/
public static final boolean KNN_DEFAULT_FAISS_AVX2_DISABLED_VALUE = false;
public static final boolean KNN_DEFAULT_FAISS_AVX512_DISABLED_VALUE = false;
Expand All @@ -113,7 +114,7 @@ public class KNNSettings {
public static final Integer KNN_MAX_QUANTIZATION_STATE_CACHE_SIZE_LIMIT_PERCENTAGE = 10; // Quantization state cache limit cannot exceed
// 10% of the JVM heap
public static final Integer KNN_DEFAULT_QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES = 60;
public static final boolean KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_VALUE = true;
public static final boolean KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_ENABLED_VALUE = true;

/**
* Settings Definition
Expand All @@ -131,9 +132,9 @@ public class KNNSettings {
*
* @see Setting#boolSetting(String, boolean, Setting.Property...)
*/
public static final Setting<Boolean> KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING = Setting.boolSetting(
KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED,
KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_VALUE,
public static final Setting<Boolean> KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_ENABLED_SETTING = Setting.boolSetting(
KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_ENABLED,
KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_ENABLED_VALUE,
IndexScope,
Dynamic
);
Expand Down Expand Up @@ -475,8 +476,8 @@ private Setting<?> getSetting(String key) {
return QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING;
}

if (KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED.equals(key)) {
return KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING;
if (KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_ENABLED.equals(key)) {
return KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_ENABLED_SETTING;
}

throw new IllegalArgumentException("Cannot find setting by key [" + key + "]");
Expand All @@ -501,7 +502,7 @@ public List<Setting<?>> getSettings() {
KNN_FAISS_AVX512_DISABLED_SETTING,
QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING,
QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING,
KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING
KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_ENABLED_SETTING
);
return Stream.concat(settings.stream(), Stream.concat(getFeatureFlags().stream(), dynamicCacheSettings.values().stream()))
.collect(Collectors.toList());
Expand Down Expand Up @@ -554,12 +555,12 @@ public static Integer getFilteredExactSearchThreshold(final String indexName) {
.getAsInt(ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE);
}

public static boolean isShardLevelRescoringDisabledForDiskBasedVector(String indexName) {
public static boolean isShardLevelRescoringEnabledForDiskBasedVector(String indexName) {
return KNNSettings.state().clusterService.state()
.getMetadata()
.index(indexName)
.getSettings()
.getAsBoolean(KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED, true);
.getAsBoolean(KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_ENABLED, true);
}

public void initialize(Client client, ClusterService clusterService) {
Expand Down
39 changes: 20 additions & 19 deletions src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ public enum CompressionLevel {
x1(1, "1x", null, Collections.emptySet()),
x2(2, "2x", null, Collections.emptySet()),
x4(4, "4x", null, Collections.emptySet()),
x8(8, "8x", new RescoreContext(2.0f), Set.of(Mode.ON_DISK)),
x16(16, "16x", new RescoreContext(3.0f), Set.of(Mode.ON_DISK)),
x32(32, "32x", new RescoreContext(3.0f), Set.of(Mode.ON_DISK));
x8(8, "8x", new RescoreContext(2.0f, false), Set.of(Mode.ON_DISK)),
x16(16, "16x", new RescoreContext(3.0f, false), Set.of(Mode.ON_DISK)),
x32(32, "32x", new RescoreContext(3.0f, false), Set.of(Mode.ON_DISK));

// Internally, an empty string is easier to deal with them null. However, from the mapping,
// we do not want users to pass in the empty string and instead want null. So we make the conversion here
Expand Down Expand Up @@ -97,32 +97,33 @@ public static boolean isConfigured(CompressionLevel compressionLevel) {
/**
* Returns the appropriate {@link RescoreContext} based on the given {@code mode} and {@code dimension}.
*
* <p>If the {@code mode} is present in the valid {@code modesForRescore} set, the method adjusts the oversample factor based on the
* {@code dimension} value:
* <p>If the {@code mode} is present in the valid {@code modesForRescore} set, the method checks the value of
* {@code dimension}:
* <ul>
* <li>If {@code dimension} is greater than or equal to 1000, no oversampling is applied (oversample factor = 1.0).</li>
* <li>If {@code dimension} is greater than or equal to 768 but less than 1000, a 2x oversample factor is applied (oversample factor = 2.0).</li>
* <li>If {@code dimension} is less than 768, a 3x oversample factor is applied (oversample factor = 3.0).</li>
* <li>If {@code dimension} is less than or equal to 1000, it returns a {@link RescoreContext} with an
* oversample factor of 5.0f.</li>
* <li>If {@code dimension} is greater than 1000, it returns the default {@link RescoreContext} associated with
* the {@link CompressionLevel}. If no default is set, it falls back to {@link RescoreContext#getDefault()}.</li>
* </ul>
* If the {@code mode} is not present in the {@code modesForRescore} set, the method returns {@code null}.
* If the {@code mode} is not valid, the method returns {@code null}.
*
* @param mode The {@link Mode} for which to retrieve the {@link RescoreContext}.
* @param dimension The dimensional value that determines the {@link RescoreContext} behavior.
* @return A {@link RescoreContext} with the appropriate oversample factor based on the dimension, or {@code null} if the mode
* is not valid.
* @return A {@link RescoreContext} with an oversample factor of 5.0f if {@code dimension} is less than
* or equal to 1000, the default {@link RescoreContext} if greater, or {@code null} if the mode
* is invalid.
*/
public RescoreContext getDefaultRescoreContext(Mode mode, int dimension) {
if (modesForRescore.contains(mode)) {
// Adjust RescoreContext based on dimension
if (dimension >= RescoreContext.DIMENSION_THRESHOLD_1000) {
// No oversampling for dimensions >= 1000
return RescoreContext.builder().oversampleFactor(RescoreContext.OVERSAMPLE_FACTOR_1000).build();
} else if (dimension >= RescoreContext.DIMENSION_THRESHOLD_768) {
// 2x oversampling for dimensions >= 768 but < 1000
return RescoreContext.builder().oversampleFactor(RescoreContext.OVERSAMPLE_FACTOR_768).build();
if (dimension <= RescoreContext.DIMENSION_THRESHOLD) {
// For dimensions <= 1000, return a RescoreContext with 5.0f oversample factor
return RescoreContext.builder()
.oversampleFactor(RescoreContext.OVERSAMPLE_FACTOR_BELOW_DIMENSION_THRESHOLD)
.userProvided(false)
.build();
} else {
// 3x oversampling for dimensions < 768
return RescoreContext.builder().oversampleFactor(RescoreContext.OVERSAMPLE_FACTOR_BELOW_768).build();
return defaultRescoreContext;
}
}
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,11 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo
if (rescoreContext == null) {
perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, finalK);
} else {
int firstPassK = rescoreContext.getFirstPassK(finalK);
boolean isShardLevelRescoringEnabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(knnQuery.getIndexName());
int dimension = knnQuery.getQueryVector().length;
int firstPassK = rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension);
perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, firstPassK);
if (KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(knnQuery.getIndexName()) == false) {
if (isShardLevelRescoringEnabled == true) {
ResultUtil.reduceToTopK(perLeafResults, firstPassK);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,11 @@ public static RescoreContext streamInput(StreamInput in) throws IOException {
return null;
}
Float oversample = in.readOptionalFloat();
Boolean userProvided = in.readOptionalBoolean();
if (oversample == null) {
return null;
}
return RescoreContext.builder().oversampleFactor(oversample).build();
return RescoreContext.builder().oversampleFactor(oversample).userProvided(userProvided).build();
}

/**
Expand All @@ -106,6 +107,7 @@ public static void streamOutput(StreamOutput out, RescoreContext rescoreContext)
return;
}
out.writeOptionalFloat(rescoreContext == null ? null : rescoreContext.getOversampleFactor());
out.writeOptionalBoolean(rescoreContext == null ? null : rescoreContext.isUserProvided());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,74 @@ public final class RescoreContext {
@Builder.Default
private float oversampleFactor = DEFAULT_OVERSAMPLE_FACTOR;

/**
* Flag to track whether the oversample factor is user-provided or default. The Reason to introduce
* this is to set default when Shard Level rescoring is false,
* else we end up overriding user provided value in NativeEngineKnnVectorQuery
*
*
* This flag is crucial to differentiate between user-defined oversample factors and system-assigned
* default values. The behavior of oversampling logic, especially when shard-level rescoring is disabled,
* depends on whether the user explicitly provided an oversample factor or whether the system is using
* a default value.
*
* When shard-level rescoring is disabled, the system applies dimension-based oversampling logic,
* overriding any default values. However, if the user provides their own oversample factor, the system
* should respect the user’s input and avoid overriding it with the dimension-based logic.
*
* This flag is set to {@code true} when the oversample factor is provided by the user, ensuring
* that their value is not overridden. It is set to {@code false} when the oversample factor is
* determined by system defaults (e.g., through a compression level or automatic logic). The system
* then applies its own oversampling rules if necessary.
*
* Key scenarios:
* - If {@code userProvided} is {@code true} and shard-level rescoring is disabled, the user's
* oversample factor is used as is, without applying the dimension-based logic.
* - If {@code userProvided} is {@code false}, the system applies dimension-based oversampling
* when shard-level rescoring is disabled.
*
* This flag enables flexibility, allowing the system to handle both user-defined and default
* behaviors, ensuring the correct oversampling logic is applied based on the context.
*/
@Builder.Default
private boolean userProvided = true;

/**
*
* @return default RescoreContext
*/
public static RescoreContext getDefault() {
return RescoreContext.builder().build();
return RescoreContext.builder().oversampleFactor(DEFAULT_OVERSAMPLE_FACTOR).userProvided(false).build();
}

/**
* Gets the number of results to return for the first pass of rescoring.
* Calculates the number of results to return for the first pass of rescoring (firstPassK).
* This method considers whether shard-level rescoring is enabled and adjusts the oversample factor
* based on the vector dimension if shard-level rescoring is disabled.
*
* @param finalK The final number of results to return for the entire shard
* @return The number of results to return for the first pass of rescoring
* @param finalK The final number of results to return for the entire shard.
* @param isShardLevelRescoringEnabled A boolean flag indicating whether shard-level rescoring is enabled.
* If true, the dimension-based oversampling logic is bypassed.
* @param dimension The dimension of the vector. This is used to determine the oversampling factor when
* shard-level rescoring is disabled.
* @return The number of results to return for the first pass of rescoring, adjusted by the oversample factor.
*/
public int getFirstPassK(int finalK) {
public int getFirstPassK(int finalK, boolean isShardLevelRescoringEnabled, int dimension) {
// Only apply default dimension-based oversampling logic when:
// 1. Shard-level rescoring is disabled
// 2. The oversample factor was not provided by the user
if (!isShardLevelRescoringEnabled && !userProvided) {
// Apply new dimension-based oversampling logic when shard-level rescoring is disabled
if (dimension >= DIMENSION_THRESHOLD_1000) {
oversampleFactor = OVERSAMPLE_FACTOR_1000; // No oversampling for dimensions >= 1000
} else if (dimension >= DIMENSION_THRESHOLD_768) {
oversampleFactor = OVERSAMPLE_FACTOR_768; // 2x oversampling for dimensions >= 768 and < 1000
} else {
oversampleFactor = OVERSAMPLE_FACTOR_BELOW_768; // 3x oversampling for dimensions < 768
}
}
// The calculation for firstPassK remains the same, applying the oversample factor
return Math.min(MAX_FIRST_PASS_RESULTS, Math.max(MIN_FIRST_PASS_RESULTS, (int) Math.ceil(finalK * oversampleFactor)));
}

}
8 changes: 4 additions & 4 deletions src/test/java/org/opensearch/knn/index/KNNSettingsTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,15 @@ public void testGetEfSearch_whenEFSearchValueSetByUser_thenReturnValue() {
}

@SneakyThrows
public void testShardLevelRescoringDisabled_whenNoValuesProvidedByUser_thenDefaultSettingsUsed() {
public void testShardLevelRescoringEnabled_whenNoValuesProvidedByUser_thenDefaultSettingsUsed() {
Node mockNode = createMockNode(Collections.emptyMap());
mockNode.start();
ClusterService clusterService = mockNode.injector().getInstance(ClusterService.class);
mockNode.client().admin().cluster().state(new ClusterStateRequest()).actionGet();
mockNode.client().admin().indices().create(new CreateIndexRequest(INDEX_NAME)).actionGet();
KNNSettings.state().setClusterService(clusterService);

boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(INDEX_NAME);
boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(INDEX_NAME);
mockNode.close();
assertTrue(shardLevelRescoringDisabled);
}
Expand All @@ -183,12 +183,12 @@ public void testShardLevelRescoringDisabled_whenValueProvidedByUser_thenSettingA
KNNSettings.state().setClusterService(clusterService);

final Settings rescoringDisabledSetting = Settings.builder()
.put(KNNSettings.KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED, userDefinedRescoringDisabled)
.put(KNNSettings.KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_ENABLED, userDefinedRescoringDisabled)
.build();

mockNode.client().admin().indices().updateSettings(new UpdateSettingsRequest(rescoringDisabledSetting, INDEX_NAME)).actionGet();

boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(INDEX_NAME);
boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(INDEX_NAME);
mockNode.close();
assertEquals(userDefinedRescoringDisabled, shardLevelRescoringDisabled);
}
Expand Down
Loading

0 comments on commit 5ba078c

Please sign in to comment.