Skip to content

Commit

Permalink
Add recall test with small dataset
Browse files Browse the repository at this point in the history
Signed-off-by: Heemin Kim <[email protected]>
  • Loading branch information
heemin32 committed Jul 31, 2024
1 parent 245995a commit b83c482
Show file tree
Hide file tree
Showing 6 changed files with 386 additions and 15 deletions.
46 changes: 31 additions & 15 deletions src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@

import java.io.IOException;
import java.net.URL;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;

Expand All @@ -41,9 +44,11 @@ public static void setUpClass() throws IOException {
}
URL testIndexVectors = BinaryIndexIT.class.getClassLoader().getResource("data/test_vectors_binary_1000x128.json");
URL testQueries = BinaryIndexIT.class.getClassLoader().getResource("data/test_queries_binary_100x128.csv");
URL groundTruthValues = BinaryIndexIT.class.getClassLoader().getResource("data/test_ground_truth_binary_100.csv");
assert testIndexVectors != null;
assert testQueries != null;
testData = new TestUtils.TestData(testIndexVectors.getPath(), testQueries.getPath());
assert groundTruthValues != null;
testData = new TestUtils.TestData(testIndexVectors.getPath(), testQueries.getPath(), groundTruthValues.getPath());
}

@After
Expand Down Expand Up @@ -83,18 +88,19 @@ public void testFaissHnswBinary_whenSmallDataSet_thenCreateIngestQueryWorks() {
}

@SneakyThrows
public void testFaissHnswBinary_when1000Data_thenCreateIngestQueryWorks() {
public void testFaissHnswBinary_when1000Data_thenRecallIsAboveNinePointZero() {
// Create Index
createKnnHnswBinaryIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 128);
ingestTestData(INDEX_NAME, FIELD_NAME);

int k = 10;
int k = 100;
for (int i = 0; i < testData.queries.length; i++) {
// Query
List<KNNResult> knnResults = runKnnQuery(INDEX_NAME, FIELD_NAME, testData.queries[i], k);

// Validate
assertEquals(k, knnResults.size());
float recall = getRecall(
Set.of(Arrays.copyOf(testData.groundTruthValues[i], k)),
knnResults.stream().map(KNNResult::getDocId).collect(Collectors.toSet())
);
assertTrue("Recall: " + recall, recall > 0.9);
}
}

Expand All @@ -109,6 +115,18 @@ public void testFaissHnswBinary_whenRadialSearch_thenThrowException() {
assertTrue(e.getMessage(), e.getMessage().contains("Binary data type does not support radial search"));
}

private float getRecall(final Set<String> truth, final Set<String> result) {
// Count the number of relevant documents retrieved
result.retainAll(truth);
int relevantRetrieved = result.size();

// Total number of relevant documents
int totalRelevant = truth.size();

// Calculate recall
return (float) relevantRetrieved / totalRelevant;
}

