diff --git a/CHANGELOG.md b/CHANGELOG.md index c6ebc3c3f..1a8d808f4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 3.0](https://github.com/opensearch-project/k-NN/compare/2.x...HEAD) ### Features ### Enhancements -### Bug Fixes +### Bug Fixes +* Add DocValuesProducers for releasing memory when close index [#1946](https://github.com/opensearch-project/k-NN/pull/1946) ### Infrastructure * Removed JDK 11 and 17 version from CI runs [#1921](https://github.com/opensearch-project/k-NN/pull/1921) ### Documentation diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundDirectory.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundDirectory.java new file mode 100644 index 000000000..0821b19ef --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundDirectory.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN80Codec; + +import lombok.Getter; +import org.apache.lucene.codecs.CompoundDirectory; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.opensearch.knn.index.engine.KNNEngine; + +import java.io.IOException; +import java.util.Set; + +public class KNN80CompoundDirectory extends CompoundDirectory { + + @Getter + private CompoundDirectory delegate; + @Getter + private Directory dir; + + public KNN80CompoundDirectory(CompoundDirectory delegate, Directory dir) { + this.delegate = delegate; + this.dir = dir; + } + + @Override + public void checkIntegrity() throws IOException { + delegate.checkIntegrity(); + } + + @Override + public String[] listAll() throws IOException { + return delegate.listAll(); + } + + @Override + public long fileLength(String name) throws IOException { + return delegate.fileLength(name); + } + + @Override + public IndexInput openInput(String name, IOContext context) throws IOException { + if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().stream().anyMatch(engine -> name.endsWith(engine.getCompoundExtension()))) { + return dir.openInput(name, context); + } + return delegate.openInput(name, context); + } + + @Override + public void close() throws IOException { + delegate.close(); + } + + @Override + public Set getPendingDeletions() throws IOException { + return delegate.getPendingDeletions(); + } + +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormat.java index 0f51bdcd5..24dbfb78b 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormat.java @@ -41,7 +41,7 @@ public KNN80CompoundFormat(CompoundFormat delegate) { @Override public CompoundDirectory getCompoundReader(Directory dir, SegmentInfo si, IOContext context) throws IOException { - return delegate.getCompoundReader(dir, si, context); + return new KNN80CompoundDirectory(delegate.getCompoundReader(dir, si, context), dir); } @Override diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesFormat.java index fe329eb1c..7e45270b6 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesFormat.java @@ -41,6 +41,6 @@ public DocValuesConsumer fieldsConsumer(SegmentWriteState state) throws IOExcept @Override public DocValuesProducer fieldsProducer(SegmentReadState state) throws IOException { - return delegate.fieldsProducer(state); + return new KNN80DocValuesProducer(delegate.fieldsProducer(state), state); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesProducer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesProducer.java new file mode 100644 index 000000000..0cfd9c668 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesProducer.java @@ -0,0 +1,143 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index.codec.KNN80Codec; + +import lombok.NonNull; +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.codecs.DocValuesProducer; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.NumericDocValues; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SortedDocValues; +import org.apache.lucene.index.SortedNumericDocValues; +import org.apache.lucene.index.SortedSetDocValues; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.FSDirectory; +import org.apache.lucene.store.FilterDirectory; +import org.opensearch.common.io.PathUtils; +import org.opensearch.knn.common.FieldInfoExtractor; +import org.opensearch.knn.index.codec.util.KNNCodecUtil; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.memory.NativeMemoryCacheManager; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapper.KNN_FIELD; + +@Log4j2 +public class KNN80DocValuesProducer extends DocValuesProducer { + + private final SegmentReadState state; + private final DocValuesProducer delegate; + private final NativeMemoryCacheManager nativeMemoryCacheManager; + private final Map indexPathMap = new HashMap(); + + public KNN80DocValuesProducer(DocValuesProducer delegate, SegmentReadState state) { + this.delegate = delegate; + this.state = state; + this.nativeMemoryCacheManager = NativeMemoryCacheManager.getInstance(); + + Directory directory = state.directory; + // directory would be CompoundDirectory, we need get directory firstly and then unwrap + if (state.directory instanceof KNN80CompoundDirectory) { + directory = ((KNN80CompoundDirectory) state.directory).getDir(); + } + + Directory dir = FilterDirectory.unwrap(directory); + if (!(dir instanceof FSDirectory)) { + log.warn("{} can not casting to FSDirectory", directory); + return; + } + String directoryPath = ((FSDirectory) dir).getDirectory().toString(); + for (FieldInfo field : state.fieldInfos) { + if (!field.attributes().containsKey(KNN_FIELD)) { + continue; + } + // Only Native Engine put into indexPathMap + KNNEngine knnEngine = getNativeKNNEngine(field); + if (knnEngine == null) { + continue; + } + List engineFiles = KNNCodecUtil.getEngineFiles(knnEngine.getExtension(), field.name, state.segmentInfo); + Path indexPath = PathUtils.get(directoryPath, engineFiles.get(0)); + indexPathMap.putIfAbsent(field.getName(), indexPath.toString()); + } + } + + @Override + public BinaryDocValues getBinary(FieldInfo field) throws IOException { + return delegate.getBinary(field); + } + + @Override + public NumericDocValues getNumeric(FieldInfo field) throws IOException { + return delegate.getNumeric(field); + } + + @Override + public SortedDocValues getSorted(FieldInfo field) throws IOException { + return delegate.getSorted(field); + } + + @Override + public SortedNumericDocValues getSortedNumeric(FieldInfo field) throws IOException { + return delegate.getSortedNumeric(field); + } + + @Override + public SortedSetDocValues getSortedSet(FieldInfo field) throws IOException { + return delegate.getSortedSet(field); + } + + @Override + public void checkIntegrity() throws IOException { + delegate.checkIntegrity(); + } + + @Override + public void close() throws IOException { + for (String path : indexPathMap.values()) { + nativeMemoryCacheManager.invalidate(path); + } + delegate.close(); + } + + public final List getOpenedIndexPath() { + return new ArrayList<>(indexPathMap.values()); + } + + /** + * Get KNNEngine From FieldInfo + * + * @param field which field we need produce from engine + * @return if and only if Native Engine we return specific engine, else return null + */ + private KNNEngine getNativeKNNEngine(@NonNull FieldInfo field) { + + final String modelId = field.attributes().get(MODEL_ID); + if (modelId != null) { + return null; + } + KNNEngine engine = FieldInfoExtractor.extractKNNEngine(field); + if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(engine)) { + return engine; + } + return null; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java index 51100a1e0..84c7c4675 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java @@ -6,8 +6,15 @@ package org.opensearch.knn.index.codec.util; import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.SegmentInfo; +import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.KNN80Codec.KNN80BinaryDocValues; +import org.opensearch.knn.index.engine.KNNEngine; + +import java.util.Comparator; +import java.util.List; +import java.util.stream.Collectors; public class KNNCodecUtil { // Floats are 4 bytes in size @@ -53,4 +60,28 @@ public static long getTotalLiveDocsCount(final BinaryDocValues binaryDocValues) } return totalLiveDocs; } + + /** + * Get Engine Files from segment with specific fieldName and engine extension + * + * @param extension Engine extension comes from {@link KNNEngine#getExtension()}} + * @param fieldName Filed for knn field + * @param segmentInfo {@link SegmentInfo} One Segment info to use for compute. + * @return List of engine files + */ + public static List getEngineFiles(String extension, String fieldName, SegmentInfo segmentInfo) { + /* + * In case of compound file, extension would be + c otherwise + */ + String engineExtension = segmentInfo.getUseCompoundFile() ? extension + KNNConstants.COMPOUND_EXTENSION : extension; + String engineSuffix = fieldName + engineExtension; + String underLineEngineSuffix = "_" + engineSuffix; + + List engineFiles = segmentInfo.files() + .stream() + .filter(fileName -> fileName.endsWith(underLineEngineSuffix)) + .sorted(Comparator.comparingInt(String::length)) + .collect(Collectors.toList()); + return engineFiles; + } } diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java index 0bb8a556f..635bc3883 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java @@ -15,6 +15,7 @@ import lombok.Setter; import org.apache.lucene.index.LeafReaderContext; import org.opensearch.knn.common.featureflags.KNNFeatureFlags; +import org.opensearch.common.concurrent.RefCountedReleasable; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.qframe.QuantizationConfig; import org.opensearch.knn.index.query.KNNWeight; @@ -81,6 +82,26 @@ public interface NativeMemoryAllocation { */ int getSizeInKB(); + /** + * Increments the refCount of this instance. + * + * @see #decRef + * @throws IllegalStateException iff the reference counter can not be incremented. + */ + default void incRef() {} + + /** + * Decreases the refCount of this instance. If the refCount drops to 0, then this + * instance is considered as closed and should not be used anymore. + * + * @see #incRef + * + * @return returns {@code true} if the ref count dropped to 0 as a result of calling this method + */ + default boolean decRef() { + return true; + } + /** * Represents native indices loaded into memory. Because these indices are backed by files, they should be * freed when file is deleted. @@ -102,6 +123,7 @@ class IndexAllocation implements NativeMemoryAllocation { private final SharedIndexState sharedIndexState; @Getter private final boolean isBinaryIndex; + private final RefCountedReleasable refCounted; /** * Constructor @@ -160,10 +182,10 @@ class IndexAllocation implements NativeMemoryAllocation { this.watcherHandle = watcherHandle; this.sharedIndexState = sharedIndexState; this.isBinaryIndex = isBinaryIndex; + this.refCounted = new RefCountedReleasable<>("IndexAllocation-Reference", this, this::closeInternal); } - @Override - public void close() { + protected void closeInternal() { Runnable onClose = () -> { writeLock(); cleanup(); @@ -179,6 +201,13 @@ public void close() { } } + @Override + public void close() { + if (!closed && refCounted.refCount() > 0) { + refCounted.close(); + } + } + private void cleanup() { if (this.closed) { return; @@ -242,6 +271,16 @@ public void writeUnlock() { public int getSizeInKB() { return size; } + + @Override + public void incRef() { + refCounted.incRef(); + } + + @Override + public boolean decRef() { + return refCounted.decRef(); + } } /** diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index b1ba9de59..1c31ed725 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -5,7 +5,6 @@ package org.opensearch.knn.index.query; -import com.google.common.annotations.VisibleForTesting; import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.LeafReaderContext; @@ -27,6 +26,7 @@ import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.util.KNNCodecUtil; import org.opensearch.knn.index.memory.NativeMemoryAllocation; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.memory.NativeMemoryEntryContext; @@ -43,7 +43,6 @@ import java.nio.file.Path; import java.util.Arrays; import java.util.Collections; -import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.concurrent.ExecutionException; @@ -272,7 +271,7 @@ private Map doANNSearch( // TODO: Change type of vector once more quantization methods are supported final byte[] quantizedVector = SegmentLevelQuantizationUtil.quantizeVector(knnQuery.getQueryVector(), segmentLevelQuantizationInfo); - List engineFiles = getEngineFiles(reader, knnEngine.getExtension()); + List engineFiles = KNNCodecUtil.getEngineFiles(knnEngine.getExtension(), knnQuery.getField(), reader.getSegmentInfo().info); if (engineFiles.isEmpty()) { log.debug("[KNN] No engine index found for field {} for segment {}", knnQuery.getField(), reader.getSegmentName()); return null; @@ -312,6 +311,7 @@ private Map doANNSearch( FilterIdsSelector.FilterIdsSelectorType filterType = filterIdsSelector.getFilterType(); // Now that we have the allocation, we need to readLock it indexAllocation.readLock(); + indexAllocation.incRef(); try { if (indexAllocation.isClosed()) { throw new RuntimeException("Index has already been closed"); @@ -361,6 +361,7 @@ private Map doANNSearch( throw new RuntimeException(e); } finally { indexAllocation.readUnlock(); + indexAllocation.decRef(); } /* @@ -378,25 +379,6 @@ private Map doANNSearch( .collect(Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), spaceType))); } - @VisibleForTesting - List getEngineFiles(SegmentReader reader, String extension) throws IOException { - /* - * In case of compound file, extension would be + c otherwise - */ - String engineExtension = reader.getSegmentInfo().info.getUseCompoundFile() - ? extension + KNNConstants.COMPOUND_EXTENSION - : extension; - String engineSuffix = knnQuery.getField() + engineExtension; - String underLineEngineSuffix = "_" + engineSuffix; - List engineFiles = reader.getSegmentInfo() - .files() - .stream() - .filter(fileName -> fileName.endsWith(underLineEngineSuffix)) - .sorted(Comparator.comparingInt(String::length)) - .collect(Collectors.toList()); - return engineFiles; - } - /** * Execute exact search for the given matched doc ids and return the results as a map of docId to score. * diff --git a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java index 1ea7ecca6..ded3a827c 100644 --- a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java +++ b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java @@ -483,4 +483,128 @@ public void testIndexingVectorValidation_updateVectorWithNull() throws Exception assertArrayEquals(vectorForDocumentOne, vectorRestoreInitialValue); } + public void testCacheClear_whenCloseIndex() throws Exception { + String indexName = "test-index-1"; + KNNEngine knnEngine1 = KNNEngine.NMSLIB; + KNNEngine knnEngine2 = KNNEngine.FAISS; + String fieldName1 = "test-field-1"; + String fieldName2 = "test-field-2"; + SpaceType spaceType1 = SpaceType.COSINESIMIL; + SpaceType spaceType2 = SpaceType.L2; + + List mValues = ImmutableList.of(16, 32, 64, 128); + List efConstructionValues = ImmutableList.of(16, 32, 64, 128); + List efSearchValues = ImmutableList.of(16, 32, 64, 128); + + Integer dimension = testData.indexData.vectors[0].length; + + // Create an index + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName1) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, KNNConstants.METHOD_HNSW) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType1.getValue()) + .field(KNNConstants.KNN_ENGINE, knnEngine1.getName()) + .startObject(KNNConstants.PARAMETERS) + .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .endObject() + .endObject() + .endObject() + .startObject(fieldName2) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, KNNConstants.METHOD_HNSW) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType2.getValue()) + .field(KNNConstants.KNN_ENGINE, knnEngine2.getName()) + .startObject(KNNConstants.PARAMETERS) + .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + Map mappingMap = xContentBuilderToMap(builder); + String mapping = builder.toString(); + createKnnIndex(indexName, mapping); + assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName))); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + indexName, + Integer.toString(testData.indexData.docs[i]), + ImmutableList.of(fieldName1, fieldName2), + ImmutableList.of( + Floats.asList(testData.indexData.vectors[i]).toArray(), + Floats.asList(testData.indexData.vectors[i]).toArray() + ) + ); + } + + // Assert we have the right number of documents in the index + refreshAllIndices(); + assertEquals(testData.indexData.docs.length, getDocCount(indexName)); + + int k = 10; + for (int i = 0; i < testData.queries.length; i++) { + // Search the first field + Response response = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName1, testData.queries[i], k), k); + String responseBody = EntityUtils.toString(response.getEntity()); + List knnResults = parseSearchResponse(responseBody, fieldName1); + assertEquals(k, knnResults.size()); + + List actualScores = parseSearchResponseScore(responseBody, fieldName1); + for (int j = 0; j < k; j++) { + float[] primitiveArray = knnResults.get(j).getVector(); + assertEquals( + knnEngine1.score(1 - KNNScoringUtil.cosinesimil(testData.queries[i], primitiveArray), spaceType1), + actualScores.get(j), + 0.0001 + ); + } + + // Search the second field + response = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName2, testData.queries[i], k), k); + responseBody = EntityUtils.toString(response.getEntity()); + knnResults = parseSearchResponse(responseBody, fieldName2); + assertEquals(k, knnResults.size()); + + actualScores = parseSearchResponseScore(responseBody, fieldName2); + for (int j = 0; j < k; j++) { + float[] primitiveArray = knnResults.get(j).getVector(); + assertEquals( + knnEngine2.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType2), + actualScores.get(j), + 0.0001 + ); + } + } + + // Get Stats + int graphCount = getTotalGraphsInCache(); + assertTrue(graphCount > 0); + // Close index + closeKNNIndex(indexName); + + // Search every 5 seconds 14 times to confirm graph gets evicted + int intervals = 14; + for (int i = 0; i < intervals; i++) { + if (getTotalGraphsInCache() == 0) { + return; + } + + Thread.sleep(5 * 1000); + } + + fail("Graphs are not getting evicted"); + } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormatTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormatTests.java index 0ecabcce6..6001a9729 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormatTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80CompoundFormatTests.java @@ -49,7 +49,9 @@ public void testGetCompoundReader() throws IOException { CompoundFormat delegate = mock(CompoundFormat.class); when(delegate.getCompoundReader(null, null, null)).thenReturn(dir); KNN80CompoundFormat knn80CompoundFormat = new KNN80CompoundFormat(delegate); - assertEquals(dir, knn80CompoundFormat.getCompoundReader(null, null, null)); + CompoundDirectory knnDir = knn80CompoundFormat.getCompoundReader(null, null, null); + assertTrue(knnDir instanceof KNN80CompoundDirectory); + assertEquals(dir, ((KNN80CompoundDirectory) knnDir).getDelegate()); } public void testWrite() throws IOException { diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesProducerTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesProducerTests.java new file mode 100644 index 000000000..b9a85bbcc --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesProducerTests.java @@ -0,0 +1,130 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN80Codec; + +import com.google.common.collect.ImmutableMap; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.DocValuesFormat; +import org.apache.lucene.codecs.DocValuesProducer; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.SegmentInfo; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexOutput; +import org.junit.Before; +import org.opensearch.Version; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.KNN87Codec.KNN87Codec; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.KNNMethodConfigContext; +import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.engine.MethodComponentContext; +import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; + +public class KNN80DocValuesProducerTests extends KNNTestCase { + + private static Directory directory; + + @Before + public void setUp() throws Exception { + super.setUp(); + directory = newFSDirectory(createTempDir()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + directory.close(); + } + + public void testProduceKNNBinaryField_fromCodec_nmslibCurrent() throws IOException { + // Set information about the segment and the fields + DocValuesFormat mockDocValuesFormat = mock(DocValuesFormat.class); + Codec mockDelegateCodec = mock(Codec.class); + DocValuesProducer mockDocValuesProducer = mock(DocValuesProducer.class); + when(mockDelegateCodec.docValuesFormat()).thenReturn(mockDocValuesFormat); + when(mockDocValuesFormat.fieldsProducer(any())).thenReturn(mockDocValuesProducer); + when(mockDocValuesFormat.getName()).thenReturn("mockDocValuesFormat"); + Codec codec = new KNN87Codec(mockDelegateCodec); + + String segmentName = "_test"; + int docsInSegment = 100; + String fieldName1 = String.format("test_field1%s", randomAlphaOfLength(4)); + String fieldName2 = String.format("test_field2%s", randomAlphaOfLength(4)); + List segmentFiles = Arrays.asList( + String.format("%s_2011_%s%s", segmentName, fieldName1, KNNEngine.NMSLIB.getExtension()), + String.format("%s_165_%s%s", segmentName, fieldName2, KNNEngine.FAISS.getExtension()) + ); + + KNNEngine knnEngine = KNNEngine.NMSLIB; + SpaceType spaceType = SpaceType.COSINESIMIL; + SegmentInfo segmentInfo = KNNCodecTestUtil.segmentInfoBuilder() + .directory(directory) + .segmentName(segmentName) + .docsInSegment(docsInSegment) + .codec(codec) + .build(); + + for (String name : segmentFiles) { + IndexOutput indexOutput = directory.createOutput(name, IOContext.DEFAULT); + indexOutput.close(); + } + segmentInfo.setFiles(segmentFiles); + + KNNMethodContext knnMethodContext = new KNNMethodContext( + knnEngine, + spaceType, + new MethodComponentContext(METHOD_HNSW, ImmutableMap.of(METHOD_PARAMETER_M, 16, METHOD_PARAMETER_EF_CONSTRUCTION, 512)) + ); + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .versionCreated(Version.CURRENT) + .build(); + String parameterString = XContentFactory.jsonBuilder() + .map(knnEngine.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext).getLibraryParameters()) + .toString(); + + FieldInfo[] fieldInfoArray = new FieldInfo[] { + KNNCodecTestUtil.FieldInfoBuilder.builder(fieldName1) + .addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true") + .addAttribute(KNNConstants.KNN_ENGINE, knnEngine.getName()) + .addAttribute(KNNConstants.SPACE_TYPE, spaceType.getValue()) + .addAttribute(KNNConstants.PARAMETERS, parameterString) + .build() }; + + FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); + SegmentReadState state = new SegmentReadState(directory, segmentInfo, fieldInfos, IOContext.DEFAULT); + + DocValuesFormat docValuesFormat = codec.docValuesFormat(); + assertTrue(docValuesFormat instanceof KNN80DocValuesFormat); + DocValuesProducer producer = docValuesFormat.fieldsProducer(state); + assertTrue(producer instanceof KNN80DocValuesProducer); + int pathSize = ((KNN80DocValuesProducer) producer).getOpenedIndexPath().size(); + assertEquals(pathSize, 1); + + String path = ((KNN80DocValuesProducer) producer).getOpenedIndexPath().get(0); + assertTrue(path.contains(segmentFiles.get(0))); + } + +} diff --git a/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java b/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java index dbea6375b..86e22cd88 100644 --- a/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java @@ -6,8 +6,15 @@ package org.opensearch.knn.index.codec.util; import junit.framework.TestCase; +import org.apache.lucene.index.SegmentInfo; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNEngine; +import java.util.List; +import java.util.Set; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import static org.opensearch.knn.index.codec.util.KNNCodecUtil.calculateArraySize; public class KNNCodecUtilTests extends TestCase { @@ -28,4 +35,15 @@ public void testCalculateArraySize() { vectorDataType = VectorDataType.BINARY; assertEquals(40, calculateArraySize(numVectors, vectorLength, vectorDataType)); } + + public void testGetKNNEngines() { + SegmentInfo segmentInfo = mock(SegmentInfo.class); + KNNEngine knnEngine = KNNEngine.FAISS; + Set SEGMENT_MULTI_FIELD_FILES_FAISS = Set.of("_0.cfe", "_0_2011_long_target_field.faissc", "_0_2011_target_field.faissc"); + when(segmentInfo.getUseCompoundFile()).thenReturn(true); + when(segmentInfo.files()).thenReturn(SEGMENT_MULTI_FIELD_FILES_FAISS); + List engineFiles = KNNCodecUtil.getEngineFiles(knnEngine.getExtension(), "target_field", segmentInfo); + assertEquals(engineFiles.size(), 2); + assertTrue(engineFiles.get(0).equals("_0_2011_target_field.faissc")); + } } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 810f49c15..f92f32406 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -41,6 +41,7 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.codec.KNN990Codec.QuantizationConfigKNNCollector; +import org.opensearch.knn.index.codec.util.KNNCodecUtil; import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; @@ -1318,7 +1319,7 @@ private void testQueryScore( String engineName = fieldInfo.attributes().getOrDefault(KNN_ENGINE, KNNEngine.NMSLIB.getName()); KNNEngine knnEngine = KNNEngine.getEngine(engineName); - List engineFiles = knnWeight.getEngineFiles(reader, knnEngine.getExtension()); + List engineFiles = KNNCodecUtil.getEngineFiles(knnEngine.getExtension(), query.getField(), reader.getSegmentInfo().info); String expectIndexPath = String.format("%s_%s_%s%s%s", SEGMENT_NAME, 2011, FIELD_NAME, knnEngine.getExtension(), "c"); assertEquals(engineFiles.get(0), expectIndexPath); diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index fb974b6e1..8c033ca03 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -357,6 +357,12 @@ protected void deleteKNNIndex(String index) throws IOException { assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } + protected void closeKNNIndex(String index) throws IOException { + Request request = new Request("POST", "/" + index + "/_close"); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + /** * For a given index, make a mapping request */