diff --git a/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java b/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java index d5b79768de..93c0c5b16d 100644 --- a/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java +++ b/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java @@ -5,6 +5,7 @@ package org.opensearch.knn.integ; +import com.google.common.collect.ImmutableList; import com.google.common.primitives.Floats; import lombok.SneakyThrows; import lombok.extern.log4j.Log4j2; @@ -13,11 +14,13 @@ import org.junit.After; import org.junit.BeforeClass; import org.opensearch.client.Response; +import org.opensearch.common.settings.Settings; import org.opensearch.knn.KNNJsonIndexMappingsBuilder; import org.opensearch.knn.KNNJsonQueryBuilder; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.knn.KNNResult; import org.opensearch.knn.TestUtils; +import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; @@ -36,6 +39,8 @@ @Log4j2 public class BinaryIndexIT extends KNNRestTestCase { private static TestUtils.TestData testData; + private static final int NEVER_BUILD_GRAPH = -1; + private static final int ALWAYS_BUILD_GRAPH = 0; @BeforeClass public static void setUpClass() throws IOException { @@ -104,6 +109,52 @@ public void testFaissHnswBinary_when1000Data_thenRecallIsAboveNinePointZero() { } } + @SneakyThrows + public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsNegativeEndToEnd_thenBuildGraphBasedOnSetting() { + // Create Index + createKnnHnswBinaryIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 128, NEVER_BUILD_GRAPH); + ingestTestData(INDEX_NAME, FIELD_NAME); + + assertEquals(0, runKnnQuery(INDEX_NAME, FIELD_NAME, testData.queries[0], 1).size()); + + // update build vector data structure setting + updateIndexSettings(INDEX_NAME, Settings.builder().put(KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, 0)); + forceMergeKnnIndex(INDEX_NAME, 1); + + int k = 100; + for (int i = 0; i < testData.queries.length; i++) { + List knnResults = runKnnQuery(INDEX_NAME, FIELD_NAME, testData.queries[i], k); + float recall = getRecall( + Set.of(Arrays.copyOf(testData.groundTruthValues[i], k)), + knnResults.stream().map(KNNResult::getDocId).collect(Collectors.toSet()) + ); + assertTrue("Recall: " + recall, recall > 0.1); + } + } + + @SneakyThrows + public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsProvidedEndToEnd_thenBuildGraphBasedOnSetting() throws Exception { + // Create Index + createKnnHnswBinaryIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 128, testData.indexData.docs.length); + ingestTestData(INDEX_NAME, FIELD_NAME, false); + + assertEquals(0, runKnnQuery(INDEX_NAME, FIELD_NAME, testData.queries[0], 1).size()); + + // update build vector data structure setting + updateIndexSettings(INDEX_NAME, Settings.builder().put(KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, 0)); + forceMergeKnnIndex(INDEX_NAME, 1); + + int k = 100; + for (int i = 0; i < testData.queries.length; i++) { + List knnResults = runKnnQuery(INDEX_NAME, FIELD_NAME, testData.queries[i], k); + float recall = getRecall( + Set.of(Arrays.copyOf(testData.groundTruthValues[i], k)), + knnResults.stream().map(KNNResult::getDocId).collect(Collectors.toSet()) + ); + assertTrue("Recall: " + recall, recall > 0.1); + } + } + @SneakyThrows public void testFaissHnswBinary_whenRadialSearch_thenThrowException() { // Create Index @@ -157,13 +208,18 @@ private List runKnnQuery(final String indexName, final String fieldNa } private void ingestTestData(final String indexName, final String fieldName) throws Exception { + ingestTestData(indexName, fieldName, true); + } + + private void ingestTestData(final String indexName, final String fieldName, boolean refresh) throws Exception { // Index the test data for (int i = 0; i < testData.indexData.docs.length; i++) { addKnnDoc( indexName, Integer.toString(testData.indexData.docs[i]), - fieldName, - Floats.asList(testData.indexData.vectors[i]).toArray() + ImmutableList.of(fieldName), + ImmutableList.of(Floats.asList(testData.indexData.vectors[i]).toArray()), + refresh ); } @@ -172,8 +228,13 @@ private void ingestTestData(final String indexName, final String fieldName) thro assertEquals(testData.indexData.docs.length, getDocCount(indexName)); } - private void createKnnHnswBinaryIndex(final KNNEngine knnEngine, final String indexName, final String fieldName, final int dimension) - throws IOException { + private void createKnnHnswBinaryIndex( + final KNNEngine knnEngine, + final String indexName, + final String fieldName, + final int dimension, + final int threshold + ) throws IOException { KNNJsonIndexMappingsBuilder.Method method = KNNJsonIndexMappingsBuilder.Method.builder() .methodName(METHOD_HNSW) .engine(knnEngine.getName()) @@ -186,7 +247,11 @@ private void createKnnHnswBinaryIndex(final KNNEngine knnEngine, final String in .method(method) .build() .getIndexMapping(); + createKnnIndex(indexName, buildKNNIndexSettings(threshold), knnIndexMapping); + } - createKnnIndex(indexName, knnIndexMapping); + private void createKnnHnswBinaryIndex(final KNNEngine knnEngine, final String indexName, final String fieldName, final int dimension) + throws IOException { + createKnnHnswBinaryIndex(knnEngine, indexName, fieldName, dimension, ALWAYS_BUILD_GRAPH); } }