private List<KNNResult> runRnnQuery(
final String indexName,
final String fieldName,
Expand Down Expand Up @@ -156,8 +174,14 @@ private void ingestTestData(final String indexName, final String fieldName) thro

private void createKnnHnswBinaryIndex(final KNNEngine knnEngine, final String indexName, final String fieldName, final int dimension)
throws IOException {
KNNJsonIndexMappingsBuilder.Method.Parameters parameters = KNNJsonIndexMappingsBuilder.Method.Parameters.builder()
.efSearch(100)
.efConstruction(100)
.build();

KNNJsonIndexMappingsBuilder.Method method = KNNJsonIndexMappingsBuilder.Method.builder()
.methodName(METHOD_HNSW)
.parameters(parameters)
.engine(knnEngine.getName())
.build();

Expand All @@ -171,12 +195,4 @@ private void createKnnHnswBinaryIndex(final KNNEngine knnEngine, final String in

createKnnIndex(indexName, knnIndexMapping);
}

private byte[] toByte(final float[] vector) {
byte[] bytes = new byte[vector.length];
for (int i = 0; i < vector.length; i++) {
bytes[i] = (byte) vector[i];
}
return bytes;
}
}
139 changes: 139 additions & 0 deletions src/test/java/org/opensearch/knn/integ/IndexIT.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.integ;

import com.google.common.primitives.Floats;
import lombok.SneakyThrows;
import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang.ArrayUtils;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.junit.After;
import org.junit.BeforeClass;
import org.opensearch.client.Response;
import org.opensearch.knn.KNNJsonIndexMappingsBuilder;
import org.opensearch.knn.KNNJsonQueryBuilder;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.KNNResult;
import org.opensearch.knn.TestUtils;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.util.KNNEngine;

import java.io.IOException;
import java.net.URL;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;

/**
* This class contains integration tests for index
*/
@Log4j2
public class IndexIT extends KNNRestTestCase {
private static TestUtils.TestData testData;

@BeforeClass
public static void setUpClass() throws IOException {
if (IndexIT.class.getClassLoader() == null) {
throw new IllegalStateException("ClassLoader of IndexIT Class is null");
}
URL testIndexVectors = IndexIT.class.getClassLoader().getResource("data/test_vectors_1000x128.json");
URL testQueries = IndexIT.class.getClassLoader().getResource("data/test_queries_100x128.csv");
URL groundTruthValues = IndexIT.class.getClassLoader().getResource("data/test_ground_truth_l2_100.csv");
assert testIndexVectors != null;
assert testQueries != null;
assert groundTruthValues != null;
testData = new TestUtils.TestData(testIndexVectors.getPath(), testQueries.getPath(), groundTruthValues.getPath());
}

@After
public void cleanUp() {
try {
deleteKNNIndex(INDEX_NAME);
} catch (Exception e) {
log.error(e);
}
}

@SneakyThrows
public void testFaissHnsw_when1000Data_thenRecallIsAboveNinePointZero() {
// Create Index
createKnnHnswIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 128);
ingestTestData(INDEX_NAME, FIELD_NAME);

int k = 100;
for (int i = 0; i < testData.queries.length; i++) {
List<KNNResult> knnResults = runKnnQuery(INDEX_NAME, FIELD_NAME, testData.queries[i], k);
float recall = getRecall(
Set.of(Arrays.copyOf(testData.groundTruthValues[i], k)),
knnResults.stream().map(KNNResult::getDocId).collect(Collectors.toSet())
);
assertTrue("Recall: " + recall, recall > 0.9);
}
}

private float getRecall(final Set<String> truth, final Set<String> result) {
// Count the number of relevant documents retrieved
result.retainAll(truth);
int relevantRetrieved = result.size();

// Total number of relevant documents
int totalRelevant = truth.size();

// Calculate recall
return (float) relevantRetrieved / totalRelevant;
}

private List<KNNResult> runKnnQuery(final String indexName, final String fieldName, final float[] queryVector, final int k)
throws Exception {
String query = KNNJsonQueryBuilder.builder()
.fieldName(fieldName)
.vector(ArrayUtils.toObject(queryVector))
.k(k)
.build()
.getQueryString();
Response response = searchKNNIndex(indexName, query, k);
return parseSearchResponse(EntityUtils.toString(response.getEntity()), fieldName);
}

private void ingestTestData(final String indexName, final String fieldName) throws Exception {
// Index the test data
for (int i = 0; i < testData.indexData.docs.length; i++) {
addKnnDoc(
indexName,
Integer.toString(testData.indexData.docs[i]),
fieldName,
Floats.asList(testData.indexData.vectors[i]).toArray()
);
}

// Assert we have the right number of documents in the index
refreshAllIndices();
assertEquals(testData.indexData.docs.length, getDocCount(indexName));
}

private void createKnnHnswIndex(final KNNEngine knnEngine, final String indexName, final String fieldName, final int dimension)
throws IOException {
KNNJsonIndexMappingsBuilder.Method method = KNNJsonIndexMappingsBuilder.Method.builder()
.methodName(METHOD_HNSW)
.spaceType(SpaceType.L2.getValue())
.engine(knnEngine.getName())
.build();

String knnIndexMapping = KNNJsonIndexMappingsBuilder.builder()
.fieldName(fieldName)
.dimension(dimension)
.vectorDataType(VectorDataType.FLOAT.getValue())
.method(method)
.build()
.getIndexMapping();

createKnnIndex(indexName, knnIndexMapping);
}
}
8 changes: 8 additions & 0 deletions src/test/resources/data/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ test_queries_100x128.csv and packing 8 bits to 1 byte with ends up with 16 lengt
For quantization technique, we calculated the median(49935.95941056451) of all values in test_vectors_1000x128.json
and converted it as 0 if it is less than the median and 1 if it is equal to or larger than the median.

# test_ground_truth_binary_100.csv
The file contains the ground truth for the query test_queries_binary_100x128.csv against the data
test_vectors_binary_1000x128.json using hamming distance.

# test_ground_truth_l2_100.csv
The file contains the ground truth for the query test_queries_100x128.csv against the data test_vectors_1000x128.json
using l2 distance

# test_vectors_nested_1000x128.json
The file contains a simulated data to represent nested field.
Consecutive ids are assigned for data from same parent document.
Expand Down
Loading

0 comments on commit b83c482

Please sign in to comment.