From 207e3415089a1d5e25cbdb4d21c64e3fdd7e991a Mon Sep 17 00:00:00 2001 From: VIKASH TIWARI Date: Tue, 1 Oct 2024 16:31:47 -0700 Subject: [PATCH] Fixing Scoring Issue with Binary Quanyized Vector --- .../opensearch/knn/index/query/KNNWeight.java | 4 + .../knn/index/query/KNNWeightTests.java | 106 ++++++++++++++++++ 2 files changed, 110 insertions(+) diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 1c31ed7253..c75d400a7b 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -375,6 +375,10 @@ private Map doANNSearch( return null; } + if (quantizedVector != null) { + return Arrays.stream(results) + .collect(Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), SpaceType.HAMMING))); + } return Arrays.stream(results) .collect(Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), spaceType))); } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index f92f324061..2a2c3ed4db 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -79,6 +79,7 @@ import static java.util.Collections.emptyMap; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyFloat; import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyLong; @@ -516,6 +517,111 @@ public void testANNWithFilterQuery_whenDoingANNBinary_thenSuccess() { validateANNWithFilterQuery_whenDoingANN_thenSuccess(true); } + @SneakyThrows + public void testScorerWithQuantizedVector() { + // Given + int k = 3; + byte[] quantizedVector = new byte[] { 1, 2, 3 }; // Mocked quantized vector + float[] queryVector = new float[] { 0.1f, 0.3f }; + + // Mock the JNI service to return KNNQueryResults + KNNQueryResult[] knnQueryResults = new KNNQueryResult[] { + new KNNQueryResult(1, 10.0f), // Mock result with id 1 and score 10 + new KNNQueryResult(2, 20.0f) // Mock result with id 2 and score 20 + }; + jniServiceMockedStatic.when( + () -> JNIService.queryBinaryIndex(anyLong(), eq(quantizedVector), eq(k), any(), any(), any(), anyInt(), any()) + ).thenReturn(knnQueryResults); + + KNNEngine knnEngine = mock(KNNEngine.class); + when(knnEngine.score(anyFloat(), eq(SpaceType.HAMMING))).thenAnswer(invocation -> { + Float score = invocation.getArgument(0); + return 1 / (1 + score); + }); + + // Build the KNNQuery object + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(queryVector) + .k(k) + .indexName(INDEX_NAME) + .vectorDataType(VectorDataType.BINARY) // Simulate binary vector type for quantization + .build(); + + final float boost = 1.0F; + final KNNWeight knnWeight = new KNNWeight(query, boost); + + final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + final SegmentReader reader = mock(SegmentReader.class); + when(leafReaderContext.reader()).thenReturn(reader); + + final FieldInfos fieldInfos = mock(FieldInfos.class); + final FieldInfo fieldInfo = mock(FieldInfo.class); + when(reader.getFieldInfos()).thenReturn(fieldInfos); + when(fieldInfos.fieldInfo(FIELD_NAME)).thenReturn(fieldInfo); + + when(fieldInfo.attributes()).thenReturn(Map.of(KNN_ENGINE, KNNEngine.FAISS.getName(), SPACE_TYPE, SpaceType.HAMMING.getValue())); + + FSDirectory directory = mock(FSDirectory.class); + when(reader.directory()).thenReturn(directory); + Path path = mock(Path.class); + when(directory.getDirectory()).thenReturn(path); + when(path.toString()).thenReturn("/fake/directory"); + + SegmentInfo segmentInfo = new SegmentInfo( + directory, // The directory where the segment is stored + Version.LATEST, // Lucene version + Version.LATEST, // Version of the segment info + "0", // Segment name + 100, // Max document count for this segment + false, // Is this a compound file segment + false, // Is this a merged segment + KNNCodecVersion.current().getDefaultCodecDelegate(), // Codec delegate for KNN + Map.of(), // Diagnostics map + new byte[StringHelper.ID_LENGTH], // Segment ID + Map.of(), // Attributes + Sort.RELEVANCE // Default sort order + ); + + final SegmentCommitInfo segmentCommitInfo = new SegmentCommitInfo(segmentInfo, 0, 0, 0, 0, 0, new byte[StringHelper.ID_LENGTH]); + + when(reader.getSegmentInfo()).thenReturn(segmentCommitInfo); + + try (MockedStatic knnCodecUtilMockedStatic = mockStatic(KNNCodecUtil.class)) { + List engineFiles = List.of("_0_1_target_field.faiss"); + knnCodecUtilMockedStatic.when(() -> KNNCodecUtil.getEngineFiles(anyString(), anyString(), eq(segmentInfo))) + .thenReturn(engineFiles); + + try (MockedStatic quantizationUtilMockedStatic = mockStatic(SegmentLevelQuantizationUtil.class)) { + quantizationUtilMockedStatic.when(() -> SegmentLevelQuantizationUtil.quantizeVector(any(), any())) + .thenReturn(quantizedVector); + + // When: Call the scorer method + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + + // Then: Ensure scorer is not null + assertNotNull(knnScorer); + + // Verify that JNIService.queryBinaryIndex is called with the quantized vector + jniServiceMockedStatic.verify( + () -> JNIService.queryBinaryIndex(anyLong(), eq(quantizedVector), eq(k), any(), any(), any(), anyInt(), any()), + times(1) + ); + + // Iterate over the results and ensure they are scored with SpaceType.HAMMING + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + assertNotNull(docIdSetIterator); + while (docIdSetIterator.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + int docId = docIdSetIterator.docID(); + float expectedScore = knnEngine.score(knnQueryResults[docId - 1].getScore(), SpaceType.HAMMING); + float actualScore = knnScorer.score(); + // Check if the score is calculated using HAMMING + assertEquals(expectedScore, actualScore, 0.01f); // Tolerance for floating-point comparison + } + } + } + } + public void validateANNWithFilterQuery_whenDoingANN_thenSuccess(final boolean isBinary) throws IOException { // Given int k = 3;