From fd59b9adf42b07aa2b2058c4badff6dacf8306a8 Mon Sep 17 00:00:00 2001 From: Tejas Shah Date: Tue, 20 Aug 2024 11:32:28 -0700 Subject: [PATCH] Integrates FAISS iterative builds with NativeEngines990KnnVectorsFormat (#1950) * Iterative Vector Insertion (#1840) * Rebased with new version of k-NN Signed-off-by: Andrew Klepchick * Optimized faiss insertion Signed-off-by: Andrew Klepchick * Optimized threadCount logic Signed-off-by: Andrew Klepchick * Removed IDEA files Signed-off-by: Andrew Klepchick * Removed unnecessary cmake file Signed-off-by: Andrew Klepchick * Added comments to new functions Signed-off-by: Andrew Klepchick * Removed createIndex and fixed test cases that use it Signed-off-by: Andrew Klepchick * Removed unused code Signed-off-by: Andrew Klepchick * Explained zero initialization for vector transfer Signed-off-by: Andrew Klepchick * Added locale Signed-off-by: Andrew Klepchick * Spotless Apply Signed-off-by: Andrew Klepchick * Account for zero documents in finished batch Signed-off-by: Andrew Klepchick * Changed where we check for zero docs Signed-off-by: Andrew Klepchick * Changed tip for return Signed-off-by: Andrew Klepchick * Use unique pointers to make sure resources are released on exception Signed-off-by: Andrew Klepchick * Moved createIndex to testUtils Signed-off-by: Andrew Klepchick * Fixed memory management so that the underlying index is not deleted after initialized Signed-off-by: Andrew Klepchick * Created new KNNIndexBuilder graph to make index building more modular Signed-off-by: Andrew Klepchick * Streamlined logic in KNNIndexBuilder. Signed-off-by: Andrew Klepchick * Cleaned up unnecessary code in KNN80DocValuesConsumer Signed-off-by: Andrew Klepchick * Fixed memory management process Signed-off-by: Andrew Klepchick * Added note about index initialization in faiss_index_service Signed-off-by: Andrew Klepchick * Accounted for case where the exception happens after the indexWriter is released. Signed-off-by: Andrew Klepchick * Delete jni/src/.idea/modules.xml Signed-off-by: Andrew Klepchick * Delete jni/src/.idea/vcs.xml Signed-off-by: Andrew Klepchick * Delete jni/src/.idea/workspace.xml Signed-off-by: Andrew Klepchick * Spotless apply and free iterative index on exception Signed-off-by: Andrew Klepchick * Undid hack for checking first document metrics Signed-off-by: Andrew Klepchick * Removed print statements Signed-off-by: Andrew Klepchick * Free Vector Transfer on batch ingestion Signed-off-by: Andrew Klepchick * Undid free Signed-off-by: Andrew Klepchick * Fixed check for transfer ready Signed-off-by: Andrew Klepchick * Don't crash when zero vectors inserted? Signed-off-by: Andrew Klepchick * Reverted to old insertion process? Signed-off-by: Andrew Klepchick * Spotless apply Signed-off-by: Andrew Klepchick * Added back createOutput Signed-off-by: Andrew Klepchick * Removed prior createOutput Signed-off-by: Andrew Klepchick * Test remaking vectorTransfer Signed-off-by: Andrew Klepchick * Test restructuring of insertion Signed-off-by: Andrew Klepchick * Fixed case where vector address is immediately discarded Signed-off-by: Andrew Klepchick * Spotless apply Signed-off-by: Andrew Klepchick * Split Index Builder into multiple classes Signed-off-by: Andrew Klepchick * Fixed descriptions of functions in faiss_index_service Signed-off-by: Andrew Klepchick * Added back copyright files Signed-off-by: Andrew Klepchick * Removed unused builder names Signed-off-by: Andrew Klepchick * Modified tests to work with new insertion methods Signed-off-by: Andrew Klepchick * Track index insertions Signed-off-by: Andrew Klepchick * Tracked insertions for binary indices Signed-off-by: Andrew Klepchick * Added back insertIds Signed-off-by: Andrew Klepchick * Added check for freeVectorData to see if it works with an already deleted address Signed-off-by: Andrew Klepchick * Cleaned up logs and comments in KNNIndexBuilder Signed-off-by: Andrew Klepchick * Restructured the logic for KNNIndexBuilder Signed-off-by: Andrew Klepchick * Changed package name of KNNIndexBuilder Signed-off-by: Andrew Klepchick * Changed all package names and deleted unnecessary headers Signed-off-by: Andrew Klepchick * Fixed for loop Signed-off-by: Andrew Klepchick * Removed createIndex methods for faiss index service Signed-off-by: Andrew Klepchick * Fixed package to fit naming conventions Signed-off-by: Andrew Klepchick * Changed name of index builder Signed-off-by: Andrew Klepchick * Spotless apply Signed-off-by: Andrew Klepchick * Added comments to NativeIndexBuilder and restructured Signed-off-by: Andrew Klepchick * Added deletion for memoryAddress Signed-off-by: Andrew Klepchick * Spotless apply Signed-off-by: Andrew Klepchick * Changed naming of classes to Writer and changed package name to fit conventions Signed-off-by: Andrew Klepchick * Changed NativeIndexInfo and NativeVectorInfo to follow builder pattern Signed-off-by: Andrew Klepchick * Added feature to changelog Signed-off-by: Andrew Klepchick * Added class descriptions to each NativeIndexWriter Signed-off-by: Andrew Klepchick * Changed name to getBytesPerVector Signed-off-by: Andrew Klepchick * Added == false instead of ! for readability Signed-off-by: Andrew Klepchick * Fixed changelog Signed-off-by: Andrew Klepchick * Fixed naming in docvaluesconsumer Signed-off-by: Andrew Klepchick * SpotlessApply Signed-off-by: Andrew Klepchick * Made it so that we don't reuse testValues and removed a foot gun Signed-off-by: Andrew Klepchick * Removed another foot gun in getIndexInfo Signed-off-by: Andrew Klepchick * Fixed javadoc Signed-off-by: Andrew Klepchick * Added deletion on exception cases Signed-off-by: Andrew Klepchick * Removed unnecessary delete (NativeIndexWriter will handle deletion of vectors on exception) Signed-off-by: Andrew Klepchick * Added correct logger and getWriter method to NativeIndexWriter Signed-off-by: Andrew Klepchick * Ensured memory safety on JNI layer so that Java doesn't have to wrap everything in a try catch loop. Signed-off-by: Andrew Klepchick * Refactored NativeIndexWriter and added comments to FaissService Signed-off-by: Andrew Klepchick * Removed free in the JNIExport since index will always be freed in writeIndex. Signed-off-by: Andrew Klepchick * Changed getVectorTransfer back to accept VectorDataType Signed-off-by: Andrew Klepchick * Reverted free since not guaranteed to be IDMap. Signed-off-by: Andrew Klepchick * Added all processes in addKNNBinaryField to NativeIndexWriter.createKNNIndex Signed-off-by: Andrew Klepchick * Fixed javadoc Signed-off-by: Andrew Klepchick * Applied spotless Signed-off-by: Andrew Klepchick * Added back writeFooter Signed-off-by: Andrew Klepchick * Removed threadCount fron writeIndex Signed-off-by: Andrew Klepchick * Removed redundancies in KNN80DocValuesConsumer Signed-off-by: Andrew Klepchick * Removed serializationMode Signed-off-by: Andrew Klepchick * Fixed changelog Signed-off-by: Andrew Klepchick * Fixed changelog Signed-off-by: Andrew Klepchick * Removed double free test as we don't have to worry about that anymore Signed-off-by: Andrew Klepchick * Accounted for HNSWSQ in index service Signed-off-by: Andrew Klepchick * Removed delete in catch Signed-off-by: Andrew Klepchick * Fixed faiss tests to work with writeIndex Signed-off-by: Andrew Klepchick --------- Signed-off-by: Andrew Klepchick * Index Initialization Alloc Method (#1933) * Added methods for allocating memory before inserting vectors to a faiss index Signed-off-by: Andrew Klepchick * Fixed logic that gets type of index Signed-off-by: Andrew Klepchick * Removed print statement Signed-off-by: Andrew Klepchick * Removed unnecessary iostream Signed-off-by: Andrew Klepchick * Removed flat index Signed-off-by: Andrew Klepchick * Fixed flat index case Signed-off-by: Andrew Klepchick * Fixed naming Signed-off-by: Andrew Klepchick * Properly allocate HNSWSQ storage Signed-off-by: Andrew Klepchick * Removed print statements Signed-off-by: Andrew Klepchick * Fixed changelog Signed-off-by: Andrew Klepchick * Removed unnecessary lib Signed-off-by: Andrew Klepchick * Made alloc adaptive to different code sizes Signed-off-by: Andrew Klepchick --------- Signed-off-by: Andrew Klepchick * Integrates FAISS iterative builds with NativeEngines990KnnVectorsFormat Changes include reusing the same vector buffer in the JNI layer Signed-off-by: Tejas Shah --------- Signed-off-by: Andrew Klepchick Signed-off-by: Tejas Shah Co-authored-by: Andrew Klepchick --- CHANGELOG.md | 6 +- jni/include/commons.h | 17 +- jni/include/faiss_index_service.h | 87 +++-- jni/include/faiss_wrapper.h | 9 +- .../org_opensearch_knn_jni_FaissService.h | 49 ++- .../org_opensearch_knn_jni_JNICommons.h | 6 +- jni/src/commons.cpp | 14 +- jni/src/faiss_index_service.cpp | 168 +++++++--- jni/src/faiss_wrapper.cpp | 72 +++-- .../org_opensearch_knn_jni_FaissService.cpp | 76 ++++- jni/src/org_opensearch_knn_jni_JNICommons.cpp | 8 +- jni/tests/commons_test.cpp | 116 ++++++- jni/tests/faiss_index_service_test.cpp | 30 +- jni/tests/faiss_wrapper_test.cpp | 94 +++++- jni/tests/mocks/faiss_index_service_mock.h | 25 +- .../knn/common/FieldInfoExtractor.java | 20 +- .../opensearch/knn/common/KNNVectorUtil.java | 35 +- .../codec/BasePerFieldKnnVectorsFormat.java | 69 ++-- .../KNN80Codec/KNN80DocValuesConsumer.java | 284 +---------------- .../NativeEngineFieldVectorsWriter.java | 3 + .../NativeEngines990KnnVectorsWriter.java | 42 ++- .../DefaultIndexBuildStrategy.java | 95 ++++++ .../MemOptimizedNativeIndexBuildStrategy.java | 117 +++++++ .../nativeindex/NativeIndexBuildStrategy.java | 19 ++ .../codec/nativeindex/NativeIndexWriter.java | 298 ++++++++++++++++++ .../nativeindex/model/BuildIndexParams.java | 26 ++ .../transfer/OffHeapBinaryVectorTransfer.java | 32 ++ .../transfer/OffHeapByteVectorTransfer.java | 38 +++ .../transfer/OffHeapFloatVectorTransfer.java | 36 +++ .../codec/transfer/OffHeapVectorTransfer.java | 99 ++++++ .../OffHeapVectorTransferFactory.java | 37 +++ .../index/codec/transfer/VectorTransfer.java | 54 ---- .../codec/transfer/VectorTransferByte.java | 68 ---- .../codec/transfer/VectorTransferFloat.java | 71 ----- .../knn/index/codec/util/KNNCodecUtil.java | 49 --- .../vectorvalues/KNNBinaryVectorValues.java | 17 +- .../vectorvalues/KNNByteVectorValues.java | 12 + .../vectorvalues/KNNFloatVectorValues.java | 11 + .../index/vectorvalues/KNNVectorValues.java | 25 +- .../org/opensearch/knn/indices/ModelUtil.java | 4 +- .../org/opensearch/knn/jni/FaissService.java | 64 +++- .../org/opensearch/knn/jni/JNICommons.java | 76 ++++- .../org/opensearch/knn/jni/JNIService.java | 197 ++++++++---- .../knn/common/FieldInfoExtractorTests.java | 24 ++ .../knn/common/KNNVectorUtilTests.java | 34 ++ .../KNN80DocValuesConsumerTests.java | 82 +++-- ...NativeEngines990KnnVectorsFormatTests.java | 14 +- .../DefaultIndexBuildStrategyTests.java | 167 ++++++++++ ...ptimizedNativeIndexBuildStrategyTests.java | 129 ++++++++ .../OffHeapVectorTransferFactoryTests.java | 26 ++ .../transfer/OffHeapVectorTransferTests.java | 92 ++++++ .../transfer/VectorTransferByteTests.java | 56 ---- .../transfer/VectorTransferFloatTests.java | 66 ---- .../index/codec/util/KNNCodecUtilTests.java | 43 --- .../memory/NativeMemoryAllocationTests.java | 5 +- .../memory/NativeMemoryLoadStrategyTests.java | 7 +- .../vectorvalues/KNNVectorValuesTests.java | 42 ++- .../opensearch/knn/jni/JNIServiceTests.java | 93 +++--- .../java/org/opensearch/knn/TestUtils.java | 17 +- 59 files changed, 2484 insertions(+), 1088 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexBuildStrategy.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryVectorTransfer.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapByteVectorTransfer.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapFloatVectorTransfer.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransfer.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactory.java delete mode 100644 src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransfer.java delete mode 100644 src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferByte.java delete mode 100644 src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloat.java create mode 100644 src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java create mode 100644 src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java create mode 100644 src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactoryTests.java create mode 100644 src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferTests.java delete mode 100644 src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferByteTests.java delete mode 100644 src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloatTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index ad8f13179..a0c61ae55 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,4 @@ + # CHANGELOG All notable changes to this project are documented in this file. @@ -6,7 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 3.0](https://github.com/opensearch-project/k-NN/compare/2.x...HEAD) ### Features ### Enhancements -### Bug Fixes +### Bug Fixes ### Infrastructure * Removed JDK 11 and 17 version from CI runs [#1921](https://github.com/opensearch-project/k-NN/pull/1921) ### Documentation @@ -17,10 +18,11 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features * Integrate Lucene Vector field with native engines to use KNNVectorFormat during segment creation [#1945](https://github.com/opensearch-project/k-NN/pull/1945) ### Enhancements +* Adds iterative graph build capability into a faiss index to improve the memory footprint during indexing and Integrates KNNVectorsFormat for native engines[#1950](https://github.com/opensearch-project/k-NN/pull/1950) ### Bug Fixes * Corrected search logic for scenario with non-existent fields in filter [#1874](https://github.com/opensearch-project/k-NN/pull/1874) * Add script_fields context to KNNAllowlist [#1917] (https://github.com/opensearch-project/k-NN/pull/1917) -* Fix graph merge stats size calculation [#1844](https://github.com/opensearch-project/k-NN/pull/1844) +* Fix graph merge stats size calculation [#1844](https://github.com/opensearch-project/k-NN/pull/1844) * Disallow a vector field to have an invalid character for a physical file name. [#1936](https://github.com/opensearch-project/k-NN/pull/1936) ### Infrastructure ### Documentation diff --git a/jni/include/commons.h b/jni/include/commons.h index d02439377..4cdaf28fc 100644 --- a/jni/include/commons.h +++ b/jni/include/commons.h @@ -19,12 +19,19 @@ namespace knn_jni { * For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location * will throw Exception. * + * append tells the method to keep appending to the existing vector. Passing the value as false will clear the vector + * without reallocating new memory. This helps with reducing memory frangmentation and overhead of allocating + * and deallocating when the memory address needs to be reused. + * + * CAUTION: The behavior is undefined if the memory address is deallocated and the method is called + * * @param memoryAddress The address of the memory location where data will be stored. * @param data 2D float array containing data to be stored in native memory. * @param initialCapacity The initial capacity of the memory location. + * @param append whether to append or start from index 0 when called subsequently with the same address * @return memory address of std::vector where the data is stored. */ - jlong storeVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong); + jlong storeVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong, jboolean); /** * This is utility function that can be used to store data in native memory. This function will allocate memory for @@ -33,12 +40,18 @@ namespace knn_jni { * For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location * will throw Exception. * + * append tells the method to keep appending to the existing vector. Passing the value as false will clear the vector + * without reallocating new memory. This helps with reducing memory frangmentation and overhead of allocating + * and deallocating when the memory address needs to be reused. + * + * CAUTION: The behavior is undefined if the memory address is deallocated and the method is called + * * @param memoryAddress The address of the memory location where data will be stored. * @param data 2D byte array containing data to be stored in native memory. * @param initialCapacity The initial capacity of the memory location. * @return memory address of std::vector where the data is stored. */ - jlong storeByteVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong); + jlong storeByteVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong, jboolean); /** * Free up the memory allocated for the data stored in memory address. This function should be used with the memory diff --git a/jni/include/faiss_index_service.h b/jni/include/faiss_index_service.h index 59f15fda9..c57309cfc 100644 --- a/jni/include/faiss_index_service.h +++ b/jni/include/faiss_index_service.h @@ -31,38 +31,41 @@ namespace faiss_wrapper { class IndexService { public: IndexService(std::unique_ptr faissMethods); - //TODO Remove dependency on JNIUtilInterface and JNIEnv - //TODO Reduce the number of parameters - /** - * Create index + * Initialize index * * @param jniUtil jni util * @param env jni environment * @param metric space type for distance calculation * @param indexDescription index description to be used by faiss index factory * @param dim dimension of vectors + * @param numVectors number of vectors + * @param threadCount number of thread count to be used while adding data + * @param parameters parameters to be applied to faiss index + * @return memory address of the native index object + */ + virtual jlong initIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, faiss::MetricType metric, std::string indexDescription, int dim, int numVectors, int threadCount, std::unordered_map parameters); + /** + * Add vectors to index + * + * @param dim dimension of vectors * @param numIds number of vectors * @param threadCount number of thread count to be used while adding data * @param vectorsAddress memory address which is holding vector data - * @param ids a list of document ids for corresponding vectors + * @param idMapAddress memory address of the native index object + */ + virtual void insertToIndex(int dim, int numIds, int threadCount, int64_t vectorsAddress, std::vector &ids, jlong idMapAddress); + /** + * Write index to disk + * + * @param threadCount number of thread count to be used while adding data * @param indexPath path to write index - * @param parameters parameters to be applied to faiss index + * @param idMap memory address of the native index object */ - virtual void createIndex( - knn_jni::JNIUtilInterface * jniUtil, - JNIEnv * env, - faiss::MetricType metric, - std::string indexDescription, - int dim, - int numIds, - int threadCount, - int64_t vectorsAddress, - std::vector ids, - std::string indexPath, - std::unordered_map parameters); + virtual void writeIndex(std::string indexPath, jlong idMapAddress); virtual ~IndexService() = default; protected: + virtual void allocIndex(faiss::Index * index, size_t dim, size_t numVectors); std::unique_ptr faissMethods; }; @@ -76,7 +79,21 @@ class BinaryIndexService : public IndexService { //TODO Reduce the number of parameters BinaryIndexService(std::unique_ptr faissMethods); /** - * Create binary index + * Initialize index + * + * @param jniUtil jni util + * @param env jni environment + * @param metric space type for distance calculation + * @param indexDescription index description to be used by faiss index factory + * @param dim dimension of vectors + * @param numVectors number of vectors + * @param threadCount number of thread count to be used while adding data + * @param parameters parameters to be applied to faiss index + * @return memory address of the native index object + */ + virtual jlong initIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, faiss::MetricType metric, std::string indexDescription, int dim, int numVectors, int threadCount, std::unordered_map parameters) override; + /** + * Add vectors to index * * @param jniUtil jni util * @param env jni environment @@ -86,28 +103,30 @@ class BinaryIndexService : public IndexService { * @param numIds number of vectors * @param threadCount number of thread count to be used while adding data * @param vectorsAddress memory address which is holding vector data - * @param ids a list of document ids for corresponding vectors + * @param idMap a map of document id and vector id + * @param parameters parameters to be applied to faiss index + */ + virtual void insertToIndex(int dim, int numIds, int threadCount, int64_t vectorsAddress, std::vector &ids, jlong idMapAddress) override; + /** + * Write index to disk + * + * @param jniUtil jni util + * @param env jni environment + * @param metric space type for distance calculation + * @param indexDescription index description to be used by faiss index factory + * @param threadCount number of thread count to be used while adding data * @param indexPath path to write index + * @param idMap a map of document id and vector id * @param parameters parameters to be applied to faiss index */ - virtual void createIndex( - knn_jni::JNIUtilInterface * jniUtil, - JNIEnv * env, - faiss::MetricType metric, - std::string indexDescription, - int dim, - int numIds, - int threadCount, - int64_t vectorsAddress, - std::vector ids, - std::string indexPath, - std::unordered_map parameters - ) override; + virtual void writeIndex(std::string indexPath, jlong idMapAddress) override; virtual ~BinaryIndexService() = default; +protected: + virtual void allocIndex(faiss::Index * index, size_t dim, size_t numVectors) override; }; } } -#endif //OPENSEARCH_KNN_FAISS_INDEX_SERVICE_H +#endif //OPENSEARCH_KNN_FAISS_INDEX_SERVICE_H \ No newline at end of file diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index 5ad0dedc4..574efb6fd 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -18,10 +18,11 @@ namespace knn_jni { namespace faiss_wrapper { - // Create an index with ids and vectors. The configuration is defined by values in the Java map, parametersJ. - // The index is serialized to indexPathJ. - void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, - jstring indexPathJ, jobject parametersJ, IndexService* indexService); + jlong InitIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong numDocs, jint dimJ, jobject parametersJ, IndexService *indexService); + + void InsertToIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, jlong indexAddr, jint threadCount, IndexService *indexService); + + void WriteIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jstring indexPathJ, jlong indexAddr, IndexService *indexService); // Create an index with ids and vectors. Instead of creating a new index, this function creates the index // based off of the template index passed in. The index is serialized to indexPathJ. diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index 025fb12e8..19e13d402 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -18,23 +18,54 @@ #ifdef __cplusplus extern "C" { #endif - /* * Class: org_opensearch_knn_jni_FaissService - * Method: createIndex + * Method: initIndex * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V */ -JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex - (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject); - +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initIndex(JNIEnv * env, jclass cls, + jlong numDocs, jint dimJ, + jobject parametersJ); /* * Class: org_opensearch_knn_jni_FaissService - * Method: createBinaryIndex + * Method: initBinaryIndex * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V */ -JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndex - (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject); - +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initBinaryIndex(JNIEnv * env, jclass cls, + jlong numDocs, jint dimJ, + jobject parametersJ); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: insertToIndex + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + */ +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToIndex(JNIEnv * env, jclass cls, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, + jlong indexAddress, jint threadCount); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: insertToBinaryIndex + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + */ +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToBinaryIndex(JNIEnv * env, jclass cls, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, + jlong indexAddress, jint threadCount); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: writeIndex + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + */ +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeIndex(JNIEnv * env, jclass cls, + jlong indexAddress, + jstring indexPathJ); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: writeBinaryIndex + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + */ +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeBinaryIndex(JNIEnv * env, jclass cls, + jlong indexAddress, + jstring indexPathJ); /* * Class: org_opensearch_knn_jni_FaissService * Method: createIndexFromTemplate diff --git a/jni/include/org_opensearch_knn_jni_JNICommons.h b/jni/include/org_opensearch_knn_jni_JNICommons.h index 89de76520..03c0d023a 100644 --- a/jni/include/org_opensearch_knn_jni_JNICommons.h +++ b/jni/include/org_opensearch_knn_jni_JNICommons.h @@ -21,10 +21,10 @@ extern "C" { /* * Class: org_opensearch_knn_jni_JNICommons * Method: storeVectorData - * Signature: (J[[FJJ) + * Signature: (J[[FJJJ) */ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData - (JNIEnv *, jclass, jlong, jobjectArray, jlong); + (JNIEnv *, jclass, jlong, jobjectArray, jlong, jboolean); /* * Class: org_opensearch_knn_jni_JNICommons @@ -32,7 +32,7 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData * Signature: (J[[FJJ) */ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeByteVectorData - (JNIEnv *, jclass, jlong, jobjectArray, jlong); + (JNIEnv *, jclass, jlong, jobjectArray, jlong, jboolean); /* * Class: org_opensearch_knn_jni_JNICommons diff --git a/jni/src/commons.cpp b/jni/src/commons.cpp index 13f59194e..f9764db73 100644 --- a/jni/src/commons.cpp +++ b/jni/src/commons.cpp @@ -18,7 +18,7 @@ #include "commons.h" jlong knn_jni::commons::storeVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong memoryAddressJ, - jobjectArray dataJ, jlong initialCapacityJ) { + jobjectArray dataJ, jlong initialCapacityJ, jboolean appendJ) { std::vector *vect; if ((long) memoryAddressJ == 0) { vect = new std::vector(); @@ -26,6 +26,11 @@ jlong knn_jni::commons::storeVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIE } else { vect = reinterpret_cast*>(memoryAddressJ); } + + if (appendJ == JNI_FALSE) { + vect->clear(); + } + int dim = jniUtil->GetInnerDimensionOf2dJavaFloatArray(env, dataJ); jniUtil->Convert2dJavaObjectArrayAndStoreToFloatVector(env, dataJ, dim, vect); @@ -33,7 +38,7 @@ jlong knn_jni::commons::storeVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIE } jlong knn_jni::commons::storeByteVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong memoryAddressJ, - jobjectArray dataJ, jlong initialCapacityJ) { + jobjectArray dataJ, jlong initialCapacityJ, jboolean appendJ) { std::vector *vect; if ((long) memoryAddressJ == 0) { vect = new std::vector(); @@ -41,6 +46,11 @@ jlong knn_jni::commons::storeByteVectorData(knn_jni::JNIUtilInterface *jniUtil, } else { vect = reinterpret_cast*>(memoryAddressJ); } + + if (appendJ == JNI_FALSE) { + vect->clear(); + } + int dim = jniUtil->GetInnerDimensionOf2dJavaByteArray(env, dataJ); jniUtil->Convert2dJavaObjectArrayAndStoreToByteVector(env, dataJ, dim, vect); diff --git a/jni/src/faiss_index_service.cpp b/jni/src/faiss_index_service.cpp index 8c5ba36af..f76c54428 100644 --- a/jni/src/faiss_index_service.cpp +++ b/jni/src/faiss_index_service.cpp @@ -57,21 +57,69 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, IndexService::IndexService(std::unique_ptr faissMethods) : faissMethods(std::move(faissMethods)) {} -void IndexService::createIndex( +void IndexService::allocIndex(faiss::Index * index, size_t dim, size_t numVectors) { + if(auto * indexHNSWSQ = dynamic_cast(index)) { + if(auto * indexScalarQuantizer = dynamic_cast(indexHNSWSQ->storage)) { + indexScalarQuantizer->codes.reserve(indexScalarQuantizer->code_size * numVectors); + } + return; + } + if(auto * indexHNSW = dynamic_cast(index)) { + if(auto * indexFlat = dynamic_cast(indexHNSW->storage)) { + indexFlat->codes.reserve(indexFlat->code_size * numVectors); + } + return; + } +} + +jlong IndexService::initIndex( knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, faiss::MetricType metric, std::string indexDescription, + int dim, + int numVectors, + int threadCount, + std::unordered_map parameters + ) { + // Create index using Faiss factory method + std::unique_ptr index(faissMethods->indexFactory(dim, indexDescription.c_str(), metric)); + + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if(threadCount != 0) { + omp_set_num_threads(threadCount); + } + + // Add extra parameters that cant be configured with the index factory + SetExtraParameters(jniUtil, env, parameters, index.get()); + + // Check that the index does not need to be trained + if(!index->is_trained) { + throw std::runtime_error("Index is not trained"); + } + + std::unique_ptr idMap (faissMethods->indexIdMap(index.get())); + //Makes sure the index is deleted when the destructor is called, this cannot be passed in the constructor + idMap->own_fields = true; + + allocIndex(dynamic_cast(idMap->index), dim, numVectors); + + //Release the ownership so as to make sure not delete the underlying index that is created. The index is needed later + //in insert and write operations + index.release(); + return reinterpret_cast(idMap.release()); +} + +void IndexService::insertToIndex( int dim, int numIds, int threadCount, int64_t vectorsAddress, - std::vector ids, - std::string indexPath, - std::unordered_map parameters + std::vector & ids, + jlong idMapAddress ) { // Read vectors from memory address - auto *inputVectors = reinterpret_cast*>(vectorsAddress); + std::vector * inputVectors = reinterpret_cast*>(vectorsAddress); // The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value int numVectors = (int) (inputVectors->size() / (uint64_t) dim); @@ -83,50 +131,89 @@ void IndexService::createIndex( throw std::runtime_error("Number of IDs does not match number of vectors"); } - std::unique_ptr indexWriter(faissMethods->indexFactory(dim, indexDescription.c_str(), metric)); - // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread if(threadCount != 0) { omp_set_num_threads(threadCount); } - // Add extra parameters that cant be configured with the index factory - SetExtraParameters(jniUtil, env, parameters, indexWriter.get()); - - // Check that the index does not need to be trained - if(!indexWriter->is_trained) { - throw std::runtime_error("Index is not trained"); - } + faiss::IndexIDMap * idMap = reinterpret_cast (idMapAddress); // Add vectors - std::unique_ptr idMap(faissMethods->indexIdMap(indexWriter.get())); idMap->add_with_ids(numVectors, inputVectors->data(), ids.data()); +} - // Write the index to disk - faissMethods->writeIndex(idMap.get(), indexPath.c_str()); +void IndexService::writeIndex( + std::string indexPath, + jlong idMapAddress + ) { + std::unique_ptr idMap (reinterpret_cast (idMapAddress)); + + try { + // Write the index to disk + faissMethods->writeIndex(idMap.get(), indexPath.c_str()); + } catch(std::exception &e) { + throw std::runtime_error("Failed to write index to disk"); + } } BinaryIndexService::BinaryIndexService(std::unique_ptr faissMethods) : IndexService(std::move(faissMethods)) {} -void BinaryIndexService::createIndex( +void BinaryIndexService::allocIndex(faiss::Index * index, size_t dim, size_t numVectors) { + if(auto * indexBinaryHNSW = dynamic_cast(index)) { + auto * indexBinaryFlat = dynamic_cast(indexBinaryHNSW->storage); + indexBinaryFlat->xb.reserve(dim * numVectors / 8); + return; + } +} + +jlong BinaryIndexService::initIndex( knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, faiss::MetricType metric, std::string indexDescription, int dim, - int numIds, + int numVectors, int threadCount, - int64_t vectorsAddress, - std::vector ids, - std::string indexPath, std::unordered_map parameters ) { - // Read vectors from memory address - auto *inputVectors = reinterpret_cast*>(vectorsAddress); + // Create index using Faiss factory method + std::unique_ptr index(faissMethods->indexBinaryFactory(dim, indexDescription.c_str())); + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if(threadCount != 0) { + omp_set_num_threads(threadCount); + } + + // Add extra parameters that cant be configured with the index factory + SetExtraParameters(jniUtil, env, parameters, index.get()); - if (dim % 8 != 0) { - throw std::runtime_error("Dimensions should be multiply of 8"); + // Check that the index does not need to be trained + if(!index->is_trained) { + throw std::runtime_error("Index is not trained"); } + + std::unique_ptr idMap(faissMethods->indexBinaryIdMap(index.get())); + //Makes sure the index is deleted when the destructor is called + idMap->own_fields = true; + + allocIndex(dynamic_cast(idMap->index), dim, numVectors); + + //Release the ownership so as to make sure not delete the underlying index that is created. The index is needed later + //in insert and write operations + index.release(); + return reinterpret_cast(idMap.release()); +} + +void BinaryIndexService::insertToIndex( + int dim, + int numIds, + int threadCount, + int64_t vectorsAddress, + std::vector & ids, + jlong idMapAddress + ) { + // Read vectors from memory address (unique ptr since we want to remove from memory after use) + std::vector * inputVectors = reinterpret_cast*>(vectorsAddress); + // The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value int numVectors = (int) (inputVectors->size() / (uint64_t) (dim / 8)); if(numVectors == 0) { @@ -137,28 +224,31 @@ void BinaryIndexService::createIndex( throw std::runtime_error("Number of IDs does not match number of vectors"); } - std::unique_ptr indexWriter(faissMethods->indexBinaryFactory(dim, indexDescription.c_str())); - // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread if(threadCount != 0) { omp_set_num_threads(threadCount); } - // Add extra parameters that cant be configured with the index factory - SetExtraParameters(jniUtil, env, parameters, indexWriter.get()); - - // Check that the index does not need to be trained - if(!indexWriter->is_trained) { - throw std::runtime_error("Index is not trained"); - } + faiss::IndexBinaryIDMap * idMap = reinterpret_cast (idMapAddress); // Add vectors - std::unique_ptr idMap(faissMethods->indexBinaryIdMap(indexWriter.get())); idMap->add_with_ids(numVectors, inputVectors->data(), ids.data()); +} - // Write the index to disk - faissMethods->writeIndexBinary(idMap.get(), indexPath.c_str()); +void BinaryIndexService::writeIndex( + std::string indexPath, + jlong idMapAddress + ) { + + std::unique_ptr idMap (reinterpret_cast (idMapAddress)); + + try { + // Write the index to disk + faissMethods->writeIndexBinary(idMap.get(), indexPath.c_str()); + } catch(std::exception &e) { + throw std::runtime_error("Failed to write index to disk"); + } } } // namespace faiss_wrapper -} // namesapce knn_jni +} // namesapce knn_jni \ No newline at end of file diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index 1d4437414..0e1029ecf 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -88,24 +88,13 @@ bool isIndexIVFPQL2(faiss::Index * index); // IndexIDMap which has member that will point to underlying index that stores the data faiss::IndexIVFPQ * extractIVFPQIndex(faiss::Index * index); -void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, - jstring indexPathJ, jobject parametersJ, IndexService* indexService) { - if (idsJ == nullptr) { - throw std::runtime_error("IDs cannot be null"); - } - - if (vectorsAddressJ <= 0) { - throw std::runtime_error("VectorsAddress cannot be less than 0"); - } +jlong knn_jni::faiss_wrapper::InitIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong numDocs, jint dimJ, + jobject parametersJ, IndexService* indexService) { if(dimJ <= 0) { throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0"); } - if (indexPathJ == nullptr) { - throw std::runtime_error("Index path cannot be null"); - } - if (parametersJ == nullptr) { throw std::runtime_error("Parameters cannot be null"); } @@ -124,8 +113,8 @@ void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JN // Dimension int dim = (int)dimJ; - // Number of vectors - int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); + // Number of docs + int docs = (int)numDocs; // Index description jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION); @@ -138,25 +127,60 @@ void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JN threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); } + // Extra parameters + // TODO: parse the entire map and remove jni object + std::unordered_map subParametersCpp; + if(parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) { + subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersCpp[knn_jni::PARAMETERS]); + } + // end parameters to pass + + // Create index + return indexService->initIndex(jniUtil, env, metric, indexDescriptionCpp, dim, numDocs, threadCount, subParametersCpp); +} + +void knn_jni::faiss_wrapper::InsertToIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, + jlong index_ptr, jint threadCount, IndexService* indexService) { + if (idsJ == nullptr) { + throw std::runtime_error("IDs cannot be null"); + } + + if (vectorsAddressJ <= 0) { + throw std::runtime_error("VectorsAddress cannot be less than 0"); + } + + if(dimJ <= 0) { + throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0"); + } + + // Dimension + int dim = (int)dimJ; + + // Number of vectors + int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); + // Vectors address int64_t vectorsAddress = (int64_t)vectorsAddressJ; // Ids auto ids = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ); - // Index path - std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); + // Create index + indexService->insertToIndex(dim, numIds, threadCount, vectorsAddress, ids, index_ptr); +} - // Extra parameters - // TODO: parse the entire map and remove jni object - std::unordered_map subParametersCpp; - if(parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) { - subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersCpp[knn_jni::PARAMETERS]); +void knn_jni::faiss_wrapper::WriteIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, + jstring indexPathJ, jlong index_ptr, IndexService* indexService) { + + if (indexPathJ == nullptr) { + throw std::runtime_error("Index path cannot be null"); } - // end parameters to pass + + // Index path + std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); // Create index - indexService->createIndex(jniUtil, env, metric, indexDescriptionCpp, dim, numIds, threadCount, vectorsAddress, ids, indexPathCpp, subParametersCpp); + indexService->writeIndex(indexPathCpp, index_ptr); } void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index 2394e2951..663e18457 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -39,37 +39,83 @@ void JNI_OnUnload(JavaVM *vm, void *reserved) { jniUtil.Uninitialize(env); } -JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex(JNIEnv * env, jclass cls, jintArray idsJ, - jlong vectorsAddressJ, jint dimJ, - jstring indexPathJ, jobject parametersJ) +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initIndex(JNIEnv * env, jclass cls, + jlong numDocs, jint dimJ, + jobject parametersJ) { try { std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods)); - knn_jni::faiss_wrapper::CreateIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, parametersJ, &indexService); + return knn_jni::faiss_wrapper::InitIndex(&jniUtil, env, numDocs, dimJ, parametersJ, &indexService); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return (jlong)0; +} + +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initBinaryIndex(JNIEnv * env, jclass cls, + jlong numDocs, jint dimJ, + jobject parametersJ) +{ + try { + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::BinaryIndexService binaryIndexService(std::move(faissMethods)); + return knn_jni::faiss_wrapper::InitIndex(&jniUtil, env, numDocs, dimJ, parametersJ, &binaryIndexService); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return (jlong)0; +} - // Releasing the vectorsAddressJ memory as that is not required once we have created the index. - // This is not the ideal approach, please refer this gh issue for long term solution: - // https://github.com/opensearch-project/k-NN/issues/1600 - delete reinterpret_cast*>(vectorsAddressJ); +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToIndex(JNIEnv * env, jclass cls, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, + jlong indexAddress, jint threadCount) +{ + try { + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods)); + knn_jni::faiss_wrapper::InsertToIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexAddress, threadCount, &indexService); } catch (...) { + // NOTE: ADDING DELETE STATEMENT HERE CAUSES A CRASH! jniUtil.CatchCppExceptionAndThrowJava(env); } } -JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndex(JNIEnv * env, jclass cls, jintArray idsJ, +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToBinaryIndex(JNIEnv * env, jclass cls, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, - jstring indexPathJ, jobject parametersJ) + jlong indexAddress, jint threadCount) { try { std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); knn_jni::faiss_wrapper::BinaryIndexService binaryIndexService(std::move(faissMethods)); - knn_jni::faiss_wrapper::CreateIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, parametersJ, &binaryIndexService); + knn_jni::faiss_wrapper::InsertToIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexAddress, threadCount, &binaryIndexService); + } catch (...) { + // NOTE: ADDING DELETE STATEMENT HERE CAUSES A CRASH! + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} + +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeIndex(JNIEnv * env, jclass cls, + jlong indexAddress, + jstring indexPathJ) +{ + try { + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods)); + knn_jni::faiss_wrapper::WriteIndex(&jniUtil, env, indexPathJ, indexAddress, &indexService); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} - // Releasing the vectorsAddressJ memory as that is not required once we have created the index. - // This is not the ideal approach, please refer this gh issue for long term solution: - // https://github.com/opensearch-project/k-NN/issues/1600 - delete reinterpret_cast*>(vectorsAddressJ); +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeBinaryIndex(JNIEnv * env, jclass cls, + jlong indexAddress, + jstring indexPathJ) +{ + try { + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::BinaryIndexService binaryIndexService(std::move(faissMethods)); + knn_jni::faiss_wrapper::WriteIndex(&jniUtil, env, indexPathJ, indexAddress, &binaryIndexService); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } diff --git a/jni/src/org_opensearch_knn_jni_JNICommons.cpp b/jni/src/org_opensearch_knn_jni_JNICommons.cpp index 0bc2e4633..7432c44d3 100644 --- a/jni/src/org_opensearch_knn_jni_JNICommons.cpp +++ b/jni/src/org_opensearch_knn_jni_JNICommons.cpp @@ -38,11 +38,11 @@ void JNI_OnUnload(JavaVM *vm, void *reserved) { JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData(JNIEnv * env, jclass cls, -jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ) +jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ, jboolean appendJ) { try { - return knn_jni::commons::storeVectorData(&jniUtil, env, memoryAddressJ, dataJ, initialCapacityJ); + return knn_jni::commons::storeVectorData(&jniUtil, env, memoryAddressJ, dataJ, initialCapacityJ, appendJ); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } @@ -50,11 +50,11 @@ jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ) } JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeByteVectorData(JNIEnv * env, jclass cls, -jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ) +jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ, jboolean appendJ) { try { - return knn_jni::commons::storeByteVectorData(&jniUtil, env, memoryAddressJ, dataJ, initialCapacityJ); + return knn_jni::commons::storeByteVectorData(&jniUtil, env, memoryAddressJ, dataJ, initialCapacityJ, appendJ); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } diff --git a/jni/tests/commons_test.cpp b/jni/tests/commons_test.cpp index 630358919..d469fe268 100644 --- a/jni/tests/commons_test.cpp +++ b/jni/tests/commons_test.cpp @@ -33,7 +33,7 @@ TEST(CommonsTests, BasicAssertions) { testing::NiceMock mockJNIUtil; jlong memoryAddress = knn_jni::commons::storeVectorData(&mockJNIUtil, jniEnv, (jlong)0, - reinterpret_cast(&data), (jlong)(totalNumberOfVector * dim)); + reinterpret_cast(&data), (jlong)(totalNumberOfVector * dim), true); ASSERT_NE(memoryAddress, 0); auto *vect = reinterpret_cast*>(memoryAddress); ASSERT_EQ(vect->size(), data.size() * dim); @@ -48,12 +48,13 @@ TEST(CommonsTests, BasicAssertions) { } data2.push_back(vector); memoryAddress = knn_jni::commons::storeVectorData(&mockJNIUtil, jniEnv, memoryAddress, - reinterpret_cast(&data2), (jlong)(totalNumberOfVector * dim)); + reinterpret_cast(&data2), (jlong)(totalNumberOfVector * dim), true); ASSERT_NE(memoryAddress, 0); ASSERT_EQ(memoryAddress, oldMemoryAddress); vect = reinterpret_cast*>(memoryAddress); int currentIndex = 0; - ASSERT_EQ(vect->size(), totalNumberOfVector*dim); + std::cout << vect->size() + "\n"; + ASSERT_EQ(vect->size(), totalNumberOfVector * dim); ASSERT_EQ(vect->capacity(), totalNumberOfVector * dim); // Validate if all vectors data are at correct location @@ -70,6 +71,115 @@ TEST(CommonsTests, BasicAssertions) { currentIndex++; } } + + // test append == true + std::vector> data3; + std::vector vecto3; + for(int j = 0 ; j < dim ; j ++) { + vecto3.push_back((float)j); + } + data3.push_back(vecto3); + memoryAddress = knn_jni::commons::storeVectorData(&mockJNIUtil, jniEnv, memoryAddress, + reinterpret_cast(&data3), (jlong)(totalNumberOfVector * dim), false); + ASSERT_NE(memoryAddress, 0); + ASSERT_EQ(memoryAddress, oldMemoryAddress); + vect = reinterpret_cast*>(memoryAddress); + + ASSERT_EQ(vect->size(), dim); //Since we just added 1 vector + ASSERT_EQ(vect->capacity(), totalNumberOfVector * dim); //This is the initial capacity allocated + + currentIndex = 0; + for(auto & i : data3) { + for(float j : i) { + ASSERT_FLOAT_EQ(vect->at(currentIndex), j); + currentIndex++; + } + } + + // Check that freeing vector data works + knn_jni::commons::freeVectorData(memoryAddress); +} + +TEST(StoreByteVectorTest, BasicAssertions) { + long dim = 3; + long totalNumberOfVector = 5; + std::vector> data; + for(int i = 0 ; i < totalNumberOfVector - 1 ; i++) { + std::vector vector; + for(int j = 0 ; j < dim ; j ++) { + vector.push_back((uint8_t)j); + } + data.push_back(vector); + } + JNIEnv *jniEnv = nullptr; + + testing::NiceMock mockJNIUtil; + + jlong memoryAddress = knn_jni::commons::storeByteVectorData(&mockJNIUtil, jniEnv, (jlong)0, + reinterpret_cast(&data), (jlong)(totalNumberOfVector * dim), true); + ASSERT_NE(memoryAddress, 0); + auto *vect = reinterpret_cast*>(memoryAddress); + ASSERT_EQ(vect->size(), data.size() * dim); + ASSERT_EQ(vect->capacity(), totalNumberOfVector * dim); + + // Check by inserting more vectors at same memory location + jlong oldMemoryAddress = memoryAddress; + std::vector> data2; + std::vector vector; + for(int j = 0 ; j < dim ; j ++) { + vector.push_back((uint8_t)j); + } + data2.push_back(vector); + memoryAddress = knn_jni::commons::storeByteVectorData(&mockJNIUtil, jniEnv, memoryAddress, + reinterpret_cast(&data2), (jlong)(totalNumberOfVector * dim), true); + ASSERT_NE(memoryAddress, 0); + ASSERT_EQ(memoryAddress, oldMemoryAddress); + vect = reinterpret_cast*>(memoryAddress); + int currentIndex = 0; + ASSERT_EQ(vect->size(), totalNumberOfVector*dim); + ASSERT_EQ(vect->capacity(), totalNumberOfVector * dim); + + // Validate if all vectors data are at correct location + for(auto & i : data) { + for(uint8_t j : i) { + ASSERT_EQ(vect->at(currentIndex), j); + currentIndex++; + } + } + + for(auto & i : data2) { + for(uint8_t j : i) { + ASSERT_EQ(vect->at(currentIndex), j); + currentIndex++; + } + } + + // test append == true + std::vector> data3; + std::vector vecto3; + for(int j = 0 ; j < dim ; j ++) { + vecto3.push_back((uint8_t)j); + } + data3.push_back(vecto3); + memoryAddress = knn_jni::commons::storeByteVectorData(&mockJNIUtil, jniEnv, memoryAddress, + reinterpret_cast(&data3), (jlong)(totalNumberOfVector * dim), false); + ASSERT_NE(memoryAddress, 0); + ASSERT_EQ(memoryAddress, oldMemoryAddress); + vect = reinterpret_cast*>(memoryAddress); + + ASSERT_EQ(vect->size(), dim); + ASSERT_EQ(vect->capacity(), totalNumberOfVector * dim); + + currentIndex = 0; + for(auto & i : data3) { + for(uint8_t j : i) { + ASSERT_EQ(vect->at(currentIndex), j); + currentIndex++; + } + } + + // Check that freeing vector data works + knn_jni::commons::freeVectorData(memoryAddress); } TEST(CommonTests, GetIntegerMethodParam) { diff --git a/jni/tests/faiss_index_service_test.cpp b/jni/tests/faiss_index_service_test.cpp index f876edced..1f00f6a1d 100644 --- a/jni/tests/faiss_index_service_test.cpp +++ b/jni/tests/faiss_index_service_test.cpp @@ -64,18 +64,9 @@ TEST(CreateIndexTest, BasicAssertions) { // Create the index knn_jni::faiss_wrapper::IndexService indexService(std::move(mockFaissMethods)); - indexService.createIndex( - &mockJNIUtil, - jniEnv, - metricType, - indexDescription, - dim, - numIds, - threadCount, - (int64_t) &vectors, - ids, - indexPath, - parametersMap); + long indexAddress = indexService.initIndex(&mockJNIUtil, jniEnv, metricType, indexDescription, dim, numIds, threadCount, parametersMap); + indexService.insertToIndex(dim, numIds, threadCount, (int64_t) &vectors, ids, indexAddress); + indexService.writeIndex(indexPath, indexAddress); } TEST(CreateBinaryIndexTest, BasicAssertions) { @@ -119,16 +110,7 @@ TEST(CreateBinaryIndexTest, BasicAssertions) { // Create the index knn_jni::faiss_wrapper::BinaryIndexService indexService(std::move(mockFaissMethods)); - indexService.createIndex( - &mockJNIUtil, - jniEnv, - metricType, - indexDescription, - dim, - numIds, - threadCount, - (int64_t) &vectors, - ids, - indexPath, - parametersMap); + long indexAddress = indexService.initIndex(&mockJNIUtil, jniEnv, metricType, indexDescription, dim, numIds, threadCount, parametersMap); + indexService.insertToIndex(dim, numIds, threadCount, (int64_t) &vectors, ids, indexAddress); + indexService.writeIndex(indexPath, indexAddress); } \ No newline at end of file diff --git a/jni/tests/faiss_wrapper_test.cpp b/jni/tests/faiss_wrapper_test.cpp index 5ae443837..a1839c6ce 100644 --- a/jni/tests/faiss_wrapper_test.cpp +++ b/jni/tests/faiss_wrapper_test.cpp @@ -32,6 +32,70 @@ float rangeSearchRandomDataMin = -50; float rangeSearchRandomDataMax = 50; float rangeSearchRadius = 20000; +void createIndexIteratively( + knn_jni::JNIUtilInterface * JNIUtil, + JNIEnv *jniEnv, + std::vector & ids, + std::vector & vectors, + int dim, + std::string & indexPath, + std::unordered_map parametersMap, + IndexService * indexService, + int insertions = 10 + ) { + long numDocs = ids.size(); + if(numDocs % insertions != 0) { + throw std::invalid_argument("Number of documents should be divisible by number of insertions"); + } + long docsPerInsertion = numDocs / insertions; + long index_ptr = knn_jni::faiss_wrapper::InitIndex(JNIUtil, jniEnv, numDocs, dim, (jobject)¶metersMap, indexService); + for(int i = 0; i < insertions; i++) { + int start_idx = i * docsPerInsertion; + int end_idx = start_idx + docsPerInsertion; + std::vector insertIds; + std::vector insertVecs; + for(int j = start_idx; j < end_idx; j++) { + insertIds.push_back(j); + for(int k = 0; k < dim; k++) { + insertVecs.push_back(vectors[j * dim + k]); + } + } + knn_jni::faiss_wrapper::InsertToIndex(JNIUtil, jniEnv, reinterpret_cast(&insertIds), (jlong)&insertVecs, dim, index_ptr, 0, indexService); + } + knn_jni::faiss_wrapper::WriteIndex(JNIUtil, jniEnv, (jstring)&indexPath, index_ptr, indexService); +} + +void createBinaryIndexIteratively( + knn_jni::JNIUtilInterface * JNIUtil, + JNIEnv *jniEnv, + std::vector & ids, + std::vector & vectors, + int dim, + std::string & indexPath, + std::unordered_map parametersMap, + IndexService * indexService, + int insertions = 10 + ) { + long numDocs = ids.size();; + long index_ptr = knn_jni::faiss_wrapper::InitIndex(JNIUtil, jniEnv, numDocs, dim, (jobject)¶metersMap, indexService); + for(int i = 0; i < insertions; i++) { + int start_idx = numDocs * i / insertions; + int end_idx = numDocs * (i + 1) / insertions; + int docs_to_insert = end_idx - start_idx; + if(docs_to_insert == 0) continue; + std::vector insertIds; + std::vector insertVecs; + for(int j = start_idx; j < end_idx; j++) { + insertIds.push_back(j); + for(int k = 0; k < dim / 8; k++) { + insertVecs.push_back(vectors[j * (dim / 8) + k]); + } + } + knn_jni::faiss_wrapper::InsertToIndex(JNIUtil, jniEnv, reinterpret_cast(&insertIds), (jlong)&insertVecs, dim, index_ptr, 0, indexService); + } + knn_jni::faiss_wrapper::WriteIndex(JNIUtil, jniEnv, (jstring)&indexPath, index_ptr, indexService); +} + TEST(FaissCreateIndexTest, BasicAssertions) { // Define the data faiss::idx_t numIds = 200; @@ -63,13 +127,15 @@ TEST(FaissCreateIndexTest, BasicAssertions) { // Create the index std::unique_ptr faissMethods(new FaissMethods()); NiceMock mockIndexService(std::move(faissMethods)); - EXPECT_CALL(mockIndexService, createIndex(_, _, faiss::METRIC_L2, indexDescription, dim, (int)numIds, 0, (int64_t)&vectors, ids, indexPath, subParametersMap)) + int insertions = 10; + EXPECT_CALL(mockIndexService, initIndex(_, _, faiss::METRIC_L2, indexDescription, dim, (int)numIds, 0, subParametersMap)) + .Times(1); + EXPECT_CALL(mockIndexService, insertToIndex(dim, numIds / insertions, 0, _, _, _)) + .Times(insertions); + EXPECT_CALL(mockIndexService, writeIndex(indexPath, _)) .Times(1); - knn_jni::faiss_wrapper::CreateIndex( - &mockJNIUtil, jniEnv, reinterpret_cast(&ids), - (jlong) &vectors, dim , (jstring)&indexPath, - (jobject)¶metersMap, &mockIndexService); + createIndexIteratively(&mockJNIUtil, jniEnv, ids, vectors, dim, indexPath, parametersMap, &mockIndexService, insertions); } TEST(FaissCreateBinaryIndexTest, BasicAssertions) { @@ -103,14 +169,16 @@ TEST(FaissCreateBinaryIndexTest, BasicAssertions) { // Create the index std::unique_ptr faissMethods(new FaissMethods()); NiceMock mockIndexService(std::move(faissMethods)); - EXPECT_CALL(mockIndexService, createIndex(_, _, faiss::METRIC_L2, indexDescription, dim, (int)numIds, 0, (int64_t)&vectors, ids, indexPath, subParametersMap)) + int insertions = 10; + EXPECT_CALL(mockIndexService, initIndex(_, _, faiss::METRIC_L2, indexDescription, dim, (int)numIds, 0, subParametersMap)) + .Times(1); + EXPECT_CALL(mockIndexService, insertToIndex(dim, numIds / insertions, 0, _, _, _)) + .Times(insertions); + EXPECT_CALL(mockIndexService, writeIndex(indexPath, _)) .Times(1); // This method calls delete vectors at the end - knn_jni::faiss_wrapper::CreateIndex( - &mockJNIUtil, jniEnv, reinterpret_cast(&ids), - (jlong) &vectors, dim , (jstring)&indexPath, - (jobject)¶metersMap, &mockIndexService); + createBinaryIndexIteratively(&mockJNIUtil, jniEnv, ids, vectors, dim, indexPath, parametersMap, &mockIndexService, insertions); } TEST(FaissCreateIndexFromTemplateTest, BasicAssertions) { @@ -683,10 +751,8 @@ TEST(FaissCreateHnswSQfp16IndexTest, BasicAssertions) { // Create the index std::unique_ptr faissMethods(new FaissMethods()); knn_jni::faiss_wrapper::IndexService IndexService(std::move(faissMethods)); - knn_jni::faiss_wrapper::CreateIndex( - &mockJNIUtil, jniEnv, reinterpret_cast(&ids), - (jlong)&vectors, dim, (jstring)&indexPath, - (jobject)¶metersMap, &IndexService); + + createIndexIteratively(&mockJNIUtil, jniEnv, ids, vectors, dim, indexPath, parametersMap, &IndexService); // Make sure index can be loaded std::unique_ptr index(test_util::FaissLoadIndex(indexPath)); diff --git a/jni/tests/mocks/faiss_index_service_mock.h b/jni/tests/mocks/faiss_index_service_mock.h index 7af08c82e..285e34053 100644 --- a/jni/tests/mocks/faiss_index_service_mock.h +++ b/jni/tests/mocks/faiss_index_service_mock.h @@ -23,20 +23,37 @@ class MockIndexService : public IndexService { public: MockIndexService(std::unique_ptr faissMethods) : IndexService(std::move(faissMethods)) {}; MOCK_METHOD( - void, - createIndex, + long, + initIndex, ( knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, faiss::MetricType metric, std::string indexDescription, + int dim, + int numIds, + int threadCount, + StringToJObjectMap parameters + ), + (override)); + MOCK_METHOD( + void, + insertToIndex, + ( int dim, int numIds, int threadCount, int64_t vectorsAddress, - std::vector ids, + std::vector & ids, + long indexPtr + ), + (override)); + MOCK_METHOD( + void, + writeIndex, + ( std::string indexPath, - StringToJObjectMap parameters + long indexPtr ), (override)); }; diff --git a/src/main/java/org/opensearch/knn/common/FieldInfoExtractor.java b/src/main/java/org/opensearch/knn/common/FieldInfoExtractor.java index 591f16735..4fba5fd5a 100644 --- a/src/main/java/org/opensearch/knn/common/FieldInfoExtractor.java +++ b/src/main/java/org/opensearch/knn/common/FieldInfoExtractor.java @@ -9,9 +9,13 @@ import org.apache.commons.lang.StringUtils; import org.apache.lucene.index.FieldInfo; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelUtil; +import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.indices.ModelUtil.getModelMetadata; + /** * A utility class to extract information from FieldInfo. */ @@ -19,7 +23,21 @@ public class FieldInfoExtractor { /** - * Extract vector data type from fieldInfo + * Extracts KNNEngine from FieldInfo + * @param field {@link FieldInfo} + * @return {@link KNNEngine} + */ + public static KNNEngine extractKNNEngine(final FieldInfo field) { + final ModelMetadata modelMetadata = getModelMetadata(field.attributes().get(MODEL_ID)); + if (modelMetadata != null) { + return modelMetadata.getKnnEngine(); + } + final String engineName = field.attributes().getOrDefault(KNNConstants.KNN_ENGINE, KNNEngine.DEFAULT.getName()); + return KNNEngine.getEngine(engineName); + } + + /** + * Extracts VectorDataType from FieldInfo * @param fieldInfo {@link FieldInfo} * @return {@link VectorDataType} */ diff --git a/src/main/java/org/opensearch/knn/common/KNNVectorUtil.java b/src/main/java/org/opensearch/knn/common/KNNVectorUtil.java index fd9e5b6c2..778cc164d 100644 --- a/src/main/java/org/opensearch/knn/common/KNNVectorUtil.java +++ b/src/main/java/org/opensearch/knn/common/KNNVectorUtil.java @@ -5,9 +5,13 @@ package org.opensearch.knn.common; -import java.util.Objects; import lombok.AccessLevel; import lombok.NoArgsConstructor; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; @NoArgsConstructor(access = AccessLevel.PRIVATE) public class KNNVectorUtil { @@ -42,4 +46,33 @@ public static boolean isZeroVector(float[] vector) { } return true; } + + /** + * Converts an integer List to and array + * @param integerList + * @return null if list is null or empty, int[] otherwise + */ + public static int[] intListToArray(final List integerList) { + if (integerList == null || integerList.isEmpty()) { + return null; + } + int[] intArray = new int[integerList.size()]; + for (int i = 0; i < integerList.size(); i++) { + intArray[i] = integerList.get(i); + } + return intArray; + } + + /** + * Iterates vector values once if it is not at start of the location, + * Intended to be done to make sure dimension and bytesPerVector are available + * @param vectorValues + * @throws IOException + */ + public static void iterateVectorValuesOnce(final KNNVectorValues vectorValues) throws IOException { + if (vectorValues.docId() == -1) { + vectorValues.nextDoc(); + vectorValues.getVector(); + } + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java index 69229036e..8beced605 100644 --- a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java @@ -8,8 +8,11 @@ import lombok.AllArgsConstructor; import lombok.extern.log4j.Log4j2; import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.opensearch.index.mapper.MapperService; +import org.opensearch.knn.index.codec.KNN990Codec.NativeEngines990KnnVectorsFormat; import org.opensearch.knn.index.codec.params.KNNScalarQuantizedVectorsFormatParams; import org.opensearch.knn.index.codec.params.KNNVectorsFormatParams; import org.opensearch.knn.index.engine.KNNEngine; @@ -17,6 +20,7 @@ import org.opensearch.knn.index.mapper.KNNMappingConfig; import org.opensearch.knn.index.mapper.KNNVectorFieldType; +import java.util.Map; import java.util.Optional; import java.util.function.Function; import java.util.function.Supplier; @@ -78,42 +82,47 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) { KNNMethodContext knnMethodContext = knnMappingConfig.getKnnMethodContext() .orElseThrow(() -> new IllegalArgumentException("KNN method context cannot be empty")); - var params = knnMethodContext.getMethodComponentContext().getParameters(); + final KNNEngine engine = knnMethodContext.getKnnEngine(); + final Map params = knnMethodContext.getMethodComponentContext().getParameters(); - if (knnMethodContext.getKnnEngine() == KNNEngine.LUCENE && params != null && params.containsKey(METHOD_ENCODER_PARAMETER)) { - KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams( - params, - defaultMaxConnections, - defaultBeamWidth - ); - if (knnScalarQuantizedVectorsFormatParams.validate(params)) { - log.debug( - "Initialize KNN vector format for field [{}] with params [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\"", - field, - MAX_CONNECTIONS, - knnScalarQuantizedVectorsFormatParams.getMaxConnections(), - BEAM_WIDTH, - knnScalarQuantizedVectorsFormatParams.getBeamWidth(), - LUCENE_SQ_CONFIDENCE_INTERVAL, - knnScalarQuantizedVectorsFormatParams.getConfidenceInterval(), - LUCENE_SQ_BITS, - knnScalarQuantizedVectorsFormatParams.getBits() + if (engine == KNNEngine.LUCENE) { + if (params != null && params.containsKey(METHOD_ENCODER_PARAMETER)) { + KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams( + params, + defaultMaxConnections, + defaultBeamWidth ); - return scalarQuantizedVectorsFormatSupplier.apply(knnScalarQuantizedVectorsFormatParams); + if (knnScalarQuantizedVectorsFormatParams.validate(params)) { + log.debug( + "Initialize KNN vector format for field [{}] with params [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\"", + field, + MAX_CONNECTIONS, + knnScalarQuantizedVectorsFormatParams.getMaxConnections(), + BEAM_WIDTH, + knnScalarQuantizedVectorsFormatParams.getBeamWidth(), + LUCENE_SQ_CONFIDENCE_INTERVAL, + knnScalarQuantizedVectorsFormatParams.getConfidenceInterval(), + LUCENE_SQ_BITS, + knnScalarQuantizedVectorsFormatParams.getBits() + ); + return scalarQuantizedVectorsFormatSupplier.apply(knnScalarQuantizedVectorsFormatParams); + } } + KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(params, defaultMaxConnections, defaultBeamWidth); + log.debug( + "Initialize KNN vector format for field [{}] with params [{}] = \"{}\" and [{}] = \"{}\"", + field, + MAX_CONNECTIONS, + knnVectorsFormatParams.getMaxConnections(), + BEAM_WIDTH, + knnVectorsFormatParams.getBeamWidth() + ); + return vectorsFormatSupplier.apply(knnVectorsFormatParams); } - KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(params, defaultMaxConnections, defaultBeamWidth); - log.debug( - "Initialize KNN vector format for field [{}] with params [{}] = \"{}\" and [{}] = \"{}\"", - field, - MAX_CONNECTIONS, - knnVectorsFormatParams.getMaxConnections(), - BEAM_WIDTH, - knnVectorsFormatParams.getBeamWidth() - ); - return vectorsFormatSupplier.apply(knnVectorsFormatParams); + // All native engines to use NativeEngines990KnnVectorsFormat + return new NativeEngines990KnnVectorsFormat(new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer())); } @Override diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index 0a0776e83..218c9d891 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -5,76 +5,40 @@ package org.opensearch.knn.index.codec.KNN80Codec; -import lombok.NonNull; import lombok.extern.log4j.Log4j2; -import org.apache.lucene.store.ChecksumIndexInput; import org.opensearch.common.StopWatch; -import org.opensearch.common.xcontent.XContentHelper; -import org.opensearch.core.common.bytes.BytesArray; -import org.opensearch.core.xcontent.MediaTypeRegistry; -import org.opensearch.core.xcontent.DeprecationHandler; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.util.IndexUtil; -import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.codec.transfer.VectorTransfer; -import org.opensearch.knn.index.codec.transfer.VectorTransferByte; -import org.opensearch.knn.index.codec.transfer.VectorTransferFloat; -import org.opensearch.knn.jni.JNIService; -import org.opensearch.knn.index.codec.util.KNNCodecUtil; import org.opensearch.knn.index.engine.KNNEngine; -import org.opensearch.knn.indices.Model; -import org.opensearch.knn.indices.ModelCache; -import org.opensearch.knn.plugin.stats.KNNCounter; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.codecs.DocValuesConsumer; import org.apache.lucene.codecs.DocValuesProducer; -import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.DocValuesType; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.SegmentWriteState; -import org.apache.lucene.store.FSDirectory; -import org.apache.lucene.store.FilterDirectory; +import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; -import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.plugin.stats.KNNGraphValue; -import java.io.Closeable; import java.io.IOException; -import java.io.OutputStream; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.file.Files; -import java.nio.file.Paths; -import java.nio.file.StandardOpenOption; -import java.security.AccessController; -import java.security.PrivilegedAction; -import java.util.HashMap; -import java.util.Map; -import static org.apache.lucene.codecs.CodecUtil.FOOTER_MAGIC; -import static org.opensearch.knn.common.KNNConstants.MODEL_ID; -import static org.opensearch.knn.common.KNNConstants.PARAMETERS; -import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileName; -import static org.opensearch.knn.index.codec.util.KNNCodecUtil.calculateArraySize; -import static org.opensearch.knn.index.engine.faiss.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; +import static org.opensearch.knn.common.FieldInfoExtractor.extractKNNEngine; +import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataType; /** * This class writes the KNN docvalues to the segments */ @Log4j2 -class KNN80DocValuesConsumer extends DocValuesConsumer implements Closeable { +class KNN80DocValuesConsumer extends DocValuesConsumer { private final Logger logger = LogManager.getLogger(KNN80DocValuesConsumer.class); private final DocValuesConsumer delegatee; private final SegmentWriteState state; - private static final Long CRC32_CHECKSUM_SANITY = 0xFFFFFFFF00000000L; - KNN80DocValuesConsumer(DocValuesConsumer delegatee, SegmentWriteState state) { this.delegatee = delegatee; this.state = state; @@ -86,7 +50,7 @@ public void addBinaryField(FieldInfo field, DocValuesProducer valuesProducer) th if (isKNNBinaryFieldRequired(field)) { StopWatch stopWatch = new StopWatch(); stopWatch.start(); - addKNNBinaryField(field, valuesProducer, false, true); + addKNNBinaryField(field, valuesProducer, false); stopWatch.stop(); long time_in_millis = stopWatch.totalTime().millis(); KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.set(KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue() + time_in_millis); @@ -95,193 +59,21 @@ public void addBinaryField(FieldInfo field, DocValuesProducer valuesProducer) th } private boolean isKNNBinaryFieldRequired(FieldInfo field) { - final KNNEngine knnEngine = getKNNEngine(field); + final KNNEngine knnEngine = extractKNNEngine(field); log.debug(String.format("Read engine [%s] for field [%s]", knnEngine.getName(), field.getName())); return field.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD) && KNNEngine.getEnginesThatCreateCustomSegmentFiles().stream().anyMatch(engine -> engine == knnEngine); } - private KNNEngine getKNNEngine(@NonNull FieldInfo field) { - final String modelId = field.attributes().get(MODEL_ID); - if (modelId != null) { - var model = ModelCache.getInstance().get(modelId); - return model.getModelMetadata().getKnnEngine(); - } - final String engineName = field.attributes().getOrDefault(KNNConstants.KNN_ENGINE, KNNEngine.DEFAULT.getName()); - return KNNEngine.getEngine(engineName); - } - - public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge, boolean isRefresh) - throws IOException { - // Get values to be indexed - BinaryDocValues values = valuesProducer.getBinary(field); - final KNNEngine knnEngine = getKNNEngine(field); - final String engineFileName = buildEngineFileName( - state.segmentInfo.name, - knnEngine.getVersion(), - field.name, - knnEngine.getExtension() - ); - final String indexPath = Paths.get( - ((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), - engineFileName - ).toString(); - - // Determine if we are creating an index from a model or from scratch - NativeIndexCreator indexCreator; - KNNCodecUtil.Pair pair; - Map fieldAttributes = field.attributes(); - VectorDataType vectorDataType; - - if (fieldAttributes.containsKey(MODEL_ID)) { - String modelId = fieldAttributes.get(MODEL_ID); - Model model = ModelCache.getInstance().get(modelId); - if (model.getModelBlob() == null) { - throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId)); - } - vectorDataType = model.getModelMetadata().getVectorDataType(); - pair = KNNCodecUtil.getPair(values, getVectorTransfer(vectorDataType)); - indexCreator = () -> createKNNIndexFromTemplate(model, pair, knnEngine, indexPath); - } else { - // get vector data type from field attributes or provide default value - vectorDataType = VectorDataType.get( - fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()) - ); - pair = KNNCodecUtil.getPair(values, getVectorTransfer(vectorDataType)); - indexCreator = () -> createKNNIndexFromScratch(field, pair, knnEngine, indexPath); - } - - // Skip index creation if no vectors or docs in segment - if (pair.getVectorAddress() == 0 || pair.docs.length == 0) { - logger.info("Skipping engine index creation as there are no vectors or docs in the segment"); - return; - } - - long arraySize = calculateArraySize(pair.docs.length, pair.getDimension(), vectorDataType); + public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge) throws IOException { + final VectorDataType vectorDataType = extractVectorDataType(field); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(vectorDataType, valuesProducer.getBinary(field)); if (isMerge) { - KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment(); - KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(pair.docs.length); - KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(arraySize); - recordMergeStats(pair.docs.length, arraySize); - } - - // Increment counter for number of graph index requests - KNNCounter.GRAPH_INDEX_REQUESTS.increment(); - - if (isRefresh) { - recordRefreshStats(); - } - - // Ensure engineFileName is added to the tracked files by Lucene's TrackingDirectoryWrapper - state.directory.createOutput(engineFileName, state.context).close(); - indexCreator.createIndex(); - writeFooter(indexPath, engineFileName); - } - - private void recordMergeStats(int length, long arraySize) { - KNNGraphValue.MERGE_CURRENT_OPERATIONS.decrement(); - KNNGraphValue.MERGE_CURRENT_DOCS.decrementBy(length); - KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.decrementBy(arraySize); - KNNGraphValue.MERGE_TOTAL_OPERATIONS.increment(); - KNNGraphValue.MERGE_TOTAL_DOCS.incrementBy(length); - KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.incrementBy(arraySize); - } - - private void recordRefreshStats() { - KNNGraphValue.REFRESH_TOTAL_OPERATIONS.increment(); - } - - private void createKNNIndexFromTemplate(Model model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) { - Map parameters = new HashMap<>(); - parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); - - IndexUtil.updateVectorDataTypeToParameters(parameters, model.getModelMetadata().getVectorDataType()); - - AccessController.doPrivileged((PrivilegedAction) () -> { - JNIService.createIndexFromTemplate( - pair.docs, - pair.getVectorAddress(), - pair.getDimension(), - indexPath, - model.getModelBlob(), - parameters, - knnEngine - ); - return null; - }); - } - - private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) - throws IOException { - Map parameters = new HashMap<>(); - Map fieldAttributes = fieldInfo.attributes(); - String parametersString = fieldAttributes.get(PARAMETERS); - // parametersString will be null when legacy mapper is used - if (parametersString == null) { - parameters.put(KNNConstants.SPACE_TYPE, fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue())); - - String efConstruction = fieldAttributes.get(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION); - Map algoParams = new HashMap<>(); - if (efConstruction != null) { - algoParams.put(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, Integer.parseInt(efConstruction)); - } - - String m = fieldAttributes.get(KNNConstants.HNSW_ALGO_M); - if (m != null) { - algoParams.put(KNNConstants.METHOD_PARAMETER_M, Integer.parseInt(m)); - } - parameters.put(PARAMETERS, algoParams); + NativeIndexWriter.getWriter(field, state).mergeIndex(knnVectorValues); } else { - parameters.putAll( - XContentHelper.createParser( - NamedXContentRegistry.EMPTY, - DeprecationHandler.THROW_UNSUPPORTED_OPERATION, - new BytesArray(parametersString), - MediaTypeRegistry.getDefaultMediaType() - ).map() - ); - } - - // In OpenSearch 2.16, we added the prefix for binary indices in the index description in the codec logic. - // After 2.16, we added the binary prefix in the faiss library code. However, to ensure backwards compatibility, - // we need to ensure that if the description does not contain the prefix but the type is binary, we add the - // description. - maybeAddBinaryPrefixForFaissBWC(knnEngine, parameters, fieldAttributes); - - // Used to determine how many threads to use when indexing - parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); - - // Pass the path for the nms library to save the file - AccessController.doPrivileged((PrivilegedAction) () -> { - JNIService.createIndex(pair.docs, pair.getVectorAddress(), pair.getDimension(), indexPath, parameters, knnEngine); - return null; - }); - } - - private void maybeAddBinaryPrefixForFaissBWC(KNNEngine knnEngine, Map parameters, Map fieldAttributes) { - if (KNNEngine.FAISS != knnEngine) { - return; - } - - if (!VectorDataType.BINARY.getValue() - .equals(fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()))) { - return; - } - - if (parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) == null) { - return; - } - - if (parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString().startsWith(FAISS_BINARY_INDEX_DESCRIPTION_PREFIX)) { - return; + NativeIndexWriter.getWriter(field, state).flushIndex(knnVectorValues); } - - parameters.put( - KNNConstants.INDEX_DESCRIPTION_PARAMETER, - FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString() - ); - IndexUtil.updateVectorDataTypeToParameters(parameters, VectorDataType.BINARY); } /** @@ -300,7 +92,7 @@ public void merge(MergeState mergeState) { if (type == DocValuesType.BINARY && fieldInfo.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD)) { StopWatch stopWatch = new StopWatch(); stopWatch.start(); - addKNNBinaryField(fieldInfo, new KNN80DocValuesReader(mergeState), true, false); + addKNNBinaryField(fieldInfo, new KNN80DocValuesReader(mergeState), true); stopWatch.stop(); long time_in_millis = stopWatch.totalTime().millis(); KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.set(KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getValue() + time_in_millis); @@ -336,52 +128,4 @@ public void addNumericField(FieldInfo field, DocValuesProducer valuesProducer) t public void close() throws IOException { delegatee.close(); } - - @FunctionalInterface - private interface NativeIndexCreator { - void createIndex() throws IOException; - } - - private void writeFooter(String indexPath, String engineFileName) throws IOException { - // Opens the engine file that was created and appends a footer to it. The footer consists of - // 1. A Footer magic number (int - 4 bytes) - // 2. A checksum algorithm id (int - 4 bytes) - // 3. A checksum (long - bytes) - // The checksum is computed on all the bytes written to the file up to that point. - // Logic where footer is written in Lucene can be found here: - // https://github.com/apache/lucene/blob/branch_9_0/lucene/core/src/java/org/apache/lucene/codecs/CodecUtil.java#L390-L412 - OutputStream os = Files.newOutputStream(Paths.get(indexPath), StandardOpenOption.APPEND); - ByteBuffer byteBuffer = ByteBuffer.allocate(8).order(ByteOrder.BIG_ENDIAN); - byteBuffer.putInt(FOOTER_MAGIC); - byteBuffer.putInt(0); - os.write(byteBuffer.array()); - os.flush(); - - ChecksumIndexInput checksumIndexInput = state.directory.openChecksumInput(engineFileName, state.context); - checksumIndexInput.seek(checksumIndexInput.length()); - long value = checksumIndexInput.getChecksum(); - checksumIndexInput.close(); - - if (isChecksumValid(value)) { - throw new IllegalStateException("Illegal CRC-32 checksum: " + value + " (resource=" + os + ")"); - } - - // Write the CRC checksum to the end of the OutputStream and close the stream - byteBuffer.putLong(0, value); - os.write(byteBuffer.array()); - os.close(); - } - - private boolean isChecksumValid(long value) { - // Check pulled from - // https://github.com/apache/lucene/blob/branch_9_0/lucene/core/src/java/org/apache/lucene/codecs/CodecUtil.java#L644-L647 - return (value & CRC32_CHECKSUM_SANITY) != 0; - } - - private VectorTransfer getVectorTransfer(VectorDataType vectorDataType) { - if (VectorDataType.BINARY == vectorDataType) { - return new VectorTransferByte(KNNSettings.getVectorStreamingMemoryLimit().getBytes()); - } - return new VectorTransferFloat(KNNSettings.getVectorStreamingMemoryLimit().getBytes()); - } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriter.java index e4860af31..1abb84944 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriter.java @@ -30,6 +30,7 @@ */ class NativeEngineFieldVectorsWriter extends KnnFieldVectorsWriter { private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(NativeEngineFieldVectorsWriter.class); + @Getter private final FieldInfo fieldInfo; /** * We are using a map here instead of list, because for sampler interface for quantization we have to advance the iterator @@ -77,6 +78,8 @@ public void addValue(int docID, T vectorValue) { + "\" appears more than once in this document (only one value is allowed per field)" ); } + // TODO: we can build the graph here too iteratively. but right now I am skipping that as we need iterative + // graph build support on the JNI layer. assert docID > lastDocID; vectors.put(docID, vectorValue); docsWithField.add(docID); diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index b81ec9789..65736a63e 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -12,23 +12,33 @@ package org.opensearch.knn.index.codec.KNN990Codec; import lombok.RequiredArgsConstructor; +import lombok.extern.log4j.Log4j2; import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.RamUsageEstimator; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; import java.io.IOException; import java.util.ArrayList; import java.util.List; +import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataType; + /** * A KNNVectorsWriter class for writing the vector data strcutures and flat vectors for Native Engines. */ +@Log4j2 @RequiredArgsConstructor public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter { private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(NativeEngines990KnnVectorsWriter.class); @@ -46,8 +56,6 @@ public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter { @Override public KnnFieldVectorsWriter addField(final FieldInfo fieldInfo) throws IOException { final NativeEngineFieldVectorsWriter newField = NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream); - // TODO: we can build the graph here too iteratively. but right now I am skipping that as we need iterative - // graph build support on the JNI layer. fields.add(newField); return flatVectorsWriter.addField(fieldInfo, newField); } @@ -62,14 +70,40 @@ public KnnFieldVectorsWriter addField(final FieldInfo fieldInfo) throws IOExc public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { // simply write data in the flat file flatVectorsWriter.flush(maxDoc, sortMap); - // TODO: add code for creating Vector datastructures during lucene flush operation + for (final NativeEngineFieldVectorsWriter field : fields) { + final VectorDataType vectorDataType = extractVectorDataType(field.getFieldInfo()); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + vectorDataType, + field.getDocsWithField(), + field.getVectors() + ); + + NativeIndexWriter.getWriter(field.getFieldInfo(), segmentWriteState).flushIndex(knnVectorValues); + } } @Override public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState) throws IOException { // This will ensure that we are merging the FlatIndex during force merge. flatVectorsWriter.mergeOneField(fieldInfo, mergeState); - // TODO: add code for creating Vector datastructures during merge operation + + // For merge, pick values from flat vector and reindex again. This will use the flush operation to create graphs + final VectorDataType vectorDataType = extractVectorDataType(fieldInfo); + final KNNVectorValues knnVectorValues; + switch (fieldInfo.getVectorEncoding()) { + case FLOAT32: + final FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + knnVectorValues = KNNVectorValuesFactory.getVectorValues(vectorDataType, mergedFloats); + break; + case BYTE: + final ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState); + knnVectorValues = KNNVectorValuesFactory.getVectorValues(vectorDataType, mergedBytes); + break; + default: + throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]"); + } + + NativeIndexWriter.getWriter(fieldInfo, segmentWriteState).mergeIndex(knnVectorValues); } /** diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java new file mode 100644 index 000000000..5787ea76b --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java @@ -0,0 +1,95 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; +import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.jni.JNIService; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.common.KNNVectorUtil.intListToArray; +import static org.opensearch.knn.common.KNNVectorUtil.iterateVectorValuesOnce; +import static org.opensearch.knn.index.codec.transfer.OffHeapVectorTransferFactory.getVectorTransfer; + +/** + * Transfers all vectors to off heap and then builds an index + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +final class DefaultIndexBuildStrategy implements NativeIndexBuildStrategy { + + private static DefaultIndexBuildStrategy INSTANCE = new DefaultIndexBuildStrategy(); + + public static DefaultIndexBuildStrategy getInstance() { + return INSTANCE; + } + + public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVectorValues knnVectorValues) throws IOException { + iterateVectorValuesOnce(knnVectorValues); // to get bytesPerVector + int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / knnVectorValues.bytesPerVector()); + try (final OffHeapVectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), transferLimit)) { + + final List tranferredDocIds = new ArrayList<>(); + while (knnVectorValues.docId() != NO_MORE_DOCS) { + // append is true here so off heap memory buffer isn't overwritten + vectorTransfer.transfer(knnVectorValues.conditionalCloneVector(), true); + tranferredDocIds.add(knnVectorValues.docId()); + knnVectorValues.nextDoc(); + } + vectorTransfer.flush(true); + + final Map params = indexInfo.getParameters(); + long vectorAddress = vectorTransfer.getVectorAddress(); + // Currently this is if else as there are only two cases, with more cases this will have to be made + // more maintainable + if (params.containsKey(MODEL_ID)) { + AccessController.doPrivileged((PrivilegedAction) () -> { + JNIService.createIndexFromTemplate( + intListToArray(tranferredDocIds), + vectorAddress, + knnVectorValues.dimension(), + indexInfo.getIndexPath(), + (byte[]) params.get(KNNConstants.MODEL_BLOB_PARAMETER), + indexInfo.getParameters(), + indexInfo.getKnnEngine() + ); + return null; + }); + } else { + AccessController.doPrivileged((PrivilegedAction) () -> { + JNIService.createIndex( + intListToArray(tranferredDocIds), + vectorAddress, + knnVectorValues.dimension(), + indexInfo.getIndexPath(), + indexInfo.getParameters(), + indexInfo.getKnnEngine() + ); + return null; + }); + } + // Resetting here as vectors are deleted in JNILayer for non-iterative index builds + vectorTransfer.reset(); + } catch (Exception exception) { + throw new RuntimeException( + "Failed to build index, field name " + indexInfo.getFieldName() + ", parameters " + indexInfo, + exception + ); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java new file mode 100644 index 000000000..af80215b6 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java @@ -0,0 +1,117 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; +import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.jni.JNIService; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.opensearch.knn.common.KNNVectorUtil.intListToArray; +import static org.opensearch.knn.common.KNNVectorUtil.iterateVectorValuesOnce; +import static org.opensearch.knn.index.codec.transfer.OffHeapVectorTransferFactory.getVectorTransfer; + +/** + * Iteratively builds the index. Iterative builds are memory optimized as it does not require all vectors + * to be transferred. It transfers vectors in small batches, builds index and can clear the offheap space where + * the vectors were transferred + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +final class MemOptimizedNativeIndexBuildStrategy implements NativeIndexBuildStrategy { + + private static MemOptimizedNativeIndexBuildStrategy INSTANCE = new MemOptimizedNativeIndexBuildStrategy(); + + public static MemOptimizedNativeIndexBuildStrategy getInstance() { + return INSTANCE; + } + + public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVectorValues knnVectorValues) throws IOException { + // Needed to make sure we dont get 0 dimensions while initializing index + iterateVectorValuesOnce(knnVectorValues); + KNNEngine engine = indexInfo.getKnnEngine(); + Map indexParameters = indexInfo.getParameters(); + + // Initialize the index + long indexMemoryAddress = AccessController.doPrivileged( + (PrivilegedAction) () -> JNIService.initIndex( + knnVectorValues.totalLiveDocs(), + knnVectorValues.dimension(), + indexParameters, + engine + ) + ); + + int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / knnVectorValues.bytesPerVector()); + try (final OffHeapVectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), transferLimit)) { + + final List tranferredDocIds = new ArrayList<>(transferLimit); + while (knnVectorValues.docId() != NO_MORE_DOCS) { + // append is false to be able to reuse the memory location + boolean transferred = vectorTransfer.transfer(knnVectorValues.conditionalCloneVector(), false); + tranferredDocIds.add(knnVectorValues.docId()); + if (transferred) { + // Insert vectors + long vectorAddress = vectorTransfer.getVectorAddress(); + AccessController.doPrivileged((PrivilegedAction) () -> { + JNIService.insertToIndex( + intListToArray(tranferredDocIds), + vectorAddress, + knnVectorValues.dimension(), + indexParameters, + indexMemoryAddress, + engine + ); + return null; + }); + tranferredDocIds.clear(); + } + knnVectorValues.nextDoc(); + } + + boolean flush = vectorTransfer.flush(false); + // Need to make sure that the flushed vectors are indexed + if (flush) { + long vectorAddress = vectorTransfer.getVectorAddress(); + AccessController.doPrivileged((PrivilegedAction) () -> { + JNIService.insertToIndex( + intListToArray(tranferredDocIds), + vectorAddress, + knnVectorValues.dimension(), + indexParameters, + indexMemoryAddress, + engine + ); + return null; + }); + tranferredDocIds.clear(); + } + + // Write vector + AccessController.doPrivileged((PrivilegedAction) () -> { + JNIService.writeIndex(indexInfo.getIndexPath(), indexMemoryAddress, engine, indexParameters); + return null; + }); + + } catch (Exception exception) { + throw new RuntimeException( + "Failed to build index, field name [" + indexInfo.getFieldName() + "], parameters " + indexInfo, + exception + ); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexBuildStrategy.java new file mode 100644 index 000000000..19475adfa --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexBuildStrategy.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex; + +import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; + +import java.io.IOException; + +/** + * Interface which dictates how the index needs to be built + */ +public interface NativeIndexBuildStrategy { + + void buildAndWriteIndex(BuildIndexParams indexInfo, final KNNVectorValues knnVectorValues) throws IOException; +} diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java new file mode 100644 index 000000000..61500371b --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java @@ -0,0 +1,298 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex; + +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.store.ChecksumIndexInput; +import org.apache.lucene.store.FSDirectory; +import org.apache.lucene.store.FilterDirectory; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.util.IndexUtil; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.indices.Model; +import org.opensearch.knn.indices.ModelCache; +import org.opensearch.knn.plugin.stats.KNNGraphValue; + +import java.io.IOException; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.nio.file.StandardOpenOption; +import java.util.HashMap; +import java.util.Map; + +import static org.apache.lucene.codecs.CodecUtil.FOOTER_MAGIC; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.opensearch.knn.common.FieldInfoExtractor.extractKNNEngine; +import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataType; +import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +import static org.opensearch.knn.common.KNNVectorUtil.iterateVectorValuesOnce; +import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileName; +import static org.opensearch.knn.index.engine.faiss.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; + +/** + * Writes KNN Index for a field in a segment. This is intended to be used for native engines + */ +@AllArgsConstructor +@Log4j2 +public class NativeIndexWriter { + private static final Long CRC32_CHECKSUM_SANITY = 0xFFFFFFFF00000000L; + + private final SegmentWriteState state; + private final FieldInfo fieldInfo; + private final NativeIndexBuildStrategy indexBuilder; + + /** + * Gets the correct writer type from fieldInfo + * + * @param fieldInfo + * @return correct NativeIndexWriter to make index specified in fieldInfo + */ + public static NativeIndexWriter getWriter(final FieldInfo fieldInfo, SegmentWriteState state) { + final KNNEngine knnEngine = extractKNNEngine(fieldInfo); + boolean isTemplate = fieldInfo.attributes().containsKey(MODEL_ID); + boolean iterative = !isTemplate && KNNEngine.FAISS == knnEngine; + if (iterative) { + return new NativeIndexWriter(state, fieldInfo, MemOptimizedNativeIndexBuildStrategy.getInstance()); + } + return new NativeIndexWriter(state, fieldInfo, DefaultIndexBuildStrategy.getInstance()); + } + + /** + * flushes the index + * + * @param knnVectorValues + * @throws IOException + */ + public void flushIndex(final KNNVectorValues knnVectorValues) throws IOException { + iterateVectorValuesOnce(knnVectorValues); + buildAndWriteIndex(knnVectorValues); + recordRefreshStats(); + } + + /** + * Merges kNN index + * @param knnVectorValues + * @throws IOException + */ + public void mergeIndex(final KNNVectorValues knnVectorValues) throws IOException { + iterateVectorValuesOnce(knnVectorValues); + if (knnVectorValues.docId() == NO_MORE_DOCS) { + // This is in place so we do not add metrics + log.debug("Skipping mergeIndex, vector values are already iterated for {}", fieldInfo.name); + return; + } + + long bytesPerVector = knnVectorValues.bytesPerVector(); + startMergeStats((int) knnVectorValues.totalLiveDocs(), bytesPerVector); + buildAndWriteIndex(knnVectorValues); + endMergeStats((int) knnVectorValues.totalLiveDocs(), bytesPerVector); + } + + private void buildAndWriteIndex(final KNNVectorValues knnVectorValues) throws IOException { + if (knnVectorValues.totalLiveDocs() == 0) { + log.debug("No live docs for field " + fieldInfo.name); + return; + } + + final KNNEngine knnEngine = extractKNNEngine(fieldInfo); + final String engineFileName = buildEngineFileName( + state.segmentInfo.name, + knnEngine.getVersion(), + fieldInfo.name, + knnEngine.getExtension() + ); + final String indexPath = Paths.get( + ((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), + engineFileName + ).toString(); + state.directory.createOutput(engineFileName, state.context).close(); + + final BuildIndexParams nativeIndexParams = indexParams(fieldInfo, indexPath, knnEngine); + indexBuilder.buildAndWriteIndex(nativeIndexParams, knnVectorValues); + writeFooter(indexPath, engineFileName, state); + } + + // The logic for building parameters need to be cleaned up. There are various cases handled here + // Currently it falls under two categories - with model and without model. Without model is further divided based on vector data type + // TODO: Refactor this so its scalable. Possibly move it out of this class + private BuildIndexParams indexParams(FieldInfo fieldInfo, String indexPath, KNNEngine knnEngine) throws IOException { + final Map parameters; + final VectorDataType vectorDataType = extractVectorDataType(fieldInfo); + if (fieldInfo.attributes().containsKey(MODEL_ID)) { + Model model = getModel(fieldInfo); + parameters = getTemplateParameters(fieldInfo, model); + } else { + parameters = getParameters(fieldInfo, vectorDataType, knnEngine); + } + + return BuildIndexParams.builder() + .fieldName(fieldInfo.name) + .parameters(parameters) + .vectorDataType(vectorDataType) + .knnEngine(knnEngine) + .indexPath(indexPath) + .build(); + } + + private Map getParameters(FieldInfo fieldInfo, VectorDataType vectorDataType, KNNEngine knnEngine) throws IOException { + Map parameters = new HashMap<>(); + Map fieldAttributes = fieldInfo.attributes(); + String parametersString = fieldAttributes.get(KNNConstants.PARAMETERS); + + // parametersString will be null when legacy mapper is used + if (parametersString == null) { + parameters.put(KNNConstants.SPACE_TYPE, fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue())); + + String efConstruction = fieldAttributes.get(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION); + Map algoParams = new HashMap<>(); + if (efConstruction != null) { + algoParams.put(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, Integer.parseInt(efConstruction)); + } + + String m = fieldAttributes.get(KNNConstants.HNSW_ALGO_M); + if (m != null) { + algoParams.put(KNNConstants.METHOD_PARAMETER_M, Integer.parseInt(m)); + } + parameters.put(PARAMETERS, algoParams); + } else { + parameters.putAll( + XContentHelper.createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + new BytesArray(parametersString), + MediaTypeRegistry.getDefaultMediaType() + ).map() + ); + } + + parameters.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue()); + // In OpenSearch 2.16, we added the prefix for binary indices in the index description in the codec logic. + // After 2.16, we added the binary prefix in the faiss library code. However, to ensure backwards compatibility, + // we need to ensure that if the description does not contain the prefix but the type is binary, we add the + // description. + maybeAddBinaryPrefixForFaissBWC(knnEngine, parameters, fieldAttributes); + + // Used to determine how many threads to use when indexing + parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); + + return parameters; + } + + private void maybeAddBinaryPrefixForFaissBWC(KNNEngine knnEngine, Map parameters, Map fieldAttributes) { + if (KNNEngine.FAISS != knnEngine) { + return; + } + + if (!VectorDataType.BINARY.getValue() + .equals(fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()))) { + return; + } + + if (parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) == null) { + return; + } + + if (parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString().startsWith(FAISS_BINARY_INDEX_DESCRIPTION_PREFIX)) { + return; + } + + parameters.put( + KNNConstants.INDEX_DESCRIPTION_PARAMETER, + FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString() + ); + IndexUtil.updateVectorDataTypeToParameters(parameters, VectorDataType.BINARY); + } + + private Map getTemplateParameters(FieldInfo fieldInfo, Model model) throws IOException { + Map parameters = new HashMap<>(); + parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); + parameters.put(KNNConstants.MODEL_ID, fieldInfo.attributes().get(MODEL_ID)); + parameters.put(KNNConstants.MODEL_BLOB_PARAMETER, model.getModelBlob()); + IndexUtil.updateVectorDataTypeToParameters(parameters, model.getModelMetadata().getVectorDataType()); + return parameters; + } + + private Model getModel(FieldInfo fieldInfo) { + String modelId = fieldInfo.attributes().get(MODEL_ID); + Model model = ModelCache.getInstance().get(modelId); + if (model.getModelBlob() == null) { + throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId)); + } + return model; + } + + private void startMergeStats(int numDocs, long bytesPerVector) { + KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment(); + KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(numDocs); + KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(bytesPerVector); + KNNGraphValue.MERGE_TOTAL_OPERATIONS.increment(); + KNNGraphValue.MERGE_TOTAL_DOCS.incrementBy(numDocs); + KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.incrementBy(bytesPerVector); + } + + private void endMergeStats(int numDocs, long arraySize) { + KNNGraphValue.MERGE_CURRENT_OPERATIONS.decrement(); + KNNGraphValue.MERGE_CURRENT_DOCS.decrementBy(numDocs); + KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.decrementBy(arraySize); + } + + private void recordRefreshStats() { + KNNGraphValue.REFRESH_TOTAL_OPERATIONS.increment(); + } + + private boolean isChecksumValid(long value) { + // Check pulled from + // https://github.com/apache/lucene/blob/branch_9_0/lucene/core/src/java/org/apache/lucene/codecs/CodecUtil.java#L644-L647 + return (value & CRC32_CHECKSUM_SANITY) != 0; + } + + private void writeFooter(String indexPath, String engineFileName, SegmentWriteState state) throws IOException { + // Opens the engine file that was created and appends a footer to it. The footer consists of + // 1. A Footer magic number (int - 4 bytes) + // 2. A checksum algorithm id (int - 4 bytes) + // 3. A checksum (long - bytes) + // The checksum is computed on all the bytes written to the file up to that point. + // Logic where footer is written in Lucene can be found here: + // https://github.com/apache/lucene/blob/branch_9_0/lucene/core/src/java/org/apache/lucene/codecs/CodecUtil.java#L390-L412 + OutputStream os = Files.newOutputStream(Paths.get(indexPath), StandardOpenOption.APPEND); + ByteBuffer byteBuffer = ByteBuffer.allocate(8).order(ByteOrder.BIG_ENDIAN); + byteBuffer.putInt(FOOTER_MAGIC); + byteBuffer.putInt(0); + os.write(byteBuffer.array()); + os.flush(); + + ChecksumIndexInput checksumIndexInput = state.directory.openChecksumInput(engineFileName, state.context); + checksumIndexInput.seek(checksumIndexInput.length()); + long value = checksumIndexInput.getChecksum(); + checksumIndexInput.close(); + + if (isChecksumValid(value)) { + throw new IllegalStateException("Illegal CRC-32 checksum: " + value + " (resource=" + os + ")"); + } + + // Write the CRC checksum to the end of the OutputStream and close the stream + byteBuffer.putLong(0, value); + os.write(byteBuffer.array()); + os.close(); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java new file mode 100644 index 000000000..af43ff37e --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java @@ -0,0 +1,26 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex.model; + +import lombok.Builder; +import lombok.ToString; +import lombok.Value; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNEngine; + +import java.util.Map; + +@Value +@Builder +@ToString +public class BuildIndexParams { + String fieldName; + KNNEngine knnEngine; + String indexPath; + VectorDataType vectorDataType; + Map parameters; + // TODO: Add quantization state as parameter to build index +} diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryVectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryVectorTransfer.java new file mode 100644 index 000000000..c9d4802fe --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapBinaryVectorTransfer.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.transfer; + +import java.io.IOException; +import java.util.List; + +/** + * Transfer quantized binary vectors to off heap memory + * The reason this is different from {@link OffHeapByteVectorTransfer} is because of allocation and deallocation + * of memory on JNI layer. Use this if unsigned int is needed on JNI layer + */ +public final class OffHeapBinaryVectorTransfer extends OffHeapVectorTransfer { + + public OffHeapBinaryVectorTransfer(int transferLimit) { + super(transferLimit); + } + + @Override + public void deallocate() { + // TODO: deallocate the memory location + } + + @Override + protected long transfer(List vectorsToTransfer, boolean append) throws IOException { + // TODO: call to JNIService to transfer vector + return 0L; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapByteVectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapByteVectorTransfer.java new file mode 100644 index 000000000..83ebf2fa3 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapByteVectorTransfer.java @@ -0,0 +1,38 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.transfer; + +import org.opensearch.knn.jni.JNICommons; + +import java.io.IOException; +import java.util.List; + +/** + * Transfer quantized byte vectors to off heap memory. + * The reason this is different from {@link OffHeapBinaryVectorTransfer} is because of allocation and deallocation + * of memory on JNI layer. Use this if signed int is needed on JNI layer + */ +public final class OffHeapByteVectorTransfer extends OffHeapVectorTransfer { + + public OffHeapByteVectorTransfer(int transferLimit) { + super(transferLimit); + } + + @Override + protected long transfer(List batch, boolean append) throws IOException { + return JNICommons.storeByteVectorData( + getVectorAddress(), + batch.toArray(new byte[][] {}), + (long) batch.get(0).length * transferLimit, + append + ); + } + + @Override + public void deallocate() { + JNICommons.freeByteVectorData(getVectorAddress()); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapFloatVectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapFloatVectorTransfer.java new file mode 100644 index 000000000..0eb28d791 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapFloatVectorTransfer.java @@ -0,0 +1,36 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.transfer; + +import org.opensearch.knn.jni.JNICommons; + +import java.io.IOException; +import java.util.List; + +/** + * Transfer float vectors to off heap memory. + */ +public final class OffHeapFloatVectorTransfer extends OffHeapVectorTransfer { + + public OffHeapFloatVectorTransfer(int transferLimit) { + super(transferLimit); + } + + @Override + protected long transfer(final List vectorsToTransfer, boolean append) throws IOException { + return JNICommons.storeVectorData( + getVectorAddress(), + vectorsToTransfer.toArray(new float[][] {}), + (long) vectorsToTransfer.get(0).length * this.transferLimit, + append + ); + } + + @Override + public void deallocate() { + JNICommons.freeVectorData(getVectorAddress()); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransfer.java new file mode 100644 index 000000000..43c27c8da --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransfer.java @@ -0,0 +1,99 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.transfer; + +import lombok.Getter; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; + +import java.io.Closeable; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +/** + *

+ * The class is intended to transfer {@link KNNVectorValues} to off heap memory. + *

+ *

+ * The class is not thread safe. + *

+ * + * @param byte[] or float[] + */ +public abstract class OffHeapVectorTransfer implements Closeable { + + @Getter + private long vectorAddress; + protected final int transferLimit; + + private final List vectorsToTransfer; + + public OffHeapVectorTransfer(final int transferLimit) { + this.transferLimit = transferLimit; + this.vectorsToTransfer = new ArrayList<>(transferLimit); + this.vectorAddress = 0; + } + + /** + * Transfer vectors to off-heap + * @param vector float[] or byte[] + * @param append This indicates whether to append or rewrite the off-heap buffer + * @return true of the vectors were transferred, false if not + * @throws IOException + */ + public boolean transfer(T vector, boolean append) throws IOException { + vectorsToTransfer.add(vector); + if (vectorsToTransfer.size() == this.transferLimit) { + vectorAddress = transfer(vectorsToTransfer, append); + vectorsToTransfer.clear(); + return true; + } + return false; + } + + /** + * Empties the {@link #vectorsToTransfer} if its not empty. Intended to be used before + * closing the transfer + * + * @param append This indicates whether to append or rewrite the off-heap buffer + * @return true of the vectors were transferred, false if not + * @throws IOException + */ + public boolean flush(boolean append) throws IOException { + // flush before closing + if (!vectorsToTransfer.isEmpty()) { + vectorAddress = transfer(vectorsToTransfer, append); + vectorsToTransfer.clear(); + return true; + } + return false; + } + + @Override + public void close() { + // Remove this if condition once create and write index is separated for nmslib + if (vectorAddress != 0) { + deallocate(); + } + reset(); + } + + /** + * Resets address and vectortoTransfer + * + * DO NOT USE this in the middle of the transfer, The behavior is undefined + * + * TODO: Make it package private once create and write index is separated for nmslib + */ + public void reset() { + vectorAddress = 0; + vectorsToTransfer.clear(); + } + + protected abstract void deallocate(); + + protected abstract long transfer(final List vectorsToTransfer, boolean append) throws IOException; +} diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactory.java b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactory.java new file mode 100644 index 000000000..bfcc13491 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactory.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.transfer; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.opensearch.knn.index.VectorDataType; + +/** + * Factory to get the right implementation of vector transfer + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +public final class OffHeapVectorTransferFactory { + + /** + * Gets the right vector transfer object based on vector data type + * @param vectorDataType {@link VectorDataType} + * @param transferLimit max number of vectors that can be transferred to off heap in one transfer + * @return Correct implementation of {@link OffHeapVectorTransfer} + * @param float[] or byte[] + */ + public static OffHeapVectorTransfer getVectorTransfer(final VectorDataType vectorDataType, final int transferLimit) { + switch (vectorDataType) { + case FLOAT: + return (OffHeapVectorTransfer) new OffHeapFloatVectorTransfer(transferLimit); + case BINARY: + // TODO: Add binary here + case BYTE: + return (OffHeapVectorTransfer) new OffHeapByteVectorTransfer(transferLimit); + default: + throw new IllegalArgumentException("Unsupported vector data type: " + vectorDataType); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransfer.java b/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransfer.java deleted file mode 100644 index c23bd4317..000000000 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransfer.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.codec.transfer; - -import lombok.Data; -import org.apache.lucene.util.BytesRef; -import org.opensearch.knn.index.codec.util.SerializationMode; - -/** - * Abstract class to transfer vector value from Java to native memory - */ -@Data -public abstract class VectorTransfer { - protected final long vectorsStreamingMemoryLimit; - protected long totalLiveDocs; - protected long vectorsPerTransfer; - protected long vectorAddress; - protected int dimension; - - public VectorTransfer(final long vectorsStreamingMemoryLimit) { - this.vectorsStreamingMemoryLimit = vectorsStreamingMemoryLimit; - this.vectorsPerTransfer = Integer.MIN_VALUE; - } - - /** - * Initialize the transfer - * - * @param totalLiveDocs total number of vectors to be transferred - */ - abstract public void init(final long totalLiveDocs); - - /** - * Transfer a single vector - * - * @param bytesRef a vector in bytes format - */ - abstract public void transfer(final BytesRef bytesRef); - - /** - * Close the transfer - */ - abstract public void close(); - - /** - * Get serialization mode of given byte stream - * - * @param bytesRef bytes of a vector - * @return serialization mode - */ - abstract public SerializationMode getSerializationMode(final BytesRef bytesRef); -} diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferByte.java b/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferByte.java deleted file mode 100644 index e81ac35fc..000000000 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferByte.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.codec.transfer; - -import org.apache.lucene.util.ArrayUtil; -import org.apache.lucene.util.BytesRef; -import org.opensearch.knn.index.codec.util.SerializationMode; -import org.opensearch.knn.jni.JNICommons; - -import java.util.ArrayList; -import java.util.List; - -/** - * Vector transfer for byte - */ -public class VectorTransferByte extends VectorTransfer { - private List vectorList; - - public VectorTransferByte(final long vectorsStreamingMemoryLimit) { - super(vectorsStreamingMemoryLimit); - vectorList = new ArrayList<>(); - } - - @Override - public void init(final long totalLiveDocs) { - this.totalLiveDocs = totalLiveDocs; - vectorList.clear(); - } - - @Override - public void transfer(final BytesRef bytesRef) { - dimension = bytesRef.length * 8; - if (vectorsPerTransfer == Integer.MIN_VALUE) { - // if vectorsStreamingMemoryLimit is 100 bytes and we have 50 vectors with length of 5, then per - // transfer we have to send 100/5 => 20 vectors. - vectorsPerTransfer = vectorsStreamingMemoryLimit / bytesRef.length; - // If vectorsPerTransfer comes out to be 0, then we set number of vectors per transfer to 1, to ensure that - // we are sending minimum number of vectors. - if (vectorsPerTransfer == 0) { - vectorsPerTransfer = 1; - } - } - - vectorList.add(ArrayUtil.copyOfSubArray(bytesRef.bytes, bytesRef.offset, bytesRef.offset + bytesRef.length)); - if (vectorList.size() == vectorsPerTransfer) { - transfer(); - } - } - - @Override - public void close() { - transfer(); - } - - @Override - public SerializationMode getSerializationMode(final BytesRef bytesRef) { - return SerializationMode.COLLECTIONS_OF_BYTES; - } - - private void transfer() { - int lengthOfVector = dimension / 8; - vectorAddress = JNICommons.storeByteVectorData(vectorAddress, vectorList.toArray(new byte[][] {}), totalLiveDocs * lengthOfVector); - vectorList.clear(); - } -} diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloat.java b/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloat.java deleted file mode 100644 index a9c792398..000000000 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloat.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.codec.transfer; - -import org.apache.lucene.util.BytesRef; -import org.opensearch.knn.index.codec.util.KNNVectorSerializer; -import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; -import org.opensearch.knn.index.codec.util.SerializationMode; -import org.opensearch.knn.jni.JNICommons; - -import java.util.ArrayList; -import java.util.List; - -/** - * Vector transfer for float - */ -public class VectorTransferFloat extends VectorTransfer { - private List vectorList; - - public VectorTransferFloat(final long vectorsStreamingMemoryLimit) { - super(vectorsStreamingMemoryLimit); - vectorList = new ArrayList<>(); - } - - @Override - public void init(final long totalLiveDocs) { - this.totalLiveDocs = totalLiveDocs; - vectorList.clear(); - } - - @Override - public void transfer(final BytesRef bytesRef) { - final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByBytesRef(bytesRef); - final float[] vector = vectorSerializer.byteToFloatArray(bytesRef); - dimension = vector.length; - - if (vectorsPerTransfer == Integer.MIN_VALUE) { - // if vectorsStreamingMemoryLimit is 100 bytes and we have 50 vectors with 5 dimension, then per - // transfer we have to send 100/(5 * 4) => 5 vectors. - vectorsPerTransfer = vectorsStreamingMemoryLimit / ((long) dimension * Float.BYTES); - // If vectorsPerTransfer comes out to be 0, then we set number of vectors per transfer to 1, to ensure that - // we are sending minimum number of vectors. - if (vectorsPerTransfer == 0) { - vectorsPerTransfer = 1; - } - } - - vectorList.add(vector); - if (vectorList.size() == vectorsPerTransfer) { - transfer(); - } - } - - @Override - public void close() { - transfer(); - } - - @Override - public SerializationMode getSerializationMode(final BytesRef bytesRef) { - return KNNVectorSerializerFactory.getSerializerModeFromBytesRef(bytesRef); - } - - private void transfer() { - vectorAddress = JNICommons.storeVectorData(vectorAddress, vectorList.toArray(new float[][] {}), totalLiveDocs * dimension); - vectorList.clear(); - } -} diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java index ea14fe883..51100a1e0 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java @@ -5,63 +5,14 @@ package org.opensearch.knn.index.codec.util; -import lombok.AllArgsConstructor; -import lombok.Getter; -import lombok.Setter; import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.util.BytesRef; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.KNN80Codec.KNN80BinaryDocValues; -import org.opensearch.knn.index.codec.transfer.VectorTransfer; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; public class KNNCodecUtil { // Floats are 4 bytes in size public static final int FLOAT_BYTE_SIZE = 4; - @AllArgsConstructor - public static final class Pair { - public int[] docs; - @Getter - @Setter - private long vectorAddress; - @Getter - @Setter - private int dimension; - public SerializationMode serializationMode; - } - - /** - * Extract docIds and vectors from binary doc values. - * - * @param values Binary doc values - * @param vectorTransfer Utility to make transfer - * @return KNNCodecUtil.Pair representing doc ids and corresponding vectors - * @throws IOException thrown when unable to get binary of vectors - */ - public static KNNCodecUtil.Pair getPair(final BinaryDocValues values, final VectorTransfer vectorTransfer) throws IOException { - List docIdList = new ArrayList<>(); - SerializationMode serializationMode = SerializationMode.COLLECTION_OF_FLOATS; - vectorTransfer.init(getTotalLiveDocsCount(values)); - for (int doc = values.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = values.nextDoc()) { - BytesRef bytesref = values.binaryValue(); - serializationMode = vectorTransfer.getSerializationMode(bytesref); - vectorTransfer.transfer(bytesref); - docIdList.add(doc); - } - vectorTransfer.close(); - return new KNNCodecUtil.Pair( - docIdList.stream().mapToInt(Integer::intValue).toArray(), - vectorTransfer.getVectorAddress(), - vectorTransfer.getDimension(), - serializationMode - ); - } - /** * This method provides a rough estimate of the number of bytes used for storing an array with the given parameters. * @param numVectors number of vectors in the array diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNBinaryVectorValues.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNBinaryVectorValues.java index f38099b74..5da093fd5 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNBinaryVectorValues.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNBinaryVectorValues.java @@ -11,6 +11,7 @@ import org.apache.lucene.index.ByteVectorValues; import java.io.IOException; +import java.util.Arrays; /** * Concrete implementation of {@link KNNVectorValues} that returns byte[] as vector where binary vector is stored and @@ -25,17 +26,17 @@ public class KNNBinaryVectorValues extends KNNVectorValues { @Override public byte[] getVector() throws IOException { final byte[] vector = VectorValueExtractorStrategy.extractBinaryVector(vectorValuesIterator); - this.dimension = vector.length; + this.dimension = vector.length * Byte.SIZE; + this.bytesPerVector = vector.length; return vector; } - /** - * Binary Vector values gets stored as byte[], hence for dimension of the binary vector we have to multiply the - * byte[] size with {@link Byte#SIZE} - * @return int - */ @Override - public int dimension() { - return super.dimension() * Byte.SIZE; + public byte[] conditionalCloneVector() throws IOException { + byte[] vector = getVector(); + if (vectorValuesIterator.getDocIdSetIterator() instanceof ByteVectorValues) { + return Arrays.copyOf(vector, vector.length); + } + return vector; } } diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNByteVectorValues.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNByteVectorValues.java index ccbbfab77..1ebc50970 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNByteVectorValues.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNByteVectorValues.java @@ -11,6 +11,7 @@ import org.apache.lucene.index.ByteVectorValues; import java.io.IOException; +import java.util.Arrays; /** * Concrete implementation of {@link KNNVectorValues} that returns float[] as vector and provides an abstraction over @@ -26,6 +27,17 @@ public class KNNByteVectorValues extends KNNVectorValues { public byte[] getVector() throws IOException { final byte[] vector = VectorValueExtractorStrategy.extractByteVector(vectorValuesIterator); this.dimension = vector.length; + this.bytesPerVector = vector.length; + return vector; + } + + @Override + public byte[] conditionalCloneVector() throws IOException { + byte[] vector = getVector(); + if (vectorValuesIterator.getDocIdSetIterator() instanceof ByteVectorValues) { + return Arrays.copyOf(vector, vector.length); + + } return vector; } } diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNFloatVectorValues.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNFloatVectorValues.java index 174f3a89e..dffdd8f0d 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNFloatVectorValues.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNFloatVectorValues.java @@ -10,6 +10,7 @@ import org.apache.lucene.index.FloatVectorValues; import java.io.IOException; +import java.util.Arrays; /** * Concrete implementation of {@link KNNVectorValues} that returns float[] as vector and provides an abstraction over @@ -24,6 +25,16 @@ public class KNNFloatVectorValues extends KNNVectorValues { public float[] getVector() throws IOException { final float[] vector = VectorValueExtractorStrategy.extractFloatVector(vectorValuesIterator); this.dimension = vector.length; + this.bytesPerVector = vector.length * 4; + return vector; + } + + @Override + public float[] conditionalCloneVector() throws IOException { + float[] vector = getVector(); + if (vectorValuesIterator.getDocIdSetIterator() instanceof FloatVectorValues) { + return Arrays.copyOf(vector, vector.length); + } return vector; } } diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValues.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValues.java index c4ed64bc2..56ebd208f 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValues.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValues.java @@ -23,6 +23,7 @@ public abstract class KNNVectorValues { protected final KNNVectorValuesIterator vectorValuesIterator; protected int dimension; + protected int bytesPerVector; protected KNNVectorValues(final KNNVectorValuesIterator vectorValuesIterator) { this.vectorValuesIterator = vectorValuesIterator; @@ -37,6 +38,20 @@ protected KNNVectorValues(final KNNVectorValuesIterator vectorValuesIterator) { */ public abstract T getVector() throws IOException; + /** + * Intended to return a vector reference either after deep copy of the vector obtained from {@code getVector} + * or return the vector itself. + *

+ * This decision to clone depends on the vector returned based on the type of iterator + *

+ * Running this function can incur latency hence should be absolutely used when necessary. + * For most of the cases {@link #getVector()} function should work. + * + * @return T an array of byte[], float[] Or a deep copy of it + * @throws IOException + */ + public abstract T conditionalCloneVector() throws IOException; + /** * Dimension of vector is returned. Do call getVector function first before calling this function otherwise you will get 0 value. * @return int @@ -46,6 +61,15 @@ public int dimension() { return dimension; } + /** + * Size of a vector in bytes is returned. Do call getVector function first before calling this function otherwise you will get 0 value. + * @return int + */ + public int bytesPerVector() { + assert docId() != -1 && bytesPerVector != 0 : "Cannot get bytesPerVector before we retrieve a vector from KNNVectorValues"; + return bytesPerVector; + } + /** * Returns the total live docs for KNNVectorValues. * @return long @@ -81,5 +105,4 @@ public int advance(int docId) throws IOException { public int nextDoc() throws IOException { return vectorValuesIterator.nextDoc(); } - } diff --git a/src/main/java/org/opensearch/knn/indices/ModelUtil.java b/src/main/java/org/opensearch/knn/indices/ModelUtil.java index 0f5a049fc..ac0e4fb79 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelUtil.java +++ b/src/main/java/org/opensearch/knn/indices/ModelUtil.java @@ -11,6 +11,7 @@ package org.opensearch.knn.indices; +import lombok.experimental.UtilityClass; import org.apache.commons.lang.StringUtils; import java.util.Locale; @@ -18,6 +19,7 @@ /** * A utility class for models. */ +@UtilityClass public class ModelUtil { public static void blockCommasInModelDescription(String description) { @@ -48,7 +50,7 @@ public static ModelMetadata getModelMetadata(final String modelId) { } final Model model = ModelCache.getInstance().get(modelId); final ModelMetadata modelMetadata = model.getModelMetadata(); - if (ModelUtil.isModelCreated(modelMetadata) == false) { + if (isModelCreated(modelMetadata) == false) { throw new IllegalArgumentException(String.format(Locale.ROOT, "Model ID '%s' is not created.", modelId)); } return modelMetadata; diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index 4f57b616a..a402be1f3 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -50,32 +50,70 @@ class FaissService { } /** - * Create an index for the native library The memory occupied by the vectorsAddress will be freed up during the + * Initialize an index for the native library. Takes in numDocs to + * allocate the correct amount of memory. + * + * @param numDocs number of documents to be added + * @param dim dimension of the vector to be indexed + * @param parameters parameters to build index + */ + public static native long initIndex(long numDocs, int dim, Map parameters); + + /** + * Initialize an index for the native library. Takes in numDocs to + * allocate the correct amount of memory. + * + * @param numDocs number of documents to be added + * @param dim dimension of the vector to be indexed + * @param parameters parameters to build index + */ + public static native long initBinaryIndex(long numDocs, int dim, Map parameters); + + /** + * Inserts to a faiss index. The memory occupied by the vectorsAddress will be freed up during the * function call. So Java layer doesn't need to free up the memory. This is not an ideal behavior because Java layer - * created the memory address and that should only free up the memory. We are tracking the proper fix for this on this - * issue + * created the memory address and that should only free up the memory. * - * @param ids array of ids mapping to the data passed in + * @param ids ids of documents * @param vectorsAddress address of native memory where vectors are stored * @param dim dimension of the vector to be indexed - * @param indexPath path to save index file to - * @param parameters parameters to build index + * @param indexAddress address of native memory where index is stored + * @param threadCount number of threads to use for insertion */ - public static native void createIndex(int[] ids, long vectorsAddress, int dim, String indexPath, Map parameters); + public static native void insertToIndex(int[] ids, long vectorsAddress, int dim, long indexAddress, int threadCount); /** - * Create a binary index for the native library The memory occupied by the vectorsAddress will be freed up during the + * Inserts to a faiss index. The memory occupied by the vectorsAddress will be freed up during the * function call. So Java layer doesn't need to free up the memory. This is not an ideal behavior because Java layer - * created the memory address and that should only free up the memory. We are tracking the proper fix for this on this - * issue + * created the memory address and that should only free up the memory. * - * @param ids array of ids mapping to the data passed in + * @param ids ids of documents * @param vectorsAddress address of native memory where vectors are stored * @param dim dimension of the vector to be indexed + * @param indexAddress address of native memory where index is stored + * @param threadCount number of threads to use for insertion + */ + public static native void insertToBinaryIndex(int[] ids, long vectorsAddress, int dim, long indexAddress, int threadCount); + + /** + * Writes a faiss index. + * + * NOTE: This will always free the index. Do not call free after this. + * + * @param indexAddress address of native memory where index is stored + * @param indexPath path to save index file to + */ + public static native void writeIndex(long indexAddress, String indexPath); + + /** + * Writes a faiss index. + * + * NOTE: This will always free the index. Do not call free after this. + * + * @param indexAddress address of native memory where index is stored * @param indexPath path to save index file to - * @param parameters parameters to build index */ - public static native void createBinaryIndex(int[] ids, long vectorsAddress, int dim, String indexPath, Map parameters); + public static native void writeBinaryIndex(long indexAddress, String indexPath); /** * Create an index for the native library with a provided template index diff --git a/src/main/java/org/opensearch/knn/jni/JNICommons.java b/src/main/java/org/opensearch/knn/jni/JNICommons.java index 31a8f43cc..c7222738e 100644 --- a/src/main/java/org/opensearch/knn/jni/JNICommons.java +++ b/src/main/java/org/opensearch/knn/jni/JNICommons.java @@ -36,16 +36,59 @@ public class JNICommons { * will throw Exception. * *

- * The function is not threadsafe. If multiple threads are trying to insert on same memory location, then it can - * lead to data corruption. + * The function is not threadsafe. If multiple threads are trying to insert on same memory location, then it can + * lead to data corruption. *

* - * @param memoryAddress The address of the memory location where data will be stored. - * @param data 2D float array containing data to be stored in native memory. + * @param memoryAddress The address of the memory location where data will be stored. + * @param data 2D float array containing data to be stored in native memory. * @param initialCapacity The initial capacity of the memory location. * @return memory address where the data is stored. */ - public static native long storeVectorData(long memoryAddress, float[][] data, long initialCapacity); + public static long storeVectorData(long memoryAddress, float[][] data, long initialCapacity) { + return storeVectorData(memoryAddress, data, initialCapacity, true); + } + + /** + * This is utility function that can be used to store data in native memory. This function will allocate memory for + * the data(rows*columns) with initialCapacity and return the memory address where the data is stored. + * If you are using this function for first time use memoryAddress = 0 to ensure that a new memory location is created. + * For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location + * will throw Exception. + * + *

+ * The function is not threadsafe. If multiple threads are trying to insert on same memory location, then it can + * lead to data corruption. + *

+ * + * @param memoryAddress The address of the memory location where data will be stored. + * @param data 2D float array containing data to be stored in native memory. + * @param initialCapacity The initial capacity of the memory location. + * @param append append the data or rewrite the memory location + * @return memory address where the data is stored. + */ + public static native long storeVectorData(long memoryAddress, float[][] data, long initialCapacity, boolean append); + + /** + * This is utility function that can be used to store data in native memory. This function will allocate memory for + * the data(rows*columns) with initialCapacity and return the memory address where the data is stored. + * If you are using this function for first time use memoryAddress = 0 to ensure that a new memory location is created. + * For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location + * will throw Exception. + * + *

+ * The function is not threadsafe. If multiple threads are trying to insert on same memory location, then it can + * lead to data corruption. + *

+ * + * @param memoryAddress The address of the memory location where data will be stored. + * @param data 2D byte array containing data to be stored in native memory. + * @param initialCapacity The initial capacity of the memory location. + * @return memory address where the data is stored. + */ + public static long storeByteVectorData(long memoryAddress, byte[][] data, long initialCapacity) { + return storeByteVectorData(memoryAddress, data, initialCapacity, true); + } /** * This is utility function that can be used to store data in native memory. This function will allocate memory for @@ -55,24 +98,25 @@ public class JNICommons { * will throw Exception. * *

- * The function is not threadsafe. If multiple threads are trying to insert on same memory location, then it can - * lead to data corruption. + * The function is not threadsafe. If multiple threads are trying to insert on same memory location, then it can + * lead to data corruption. *

* - * @param memoryAddress The address of the memory location where data will be stored. - * @param data 2D byte array containing data to be stored in native memory. + * @param memoryAddress The address of the memory location where data will be stored. + * @param data 2D byte array containing data to be stored in native memory. * @param initialCapacity The initial capacity of the memory location. + * @param append append the data or rewrite the memory location * @return memory address where the data is stored. */ - public static native long storeByteVectorData(long memoryAddress, byte[][] data, long initialCapacity); + public static native long storeByteVectorData(long memoryAddress, byte[][] data, long initialCapacity, boolean append); /** * Free up the memory allocated for the data stored in memory address. This function should be used with the memory - * address returned by {@link JNICommons#storeVectorData(long, float[][], long)} + * address returned by {@link JNICommons#storeVectorData(long, float[][], long, boolean)} * *

- * The function is not threadsafe. If multiple threads are trying to free up same memory location, then it can - * lead to errors. + * The function is not threadsafe. If multiple threads are trying to free up same memory location, then it can + * lead to errors. *

* * @param memoryAddress address to be freed. @@ -81,11 +125,11 @@ public class JNICommons { /** * Free up the memory allocated for the byte data stored in memory address. This function should be used with the memory - * address returned by {@link JNICommons#storeVectorData(long, float[][], long)} + * address returned by {@link JNICommons#storeVectorData(long, float[][], long, boolean)} * *

- * The function is not threadsafe. If multiple threads are trying to free up same memory location, then it can - * lead to errors. + * The function is not threadsafe. If multiple threads are trying to free up same memory location, then it can + * lead to errors. *

* * @param memoryAddress address to be freed. diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index de696b5ce..d1d5f6c11 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -13,28 +13,110 @@ import org.apache.commons.lang.ArrayUtils; import org.opensearch.common.Nullable; -import org.opensearch.knn.index.util.IndexUtil; -import org.opensearch.knn.index.query.KNNQueryResult; +import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.query.KNNQueryResult; +import org.opensearch.knn.index.util.IndexUtil; +import java.util.Locale; import java.util.Map; /** * Service to distribute requests to the proper engine jni service */ public class JNIService { + /** + * Initialize an index for the native library. Takes in numDocs to + * allocate the correct amount of memory. + * + * @param numDocs number of documents to be added + * @param dim dimension of the vector to be indexed + * @param parameters parameters to build index + * @param knnEngine knn engine + * @return address of the index in memory + */ + public static long initIndex(long numDocs, int dim, Map parameters, KNNEngine knnEngine) { + if (KNNEngine.FAISS == knnEngine) { + if (IndexUtil.isBinaryIndex(knnEngine, parameters)) { + return FaissService.initBinaryIndex(numDocs, dim, parameters); + } else { + return FaissService.initIndex(numDocs, dim, parameters); + } + } + + throw new IllegalArgumentException( + String.format(Locale.ROOT, "initIndexFromScratch not supported for provided engine : %s", knnEngine.getName()) + ); + } + + /** + * Inserts to a faiss index. + * + * @param docs ids of documents + * @param vectorsAddress address of native memory where vectors are stored + * @param dimension dimension of the vector to be indexed + * @param parameters parameters to build index + * @param indexAddress address of native memory where index is stored + * @param knnEngine knn engine + */ + public static void insertToIndex( + int[] docs, + long vectorsAddress, + int dimension, + Map parameters, + long indexAddress, + KNNEngine knnEngine + ) { + int threadCount = (int) parameters.getOrDefault(KNNConstants.INDEX_THREAD_QTY, 0); + if (KNNEngine.FAISS == knnEngine) { + if (IndexUtil.isBinaryIndex(knnEngine, parameters)) { + FaissService.insertToBinaryIndex(docs, vectorsAddress, dimension, indexAddress, threadCount); + } else { + FaissService.insertToIndex(docs, vectorsAddress, dimension, indexAddress, threadCount); + } + return; + } + + throw new IllegalArgumentException( + String.format(Locale.ROOT, "insertToIndex not supported for provided engine : %s", knnEngine.getName()) + ); + } + + /** + * Writes a faiss index to disk. + * + * @param indexPath path to save index to + * @param indexAddress address of native memory where index is stored + * @param knnEngine knn engine + * @param parameters parameters to build index + */ + public static void writeIndex(String indexPath, long indexAddress, KNNEngine knnEngine, Map parameters) { + if (KNNEngine.FAISS == knnEngine) { + if (IndexUtil.isBinaryIndex(knnEngine, parameters)) { + FaissService.writeBinaryIndex(indexAddress, indexPath); + } else { + FaissService.writeIndex(indexAddress, indexPath); + } + return; + } + + throw new IllegalArgumentException( + String.format(Locale.ROOT, "writeIndex not supported for provided engine : %s", knnEngine.getName()) + ); + } + /** * Create an index for the native library. The memory occupied by the vectorsAddress will be freed up during the * function call. So Java layer doesn't need to free up the memory. This is not an ideal behavior because Java layer * created the memory address and that should only free up the memory. We are tracking the proper fix for this on this * issue * - * @param ids array of ids mapping to the data passed in + * @param ids array of ids mapping to the data passed in * @param vectorsAddress address of native memory where vectors are stored - * @param dim dimension of the vector to be indexed - * @param indexPath path to save index file to - * @param parameters parameters to build index - * @param knnEngine engine to build index for + * @param dim dimension of the vector to be indexed + * @param indexPath path to save index file to + * @param parameters parameters to build index + * @param knnEngine engine to build index for */ public static void createIndex( int[] ids, @@ -50,28 +132,21 @@ public static void createIndex( return; } - if (KNNEngine.FAISS == knnEngine) { - if (IndexUtil.isBinaryIndex(knnEngine, parameters)) { - FaissService.createBinaryIndex(ids, vectorsAddress, dim, indexPath, parameters); - } else { - FaissService.createIndex(ids, vectorsAddress, dim, indexPath, parameters); - } - return; - } - - throw new IllegalArgumentException(String.format("CreateIndex not supported for provided engine : %s", knnEngine.getName())); + throw new IllegalArgumentException( + String.format(Locale.ROOT, "CreateIndex not supported for provided engine : %s", knnEngine.getName()) + ); } /** * Create an index for the native library with a provided template index * - * @param ids array of ids mapping to the data passed in + * @param ids array of ids mapping to the data passed in * @param vectorsAddress address of native memory where vectors are stored - * @param dim dimension of vectors to be indexed - * @param indexPath path to save index file to - * @param templateIndex empty template index - * @param parameters parameters to build index - * @param knnEngine engine to build index for + * @param dim dimension of vectors to be indexed + * @param indexPath path to save index file to + * @param templateIndex empty template index + * @param parameters parameters to build index + * @param knnEngine engine to build index for */ public static void createIndexFromTemplate( int[] ids, @@ -93,7 +168,7 @@ public static void createIndexFromTemplate( } throw new IllegalArgumentException( - String.format("CreateIndexFromTemplate not supported for provided engine : %s", knnEngine.getName()) + String.format(Locale.ROOT, "CreateIndexFromTemplate not supported for provided engine : %s", knnEngine.getName()) ); } @@ -118,7 +193,9 @@ public static long loadIndex(String indexPath, Map parameters, K } } - throw new IllegalArgumentException(String.format("LoadIndex not supported for provided engine : %s", knnEngine.getName())); + throw new IllegalArgumentException( + String.format(Locale.ROOT, "LoadIndex not supported for provided engine : %s", knnEngine.getName()) + ); } /** @@ -150,7 +227,7 @@ public static long initSharedIndexState(long indexAddr, KNNEngine knnEngine) { return FaissService.initSharedIndexState(indexAddr); } throw new IllegalArgumentException( - String.format("InitSharedIndexState not supported for provided engine : %s", knnEngine.getName()) + String.format(Locale.ROOT, "InitSharedIndexState not supported for provided engine : %s", knnEngine.getName()) ); } @@ -168,20 +245,20 @@ public static void setSharedIndexState(long indexAddr, long shareIndexStateAddr, } throw new IllegalArgumentException( - String.format("SetSharedIndexState not supported for provided engine : %s", knnEngine.getName()) + String.format(Locale.ROOT, "SetSharedIndexState not supported for provided engine : %s", knnEngine.getName()) ); } /** * Query an index * - * @param indexPointer pointer to index in memory - * @param queryVector vector to be used for query - * @param k neighbors to be returned - * @param methodParameters method parameter - * @param knnEngine engine to query index - * @param filteredIds array of ints on which should be used for search. - * @param filterIdsType how to filter ids: Batch or BitMap + * @param indexPointer pointer to index in memory + * @param queryVector vector to be used for query + * @param k neighbors to be returned + * @param methodParameters method parameter + * @param knnEngine engine to query index + * @param filteredIds array of ints on which should be used for search. + * @param filterIdsType how to filter ids: Batch or BitMap * @return KNNQueryResult array of k neighbors */ public static KNNQueryResult[] queryIndex( @@ -216,19 +293,21 @@ public static KNNQueryResult[] queryIndex( } return FaissService.queryIndex(indexPointer, queryVector, k, methodParameters, parentIds); } - throw new IllegalArgumentException(String.format("QueryIndex not supported for provided engine : %s", knnEngine.getName())); + throw new IllegalArgumentException( + String.format(Locale.ROOT, "QueryIndex not supported for provided engine : %s", knnEngine.getName()) + ); } /** * Query a binary index * - * @param indexPointer pointer to index in memory - * @param queryVector vector to be used for query - * @param k neighbors to be returned - * @param methodParameters method parameter - * @param knnEngine engine to query index - * @param filteredIds array of ints on which should be used for search. - * @param filterIdsType how to filter ids: Batch or BitMap + * @param indexPointer pointer to index in memory + * @param queryVector vector to be used for query + * @param k neighbors to be returned + * @param methodParameters method parameter + * @param knnEngine engine to query index + * @param filteredIds array of ints on which should be used for search. + * @param filterIdsType how to filter ids: Batch or BitMap * @return KNNQueryResult array of k neighbors */ public static KNNQueryResult[] queryBinaryIndex( @@ -252,7 +331,9 @@ public static KNNQueryResult[] queryBinaryIndex( parentIds ); } - throw new IllegalArgumentException(String.format("QueryBinaryIndex not supported for provided engine : %s", knnEngine.getName())); + throw new IllegalArgumentException( + String.format(Locale.ROOT, "QueryBinaryIndex not supported for provided engine : %s", knnEngine.getName()) + ); } /** @@ -283,7 +364,7 @@ public static void free(final long indexPointer, final KNNEngine knnEngine, fina return; } - throw new IllegalArgumentException(String.format("Free not supported for provided engine : %s", knnEngine.getName())); + throw new IllegalArgumentException(String.format(Locale.ROOT, "Free not supported for provided engine : %s", knnEngine.getName())); } /** @@ -298,7 +379,7 @@ public static void freeSharedIndexState(long shareIndexStateAddr, KNNEngine knnE return; } throw new IllegalArgumentException( - String.format("FreeSharedIndexState not supported for provided engine : %s", knnEngine.getName()) + String.format(Locale.ROOT, "FreeSharedIndexState not supported for provided engine : %s", knnEngine.getName()) ); } @@ -319,17 +400,19 @@ public static byte[] trainIndex(Map indexParameters, int dimensi return FaissService.trainIndex(indexParameters, dimension, trainVectorsPointer); } - throw new IllegalArgumentException(String.format("TrainIndex not supported for provided engine : %s", knnEngine.getName())); + throw new IllegalArgumentException( + String.format(Locale.ROOT, "TrainIndex not supported for provided engine : %s", knnEngine.getName()) + ); } /** *

- * The function is deprecated. Use {@link JNICommons#storeVectorData(long, float[][], long)} + * The function is deprecated. Use {@link JNICommons#storeVectorData(long, float[][], long, boolean)} *

* Transfer vectors from Java to native * * @param vectorsPointer pointer to vectors in native memory. Should be 0 to create vector as well - * @param trainingData data to be transferred + * @param trainingData data to be transferred * @return pointer to native memory location of training data */ @Deprecated(since = "2.14.0", forRemoval = true) @@ -340,15 +423,15 @@ public static long transferVectors(long vectorsPointer, float[][] trainingData) /** * Range search index for a given query vector * - * @param indexPointer pointer to index in memory - * @param queryVector vector to be used for query - * @param radius search within radius threshold - * @param methodParameters parameters to be used when loading index - * @param knnEngine engine to query index + * @param indexPointer pointer to index in memory + * @param queryVector vector to be used for query + * @param radius search within radius threshold + * @param methodParameters parameters to be used when loading index + * @param knnEngine engine to query index * @param indexMaxResultWindow maximum number of results to return - * @param filteredIds list of doc ids to include in the query result - * @param filterIdsType how to filter ids: Batch or BitMap - * @param parentIds parent ids of the vectors + * @param filteredIds list of doc ids to include in the query result + * @param filterIdsType how to filter ids: Batch or BitMap + * @param parentIds parent ids of the vectors * @return KNNQueryResult array of neighbors within radius */ public static KNNQueryResult[] radiusQueryIndex( @@ -377,6 +460,6 @@ public static KNNQueryResult[] radiusQueryIndex( } return FaissService.rangeSearchIndex(indexPointer, queryVector, radius, methodParameters, indexMaxResultWindow, parentIds); } - throw new IllegalArgumentException("RadiusQueryIndex not supported for provided engine"); + throw new IllegalArgumentException(String.format(Locale.ROOT, "RadiusQueryIndex not supported for provided engine")); } } diff --git a/src/test/java/org/opensearch/knn/common/FieldInfoExtractorTests.java b/src/test/java/org/opensearch/knn/common/FieldInfoExtractorTests.java index df529bb47..27aedd1d0 100644 --- a/src/test/java/org/opensearch/knn/common/FieldInfoExtractorTests.java +++ b/src/test/java/org/opensearch/knn/common/FieldInfoExtractorTests.java @@ -14,6 +14,8 @@ import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelUtil; +import static org.mockito.Mockito.when; + public class FieldInfoExtractorTests extends KNNTestCase { private static final String MODEL_ID = "model_id"; @@ -39,4 +41,26 @@ public void testExtractVectorDataType_whenDifferentConditions_thenSuccess() { Assert.assertEquals(VectorDataType.BYTE, FieldInfoExtractor.extractVectorDataType(fieldInfo)); } } + + public void testExtractVectorDataType() { + FieldInfo fieldInfo = Mockito.mock(FieldInfo.class); + when(fieldInfo.getAttribute("data_type")).thenReturn(VectorDataType.BINARY.getValue()); + + assertEquals(VectorDataType.BINARY, FieldInfoExtractor.extractVectorDataType(fieldInfo)); + when(fieldInfo.getAttribute("data_type")).thenReturn(null); + + when(fieldInfo.getAttribute("model_id")).thenReturn(MODEL_ID); + try (MockedStatic modelUtilMockedStatic = Mockito.mockStatic(ModelUtil.class)) { + ModelMetadata modelMetadata = Mockito.mock(ModelMetadata.class); + modelUtilMockedStatic.when(() -> ModelUtil.getModelMetadata(MODEL_ID)).thenReturn(modelMetadata); + when(modelMetadata.getVectorDataType()).thenReturn(VectorDataType.BYTE); + + assertEquals(VectorDataType.BYTE, FieldInfoExtractor.extractVectorDataType(fieldInfo)); + when(modelMetadata.getVectorDataType()).thenReturn(null); + when(modelMetadata.getVectorDataType()).thenReturn(VectorDataType.DEFAULT); + } + + when(fieldInfo.getAttribute("model_id")).thenReturn(null); + assertEquals(VectorDataType.DEFAULT, FieldInfoExtractor.extractVectorDataType(fieldInfo)); + } } diff --git a/src/test/java/org/opensearch/knn/common/KNNVectorUtilTests.java b/src/test/java/org/opensearch/knn/common/KNNVectorUtilTests.java index 457ea8c5b..d64b73c9a 100644 --- a/src/test/java/org/opensearch/knn/common/KNNVectorUtilTests.java +++ b/src/test/java/org/opensearch/knn/common/KNNVectorUtilTests.java @@ -11,7 +11,16 @@ package org.opensearch.knn.common; +import lombok.SneakyThrows; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; +import org.opensearch.knn.index.vectorvalues.TestVectorValues; + +import java.util.List; + +import static org.opensearch.knn.common.KNNVectorUtil.iterateVectorValuesOnce; public class KNNVectorUtilTests extends KNNTestCase { public void testByteZeroVector() { @@ -23,4 +32,29 @@ public void testFloatZeroVector() { assertTrue(KNNVectorUtil.isZeroVector(new float[] { 0.0f, 0.0f, 0.0f })); assertFalse(KNNVectorUtil.isZeroVector(new float[] { 1.0f, 1.0f, 1.0f })); } + + public void testIntListToArray() { + assertArrayEquals(new int[] { 1, 2, 3 }, KNNVectorUtil.intListToArray(List.of(1, 2, 3))); + assertNull(KNNVectorUtil.intListToArray(List.of())); + assertNull(KNNVectorUtil.intListToArray(null)); + } + + @SneakyThrows + public void testInit() { + // Give + final List floatArray = List.of(new float[] { 1, 2 }, new float[] { 2, 3 }); + final int dimension = floatArray.get(0).length; + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + floatArray + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + + // When + iterateVectorValuesOnce(knnVectorValues); + + // Then + assertNotEquals(-1, knnVectorValues.docId()); + assertArrayEquals(floatArray.get(0), knnVectorValues.getVector(), 0.001f); + assertEquals(dimension, knnVectorValues.dimension()); + } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java index 4c235a896..f49587bc5 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java @@ -105,7 +105,7 @@ public void testAddBinaryField_withKNN() throws IOException { KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(delegate, null) { @Override - public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge, boolean isRefresh) { + public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge) { called[0] = true; } }; @@ -142,7 +142,7 @@ public void testAddBinaryField_withoutKNN() throws IOException { KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(delegate, state) { @Override - public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge, boolean isRefresh) { + public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge) { called[0] = true; } }; @@ -160,7 +160,6 @@ public void testAddKNNBinaryField_noVectors() throws IOException { 128 ); Long initialGraphIndexRequests = KNNCounter.GRAPH_INDEX_REQUESTS.getCount(); - Long initialRefreshOperations = KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue(); Long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); Long initialMergeSize = KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue(); Long initialMergeDocs = KNNGraphValue.MERGE_TOTAL_DOCS.getValue(); @@ -178,9 +177,8 @@ public void testAddKNNBinaryField_noVectors() throws IOException { SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); FieldInfo fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder("test-field").build(); - knn80DocValuesConsumer.addKNNBinaryField(fieldInfo, randomVectorDocValuesProducer, true, true); + knn80DocValuesConsumer.addKNNBinaryField(fieldInfo, randomVectorDocValuesProducer, true); assertEquals(initialGraphIndexRequests, KNNCounter.GRAPH_INDEX_REQUESTS.getCount()); - assertEquals(initialRefreshOperations, KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); assertEquals(initialMergeOperations, KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); assertEquals(initialMergeSize, KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); assertEquals(initialMergeDocs, KNNGraphValue.MERGE_TOTAL_DOCS.getValue()); @@ -228,7 +226,6 @@ public void testAddKNNBinaryField_fromScratch_nmslibCurrent() throws IOException FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); - long initialRefreshOperations = KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue(); long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); // Add documents to the field @@ -237,7 +234,61 @@ public void testAddKNNBinaryField_fromScratch_nmslibCurrent() throws IOException docsInSegment, dimension ); - knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true, true); + knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true); + + // The document should be created in the correct location + String expectedFile = KNNCodecUtil.buildEngineFileName(segmentName, knnEngine.getVersion(), fieldName, knnEngine.getExtension()); + assertFileInCorrectLocation(state, expectedFile); + + // The footer should be valid + assertValidFooter(state.directory, expectedFile); + + // The document should be readable by nmslib + assertLoadableByEngine(null, state, expectedFile, knnEngine, spaceType, dimension); + + // The graph creation statistics should be updated + assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); + assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_DOCS.getValue()); + assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); + } + + public void testAddKNNBinaryField_fromScratch_nmslibLegacy() throws IOException { + // Set information about the segment and the fields + String segmentName = String.format("test_segment%s", randomAlphaOfLength(4)); + int docsInSegment = 100; + String fieldName = String.format("test_field%s", randomAlphaOfLength(4)); + + KNNEngine knnEngine = KNNEngine.NMSLIB; + SpaceType spaceType = SpaceType.COSINESIMIL; + int dimension = 16; + + SegmentInfo segmentInfo = KNNCodecTestUtil.segmentInfoBuilder() + .directory(directory) + .segmentName(segmentName) + .docsInSegment(docsInSegment) + .codec(codec) + .build(); + + FieldInfo[] fieldInfoArray = new FieldInfo[] { + KNNCodecTestUtil.FieldInfoBuilder.builder(fieldName) + .addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true") + .addAttribute(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION, "512") + .addAttribute(KNNConstants.HNSW_ALGO_M, "16") + .addAttribute(KNNConstants.SPACE_TYPE, spaceType.getValue()) + .build() }; + + FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); + SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); + + long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); + + // Add documents to the field + KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); + TestVectorValues.RandomVectorDocValuesProducer randomVectorDocValuesProducer = new TestVectorValues.RandomVectorDocValuesProducer( + docsInSegment, + dimension + ); + knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true); // The document should be created in the correct location String expectedFile = KNNCodecUtil.buildEngineFileName(segmentName, knnEngine.getVersion(), fieldName, knnEngine.getExtension()); @@ -250,7 +301,6 @@ public void testAddKNNBinaryField_fromScratch_nmslibCurrent() throws IOException assertLoadableByEngine(null, state, expectedFile, knnEngine, spaceType, dimension); // The graph creation statistics should be updated - assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_DOCS.getValue()); assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); @@ -298,7 +348,6 @@ public void testAddKNNBinaryField_fromScratch_faissCurrent() throws IOException SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); long initialRefreshOperations = KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue(); - long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); // Add documents to the field KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); @@ -306,7 +355,7 @@ public void testAddKNNBinaryField_fromScratch_faissCurrent() throws IOException docsInSegment, dimension ); - knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true, true); + knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, false); // The document should be created in the correct location String expectedFile = KNNCodecUtil.buildEngineFileName(segmentName, knnEngine.getVersion(), fieldName, knnEngine.getExtension()); @@ -320,9 +369,6 @@ public void testAddKNNBinaryField_fromScratch_faissCurrent() throws IOException // The graph creation statistics should be updated assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); - assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); - assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_DOCS.getValue()); - assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); } public void testAddKNNBinaryField_whenFaissBinary_thenAdded() throws IOException { @@ -368,7 +414,6 @@ public void testAddKNNBinaryField_whenFaissBinary_thenAdded() throws IOException FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); - long initialRefreshOperations = KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue(); long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); // Add documents to the field @@ -377,7 +422,7 @@ public void testAddKNNBinaryField_whenFaissBinary_thenAdded() throws IOException docsInSegment, dimension ); - knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true, true); + knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true); // The document should be created in the correct location String expectedFile = KNNCodecUtil.buildEngineFileName(segmentName, knnEngine.getVersion(), fieldName, knnEngine.getExtension()); @@ -390,7 +435,6 @@ public void testAddKNNBinaryField_whenFaissBinary_thenAdded() throws IOException assertBinaryIndexLoadableByEngine(state, expectedFile, knnEngine, spaceType, dimension, dataType); // The graph creation statistics should be updated - assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_DOCS.getValue()); assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); @@ -467,7 +511,6 @@ public void testAddKNNBinaryField_fromModel_faiss() throws IOException, Executio FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); - long initialRefreshOperations = KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue(); long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); // Add documents to the field @@ -476,7 +519,7 @@ public void testAddKNNBinaryField_fromModel_faiss() throws IOException, Executio docsInSegment, dimension ); - knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true, true); + knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true); // The document should be created in the correct location String expectedFile = KNNCodecUtil.buildEngineFileName(segmentName, knnEngine.getVersion(), fieldName, knnEngine.getExtension()); @@ -489,7 +532,6 @@ public void testAddKNNBinaryField_fromModel_faiss() throws IOException, Executio assertLoadableByEngine(HNSW_METHODPARAMETERS, state, expectedFile, knnEngine, spaceType, dimension); // The graph creation statistics should be updated - assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_DOCS.getValue()); assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); @@ -561,6 +603,6 @@ public void testAddBinaryField_luceneEngine_noInvocations_addKNNBinary() throws knn80DocValuesConsumer.addBinaryField(fieldInfo, docValuesProducer); verify(delegate, times(1)).addBinaryField(fieldInfo, docValuesProducer); - verify(knn80DocValuesConsumer, never()).addKNNBinaryField(any(), any(), eq(false), eq(true)); + verify(knn80DocValuesConsumer, never()).addKNNBinaryField(any(), any(), eq(false)); } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java index 322b714f2..3810d46fd 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java @@ -46,6 +46,7 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.engine.KNNEngine; @@ -99,10 +100,13 @@ public void testNativeEngineVectorFormat_whenMultipleVectorFieldIndexed_thenSucc byte[] byteVector = { 6, 14 }; addFieldToIndex( - new KnnFloatVectorField(FLOAT_VECTOR_FIELD, floatVector, createVectorField(3, VectorEncoding.FLOAT32)), + new KnnFloatVectorField(FLOAT_VECTOR_FIELD, floatVector, createVectorField(3, VectorEncoding.FLOAT32, VectorDataType.FLOAT)), + indexWriter + ); + addFieldToIndex( + new KnnByteVectorField(BYTE_VECTOR_FIELD, byteVector, createVectorField(2, VectorEncoding.BYTE, VectorDataType.BINARY)), indexWriter ); - addFieldToIndex(new KnnByteVectorField(BYTE_VECTOR_FIELD, byteVector, createVectorField(2, VectorEncoding.BYTE)), indexWriter); final IndexReader indexReader = indexWriter.getReader(); // ensuring segments are created indexWriter.flush(); @@ -187,17 +191,19 @@ private void addFieldToIndex(final Field vectorField, final RandomIndexWriter in indexWriter.addDocument(doc1); } - private FieldType createVectorField(int dimension, VectorEncoding vectorEncoding) { + private FieldType createVectorField(int dimension, VectorEncoding vectorEncoding, VectorDataType vectorDataType) { FieldType nativeVectorField = new FieldType(); // TODO: Replace this with the default field which will be created in mapper for Native Engines with KNNVectorsFormat nativeVectorField.setTokenized(false); nativeVectorField.setIndexOptions(IndexOptions.NONE); nativeVectorField.putAttribute(KNNVectorFieldMapper.KNN_FIELD, "true"); nativeVectorField.putAttribute(KNNConstants.KNN_METHOD, KNNConstants.METHOD_HNSW); - nativeVectorField.putAttribute(KNNConstants.KNN_ENGINE, KNNEngine.NMSLIB.getName()); + nativeVectorField.putAttribute(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()); nativeVectorField.putAttribute(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()); nativeVectorField.putAttribute(KNNConstants.HNSW_ALGO_M, "32"); nativeVectorField.putAttribute(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION, "512"); + nativeVectorField.putAttribute(KNNConstants.VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue()); + nativeVectorField.putAttribute(KNNConstants.PARAMETERS, "{ \"index_description\":\"HNSW16,Flat\", \"spaceType\": \"l2\"}"); nativeVectorField.setVectorAttributes( dimension, vectorEncoding, diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java new file mode 100644 index 000000000..34a333471 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java @@ -0,0 +1,167 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex; + +import lombok.SneakyThrows; +import org.apache.lucene.index.DocsWithFieldSet; +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.MockedStatic; +import org.opensearch.core.common.unit.ByteSizeValue; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; +import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer; +import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransferFactory; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; +import org.opensearch.knn.index.vectorvalues.TestVectorValues; +import org.opensearch.knn.jni.JNIService; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.List; +import java.util.Map; + +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class DefaultIndexBuildStrategyTests extends OpenSearchTestCase { + + ArgumentCaptor vectorTransferCapture = ArgumentCaptor.forClass(float[].class); + + @Before + public void init() { + vectorTransferCapture = ArgumentCaptor.forClass(float[].class); + } + + @SneakyThrows + public void testBuildAndWrite() { + // Given + List vectorValues = List.of(new float[] { 1, 2 }, new float[] { 2, 3 }, new float[] { 3, 4 }); + + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + vectorValues + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + + try ( + MockedStatic mockedKNNSettings = mockStatic(KNNSettings.class); + MockedStatic mockedJNIService = mockStatic(JNIService.class); + MockedStatic mockedOffHeapVectorTransferFactory = mockStatic(OffHeapVectorTransferFactory.class) + ) { + + mockedKNNSettings.when(KNNSettings::getVectorStreamingMemoryLimit).thenReturn(new ByteSizeValue(16)); + OffHeapVectorTransfer offHeapVectorTransfer = mock(OffHeapVectorTransfer.class); + mockedOffHeapVectorTransferFactory.when(() -> OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 2)) + .thenReturn(offHeapVectorTransfer); + + when(offHeapVectorTransfer.getVectorAddress()).thenReturn(200L); + + BuildIndexParams buildIndexParams = BuildIndexParams.builder() + .indexPath("indexPath") + .knnEngine(KNNEngine.NMSLIB) + .vectorDataType(VectorDataType.FLOAT) + .parameters(Map.of("index", "param")) + .build(); + + // When + DefaultIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams, knnVectorValues); + + // Then + mockedJNIService.verify( + () -> JNIService.createIndex( + eq(new int[] { 0, 1, 2 }), + eq(200L), + eq(knnVectorValues.dimension()), + eq("indexPath"), + eq(Map.of("index", "param")), + eq(KNNEngine.NMSLIB) + ) + ); + mockedJNIService.verifyNoMoreInteractions(); + verify(offHeapVectorTransfer).flush(true); + verify(offHeapVectorTransfer, times(3)).transfer(vectorTransferCapture.capture(), eq(true)); + verify(offHeapVectorTransfer).reset(); + + float[] prev = null; + for (float[] vector : vectorTransferCapture.getAllValues()) { + if (prev != null) { + assertNotSame(prev, vector); + } + prev = vector; + } + } + } + + @SneakyThrows + public void testBuildAndWriteWithModel() { + // Given + final Map docs = Map.of(0, new float[] { 1, 2 }, 1, new float[] { 2, 3 }, 2, new float[] { 3, 4 }); + DocsWithFieldSet docsWithFieldSet = new DocsWithFieldSet(); + docs.keySet().stream().sorted().forEach(docsWithFieldSet::add); + + byte[] modelBlob = new byte[] { 1 }; + + KNNFloatVectorValues knnVectorValues = (KNNFloatVectorValues) KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + docsWithFieldSet, + docs + ); + try ( + MockedStatic mockedKNNSettings = mockStatic(KNNSettings.class); + MockedStatic mockedJNIService = mockStatic(JNIService.class); + MockedStatic mockedOffHeapVectorTransferFactory = mockStatic(OffHeapVectorTransferFactory.class) + ) { + + mockedKNNSettings.when(KNNSettings::getVectorStreamingMemoryLimit).thenReturn(new ByteSizeValue(16)); + OffHeapVectorTransfer offHeapVectorTransfer = mock(OffHeapVectorTransfer.class); + mockedOffHeapVectorTransferFactory.when(() -> OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 2)) + .thenReturn(offHeapVectorTransfer); + + when(offHeapVectorTransfer.getVectorAddress()).thenReturn(200L); + + BuildIndexParams buildIndexParams = BuildIndexParams.builder() + .indexPath("indexPath") + .knnEngine(KNNEngine.NMSLIB) + .vectorDataType(VectorDataType.FLOAT) + .parameters(Map.of("model_id", "id", "model_blob", modelBlob)) + .build(); + + // When + DefaultIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams, knnVectorValues); + + // Then + mockedJNIService.verify( + () -> JNIService.createIndexFromTemplate( + eq(new int[] { 0, 1, 2 }), + eq(200L), + eq(2), + eq("indexPath"), + eq(modelBlob), + eq(Map.of("model_id", "id", "model_blob", modelBlob)), + eq(KNNEngine.NMSLIB) + ) + ); + mockedJNIService.verifyNoMoreInteractions(); + verify(offHeapVectorTransfer).flush(true); + verify(offHeapVectorTransfer, times(3)).transfer(vectorTransferCapture.capture(), eq(true)); + + float[] prev = null; + for (float[] vector : vectorTransferCapture.getAllValues()) { + if (prev != null) { + assertNotSame(prev, vector); + } + prev = vector; + } + } + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java new file mode 100644 index 000000000..2ecfe9259 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java @@ -0,0 +1,129 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex; + +import lombok.SneakyThrows; +import org.mockito.ArgumentCaptor; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.opensearch.core.common.unit.ByteSizeValue; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; +import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer; +import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransferFactory; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; +import org.opensearch.knn.index.vectorvalues.TestVectorValues; +import org.opensearch.knn.jni.JNIService; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.List; +import java.util.Map; + +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class MemOptimizedNativeIndexBuildStrategyTests extends OpenSearchTestCase { + + @SneakyThrows + public void testBuildAndWrite() { + // Given + ArgumentCaptor vectorAddressCaptor = ArgumentCaptor.forClass(Long.class); + ArgumentCaptor vectorTransferCapture = ArgumentCaptor.forClass(float[].class); + + List vectorValues = List.of(new float[] { 1, 2 }, new float[] { 2, 3 }, new float[] { 3, 4 }); + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + vectorValues + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + + try ( + MockedStatic mockedKNNSettings = Mockito.mockStatic(KNNSettings.class); + MockedStatic mockedJNIService = Mockito.mockStatic(JNIService.class); + MockedStatic mockedOffHeapVectorTransferFactory = Mockito.mockStatic( + OffHeapVectorTransferFactory.class + ); + ) { + + // Limits transfer to 2 vectors + mockedKNNSettings.when(KNNSettings::getVectorStreamingMemoryLimit).thenReturn(new ByteSizeValue(16)); + mockedJNIService.when(() -> JNIService.initIndex(3, 2, Map.of("index", "param"), KNNEngine.FAISS)).thenReturn(100L); + + OffHeapVectorTransfer offHeapVectorTransfer = mock(OffHeapVectorTransfer.class); + mockedOffHeapVectorTransferFactory.when(() -> OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 2)) + .thenReturn(offHeapVectorTransfer); + + when(offHeapVectorTransfer.transfer(vectorTransferCapture.capture(), eq(false))).thenReturn(false) + .thenReturn(true) + .thenReturn(false); + when(offHeapVectorTransfer.flush(false)).thenReturn(true); + when(offHeapVectorTransfer.getVectorAddress()).thenReturn(200L); + + BuildIndexParams buildIndexParams = BuildIndexParams.builder() + .indexPath("indexPath") + .knnEngine(KNNEngine.FAISS) + .vectorDataType(VectorDataType.FLOAT) + .parameters(Map.of("index", "param")) + .build(); + + // When + MemOptimizedNativeIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams, knnVectorValues); + + // Then + mockedJNIService.verify( + () -> JNIService.initIndex( + knnVectorValues.totalLiveDocs(), + knnVectorValues.dimension(), + Map.of("index", "param"), + KNNEngine.FAISS + ) + ); + + mockedJNIService.verify( + () -> JNIService.insertToIndex( + eq(new int[] { 0, 1 }), + vectorAddressCaptor.capture(), + eq(knnVectorValues.dimension()), + eq(Map.of("index", "param")), + eq(100L), + eq(KNNEngine.FAISS) + ) + ); + + // For the flush + mockedJNIService.verify( + () -> JNIService.insertToIndex( + eq(new int[] { 2 }), + vectorAddressCaptor.capture(), + eq(knnVectorValues.dimension()), + eq(Map.of("index", "param")), + eq(100L), + eq(KNNEngine.FAISS) + ) + ); + + mockedJNIService.verify( + () -> JNIService.writeIndex(eq("indexPath"), eq(100L), eq(KNNEngine.FAISS), eq(Map.of("index", "param"))) + ); + assertEquals(200L, vectorAddressCaptor.getValue().longValue()); + assertEquals(vectorAddressCaptor.getValue().longValue(), vectorAddressCaptor.getAllValues().get(0).longValue()); + verify(offHeapVectorTransfer, times(0)).reset(); + + float[] prev = null; + for (float[] vector : vectorTransferCapture.getAllValues()) { + if (prev != null) { + assertNotSame(prev, vector); + } + prev = vector; + } + } + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactoryTests.java b/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactoryTests.java new file mode 100644 index 000000000..cef875cfc --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferFactoryTests.java @@ -0,0 +1,26 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.transfer; + +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.test.OpenSearchTestCase; + +public class OffHeapVectorTransferFactoryTests extends OpenSearchTestCase { + + public void testOffHeapVectorTransferFactory() { + var floatVectorTransfer = OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 10); + assertEquals(OffHeapFloatVectorTransfer.class, floatVectorTransfer.getClass()); + assertNotSame(floatVectorTransfer, OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 10)); + + var byteVectorTransfer = OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.BYTE, 10); + assertEquals(OffHeapByteVectorTransfer.class, byteVectorTransfer.getClass()); + assertNotSame(byteVectorTransfer, OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.BYTE, 10)); + + var binaryVectorTransfer = OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.BINARY, 10); + assertEquals(OffHeapByteVectorTransfer.class, binaryVectorTransfer.getClass()); + assertNotSame(binaryVectorTransfer, OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.BINARY, 10)); + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferTests.java b/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferTests.java new file mode 100644 index 000000000..f1650db8f --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/transfer/OffHeapVectorTransferTests.java @@ -0,0 +1,92 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.transfer; + +import lombok.SneakyThrows; +import org.opensearch.knn.KNNTestCase; + +import java.util.List; + +public class OffHeapVectorTransferTests extends KNNTestCase { + + @SneakyThrows + public void testFloatTransfer() { + List vectors = List.of( + new float[] { 0.1f, 0.2f }, + new float[] { 0.2f, 0.3f }, + new float[] { 0.3f, 0.4f }, + new float[] { 0.3f, 0.4f }, + new float[] { 0.3f, 0.4f } + ); + + OffHeapFloatVectorTransfer vectorTransfer = new OffHeapFloatVectorTransfer(2); + long vectorAddress = 0; + assertFalse(vectorTransfer.transfer(vectors.get(0), false)); + assertEquals(0, vectorTransfer.getVectorAddress()); + assertTrue(vectorTransfer.transfer(vectors.get(1), false)); + vectorAddress = vectorTransfer.getVectorAddress(); + assertFalse(vectorTransfer.transfer(vectors.get(2), false)); + assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); + assertTrue(vectorTransfer.transfer(vectors.get(3), false)); + assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); + assertFalse(vectorTransfer.transfer(vectors.get(4), false)); + assertTrue(vectorTransfer.flush(false)); + vectorTransfer.reset(); + assertEquals(0, vectorTransfer.getVectorAddress()); + vectorTransfer.close(); + } + + @SneakyThrows + public void testByteTransfer() { + List vectors = List.of( + new byte[] { 0, 1 }, + new byte[] { 2, 3 }, + new byte[] { 4, 5 }, + new byte[] { 6, 7 }, + new byte[] { 8, 9 } + ); + + OffHeapByteVectorTransfer vectorTransfer = new OffHeapByteVectorTransfer(2); + long vectorAddress = 0; + assertFalse(vectorTransfer.transfer(vectors.get(0), false)); + assertEquals(0, vectorTransfer.getVectorAddress()); + assertTrue(vectorTransfer.transfer(vectors.get(1), false)); + vectorAddress = vectorTransfer.getVectorAddress(); + assertFalse(vectorTransfer.transfer(vectors.get(2), false)); + assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); + assertTrue(vectorTransfer.transfer(vectors.get(3), false)); + assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); + assertFalse(vectorTransfer.transfer(vectors.get(4), false)); + assertTrue(vectorTransfer.flush(false)); + vectorTransfer.close(); + assertEquals(0, vectorTransfer.getVectorAddress()); + } + + @SneakyThrows + public void testBinaryTransfer() { + List vectors = List.of( + new byte[] { 0, 1 }, + new byte[] { 2, 3 }, + new byte[] { 4, 5 }, + new byte[] { 6, 7 }, + new byte[] { 8, 9 } + ); + + OffHeapBinaryVectorTransfer vectorTransfer = new OffHeapBinaryVectorTransfer(2); + long vectorAddress = 0; + assertFalse(vectorTransfer.transfer(vectors.get(0), false)); + assertEquals(0, vectorTransfer.getVectorAddress()); + assertTrue(vectorTransfer.transfer(vectors.get(1), false)); + vectorAddress = vectorTransfer.getVectorAddress(); + assertFalse(vectorTransfer.transfer(vectors.get(2), false)); + assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); + assertTrue(vectorTransfer.transfer(vectors.get(3), false)); + assertEquals(vectorAddress, vectorTransfer.getVectorAddress()); + assertFalse(vectorTransfer.transfer(vectors.get(4), false)); + assertTrue(vectorTransfer.flush(false)); + vectorTransfer.close(); + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferByteTests.java b/src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferByteTests.java deleted file mode 100644 index 2f091a035..000000000 --- a/src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferByteTests.java +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.codec.transfer; - -import junit.framework.TestCase; -import lombok.SneakyThrows; -import org.apache.lucene.util.BytesRef; -import org.opensearch.knn.index.codec.util.SerializationMode; -import org.opensearch.knn.jni.JNICommons; - -import java.io.IOException; -import java.util.Random; - -import static org.junit.Assert.assertNotEquals; - -public class VectorTransferByteTests extends TestCase { - @SneakyThrows - public void testTransfer_whenCalled_thenAdded() { - final BytesRef bytesRef1 = getByteArrayOfVectors(20); - final BytesRef bytesRef2 = getByteArrayOfVectors(20); - VectorTransferByte vectorTransfer = new VectorTransferByte(40); - try { - vectorTransfer.init(2); - - vectorTransfer.transfer(bytesRef1); - // flush is not called - assertEquals(0, vectorTransfer.getVectorAddress()); - - vectorTransfer.transfer(bytesRef2); - // flush should be called - assertNotEquals(0, vectorTransfer.getVectorAddress()); - } finally { - if (vectorTransfer.getVectorAddress() != 0) { - JNICommons.freeVectorData(vectorTransfer.getVectorAddress()); - } - } - } - - @SneakyThrows - public void testSerializationMode_whenCalled_thenReturn() { - final BytesRef bytesRef = getByteArrayOfVectors(20); - VectorTransferByte vectorTransfer = new VectorTransferByte(1000); - - // Verify - assertEquals(SerializationMode.COLLECTIONS_OF_BYTES, vectorTransfer.getSerializationMode(bytesRef)); - } - - private BytesRef getByteArrayOfVectors(int vectorLength) throws IOException { - byte[] vector = new byte[vectorLength]; - new Random().nextBytes(vector); - return new BytesRef(vector); - } -} diff --git a/src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloatTests.java b/src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloatTests.java deleted file mode 100644 index 620fd7c65..000000000 --- a/src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloatTests.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.codec.transfer; - -import junit.framework.TestCase; -import lombok.SneakyThrows; -import org.apache.lucene.util.BytesRef; -import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; -import org.opensearch.knn.jni.JNICommons; - -import java.io.ByteArrayOutputStream; -import java.io.DataOutputStream; -import java.io.IOException; -import java.util.Random; -import java.util.stream.IntStream; - -import static org.junit.Assert.assertNotEquals; - -public class VectorTransferFloatTests extends TestCase { - @SneakyThrows - public void testTransfer_whenCalled_thenAdded() { - final BytesRef bytesRef1 = getByteArrayOfVectors(20); - final BytesRef bytesRef2 = getByteArrayOfVectors(20); - VectorTransferFloat vectorTransfer = new VectorTransferFloat(160); - try { - vectorTransfer.init(2); - - vectorTransfer.transfer(bytesRef1); - // flush is not called - assertEquals(0, vectorTransfer.getVectorAddress()); - - vectorTransfer.transfer(bytesRef2); - // flush should be called - assertNotEquals(0, vectorTransfer.getVectorAddress()); - } finally { - if (vectorTransfer.getVectorAddress() != 0) { - JNICommons.freeVectorData(vectorTransfer.getVectorAddress()); - } - } - } - - @SneakyThrows - public void testSerializationMode_whenCalled_thenReturn() { - final BytesRef bytesRef = getByteArrayOfVectors(20); - VectorTransferFloat vectorTransfer = new VectorTransferFloat(1000); - - // Verify - assertEquals(KNNVectorSerializerFactory.getSerializerModeFromBytesRef(bytesRef), vectorTransfer.getSerializationMode(bytesRef)); - } - - private BytesRef getByteArrayOfVectors(int vectorLength) throws IOException { - float[] vector = new float[vectorLength]; - IntStream.range(0, vectorLength).forEach(index -> vector[index] = new Random().nextFloat()); - - final ByteArrayOutputStream bas = new ByteArrayOutputStream(); - final DataOutputStream ds = new DataOutputStream(bas); - for (float f : vector) { - ds.writeFloat(f); - } - final byte[] vectorAsCollectionOfFloats = bas.toByteArray(); - return new BytesRef(vectorAsCollectionOfFloats); - } -} diff --git a/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java b/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java index 47dd1dda9..dbea6375b 100644 --- a/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java @@ -6,54 +6,11 @@ package org.opensearch.knn.index.codec.util; import junit.framework.TestCase; -import lombok.SneakyThrows; -import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.util.BytesRef; import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.codec.transfer.VectorTransfer; -import java.util.Arrays; - -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; import static org.opensearch.knn.index.codec.util.KNNCodecUtil.calculateArraySize; public class KNNCodecUtilTests extends TestCase { - @SneakyThrows - public void testGetPair_whenCalled_thenReturn() { - long liveDocCount = 1l; - int[] docId = { 2 }; - long vectorAddress = 3l; - int dimension = 4; - BytesRef bytesRef = new BytesRef(); - - BinaryDocValues binaryDocValues = mock(BinaryDocValues.class); - when(binaryDocValues.cost()).thenReturn(liveDocCount); - when(binaryDocValues.nextDoc()).thenReturn(docId[0], NO_MORE_DOCS); - when(binaryDocValues.binaryValue()).thenReturn(bytesRef); - - VectorTransfer vectorTransfer = mock(VectorTransfer.class); - when(vectorTransfer.getSerializationMode(any(BytesRef.class))).thenReturn(SerializationMode.COLLECTIONS_OF_BYTES); - when(vectorTransfer.getVectorAddress()).thenReturn(vectorAddress); - when(vectorTransfer.getDimension()).thenReturn(dimension); - - // Run - KNNCodecUtil.Pair pair = KNNCodecUtil.getPair(binaryDocValues, vectorTransfer); - - // Verify - verify(vectorTransfer).init(liveDocCount); - verify(vectorTransfer).getSerializationMode(any(BytesRef.class)); - verify(vectorTransfer).transfer(any(BytesRef.class)); - verify(vectorTransfer).close(); - - assertTrue(Arrays.equals(docId, pair.docs)); - assertEquals(vectorAddress, pair.getVectorAddress()); - assertEquals(dimension, pair.getDimension()); - assertEquals(SerializationMode.COLLECTIONS_OF_BYTES, pair.serializationMode); - } public void testCalculateArraySize() { int numVectors = 4; diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java index dc9a97fbf..316582f6c 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java @@ -14,6 +14,7 @@ import com.google.common.collect.ImmutableMap; import lombok.SneakyThrows; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.TestUtils; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.util.IndexUtil; import org.opensearch.knn.index.VectorDataType; @@ -56,7 +57,7 @@ public void testIndexAllocation_close() throws InterruptedException { } Map parameters = ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue()); long vectorMemoryAddress = JNICommons.storeVectorData(0, vectors, numVectors * dimension); - JNIService.createIndex(ids, vectorMemoryAddress, dimension, path, parameters, knnEngine); + TestUtils.createIndex(ids, vectorMemoryAddress, dimension, path, parameters, knnEngine); // Load index into memory long memoryAddress = JNIService.loadIndex(path, parameters, knnEngine); @@ -117,7 +118,7 @@ public void testClose_whenBinaryFiass_thenSuccess() { VectorDataType.BINARY.getValue() ); long vectorMemoryAddress = JNICommons.storeByteVectorData(0, vectors, numVectors * dataLength); - JNIService.createIndex(ids, vectorMemoryAddress, dimension, path, parameters, knnEngine); + TestUtils.createIndex(ids, vectorMemoryAddress, dimension, path, parameters, knnEngine); // Load index into memory long memoryAddress = JNIService.loadIndex(path, parameters, knnEngine); diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java index 51f95d29a..8a38cadb5 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java @@ -15,6 +15,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.action.search.SearchResponse; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.TestUtils; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.jni.JNICommons; @@ -32,6 +33,8 @@ import java.util.Arrays; import java.util.Map; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.any; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.doAnswer; @@ -56,7 +59,7 @@ public void testIndexLoadStrategy_load() throws IOException { } Map parameters = ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue()); long memoryAddress = JNICommons.storeVectorData(0, vectors, numVectors * dimension); - JNIService.createIndex(ids, memoryAddress, dimension, path, parameters, knnEngine); + TestUtils.createIndex(ids, memoryAddress, dimension, path, parameters, knnEngine); // Setup mock resource manager ResourceWatcherService resourceWatcherService = mock(ResourceWatcherService.class); @@ -104,7 +107,7 @@ public void testLoad_whenFaissBinary_thenSuccess() throws IOException { VectorDataType.BINARY.getValue() ); long memoryAddress = JNICommons.storeByteVectorData(0, vectors, numVectors); - JNIService.createIndex(ids, memoryAddress, dimension, path, parameters, knnEngine); + TestUtils.createIndex(ids, memoryAddress, dimension, path, parameters, knnEngine); // Setup mock resource manager ResourceWatcherService resourceWatcherService = mock(ResourceWatcherService.class); diff --git a/src/test/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesTests.java b/src/test/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesTests.java index f5a1351ae..0b631ab41 100644 --- a/src/test/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesTests.java +++ b/src/test/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesTests.java @@ -26,7 +26,7 @@ public void testFloatVectorValues_whenValidInput_thenSuccess() { floatArray ); final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); - new CompareVectorValues().validateVectorValues(knnVectorValues, floatArray, dimension, true); + new CompareVectorValues().validateVectorValues(knnVectorValues, floatArray, 8, dimension, true); final DocsWithFieldSet docsWithFieldSet = getDocIdSetIterator(floatArray.size()); @@ -36,7 +36,7 @@ public void testFloatVectorValues_whenValidInput_thenSuccess() { docsWithFieldSet, vectorsMap ); - new CompareVectorValues().validateVectorValues(knnVectorValuesForFieldWriter, floatArray, dimension, false); + new CompareVectorValues().validateVectorValues(knnVectorValuesForFieldWriter, floatArray, 8, dimension, false); final TestVectorValues.PredefinedFloatVectorBinaryDocValues preDefinedFloatVectorValues = new TestVectorValues.PredefinedFloatVectorBinaryDocValues(floatArray); @@ -44,7 +44,7 @@ public void testFloatVectorValues_whenValidInput_thenSuccess() { VectorDataType.FLOAT, preDefinedFloatVectorValues ); - new CompareVectorValues().validateVectorValues(knnFloatVectorValuesBinaryDocValues, floatArray, dimension, false); + new CompareVectorValues().validateVectorValues(knnFloatVectorValuesBinaryDocValues, floatArray, 8, dimension, false); } @SneakyThrows @@ -53,7 +53,7 @@ public void testByteVectorValues_whenValidInput_thenSuccess() { final int dimension = byteArray.get(0).length; final TestVectorValues.PreDefinedByteVectorValues randomVectorValues = new TestVectorValues.PreDefinedByteVectorValues(byteArray); final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.BYTE, randomVectorValues); - new CompareVectorValues().validateVectorValues(knnVectorValues, byteArray, dimension, true); + new CompareVectorValues().validateVectorValues(knnVectorValues, byteArray, 2, dimension, true); final DocsWithFieldSet docsWithFieldSet = getDocIdSetIterator(byteArray.size()); final Map vectorsMap = Map.of(0, byteArray.get(0), 1, byteArray.get(1)); @@ -62,7 +62,7 @@ public void testByteVectorValues_whenValidInput_thenSuccess() { docsWithFieldSet, vectorsMap ); - new CompareVectorValues().validateVectorValues(knnVectorValuesForFieldWriter, byteArray, dimension, false); + new CompareVectorValues().validateVectorValues(knnVectorValuesForFieldWriter, byteArray, 2, dimension, false); final TestVectorValues.PredefinedByteVectorBinaryDocValues preDefinedByteVectorValues = new TestVectorValues.PredefinedByteVectorBinaryDocValues(byteArray); @@ -70,7 +70,7 @@ public void testByteVectorValues_whenValidInput_thenSuccess() { VectorDataType.BYTE, preDefinedByteVectorValues ); - new CompareVectorValues().validateVectorValues(knnBinaryVectorValuesBinaryDocValues, byteArray, dimension, false); + new CompareVectorValues().validateVectorValues(knnBinaryVectorValuesBinaryDocValues, byteArray, 2, dimension, false); } @SneakyThrows @@ -81,7 +81,7 @@ public void testBinaryVectorValues_whenValidInput_thenSuccess() { byteArray ); final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.BINARY, randomVectorValues); - new CompareVectorValues().validateVectorValues(knnVectorValues, byteArray, dimension, true); + new CompareVectorValues().validateVectorValues(knnVectorValues, byteArray, 3, dimension, true); final DocsWithFieldSet docsWithFieldSet = getDocIdSetIterator(byteArray.size()); final Map vectorsMap = Map.of(0, byteArray.get(0), 1, byteArray.get(1)); @@ -90,7 +90,7 @@ public void testBinaryVectorValues_whenValidInput_thenSuccess() { docsWithFieldSet, vectorsMap ); - new CompareVectorValues().validateVectorValues(knnVectorValuesForFieldWriter, byteArray, dimension, false); + new CompareVectorValues().validateVectorValues(knnVectorValuesForFieldWriter, byteArray, 3, dimension, false); final TestVectorValues.PredefinedByteVectorBinaryDocValues preDefinedByteVectorValues = new TestVectorValues.PredefinedByteVectorBinaryDocValues(byteArray); @@ -98,7 +98,7 @@ public void testBinaryVectorValues_whenValidInput_thenSuccess() { VectorDataType.BINARY, preDefinedByteVectorValues ); - new CompareVectorValues().validateVectorValues(knnBinaryVectorValuesBinaryDocValues, byteArray, dimension, false); + new CompareVectorValues().validateVectorValues(knnBinaryVectorValuesBinaryDocValues, byteArray, 3, dimension, false); } public void testDocIdsIteratorValues_whenInvalidDisi_thenThrowException() { @@ -117,28 +117,38 @@ private DocsWithFieldSet getDocIdSetIterator(int numberOfDocIds) { } private class CompareVectorValues { - void validateVectorValues(KNNVectorValues vectorValues, List vectors, int dimension, boolean validateAddress) - throws IOException { - Assert.assertEquals(vectorValues.totalLiveDocs(), vectors.size()); + void validateVectorValues( + KNNVectorValues vectorValues, + List vectors, + int bytesPerVector, + int dimension, + boolean validateAddress + ) throws IOException { + assertEquals(vectorValues.totalLiveDocs(), vectors.size()); int docId, i = 0; T oldActual = null; int oldDocId = -1; final KNNVectorValuesIterator iterator = vectorValues.vectorValuesIterator; for (docId = iterator.nextDoc(); docId != DocIdSetIterator.NO_MORE_DOCS && i < vectors.size(); docId = iterator.nextDoc()) { T actual = vectorValues.getVector(); + T clone = vectorValues.conditionalCloneVector(); T expected = vectors.get(i); - Assert.assertNotEquals(oldDocId, docId); - Assert.assertEquals(dimension, vectorValues.dimension()); + assertNotEquals(oldDocId, docId); + assertEquals(dimension, vectorValues.dimension()); // this will check if reference is correct for the vectors. This is mainly required because for // VectorValues of Lucene when reading vectors put the vector at same reference if (oldActual != null && validateAddress) { - Assert.assertSame(actual, oldActual); + assertSame(actual, oldActual); + assertNotSame(clone, oldActual); } + oldActual = actual; // this will do the deep equals - Assert.assertArrayEquals(new Object[] { actual }, new Object[] { expected }); + assertArrayEquals(new Object[] { actual }, new Object[] { expected }); + assertArrayEquals(new Object[] { clone }, new Object[] { expected }); i++; } + assertEquals(bytesPerVector, vectorValues.bytesPerVector); } } diff --git a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java index 53245cc62..c78478f4d 100644 --- a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java +++ b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java @@ -86,7 +86,7 @@ public static void setUpClass() throws IOException { public void testCreateIndex_invalid_engineNotSupported() { expectThrows( IllegalArgumentException.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( new int[] {}, 0, 0, @@ -100,21 +100,14 @@ public void testCreateIndex_invalid_engineNotSupported() { public void testCreateIndex_invalid_engineNull() { expectThrows( Exception.class, - () -> JNIService.createIndex( - new int[] {}, - 0, - 0, - "test", - ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - null - ) + () -> TestUtils.createIndex(new int[] {}, 0, 0, "test", ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), null) ); } public void testCreateIndex_nmslib_invalid_noSpaceType() { expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -133,7 +126,7 @@ public void testCreateIndex_nmslib_invalid_vectorDocIDMismatch() throws IOExcept Path tmpFile1 = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress, vectors1[0].length, @@ -149,7 +142,7 @@ public void testCreateIndex_nmslib_invalid_vectorDocIDMismatch() throws IOExcept Path tmpFile2 = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress2, vectors2[0].length, @@ -168,7 +161,7 @@ public void testCreateIndex_nmslib_invalid_nullArgument() throws IOException { Path tmpFile = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( null, memoryAddress, 0, @@ -180,7 +173,7 @@ public void testCreateIndex_nmslib_invalid_nullArgument() throws IOException { expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, 0, 0, @@ -192,7 +185,7 @@ public void testCreateIndex_nmslib_invalid_nullArgument() throws IOException { expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress, 0, @@ -204,12 +197,12 @@ public void testCreateIndex_nmslib_invalid_nullArgument() throws IOException { expectThrows( Exception.class, - () -> JNIService.createIndex(docIds, memoryAddress, 0, tmpFile.toAbsolutePath().toString(), null, KNNEngine.NMSLIB) + () -> TestUtils.createIndex(docIds, memoryAddress, 0, tmpFile.toAbsolutePath().toString(), null, KNNEngine.NMSLIB) ); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress, 0, @@ -228,7 +221,7 @@ public void testCreateIndex_nmslib_invalid_badSpace() throws IOException { Path tmpFile = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress, vectors[0].length, @@ -254,7 +247,7 @@ public void testCreateIndex_nmslib_invalid_badParameterType() throws IOException Path tmpFile = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress, vectors[0].length, @@ -274,7 +267,7 @@ public void testCreateIndex_nmslib_valid() throws IOException { Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -286,7 +279,7 @@ public void testCreateIndex_nmslib_valid() throws IOException { tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -310,7 +303,7 @@ public void testCreateIndex_faiss_invalid_noSpaceType() { expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -329,7 +322,7 @@ public void testCreateIndex_faiss_invalid_vectorDocIDMismatch() throws IOExcepti Path tmpFile1 = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress, vectors1[0].length, @@ -344,7 +337,7 @@ public void testCreateIndex_faiss_invalid_vectorDocIDMismatch() throws IOExcepti Path tmpFile2 = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress, vectors2[0].length, @@ -364,7 +357,7 @@ public void testCreateIndex_faiss_invalid_null() throws IOException { Path tmpFile = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( null, memoryAddress, 0, @@ -376,7 +369,7 @@ public void testCreateIndex_faiss_invalid_null() throws IOException { expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, 0, 0, @@ -388,7 +381,7 @@ public void testCreateIndex_faiss_invalid_null() throws IOException { expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -400,7 +393,7 @@ public void testCreateIndex_faiss_invalid_null() throws IOException { expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -412,7 +405,7 @@ public void testCreateIndex_faiss_invalid_null() throws IOException { expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -432,7 +425,7 @@ public void testCreateIndex_faiss_invalid_invalidSpace() throws IOException { Path tmpFile = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress, vectors[0].length, @@ -452,7 +445,7 @@ public void testCreateIndex_faiss_invalid_noIndexDescription() throws IOExceptio Path tmpFile = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress, vectors[0].length, @@ -470,7 +463,7 @@ public void testCreateIndex_faiss_invalid_invalidIndexDescription() throws IOExc Path tmpFile = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress, vectors[0].length, @@ -493,7 +486,7 @@ public void testCreateIndex_faiss_sqfp16_invalidIndexDescription() { Path tmpFile = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, memoryAddress, vectors[0].length, @@ -517,7 +510,7 @@ public void testLoadIndex_faiss_sqfp16_valid() { String sqfp16IndexDescription = "HNSW16,SQfp16"; long memoryAddress = JNICommons.storeVectorData(0, vectors, (long) vectors.length * vectors[0].length); Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( docIds, memoryAddress, vectors[0].length, @@ -540,7 +533,7 @@ public void testQueryIndex_faiss_sqfp16_valid() { float[][] truncatedVectors = truncateToFp16Range(testData.indexData.vectors); long memoryAddress = JNICommons.storeVectorData(0, truncatedVectors, (long) truncatedVectors.length * truncatedVectors[0].length); Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, memoryAddress, testData.indexData.getDimension(), @@ -634,7 +627,7 @@ public void testCreateIndex_faiss_invalid_invalidParameterType() throws IOExcept Path tmpFile = createTempFile(); expectThrows( Exception.class, - () -> JNIService.createIndex( + () -> TestUtils.createIndex( docIds, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -660,7 +653,7 @@ public void testCreateIndex_faiss_valid() throws IOException { for (String method : methods) { for (SpaceType spaceType : spaces) { Path tmpFile1 = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -677,7 +670,7 @@ public void testCreateIndex_faiss_valid() throws IOException { public void testCreateIndex_binary_faiss_valid() { Path tmpFile1 = createTempFile(); long memoryAddr = testData.loadBinaryDataToMemoryAddress(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, memoryAddr, testData.indexData.getDimension(), @@ -733,7 +726,7 @@ public void testLoadIndex_nmslib_valid() throws IOException { Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -769,7 +762,7 @@ public void testLoadIndex_faiss_valid() throws IOException { Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -799,7 +792,7 @@ public void testQueryIndex_nmslib_invalid_nullQueryVector() throws IOException { Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -829,7 +822,7 @@ public void testQueryIndex_nmslib_valid() throws IOException { Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -862,7 +855,7 @@ public void testQueryIndex_faiss_invalid_nullQueryVector() throws IOException { Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -888,7 +881,7 @@ public void testQueryIndex_faiss_valid() throws IOException { for (String method : methods) { for (SpaceType spaceType : spaces) { Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -949,7 +942,7 @@ public void testQueryIndex_faiss_parentIds() throws IOException { for (String method : methods) { for (SpaceType spaceType : spaces) { Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testDataNested.indexData.docs, testData.loadDataToMemoryAddress(), testDataNested.indexData.getDimension(), @@ -992,7 +985,7 @@ public void testQueryBinaryIndex_faiss_valid() { for (String method : methods) { Path tmpFile = createTempFile(); long memoryAddr = testData.loadBinaryDataToMemoryAddress(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, memoryAddr, testData.indexData.getDimension(), @@ -1071,7 +1064,7 @@ public void testFree_nmslib_valid() throws IOException { Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -1095,7 +1088,7 @@ public void testFree_faiss_valid() throws IOException { Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), @@ -1235,7 +1228,7 @@ private long transferVectors(int numDuplicates) { return trainPointer1; } - public void testCreateIndexFromTemplate() throws IOException { + public void createIndexFromTemplate() throws IOException { long trainPointer1 = JNIService.transferVectors(0, testData.indexData.vectors); assertNotEquals(0, trainPointer1); @@ -1445,7 +1438,7 @@ private String createFaissIVFPQIndex(int ivfNlist, int pqM, int pqCodeSize, Spac private String createFaissHNSWIndex(SpaceType spaceType) throws IOException { Path tmpFile = createTempFile(); - JNIService.createIndex( + TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), diff --git a/src/testFixtures/java/org/opensearch/knn/TestUtils.java b/src/testFixtures/java/org/opensearch/knn/TestUtils.java index 6676ee154..6bbbc8a5b 100644 --- a/src/testFixtures/java/org/opensearch/knn/TestUtils.java +++ b/src/testFixtures/java/org/opensearch/knn/TestUtils.java @@ -19,7 +19,9 @@ import java.io.IOException; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.codec.util.SerializationMode; +import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.jni.JNICommons; +import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.plugin.script.KNNScoringUtil; import java.util.Collections; @@ -397,11 +399,11 @@ private void initBinaryData() { } public long loadDataToMemoryAddress() { - return JNICommons.storeVectorData(0, indexData.vectors, (long) indexData.vectors.length * indexData.vectors[0].length); + return JNICommons.storeVectorData(0, indexData.vectors, (long) indexData.vectors.length * indexData.vectors[0].length, true); } public long loadBinaryDataToMemoryAddress() { - return JNICommons.storeByteVectorData(0, indexBinaryData, (long) indexBinaryData.length * indexBinaryData[0].length); + return JNICommons.storeByteVectorData(0, indexBinaryData, (long) indexBinaryData.length * indexBinaryData[0].length, true); } @AllArgsConstructor @@ -414,4 +416,15 @@ public static class Pair { public float[][] vectors; } } + + public static void createIndex(int[] ids, long address, int dimension, String name, Map parameters, KNNEngine engine) { + if (engine != KNNEngine.FAISS) { + JNIService.createIndex(ids, address, dimension, name, parameters, engine); + } else { + // We can initialize numDocs as 0, this will just not reserve anything. + long indexAddress = JNIService.initIndex(0, dimension, parameters, engine); + JNIService.insertToIndex(ids, address, dimension, parameters, indexAddress, engine); + JNIService.writeIndex(name, indexAddress, engine, parameters); + } + } }