From bc96b8c7fa44683743b9046de49e2168fd8664ab Mon Sep 17 00:00:00 2001 From: VIKASH TIWARI Date: Thu, 3 Oct 2024 14:35:34 -0700 Subject: [PATCH] Score Fix for Binary Quantized Vector and Seeting Default value in case of shard level rescoring is disabled for oversampling factor Signed-off-by: VIKASH TIWARI --- .../org/opensearch/knn/index/KNNSettings.java | 21 ++--- .../knn/index/mapper/CompressionLevel.java | 39 ++++----- .../nativelib/NativeEngineKnnVectorQuery.java | 6 +- .../knn/index/query/parser/RescoreParser.java | 4 +- .../index/query/rescore/RescoreContext.java | 63 +++++++++++++-- .../knn/index/KNNSettingsTests.java | 8 +- .../index/mapper/CompressionLevelTests.java | 76 ++++++------------ .../NativeEngineKNNVectorQueryTests.java | 6 +- .../query/rescore/RescoreContextTests.java | 79 ++++++++++++------- 9 files changed, 181 insertions(+), 121 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index 1753140e63..b53570feba 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -88,10 +88,11 @@ public class KNNSettings { public static final String QUANTIZATION_STATE_CACHE_SIZE_LIMIT = "knn.quantization.cache.size.limit"; public static final String QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES = "knn.quantization.cache.expiry.minutes"; public static final String KNN_FAISS_AVX512_DISABLED = "knn.faiss.avx512.disabled"; - public static final String KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED = "index.knn.disk.vector.shard_level_rescoring_disabled"; + public static final String KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_ENABLED = "index.knn.disk.vector.shard_level_rescoring_enabled"; /** * Default setting values + * */ public static final boolean KNN_DEFAULT_FAISS_AVX2_DISABLED_VALUE = false; public static final boolean KNN_DEFAULT_FAISS_AVX512_DISABLED_VALUE = false; @@ -113,7 +114,7 @@ public class KNNSettings { public static final Integer KNN_MAX_QUANTIZATION_STATE_CACHE_SIZE_LIMIT_PERCENTAGE = 10; // Quantization state cache limit cannot exceed // 10% of the JVM heap public static final Integer KNN_DEFAULT_QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES = 60; - public static final boolean KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_VALUE = true; + public static final boolean KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_ENABLED_VALUE = true; /** * Settings Definition @@ -131,9 +132,9 @@ public class KNNSettings { * * @see Setting#boolSetting(String, boolean, Setting.Property...) */ - public static final Setting KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING = Setting.boolSetting( - KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED, - KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_VALUE, + public static final Setting KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_ENABLED_SETTING = Setting.boolSetting( + KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_ENABLED, + KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_ENABLED_VALUE, IndexScope, Dynamic ); @@ -475,8 +476,8 @@ private Setting getSetting(String key) { return QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING; } - if (KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED.equals(key)) { - return KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING; + if (KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_ENABLED.equals(key)) { + return KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_ENABLED_SETTING; } throw new IllegalArgumentException("Cannot find setting by key [" + key + "]"); @@ -501,7 +502,7 @@ public List> getSettings() { KNN_FAISS_AVX512_DISABLED_SETTING, QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING, - KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING + KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_ENABLED_SETTING ); return Stream.concat(settings.stream(), Stream.concat(getFeatureFlags().stream(), dynamicCacheSettings.values().stream())) .collect(Collectors.toList()); @@ -554,12 +555,12 @@ public static Integer getFilteredExactSearchThreshold(final String indexName) { .getAsInt(ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE); } - public static boolean isShardLevelRescoringDisabledForDiskBasedVector(String indexName) { + public static boolean isShardLevelRescoringEnabledForDiskBasedVector(String indexName) { return KNNSettings.state().clusterService.state() .getMetadata() .index(indexName) .getSettings() - .getAsBoolean(KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED, true); + .getAsBoolean(KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_ENABLED, true); } public void initialize(Client client, ClusterService clusterService) { diff --git a/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java b/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java index c9a169efca..ab583a2e08 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java +++ b/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java @@ -25,9 +25,9 @@ public enum CompressionLevel { x1(1, "1x", null, Collections.emptySet()), x2(2, "2x", null, Collections.emptySet()), x4(4, "4x", null, Collections.emptySet()), - x8(8, "8x", new RescoreContext(2.0f), Set.of(Mode.ON_DISK)), - x16(16, "16x", new RescoreContext(3.0f), Set.of(Mode.ON_DISK)), - x32(32, "32x", new RescoreContext(3.0f), Set.of(Mode.ON_DISK)); + x8(8, "8x", new RescoreContext(2.0f, false), Set.of(Mode.ON_DISK)), + x16(16, "16x", new RescoreContext(3.0f, false), Set.of(Mode.ON_DISK)), + x32(32, "32x", new RescoreContext(3.0f, false), Set.of(Mode.ON_DISK)); // Internally, an empty string is easier to deal with them null. However, from the mapping, // we do not want users to pass in the empty string and instead want null. So we make the conversion here @@ -97,32 +97,33 @@ public static boolean isConfigured(CompressionLevel compressionLevel) { /** * Returns the appropriate {@link RescoreContext} based on the given {@code mode} and {@code dimension}. * - *

If the {@code mode} is present in the valid {@code modesForRescore} set, the method adjusts the oversample factor based on the - * {@code dimension} value: + *

If the {@code mode} is present in the valid {@code modesForRescore} set, the method checks the value of + * {@code dimension}: *

    - *
  • If {@code dimension} is greater than or equal to 1000, no oversampling is applied (oversample factor = 1.0).
  • - *
  • If {@code dimension} is greater than or equal to 768 but less than 1000, a 2x oversample factor is applied (oversample factor = 2.0).
  • - *
  • If {@code dimension} is less than 768, a 3x oversample factor is applied (oversample factor = 3.0).
  • + *
  • If {@code dimension} is less than or equal to 1000, it returns a {@link RescoreContext} with an + * oversample factor of 5.0f.
  • + *
  • If {@code dimension} is greater than 1000, it returns the default {@link RescoreContext} associated with + * the {@link CompressionLevel}. If no default is set, it falls back to {@link RescoreContext#getDefault()}.
  • *
- * If the {@code mode} is not present in the {@code modesForRescore} set, the method returns {@code null}. + * If the {@code mode} is not valid, the method returns {@code null}. * * @param mode The {@link Mode} for which to retrieve the {@link RescoreContext}. * @param dimension The dimensional value that determines the {@link RescoreContext} behavior. - * @return A {@link RescoreContext} with the appropriate oversample factor based on the dimension, or {@code null} if the mode - * is not valid. + * @return A {@link RescoreContext} with an oversample factor of 5.0f if {@code dimension} is less than + * or equal to 1000, the default {@link RescoreContext} if greater, or {@code null} if the mode + * is invalid. */ public RescoreContext getDefaultRescoreContext(Mode mode, int dimension) { if (modesForRescore.contains(mode)) { // Adjust RescoreContext based on dimension - if (dimension >= RescoreContext.DIMENSION_THRESHOLD_1000) { - // No oversampling for dimensions >= 1000 - return RescoreContext.builder().oversampleFactor(RescoreContext.OVERSAMPLE_FACTOR_1000).build(); - } else if (dimension >= RescoreContext.DIMENSION_THRESHOLD_768) { - // 2x oversampling for dimensions >= 768 but < 1000 - return RescoreContext.builder().oversampleFactor(RescoreContext.OVERSAMPLE_FACTOR_768).build(); + if (dimension <= RescoreContext.DIMENSION_THRESHOLD) { + // For dimensions <= 1000, return a RescoreContext with 5.0f oversample factor + return RescoreContext.builder() + .oversampleFactor(RescoreContext.OVERSAMPLE_FACTOR_BELOW_DIMENSION_THRESHOLD) + .userProvided(false) + .build(); } else { - // 3x oversampling for dimensions < 768 - return RescoreContext.builder().oversampleFactor(RescoreContext.OVERSAMPLE_FACTOR_BELOW_768).build(); + return defaultRescoreContext; } } return null; diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java index adb2875d5e..c97a0d061c 100644 --- a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java @@ -61,9 +61,11 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo if (rescoreContext == null) { perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, finalK); } else { - int firstPassK = rescoreContext.getFirstPassK(finalK); + boolean isShardLevelRescoringEnabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(knnQuery.getIndexName()); + int dimension = knnQuery.getQueryVector().length; + int firstPassK = rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension); perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, firstPassK); - if (KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(knnQuery.getIndexName()) == false) { + if (isShardLevelRescoringEnabled == true) { ResultUtil.reduceToTopK(perLeafResults, firstPassK); } diff --git a/src/main/java/org/opensearch/knn/index/query/parser/RescoreParser.java b/src/main/java/org/opensearch/knn/index/query/parser/RescoreParser.java index 06062aed1c..c174e2f2ea 100644 --- a/src/main/java/org/opensearch/knn/index/query/parser/RescoreParser.java +++ b/src/main/java/org/opensearch/knn/index/query/parser/RescoreParser.java @@ -89,10 +89,11 @@ public static RescoreContext streamInput(StreamInput in) throws IOException { return null; } Float oversample = in.readOptionalFloat(); + Boolean userProvided = in.readOptionalBoolean(); if (oversample == null) { return null; } - return RescoreContext.builder().oversampleFactor(oversample).build(); + return RescoreContext.builder().oversampleFactor(oversample).userProvided(userProvided).build(); } /** @@ -106,6 +107,7 @@ public static void streamOutput(StreamOutput out, RescoreContext rescoreContext) return; } out.writeOptionalFloat(rescoreContext == null ? null : rescoreContext.getOversampleFactor()); + out.writeOptionalBoolean(rescoreContext == null ? null : rescoreContext.isUserProvided()); } /** diff --git a/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java b/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java index a2563b2a61..0f8c594990 100644 --- a/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java +++ b/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java @@ -39,21 +39,74 @@ public final class RescoreContext { @Builder.Default private float oversampleFactor = DEFAULT_OVERSAMPLE_FACTOR; + /** + * Flag to track whether the oversample factor is user-provided or default. The Reason to introduce + * this is to set default when Shard Level rescoring is false, + * else we end up overriding user provided value in NativeEngineKnnVectorQuery + * + * + * This flag is crucial to differentiate between user-defined oversample factors and system-assigned + * default values. The behavior of oversampling logic, especially when shard-level rescoring is disabled, + * depends on whether the user explicitly provided an oversample factor or whether the system is using + * a default value. + * + * When shard-level rescoring is disabled, the system applies dimension-based oversampling logic, + * overriding any default values. However, if the user provides their own oversample factor, the system + * should respect the user’s input and avoid overriding it with the dimension-based logic. + * + * This flag is set to {@code true} when the oversample factor is provided by the user, ensuring + * that their value is not overridden. It is set to {@code false} when the oversample factor is + * determined by system defaults (e.g., through a compression level or automatic logic). The system + * then applies its own oversampling rules if necessary. + * + * Key scenarios: + * - If {@code userProvided} is {@code true} and shard-level rescoring is disabled, the user's + * oversample factor is used as is, without applying the dimension-based logic. + * - If {@code userProvided} is {@code false}, the system applies dimension-based oversampling + * when shard-level rescoring is disabled. + * + * This flag enables flexibility, allowing the system to handle both user-defined and default + * behaviors, ensuring the correct oversampling logic is applied based on the context. + */ + @Builder.Default + private boolean userProvided = true; + /** * * @return default RescoreContext */ public static RescoreContext getDefault() { - return RescoreContext.builder().build(); + return RescoreContext.builder().oversampleFactor(DEFAULT_OVERSAMPLE_FACTOR).userProvided(false).build(); } /** - * Gets the number of results to return for the first pass of rescoring. + * Calculates the number of results to return for the first pass of rescoring (firstPassK). + * This method considers whether shard-level rescoring is enabled and adjusts the oversample factor + * based on the vector dimension if shard-level rescoring is disabled. * - * @param finalK The final number of results to return for the entire shard - * @return The number of results to return for the first pass of rescoring + * @param finalK The final number of results to return for the entire shard. + * @param isShardLevelRescoringEnabled A boolean flag indicating whether shard-level rescoring is enabled. + * If true, the dimension-based oversampling logic is bypassed. + * @param dimension The dimension of the vector. This is used to determine the oversampling factor when + * shard-level rescoring is disabled. + * @return The number of results to return for the first pass of rescoring, adjusted by the oversample factor. */ - public int getFirstPassK(int finalK) { + public int getFirstPassK(int finalK, boolean isShardLevelRescoringEnabled, int dimension) { + // Only apply default dimension-based oversampling logic when: + // 1. Shard-level rescoring is disabled + // 2. The oversample factor was not provided by the user + if (!isShardLevelRescoringEnabled && !userProvided) { + // Apply new dimension-based oversampling logic when shard-level rescoring is disabled + if (dimension >= DIMENSION_THRESHOLD_1000) { + oversampleFactor = OVERSAMPLE_FACTOR_1000; // No oversampling for dimensions >= 1000 + } else if (dimension >= DIMENSION_THRESHOLD_768) { + oversampleFactor = OVERSAMPLE_FACTOR_768; // 2x oversampling for dimensions >= 768 and < 1000 + } else { + oversampleFactor = OVERSAMPLE_FACTOR_BELOW_768; // 3x oversampling for dimensions < 768 + } + } + // The calculation for firstPassK remains the same, applying the oversample factor return Math.min(MAX_FIRST_PASS_RESULTS, Math.max(MIN_FIRST_PASS_RESULTS, (int) Math.ceil(finalK * oversampleFactor))); } + } diff --git a/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java b/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java index fd25699ccd..d22ce11f8e 100644 --- a/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java @@ -159,7 +159,7 @@ public void testGetEfSearch_whenEFSearchValueSetByUser_thenReturnValue() { } @SneakyThrows - public void testShardLevelRescoringDisabled_whenNoValuesProvidedByUser_thenDefaultSettingsUsed() { + public void testShardLevelRescoringEnabled_whenNoValuesProvidedByUser_thenDefaultSettingsUsed() { Node mockNode = createMockNode(Collections.emptyMap()); mockNode.start(); ClusterService clusterService = mockNode.injector().getInstance(ClusterService.class); @@ -167,7 +167,7 @@ public void testShardLevelRescoringDisabled_whenNoValuesProvidedByUser_thenDefau mockNode.client().admin().indices().create(new CreateIndexRequest(INDEX_NAME)).actionGet(); KNNSettings.state().setClusterService(clusterService); - boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(INDEX_NAME); + boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(INDEX_NAME); mockNode.close(); assertTrue(shardLevelRescoringDisabled); } @@ -183,12 +183,12 @@ public void testShardLevelRescoringDisabled_whenValueProvidedByUser_thenSettingA KNNSettings.state().setClusterService(clusterService); final Settings rescoringDisabledSetting = Settings.builder() - .put(KNNSettings.KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED, userDefinedRescoringDisabled) + .put(KNNSettings.KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_ENABLED, userDefinedRescoringDisabled) .build(); mockNode.client().admin().indices().updateSettings(new UpdateSettingsRequest(rescoringDisabledSetting, INDEX_NAME)).actionGet(); - boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(INDEX_NAME); + boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(INDEX_NAME); mockNode.close(); assertEquals(userDefinedRescoringDisabled, shardLevelRescoringDisabled); } diff --git a/src/test/java/org/opensearch/knn/index/mapper/CompressionLevelTests.java b/src/test/java/org/opensearch/knn/index/mapper/CompressionLevelTests.java index 57372b11ee..e882d6697a 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/CompressionLevelTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/CompressionLevelTests.java @@ -45,83 +45,57 @@ public void testGetDefaultRescoreContext() { // Test rescore context for ON_DISK mode Mode mode = Mode.ON_DISK; - // Test various dimensions based on the updated oversampling logic - int belowThresholdDimension = 500; // A dimension below 768 - int between768and1000Dimension = 800; // A dimension between 768 and 1000 - int above1000Dimension = 1500; // A dimension above 1000 + int belowThresholdDimension = 500; // A dimension below the threshold + int aboveThresholdDimension = 1500; // A dimension above the threshold - // Compression level x32 with dimension < 768 should have an oversample factor of 3.0f + // x32 with dimension <= 1000 should have an oversample factor of 5.0f RescoreContext rescoreContext = CompressionLevel.x32.getDefaultRescoreContext(mode, belowThresholdDimension); assertNotNull(rescoreContext); - assertEquals(3.0f, rescoreContext.getOversampleFactor(), 0.0f); - - // Compression level x32 with dimension between 768 and 1000 should have an oversample factor of 2.0f - rescoreContext = CompressionLevel.x32.getDefaultRescoreContext(mode, between768and1000Dimension); - assertNotNull(rescoreContext); - assertEquals(2.0f, rescoreContext.getOversampleFactor(), 0.0f); + assertEquals(5.0f, rescoreContext.getOversampleFactor(), 0.0f); - // Compression level x32 with dimension > 1000 should have no oversampling (1.0f) - rescoreContext = CompressionLevel.x32.getDefaultRescoreContext(mode, above1000Dimension); - assertNotNull(rescoreContext); - assertEquals(1.0f, rescoreContext.getOversampleFactor(), 0.0f); - - // Compression level x16 with dimension < 768 should have an oversample factor of 3.0f - rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, belowThresholdDimension); + // x32 with dimension > 1000 should have an oversample factor of 3.0f + rescoreContext = CompressionLevel.x32.getDefaultRescoreContext(mode, aboveThresholdDimension); assertNotNull(rescoreContext); assertEquals(3.0f, rescoreContext.getOversampleFactor(), 0.0f); - // Compression level x16 with dimension between 768 and 1000 should have an oversample factor of 2.0f - rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, between768and1000Dimension); + // x16 with dimension <= 1000 should have an oversample factor of 5.0f + rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, belowThresholdDimension); assertNotNull(rescoreContext); - assertEquals(2.0f, rescoreContext.getOversampleFactor(), 0.0f); + assertEquals(5.0f, rescoreContext.getOversampleFactor(), 0.0f); - // Compression level x16 with dimension > 1000 should have no oversampling (1.0f) - rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, above1000Dimension); + // x16 with dimension > 1000 should have an oversample factor of 3.0f + rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, aboveThresholdDimension); assertNotNull(rescoreContext); - assertEquals(1.0f, rescoreContext.getOversampleFactor(), 0.0f); + assertEquals(3.0f, rescoreContext.getOversampleFactor(), 0.0f); - // Compression level x8 with dimension < 768 should have an oversample factor of 3.0f + // x8 with dimension <= 1000 should have an oversample factor of 5.0f rescoreContext = CompressionLevel.x8.getDefaultRescoreContext(mode, belowThresholdDimension); assertNotNull(rescoreContext); - assertEquals(3.0f, rescoreContext.getOversampleFactor(), 0.0f); - - // Compression level x8 with dimension between 768 and 1000 should have an oversample factor of 2.0f - rescoreContext = CompressionLevel.x8.getDefaultRescoreContext(mode, between768and1000Dimension); + assertEquals(5.0f, rescoreContext.getOversampleFactor(), 0.0f); + // x8 with dimension > 1000 should have an oversample factor of 2.0f + rescoreContext = CompressionLevel.x8.getDefaultRescoreContext(mode, aboveThresholdDimension); assertNotNull(rescoreContext); assertEquals(2.0f, rescoreContext.getOversampleFactor(), 0.0f); - // Compression level x8 with dimension > 1000 should have no oversampling (1.0f) - rescoreContext = CompressionLevel.x8.getDefaultRescoreContext(mode, above1000Dimension); - assertNotNull(rescoreContext); - assertEquals(1.0f, rescoreContext.getOversampleFactor(), 0.0f); - - // Compression level x4 with dimension < 768 should return null (no RescoreContext) + // x4 with dimension <= 1000 should have an oversample factor of 5.0f (though it doesn't have its own RescoreContext) rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, belowThresholdDimension); assertNull(rescoreContext); - - // Compression level x4 with dimension > 1000 should return null (no RescoreContext) - rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, above1000Dimension); + // x4 with dimension > 1000 should return null (no RescoreContext is configured for x4) + rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, aboveThresholdDimension); assertNull(rescoreContext); - - // Compression level x2 with dimension < 768 should return null + // Other compression levels should behave similarly with respect to dimension rescoreContext = CompressionLevel.x2.getDefaultRescoreContext(mode, belowThresholdDimension); assertNull(rescoreContext); - - // Compression level x2 with dimension > 1000 should return null - rescoreContext = CompressionLevel.x2.getDefaultRescoreContext(mode, above1000Dimension); + // x2 with dimension > 1000 should return null + rescoreContext = CompressionLevel.x2.getDefaultRescoreContext(mode, aboveThresholdDimension); assertNull(rescoreContext); - - // Compression level x1 with dimension < 768 should return null rescoreContext = CompressionLevel.x1.getDefaultRescoreContext(mode, belowThresholdDimension); assertNull(rescoreContext); - - // Compression level x1 with dimension > 1000 should return null - rescoreContext = CompressionLevel.x1.getDefaultRescoreContext(mode, above1000Dimension); + // x1 with dimension > 1000 should return null + rescoreContext = CompressionLevel.x1.getDefaultRescoreContext(mode, aboveThresholdDimension); assertNull(rescoreContext); - - // NOT_CONFIGURED mode should return null for any dimension + // NOT_CONFIGURED with dimension <= 1000 should return a RescoreContext with an oversample factor of 5.0f rescoreContext = CompressionLevel.NOT_CONFIGURED.getDefaultRescoreContext(mode, belowThresholdDimension); assertNull(rescoreContext); } - } diff --git a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java index 7fd96c6df4..53873e15f6 100644 --- a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java @@ -103,6 +103,8 @@ public void setUp() throws Exception { // Set ClusterService in KNNSettings KNNSettings.state().setClusterService(clusterService); + when(knnQuery.getQueryVector()).thenReturn(new float[] { 1.0f, 2.0f, 3.0f }); // Example vector + } @SneakyThrows @@ -166,7 +168,7 @@ public void testRescoreWhenShardLevelRescoringEnabled() { ) { // When shard-level re-scoring is enabled - mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(any())).thenReturn(false); + mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(any())).thenReturn(true); // Mock ResultUtil to return valid TopDocs mockedResultUtil.when(() -> ResultUtil.resultMapToTopDocs(any(), anyInt())) @@ -250,7 +252,7 @@ public void testRescore() { ) { // When shard-level re-scoring is enabled - mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(any())).thenReturn(true); + mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(any())).thenReturn(true); mockedResultUtil.when(() -> ResultUtil.reduceToTopK(any(), anyInt())).thenAnswer(InvocationOnMock::callRealMethod); mockedResultUtil.when(() -> ResultUtil.resultMapToTopDocs(eq(rescoredLeaf1Results), anyInt())).thenAnswer(t -> topDocs1); diff --git a/src/test/java/org/opensearch/knn/index/query/rescore/RescoreContextTests.java b/src/test/java/org/opensearch/knn/index/query/rescore/RescoreContextTests.java index fd94667dbc..2b309e4abd 100644 --- a/src/test/java/org/opensearch/knn/index/query/rescore/RescoreContextTests.java +++ b/src/test/java/org/opensearch/knn/index/query/rescore/RescoreContextTests.java @@ -14,47 +14,72 @@ public class RescoreContextTests extends KNNTestCase { public void testGetFirstPassK() { float oversample = 2.6f; - RescoreContext rescoreContext = RescoreContext.builder().oversampleFactor(oversample).build(); + RescoreContext rescoreContext = RescoreContext.builder().oversampleFactor(oversample).userProvided(true).build(); int finalK = 100; - assertEquals(260, rescoreContext.getFirstPassK(finalK)); - finalK = 1; - assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK)); - finalK = 0; - assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK)); - finalK = MAX_FIRST_PASS_RESULTS; - assertEquals(MAX_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK)); - } - - public void testGetFirstPassKWithMinPassK() { - float oversample = 2.6f; - RescoreContext rescoreContext = RescoreContext.builder().oversampleFactor(oversample).build(); + boolean isShardLevelRescoringEnabled = true; + int dimension = 500; - // Case 1: Test with a finalK that results in a value greater than MIN_FIRST_PASS_RESULTS - int finalK = 100; - assertEquals(260, rescoreContext.getFirstPassK(finalK)); + // Case 1: Test with standard oversample factor when shard-level rescoring is enabled + assertEquals(260, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension)); // Case 2: Test with a very small finalK that should result in a value less than MIN_FIRST_PASS_RESULTS finalK = 1; - assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK)); + assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension)); - // Case 3: Test with finalK = 0, should return 0 + // Case 3: Test with finalK = 0, should return MIN_FIRST_PASS_RESULTS finalK = 0; - assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK)); + assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension)); // Case 4: Test with finalK = MAX_FIRST_PASS_RESULTS, should cap at MAX_FIRST_PASS_RESULTS finalK = MAX_FIRST_PASS_RESULTS; - assertEquals(MAX_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK)); + assertEquals(MAX_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension)); + } - // Case 5: Test where finalK * oversample is smaller than MIN_FIRST_PASS_RESULTS + public void testGetFirstPassKWithDimensionBasedOversampling() { + int finalK = 100; + int dimension; + + // Case 1: Test no oversampling for dimensions >= 1000 when shard-level rescoring is disabled + dimension = 1000; + RescoreContext rescoreContext = RescoreContext.builder().userProvided(false).build(); // Ensuring dimension-based logic applies + assertEquals(100, rescoreContext.getFirstPassK(finalK, false, dimension)); // No oversampling + + // Case 2: Test 2x oversampling for dimensions >= 768 but < 1000 when shard-level rescoring is disabled + dimension = 800; + rescoreContext = RescoreContext.builder().userProvided(false).build(); // Ensure previous values don't carry over + assertEquals(200, rescoreContext.getFirstPassK(finalK, false, dimension)); // 2x oversampling + + // Case 3: Test 3x oversampling for dimensions < 768 when shard-level rescoring is disabled + dimension = 700; + rescoreContext = RescoreContext.builder().userProvided(false).build(); // Ensure previous values don't carry over + assertEquals(300, rescoreContext.getFirstPassK(finalK, false, dimension)); // 3x oversampling + + // Case 4: Shard-level rescoring enabled, oversample factor should be used as provided by the user (ignore dimension) + rescoreContext = RescoreContext.builder().oversampleFactor(5.0f).userProvided(true).build(); // Provided by user + dimension = 500; + assertEquals(500, rescoreContext.getFirstPassK(finalK, true, dimension)); // User-defined oversample factor should be used + + // Case 5: Test finalK where oversampling factor results in a value less than MIN_FIRST_PASS_RESULTS finalK = 10; - oversample = 0.5f; // This will result in 5, which is less than MIN_FIRST_PASS_RESULTS - rescoreContext = RescoreContext.builder().oversampleFactor(oversample).build(); - assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK)); + dimension = 700; + rescoreContext = RescoreContext.builder().userProvided(false).build(); // Ensure dimension-based logic applies + assertEquals(100, rescoreContext.getFirstPassK(finalK, false, dimension)); // 3x oversampling results in 30 + } + + public void testGetFirstPassKWithMinPassK() { + float oversample = 0.5f; + RescoreContext rescoreContext = RescoreContext.builder().oversampleFactor(oversample).userProvided(true).build(); // User provided + boolean isShardLevelRescoringEnabled = false; + + // Case 1: Test where finalK * oversample is smaller than MIN_FIRST_PASS_RESULTS + int finalK = 10; + int dimension = 700; + assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension)); - // Case 6: Test where finalK * oversample results in exactly MIN_FIRST_PASS_RESULTS + // Case 2: Test where finalK * oversample results in exactly MIN_FIRST_PASS_RESULTS finalK = 100; oversample = 1.0f; // This will result in exactly 100 (MIN_FIRST_PASS_RESULTS) - rescoreContext = RescoreContext.builder().oversampleFactor(oversample).build(); - assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK)); + rescoreContext = RescoreContext.builder().oversampleFactor(oversample).userProvided(true).build(); // User provided + assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension)); } }