From 102f239c2af2d7a1b55048c93233341a8b25ff5d Mon Sep 17 00:00:00 2001 From: Vijayan Balasubramanian Date: Tue, 17 Sep 2024 11:28:16 -0700 Subject: [PATCH] Introduce new setting to configure when to build graph during segment creation (#2007) Added new updatable index setting "build_vector_data_structure_threshold", which will be considered when to build braph or not for native engines. This is noop for lucene. This depends on use lucene format as prerequisite. We don't need to add flag since it is only enable if lucene format is already enabled. Signed-off-by: Vijayan Balasubramanian --- CHANGELOG.md | 1 + .../org/opensearch/knn/index/KNNSettings.java | 20 ++ .../codec/BasePerFieldKnnVectorsFormat.java | 18 +- .../NativeEngines990KnnVectorsFormat.java | 24 +- .../NativeEngines990KnnVectorsWriter.java | 35 +- .../opensearch/knn/index/OpenSearchIT.java | 210 +++++++++++ ...eEngines990KnnVectorsWriterFlushTests.java | 330 +++++++++++++++++- ...eEngines990KnnVectorsWriterMergeTests.java | 123 ++++++- .../org/opensearch/knn/KNNRestTestCase.java | 18 +- 9 files changed, 768 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f615a78fb9..0fbaeae1b3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 3.0](https://github.com/opensearch-project/k-NN/compare/2.x...HEAD) ### Features +* Introduce new setting to configure whether to build vector data structure or not during segment creation [#2007](https://github.com/opensearch-project/k-NN/pull/2007) ### Enhancements * Introducing a loading layer in FAISS [#2033](https://github.com/opensearch-project/k-NN/issues/2033) ### Bug Fixes diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index f5980879ab..ecacc500b7 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -67,6 +67,7 @@ public class KNNSettings { * Settings name */ public static final String KNN_SPACE_TYPE = "index.knn.space_type"; + public static final String INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD = "index.knn.build_vector_data_structure_threshold"; public static final String KNN_ALGO_PARAM_M = "index.knn.algo_param.m"; public static final String KNN_ALGO_PARAM_EF_CONSTRUCTION = "index.knn.algo_param.ef_construction"; public static final String KNN_ALGO_PARAM_EF_SEARCH = "index.knn.algo_param.ef_search"; @@ -97,6 +98,9 @@ public class KNNSettings { public static final boolean KNN_DEFAULT_FAISS_AVX2_DISABLED_VALUE = false; public static final boolean KNN_DEFAULT_FAISS_AVX512_DISABLED_VALUE = false; public static final String INDEX_KNN_DEFAULT_SPACE_TYPE = "l2"; + public static final Integer INDEX_KNN_DEFAULT_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD = 0; + public static final Integer INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD_MIN = -1; + public static final Integer INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD_MAX = Integer.MAX_VALUE - 2; public static final String INDEX_KNN_DEFAULT_SPACE_TYPE_FOR_BINARY = "hamming"; public static final Integer INDEX_KNN_DEFAULT_ALGO_PARAM_M = 16; public static final Integer INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH = 100; @@ -156,6 +160,21 @@ public class KNNSettings { Setting.Property.Deprecated ); + /** + * build_vector_data_structure_threshold - This parameter determines when to build vector data structure for knn fields during indexing + * and merging. Setting -1 (min) will skip building graph, whereas on any other values, the graph will be built if + * number of live docs in segment is greater than this threshold. Since max number of documents in a segment can + * be Integer.MAX_VALUE - 1, this setting will allow threshold to be up to 1 less than max number of documents in a segment + */ + public static final Setting INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD_SETTING = Setting.intSetting( + INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, + INDEX_KNN_DEFAULT_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, + INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD_MIN, + INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD_MAX, + IndexScope, + Dynamic + ); + /** * M - the number of bi-directional links created for every new element during construction. * Reasonable range for M is 2-100. Higher M work better on datasets with high intrinsic @@ -486,6 +505,7 @@ private Setting getSetting(String key) { public List> getSettings() { List> settings = Arrays.asList( INDEX_KNN_SPACE_TYPE, + INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD_SETTING, INDEX_KNN_ALGO_PARAM_M_SETTING, INDEX_KNN_ALGO_PARAM_EF_CONSTRUCTION_SETTING, INDEX_KNN_ALGO_PARAM_EF_SEARCH_SETTING, diff --git a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java index b06c2a1d8a..516f832232 100644 --- a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java @@ -11,7 +11,9 @@ import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.opensearch.index.IndexSettings; import org.opensearch.index.mapper.MapperService; +import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.codec.KNN990Codec.NativeEngines990KnnVectorsFormat; import org.opensearch.knn.index.codec.params.KNNScalarQuantizedVectorsFormatParams; import org.opensearch.knn.index.codec.params.KNNVectorsFormatParams; @@ -129,7 +131,21 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) { } private NativeEngines990KnnVectorsFormat nativeEngineVectorsFormat() { - return new NativeEngines990KnnVectorsFormat(new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer())); + int buildVectorDatastructureThreshold = getBuildVectorDatastructureThresholdSetting(mapperService.get()); + return new NativeEngines990KnnVectorsFormat( + new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()), + buildVectorDatastructureThreshold + ); + } + + private int getBuildVectorDatastructureThresholdSetting(final MapperService knnMapperService) { + final IndexSettings indexSettings = knnMapperService.getIndexSettings(); + final Integer buildVectorDatastructureThreshold = indexSettings.getValue( + KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD_SETTING + ); + return buildVectorDatastructureThreshold != null + ? buildVectorDatastructureThreshold + : KNNSettings.INDEX_KNN_DEFAULT_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD; } @Override diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java index 626210f250..5b2c9cdddb 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java @@ -19,6 +19,7 @@ import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; +import org.opensearch.knn.index.KNNSettings; import java.io.IOException; @@ -30,15 +31,20 @@ public class NativeEngines990KnnVectorsFormat extends KnnVectorsFormat { /** The format for storing, reading, merging vectors on disk */ private static FlatVectorsFormat flatVectorsFormat; private static final String FORMAT_NAME = "NativeEngines990KnnVectorsFormat"; + private static int buildVectorDatastructureThreshold; public NativeEngines990KnnVectorsFormat() { - super(FORMAT_NAME); - flatVectorsFormat = new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer()); + this(new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer())); + } + + public NativeEngines990KnnVectorsFormat(final FlatVectorsFormat flatVectorsFormat) { + this(flatVectorsFormat, KNNSettings.INDEX_KNN_DEFAULT_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD); } - public NativeEngines990KnnVectorsFormat(final FlatVectorsFormat lucene99FlatVectorsFormat) { + public NativeEngines990KnnVectorsFormat(final FlatVectorsFormat flatVectorsFormat, int buildVectorDatastructureThreshold) { super(FORMAT_NAME); - flatVectorsFormat = lucene99FlatVectorsFormat; + NativeEngines990KnnVectorsFormat.flatVectorsFormat = flatVectorsFormat; + NativeEngines990KnnVectorsFormat.buildVectorDatastructureThreshold = buildVectorDatastructureThreshold; } /** @@ -48,7 +54,7 @@ public NativeEngines990KnnVectorsFormat(final FlatVectorsFormat lucene99FlatVect */ @Override public KnnVectorsWriter fieldsWriter(final SegmentWriteState state) throws IOException { - return new NativeEngines990KnnVectorsWriter(state, flatVectorsFormat.fieldsWriter(state)); + return new NativeEngines990KnnVectorsWriter(state, flatVectorsFormat.fieldsWriter(state), buildVectorDatastructureThreshold); } /** @@ -63,6 +69,12 @@ public KnnVectorsReader fieldsReader(final SegmentReadState state) throws IOExce @Override public String toString() { - return "NativeEngines99KnnVectorsFormat(name=" + this.getClass().getSimpleName() + ", flatVectorsFormat=" + flatVectorsFormat + ")"; + return "NativeEngines99KnnVectorsFormat(name=" + + this.getClass().getSimpleName() + + ", flatVectorsFormat=" + + flatVectorsFormat + + ", buildVectorDatastructureThreshold" + + buildVectorDatastructureThreshold + + ")"; } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index 2f22565c98..1251b9763f 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -53,10 +53,16 @@ public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter { private KNN990QuantizationStateWriter quantizationStateWriter; private final List> fields = new ArrayList<>(); private boolean finished; + private final Integer buildVectorDataStructureThreshold; - public NativeEngines990KnnVectorsWriter(SegmentWriteState segmentWriteState, FlatVectorsWriter flatVectorsWriter) { + public NativeEngines990KnnVectorsWriter( + SegmentWriteState segmentWriteState, + FlatVectorsWriter flatVectorsWriter, + Integer buildVectorDataStructureThreshold + ) { this.segmentWriteState = segmentWriteState; this.flatVectorsWriter = flatVectorsWriter; + this.buildVectorDataStructureThreshold = buildVectorDataStructureThreshold; } /** @@ -94,6 +100,16 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { field.getVectors() ); final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs); + // Will consider building vector data structure based on threshold only for non quantization indices + if (quantizationState == null && shouldSkipBuildingVectorDataStructure(totalLiveDocs)) { + log.info( + "Skip building vector data structure for field: {}, as liveDoc: {} is less than the threshold {} during flush", + fieldInfo.name, + totalLiveDocs, + buildVectorDataStructureThreshold + ); + continue; + } final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState); final KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); @@ -123,6 +139,16 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState } final QuantizationState quantizationState = train(fieldInfo, knnVectorValuesSupplier, totalLiveDocs); + // Will consider building vector data structure based on threshold only for non quantization indices + if (quantizationState == null && shouldSkipBuildingVectorDataStructure(totalLiveDocs)) { + log.info( + "Skip building vector data structure for field: {}, as liveDoc: {} is less than the threshold {} during merge", + fieldInfo.name, + totalLiveDocs, + buildVectorDataStructureThreshold + ); + return; + } final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState); final KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); @@ -253,4 +279,11 @@ private void initQuantizationStateWriterIfNecessary() throws IOException { quantizationStateWriter.writeHeader(segmentWriteState); } } + + private boolean shouldSkipBuildingVectorDataStructure(final long docCount) { + if (buildVectorDataStructureThreshold < 0) { + return true; + } + return docCount < buildVectorDataStructureThreshold; + } } diff --git a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java index 0bb7947baf..e9da88c0a6 100644 --- a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java +++ b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java @@ -15,6 +15,7 @@ import com.google.common.primitives.Floats; import java.util.Locale; import lombok.SneakyThrows; +import org.apache.hc.core5.http.ParseException; import org.junit.BeforeClass; import org.junit.Ignore; import org.opensearch.knn.KNNRestTestCase; @@ -42,6 +43,8 @@ import java.util.TreeMap; import static org.hamcrest.Matchers.containsString; +import static org.opensearch.knn.index.KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD_MAX; +import static org.opensearch.knn.index.KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD_MIN; public class OpenSearchIT extends KNNRestTestCase { @@ -611,4 +614,211 @@ public void testCacheClear_whenCloseIndex() throws Exception { fail("Graphs are not getting evicted"); } + + public void testKNNIndex_whenBuildGraphThresholdIsPresent_thenGetThresholdValue() throws Exception { + final Integer buildVectorDataStructureThreshold = randomIntBetween( + INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD_MIN, + INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD_MAX + ); + final Settings settings = Settings.builder().put(buildKNNIndexSettings(buildVectorDataStructureThreshold)).build(); + final String knnIndexMapping = createKnnIndexMapping(FIELD_NAME, KNNEngine.getMaxDimensionByEngine(KNNEngine.DEFAULT)); + final String indexName = "test-index-with-build-graph-settings"; + createKnnIndex(indexName, settings, knnIndexMapping); + final String buildVectorDataStructureThresholdSetting = getIndexSettingByName( + indexName, + KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD + ); + assertNotNull("build_vector_data_structure_threshold index setting is not found", buildVectorDataStructureThresholdSetting); + assertEquals( + "incorrect setting for build_vector_data_structure_threshold", + buildVectorDataStructureThreshold, + Integer.valueOf(buildVectorDataStructureThresholdSetting) + ); + deleteKNNIndex(indexName); + } + + public void testKNNIndex_whenBuildThresholdIsNotProvided_thenShouldNotReturnSetting() throws Exception { + final String knnIndexMapping = createKnnIndexMapping(FIELD_NAME, KNNEngine.getMaxDimensionByEngine(KNNEngine.DEFAULT)); + final String indexName = "test-index-with-build-graph-settings"; + createKnnIndex(indexName, knnIndexMapping); + final String buildVectorDataStructureThresholdSetting = getIndexSettingByName( + indexName, + KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD + ); + assertNull( + "build_vector_data_structure_threshold index setting should not be added in index setting", + buildVectorDataStructureThresholdSetting + ); + deleteKNNIndex(indexName); + } + + public void testKNNIndex_whenGetIndexSettingWithDefaultIsCalled_thenReturnDefaultBuildGraphThresholdValue() throws Exception { + final String knnIndexMapping = createKnnIndexMapping(FIELD_NAME, KNNEngine.getMaxDimensionByEngine(KNNEngine.DEFAULT)); + final String indexName = "test-index-with-build-vector-graph-settings"; + createKnnIndex(indexName, knnIndexMapping); + final String buildVectorDataStructureThresholdSetting = getIndexSettingByName( + indexName, + KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, + true + ); + assertNotNull("build_vector_data_structure index setting is not found", buildVectorDataStructureThresholdSetting); + assertEquals( + "incorrect default setting for build_vector_data_structure_threshold", + KNNSettings.INDEX_KNN_DEFAULT_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, + Integer.valueOf(buildVectorDataStructureThresholdSetting) + ); + deleteKNNIndex(indexName); + } + + /* + For this testcase, we will create index with setting build_vector_data_structure_threshold as -1, then index few documents, perform knn search, + then, confirm no hits since there are no graph. In next step, update setting to 0, force merge segment to 1, perform knn search and confirm expected + hits are returned. + */ + public void testKNNIndex_whenBuildVectorGraphThresholdIsProvidedEndToEnd_thenBuildGraphBasedOnSetting() throws Exception { + final String indexName = "test-index-1"; + final String fieldName1 = "test-field-1"; + final String fieldName2 = "test-field-2"; + + final Integer dimension = testData.indexData.vectors[0].length; + final Settings knnIndexSettings = buildKNNIndexSettings(-1); + + // Create an index + final XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName1) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, KNNConstants.METHOD_HNSW) + .field(KNNConstants.KNN_ENGINE, KNNEngine.NMSLIB.getName()) + .startObject(KNNConstants.PARAMETERS) + .endObject() + .endObject() + .endObject() + .startObject(fieldName2) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, KNNConstants.METHOD_HNSW) + .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(KNNConstants.PARAMETERS) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + createKnnIndex(indexName, knnIndexSettings, builder.toString()); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + indexName, + Integer.toString(testData.indexData.docs[i]), + ImmutableList.of(fieldName1, fieldName2), + ImmutableList.of( + Floats.asList(testData.indexData.vectors[i]).toArray(), + Floats.asList(testData.indexData.vectors[i]).toArray() + ) + ); + } + + refreshAllIndices(); + // Assert we have the right number of documents in the index + assertEquals(testData.indexData.docs.length, getDocCount(indexName)); + + final List nmslibNeighbors = getResults(indexName, fieldName1, testData.queries[0], 1); + assertEquals("unexpected neighbors are returned", 0, nmslibNeighbors.size()); + + final List faissNeighbors = getResults(indexName, fieldName2, testData.queries[0], 1); + assertEquals("unexpected neighbors are returned", 0, faissNeighbors.size()); + + // update build vector data structure setting + updateIndexSettings(indexName, Settings.builder().put(KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, 0)); + forceMergeKnnIndex(indexName, 1); + + final int k = 10; + for (int i = 0; i < testData.queries.length; i++) { + // Search nmslib field + final Response response = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName1, testData.queries[i], k), k); + final String responseBody = EntityUtils.toString(response.getEntity()); + final List nmslibValidNeighbors = parseSearchResponse(responseBody, fieldName1); + assertEquals(k, nmslibValidNeighbors.size()); + // Search faiss field + final List faissValidNeighbors = getResults(indexName, fieldName2, testData.queries[i], k); + assertEquals(k, faissValidNeighbors.size()); + } + + // Delete index + deleteKNNIndex(indexName); + } + + /* + For this testcase, we will create index with setting build_vector_data_structure_threshold number of documents to ingest, then index x documents, perform knn search, + then, confirm expected hits are returned. Here, we don't need force merge to build graph, since, threshold is less than + actual number of documents in segments + */ + public void testKNNIndex_whenBuildVectorDataStructureIsLessThanDocCount_thenBuildGraphBasedSuccessfully() throws Exception { + final String indexName = "test-index-1"; + final String fieldName = "test-field-1"; + + final Integer dimension = testData.indexData.vectors[0].length; + final Settings knnIndexSettings = buildKNNIndexSettings(testData.indexData.docs.length); + + // Create an index + final XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, KNNConstants.METHOD_HNSW) + .field(KNNConstants.KNN_ENGINE, KNNEngine.NMSLIB.getName()) + .startObject(KNNConstants.PARAMETERS) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + createKnnIndex(indexName, knnIndexSettings, builder.toString()); + // Disable refresh + updateIndexSettings(indexName, Settings.builder().put("index.refresh_interval", -1)); + + // Index the test data without refresh on every document + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + indexName, + Integer.toString(testData.indexData.docs[i]), + ImmutableList.of(fieldName), + ImmutableList.of(Floats.asList(testData.indexData.vectors[i]).toArray()), + false + ); + } + + refreshAllIndices(); + // Assert we have the right number of documents in the index + assertEquals(testData.indexData.docs.length, getDocCount(indexName)); + + final int k = 10; + for (int i = 0; i < testData.queries.length; i++) { + final Response response = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, testData.queries[i], k), k); + final String responseBody = EntityUtils.toString(response.getEntity()); + final List nmslibValidNeighbors = parseSearchResponse(responseBody, fieldName); + assertEquals(k, nmslibValidNeighbors.size()); + } + // Delete index + deleteKNNIndex(indexName); + } + + private List getResults(final String indexName, final String fieldName, final float[] vector, final int k) + throws IOException, ParseException { + final Response searchResponseField = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, vector, k), k); + final String searchResponseBody = EntityUtils.toString(searchResponseField.getEntity()); + return parseSearchResponse(searchResponseBody, fieldName); + } + } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java index 9f74b2c104..659c980ff6 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java @@ -29,10 +29,12 @@ import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.test.OpenSearchTestCase; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.function.Predicate; @@ -50,6 +52,7 @@ import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; @RequiredArgsConstructor @@ -72,12 +75,14 @@ public class NativeEngines990KnnVectorsWriterFlushTests extends OpenSearchTestCa private final String description; private final List> vectorsPerField; + private static final Integer BUILD_GRAPH_ALWAYS_THRESHOLD = 0; + private static final Integer BUILD_GRAPH_NEVER_THRESHOLD = -1; @Override public void setUp() throws Exception { super.setUp(); MockitoAnnotations.openMocks(this); - objectUnderTest = new NativeEngines990KnnVectorsWriter(segmentWriteState, flatVectorsWriter); + objectUnderTest = new NativeEngines990KnnVectorsWriter(segmentWriteState, flatVectorsWriter, BUILD_GRAPH_ALWAYS_THRESHOLD); } @ParametersFactory @@ -290,6 +295,329 @@ public void testFlush_WithQuantization() { } } + public void testFlush_whenThresholdIsNegative_thenNativeIndexWriterIsNeverCalled() throws IOException { + // Given + List> expectedVectorValues = new ArrayList<>(); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(vectorsPerField.get(i).values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + expectedVectorValues.add(knnVectorValues); + + }); + + final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter( + segmentWriteState, + flatVectorsWriter, + BUILD_GRAPH_NEVER_THRESHOLD + ); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final FieldInfo fieldInfo = fieldInfo( + i, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); + fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) + .thenReturn(field); + try { + nativeEngineWriter.addField(fieldInfo); + } catch (Exception e) { + throw new RuntimeException(e); + } + + DocsWithFieldSet docsWithFieldSet = field.getDocsWithField(); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValues.get(i)); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + .thenReturn(nativeIndexWriter); + }); + + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).flushIndex(any(), anyInt()); + + // When + nativeEngineWriter.flush(5, null); + + // Then + verify(flatVectorsWriter).flush(5, null); + if (vectorsPerField.size() > 0) { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + } + verifyNoInteractions(nativeIndexWriter); + } + } + + public void testFlush_whenThresholdIsGreaterThanVectorSize_thenNativeIndexWriterIsNeverCalled() throws IOException { + // Given + List> expectedVectorValues = new ArrayList<>(); + final Map sizeMap = new HashMap<>(); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(vectorsPerField.get(i).values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + sizeMap.put(i, randomVectorValues.size()); + expectedVectorValues.add(knnVectorValues); + + }); + final int maxThreshold = sizeMap.values().stream().filter(count -> count != 0).max(Integer::compareTo).orElse(0); + final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter( + segmentWriteState, + flatVectorsWriter, + maxThreshold + 1 + ); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final FieldInfo fieldInfo = fieldInfo( + i, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); + fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) + .thenReturn(field); + try { + nativeEngineWriter.addField(fieldInfo); + } catch (Exception e) { + throw new RuntimeException(e); + } + + DocsWithFieldSet docsWithFieldSet = field.getDocsWithField(); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValues.get(i)); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + .thenReturn(nativeIndexWriter); + }); + + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).flushIndex(any(), anyInt()); + + // When + nativeEngineWriter.flush(5, null); + + // Then + verify(flatVectorsWriter).flush(5, null); + if (vectorsPerField.size() > 0) { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + } + verifyNoInteractions(nativeIndexWriter); + } + } + + public void testFlush_whenThresholdIsEqualToMinNumberOfVectors_thenNativeIndexWriterIsCalled() throws IOException { + // Given + List> expectedVectorValues = new ArrayList<>(); + final Map sizeMap = new HashMap<>(); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(vectorsPerField.get(i).values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + sizeMap.put(i, randomVectorValues.size()); + expectedVectorValues.add(knnVectorValues); + + }); + + final int minThreshold = sizeMap.values().stream().filter(count -> count != 0).min(Integer::compareTo).orElse(0); + final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter( + segmentWriteState, + flatVectorsWriter, + minThreshold + ); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final FieldInfo fieldInfo = fieldInfo( + i, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); + fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) + .thenReturn(field); + try { + nativeEngineWriter.addField(fieldInfo); + } catch (Exception e) { + throw new RuntimeException(e); + } + + DocsWithFieldSet docsWithFieldSet = field.getDocsWithField(); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValues.get(i)); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + .thenReturn(nativeIndexWriter); + }); + + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).flushIndex(any(), anyInt()); + + // When + nativeEngineWriter.flush(5, null); + + // Then + verify(flatVectorsWriter).flush(5, null); + if (vectorsPerField.size() > 0) { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + assertTrue((long) KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue() > 0); + } + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + try { + if (vectorsPerField.get(i).size() > 0) { + verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + } else { + verify(nativeIndexWriter, never()).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + } + + public void testFlush_whenThresholdIsEqualToFixedValue_thenRelevantNativeIndexWriterIsCalled() throws IOException { + // Given + List> expectedVectorValues = new ArrayList<>(); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(vectorsPerField.get(i).values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + expectedVectorValues.add(knnVectorValues); + + }); + final int threshold = 4; + final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter( + segmentWriteState, + flatVectorsWriter, + threshold + ); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final FieldInfo fieldInfo = fieldInfo( + i, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); + fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) + .thenReturn(field); + try { + nativeEngineWriter.addField(fieldInfo); + } catch (Exception e) { + throw new RuntimeException(e); + } + + DocsWithFieldSet docsWithFieldSet = field.getDocsWithField(); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValues.get(i)); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + .thenReturn(nativeIndexWriter); + }); + + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).flushIndex(any(), anyInt()); + + // When + nativeEngineWriter.flush(5, null); + + // Then + verify(flatVectorsWriter).flush(5, null); + if (vectorsPerField.size() > 0) { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + assertTrue((long) KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue() > 0); + } + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + try { + if (vectorsPerField.get(i).size() >= threshold) { + verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + } else { + verify(nativeIndexWriter, never()).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + } + private FieldInfo fieldInfo(int fieldNumber, VectorEncoding vectorEncoding, Map attributes) { FieldInfo fieldInfo = mock(FieldInfo.class); when(fieldInfo.getFieldNumber()).thenReturn(fieldNumber); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java index 41940c4d47..58de424a38 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java @@ -32,6 +32,7 @@ import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.test.OpenSearchTestCase; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -74,12 +75,14 @@ public class NativeEngines990KnnVectorsWriterMergeTests extends OpenSearchTestCa private final String description; private final Map mergedVectors; + private static final Integer BUILD_GRAPH_ALWAYS_THRESHOLD = 0; + private static final Integer BUILD_GRAPH_NEVER_THRESHOLD = -1; @Override public void setUp() throws Exception { super.setUp(); MockitoAnnotations.openMocks(this); - objectUnderTest = new NativeEngines990KnnVectorsWriter(segmentWriteState, flatVectorsWriter); + objectUnderTest = new NativeEngines990KnnVectorsWriter(segmentWriteState, flatVectorsWriter, BUILD_GRAPH_ALWAYS_THRESHOLD); } @ParametersFactory @@ -155,6 +158,124 @@ public void testMerge() { } } + public void testMerge_whenThresholdIsNegative_thenNativeIndexWriterIsNeverCalled() throws IOException { + // Given + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(mergedVectors.values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter( + segmentWriteState, + flatVectorsWriter, + BUILD_GRAPH_NEVER_THRESHOLD + ); + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedStatic mergedVectorValuesMockedStatic = mockStatic( + KnnVectorsWriter.MergedVectorValues.class + ); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + final FieldInfo fieldInfo = fieldInfo( + 0, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, mergedVectors); + fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) + .thenReturn(field); + + mergedVectorValuesMockedStatic.when(() -> KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)) + .thenReturn(floatVectorValues); + knnVectorValuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues)) + .thenReturn(knnVectorValues); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + .thenReturn(nativeIndexWriter); + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).mergeIndex(any(), anyInt()); + + // When + nativeEngineWriter.mergeOneField(fieldInfo, mergeState); + + // Then + verify(flatVectorsWriter).mergeOneField(fieldInfo, mergeState); + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + verifyNoInteractions(nativeIndexWriter); + } + } + + public void testMerge_whenThresholdIsEqualToNumberOfVectors_thenNativeIndexWriterIsCalled() throws IOException { + // Given + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(mergedVectors.values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter( + segmentWriteState, + flatVectorsWriter, + mergedVectors.size() + ); + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedStatic mergedVectorValuesMockedStatic = mockStatic( + KnnVectorsWriter.MergedVectorValues.class + ); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + final FieldInfo fieldInfo = fieldInfo( + 0, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, mergedVectors); + fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) + .thenReturn(field); + + mergedVectorValuesMockedStatic.when(() -> KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)) + .thenReturn(floatVectorValues); + knnVectorValuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues)) + .thenReturn(knnVectorValues); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + .thenReturn(nativeIndexWriter); + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).mergeIndex(any(), anyInt()); + + // When + nativeEngineWriter.mergeOneField(fieldInfo, mergeState); + + // Then + verify(flatVectorsWriter).mergeOneField(fieldInfo, mergeState); + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + if (!mergedVectors.isEmpty()) { + verify(nativeIndexWriter).mergeIndex(knnVectorValues, mergedVectors.size()); + } else { + verifyNoInteractions(nativeIndexWriter); + } + } + } + @SneakyThrows public void testMerge_WithQuantization() { // Given diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 8c033ca030..557817f583 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -100,6 +100,8 @@ import static org.opensearch.knn.TestUtils.computeGroundTruthValues; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.index.KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD; +import static org.opensearch.knn.index.KNNSettings.KNN_INDEX; import static org.opensearch.knn.index.SpaceType.L2; import static org.opensearch.knn.index.memory.NativeMemoryCacheManager.GRAPH_COUNT; import static org.opensearch.knn.index.engine.KNNEngine.FAISS; @@ -638,7 +640,12 @@ protected void addKnnDocWithNestedField(String index, String docId, String neste * Add a single KNN Doc to an index with multiple fields */ protected void addKnnDoc(String index, String docId, List fieldNames, List vectors) throws IOException { - Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); + addKnnDoc(index, docId, fieldNames, vectors, true); + } + + protected void addKnnDoc(String index, String docId, List fieldNames, List vectors, boolean refresh) + throws IOException { + Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=" + refresh); XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); for (int i = 0; i < fieldNames.size(); i++) { @@ -768,6 +775,15 @@ protected Settings getKNNSegmentReplicatedIndexSettings() { .build(); } + protected Settings buildKNNIndexSettings(int buildVectorDatastructureThreshold) { + return Settings.builder() + .put("number_of_shards", 1) + .put("number_of_replicas", 0) + .put(KNN_INDEX, true) + .put(INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, buildVectorDatastructureThreshold) + .build(); + } + @SneakyThrows protected int getDataNodeCount() { Request request = new Request("GET", "_nodes/stats?filter_path=nodes.*.roles");