Skip to content

Commit

Permalink
Fixing Scoring Issue with Binary Quanyized Vector
Browse files Browse the repository at this point in the history
  • Loading branch information
Vikasht34 committed Oct 1, 2024
1 parent 07f4df2 commit 207e341
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,10 @@ private Map<Integer, Float> doANNSearch(
return null;
}

if (quantizedVector != null) {
return Arrays.stream(results)
.collect(Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), SpaceType.HAMMING)));
}
return Arrays.stream(results)
.collect(Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), spaceType)));
}
Expand Down
106 changes: 106 additions & 0 deletions src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
import static java.util.Collections.emptyMap;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyFloat;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyLong;
Expand Down Expand Up @@ -516,6 +517,111 @@ public void testANNWithFilterQuery_whenDoingANNBinary_thenSuccess() {
validateANNWithFilterQuery_whenDoingANN_thenSuccess(true);
}

@SneakyThrows
public void testScorerWithQuantizedVector() {
// Given
int k = 3;
byte[] quantizedVector = new byte[] { 1, 2, 3 }; // Mocked quantized vector
float[] queryVector = new float[] { 0.1f, 0.3f };

// Mock the JNI service to return KNNQueryResults
KNNQueryResult[] knnQueryResults = new KNNQueryResult[] {
new KNNQueryResult(1, 10.0f), // Mock result with id 1 and score 10
new KNNQueryResult(2, 20.0f) // Mock result with id 2 and score 20
};
jniServiceMockedStatic.when(
() -> JNIService.queryBinaryIndex(anyLong(), eq(quantizedVector), eq(k), any(), any(), any(), anyInt(), any())
).thenReturn(knnQueryResults);

KNNEngine knnEngine = mock(KNNEngine.class);
when(knnEngine.score(anyFloat(), eq(SpaceType.HAMMING))).thenAnswer(invocation -> {
Float score = invocation.getArgument(0);
return 1 / (1 + score);
});

// Build the KNNQuery object
final KNNQuery query = KNNQuery.builder()
.field(FIELD_NAME)
.queryVector(queryVector)
.k(k)
.indexName(INDEX_NAME)
.vectorDataType(VectorDataType.BINARY) // Simulate binary vector type for quantization
.build();

final float boost = 1.0F;
final KNNWeight knnWeight = new KNNWeight(query, boost);

final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
final SegmentReader reader = mock(SegmentReader.class);
when(leafReaderContext.reader()).thenReturn(reader);

final FieldInfos fieldInfos = mock(FieldInfos.class);
final FieldInfo fieldInfo = mock(FieldInfo.class);
when(reader.getFieldInfos()).thenReturn(fieldInfos);
when(fieldInfos.fieldInfo(FIELD_NAME)).thenReturn(fieldInfo);

when(fieldInfo.attributes()).thenReturn(Map.of(KNN_ENGINE, KNNEngine.FAISS.getName(), SPACE_TYPE, SpaceType.HAMMING.getValue()));

FSDirectory directory = mock(FSDirectory.class);
when(reader.directory()).thenReturn(directory);
Path path = mock(Path.class);
when(directory.getDirectory()).thenReturn(path);
when(path.toString()).thenReturn("/fake/directory");

SegmentInfo segmentInfo = new SegmentInfo(
directory, // The directory where the segment is stored
Version.LATEST, // Lucene version
Version.LATEST, // Version of the segment info
"0", // Segment name
100, // Max document count for this segment
false, // Is this a compound file segment
false, // Is this a merged segment
KNNCodecVersion.current().getDefaultCodecDelegate(), // Codec delegate for KNN
Map.of(), // Diagnostics map
new byte[StringHelper.ID_LENGTH], // Segment ID
Map.of(), // Attributes
Sort.RELEVANCE // Default sort order
);

final SegmentCommitInfo segmentCommitInfo = new SegmentCommitInfo(segmentInfo, 0, 0, 0, 0, 0, new byte[StringHelper.ID_LENGTH]);

when(reader.getSegmentInfo()).thenReturn(segmentCommitInfo);

try (MockedStatic<KNNCodecUtil> knnCodecUtilMockedStatic = mockStatic(KNNCodecUtil.class)) {
List<String> engineFiles = List.of("_0_1_target_field.faiss");
knnCodecUtilMockedStatic.when(() -> KNNCodecUtil.getEngineFiles(anyString(), anyString(), eq(segmentInfo)))
.thenReturn(engineFiles);

try (MockedStatic<SegmentLevelQuantizationUtil> quantizationUtilMockedStatic = mockStatic(SegmentLevelQuantizationUtil.class)) {
quantizationUtilMockedStatic.when(() -> SegmentLevelQuantizationUtil.quantizeVector(any(), any()))
.thenReturn(quantizedVector);

// When: Call the scorer method
final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext);

// Then: Ensure scorer is not null
assertNotNull(knnScorer);

// Verify that JNIService.queryBinaryIndex is called with the quantized vector
jniServiceMockedStatic.verify(
() -> JNIService.queryBinaryIndex(anyLong(), eq(quantizedVector), eq(k), any(), any(), any(), anyInt(), any()),
times(1)
);

// Iterate over the results and ensure they are scored with SpaceType.HAMMING
final DocIdSetIterator docIdSetIterator = knnScorer.iterator();
assertNotNull(docIdSetIterator);
while (docIdSetIterator.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
int docId = docIdSetIterator.docID();
float expectedScore = knnEngine.score(knnQueryResults[docId - 1].getScore(), SpaceType.HAMMING);
float actualScore = knnScorer.score();
// Check if the score is calculated using HAMMING
assertEquals(expectedScore, actualScore, 0.01f); // Tolerance for floating-point comparison
}
}
}
}

public void validateANNWithFilterQuery_whenDoingANN_thenSuccess(final boolean isBinary) throws IOException {
// Given
int k = 3;
Expand Down

0 comments on commit 207e341

Please sign in to comment.