diff --git a/CHANGELOG.md b/CHANGELOG.md index f615a78fb9..0fbaeae1b3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 3.0](https://github.com/opensearch-project/k-NN/compare/2.x...HEAD) ### Features +* Introduce new setting to configure whether to build vector data structure or not during segment creation [#2007](https://github.com/opensearch-project/k-NN/pull/2007) ### Enhancements * Introducing a loading layer in FAISS [#2033](https://github.com/opensearch-project/k-NN/issues/2033) ### Bug Fixes diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index f5980879ab..ecacc500b7 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -67,6 +67,7 @@ public class KNNSettings { * Settings name */ public static final String KNN_SPACE_TYPE = "index.knn.space_type"; + public static final String INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD = "index.knn.build_vector_data_structure_threshold"; public static final String KNN_ALGO_PARAM_M = "index.knn.algo_param.m"; public static final String KNN_ALGO_PARAM_EF_CONSTRUCTION = "index.knn.algo_param.ef_construction"; public static final String KNN_ALGO_PARAM_EF_SEARCH = "index.knn.algo_param.ef_search"; @@ -97,6 +98,9 @@ public class KNNSettings { public static final boolean KNN_DEFAULT_FAISS_AVX2_DISABLED_VALUE = false; public static final boolean KNN_DEFAULT_FAISS_AVX512_DISABLED_VALUE = false; public static final String INDEX_KNN_DEFAULT_SPACE_TYPE = "l2"; + public static final Integer INDEX_KNN_DEFAULT_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD = 0; + public static final Integer INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD_MIN = -1; + public static final Integer INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD_MAX = Integer.MAX_VALUE - 2; public static final String INDEX_KNN_DEFAULT_SPACE_TYPE_FOR_BINARY = "hamming"; public static final Integer INDEX_KNN_DEFAULT_ALGO_PARAM_M = 16; public static final Integer INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH = 100; @@ -156,6 +160,21 @@ public class KNNSettings { Setting.Property.Deprecated ); + /** + * build_vector_data_structure_threshold - This parameter determines when to build vector data structure for knn fields during indexing + * and merging. Setting -1 (min) will skip building graph, whereas on any other values, the graph will be built if + * number of live docs in segment is greater than this threshold. Since max number of documents in a segment can + * be Integer.MAX_VALUE - 1, this setting will allow threshold to be up to 1 less than max number of documents in a segment + */ + public static final Setting INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD_SETTING = Setting.intSetting( + INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, + INDEX_KNN_DEFAULT_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, + INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD_MIN, + INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD_MAX, + IndexScope, + Dynamic + ); + /** * M - the number of bi-directional links created for every new element during construction. * Reasonable range for M is 2-100. Higher M work better on datasets with high intrinsic @@ -486,6 +505,7 @@ private Setting getSetting(String key) { public List> getSettings() { List> settings = Arrays.asList( INDEX_KNN_SPACE_TYPE, + INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD_SETTING, INDEX_KNN_ALGO_PARAM_M_SETTING, INDEX_KNN_ALGO_PARAM_EF_CONSTRUCTION_SETTING, INDEX_KNN_ALGO_PARAM_EF_SEARCH_SETTING, diff --git a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java index b06c2a1d8a..516f832232 100644 --- a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java @@ -11,7 +11,9 @@ import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.opensearch.index.IndexSettings; import org.opensearch.index.mapper.MapperService; +import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.codec.KNN990Codec.NativeEngines990KnnVectorsFormat; import org.opensearch.knn.index.codec.params.KNNScalarQuantizedVectorsFormatParams; import org.opensearch.knn.index.codec.params.KNNVectorsFormatParams; @@ -129,7 +131,21 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) { } private NativeEngines990KnnVectorsFormat nativeEngineVectorsFormat() { - return new NativeEngines990KnnVectorsFormat(new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer())); + int buildVectorDatastructureThreshold = getBuildVectorDatastructureThresholdSetting(mapperService.get()); + return new NativeEngines990KnnVectorsFormat( + new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()), + buildVectorDatastructureThreshold + ); + } + + private int getBuildVectorDatastructureThresholdSetting(final MapperService knnMapperService) { + final IndexSettings indexSettings = knnMapperService.getIndexSettings(); + final Integer buildVectorDatastructureThreshold = indexSettings.getValue( + KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD_SETTING + ); + return buildVectorDatastructureThreshold != null + ? buildVectorDatastructureThreshold + : KNNSettings.INDEX_KNN_DEFAULT_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD; } @Override diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java index 626210f250..5b2c9cdddb 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java @@ -19,6 +19,7 @@ import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; +import org.opensearch.knn.index.KNNSettings; import java.io.IOException; @@ -30,15 +31,20 @@ public class NativeEngines990KnnVectorsFormat extends KnnVectorsFormat { /** The format for storing, reading, merging vectors on disk */ private static FlatVectorsFormat flatVectorsFormat; private static final String FORMAT_NAME = "NativeEngines990KnnVectorsFormat"; + private static int buildVectorDatastructureThreshold; public NativeEngines990KnnVectorsFormat() { - super(FORMAT_NAME); - flatVectorsFormat = new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer()); + this(new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer())); + } + + public NativeEngines990KnnVectorsFormat(final FlatVectorsFormat flatVectorsFormat) { + this(flatVectorsFormat, KNNSettings.INDEX_KNN_DEFAULT_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD); } - public NativeEngines990KnnVectorsFormat(final FlatVectorsFormat lucene99FlatVectorsFormat) { + public NativeEngines990KnnVectorsFormat(final FlatVectorsFormat flatVectorsFormat, int buildVectorDatastructureThreshold) { super(FORMAT_NAME); - flatVectorsFormat = lucene99FlatVectorsFormat; + NativeEngines990KnnVectorsFormat.flatVectorsFormat = flatVectorsFormat; + NativeEngines990KnnVectorsFormat.buildVectorDatastructureThreshold = buildVectorDatastructureThreshold; } /** @@ -48,7 +54,7 @@ public NativeEngines990KnnVectorsFormat(final FlatVectorsFormat lucene99FlatVect */ @Override public KnnVectorsWriter fieldsWriter(final SegmentWriteState state) throws IOException { - return new NativeEngines990KnnVectorsWriter(state, flatVectorsFormat.fieldsWriter(state)); + return new NativeEngines990KnnVectorsWriter(state, flatVectorsFormat.fieldsWriter(state), buildVectorDatastructureThreshold); } /** @@ -63,6 +69,12 @@ public KnnVectorsReader fieldsReader(final SegmentReadState state) throws IOExce @Override public String toString() { - return "NativeEngines99KnnVectorsFormat(name=" + this.getClass().getSimpleName() + ", flatVectorsFormat=" + flatVectorsFormat + ")"; + return "NativeEngines99KnnVectorsFormat(name=" + + this.getClass().getSimpleName() + + ", flatVectorsFormat=" + + flatVectorsFormat + + ", buildVectorDatastructureThreshold" + + buildVectorDatastructureThreshold + + ")"; } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index 2f22565c98..1251b9763f 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -53,10 +53,16 @@ public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter { private KNN990QuantizationStateWriter quantizationStateWriter; private final List> fields = new ArrayList<>(); private boolean finished; + private final Integer buildVectorDataStructureThreshold; - public NativeEngines990KnnVectorsWriter(SegmentWriteState segmentWriteState, FlatVectorsWriter flatVectorsWriter) { + public NativeEngines990KnnVectorsWriter( + SegmentWriteState segmentWriteState, + FlatVectorsWriter flatVectorsWriter, + Integer buildVectorDataStructureThreshold + ) { this.segmentWriteState = segmentWriteState; this.flatVectorsWriter = flatVectorsWriter; + this.buildVectorDataStructureThreshold = buildVectorDataStructureThreshold; } /** @@ -94,6 +100,16 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { field.getVectors() ); final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs); + // Will consider building vector data structure based on threshold only for non quantization indices + if (quantizationState == null && shouldSkipBuildingVectorDataStructure(totalLiveDocs)) { + log.info( + "Skip building vector data structure for field: {}, as liveDoc: {} is less than the threshold {} during flush", + fieldInfo.name, + totalLiveDocs, + buildVectorDataStructureThreshold + ); + continue; + } final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState); final KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); @@ -123,6 +139,16 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState } final QuantizationState quantizationState = train(fieldInfo, knnVectorValuesSupplier, totalLiveDocs); + // Will consider building vector data structure based on threshold only for non quantization indices + if (quantizationState == null && shouldSkipBuildingVectorDataStructure(totalLiveDocs)) { + log.info( + "Skip building vector data structure for field: {}, as liveDoc: {} is less than the threshold {} during merge", + fieldInfo.name, + totalLiveDocs, + buildVectorDataStructureThreshold + ); + return; + } final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState); final KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); @@ -253,4 +279,11 @@ private void initQuantizationStateWriterIfNecessary() throws IOException { quantizationStateWriter.writeHeader(segmentWriteState); } } + + private boolean shouldSkipBuildingVectorDataStructure(final long docCount) { + if (buildVectorDataStructureThreshold < 0) { + return true; + } + return docCount < buildVectorDataStructureThreshold; + } } diff --git a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java index 0bb7947baf..e9da88c0a6 100644 --- a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java +++ b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java @@ -15,6 +15,7 @@ import com.google.common.primitives.Floats; import java.util.Locale; import lombok.SneakyThrows; +import org.apache.hc.core5.http.ParseException; import org.junit.BeforeClass; import org.junit.Ignore; import org.opensearch.knn.KNNRestTestCase; @@ -42,6 +43,8 @@ import java.util.TreeMap; import static org.hamcrest.Matchers.containsString; +import static org.opensearch.knn.index.KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD_MAX; +import static org.opensearch.knn.index.KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD_MIN; public class OpenSearchIT extends KNNRestTestCase { @@ -611,4 +614,211 @@ public void testCacheClear_whenCloseIndex() throws Exception { fail("Graphs are not getting evicted"); } + + public void testKNNIndex_whenBuildGraphThresholdIsPresent_thenGetThresholdValue() throws Exception { + final Integer buildVectorDataStructureThreshold = randomIntBetween( + INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD_MIN, + INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD_MAX + ); + final Settings settings = Settings.builder().put(buildKNNIndexSettings(buildVectorDataStructureThreshold)).build(); + final String knnIndexMapping = createKnnIndexMapping(FIELD_NAME, KNNEngine.getMaxDimensionByEngine(KNNEngine.DEFAULT)); + final String indexName = "test-index-with-build-graph-settings"; + createKnnIndex(indexName, settings, knnIndexMapping); + final String buildVectorDataStructureThresholdSetting = getIndexSettingByName( + indexName, + KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD + ); + assertNotNull("build_vector_data_structure_threshold index setting is not found", buildVectorDataStructureThresholdSetting); + assertEquals( + "incorrect setting for build_vector_data_structure_threshold", + buildVectorDataStructureThreshold, + Integer.valueOf(buildVectorDataStructureThresholdSetting) + ); + deleteKNNIndex(indexName); + } + + public void testKNNIndex_whenBuildThresholdIsNotProvided_thenShouldNotReturnSetting() throws Exception { + final String knnIndexMapping = createKnnIndexMapping(FIELD_NAME, KNNEngine.getMaxDimensionByEngine(KNNEngine.DEFAULT)); + final String indexName = "test-index-with-build-graph-settings"; + createKnnIndex(indexName, knnIndexMapping); + final String buildVectorDataStructureThresholdSetting = getIndexSettingByName( + indexName, + KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD + ); + assertNull( + "build_vector_data_structure_threshold index setting should not be added in index setting", + buildVectorDataStructureThresholdSetting + ); + deleteKNNIndex(indexName); + } + + public void testKNNIndex_whenGetIndexSettingWithDefaultIsCalled_thenReturnDefaultBuildGraphThresholdValue() throws Exception { + final String knnIndexMapping = createKnnIndexMapping(FIELD_NAME, KNNEngine.getMaxDimensionByEngine(KNNEngine.DEFAULT)); + final String indexName = "test-index-with-build-vector-graph-settings"; + createKnnIndex(indexName, knnIndexMapping); + final String buildVectorDataStructureThresholdSetting = getIndexSettingByName( + indexName, + KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, + true + ); + assertNotNull("build_vector_data_structure index setting is not found", buildVectorDataStructureThresholdSetting); + assertEquals( + "incorrect default setting for build_vector_data_structure_threshold", + KNNSettings.INDEX_KNN_DEFAULT_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, + Integer.valueOf(buildVectorDataStructureThresholdSetting) + ); + deleteKNNIndex(indexName); + } + + /* + For this testcase, we will create index with setting build_vector_data_structure_threshold as -1, then index few documents, perform knn search, + then, confirm no hits since there are no graph. In next step, update setting to 0, force merge segment to 1, perform knn search and confirm expected + hits are returned. + */ + public void testKNNIndex_whenBuildVectorGraphThresholdIsProvidedEndToEnd_thenBuildGraphBasedOnSetting() throws Exception { + final String indexName = "test-index-1"; + final String fieldName1 = "test-field-1"; + final String fieldName2 = "test-field-2"; + + final Integer dimension = testData.indexData.vectors[0].length; + final Settings knnIndexSettings = buildKNNIndexSettings(-1); + + // Create an index + final XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName1) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, KNNConstants.METHOD_HNSW) + .field(KNNConstants.KNN_ENGINE, KNNEngine.NMSLIB.getName()) + .startObject(KNNConstants.PARAMETERS) + .endObject() + .endObject() + .endObject() + .startObject(fieldName2) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, KNNConstants.METHOD_HNSW) + .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(KNNConstants.PARAMETERS) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + createKnnIndex(indexName, knnIndexSettings, builder.toString()); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + indexName, + Integer.toString(testData.indexData.docs[i]), + ImmutableList.of(fieldName1, fieldName2), + ImmutableList.of( + Floats.asList(testData.indexData.vectors[i]).toArray(), + Floats.asList(testData.indexData.vectors[i]).toArray() + ) + ); + } + + refreshAllIndices(); + // Assert we have the right number of documents in the index + assertEquals(testData.indexData.docs.length, getDocCount(indexName)); + + final List nmslibNeighbors = getResults(indexName, fieldName1, testData.queries[0], 1); + assertEquals("unexpected neighbors are returned", 0, nmslibNeighbors.size()); + + final List faissNeighbors = getResults(indexName, fieldName2, testData.queries[0], 1); + assertEquals("unexpected neighbors are returned", 0, faissNeighbors.size()); + + // update build vector data structure setting + updateIndexSettings(indexName, Settings.builder().put(KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, 0)); + forceMergeKnnIndex(indexName, 1); + + final int k = 10; + for (int i = 0; i < testData.queries.length; i++) { + // Search nmslib field + final Response response = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName1, testData.queries[i], k), k); + final String responseBody = EntityUtils.toString(response.getEntity()); + final List nmslibValidNeighbors = parseSearchResponse(responseBody, fieldName1); + assertEquals(k, nmslibValidNeighbors.size()); + // Search faiss field + final List faissValidNeighbors = getResults(indexName, fieldName2, testData.queries[i], k); + assertEquals(k, faissValidNeighbors.size()); + } + + // Delete index + deleteKNNIndex(indexName); + } + + /* + For this testcase, we will create index with setting build_vector_data_structure_threshold number of documents to ingest, then index x documents, perform knn search, + then, confirm expected hits are returned. Here, we don't need force merge to build graph, since, threshold is less than + actual number of documents in segments + */ + public void testKNNIndex_whenBuildVectorDataStructureIsLessThanDocCount_thenBuildGraphBasedSuccessfully() throws Exception { + final String indexName = "test-index-1"; + final String fieldName = "test-field-1"; + + final Integer dimension = testData.indexData.vectors[0].length; + final Settings knnIndexSettings = buildKNNIndexSettings(testData.indexData.docs.length); + + // Create an index + final XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, KNNConstants.METHOD_HNSW) + .field(KNNConstants.KNN_ENGINE, KNNEngine.NMSLIB.getName()) + .startObject(KNNConstants.PARAMETERS) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + createKnnIndex(indexName, knnIndexSettings, builder.toString()); + // Disable refresh + updateIndexSettings(indexName, Settings.builder().put("index.refresh_interval", -1)); + + // Index the test data without refresh on every document + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + indexName, + Integer.toString(testData.indexData.docs[i]), + ImmutableList.of(fieldName), + ImmutableList.of(Floats.asList(testData.indexData.vectors[i]).toArray()), + false + ); + } + + refreshAllIndices(); + // Assert we have the right number of documents in the index + assertEquals(testData.indexData.docs.length, getDocCount(indexName)); + + final int k = 10; + for (int i = 0; i < testData.queries.length; i++) { + final Response response = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, testData.queries[i], k), k); + final String responseBody = EntityUtils.toString(response.getEntity()); + final List nmslibValidNeighbors = parseSearchResponse(responseBody, fieldName); + assertEquals(k, nmslibValidNeighbors.size()); + } + // Delete index + deleteKNNIndex(indexName); + } + + private List getResults(final String indexName, final String fieldName, final float[] vector, final int k) + throws IOException, ParseException { + final Response searchResponseField = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, vector, k), k); + final String searchResponseBody = EntityUtils.toString(searchResponseField.getEntity()); + return parseSearchResponse(searchResponseBody, fieldName); + } + } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java index 9f74b2c104..659c980ff6 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java @@ -29,10 +29,12 @@ import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.test.OpenSearchTestCase; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.function.Predicate; @@ -50,6 +52,7 @@ import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; @RequiredArgsConstructor @@ -72,12 +75,14 @@ public class NativeEngines990KnnVectorsWriterFlushTests extends OpenSearchTestCa private final String description; private final List> vectorsPerField; + private static final Integer BUILD_GRAPH_ALWAYS_THRESHOLD = 0; + private static final Integer BUILD_GRAPH_NEVER_THRESHOLD = -1; @Override public void setUp() throws Exception { super.setUp(); MockitoAnnotations.openMocks(this); - objectUnderTest = new NativeEngines990KnnVectorsWriter(segmentWriteState, flatVectorsWriter); + objectUnderTest = new NativeEngines990KnnVectorsWriter(segmentWriteState, flatVectorsWriter, BUILD_GRAPH_ALWAYS_THRESHOLD); } @ParametersFactory @@ -290,6 +295,329 @@ public void testFlush_WithQuantization() { } } + public void testFlush_whenThresholdIsNegative_thenNativeIndexWriterIsNeverCalled() throws IOException { + // Given + List> expectedVectorValues = new ArrayList<>(); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(vectorsPerField.get(i).values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + expectedVectorValues.add(knnVectorValues); + + }); + + final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter( + segmentWriteState, + flatVectorsWriter, + BUILD_GRAPH_NEVER_THRESHOLD + ); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final FieldInfo fieldInfo = fieldInfo( + i, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); + fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) + .thenReturn(field); + try { + nativeEngineWriter.addField(fieldInfo); + } catch (Exception e) { + throw new RuntimeException(e); + } + + DocsWithFieldSet docsWithFieldSet = field.getDocsWithField(); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValues.get(i)); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + .thenReturn(nativeIndexWriter); + }); + + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).flushIndex(any(), anyInt()); + + // When + nativeEngineWriter.flush(5, null); + + // Then + verify(flatVectorsWriter).flush(5, null); + if (vectorsPerField.size() > 0) { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + } + verifyNoInteractions(nativeIndexWriter); + } + } + + public void testFlush_whenThresholdIsGreaterThanVectorSize_thenNativeIndexWriterIsNeverCalled() throws IOException { + // Given + List> expectedVectorValues = new ArrayList<>(); + final Map sizeMap = new HashMap<>(); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(vectorsPerField.get(i).values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + sizeMap.put(i, randomVectorValues.size()); + expectedVectorValues.add(knnVectorValues); + + }); + final int maxThreshold = sizeMap.values().stream().filter(count -> count != 0).max(Integer::compareTo).orElse(0); + final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter( + segmentWriteState, + flatVectorsWriter, + maxThreshold + 1 + ); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final FieldInfo fieldInfo = fieldInfo( + i, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); + fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) + .thenReturn(field); + try { + nativeEngineWriter.addField(fieldInfo); + } catch (Exception e) { + throw new RuntimeException(e); + } + + DocsWithFieldSet docsWithFieldSet = field.getDocsWithField(); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValues.get(i)); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + .thenReturn(nativeIndexWriter); + }); + + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).flushIndex(any(), anyInt()); + + // When + nativeEngineWriter.flush(5, null); + + // Then + verify(flatVectorsWriter).flush(5, null); + if (vectorsPerField.size() > 0) { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + } + verifyNoInteractions(nativeIndexWriter); + } + } + + public void testFlush_whenThresholdIsEqualToMinNumberOfVectors_thenNativeIndexWriterIsCalled() throws IOException { + // Given + List> expectedVectorValues = new ArrayList<>(); + final Map sizeMap = new HashMap<>(); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(vectorsPerField.get(i).values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + sizeMap.put(i, randomVectorValues.size()); + expectedVectorValues.add(knnVectorValues); + + }); + + final int minThreshold = sizeMap.values().stream().filter(count -> count != 0).min(Integer::compareTo).orElse(0); + final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter( + segmentWriteState, + flatVectorsWriter, + minThreshold + ); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final FieldInfo fieldInfo = fieldInfo( + i, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); + fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) + .thenReturn(field); + try { + nativeEngineWriter.addField(fieldInfo); + } catch (Exception e) { + throw new RuntimeException(e); + } + + DocsWithFieldSet docsWithFieldSet = field.getDocsWithField(); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValues.get(i)); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + .thenReturn(nativeIndexWriter); + }); + + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).flushIndex(any(), anyInt()); + + // When + nativeEngineWriter.flush(5, null); + + // Then + verify(flatVectorsWriter).flush(5, null); + if (vectorsPerField.size() > 0) { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + assertTrue((long) KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue() > 0); + } + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + try { + if (vectorsPerField.get(i).size() > 0) { + verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + } else { + verify(nativeIndexWriter, never()).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + } + + public void testFlush_whenThresholdIsEqualToFixedValue_thenRelevantNativeIndexWriterIsCalled() throws IOException { + // Given + List> expectedVectorValues = new ArrayList<>(); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(vectorsPerField.get(i).values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + expectedVectorValues.add(knnVectorValues); + + }); + final int threshold = 4; + final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter( + segmentWriteState, + flatVectorsWriter, + threshold + ); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final FieldInfo fieldInfo = fieldInfo( + i, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); + fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) + .thenReturn(field); + try { + nativeEngineWriter.addField(fieldInfo); + } catch (Exception e) { + throw new RuntimeException(e); + } + + DocsWithFieldSet docsWithFieldSet = field.getDocsWithField(); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValues.get(i)); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + .thenReturn(nativeIndexWriter); + }); + + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).flushIndex(any(), anyInt()); + + // When + nativeEngineWriter.flush(5, null); + + // Then + verify(flatVectorsWriter).flush(5, null); + if (vectorsPerField.size() > 0) { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + assertTrue((long) KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue() > 0); + } + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + try { + if (vectorsPerField.get(i).size() >= threshold) { + verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + } else { + verify(nativeIndexWriter, never()).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + } + private FieldInfo fieldInfo(int fieldNumber, VectorEncoding vectorEncoding, Map attributes) { FieldInfo fieldInfo = mock(FieldInfo.class); when(fieldInfo.getFieldNumber()).thenReturn(fieldNumber); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java index 41940c4d47..58de424a38 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java @@ -32,6 +32,7 @@ import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.test.OpenSearchTestCase; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -74,12 +75,14 @@ public class NativeEngines990KnnVectorsWriterMergeTests extends OpenSearchTestCa private final String description; private final Map mergedVectors; + private static final Integer BUILD_GRAPH_ALWAYS_THRESHOLD = 0; + private static final Integer BUILD_GRAPH_NEVER_THRESHOLD = -1; @Override public void setUp() throws Exception { super.setUp(); MockitoAnnotations.openMocks(this); - objectUnderTest = new NativeEngines990KnnVectorsWriter(segmentWriteState, flatVectorsWriter); + objectUnderTest = new NativeEngines990KnnVectorsWriter(segmentWriteState, flatVectorsWriter, BUILD_GRAPH_ALWAYS_THRESHOLD); } @ParametersFactory @@ -155,6 +158,124 @@ public void testMerge() { } } + public void testMerge_whenThresholdIsNegative_thenNativeIndexWriterIsNeverCalled() throws IOException { + // Given + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(mergedVectors.values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter( + segmentWriteState, + flatVectorsWriter, + BUILD_GRAPH_NEVER_THRESHOLD + ); + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedStatic mergedVectorValuesMockedStatic = mockStatic( + KnnVectorsWriter.MergedVectorValues.class + ); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + final FieldInfo fieldInfo = fieldInfo( + 0, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, mergedVectors); + fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) + .thenReturn(field); + + mergedVectorValuesMockedStatic.when(() -> KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)) + .thenReturn(floatVectorValues); + knnVectorValuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues)) + .thenReturn(knnVectorValues); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + .thenReturn(nativeIndexWriter); + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).mergeIndex(any(), anyInt()); + + // When + nativeEngineWriter.mergeOneField(fieldInfo, mergeState); + + // Then + verify(flatVectorsWriter).mergeOneField(fieldInfo, mergeState); + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + verifyNoInteractions(nativeIndexWriter); + } + } + + public void testMerge_whenThresholdIsEqualToNumberOfVectors_thenNativeIndexWriterIsCalled() throws IOException { + // Given + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(mergedVectors.values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter( + segmentWriteState, + flatVectorsWriter, + mergedVectors.size() + ); + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedStatic mergedVectorValuesMockedStatic = mockStatic( + KnnVectorsWriter.MergedVectorValues.class + ); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + final FieldInfo fieldInfo = fieldInfo( + 0, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, mergedVectors); + fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) + .thenReturn(field); + + mergedVectorValuesMockedStatic.when(() -> KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)) + .thenReturn(floatVectorValues); + knnVectorValuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues)) + .thenReturn(knnVectorValues); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + .thenReturn(nativeIndexWriter); + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).mergeIndex(any(), anyInt()); + + // When + nativeEngineWriter.mergeOneField(fieldInfo, mergeState); + + // Then + verify(flatVectorsWriter).mergeOneField(fieldInfo, mergeState); + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + if (!mergedVectors.isEmpty()) { + verify(nativeIndexWriter).mergeIndex(knnVectorValues, mergedVectors.size()); + } else { + verifyNoInteractions(nativeIndexWriter); + } + } + } + @SneakyThrows public void testMerge_WithQuantization() { // Given diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 8c033ca030..557817f583 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -100,6 +100,8 @@ import static org.opensearch.knn.TestUtils.computeGroundTruthValues; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.index.KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD; +import static org.opensearch.knn.index.KNNSettings.KNN_INDEX; import static org.opensearch.knn.index.SpaceType.L2; import static org.opensearch.knn.index.memory.NativeMemoryCacheManager.GRAPH_COUNT; import static org.opensearch.knn.index.engine.KNNEngine.FAISS; @@ -638,7 +640,12 @@ protected void addKnnDocWithNestedField(String index, String docId, String neste * Add a single KNN Doc to an index with multiple fields */ protected void addKnnDoc(String index, String docId, List fieldNames, List vectors) throws IOException { - Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); + addKnnDoc(index, docId, fieldNames, vectors, true); + } + + protected void addKnnDoc(String index, String docId, List fieldNames, List vectors, boolean refresh) + throws IOException { + Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=" + refresh); XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); for (int i = 0; i < fieldNames.size(); i++) { @@ -768,6 +775,15 @@ protected Settings getKNNSegmentReplicatedIndexSettings() { .build(); } + protected Settings buildKNNIndexSettings(int buildVectorDatastructureThreshold) { + return Settings.builder() + .put("number_of_shards", 1) + .put("number_of_replicas", 0) + .put(KNN_INDEX, true) + .put(INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, buildVectorDatastructureThreshold) + .build(); + } + @SneakyThrows protected int getDataNodeCount() { Request request = new Request("GET", "_nodes/stats?filter_path=nodes.*.roles");