diff --git a/CHANGELOG.md b/CHANGELOG.md
index fa86cbe3f..f615a78fb 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -29,6 +29,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Adding Support to Enable/Disble Share level Rescoring and Update Oversampling Factor[#2172](https://github.com/opensearch-project/k-NN/pull/2172)
### Bug Fixes
* KNN80DocValues should only be considered for BinaryDocValues fields [#2147](https://github.com/opensearch-project/k-NN/pull/2147)
+* Score Fix for Binary Quantized Vector and Setting Default value in case of shard level rescoring is disabled for oversampling factor[#2183](https://github.com/opensearch-project/k-NN/pull/2183)
### Infrastructure
### Documentation
### Maintenance
diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java
index 1753140e6..f5980879a 100644
--- a/src/main/java/org/opensearch/knn/index/KNNSettings.java
+++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java
@@ -92,6 +92,7 @@ public class KNNSettings {
/**
* Default setting values
+ *
*/
public static final boolean KNN_DEFAULT_FAISS_AVX2_DISABLED_VALUE = false;
public static final boolean KNN_DEFAULT_FAISS_AVX512_DISABLED_VALUE = false;
@@ -113,7 +114,7 @@ public class KNNSettings {
public static final Integer KNN_MAX_QUANTIZATION_STATE_CACHE_SIZE_LIMIT_PERCENTAGE = 10; // Quantization state cache limit cannot exceed
// 10% of the JVM heap
public static final Integer KNN_DEFAULT_QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES = 60;
- public static final boolean KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_VALUE = true;
+ public static final boolean KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_VALUE = false;
/**
* Settings Definition
@@ -554,12 +555,12 @@ public static Integer getFilteredExactSearchThreshold(final String indexName) {
.getAsInt(ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE);
}
- public static boolean isShardLevelRescoringDisabledForDiskBasedVector(String indexName) {
+ public static boolean isShardLevelRescoringEnabledForDiskBasedVector(String indexName) {
return KNNSettings.state().clusterService.state()
.getMetadata()
.index(indexName)
.getSettings()
- .getAsBoolean(KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED, true);
+ .getAsBoolean(KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED, false);
}
public void initialize(Client client, ClusterService clusterService) {
diff --git a/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java b/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java
index c9a169efc..ab583a2e0 100644
--- a/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java
+++ b/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java
@@ -25,9 +25,9 @@ public enum CompressionLevel {
x1(1, "1x", null, Collections.emptySet()),
x2(2, "2x", null, Collections.emptySet()),
x4(4, "4x", null, Collections.emptySet()),
- x8(8, "8x", new RescoreContext(2.0f), Set.of(Mode.ON_DISK)),
- x16(16, "16x", new RescoreContext(3.0f), Set.of(Mode.ON_DISK)),
- x32(32, "32x", new RescoreContext(3.0f), Set.of(Mode.ON_DISK));
+ x8(8, "8x", new RescoreContext(2.0f, false), Set.of(Mode.ON_DISK)),
+ x16(16, "16x", new RescoreContext(3.0f, false), Set.of(Mode.ON_DISK)),
+ x32(32, "32x", new RescoreContext(3.0f, false), Set.of(Mode.ON_DISK));
// Internally, an empty string is easier to deal with them null. However, from the mapping,
// we do not want users to pass in the empty string and instead want null. So we make the conversion here
@@ -97,32 +97,33 @@ public static boolean isConfigured(CompressionLevel compressionLevel) {
/**
* Returns the appropriate {@link RescoreContext} based on the given {@code mode} and {@code dimension}.
*
- *
If the {@code mode} is present in the valid {@code modesForRescore} set, the method adjusts the oversample factor based on the
- * {@code dimension} value:
+ *
If the {@code mode} is present in the valid {@code modesForRescore} set, the method checks the value of
+ * {@code dimension}:
*
- * - If {@code dimension} is greater than or equal to 1000, no oversampling is applied (oversample factor = 1.0).
- * - If {@code dimension} is greater than or equal to 768 but less than 1000, a 2x oversample factor is applied (oversample factor = 2.0).
- * - If {@code dimension} is less than 768, a 3x oversample factor is applied (oversample factor = 3.0).
+ * - If {@code dimension} is less than or equal to 1000, it returns a {@link RescoreContext} with an
+ * oversample factor of 5.0f.
+ * - If {@code dimension} is greater than 1000, it returns the default {@link RescoreContext} associated with
+ * the {@link CompressionLevel}. If no default is set, it falls back to {@link RescoreContext#getDefault()}.
*
- * If the {@code mode} is not present in the {@code modesForRescore} set, the method returns {@code null}.
+ * If the {@code mode} is not valid, the method returns {@code null}.
*
* @param mode The {@link Mode} for which to retrieve the {@link RescoreContext}.
* @param dimension The dimensional value that determines the {@link RescoreContext} behavior.
- * @return A {@link RescoreContext} with the appropriate oversample factor based on the dimension, or {@code null} if the mode
- * is not valid.
+ * @return A {@link RescoreContext} with an oversample factor of 5.0f if {@code dimension} is less than
+ * or equal to 1000, the default {@link RescoreContext} if greater, or {@code null} if the mode
+ * is invalid.
*/
public RescoreContext getDefaultRescoreContext(Mode mode, int dimension) {
if (modesForRescore.contains(mode)) {
// Adjust RescoreContext based on dimension
- if (dimension >= RescoreContext.DIMENSION_THRESHOLD_1000) {
- // No oversampling for dimensions >= 1000
- return RescoreContext.builder().oversampleFactor(RescoreContext.OVERSAMPLE_FACTOR_1000).build();
- } else if (dimension >= RescoreContext.DIMENSION_THRESHOLD_768) {
- // 2x oversampling for dimensions >= 768 but < 1000
- return RescoreContext.builder().oversampleFactor(RescoreContext.OVERSAMPLE_FACTOR_768).build();
+ if (dimension <= RescoreContext.DIMENSION_THRESHOLD) {
+ // For dimensions <= 1000, return a RescoreContext with 5.0f oversample factor
+ return RescoreContext.builder()
+ .oversampleFactor(RescoreContext.OVERSAMPLE_FACTOR_BELOW_DIMENSION_THRESHOLD)
+ .userProvided(false)
+ .build();
} else {
- // 3x oversampling for dimensions < 768
- return RescoreContext.builder().oversampleFactor(RescoreContext.OVERSAMPLE_FACTOR_BELOW_768).build();
+ return defaultRescoreContext;
}
}
return null;
diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java
index 0fd2fddf7..37695c208 100644
--- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java
+++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java
@@ -376,6 +376,10 @@ private Map doANNSearch(
return null;
}
+ if (quantizedVector != null) {
+ return Arrays.stream(results)
+ .collect(Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), SpaceType.HAMMING)));
+ }
return Arrays.stream(results)
.collect(Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), spaceType)));
}
diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java
index adb2875d5..c97a0d061 100644
--- a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java
+++ b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java
@@ -61,9 +61,11 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo
if (rescoreContext == null) {
perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, finalK);
} else {
- int firstPassK = rescoreContext.getFirstPassK(finalK);
+ boolean isShardLevelRescoringEnabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(knnQuery.getIndexName());
+ int dimension = knnQuery.getQueryVector().length;
+ int firstPassK = rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension);
perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, firstPassK);
- if (KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(knnQuery.getIndexName()) == false) {
+ if (isShardLevelRescoringEnabled == true) {
ResultUtil.reduceToTopK(perLeafResults, firstPassK);
}
diff --git a/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java b/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java
index a2563b2a6..0f8c59499 100644
--- a/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java
+++ b/src/main/java/org/opensearch/knn/index/query/rescore/RescoreContext.java
@@ -39,21 +39,74 @@ public final class RescoreContext {
@Builder.Default
private float oversampleFactor = DEFAULT_OVERSAMPLE_FACTOR;
+ /**
+ * Flag to track whether the oversample factor is user-provided or default. The Reason to introduce
+ * this is to set default when Shard Level rescoring is false,
+ * else we end up overriding user provided value in NativeEngineKnnVectorQuery
+ *
+ *
+ * This flag is crucial to differentiate between user-defined oversample factors and system-assigned
+ * default values. The behavior of oversampling logic, especially when shard-level rescoring is disabled,
+ * depends on whether the user explicitly provided an oversample factor or whether the system is using
+ * a default value.
+ *
+ * When shard-level rescoring is disabled, the system applies dimension-based oversampling logic,
+ * overriding any default values. However, if the user provides their own oversample factor, the system
+ * should respect the user’s input and avoid overriding it with the dimension-based logic.
+ *
+ * This flag is set to {@code true} when the oversample factor is provided by the user, ensuring
+ * that their value is not overridden. It is set to {@code false} when the oversample factor is
+ * determined by system defaults (e.g., through a compression level or automatic logic). The system
+ * then applies its own oversampling rules if necessary.
+ *
+ * Key scenarios:
+ * - If {@code userProvided} is {@code true} and shard-level rescoring is disabled, the user's
+ * oversample factor is used as is, without applying the dimension-based logic.
+ * - If {@code userProvided} is {@code false}, the system applies dimension-based oversampling
+ * when shard-level rescoring is disabled.
+ *
+ * This flag enables flexibility, allowing the system to handle both user-defined and default
+ * behaviors, ensuring the correct oversampling logic is applied based on the context.
+ */
+ @Builder.Default
+ private boolean userProvided = true;
+
/**
*
* @return default RescoreContext
*/
public static RescoreContext getDefault() {
- return RescoreContext.builder().build();
+ return RescoreContext.builder().oversampleFactor(DEFAULT_OVERSAMPLE_FACTOR).userProvided(false).build();
}
/**
- * Gets the number of results to return for the first pass of rescoring.
+ * Calculates the number of results to return for the first pass of rescoring (firstPassK).
+ * This method considers whether shard-level rescoring is enabled and adjusts the oversample factor
+ * based on the vector dimension if shard-level rescoring is disabled.
*
- * @param finalK The final number of results to return for the entire shard
- * @return The number of results to return for the first pass of rescoring
+ * @param finalK The final number of results to return for the entire shard.
+ * @param isShardLevelRescoringEnabled A boolean flag indicating whether shard-level rescoring is enabled.
+ * If true, the dimension-based oversampling logic is bypassed.
+ * @param dimension The dimension of the vector. This is used to determine the oversampling factor when
+ * shard-level rescoring is disabled.
+ * @return The number of results to return for the first pass of rescoring, adjusted by the oversample factor.
*/
- public int getFirstPassK(int finalK) {
+ public int getFirstPassK(int finalK, boolean isShardLevelRescoringEnabled, int dimension) {
+ // Only apply default dimension-based oversampling logic when:
+ // 1. Shard-level rescoring is disabled
+ // 2. The oversample factor was not provided by the user
+ if (!isShardLevelRescoringEnabled && !userProvided) {
+ // Apply new dimension-based oversampling logic when shard-level rescoring is disabled
+ if (dimension >= DIMENSION_THRESHOLD_1000) {
+ oversampleFactor = OVERSAMPLE_FACTOR_1000; // No oversampling for dimensions >= 1000
+ } else if (dimension >= DIMENSION_THRESHOLD_768) {
+ oversampleFactor = OVERSAMPLE_FACTOR_768; // 2x oversampling for dimensions >= 768 and < 1000
+ } else {
+ oversampleFactor = OVERSAMPLE_FACTOR_BELOW_768; // 3x oversampling for dimensions < 768
+ }
+ }
+ // The calculation for firstPassK remains the same, applying the oversample factor
return Math.min(MAX_FIRST_PASS_RESULTS, Math.max(MIN_FIRST_PASS_RESULTS, (int) Math.ceil(finalK * oversampleFactor)));
}
+
}
diff --git a/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java b/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java
index fd25699cc..c7a8e7ed8 100644
--- a/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java
+++ b/src/test/java/org/opensearch/knn/index/KNNSettingsTests.java
@@ -159,7 +159,7 @@ public void testGetEfSearch_whenEFSearchValueSetByUser_thenReturnValue() {
}
@SneakyThrows
- public void testShardLevelRescoringDisabled_whenNoValuesProvidedByUser_thenDefaultSettingsUsed() {
+ public void testShardLevelRescoringEnabled_whenNoValuesProvidedByUser_thenDefaultSettingsUsed() {
Node mockNode = createMockNode(Collections.emptyMap());
mockNode.start();
ClusterService clusterService = mockNode.injector().getInstance(ClusterService.class);
@@ -167,14 +167,14 @@ public void testShardLevelRescoringDisabled_whenNoValuesProvidedByUser_thenDefau
mockNode.client().admin().indices().create(new CreateIndexRequest(INDEX_NAME)).actionGet();
KNNSettings.state().setClusterService(clusterService);
- boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(INDEX_NAME);
+ boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(INDEX_NAME);
mockNode.close();
- assertTrue(shardLevelRescoringDisabled);
+ assertFalse(shardLevelRescoringDisabled);
}
@SneakyThrows
public void testShardLevelRescoringDisabled_whenValueProvidedByUser_thenSettingApplied() {
- boolean userDefinedRescoringDisabled = false;
+ boolean userDefinedRescoringDisabled = true;
Node mockNode = createMockNode(Collections.emptyMap());
mockNode.start();
ClusterService clusterService = mockNode.injector().getInstance(ClusterService.class);
@@ -188,7 +188,7 @@ public void testShardLevelRescoringDisabled_whenValueProvidedByUser_thenSettingA
mockNode.client().admin().indices().updateSettings(new UpdateSettingsRequest(rescoringDisabledSetting, INDEX_NAME)).actionGet();
- boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(INDEX_NAME);
+ boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(INDEX_NAME);
mockNode.close();
assertEquals(userDefinedRescoringDisabled, shardLevelRescoringDisabled);
}
diff --git a/src/test/java/org/opensearch/knn/index/mapper/CompressionLevelTests.java b/src/test/java/org/opensearch/knn/index/mapper/CompressionLevelTests.java
index 57372b11e..e882d6697 100644
--- a/src/test/java/org/opensearch/knn/index/mapper/CompressionLevelTests.java
+++ b/src/test/java/org/opensearch/knn/index/mapper/CompressionLevelTests.java
@@ -45,83 +45,57 @@ public void testGetDefaultRescoreContext() {
// Test rescore context for ON_DISK mode
Mode mode = Mode.ON_DISK;
- // Test various dimensions based on the updated oversampling logic
- int belowThresholdDimension = 500; // A dimension below 768
- int between768and1000Dimension = 800; // A dimension between 768 and 1000
- int above1000Dimension = 1500; // A dimension above 1000
+ int belowThresholdDimension = 500; // A dimension below the threshold
+ int aboveThresholdDimension = 1500; // A dimension above the threshold
- // Compression level x32 with dimension < 768 should have an oversample factor of 3.0f
+ // x32 with dimension <= 1000 should have an oversample factor of 5.0f
RescoreContext rescoreContext = CompressionLevel.x32.getDefaultRescoreContext(mode, belowThresholdDimension);
assertNotNull(rescoreContext);
- assertEquals(3.0f, rescoreContext.getOversampleFactor(), 0.0f);
-
- // Compression level x32 with dimension between 768 and 1000 should have an oversample factor of 2.0f
- rescoreContext = CompressionLevel.x32.getDefaultRescoreContext(mode, between768and1000Dimension);
- assertNotNull(rescoreContext);
- assertEquals(2.0f, rescoreContext.getOversampleFactor(), 0.0f);
+ assertEquals(5.0f, rescoreContext.getOversampleFactor(), 0.0f);
- // Compression level x32 with dimension > 1000 should have no oversampling (1.0f)
- rescoreContext = CompressionLevel.x32.getDefaultRescoreContext(mode, above1000Dimension);
- assertNotNull(rescoreContext);
- assertEquals(1.0f, rescoreContext.getOversampleFactor(), 0.0f);
-
- // Compression level x16 with dimension < 768 should have an oversample factor of 3.0f
- rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, belowThresholdDimension);
+ // x32 with dimension > 1000 should have an oversample factor of 3.0f
+ rescoreContext = CompressionLevel.x32.getDefaultRescoreContext(mode, aboveThresholdDimension);
assertNotNull(rescoreContext);
assertEquals(3.0f, rescoreContext.getOversampleFactor(), 0.0f);
- // Compression level x16 with dimension between 768 and 1000 should have an oversample factor of 2.0f
- rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, between768and1000Dimension);
+ // x16 with dimension <= 1000 should have an oversample factor of 5.0f
+ rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, belowThresholdDimension);
assertNotNull(rescoreContext);
- assertEquals(2.0f, rescoreContext.getOversampleFactor(), 0.0f);
+ assertEquals(5.0f, rescoreContext.getOversampleFactor(), 0.0f);
- // Compression level x16 with dimension > 1000 should have no oversampling (1.0f)
- rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, above1000Dimension);
+ // x16 with dimension > 1000 should have an oversample factor of 3.0f
+ rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, aboveThresholdDimension);
assertNotNull(rescoreContext);
- assertEquals(1.0f, rescoreContext.getOversampleFactor(), 0.0f);
+ assertEquals(3.0f, rescoreContext.getOversampleFactor(), 0.0f);
- // Compression level x8 with dimension < 768 should have an oversample factor of 3.0f
+ // x8 with dimension <= 1000 should have an oversample factor of 5.0f
rescoreContext = CompressionLevel.x8.getDefaultRescoreContext(mode, belowThresholdDimension);
assertNotNull(rescoreContext);
- assertEquals(3.0f, rescoreContext.getOversampleFactor(), 0.0f);
-
- // Compression level x8 with dimension between 768 and 1000 should have an oversample factor of 2.0f
- rescoreContext = CompressionLevel.x8.getDefaultRescoreContext(mode, between768and1000Dimension);
+ assertEquals(5.0f, rescoreContext.getOversampleFactor(), 0.0f);
+ // x8 with dimension > 1000 should have an oversample factor of 2.0f
+ rescoreContext = CompressionLevel.x8.getDefaultRescoreContext(mode, aboveThresholdDimension);
assertNotNull(rescoreContext);
assertEquals(2.0f, rescoreContext.getOversampleFactor(), 0.0f);
- // Compression level x8 with dimension > 1000 should have no oversampling (1.0f)
- rescoreContext = CompressionLevel.x8.getDefaultRescoreContext(mode, above1000Dimension);
- assertNotNull(rescoreContext);
- assertEquals(1.0f, rescoreContext.getOversampleFactor(), 0.0f);
-
- // Compression level x4 with dimension < 768 should return null (no RescoreContext)
+ // x4 with dimension <= 1000 should have an oversample factor of 5.0f (though it doesn't have its own RescoreContext)
rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, belowThresholdDimension);
assertNull(rescoreContext);
-
- // Compression level x4 with dimension > 1000 should return null (no RescoreContext)
- rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, above1000Dimension);
+ // x4 with dimension > 1000 should return null (no RescoreContext is configured for x4)
+ rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, aboveThresholdDimension);
assertNull(rescoreContext);
-
- // Compression level x2 with dimension < 768 should return null
+ // Other compression levels should behave similarly with respect to dimension
rescoreContext = CompressionLevel.x2.getDefaultRescoreContext(mode, belowThresholdDimension);
assertNull(rescoreContext);
-
- // Compression level x2 with dimension > 1000 should return null
- rescoreContext = CompressionLevel.x2.getDefaultRescoreContext(mode, above1000Dimension);
+ // x2 with dimension > 1000 should return null
+ rescoreContext = CompressionLevel.x2.getDefaultRescoreContext(mode, aboveThresholdDimension);
assertNull(rescoreContext);
-
- // Compression level x1 with dimension < 768 should return null
rescoreContext = CompressionLevel.x1.getDefaultRescoreContext(mode, belowThresholdDimension);
assertNull(rescoreContext);
-
- // Compression level x1 with dimension > 1000 should return null
- rescoreContext = CompressionLevel.x1.getDefaultRescoreContext(mode, above1000Dimension);
+ // x1 with dimension > 1000 should return null
+ rescoreContext = CompressionLevel.x1.getDefaultRescoreContext(mode, aboveThresholdDimension);
assertNull(rescoreContext);
-
- // NOT_CONFIGURED mode should return null for any dimension
+ // NOT_CONFIGURED with dimension <= 1000 should return a RescoreContext with an oversample factor of 5.0f
rescoreContext = CompressionLevel.NOT_CONFIGURED.getDefaultRescoreContext(mode, belowThresholdDimension);
assertNull(rescoreContext);
}
-
}
diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java
index 3db03085b..b28b790d1 100644
--- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java
+++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java
@@ -912,7 +912,8 @@ private void assertRescore(Version version, RescoreContext expectedRescoreContex
}
if (expectedRescoreContext != null) {
- assertEquals(expectedRescoreContext, actualRescoreContext);
+ assertNotNull(actualRescoreContext);
+ assertEquals(expectedRescoreContext.getOversampleFactor(), actualRescoreContext.getOversampleFactor(), 0.0f);
}
}
diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java
index f92f32406..2a2c3ed4d 100644
--- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java
+++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java
@@ -79,6 +79,7 @@
import static java.util.Collections.emptyMap;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyFloat;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyLong;
@@ -516,6 +517,111 @@ public void testANNWithFilterQuery_whenDoingANNBinary_thenSuccess() {
validateANNWithFilterQuery_whenDoingANN_thenSuccess(true);
}
+ @SneakyThrows
+ public void testScorerWithQuantizedVector() {
+ // Given
+ int k = 3;
+ byte[] quantizedVector = new byte[] { 1, 2, 3 }; // Mocked quantized vector
+ float[] queryVector = new float[] { 0.1f, 0.3f };
+
+ // Mock the JNI service to return KNNQueryResults
+ KNNQueryResult[] knnQueryResults = new KNNQueryResult[] {
+ new KNNQueryResult(1, 10.0f), // Mock result with id 1 and score 10
+ new KNNQueryResult(2, 20.0f) // Mock result with id 2 and score 20
+ };
+ jniServiceMockedStatic.when(
+ () -> JNIService.queryBinaryIndex(anyLong(), eq(quantizedVector), eq(k), any(), any(), any(), anyInt(), any())
+ ).thenReturn(knnQueryResults);
+
+ KNNEngine knnEngine = mock(KNNEngine.class);
+ when(knnEngine.score(anyFloat(), eq(SpaceType.HAMMING))).thenAnswer(invocation -> {
+ Float score = invocation.getArgument(0);
+ return 1 / (1 + score);
+ });
+
+ // Build the KNNQuery object
+ final KNNQuery query = KNNQuery.builder()
+ .field(FIELD_NAME)
+ .queryVector(queryVector)
+ .k(k)
+ .indexName(INDEX_NAME)
+ .vectorDataType(VectorDataType.BINARY) // Simulate binary vector type for quantization
+ .build();
+
+ final float boost = 1.0F;
+ final KNNWeight knnWeight = new KNNWeight(query, boost);
+
+ final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
+ final SegmentReader reader = mock(SegmentReader.class);
+ when(leafReaderContext.reader()).thenReturn(reader);
+
+ final FieldInfos fieldInfos = mock(FieldInfos.class);
+ final FieldInfo fieldInfo = mock(FieldInfo.class);
+ when(reader.getFieldInfos()).thenReturn(fieldInfos);
+ when(fieldInfos.fieldInfo(FIELD_NAME)).thenReturn(fieldInfo);
+
+ when(fieldInfo.attributes()).thenReturn(Map.of(KNN_ENGINE, KNNEngine.FAISS.getName(), SPACE_TYPE, SpaceType.HAMMING.getValue()));
+
+ FSDirectory directory = mock(FSDirectory.class);
+ when(reader.directory()).thenReturn(directory);
+ Path path = mock(Path.class);
+ when(directory.getDirectory()).thenReturn(path);
+ when(path.toString()).thenReturn("/fake/directory");
+
+ SegmentInfo segmentInfo = new SegmentInfo(
+ directory, // The directory where the segment is stored
+ Version.LATEST, // Lucene version
+ Version.LATEST, // Version of the segment info
+ "0", // Segment name
+ 100, // Max document count for this segment
+ false, // Is this a compound file segment
+ false, // Is this a merged segment
+ KNNCodecVersion.current().getDefaultCodecDelegate(), // Codec delegate for KNN
+ Map.of(), // Diagnostics map
+ new byte[StringHelper.ID_LENGTH], // Segment ID
+ Map.of(), // Attributes
+ Sort.RELEVANCE // Default sort order
+ );
+
+ final SegmentCommitInfo segmentCommitInfo = new SegmentCommitInfo(segmentInfo, 0, 0, 0, 0, 0, new byte[StringHelper.ID_LENGTH]);
+
+ when(reader.getSegmentInfo()).thenReturn(segmentCommitInfo);
+
+ try (MockedStatic knnCodecUtilMockedStatic = mockStatic(KNNCodecUtil.class)) {
+ List engineFiles = List.of("_0_1_target_field.faiss");
+ knnCodecUtilMockedStatic.when(() -> KNNCodecUtil.getEngineFiles(anyString(), anyString(), eq(segmentInfo)))
+ .thenReturn(engineFiles);
+
+ try (MockedStatic quantizationUtilMockedStatic = mockStatic(SegmentLevelQuantizationUtil.class)) {
+ quantizationUtilMockedStatic.when(() -> SegmentLevelQuantizationUtil.quantizeVector(any(), any()))
+ .thenReturn(quantizedVector);
+
+ // When: Call the scorer method
+ final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext);
+
+ // Then: Ensure scorer is not null
+ assertNotNull(knnScorer);
+
+ // Verify that JNIService.queryBinaryIndex is called with the quantized vector
+ jniServiceMockedStatic.verify(
+ () -> JNIService.queryBinaryIndex(anyLong(), eq(quantizedVector), eq(k), any(), any(), any(), anyInt(), any()),
+ times(1)
+ );
+
+ // Iterate over the results and ensure they are scored with SpaceType.HAMMING
+ final DocIdSetIterator docIdSetIterator = knnScorer.iterator();
+ assertNotNull(docIdSetIterator);
+ while (docIdSetIterator.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
+ int docId = docIdSetIterator.docID();
+ float expectedScore = knnEngine.score(knnQueryResults[docId - 1].getScore(), SpaceType.HAMMING);
+ float actualScore = knnScorer.score();
+ // Check if the score is calculated using HAMMING
+ assertEquals(expectedScore, actualScore, 0.01f); // Tolerance for floating-point comparison
+ }
+ }
+ }
+ }
+
public void validateANNWithFilterQuery_whenDoingANN_thenSuccess(final boolean isBinary) throws IOException {
// Given
int k = 3;
diff --git a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java
index 7fd96c6df..53873e15f 100644
--- a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java
+++ b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java
@@ -103,6 +103,8 @@ public void setUp() throws Exception {
// Set ClusterService in KNNSettings
KNNSettings.state().setClusterService(clusterService);
+ when(knnQuery.getQueryVector()).thenReturn(new float[] { 1.0f, 2.0f, 3.0f }); // Example vector
+
}
@SneakyThrows
@@ -166,7 +168,7 @@ public void testRescoreWhenShardLevelRescoringEnabled() {
) {
// When shard-level re-scoring is enabled
- mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(any())).thenReturn(false);
+ mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(any())).thenReturn(true);
// Mock ResultUtil to return valid TopDocs
mockedResultUtil.when(() -> ResultUtil.resultMapToTopDocs(any(), anyInt()))
@@ -250,7 +252,7 @@ public void testRescore() {
) {
// When shard-level re-scoring is enabled
- mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(any())).thenReturn(true);
+ mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(any())).thenReturn(true);
mockedResultUtil.when(() -> ResultUtil.reduceToTopK(any(), anyInt())).thenAnswer(InvocationOnMock::callRealMethod);
mockedResultUtil.when(() -> ResultUtil.resultMapToTopDocs(eq(rescoredLeaf1Results), anyInt())).thenAnswer(t -> topDocs1);
diff --git a/src/test/java/org/opensearch/knn/index/query/rescore/RescoreContextTests.java b/src/test/java/org/opensearch/knn/index/query/rescore/RescoreContextTests.java
index fd94667db..2b309e4ab 100644
--- a/src/test/java/org/opensearch/knn/index/query/rescore/RescoreContextTests.java
+++ b/src/test/java/org/opensearch/knn/index/query/rescore/RescoreContextTests.java
@@ -14,47 +14,72 @@ public class RescoreContextTests extends KNNTestCase {
public void testGetFirstPassK() {
float oversample = 2.6f;
- RescoreContext rescoreContext = RescoreContext.builder().oversampleFactor(oversample).build();
+ RescoreContext rescoreContext = RescoreContext.builder().oversampleFactor(oversample).userProvided(true).build();
int finalK = 100;
- assertEquals(260, rescoreContext.getFirstPassK(finalK));
- finalK = 1;
- assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK));
- finalK = 0;
- assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK));
- finalK = MAX_FIRST_PASS_RESULTS;
- assertEquals(MAX_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK));
- }
-
- public void testGetFirstPassKWithMinPassK() {
- float oversample = 2.6f;
- RescoreContext rescoreContext = RescoreContext.builder().oversampleFactor(oversample).build();
+ boolean isShardLevelRescoringEnabled = true;
+ int dimension = 500;
- // Case 1: Test with a finalK that results in a value greater than MIN_FIRST_PASS_RESULTS
- int finalK = 100;
- assertEquals(260, rescoreContext.getFirstPassK(finalK));
+ // Case 1: Test with standard oversample factor when shard-level rescoring is enabled
+ assertEquals(260, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension));
// Case 2: Test with a very small finalK that should result in a value less than MIN_FIRST_PASS_RESULTS
finalK = 1;
- assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK));
+ assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension));
- // Case 3: Test with finalK = 0, should return 0
+ // Case 3: Test with finalK = 0, should return MIN_FIRST_PASS_RESULTS
finalK = 0;
- assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK));
+ assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension));
// Case 4: Test with finalK = MAX_FIRST_PASS_RESULTS, should cap at MAX_FIRST_PASS_RESULTS
finalK = MAX_FIRST_PASS_RESULTS;
- assertEquals(MAX_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK));
+ assertEquals(MAX_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension));
+ }
- // Case 5: Test where finalK * oversample is smaller than MIN_FIRST_PASS_RESULTS
+ public void testGetFirstPassKWithDimensionBasedOversampling() {
+ int finalK = 100;
+ int dimension;
+
+ // Case 1: Test no oversampling for dimensions >= 1000 when shard-level rescoring is disabled
+ dimension = 1000;
+ RescoreContext rescoreContext = RescoreContext.builder().userProvided(false).build(); // Ensuring dimension-based logic applies
+ assertEquals(100, rescoreContext.getFirstPassK(finalK, false, dimension)); // No oversampling
+
+ // Case 2: Test 2x oversampling for dimensions >= 768 but < 1000 when shard-level rescoring is disabled
+ dimension = 800;
+ rescoreContext = RescoreContext.builder().userProvided(false).build(); // Ensure previous values don't carry over
+ assertEquals(200, rescoreContext.getFirstPassK(finalK, false, dimension)); // 2x oversampling
+
+ // Case 3: Test 3x oversampling for dimensions < 768 when shard-level rescoring is disabled
+ dimension = 700;
+ rescoreContext = RescoreContext.builder().userProvided(false).build(); // Ensure previous values don't carry over
+ assertEquals(300, rescoreContext.getFirstPassK(finalK, false, dimension)); // 3x oversampling
+
+ // Case 4: Shard-level rescoring enabled, oversample factor should be used as provided by the user (ignore dimension)
+ rescoreContext = RescoreContext.builder().oversampleFactor(5.0f).userProvided(true).build(); // Provided by user
+ dimension = 500;
+ assertEquals(500, rescoreContext.getFirstPassK(finalK, true, dimension)); // User-defined oversample factor should be used
+
+ // Case 5: Test finalK where oversampling factor results in a value less than MIN_FIRST_PASS_RESULTS
finalK = 10;
- oversample = 0.5f; // This will result in 5, which is less than MIN_FIRST_PASS_RESULTS
- rescoreContext = RescoreContext.builder().oversampleFactor(oversample).build();
- assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK));
+ dimension = 700;
+ rescoreContext = RescoreContext.builder().userProvided(false).build(); // Ensure dimension-based logic applies
+ assertEquals(100, rescoreContext.getFirstPassK(finalK, false, dimension)); // 3x oversampling results in 30
+ }
+
+ public void testGetFirstPassKWithMinPassK() {
+ float oversample = 0.5f;
+ RescoreContext rescoreContext = RescoreContext.builder().oversampleFactor(oversample).userProvided(true).build(); // User provided
+ boolean isShardLevelRescoringEnabled = false;
+
+ // Case 1: Test where finalK * oversample is smaller than MIN_FIRST_PASS_RESULTS
+ int finalK = 10;
+ int dimension = 700;
+ assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension));
- // Case 6: Test where finalK * oversample results in exactly MIN_FIRST_PASS_RESULTS
+ // Case 2: Test where finalK * oversample results in exactly MIN_FIRST_PASS_RESULTS
finalK = 100;
oversample = 1.0f; // This will result in exactly 100 (MIN_FIRST_PASS_RESULTS)
- rescoreContext = RescoreContext.builder().oversampleFactor(oversample).build();
- assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK));
+ rescoreContext = RescoreContext.builder().oversampleFactor(oversample).userProvided(true).build(); // User provided
+ assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension));
}
}